Bläddra i källkod

Rewritten NAT check handling.

We simplified PeerConnection::receiveHandshake().
DefaultBtMessageReceiver and PeerReceiverHandshakeCommand look
PeerConnection's buffer and do NAT check handling themselves.
Tatsuhiro Tsujikawa 14 år sedan
förälder
incheckning
3e67079087
3 ändrade filer med 57 tillägg och 47 borttagningar
  1. 29 10
      src/DefaultBtMessageReceiver.cc
  2. 16 30
      src/PeerConnection.cc
  3. 12 7
      src/PeerReceiveHandshakeCommand.cc

+ 29 - 10
src/DefaultBtMessageReceiver.cc

@@ -47,6 +47,9 @@
 #include "LogFactory.h"
 #include "bittorrent_helper.h"
 #include "BtPieceMessage.h"
+#include "util.h"
+#include "fmt.h"
+#include "DlAbortEx.h"
 
 namespace aria2 {
 
@@ -60,24 +63,40 @@ DefaultBtMessageReceiver::DefaultBtMessageReceiver():
 SharedHandle<BtHandshakeMessage>
 DefaultBtMessageReceiver::receiveHandshake(bool quickReply)
 {
+  A2_LOG_DEBUG
+    (fmt("Receiving handshake bufferLength=%lu",
+         static_cast<unsigned long>(peerConnection_->getBufferLength())));
   unsigned char data[BtHandshakeMessage::MESSAGE_LENGTH];
   size_t dataLength = BtHandshakeMessage::MESSAGE_LENGTH;
-  bool retval = peerConnection_->receiveHandshake(data, dataLength);
-  // To handle tracker's NAT-checking feature
-  if(!handshakeSent_ && quickReply && dataLength >= 48) {
+  SharedHandle<BtHandshakeMessage> msg;
+  if(handshakeSent_ || !quickReply || peerConnection_->getBufferLength() < 48) {
+    if(peerConnection_->receiveHandshake(data, dataLength)) {
+      msg = messageFactory_->createHandshakeMessage(data, dataLength);
+      msg->validate();
+    }
+  }
+  // Handle tracker's NAT-checking feature
+  if(!handshakeSent_ && quickReply && peerConnection_->getBufferLength() >= 48){
     handshakeSent_ = true;
     // check info_hash
-    if(memcmp(bittorrent::getInfoHash(downloadContext_), &data[28],
+    if(memcmp(bittorrent::getInfoHash(downloadContext_),
+              peerConnection_->getBuffer()+28,
               INFO_HASH_LENGTH) == 0) {
       sendHandshake();
+    } else {
+      throw DL_ABORT_EX
+        (fmt("Bad Info Hash %s",
+             util::toHex(peerConnection_->getBuffer()+28,
+                         INFO_HASH_LENGTH).c_str()));
+    }
+    if(!msg &&
+       peerConnection_->getBufferLength() ==
+       BtHandshakeMessage::MESSAGE_LENGTH &&
+       peerConnection_->receiveHandshake(data, dataLength)) {
+      msg = messageFactory_->createHandshakeMessage(data, dataLength);
+      msg->validate();
     }
   }
-  if(!retval) {
-    return SharedHandle<BtHandshakeMessage>();
-  }
-  SharedHandle<BtHandshakeMessage> msg =
-    messageFactory_->createHandshakeMessage(data, dataLength);
-  msg->validate();
   return msg;
 }
 

+ 16 - 30
src/PeerConnection.cc

@@ -182,36 +182,22 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
       ("More than BtHandshakeMessage::MESSAGE_LENGTH bytes are buffered.");
   }
   bool retval = true;
-  if(((!prevPeek_ && peek) || (prevPeek_ && !peek)) && resbufLength_) {
-    // We have data in previous peek.
-    // There is a chance that socket is readable because of EOF, for example,
-    // official bttrack shutdowns socket after sending first 48 bytes of
-    // handshake in its NAT checking.
-    // So if there are data in resbuf_, return it without checking socket
-    // status.
-    //
-    // (!prevPeek_ && peek) effectively returns preset buffer.
-    prevPeek_ = false;
-    retval = BtHandshakeMessage::MESSAGE_LENGTH <= resbufLength_;
-  } else {
-    prevPeek_ = peek;
-    size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength_;
-    if(remaining > 0) {
-      size_t temp = remaining;
-      readData(resbuf_+resbufLength_, remaining, encryptionEnabled_);
-      if(remaining == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
-        // we got EOF
-        A2_LOG_DEBUG
-          (fmt("CUID#%lld - In PeerConnection::receiveHandshake(), remain=%lu",
-               cuid_,
-               static_cast<unsigned long>(temp)));
-        peer_->setDisconnectedGracefully(true);
-        throw DL_ABORT_EX(EX_EOF_FROM_PEER);
-      }
-      resbufLength_ += remaining;
-      if(BtHandshakeMessage::MESSAGE_LENGTH > resbufLength_) {
-        retval = false;
-      }
+  size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength_;
+  if(remaining > 0) {
+    size_t temp = remaining;
+    readData(resbuf_+resbufLength_, remaining, encryptionEnabled_);
+    if(remaining == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
+      // we got EOF
+      A2_LOG_DEBUG
+        (fmt("CUID#%lld - In PeerConnection::receiveHandshake(), remain=%lu",
+             cuid_,
+             static_cast<unsigned long>(temp)));
+      peer_->setDisconnectedGracefully(true);
+      throw DL_ABORT_EX(EX_EOF_FROM_PEER);
+    }
+    resbufLength_ += remaining;
+    if(BtHandshakeMessage::MESSAGE_LENGTH > resbufLength_) {
+      retval = false;
     }
   }
   size_t writeLength = std::min(resbufLength_, dataLength);

+ 12 - 7
src/PeerReceiveHandshakeCommand.cc

@@ -33,6 +33,9 @@
  */
 /* copyright --> */
 #include "PeerReceiveHandshakeCommand.h"
+
+#include <cstring>
+
 #include "PeerConnection.h"
 #include "DownloadEngine.h"
 #include "BtHandshakeMessage.h"
@@ -87,13 +90,15 @@ bool PeerReceiveHandshakeCommand::exitBeforeExecute()
 
 bool PeerReceiveHandshakeCommand::executeInternal()
 {
-  unsigned char data[BtHandshakeMessage::MESSAGE_LENGTH];
-  size_t dataLength = BtHandshakeMessage::MESSAGE_LENGTH;
-  // ignore return value. The received data is kept in PeerConnection object
-  // because of peek = true.
-  peerConnection_->receiveHandshake(data, dataLength, true);
-  // To handle tracker's NAT-checking feature
-  if(dataLength >= 48) {
+  // Handle tracker's NAT-checking feature
+  if(peerConnection_->getBufferLength() < 48) {
+    size_t dataLength = 0;
+    // Ignore return value. The received data is kept in
+    // PeerConnection object because of peek = true.
+    peerConnection_->receiveHandshake(0, dataLength, true);
+  }
+  if(peerConnection_->getBufferLength() >= 48) {
+    const unsigned char* data = peerConnection_->getBuffer();
     // check info_hash
     std::string infoHash = std::string(&data[28], &data[28+INFO_HASH_LENGTH]);