Ver código fonte

Rewritten DHTRoutingTableDeserializer using stdio instead of stream.

Tatsuhiro Tsujikawa 14 anos atrás
pai
commit
f141cd4228

+ 32 - 48
src/DHTRoutingTableDeserializer.cc

@@ -36,7 +36,7 @@
 
 #include <cstring>
 #include <cassert>
-#include <istream>
+#include <cstdio>
 #include <utility>
 
 #include "DHTNode.h"
@@ -56,27 +56,27 @@ DHTRoutingTableDeserializer::DHTRoutingTableDeserializer(int family):
 
 DHTRoutingTableDeserializer::~DHTRoutingTableDeserializer() {}
 
+#define FREAD_CHECK(ptr, count, fp)                                     \
+  if(fread((ptr), 1, (count), (fp)) != (count)) {                       \
+    throw DL_ABORT_EX("Failed to load DHT routing table.");             \
+  }
+
 namespace {
 void readBytes(unsigned char* buf, size_t buflen,
-               std::istream& in, size_t readlen)
+               FILE* fp, size_t readlen)
 {
   assert(readlen <= buflen);
-  in.read(reinterpret_cast<char*>(buf), readlen);
+  FREAD_CHECK(buf, readlen, fp);
 }
 } // namespace
 
-#define CHECK_STREAM(in, length)                                        \
-  if(in.gcount() != length) {                                           \
-    throw DL_ABORT_EX                                                   \
-      (fmt("Failed to load DHT routing table. cause:%s",                \
-           "Unexpected EOF"));                                          \
-  }                                                                     \
-  if(!in) {                                                             \
-    throw DL_ABORT_EX("Failed to load DHT routing table.");             \
-  }
-
-void DHTRoutingTableDeserializer::deserialize(std::istream& in)
+void DHTRoutingTableDeserializer::deserialize(const std::string& filename)
 {
+  FILE* fp = a2fopen(utf8ToWChar(filename).c_str(), "rb");
+  if(!fp) {
+    throw DL_ABORT_EX("Failed to load DHT routing table.");
+  }
+  auto_delete_r<FILE*, int> deleter(fp, fclose);
   char header[8];
   memset(header, 0, sizeof(header));
   // magic
@@ -109,8 +109,7 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
   array_wrapper<unsigned char, 255> buf;
 
   // header
-  readBytes(buf, buf.size(), in, 8);
-  CHECK_STREAM(in, 8);
+  readBytes(buf, buf.size(), fp, 8);
   if(memcmp(header, buf, 8) == 0) {
     version = 3;
   } else if(memcmp(headerCompat, buf, 8) == 0) {
@@ -125,37 +124,29 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
   uint64_t temp64;
   // time
   if(version == 2) {
-    in.read(reinterpret_cast<char*>(&temp32), sizeof(temp32));
-    CHECK_STREAM(in, sizeof(temp32));
+    FREAD_CHECK(&temp32, sizeof(temp32), fp);
     serializedTime_.setTimeInSec(ntohl(temp32));
     // 4bytes reserved
-    readBytes(buf, buf.size(), in, 4);
-    CHECK_STREAM(in, 4);
+    readBytes(buf, buf.size(), fp, 4);
   } else {
-    in.read(reinterpret_cast<char*>(&temp64), sizeof(temp64));
-    CHECK_STREAM(in, sizeof(temp64));
+    FREAD_CHECK(&temp64, sizeof(temp64), fp);
     serializedTime_.setTimeInSec(ntoh64(temp64));
   }
   
   // localnode
   // 8bytes reserved
-  readBytes(buf, buf.size(), in, 8);
-  CHECK_STREAM(in, 8);
+  readBytes(buf, buf.size(), fp, 8);
   // localnode ID
-  readBytes(buf, buf.size(), in, DHT_ID_LENGTH);
-  CHECK_STREAM(in, DHT_ID_LENGTH);
+  readBytes(buf, buf.size(), fp, DHT_ID_LENGTH);
   SharedHandle<DHTNode> localNode(new DHTNode(buf));
   // 4bytes reserved
-  readBytes(buf, buf.size(), in, 4);
-  CHECK_STREAM(in, 4);
+  readBytes(buf, buf.size(), fp, 4);
 
   // number of nodes
-  in.read(reinterpret_cast<char*>(&temp32), sizeof(temp32));
-  CHECK_STREAM(in, sizeof(temp32));
+  FREAD_CHECK(&temp32, sizeof(temp32), fp);
   uint32_t numNodes = ntohl(temp32);
   // 4bytes reserved
-  readBytes(buf, buf.size(), in, 4);
-  CHECK_STREAM(in, 4);
+  readBytes(buf, buf.size(), fp, 4);
 
   std::vector<SharedHandle<DHTNode> > nodes;
   // nodes
@@ -163,45 +154,38 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
   for(size_t i = 0; i < numNodes; ++i) {
     // 1byte compact peer info length
     uint8_t peerInfoLen;
-    in >> peerInfoLen;
+    FREAD_CHECK(&peerInfoLen, sizeof(peerInfoLen), fp);
     if(peerInfoLen != compactlen) {
       // skip this entry
-      readBytes(buf, buf.size(), in, 7+48);
-      CHECK_STREAM(in, 7+48);
+      readBytes(buf, buf.size(), fp, 7+48);
       continue;
     }
     // 7bytes reserved
-    readBytes(buf, buf.size(), in, 7);
-    CHECK_STREAM(in, 7);
+    readBytes(buf, buf.size(), fp, 7);
     // compactlen bytes compact peer info
-    readBytes(buf, buf.size(), in, compactlen);
-    CHECK_STREAM(in, compactlen);
+    readBytes(buf, buf.size(), fp, compactlen);
     if(memcmp(zero, buf, compactlen) == 0) {
       // skip this entry
-      readBytes(buf, buf.size(), in, 48-compactlen);
-      CHECK_STREAM(in, 48-compactlen);
+      readBytes(buf, buf.size(), fp, 48-compactlen);
       continue;
     }
     std::pair<std::string, uint16_t> peer =
       bittorrent::unpackcompact(buf, family_);
     if(peer.first.empty()) {
       // skip this entry
-      readBytes(buf, buf.size(), in, 48-compactlen);
-      CHECK_STREAM(in, 48-compactlen);
+      readBytes(buf, buf.size(), fp, 48-compactlen);
       continue;
     }
     // 24-compactlen bytes reserved
-    readBytes(buf, buf.size(), in, 24-compactlen);
+    readBytes(buf, buf.size(), fp, 24-compactlen);
     // node ID
-    readBytes(buf, buf.size(), in, DHT_ID_LENGTH);
-    CHECK_STREAM(in, DHT_ID_LENGTH);
+    readBytes(buf, buf.size(), fp, DHT_ID_LENGTH);
 
     SharedHandle<DHTNode> node(new DHTNode(buf));
     node->setIPAddress(peer.first);
     node->setPort(peer.second);
     // 4bytes reserved
-    readBytes(buf, buf.size(), in, 4);
-    CHECK_STREAM(in, 4);
+    readBytes(buf, buf.size(), fp, 4);
 
     nodes.push_back(node);
   }

+ 2 - 2
src/DHTRoutingTableDeserializer.h

@@ -38,7 +38,7 @@
 #include "common.h"
 
 #include <vector>
-#include <iosfwd>
+#include <string>
 
 #include "SharedHandle.h"
 #include "TimeA2.h"
@@ -76,7 +76,7 @@ public:
     return serializedTime_;
   }
 
-  void deserialize(std::istream& in);
+  void deserialize(const std::string& filename);
 };
 
 } // namespace aria2

