/* */ #include "DefaultBtMessageReceiver.h" #include #include "BtHandshakeMessage.h" #include "message.h" #include "DownloadContext.h" #include "Peer.h" #include "PeerConnection.h" #include "BtMessageDispatcher.h" #include "BtMessageFactory.h" #include "Logger.h" #include "LogFactory.h" #include "bittorrent_helper.h" #include "BtPieceMessage.h" #include "util.h" #include "fmt.h" #include "DlAbortEx.h" namespace aria2 { DefaultBtMessageReceiver::DefaultBtMessageReceiver() : handshakeSent_{false}, downloadContext_{nullptr}, peerConnection_{nullptr}, dispatcher_{nullptr}, messageFactory_{nullptr} { } std::unique_ptr DefaultBtMessageReceiver::receiveHandshake(bool quickReply) { A2_LOG_DEBUG( fmt("Receiving handshake bufferLength=%lu", static_cast(peerConnection_->getBufferLength()))); unsigned char data[BtHandshakeMessage::MESSAGE_LENGTH]; size_t dataLength = BtHandshakeMessage::MESSAGE_LENGTH; if (handshakeSent_ || !quickReply || peerConnection_->getBufferLength() < 48) { if (peerConnection_->receiveHandshake(data, dataLength)) { auto msg = messageFactory_->createHandshakeMessage(data, dataLength); msg->validate(); return msg; } } else { // Handle tracker's NAT-checking feature handshakeSent_ = true; // check info_hash if (memcmp(bittorrent::getInfoHash(downloadContext_), peerConnection_->getBuffer() + 28, INFO_HASH_LENGTH) == 0) { sendHandshake(); } else { throw DL_ABORT_EX( fmt("Bad Info Hash %s", util::toHex(peerConnection_->getBuffer() + 28, INFO_HASH_LENGTH).c_str())); } if (peerConnection_->getBufferLength() == BtHandshakeMessage::MESSAGE_LENGTH && peerConnection_->receiveHandshake(data, dataLength)) { auto msg = messageFactory_->createHandshakeMessage(data, dataLength); msg->validate(); return msg; } } return nullptr; } std::unique_ptr DefaultBtMessageReceiver::receiveAndSendHandshake() { return receiveHandshake(true); } void DefaultBtMessageReceiver::sendHandshake() { dispatcher_->addMessageToQueue(messageFactory_->createHandshakeMessage( bittorrent::getInfoHash(downloadContext_), bittorrent::getStaticPeerId())); dispatcher_->sendMessages(); } std::unique_ptr DefaultBtMessageReceiver::receiveMessage() { size_t dataLength = 0; // Give 0 to PeerConnection::receiveMessage() to prevent memcpy. if (!peerConnection_->receiveMessage(nullptr, dataLength)) { return nullptr; } auto msg = messageFactory_->createBtMessage( peerConnection_->getMsgPayloadBuffer(), dataLength); msg->validate(); if (msg->getId() == BtPieceMessage::ID) { auto piecemsg = static_cast(msg.get()); piecemsg->setMsgPayload(peerConnection_->getMsgPayloadBuffer()); } return msg; } void DefaultBtMessageReceiver::setDownloadContext( DownloadContext* downloadContext) { downloadContext_ = downloadContext; } void DefaultBtMessageReceiver::setPeerConnection(PeerConnection* peerConnection) { peerConnection_ = peerConnection; } void DefaultBtMessageReceiver::setDispatcher(BtMessageDispatcher* dispatcher) { dispatcher_ = dispatcher; } void DefaultBtMessageReceiver::setBtMessageFactory(BtMessageFactory* factory) { messageFactory_ = factory; } } // namespace aria2