Просмотр исходного кода

Abstract TLS session implementation

Now TLS session object is abstracted as TLSSession class. Currently,
we have GNUTLS and OpenSSL implementations.
Tatsuhiro Tsujikawa 12 лет назад
Родитель
Сommit
8580c98bce
9 измененных файлов с 944 добавлено и 437 удалено
  1. 266 0
      src/LibgnutlsTLSSession.cc
  2. 73 0
      src/LibgnutlsTLSSession.h
  3. 299 0
      src/LibsslTLSSession.cc
  4. 74 0
      src/LibsslTLSSession.h
  5. 6 2
      src/Makefile.am
  6. 59 406
      src/SocketCore.cc
  7. 8 29
      src/SocketCore.h
  8. 104 0
      src/TLSSession.h
  9. 55 0
      src/TLSSessionConst.h

+ 266 - 0
src/LibgnutlsTLSSession.cc

@@ -0,0 +1,266 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Tatsuhiro Tsujikawa
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#include "LibgnutlsTLSSession.h"
+
+#include <gnutls/x509.h>
+
+#include "TLSContext.h"
+#include "util.h"
+#include "SocketCore.h"
+
+namespace aria2 {
+
+TLSSession::TLSSession(TLSContext* tlsContext)
+  : sslSession_(0),
+    tlsContext_(tlsContext),
+    rv_(0)
+{}
+
+TLSSession::~TLSSession()
+{
+  if(sslSession_) {
+    gnutls_deinit(sslSession_);
+  }
+}
+
+int TLSSession::init(sock_t sockfd)
+{
+  rv_ = gnutls_init(&sslSession_,
+                    tlsContext_->getSide() == TLS_CLIENT ?
+                    GNUTLS_CLIENT : GNUTLS_SERVER);
+  if(rv_ != GNUTLS_E_SUCCESS) {
+    return TLS_ERR_ERROR;
+  }
+  // It seems err is not error message, but the argument string
+  // which causes syntax error.
+  const char* err;
+  // For client side, disables TLS1.1 here because there are servers
+  // that don't understand TLS1.1.  TODO Is this still necessary?
+  rv_ = gnutls_priority_set_direct(sslSession_,
+                                   tlsContext_->getSide() == TLS_CLIENT ?
+                                   "NORMAL:-VERS-TLS1.1" :
+                                   "NORMAL",
+                                   &err);
+  if(rv_ != GNUTLS_E_SUCCESS) {
+    return TLS_ERR_ERROR;
+  }
+  // put the x509 credentials to the current session
+  rv_ = gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE,
+                               tlsContext_->getCertCred());
+  if(rv_ != GNUTLS_E_SUCCESS) {
+    return TLS_ERR_ERROR;
+  }
+  // TODO Consider to use gnutls_transport_set_int() for GNUTLS 3.1.9
+  // or later
+  gnutls_transport_set_ptr(sslSession_,
+                           (gnutls_transport_ptr_t)(ptrdiff_t)sockfd);
+  return TLS_ERR_OK;
+}
+
+int TLSSession::setSNIHostname(const std::string& hostname)
+{
+  // TLS extensions: SNI
+  rv_ = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS,
+                               hostname.c_str(), hostname.size());
+  if(rv_ != GNUTLS_E_SUCCESS) {
+    return TLS_ERR_ERROR;
+  }
+  return TLS_ERR_OK;
+}
+
+int TLSSession::closeConnection()
+{
+  rv_ = gnutls_bye(sslSession_, GNUTLS_SHUT_WR);
+  if(rv_ == GNUTLS_E_SUCCESS) {
+    return TLS_ERR_OK;
+  } else if(rv_ == GNUTLS_E_AGAIN) {
+    return TLS_ERR_WOULDBLOCK;
+  } else {
+    return TLS_ERR_ERROR;
+  }
+}
+
+int TLSSession::checkDirection()
+{
+  int direction = gnutls_record_get_direction(sslSession_);
+  return direction == 0 ? TLS_WANT_READ : TLS_WANT_WRITE;
+}
+
+ssize_t TLSSession::writeData(const void* data, size_t len)
+{
+  while((rv_ = gnutls_record_send(sslSession_, data, len)) ==
+        GNUTLS_E_INTERRUPTED);
+  if(rv_ >= 0) {
+    ssize_t ret = rv_;
+    rv_ = 0;
+    return ret;
+  } else if(rv_ == GNUTLS_E_AGAIN) {
+    return TLS_ERR_WOULDBLOCK;
+  } else {
+    return TLS_ERR_ERROR;
+  }
+}
+
+ssize_t TLSSession::readData(void* data, size_t len)
+{
+  while((rv_ = gnutls_record_recv(sslSession_, data, len)) ==
+        GNUTLS_E_INTERRUPTED);
+  if(rv_ >= 0) {
+    ssize_t ret = rv_;
+    rv_ = 0;
+    return ret;
+  } else if(rv_ == GNUTLS_E_AGAIN) {
+    return TLS_ERR_WOULDBLOCK;
+  } else {
+    return TLS_ERR_ERROR;
+  }
+}
+
+int TLSSession::tlsConnect(const std::string& hostname,
+                           std::string& handshakeErr)
+{
+  handshakeErr = "";
+  rv_ = gnutls_handshake(sslSession_);
+  if(rv_ < 0) {
+    if(rv_ == GNUTLS_E_AGAIN) {
+      return TLS_ERR_WOULDBLOCK;
+    } else {
+      return TLS_ERR_ERROR;
+    }
+  }
+  if(tlsContext_->peerVerificationEnabled()) {
+    // verify peer
+    unsigned int status;
+    rv_ = gnutls_certificate_verify_peers2(sslSession_, &status);
+    if(rv_ != GNUTLS_E_SUCCESS) {
+      return TLS_ERR_ERROR;
+    }
+    if(status) {
+      handshakeErr = "";
+      if(status & GNUTLS_CERT_INVALID) {
+        handshakeErr += " `not signed by known authorities or invalid'";
+      }
+      if(status & GNUTLS_CERT_REVOKED) {
+        handshakeErr += " `revoked by its CA'";
+      }
+      if(status & GNUTLS_CERT_SIGNER_NOT_FOUND) {
+        handshakeErr += " `issuer is not known'";
+      }
+      // TODO should check GNUTLS_CERT_SIGNER_NOT_CA ?
+      if(status & GNUTLS_CERT_INSECURE_ALGORITHM) {
+        handshakeErr += " `insecure algorithm'";
+      }
+      if(status & GNUTLS_CERT_NOT_ACTIVATED) {
+        handshakeErr += " `not activated yet'";
+      }
+      if(status & GNUTLS_CERT_EXPIRED) {
+        handshakeErr += " `expired'";
+      }
+      // TODO Add GNUTLS_CERT_SIGNATURE_FAILURE here
+      if(!handshakeErr.empty()) {
+        return TLS_ERR_ERROR;
+      }
+    }
+    // certificate type: only X509 is allowed.
+    if(gnutls_certificate_type_get(sslSession_) != GNUTLS_CRT_X509) {
+      handshakeErr = "certificate type must be X509";
+      return TLS_ERR_ERROR;
+    }
+    unsigned int peerCertsLength;
+    const gnutls_datum_t* peerCerts;
+    peerCerts = gnutls_certificate_get_peers(sslSession_, &peerCertsLength);
+    if(!peerCerts || peerCertsLength == 0 ) {
+      handshakeErr = "certificate not found";
+      return TLS_ERR_ERROR;
+    }
+    gnutls_x509_crt_t cert;
+    rv_ = gnutls_x509_crt_init(&cert);
+    if(rv_ != GNUTLS_E_SUCCESS) {
+      return TLS_ERR_ERROR;
+    }
+    auto_delete<gnutls_x509_crt_t> certDeleter(cert, gnutls_x509_crt_deinit);
+    rv_ = gnutls_x509_crt_import(cert, &peerCerts[0], GNUTLS_X509_FMT_DER);
+    if(rv_ != GNUTLS_E_SUCCESS) {
+      return TLS_ERR_ERROR;
+    }
+    std::string commonName;
+    std::vector<std::string> dnsNames;
+    std::vector<std::string> ipAddrs;
+    int ret = 0;
+    char altName[256];
+    size_t altNameLen;
+    for(int i = 0; !(ret < 0); ++i) {
+      altNameLen = sizeof(altName);
+      ret = gnutls_x509_crt_get_subject_alt_name(cert, i, altName,
+                                                 &altNameLen, 0);
+      if(ret == GNUTLS_SAN_DNSNAME) {
+        dnsNames.push_back(std::string(altName, altNameLen));
+      } else if(ret == GNUTLS_SAN_IPADDRESS) {
+        ipAddrs.push_back(std::string(altName, altNameLen));
+      }
+    }
+    altNameLen = sizeof(altName);
+    ret = gnutls_x509_crt_get_dn_by_oid(cert,
+                                        GNUTLS_OID_X520_COMMON_NAME, 0, 0,
+                                        altName, &altNameLen);
+    if(ret == 0) {
+      commonName.assign(altName, altNameLen);
+    }
+    if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) {
+      handshakeErr = "hostname does not match";
+      return TLS_ERR_ERROR;
+    }
+  }
+  return TLS_ERR_OK;
+}
+
+int TLSSession::tlsAccept()
+{
+  rv_ = gnutls_handshake(sslSession_);
+  if(rv_ == GNUTLS_E_SUCCESS) {
+    return TLS_ERR_OK;
+  } else if(rv_ == GNUTLS_E_AGAIN) {
+    return TLS_ERR_WOULDBLOCK;
+  } else {
+    return TLS_ERR_ERROR;
+  }
+}
+
+std::string TLSSession::getLastErrorString()
+{
+  return gnutls_strerror(rv_);
+}
+
+} // namespace aria2

