Browse Source

Rewritten util::inSameCidrBlock() to support IPv6 address.

We also introduced union sockaddr_union in a2netcompat.h.
Tatsuhiro Tsujikawa 14 years ago
parent
commit
74e5aa0ace
7 changed files with 101 additions and 57 deletions
  1. 29 0
      src/SocketCore.cc
  2. 12 0
      src/SocketCore.h
  3. 7 0
      src/a2netcompat.h
  4. 17 26
      src/util.cc
  5. 7 6
      src/util.h
  6. 19 0
      test/SocketCoreTest.cc
  7. 10 25
      test/UtilTest.cc

+ 29 - 0
src/SocketCore.cc

@@ -1260,4 +1260,33 @@ int inetNtop(int af, const void* src, char* dst, socklen_t size)
   return s;
 }
 
+namespace net {
+
+size_t getBinAddr(unsigned char* dest, const std::string& ip)
+{
+  size_t len = 0;
+  addrinfo* res;
+  if(callGetaddrinfo(&res, ip.c_str(), 0, AF_UNSPEC,
+                     0, AI_NUMERICHOST, 0) != 0) {
+    return len;
+  }
+  WSAAPI_AUTO_DELETE<addrinfo*> resDeleter(res, freeaddrinfo);
+  for(addrinfo* rp = res; rp; rp = rp->ai_next) {
+    if(rp->ai_family == AF_INET) {
+      sockaddr_in* addr = &reinterpret_cast<sockaddr_union*>(rp->ai_addr)->in;
+      len = 4;
+      memcpy(dest, &(addr->sin_addr), len);
+      return len;
+    } else if(rp->ai_family == AF_INET6) {
+      sockaddr_in6* addr = &reinterpret_cast<sockaddr_union*>(rp->ai_addr)->in6;
+      len = 16;
+      memcpy(dest, &(addr->sin6_addr), len);
+      return len;
+    }
+  }
+  return len;
+}
+
+} // namespace net
+
 } // namespace aria2

+ 12 - 0
src/SocketCore.h

@@ -374,6 +374,18 @@ void getInterfaceAddress
 // message using gai_strerror(3).
 int inetNtop(int af, const void* src, char* dst, socklen_t size);
 
+namespace net {
+
+// Stores binary representation of IP address ip which is represented
+// in text.  ip must be numeric IPv4 or IPv6 address. dest must be
+// allocated by caller before the call. For IPv4 address, dest must be
+// at least 4. For IPv6 address, dest must be at least 16. Returns the
+// number of bytes written in dest, that is 4 for IPv4 and 16 for
+// IPv6. Return 0 if error occurred.
+size_t getBinAddr(unsigned char* dest, const std::string& ip);
+
+} // namespace net
+
 } // namespace aria2
 
 #endif // D_SOCKET_CORE_H

+ 7 - 0
src/a2netcompat.h

@@ -132,4 +132,11 @@ public:
 # define WSAAPI_AUTO_DELETE auto_delete
 #endif // !__MINGW32__
 
+union sockaddr_union {
+  sockaddr sa;
+  sockaddr_storage storage;
+  sockaddr_in6 in6;
+  sockaddr_in in;
+};
+
 #endif // D_A2NETCOMPAT_H

+ 17 - 26
src/util.cc

@@ -70,6 +70,7 @@
 #include "Option.h"
 #include "DownloadContext.h"
 #include "BufferedFile.h"
+#include "SocketCore.h"
 
 #ifdef ENABLE_MESSAGE_DIGEST
 # include "MessageDigest.h"
@@ -1577,37 +1578,27 @@ std::string escapePath(const std::string& s)
   return d;
 }
 
