/* */ #include "DefaultExtensionMessageFactory.h" #include #include "Peer.h" #include "DlAbortEx.h" #include "HandshakeExtensionMessage.h" #include "UTPexExtensionMessage.h" #include "fmt.h" #include "PeerStorage.h" #include "ExtensionMessageRegistry.h" #include "DownloadContext.h" #include "BtMessageDispatcher.h" #include "BtMessageFactory.h" #include "UTMetadataRequestExtensionMessage.h" #include "UTMetadataDataExtensionMessage.h" #include "UTMetadataRejectExtensionMessage.h" #include "message.h" #include "PieceStorage.h" #include "UTMetadataRequestTracker.h" #include "RequestGroup.h" #include "bencode2.h" namespace aria2 { // i686-w64-mingw32-g++ 4.6 does not support constructor delegate DefaultExtensionMessageFactory::DefaultExtensionMessageFactory() : peerStorage_{nullptr}, registry_{nullptr}, dctx_{nullptr}, messageFactory_{nullptr}, dispatcher_{nullptr}, tracker_{nullptr} { } DefaultExtensionMessageFactory::DefaultExtensionMessageFactory( const std::shared_ptr& peer, ExtensionMessageRegistry* registry) : peerStorage_{nullptr}, peer_{peer}, registry_{registry}, dctx_{nullptr}, messageFactory_{nullptr}, dispatcher_{nullptr}, tracker_{nullptr} { } std::unique_ptr DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t length) { uint8_t extensionMessageID = *data; if (extensionMessageID == 0) { // handshake auto m = HandshakeExtensionMessage::create(data, length); m->setPeer(peer_); m->setDownloadContext(dctx_); return std::move(m); } else { const char* extensionName = registry_->getExtensionName(extensionMessageID); if (!extensionName) { throw DL_ABORT_EX( fmt("No extension registered for extended message ID %u", extensionMessageID)); } if (strcmp(extensionName, "ut_pex") == 0) { // uTorrent compatible Peer-Exchange auto m = UTPexExtensionMessage::create(data, length); m->setPeerStorage(peerStorage_); return std::move(m); } else if (strcmp(extensionName, "ut_metadata") == 0) { if (length == 0) { throw DL_ABORT_EX(fmt(MSG_TOO_SMALL_PAYLOAD_SIZE, "ut_metadata", static_cast(length))); } size_t end; auto decoded = bencode2::decode(data + 1, length - 1, end); const Dict* dict = downcast(decoded); if (!dict) { throw DL_ABORT_EX("Bad ut_metadata: dictionary not found"); } const Integer* msgType = downcast(dict->get("msg_type")); if (!msgType) { throw DL_ABORT_EX("Bad ut_metadata: msg_type not found"); } const Integer* index = downcast(dict->get("piece")); if (!index || index->i() < 0) { throw DL_ABORT_EX("Bad ut_metadata: piece not found"); } switch (msgType->i()) { case 0: { auto m = make_unique(extensionMessageID); m->setIndex(index->i()); m->setDownloadContext(dctx_); m->setPeer(peer_); m->setBtMessageFactory(messageFactory_); m->setBtMessageDispatcher(dispatcher_); return std::move(m); } case 1: { if (end == length) { throw DL_ABORT_EX("Bad ut_metadata data: data not found"); } const Integer* totalSize = downcast(dict->get("total_size")); if (!totalSize || totalSize->i() < 0) { throw DL_ABORT_EX("Bad ut_metadata data: total_size not found"); } auto m = make_unique(extensionMessageID); m->setIndex(index->i()); m->setTotalSize(totalSize->i()); m->setData(&data[1 + end], &data[length]); m->setUTMetadataRequestTracker(tracker_); m->setPieceStorage( dctx_->getOwnerRequestGroup()->getPieceStorage().get()); m->setDownloadContext(dctx_); return std::move(m); } case 2: { auto m = make_unique(extensionMessageID); m->setIndex(index->i()); // No need to inject tracker because peer will be disconnected. return std::move(m); } default: throw DL_ABORT_EX( fmt("Bad ut_metadata: unknown msg_type=%" PRId64, msgType->i())); } } else { throw DL_ABORT_EX(fmt("Unsupported extension message received." " extensionMessageID=%u, extensionName=%s", extensionMessageID, extensionName)); } } } void DefaultExtensionMessageFactory::setPeerStorage(PeerStorage* peerStorage) { peerStorage_ = peerStorage; } void DefaultExtensionMessageFactory::setPeer(const std::shared_ptr& peer) { peer_ = peer; } void DefaultExtensionMessageFactory::setExtensionMessageRegistry( ExtensionMessageRegistry* registry) { registry_ = registry; } void DefaultExtensionMessageFactory::setDownloadContext(DownloadContext* dctx) { dctx_ = dctx; } void DefaultExtensionMessageFactory::setBtMessageFactory( BtMessageFactory* factory) { messageFactory_ = factory; } void DefaultExtensionMessageFactory::setBtMessageDispatcher( BtMessageDispatcher* disp) { dispatcher_ = disp; } void DefaultExtensionMessageFactory::setUTMetadataRequestTracker( UTMetadataRequestTracker* tracker) { tracker_ = tracker; } } // namespace aria2