+ 73 - 0
src/LibgnutlsTLSSession.h

@@ -0,0 +1,73 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Tatsuhiro Tsujikawa
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#ifndef LIBGNUTLS_TLS_SESSION_H
+#define LIBGNUTLS_TLS_SESSION_H
+
+#include "common.h"
+
+#include <gnutls/gnutls.h>
+
+#include <string>
+
+#include "TLSSessionConst.h"
+#include "a2netcompat.h"
+
+namespace aria2 {
+
+class TLSContext;
+
+class TLSSession {
+public:
+  TLSSession(TLSContext* tlsContext);
+  ~TLSSession();
+  int init(sock_t sockfd);
+  int setSNIHostname(const std::string& hostname);
+  int closeConnection();
+  int checkDirection();
+  ssize_t writeData(const void* data, size_t len);
+  ssize_t readData(void* data, size_t len);
+  int tlsConnect(const std::string& hostname, std::string& handshakeErr);
+  int tlsAccept();
+  std::string getLastErrorString();
+private:
+  gnutls_session_t sslSession_;
+  TLSContext* tlsContext_;
+  // Last error code from gnutls library functions
+  int rv_;
+};
+
+} // namespace aria2
+
+#endif // LIBGNUTLS_TLS_SESSION_H

+ 299 - 0
src/LibsslTLSSession.cc

