Преглед на файлове

Rewritten PeerConnection::receiveMessage()

The old implementation calls at least 2 read(2) (4bytes length and
payload) to receive the message. This change will read as many bytes
as possible in one read(2) call. BtPieceMessage::data_ is now just a
const pointer to the internal buffer of PeerConnection.
Tatsuhiro Tsujikawa преди 13 години
родител
ревизия
e816c5eee4
променени са 7 файла, в които са добавени 141 реда и са изтрити 95 реда
  1. 2 5
      src/BtPieceMessage.cc
  2. 4 5
      src/BtPieceMessage.h
  3. 3 2
      src/DefaultBtMessageReceiver.cc
  4. 109 71
      src/PeerConnection.cc
  5. 19 10
      src/PeerConnection.h
  6. 2 0
      src/SocketRecvBuffer.h
  7. 2 2
      test/PeerConnectionTest.cc

+ 2 - 5
src/BtPieceMessage.cc

@@ -73,13 +73,10 @@ BtPieceMessage::BtPieceMessage
 }
 
 BtPieceMessage::~BtPieceMessage()
-{
-  delete [] data_;
-}
+{}
 
-void BtPieceMessage::setRawMessage(unsigned char* data)
+void BtPieceMessage::setMsgPayload(const unsigned char* data)
 {
-  delete [] data_;
   data_ = data;
 }
 

+ 4 - 5
src/BtPieceMessage.h

@@ -51,7 +51,7 @@ private:
   size_t index_;
   int32_t begin_;
   int32_t blockLength_;
-  unsigned char* data_;
+  const unsigned char* data_;
   SharedHandle<DownloadContext> downloadContext_;
   SharedHandle<PeerStorage> peerStorage_;
 
@@ -87,10 +87,9 @@ public:
 
   int32_t getBlockLength() const { return blockLength_; }
 
-  // Stores raw message data. After this function call, this object
-  // has ownership of data. Caller must not be free or alter data.
-  // Member block is pointed to block starting position in data.
-  void setRawMessage(unsigned char* data);
+  // Sets message payload data. Caller must not change or free data
+  // before doReceivedAction().
+  void setMsgPayload(const unsigned char* data);
 
   void setBlockLength(int32_t blockLength) { blockLength_ = blockLength; }
 

+ 3 - 2
src/DefaultBtMessageReceiver.cc

@@ -121,12 +121,13 @@ BtMessageHandle DefaultBtMessageReceiver::receiveMessage() {
     return SharedHandle<BtMessage>();
   }
   BtMessageHandle msg =
-    messageFactory_->createBtMessage(peerConnection_->getBuffer(), dataLength);
+    messageFactory_->createBtMessage(peerConnection_->getMsgPayloadBuffer(),
+                                     dataLength);
   msg->validate();
   if(msg->getId() == BtPieceMessage::ID) {
     SharedHandle<BtPieceMessage> piecemsg =
       static_pointer_cast<BtPieceMessage>(msg);
-    piecemsg->setRawMessage(peerConnection_->detachBuffer());
+    piecemsg->setMsgPayload(peerConnection_->getMsgPayloadBuffer());
   }
   return msg;
 }

+ 109 - 71
src/PeerConnection.cc

@@ -52,16 +52,29 @@
 
 namespace aria2 {
 
+namespace {
+enum {
+  // Before reading first byte of message length
+  BT_MSG_PREV_READ_LENGTH,
+  // Reading 4 bytes message length
+  BT_MSG_READ_LENGTH,
+  // Reading message payload following message length
+  BT_MSG_READ_PAYLOAD
+};
+} // namespace
+
 PeerConnection::PeerConnection
 (cuid_t cuid, const SharedHandle<Peer>& peer, const SocketHandle& socket)
   : cuid_(cuid),
     peer_(peer),
     socket_(socket),
-    maxPayloadLength_(MAX_PAYLOAD_LEN),
-    resbuf_(new unsigned char[maxPayloadLength_]),
+    msgState_(BT_MSG_PREV_READ_LENGTH),
+    bufferCapacity_(MAX_BUFFER_CAPACITY),
+    resbuf_(new unsigned char[bufferCapacity_]),
     resbufLength_(0),
     currentPayloadLength_(0),
-    lenbufLength_(0),
+    resbufOffset_(0),
+    msgOffset_(0),
     socketBuffer_(socket),
     encryptionEnabled_(false),
     prevPeek_(false)
@@ -80,71 +93,98 @@ void PeerConnection::pushBytes(unsigned char* data, size_t len)
   socketBuffer_.pushBytes(data, len);
 }
 
-bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
-  if(resbufLength_ == 0 && 4 > lenbufLength_) {
-    // read payload size, 32bit unsigned integer
-    size_t remaining = 4-lenbufLength_;
-    size_t temp = remaining;
-    readData(lenbuf_+lenbufLength_, remaining, encryptionEnabled_);
-    if(remaining == 0) {
-      if(socket_->wantRead() || socket_->wantWrite()) {
-        return false;
+bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength)
+{
+  while(1) {
+    bool done = false;
+    size_t i;
+    for(i = resbufOffset_; i < resbufLength_ && !done; ++i) {
+      unsigned char c = resbuf_[i];
+      switch(msgState_) {
+      case(BT_MSG_PREV_READ_LENGTH):
+        msgOffset_ = i;
+        currentPayloadLength_ = 0;
+        msgState_ = BT_MSG_READ_LENGTH;
+        // Fall through
+      case(BT_MSG_READ_LENGTH):
+        currentPayloadLength_ <<= 8;
+        currentPayloadLength_ += c;
+        // The message length is uint32_t
+        if(i - msgOffset_ == 3) {
+          if(currentPayloadLength_ + 4 > bufferCapacity_) {
+            throw DL_ABORT_EX(fmt(EX_TOO_LONG_PAYLOAD, currentPayloadLength_));
+          }
+          if(currentPayloadLength_ == 0) {
+            // Length == 0 means keep-alive message.
+            done = true;
+            msgState_ = BT_MSG_PREV_READ_LENGTH;
+          } else {
+            msgState_ = BT_MSG_READ_PAYLOAD;
+          }
+        }
+        break;
+      case(BT_MSG_READ_PAYLOAD):
+        // We chosen the bufferCapacity_ so that whole message,
+        // including 4 bytes length and payload, in it. So here we
+        // just make sure that it happens.
+        if(resbufLength_ - msgOffset_ >= 4 + currentPayloadLength_) {
+          i = msgOffset_ + 4 + currentPayloadLength_ - 1;
+          done = true;
+          msgState_ = BT_MSG_PREV_READ_LENGTH;
+        } else {
+          // We need another read.
+          i = resbufLength_-1;
+        }
+        break;
       }
-      // we got EOF
-      A2_LOG_DEBUG(fmt("CUID#%lld - In PeerConnection::receiveMessage(),"
-                       " remain=%lu",
-                       cuid_,
-                       static_cast<unsigned long>(temp)));
-      peer_->setDisconnectedGracefully(true);
-      throw DL_ABORT_EX(EX_EOF_FROM_PEER);
     }
-    lenbufLength_ += remaining;
-    if(4 > lenbufLength_) {
-      // still 4-lenbufLength_ bytes to go
-      return false;
-    }
-    uint32_t payloadLength;
-    memcpy(&payloadLength, lenbuf_, sizeof(payloadLength));
-    payloadLength = ntohl(payloadLength);
-    if(payloadLength > maxPayloadLength_) {
-      throw DL_ABORT_EX(fmt(EX_TOO_LONG_PAYLOAD, payloadLength));
-    }
-    currentPayloadLength_ = payloadLength;
-  }
-  if(!socket_->isReadable(0)) {
-    return false;
-  }
-  // we have currentPayloadLen-resbufLen bytes to read
-  size_t remaining = currentPayloadLength_-resbufLength_;
-  size_t temp = remaining;
-  if(remaining > 0) {
-    readData(resbuf_+resbufLength_, remaining, encryptionEnabled_);
-    if(remaining == 0) {
-      if(socket_->wantRead() || socket_->wantWrite()) {
-        return false;
+    resbufOffset_ = i;
+    if(done) {
+      if(data) {
+        memcpy(data, resbuf_ + msgOffset_ + 4, currentPayloadLength_);
+      }
+      dataLength = currentPayloadLength_;
+      return true;
+    } else {
+      assert(resbufOffset_ == resbufLength_);
+      if(resbufLength_ != 0) {
+        if(msgOffset_ == 0 && resbufLength_ == currentPayloadLength_ + 4) {
+          // All bytes in buffer have been processed, so clear it
+          // away.
+          resbufLength_ = 0;
+          resbufOffset_ = 0;
+          msgOffset_ = 0;
+        } else {
+          // Shift buffer so that resbuf_[msgOffset_] moves to
+          // rebuf_[0].
+          memmove(resbuf_, resbuf_ + msgOffset_, resbufLength_ - msgOffset_);
+          resbufLength_ -= msgOffset_;
+          resbufOffset_ = resbufLength_;
+          msgOffset_ = 0;
+        }
+      }
+      size_t nread;
+      // To reduce the amount of copy involved in buffer shift, large
+      // payload will be read exactly.
+      if(currentPayloadLength_ > 4096) {
+        nread = currentPayloadLength_ + 4 - resbufLength_;
+      } else {
+        nread = bufferCapacity_ - resbufLength_;
+      }
+      readData(resbuf_+resbufLength_, nread, encryptionEnabled_);
+      if(nread == 0) {
+        if(socket_->wantRead() || socket_->wantWrite()) {
+          break;
+        } else {
+          peer_->setDisconnectedGracefully(true);
+          throw DL_ABORT_EX(EX_EOF_FROM_PEER);
+        }
+      } else {
+        resbufLength_ += nread;
       }
-      // we got EOF
-      A2_LOG_DEBUG(fmt("CUID#%lld - In PeerConnection::receiveMessage(),"
-                       " payloadlen=%lu, remaining=%lu",
-                       cuid_,
-                       static_cast<unsigned long>(currentPayloadLength_),
-                       static_cast<unsigned long>(temp)));
-      peer_->setDisconnectedGracefully(true);
-      throw DL_ABORT_EX(EX_EOF_FROM_PEER);
-    }
-    resbufLength_ += remaining;
-    if(currentPayloadLength_ > resbufLength_) {
-      return false;
     }
   }
-  // we got whole payload.
-  resbufLength_ = 0;
-  lenbufLength_ = 0;
-  if(data) {
-    memcpy(data, resbuf_, currentPayloadLength_);
-  }
-  dataLength = currentPayloadLength_;
-  return true;
+  return false;
 }
 
 bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
@@ -202,7 +242,7 @@ void PeerConnection::enableEncryption
 
 void PeerConnection::presetBuffer(const unsigned char* data, size_t length)
 {
-  size_t nwrite = std::min(maxPayloadLength_, length);
+  size_t nwrite = std::min(bufferCapacity_, length);
   memcpy(resbuf_, data, nwrite);
   resbufLength_ = length;
 }
@@ -219,18 +259,16 @@ ssize_t PeerConnection::sendPendingData()
   return writtenLength;
 }
 
-unsigned char* PeerConnection::detachBuffer()
+const unsigned char* PeerConnection::getMsgPayloadBuffer() const
 {
-  unsigned char* detachbuf = resbuf_;
-  resbuf_ = new unsigned char[maxPayloadLength_];
-  return detachbuf;
+  return resbuf_ + msgOffset_ + 4;
 }
 
 void PeerConnection::reserveBuffer(size_t minSize)
 {
-  if(maxPayloadLength_ < minSize) {
-    maxPayloadLength_ = minSize;
-    unsigned char *buf = new unsigned char[maxPayloadLength_];
+  if(bufferCapacity_ < minSize) {
+    bufferCapacity_ = minSize;
+    unsigned char *buf = new unsigned char[bufferCapacity_];
     memcpy(buf, resbuf_, resbufLength_);
     delete [] resbuf_;
     resbuf_ = buf;

+ 19 - 10
src/PeerConnection.h

@@ -49,9 +49,10 @@ class Peer;
 class SocketCore;
 class ARC4Encryptor;
 
-// The maximum length of payload. Messages beyond that length are
+// The maximum length of buffer. If the message length (including 4
+// bytes length and payload length) is larger than this value, it is
 // dropped.
-#define MAX_PAYLOAD_LEN (16*1024+128)
+#define MAX_BUFFER_CAPACITY (16*1024+128)
 
 class PeerConnection {
 private:
@@ -59,13 +60,19 @@ private:
   SharedHandle<Peer> peer_;
   SharedHandle<SocketCore> socket_;
 
-  // Maximum payload length
-  size_t maxPayloadLength_;
+  int msgState_;
+  // The capacity of the buffer resbuf_
+  size_t bufferCapacity_;
+  // The internal buffer of incoming handshakes and messages
   unsigned char* resbuf_;
+  // The number of bytes written in resbuf_
   size_t resbufLength_;
-  size_t currentPayloadLength_;
-  unsigned char lenbuf_[4];
-  size_t lenbufLength_;
+  // The length of message (not handshake) currently receiving
+  uint32_t currentPayloadLength_;
+  // The number of bytes processed in resbuf_
+  size_t resbufOffset_;
+  // The offset in resbuf_ where the 4 bytes message length begins
+  size_t msgOffset_;
 
   SocketBuffer socketBuffer_;
 
@@ -123,15 +130,17 @@ public:
     return resbufLength_;
   }
 
-  unsigned char* detachBuffer();
+  // Returns the pointer to the message in wire format.  This method
+  // must be called after receiveMessage() returned true.
+  const unsigned char* getMsgPayloadBuffer() const;
 
   // Reserves buffer at least minSize. Reallocate memory if current
   // buffer length < minSize
   void reserveBuffer(size_t minSize);
 
-  size_t getMaxPayloadLength()
+  size_t getBufferCapacity()
   {
-    return maxPayloadLength_;
+    return bufferCapacity_;
   }
 };
 

+ 2 - 0
src/SocketRecvBuffer.h

@@ -79,6 +79,8 @@ public:
   {
     return bufLen_ == 0;
   }
+
+  void pushBuffer(const unsigned char* data, size_t len);
 private:
   SharedHandle<SocketCore> socket_;
   size_t capacity_;

+ 2 - 2
test/PeerConnectionTest.cc

@@ -23,13 +23,13 @@ CPPUNIT_TEST_SUITE_REGISTRATION(PeerConnectionTest);
 void PeerConnectionTest::testReserveBuffer() {
   PeerConnection con(1, SharedHandle<Peer>(), SharedHandle<SocketCore>());
   con.presetBuffer((unsigned char*)"foo", 3);
-  CPPUNIT_ASSERT_EQUAL((size_t)MAX_PAYLOAD_LEN, con.getMaxPayloadLength());
+  CPPUNIT_ASSERT_EQUAL((size_t)MAX_BUFFER_CAPACITY, con.getBufferCapacity());
   CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength());
 
   size_t newLength = 32*1024;
   con.reserveBuffer(newLength);
 
-  CPPUNIT_ASSERT_EQUAL(newLength, con.getMaxPayloadLength());
+  CPPUNIT_ASSERT_EQUAL(newLength, con.getBufferCapacity());
   CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength());
   CPPUNIT_ASSERT(memcmp("foo", con.getBuffer(), 3) == 0);
 }