Forráskód Böngészése

Move warn logic into SocketCore

Also fiddle a bit with the WinTLS implementation, forcing "strong"
crypto only for > SSLv3.
Nils Maier 10 éve
szülő
commit
3c8704178a

+ 21 - 10
src/AppleTLSSession.cc

@@ -43,7 +43,6 @@
 #include "LogFactory.h"
 #include "a2functional.h"
 #include "fmt.h"
-#include "message.h"
 
 #define ioErr -36
 #define paramErr -50
@@ -380,6 +379,8 @@ AppleTLSSession::AppleTLSSession(AppleTLSContext* ctx)
   case TLS_PROTO_TLS12:
     (void)SSLSetProtocolVersionMin(sslCtx_, kTLSProtocol12);
     break;
+  default:
+    break;
   }
 #else
   (void)SSLSetProtocolVersionEnabled(sslCtx_, kSSLProtocolAll, false);
@@ -395,6 +396,8 @@ AppleTLSSession::AppleTLSSession(AppleTLSContext* ctx)
     // fall through
   case TLS_PROTO_TLS12:
     (void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol12, true);
+  default:
+    break;
   }
 #endif
 
@@ -696,6 +699,7 @@ OSStatus AppleTLSSession::sockRead(void* data, size_t* len)
 }
 
 int AppleTLSSession::tlsConnect(const std::string& hostname,
+                                TLSVersion& version,
                                 std::string& handshakeErr)
 {
   if (state_ != st_initialized) {
@@ -714,7 +718,7 @@ int AppleTLSSession::tlsConnect(const std::string& hostname,
     return TLS_ERR_WOULDBLOCK;
 
   case errSSLServerAuthCompleted:
-    return tlsConnect(hostname, handshakeErr);
+    return tlsConnect(hostname, version, handshakeErr);
 
   default:
     handshakeErr = getLastErrorString();
@@ -732,25 +736,32 @@ int AppleTLSSession::tlsConnect(const std::string& hostname,
                   hostname.c_str(),
                   protoToString(proto),
                   suiteToString(suite).c_str()));
+
   switch (proto) {
-    case kSSLProtocol2:
-    case kSSLProtocol3: {
-      std::string protoAndSuite = protoToString(proto);
-      protoAndSuite += " " + suiteToString(suite);
-      A2_LOG_WARN(fmt(MSG_WARN_OLD_TLS_CONNECTION, protoAndSuite.c_str()));
+    case kSSLProtocol3:
+      version = TLS_PROTO_SSL3;
+      break;
+    case kTLSProtocol1:
+      version = TLS_PROTO_TLS10;
+      break;
+    case kTLSProtocol11:
+      version = TLS_PROTO_TLS11;
+      break;
+    case kTLSProtocol12:
+      version = TLS_PROTO_TLS12;
       break;
-    }
     default:
+      version = TLS_PROTO_NONE;
       break;
   }
 
   return TLS_ERR_OK;
 }
 
-int AppleTLSSession::tlsAccept()
+int AppleTLSSession::tlsAccept(TLSVersion& version)
 {
   std::string hostname, err;
-  return tlsConnect(hostname, err);
+  return tlsConnect(hostname, version, err);
 }
 
 std::string AppleTLSSession::getLastErrorString()

+ 2 - 1
src/AppleTLSSession.h

@@ -96,12 +96,13 @@ public:
   // When returning TLS_ERR_ERROR, provide certificate validation error
   // in |handshakeErr|.
   virtual int tlsConnect(const std::string& hostname,
+                         TLSVersion& version,
                          std::string& handshakeErr) CXX11_OVERRIDE;
 
   // Performs server side handshake. This function returns TLS_ERR_OK
   // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
   // blocks, or TLS_ERR_ERROR.
-  virtual int tlsAccept() CXX11_OVERRIDE;
+  virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
 
   // Returns last error string
   virtual std::string getLastErrorString() CXX11_OVERRIDE;

+ 38 - 21
src/LibgnutlsTLSSession.cc

@@ -39,9 +39,39 @@
 #include "TLSContext.h"
 #include "util.h"
 #include "SocketCore.h"
-#include "LogFactory.h"
-#include "fmt.h"
-#include "message.h"
+
+namespace {
+using namespace aria2;
+
+TLSVersion getProtocolFromSession(gnutls_session_t& session) {
+  auto proto = gnutls_protocol_get_version(session);
+  switch(proto) {
+    case GNUTLS_SSL3:
+      return TLS_PROTO_SSL3;
+
+#ifdef GNUTLS_TLS1_0
+    case GNUTLS_TLS1_0:
+      return TLS_PROTO_TLS10;
+#endif // GNUTLS_TLS1_0
+
+#ifdef GNUTLS_TLS1_1
+    case GNUTLS_TLS1_1:
+      return TLS_PROTO_TLS11;
+      break;
+#endif // GNUTLS_TLS1_1
+
+#ifdef GNUTLS_TLS1_2
+    case GNUTLS_TLS1_2:
+      return TLS_PROTO_TLS12;
+      break;
+#endif // GNUTLS_TLS1_2
+
+    default:
+      return TLS_PROTO_NONE;
+      break;
+  }
+}
+} // namespace
 
 namespace aria2 {
 
@@ -200,7 +230,8 @@ ssize_t GnuTLSSession::readData(void* data, size_t len)
 }
 
 int GnuTLSSession::tlsConnect(const std::string& hostname,
-                           std::string& handshakeErr)
+                              TLSVersion& version,
+                              std::string& handshakeErr)
 {
   handshakeErr = "";
   for(;;) {
@@ -300,32 +331,18 @@ int GnuTLSSession::tlsConnect(const std::string& hostname,
       return TLS_ERR_ERROR;
     }
   }
-  auto proto = gnutls_protocol_get_version(sslSession_);
-  switch(proto) {
-    case GNUTLS_SSL3: {
-      std::string protoAndSuite = gnutls_protocol_get_name(proto);
-      protoAndSuite += " ";
-      protoAndSuite += gnutls_cipher_suite_get_name(
-          gnutls_kx_get(sslSession_),
-          gnutls_cipher_get(sslSession_),
-          gnutls_mac_get(sslSession_)
-          );
-      A2_LOG_WARN(fmt(MSG_WARN_OLD_TLS_CONNECTION, protoAndSuite.c_str()));
-      break;
-    }
 
-    default:
-      break;
-  }
+  version = getProtocolFromSession(sslSession_);
 
   return TLS_ERR_OK;
 }
 