@@ -0,0 +1,299 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Tatsuhiro Tsujikawa
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#include "LibsslTLSSession.h"
+
+#include <openssl/err.h>
+#include <openssl/x509.h>
+#include <openssl/x509v3.h>
+
+#include "TLSContext.h"
+#include "util.h"
+#include "SocketCore.h"
+
+namespace aria2 {
+
+TLSSession::TLSSession(TLSContext* tlsContext)
+  : ssl_(0),
+    tlsContext_(tlsContext),
+    rv_(1)
+{}
+
+TLSSession::~TLSSession()
+{
+  if(ssl_) {
+    SSL_shutdown(ssl_);
+  }
+}
+
+int TLSSession::init(sock_t sockfd)
+{
+  ERR_clear_error();
+  ssl_ = SSL_new(tlsContext_->getSSLCtx());
+  if(!ssl_) {
+    return TLS_ERR_ERROR;
+  }
+  rv_ = SSL_set_fd(ssl_, sockfd);
+  if(rv_ == 0) {
+    return TLS_ERR_ERROR;
+  }
+  return TLS_ERR_OK;
+}
+
+int TLSSession::setSNIHostname(const std::string& hostname)
+{
+#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
+  ERR_clear_error();
+  // TLS extensions: SNI.  There is not documentation about the
+  // return code for this function (actually this is macro
+  // wrapping SSL_ctrl at the time of this writing).
+  SSL_set_tlsext_host_name(ssl_, hostname.c_str());
+#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME
+  return TLS_ERR_OK;
+}
+
+int TLSSession::closeConnection()
+{
+  ERR_clear_error();
+  SSL_shutdown(ssl_);
+  // TODO handle return value
+  return TLS_ERR_OK;
+}
+
+int TLSSession::checkDirection()
+{
+  int error = SSL_get_error(ssl_, rv_);
+  if(error == SSL_ERROR_WANT_WRITE) {
+    return TLS_WANT_WRITE;
+  } else {
+    // TODO We ignore error other than SSL_ERR_WANT_READ here for now
+    return TLS_WANT_READ;
+  }
+}
+
+namespace {
+bool wouldblock(SSL* ssl, int rv)
+{
+  int error = SSL_get_error(ssl, rv);
+  return error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE;
+}
+} // namespace
+
+ssize_t TLSSession::writeData(const void* data, size_t len)
+{
+  ERR_clear_error();
+  rv_ = SSL_write(ssl_, data, len);
+  if(rv_ <= 0) {
+    if(wouldblock(ssl_, rv_)) {
+      return TLS_ERR_WOULDBLOCK;
+    } else {
+      return TLS_ERR_ERROR;
+    }
+  } else {
+    ssize_t ret = rv_;
+    rv_ = 1;
+    return ret;
+  }
+}
+
+ssize_t TLSSession::readData(void* data, size_t len)
+{
+  ERR_clear_error();
+  rv_ = SSL_read(ssl_, data, len);
+  if(rv_ <= 0) {
+    if(wouldblock(ssl_, rv_)) {
+      return TLS_ERR_WOULDBLOCK;
+    } else {
+      return TLS_ERR_ERROR;
+    }
+  } else {
+    ssize_t ret = rv_;
+    rv_ = 1;
+    return ret;
+  }
+}
+
+int TLSSession::handshake()
+{
+  ERR_clear_error();
+  if(tlsContext_->getSide() == TLS_CLIENT) {
+    rv_ = SSL_connect(ssl_);
+  } else {
+    rv_ = SSL_accept(ssl_);
+  }
+  if(rv_ <= 0) {
+    int sslError = SSL_get_error(ssl_, rv_);
+    switch(sslError) {
+    case SSL_ERROR_NONE:
+    case SSL_ERROR_WANT_X509_LOOKUP:
+    case SSL_ERROR_ZERO_RETURN:
+      // TODO Now assume we are doing non-blocking. Then above 2
+      // errors are OK.
+      break;
+    case SSL_ERROR_WANT_READ:
+    case SSL_ERROR_WANT_WRITE:
+      return TLS_ERR_WOULDBLOCK;
+    default:
+      return TLS_ERR_ERROR;
+    }
+  }
+  return TLS_ERR_OK;
+}
+
+int TLSSession::tlsConnect(const std::string& hostname,
+                           std::string& handshakeErr)
+{
+  handshakeErr = "";
+  int ret;
+  ret = handshake();
+  if(ret != TLS_ERR_OK) {
+    return ret;
+  }
+  if(tlsContext_->getSide() == TLS_CLIENT &&
+     tlsContext_->peerVerificationEnabled()) {
+    // verify peer
+    X509* peerCert = SSL_get_peer_certificate(ssl_);
+    if(!peerCert) {
+      handshakeErr = "certificate not found";
+      return TLS_ERR_ERROR;
+    }
+    auto_delete<X509*> certDeleter(peerCert, X509_free);
+    long verifyResult = SSL_get_verify_result(ssl_);
+    if(verifyResult != X509_V_OK) {
+      handshakeErr = X509_verify_cert_error_string(verifyResult);
+      return TLS_ERR_ERROR;
+    }
+    std::string commonName;
+    std::vector<std::string> dnsNames;
+    std::vector<std::string> ipAddrs;
+    GENERAL_NAMES* altNames;
+    altNames = reinterpret_cast<GENERAL_NAMES*>
+      (X509_get_ext_d2i(peerCert, NID_subject_alt_name, NULL, NULL));
+    if(altNames) {
+      auto_delete<GENERAL_NAMES*> altNamesDeleter
+        (altNames, GENERAL_NAMES_free);
+      size_t n = sk_GENERAL_NAME_num(altNames);
+      for(size_t i = 0; i < n; ++i) {
+        const GENERAL_NAME* altName = sk_GENERAL_NAME_value(altNames, i);
+        if(altName->type == GEN_DNS) {
+          const char* name =
+            reinterpret_cast<char*>(ASN1_STRING_data(altName->d.ia5));
+          if(!name) {
+            continue;
+          }
+          size_t len = ASN1_STRING_length(altName->d.ia5);
+          dnsNames.push_back(std::string(name, len));
+        } else if(altName->type == GEN_IPADD) {
+          const unsigned char* ipAddr = altName->d.iPAddress->data;
+          if(!ipAddr) {
+            continue;
+          }
+          size_t len = altName->d.iPAddress->length;
+          ipAddrs.push_back(std::string(reinterpret_cast<const char*>(ipAddr),
+                                        len));
+        }
+      }
+    }
+    X509_NAME* subjectName = X509_get_subject_name(peerCert);
+    if(!subjectName) {
+      handshakeErr = "could not get X509 name object from the certificate.";
+      return TLS_ERR_ERROR;
+    }
+    int lastpos = -1;
+    while(1) {
+      lastpos = X509_NAME_get_index_by_NID(subjectName, NID_commonName,
+                                           lastpos);
+      if(lastpos == -1) {
+        break;
+      }
+      X509_NAME_ENTRY* entry = X509_NAME_get_entry(subjectName, lastpos);
+      unsigned char* out;
+      int outlen = ASN1_STRING_to_UTF8(&out,
+                                       X509_NAME_ENTRY_get_data(entry));
+      if(outlen < 0) {
+        continue;
+      }
+      commonName.assign(&out[0], &out[outlen]);
+      OPENSSL_free(out);
+      break;
+    }
+    if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) {
+      handshakeErr = "hostname does not match";
+      return TLS_ERR_ERROR;
+    }
+  }
+  return TLS_ERR_OK;
+}
+
+int TLSSession::tlsAccept()
+{
+  return handshake();
+}
+
+std::string TLSSession::getLastErrorString()
+{
+  if(rv_ <= 0) {
+    int sslError = SSL_get_error(ssl_, rv_);
+    switch(sslError) {
+    case SSL_ERROR_NONE:
+    case SSL_ERROR_WANT_READ:
+    case SSL_ERROR_WANT_WRITE:
+    case SSL_ERROR_WANT_X509_LOOKUP:
+    case SSL_ERROR_ZERO_RETURN:
+      return "";
+    case SSL_ERROR_SYSCALL: {
+      int err = ERR_get_error();
+      if(err == 0) {
+        if(rv_ == 0) {
+          return "EOF was received";
+        } else if(rv_ == -1) {
+          return "SSL I/O error";
+        } else {
+          return "unknown syscall error";
+        }
+      } else {
+        return ERR_error_string(err, 0);
+      }
+    }
+    case SSL_ERROR_SSL:
+      return "protocol error";
+    default:
+      return "unknown error";
+    }
+  } else {
+    return "";
+  }
+}
+
+} // namespace aria2

