فهرست منبع

2008-10-26 Tatsuhiro Tsujikawa <tujikawa at rednoah dot com>

	Changed signature of DHTMessageFactory::createResponseMessage().
	Removed unused validateIDMatch.
	* src/DHTMessageFactory.h
	* src/DHTMessageFactoryImpl.cc
	* src/DHTMessageFactoryImpl.h
	* src/DHTMessageTracker.cc
	* test/DHTMessageFactoryImplTest.cc
	* test/MockDHTMessageFactory.h

	Dropped DHT message coming from same ID of localhost.
	* src/DHTMessageReceiver.cc

	Rejected adding node whose ID is the same as localhost's.
	* src/DHTRoutingTable.cc
	* test/BtPortMessageTest.cc
	* test/DHTRoutingTableTest.cc
Tatsuhiro Tsujikawa 17 سال پیش
والد
کامیت
8a920ba5e3

+ 19 - 0
ChangeLog

@@ -1,3 +1,22 @@
+2008-10-26  Tatsuhiro Tsujikawa  <tujikawa at rednoah dot com>
+
+	Changed signature of DHTMessageFactory::createResponseMessage().
+	Removed unused validateIDMatch.
+	* src/DHTMessageFactory.h
+	* src/DHTMessageFactoryImpl.cc
+	* src/DHTMessageFactoryImpl.h
+	* src/DHTMessageTracker.cc
+	* test/DHTMessageFactoryImplTest.cc
+	* test/MockDHTMessageFactory.h
+
+	Dropped DHT message coming from same ID of localhost.
+	* src/DHTMessageReceiver.cc
+
+	Rejected adding node whose ID is the same as localhost's.
+	* src/DHTRoutingTable.cc
+	* test/BtPortMessageTest.cc
+	* test/DHTRoutingTableTest.cc
+	
 2008-10-23  Tatsuhiro Tsujikawa  <tujikawa at rednoah dot com>
 
 	Pool connection when redirection occurs with Content-Length = 0.

+ 5 - 3
src/DHTMessageFactory.h

@@ -36,11 +36,13 @@
 #define _D_DHT_MESSAGE_FACTORY_H_
 
 #include "common.h"
-#include "SharedHandle.h"
-#include "A2STR.h"
+
 #include <string>
 #include <deque>
 
+#include "SharedHandle.h"
+#include "A2STR.h"
+
 namespace aria2 {
 
 class DHTMessage;
@@ -59,7 +61,7 @@ public:
   virtual SharedHandle<DHTMessage>
   createResponseMessage(const std::string& messageType,
 			const Dictionary* d,
-			const SharedHandle<DHTNode>& remoteNode) = 0;
+			const std::string& ipaddr, uint16_t port) = 0;
 
   virtual SharedHandle<DHTMessage>
   createPingMessage(const SharedHandle<DHTNode>& remoteNode,

+ 8 - 11
src/DHTMessageFactoryImpl.cc

@@ -33,6 +33,10 @@
  */
 /* copyright --> */
 #include "DHTMessageFactoryImpl.h"
+
+#include <cstring>
+#include <utility>
+
 #include "LogFactory.h"
 #include "DlAbortEx.h"
 #include "Data.h"
@@ -60,8 +64,6 @@
 #include "Peer.h"
 #include "Logger.h"
 #include "StringFormat.h"
-#include <cstring>
-#include <utility>
 
 namespace aria2 {
 
@@ -135,13 +137,6 @@ void DHTMessageFactoryImpl::validateID(const Data* id) const
   }
 }
 