-int GnuTLSSession::tlsAccept()
+int GnuTLSSession::tlsAccept(TLSVersion& version)
 {
   for(;;) {
     rv_ = gnutls_handshake(sslSession_);
     if(rv_ == GNUTLS_E_SUCCESS) {
+      version = getProtocolFromSession(sslSession_);
       return TLS_ERR_OK;
     }
     if(rv_ == GNUTLS_E_AGAIN || rv_ == GNUTLS_E_INTERRUPTED) {

+ 3 - 2
src/LibgnutlsTLSSession.h

@@ -56,8 +56,9 @@ public:
   virtual ssize_t writeData(const void* data, size_t len) CXX11_OVERRIDE;
   virtual ssize_t readData(void* data, size_t len) CXX11_OVERRIDE;
   virtual int tlsConnect
-  (const std::string& hostname, std::string& handshakeErr) CXX11_OVERRIDE;
-  virtual int tlsAccept() CXX11_OVERRIDE;
+  (const std::string& hostname, TLSVersion& version, std::string& handshakeErr)
+  CXX11_OVERRIDE;
+  virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
 private:
   gnutls_session_t sslSession_;

+ 35 - 24
src/LibsslTLSSession.cc

@@ -157,7 +157,7 @@ ssize_t OpenSSLTLSSession::readData(void* data, size_t len)
   return ret;
 }
 
-int OpenSSLTLSSession::handshake()
+int OpenSSLTLSSession::handshake(TLSVersion& version)
 {
   ERR_clear_error();
   if(tlsContext_->getSide() == TLS_CLIENT) {
@@ -181,15 +181,45 @@ int OpenSSLTLSSession::handshake()
       return TLS_ERR_ERROR;
     }
   }
+
+  switch(SSL_version(ssl_)) {
+    case SSL3_VERSION:
+      version = TLS_PROTO_SSL3;
+      break;
+
+#ifdef TLS1_VERSION
+    case TLS1_VERSION:
+      version = TLS_PROTO_TLS10;
+      break;
+#endif // TLS1_VERSION
+
+#ifdef TLS1_1_VERSION
+    case TLS1_1_VERSION:
+      version = TLS_PROTO_TLS11;
+      break;
+#endif // TLS1_1_VERSION
+
+#ifdef TLS1_2_VERSION
+    case TLS1_2_VERSION:
+      version = TLS_PROTO_TLS12;
+      break;
+#endif // TLS1_2_VERSION
+
+    default:
+      version = TLS_PROTO_NONE;
+      break;
+  }
+
   return TLS_ERR_OK;
 }
 
 int OpenSSLTLSSession::tlsConnect(const std::string& hostname,
-                           std::string& handshakeErr)
+                                  TLSVersion& version,
+                                  std::string& handshakeErr)
 {
   handshakeErr = "";
   int ret;
-  ret = handshake();
+  ret = handshake(version);
   if(ret != TLS_ERR_OK) {
     return ret;
   }
@@ -268,31 +298,12 @@ int OpenSSLTLSSession::tlsConnect(const std::string& hostname,
     }
   }
 