+ 74 - 0
src/LibsslTLSSession.h

@@ -0,0 +1,74 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Tatsuhiro Tsujikawa
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#ifndef LIBSSL_TLS_SESSION_H
+#define LIBSSL_TLS_SESSION_H
+
+#include "common.h"
+
+#include <openssl/ssl.h>
+
+#include <string>
+
+#include "TLSSessionConst.h"
+#include "a2netcompat.h"
+
+namespace aria2 {
+
+class TLSContext;
+
+class TLSSession {
+public:
+  TLSSession(TLSContext* tlsContext);
+  ~TLSSession();
+  int init(sock_t sockfd);
+  int setSNIHostname(const std::string& hostname);
+  int closeConnection();
+  int checkDirection();
+  ssize_t writeData(const void* data, size_t len);
+  ssize_t readData(void* data, size_t len);
+  int tlsConnect(const std::string& hostname, std::string& handshakeErr);
+  int tlsAccept();
+  std::string getLastErrorString();
+private:
+  int handshake();
+  SSL* ssl_;
+  TLSContext* tlsContext_;
+  // Last error code from openSSL library functions
+  int rv_;
+};
+
+} // namespace aria2
+
+#endif // LIBSSL_TLS_SESSION_H

+ 6 - 2
src/Makefile.am

@@ -299,11 +299,14 @@ SRCS += EpollEventPoll.cc EpollEventPoll.h
 endif # HAVE_EPOLL
 
 if ENABLE_SSL
-SRCS += TLSContext.h
+SRCS += TLSContext.h\
+	TLSSession.h\
+	TLSSessionConst.h
 endif # ENABLE_SSL
 
 if HAVE_LIBGNUTLS
-SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h
+SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h\
+	LibgnutlsTLSSession.cc LibgnutlsTLSSession.h
 endif # HAVE_LIBGNUTLS
 
 if HAVE_LIBGCRYPT
@@ -324,6 +327,7 @@ endif # HAVE_LIBGMP
 
 if HAVE_OPENSSL
 SRCS += LibsslTLSContext.cc LibsslTLSContext.h\
+	LibsslTLSSession.cc LibsslTLSSession.h\
 	LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h\
 	LibsslARC4Encryptor.cc LibsslARC4Encryptor.h\
 	LibsslDHKeyExchange.cc LibsslDHKeyExchange.h

+ 59 - 406
src/SocketCore.cc

@@ -46,15 +46,6 @@
 #include <cerrno>
 #include <cstring>
 
-#ifdef HAVE_OPENSSL
-# include <openssl/x509.h>
-# include <openssl/x509v3.h>
-#endif // HAVE_OPENSSL
-
-#ifdef HAVE_LIBGNUTLS
-# include <gnutls/x509.h>
-#endif // HAVE_LIBGNUTLS
-
 #include "message.h"
 #include "DlRetryEx.h"
 #include "DlAbortEx.h"
@@ -66,6 +57,7 @@
 #include "A2STR.h"
 #ifdef ENABLE_SSL
 # include "TLSContext.h"
