Browse Source

Rewritten TLS hostname check based on RFC 6125.

Tatsuhiro Tsujikawa 13 years ago
parent
commit
0a9abd89c6
5 changed files with 139 additions and 111 deletions
  1. 75 48
      src/SocketCore.cc
  2. 7 0
      src/SocketCore.h
  3. 39 29
      src/util.cc
  4. 1 16
      src/util.h
  5. 17 18
      test/UtilTest.cc

+ 75 - 48
src/SocketCore.cc

@@ -887,71 +887,61 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
           (fmt(MSG_CERT_VERIFICATION_FAILED,
                X509_verify_cert_error_string(verifyResult)));
       }
-      int hostnameOK = -1;
       GENERAL_NAMES* altNames;
+      std::string commonName;
+      std::vector<std::string> dnsNames;
+      std::vector<std::string> ipAddrs;
       altNames = reinterpret_cast<GENERAL_NAMES*>
         (X509_get_ext_d2i(peerCert, NID_subject_alt_name, NULL, NULL));
       if(altNames) {
-        int addrType;
-        if(util::isNumericHost(hostname)) {
-          addrType = GEN_IPADD;
-        } else {
-          addrType = GEN_DNS;
-        }
+        auto_delete<GENERAL_NAMES*> altNamesDeleter
+          (altNames, GENERAL_NAMES_free);
         size_t n = sk_GENERAL_NAME_num(altNames);
         for(size_t i = 0; i < n; ++i) {
           const GENERAL_NAME* altName = sk_GENERAL_NAME_value(altNames, i);
-          if(altName->type == addrType) {
+          if(altName->type == GEN_DNS) {
             const char* name =
               reinterpret_cast<char*>(ASN1_STRING_data(altName->d.ia5));
+            if(!name) {
+              continue;
+            }
             size_t len = ASN1_STRING_length(altName->d.ia5);
-            if(addrType == GEN_DNS) {
-              if(util::tlsHostnameMatch(std::string(name, len), hostname)) {
-                hostnameOK = 1;
-                break;
-              } else {
-                hostnameOK = 0;
-              }
-            } else if(addrType == GEN_IPADD) {
-              if(hostname == std::string(name, len)) {
-                hostnameOK = 1;
-                break;
-              } else {
-                hostnameOK = 0;
-              }
+            dnsNames.push_back(std::string(name, len));
+          } else if(altName->type == GEN_IPADD) {
+            const unsigned char* ipAddr = altName->d.iPAddress->data;
+            if(!ipAddr) {
+              continue;
             }
+            size_t len = altName->d.iPAddress->length;
+            ipAddrs.push_back(std::string(reinterpret_cast<const char*>(ipAddr),
+                                          len));
           }
         }
-        GENERAL_NAMES_free(altNames);
       }
-      if(hostnameOK == -1) {
-        X509_NAME* name = X509_get_subject_name(peerCert);
-        if(!name) {
-          throw DL_ABORT_EX
-            ("Could not get X509 name object from the certificate.");
+      X509_NAME* subjectName = X509_get_subject_name(peerCert);
+      if(!subjectName) {
+        throw DL_ABORT_EX
+          ("Could not get X509 name object from the certificate.");
+      }
+      int lastpos = -1;
+      while(1) {
+        lastpos = X509_NAME_get_index_by_NID(subjectName, NID_commonName,
+                                             lastpos);
+        if(lastpos == -1) {
+          break;
         }
-        int lastpos = -1;
-        while(true) {
-          lastpos = X509_NAME_get_index_by_NID(name, NID_commonName, lastpos);
-          if(lastpos == -1) {
-            break;
-          }
-          X509_NAME_ENTRY* entry = X509_NAME_get_entry(name, lastpos);
-          unsigned char* out;
-          int outlen = ASN1_STRING_to_UTF8(&out,
-                                           X509_NAME_ENTRY_get_data(entry));
-          if(outlen < 0) {
-            continue;
-          }
-          std::string commonName(&out[0], &out[outlen]);
-          OPENSSL_free(out);
-          if(commonName == hostname) {
-            hostnameOK = 1;
-            break;
-          }
+        X509_NAME_ENTRY* entry = X509_NAME_get_entry(subjectName, lastpos);
+        unsigned char* out;
+        int outlen = ASN1_STRING_to_UTF8(&out,
+                                         X509_NAME_ENTRY_get_data(entry));
+        if(outlen < 0) {
+          continue;
         }
+        commonName.assign(&out[0], &out[outlen]);
+        OPENSSL_free(out);
+        break;
       }
-      if(hostnameOK != 1) {
+      if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) {
         throw DL_ABORT_EX(MSG_HOSTNAME_NOT_MATCH);
       }
     }
@@ -1311,6 +1301,43 @@ size_t getBinAddr(unsigned char* dest, const std::string& ip)
   return len;
 }
 
+bool verifyHostname(const std::string& hostname,
+                    const std::vector<std::string>& dnsNames,
+                    const std::vector<std::string>& ipAddrs,
+                    const std::string& commonName)
+{
+  if(util::isNumericHost(hostname)) {
+    // We need max 16 bytes to store IPv6 address.
+    unsigned char binAddr[16];
+    size_t addrLen = getBinAddr(binAddr, hostname);
+    if(addrLen == 0) {
+      return false;
+    }
+    if(ipAddrs.empty()) {
+      return addrLen == commonName.size() &&
+        memcmp(binAddr, commonName.c_str(), addrLen) == 0;
+    }
+    for(std::vector<std::string>::const_iterator i = ipAddrs.begin(),
+          eoi = ipAddrs.end(); i != eoi; ++i) {
+      if(addrLen == (*i).size() &&
+         memcmp(binAddr, (*i).c_str(), addrLen) == 0) {
+        return true;
+      }
+    }
+  } else {
+    if(dnsNames.empty()) {
+      return util::tlsHostnameMatch(commonName, hostname);
+    }
+    for(std::vector<std::string>::const_iterator i = dnsNames.begin(),
+          eoi = dnsNames.end(); i != eoi; ++i) {
+      if(util::tlsHostnameMatch(*i, hostname)) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
 } // namespace net
 
 } // namespace aria2

+ 7 - 0
src/SocketCore.h

@@ -391,6 +391,13 @@ namespace net {
 // IPv6. Return 0 if error occurred.
 size_t getBinAddr(unsigned char* dest, const std::string& ip);
 
+// Verifies hostname against presented identifiers in the certificate.
+// The implementation is based on the procedure described in RFC 6125.
+bool verifyHostname(const std::string& hostname,
+                    const std::vector<std::string>& dnsNames,
+                    const std::vector<std::string>& ipAddrs,
+                    const std::string& commonName);
+
 } // namespace net
 
 } // namespace aria2

+ 39 - 29
src/util.cc

@@ -1620,39 +1620,49 @@ bool noProxyDomainMatch
 
 bool tlsHostnameMatch(const std::string& pattern, const std::string& hostname)
 {
-  int wildcardpos;
-  {
-    std::string::size_type pos = pattern.find('*');
-    if(pos == std::string::npos) {
-      return pattern == hostname;
-    } else if(pos > hostname.size()) {
-      return false;
-    } else {
-      wildcardpos = pos;
-    }
-  }
-  int i, j;
-  for(i = 0; i < wildcardpos; ++i) {
-    if(pattern[i] != hostname[i]) {
-      return false;
-    }
+  // Do case-insensitive match. At least 2 dots are required to enable
+  // wildcard match.
+  std::string::const_iterator ptLeftLabelEnd = std::find(pattern.begin(),
+                                                         pattern.end(),
+                                                         '.');
+  bool wildcardEnabled = true;
+  if(ptLeftLabelEnd == pattern.end() ||
+     std::find(ptLeftLabelEnd+1, pattern.end(), '.') == pattern.end()) {
+    wildcardEnabled = false;
+  }
+  if(!wildcardEnabled) {
+    return strieq(pattern.begin(), pattern.end(),
+                  hostname.begin(), hostname.end());
+  }
+  std::string::const_iterator ptWildcard = std::find(pattern.begin(),
+                                                     ptLeftLabelEnd,
+                                                     '*');
+  if(ptWildcard == ptLeftLabelEnd) {
+    return strieq(pattern.begin(), pattern.end(),
+                  hostname.begin(), hostname.end());
+  }
+  std::string::const_iterator hnLeftLabelEnd = std::find(hostname.begin(),
+                                                         hostname.end(),
+                                                         '.');
+  if(!strieq(ptLeftLabelEnd, pattern.end(), hnLeftLabelEnd, hostname.end())) {
+    return false;
   }
-  for(i = static_cast<int>(pattern.size())-1,
-        j = static_cast<int>(hostname.size())-1;
-      i > wildcardpos && j >= wildcardpos; --i, --j) {
-    if(pattern[i] != hostname[j]) {
-      return false;
-    }
+  // Don't attempt to match a presented identifier where the wildcard
+  // character is embedded within an A-label.
+  if(istartsWith(pattern, "xn--")) {
+    return strieq(pattern.begin(), ptLeftLabelEnd,
+                  hostname.begin(), hnLeftLabelEnd);
   }
-  if(i != wildcardpos) {
+  // Perform wildcard match. Here '*' must match at least one
+  // character.
+  if(hnLeftLabelEnd - hostname.begin() < ptLeftLabelEnd - pattern.begin()) {
     return false;
   }
-  for(i = wildcardpos; i <= j; ++i) {
-    if(hostname[i] == '.') {
-      return false;
-    }
-  }
-  return true;
+  return
+    istartsWith(hostname.begin(), hnLeftLabelEnd,
+                pattern.begin(), ptWildcard) &&
+    iendsWith(hostname.begin(), hnLeftLabelEnd,
+              ptWildcard+1, ptLeftLabelEnd);
 }
 
 bool startsWith(const std::string& a, const char* b)

+ 1 - 16
src/util.h

@@ -852,22 +852,7 @@ SharedHandle<T> copy(const SharedHandle<T>& a)
 // * noProxyDomainMatch("sf.net", ".sf.net") returns false.
 bool noProxyDomainMatch(const std::string& hostname, const std::string& domain);
 
-// Checks hostname matches pattern as described in RFC 2818.
-//
-// Quoted from RFC 2818 section 3.1. Server Identity:
-//
-// Matching is performed using the matching rules specified by
-// [RFC2459].  If more than one identity of a given type is present in
-// the certificate (e.g., more than one dNSName name, a match in any
-// one of the set is considered acceptable.) Names may contain the
-// wildcard character * which is considered to match any single domain
-// name component or component fragment. E.g., *.a.com matches
-// foo.a.com but not bar.foo.a.com. f*.com matches foo.com but not
-// bar.com.
-//
-// If pattern contains multiple '*', this function considers left most
-// '*' as a wildcard character and other '*'s are considered just
-// character literals.
+// Checks hostname matches pattern as described in RFC 6125.
 bool tlsHostnameMatch(const std::string& pattern, const std::string& hostname);
 
 } // namespace util

+ 17 - 18
test/UtilTest.cc

@@ -1847,27 +1847,26 @@ void UtilTest::testSecfmt()
 
 void UtilTest::testTlsHostnameMatch()
 {
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("foo.com", "foo.com"));
+  CPPUNIT_ASSERT(util::tlsHostnameMatch("Foo.com", "foo.com"));
   CPPUNIT_ASSERT(util::tlsHostnameMatch("*.a.com", "foo.a.com"));
   CPPUNIT_ASSERT(!util::tlsHostnameMatch("*.a.com", "bar.foo.a.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("f*.com", "foo.com"));
-  CPPUNIT_ASSERT(!util::tlsHostnameMatch("f*.com", "bar.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("foo.*", "foo.com"));
-  CPPUNIT_ASSERT(!util::tlsHostnameMatch("foo.*", "bar.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("foo.*m", "foo.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("foo.c*", "foo.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("foo.com*", "foo.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("*foo.com", "foo.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("foo.b*z.com", "foo.baz.com"));
-  CPPUNIT_ASSERT(!util::tlsHostnameMatch("foo.b*z.com", "foo.bar.baz.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("*", "foo"));
-  CPPUNIT_ASSERT(!util::tlsHostnameMatch("*", "foo.com"));
-  CPPUNIT_ASSERT(!util::tlsHostnameMatch("*.co*", "foo.com"));
-  CPPUNIT_ASSERT(!util::tlsHostnameMatch("fooo*.com", "foo.com"));
-  CPPUNIT_ASSERT(!util::tlsHostnameMatch("foo*foo.com", "foo.com"));
-  CPPUNIT_ASSERT(!util::tlsHostnameMatch("", "foo.com"));
-  CPPUNIT_ASSERT(util::tlsHostnameMatch("*", ""));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("f*.com", "foo.com"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("*.com", "bar.com"));
+  CPPUNIT_ASSERT(util::tlsHostnameMatch("com", "com"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("foo.*", "foo.com"));
+  CPPUNIT_ASSERT(util::tlsHostnameMatch("a.foo.com", "A.foo.com"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("a.foo.com", "b.foo.com"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("*a.foo.com", "a.foo.com"));
+  CPPUNIT_ASSERT(util::tlsHostnameMatch("*a.foo.com", "ba.foo.com"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("a*.foo.com", "a.foo.com"));
+  CPPUNIT_ASSERT(util::tlsHostnameMatch("a*.foo.com", "ab.foo.com"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("foo.b*z.foo.com", "foo.baz.foo.com"));
+  CPPUNIT_ASSERT(util::tlsHostnameMatch("B*z.foo.com", "bAZ.Foo.com"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("b*z.foo.com", "bz.foo.com"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("*", "foo"));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("*", ""));
   CPPUNIT_ASSERT(util::tlsHostnameMatch("", ""));
+  CPPUNIT_ASSERT(!util::tlsHostnameMatch("xn--*.a.b", "xn--c.a.b"));
 }
 
 } // namespace aria2