+ 1 - 5
src/DHTSetup.cc

@@ -99,11 +99,7 @@ void DHTSetup::setup
       e->getOption()->get(family == AF_INET?PREF_DHT_FILE_PATH:
                           PREF_DHT_FILE_PATH6);
     try {
-      std::ifstream in(dhtFile.c_str(), std::ios::binary);
-      if(!in) {
-        throw DL_ABORT_EX("Could not open file");
-      }
-      deserializer.deserialize(in);
+      deserializer.deserialize(dhtFile);
       localNode = deserializer.getLocalNode();
     } catch(RecoverableException& e) {
       A2_LOG_ERROR_EX

+ 10 - 6
test/DHTRoutingTableDeserializerTest.cc

@@ -54,11 +54,13 @@ void DHTRoutingTableDeserializerTest::testDeserialize()
   s.setLocalNode(localNode);
   s.setNodes(nodes);
 
-  std::stringstream ss;
-  s.serialize(ss);
+  std::string filename = A2_TEST_OUT_DIR"/aria2_DHTRoutingTableDeserializerTest_testDeserialize";
+  std::ofstream outfile(filename.c_str(), std::ios::binary);
+  s.serialize(outfile);
+  outfile.close();
 
   DHTRoutingTableDeserializer d(AF_INET);
-  d.deserialize(ss);
+  d.deserialize(filename);
 
   CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(),
                         DHT_ID_LENGTH) == 0);
@@ -93,11 +95,13 @@ void DHTRoutingTableDeserializerTest::testDeserialize6()
   s.setLocalNode(localNode);
   s.setNodes(nodes);
 
-  std::stringstream ss;
-  s.serialize(ss);
+  std::string filename = A2_TEST_OUT_DIR"/aria2_DHTRoutingTableDeserializerTest_testDeserialize6";
+  std::ofstream outfile(filename.c_str(), std::ios::binary);
+  s.serialize(outfile);
+  outfile.close();
 
   DHTRoutingTableDeserializer d(AF_INET6);
-  d.deserialize(ss);
+  d.deserialize(filename);
 
   CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(),
                         DHT_ID_LENGTH) == 0);