Преглед изворни кода

Use std::unique_ptr to receive BtMessage

Tatsuhiro Tsujikawa пре 12 година
родитељ
комит
c6a733378f

+ 5 - 3
src/BtInteractive.h

@@ -41,7 +41,8 @@
 
 namespace aria2 {
 
-class BtMessage;
+
+class BtHandshakeMessage;
 
 class BtInteractive {
 public:
@@ -49,9 +50,10 @@ public:
 
   virtual void initiateHandshake() = 0;
 
-  virtual std::shared_ptr<BtMessage> receiveHandshake(bool quickReply = false) = 0;
+  virtual std::unique_ptr<BtHandshakeMessage> receiveHandshake
+  (bool quickReply = false) = 0;
 
-  virtual std::shared_ptr<BtMessage> receiveAndSendHandshake() = 0;
+  virtual std::unique_ptr<BtHandshakeMessage> receiveAndSendHandshake() = 0;
 
   virtual void doPostHandshakeProcessing() = 0;
 

+ 4 - 3
src/BtMessageReceiver.h

@@ -48,11 +48,12 @@ class BtMessageReceiver {
 public:
   virtual ~BtMessageReceiver() {}
 
-  virtual std::shared_ptr<BtHandshakeMessage> receiveHandshake(bool quickReply = false) = 0;
+  virtual std::unique_ptr<BtHandshakeMessage> receiveHandshake
+  (bool quickReply = false) = 0;
 
-  virtual std::shared_ptr<BtHandshakeMessage> receiveAndSendHandshake() = 0;
+  virtual std::unique_ptr<BtHandshakeMessage> receiveAndSendHandshake() = 0;
 
-  virtual std::shared_ptr<BtMessage> receiveMessage() = 0;
+  virtual std::unique_ptr<BtMessage> receiveMessage() = 0;
 };
 
 } // namespace aria2

+ 10 - 10
src/DefaultBtInteractive.cc

@@ -119,9 +119,9 @@ void DefaultBtInteractive::initiateHandshake() {
   dispatcher_->sendMessages();
 }
 
-std::shared_ptr<BtMessage> DefaultBtInteractive::receiveHandshake(bool quickReply) {
-  std::shared_ptr<BtHandshakeMessage> message =
-    btMessageReceiver_->receiveHandshake(quickReply);
+std::unique_ptr<BtHandshakeMessage> DefaultBtInteractive::receiveHandshake
+(bool quickReply) {
+  auto message = btMessageReceiver_->receiveHandshake(quickReply);
   if(!message) {
     return nullptr;
   }
@@ -131,11 +131,9 @@ std::shared_ptr<BtMessage> DefaultBtInteractive::receiveHandshake(bool quickRepl
       (fmt("CUID#%" PRId64 " - Drop connection from the same Peer ID",
            cuid_));
   }
-  const PeerSet& usedPeers = peerStorage_->getUsedPeers();
-  for(PeerSet::const_iterator i = usedPeers.begin(), eoi = usedPeers.end();
-      i != eoi; ++i) {
-    if((*i)->isActive() &&
-       memcmp((*i)->getPeerId(), message->getPeerId(), PEER_ID_LENGTH) == 0) {
+  for(auto& peer : peerStorage_->getUsedPeers()) {
+    if(peer->isActive() &&
+       memcmp(peer->getPeerId(), message->getPeerId(), PEER_ID_LENGTH) == 0) {
       throw DL_ABORT_EX
         (fmt("CUID#%" PRId64 " - Same Peer ID has been already seen.",
              cuid_));
@@ -166,7 +164,9 @@ std::shared_ptr<BtMessage> DefaultBtInteractive::receiveHandshake(bool quickRepl
   return message;
 }
 
-std::shared_ptr<BtMessage> DefaultBtInteractive::receiveAndSendHandshake() {
+std::unique_ptr<BtHandshakeMessage>
+DefaultBtInteractive::receiveAndSendHandshake()
+{
   return receiveHandshake(true);
 }
 
@@ -297,7 +297,7 @@ size_t DefaultBtInteractive::receiveMessages() {
        downloadContext_->getOwnerRequestGroup()->doesDownloadSpeedExceed()) {
       break;
     }
-    std::shared_ptr<BtMessage> message = btMessageReceiver_->receiveMessage();
+    auto message = btMessageReceiver_->receiveMessage();
     if(!message) {
       break;
     }

+ 3 - 2
src/DefaultBtInteractive.h

@@ -173,9 +173,10 @@ public:
 
   virtual void initiateHandshake();
 
-  virtual std::shared_ptr<BtMessage> receiveHandshake(bool quickReply = false);
+  virtual std::unique_ptr<BtHandshakeMessage> receiveHandshake
+  (bool quickReply = false);
 
-  virtual std::shared_ptr<BtMessage> receiveAndSendHandshake();
+  virtual std::unique_ptr<BtHandshakeMessage> receiveAndSendHandshake();
 
   virtual void doPostHandshakeProcessing();
 

+ 23 - 21
src/DefaultBtMessageReceiver.cc

@@ -54,14 +54,14 @@
 namespace aria2 {
 
 DefaultBtMessageReceiver::DefaultBtMessageReceiver():
-  handshakeSent_(false),
-  downloadContext_{0},
-  peerConnection_(0),
-  dispatcher_(0),
-  messageFactory_(0)
+  handshakeSent_{false},
+  downloadContext_{nullptr},
+  peerConnection_{nullptr},
+  dispatcher_{nullptr},
+  messageFactory_{nullptr}
 {}
 
-std::shared_ptr<BtHandshakeMessage>
+std::unique_ptr<BtHandshakeMessage>
 DefaultBtMessageReceiver::receiveHandshake(bool quickReply)
 {
   A2_LOG_DEBUG
@@ -69,15 +69,14 @@ DefaultBtMessageReceiver::receiveHandshake(bool quickReply)
          static_cast<unsigned long>(peerConnection_->getBufferLength())));
   unsigned char data[BtHandshakeMessage::MESSAGE_LENGTH];
   size_t dataLength = BtHandshakeMessage::MESSAGE_LENGTH;
-  std::shared_ptr<BtHandshakeMessage> msg;
   if(handshakeSent_ || !quickReply || peerConnection_->getBufferLength() < 48) {
     if(peerConnection_->receiveHandshake(data, dataLength)) {
-      msg = messageFactory_->createHandshakeMessage(data, dataLength);
+      auto msg = messageFactory_->createHandshakeMessage(data, dataLength);
       msg->validate();
+      return msg;
     }
-  }
-  // Handle tracker's NAT-checking feature
-  if(!handshakeSent_ && quickReply && peerConnection_->getBufferLength() >= 48){
+  } else {
+    // Handle tracker's NAT-checking feature
     handshakeSent_ = true;
     // check info_hash
     if(memcmp(bittorrent::getInfoHash(downloadContext_),
@@ -90,24 +89,25 @@ DefaultBtMessageReceiver::receiveHandshake(bool quickReply)
              util::toHex(peerConnection_->getBuffer()+28,
                          INFO_HASH_LENGTH).c_str()));
     }
-    if(!msg &&
-       peerConnection_->getBufferLength() ==
+    if(peerConnection_->getBufferLength() ==
        BtHandshakeMessage::MESSAGE_LENGTH &&
        peerConnection_->receiveHandshake(data, dataLength)) {
-      msg = messageFactory_->createHandshakeMessage(data, dataLength);
+      auto msg = messageFactory_->createHandshakeMessage(data, dataLength);
       msg->validate();
+      return msg;
     }
   }
-  return msg;
+  return nullptr;
 }
 
-std::shared_ptr<BtHandshakeMessage>
+std::unique_ptr<BtHandshakeMessage>
 DefaultBtMessageReceiver::receiveAndSendHandshake()
 {
   return receiveHandshake(true);
 }
 
-void DefaultBtMessageReceiver::sendHandshake() {
+void DefaultBtMessageReceiver::sendHandshake()
+{
   dispatcher_->addMessageToQueue
     (messageFactory_->createHandshakeMessage
      (bittorrent::getInfoHash(downloadContext_),
@@ -115,18 +115,19 @@ void DefaultBtMessageReceiver::sendHandshake() {
   dispatcher_->sendMessages();
 }
 
-std::shared_ptr<BtMessage> DefaultBtMessageReceiver::receiveMessage() {
+std::unique_ptr<BtMessage> DefaultBtMessageReceiver::receiveMessage()
+{
   size_t dataLength = 0;
   // Give 0 to PeerConnection::receiveMessage() to prevent memcpy.
   if(!peerConnection_->receiveMessage(0, dataLength)) {
     return nullptr;
   }
-  std::shared_ptr<BtMessage> msg =
+  auto msg =
     messageFactory_->createBtMessage(peerConnection_->getMsgPayloadBuffer(),
                                      dataLength);
   msg->validate();
   if(msg->getId() == BtPieceMessage::ID) {
-    auto piecemsg = std::static_pointer_cast<BtPieceMessage>(msg);
+    auto piecemsg = static_cast<BtPieceMessage*>(msg.get());
     piecemsg->setMsgPayload(peerConnection_->getMsgPayloadBuffer());
   }
   return msg;
@@ -138,7 +139,8 @@ void DefaultBtMessageReceiver::setDownloadContext
   downloadContext_ = downloadContext;
 }
 
-void DefaultBtMessageReceiver::setPeerConnection(PeerConnection* peerConnection)
+void DefaultBtMessageReceiver::setPeerConnection
+(PeerConnection* peerConnection)
 {
   peerConnection_ = peerConnection;
 }

+ 3 - 3
src/DefaultBtMessageReceiver.h

@@ -58,12 +58,12 @@ private:
 public:
   DefaultBtMessageReceiver();
 
-  virtual std::shared_ptr<BtHandshakeMessage> receiveHandshake
+  virtual std::unique_ptr<BtHandshakeMessage> receiveHandshake
   (bool quickReply = false);
 
-  virtual std::shared_ptr<BtHandshakeMessage> receiveAndSendHandshake();
+  virtual std::unique_ptr<BtHandshakeMessage> receiveAndSendHandshake();
 
-  virtual std::shared_ptr<BtMessage> receiveMessage();
+  virtual std::unique_ptr<BtMessage> receiveMessage();
 
   void setDownloadContext(DownloadContext* downloadContext);
 

+ 3 - 4
src/PeerInteractionCommand.cc

@@ -47,6 +47,7 @@
 #include "DownloadContext.h"
 #include "Peer.h"
 #include "BtMessage.h"
+#include "BtHandshakeMessage.h"
 #include "BtRuntime.h"
 #include "PeerStorage.h"
 #include "DefaultBtMessageDispatcher.h"
@@ -327,8 +328,7 @@ bool PeerInteractionCommand::executeInternal() {
           break;
         }
       }
-      std::shared_ptr<BtMessage> handshakeMessage =
-        btInteractive_->receiveHandshake();
+      auto handshakeMessage = btInteractive_->receiveHandshake();
       if(!handshakeMessage) {
         done = true;
         break;
@@ -338,8 +338,7 @@ bool PeerInteractionCommand::executeInternal() {
       break;
     }
     case RECEIVER_WAIT_HANDSHAKE: {
-      std::shared_ptr<BtMessage> handshakeMessage =
-        btInteractive_->receiveAndSendHandshake();
+      auto handshakeMessage = btInteractive_->receiveAndSendHandshake();
       if(!handshakeMessage) {
         done = true;
         break;