DHTMessageFactoryImpl.cc 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. /* <!-- copyright */
  2. /*
  3. * aria2 - The high speed download utility
  4. *
  5. * Copyright (C) 2006 Tatsuhiro Tsujikawa
  6. *
  7. * This program is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation; either version 2 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program; if not, write to the Free Software
  19. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  20. *
  21. * In addition, as a special exception, the copyright holders give
  22. * permission to link the code of portions of this program with the
  23. * OpenSSL library under certain conditions as described in each
  24. * individual source file, and distribute linked combinations
  25. * including the two.
  26. * You must obey the GNU General Public License in all respects
  27. * for all of the code used other than OpenSSL. If you modify
  28. * file(s) with this exception, you may extend this exception to your
  29. * version of the file(s), but you are not obligated to do so. If you
  30. * do not wish to do so, delete this exception statement from your
  31. * version. If you delete this exception statement from all source
  32. * files in the program, then also delete it here.
  33. */
  34. /* copyright --> */
  35. #include "DHTMessageFactoryImpl.h"
  36. #include <cstring>
  37. #include <utility>
  38. #include "LogFactory.h"
  39. #include "DlAbortEx.h"
  40. #include "DHTNode.h"
  41. #include "DHTRoutingTable.h"
  42. #include "DHTPingMessage.h"
  43. #include "DHTPingReplyMessage.h"
  44. #include "DHTFindNodeMessage.h"
  45. #include "DHTFindNodeReplyMessage.h"
  46. #include "DHTGetPeersMessage.h"
  47. #include "DHTGetPeersReplyMessage.h"
  48. #include "DHTAnnouncePeerMessage.h"
  49. #include "DHTAnnouncePeerReplyMessage.h"
  50. #include "DHTUnknownMessage.h"
  51. #include "DHTConnection.h"
  52. #include "DHTMessageDispatcher.h"
  53. #include "DHTPeerAnnounceStorage.h"
  54. #include "DHTTokenTracker.h"
  55. #include "DHTMessageCallback.h"
  56. #include "bittorrent_helper.h"
  57. #include "BtRuntime.h"
  58. #include "util.h"
  59. #include "Peer.h"
  60. #include "Logger.h"
  61. #include "fmt.h"
  62. namespace aria2 {
  63. DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family)
  64. : family_(family),
  65. connection_(0),
  66. dispatcher_(0),
  67. routingTable_(0),
  68. peerAnnounceStorage_(0),
  69. tokenTracker_(0)
  70. {}
  71. DHTMessageFactoryImpl::~DHTMessageFactoryImpl() {}
  72. SharedHandle<DHTNode>
  73. DHTMessageFactoryImpl::getRemoteNode
  74. (const unsigned char* id, const std::string& ipaddr, uint16_t port) const
  75. {
  76. SharedHandle<DHTNode> node = routingTable_->getNode(id, ipaddr, port);
  77. if(!node) {
  78. node.reset(new DHTNode(id));
  79. node->setIPAddress(ipaddr);
  80. node->setPort(port);
  81. }
  82. return node;
  83. }
  84. namespace {
  85. const Dict* getDictionary(const Dict* dict, const std::string& key)
  86. {
  87. const Dict* d = asDict(dict->get(key));
  88. if(d) {
  89. return d;
  90. } else {
  91. throw DL_ABORT_EX
  92. (fmt("Malformed DHT message. Missing %s", key.c_str()));
  93. }
  94. }
  95. } // namespace
  96. namespace {
  97. const String* getString(const Dict* dict, const std::string& key)
  98. {
  99. const String* c = asString(dict->get(key));
  100. if(c) {
  101. return c;
  102. } else {
  103. throw DL_ABORT_EX
  104. (fmt("Malformed DHT message. Missing %s", key.c_str()));
  105. }
  106. }
  107. } // namespace
  108. namespace {
  109. const Integer* getInteger(const Dict* dict, const std::string& key)
  110. {
  111. const Integer* c = asInteger(dict->get(key));
  112. if(c) {
  113. return c;
  114. } else {
  115. throw DL_ABORT_EX
  116. (fmt("Malformed DHT message. Missing %s", key.c_str()));
  117. }
  118. }
  119. } // namespace
  120. namespace {
  121. const String* getString(const List* list, size_t index)
  122. {
  123. const String* c = asString(list->get(index));
  124. if(c) {
  125. return c;
  126. } else {
  127. throw DL_ABORT_EX
  128. (fmt("Malformed DHT message. element[%lu] is not String.",
  129. static_cast<unsigned long>(index)));
  130. }
  131. }
  132. } // namespace
  133. namespace {
  134. const Integer* getInteger(const List* list, size_t index)
  135. {
  136. const Integer* c = asInteger(list->get(index));
  137. if(c) {
  138. return c;
  139. } else {
  140. throw DL_ABORT_EX
  141. (fmt("Malformed DHT message. element[%lu] is not Integer.",
  142. static_cast<unsigned long>(index)));
  143. }
  144. }
  145. } // namespace
  146. namespace {
  147. const List* getList(const Dict* dict, const std::string& key)
  148. {
  149. const List* l = asList(dict->get(key));
  150. if(l) {
  151. return l;
  152. } else {
  153. throw DL_ABORT_EX
  154. (fmt("Malformed DHT message. Missing %s", key.c_str()));
  155. }
  156. }
  157. } // namespace
  158. void DHTMessageFactoryImpl::validateID(const String* id) const
  159. {
  160. if(id->s().size() != DHT_ID_LENGTH) {
  161. throw DL_ABORT_EX
  162. (fmt("Malformed DHT message. Invalid ID length."
  163. " Expected:%lu, Actual:%lu",
  164. static_cast<unsigned long>(DHT_ID_LENGTH),
  165. static_cast<unsigned long>(id->s().size())));
  166. }
  167. }
  168. void DHTMessageFactoryImpl::validatePort(const Integer* port) const
  169. {
  170. if(!(0 < port->i() && port->i() < UINT16_MAX)) {
  171. throw DL_ABORT_EX
  172. (fmt("Malformed DHT message. Invalid port=%s",
  173. util::itos(port->i()).c_str()));
  174. }
  175. }
  176. namespace {
  177. void setVersion(const SharedHandle<DHTMessage>& msg, const Dict* dict)
  178. {
  179. const String* v = asString(dict->get(DHTMessage::V));
  180. if(v) {
  181. msg->setVersion(v->s());
  182. } else {
  183. msg->setVersion(A2STR::NIL);
  184. }
  185. }
  186. } // namespace
  187. SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
  188. (const Dict* dict, const std::string& ipaddr, uint16_t port)
  189. {
  190. const String* messageType = getString(dict, DHTQueryMessage::Q);
  191. const String* transactionID = getString(dict, DHTMessage::T);
  192. const String* y = getString(dict, DHTMessage::Y);
  193. const Dict* aDict = getDictionary(dict, DHTQueryMessage::A);
  194. if(y->s() != DHTQueryMessage::Q) {
  195. throw DL_ABORT_EX("Malformed DHT message. y != q");
  196. }
  197. const String* id = getString(aDict, DHTMessage::ID);
  198. validateID(id);
  199. SharedHandle<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port);
  200. SharedHandle<DHTQueryMessage> msg;
  201. if(messageType->s() == DHTPingMessage::PING) {
  202. msg = createPingMessage(remoteNode, transactionID->s());
  203. } else if(messageType->s() == DHTFindNodeMessage::FIND_NODE) {
  204. const String* targetNodeID =
  205. getString(aDict, DHTFindNodeMessage::TARGET_NODE);
  206. validateID(targetNodeID);
  207. msg = createFindNodeMessage(remoteNode, targetNodeID->uc(),
  208. transactionID->s());
  209. } else if(messageType->s() == DHTGetPeersMessage::GET_PEERS) {
  210. const String* infoHash = getString(aDict, DHTGetPeersMessage::INFO_HASH);
  211. validateID(infoHash);
  212. msg = createGetPeersMessage(remoteNode, infoHash->uc(), transactionID->s());
  213. } else if(messageType->s() == DHTAnnouncePeerMessage::ANNOUNCE_PEER) {
  214. const String* infoHash = getString(aDict,DHTAnnouncePeerMessage::INFO_HASH);
  215. validateID(infoHash);
  216. const Integer* port = getInteger(aDict, DHTAnnouncePeerMessage::PORT);
  217. validatePort(port);
  218. const String* token = getString(aDict, DHTAnnouncePeerMessage::TOKEN);
  219. msg = createAnnouncePeerMessage(remoteNode, infoHash->uc(),
  220. static_cast<uint16_t>(port->i()),
  221. token->s(), transactionID->s());
  222. } else {
  223. throw DL_ABORT_EX(fmt("Unsupported message type: %s",
  224. messageType->s().c_str()));
  225. }
  226. setVersion(msg, dict);
  227. return msg;
  228. }
  229. SharedHandle<DHTResponseMessage>
  230. DHTMessageFactoryImpl::createResponseMessage
  231. (const std::string& messageType,
  232. const Dict* dict,
  233. const std::string& ipaddr,
  234. uint16_t port)
  235. {
  236. const String* transactionID = getString(dict, DHTMessage::T);
  237. const String* y = getString(dict, DHTMessage::Y);
  238. if(y->s() == DHTUnknownMessage::E) {
  239. // for now, just report error message arrived and throw exception.
  240. const List* e = getList(dict, DHTUnknownMessage::E);
  241. if(e->size() == 2) {
  242. A2_LOG_INFO(fmt("Received Error DHT message. code=%s, msg=%s",
  243. util::itos(getInteger(e, 0)->i()).c_str(),
  244. util::percentEncode(getString(e, 1)->s()).c_str()));
  245. } else {
  246. A2_LOG_DEBUG("e doesn't have 2 elements.");
  247. }
  248. throw DL_ABORT_EX("Received Error DHT message.");
  249. } else if(y->s() != DHTResponseMessage::R) {
  250. throw DL_ABORT_EX
  251. (fmt("Malformed DHT message. y != r: y=%s",
  252. util::percentEncode(y->s()).c_str()));
  253. }
  254. const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
  255. const String* id = getString(rDict, DHTMessage::ID);
  256. validateID(id);
  257. SharedHandle<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port);
  258. SharedHandle<DHTResponseMessage> msg;
  259. if(messageType == DHTPingReplyMessage::PING) {
  260. msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s());
  261. } else if(messageType == DHTFindNodeReplyMessage::FIND_NODE) {
  262. msg = createFindNodeReplyMessage(remoteNode, dict, transactionID->s());
  263. } else if(messageType == DHTGetPeersReplyMessage::GET_PEERS) {
  264. msg = createGetPeersReplyMessage(remoteNode, dict, transactionID->s());
  265. } else if(messageType == DHTAnnouncePeerReplyMessage::ANNOUNCE_PEER) {
  266. msg = createAnnouncePeerReplyMessage(remoteNode, transactionID->s());
  267. } else {
  268. throw DL_ABORT_EX
  269. (fmt("Unsupported message type: %s", messageType.c_str()));
  270. }
  271. setVersion(msg, dict);
  272. return msg;
  273. }
  274. namespace {
  275. const std::string& getDefaultVersion()
  276. {
  277. static std::string version;
  278. if(version.empty()) {
  279. uint16_t vnum16 = htons(DHT_VERSION);
  280. unsigned char buf[] = { 'A' , '2', 0, 0 };
  281. char* vnump = reinterpret_cast<char*>(&vnum16);
  282. memcpy(buf+2, vnump, 2);
  283. version.assign(&buf[0], &buf[4]);
  284. }
  285. return version;
  286. }
  287. } // namespace
  288. void DHTMessageFactoryImpl::setCommonProperty
  289. (const SharedHandle<DHTAbstractMessage>& m)
  290. {
  291. m->setConnection(connection_);
  292. m->setMessageDispatcher(dispatcher_);
  293. m->setRoutingTable(routingTable_);
  294. m->setMessageFactory(this);
  295. m->setVersion(getDefaultVersion());
  296. }
  297. SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createPingMessage
  298. (const SharedHandle<DHTNode>& remoteNode, const std::string& transactionID)
  299. {
  300. SharedHandle<DHTPingMessage> m
  301. (new DHTPingMessage(localNode_, remoteNode, transactionID));
  302. setCommonProperty(m);
  303. return m;
  304. }
  305. SharedHandle<DHTResponseMessage> DHTMessageFactoryImpl::createPingReplyMessage
  306. (const SharedHandle<DHTNode>& remoteNode,
  307. const unsigned char* id,
  308. const std::string& transactionID)
  309. {
  310. SharedHandle<DHTPingReplyMessage> m
  311. (new DHTPingReplyMessage(localNode_, remoteNode, id, transactionID));
  312. setCommonProperty(m);
  313. return m;
  314. }
  315. SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createFindNodeMessage
  316. (const SharedHandle<DHTNode>& remoteNode,
  317. const unsigned char* targetNodeID,
  318. const std::string& transactionID)
  319. {
  320. SharedHandle<DHTFindNodeMessage> m
  321. (new DHTFindNodeMessage
  322. (localNode_, remoteNode, targetNodeID, transactionID));
  323. setCommonProperty(m);
  324. return m;
  325. }
  326. SharedHandle<DHTResponseMessage>
  327. DHTMessageFactoryImpl::createFindNodeReplyMessage
  328. (const SharedHandle<DHTNode>& remoteNode,
  329. const std::vector<SharedHandle<DHTNode> >& closestKNodes,
  330. const std::string& transactionID)
  331. {
  332. SharedHandle<DHTFindNodeReplyMessage> m
  333. (new DHTFindNodeReplyMessage
  334. (family_, localNode_, remoteNode, transactionID));
  335. m->setClosestKNodes(closestKNodes);
  336. setCommonProperty(m);
  337. return m;
  338. }
  339. void DHTMessageFactoryImpl::extractNodes
  340. (std::vector<SharedHandle<DHTNode> >& nodes,
  341. const unsigned char* src, size_t length)
  342. {
  343. int unit = bittorrent::getCompactLength(family_)+20;
  344. if(length%unit != 0) {
  345. throw DL_ABORT_EX
  346. (fmt("Nodes length is not multiple of %d", unit));
  347. }
  348. for(size_t offset = 0; offset < length; offset += unit) {
  349. SharedHandle<DHTNode> node(new DHTNode(src+offset));
  350. std::pair<std::string, uint16_t> addr =
  351. bittorrent::unpackcompact(src+offset+DHT_ID_LENGTH, family_);
  352. if(addr.first.empty()) {
  353. continue;
  354. }
  355. node->setIPAddress(addr.first);
  356. node->setPort(addr.second);
  357. nodes.push_back(node);
  358. }
  359. }
  360. SharedHandle<DHTResponseMessage>
  361. DHTMessageFactoryImpl::createFindNodeReplyMessage
  362. (const SharedHandle<DHTNode>& remoteNode,
  363. const Dict* dict,
  364. const std::string& transactionID)
  365. {
  366. const String* nodesData =
  367. asString(getDictionary(dict, DHTResponseMessage::R)->
  368. get(family_ == AF_INET?DHTFindNodeReplyMessage::NODES:
  369. DHTFindNodeReplyMessage::NODES6));
  370. std::vector<SharedHandle<DHTNode> > nodes;
  371. if(nodesData) {
  372. extractNodes(nodes, nodesData->uc(), nodesData->s().size());
  373. }
  374. return createFindNodeReplyMessage(remoteNode, nodes, transactionID);
  375. }
  376. SharedHandle<DHTQueryMessage>
  377. DHTMessageFactoryImpl::createGetPeersMessage
  378. (const SharedHandle<DHTNode>& remoteNode,
  379. const unsigned char* infoHash,
  380. const std::string& transactionID)
  381. {
  382. SharedHandle<DHTGetPeersMessage> m
  383. (new DHTGetPeersMessage(localNode_, remoteNode, infoHash, transactionID));
  384. m->setPeerAnnounceStorage(peerAnnounceStorage_);
  385. m->setTokenTracker(tokenTracker_);
  386. setCommonProperty(m);
  387. return m;
  388. }
  389. SharedHandle<DHTResponseMessage>
  390. DHTMessageFactoryImpl::createGetPeersReplyMessage
  391. (const SharedHandle<DHTNode>& remoteNode,
  392. const Dict* dict,
  393. const std::string& transactionID)
  394. {
  395. const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
  396. const String* nodesData =
  397. asString(rDict->get(family_ == AF_INET?DHTGetPeersReplyMessage::NODES:
  398. DHTGetPeersReplyMessage::NODES6));
  399. std::vector<SharedHandle<DHTNode> > nodes;
  400. if(nodesData) {
  401. extractNodes(nodes, nodesData->uc(), nodesData->s().size());
  402. }
  403. const List* valuesList =
  404. asList(rDict->get(DHTGetPeersReplyMessage::VALUES));
  405. std::vector<SharedHandle<Peer> > peers;
  406. size_t clen = bittorrent::getCompactLength(family_);
  407. if(valuesList) {
  408. for(List::ValueType::const_iterator i = valuesList->begin(),
  409. eoi = valuesList->end(); i != eoi; ++i) {
  410. const String* data = asString(*i);
  411. if(data && data->s().size() == clen) {
  412. std::pair<std::string, uint16_t> addr =
  413. bittorrent::unpackcompact(data->uc(), family_);
  414. if(addr.first.empty()) {
  415. continue;
  416. }
  417. SharedHandle<Peer> peer(new Peer(addr.first, addr.second));
  418. peers.push_back(peer);
  419. }
  420. }
  421. }
  422. const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN);
  423. return createGetPeersReplyMessage
  424. (remoteNode, nodes, peers, token->s(), transactionID);
  425. }
  426. SharedHandle<DHTResponseMessage>
  427. DHTMessageFactoryImpl::createGetPeersReplyMessage
  428. (const SharedHandle<DHTNode>& remoteNode,
  429. const std::vector<SharedHandle<DHTNode> >& closestKNodes,
  430. const std::vector<SharedHandle<Peer> >& values,
  431. const std::string& token,
  432. const std::string& transactionID)
  433. {
  434. SharedHandle<DHTGetPeersReplyMessage> m
  435. (new DHTGetPeersReplyMessage
  436. (family_, localNode_, remoteNode, token, transactionID));
  437. m->setClosestKNodes(closestKNodes);
  438. m->setValues(values);
  439. setCommonProperty(m);
  440. return m;
  441. }
  442. SharedHandle<DHTQueryMessage>
  443. DHTMessageFactoryImpl::createAnnouncePeerMessage
  444. (const SharedHandle<DHTNode>& remoteNode,
  445. const unsigned char* infoHash,
  446. uint16_t tcpPort,
  447. const std::string& token,
  448. const std::string& transactionID)
  449. {
  450. SharedHandle<DHTAnnouncePeerMessage> m
  451. (new DHTAnnouncePeerMessage
  452. (localNode_, remoteNode, infoHash, tcpPort, token, transactionID));
  453. m->setPeerAnnounceStorage(peerAnnounceStorage_);
  454. m->setTokenTracker(tokenTracker_);
  455. setCommonProperty(m);
  456. return m;
  457. }
  458. SharedHandle<DHTResponseMessage>
  459. DHTMessageFactoryImpl::createAnnouncePeerReplyMessage
  460. (const SharedHandle<DHTNode>& remoteNode, const std::string& transactionID)
  461. {
  462. SharedHandle<DHTAnnouncePeerReplyMessage> m
  463. (new DHTAnnouncePeerReplyMessage(localNode_, remoteNode, transactionID));
  464. setCommonProperty(m);
  465. return m;
  466. }
  467. SharedHandle<DHTMessage>
  468. DHTMessageFactoryImpl::createUnknownMessage
  469. (const unsigned char* data, size_t length,
  470. const std::string& ipaddr, uint16_t port)
  471. {
  472. SharedHandle<DHTUnknownMessage> m
  473. (new DHTUnknownMessage(localNode_, data, length, ipaddr, port));
  474. return m;
  475. }
  476. void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable)
  477. {
  478. routingTable_ = routingTable;
  479. }
  480. void DHTMessageFactoryImpl::setConnection(DHTConnection* connection)
  481. {
  482. connection_ = connection;
  483. }
  484. void DHTMessageFactoryImpl::setMessageDispatcher
  485. (DHTMessageDispatcher* dispatcher)
  486. {
  487. dispatcher_ = dispatcher;
  488. }
  489. void DHTMessageFactoryImpl::setPeerAnnounceStorage
  490. (DHTPeerAnnounceStorage* storage)
  491. {
  492. peerAnnounceStorage_ = storage;
  493. }
  494. void DHTMessageFactoryImpl::setTokenTracker(DHTTokenTracker* tokenTracker)
  495. {
  496. tokenTracker_ = tokenTracker;
  497. }
  498. void DHTMessageFactoryImpl::setLocalNode
  499. (const SharedHandle<DHTNode>& localNode)
  500. {
  501. localNode_ = localNode;
  502. }
  503. } // namespace aria2