+# include "TLSSession.h"
 #endif // ENABLE_SSL
 
 namespace aria2 {
@@ -179,14 +171,6 @@ void SocketCore::init()
 
   wantRead_ = false;
   wantWrite_ = false;
-
-#ifdef HAVE_OPENSSL
-  // for SSL
-  ssl = NULL;
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-  sslSession_ = 0;
-#endif //HAVE_LIBGNUTLS
 }
 
 SocketCore::~SocketCore() {
@@ -586,33 +570,15 @@ void SocketCore::setBlockingMode()
 
 void SocketCore::closeConnection()
 {
-#ifdef HAVE_OPENSSL
-  // for SSL
-  if(secure_) {
-    SSL_shutdown(ssl);
-  }
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-  if(secure_) {
-    gnutls_bye(sslSession_, GNUTLS_SHUT_WR);
+  if(tlsSession_) {
+    tlsSession_->closeConnection();
+    tlsSession_.reset();
   }
-#endif // HAVE_LIBGNUTLS
   if(sockfd_ != (sock_t) -1) {
     shutdown(sockfd_, SHUT_WR);
     CLOSE(sockfd_);
     sockfd_ = -1;
   }
-#ifdef HAVE_OPENSSL
-  // for SSL
-  if(secure_) {
-    SSL_free(ssl);
-  }
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-  if(secure_) {
-    gnutls_deinit(sslSession_);
-  }
-#endif // HAVE_LIBGNUTLS
 }
 
 #ifndef __MINGW32__
@@ -716,34 +682,6 @@ bool SocketCore::isReadable(time_t timeout)
 #endif // !HAVE_POLL
 }
 
-#ifdef HAVE_OPENSSL
-int SocketCore::sslHandleEAGAIN(int ret)
-{
-  int error = SSL_get_error(ssl, ret);
-  if(error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) {
-    ret = 0;
-    if(error == SSL_ERROR_WANT_READ) {
-      wantRead_ = true;
-    } else {
-      wantWrite_ = true;
-    }
-  }
-  return ret;
-}
-#endif // HAVE_OPENSSL
-
-#ifdef HAVE_LIBGNUTLS
-void SocketCore::gnutlsRecordCheckDirection()
-{
-  int direction = gnutls_record_get_direction(sslSession_);
-  if(direction == 0) {
-    wantRead_ = true;
-  } else { // if(direction == 1) {
-    wantWrite_ = true;
-  }
-}
-#endif // HAVE_LIBGNUTLS
-
 ssize_t SocketCore::writeVector(a2iovec *iov, size_t iovcnt)
 {
   ssize_t ret = 0;
@@ -805,29 +743,21 @@ ssize_t SocketCore::writeData(const void* data, size_t len)
       }
     }
   } else {
-#ifdef HAVE_OPENSSL
-    ERR_clear_error();
-    ret = SSL_write(ssl, data, len);
+    ret = tlsSession_->writeData(data, len);
     if(ret < 0) {
-      ret = sslHandleEAGAIN(ret);
-    }
-    if(ret < 0) {
-      throw DL_RETRY_EX
-        (fmt(EX_SOCKET_SEND, ERR_error_string(ERR_get_error(), 0)));
-    }
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-    while((ret = gnutls_record_send(sslSession_, data, len)) ==
-          GNUTLS_E_INTERRUPTED);
-    if(ret == GNUTLS_E_AGAIN) {
-      gnutlsRecordCheckDirection();
-      ret = 0;
-    } else if(ret < 0) {
-      throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, gnutls_strerror(ret)));
+      if(ret == TLS_ERR_WOULDBLOCK) {
+        if(tlsSession_->checkDirection() == TLS_WANT_READ) {
+          wantRead_ = true;
+        } else {
+          wantWrite_ = true;
+        }
+        ret = 0;
+      } else {
+        throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
+                              tlsSession_->getLastErrorString().c_str()));
+      }
     }
-#endif // HAVE_LIBGNUTLS
   }
-
   return ret;
 }
 
@@ -851,31 +781,21 @@ void SocketCore::readData(void* data, size_t& len)
       }
     }
   } else {
-#ifdef HAVE_OPENSSL
-    // for SSL
-    // TODO handling len == 0 case required
-    ERR_clear_error();
-    ret = SSL_read(ssl, data, len);
+    ret = tlsSession_->readData(data, len);
     if(ret < 0) {
-      ret = sslHandleEAGAIN(ret);
-    }
-    if(ret < 0) {
-      throw DL_RETRY_EX
-        (fmt(EX_SOCKET_RECV, ERR_error_string(ERR_get_error(), 0)));
-    }
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-    while((ret = gnutls_record_recv(sslSession_, data, len)) ==
-          GNUTLS_E_INTERRUPTED);
-    if(ret == GNUTLS_E_AGAIN) {
-      gnutlsRecordCheckDirection();
-      ret = 0;
-    } else if(ret < 0) {
-      throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, gnutls_strerror(ret)));
+      if(ret == TLS_ERR_WOULDBLOCK) {
+        if(tlsSession_->checkDirection() == TLS_WANT_READ) {
+          wantRead_ = true;
+        } else {
+          wantWrite_ = true;
+        }
+        ret = 0;
+      } else {
+        throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
+                              tlsSession_->getLastErrorString().c_str()));
+      }
     }
-#endif // HAVE_LIBGNUTLS
   }
-
   len = ret;
 }
 
@@ -893,324 +813,57 @@ bool SocketCore::tlsConnect(const std::string& hostname)
 
 bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
 {
+  int rv = 0;
+  std::string handshakeError;
   wantRead_ = false;
   wantWrite_ = false;
-#ifdef HAVE_OPENSSL
   switch(secure_) {
   case A2_TLS_NONE:
-    ssl = SSL_new(tlsctx->getSSLCtx());
-    if(!ssl) {
-      throw DL_ABORT_EX
-        (fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0)));
-    }
-    if(SSL_set_fd(ssl, sockfd_) == 0) {
-      throw DL_ABORT_EX
-        (fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0)));
+    tlsSession_.reset(new TLSSession(tlsctx));
+    rv = tlsSession_->init(sockfd_);
+    if(rv != TLS_ERR_OK) {
+      std::string error = tlsSession_->getLastErrorString();
+      tlsSession_.reset();
+      throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, error.c_str()));
     }
-    // Fall through
-#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
-    if(tlsctx->getSide() == TLS_CLIENT && !util::isNumericHost(hostname)) {
-      // TLS extensions: SNI.  There is not documentation about the
-      // return code for this function (actually this is macro
-      // wrapping SSL_ctrl at the time of this writing).
-      SSL_set_tlsext_host_name(ssl, hostname.c_str());
+    // Check hostname is not numeric and it includes ".". Setting
+    // "localhost" will produce TLS alert with GNUTLS.
+    if(tlsctx->getSide() == TLS_CLIENT &&
+       !util::isNumericHost(hostname) &&
+       hostname.find(".") != std::string::npos) {
+      rv = tlsSession_->setSNIHostname(hostname);
+      if(rv != TLS_ERR_OK) {
+        throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE,
+                              tlsSession_->getLastErrorString().c_str()));
+      }
     }
-#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME
     secure_ = A2_TLS_HANDSHAKING;
     // Fall through