-void DHTMessageFactoryImpl::validateIDMatch(const unsigned char* expected, const unsigned char* actual) const
-{
-  if(memcmp(expected, actual, DHT_ID_LENGTH) != 0) {
-    //throw DlAbortEx("Different ID received.");
-  }
-}
-
 void DHTMessageFactoryImpl::validatePort(const Data* i) const
 {
   if(!i->isNumber()) {
@@ -202,7 +197,8 @@ SharedHandle<DHTMessage> DHTMessageFactoryImpl::createQueryMessage(const Diction
 SharedHandle<DHTMessage>
 DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType,
 					     const Dictionary* d,
-					     const SharedHandle<DHTNode>& remoteNode)
+					     const std::string& ipaddr,
+					     uint16_t port)
 {
   const Data* t = getData(d, DHTMessage::T);
   const Data* y = getData(d, DHTMessage::Y);
@@ -225,7 +221,8 @@ DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType,
   const Dictionary* r = getDictionary(d, DHTResponseMessage::R);
   const Data* id = getData(r, DHTMessage::ID);
   validateID(id);
-  validateIDMatch(remoteNode->getID(), id->getData());
+  SharedHandle<DHTNode> remoteNode = getRemoteNode(id->getData(), ipaddr, port);
+
   std::string transactionID = t->toString();
   if(messageType == DHTPingReplyMessage::PING) {
     return createPingReplyMessage(remoteNode,

+ 1 - 3
src/DHTMessageFactoryImpl.h

@@ -71,8 +71,6 @@ private:
 
   void validateID(const Data* id) const;
 
-  void validateIDMatch(const unsigned char* expected, const unsigned char* actual) const;
-
   void validatePort(const Data* i) const;
 
   std::deque<SharedHandle<DHTNode> > extractNodes(const unsigned char* src, size_t length);
@@ -91,7 +89,7 @@ public:
   virtual SharedHandle<DHTMessage>
   createResponseMessage(const std::string& messageType,
 			const Dictionary* d,
-			const SharedHandle<DHTNode>& remoteNode);
+			const std::string& ipaddr, uint16_t port);
 
   virtual SharedHandle<DHTMessage>
   createPingMessage(const SharedHandle<DHTNode>& remoteNode,

+ 9 - 1
src/DHTMessageReceiver.cc

@@ -33,6 +33,10 @@
  */
 /* copyright --> */
 #include "DHTMessageReceiver.h"
+
+#include <cstring>
+#include <utility>
+
 #include "DHTMessageTracker.h"
 #include "DHTConnection.h"
 #include "DHTMessage.h"
@@ -49,7 +53,6 @@
 #include "LogFactory.h"
 #include "Logger.h"
 #include "Util.h"
-#include <utility>
 
 namespace aria2 {
 
@@ -102,6 +105,11 @@ SharedHandle<DHTMessage> DHTMessageReceiver::receiveMessage()
       }
     } else {
       message = _factory->createQueryMessage(d, remoteAddr, remotePort);
+      if(message->getLocalNode() == message->getRemoteNode()) {
+	// drop message from localnode
+	_logger->info("Recieved DHT message from localnode.");
+	return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort);
+      }
     }
     _logger->info("Message received: %s", message->toString().c_str());
     message->validate();

+ 9 - 4
src/DHTMessageTracker.cc

@@ -33,6 +33,9 @@
  */
 /* copyright --> */
 #include "DHTMessageTracker.h"
+
+#include <utility>
+
 #include "DHTMessage.h"
 #include "DHTMessageCallback.h"
 #include "DHTMessageTrackerEntry.h"
@@ -47,7 +50,6 @@
 #include "DlAbortEx.h"
 #include "DHTConstants.h"
 #include "StringFormat.h"
-#include <utility>
 
 namespace aria2 {
 
@@ -86,11 +88,14 @@ DHTMessageTracker::messageArrived(const Dictionary* d,
       _logger->debug("Tracker entry found.");
       SharedHandle<DHTNode> targetNode = entry->getTargetNode();
 
-      SharedHandle<DHTMessage> message = _factory->createResponseMessage(entry->getMessageType(),
-								 d, targetNode);
+      SharedHandle<DHTMessage> message =
+	_factory->createResponseMessage(entry->getMessageType(), d,
+					targetNode->getIPAddress(),
+					targetNode->getPort());
+
       int64_t rtt = entry->getElapsedMillis();
       _logger->debug("RTT is %s", Util::itos(rtt).c_str());
-      targetNode->updateRTT(rtt);
+      message->getRemoteNode()->updateRTT(rtt);
       SharedHandle<DHTMessageCallback> callback = entry->getCallback();
       return std::pair<SharedHandle<DHTMessage>, SharedHandle<DHTMessageCallback> >(message, callback);
     }

+ 8 - 0
src/DHTRoutingTable.cc

@@ -33,6 +33,9 @@
  */
 /* copyright --> */
 #include "DHTRoutingTable.h"
+
+#include <cstring>
+
 #include "DHTNode.h"
 #include "DHTBucket.h"
 #include "BNode.h"
@@ -72,6 +75,11 @@ bool DHTRoutingTable::addGoodNode(const SharedHandle<DHTNode>& node)
 bool DHTRoutingTable::addNode(const SharedHandle<DHTNode>& node, bool good)
 {
   _logger->debug("Trying to add node:%s", node->toString().c_str());
+  if(_localNode == node) {
+    _logger->debug("Adding node with the same ID with localnode is not"
+		   " allowed.");
+    return false;
+  }
   BNode* bnode = BNode::findBNodeFor(_root, node->getID());
   SharedHandle<DHTBucket> bucket = bnode->getBucket();
   while(1) {

+ 1 - 1
test/BtPortMessageTest.cc

@@ -104,7 +104,7 @@ void BtPortMessageTest::testDoReceivedAction()
   SharedHandle<DHTNode> nodes[9];
   for(size_t i = 0; i < arrayLength(nodes); ++i) {
     memset(nodeID, 0, DHT_ID_LENGTH);
-    nodeID[DHT_ID_LENGTH-1] = i;
+    nodeID[DHT_ID_LENGTH-1] = i+1;
     nodes[i].reset(new DHTNode(nodeID));
   }
 

+ 29 - 9
test/DHTMessageFactoryImplTest.cc

@@ -1,4 +1,10 @@
 #include "DHTMessageFactoryImpl.h"
+
+#include <cstring>
+#include <iostream>
+
+#include <cppunit/extensions/HelperMacros.h>
+
 #include "RecoverableException.h"
 #include "Util.h"
 #include "DHTNode.h"
@@ -17,9 +23,6 @@
 #include "DHTGetPeersReplyMessage.h"
 #include "DHTAnnouncePeerMessage.h"
 #include "DHTAnnouncePeerReplyMessage.h"
-#include <cstring>
-#include <iostream>
-#include <cppunit/extensions/HelperMacros.h>
 
 namespace aria2 {
 
@@ -112,7 +115,10 @@ void DHTMessageFactoryImplTest::testCreatePingReplyMessage()
   remoteNode->setPort(6881);
   
   SharedHandle<DHTPingReplyMessage> m
-    (dynamic_pointer_cast<DHTPingReplyMessage>(factory->createResponseMessage("ping", d.get(), remoteNode)));
+    (dynamic_pointer_cast<DHTPingReplyMessage>
+     (factory->createResponseMessage("ping", d.get(),
+				     remoteNode->getIPAddress(),
+				     remoteNode->getPort())));
 
   CPPUNIT_ASSERT(localNode == m->getLocalNode());
   CPPUNIT_ASSERT(remoteNode == m->getRemoteNode());
@@ -176,7 +182,10 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage()
     remoteNode->setPort(6881);
   
     SharedHandle<DHTFindNodeReplyMessage> m
-      (dynamic_pointer_cast<DHTFindNodeReplyMessage>(factory->createResponseMessage("find_node", d.get(), remoteNode)));
+      (dynamic_pointer_cast<DHTFindNodeReplyMessage>
+       (factory->createResponseMessage("find_node", d.get(),
+				       remoteNode->getIPAddress(),
+				       remoteNode->getPort())));
 
     CPPUNIT_ASSERT(localNode == m->getLocalNode());
     CPPUNIT_ASSERT(remoteNode == m->getRemoteNode());
@@ -247,7 +256,10 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage_nodes()
     remoteNode->setPort(6881);
   
     SharedHandle<DHTGetPeersReplyMessage> m
-      (dynamic_pointer_cast<DHTGetPeersReplyMessage>(factory->createResponseMessage("get_peers", d.get(), remoteNode)));
+      (dynamic_pointer_cast<DHTGetPeersReplyMessage>
+       (factory->createResponseMessage("get_peers", d.get(),
+				       remoteNode->getIPAddress(),
+				       remoteNode->getPort())));
 
     CPPUNIT_ASSERT(localNode == m->getLocalNode());
     CPPUNIT_ASSERT(remoteNode == m->getRemoteNode());
@@ -290,7 +302,10 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage_values()
     remoteNode->setPort(6881);
   
     SharedHandle<DHTGetPeersReplyMessage> m
-      (dynamic_pointer_cast<DHTGetPeersReplyMessage>(factory->createResponseMessage("get_peers", d.get(), remoteNode)));
+      (dynamic_pointer_cast<DHTGetPeersReplyMessage>
+       (factory->createResponseMessage("get_peers", d.get(),
+				       remoteNode->getIPAddress(),
+				       remoteNode->getPort())));
 
     CPPUNIT_ASSERT(localNode == m->getLocalNode());
     CPPUNIT_ASSERT(remoteNode == m->getRemoteNode());
@@ -356,7 +371,10 @@ void DHTMessageFactoryImplTest::testCreateAnnouncePeerReplyMessage()
   remoteNode->setPort(6881);
   
   SharedHandle<DHTAnnouncePeerReplyMessage> m
-    (dynamic_pointer_cast<DHTAnnouncePeerReplyMessage>(factory->createResponseMessage("announce_peer", d.get(), remoteNode)));
+    (dynamic_pointer_cast<DHTAnnouncePeerReplyMessage>
+     (factory->createResponseMessage("announce_peer", d.get(),
+				     remoteNode->getIPAddress(),
+				     remoteNode->getPort())));
 
   CPPUNIT_ASSERT(localNode == m->getLocalNode());
   CPPUNIT_ASSERT(remoteNode == m->getRemoteNode());
@@ -379,7 +397,9 @@ void DHTMessageFactoryImplTest::testReceivedErrorMessage()
   remoteNode->setPort(6881);
 
   try {
-    factory->createResponseMessage("announce_peer", d.get(), remoteNode);
+    factory->createResponseMessage("announce_peer", d.get(),
+				   remoteNode->getIPAddress(),
+				   remoteNode->getPort());
     CPPUNIT_FAIL("exception must be thrown.");
   } catch(RecoverableException& e) {
     std::cerr << e.stackTrace() << std::endl;

+ 19 - 2
test/DHTRoutingTableTest.cc

@@ -1,4 +1,8 @@
 #include "DHTRoutingTable.h"
+
+#include <cstring>
+#include <cppunit/extensions/HelperMacros.h>
+
 #include "Exception.h"
 #include "Util.h"
 #include "DHTNode.h"
@@ -6,8 +10,6 @@
 #include "MockDHTTaskQueue.h"
 #include "MockDHTTaskFactory.h"
 #include "DHTTask.h"
-#include <cstring>
-#include <cppunit/extensions/HelperMacros.h>
 
 namespace aria2 {
 
@@ -15,6 +17,7 @@ class DHTRoutingTableTest:public CppUnit::TestFixture {
 
   CPPUNIT_TEST_SUITE(DHTRoutingTableTest);
   CPPUNIT_TEST(testAddNode);
+  CPPUNIT_TEST(testAddNode_localNode);
   CPPUNIT_TEST(testGetClosestKNodes);
   CPPUNIT_TEST_SUITE_END();
 public:
@@ -23,6 +26,7 @@ public:
   void tearDown() {}
 
   void testAddNode();
+  void testAddNode_localNode();
   void testGetClosestKNodes();
 };
 
@@ -47,6 +51,19 @@ void DHTRoutingTableTest::testAddNode()
   table.showBuckets();
 }
 
+void DHTRoutingTableTest::testAddNode_localNode()
+{
+  SharedHandle<DHTNode> localNode(new DHTNode());
+  DHTRoutingTable table(localNode);
+  SharedHandle<MockDHTTaskFactory> taskFactory(new MockDHTTaskFactory());
+  table.setTaskFactory(taskFactory);
+  SharedHandle<MockDHTTaskQueue> taskQueue(new MockDHTTaskQueue());
+  table.setTaskQueue(taskQueue);
+
+  SharedHandle<DHTNode> newNode(new DHTNode(localNode->getID()));
+  CPPUNIT_ASSERT(!table.addNode(newNode));
+}
+
 static void createID(unsigned char* id, unsigned char firstChar, unsigned char lastChar)
 {
   memset(id, 0, DHT_ID_LENGTH);

+ 5 - 1
test/MockDHTMessageFactory.h

@@ -27,8 +27,12 @@ public:
   virtual SharedHandle<DHTMessage>
   createResponseMessage(const std::string& messageType,
 			const Dictionary* d,
-			const SharedHandle<DHTNode>& remoteNode)
+			const std::string& ipaddr, uint16_t port)
   {
+    SharedHandle<DHTNode> remoteNode(new DHTNode());
+    // TODO At this point, removeNode's ID is random.
+    remoteNode->setIPAddress(ipaddr);
+    remoteNode->setPort(port);
     SharedHandle<MockDHTMessage> m
       (new MockDHTMessage(_localNode, remoteNode,
 			  reinterpret_cast<const Data*>(d->get("t"))->toString()));