/* */ #include "DefaultExtensionMessageFactory.h" #include "Peer.h" #include "DlAbortEx.h" #include "HandshakeExtensionMessage.h" #include "UTPexExtensionMessage.h" #include "LogFactory.h" #include "Logger.h" #include "StringFormat.h" #include "PeerStorage.h" #include "ExtensionMessageRegistry.h" #include "DownloadContext.h" #include "BtMessageDispatcher.h" #include "BtMessageFactory.h" #include "UTMetadataRequestExtensionMessage.h" #include "UTMetadataDataExtensionMessage.h" #include "UTMetadataRejectExtensionMessage.h" #include "message.h" #include "bencode.h" #include "PieceStorage.h" #include "UTMetadataRequestTracker.h" #include "RequestGroup.h" namespace aria2 { DefaultExtensionMessageFactory::DefaultExtensionMessageFactory(): _logger(LogFactory::getInstance()) {} DefaultExtensionMessageFactory::DefaultExtensionMessageFactory (const PeerHandle& peer, const SharedHandle& registry): _peer(peer), _registry(registry), _logger(LogFactory::getInstance()) {} DefaultExtensionMessageFactory::~DefaultExtensionMessageFactory() {} ExtensionMessageHandle DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t length) { uint8_t extensionMessageID = *data; if(extensionMessageID == 0) { // handshake HandshakeExtensionMessageHandle m = HandshakeExtensionMessage::create(data, length); m->setPeer(_peer); m->setDownloadContext(_dctx); return m; } else { std::string extensionName = _registry->getExtensionName(extensionMessageID); if(extensionName.empty()) { throw DL_ABORT_EX (StringFormat("No extension registered for extended message ID %u", extensionMessageID).str()); } if(extensionName == "ut_pex") { // uTorrent compatible Peer-Exchange UTPexExtensionMessageHandle m = UTPexExtensionMessage::create(data, length); m->setPeerStorage(_peerStorage); return m; } else if(extensionName == "ut_metadata") { if(length == 0) { throw DL_ABORT_EX(StringFormat(MSG_TOO_SMALL_PAYLOAD_SIZE, "ut_metadata", length).str()); } size_t end; BDE dict = bencode::decode(data+1, length-1, end); if(!dict.isDict()) { throw DL_ABORT_EX("Bad ut_metadata: dictionary not found"); } const BDE& msgType = dict["msg_type"]; if(!msgType.isInteger()) { throw DL_ABORT_EX("Bad ut_metadata: msg_type not found"); } const BDE& index = dict["piece"]; if(!index.isInteger()) { throw DL_ABORT_EX("Bad ut_metadata: piece not found"); } switch(msgType.i()) { case 0: { SharedHandle m (new UTMetadataRequestExtensionMessage(extensionMessageID)); m->setIndex(index.i()); m->setDownloadContext(_dctx); m->setPeer(_peer); m->setBtMessageFactory(_messageFactory); m->setBtMessageDispatcher(_dispatcher); return m; } case 1: { if(end == length) { throw DL_ABORT_EX("Bad ut_metadata data: data not found"); } const BDE& totalSize = dict["total_size"]; if(!totalSize.isInteger()) { throw DL_ABORT_EX("Bad ut_metadata data: total_size not found"); } SharedHandle m (new UTMetadataDataExtensionMessage(extensionMessageID)); m->setIndex(index.i()); m->setTotalSize(totalSize.i()); m->setData(std::string(&data[1+end], &data[length])); m->setUTMetadataRequestTracker(_tracker); m->setPieceStorage(_dctx->getOwnerRequestGroup()->getPieceStorage()); m->setDownloadContext(_dctx); return m; } case 2: { SharedHandle m (new UTMetadataRejectExtensionMessage(extensionMessageID)); m->setIndex(index.i()); // No need to inject tracker because peer will be disconnected. return m; } default: throw DL_ABORT_EX(StringFormat("Bad ut_metadata: unknown msg_type=%u", msgType.i()).str()); } } else { throw DL_ABORT_EX (StringFormat("Unsupported extension message received. extensionMessageID=%u, extensionName=%s", extensionMessageID, extensionName.c_str()).str()); } } } void DefaultExtensionMessageFactory::setPeerStorage (const SharedHandle& peerStorage) { _peerStorage = peerStorage; } void DefaultExtensionMessageFactory::setPeer(const PeerHandle& peer) { _peer = peer; } } // namespace aria2