Browse Source

Return ExtensionMessage subclass create return raw pointer

Tatsuhiro Tsujikawa 13 years ago
parent
commit
4b94ede268

+ 10 - 11
src/DefaultExtensionMessageFactory.cc

@@ -78,11 +78,11 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
   uint8_t extensionMessageID = *data;
   if(extensionMessageID == 0) {
     // handshake
-    SharedHandle<HandshakeExtensionMessage> m =
+    HandshakeExtensionMessage* m =
       HandshakeExtensionMessage::create(data, length);
     m->setPeer(peer_);
     m->setDownloadContext(dctx_);
-    return m;
+    return SharedHandle<ExtensionMessage>(m);
   } else {
     const char* extensionName = registry_->getExtensionName(extensionMessageID);
     if(!extensionName) {
@@ -92,10 +92,9 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
     }
     if(strcmp(extensionName, "ut_pex") == 0) {
       // uTorrent compatible Peer-Exchange
-      SharedHandle<UTPexExtensionMessage> m =
-        UTPexExtensionMessage::create(data, length);
+      UTPexExtensionMessage* m = UTPexExtensionMessage::create(data, length);
       m->setPeerStorage(peerStorage_);
-      return m;
+      return SharedHandle<ExtensionMessage>(m);
     } else if(strcmp(extensionName, "ut_metadata") == 0) {
       if(length == 0) {
         throw DL_ABORT_EX
@@ -120,14 +119,14 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
       }
       switch(msgType->i()) {
       case 0: {
-        SharedHandle<UTMetadataRequestExtensionMessage> m
+        UTMetadataRequestExtensionMessage* m
           (new UTMetadataRequestExtensionMessage(extensionMessageID));
         m->setIndex(index->i());
         m->setDownloadContext(dctx_);
         m->setPeer(peer_);
         m->setBtMessageFactory(messageFactory_);
         m->setBtMessageDispatcher(dispatcher_);
-        return m;
+        return SharedHandle<ExtensionMessage>(m);
       }
       case 1: {
         if(end == length) {
@@ -137,7 +136,7 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
         if(!totalSize) {
           throw DL_ABORT_EX("Bad ut_metadata data: total_size not found");
         }
-        SharedHandle<UTMetadataDataExtensionMessage> m
+        UTMetadataDataExtensionMessage* m
           (new UTMetadataDataExtensionMessage(extensionMessageID));
         m->setIndex(index->i());
         m->setTotalSize(totalSize->i());
@@ -145,14 +144,14 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
         m->setUTMetadataRequestTracker(tracker_);
         m->setPieceStorage(dctx_->getOwnerRequestGroup()->getPieceStorage());
         m->setDownloadContext(dctx_);
-        return m;
+        return SharedHandle<ExtensionMessage>(m);
       }
       case 2: {
-        SharedHandle<UTMetadataRejectExtensionMessage> m
+        UTMetadataRejectExtensionMessage* m
           (new UTMetadataRejectExtensionMessage(extensionMessageID));
         m->setIndex(index->i());
         // No need to inject tracker because peer will be disconnected.
-        return m;
+        return SharedHandle<ExtensionMessage>(m);
       }
       default:
         throw DL_ABORT_EX

+ 2 - 2
src/HandshakeExtensionMessage.cc

@@ -164,7 +164,7 @@ uint8_t HandshakeExtensionMessage::getExtensionMessageID(int key) const
   return extreg_.getExtensionMessageID(key);
 }
 
-SharedHandle<HandshakeExtensionMessage>
+HandshakeExtensionMessage*
 HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
 {
   if(length < 1) {
@@ -172,7 +172,6 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
       (fmt(MSG_TOO_SMALL_PAYLOAD_SIZE,
            EXTENSION_NAME, static_cast<unsigned long>(length)));
   }
-  SharedHandle<HandshakeExtensionMessage> msg(new HandshakeExtensionMessage());
   A2_LOG_DEBUG(fmt("Creating HandshakeExtensionMessage from %s",
                    util::percentEncode(data, length).c_str()));
   SharedHandle<ValueBase> decoded = bencode2::decode(data+1, length - 1);
@@ -181,6 +180,7 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
     throw DL_ABORT_EX
       ("Unexpected payload format for extended message handshake");
   }
+  HandshakeExtensionMessage* msg(new HandshakeExtensionMessage());
   const Integer* port = downcast<Integer>(dict->get("p"));
   if(port && 0 < port->i() && port->i() < 65536) {
     msg->tcpPort_ = port->i();

+ 1 - 1
src/HandshakeExtensionMessage.h

@@ -125,7 +125,7 @@ public:
 
   void setPeer(const SharedHandle<Peer>& peer);
 
-  static SharedHandle<HandshakeExtensionMessage>
+  static HandshakeExtensionMessage*
   create(const unsigned char* data, size_t dataLength);
 };
 

+ 2 - 2
src/UTPexExtensionMessage.cc

@@ -181,7 +181,7 @@ void UTPexExtensionMessage::setPeerStorage
   peerStorage_ = peerStorage;
 }
 
-SharedHandle<UTPexExtensionMessage>
+UTPexExtensionMessage*
 UTPexExtensionMessage::create(const unsigned char* data, size_t len)
 {
   if(len < 1) {
@@ -189,7 +189,7 @@ UTPexExtensionMessage::create(const unsigned char* data, size_t len)
                           EXTENSION_NAME,
                           static_cast<unsigned long>(len)));
   }
-  SharedHandle<UTPexExtensionMessage> msg(new UTPexExtensionMessage(*data));
+  UTPexExtensionMessage* msg(new UTPexExtensionMessage(*data));
 
   SharedHandle<ValueBase> decoded = bencode2::decode(data+1, len - 1);
   const Dict* dict = downcast<Dict>(decoded);

+ 1 - 1
src/UTPexExtensionMessage.h

@@ -111,7 +111,7 @@ public:
 
   void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage);
 
