浏览代码

Implement WinTLS

Nils Maier 12 年之前
父节点
当前提交
00dd83b461
共有 7 个文件被更改,包括 1354 次插入13 次删除
  1. 30 12
      configure.ac
  2. 5 0
      src/Makefile.am
  3. 4 1
      src/SocketCore.cc
  4. 189 0
      src/WinTLSContext.cc
  5. 116 0
      src/WinTLSContext.h
  6. 816 0
      src/WinTLSSession.cc
  7. 194 0
      src/WinTLSSession.h

+ 30 - 12
configure.ac

@@ -40,7 +40,7 @@ AC_DEFINE_UNQUOTED([TARGET], ["$target"], [Define target-type])
 # Checks for arguments.
 ARIA2_ARG_WITHOUT([libuv])
 ARIA2_ARG_WITHOUT([appletls])
-ARIA2_ARG_WITHOUT([wintls])
+ARIA2_ARG_WITH([wintls])
 ARIA2_ARG_WITHOUT([gnutls])
 ARIA2_ARG_WITHOUT([libnettle])
 ARIA2_ARG_WITHOUT([libgmp])
@@ -337,23 +337,39 @@ if test "x$with_appletls" = "xyes"; then
 fi
 
 if test "x$with_wintls" = "xyes"; then
-  AC_SEARCH_LIBS([CryptAcquireContextW], [advapi32], [
-                  AC_CHECK_HEADER([wincrypt.h], [have_wincrypt=yes], [have_wincrypt=no],
-                      [[
+  AC_HAVE_LIBRARY([crypt32],[have_wintls_libs=yes],[have_wintls_libs=no])
+  AC_HAVE_LIBRARY([secur32],[have_wintls_libs=$have_wintls_libs],[have_wintls_libs=no])
+  AC_HAVE_LIBRARY([advapi32],[have_wintls_libs=$have_wintls_libs],[have_wintls_libs=no])
+  AC_CHECK_HEADER([wincrypt.h], [have_wintls_headers=yes], [have_wintls_headers=no], [[
 #ifdef HAVE_WINDOWS_H
 # include <windows.h>
 #endif
-                      ]])
-                  break;
-                  ], [have_wincrypt=no])
-  if test "x$have_wincrypt" != "xyes"; then
+  ]])
+  AC_CHECK_HEADER([security.h], [have_wintls_headers=$have_wintls_headers], [have_wintls_headers=no], [[
+#ifdef HAVE_WINDOWS_H
+# include <windows.h>
+#endif
+#ifndef SECURITY_WIN32
+#define SECURITY_WIN32 1
+#endif
+  ]])
+
+  if test "x$have_wintls_libs" = "xyes" &&
+     test "x$have_wintls_headers" = "xyes"; then
+    AC_DEFINE([SECURITY_WIN32], [1], [Use security.h in WIN32 mode])
+    LIBS="$LIBS -lcrypt32 -lsecur32 -ladvapi32"
+    have_wintls=yes
+  else
+    have_wintls=no
+  fi
+  if test "x$have_wintls" != "xyes"; then
     if test "x$with_wintls_requested" = "xyes"; then
       ARIA2_DEP_NOT_MET([wintls])
     fi
   fi
 fi
 