-  switch(SSL_version(ssl_)) {
-    case SSL3_VERSION:
-    case SSL2_VERSION: {
-      std::string protoAndSuite = "Unknown";
-      auto cipher = SSL_get_current_cipher(ssl_);
-      if(cipher) {
-        auto buf = make_unique<char[]>(256);
-        auto cipherstr = SSL_CIPHER_description(cipher, buf.get(), 256);
-        if(cipherstr) {
-          protoAndSuite = cipherstr;
-        }
-      }
-      A2_LOG_WARN(fmt(MSG_WARN_OLD_TLS_CONNECTION, protoAndSuite.c_str()));
-      break;
-    }
-    default:
-      break;
-  }
-
   return TLS_ERR_OK;
 }
 
-int OpenSSLTLSSession::tlsAccept()
+int OpenSSLTLSSession::tlsAccept(TLSVersion& version)
 {
-  return handshake();
+  return handshake(version);
 }
 
 std::string OpenSSLTLSSession::getLastErrorString()

+ 4 - 3
src/LibsslTLSSession.h

@@ -56,11 +56,12 @@ public:
   virtual ssize_t writeData(const void* data, size_t len) CXX11_OVERRIDE;
   virtual ssize_t readData(void* data, size_t len) CXX11_OVERRIDE;
   virtual int tlsConnect
-  (const std::string& hostname, std::string& handshakeErr) CXX11_OVERRIDE;
-  virtual int tlsAccept() CXX11_OVERRIDE;
+  (const std::string& hostname, TLSVersion& version, std::string& handshakeErr)
+  CXX11_OVERRIDE;
+  virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
 private:
-  int handshake();
+  int handshake(TLSVersion& version);
   SSL* ssl_;
   OpenSSLTLSContext* tlsContext_;
   // Last error code from openSSL library functions

+ 15 - 2
src/SocketCore.cc

@@ -830,6 +830,7 @@ bool SocketCore::tlsConnect(const std::string& hostname)
 
 bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
 {
+  TLSVersion ver = TLS_PROTO_NONE;
   int rv = 0;
   std::string handshakeError;
   wantRead_ = false;
@@ -860,9 +861,9 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
     // Fall through
   case A2_TLS_HANDSHAKING:
     if(tlsctx->getSide() == TLS_CLIENT) {
-      rv = tlsSession_->tlsConnect(hostname, handshakeError);
+      rv = tlsSession_->tlsConnect(hostname, ver, handshakeError);
     } else {
-      rv = tlsSession_->tlsAccept();
+      rv = tlsSession_->tlsAccept(ver);
     }
     if(rv == TLS_ERR_OK) {
       secure_ = A2_TLS_CONNECTED;
@@ -883,6 +884,18 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
   default:
     break;
   }
+
+  switch(ver) {
+    case TLS_PROTO_NONE:
+      A2_LOG_WARN(MSG_WARN_UNKNOWN_TLS_CONNECTION);
+      break;
+    case TLS_PROTO_SSL3:
+      A2_LOG_WARN(fmt(MSG_WARN_OLD_TLS_CONNECTION, "SSLv3"));
+      break;
+    default:
+      break;
+  }
+
   return true;
 }
 

+ 1 - 0
src/TLSContext.h

@@ -47,6 +47,7 @@ enum TLSSessionSide {
 };
 
 enum TLSVersion {
+  TLS_PROTO_NONE,
   TLS_PROTO_SSL3,
   TLS_PROTO_TLS10,
   TLS_PROTO_TLS11,

+ 3 - 2
src/TLSSession.h

@@ -99,12 +99,13 @@ public:
   // if the underlying transport blocks, or TLS_ERR_ERROR.
   // When returning TLS_ERR_ERROR, provide certificate validation error
   // in |handshakeErr|.
-  virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr) = 0;
+  virtual int tlsConnect(const std::string& hostname, TLSVersion& version,
+                         std::string& handshakeErr) = 0;
 
   // Performs server side handshake. This function returns TLS_ERR_OK
   // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
   // blocks, or TLS_ERR_ERROR.
