Przeglądaj źródła

Merge pull request #772 from aria2/refactor-wintls-write

WinTLS: Rewrite writeData
Tatsuhiro Tsujikawa 9 lat temu
rodzic
commit
9df50804d4
2 zmienionych plików z 182 dodań i 125 usunięć
  1. 142 121
      src/WinTLSSession.cc
  2. 40 4
      src/WinTLSSession.h

+ 142 - 121
src/WinTLSSession.cc

@@ -86,16 +86,6 @@ static const ULONG kReqAFlags =
     ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
     ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
 
-class TLSBuffer : public ::SecBuffer {
-public:
-  explicit TLSBuffer(ULONG type, ULONG size, void* data)
-  {
-    cbBuffer = size;
-    BufferType = type;
-    pvBuffer = data;
-  }
-};
-
 class TLSBufferDesc : public ::SecBufferDesc {
 public:
   explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
@@ -142,7 +132,8 @@ WinTLSSession::WinTLSSession(WinTLSContext* ctx)
       cred_(ctx->getCredHandle()),
       writeBuffered_(0),
       state_(st_constructed),
-      status_(SEC_E_OK)
+      status_(SEC_E_OK),
+      recordBytesSent_(0)
 {
   memset(&handle_, 0, sizeof(handle_));
 }
@@ -213,7 +204,8 @@ int WinTLSSession::closeConnection()
       status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0,
                                         &handle_, &desc, &flags, nullptr);
     }
-    if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) {
+    if ((status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) &&
+        getLeftTLSRecordSize() == 0) {
       size_t len = ctx.cbBuffer;
       ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
       ::FreeContextBuffer(ctx.pvBuffer);
@@ -228,17 +220,6 @@ int WinTLSSession::closeConnection()
     }
   }
 
-  // Send remaining data.
-  while (writeBuf_.size()) {
-    int rv = writeData(nullptr, 0);
-    if (rv == 0) {
-      break;
-    }
-    if (rv < 0) {
-      return rv;
-    }
-  }
-
   A2_LOG_DEBUG("WinTLS: Closed Connection");
   state_ = st_closed;
   return TLS_ERR_OK;
@@ -255,12 +236,82 @@ int WinTLSSession::checkDirection()
   if (readBuf_.size() || decBuf_.size()) {
     return TLS_WANT_READ;
   }
-  if (writeBuf_.size()) {
+  if (getLeftTLSRecordSize() || writeBuf_.size()) {
     return TLS_WANT_WRITE;
   }
   return TLS_WANT_READ;
 }
 
+namespace {
+// Fills |iov| of length |len| to send remaining data in |buffers|.
+// We have already sent |offset| bytes.  This function returns the
+// number of |iov| filled.  It assumes the array |buffers| is at least
+// |len| elements.
+size_t fillSendIOV(a2iovec* iov, size_t len, TLSBuffer* buffers, size_t offset)
+{
+  size_t iovcnt = 0;
+  for (size_t i = 0; i < len; ++i) {
+    if (offset < buffers[i].cbBuffer) {
+      iov[iovcnt].A2IOVEC_BASE =
+          static_cast<char*>(buffers[i].pvBuffer) + offset;
+      iov[iovcnt].A2IOVEC_LEN = buffers[i].cbBuffer - offset;
+      ++iovcnt;
+      offset = 0;
+    }
+    else {
+      offset -= buffers[i].cbBuffer;
+    }
+  }
+  return iovcnt;
+}
+} // namespace
+
+size_t WinTLSSession::getLeftTLSRecordSize() const
+{
+  return sendRecordBuffers_[0].cbBuffer + sendRecordBuffers_[1].cbBuffer +
+         sendRecordBuffers_[2].cbBuffer - recordBytesSent_;
+}
+
+int WinTLSSession::sendTLSRecord()
+{
+  A2_LOG_DEBUG(fmt("WinTLS: TLS record %" PRIu64 " bytes left",
+                   static_cast<uint64_t>(getLeftTLSRecordSize())));
+
+  while (getLeftTLSRecordSize()) {
+    std::array<a2iovec, 3> iov;
+    auto iovcnt = fillSendIOV(iov.data(), iov.size(), sendRecordBuffers_.data(),
+                              recordBytesSent_);
+
+    DWORD nwrite;
+    auto rv =
+        WSASend(sockfd_, iov.data(), iovcnt, &nwrite, 0, nullptr, nullptr);
+    if (rv != 0) {
+      auto errnum = ::WSAGetLastError();
+      if (errnum == WSAEINTR) {
+        continue;
+      }
+
+      if (errnum == WSAEWOULDBLOCK) {
+        return TLS_ERR_WOULDBLOCK;
+      }
+
+      A2_LOG_ERROR("WinTLS: Connection error while writing");
+      status_ = SEC_E_INCOMPLETE_MESSAGE;
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+    }
+
+    recordBytesSent_ += nwrite;
+  }
+
+  recordBytesSent_ = 0;
+  sendRecordBuffers_[0].cbBuffer = 0;
+  sendRecordBuffers_[1].cbBuffer = 0;
+  sendRecordBuffers_[2].cbBuffer = 0;
+
+  return 0;
+}
+
 ssize_t WinTLSSession::writeData(const void* data, size_t len)
 {
   if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
@@ -281,45 +332,15 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
   }
 
   A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