-  case A2_TLS_HANDSHAKING: {
-    ERR_clear_error();
-    int e;
+  case A2_TLS_HANDSHAKING:
     if(tlsctx->getSide() == TLS_CLIENT) {
-      e = SSL_connect(ssl);
+      rv = tlsSession_->tlsConnect(hostname, handshakeError);
     } else {
-      e = SSL_accept(ssl);
+      rv = tlsSession_->tlsAccept();
     }
-
-    if (e <= 0) {
-      int ssl_error = SSL_get_error(ssl, e);
-      switch(ssl_error) {
-      case SSL_ERROR_NONE:
-        break;
-      case SSL_ERROR_WANT_READ:
+    if(rv == TLS_ERR_OK) {
+      secure_ = A2_TLS_CONNECTED;
+    } else if(rv == TLS_ERR_WOULDBLOCK) {
+      if(tlsSession_->checkDirection() == TLS_WANT_READ) {
         wantRead_ = true;
-        return false;
-      case SSL_ERROR_WANT_WRITE:
+      } else {
         wantWrite_ = true;
-        return false;
-      case SSL_ERROR_WANT_X509_LOOKUP:
-      case SSL_ERROR_ZERO_RETURN:
-        if (blocking_) {
-          throw DL_ABORT_EX(fmt(EX_SSL_CONNECT_ERROR, ssl_error));
-        }
-        break;
-
-      case SSL_ERROR_SYSCALL: {
-        int sslErr = ERR_get_error();
-        if(sslErr == 0) {
-          if(e == 0) {
-            throw DL_ABORT_EX("Got EOF in SSL handshake");
-          } else if(e == -1) {
-            throw DL_ABORT_EX(fmt("SSL I/O error: %s", strerror(errno)));
-          } else {
-            throw DL_ABORT_EX(EX_SSL_IO_ERROR);
-          }
-        } else {
-          throw DL_ABORT_EX(fmt("SSL I/O error: %s",
-                                ERR_error_string(sslErr, 0)));
-        }
       }
-      case SSL_ERROR_SSL:
-        throw DL_ABORT_EX(EX_SSL_PROTOCOL_ERROR);
-
-      default:
-        throw DL_ABORT_EX(fmt(EX_SSL_UNKNOWN_ERROR, ssl_error));
-      }
-    }
-    if(tlsctx->getSide() == TLS_CLIENT &&
-       tlsctx->peerVerificationEnabled()) {
-      // verify peer
-      X509* peerCert = SSL_get_peer_certificate(ssl);
-      if(!peerCert) {
-        throw DL_ABORT_EX(MSG_NO_CERT_FOUND);
-      }
-      auto_delete<X509*> certDeleter(peerCert, X509_free);
-
-      long verifyResult = SSL_get_verify_result(ssl);
-      if(verifyResult != X509_V_OK) {
-        throw DL_ABORT_EX
-          (fmt(MSG_CERT_VERIFICATION_FAILED,
-               X509_verify_cert_error_string(verifyResult)));
-      }
-      std::string commonName;
-      std::vector<std::string> dnsNames;
-      std::vector<std::string> ipAddrs;
-      GENERAL_NAMES* altNames;
-      altNames = reinterpret_cast<GENERAL_NAMES*>
-        (X509_get_ext_d2i(peerCert, NID_subject_alt_name, NULL, NULL));
-      if(altNames) {
-        auto_delete<GENERAL_NAMES*> altNamesDeleter
-          (altNames, GENERAL_NAMES_free);
-        size_t n = sk_GENERAL_NAME_num(altNames);
-        for(size_t i = 0; i < n; ++i) {
-          const GENERAL_NAME* altName = sk_GENERAL_NAME_value(altNames, i);
-          if(altName->type == GEN_DNS) {
-            const char* name =
-              reinterpret_cast<char*>(ASN1_STRING_data(altName->d.ia5));
-            if(!name) {
-              continue;
-            }
-            size_t len = ASN1_STRING_length(altName->d.ia5);
-            dnsNames.push_back(std::string(name, len));
-          } else if(altName->type == GEN_IPADD) {
-            const unsigned char* ipAddr = altName->d.iPAddress->data;
-            if(!ipAddr) {
-              continue;
-            }
-            size_t len = altName->d.iPAddress->length;
-            ipAddrs.push_back(std::string(reinterpret_cast<const char*>(ipAddr),
-                                          len));
-          }
-        }
-      }
-      X509_NAME* subjectName = X509_get_subject_name(peerCert);
-      if(!subjectName) {
-        throw DL_ABORT_EX
-          ("Could not get X509 name object from the certificate.");
-      }
-      int lastpos = -1;
-      while(1) {
-        lastpos = X509_NAME_get_index_by_NID(subjectName, NID_commonName,
-                                             lastpos);
-        if(lastpos == -1) {
-          break;
-        }
-        X509_NAME_ENTRY* entry = X509_NAME_get_entry(subjectName, lastpos);
-        unsigned char* out;
-        int outlen = ASN1_STRING_to_UTF8(&out,
-                                         X509_NAME_ENTRY_get_data(entry));
-        if(outlen < 0) {
-          continue;
-        }
-        commonName.assign(&out[0], &out[outlen]);
-        OPENSSL_free(out);
-        break;
-      }
-      if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) {
-        throw DL_ABORT_EX(MSG_HOSTNAME_NOT_MATCH);
-      }
-    }
-    secure_ = A2_TLS_CONNECTED;
-    break;
-  }
-  default:
-    break;
-  }
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-  switch(secure_) {
-  case A2_TLS_NONE:
-    int r;
-    gnutls_init(&sslSession_,
-                tlsctx->getSide() == TLS_CLIENT ?
-                GNUTLS_CLIENT : GNUTLS_SERVER);
-    // It seems err is not error message, but the argument string
-    // which causes syntax error.
-    const char* err;
-    // For client side, disables TLS1.1 here because there are servers
-    // that don't understand TLS1.1.  TODO Is this still necessary?
-    r = gnutls_priority_set_direct(sslSession_,
-                                   tlsctx->getSide() == TLS_CLIENT ?
-                                   "NORMAL:-VERS-TLS1.1" :
-                                   "NORMAL",
-                                   &err);
-    if(r != GNUTLS_E_SUCCESS) {
-      throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(r)));
-    }
-    // put the x509 credentials to the current session
-    gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE,
-                           tlsctx->getCertCred());
-    gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_);
-    if(tlsctx->getSide() == TLS_CLIENT) {
-      // Check hostname is not numeric and it includes ".". Setting
-      // "localhost" will produce TLS alert.
-      if(!util::isNumericHost(hostname) &&
-         hostname.find(".") != std::string::npos) {
-        // TLS extensions: SNI
-        int ret = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS,
-                                         hostname.c_str(), hostname.size());
-        if(ret < 0) {
-          A2_LOG_WARN(fmt
-                      ("Setting hostname in SNI extension failed. Cause: %s",
-                       gnutls_strerror(ret)));
-        }
-      }
-    }
-    secure_ = A2_TLS_HANDSHAKING;
-    // Fall through
-  case A2_TLS_HANDSHAKING: {
-    int ret = gnutls_handshake(sslSession_);
-    if(ret == GNUTLS_E_AGAIN) {
-      gnutlsRecordCheckDirection();
       return false;
-    } else if(ret < 0) {
-      throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(ret)));
-    }
-
-    if(tlsctx->getSide() == TLS_CLIENT && tlsctx->peerVerificationEnabled()) {
-      // verify peer
-      unsigned int status;
-      ret = gnutls_certificate_verify_peers2(sslSession_, &status);
-      if(ret < 0) {
-        throw DL_ABORT_EX
-          (fmt("gnutls_certificate_verify_peer2() failed. Cause: %s",
-               gnutls_strerror(ret)));
-      }
-      if(status) {
-        std::string errors;
-        if(status & GNUTLS_CERT_INVALID) {
-          errors += " `not signed by known authorities or invalid'";
-        }
-        if(status & GNUTLS_CERT_REVOKED) {
-          errors += " `revoked by its CA'";
-        }
-        if(status & GNUTLS_CERT_SIGNER_NOT_FOUND) {
-          errors += " `issuer is not known'";
-        }
-        // TODO should check GNUTLS_CERT_SIGNER_NOT_CA ?
-        if(status & GNUTLS_CERT_INSECURE_ALGORITHM) {
-          errors += " `insecure algorithm'";
-        }
-        if(status & GNUTLS_CERT_NOT_ACTIVATED) {
-          errors += " `not activated yet'";
-        }
-        if(status & GNUTLS_CERT_EXPIRED) {
-          errors += " `expired'";
-        }
-        // TODO Add GNUTLS_CERT_SIGNATURE_FAILURE here
-        if(!errors.empty()) {
-          throw DL_ABORT_EX(fmt(MSG_CERT_VERIFICATION_FAILED, errors.c_str()));
-        }
-      }
-      // certificate type: only X509 is allowed.
-      if(gnutls_certificate_type_get(sslSession_) != GNUTLS_CRT_X509) {
-        throw DL_ABORT_EX("Certificate type is not X509.");
-      }
-
-      unsigned int peerCertsLength;
-      const gnutls_datum_t* peerCerts = gnutls_certificate_get_peers
-        (sslSession_, &peerCertsLength);
-      if(!peerCerts || peerCertsLength == 0 ) {
-        throw DL_ABORT_EX(MSG_NO_CERT_FOUND);
-      }
-      Time now;
-      for(unsigned int i = 0; i < peerCertsLength; ++i) {
-        gnutls_x509_crt_t cert;
-        ret = gnutls_x509_crt_init(&cert);
-        if(ret < 0) {
-          throw DL_ABORT_EX
-            (fmt("gnutls_x509_crt_init() failed. Cause: %s",
-                 gnutls_strerror(ret)));
-        }
-        auto_delete<gnutls_x509_crt_t> certDeleter
-          (cert, gnutls_x509_crt_deinit);
-        ret = gnutls_x509_crt_import(cert, &peerCerts[i], GNUTLS_X509_FMT_DER);
-        if(ret < 0) {
-          throw DL_ABORT_EX
-            (fmt("gnutls_x509_crt_import() failed. Cause: %s",
-                 gnutls_strerror(ret)));
-        }
-        if(i == 0) {
-          std::string commonName;
-          std::vector<std::string> dnsNames;
-          std::vector<std::string> ipAddrs;
-          int ret = 0;
-          char altName[256];
-          size_t altNameLen;
-          for(int j = 0; !(ret < 0); ++j) {
-            altNameLen = sizeof(altName);
-            ret = gnutls_x509_crt_get_subject_alt_name(cert, j, altName,
-                                                       &altNameLen, 0);
-            if(ret == GNUTLS_SAN_DNSNAME) {
-              dnsNames.push_back(std::string(altName, altNameLen));
-            } else if(ret == GNUTLS_SAN_IPADDRESS) {
-              ipAddrs.push_back(std::string(altName, altNameLen));
-            }
-          }
-          altNameLen = sizeof(altName);
-          ret = gnutls_x509_crt_get_dn_by_oid(cert,
-                                              GNUTLS_OID_X520_COMMON_NAME, 0, 0,
-                                              altName, &altNameLen);
-          if(ret == 0) {
-            commonName.assign(altName, altNameLen);
-          }
-          if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) {
-            throw DL_ABORT_EX(MSG_HOSTNAME_NOT_MATCH);
-          }
-        }
-        time_t activationTime = gnutls_x509_crt_get_activation_time(cert);
-        if(activationTime == -1) {
-          throw DL_ABORT_EX("Could not get activation time from certificate.");
-        }
-        if(now.getTime() < activationTime) {
-          throw DL_ABORT_EX("Certificate is not activated yet.");
-        }
-        time_t expirationTime = gnutls_x509_crt_get_expiration_time(cert);
-        if(expirationTime == -1) {
-          throw DL_ABORT_EX("Could not get expiration time from certificate.");
-        }
-        if(expirationTime < now.getTime()) {
-          throw DL_ABORT_EX("Certificate has expired.");
-        }
-      }
+    } else {
+      throw DL_ABORT_EX(fmt("SSL/TLS handshake failure: %s",
+                            handshakeError.empty() ?
+                            tlsSession_->getLastErrorString().c_str() :
+                            handshakeError.c_str()));
     }