-  virtual int tlsAccept() = 0;
+  virtual int tlsAccept(TLSVersion& version) = 0;
 
   // Returns last error string
   virtual std::string getLastErrorString() = 0;

+ 41 - 10
src/WinTLSContext.cc

@@ -61,6 +61,9 @@
 #define SCH_USE_STRONG_CRYPTO 0x00400000
 #endif
 
+#define WEAK_CIPHER_BITS 56
+#define STRONG_CIPHER_BITS 128
+
 namespace aria2 {
 
 WinTLSContext::WinTLSContext(TLSSessionSide side, TLSVersion ver)
@@ -82,6 +85,9 @@ WinTLSContext::WinTLSContext(TLSSessionSide side, TLSVersion ver)
       // fall through
     case TLS_PROTO_TLS12:
       credentials_.grbitEnabledProtocols |= SP_PROT_TLS1_2_CLIENT;
+      // fall through
+    default:
+      break;
     }
   }
   else {
@@ -97,9 +103,23 @@ WinTLSContext::WinTLSContext(TLSSessionSide side, TLSVersion ver)
       // fall through
     case TLS_PROTO_TLS12:
       credentials_.grbitEnabledProtocols |= SP_PROT_TLS1_2_SERVER;
+      // fall through
+    default:
+      break;
     }
   }
-  credentials_.dwMinimumCipherStrength = 128; // bit
+
+  switch (ver) {
+  case TLS_PROTO_SSL3:
+    // User explicitly wanted SSLv3 and therefore weak ciphers.
+    credentials_.dwMinimumCipherStrength = WEAK_CIPHER_BITS;
+    break;
+
+  default:
+    // Strong protocol versions: Use a minimum strength, which might be later
+    // refined using SCH_USE_STRONG_CRYPTO in the flags.
+    credentials_.dwMinimumCipherStrength = STRONG_CIPHER_BITS;
+  }
 
   setVerifyPeer(side_ == TLS_CLIENT);
 }
@@ -126,19 +146,30 @@ void WinTLSContext::setVerifyPeer(bool verify)
 {
   cred_.reset();
 
+  // Never automatically push any client or server certs. We'll do cert setup
+  // ourselves.
+  credentials_.dwFlags = SCH_CRED_NO_DEFAULT_CREDS;
+
+  if (credentials_.dwMinimumCipherStrength > WEAK_CIPHER_BITS) {
+    // Enable strong crypto if we already set a minimum cipher streams.
+    // This might actually require evem stronger algorithms, which is a good
+    // thing.
+    credentials_.dwFlags |= SCH_USE_STRONG_CRYPTO;
+  }
+
   if (side_ != TLS_CLIENT || !verify) {
-    credentials_.dwFlags = SCH_CRED_NO_DEFAULT_CREDS |
-                           SCH_CRED_MANUAL_CRED_VALIDATION |
-                           SCH_CRED_IGNORE_NO_REVOCATION_CHECK |
-                           SCH_CRED_IGNORE_REVOCATION_OFFLINE |
-                           SCH_CRED_NO_SERVERNAME_CHECK | SCH_USE_STRONG_CRYPTO;
+    // No verfication for servers and if user explicitly requested it
+    credentials_.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION |
+                            SCH_CRED_IGNORE_NO_REVOCATION_CHECK |
+                            SCH_CRED_IGNORE_REVOCATION_OFFLINE |
+                            SCH_CRED_NO_SERVERNAME_CHECK;
     return;
   }
 
-  credentials_.dwFlags =
-      SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_AUTO_CRED_VALIDATION |
-      SCH_CRED_REVOCATION_CHECK_CHAIN | SCH_CRED_IGNORE_NO_REVOCATION_CHECK |
-      SCH_USE_STRONG_CRYPTO;
+  // Verify other side's cert chain.
+  credentials_.dwFlags |= SCH_CRED_AUTO_CRED_VALIDATION |
+                          SCH_CRED_REVOCATION_CHECK_CHAIN |
+                          SCH_CRED_IGNORE_NO_REVOCATION_CHECK;
 }
 
 CredHandle* WinTLSContext::getCredHandle()

+ 22 - 17
src/WinTLSSession.cc

