Просмотр исходного кода

UDP tracker: Use unsigned integer for transaction ID and connection ID

Tatsuhiro Tsujikawa 9 лет назад
Родитель
Сommit
dd678b7c65
4 измененных файлов с 43 добавлено и 45 удалено
  1. 19 16
      src/UDPTrackerClient.cc
  2. 1 1
      src/UDPTrackerClient.h
  3. 3 3
      src/UDPTrackerRequest.h
  4. 20 25
      test/UDPTrackerClientTest.cc

+ 19 - 16
src/UDPTrackerClient.cc

@@ -57,9 +57,12 @@ void failRequest(InputIterator first, InputIterator last, int error)
 } // namespace
 
 namespace {
-int32_t generateTransactionId()
+uint32_t generateTransactionId()
 {
-  return SimpleRandomizer::getInstance()->getRandomNumber(INT32_MAX);
+  uint32_t res;
+  SimpleRandomizer::getInstance()->getRandomBytes(
+      reinterpret_cast<unsigned char*>(&res), sizeof(res));
+  return res;
 }
 } // namespace
 
@@ -76,9 +79,9 @@ void logInvalidLength(const std::string& remoteAddr, uint16_t remotePort,
 
 namespace {
 void logInvalidTransaction(const std::string& remoteAddr, uint16_t remotePort,
-                           int action, int32_t transactionId)
+                           int action, uint32_t transactionId)
 {
-  A2_LOG_INFO(fmt("UDPT received %s reply from %s:%u invalid transaction_id=%d",
+  A2_LOG_INFO(fmt("UDPT received %s reply from %s:%u invalid transaction_id=%u",
                   getUDPTrackerActionStr(action), remoteAddr.c_str(),
                   remotePort, transactionId));
 }
@@ -139,7 +142,7 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length,
       logInvalidLength(remoteAddr, remotePort, action, 16, length);
       return -1;
     }
-    int32_t transactionId = bittorrent::getIntParam(data, 4);
+    auto transactionId = bittorrent::getIntParam(data, 4);
     std::shared_ptr<UDPTrackerRequest> req =
         findInflightRequest(remoteAddr, remotePort, transactionId, true);
     if (!req) {
@@ -148,9 +151,9 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length,
     }
     req->state = UDPT_STA_COMPLETE;
 
-    int64_t connectionId = bittorrent::getLLIntParam(data, 8);
+    auto connectionId = bittorrent::getLLIntParam(data, 8);
     A2_LOG_INFO(fmt("UDPT received CONNECT reply from %s:%u transaction_id=%u,"
-                    "connection_id=%" PRId64,
+                    "connection_id=%" PRIu64,
                     remoteAddr.c_str(), remotePort, transactionId,
                     connectionId));
     UDPTrackerConnection c(UDPT_CST_CONNECTED, connectionId, now);
@@ -170,7 +173,7 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length,
       logTooShortLength(remoteAddr, remotePort, action, 20, length);
       return -1;
     }
-    int32_t transactionId = bittorrent::getIntParam(data, 4);
+    auto transactionId = bittorrent::getIntParam(data, 4);
     std::shared_ptr<UDPTrackerRequest> req =
         findInflightRequest(remoteAddr, remotePort, transactionId, true);
     if (!req) {
@@ -197,7 +200,7 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length,
     }
 
     A2_LOG_INFO(fmt("UDPT received ANNOUNCE reply from %s:%u transaction_id=%u,"
-                    "connection_id=%" PRId64 ", event=%s, infohash=%s, "
+                    "connection_id=%" PRIu64 ", event=%s, infohash=%s, "
                     "interval=%d, leechers=%d, "
                     "seeders=%d, num_peers=%d",
                     remoteAddr.c_str(), remotePort, transactionId,
@@ -211,7 +214,7 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length,
       logTooShortLength(remoteAddr, remotePort, action, 8, length);
       return -1;
     }
-    int32_t transactionId = bittorrent::getIntParam(data, 4);
+    auto transactionId = bittorrent::getIntParam(data, 4);
     std::shared_ptr<UDPTrackerRequest> req =
         findInflightRequest(remoteAddr, remotePort, transactionId, true);
     if (!req) {
@@ -225,7 +228,7 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length,
     req->error = UDPT_ERR_TRACKER;
 
     A2_LOG_INFO(fmt("UDPT received ERROR reply from %s:%u transaction_id=%u,"
-                    "connection_id=%" PRId64 ", action=%d, error_string=%s",
+                    "connection_id=%" PRIu64 ", action=%d, error_string=%s",
                     remoteAddr.c_str(), remotePort, transactionId,
                     req->connectionId, action, errorString.c_str()));
     if (req->action == UDPT_ACT_CONNECT) {
@@ -301,7 +304,7 @@ void UDPTrackerClient::requestSent(const Timer& now)
     break;
   case UDPT_ACT_ANNOUNCE:
     A2_LOG_INFO(fmt("UDPT sent ANNOUNCE to %s:%u transaction_id=%u, "
-                    "connection_id=%" PRId64 ", event=%s, infohash=%s",
+                    "connection_id=%" PRIu64 ", event=%s, infohash=%s",
                     req->remoteAddr.c_str(), req->remotePort,
                     req->transactionId, req->connectionId,
                     getUDPTrackerEventStr(req->event),
@@ -339,7 +342,7 @@ void UDPTrackerClient::requestFail(int error)
     break;
   case UDPT_ACT_ANNOUNCE:
     A2_LOG_INFO(fmt("UDPT fail ANNOUNCE to %s:%u transaction_id=%u, "
-                    "connection_id=%" PRId64 ", event=%s, infohash=%s",
+                    "connection_id=%" PRIu64 ", event=%s, infohash=%s",
                     req->remoteAddr.c_str(), req->remotePort,
                     req->transactionId, req->connectionId,
                     getUDPTrackerEventStr(req->event),
@@ -376,7 +379,7 @@ struct TimeoutCheck {
           break;
         case UDPT_ACT_ANNOUNCE:
           A2_LOG_INFO(fmt("UDPT resend ANNOUNCE to %s:%u transaction_id=%u, "
-                          "connection_id=%" PRId64 ", event=%s, infohash=%s",
+                          "connection_id=%" PRIu64 ", event=%s, infohash=%s",
                           req->remoteAddr.c_str(), req->remotePort,
                           req->transactionId, req->connectionId,
                           getUDPTrackerEventStr(req->event),
@@ -406,7 +409,7 @@ struct TimeoutCheck {
           break;
         case UDPT_ACT_ANNOUNCE:
           A2_LOG_INFO(fmt("UDPT timeout ANNOUNCE to %s:%u transaction_id=%u, "
-                          "connection_id=%" PRId64 ", event=%s, infohash=%s",
+                          "connection_id=%" PRIu64 ", event=%s, infohash=%s",
                           req->remoteAddr.c_str(), req->remotePort,
                           req->transactionId, req->connectionId,
                           getUDPTrackerEventStr(req->event),
@@ -450,7 +453,7 @@ void UDPTrackerClient::handleTimeout(const Timer& now)
 std::shared_ptr<UDPTrackerRequest>
 UDPTrackerClient::findInflightRequest(const std::string& remoteAddr,
                                       uint16_t remotePort,
-                                      int32_t transactionId, bool remove)
+                                      uint32_t transactionId, bool remove)
 {
   std::shared_ptr<UDPTrackerRequest> res;
   for (auto i = inflightRequests_.begin(), eoi = inflightRequests_.end();

+ 1 - 1
src/UDPTrackerClient.h

@@ -136,7 +136,7 @@ public:
 private:
   std::shared_ptr<UDPTrackerRequest>
   findInflightRequest(const std::string& remoteAddr, uint16_t remotePort,
-                      int32_t transactionId, bool remove);
+                      uint32_t transactionId, bool remove);
 
   UDPTrackerConnection* getConnectionId(const std::string& remoteAddr,
                                         uint16_t remotePort, const Timer& now);

+ 3 - 3
src/UDPTrackerRequest.h

@@ -71,7 +71,7 @@ enum UDPTrackerEvent {
 
 struct UDPTrackerReply {
   int32_t action;
-  int32_t transactionId;
+  uint32_t transactionId;
   int32_t interval;
   int32_t leechers;
   int32_t seeders;
@@ -82,9 +82,9 @@ struct UDPTrackerReply {
 struct UDPTrackerRequest {
   std::string remoteAddr;
   uint16_t remotePort;
-  int64_t connectionId;
+  uint64_t connectionId;
   int32_t action;
-  int32_t transactionId;
+  uint32_t transactionId;
   std::string infohash;
   std::string peerId;
   int64_t downloaded;

+ 20 - 25
test/UDPTrackerClientTest.cc

@@ -36,10 +36,10 @@ CPPUNIT_TEST_SUITE_REGISTRATION(UDPTrackerClientTest);
 namespace {
 std::shared_ptr<UDPTrackerRequest> createAnnounce(const std::string& remoteAddr,
                                                   uint16_t remotePort,
-                                                  int32_t transactionId)
+                                                  uint32_t transactionId)
 {
   std::shared_ptr<UDPTrackerRequest> req(new UDPTrackerRequest());
-  req->connectionId = INT64_MAX;
+  req->connectionId = std::numeric_limits<uint64_t>::max();
   req->action = UDPT_ACT_ANNOUNCE;
   req->remoteAddr = remoteAddr;
   req->remotePort = remotePort;
@@ -60,8 +60,8 @@ std::shared_ptr<UDPTrackerRequest> createAnnounce(const std::string& remoteAddr,
 } // namespace
 
 namespace {
-ssize_t createErrorReply(unsigned char* data, size_t len, int32_t transactionId,
-                         const std::string& errorString)
+ssize_t createErrorReply(unsigned char* data, size_t len,
+                         uint32_t transactionId, const std::string& errorString)
 {
   bittorrent::setIntParam(data, UDPT_ACT_ERROR);
   bittorrent::setIntParam(data + 4, transactionId);
@@ -72,7 +72,7 @@ ssize_t createErrorReply(unsigned char* data, size_t len, int32_t transactionId,
 
 namespace {
 ssize_t createConnectReply(unsigned char* data, size_t len,
-                           uint64_t connectionId, int32_t transactionId)
+                           uint64_t connectionId, uint32_t transactionId)
 {
   bittorrent::setIntParam(data, UDPT_ACT_CONNECT);
   bittorrent::setIntParam(data + 4, transactionId);
@@ -83,7 +83,7 @@ ssize_t createConnectReply(unsigned char* data, size_t len,
 
 namespace {
 ssize_t createAnnounceReply(unsigned char* data, size_t len,
-                            int32_t transactionId, int numPeers = 0)
+                            uint32_t transactionId, int numPeers = 0)
 {
   bittorrent::setIntParam(data, UDPT_ACT_ANNOUNCE);
   bittorrent::setIntParam(data + 4, transactionId);
@@ -116,8 +116,7 @@ void UDPTrackerClientTest::testCreateUDPTrackerConnect()
   CPPUNIT_ASSERT_EQUAL((int64_t)UDPT_INITIAL_CONNECTION_ID,
                        (int64_t)bittorrent::getLLIntParam(data, 0));
   CPPUNIT_ASSERT_EQUAL((int)req->action, (int)bittorrent::getIntParam(data, 8));
-  CPPUNIT_ASSERT_EQUAL(req->transactionId,
-                       (int32_t)bittorrent::getIntParam(data, 12));
+  CPPUNIT_ASSERT_EQUAL(req->transactionId, bittorrent::getIntParam(data, 12));
 }
 
 void UDPTrackerClientTest::testCreateUDPTrackerAnnounce()
@@ -130,11 +129,9 @@ void UDPTrackerClientTest::testCreateUDPTrackerAnnounce()
   ssize_t rv =
       createUDPTrackerAnnounce(data, sizeof(data), remoteAddr, remotePort, req);
   CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
-  CPPUNIT_ASSERT_EQUAL(req->connectionId,
-                       (int64_t)bittorrent::getLLIntParam(data, 0));
+  CPPUNIT_ASSERT_EQUAL(req->connectionId, bittorrent::getLLIntParam(data, 0));
   CPPUNIT_ASSERT_EQUAL((int)req->action, (int)bittorrent::getIntParam(data, 8));
-  CPPUNIT_ASSERT_EQUAL(req->transactionId,
-                       (int32_t)bittorrent::getIntParam(data, 12));
+  CPPUNIT_ASSERT_EQUAL(req->transactionId, bittorrent::getIntParam(data, 12));
   CPPUNIT_ASSERT_EQUAL(req->infohash, std::string(&data[16], &data[36]));
   CPPUNIT_ASSERT_EQUAL(req->peerId, std::string(&data[36], &data[56]));
   CPPUNIT_ASSERT_EQUAL(req->downloaded,
@@ -177,7 +174,7 @@ void UDPTrackerClientTest::testConnectFollowedByAnnounce()
   CPPUNIT_ASSERT_EQUAL(req1->remotePort, remotePort);
   CPPUNIT_ASSERT_EQUAL((int64_t)UDPT_INITIAL_CONNECTION_ID,
                        (int64_t)bittorrent::getLLIntParam(data, 0));
-  int32_t transactionId = bittorrent::getIntParam(data, 12);
+  uint32_t transactionId = bittorrent::getIntParam(data, 12);
   rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
   // Duplicate CONNECT request was not inserted
   CPPUNIT_ASSERT_EQUAL((size_t)3, tr.getPendingRequests().size());
@@ -191,7 +188,7 @@ void UDPTrackerClientTest::testConnectFollowedByAnnounce()
   CPPUNIT_ASSERT_EQUAL((ssize_t)-1, rv);
   CPPUNIT_ASSERT(tr.getPendingRequests().empty());
 
-  int64_t connectionId = 12345;
+  uint64_t connectionId = 12345;
   rv = createConnectReply(data, sizeof(data), connectionId, transactionId);
   rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
   CPPUNIT_ASSERT_EQUAL(0, (int)rv);
@@ -202,8 +199,7 @@ void UDPTrackerClientTest::testConnectFollowedByAnnounce()
   // Creates announce for req1
   CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
   CPPUNIT_ASSERT_EQUAL((size_t)2, tr.getPendingRequests().size());
-  CPPUNIT_ASSERT_EQUAL(connectionId,
-                       (int64_t)bittorrent::getLLIntParam(data, 0));
+  CPPUNIT_ASSERT_EQUAL(connectionId, bittorrent::getLLIntParam(data, 0));
   CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE,
                        (int)bittorrent::getIntParam(data, 8));
   CPPUNIT_ASSERT_EQUAL(req1->infohash, std::string(&data[16], &data[36]));
@@ -212,18 +208,17 @@ void UDPTrackerClientTest::testConnectFollowedByAnnounce()
   // Don't duplicate same request data
   CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
   CPPUNIT_ASSERT_EQUAL((size_t)2, tr.getPendingRequests().size());
-  int32_t transactionId1 = bittorrent::getIntParam(data, 12);
+  uint32_t transactionId1 = bittorrent::getIntParam(data, 12);
 
   tr.requestSent(now);
   CPPUNIT_ASSERT_EQUAL((size_t)1, tr.getPendingRequests().size());
 
   rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
-  int32_t transactionId2 = bittorrent::getIntParam(data, 12);
+  uint32_t transactionId2 = bittorrent::getIntParam(data, 12);
   // Creates announce for req2
   CPPUNIT_ASSERT_EQUAL((ssize_t)100, rv);
   CPPUNIT_ASSERT_EQUAL((size_t)1, tr.getPendingRequests().size());
-  CPPUNIT_ASSERT_EQUAL(connectionId,
-                       (int64_t)bittorrent::getLLIntParam(data, 0));
+  CPPUNIT_ASSERT_EQUAL(connectionId, bittorrent::getLLIntParam(data, 0));
   CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_ANNOUNCE,
                        (int)bittorrent::getIntParam(data, 8));
   CPPUNIT_ASSERT_EQUAL(req2->infohash, std::string(&data[16], &data[36]));
@@ -316,7 +311,7 @@ void UDPTrackerClientTest::testRequestFailure()
     rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);
     CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
                          (int)bittorrent::getIntParam(data, 8));
-    int32_t transactionId = bittorrent::getIntParam(data, 12);
+    uint32_t transactionId = bittorrent::getIntParam(data, 12);
     tr.requestSent(now);
 
     rv = createErrorReply(data, sizeof(data), transactionId, "error");
@@ -338,10 +333,10 @@ void UDPTrackerClientTest::testRequestFailure()
     CPPUNIT_ASSERT_EQUAL((ssize_t)16, rv);
     CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
                          (int)bittorrent::getIntParam(data, 8));
-    int32_t transactionId = bittorrent::getIntParam(data, 12);
+    uint32_t transactionId = bittorrent::getIntParam(data, 12);
     tr.requestSent(now);
 
-    int64_t connectionId = 12345;
+    uint64_t connectionId = 12345;
     rv = createConnectReply(data, sizeof(data), connectionId, transactionId);
     rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
     CPPUNIT_ASSERT_EQUAL(0, (int)rv);
@@ -414,10 +409,10 @@ void UDPTrackerClientTest::testTimeout()
     CPPUNIT_ASSERT_EQUAL((ssize_t)16, rv);
     CPPUNIT_ASSERT_EQUAL((int)UDPT_ACT_CONNECT,
                          (int)bittorrent::getIntParam(data, 8));
-    int32_t transactionId = bittorrent::getIntParam(data, 12);
+    uint32_t transactionId = bittorrent::getIntParam(data, 12);
     tr.requestSent(now);
 
-    int64_t connectionId = 12345;
+    uint64_t connectionId = 12345;
     rv = createConnectReply(data, sizeof(data), connectionId, transactionId);
     rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now);
     CPPUNIT_ASSERT_EQUAL(0, (int)rv);