-if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then
+if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_wintls" != "xyes"; then
   # gnutls >= 2.8 doesn't have libgnutls-config anymore. We require
   # 2.2.0 because we use gnutls_priority_set_direct()
   PKG_CHECK_MODULES([LIBGNUTLS], [gnutls >= 2.2.0],
@@ -371,7 +387,7 @@ if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then
   fi
 fi
 
-if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_libgnutls" != "xyes"; then
+if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_wintls" != "xyes" && test "x$have_libgnutls" != "xyes"; then
   PKG_CHECK_MODULES([OPENSSL], [openssl >= 0.9.8],
                     [have_openssl=yes], [have_openssl=no])
   if test "x$have_openssl" = "xyes"; then
@@ -448,7 +464,7 @@ if test "x$have_appletls" == "xyes"; then
   use_md="apple"
   AC_DEFINE([USE_APPLE_MD], [1], [What message digest implementation to use])
 else
-  if test "x$have_wincrypt" == "xyes"; then
+  if test "x$have_wintls" == "xyes"; then
     use_md="windows"
     AC_DEFINE([USE_WINDOWS_MD], [1], [What message digest implementation to use])
   else
@@ -473,7 +489,7 @@ else
 fi
 
 # Define variables based on the result of the checks for libraries.
-if test "x$have_appletls" = "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
+if test "x$have_appletls" = "xyes" || test "x$have_wintls" == "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
   have_ssl="yes"
   AC_DEFINE([ENABLE_SSL], [1], [Define to 1 if ssl support is enabled.])
   AM_CONDITIONAL([ENABLE_SSL], true)
@@ -485,6 +501,7 @@ fi
 
 AM_CONDITIONAL([HAVE_OSX], [ test "x$have_osx" = "xyes" ])
 AM_CONDITIONAL([HAVE_APPLETLS], [ test "x$have_appletls" = "xyes" ])
+AM_CONDITIONAL([HAVE_WINTLS], [ test "x$have_wintls" = "xyes" ])
 AM_CONDITIONAL([USE_APPLE_MD], [ test "x$use_md" = "xapple" ])
 AM_CONDITIONAL([USE_WINDOWS_MD], [ test "x$use_md" = "xwindows" ])
 AM_CONDITIONAL([HAVE_LIBGNUTLS], [ test "x$have_libgnutls" = "xyes" ])
@@ -985,6 +1002,7 @@ echo "LibUV:          $have_libuv"
 echo "SQLite3:        $have_sqlite3"
 echo "SSL Support:    $have_ssl"
 echo "AppleTLS:       $have_appletls"
+echo "WinTLS:         $have_wintls"
 echo "GnuTLS:         $have_libgnutls"
 echo "OpenSSL:        $have_openssl"
 echo "CA Bundle:      $ca_bundle"

+ 5 - 0
src/Makefile.am

@@ -333,6 +333,11 @@ if USE_WINDOWS_MD
 SRCS += WinMessageDigestImpl.cc
 endif # USE_WINDOWS_MD
 
+if HAVE_WINTLS
+SRCS += WinTLSContext.cc WinTLSContext.h \
+	WinTLSSession.cc WinTLSSession.h
+endif # HAVE_WINTLS
+
 if USE_INTERNAL_BIGNUM
 SRCS += InternalDHKeyExchange.cc InternalDHKeyExchange.h bignum.h
 endif

+ 4 - 1
src/SocketCore.cc

@@ -779,7 +779,7 @@ void SocketCore::readData(void* data, size_t& len)
     ret = tlsSession_->readData(data, len);
     if(ret < 0) {
       if(ret != TLS_ERR_WOULDBLOCK) {
-        throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
+        throw DL_RETRY_EX(fmt(EX_SOCKET_RECV,
                               tlsSession_->getLastErrorString().c_str()));
       }
       if(tlsSession_->checkDirection() == TLS_WANT_READ) {
@@ -814,6 +814,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
   wantWrite_ = false;
   switch(secure_) {
   case A2_TLS_NONE:
+    A2_LOG_DEBUG("Creating TLS session");
     tlsSession_.reset(TLSSession::make(tlsctx));
     rv = tlsSession_->init(sockfd_);
     if(rv != TLS_ERR_OK) {
@@ -835,6 +836,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
     secure_ = A2_TLS_HANDSHAKING;
     // Fall through
   case A2_TLS_HANDSHAKING:
+    A2_LOG_DEBUG("TLS Handshaking");
     if(tlsctx->getSide() == TLS_CLIENT) {
       rv = tlsSession_->tlsConnect(hostname, handshakeError);
     } else {
@@ -857,6 +859,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
     }
     return false;
   default:
+    A2_LOG_DEBUG("TLS else");
     break;
   }
   return true;

+ 189 - 0
src/WinTLSContext.cc

@@ -0,0 +1,189 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+
+#include "WinTLSContext.h"
+
+#include <sstream>
+
+#include "BufferedFile.h"
+#include "LogFactory.h"
+#include "Logger.h"
+#include "fmt.h"
+#include "message.h"
+#include "util.h"
+
+namespace aria2 {
+
+WinTLSContext::WinTLSContext(TLSSessionSide side)
+  : side_(side), store_(0)
+{
+  memset(&credentials_, 0, sizeof(credentials_));
+  credentials_.dwVersion = SCHANNEL_CRED_VERSION;
+  if (side_ == TLS_CLIENT) {
+    credentials_.grbitEnabledProtocols =
+      SP_PROT_SSL3_CLIENT |
+      SP_PROT_TLS1_CLIENT |
+      SP_PROT_TLS1_1_CLIENT |
+      SP_PROT_TLS1_2_CLIENT;
+  }
+  else {
+    credentials_.grbitEnabledProtocols =
+      SP_PROT_SSL3_SERVER |
+      SP_PROT_TLS1_SERVER |
+      SP_PROT_TLS1_1_SERVER |
+      SP_PROT_TLS1_2_SERVER;
+  }
+  credentials_.dwMinimumCipherStrength = 128; // bit
+
+  setVerifyPeer(side_ == TLS_CLIENT);
+}
+
+TLSContext* TLSContext::make(TLSSessionSide side)
+{
+  return new WinTLSContext(side);
+}
+
+WinTLSContext::~WinTLSContext()
+{
+  if (store_) {
+    CertCloseStore(store_, 0);
+    store_ = 0;
+  }
+}
+
+bool WinTLSContext::getVerifyPeer() const
+{
+  return credentials_.dwFlags & SCH_CRED_AUTO_CRED_VALIDATION;
+}
+
+void WinTLSContext::setVerifyPeer(bool verify)
+{
+  if (side_ == TLS_CLIENT && verify) {
+    credentials_.dwFlags =
+      SCH_CRED_NO_DEFAULT_CREDS |
+      SCH_CRED_AUTO_CRED_VALIDATION |
+      SCH_CRED_REVOCATION_CHECK_CHAIN;
+  }
+  else {
+    credentials_.dwFlags =
+      SCH_CRED_NO_DEFAULT_CREDS |
+      SCH_CRED_MANUAL_CRED_VALIDATION |
+      SCH_CRED_IGNORE_NO_REVOCATION_CHECK |
+      SCH_CRED_IGNORE_REVOCATION_OFFLINE |
+      SCH_CRED_NO_SERVERNAME_CHECK;
+  }
+
+  // Need to initialize cred_ early, because later on it will segfault deep
+  // within AcquireCredentialsHandle for whatever reason.
+  cred_.reset();
+  getCredHandle();
+}
+
+CredHandle* WinTLSContext::getCredHandle()
+{
+  if (cred_) {
+    return cred_.get();
+  }
+
+  TimeStamp ts;
+  cred_.reset(new CredHandle());
+  SECURITY_STATUS status = ::AcquireCredentialsHandleW(
+      nullptr,
+      (SEC_WCHAR*)UNISP_NAME_W,
+      side_ == TLS_CLIENT ? SECPKG_CRED_OUTBOUND : SECPKG_CRED_INBOUND,
+      nullptr,
+      &credentials_,
+      nullptr,
+      nullptr,
+      cred_.get(),
+      &ts);
+  if (status != SEC_E_OK) {
+    cred_.reset();
+    throw DL_ABORT_EX("Failed to initialize WinTLS context handle");
+  }
+  return cred_.get();
+}
+
+bool WinTLSContext::addCredentialFile(const std::string& certfile,
+                                        const std::string& keyfile)
+{
+  std::stringstream ss;
+  BufferedFile(certfile.c_str(), "rb").transfer(ss);
+  auto data = ss.str();
+  CRYPT_DATA_BLOB blob = {
+    (DWORD)data.length(),
+    (BYTE*)data.c_str()
+  };
+  if (!PFXIsPFXBlob(&blob)) {
+    A2_LOG_ERROR("Not a valid PKCS12 file");
+    return false;
+  }
+  store_ = ::PFXImportCertStore(&blob, L"",
+                                CRYPT_EXPORTABLE | CRYPT_USER_KEYSET);
+  if (!store_) {
+    store_ = ::PFXImportCertStore(&blob, nullptr,
+                                  CRYPT_EXPORTABLE | CRYPT_USER_KEYSET);
+  }
+  if (!store_) {
+    A2_LOG_ERROR("Failed to import PKCS12 store");
+    return false;
+  }
+
+  const CERT_CONTEXT* ctx = ::CertEnumCertificatesInStore(store_, nullptr);
+  if (!ctx) {
+    A2_LOG_ERROR("Failed to read any certificates from the PKCS12 store");
+    return false;
+  }
+  credentials_.cCreds = 1;
+  credentials_.paCred = &ctx;
+
+  // Need to initialize cred_ early, because later on it will segfault deep
+  // within AcquireCredentialsHandle for whatever reason.
+  cred_.reset();
+  getCredHandle();
+
+  CertFreeCertificateContext(ctx);
+
+  return true;
+}
+
+bool WinTLSContext::addTrustedCACertFile(const std::string& certfile)
+{
+  A2_LOG_INFO("TLS CA bundle files are not supported. "
+              "The system trust store will be used.");
+  return false;
+}
+
+} // namespace aria2

+ 116 - 0
src/WinTLSContext.h

@@ -0,0 +1,116 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+
+#ifndef D_WIN_TLS_CONTEXT_H
+#define D_WIN_TLS_CONTEXT_H
+
+#include "common.h"
+#include "config.h"
+
+#include <string>
+
+#include <windows.h>
+#include <security.h>
+#include <schnlsp.h>
+
+#include "TLSContext.h"
+#include "DlAbortEx.h"
+
+#ifndef SP_PROT_TLS1_1_CLIENT
+#define SP_PROT_TLS1_1_CLIENT 0x00000200
+#endif
+#ifndef SP_PROT_TLS1_1_SERVER
+#define SP_PROT_TLS1_1_SERVER 0x00000100
+#endif
+#ifndef SP_PROT_TLS1_2_CLIENT
+#define SP_PROT_TLS1_2_CLIENT 0x00000800
+#endif
+#ifndef SP_PROT_TLS1_2_SERVER
+#define SP_PROT_TLS1_2_SERVER 0x00000400
+#endif
+
+namespace aria2 {
+
+namespace wintls {
+  struct cred_deleter{
+    void operator()(CredHandle* handle) {
+      if (handle) {
+        FreeCredentialsHandle(handle);
+        delete handle;
+      }
+    }
+  };
+  typedef std::unique_ptr<CredHandle, cred_deleter> CredPtr;
+} // namespace wintls
+
+class WinTLSContext : public TLSContext {
+public:
+  WinTLSContext(TLSSessionSide side);
+  virtual ~WinTLSContext();
+
+  // private key `keyfile' must be decrypted.
+  virtual bool addCredentialFile(const std::string& certfile,
+                                 const std::string& keyfile) CXX11_OVERRIDE;
+
+  virtual bool addSystemTrustedCACerts() CXX11_OVERRIDE {
+    return true;
+  }
+
+  // certfile can contain multiple certificates.
+  virtual bool addTrustedCACertFile(const std::string& certfile)
+    CXX11_OVERRIDE;
+
+  virtual bool good() const CXX11_OVERRIDE {
+    return true;
+  }
+  virtual TLSSessionSide getSide() const CXX11_OVERRIDE {
+    return side_;
+  }
+
+  virtual bool getVerifyPeer() const CXX11_OVERRIDE;
+  virtual void setVerifyPeer(bool verify) CXX11_OVERRIDE;
+
+  CredHandle* getCredHandle();
+
+private:
+  TLSSessionSide side_;
+  SCHANNEL_CRED credentials_;
+  HCERTSTORE store_;
+  wintls::CredPtr cred_;
+};
+
+} // namespace aria2
+
+#endif // D_LIBSSL_TLS_CONTEXT_H

+ 816 - 0
src/WinTLSSession.cc

@@ -0,0 +1,816 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+
+#include "WinTLSSession.h"
+
+#include <sstream>
+
+#include "LogFactory.h"
+#include "a2functional.h"
+#include "fmt.h"
+#include "util.h"
+
+#ifndef SECBUFFER_ALERT
+#define SECBUFFER_ALERT 17
+#endif
+
+#ifndef SZ_ALG_MAX_SIZE
+#define SZ_ALG_MAX_SIZE 64
+#endif
+#ifndef SECPKGCONTEXT_CIPHERINFO_V1
+#define SECPKGCONTEXT_CIPHERINFO_V1 1
+#endif
+#ifndef SECPKG_ATTR_CIPHER_INFO
+#define SECPKG_ATTR_CIPHER_INFO  0x64
+#endif
+
+namespace {
+  using namespace aria2;
+
+  struct WinSecPkgContext_CipherInfo {
+      DWORD dwVersion;
+      DWORD dwProtocol;
+      DWORD dwCipherSuite;
+      DWORD dwBaseCipherSuite;
+      WCHAR szCipherSuite[SZ_ALG_MAX_SIZE];
+      WCHAR szCipher[SZ_ALG_MAX_SIZE];
+      DWORD dwCipherLen;
+      DWORD dwCipherBlockLen;    // in bytes
+      WCHAR szHash[SZ_ALG_MAX_SIZE];
+      DWORD dwHashLen;
+      WCHAR szExchange[SZ_ALG_MAX_SIZE];
+      DWORD dwMinExchangeLen;
+      DWORD dwMaxExchangeLen;
+      WCHAR szCertificate[SZ_ALG_MAX_SIZE];
+      DWORD dwKeyType;
+  };
+
+  static const ULONG kReqFlags = ISC_REQ_SEQUENCE_DETECT |
+                                 ISC_REQ_REPLAY_DETECT |
+                                 ISC_REQ_CONFIDENTIALITY |
+                                 ISC_REQ_ALLOCATE_MEMORY |
+                                 ISC_REQ_STREAM;
+  static const ULONG kReqAFlags = ASC_REQ_SEQUENCE_DETECT |
+                                  ASC_REQ_REPLAY_DETECT |
+                                  ASC_REQ_CONFIDENTIALITY |
+                                  ASC_REQ_EXTENDED_ERROR |
+                                  ASC_REQ_ALLOCATE_MEMORY |
+                                  ASC_REQ_STREAM;
+
+  class TLSBuffer : public ::SecBuffer {
+  public:
+    explicit TLSBuffer(ULONG type, ULONG size, void *data)
+    {
+      cbBuffer = size;
+      BufferType = type;
+      pvBuffer = data;
+    }
+  };
+
+  class TLSBufferDesc: public ::SecBufferDesc {
+  public:
+    explicit TLSBufferDesc(SecBuffer *arr, ULONG buffers)
+    {
+      ulVersion = SECBUFFER_VERSION;
+      cBuffers = buffers;
+      pBuffers = arr;
+    }
+  };
+
+  inline static std::string getCipherSuite(CtxtHandle *handle)
+  {
+    WinSecPkgContext_CipherInfo info = { SECPKGCONTEXT_CIPHERINFO_V1 };
+    if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
+        SEC_E_OK) {
+      return wCharToUtf8(info.szCipherSuite);
+    }
+    return "Unknown";
+  }
+}
+
+namespace aria2 {
+
+TLSSession* TLSSession::make(TLSContext* ctx)
+{
+  return new WinTLSSession(static_cast<WinTLSContext*>(ctx));
+}
+
+WinTLSSession::WinTLSSession(WinTLSContext* ctx)
+  : sockfd_(0),
+    side_(ctx->getSide()),
+    cred_(ctx->getCredHandle()),
+    writeBuffered_(0),
+    state_(st_constructed),
+    status_(SEC_E_OK)
+{
+  memset(&handle_, 0, sizeof(handle_));
+}
+
+WinTLSSession::~WinTLSSession()
+{
+  ::DeleteSecurityContext(&handle_);
+  state_ = st_error;
+}
+
+int WinTLSSession::init(sock_t sockfd)
+{
+  if (state_ != st_constructed) {
+    status_ = SEC_E_INVALID_HANDLE;
+    return TLS_ERR_ERROR;
+  }
+  sockfd_ = sockfd;
+  state_ = st_initialized;
+
+  return TLS_ERR_OK;
+}
+
+int WinTLSSession::setSNIHostname(const std::string& hostname)
+{
+  if (state_ != st_initialized) {
+    status_ = SEC_E_INVALID_HANDLE;
+    return TLS_ERR_ERROR;
+  }
+  hostname_ = hostname;
+  return TLS_ERR_OK;
+}
+
+int WinTLSSession::closeConnection()
+{
+  if (state_ != st_connected || state_ != st_closing) {
+    return TLS_ERR_ERROR;
+  }
+
+  if (state_ == st_connected) {
+    state_ = st_closing;
+
+    DWORD dwShut = SCHANNEL_SHUTDOWN;
+    TLSBuffer shut(SECBUFFER_TOKEN, sizeof(dwShut), &dwShut);
+    TLSBufferDesc shutDesc(&shut, 1);
+    status_ = ::ApplyControlToken(&handle_, &shutDesc);
+    if (status_ != SEC_E_OK) {
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+    }
+    TLSBuffer ctx(SECBUFFER_EMPTY, 0, nullptr);
+    TLSBufferDesc desc(&ctx, 1);
+    ULONG flags = 0;
+    if (side_ == TLS_CLIENT) {
+      SEC_CHAR* host = hostname_.empty() ?
+        nullptr :
+        const_cast<SEC_CHAR*>(hostname_.c_str());
+      status_ = ::InitializeSecurityContext(
+          cred_,
+          &handle_,
+          host,
+          kReqFlags,
+          0,
+          0,
+          nullptr,
+          0,
+          &handle_,
+          &desc,
+          &flags,
+          nullptr);
+    }
+    else {
+      status_ = ::AcceptSecurityContext(
+          cred_,
+          &handle_,
+          nullptr,
+          kReqAFlags,
+          0,
+          &handle_,
+          &desc,
+          &flags,
+          nullptr);
+    }
+    if (status_ == SEC_E_OK || status_== SEC_I_CONTEXT_EXPIRED) {
+      size_t len = ctx.cbBuffer;
+      ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
+      ::FreeContextBuffer(ctx.pvBuffer);
+      if (rv == TLS_ERR_WOULDBLOCK) {
+        return rv;
+      }
+
+      // Alright data is sent or buffered
+      if (rv - len != 0) {
+        return TLS_ERR_WOULDBLOCK;
+      }
+    }
+  }
+
+  // Send remaining data.
+  while (writeBuf_.size()) {
+    int rv = writeData(nullptr, 0);
+    if (rv == TLS_ERR_WOULDBLOCK) {
+      return rv;
+    }
+  }
+
+  state_ = st_closed;
+  return TLS_ERR_OK;
+}
+
+int WinTLSSession::checkDirection()
+{
+  if (state_ == st_handshake_write || state_  == st_handshake_write_last) {
+    return TLS_WANT_WRITE;
+  }
+  if (state_ == st_handshake_read) {
+    return TLS_WANT_READ;
+  }
+  if (readBuf_.size() || decBuf_.size()) {
+    return TLS_WANT_READ;
+  }
+  if (writeBuf_.size()) {
+    return TLS_WANT_WRITE;
+  }
+  return TLS_WANT_READ;
+}
+
+ssize_t WinTLSSession::writeData(const void* data, size_t len)
+{
+  if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
+      state_ == st_handshake_read) {
+    // Renegotiating
+    std::string hn, err;
+    auto connect = tlsConnect(hn, err);
+    if (connect != TLS_ERR_OK) {
+      return connect;
+    }
+    // Continue.
+  }
+
+  if (state_ != st_connected && state_ != st_closing) {
+    status_ = SEC_E_INVALID_HANDLE;
+    return TLS_ERR_ERROR;
+  }
+
+  A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
+                   (uint64_t)len, (uint64_t)writeBuf_.size()));
+
+  // Write remaining buffered data, if any.
+  size_t written = 0;
+  while (writeBuf_.size()) {
+    written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
+    errno = ::WSAGetLastError();
+    if (written < 0 && errno == WSAEINTR) {
+      continue;
+    }
+    if (written < 0 && errno == WSAEWOULDBLOCK) {
+      return TLS_ERR_WOULDBLOCK;
+    }
+    if (written == 0) {
+      return written;
+    }
+    if (written < 0) {
+      status_ = SEC_E_INVALID_HANDLE;
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+    }
+    writeBuf_.eat(written);
+  }
+
+  if (len == 0) {
+    return 0;
+  }
+
+  if (!streamSizes_) {
+    streamSizes_.reset(new SecPkgContext_StreamSizes());
+    status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
+                                       streamSizes_.get());
+    if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+    }
+  }
+
+  size_t process = len;
+  auto bytes = reinterpret_cast<const char*>(data);
+  if (writeBuffered_) {
+    // There was buffered data, hence we need to "remove" that data from the
+    // incoming buffer to avoid writing it again
+    if (len < writeBuffered_) {
+      // We didn't get called with the same data again, obviously.
+      status_ = SEC_E_INVALID_HANDLE;
+      status_ = st_error;
+      return TLS_ERR_ERROR;
+    }
+    // just advance the buffer by writeBuffered_ bytes
+    bytes += writeBuffered_;
+    process -= writeBuffered_;
+    writeBuffered_ = 0;
+  }
+  if (!process) {
+    // The buffer contained the full remainder. At this point, the buffer has
+    // been written, so the request is done in its entirety;
+    return len;
+  }
+
+  // Buffered data was already written ;)
+  // If there was no buffered data, this will be len - len = 0.
+  len = len - process;
+  while (process) {
+    // Set up an outgoing message, according to streamSizes_
+    writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
+    size_t dl = streamSizes_->cbHeader + writeBuffered_ +
+                streamSizes_->cbTrailer;
+    auto buf = make_unique<char[]>(dl);
+    TLSBuffer buffers[] = {
+      TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
+      TLSBuffer(SECBUFFER_DATA, writeBuffered_,
+                buf.get() + streamSizes_->cbHeader),
+      TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
+                buf.get() + streamSizes_->cbHeader + writeBuffered_),
+      TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
+    };
+    TLSBufferDesc desc(buffers, 4);
+    memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
+    status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
+    if (status_ != SEC_E_OK) {
+      A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
+                       getLastErrorString().c_str()));
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+    }
+
+    // EncryptMessage may have truncated the buffers.
+    // Should rarely happen, if ever, except for the trailer.
+    dl = buffers[0].cbBuffer;
+    if (dl < streamSizes_->cbHeader) {
+      // Move message.
+      memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
+    }
+    dl += buffers[1].cbBuffer;
+    if (dl < streamSizes_->cbHeader + writeBuffered_) {
+      // Move tailer.
+      memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
+    }
+    dl += buffers[2].cbBuffer;
+
+    // Write (or buffer) the message.
+    char* p = buf.get();
+    while (dl) {
+      written = ::send(sockfd_, p, dl, 0);
+      errno = ::WSAGetLastError();
+      if (written < 0 && errno == WSAEINTR) {
+        continue;
+      }
+      if (written < 0 && errno == WSAEWOULDBLOCK) {
+        // Buffer the rest of the message...
+        writeBuf_.write(p, dl);
+        // and return...
+        return len;
+      }
+      if (written == 0) {
+        A2_LOG_ERROR("WinTLS: Connection closed while writing");
+        status_ = SEC_E_INCOMPLETE_MESSAGE;
+        state_ = st_error;
+        return TLS_ERR_ERROR;
+      }
+      if (written < 0) {
+        A2_LOG_ERROR("WinTLS: Connection error while writing");
+        status_ = SEC_E_INCOMPLETE_MESSAGE;
+        state_ = st_error;
+        return TLS_ERR_ERROR;
+      }
+      dl -= written;
+      p += written;
+    }
+
+    len += writeBuffered_;
+    bytes += writeBuffered_;
+    process -= writeBuffered_;
+    writeBuffered_ = 0;
+  }
+
+  A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
+                   (uint64_t)len, (uint64_t)writeBuf_.size()));
+  if (!len) {
+    return TLS_ERR_WOULDBLOCK;
+  }
+  return len;
+}
+
+ssize_t WinTLSSession::readData(void* data, size_t len)
+{
+  A2_LOG_DEBUG(fmt("WinTLS: Read request: %" PRIu64 " buffered: %" PRIu64,
+                   (uint64_t)len, (uint64_t)readBuf_.size()));
+  if (len == 0) {
+    return 0;
+  }
+
+  // Can be filled from decBuffer entirely?
+  if (decBuf_.size() >= len) {
+    A2_LOG_DEBUG("WinTLS: Fullfilling req from buffer");
+    memcpy(data, decBuf_.data(), len);
+    decBuf_.eat(len);
+    return len;
+  }
+
+  if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
+      state_ == st_handshake_read) {
+    // Renegotiating
+    std::string hn, err;
+    auto connect = tlsConnect(hn, err);
+    if (connect != TLS_ERR_OK) {
+      return connect;
+    }
+    // Continue.
+  }
+  if (state_ != st_connected) {
+    status_ = SEC_E_INVALID_HANDLE;
+    return TLS_ERR_ERROR;
+  }
+
+  // Read as many bytes as available from the connection, up to len + 4k.
+  readBuf_.resize(len + 4096);
+  while (readBuf_.free()) {
+    ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
+    errno = ::WSAGetLastError();
+    if (read < 0 && errno == WSAEINTR) {
+      continue;
+    }
+    if (read < 0 && errno == WSAEWOULDBLOCK) {
+      break;
+    }
+    if (read == 0) {
+      break;
+    }
+    if (read < 0) {
+      status_ = SEC_E_INCOMPLETE_MESSAGE;
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+    }
+    readBuf_.advance(read);
+  }
+
+  // Try to decrypt as many messages as possible from the readBuf_.
+  while (readBuf_.size()) {
+    TLSBuffer bufs[] = {
+      TLSBuffer(SECBUFFER_DATA, readBuf_.size(), readBuf_.data()),
+      TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
+      TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
+      TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
+    };
+    TLSBufferDesc desc(bufs, 4);
+    status_ = ::DecryptMessage(&handle_, &desc, 0, nullptr);
+    if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
+      // Need to stop now, and wait for more bytes to arrive on the socket.
+      break;
+    }
+
+    if (status_ != SEC_E_OK && status_ != SEC_I_CONTEXT_EXPIRED &&
+        status_ != SEC_I_RENEGOTIATE) {
+      A2_LOG_ERROR(fmt("WinTLS: Failed to decrypt a message! %s",
+                       getLastErrorString().c_str()));
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+    }
+
+    // Decrypted message successfully.
+    bool ate = false;
+    for (auto& buf : bufs) {
+      if (buf.BufferType == SECBUFFER_DATA && buf.cbBuffer > 0) {
+        decBuf_.write(buf.pvBuffer, buf.cbBuffer);
+      }
+      else if (buf.BufferType == SECBUFFER_EXTRA && buf.cbBuffer > 0) {
+        readBuf_.eat(readBuf_.size() - buf.cbBuffer);
+        ate = true;
+      }
+    }
+    if (!ate) {
+      readBuf_.clear();
+    }
+
+    if (status_ == SEC_I_RENEGOTIATE) {
+      // Renegotiation basically means performing another handshake
+      state_ = st_initialized;
+      A2_LOG_INFO("WinTLS: Renegotiate");
+      std::string hn, err;
+      auto connect = tlsConnect(hn, err);
+      if (connect == TLS_ERR_WOULDBLOCK) {
+        break;
+      }
+      if (connect == TLS_ERR_ERROR) {
+        return connect;
+      }
+      // Still good.
+    }
+    if (status_ == SEC_I_CONTEXT_EXPIRED) {
+      // Connection is gone now, but the buffered bytes are still valid.
+      A2_LOG_DEBUG("WinTLS: Connection closed!");
+      closeConnection();
+      break;
+    }
+  }
+
+  len = std::min(decBuf_.size(), len);
+  if (len == 0) {
+    return TLS_ERR_WOULDBLOCK;
+  }
+  memcpy(data, decBuf_.data(), len);
+  decBuf_.eat(len);
+  return len;
+}
+
+int WinTLSSession::tlsConnect(const std::string& hostname,
+                              std::string& handshakeErr)
+{
+  // Handshaking will require sending multiple read/write exchanges until the
+  // handshake is actually done. The client will first generate the initial
+  // handshake message, then write that to the server, read the response
+  // message, and write and/or read additional messages until the handshake is
+  // either complete and successful, or something went wrong.
+  // The server works analog to that.
+
+  A2_LOG_DEBUG("WinTLS: Starting/Resuming TLS Connect");
+  ULONG flags = 0;
+
+restart:
+
+  switch (state_) {
+    default:
+      A2_LOG_ERROR("WinTLS: Invalid state");
+      status_ = SEC_E_INVALID_HANDLE;
+      return TLS_ERR_ERROR;
+
+    case st_initialized: {
+      if (side_ == TLS_SERVER) {
+        goto read;
+      }
+
+      if (!hostname.empty()) {
+        setSNIHostname(hostname);
+      }
+      A2_LOG_DEBUG("WinTLS: Initializing handshake");
+      TLSBuffer buf(SECBUFFER_EMPTY, 0, nullptr);
+      TLSBufferDesc desc(&buf, 1);
+      SEC_CHAR* host = hostname_.empty() ?
+        nullptr :
+        const_cast<SEC_CHAR*>(hostname_.c_str());
+      status_ = ::InitializeSecurityContext(
+          cred_,
+          nullptr,
+          host,
+          kReqFlags,
+          0,
+          0,
+          nullptr,
+          0,
+          &handle_,
+          &desc,
+          &flags,
+          nullptr);
+      if (status_ != SEC_I_CONTINUE_NEEDED) {
+        // Has to be SEC_I_CONTINUE_NEEDED, as we did not actually send data
+        // at this point.
+        state_ = st_error;
+        return TLS_ERR_ERROR;
+      }
+
+      // Queue the initial message...
+      writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
+      FreeContextBuffer(buf.pvBuffer);
+
+      // ... and start sending it
+      state_ = st_handshake_write;
+    }
+    // Fall through
+
+    case st_handshake_write_last:
+    case st_handshake_write: {
+      A2_LOG_DEBUG("WinTLS: Writing handshake");
+
+      // Write the currently queued handshake message until all data is sent.
+      while(writeBuf_.size()) {
+        ssize_t writ = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
+        errno = ::WSAGetLastError();
+        if (writ < 0 && errno == WSAEINTR) {
+          continue;
+        }
+        if (writ < 0 && errno == WSAEWOULDBLOCK) {
+          return TLS_ERR_WOULDBLOCK;
+        }
+        if (writ <= 0) {
+          status_ = SEC_E_INCOMPLETE_MESSAGE;
+          state_ = st_error;
+          return TLS_ERR_ERROR;
+        }
+        writeBuf_.eat(writ);
+      }
+
+      if (state_ == st_handshake_write_last) {
+        state_ = st_handshake_done;
+        goto restart;
+      }
+
+      // Have to read one or more response messages.
+      state_ = st_handshake_read;
+    }
+    // Fall through
+
+    case st_handshake_read: {
+read:
+      A2_LOG_DEBUG("WinTLS: Reading handshake...");
+
+      // All write buffered data is invalid at this point!
+      writeBuf_.clear();
+
+      // Read as many bytes as possible, up to 4k new bytes.
+      // We do not know how many bytes will arrive from the server at this
+      // point.
+      readBuf_.resize(readBuf_.size() + 4096);
+      while (readBuf_.free()) {
+        ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
+        errno = ::WSAGetLastError();
+        if (read < 0 && errno == WSAEINTR) {
+          continue;
+        }
+        if (read < 0 && errno == WSAEWOULDBLOCK) {
+          break;
+        }
+        if (read <= 0) {
+          status_ = SEC_E_INCOMPLETE_MESSAGE;
+          state_ = st_error;
+          return TLS_ERR_ERROR;
+        }
+        readBuf_.advance(read);
+        break;
+      }
+      if (!readBuf_.size()) {
+        return TLS_ERR_WOULDBLOCK;
+      }
+
+      // Need to copy the data, as Schannel is free to mess with it. But we
+      // might later need unmodified data from the original read buffer.
+      auto bufcopy = make_unique<char[]>(readBuf_.size());
+      memcpy(bufcopy.get(), readBuf_.data(), readBuf_.size());
+
+      // Set up buffers. inbufs will be the raw bytes the library has to decode.
+      // outbufs will contain generated responses, if any.
+      TLSBuffer inbufs[] = {
+        TLSBuffer(SECBUFFER_TOKEN, readBuf_.size(), bufcopy.get()),
+        TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
+      };
+      TLSBufferDesc indesc(inbufs, 2);
+      TLSBuffer outbufs[] = {
+        TLSBuffer(SECBUFFER_TOKEN, 0, nullptr),
+        TLSBuffer(SECBUFFER_ALERT, 0, nullptr),
+      };
+      TLSBufferDesc outdesc(outbufs, 2);
+      if (side_ == TLS_CLIENT) {
+        SEC_CHAR* host = hostname_.empty() ?
+          nullptr :
+          const_cast<SEC_CHAR*>(hostname_.c_str());
+        status_ = ::InitializeSecurityContext(
+            cred_,
+            &handle_,
+            host,
+            kReqFlags,
+            0,
+            0,
+            &indesc,
+            0,
+            nullptr,
+            &outdesc,
+            &flags,
+            nullptr);
+      }
+      else {
+        status_ = ::AcceptSecurityContext(
+            cred_,
+            state_ == st_initialized ? nullptr : &handle_,
+            &indesc,
+            kReqAFlags,
+            0,
+            state_ == st_initialized ? &handle_ : nullptr,
+            &outdesc,
+            &flags,
+            nullptr);
+      }
+      if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
+        // Not enough raw bytes read yet to decode a full message.
+        return TLS_ERR_WOULDBLOCK;
+      }
+      if (status_ != SEC_E_OK && status_ != SEC_I_CONTINUE_NEEDED) {
+        state_ = st_error;
+        return TLS_ERR_ERROR;
+      }
+
+      // Raw bytes where not entirely consumed, i.e. readBuf_ still contains
+      // unprocessed data from the next message?
+      if (inbufs[1].BufferType == SECBUFFER_EXTRA && inbufs[1].cbBuffer > 0) {
+        readBuf_.eat(readBuf_.size() - inbufs[1].cbBuffer);
+      }
+      else {
+        readBuf_.clear();
+      }
+
+      // Check if the library produced a new outgoing message and queue it.
+      for (auto& buf : outbufs) {
+        if (buf.BufferType == SECBUFFER_TOKEN && buf.cbBuffer > 0) {
+          writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
+          FreeContextBuffer(buf.pvBuffer);
+          state_ = st_handshake_write;
+        }
+      }
+
+      // Need to read additional messages?
+      if (status_ == SEC_I_CONTINUE_NEEDED) {
+        A2_LOG_DEBUG("WinTLS: Continuing with handshake");
+        goto restart;
+      }
+
+      if (side_ == TLS_CLIENT && flags != kReqFlags) {
+        A2_LOG_ERROR(fmt("WinTLS: Channel setup failed. Schannel provider did "
+                         "not fulfill requested flags. "
+                         "Excepted: %lu Actual: %lu",
+                         kReqFlags, flags));
+        status_ = SEC_E_INTERNAL_ERROR;
+        state_ = st_error;
+        return TLS_ERR_ERROR;
+      }
+
+      if (state_ == st_handshake_write) {
+        A2_LOG_DEBUG("WinTLS: Continuing with handshake (last write)");
+        state_ = st_handshake_write_last;
+        goto restart;
+      }
+    }
+    // Fall through
+
+    case st_handshake_done:
+      // All ready now :D
+      state_ = st_connected;
+      A2_LOG_INFO(fmt("WinTLS: connected with: %s",
+                      getCipherSuite(&handle_).c_str()));
+      return TLS_ERR_OK;
+  }
+
+  A2_LOG_ERROR("WinTLS: Unreachable reached during tlsConnect! This is a bug!");
+  state_ = st_error;
+  return TLS_ERR_ERROR;
+}
+
+int WinTLSSession::tlsAccept()
+{
+  std::string host, err;
+  return tlsConnect(host, err);
+}
+
+std::string WinTLSSession::getLastErrorString()
+{
+  std::stringstream ss;
+  wchar_t* buf = nullptr;
+  if (FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
+                    FORMAT_MESSAGE_FROM_SYSTEM |
+                    FORMAT_MESSAGE_IGNORE_INSERTS,
+                    nullptr,
+                    status_,
+                    MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
+                    (LPWSTR)&buf,
+                    1024,
+                    nullptr) && buf) {
+    ss << "Error: " << wCharToUtf8(buf);
+    LocalFree(buf);
+  }
+  else {
+    ss << "Error: " << std::hex << status_;
+  }
+  return ss.str();
+}
+
+} // namespace aria2