-  static SharedHandle<UTPexExtensionMessage>
+  static UTPexExtensionMessage*
   create(const unsigned char* data, size_t len);
 
   void setMaxFreshPeer(size_t maxFreshPeer);

+ 10 - 9
test/HandshakeExtensionMessageTest.cc

@@ -136,11 +136,12 @@ void HandshakeExtensionMessageTest::testDoReceivedAction()
 
 void HandshakeExtensionMessageTest::testCreate()
 {
-  std::string in = 
+  std::string in =
     "0d1:pi6881e1:v5:aria21:md5:a2dhti2e6:ut_pexi1ee13:metadata_sizei1024ee";
-  SharedHandle<HandshakeExtensionMessage> m =
-    HandshakeExtensionMessage::create(reinterpret_cast<const unsigned char*>(in.c_str()),
-                                      in.size());
+  SharedHandle<HandshakeExtensionMessage> m
+    (HandshakeExtensionMessage::create
+     (reinterpret_cast<const unsigned char*>(in.c_str()),
+      in.size()));
   CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
   CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort());
   CPPUNIT_ASSERT_EQUAL((uint8_t)1,
@@ -173,16 +174,16 @@ void HandshakeExtensionMessageTest::testCreate()
     CPPUNIT_FAIL("exception must be thrown.");
   } catch(Exception& e) {
     std::cerr << e.stackTrace() << std::endl;
-  }    
+  }
 }
 
 void HandshakeExtensionMessageTest::testCreate_stringnum()
 {
   std::string in = "0d1:p4:68811:v5:aria21:md6:ut_pex1:1ee";
-  SharedHandle<HandshakeExtensionMessage> m =
-    HandshakeExtensionMessage::create
-    (reinterpret_cast<const unsigned char*>(in.c_str()),
-     in.size());
+  SharedHandle<HandshakeExtensionMessage> m
+    (HandshakeExtensionMessage::create
+     (reinterpret_cast<const unsigned char*>(in.c_str()),
+      in.size()));
   CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
   // port number in string is not allowed
   CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort());

+ 5 - 5
test/UTPexExtensionMessageTest.cc

@@ -202,10 +202,10 @@ void UTPexExtensionMessageTest::testCreate()
     "7:dropped12:"+std::string(&c3[0], &c3[6])+std::string(&c4[0], &c4[6])+
     "8:dropped618:"+std::string(&c6[0], &c6[COMPACT_LEN_IPV6])+
     "e";
-  
-  SharedHandle<UTPexExtensionMessage> msg =
-    UTPexExtensionMessage::create
-    (reinterpret_cast<const unsigned char*>(data.c_str()), data.size());
+
+  SharedHandle<UTPexExtensionMessage> msg
+    (UTPexExtensionMessage::create
+     (reinterpret_cast<const unsigned char*>(data.c_str()), data.size()));
   CPPUNIT_ASSERT_EQUAL((uint8_t)1, msg->getExtensionMessageID());
   CPPUNIT_ASSERT_EQUAL((size_t)3, msg->getFreshPeers().size());
   CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"),
@@ -238,7 +238,7 @@ void UTPexExtensionMessageTest::testCreate()
     CPPUNIT_FAIL("exception must be thrown.");
   } catch(Exception& e) {
     std::cerr << e.stackTrace() << std::endl;
-  }    
+  }
 }
 
 void UTPexExtensionMessageTest::testAddFreshPeer()