@@ -283,7 +283,8 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
       state_ == st_handshake_read) {
     // Renegotiating
     std::string hn, err;
-    auto connect = tlsConnect(hn, err);
+    TLSVersion ver;
+    auto connect = tlsConnect(hn, ver, err);
     if (connect != TLS_ERR_OK) {
       return connect;
     }
@@ -479,7 +480,8 @@ ssize_t WinTLSSession::readData(void* data, size_t len)
       state_ == st_handshake_read) {
     // Renegotiating
     std::string hn, err;
-    auto connect = tlsConnect(hn, err);
+    TLSVersion ver;
+    auto connect = tlsConnect(hn, ver, err);
     if (connect != TLS_ERR_OK) {
       return connect;
     }
@@ -559,7 +561,8 @@ ssize_t WinTLSSession::readData(void* data, size_t len)
       state_ = st_initialized;
       A2_LOG_INFO("WinTLS: Renegotiate");
       std::string hn, err;
-      auto connect = tlsConnect(hn, err);
+      TLSVersion ver;
+      auto connect = tlsConnect(hn, ver, err);
       if (connect == TLS_ERR_WOULDBLOCK) {
         break;
       }
@@ -590,6 +593,7 @@ ssize_t WinTLSSession::readData(void* data, size_t len)
 }
 
 int WinTLSSession::tlsConnect(const std::string& hostname,
+                              TLSVersion& version,
                               std::string& handshakeErr)
 {
   // Handshaking will require sending multiple read/write exchanges until the
@@ -819,28 +823,29 @@ restart:
   }
   // Fall through
 
-  case st_handshake_done: {
+  case st_handshake_done:
     // All ready now :D
     state_ = st_connected;
     A2_LOG_INFO(
         fmt("WinTLS: connected with: %s", getCipherSuite(&handle_).c_str()));
-    auto proto = getProtocolVersion(&handle_);
-    if (proto < 0x301) {
-      std::string protoAndSuite;
-      switch (proto) {
+    switch (getProtocolVersion(&handle_)) {
       case 0x300:
-        protoAndSuite = "SSLv3";
+        version = TLS_PROTO_SSL3;
+        break;
+      case 0x301:
+        version = TLS_PROTO_TLS10;
+        break;
+      case 0x302:
+        version = TLS_PROTO_TLS11;
+        break;
+      case 0x303:
+        version = TLS_PROTO_TLS12;
         break;
       default:
-        protoAndSuite = "Unknown";
+        version = TLS_PROTO_NONE;
         break;
-      }
-      protoAndSuite += " " + getCipherSuite(&handle_);
-      A2_LOG_WARN(fmt(MSG_WARN_OLD_TLS_CONNECTION, protoAndSuite.c_str()));
     }
-
     return TLS_ERR_OK;
-  }
 
   }
 
@@ -849,10 +854,10 @@ restart:
   return TLS_ERR_ERROR;
 }
 
-int WinTLSSession::tlsAccept()
+int WinTLSSession::tlsAccept(TLSVersion& version)
 {
   std::string host, err;
-  return tlsConnect(host, err);
+  return tlsConnect(host, version, err);
 }
 
 std::string WinTLSSession::getLastErrorString()

+ 2 - 1
src/WinTLSSession.h

@@ -176,12 +176,13 @@ public:
   // When returning TLS_ERR_ERROR, provide certificate validation error
   // in |handshakeErr|.
   virtual int tlsConnect(const std::string& hostname,
+                         TLSVersion& version,
                          std::string& handshakeErr) CXX11_OVERRIDE;
 
   // Performs server side handshake. This function returns TLS_ERR_OK
   // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
   // blocks, or TLS_ERR_ERROR.
-  virtual int tlsAccept() CXX11_OVERRIDE;
+  virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
 
   // Returns last error string
   virtual std::string getLastErrorString() CXX11_OVERRIDE;

+ 6 - 3
src/message.h

@@ -183,10 +183,13 @@
 #define MSG_WARN_NO_CA_CERT                                             \
   _("You may encounter the certificate verification error with HTTPS server." \
     " See --ca-certificate and --check-certificate option.")
+#define MSG_WARN_UNKNOWN_TLS_CONNECTION \
+  _("aria2c had to connect to the other side using an unknown TLS protocol. " \
+    "The integrity and confidentiality of the connection might be compromised.")
 #define MSG_WARN_OLD_TLS_CONNECTION \
-  _("aria2c had to connect to the server using an old and vulnerable cipher" \
-    " suite. The integrity and confidentiality of the connection might be" \
-    " compromised.\nProtocol and cipher suite: %s")
+  _("aria2c had to connect to the other side using an old and vulnerable TLS" \
+    " protocol. The integrity and confidentiality of the connection might be" \
+    " compromised.\nProtocol: %s")
 #define MSG_SHOW_FILES _("Printing the contents of file '%s'...")
 #define MSG_NOT_TORRENT_METALINK _("This file is neither Torrent nor Metalink" \
                                    " file. Skipping.")