-    secure_ = A2_TLS_CONNECTED;
     break;
-  }
   default:
     break;
   }
-#endif // HAVE_LIBGNUTLS
   return true;
 }
 

+ 8 - 29
src/SocketCore.h

@@ -43,16 +43,6 @@
 #include <vector>
 
 #include "a2netcompat.h"
-
-#ifdef HAVE_OPENSSL
-// for SSL
-# include <openssl/ssl.h>
-# include <openssl/err.h>
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-# include <gnutls/gnutls.h>
-#endif // HAVE_LIBGNUTLS
-
 #include "SharedHandle.h"
 #include "a2io.h"
 #include "a2netcompat.h"
@@ -62,6 +52,7 @@ namespace aria2 {
 
 #ifdef ENABLE_SSL
 class TLSContext;
+class TLSSession;
 #endif // ENABLE_SSL
 
 class SocketCore {
@@ -89,27 +80,9 @@ private:
   static SharedHandle<TLSContext> clTlsContext_;
   // TLS context for server side
   static SharedHandle<TLSContext> svTlsContext_;
-#endif // ENABLE_SSL
-
-#ifdef HAVE_OPENSSL
-  // for SSL
-  SSL* ssl;
-
-  int sslHandleEAGAIN(int ret);
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-  gnutls_session_t sslSession_;
-
-  void gnutlsRecordCheckDirection();
-#endif // HAVE_LIBGNUTLS
-
-  void init();
-
-  void bind(const struct sockaddr* addr, socklen_t addrlen);
 
-  void setSockOpt(int level, int optname, void* optval, socklen_t optlen);
+  SharedHandle<TLSSession> tlsSession_;
 
-#ifdef ENABLE_SSL
   /**
    * Makes this socket secure. The connection must be established
    * before calling this method.
@@ -119,6 +92,12 @@ private:
   bool tlsHandshake(TLSContext* tlsctx, const std::string& hostname);
 #endif // ENABLE_SSL
 
+  void init();
+
+  void bind(const struct sockaddr* addr, socklen_t addrlen);
+
+  void setSockOpt(int level, int optname, void* optval, socklen_t optlen);
+
   SocketCore(sock_t sockfd, int sockType);
 public:
   SocketCore(int sockType = SOCK_STREAM);

+ 104 - 0
src/TLSSession.h

@@ -0,0 +1,104 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Tatsuhiro Tsujikawa
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#ifndef TLS_SESSION_H
+#define TLS_SESSION_H
+
+#include "common.h"
+
+// To create another SSL/TLS backend, implement TLSSession class below.
+//
+// class TLSSession {
+// public:
+//   TLSSession(TLSContext* tlsContext);
+//
+//   // MUST deallocate all resources
+//   ~TLSSession();
+//
+//   // Initializes SSL/TLS session. The |sockfd| is the underlying
+//   // tranport socket. This function returns TLS_ERR_OK if it
+//   // succeeds, or TLS_ERR_ERROR.
+//   int init(sock_t sockfd);
+//
+//   // Sets |hostname| for TLS SNI extension. This is only meaningful for
+//   // client side session. This function returns TLS_ERR_OK if it
+//   // succeeds, or TLS_ERR_ERROR.
+//   int setSNIHostname(const std::string& hostname);
+//
+//   // Closes the SSL/TLS session. Don't close underlying transport
+//   // socket. This function returns TLS_ERR_OK if it succeeds, or
+//   // TLS_ERR_ERROR.
+//   int closeConnection();
+//
+//   // Returns TLS_WANT_READ if SSL/TLS session needs more data from
+//   // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
+//   // needs to write more data to proceed. If SSL/TLS session needs
+//   // neither read nor write data at the moment, return value is
+//   // undefined.
+//   int checkDirection();
+//
+//   // Sends |data| with length |len|. This function returns the number
+//   // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
+//   // underlying tranport blocks, or TLS_ERR_ERROR.
+//   ssize_t writeData(const void* data, size_t len);
+//
+//   // Receives data into |data| with length |len|. This function returns
+//   // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
+//   // if the underlying tranport blocks, or TLS_ERR_ERROR.
+//   ssize_t readData(void* data, size_t len);
+//
+//   // Performs client side handshake. The |hostname| is the hostname of
+//   // the remote endpoint and is used to verify its certificate. This
+//   // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
+//   // if the underlying transport blocks, or TLS_ERR_ERROR.
+//   // When returning TLS_ERR_ERROR, provide certificate validation error
+//   // in |handshakeErr|.
+//   int tlsConnect(const std::string& hostname, std::string& handshakeErr);
+//
+//   // Performs server side handshake. This function returns TLS_ERR_OK
+//   // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
+//   // blocks, or TLS_ERR_ERROR.
+//   int tlsAccept();
+//
+//   // Returns last error string
+//   std::string getLastErrorString();
+// };
+
+#ifdef HAVE_OPENSSL
+# include "LibsslTLSSession.h"
+#elif defined HAVE_LIBGNUTLS
+# include "LibgnutlsTLSSession.h"
+#endif
+
+#endif // TLS_SESSION_H

+ 55 - 0
src/TLSSessionConst.h

@@ -0,0 +1,55 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Tatsuhiro Tsujikawa
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#ifndef TLS_SESSION_CONST_H
+#define TLS_SESSION_CONST_H
+
+#include "common.h"
+
+namespace aria2 {
+
+enum TLSDirection {
+  TLS_WANT_READ = 1,
+  TLS_WANT_WRITE
+};
+
+enum TLSErrorCode {
+  TLS_ERR_OK = 0,
+  TLS_ERR_ERROR = -1,
+  TLS_ERR_WOULDBLOCK = -2
+};
+
+} // namespace aria2
+
+#endif // TLS_SESSION_CONST_H