DefaultBtRequestFactory.cc 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. /* <!-- copyright */
  2. /*
  3. * aria2 - The high speed download utility
  4. *
  5. * Copyright (C) 2006 Tatsuhiro Tsujikawa
  6. *
  7. * This program is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation; either version 2 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program; if not, write to the Free Software
  19. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  20. *
  21. * In addition, as a special exception, the copyright holders give
  22. * permission to link the code of portions of this program with the
  23. * OpenSSL library under certain conditions as described in each
  24. * individual source file, and distribute linked combinations
  25. * including the two.
  26. * You must obey the GNU General Public License in all respects
  27. * for all of the code used other than OpenSSL. If you modify
  28. * file(s) with this exception, you may extend this exception to your
  29. * version of the file(s), but you are not obligated to do so. If you
  30. * do not wish to do so, delete this exception statement from your
  31. * version. If you delete this exception statement from all source
  32. * files in the program, then also delete it here.
  33. */
  34. /* copyright --> */
  35. #include "DefaultBtRequestFactory.h"
  36. #include <algorithm>
  37. #include "LogFactory.h"
  38. #include "Logger.h"
  39. #include "Piece.h"
  40. #include "Peer.h"
  41. #include "PieceStorage.h"
  42. #include "BtMessageDispatcher.h"
  43. #include "BtMessageFactory.h"
  44. #include "BtMessage.h"
  45. #include "a2functional.h"
  46. #include "SimpleRandomizer.h"
  47. #include "array_fun.h"
  48. #include "fmt.h"
  49. #include "BtRequestMessage.h"
  50. namespace aria2 {
  51. DefaultBtRequestFactory::DefaultBtRequestFactory()
  52. : pieceStorage_(nullptr),
  53. dispatcher_(nullptr),
  54. messageFactory_(nullptr),
  55. cuid_(0)
  56. {
  57. }
  58. DefaultBtRequestFactory::~DefaultBtRequestFactory() = default;
  59. void DefaultBtRequestFactory::addTargetPiece(
  60. const std::shared_ptr<Piece>& piece)
  61. {
  62. pieces_.push_back(piece);
  63. }
  64. namespace {
  65. class AbortCompletedPieceRequest {
  66. private:
  67. BtMessageDispatcher* dispatcher_;
  68. public:
  69. AbortCompletedPieceRequest(BtMessageDispatcher* dispatcher)
  70. : dispatcher_(dispatcher)
  71. {
  72. }
  73. void operator()(const std::shared_ptr<Piece>& piece)
  74. {
  75. if (piece->pieceComplete()) {
  76. dispatcher_->doAbortOutstandingRequestAction(piece);
  77. }
  78. }
  79. };
  80. } // namespace
  81. void DefaultBtRequestFactory::removeCompletedPiece()
  82. {
  83. std::for_each(pieces_.begin(), pieces_.end(),
  84. AbortCompletedPieceRequest(dispatcher_));
  85. pieces_.erase(std::remove_if(pieces_.begin(), pieces_.end(),
  86. std::mem_fn(&Piece::pieceComplete)),
  87. pieces_.end());
  88. }
  89. void DefaultBtRequestFactory::removeTargetPiece(
  90. const std::shared_ptr<Piece>& piece)
  91. {
  92. pieces_.erase(
  93. std::remove_if(pieces_.begin(), pieces_.end(), derefEqual(piece)),
  94. pieces_.end());
  95. dispatcher_->doAbortOutstandingRequestAction(piece);
  96. pieceStorage_->cancelPiece(piece, cuid_);
  97. }
  98. namespace {
  99. class ProcessChokedPiece {
  100. private:
  101. std::shared_ptr<Peer> peer_;
  102. PieceStorage* pieceStorage_;
  103. cuid_t cuid_;
  104. public:
  105. ProcessChokedPiece(std::shared_ptr<Peer> peer, PieceStorage* pieceStorage,
  106. cuid_t cuid)
  107. : peer_(std::move(peer)), pieceStorage_(pieceStorage), cuid_(cuid)
  108. {
  109. }
  110. void operator()(const std::shared_ptr<Piece>& piece)
  111. {
  112. if (!peer_->isInPeerAllowedIndexSet(piece->getIndex())) {
  113. pieceStorage_->cancelPiece(piece, cuid_);
  114. }
  115. }
  116. };
  117. } // namespace
  118. namespace {
  119. class FindChokedPiece {
  120. private:
  121. std::shared_ptr<Peer> peer_;
  122. public:
  123. FindChokedPiece(std::shared_ptr<Peer> peer) : peer_(std::move(peer)) {}
  124. bool operator()(const std::shared_ptr<Piece>& piece)
  125. {
  126. return !peer_->isInPeerAllowedIndexSet(piece->getIndex());
  127. }
  128. };
  129. } // namespace
  130. void DefaultBtRequestFactory::doChokedAction()
  131. {
  132. std::for_each(pieces_.begin(), pieces_.end(),
  133. ProcessChokedPiece(peer_, pieceStorage_, cuid_));
  134. pieces_.erase(
  135. std::remove_if(pieces_.begin(), pieces_.end(), FindChokedPiece(peer_)),
  136. pieces_.end());
  137. }
  138. void DefaultBtRequestFactory::removeAllTargetPiece()
  139. {
  140. for (auto& elem : pieces_) {
  141. dispatcher_->doAbortOutstandingRequestAction(elem);
  142. pieceStorage_->cancelPiece(elem, cuid_);
  143. }
  144. pieces_.clear();
  145. }
  146. std::vector<std::unique_ptr<BtRequestMessage>>
  147. DefaultBtRequestFactory::createRequestMessages(size_t max, bool endGame)
  148. {
  149. if (endGame) {
  150. return createRequestMessagesOnEndGame(max);
  151. }
  152. auto requests = std::vector<std::unique_ptr<BtRequestMessage>>{};
  153. size_t getnum = max - requests.size();
  154. auto blockIndexes = std::vector<size_t>{};
  155. blockIndexes.reserve(getnum);
  156. for (auto itr = std::begin(pieces_), eoi = std::end(pieces_);
  157. itr != eoi && getnum; ++itr) {
  158. auto& piece = *itr;
  159. if (piece->getMissingUnusedBlockIndex(blockIndexes, getnum)) {
  160. getnum -= blockIndexes.size();
  161. for (auto i = std::begin(blockIndexes), eoi2 = std::end(blockIndexes);
  162. i != eoi2; ++i) {
  163. A2_LOG_DEBUG(
  164. fmt("Creating RequestMessage index=%lu, begin=%u,"
  165. " blockIndex=%lu",
  166. static_cast<unsigned long>(piece->getIndex()),
  167. static_cast<unsigned int>((*i) * piece->getBlockLength()),
  168. static_cast<unsigned long>(*i)));
  169. requests.push_back(messageFactory_->createRequestMessage(piece, *i));
  170. }
  171. blockIndexes.clear();
  172. }
  173. }
  174. return requests;
  175. }
  176. std::vector<std::unique_ptr<BtRequestMessage>>
  177. DefaultBtRequestFactory::createRequestMessagesOnEndGame(size_t max)
  178. {
  179. auto requests = std::vector<std::unique_ptr<BtRequestMessage>>{};
  180. for (auto itr = std::begin(pieces_), eoi = std::end(pieces_);
  181. itr != eoi && requests.size() < max; ++itr) {
  182. auto& piece = *itr;
  183. const size_t mislen = piece->getBitfieldLength();
  184. auto misbitfield = make_unique<unsigned char[]>(mislen);
  185. piece->getAllMissingBlockIndexes(misbitfield.get(), mislen);
  186. auto missingBlockIndexes = std::vector<size_t>{};
  187. size_t blockIndex = 0;
  188. for (size_t i = 0; i < mislen; ++i) {
  189. unsigned char bits = misbitfield[i];
  190. unsigned char mask = 128;
  191. for (size_t bi = 0; bi < 8; ++bi, mask >>= 1, ++blockIndex) {
  192. if (bits & mask) {
  193. missingBlockIndexes.push_back(blockIndex);
  194. }
  195. }
  196. }
  197. std::shuffle(std::begin(missingBlockIndexes), std::end(missingBlockIndexes),
  198. *SimpleRandomizer::getInstance());
  199. for (auto bitr = std::begin(missingBlockIndexes),
  200. eoi2 = std::end(missingBlockIndexes);
  201. bitr != eoi2 && requests.size() < max; ++bitr) {
  202. size_t blockIndex = *bitr;
  203. if (!dispatcher_->isOutstandingRequest(piece->getIndex(), blockIndex)) {
  204. A2_LOG_DEBUG(
  205. fmt("Creating RequestMessage index=%lu, begin=%u,"
  206. " blockIndex=%lu",
  207. static_cast<unsigned long>(piece->getIndex()),
  208. static_cast<unsigned int>(blockIndex * piece->getBlockLength()),
  209. static_cast<unsigned long>(blockIndex)));
  210. requests.push_back(
  211. messageFactory_->createRequestMessage(piece, blockIndex));
  212. }
  213. }
  214. }
  215. return requests;
  216. }
  217. namespace {
  218. class CountMissingBlock {
  219. private:
  220. size_t numMissingBlock_;
  221. public:
  222. CountMissingBlock() : numMissingBlock_(0) {}
  223. size_t getNumMissingBlock() { return numMissingBlock_; }
  224. void operator()(const std::shared_ptr<Piece>& piece)
  225. {
  226. numMissingBlock_ += piece->countMissingBlock();
  227. }
  228. };
  229. } // namespace
  230. size_t DefaultBtRequestFactory::countMissingBlock()
  231. {
  232. return std::for_each(pieces_.begin(), pieces_.end(), CountMissingBlock())
  233. .getNumMissingBlock();
  234. }
  235. std::vector<size_t> DefaultBtRequestFactory::getTargetPieceIndexes() const
  236. {
  237. auto res = std::vector<size_t>{};
  238. res.reserve(pieces_.size());
  239. std::transform(std::begin(pieces_), std::end(pieces_),
  240. std::back_inserter(res), std::mem_fn(&Piece::getIndex));
  241. return res;
  242. }
  243. void DefaultBtRequestFactory::setPieceStorage(PieceStorage* pieceStorage)
  244. {
  245. pieceStorage_ = pieceStorage;
  246. }
  247. void DefaultBtRequestFactory::setPeer(const std::shared_ptr<Peer>& peer)
  248. {
  249. peer_ = peer;
  250. }
  251. void DefaultBtRequestFactory::setBtMessageDispatcher(
  252. BtMessageDispatcher* dispatcher)
  253. {
  254. dispatcher_ = dispatcher;
  255. }
  256. void DefaultBtRequestFactory::setBtMessageFactory(BtMessageFactory* factory)
  257. {
  258. messageFactory_ = factory;
  259. }
  260. } // namespace aria2