瀏覽代碼

Rewritten ExtensionMessageRegistry

Tatsuhiro Tsujikawa 13 年之前
父節點
當前提交
c13dc166de

+ 1 - 4
src/BtConstants.h

@@ -36,10 +36,7 @@
 #define D_BT_CONSTANTS_H
 
 #include "common.h"
-#include <map>
-#include <string>
-
-typedef std::map<std::string, uint8_t> Extensions;
+#include <vector>
 
 #define INFO_HASH_LENGTH 20
 

+ 7 - 4
src/DefaultBtInteractive.cc

@@ -140,7 +140,8 @@ BtMessageHandle DefaultBtInteractive::receiveHandshake(bool quickReply) {
   if(message->isExtendedMessagingEnabled()) {
     peer_->setExtendedMessagingEnabled(true);
     if(!utPexEnabled_) {
-      extensionMessageRegistry_->removeExtension("ut_pex");
+      extensionMessageRegistry_->removeExtension
+        (ExtensionMessageRegistry::UT_PEX);
     }
     A2_LOG_INFO(fmt(MSG_EXTENDED_MESSAGING_ENABLED, cuid_));
   }
@@ -472,7 +473,8 @@ void DefaultBtInteractive::addPeerExchangeMessage()
   if(pexTimer_.
      difference(global::wallclock()) >= UTPexExtensionMessage::DEFAULT_INTERVAL) {
     UTPexExtensionMessageHandle m
-      (new UTPexExtensionMessage(peer_->getExtensionMessageID("ut_pex")));
+      (new UTPexExtensionMessage(peer_->getExtensionMessageID
+                                 (ExtensionMessageRegistry::UT_PEX)));
 
     std::vector<SharedHandle<Peer> > activePeers;
     peerStorage_->getActivePeers(activePeers);
@@ -508,7 +510,7 @@ void DefaultBtInteractive::doInteractionProcessing() {
     // HandshakeExtensionMessage::doReceivedAction().
     pieceStorage_ =
       downloadContext_->getOwnerRequestGroup()->getPieceStorage();
-    if(peer_->getExtensionMessageID("ut_metadata") &&
+    if(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA) &&
        downloadContext_->getTotalLength() > 0) {
       size_t num = utMetadataRequestTracker_->avail();
       if(num > 0) {
@@ -549,7 +551,8 @@ void DefaultBtInteractive::doInteractionProcessing() {
       addRequests();
     }
   }
-  if(peer_->getExtensionMessageID("ut_pex") && utPexEnabled_) {
+  if(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_PEX) &&
+     utPexEnabled_) {
     addPeerExchangeMessage();
   }
 

+ 8 - 5
src/DefaultExtensionMessageFactory.cc

@@ -33,6 +33,9 @@
  */
 /* copyright --> */
 #include "DefaultExtensionMessageFactory.h"
+
+#include <cstring>
+
 #include "Peer.h"
 #include "DlAbortEx.h"
 #include "HandshakeExtensionMessage.h"
@@ -81,19 +84,19 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
     m->setDownloadContext(dctx_);
     return m;
   } else {
-    std::string extensionName = registry_->getExtensionName(extensionMessageID);
-    if(extensionName.empty()) {
+    const char* extensionName = registry_->getExtensionName(extensionMessageID);
+    if(!extensionName) {
       throw DL_ABORT_EX
         (fmt("No extension registered for extended message ID %u",
              extensionMessageID));
     }
-    if(extensionName == "ut_pex") {
+    if(strcmp(extensionName, "ut_pex") == 0) {
       // uTorrent compatible Peer-Exchange
       UTPexExtensionMessageHandle m =
         UTPexExtensionMessage::create(data, length);
       m->setPeerStorage(peerStorage_);
       return m;
-    } else if(extensionName == "ut_metadata") {
+    } else if(strcmp(extensionName, "ut_metadata") == 0) {
       if(length == 0) {
         throw DL_ABORT_EX
           (fmt(MSG_TOO_SMALL_PAYLOAD_SIZE,
@@ -160,7 +163,7 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
       throw DL_ABORT_EX
         (fmt("Unsupported extension message received."
              " extensionMessageID=%u, extensionName=%s",
-             extensionMessageID, extensionName.c_str()));
+             extensionMessageID, extensionName));
     }
   }
 }

+ 59 - 24
src/ExtensionMessageRegistry.cc

@@ -2,7 +2,7 @@
 /*
  * aria2 - The high speed download utility
  *
- * Copyright (C) 2010 Tatsuhiro Tsujikawa
+ * Copyright (C) 2012 Tatsuhiro Tsujikawa
  *
  * This program is free software; you can redistribute it and/or modify
  * it under the terms of the GNU General Public License as published by
@@ -33,46 +33,81 @@
  */
 /* copyright --> */
 #include "ExtensionMessageRegistry.h"
-#include "BtConstants.h"
-#include "A2STR.h"
+
+#include <cstring>
+#include <cassert>
 
 namespace aria2 {
 
 ExtensionMessageRegistry::ExtensionMessageRegistry()
-{
-  extensions_["ut_pex"] = 8;
-  // http://www.bittorrent.org/beps/bep_0009.html
-  extensions_["ut_metadata"] = 9;
-}
+  : extensions_(MAX_EXTENSION)
+{}
 
 ExtensionMessageRegistry::~ExtensionMessageRegistry() {}
 
-uint8_t ExtensionMessageRegistry::getExtensionMessageID
-(const std::string& name) const
+namespace {
+const char* EXTENSION_NAMES[] = {
+  "ut_metadata",
+  "ut_pex",
+  0
+};
+} // namespace
+
+uint8_t ExtensionMessageRegistry::getExtensionMessageID(int key) const
+{
+  assert(key < MAX_EXTENSION);
+  return extensions_[key];
+}
+
+const char* ExtensionMessageRegistry::getExtensionName(uint8_t id) const
 {
-  Extensions::const_iterator itr = extensions_.find(name);
-  if(itr == extensions_.end()) {
+  int i;
+  if(id == 0) {
     return 0;
-  } else {
-    return (*itr).second;
   }
+  for(i = 0; i < MAX_EXTENSION; ++i) {
+    if(extensions_[i] == id) {
+      break;
+    }
+  }
+  return EXTENSION_NAMES[i];
 }
 
-const std::string& ExtensionMessageRegistry::getExtensionName(uint8_t id) const
+void ExtensionMessageRegistry::setExtensionMessageID(int key, uint8_t id)
 {
-  for(Extensions::const_iterator itr = extensions_.begin(),
-        eoi = extensions_.end(); itr != eoi; ++itr) {
-    const Extensions::value_type& p = *itr;
-    if(p.second == id) {
-      return p.first;
-    }
+  assert(key < MAX_EXTENSION);
+  extensions_[key] = id;
+}
+
+void ExtensionMessageRegistry::removeExtension(int key)
+{
+  assert(key < MAX_EXTENSION);
+  extensions_[key] = 0;
+}
+
+void ExtensionMessageRegistry::setExtensions(const Extensions& extensions)
+{
+  extensions_ = extensions;
+}
+
+const char* strBtExtension(int key)
+{
+  if(key >= ExtensionMessageRegistry::MAX_EXTENSION) {
+    return 0;
+  } else {
+    return EXTENSION_NAMES[key];
   }
-  return A2STR::NIL;
 }
 
-void ExtensionMessageRegistry::removeExtension(const std::string& name)
+int keyBtExtension(const char* name)
 {
-  extensions_.erase(name);
+  int i;
+  for(i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) {
+    if(strcmp(EXTENSION_NAMES[i], name) == 0) {
+      break;
+    }
+  }
+  return i;
 }
 
 } // namespace aria2

+ 45 - 9
src/ExtensionMessageRegistry.h

@@ -2,7 +2,7 @@
 /*
  * aria2 - The high speed download utility
  *
- * Copyright (C) 2009 Tatsuhiro Tsujikawa
+ * Copyright (C) 2012 Tatsuhiro Tsujikawa
  *
  * This program is free software; you can redistribute it and/or modify
  * it under the terms of the GNU General Public License as published by
@@ -37,16 +37,27 @@
 
 #include "common.h"
 
-#include <string>
-
-#include "BtConstants.h"
+#include <vector>
 
 namespace aria2 {
 
+typedef std::vector<int> Extensions;
+
+// This class stores mapping between BitTorrent entension name and its
+// ID. The BitTorrent Extension Protocol is specified in BEP10.  This
+// class is defined to only stores extensions aria2 supports. See
+// InterestingExtension for supported extensions.
+//
+// See also http://bittorrent.org/beps/bep_0010.html
 class ExtensionMessageRegistry {
-private:
-  Extensions extensions_;
 public:
+  enum InterestingExtension {
+    UT_METADATA,
+    UT_PEX,
+    // The number of extensions.
+    MAX_EXTENSION
+  };
+
   ExtensionMessageRegistry();
 
   ~ExtensionMessageRegistry();
@@ -56,13 +67,38 @@ public:
     return extensions_;
   }
 
-  uint8_t getExtensionMessageID(const std::string& name) const;
+  void setExtensions(const Extensions& extensions);
+
+  // Returns message ID corresponding the given |key|.  The |key| must
+  // be one of InterestingExtension other than MAX_EXTENSION. If
+  // message ID is not defined, returns 0.
+  uint8_t getExtensionMessageID(int key) const;
 
-  const std::string& getExtensionName(uint8_t id) const;
+  // Returns extension name corresponding to the given |id|. If no
+  // extension is defined for the given |id|, returns NULL.
+  const char* getExtensionName(uint8_t id) const;
 
-  void removeExtension(const std::string& name);
+  // Sets association of the |key| and |id|. The |key| must be one of
+  // InterestingExtension other than MAX_EXTENSION.
+  void setExtensionMessageID(int key, uint8_t id);
+
+  // Removes association of the |key|. The |key| must be one of
+  // InterestingExtension other than MAX_EXTENSION. After this call,
+  // getExtensionMessageID(key) returns 0.
+  void removeExtension(int key);
+private:
+  Extensions extensions_;
 };
 
+// Returns the extension name corresponding to the given |key|. The
+// |key| must be one of InterestingExtension other than MAX_EXTENSION.
+const char* strBtExtension(int key);
+
+// Returns extension key corresponding to the given extension |name|.
+// If no such key exists, returns
+// ExtensionMessageRegistry::MAX_EXTENSION.
+int keyBtExtension(const char* name);
+
 } // namespace aria2
 
 #endif // D_EXTENSION_MESSAGE_REGISTRY_H

+ 35 - 21
src/HandshakeExtensionMessage.cc

@@ -68,10 +68,11 @@ std::string HandshakeExtensionMessage::getPayload()
     dict.put("p", Integer::g(tcpPort_));
   }
   SharedHandle<Dict> extDict = Dict::g();
-  for(std::map<std::string, uint8_t>::const_iterator itr = extensions_.begin(),
-        eoi = extensions_.end(); itr != eoi; ++itr) {
-    const std::map<std::string, uint8_t>::value_type& vt = *itr;
-    extDict->put(vt.first, Integer::g(vt.second));
+  for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) {
+    int id = extreg_.getExtensionMessageID(i);
+    if(id) {
+      extDict->put(strBtExtension(i), Integer::g(id));
+    }
   }
   dict.put("m", extDict);
   if(metadataSize_) {
@@ -87,10 +88,11 @@ std::string HandshakeExtensionMessage::toString() const
                     util::percentEncode(clientVersion_).c_str(),
                     tcpPort_,
                     static_cast<unsigned long>(metadataSize_)));
-  for(std::map<std::string, uint8_t>::const_iterator itr = extensions_.begin(),
-        eoi = extensions_.end(); itr != eoi; ++itr) {
-    const std::map<std::string, uint8_t>::value_type& vt = *itr;
-    s += fmt(", %s=%u", vt.first.c_str(), vt.second);
+  for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) {
+    int id = extreg_.getExtensionMessageID(i);
+    if(id) {
+      s += fmt(", %s=%u", strBtExtension(i), id);
+    }
   }
   return s;
 }
@@ -101,14 +103,15 @@ void HandshakeExtensionMessage::doReceivedAction()
     peer_->setPort(tcpPort_);
     peer_->setIncomingPeer(false);
   }
-  for(std::map<std::string, uint8_t>::const_iterator itr = extensions_.begin(),
-        eoi = extensions_.end(); itr != eoi; ++itr) {
-    const std::map<std::string, uint8_t>::value_type& vt = *itr;
-    peer_->setExtension(vt.first, vt.second);
+  for(int i = 0; i < ExtensionMessageRegistry::MAX_EXTENSION; ++i) {
+    int id = extreg_.getExtensionMessageID(i);
+    if(id) {
+      peer_->setExtension(i, id);
+    }
   }
   SharedHandle<TorrentAttribute> attrs = bittorrent::getTorrentAttrs(dctx_);
   if(attrs->metadata.empty()) {
-    if(!peer_->getExtensionMessageID("ut_metadata")) {
+    if(!peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA)) {
       // TODO In metadataGetMode, if peer don't support metadata
       // transfer, should we drop connection? There is a possibility
       // that peer can still tell us peers using PEX.
@@ -146,14 +149,19 @@ void HandshakeExtensionMessage::setPeer(const SharedHandle<Peer>& peer)
   peer_ = peer;
 }
 
-uint8_t HandshakeExtensionMessage::getExtensionMessageID(const std::string& name) const
+void HandshakeExtensionMessage::setExtension(int key, uint8_t id)
 {
-  std::map<std::string, uint8_t>::const_iterator i = extensions_.find(name);
-  if(i == extensions_.end()) {
-    return 0;
-  } else {
-    return (*i).second;
-  }
+  extreg_.setExtensionMessageID(key, id);
+}
+
+void HandshakeExtensionMessage::setExtensions(const Extensions& extensions)
+{
+  extreg_.setExtensions(extensions);
+}
+
+uint8_t HandshakeExtensionMessage::getExtensionMessageID(int key) const
+{
+  return extreg_.getExtensionMessageID(key);
 }
 
 HandshakeExtensionMessageHandle
@@ -187,7 +195,13 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
           eoi = extDict->end(); i != eoi; ++i) {
       const Integer* extId = downcast<Integer>((*i).second);
       if(extId) {
-        msg->extensions_[(*i).first] = extId->i();
+        int key = keyBtExtension((*i).first.c_str());
+        if(key == ExtensionMessageRegistry::MAX_EXTENSION) {
+          A2_LOG_DEBUG(fmt("Unsupported BitTorrent extension %s=%" PRId64,
+                           (*i).first.c_str(), extId->i()));
+        } else {
+          msg->setExtension(key, extId->i());
+        }
       }
     }
   }

+ 5 - 12
src/HandshakeExtensionMessage.h

@@ -37,9 +37,8 @@
 
 #include "ExtensionMessage.h"
 
-#include <map>
-
 #include "BtConstants.h"
+#include "ExtensionMessageRegistry.h"
 
 namespace aria2 {
 
@@ -54,7 +53,7 @@ private:
 
   size_t metadataSize_;
 
-  std::map<std::string, uint8_t> extensions_;
+  ExtensionMessageRegistry extreg_;
 
   SharedHandle<DownloadContext> dctx_;
 
@@ -117,17 +116,11 @@ public:
     dctx_ = dctx;
   }
 
-  void setExtension(const std::string& name, uint8_t id)
-  {
-    extensions_[name] = id;
-  }
+  void setExtension(int key, uint8_t id);
 
-  void setExtensions(const Extensions& extensions)
-  {
-    extensions_ = extensions;
-  }
+  void setExtensions(const Extensions& extensions);
 
-  uint8_t getExtensionMessageID(const std::string& name) const;
+  uint8_t getExtensionMessageID(int key) const;
 
   void setPeer(const SharedHandle<Peer>& peer);
 

+ 5 - 5
src/Peer.cc

@@ -334,22 +334,22 @@ bool Peer::isGood() const
     difference(global::wallclock()) >= BAD_CONDITION_INTERVAL;
 }
 
-uint8_t Peer::getExtensionMessageID(const std::string& name) const
+uint8_t Peer::getExtensionMessageID(int key) const
 {
   assert(res_);
-  return res_->getExtensionMessageID(name);
+  return res_->getExtensionMessageID(key);
 }
 
-std::string Peer::getExtensionName(uint8_t id) const
+const char* Peer::getExtensionName(uint8_t id) const
 {
   assert(res_);
   return res_->getExtensionName(id);
 }
 
-void Peer::setExtension(const std::string& name, uint8_t id)
+void Peer::setExtension(int key, uint8_t id)
 {
   assert(res_);
-  res_->addExtension(name, id);
+  res_->addExtension(key, id);
 }
 
 void Peer::setExtendedMessagingEnabled(bool enabled)

+ 3 - 3
src/Peer.h

@@ -283,11 +283,11 @@ public:
 
   bool hasPiece(size_t index) const;
 
-  uint8_t getExtensionMessageID(const std::string& name) const;
+  uint8_t getExtensionMessageID(int key) const;
 
-  std::string getExtensionName(uint8_t id) const;
+  const char* getExtensionName(uint8_t id) const;
 
-  void setExtension(const std::string& name, uint8_t id);
+  void setExtension(int key, uint8_t id);
 
   const Timer& getLastDownloadUpdate() const;
 

+ 4 - 0
src/PeerInteractionCommand.cc

@@ -120,6 +120,10 @@ PeerInteractionCommand::PeerInteractionCommand
 
   SharedHandle<ExtensionMessageRegistry> exMsgRegistry
     (new ExtensionMessageRegistry());
+  exMsgRegistry->setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 8);
+  // http://www.bittorrent.org/beps/bep_0009.html
+  exMsgRegistry->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA,
+                                       9);
 
   SharedHandle<UTMetadataRequestFactory> utMetadataRequestFactory;
   SharedHandle<UTMetadataRequestTracker> utMetadataRequestTracker;

+ 7 - 20
src/PeerSessionResource.cc

@@ -192,32 +192,19 @@ void PeerSessionResource::extendedMessagingEnabled(bool b)
   extendedMessagingEnabled_ = b;
 }
 
-uint8_t
-PeerSessionResource::getExtensionMessageID(const std::string& name) const
-{
-  Extensions::const_iterator itr = extensions_.find(name);
-  if(itr == extensions_.end()) {
-    return 0;
-  } else {
-    return (*itr).second;
-  }
+uint8_t PeerSessionResource::getExtensionMessageID(int key) const
+{
+  return extreg_.getExtensionMessageID(key);
 }
 
-std::string PeerSessionResource::getExtensionName(uint8_t id) const
+const char* PeerSessionResource::getExtensionName(uint8_t id) const
 {
-  for(Extensions::const_iterator itr = extensions_.begin(),
-        eoi = extensions_.end(); itr != eoi; ++itr) {
-    const Extensions::value_type& p = *itr;
-    if(p.second == id) {
-      return p.first;
-    }
-  }
-  return A2STR::NIL;
+  return extreg_.getExtensionName(id);
 }
 
-void PeerSessionResource::addExtension(const std::string& name, uint8_t id)
+void PeerSessionResource::addExtension(int key, uint8_t id)
 {
-  extensions_[name] = id;
+  extreg_.setExtensionMessageID(key, id);
 }
 
 void PeerSessionResource::dhtEnabled(bool b)

+ 5 - 4
src/PeerSessionResource.h

@@ -43,6 +43,7 @@
 #include "BtConstants.h"
 #include "PeerStat.h"
 #include "TimerA2.h"
+#include "ExtensionMessageRegistry.h"
 
 namespace aria2 {
 
@@ -73,7 +74,7 @@ private:
   // fast index set which localhost has sent to a peer.
   std::set<size_t> amAllowedIndexSet_;
   bool extendedMessagingEnabled_;
-  Extensions extensions_;
+  ExtensionMessageRegistry extreg_;
   bool dhtEnabled_;
   PeerStat peerStat_;
 
@@ -192,11 +193,11 @@ public:
 
   void extendedMessagingEnabled(bool b);
 
-  uint8_t getExtensionMessageID(const std::string& name) const;
+  uint8_t getExtensionMessageID(int key) const;
 
-  std::string getExtensionName(uint8_t id) const;
+  const char* getExtensionName(uint8_t id) const;
 
-  void addExtension(const std::string& name, uint8_t id);
+  void addExtension(int key, uint8_t id);
 
   bool dhtEnabled() const
   {

+ 3 - 1
src/UTMetadataRequestExtensionMessage.cc

@@ -48,6 +48,7 @@
 #include "DownloadContext.h"
 #include "BtMessage.h"
 #include "PieceStorage.h"
+#include "ExtensionMessageRegistry.h"
 
 namespace aria2 {
 
@@ -76,7 +77,8 @@ std::string UTMetadataRequestExtensionMessage::toString() const
 void UTMetadataRequestExtensionMessage::doReceivedAction()
 {
   SharedHandle<TorrentAttribute> attrs = bittorrent::getTorrentAttrs(dctx_);
-  uint8_t id = peer_->getExtensionMessageID("ut_metadata");
+  uint8_t id = peer_->getExtensionMessageID
+    (ExtensionMessageRegistry::UT_METADATA);
   if(attrs->metadata.empty()) {
     SharedHandle<UTMetadataRejectExtensionMessage> m
       (new UTMetadataRejectExtensionMessage(id));

+ 2 - 1
src/UTMetadataRequestFactory.cc

@@ -44,6 +44,7 @@
 #include "Logger.h"
 #include "LogFactory.h"
 #include "fmt.h"
+#include "ExtensionMessageRegistry.h"
 
 namespace aria2 {
 
@@ -71,7 +72,7 @@ void UTMetadataRequestFactory::create
                      static_cast<unsigned long>(p->getIndex())));
     SharedHandle<UTMetadataRequestExtensionMessage> m
       (new UTMetadataRequestExtensionMessage
-       (peer_->getExtensionMessageID("ut_metadata")));
+       (peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA)));
     m->setIndex(p->getIndex());
     m->setDownloadContext(dctx_);
     m->setBtMessageDispatcher(dispatcher_);

+ 22 - 9
test/DefaultExtensionMessageFactoryTest.cc

@@ -53,7 +53,7 @@ public:
 
     peer_.reset(new Peer("192.168.0.1", 6969));
     peer_->allocateSessionResource(1024, 1024*1024);
-    peer_->setExtension("ut_pex", 1);
+    peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 1);
 
     registry_.reset(new ExtensionMessageRegistry());
 
@@ -76,9 +76,9 @@ public:
     factory_->setDownloadContext(dctx_);
   }
 
-  std::string getExtensionMessageID(const std::string& name)
+  std::string getExtensionMessageID(int key)
   {
-    unsigned char id[1] = { registry_->getExtensionMessageID(name) };
+    unsigned char id[1] = { registry_->getExtensionMessageID(key) };
     return std::string(&id[0], &id[1]);
   }
 
@@ -103,7 +103,7 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DefaultExtensionMessageFactoryTest);
 
 void DefaultExtensionMessageFactoryTest::testCreateMessage_unknown()
 {
-  peer_->setExtension("foo", 255);
+  peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 255);
 
   unsigned char id[1] = { 255 };
 
@@ -139,7 +139,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTPex()
   bittorrent::packcompact(c3, "192.168.0.2", 6882);
   bittorrent::packcompact(c4, "10.1.1.3",10000);
 
-  std::string data = getExtensionMessageID("ut_pex")+"d5:added12:"+
+  registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 1);
+
+  std::string data = getExtensionMessageID(ExtensionMessageRegistry::UT_PEX)
+    +"d5:added12:"+
     std::string(&c1[0], &c1[6])+std::string(&c2[0], &c2[6])+
     "7:added.f2:207:dropped12:"+
     std::string(&c3[0], &c3[6])+std::string(&c4[0], &c4[6])+
@@ -147,13 +150,17 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTPex()
 
   SharedHandle<UTPexExtensionMessage> m =
     createMessage<UTPexExtensionMessage>(data);
-  CPPUNIT_ASSERT_EQUAL(registry_->getExtensionMessageID("ut_pex"),
+  CPPUNIT_ASSERT_EQUAL(registry_->getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX),
                        m->getExtensionMessageID());
 }
 
 void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest()
 {
-  std::string data = getExtensionMessageID("ut_metadata")+
+  registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);
+
+  std::string data = getExtensionMessageID
+    (ExtensionMessageRegistry::UT_METADATA)+
     "d8:msg_typei0e5:piecei1ee";
   SharedHandle<UTMetadataRequestExtensionMessage> m =
     createMessage<UTMetadataRequestExtensionMessage>(data);
@@ -162,7 +169,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest()
 
 void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData()
 {
-  std::string data = getExtensionMessageID("ut_metadata")+
+  registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);
+
+  std::string data = getExtensionMessageID
+    (ExtensionMessageRegistry::UT_METADATA)+
     "d8:msg_typei1e5:piecei1e10:total_sizei300ee0000000000";
   SharedHandle<UTMetadataDataExtensionMessage> m =
     createMessage<UTMetadataDataExtensionMessage>(data);
@@ -173,7 +183,10 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData()
 
 void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataReject()
 {
-  std::string data = getExtensionMessageID("ut_metadata")+
+  registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);
+
+  std::string data = getExtensionMessageID
+    (ExtensionMessageRegistry::UT_METADATA)+
     "d8:msg_typei2e5:piecei1ee";
   SharedHandle<UTMetadataRejectExtensionMessage> m =
     createMessage<UTMetadataRejectExtensionMessage>(data);

+ 75 - 0
test/ExtensionMessageRegistryTest.cc

@@ -0,0 +1,75 @@
+#include "ExtensionMessageRegistry.h"
+
+#include <cppunit/extensions/HelperMacros.h>
+
+namespace aria2 {
+
+class ExtensionMessageRegistryTest:public CppUnit::TestFixture {
+
+  CPPUNIT_TEST_SUITE(ExtensionMessageRegistryTest);
+  CPPUNIT_TEST(testStrBtExtension);
+  CPPUNIT_TEST(testKeyBtExtension);
+  CPPUNIT_TEST(testGetExtensionMessageID);
+  CPPUNIT_TEST_SUITE_END();
+public:
+  void testStrBtExtension();
+  void testKeyBtExtension();
+  void testGetExtensionMessageID();
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION( ExtensionMessageRegistryTest );
+
+void ExtensionMessageRegistryTest::testStrBtExtension()
+{
+  CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"),
+                       std::string(strBtExtension
+                                   (ExtensionMessageRegistry::UT_PEX)));
+  CPPUNIT_ASSERT_EQUAL(std::string("ut_metadata"),
+                       std::string(strBtExtension
+                                   (ExtensionMessageRegistry::UT_METADATA)));
+  CPPUNIT_ASSERT(!strBtExtension(100));
+}
+
+void ExtensionMessageRegistryTest::testKeyBtExtension()
+{
+  CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::UT_PEX,
+                       keyBtExtension("ut_pex"));
+  CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::UT_METADATA,
+                       keyBtExtension("ut_metadata"));
+  CPPUNIT_ASSERT_EQUAL((int)ExtensionMessageRegistry::MAX_EXTENSION,
+                       keyBtExtension("unknown"));
+}
+
+void ExtensionMessageRegistryTest::testGetExtensionMessageID()
+{
+  ExtensionMessageRegistry reg;
+  CPPUNIT_ASSERT_EQUAL((uint8_t)0, reg.getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX));
+  CPPUNIT_ASSERT(!reg.getExtensionName(0));
+  CPPUNIT_ASSERT(!reg.getExtensionName(1));
+  CPPUNIT_ASSERT(!reg.getExtensionName(100));
+
+  reg.setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 1);
+
+  CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"),
+                       std::string(reg.getExtensionName(1)));
+  CPPUNIT_ASSERT_EQUAL((uint8_t)1, reg.getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX));
+
+  reg.setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 127);
+
+  CPPUNIT_ASSERT_EQUAL((uint8_t)127, reg.getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_METADATA));
+  CPPUNIT_ASSERT_EQUAL((uint8_t)1, reg.getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX));
+
+  reg.removeExtension(ExtensionMessageRegistry::UT_PEX);
+
+  CPPUNIT_ASSERT_EQUAL((uint8_t)127, reg.getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_METADATA));
+  CPPUNIT_ASSERT_EQUAL((uint8_t)0, reg.getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX));
+  CPPUNIT_ASSERT(!reg.getExtensionName(1));
+}
+
+} // namespace aria2

+ 21 - 14
test/HandshakeExtensionMessageTest.cc

@@ -60,12 +60,12 @@ void HandshakeExtensionMessageTest::testGetBencodedData()
   HandshakeExtensionMessage msg;
   msg.setClientVersion("aria2");
   msg.setTCPPort(6889);
-  msg.setExtension("ut_pex", 1);
-  msg.setExtension("a2_dht", 2);
+  msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1);
+  msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 2);
   msg.setMetadataSize(1024);
   CPPUNIT_ASSERT_EQUAL
     (std::string("d"
-                 "1:md6:a2_dhti2e6:ut_pexi1ee"
+                 "1:md11:ut_metadatai2e6:ut_pexi1ee"
                  "13:metadata_sizei1024e"
                  "1:pi6889e"
                  "1:v5:aria2"
@@ -81,12 +81,12 @@ void HandshakeExtensionMessageTest::testToString()
   HandshakeExtensionMessage msg;
   msg.setClientVersion("aria2");
   msg.setTCPPort(6889);
-  msg.setExtension("ut_pex", 1);
-  msg.setExtension("a2_dht", 2);
+  msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1);
+  msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 2);
   msg.setMetadataSize(1024);
   CPPUNIT_ASSERT_EQUAL
     (std::string("handshake client=aria2, tcpPort=6889, metadataSize=1024,"
-                 " a2_dht=2, ut_pex=1"), msg.toString());
+                 " ut_metadata=2, ut_pex=1"), msg.toString());
 }
 
 void HandshakeExtensionMessageTest::testDoReceivedAction()
@@ -106,9 +106,8 @@ void HandshakeExtensionMessageTest::testDoReceivedAction()
   HandshakeExtensionMessage msg;
   msg.setClientVersion("aria2");
   msg.setTCPPort(6889);
-  msg.setExtension("ut_pex", 1);
-  msg.setExtension("a2_dht", 2);
-  msg.setExtension("ut_metadata", 3);
+  msg.setExtension(ExtensionMessageRegistry::UT_PEX, 1);
+  msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 3);
   msg.setMetadataSize(1024);
   msg.setPeer(peer);
   msg.setDownloadContext(dctx);
@@ -116,8 +115,12 @@ void HandshakeExtensionMessageTest::testDoReceivedAction()
   msg.doReceivedAction();
 
   CPPUNIT_ASSERT_EQUAL((uint16_t)6889, peer->getPort());
-  CPPUNIT_ASSERT_EQUAL((uint8_t)1, peer->getExtensionMessageID("ut_pex"));
-  CPPUNIT_ASSERT_EQUAL((uint8_t)2, peer->getExtensionMessageID("a2_dht"));
+  CPPUNIT_ASSERT_EQUAL((uint8_t)1,
+                       peer->getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX));
+  CPPUNIT_ASSERT_EQUAL((uint8_t)3,
+                       peer->getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_METADATA));
   CPPUNIT_ASSERT(peer->isSeeder());
   CPPUNIT_ASSERT_EQUAL((size_t)1024, attrs->metadataSize);
   CPPUNIT_ASSERT_EQUAL((int64_t)1024, dctx->getTotalLength());
@@ -134,13 +137,15 @@ void HandshakeExtensionMessageTest::testDoReceivedAction()
 void HandshakeExtensionMessageTest::testCreate()
 {
   std::string in = 
-    "0d1:pi6881e1:v5:aria21:md6:ut_pexi1ee13:metadata_sizei1024ee";
+    "0d1:pi6881e1:v5:aria21:md5:a2dhti2e6:ut_pexi1ee13:metadata_sizei1024ee";
   SharedHandle<HandshakeExtensionMessage> m =
     HandshakeExtensionMessage::create(reinterpret_cast<const unsigned char*>(in.c_str()),
                                       in.size());
   CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
   CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort());
-  CPPUNIT_ASSERT_EQUAL((uint8_t)1, m->getExtensionMessageID("ut_pex"));
+  CPPUNIT_ASSERT_EQUAL((uint8_t)1,
+                       m->getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX));
   CPPUNIT_ASSERT_EQUAL((size_t)1024, m->getMetadataSize());
   try {
     // bad payload format
@@ -182,7 +187,9 @@ void HandshakeExtensionMessageTest::testCreate_stringnum()
   // port number in string is not allowed
   CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort());
   // extension ID in string is not allowed
-  CPPUNIT_ASSERT_EQUAL((uint8_t)0, m->getExtensionMessageID("ut_pex"));
+  CPPUNIT_ASSERT_EQUAL((uint8_t)0,
+                       m->getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX));
 }
 
 } // namespace aria2

+ 2 - 1
test/Makefile.am

@@ -208,7 +208,8 @@ aria2c_SOURCES += BtAllowedFastMessageTest.cc\
 	LpdMessageReceiverTest.cc\
 	Bencode2Test.cc\
 	PeerConnectionTest.cc\
-	ValueBaseBencodeParserTest.cc
+	ValueBaseBencodeParserTest.cc\
+	ExtensionMessageRegistryTest.cc
 endif # ENABLE_BITTORRENT
 
 if ENABLE_METALINK

+ 11 - 6
test/PeerSessionResourceTest.cc

@@ -137,12 +137,17 @@ void PeerSessionResourceTest::testGetExtensionMessageID()
 {
   PeerSessionResource res(1024, 1024*1024);
 
-  res.addExtension("a2", 9);
-  CPPUNIT_ASSERT_EQUAL((uint8_t)9, res.getExtensionMessageID("a2"));
-  CPPUNIT_ASSERT_EQUAL((uint8_t)0, res.getExtensionMessageID("non"));
-
-  CPPUNIT_ASSERT_EQUAL(std::string("a2"), res.getExtensionName(9));
-  CPPUNIT_ASSERT_EQUAL(std::string(""), res.getExtensionName(10));
+  res.addExtension(ExtensionMessageRegistry::UT_PEX, 9);
+  CPPUNIT_ASSERT_EQUAL((uint8_t)9,
+                       res.getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_PEX));
+  CPPUNIT_ASSERT_EQUAL((uint8_t)0,
+                       res.getExtensionMessageID
+                       (ExtensionMessageRegistry::UT_METADATA));
+
+  CPPUNIT_ASSERT_EQUAL(std::string("ut_pex"),
+                       std::string(res.getExtensionName(9)));
+  CPPUNIT_ASSERT(!res.getExtensionName(10));
 }
 
 void PeerSessionResourceTest::testFastExtensionEnabled()

+ 2 - 1
test/UTMetadataRequestExtensionMessageTest.cc

@@ -16,6 +16,7 @@
 #include "PieceStorage.h"
 #include "extension_message_test_helper.h"
 #include "DlAbortEx.h"
+#include "ExtensionMessageRegistry.h"
 
 namespace aria2 {
 
@@ -44,7 +45,7 @@ public:
     dctx_->setAttribute(CTX_ATTR_BT, attrs);
     peer_.reset(new Peer("host", 6880));
     peer_->allocateSessionResource(0, 0);
-    peer_->setExtension("ut_metadata", 1);
+    peer_->setExtension(ExtensionMessageRegistry::UT_METADATA, 1);
   }
 
   template<typename T>