Переглянути джерело

Define SockAddr and use it instead of raw std::pair

Tatsuhiro Tsujikawa 9 роки тому
батько
коміт
e899eba76f
4 змінених файлів з 45 додано та 52 видалено
  1. 4 4
      src/BtSetup.cc
  2. 25 31
      src/SocketCore.cc
  3. 11 17
      src/SocketCore.h
  4. 5 0
      src/a2netcompat.h

+ 4 - 4
src/BtSetup.cc

@@ -228,11 +228,11 @@ void BtSetup::setup(std::vector<std::unique_ptr<Command>>& commands,
         }
       }
       else {
-        std::vector<std::pair<sockaddr_union, socklen_t>> ifAddrs;
-        getInterfaceAddress(ifAddrs, lpdInterface, AF_INET, AI_NUMERICHOST);
-        for (const auto& i : ifAddrs) {
+        auto ifAddrs = SocketCore::getInterfaceAddress(lpdInterface, AF_INET,
+                                                       AI_NUMERICHOST);
+        for (const auto& soaddr : ifAddrs) {
           char host[NI_MAXHOST];
-          if (inetNtop(AF_INET, &i.first.in.sin_addr, host, sizeof(host)) ==
+          if (inetNtop(AF_INET, &soaddr.su.in.sin_addr, host, sizeof(host)) ==
                   0 &&
               receiver->init(host)) {
             initialized = true;

+ 25 - 31
src/SocketCore.cc

@@ -130,11 +130,9 @@ enum TlsState {
 int SocketCore::protocolFamily_ = AF_UNSPEC;
 int SocketCore::ipDscp_ = 0;
 
-std::vector<std::pair<sockaddr_union, socklen_t>> SocketCore::bindAddrs_;
-std::vector<std::vector<std::pair<sockaddr_union, socklen_t>>>
-    SocketCore::bindAddrsList_;
-std::vector<std::vector<std::pair<sockaddr_union, socklen_t>>>::iterator
-    SocketCore::bindAddrsListIt_;
+std::vector<SockAddr> SocketCore::bindAddrs_;
+std::vector<std::vector<SockAddr>> SocketCore::bindAddrsList_;
+std::vector<std::vector<SockAddr>>::iterator SocketCore::bindAddrsListIt_;
 
 int SocketCore::socketRecvBufferSize_ = 0;
 
@@ -318,10 +316,10 @@ void SocketCore::bind(const char* addr, uint16_t port, int family, int flags)
   std::array<char, NI_MAXHOST> host;
   for (const auto& bindAddrs : bindAddrsList_) {
     for (const auto& a : bindAddrs) {
-      if (family != AF_UNSPEC && family != a.first.storage.ss_family) {
+      if (family != AF_UNSPEC && family != a.su.storage.ss_family) {
         continue;
       }
-      auto s = getnameinfo(&a.first.sa, a.second, host.data(), NI_MAXHOST,
+      auto s = getnameinfo(&a.su.sa, a.suLength, host.data(), NI_MAXHOST,
                            nullptr, 0, NI_NUMERICHOST);
       if (s) {
         error = gai_strerror(s);
@@ -462,11 +460,8 @@ void SocketCore::establishConnection(const std::string& host, uint16_t port,
 
     if (!bindAddrs_.empty()) {
       bool bindSuccess = false;
-      for (std::vector<std::pair<sockaddr_union, socklen_t>>::const_iterator
-               i = bindAddrs_.begin(),
-               eoi = bindAddrs_.end();
-           i != eoi; ++i) {
-        if (::bind(fd, &(*i).first.sa, (*i).second) == -1) {
+      for (const auto& soaddr : bindAddrs_) {
+        if (::bind(fd, &soaddr.su.sa, soaddr.suLength) == -1) {
           errNum = SOCKET_ERRNO;
           error = errorMsg(errNum);
           A2_LOG_DEBUG(fmt(EX_SOCKET_BIND, error.c_str()));
@@ -1276,8 +1271,7 @@ bool SocketCore::wantWrite() const { return wantWrite_; }
 
 void SocketCore::bindAddress(const std::string& iface)
 {
-  std::vector<std::pair<sockaddr_union, socklen_t>> bindAddrs;
-  getInterfaceAddress(bindAddrs, iface, protocolFamily_);
+  auto bindAddrs = getInterfaceAddress(iface, protocolFamily_);
   if (bindAddrs.empty()) {
     throw DL_ABORT_EX(
         fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), "not available"));
@@ -1286,7 +1280,7 @@ void SocketCore::bindAddress(const std::string& iface)
   for (const auto& a : bindAddrs_) {
     char host[NI_MAXHOST];
     int s;
-    s = getnameinfo(&a.first.sa, a.second, host, NI_MAXHOST, nullptr, 0,
+    s = getnameinfo(&a.su.sa, a.suLength, host, NI_MAXHOST, nullptr, 0,
                     NI_NUMERICHOST);
     if (s == 0) {
       A2_LOG_DEBUG(fmt("Sockets will bind to %s", host));
@@ -1298,7 +1292,7 @@ void SocketCore::bindAddress(const std::string& iface)
 
 void SocketCore::bindAllAddress(const std::string& ifaces)
 {
-  std::vector<std::vector<std::pair<sockaddr_union, socklen_t>>> bindAddrsList;
+  std::vector<std::vector<SockAddr>> bindAddrsList;
   std::vector<std::string> ifaceList;
   util::split(ifaces.begin(), ifaces.end(), std::back_inserter(ifaceList), ',',
               true);
@@ -1307,8 +1301,7 @@ void SocketCore::bindAllAddress(const std::string& ifaces)
         "List of interfaces is empty, one or more interfaces is required");
   }
   for (auto& iface : ifaceList) {
-    std::vector<std::pair<sockaddr_union, socklen_t>> bindAddrs;
-    getInterfaceAddress(bindAddrs, iface, protocolFamily_);
+    auto bindAddrs = getInterfaceAddress(iface, protocolFamily_);
     if (bindAddrs.empty()) {
       throw DL_ABORT_EX(
           fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), "not available"));
@@ -1317,7 +1310,7 @@ void SocketCore::bindAllAddress(const std::string& ifaces)
     for (const auto& a : bindAddrs) {
       char host[NI_MAXHOST];
       int s;
-      s = getnameinfo(&a.first.sa, a.second, host, NI_MAXHOST, nullptr, 0,
+      s = getnameinfo(&a.su.sa, a.suLength, host, NI_MAXHOST, nullptr, 0,
                       NI_NUMERICHOST);
       if (s == 0) {
         A2_LOG_DEBUG(fmt("Sockets will bind to %s", host));
@@ -1336,11 +1329,11 @@ void SocketCore::setSocketRecvBufferSize(int size)
 
 int SocketCore::getSocketRecvBufferSize() { return socketRecvBufferSize_; }
 
-void getInterfaceAddress(
-    std::vector<std::pair<sockaddr_union, socklen_t>>& ifAddrs,
-    const std::string& iface, int family, int aiFlags)
+std::vector<SockAddr> SocketCore::getInterfaceAddress(const std::string& iface,
+                                                      int family, int aiFlags)
 {
   A2_LOG_DEBUG(fmt("Finding interface %s", iface.c_str()));
+  std::vector<SockAddr> ifAddrs;
 #ifdef HAVE_GETIFADDRS
   // First find interface in interface addresses
   struct ifaddrs* ifaddr = nullptr;
@@ -1376,12 +1369,11 @@ void getInterfaceAddress(
         continue;
       }
       if (strcmp(iface.c_str(), ifa->ifa_name) == 0) {
-        socklen_t bindAddrLen =
+        SockAddr soaddr;
+        soaddr.suLength =
             iffamily == AF_INET ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
-        sockaddr_union bindAddr;
-        memset(&bindAddr, 0, sizeof(bindAddr));
-        memcpy(&bindAddr.storage, ifa->ifa_addr, bindAddrLen);
-        ifAddrs.push_back(std::make_pair(bindAddr, bindAddrLen));
+        memcpy(&soaddr.su, ifa->ifa_addr, soaddr.suLength);
+        ifAddrs.push_back(soaddr);
       }
     }
   }
@@ -1404,10 +1396,10 @@ void getInterfaceAddress(
         try {
           SocketCore socket;
           socket.bind(rp->ai_addr, rp->ai_addrlen);
-          sockaddr_union bindAddr;
-          memset(&bindAddr, 0, sizeof(bindAddr));
-          memcpy(&bindAddr.storage, rp->ai_addr, rp->ai_addrlen);
-          ifAddrs.push_back(std::make_pair(bindAddr, rp->ai_addrlen));
+          SockAddr soaddr;
+          memcpy(&soaddr.su, rp->ai_addr, rp->ai_addrlen);
+          soaddr.suLength = rp->ai_addrlen;
+          ifAddrs.push_back(soaddr);
         }
         catch (RecoverableException& e) {
           continue;
@@ -1415,6 +1407,8 @@ void getInterfaceAddress(
       }
     }
   }
+
+  return ifAddrs;
 }
 
 namespace {

+ 11 - 17
src/SocketCore.h

@@ -73,11 +73,9 @@ private:
   static int protocolFamily_;
   static int ipDscp_;
 
-  static std::vector<std::pair<sockaddr_union, socklen_t>> bindAddrs_;
-  static std::vector<std::vector<std::pair<sockaddr_union, socklen_t>>>
-      bindAddrsList_;
-  static std::vector<std::vector<std::pair<sockaddr_union, socklen_t>>>::
-      iterator bindAddrsListIt_;
+  static std::vector<SockAddr> bindAddrs_;
+  static std::vector<std::vector<SockAddr>> bindAddrsList_;
+  static std::vector<std::vector<SockAddr>>::iterator bindAddrsListIt_;
 
   static int socketRecvBufferSize_;
 
@@ -372,9 +370,14 @@ public:
   static void bindAddress(const std::string& iface);
   static void bindAllAddress(const std::string& ifaces);
 
-  friend void getInterfaceAddress(
-      std::vector<std::pair<sockaddr_union, socklen_t>>& ifAddrs,
-      const std::string& iface, int family, int aiFlags);
+  // Collects IP addresses of given interface iface and stores in
+  // ifAddres. iface may be specified as a hostname, IP address or
+  // interface name like eth0. You can limit the family of IP
+  // addresses to collect using family argument. aiFlags is passed to
+  // getaddrinfo() as hints.ai_flags. No throw.
+  static std::vector<SockAddr> getInterfaceAddress(const std::string& iface,
+                                                   int family = AF_UNSPEC,
+                                                   int aiFlags = 0);
 };
 
 // Set default ai_flags. hints.ai_flags is initialized with this
@@ -389,15 +392,6 @@ int callGetaddrinfo(struct addrinfo** resPtr, const char* host,
                     const char* service, int family, int sockType, int flags,
                     int protocol);
 
-// Collects IP addresses of given interface iface and stores in
-// ifAddres. iface may be specified as a hostname, IP address or
-// interface name like eth0. You can limit the family of IP addresses
-// to collect using family argument. aiFlags is passed to
-// getaddrinfo() as hints.ai_flags. No throw.
-void getInterfaceAddress(
-    std::vector<std::pair<sockaddr_union, socklen_t>>& ifAddrs,
-    const std::string& iface, int family = AF_UNSPEC, int aiFlags = 0);
-
 // Provides functionality of inet_ntop using getnameinfo.  The return
 // value is the exact value of getnameinfo returns. You can get error
 // message using gai_strerror(3).

+ 5 - 0
src/a2netcompat.h

@@ -115,6 +115,11 @@ union sockaddr_union {
   sockaddr_in in;
 };
 
+struct SockAddr {
+  sockaddr_union su;
+  socklen_t suLength;
+};
+
 #define A2_DEFAULT_IOV_MAX 128
 
 #if defined(IOV_MAX) && IOV_MAX < A2_DEFAULT_IOV_MAX