-                   (uint64_t)len, (uint64_t)writeBuf_.size()));
-
-  // Write remaining buffered data, if any.
-  while (writeBuf_.size()) {
-    auto written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
-    errno = ::WSAGetLastError();
-    if (written < 0 && errno == WSAEINTR) {
-      continue;
-    }
-    if (written < 0 && errno == WSAEWOULDBLOCK) {
-      return TLS_ERR_WOULDBLOCK;
-    }
-    if (written == 0) {
-      return written;
-    }
-    if (written < 0) {
-      status_ = SEC_E_INVALID_HANDLE;
-      state_ = st_error;
-      return TLS_ERR_ERROR;
-    }
-    writeBuf_.eat(written);
-  }
-
-  if (len == 0) {
-    return 0;
-  }
+                   (uint64_t)len, (uint64_t)recordBytesSent_));
 
-  if (!streamSizes_) {
-    streamSizes_.reset(new SecPkgContext_StreamSizes());
-    status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
-                                       streamSizes_.get());
-    if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
-      state_ = st_error;
-      return TLS_ERR_ERROR;
-    }
+  auto rv = sendTLSRecord();
+  if (rv != 0) {
+    return rv;
   }
 
-  size_t process = len;
-  auto bytes = reinterpret_cast<const char*>(data);
+  auto left = len;
+  auto bytes = static_cast<const char*>(data);
   if (writeBuffered_) {
     // There was buffered data, hence we need to "remove" that data from the
     // incoming buffer to avoid writing it again
@@ -331,10 +352,10 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
     }
     // just advance the buffer by writeBuffered_ bytes
     bytes += writeBuffered_;
-    process -= writeBuffered_;
+    left -= writeBuffered_;
     writeBuffered_ = 0;
   }
-  if (!process) {
+  if (!left) {
     // The buffer contained the full remainder. At this point, the buffer has
     // been written, so the request is done in its entirety;
     return len;
@@ -342,23 +363,25 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
 
   // Buffered data was already written ;)
   // If there was no buffered data, this will be len - len = 0.
-  len = len - process;
-  while (process) {
+  len -= left;
+  while (left) {
     // Set up an outgoing message, according to streamSizes_
-    writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
-    size_t dl =
-        streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer;
-    auto buf = make_unique<char[]>(dl);
-    TLSBuffer buffers[] = {
-        TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
+    writeBuffered_ =
+        std::min(left, static_cast<size_t>(streamSizes_.cbMaximumMessage));
+
+    sendRecordBuffers_ = {
+        TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_.cbHeader,
+                  sendBuffer_.data()),
         TLSBuffer(SECBUFFER_DATA, writeBuffered_,
-                  buf.get() + streamSizes_->cbHeader),
-        TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
-                  buf.get() + streamSizes_->cbHeader + writeBuffered_),
+                  sendBuffer_.data() + streamSizes_.cbHeader),
+        TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_.cbTrailer,
+                  sendBuffer_.data() + streamSizes_.cbHeader + writeBuffered_),
         TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
     };
-    TLSBufferDesc desc(buffers, 4);
-    memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
+
+    TLSBufferDesc desc(sendRecordBuffers_.data(), sendRecordBuffers_.size());
+    std::copy_n(bytes, writeBuffered_,
+                static_cast<char*>(sendRecordBuffers_[1].pvBuffer));
     status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
     if (status_ != SEC_E_OK) {
       A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
@@ -367,61 +390,32 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
       return TLS_ERR_ERROR;
     }
 
-    // EncryptMessage may have truncated the buffers.
-    // Should rarely happen, if ever, except for the trailer.
-    dl = buffers[0].cbBuffer;
-    if (dl < streamSizes_->cbHeader) {
-      // Move message.
-      memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
-    }
-    dl += buffers[1].cbBuffer;
-    if (dl < streamSizes_->cbHeader + writeBuffered_) {
-      // Move trailer.
-      memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
-    }
-    dl += buffers[2].cbBuffer;
+    A2_LOG_DEBUG(fmt("WinTLS: Write TLS record header: %" PRIu64
+                     " body: %" PRIu64 " trailer: %" PRIu64,
+                     static_cast<uint64_t>(sendRecordBuffers_[0].cbBuffer),
+                     static_cast<uint64_t>(sendRecordBuffers_[1].cbBuffer),
+                     static_cast<uint64_t>(sendRecordBuffers_[2].cbBuffer)));
 
-    // Write (or buffer) the message.
-    char* p = buf.get();
-    while (dl) {
-      auto written = ::send(sockfd_, p, dl, 0);
-      errno = ::WSAGetLastError();
-      if (written < 0 && errno == WSAEINTR) {
-        continue;
-      }
-      if (written < 0 && errno == WSAEWOULDBLOCK) {
-        // Buffer the rest of the message...
-        writeBuf_.write(p, dl);
-        // and return...
-        return len;
-      }
-      if (written == 0) {
-        A2_LOG_ERROR("WinTLS: Connection closed while writing");
-        status_ = SEC_E_INCOMPLETE_MESSAGE;
-        state_ = st_error;
-        return TLS_ERR_ERROR;
-      }
-      if (written < 0) {
-        A2_LOG_ERROR("WinTLS: Connection error while writing");
-        status_ = SEC_E_INCOMPLETE_MESSAGE;
-        state_ = st_error;
-        return TLS_ERR_ERROR;
+    auto rv = sendTLSRecord();
+    if (rv == TLS_ERR_WOULDBLOCK) {
+      if (len == 0) {
+        return TLS_ERR_WOULDBLOCK;
       }
-      dl -= written;
-      p += written;
+      return len;
+    }
+
+    if (rv != 0) {
+      return rv;
     }
 
     len += writeBuffered_;
     bytes += writeBuffered_;
-    process -= writeBuffered_;
+    left -= writeBuffered_;
     writeBuffered_ = 0;
   }
 
-  A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
-                   (uint64_t)len, (uint64_t)writeBuf_.size()));
-  if (!len) {
-    return TLS_ERR_WOULDBLOCK;
-  }
+  A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64, (uint64_t)len));
+
   return len;
 }
 
