瀏覽代碼

Rewritten LpdMessageReceiver::receiveMessage()

Tatsuhiro Tsujikawa 13 年之前
父節點
當前提交
b9f972665b
共有 3 個文件被更改,包括 29 次插入26 次删除
  1. 27 16
      src/LpdMessageReceiver.cc
  2. 0 4
      src/LpdReceiveMessageCommand.cc
  3. 2 6
      test/LpdMessageReceiverTest.cc

+ 27 - 16
src/LpdMessageReceiver.cc

@@ -80,21 +80,37 @@ bool LpdMessageReceiver::init(const std::string& localAddr)
 SharedHandle<LpdMessage> LpdMessageReceiver::receiveMessage()
 {
   SharedHandle<LpdMessage> msg;
-  try {
+  while(1) {
     unsigned char buf[200];
     std::pair<std::string, uint16_t> peerAddr;
-    ssize_t length = socket_->readDataFrom(buf, sizeof(buf), peerAddr);
-    if(length == 0) {
+    ssize_t length;
+    try {
+      length = socket_->readDataFrom(buf, sizeof(buf), peerAddr);
+      if(length == 0) {
+        return msg;
+      }
+    } catch(RecoverableException& e) {
+      A2_LOG_INFO_EX("Failed to receive LPD message.", e);
       return msg;
     }
     HttpHeaderProcessor proc(HttpHeaderProcessor::SERVER_PARSER);
-    if(!proc.parse(buf, length)) {
-      msg.reset(new LpdMessage());
-      return msg;
+    try {
+      if(!proc.parse(buf, length)) {
+        // UDP packet must contain whole HTTP header block.
+        continue;
+      }
+    } catch(RecoverableException& e) {
+      A2_LOG_INFO_EX("Failed to parse LPD message.", e);
+      continue;
     }
     const SharedHandle<HttpHeader>& header = proc.getResult();
     const std::string& infoHashString = header->find(HttpHeader::INFOHASH);
-    uint16_t port = header->findAsInt(HttpHeader::PORT);
+    uint32_t port = 0;
+    if(!util::parseUIntNoThrow(port, header->find(HttpHeader::PORT)) ||
+       port > UINT16_MAX || port == 0) {
+      A2_LOG_INFO(fmt("Bad LPD port=%u", port));
+      continue;
+    }
     A2_LOG_INFO(fmt("LPD message received infohash=%s, port=%u from %s",
                     infoHashString.c_str(),
                     port,
@@ -102,11 +118,10 @@ SharedHandle<LpdMessage> LpdMessageReceiver::receiveMessage()
     std::string infoHash;
     if(infoHashString.size() != 40 ||
        (infoHash = util::fromHex(infoHashString.begin(),
-                                 infoHashString.end())).empty() ||
-       port == 0) {
-      A2_LOG_INFO(fmt("LPD bad request. infohash=%s", infoHashString.c_str()));
-      msg.reset(new LpdMessage());
-      return msg;
+                                 infoHashString.end())).empty()) {
+      A2_LOG_INFO(fmt("LPD bad request. infohash=%s",
+                      infoHashString.c_str()));
+      continue;
     }
     SharedHandle<Peer> peer(new Peer(peerAddr.first, port, false));
     if(util::inPrivateAddress(peerAddr.first)) {
@@ -114,10 +129,6 @@ SharedHandle<LpdMessage> LpdMessageReceiver::receiveMessage()
     }
     msg.reset(new LpdMessage(peer, infoHash));
     return msg;
-  } catch(RecoverableException& e) {
-    A2_LOG_INFO_EX("Failed to receive LPD message.", e);
-    msg.reset(new LpdMessage());
-    return msg;
   }
 }
 

+ 0 - 4
src/LpdReceiveMessageCommand.cc

@@ -77,10 +77,6 @@ bool LpdReceiveMessageCommand::execute()
     if(!m) {
       break;
     }
-    if(!m->peer) {
-      // bad message
-      continue;
-    }
     SharedHandle<BtRegistry> reg = e_->getBtRegistry();
     SharedHandle<DownloadContext> dctx = reg->getDownloadContext(m->infoHash);
     if(!dctx) {

+ 2 - 6
test/LpdMessageReceiverTest.cc

@@ -67,9 +67,7 @@ void LpdMessageReceiverTest::testReceiveMessage()
 
   rcv.getSocket()->isReadable(5);
   msg = rcv.receiveMessage();
-  CPPUNIT_ASSERT(msg);
-  CPPUNIT_ASSERT(!msg->peer);
-  CPPUNIT_ASSERT(msg->infoHash.empty());
+  CPPUNIT_ASSERT(!msg);
 
   // Bad port
   request =
@@ -81,9 +79,7 @@ void LpdMessageReceiverTest::testReceiveMessage()
 
   rcv.getSocket()->isReadable(5);
   msg = rcv.receiveMessage();
-  CPPUNIT_ASSERT(msg);
-  CPPUNIT_ASSERT(!msg->peer);
-  CPPUNIT_ASSERT(msg->infoHash.empty());
+  CPPUNIT_ASSERT(!msg);
 
   // No data available
   msg = rcv.receiveMessage();