Просмотр исходного кода

Eliminated SocketCore::peekData from MSEHandshake.

Tatsuhiro Tsujikawa 14 лет назад
Родитель
Сommit
ce2d401dce

+ 91 - 64
src/InitiatorMSEHandshakeCommand.cc

@@ -56,6 +56,7 @@
 #include "bittorrent_helper.h"
 #include "util.h"
 #include "fmt.h"
+#include "array_fun.h"
 
 namespace aria2 {
 
@@ -89,82 +90,108 @@ InitiatorMSEHandshakeCommand::~InitiatorMSEHandshakeCommand()
 }
 
 bool InitiatorMSEHandshakeCommand::executeInternal() {
-  switch(sequence_) {
-  case INITIATOR_SEND_KEY: {
-    if(!getSocket()->isWritable(0)) {
-      break;
-    }
-    disableWriteCheckSocket();
-    setReadCheckSocket(getSocket());
-    //socket->setBlockingMode();
-    setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
-    mseHandshake_->initEncryptionFacility(true);
-    if(mseHandshake_->sendPublicKey()) {
-      sequence_ = INITIATOR_WAIT_KEY;
-    } else {
-      setWriteCheckSocket(getSocket());
+  if(mseHandshake_->getWantRead()) {
+    mseHandshake_->read();
+  }
+  bool done = false;
+  while(!done) {
+    switch(sequence_) {
+    case INITIATOR_SEND_KEY: {
+      if(!getSocket()->isWritable(0)) {
+        getDownloadEngine()->addCommand(this);
+        return false;
+      }
+      setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
+      mseHandshake_->initEncryptionFacility(true);
+      mseHandshake_->sendPublicKey();
       sequence_ = INITIATOR_SEND_KEY_PENDING;
+      break;
     }
-    break;
-  }
-  case INITIATOR_SEND_KEY_PENDING:
-    if(mseHandshake_->sendPublicKey()) {
-      disableWriteCheckSocket();
-      sequence_ = INITIATOR_WAIT_KEY;
+    case INITIATOR_SEND_KEY_PENDING:
+      if(mseHandshake_->send()) {
+        sequence_ = INITIATOR_WAIT_KEY;
+      } else {
+        done = true;
+      }
+      break;
+    case INITIATOR_WAIT_KEY: {
+      if(mseHandshake_->receivePublicKey()) {
+        mseHandshake_->initCipher
+          (bittorrent::getInfoHash(requestGroup_->getDownloadContext()));;
+        mseHandshake_->sendInitiatorStep2();
+        sequence_ = INITIATOR_SEND_STEP2_PENDING;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  case INITIATOR_WAIT_KEY: {
-    if(mseHandshake_->receivePublicKey()) {
-      mseHandshake_->initCipher
-        (bittorrent::getInfoHash(requestGroup_->getDownloadContext()));;
-      if(mseHandshake_->sendInitiatorStep2()) {
+    case INITIATOR_SEND_STEP2_PENDING:
+      if(mseHandshake_->send()) {
         sequence_ = INITIATOR_FIND_VC_MARKER;
       } else {
-        setWriteCheckSocket(getSocket());
-        sequence_ = INITIATOR_SEND_STEP2_PENDING;
+        done = true;
+      }
+      break;
+    case INITIATOR_FIND_VC_MARKER: {
+      if(mseHandshake_->findInitiatorVCMarker()) {
+        sequence_ = INITIATOR_RECEIVE_PAD_D_LENGTH;
+      } else {
+        done = true;
       }
+      break;
     }
-    break;
-  }
-  case INITIATOR_SEND_STEP2_PENDING:
-    if(mseHandshake_->sendInitiatorStep2()) {
-      disableWriteCheckSocket();
-      sequence_ = INITIATOR_FIND_VC_MARKER;
+    case INITIATOR_RECEIVE_PAD_D_LENGTH: {
+      if(mseHandshake_->receiveInitiatorCryptoSelectAndPadDLength()) {
+        sequence_ = INITIATOR_RECEIVE_PAD_D;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  case INITIATOR_FIND_VC_MARKER: {
-    if(mseHandshake_->findInitiatorVCMarker()) {
-      sequence_ = INITIATOR_RECEIVE_PAD_D_LENGTH;
+    case INITIATOR_RECEIVE_PAD_D: {
+      if(mseHandshake_->receivePad()) {
+        SharedHandle<PeerConnection> peerConnection
+          (new PeerConnection(getCuid(), getPeer(), getSocket()));
+        if(mseHandshake_->getNegotiatedCryptoType() ==
+           MSEHandshake::CRYPTO_ARC4){
+          peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
+                                           mseHandshake_->getDecryptor());
+          size_t buflen = mseHandshake_->getBufferLength();
+          array_ptr<unsigned char> buffer(new unsigned char[buflen]);
+          mseHandshake_->getDecryptor()->decrypt(buffer, buflen,
+                                                 mseHandshake_->getBuffer(),
+                                                 buflen);
+          peerConnection->presetBuffer(buffer, buflen);
+        } else {
+          peerConnection->presetBuffer(mseHandshake_->getBuffer(),
+                                       mseHandshake_->getBufferLength());
+        }
+        PeerInteractionCommand* c =
+          new PeerInteractionCommand
+          (getCuid(), requestGroup_, getPeer(), getDownloadEngine(), btRuntime_,
+           pieceStorage_,
+           peerStorage_,
+           getSocket(),
+           PeerInteractionCommand::INITIATOR_SEND_HANDSHAKE,
+           peerConnection);
+        getDownloadEngine()->addCommand(c);
+        return true;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  }
-  case INITIATOR_RECEIVE_PAD_D_LENGTH: {
-    if(mseHandshake_->receiveInitiatorCryptoSelectAndPadDLength()) {
-      sequence_ = INITIATOR_RECEIVE_PAD_D;
     }
-    break;
   }
-  case INITIATOR_RECEIVE_PAD_D: {
-    if(mseHandshake_->receivePad()) {
-      SharedHandle<PeerConnection> peerConnection
-        (new PeerConnection(getCuid(), getPeer(), getSocket()));
-      if(mseHandshake_->getNegotiatedCryptoType() == MSEHandshake::CRYPTO_ARC4){
-        peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
-                                         mseHandshake_->getDecryptor());
-      }
-      PeerInteractionCommand* c =
-        new PeerInteractionCommand
-        (getCuid(), requestGroup_, getPeer(), getDownloadEngine(), btRuntime_,
-         pieceStorage_,
-         peerStorage_,
-         getSocket(),
-         PeerInteractionCommand::INITIATOR_SEND_HANDSHAKE,
-         peerConnection);
-      getDownloadEngine()->addCommand(c);
-      return true;
-    }
-    break;
+  if(mseHandshake_->getWantRead()) {
+    setReadCheckSocket(getSocket());
+  } else {
+    disableReadCheckSocket();
   }
+  if(mseHandshake_->getWantWrite()) {
+    setWriteCheckSocket(getSocket());
+  } else {
+    disableWriteCheckSocket();
   }
   getDownloadEngine()->addCommand(this);
   return false;

+ 159 - 202
src/MSEHandshake.cc

@@ -71,6 +71,7 @@ MSEHandshake::MSEHandshake
  const Option* op)
   : cuid_(cuid),
     socket_(socket),
+    wantRead_(false),
     option_(op),
     rbufLength_(0),
     socketBuffer_(socket),
@@ -92,16 +93,8 @@ MSEHandshake::~MSEHandshake()
 
 MSEHandshake::HANDSHAKE_TYPE MSEHandshake::identifyHandshakeType()
 {
-  if(!socket_->isReadable(0)) {
-    return HANDSHAKE_NOT_YET;
-  }
-  size_t r = 20-rbufLength_;
-  socket_->readData(rbuf_+rbufLength_, r);
-  if(r == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
-    throw DL_ABORT_EX(EX_EOF_FROM_PEER);
-  }
-  rbufLength_ += r;
   if(rbufLength_ < 20) {
+    wantRead_ = true;
     return HANDSHAKE_NOT_YET;
   }
   if(rbuf_[0] == BtHandshakeMessage::PSTR_LENGTH &&
@@ -126,35 +119,59 @@ void MSEHandshake::initEncryptionFacility(bool initiator)
   initiator_ = initiator;
 }
 
-bool MSEHandshake::sendPublicKey()
+void MSEHandshake::sendPublicKey()
 {
-  if(socketBuffer_.sendBufferIsEmpty()) {
-    A2_LOG_DEBUG(fmt("CUID#%lld - Sending public key.",
-                     cuid_));
-    unsigned char buffer[KEY_LENGTH+MAX_PAD_LENGTH];
-    dh_->getPublicKey(buffer, KEY_LENGTH);
+  A2_LOG_DEBUG(fmt("CUID#%lld - Sending public key.",
+                   cuid_));
+  unsigned char buffer[KEY_LENGTH+MAX_PAD_LENGTH];
+  dh_->getPublicKey(buffer, KEY_LENGTH);
+
+  size_t padLength =
+    SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
+  dh_->generateNonce(buffer+KEY_LENGTH, padLength);
+  socketBuffer_.pushStr(std::string(&buffer[0],
+                                    &buffer[KEY_LENGTH+padLength]));
+}
 
-    size_t padLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
-    dh_->generateNonce(buffer+KEY_LENGTH, padLength);
-    socketBuffer_.pushStr(std::string(&buffer[0],
-                                      &buffer[KEY_LENGTH+padLength]));
+void MSEHandshake::read()
+{
+  if(rbufLength_ >= MAX_BUFFER_LENGTH) {
+    assert(!wantRead_);
+    return;
+  }
+  size_t len = MAX_BUFFER_LENGTH-rbufLength_;
+  socket_->readData(rbuf_+rbufLength_, len);
+  if(len == 0  && !socket_->wantRead() && !socket_->wantWrite()) {
+    // TODO Should we set graceful in peer?
+    throw DL_ABORT_EX(EX_EOF_FROM_PEER);
   }
+  rbufLength_ += len;
+  wantRead_ = false;
+}
+
+bool MSEHandshake::send()
+{
   socketBuffer_.send();
   return socketBuffer_.sendBufferIsEmpty();
 }
 
+void MSEHandshake::shiftBuffer(size_t offset)
+{
+  memmove(rbuf_, rbuf_+offset, rbufLength_-offset);
+  rbufLength_ -= offset;
+}
+
 bool MSEHandshake::receivePublicKey()
 {
-  size_t r = KEY_LENGTH-rbufLength_;
-  if(r > receiveNBytes(r)) {
+  if(rbufLength_ < KEY_LENGTH) {
+    wantRead_ = true;
     return false;
   }
-  A2_LOG_DEBUG(fmt("CUID#%lld - public key received.",
-                   cuid_));
+  A2_LOG_DEBUG(fmt("CUID#%lld - public key received.", cuid_));
   // TODO handle exception. in catch, resbufLength = 0;
-  dh_->computeSecret(secret_, sizeof(secret_), rbuf_, rbufLength_);
-  // reset rbufLength_
-  rbufLength_ = 0;
+  dh_->computeSecret(secret_, sizeof(secret_), rbuf_, KEY_LENGTH);
+  // shift buffer
+  shiftBuffer(KEY_LENGTH);
   return true;
 }
 
@@ -251,109 +268,83 @@ uint16_t MSEHandshake::decodeLength16(const unsigned char* buffer)
   return ntohs(be);
 }
 
-bool MSEHandshake::sendInitiatorStep2()
+void MSEHandshake::sendInitiatorStep2()
 {
-  if(socketBuffer_.sendBufferIsEmpty()) {
-    A2_LOG_DEBUG(fmt("CUID#%lld - Sending negotiation step2.",
-                     cuid_));
-    unsigned char md[20];
-    createReq1Hash(md);
-    socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
-
-    createReq23Hash(md, infoHash_);
-    socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
-
-    {
-      // buffer is filled in this order:
-      //   VC(VC_LENGTH bytes),
-      //   crypto_provide(CRYPTO_BITFIELD_LENGTH bytes),
-      //   len(padC)(2bytes),
-      //   padC(len(padC)bytes <= MAX_PAD_LENGTH),
-      //   len(IA)(2bytes)
-      unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH+2];
-
-      // VC
-      memcpy(buffer, VC, sizeof(VC));
-      // crypto_provide
-      unsigned char cryptoProvide[CRYPTO_BITFIELD_LENGTH];
-      memset(cryptoProvide, 0, sizeof(cryptoProvide));
-      if(option_->get(PREF_BT_MIN_CRYPTO_LEVEL) == V_PLAIN) {
-        cryptoProvide[3] = CRYPTO_PLAIN_TEXT;
-      }
-      cryptoProvide[3] |= CRYPTO_ARC4;
-      memcpy(buffer+VC_LENGTH, cryptoProvide, sizeof(cryptoProvide));
-
-      // len(padC)
-      uint16_t padCLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
-      {
-        uint16_t padCLengthBE = htons(padCLength);
-        memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padCLengthBE,
-               sizeof(padCLengthBE));
-      }
-      // padC
-      memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padCLength);
-      // len(IA)
-      // currently, IA is zero-length.
-      uint16_t iaLength = 0;
-      {
-        uint16_t iaLengthBE = htons(iaLength);
-        memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength,
-               &iaLengthBE,sizeof(iaLengthBE));
-      }
-      encryptAndSendData(buffer,
-                         VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength+2);
-    }
+  A2_LOG_DEBUG(fmt("CUID#%lld - Sending negotiation step2.", cuid_));
+  unsigned char md[20];
+  createReq1Hash(md);
+  socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
+  createReq23Hash(md, infoHash_);
+  socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
+
+  // buffer is filled in this order:
+  //   VC(VC_LENGTH bytes),
+  //   crypto_provide(CRYPTO_BITFIELD_LENGTH bytes),
+  //   len(padC)(2bytes),
+  //   padC(len(padC)bytes <= MAX_PAD_LENGTH),
+  //   len(IA)(2bytes)
+  unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH+2];
+
+  // VC
+  memcpy(buffer, VC, sizeof(VC));
+  // crypto_provide
+  unsigned char cryptoProvide[CRYPTO_BITFIELD_LENGTH];
+  memset(cryptoProvide, 0, sizeof(cryptoProvide));
+  if(option_->get(PREF_BT_MIN_CRYPTO_LEVEL) == V_PLAIN) {
+    cryptoProvide[3] = CRYPTO_PLAIN_TEXT;
+  }
+  cryptoProvide[3] |= CRYPTO_ARC4;
+  memcpy(buffer+VC_LENGTH, cryptoProvide, sizeof(cryptoProvide));
+
+  // len(padC)
+  uint16_t padCLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
+  {
+    uint16_t padCLengthBE = htons(padCLength);
+    memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padCLengthBE,
+           sizeof(padCLengthBE));
+  }
+  // padC
+  memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padCLength);
+  // len(IA)
+  // currently, IA is zero-length.
+  uint16_t iaLength = 0;
+  {
+    uint16_t iaLengthBE = htons(iaLength);
+    memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength,
+           &iaLengthBE,sizeof(iaLengthBE));
   }
-  socketBuffer_.send();
-  return socketBuffer_.sendBufferIsEmpty();
+  encryptAndSendData(buffer,
+                     VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength+2);
 }
 
 // This function reads exactly until the end of VC marker is reached.
 bool MSEHandshake::findInitiatorVCMarker()
 {
   // 616 is synchronization point of initiator
-  size_t r = 616-KEY_LENGTH-rbufLength_;
-  if(!socket_->isReadable(0)) {
-    return false;
-  }
-  socket_->peekData(rbuf_+rbufLength_, r);
-  if(r == 0) {
-    if(socket_->wantRead() || socket_->wantWrite()) {
-      return false;
-    }
-    throw DL_ABORT_EX(EX_EOF_FROM_PEER);
-  }
   // find vc
-  {
-    std::string buf(&rbuf_[0], &rbuf_[rbufLength_+r]);
-    std::string vc(&initiatorVCMarker_[0], &initiatorVCMarker_[VC_LENGTH]);
-    if((markerIndex_ = buf.find(vc)) == std::string::npos) {
-      if(616-KEY_LENGTH <= rbufLength_+r) {
-        throw DL_ABORT_EX("Failed to find VC marker.");
-      } else {
-        socket_->readData(rbuf_+rbufLength_, r);
-        rbufLength_ += r;
-        return false;
-      }
+  std::string buf(&rbuf_[0], &rbuf_[rbufLength_]);
+  std::string vc(&initiatorVCMarker_[0], &initiatorVCMarker_[VC_LENGTH]);
+  if((markerIndex_ = buf.find(vc)) == std::string::npos) {
+    if(616-KEY_LENGTH <= rbufLength_) {
+      throw DL_ABORT_EX("Failed to find VC marker.");
+    } else {
+      wantRead_ = true;
+      return false;
     }
   }
-  assert(markerIndex_+VC_LENGTH-rbufLength_ <= r);
-  size_t toRead = markerIndex_+VC_LENGTH-rbufLength_;
-  socket_->readData(rbuf_+rbufLength_, toRead);
-  rbufLength_ += toRead;
   A2_LOG_DEBUG(fmt("CUID#%lld - VC marker found at %lu",
                    cuid_,
                    static_cast<unsigned long>(markerIndex_)));
   verifyVC(rbuf_+markerIndex_);
-  // reset rbufLength_
-  rbufLength_ = 0;
+  // shift rbuf
+  shiftBuffer(markerIndex_+VC_LENGTH);
   return true;
 }
 
 bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength()
 {
-  size_t r = CRYPTO_BITFIELD_LENGTH+2/* PadD length*/-rbufLength_;
-  if(r > receiveNBytes(r)) {
+  if(CRYPTO_BITFIELD_LENGTH+2/* PadD length*/ > rbufLength_) {
+    wantRead_ = true;
     return false;
   }
   //verifyCryptoSelect
@@ -382,75 +373,57 @@ bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength()
   // padD length
   rbufptr += CRYPTO_BITFIELD_LENGTH;
   padLength_ = verifyPadLength(rbufptr, "PadD");
-  // reset rbufLength_
-  rbufLength_ = 0;
+  // shift rbuf
+  shiftBuffer(CRYPTO_BITFIELD_LENGTH+2/* PadD length*/);
   return true;
 }
 
 bool MSEHandshake::receivePad()
 {
+  if(padLength_ > rbufLength_) {
+    wantRead_ = true;
+    return false;
+  }
   if(padLength_ == 0) {
     return true;
   }
-  size_t r = padLength_-rbufLength_;
-  if(r > receiveNBytes(r)) {
-    return false;
-  }
   unsigned char temp[MAX_PAD_LENGTH];
   decryptor_->decrypt(temp, padLength_, rbuf_, padLength_);
-  // reset rbufLength_
-  rbufLength_ = 0;
+  // shift rbuf_
+  shiftBuffer(padLength_);
   return true;
 }
 
 bool MSEHandshake::findReceiverHashMarker()
 {
   // 628 is synchronization limit of receiver.
-  size_t r = 628-KEY_LENGTH-rbufLength_;
-  if(!socket_->isReadable(0)) {
-    return false;
-  }
-  socket_->peekData(rbuf_+rbufLength_, r);
-  if(r == 0) {
-    if(socket_->wantRead() || socket_->wantWrite()) {
-      return false;
-    }
-    throw DL_ABORT_EX(EX_EOF_FROM_PEER);
-  }
   // find hash('req1', S), S is secret_.
-  {
-    std::string buf(&rbuf_[0], &rbuf_[rbufLength_+r]);
-    unsigned char md[20];
-    createReq1Hash(md);
-    std::string req1(&md[0], &md[sizeof(md)]);
-    if((markerIndex_ = buf.find(req1)) == std::string::npos) {
-      if(628-KEY_LENGTH <= rbufLength_+r) {
-        throw DL_ABORT_EX("Failed to find hash marker.");
-      } else {
-        socket_->readData(rbuf_+rbufLength_, r);
-        rbufLength_ += r;
-        return false;
-      }
+  std::string buf(&rbuf_[0], &rbuf_[rbufLength_]);
+  unsigned char md[20];
+  createReq1Hash(md);
+  std::string req1(&md[0], &md[sizeof(md)]);
+  if((markerIndex_ = buf.find(req1)) == std::string::npos) {
+    if(628-KEY_LENGTH <= rbufLength_) {
+      throw DL_ABORT_EX("Failed to find hash marker.");
+    } else {
+      wantRead_ = true;
+      return false;
     }
   }
-  assert(markerIndex_+20-rbufLength_ <= r);
-  size_t toRead = markerIndex_+20-rbufLength_;
-  socket_->readData(rbuf_+rbufLength_, toRead);
-  rbufLength_ += toRead;
   A2_LOG_DEBUG(fmt("CUID#%lld - Hash marker found at %lu.",
                    cuid_,
                    static_cast<unsigned long>(markerIndex_)));
   verifyReq1Hash(rbuf_+markerIndex_);
-  // reset rbufLength_
-  rbufLength_ = 0;
+  // shift rbuf_
+  shiftBuffer(markerIndex_+20);
   return true;
 }
 
 bool MSEHandshake::receiveReceiverHashAndPadCLength
 (const std::vector<SharedHandle<DownloadContext> >& downloadContexts)
 {
-  size_t r = 20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/-rbufLength_;
-  if(r > receiveNBytes(r)) {
+  if(20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/ > rbufLength_) {
+    wantRead_ = true;
     return false;
   }
   // resolve info hash
@@ -505,23 +478,22 @@ bool MSEHandshake::receiveReceiverHashAndPadCLength
   // decrypt PadC length
   rbufptr += CRYPTO_BITFIELD_LENGTH;
   padLength_ = verifyPadLength(rbufptr, "PadC");
-  // reset rbufLength_
-  rbufLength_ = 0;
+  // shift rbuf_
+  shiftBuffer(20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/);
   return true;
 }
 
 bool MSEHandshake::receiveReceiverIALength()
 {
-  size_t r = 2-rbufLength_;
-  assert(r > 0);
-  if(r > receiveNBytes(r)) {
+  if(2 > rbufLength_) {
+    wantRead_ = true;
     return false;
   }
   iaLength_ = decodeLength16(rbuf_);
-  A2_LOG_DEBUG(fmt("CUID#%lld - len(IA)=%u.",
-                   cuid_, iaLength_));
-  // reset rbufLength_
-  rbufLength_ = 0;
+  // TODO limit iaLength \19...+handshake
+  A2_LOG_DEBUG(fmt("CUID#%lld - len(IA)=%u.", cuid_, iaLength_));
+  // shift rbuf_
+  shiftBuffer(2);
   return true;
 }
 
@@ -530,48 +502,44 @@ bool MSEHandshake::receiveReceiverIA()
   if(iaLength_ == 0) {
     return true;
   }
-  size_t r = iaLength_-rbufLength_;
-  if(r > receiveNBytes(r)) {
+  if(iaLength_ > rbufLength_) {
+    wantRead_ = true;
     return false;
   }
   delete [] ia_;
   ia_ = new unsigned char[iaLength_];
   decryptor_->decrypt(ia_, iaLength_, rbuf_, iaLength_);
   A2_LOG_DEBUG(fmt("CUID#%lld - IA received.", cuid_));
-  // reset rbufLength_
-  rbufLength_ = 0;
+  // shift rbuf_
+  shiftBuffer(iaLength_);
   return true;
 }
 
-bool MSEHandshake::sendReceiverStep2()
+void MSEHandshake::sendReceiverStep2()
 {
-  if(socketBuffer_.sendBufferIsEmpty()) {
-    // buffer is filled in this order:
-    //   VC(VC_LENGTH bytes),
-    //   cryptoSelect(CRYPTO_BITFIELD_LENGTH bytes),
-    //   len(padD)(2bytes),
-    //   padD(len(padD)bytes <= MAX_PAD_LENGTH)
-    unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH];
-    // VC
-    memcpy(buffer, VC, sizeof(VC));
-    // crypto_select
-    unsigned char cryptoSelect[CRYPTO_BITFIELD_LENGTH];
-    memset(cryptoSelect, 0, sizeof(cryptoSelect));
-    cryptoSelect[3] = negotiatedCryptoType_;
-    memcpy(buffer+VC_LENGTH, cryptoSelect, sizeof(cryptoSelect));
-    // len(padD)
-    uint16_t padDLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
-    {
-      uint16_t padDLengthBE = htons(padDLength);
-      memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padDLengthBE,
-             sizeof(padDLengthBE));
-    }
-    // padD, all zeroed
-    memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padDLength);
-    encryptAndSendData(buffer, VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padDLength);
+  // buffer is filled in this order:
+  //   VC(VC_LENGTH bytes),
+  //   cryptoSelect(CRYPTO_BITFIELD_LENGTH bytes),
+  //   len(padD)(2bytes),
+  //   padD(len(padD)bytes <= MAX_PAD_LENGTH)
+  unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH];
+  // VC
+  memcpy(buffer, VC, sizeof(VC));
+  // crypto_select
+  unsigned char cryptoSelect[CRYPTO_BITFIELD_LENGTH];
+  memset(cryptoSelect, 0, sizeof(cryptoSelect));
+  cryptoSelect[3] = negotiatedCryptoType_;
+  memcpy(buffer+VC_LENGTH, cryptoSelect, sizeof(cryptoSelect));
+  // len(padD)
+  uint16_t padDLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
+  {
+    uint16_t padDLengthBE = htons(padDLength);
+    memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padDLengthBE,
+           sizeof(padDLengthBE));
   }
-  socketBuffer_.send();
-  return socketBuffer_.sendBufferIsEmpty();
+  // padD, all zeroed
+  memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padDLength);
+  encryptAndSendData(buffer, VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padDLength);
 }
 
 uint16_t MSEHandshake::verifyPadLength(const unsigned char* padlenbuf, const char* padName)
@@ -609,20 +577,9 @@ void MSEHandshake::verifyReq1Hash(const unsigned char* req1buf)
   }
 }
 
-size_t MSEHandshake::receiveNBytes(size_t bytes)
+bool MSEHandshake::getWantWrite() const
 {
-  size_t r = bytes;
-  if(r > 0) {
-    if(!socket_->isReadable(0)) {
-      return 0;
-    }
-    socket_->readData(rbuf_+rbufLength_, r);
-    if(r == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
-      throw DL_ABORT_EX(EX_EOF_FROM_PEER);
-    }
-    rbufLength_ += r;
-  }
-  return r;
+  return !socketBuffer_.sendBufferIsEmpty();
 }
 
 } // namespace aria2

+ 25 - 6
src/MSEHandshake.h

@@ -83,6 +83,7 @@ private:
 
   cuid_t cuid_;
   SharedHandle<SocketCore> socket_;
+  bool wantRead_;
   const Option* option_;
 
   unsigned char rbuf_[MAX_BUFFER_LENGTH];
@@ -130,8 +131,7 @@ private:
 
   void verifyReq1Hash(const unsigned char* req1buf);
 
-  size_t receiveNBytes(size_t bytes);
-
+  void shiftBuffer(size_t offset);
 public:
   MSEHandshake(cuid_t cuid, const SharedHandle<SocketCore>& socket,
                const Option* op);
@@ -142,13 +142,33 @@ public:
 
   void initEncryptionFacility(bool initiator);
 
-  bool sendPublicKey();
+  // Reads data from Socket. If EOF is reached, throws
+  // RecoverableException.
+  void read();
+
+  // Sends pending data in the send buffer. Returns true if all data
+  // is sent. Otherwise returns false.
+  bool send();
+
+  bool getWantRead() const
+  {
+    return wantRead_;
+  }
+
+  void setWantRead(bool wantRead)
+  {
+    wantRead_ = wantRead;
+  }
+
+  bool getWantWrite() const;
+
+  void sendPublicKey();
 
   bool receivePublicKey();
 
   void initCipher(const unsigned char* infoHash);
 
-  bool sendInitiatorStep2();
+  void sendInitiatorStep2();
 
   bool findInitiatorVCMarker();
 
@@ -165,7 +185,7 @@ public:
 
   bool receiveReceiverIA();
 
-  bool sendReceiverStep2();
+  void sendReceiverStep2();
 
   // returns plain text IA
   const unsigned char* getIA() const
@@ -207,7 +227,6 @@ public:
   {
     return rbufLength_;
   }
-
 };
 
 } // namespace aria2

+ 5 - 12
src/PeerConnection.cc

@@ -110,9 +110,6 @@ void PeerConnection::pushBytes(unsigned char* data, size_t len)
 
 bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
   if(resbufLength_ == 0 && 4 > lenbufLength_) {
-    if(!socket_->isReadable(0)) {
-      return false;
-    }
     // read payload size, 32bit unsigned integer
     size_t remaining = 4-lenbufLength_;
     size_t temp = remaining;
@@ -182,7 +179,7 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
                                       bool peek) {
   assert(BtHandshakeMessage::MESSAGE_LENGTH >= resbufLength_);
   bool retval = true;
-  if(prevPeek_ && !peek && resbufLength_) {
+  if(prevPeek_ && resbufLength_) {
     // We have data in previous peek.
     // There is a chance that socket is readable because of EOF, for example,
     // official bttrack shutdowns socket after sending first 48 bytes of
@@ -194,17 +191,10 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
   } else {
     prevPeek_ = peek;
     size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength_;
-    if(remaining > 0 && !socket_->isReadable(0)) {
-      dataLength = 0;
-      return false;
-    }
     if(remaining > 0) {
       size_t temp = remaining;
       readData(resbuf_+resbufLength_, remaining, encryptionEnabled_);
-      if(remaining == 0) {
-        if(socket_->wantRead() || socket_->wantWrite()) {
-          return false;
-        }
+      if(remaining == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
         // we got EOF
         A2_LOG_DEBUG
           (fmt("CUID#%lld - In PeerConnection::receiveHandshake(), remain=%lu",
@@ -256,6 +246,9 @@ void PeerConnection::presetBuffer(const unsigned char* data, size_t length)
   size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length);
   memcpy(resbuf_, data, nwrite);
   resbufLength_ = length;
+  if(resbufLength_ > 0) {
+    prevPeek_ = true;
+  }
 }
 
 bool PeerConnection::sendBufferIsEmpty() const

+ 5 - 0
src/PeerConnection.h

@@ -117,6 +117,11 @@ public:
     return resbuf_;
   }
 
+  size_t getBufferLength() const
+  {
+    return resbufLength_;
+  }
+
   unsigned char* detachBuffer();
 };
 

+ 66 - 52
src/PeerInteractionCommand.cc

@@ -163,6 +163,11 @@ PeerInteractionCommand::PeerInteractionCommand
     peerConnection.reset(new PeerConnection(cuid, getPeer(), getSocket()));
   } else {
     peerConnection = passedPeerConnection;
+    if(sequence_ == RECEIVER_WAIT_HANDSHAKE &&
+       peerConnection->getBufferLength() > 0) {
+      setStatus(Command::STATUS_ONESHOT_REALTIME);
+      getDownloadEngine()->setNoWait(true);
+    }
   }
 
   SharedHandle<DefaultBtMessageDispatcher> dispatcher
@@ -274,71 +279,80 @@ PeerInteractionCommand::~PeerInteractionCommand() {
 
 bool PeerInteractionCommand::executeInternal() {
   setNoCheck(false);
-  switch(sequence_) {
-  case INITIATOR_SEND_HANDSHAKE:
-    if(!getSocket()->isWritable(0)) {
+  bool done = false;
+  while(!done) {
+    switch(sequence_) {
+    case INITIATOR_SEND_HANDSHAKE:
+      if(!getSocket()->isWritable(0)) {
+        done = true;
+        break;
+      }
+      disableWriteCheckSocket();
+      setReadCheckSocket(getSocket());
+      //socket->setBlockingMode();
+      setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
+      btInteractive_->initiateHandshake();
+      sequence_ = INITIATOR_WAIT_HANDSHAKE;
       break;
-    }
-    disableWriteCheckSocket();
-    setReadCheckSocket(getSocket());
-    //socket->setBlockingMode();
-    setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
-    btInteractive_->initiateHandshake();
-    sequence_ = INITIATOR_WAIT_HANDSHAKE;
-    break;
-  case INITIATOR_WAIT_HANDSHAKE: {
-    if(btInteractive_->countPendingMessage() > 0) {
-      btInteractive_->sendPendingMessage();
+    case INITIATOR_WAIT_HANDSHAKE: {
       if(btInteractive_->countPendingMessage() > 0) {
+        btInteractive_->sendPendingMessage();
+        if(btInteractive_->countPendingMessage() > 0) {
+          done = true;
+          break;
+        }
+      }
+      BtMessageHandle handshakeMessage = btInteractive_->receiveHandshake();
+      if(!handshakeMessage) {
+        done = true;
         break;
       }
-    }
-    BtMessageHandle handshakeMessage = btInteractive_->receiveHandshake();
-    if(!handshakeMessage) {
+      btInteractive_->doPostHandshakeProcessing();
+      sequence_ = WIRED;
       break;
     }
-    btInteractive_->doPostHandshakeProcessing();
-    sequence_ = WIRED;
-    break;
-  }
-  case RECEIVER_WAIT_HANDSHAKE: {
-    BtMessageHandle handshakeMessage =btInteractive_->receiveAndSendHandshake();
-    if(!handshakeMessage) {
+    case RECEIVER_WAIT_HANDSHAKE: {
+      BtMessageHandle handshakeMessage =
+        btInteractive_->receiveAndSendHandshake();
+      if(!handshakeMessage) {
+        done = true;
+        break;
+      }
+      btInteractive_->doPostHandshakeProcessing();
+      sequence_ = WIRED;
       break;
     }
-    btInteractive_->doPostHandshakeProcessing();
-    sequence_ = WIRED;    
-    break;
-  }
-  case WIRED:
-    // See the comment for writable check below.
-    disableWriteCheckSocket();
-
-    btInteractive_->doInteractionProcessing();
-    if(btInteractive_->countReceivedMessageInIteration() > 0) {
-      updateKeepAlive();
-    }
-    if((getPeer()->amInterested() && !getPeer()->peerChoking()) ||
-       btInteractive_->countOutstandingRequest() ||
-       (getPeer()->peerInterested() && !getPeer()->amChoking())) {
+    case WIRED:
+      // See the comment for writable check below.
+      disableWriteCheckSocket();
 
-      // Writable check to avoid slow seeding
-      if(btInteractive_->isSendingMessageInProgress()) {
-        setWriteCheckSocket(getSocket());
+      btInteractive_->doInteractionProcessing();
+      if(btInteractive_->countReceivedMessageInIteration() > 0) {
+        updateKeepAlive();
       }
-
-      if(getDownloadEngine()->getRequestGroupMan()->
-         doesOverallDownloadSpeedExceed() ||
-         requestGroup_->doesDownloadSpeedExceed()) {
-        disableReadCheckSocket();
-        setNoCheck(true);
+      if((getPeer()->amInterested() && !getPeer()->peerChoking()) ||
+         btInteractive_->countOutstandingRequest() ||
+         (getPeer()->peerInterested() && !getPeer()->amChoking())) {
+
+        // Writable check to avoid slow seeding
+        if(btInteractive_->isSendingMessageInProgress()) {
+          setWriteCheckSocket(getSocket());
+        }
+
+        if(getDownloadEngine()->getRequestGroupMan()->
+           doesOverallDownloadSpeedExceed() ||
+           requestGroup_->doesDownloadSpeedExceed()) {
+          disableReadCheckSocket();
+          setNoCheck(true);
+        } else {
+          setReadCheckSocket(getSocket());
+        }
       } else {
-        setReadCheckSocket(getSocket());
+        disableReadCheckSocket();
       }
-    } else {
-      disableReadCheckSocket();
+      done = true;
+      break;
     }
-    break;
   }
   if(btInteractive_->countPendingMessage() > 0) {
     setNoCheck(true);

+ 6 - 1
src/PeerReceiveHandshakeCommand.cc

@@ -67,7 +67,12 @@ PeerReceiveHandshakeCommand::PeerReceiveHandshakeCommand
   : PeerAbstractCommand(cuid, peer, e, s),
     peerConnection_(peerConnection)
 {
-  if(!peerConnection_) {
+  if(peerConnection_) {
+    if(peerConnection_->getBufferLength() > 0) {
+      setStatus(Command::STATUS_ONESHOT_REALTIME);
+      getDownloadEngine()->setNoWait(true);
+    }
+  } else {
     peerConnection_.reset(new PeerConnection(cuid, getPeer(), getSocket()));
   }
 }

+ 116 - 82
src/ReceiverMSEHandshakeCommand.cc

@@ -50,6 +50,7 @@
 #include "RequestGroupMan.h"
 #include "BtRegistry.h"
 #include "DownloadContext.h"
+#include "array_fun.h"
 
 namespace aria2 {
 
@@ -64,6 +65,7 @@ ReceiverMSEHandshakeCommand::ReceiverMSEHandshakeCommand
   mseHandshake_(new MSEHandshake(cuid, s, e->getOption()))
 {
   setTimeout(e->getOption()->getAsInt(PREF_PEER_CONNECTION_TIMEOUT));
+  mseHandshake_->setWantRead(true);
 }
 
 ReceiverMSEHandshakeCommand::~ReceiverMSEHandshakeCommand()
@@ -79,102 +81,125 @@ bool ReceiverMSEHandshakeCommand::exitBeforeExecute()
 
 bool ReceiverMSEHandshakeCommand::executeInternal()
 {
-  switch(sequence_) {
-  case RECEIVER_IDENTIFY_HANDSHAKE: {
-    MSEHandshake::HANDSHAKE_TYPE type = mseHandshake_->identifyHandshakeType();
-    switch(type) {
-    case MSEHandshake::HANDSHAKE_NOT_YET:
-      break;
-    case MSEHandshake::HANDSHAKE_ENCRYPTED:
-      mseHandshake_->initEncryptionFacility(false);
-      sequence_ = RECEIVER_WAIT_KEY;
-      break;
-    case MSEHandshake::HANDSHAKE_LEGACY: {
-      if(getDownloadEngine()->getOption()->getAsBool(PREF_BT_REQUIRE_CRYPTO)) {
-        throw DL_ABORT_EX
-          ("The legacy BitTorrent handshake is not acceptable by the"
-           " preference.");
+  if(mseHandshake_->getWantRead()) {
+    mseHandshake_->read();
+  }
+  bool done = false;
+  while(!done) {
+    switch(sequence_) {
+    case RECEIVER_IDENTIFY_HANDSHAKE: {
+      MSEHandshake::HANDSHAKE_TYPE type =
+        mseHandshake_->identifyHandshakeType();
+      switch(type) {
+      case MSEHandshake::HANDSHAKE_NOT_YET:
+        done = true;
+        break;
+      case MSEHandshake::HANDSHAKE_ENCRYPTED:
+        mseHandshake_->initEncryptionFacility(false);
+        sequence_ = RECEIVER_WAIT_KEY;
+        break;
+      case MSEHandshake::HANDSHAKE_LEGACY: {
+        if(getDownloadEngine()->getOption()->getAsBool(PREF_BT_REQUIRE_CRYPTO)){
+          throw DL_ABORT_EX
+            ("The legacy BitTorrent handshake is not acceptable by the"
+             " preference.");
+        }
+        SharedHandle<PeerConnection> peerConnection
+          (new PeerConnection(getCuid(), getPeer(), getSocket()));
+        peerConnection->presetBuffer(mseHandshake_->getBuffer(),
+                                     mseHandshake_->getBufferLength());
+        Command* c = new PeerReceiveHandshakeCommand(getCuid(),
+                                                     getPeer(),
+                                                     getDownloadEngine(),
+                                                     getSocket(),
+                                                     peerConnection);
+        getDownloadEngine()->addCommand(c);
+        return true;
+      }
+      default:
+        throw DL_ABORT_EX("Not supported handshake type.");
       }
-      SharedHandle<PeerConnection> peerConnection
-        (new PeerConnection(getCuid(), getPeer(), getSocket()));
-      peerConnection->presetBuffer(mseHandshake_->getBuffer(),
-                                   mseHandshake_->getBufferLength());
-      Command* c = new PeerReceiveHandshakeCommand(getCuid(),
-                                                   getPeer(),
-                                                   getDownloadEngine(),
-                                                   getSocket(),
-                                                   peerConnection);
-      getDownloadEngine()->addCommand(c);
-      return true;
+      break;
     }
-    default:
-      throw DL_ABORT_EX("Not supported handshake type.");
+    case RECEIVER_WAIT_KEY: {
+      if(mseHandshake_->receivePublicKey()) {
+        mseHandshake_->sendPublicKey();
+        sequence_ = RECEIVER_SEND_KEY_PENDING;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  }
-  case RECEIVER_WAIT_KEY: {
-    if(mseHandshake_->receivePublicKey()) {
-      if(mseHandshake_->sendPublicKey()) {
+    case RECEIVER_SEND_KEY_PENDING:
+      if(mseHandshake_->send()) {
         sequence_ = RECEIVER_FIND_HASH_MARKER;
       } else {
-        setWriteCheckSocket(getSocket());
-        sequence_ = RECEIVER_SEND_KEY_PENDING;
+        done = true;
       }
+      break;
+    case RECEIVER_FIND_HASH_MARKER: {
+      if(mseHandshake_->findReceiverHashMarker()) {
+        sequence_ = RECEIVER_RECEIVE_PAD_C_LENGTH;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  }
-  case RECEIVER_SEND_KEY_PENDING:
-    if(mseHandshake_->sendPublicKey()) {
-      disableWriteCheckSocket();
-      sequence_ = RECEIVER_FIND_HASH_MARKER;
-    }
-    break;
-  case RECEIVER_FIND_HASH_MARKER: {
-    if(mseHandshake_->findReceiverHashMarker()) {
-      sequence_ = RECEIVER_RECEIVE_PAD_C_LENGTH;
+    case RECEIVER_RECEIVE_PAD_C_LENGTH: {
+      std::vector<SharedHandle<DownloadContext> > downloadContexts;
+      getDownloadEngine()->getBtRegistry()->getAllDownloadContext
+        (std::back_inserter(downloadContexts));
+      if(mseHandshake_->receiveReceiverHashAndPadCLength(downloadContexts)) {
+        sequence_ = RECEIVER_RECEIVE_PAD_C;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  }
-  case RECEIVER_RECEIVE_PAD_C_LENGTH: {
-    std::vector<SharedHandle<DownloadContext> > downloadContexts;
-    getDownloadEngine()->getBtRegistry()->getAllDownloadContext
-      (std::back_inserter(downloadContexts));
-    if(mseHandshake_->receiveReceiverHashAndPadCLength(downloadContexts)) {
-      sequence_ = RECEIVER_RECEIVE_PAD_C;
+    case RECEIVER_RECEIVE_PAD_C: {
+      if(mseHandshake_->receivePad()) {
+        sequence_ = RECEIVER_RECEIVE_IA_LENGTH;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  }
-  case RECEIVER_RECEIVE_PAD_C: {
-    if(mseHandshake_->receivePad()) {
-      sequence_ = RECEIVER_RECEIVE_IA_LENGTH;
+    case RECEIVER_RECEIVE_IA_LENGTH: {
+      if(mseHandshake_->receiveReceiverIALength()) {
+        sequence_ = RECEIVER_RECEIVE_IA;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  }
-  case RECEIVER_RECEIVE_IA_LENGTH: {
-    if(mseHandshake_->receiveReceiverIALength()) {
-      sequence_ = RECEIVER_RECEIVE_IA;
+    case RECEIVER_RECEIVE_IA: {
+      if(mseHandshake_->receiveReceiverIA()) {
+        mseHandshake_->sendReceiverStep2();
+        sequence_ = RECEIVER_SEND_STEP2_PENDING;
+      } else {
+        done = true;
+      }
+      break;
     }
-    break;
-  }
-  case RECEIVER_RECEIVE_IA: {
-    if(mseHandshake_->receiveReceiverIA()) {
-      if(mseHandshake_->sendReceiverStep2()) {
+    case RECEIVER_SEND_STEP2_PENDING:
+      if(mseHandshake_->send()) {
         createCommand();
         return true;
       } else {
-        setWriteCheckSocket(getSocket());
-        sequence_ = RECEIVER_SEND_STEP2_PENDING;
+        done = true;
       }
+      break;
     }
-    break;
   }
-  case RECEIVER_SEND_STEP2_PENDING:
-    if(mseHandshake_->sendReceiverStep2()) {
-      disableWriteCheckSocket();
-      createCommand();
-      return true;
-    }
-    break;
+  if(mseHandshake_->getWantRead()) {
+    setReadCheckSocket(getSocket());
+  } else {
+    disableReadCheckSocket();
+  }
+  if(mseHandshake_->getWantWrite()) {
+    setWriteCheckSocket(getSocket());
+  } else {
+    disableWriteCheckSocket();
   }
   getDownloadEngine()->addCommand(this);
   return false;
@@ -188,10 +213,19 @@ void ReceiverMSEHandshakeCommand::createCommand()
     peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
                                      mseHandshake_->getDecryptor());
   }
-  if(mseHandshake_->getIALength() > 0) {
-    peerConnection->presetBuffer(mseHandshake_->getIA(),
-                                 mseHandshake_->getIALength());
+  size_t buflen = mseHandshake_->getIALength()+mseHandshake_->getBufferLength();
+  array_ptr<unsigned char> buffer(new unsigned char[buflen]);
+  memcpy(buffer, mseHandshake_->getIA(), mseHandshake_->getIALength());
+  if(mseHandshake_->getNegotiatedCryptoType() == MSEHandshake::CRYPTO_ARC4) {
+    mseHandshake_->getDecryptor()->decrypt(buffer+mseHandshake_->getIALength(),
+                                           mseHandshake_->getBufferLength(),
+                                           mseHandshake_->getBuffer(),
+                                           mseHandshake_->getBufferLength());
+  } else {
+    memcpy(buffer+mseHandshake_->getIALength(),
+           mseHandshake_->getBuffer(), mseHandshake_->getBufferLength());
   }
+  peerConnection->presetBuffer(buffer, buflen);
   // TODO add mseHandshake_->getInfoHash() to PeerReceiveHandshakeCommand
   // as a hint. If this info hash and one in BitTorrent Handshake does not
   // match, then drop connection.

+ 42 - 11
test/MSEHandshakeTest.cc

@@ -70,26 +70,57 @@ createSocketPair()
 void MSEHandshakeTest::doHandshake(const SharedHandle<MSEHandshake>& initiator, const SharedHandle<MSEHandshake>& receiver)
 {
   initiator->sendPublicKey();
-
-  while(!receiver->receivePublicKey());
+  while(initiator->getWantWrite()) {
+    initiator->send();
+  }
+  while(!receiver->receivePublicKey()) {
+    receiver->read();
+  }
   receiver->sendPublicKey();
+  while(receiver->getWantWrite()) {
+    receiver->send();
+  }
 
-  while(!initiator->receivePublicKey());
+  while(!initiator->receivePublicKey()) {
+    initiator->read();
+  }
   initiator->initCipher(bittorrent::getInfoHash(dctx_));
   initiator->sendInitiatorStep2();
+  while(initiator->getWantWrite()) {
+    initiator->send();
+  }
 
-  while(!receiver->findReceiverHashMarker());
+  while(!receiver->findReceiverHashMarker()) {
+    receiver->read();
+  }
   std::vector<SharedHandle<DownloadContext> > contexts;
   contexts.push_back(dctx_);
-  while(!receiver->receiveReceiverHashAndPadCLength(contexts));
-  while(!receiver->receivePad());
-  while(!receiver->receiveReceiverIALength());
-  while(!receiver->receiveReceiverIA());
+  while(!receiver->receiveReceiverHashAndPadCLength(contexts)) {
+    receiver->read();
+  }
+  while(!receiver->receivePad()) {
+    receiver->read();
+  }
+  while(!receiver->receiveReceiverIALength()) {
+    receiver->read();
+  }
+  while(!receiver->receiveReceiverIA()) {
+    receiver->read();
+  }
   receiver->sendReceiverStep2();
+  while(receiver->getWantWrite()) {
+    receiver->send();
+  }
 
-  while(!initiator->findInitiatorVCMarker());
-  while(!initiator->receiveInitiatorCryptoSelectAndPadDLength());
-  while(!initiator->receivePad());
+  while(!initiator->findInitiatorVCMarker()) {
+    initiator->read();
+  }
+  while(!initiator->receiveInitiatorCryptoSelectAndPadDLength()) {
+    initiator->read();
+  }
+  while(!initiator->receivePad()) {
+    initiator->read();
+  }
 }
 
 namespace {