@@ -777,6 +771,11 @@ restart:
   // Fall through
 
   case st_handshake_done:
+    if (obtainTLSRecordSizes() != 0) {
+      return TLS_ERR_ERROR;
+    }
+    ensureSendBuffer();
+
     // All ready now :D
     state_ = st_connected;
     A2_LOG_INFO(
@@ -833,4 +832,26 @@ std::string WinTLSSession::getLastErrorString()
 
 size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
 
+int WinTLSSession::obtainTLSRecordSizes()
+{
+  status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
+                                     &streamSizes_);
+  if (status_ != SEC_E_OK || !streamSizes_.cbMaximumMessage) {
+    A2_LOG_ERROR("WinTLS: Unable to obtain stream sizes");
+    state_ = st_error;
+    return -1;
+  }
+
+  return 0;
+}
+
+void WinTLSSession::ensureSendBuffer()
+{
+  auto sum = streamSizes_.cbHeader + streamSizes_.cbMaximumMessage +
+             streamSizes_.cbTrailer;
+  if (sendBuffer_.size() < sum) {
+    sendBuffer_.resize(sum);
+  }
+}
+
 } // namespace aria2

+ 40 - 4
src/WinTLSSession.h

@@ -100,6 +100,18 @@ public:
 };
 } // namespace wintls
 
+class TLSBuffer : public ::SecBuffer {
+public:
+  TLSBuffer() : ::SecBuffer{}{}
+
+  explicit TLSBuffer(ULONG type, ULONG size, void* data)
+  {
+    cbBuffer = size;
+    BufferType = type;
+    pvBuffer = data;
+  }
+};
+
 class WinTLSSession : public TLSSession {
   enum state_t {
     st_constructed,
@@ -172,16 +184,31 @@ public:
   virtual size_t getRecvBufferedLength() CXX11_OVERRIDE;
 
 private:
+  // Obtains TLS record size limits.  This function returns 0 if it
+  // succeeds, or -1.  status_ and state_ are updated according to the
+  // result.
+  int obtainTLSRecordSizes();
+  // Ensures the buffer size so that maximum TLS record can be sent.
+  void ensureSendBuffer();
+  // Sends TLS record specified in sendRecordBuffers_.  It uses
+  // recordBytesSent_ to track down how many bytes have been sent.
+  // This function returns 0 if it succeeds, or negative error codes.
+  int sendTLSRecord();
+  // Returns the number of bytes in the remaining TLS record size.
+  size_t getLeftTLSRecordSize() const;
+
   std::string hostname_;
   sock_t sockfd_;
   TLSSessionSide side_;
   CredHandle* cred_;
   CtxtHandle handle_;
 
-  // Buffer for already encrypted writes
+  // Buffer for already encrypted writes.  This is only used in
+  // handshake.
   wintls::Buffer writeBuf_;
-  // While the writeBuf_ holds encrypted messages, writeBuffered_ has the
-  // corresponding size of unencrypted data used to produce the messages.
+  // While the sendRecordBuffers_ holds encrypted messages,
+  // writeBuffered_ has the corresponding size of unencrypted data
+  // used to produce the messages.
   size_t writeBuffered_;
   // Buffer for still encrypted reads
   wintls::Buffer readBuf_;
@@ -191,7 +218,16 @@ private:
   state_t state_;
 
   SECURITY_STATUS status_;
-  std::unique_ptr<SecPkgContext_StreamSizes> streamSizes_;
+  // The number of maximum size for TLS record header, body, and
+  // trailer.
+  SecPkgContext_StreamSizes streamSizes_;
+  // Underlying buffer for outgoing TLS record.
+  std::vector<unsigned char> sendBuffer_;
+  // How many bytes has been sent for current TLS record held in
+  // sendRecordBuffers_.
+  size_t recordBytesSent_;
+  // This holds current outgoing TLS record.
+  std::array<TLSBuffer, 4> sendRecordBuffers_;
 };
 
 } // namespace aria2