+ 194 - 0
src/WinTLSSession.h

@@ -0,0 +1,194 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+
+#ifndef WIN_TLS_SESSION_H
+#define WIN_TLS_SESSION_H
+
+#include <vector>
+
+#include "common.h"
+#include "TLSSession.h"
+#include "WinTLSContext.h"
+
+namespace aria2 {
+
+namespace wintls {
+  struct Buffer {
+  private:
+    size_t off_, free_, cap_;
+    std::vector<char> buf_;
+
+  public:
+    inline Buffer() : off_(0), free_(0), cap_(0) {}
+
+    inline size_t size() const {
+      return off_;
+    }
+    inline size_t free() const {
+      return free_;
+    }
+    inline void resize(size_t len) {
+      if (cap_ >= len) {
+        return;
+      }
+      buf_.resize(len);
+      cap_ = buf_.size();
+      free_ = cap_ - off_;
+    }
+    inline char* data() {
+      return buf_.data();
+    }
+    inline char* end() {
+      return buf_.data() + off_;
+    }
+    inline void eat(size_t len) {
+      off_ -= len;
+      if (off_) {
+        memmove(buf_.data(), buf_.data() + len, off_);
+      }
+      free_ = cap_ - off_;
+    }
+    inline void clear() {
+      eat(off_);
+    }
+    inline void advance(size_t len) {
+      off_ += len;
+      free_ = cap_ - off_;
+    }
+    inline void write(const void* data, size_t len) {
+      if (!len) {
+        return;
+      }
+      resize(off_ + len);
+      memcpy(end(), data, len);
+      advance(len);
+    }
+  };
+} // namespace wintls
+
+class WinTLSSession : public TLSSession {
+  enum state_t {
+    st_constructed,
+    st_initialized,
+    st_handshake_write,
+    st_handshake_write_last,
+    st_handshake_read,
+    st_handshake_done,
+    st_connected,
+    st_closing,
+    st_closed,
+    st_error
+  };
+
+public:
+  WinTLSSession(WinTLSContext* ctx);
+
+  // MUST deallocate all resources
+  virtual ~WinTLSSession();
+
+  // Initializes SSL/TLS session. The |sockfd| is the underlying
+  // tranport socket. This function returns TLS_ERR_OK if it
+  // succeeds, or TLS_ERR_ERROR.
+  virtual int init(sock_t sockfd) CXX11_OVERRIDE;
+
+  // Sets |hostname| for TLS SNI extension. This is only meaningful for
+  // client side session. This function returns TLS_ERR_OK if it
+  // succeeds, or TLS_ERR_ERROR.
+  virtual int setSNIHostname(const std::string& hostname) CXX11_OVERRIDE;
+
+  // Closes the SSL/TLS session. Don't close underlying transport
+  // socket. This function returns TLS_ERR_OK if it succeeds, or
+  // TLS_ERR_ERROR.
+  virtual int closeConnection() CXX11_OVERRIDE;
+
+  // Returns TLS_WANT_READ if SSL/TLS session needs more data from
+  // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
+  // needs to write more data to proceed. If SSL/TLS session needs
+  // neither read nor write data at the moment, return value is
+  // undefined.
+  virtual int checkDirection() CXX11_OVERRIDE;
+
+  // Sends |data| with length |len|. This function returns the number
+  // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
+  // underlying tranport blocks, or TLS_ERR_ERROR.
+  virtual ssize_t writeData(const void* data, size_t len) CXX11_OVERRIDE;
+
+  // Receives data into |data| with length |len|. This function returns
+  // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
+  // if the underlying tranport blocks, or TLS_ERR_ERROR.
+  virtual ssize_t readData(void* data, size_t len) CXX11_OVERRIDE;
+
+  // Performs client side handshake. The |hostname| is the hostname of
+  // the remote endpoint and is used to verify its certificate. This
+  // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
+  // if the underlying transport blocks, or TLS_ERR_ERROR.
+  // When returning TLS_ERR_ERROR, provide certificate validation error
+  // in |handshakeErr|.
+  virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr) CXX11_OVERRIDE;
+
+  // Performs server side handshake. This function returns TLS_ERR_OK
+  // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
+  // blocks, or TLS_ERR_ERROR.
+  virtual int tlsAccept() CXX11_OVERRIDE;
+
+  // Returns last error string
+  virtual std::string getLastErrorString() CXX11_OVERRIDE;
+
+private:
+  std::string hostname_;
+  sock_t sockfd_;
+  TLSSessionSide side_;
+  CredHandle* cred_;
+  CtxtHandle handle_;
+
+  // Buffer for already encrypted writes
+  wintls::Buffer writeBuf_;
+  // While the writeBuf_ holds encrypted messages, writeBuffered_ has the
+  // corresponding size of unencrpted data used to procude the messages.
+  size_t writeBuffered_;
+  // Buffer for still encrypted reads
+  wintls::Buffer readBuf_;
+  // Buffer for already decrypted reads
+  wintls::Buffer decBuf_;
+
+  state_t state_;
+
+  SECURITY_STATUS status_;
+  std::unique_ptr<SecPkgContext_StreamSizes> streamSizes_;
+};
+
+} // namespace aria2
+
+#endif // TLS_SESSION_H