-bool getCidrPrefix(struct in_addr& in, const std::string& ip, int bits)
+bool inSameCidrBlock
+(const std::string& ip1, const std::string& ip2, size_t bits)
 {
-  struct in_addr t;
-  if(inet_aton(ip.c_str(), &t) == 0) {
+  unsigned char s1[16], s2[16];
+  size_t len1, len2;
+  if((len1 = net::getBinAddr(s1, ip1)) == 0 ||
+     (len2 = net::getBinAddr(s2, ip2)) == 0 ||
+     len1 != len2) {
     return false;
   }
-  int lastindex = bits/8;
-  if(lastindex < 4) {
-    char* p = reinterpret_cast<char*>(&t.s_addr);
-    const char* last = p+4;
-    p += lastindex;    
-    if(bits%8 != 0) {
-      *p &= bitfield::lastByteMask(bits);
-      ++p;
-    }
-    for(; p != last; ++p) {
-      *p &= 0;
-    }
+  if(bits > 8*len1) {
+    bits = 8*len1;
   }
-  in = t;
-  return true;
-}
-
-bool inSameCidrBlock(const std::string& ip1, const std::string& ip2, int bits)
-{
-  struct in_addr in1;
-  struct in_addr in2;
-  if(!getCidrPrefix(in1, ip1, bits) || !getCidrPrefix(in2, ip2, bits)) {
-    return false;
+  int last = (bits-1)/8;
+  for(int i = 0; i < last; ++i) {
+    if(s1[i] != s2[i]) {
+      return false;
+    }
   }
-  return in1.s_addr == in2.s_addr;
+  unsigned char mask = bitfield::lastByteMask(bits);
+  return (s1[last] & mask) == (s2[last] & mask);
 }
 
 void removeMetalinkContentTypes(const SharedHandle<RequestGroup>& group)

+ 7 - 6
src/util.h

@@ -427,12 +427,13 @@ bool detectDirTraversal(const std::string& s);
 // '_': '"', '*', ':', '<', '>', '?', '\', '|'.
 std::string escapePath(const std::string& s);
 
-// Stores network address of numeric IPv4 address ip using CIDR bits
-// into in.  On success, returns true. Otherwise returns false.
-bool getCidrPrefix(struct in_addr& in, const std::string& ip, int bits);
-
-// Returns true if ip1 and ip2 are in the same CIDR block.
-bool inSameCidrBlock(const std::string& ip1, const std::string& ip2, int bits);
+// Returns true if ip1 and ip2 are in the same CIDR block.  ip1 and
+// ip2 must be numeric IPv4 or IPv6 address. If either of them or both
+// of them is not valid numeric address, then returns false. bits is
+// prefix bits. If bits is out of range, then bits is set to the
+// length of binary representation of the address*8.
+bool inSameCidrBlock
+(const std::string& ip1, const std::string& ip2, size_t bits);
 
 void removeMetalinkContentTypes(const SharedHandle<RequestGroup>& group);
 void removeMetalinkContentTypes(RequestGroup* group);

+ 19 - 0
test/SocketCoreTest.cc

@@ -15,6 +15,7 @@ class SocketCoreTest:public CppUnit::TestFixture {
   CPPUNIT_TEST(testWriteAndReadDatagram);
   CPPUNIT_TEST(testGetSocketError);
   CPPUNIT_TEST(testInetNtop);
+  CPPUNIT_TEST(testGetBinAddr);
   CPPUNIT_TEST_SUITE_END();
 public:
   void setUp() {}
@@ -24,6 +25,7 @@ public:
   void testWriteAndReadDatagram();
   void testGetSocketError();
   void testInetNtop();
+  void testGetBinAddr();
 };
 
 
@@ -104,4 +106,21 @@ void SocketCoreTest::testInetNtop()
   }
 }
 
+void SocketCoreTest::testGetBinAddr()
+{
+  unsigned char dest[16];
+  unsigned char ans1[] = { 192, 168, 0, 1 };
+  CPPUNIT_ASSERT_EQUAL((size_t)4, net::getBinAddr(dest, "192.168.0.1"));
+  CPPUNIT_ASSERT(std::equal(&dest[0], &dest[4], &ans1[0]));
+
+  unsigned char ans2[] = { 0x20u, 0x01u, 0x0du, 0xb8u,
+                           0x00u, 0x00u, 0x00u, 0x00u,
+                           0x00u, 0x00u, 0x00u, 0x00u,
+                           0x00u, 0x02u, 0x00u, 0x01u };
+  CPPUNIT_ASSERT_EQUAL((size_t)16, net::getBinAddr(dest, "2001:db8::2:1"));
+  CPPUNIT_ASSERT(std::equal(&dest[0], &dest[16], &ans2[0]));
+
+  CPPUNIT_ASSERT_EQUAL((size_t)0, net::getBinAddr(dest, "localhost"));
+}
+
 } // namespace aria2

