|
@@ -86,16 +86,6 @@ static const ULONG kReqAFlags =
|
|
ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
|
|
ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
|
|
ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
|
|
ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
|
|
|
|
|
|
-class TLSBuffer : public ::SecBuffer {
|
|
|
|
-public:
|
|
|
|
- explicit TLSBuffer(ULONG type, ULONG size, void* data)
|
|
|
|
- {
|
|
|
|
- cbBuffer = size;
|
|
|
|
- BufferType = type;
|
|
|
|
- pvBuffer = data;
|
|
|
|
- }
|
|
|
|
-};
|
|
|
|
-
|
|
|
|
class TLSBufferDesc : public ::SecBufferDesc {
|
|
class TLSBufferDesc : public ::SecBufferDesc {
|
|
public:
|
|
public:
|
|
explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
|
|
explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
|
|
@@ -142,7 +132,8 @@ WinTLSSession::WinTLSSession(WinTLSContext* ctx)
|
|
cred_(ctx->getCredHandle()),
|
|
cred_(ctx->getCredHandle()),
|
|
writeBuffered_(0),
|
|
writeBuffered_(0),
|
|
state_(st_constructed),
|
|
state_(st_constructed),
|
|
- status_(SEC_E_OK)
|
|
|
|
|
|
+ status_(SEC_E_OK),
|
|
|
|
+ recordBytesSent_(0)
|
|
{
|
|
{
|
|
memset(&handle_, 0, sizeof(handle_));
|
|
memset(&handle_, 0, sizeof(handle_));
|
|
}
|
|
}
|
|
@@ -213,7 +204,8 @@ int WinTLSSession::closeConnection()
|
|
status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0,
|
|
status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0,
|
|
&handle_, &desc, &flags, nullptr);
|
|
&handle_, &desc, &flags, nullptr);
|
|
}
|
|
}
|
|
- if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) {
|
|
|
|
|
|
+ if ((status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) &&
|
|
|
|
+ getLeftTLSRecordSize() == 0) {
|
|
size_t len = ctx.cbBuffer;
|
|
size_t len = ctx.cbBuffer;
|
|
ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
|
|
ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
|
|
::FreeContextBuffer(ctx.pvBuffer);
|
|
::FreeContextBuffer(ctx.pvBuffer);
|
|
@@ -228,17 +220,6 @@ int WinTLSSession::closeConnection()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- // Send remaining data.
|
|
|
|
- while (writeBuf_.size()) {
|
|
|
|
- int rv = writeData(nullptr, 0);
|
|
|
|
- if (rv == 0) {
|
|
|
|
- break;
|
|
|
|
- }
|
|
|
|
- if (rv < 0) {
|
|
|
|
- return rv;
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
A2_LOG_DEBUG("WinTLS: Closed Connection");
|
|
A2_LOG_DEBUG("WinTLS: Closed Connection");
|
|
state_ = st_closed;
|
|
state_ = st_closed;
|
|
return TLS_ERR_OK;
|
|
return TLS_ERR_OK;
|
|
@@ -255,12 +236,82 @@ int WinTLSSession::checkDirection()
|
|
if (readBuf_.size() || decBuf_.size()) {
|
|
if (readBuf_.size() || decBuf_.size()) {
|
|
return TLS_WANT_READ;
|
|
return TLS_WANT_READ;
|
|
}
|
|
}
|
|
- if (writeBuf_.size()) {
|
|
|
|
|
|
+ if (getLeftTLSRecordSize() || writeBuf_.size()) {
|
|
return TLS_WANT_WRITE;
|
|
return TLS_WANT_WRITE;
|
|
}
|
|
}
|
|
return TLS_WANT_READ;
|
|
return TLS_WANT_READ;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+namespace {
|
|
|
|
+// Fills |iov| of length |len| to send remaining data in |buffers|.
|
|
|
|
+// We have already sent |offset| bytes. This function returns the
|
|
|
|
+// number of |iov| filled. It assumes the array |buffers| is at least
|
|
|
|
+// |len| elements.
|
|
|
|
+size_t fillSendIOV(a2iovec* iov, size_t len, TLSBuffer* buffers, size_t offset)
|
|
|
|
+{
|
|
|
|
+ size_t iovcnt = 0;
|
|
|
|
+ for (size_t i = 0; i < len; ++i) {
|
|
|
|
+ if (offset < buffers[i].cbBuffer) {
|
|
|
|
+ iov[iovcnt].A2IOVEC_BASE =
|
|
|
|
+ static_cast<char*>(buffers[i].pvBuffer) + offset;
|
|
|
|
+ iov[iovcnt].A2IOVEC_LEN = buffers[i].cbBuffer - offset;
|
|
|
|
+ ++iovcnt;
|
|
|
|
+ offset = 0;
|
|
|
|
+ }
|
|
|
|
+ else {
|
|
|
|
+ offset -= buffers[i].cbBuffer;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return iovcnt;
|
|
|
|
+}
|
|
|
|
+} // namespace
|
|
|
|
+
|
|
|
|
+size_t WinTLSSession::getLeftTLSRecordSize() const
|
|
|
|
+{
|
|
|
|
+ return sendRecordBuffers_[0].cbBuffer + sendRecordBuffers_[1].cbBuffer +
|
|
|
|
+ sendRecordBuffers_[2].cbBuffer - recordBytesSent_;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+int WinTLSSession::sendTLSRecord()
|
|
|
|
+{
|
|
|
|
+ A2_LOG_DEBUG(fmt("WinTLS: TLS record %" PRIu64 " bytes left",
|
|
|
|
+ static_cast<uint64_t>(getLeftTLSRecordSize())));
|
|
|
|
+
|
|
|
|
+ while (getLeftTLSRecordSize()) {
|
|
|
|
+ std::array<a2iovec, 3> iov;
|
|
|
|
+ auto iovcnt = fillSendIOV(iov.data(), iov.size(), sendRecordBuffers_.data(),
|
|
|
|
+ recordBytesSent_);
|
|
|
|
+
|
|
|
|
+ DWORD nwrite;
|
|
|
|
+ auto rv =
|
|
|
|
+ WSASend(sockfd_, iov.data(), iovcnt, &nwrite, 0, nullptr, nullptr);
|
|
|
|
+ if (rv != 0) {
|
|
|
|
+ auto errnum = ::WSAGetLastError();
|
|
|
|
+ if (errnum == WSAEINTR) {
|
|
|
|
+ continue;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if (errnum == WSAEWOULDBLOCK) {
|
|
|
|
+ return TLS_ERR_WOULDBLOCK;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ A2_LOG_ERROR("WinTLS: Connection error while writing");
|
|
|
|
+ status_ = SEC_E_INCOMPLETE_MESSAGE;
|
|
|
|
+ state_ = st_error;
|
|
|
|
+ return TLS_ERR_ERROR;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ recordBytesSent_ += nwrite;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ recordBytesSent_ = 0;
|
|
|
|
+ sendRecordBuffers_[0].cbBuffer = 0;
|
|
|
|
+ sendRecordBuffers_[1].cbBuffer = 0;
|
|
|
|
+ sendRecordBuffers_[2].cbBuffer = 0;
|
|
|
|
+
|
|
|
|
+ return 0;
|
|
|
|
+}
|
|
|
|
+
|
|
ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|
ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|
{
|
|
{
|
|
if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
|
|
if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
|
|
@@ -281,45 +332,15 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|
}
|
|
}
|
|
|
|
|
|
A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
|
|
A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
|
|
- (uint64_t)len, (uint64_t)writeBuf_.size()));
|
|
|
|
-
|
|
|
|
- // Write remaining buffered data, if any.
|
|
|
|
- while (writeBuf_.size()) {
|
|
|
|
- auto written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
|
|
|
|
- errno = ::WSAGetLastError();
|
|
|
|
- if (written < 0 && errno == WSAEINTR) {
|
|
|
|
- continue;
|
|
|
|
- }
|
|
|
|
- if (written < 0 && errno == WSAEWOULDBLOCK) {
|
|
|
|
- return TLS_ERR_WOULDBLOCK;
|
|
|
|
- }
|
|
|
|
- if (written == 0) {
|
|
|
|
- return written;
|
|
|
|
- }
|
|
|
|
- if (written < 0) {
|
|
|
|
- status_ = SEC_E_INVALID_HANDLE;
|
|
|
|
- state_ = st_error;
|
|
|
|
- return TLS_ERR_ERROR;
|
|
|
|
- }
|
|
|
|
- writeBuf_.eat(written);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if (len == 0) {
|
|
|
|
- return 0;
|
|
|
|
- }
|
|
|
|
|
|
+ (uint64_t)len, (uint64_t)recordBytesSent_));
|
|
|
|
|
|
- if (!streamSizes_) {
|
|
|
|
- streamSizes_.reset(new SecPkgContext_StreamSizes());
|
|
|
|
- status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
|
|
|
|
- streamSizes_.get());
|
|
|
|
- if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
|
|
|
|
- state_ = st_error;
|
|
|
|
- return TLS_ERR_ERROR;
|
|
|
|
- }
|
|
|
|
|
|
+ auto rv = sendTLSRecord();
|
|
|
|
+ if (rv != 0) {
|
|
|
|
+ return rv;
|
|
}
|
|
}
|
|
|
|
|
|
- size_t process = len;
|
|
|
|
- auto bytes = reinterpret_cast<const char*>(data);
|
|
|
|
|
|
+ auto left = len;
|
|
|
|
+ auto bytes = static_cast<const char*>(data);
|
|
if (writeBuffered_) {
|
|
if (writeBuffered_) {
|
|
// There was buffered data, hence we need to "remove" that data from the
|
|
// There was buffered data, hence we need to "remove" that data from the
|
|
// incoming buffer to avoid writing it again
|
|
// incoming buffer to avoid writing it again
|
|
@@ -331,10 +352,10 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|
}
|
|
}
|
|
// just advance the buffer by writeBuffered_ bytes
|
|
// just advance the buffer by writeBuffered_ bytes
|
|
bytes += writeBuffered_;
|
|
bytes += writeBuffered_;
|
|
- process -= writeBuffered_;
|
|
|
|
|
|
+ left -= writeBuffered_;
|
|
writeBuffered_ = 0;
|
|
writeBuffered_ = 0;
|
|
}
|
|
}
|
|
- if (!process) {
|
|
|
|
|
|
+ if (!left) {
|
|
// The buffer contained the full remainder. At this point, the buffer has
|
|
// The buffer contained the full remainder. At this point, the buffer has
|
|
// been written, so the request is done in its entirety;
|
|
// been written, so the request is done in its entirety;
|
|
return len;
|
|
return len;
|
|
@@ -342,23 +363,25 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|
|
|
|
|
// Buffered data was already written ;)
|
|
// Buffered data was already written ;)
|
|
// If there was no buffered data, this will be len - len = 0.
|
|
// If there was no buffered data, this will be len - len = 0.
|
|
- len = len - process;
|
|
|
|
- while (process) {
|
|
|
|
|
|
+ len -= left;
|
|
|
|
+ while (left) {
|
|
// Set up an outgoing message, according to streamSizes_
|
|
// Set up an outgoing message, according to streamSizes_
|
|
- writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
|
|
|
|
- size_t dl =
|
|
|
|
- streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer;
|
|
|
|
- auto buf = make_unique<char[]>(dl);
|
|
|
|
- TLSBuffer buffers[] = {
|
|
|
|
- TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
|
|
|
|
|
|
+ writeBuffered_ =
|
|
|
|
+ std::min(left, static_cast<size_t>(streamSizes_.cbMaximumMessage));
|
|
|
|
+
|
|
|
|
+ sendRecordBuffers_ = {
|
|
|
|
+ TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_.cbHeader,
|
|
|
|
+ sendBuffer_.data()),
|
|
TLSBuffer(SECBUFFER_DATA, writeBuffered_,
|
|
TLSBuffer(SECBUFFER_DATA, writeBuffered_,
|
|
- buf.get() + streamSizes_->cbHeader),
|
|
|
|
- TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
|
|
|
|
- buf.get() + streamSizes_->cbHeader + writeBuffered_),
|
|
|
|
|
|
+ sendBuffer_.data() + streamSizes_.cbHeader),
|
|
|
|
+ TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_.cbTrailer,
|
|
|
|
+ sendBuffer_.data() + streamSizes_.cbHeader + writeBuffered_),
|
|
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
|
|
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
|
|
};
|
|
};
|
|
- TLSBufferDesc desc(buffers, 4);
|
|
|
|
- memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
|
|
|
|
|
|
+
|
|
|
|
+ TLSBufferDesc desc(sendRecordBuffers_.data(), sendRecordBuffers_.size());
|
|
|
|
+ std::copy_n(bytes, writeBuffered_,
|
|
|
|
+ static_cast<char*>(sendRecordBuffers_[1].pvBuffer));
|
|
status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
|
|
status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
|
|
if (status_ != SEC_E_OK) {
|
|
if (status_ != SEC_E_OK) {
|
|
A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
|
|
A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
|
|
@@ -367,61 +390,32 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|
return TLS_ERR_ERROR;
|
|
return TLS_ERR_ERROR;
|
|
}
|
|
}
|
|
|
|
|
|
- // EncryptMessage may have truncated the buffers.
|
|
|
|
- // Should rarely happen, if ever, except for the trailer.
|
|
|
|
- dl = buffers[0].cbBuffer;
|
|
|
|
- if (dl < streamSizes_->cbHeader) {
|
|
|
|
- // Move message.
|
|
|
|
- memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
|
|
|
|
- }
|
|
|
|
- dl += buffers[1].cbBuffer;
|
|
|
|
- if (dl < streamSizes_->cbHeader + writeBuffered_) {
|
|
|
|
- // Move trailer.
|
|
|
|
- memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
|
|
|
|
- }
|
|
|
|
- dl += buffers[2].cbBuffer;
|
|
|
|
|
|
+ A2_LOG_DEBUG(fmt("WinTLS: Write TLS record header: %" PRIu64
|
|
|
|
+ " body: %" PRIu64 " trailer: %" PRIu64,
|
|
|
|
+ static_cast<uint64_t>(sendRecordBuffers_[0].cbBuffer),
|
|
|
|
+ static_cast<uint64_t>(sendRecordBuffers_[1].cbBuffer),
|
|
|
|
+ static_cast<uint64_t>(sendRecordBuffers_[2].cbBuffer)));
|
|
|
|
|
|
- // Write (or buffer) the message.
|
|
|
|
- char* p = buf.get();
|
|
|
|
- while (dl) {
|
|
|
|
- auto written = ::send(sockfd_, p, dl, 0);
|
|
|
|
- errno = ::WSAGetLastError();
|
|
|
|
- if (written < 0 && errno == WSAEINTR) {
|
|
|
|
- continue;
|
|
|
|
- }
|
|
|
|
- if (written < 0 && errno == WSAEWOULDBLOCK) {
|
|
|
|
- // Buffer the rest of the message...
|
|
|
|
- writeBuf_.write(p, dl);
|
|
|
|
- // and return...
|
|
|
|
- return len;
|
|
|
|
- }
|
|
|
|
- if (written == 0) {
|
|
|
|
- A2_LOG_ERROR("WinTLS: Connection closed while writing");
|
|
|
|
- status_ = SEC_E_INCOMPLETE_MESSAGE;
|
|
|
|
- state_ = st_error;
|
|
|
|
- return TLS_ERR_ERROR;
|
|
|
|
- }
|
|
|
|
- if (written < 0) {
|
|
|
|
- A2_LOG_ERROR("WinTLS: Connection error while writing");
|
|
|
|
- status_ = SEC_E_INCOMPLETE_MESSAGE;
|
|
|
|
- state_ = st_error;
|
|
|
|
- return TLS_ERR_ERROR;
|
|
|
|
|
|
+ auto rv = sendTLSRecord();
|
|
|
|
+ if (rv == TLS_ERR_WOULDBLOCK) {
|
|
|
|
+ if (len == 0) {
|
|
|
|
+ return TLS_ERR_WOULDBLOCK;
|
|
}
|
|
}
|
|
- dl -= written;
|
|
|
|
- p += written;
|
|
|
|
|
|
+ return len;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if (rv != 0) {
|
|
|
|
+ return rv;
|
|
}
|
|
}
|
|
|
|
|
|
len += writeBuffered_;
|
|
len += writeBuffered_;
|
|
bytes += writeBuffered_;
|
|
bytes += writeBuffered_;
|
|
- process -= writeBuffered_;
|
|
|
|
|
|
+ left -= writeBuffered_;
|
|
writeBuffered_ = 0;
|
|
writeBuffered_ = 0;
|
|
}
|
|
}
|
|
|
|
|
|
- A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
|
|
|
|
- (uint64_t)len, (uint64_t)writeBuf_.size()));
|
|
|
|
- if (!len) {
|
|
|
|
- return TLS_ERR_WOULDBLOCK;
|
|
|
|
- }
|
|
|
|
|
|
+ A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64, (uint64_t)len));
|
|
|
|
+
|
|
return len;
|
|
return len;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -777,6 +771,11 @@ restart:
|
|
// Fall through
|
|
// Fall through
|
|
|
|
|
|
case st_handshake_done:
|
|
case st_handshake_done:
|
|
|
|
+ if (obtainTLSRecordSizes() != 0) {
|
|
|
|
+ return TLS_ERR_ERROR;
|
|
|
|
+ }
|
|
|
|
+ ensureSendBuffer();
|
|
|
|
+
|
|
// All ready now :D
|
|
// All ready now :D
|
|
state_ = st_connected;
|
|
state_ = st_connected;
|
|
A2_LOG_INFO(
|
|
A2_LOG_INFO(
|
|
@@ -833,4 +832,26 @@ std::string WinTLSSession::getLastErrorString()
|
|
|
|
|
|
size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
|
|
size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
|
|
|
|
|
|
|
|
+int WinTLSSession::obtainTLSRecordSizes()
|
|
|
|
+{
|
|
|
|
+ status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
|
|
|
|
+ &streamSizes_);
|
|
|
|
+ if (status_ != SEC_E_OK || !streamSizes_.cbMaximumMessage) {
|
|
|
|
+ A2_LOG_ERROR("WinTLS: Unable to obtain stream sizes");
|
|
|
|
+ state_ = st_error;
|
|
|
|
+ return -1;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return 0;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+void WinTLSSession::ensureSendBuffer()
|
|
|
|
+{
|
|
|
|
+ auto sum = streamSizes_.cbHeader + streamSizes_.cbMaximumMessage +
|
|
|
|
+ streamSizes_.cbTrailer;
|
|
|
|
+ if (sendBuffer_.size() < sum) {
|
|
|
|
+ sendBuffer_.resize(sum);
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
} // namespace aria2
|
|
} // namespace aria2
|