+ 10 - 25
test/UtilTest.cc

@@ -15,6 +15,7 @@
 #include "array_fun.h"
 #include "BufferedFile.h"
 #include "TestUtil.h"
+#include "SocketCore.h"
 
 namespace aria2 {
 
@@ -70,7 +71,6 @@ class UtilTest:public CppUnit::TestFixture {
   CPPUNIT_TEST(testIsNumericHost);
   CPPUNIT_TEST(testDetectDirTraversal);
   CPPUNIT_TEST(testEscapePath);
-  CPPUNIT_TEST(testGetCidrPrefix);
   CPPUNIT_TEST(testInSameCidrBlock);
   CPPUNIT_TEST(testIsUtf8String);
   CPPUNIT_TEST(testNextParam);
@@ -130,7 +130,6 @@ public:
   void testIsNumericHost();
   void testDetectDirTraversal();
   void testEscapePath();
-  void testGetCidrPrefix();
   void testInSameCidrBlock();
   void testIsUtf8String();
   void testNextParam();
@@ -1192,34 +1191,20 @@ void UtilTest::testEscapePath()
 #endif // !__MINGW32__
 }
 
-void UtilTest::testGetCidrPrefix()
+void UtilTest::testInSameCidrBlock()
 {
-  struct in_addr in;
-  CPPUNIT_ASSERT(util::getCidrPrefix(in, "192.168.0.1", 16));
-  CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.0"), std::string(inet_ntoa(in)));
-
-  CPPUNIT_ASSERT(util::getCidrPrefix(in, "192.168.255.255", 17));
-  CPPUNIT_ASSERT_EQUAL(std::string("192.168.128.0"),std::string(inet_ntoa(in)));
-
-  CPPUNIT_ASSERT(util::getCidrPrefix(in, "192.168.128.1", 16));
-  CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.0"), std::string(inet_ntoa(in)));
+  CPPUNIT_ASSERT(util::inSameCidrBlock("192.168.128.1", "192.168.0.1", 16));
+  CPPUNIT_ASSERT(!util::inSameCidrBlock("192.168.128.1", "192.168.0.1", 17));
 
-  CPPUNIT_ASSERT(util::getCidrPrefix(in, "192.168.0.1", 32));
-  CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"), std::string(inet_ntoa(in)));
+  CPPUNIT_ASSERT(util::inSameCidrBlock("192.168.0.1", "192.168.0.1", 32));
+  CPPUNIT_ASSERT(!util::inSameCidrBlock("192.168.0.1", "192.168.0.0", 32));
 
-  CPPUNIT_ASSERT(util::getCidrPrefix(in, "192.168.0.1", 0));
-  CPPUNIT_ASSERT_EQUAL(std::string("0.0.0.0"), std::string(inet_ntoa(in)));
+  CPPUNIT_ASSERT(util::inSameCidrBlock("192.168.0.1", "10.0.0.1", 0));
 
-  CPPUNIT_ASSERT(util::getCidrPrefix(in, "10.10.1.44", 27));
-  CPPUNIT_ASSERT_EQUAL(std::string("10.10.1.32"), std::string(inet_ntoa(in)));
+  CPPUNIT_ASSERT(util::inSameCidrBlock("2001:db8::2:1", "2001:db0::2:2", 28));
+  CPPUNIT_ASSERT(!util::inSameCidrBlock("2001:db8::2:1", "2001:db0::2:2", 29));
 
-  CPPUNIT_ASSERT(!util::getCidrPrefix(in, "::1", 32));
-}
-
-void UtilTest::testInSameCidrBlock()
-{
-  CPPUNIT_ASSERT(util::inSameCidrBlock("192.168.128.1", "192.168.0.1", 16));
-  CPPUNIT_ASSERT(!util::inSameCidrBlock("192.168.128.1", "192.168.0.1", 17));
+  CPPUNIT_ASSERT(!util::inSameCidrBlock("2001:db8::2:1", "192.168.0.1", 8));
 }
 
 void UtilTest::testIsUtf8String()