瀏覽代碼

Refactor SocketCore::getPeerInfo, getAddrInfo to return Endpoint object

Tatsuhiro Tsujikawa 9 年之前
父節點
當前提交
ea4d99ea08

+ 8 - 10
src/DHTConnectionImpl.cc

@@ -77,9 +77,8 @@ bool DHTConnectionImpl::bind(uint16_t& port, const std::string& addr)
   try {
     socket_->bind(addr.c_str(), port, family_);
     socket_->setNonBlockingMode();
-    std::pair<std::string, uint16_t> svaddr;
-    socket_->getAddrInfo(svaddr);
-    port = svaddr.second;
+    auto endpoint = socket_->getAddrInfo();
+    port = endpoint.port;
     A2_LOG_NOTICE(fmt(_("IPv%d DHT: listening on UDP port %u"), ipv, port));
     return true;
   }
@@ -92,16 +91,15 @@ bool DHTConnectionImpl::bind(uint16_t& port, const std::string& addr)
 ssize_t DHTConnectionImpl::receiveMessage(unsigned char* data, size_t len,
                                           std::string& host, uint16_t& port)
 {
-  std::pair<std::string, uint16_t> remoteHost;
-  ssize_t length = socket_->readDataFrom(data, len, remoteHost);
+  Endpoint remoteEndpoint;
+  ssize_t length = socket_->readDataFrom(data, len, remoteEndpoint);
   if (length == 0) {
     return length;
   }
-  else {
-    host = remoteHost.first;
-    port = remoteHost.second;
-    return length;
-  }
+
+  host = remoteEndpoint.addr;
+  port = remoteEndpoint.port;
+  return length;
 }
 
 ssize_t DHTConnectionImpl::sendMessage(const unsigned char* data, size_t len,

+ 6 - 7
src/DownloadEngine.cc

@@ -371,11 +371,10 @@ void DownloadEngine::poolSocket(const std::string& ipaddr, uint16_t port,
 }
 
 namespace {
-bool getPeerInfo(std::pair<std::string, uint16_t>& res,
-                 const std::shared_ptr<SocketCore>& socket)
+bool getPeerInfo(Endpoint& res, const std::shared_ptr<SocketCore>& socket)
 {
   try {
-    socket->getPeerInfo(res);
+    res = socket->getPeerInfo();
     return true;
   }
   catch (RecoverableException& e) {
@@ -399,9 +398,9 @@ void DownloadEngine::poolSocket(const std::shared_ptr<Request>& request,
     return;
   }
 
-  std::pair<std::string, uint16_t> peerInfo;
+  Endpoint peerInfo;
   if (getPeerInfo(peerInfo, socket)) {
-    poolSocket(peerInfo.first, peerInfo.second, A2STR::NIL, 0, socket,
+    poolSocket(peerInfo.addr, peerInfo.port, A2STR::NIL, 0, socket,
                std::move(timeout));
   }
 }
@@ -421,9 +420,9 @@ void DownloadEngine::poolSocket(const std::shared_ptr<Request>& request,
     return;
   }
 
-  std::pair<std::string, uint16_t> peerInfo;
+  Endpoint peerInfo;
   if (getPeerInfo(peerInfo, socket)) {
-    poolSocket(peerInfo.first, peerInfo.second, username, A2STR::NIL, 0, socket,
+    poolSocket(peerInfo.addr, peerInfo.port, username, A2STR::NIL, 0, socket,
                options, std::move(timeout));
   }
 }

+ 11 - 17
src/FtpConnection.cc

@@ -193,10 +193,9 @@ bool FtpConnection::sendPasv()
 
 std::shared_ptr<SocketCore> FtpConnection::createServerSocket()
 {
-  std::pair<std::string, uint16_t> addrinfo;
-  socket_->getAddrInfo(addrinfo);
+  auto endpoint = socket_->getAddrInfo();
   auto serverSocket = std::make_shared<SocketCore>();
-  serverSocket->bind(addrinfo.first.c_str(), 0, AF_UNSPEC);
+  serverSocket->bind(endpoint.addr.c_str(), 0, AF_UNSPEC);
   serverSocket->beginListen();
   return serverSocket;
 }
@@ -204,14 +203,10 @@ std::shared_ptr<SocketCore> FtpConnection::createServerSocket()
 bool FtpConnection::sendEprt(const std::shared_ptr<SocketCore>& serverSocket)
 {
   if (socketBuffer_.sendBufferIsEmpty()) {
-    sockaddr_union sockaddr;
-    socklen_t len = sizeof(sockaddr);
-    serverSocket->getAddrInfo(sockaddr, len);
-    std::pair<std::string, uint16_t> addrinfo =
-        util::getNumericNameInfo(&sockaddr.sa, len);
-    std::string request = fmt("EPRT |%d|%s|%u|\r\n",
-                              sockaddr.storage.ss_family == AF_INET ? 1 : 2,
-                              addrinfo.first.c_str(), addrinfo.second);
+    auto endpoint = serverSocket->getAddrInfo();
+    auto request =
+        fmt("EPRT |%d|%s|%u|\r\n", endpoint.family == AF_INET ? 1 : 2,
+            endpoint.addr.c_str(), endpoint.port);
     A2_LOG_INFO(fmt(MSG_SENDING_REQUEST, cuid_, request.c_str()));
     socketBuffer_.pushStr(std::move(request));
   }
@@ -222,15 +217,14 @@ bool FtpConnection::sendEprt(const std::shared_ptr<SocketCore>& serverSocket)
 bool FtpConnection::sendPort(const std::shared_ptr<SocketCore>& serverSocket)
 {
   if (socketBuffer_.sendBufferIsEmpty()) {
-    std::pair<std::string, uint16_t> addrinfo;
-    socket_->getAddrInfo(addrinfo);
+    auto endpoint = socket_->getAddrInfo();
     int ipaddr[4];
-    sscanf(addrinfo.first.c_str(), "%d.%d.%d.%d", &ipaddr[0], &ipaddr[1],
+    sscanf(endpoint.addr.c_str(), "%d.%d.%d.%d", &ipaddr[0], &ipaddr[1],
            &ipaddr[2], &ipaddr[3]);
-    serverSocket->getAddrInfo(addrinfo);
-    std::string request =
+    auto svEndpoint = serverSocket->getAddrInfo();
+    auto request =
         fmt("PORT %d,%d,%d,%d,%d,%d\r\n", ipaddr[0], ipaddr[1], ipaddr[2],
-            ipaddr[3], addrinfo.second / 256, addrinfo.second % 256);
+            ipaddr[3], svEndpoint.port / 256, svEndpoint.port % 256);
     A2_LOG_INFO(fmt(MSG_SENDING_REQUEST, cuid_, request.c_str()));
     socketBuffer_.pushStr(std::move(request));
   }

+ 3 - 4
src/FtpNegotiationCommand.cc

@@ -708,13 +708,12 @@ bool FtpNegotiationCommand::preparePasvConnect()
     return true;
   }
   else {
-    std::pair<std::string, uint16_t> dataAddr;
-    getSocket()->getPeerInfo(dataAddr);
+    auto endpoint = getSocket()->getPeerInfo();
     // make a data connection to the server.
-    A2_LOG_INFO(fmt(MSG_CONNECTING_TO_SERVER, getCuid(), dataAddr.first.c_str(),
+    A2_LOG_INFO(fmt(MSG_CONNECTING_TO_SERVER, getCuid(), endpoint.addr.c_str(),
                     pasvPort_));
     dataSocket_ = std::make_shared<SocketCore>();
-    dataSocket_->establishConnection(dataAddr.first, pasvPort_, false);
+    dataSocket_->establishConnection(endpoint.addr, pasvPort_, false);
     disableReadCheckSocket();
     setWriteCheckSocket(dataSocket_);
     sequence_ = SEQ_SEND_REST_PASV;

+ 2 - 3
src/HttpListenCommand.cc

@@ -72,11 +72,10 @@ bool HttpListenCommand::execute()
     if (serverSocket_->isReadable(0)) {
       std::shared_ptr<SocketCore> socket(serverSocket_->acceptConnection());
       socket->setTcpNodelay(true);
-      std::pair<std::string, uint16_t> peerInfo;
-      socket->getPeerInfo(peerInfo);
+      auto endpoint = socket->getPeerInfo();
 
       A2_LOG_INFO(fmt("RPC: Accepted the connection from %s:%u.",
-                      peerInfo.first.c_str(), peerInfo.second));
+                      endpoint.addr.c_str(), endpoint.port));
 
       e_->setNoWait(true);
       e_->addCommand(

+ 2 - 3
src/InitiateConnectionCommand.cc

@@ -124,9 +124,8 @@ void InitiateConnectionCommand::setConnectedAddrInfo(
     const std::shared_ptr<Request>& req, const std::string& hostname,
     const std::shared_ptr<SocketCore>& socket)
 {
-  std::pair<std::string, uint16_t> peerAddr;
-  socket->getPeerInfo(peerAddr);
-  req->setConnectedAddrInfo(hostname, peerAddr.first, peerAddr.second);
+  auto endpoint = socket->getPeerInfo();
+  req->setConnectedAddrInfo(hostname, endpoint.addr, endpoint.port);
 }
 
 std::shared_ptr<BackupConnectInfo>

+ 5 - 5
src/LpdMessageReceiver.cc

@@ -82,10 +82,10 @@ std::unique_ptr<LpdMessage> LpdMessageReceiver::receiveMessage()
 {
   while (1) {
     unsigned char buf[200];
-    std::pair<std::string, uint16_t> peerAddr;
+    Endpoint remoteEndpoint;
     ssize_t length;
     try {
-      length = socket_->readDataFrom(buf, sizeof(buf), peerAddr);
+      length = socket_->readDataFrom(buf, sizeof(buf), remoteEndpoint);
       if (length == 0) {
         return nullptr;
       }
@@ -114,7 +114,7 @@ std::unique_ptr<LpdMessage> LpdMessageReceiver::receiveMessage()
       continue;
     }
     A2_LOG_INFO(fmt("LPD message received infohash=%s, port=%u from %s",
-                    infoHashString.c_str(), port, peerAddr.first.c_str()));
+                    infoHashString.c_str(), port, remoteEndpoint.addr.c_str()));
     std::string infoHash;
     if (infoHashString.size() != 40 ||
         (infoHash = util::fromHex(infoHashString.begin(), infoHashString.end()))
@@ -122,8 +122,8 @@ std::unique_ptr<LpdMessage> LpdMessageReceiver::receiveMessage()
       A2_LOG_INFO(fmt("LPD bad request. infohash=%s", infoHashString.c_str()));
       continue;
     }
-    auto peer = std::make_shared<Peer>(peerAddr.first, port, false);
-    if (util::inPrivateAddress(peerAddr.first)) {
+    auto peer = std::make_shared<Peer>(remoteEndpoint.addr, port, false);
+    if (util::inPrivateAddress(remoteEndpoint.addr)) {
       peer->setLocalPeer(true);
     }
     return make_unique<LpdMessage>(peer, infoHash);

+ 2 - 3
src/NameResolver.cc

@@ -63,9 +63,8 @@ void NameResolver::resolve(std::vector<std::string>& resolvedAddresses,
                                                                 freeaddrinfo);
   struct addrinfo* rp;
   for (rp = res; rp; rp = rp->ai_next) {
-    std::pair<std::string, uint16_t> addressPort =
-        util::getNumericNameInfo(rp->ai_addr, rp->ai_addrlen);
-    resolvedAddresses.push_back(addressPort.first);
+    auto endpoint = util::getNumericNameInfo(rp->ai_addr, rp->ai_addrlen);
+    resolvedAddresses.push_back(endpoint.addr);
   }
 }
 

+ 4 - 8
src/PeerListenCommand.cc

@@ -93,11 +93,8 @@ uint16_t PeerListenCommand::getPort() const
   if (!socket_) {
     return 0;
   }
-  else {
-    std::pair<std::string, uint16_t> addr;
-    socket_->getAddrInfo(addr);
-    return addr.second;
-  }
+
+  return socket_->getAddrInfo().port;
 }
 
 bool PeerListenCommand::execute()
@@ -110,10 +107,9 @@ bool PeerListenCommand::execute()
     try {
       peerSocket = socket_->acceptConnection();
       peerSocket->applyIpDscp();
-      std::pair<std::string, uint16_t> peerInfo;
-      peerSocket->getPeerInfo(peerInfo);
+      auto endpoint = peerSocket->getPeerInfo();
 
-      auto peer = std::make_shared<Peer>(peerInfo.first, peerInfo.second, true);
+      auto peer = std::make_shared<Peer>(endpoint.addr, endpoint.port, true);
       cuid_t cuid = e_->newCUID();
       e_->addCommand(
           make_unique<ReceiverMSEHandshakeCommand>(cuid, peer, e_, peerSocket));

+ 7 - 12
src/SocketCore.cc

@@ -387,13 +387,12 @@ std::shared_ptr<SocketCore> SocketCore::acceptConnection() const
   return sock;
 }
 
-int SocketCore::getAddrInfo(std::pair<std::string, uint16_t>& addrinfo) const
+Endpoint SocketCore::getAddrInfo() const
 {
   sockaddr_union sockaddr;
   socklen_t len = sizeof(sockaddr);
   getAddrInfo(sockaddr, len);
-  addrinfo = util::getNumericNameInfo(&sockaddr.sa, len);
-  return sockaddr.storage.ss_family;
+  return util::getNumericNameInfo(&sockaddr.sa, len);
 }
 
 void SocketCore::getAddrInfo(sockaddr_union& sockaddr, socklen_t& len) const
@@ -412,7 +411,7 @@ int SocketCore::getAddressFamily() const
   return sockaddr.storage.ss_family;
 }
 
-int SocketCore::getPeerInfo(std::pair<std::string, uint16_t>& peerinfo) const
+Endpoint SocketCore::getPeerInfo() const
 {
   sockaddr_union sockaddr;
   socklen_t len = sizeof(sockaddr);
@@ -420,8 +419,7 @@ int SocketCore::getPeerInfo(std::pair<std::string, uint16_t>& peerinfo) const
     int errNum = SOCKET_ERRNO;
     throw DL_ABORT_EX(fmt(EX_SOCKET_GET_NAME, errorMsg(errNum).c_str()));
   }
-  peerinfo = util::getNumericNameInfo(&sockaddr.sa, len);
-  return sockaddr.storage.ss_family;
+  return util::getNumericNameInfo(&sockaddr.sa, len);
 }
 
 void SocketCore::establishConnection(const std::string& host, uint16_t port,
@@ -965,9 +963,8 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
       if (!hostname.empty()) {
         ss << hostname << " (";
       }
-      std::pair<std::string, uint16_t> peer;
-      getPeerInfo(peer);
-      ss << peer.first << ":" << peer.second;
+      auto peerEndpoint = getPeerInfo();
+      ss << peerEndpoint.addr << ":" << peerEndpoint.port;
       if (!hostname.empty()) {
         ss << ")";
       }
@@ -1219,9 +1216,7 @@ ssize_t SocketCore::writeData(const void* data, size_t len,
   return r;
 }
 
-ssize_t SocketCore::readDataFrom(
-    void* data, size_t len,
-    std::pair<std::string /* numerichost */, uint16_t /* port */>& sender)
+ssize_t SocketCore::readDataFrom(void* data, size_t len, Endpoint& sender)
 {
   wantRead_ = false;
   wantWrite_ = false;

+ 6 - 13
src/SocketCore.h

@@ -169,12 +169,9 @@ public:
   void beginListen();
 
   /**
-   * Stores host address and port of this socket to addrinfo and
-   * returns address family.
-   *
-   * @param addrinfo placeholder to store host address and port.
+   * Returns host address, family and port of this socket.
    */
-  int getAddrInfo(std::pair<std::string, uint16_t>& addrinfo) const;
+  Endpoint getAddrInfo() const;
 
   /**
    * Stores address of this socket to sockaddr.  len must be
@@ -191,12 +188,9 @@ public:
   int getAddressFamily() const;
 
   /**
-   * Stores peer's address and port to peerinfo and returns address
-   * family.
-   *
-   * @param peerinfo placeholder to store peer's address and port.
+   * Returns peer's address, family and port.
    */
-  int getPeerInfo(std::pair<std::string, uint16_t>& peerinfo) const;
+  Endpoint getPeerInfo() const;
 
   /**
    * Accepts incoming connection on this socket.
@@ -288,9 +282,8 @@ public:
    */
   void readData(void* data, size_t& len);
 
-  ssize_t readDataFrom(
-      void* data, size_t len,
-      std::pair<std::string /* numerichost */, uint16_t /* port */>& sender);
+  // sender.addr will be numerihost assigned.
+  ssize_t readDataFrom(void* data, size_t len, Endpoint& sender);
 
 #ifdef ENABLE_SSL
   // Performs TLS server side handshake. If handshake is completed,

+ 12 - 0
src/a2netcompat.h

@@ -90,6 +90,8 @@
 #include "gai_strerror.h"
 #endif // HAVE_GAI_STRERROR
 
+#include <string>
+
 #ifdef HAVE_WINSOCK2_H
 #define sock_t SOCKET
 #else
@@ -120,6 +122,16 @@ struct SockAddr {
   socklen_t suLength;
 };
 
+// Human readable address, family and port.  In other words, addr is
+// text name, usually obtained from getnameinfo(3).  The family field
+// is the protocol family if it is known when generating this object.
+// If it is unknown, this is AF_UNSPEC.
+struct Endpoint {
+  std::string addr;
+  int family;
+  uint16_t port;
+};
+
 #define A2_DEFAULT_IOV_MAX 128
 
 #if defined(IOV_MAX) && IOV_MAX < A2_DEFAULT_IOV_MAX

+ 3 - 3
src/util.cc

@@ -1595,8 +1595,7 @@ void* allocateAlignedMemory(size_t alignment, size_t size)
 }
 #endif // HAVE_POSIX_MEMALIGN
 
-std::pair<std::string, uint16_t>
-getNumericNameInfo(const struct sockaddr* sockaddr, socklen_t len)
+Endpoint getNumericNameInfo(const struct sockaddr* sockaddr, socklen_t len)
 {
   char host[NI_MAXHOST];
   char service[NI_MAXSERV];
@@ -1606,7 +1605,8 @@ getNumericNameInfo(const struct sockaddr* sockaddr, socklen_t len)
     throw DL_ABORT_EX(
         fmt("Failed to get hostname and port. cause: %s", gai_strerror(s)));
   }
-  return std::pair<std::string, uint16_t>(host, atoi(service)); // TODO
+  return {host, sockaddr->sa_family,
+          static_cast<uint16_t>(strtoul(service, nullptr, 10))};
 }
 
 std::string htmlEscape(const std::string& src)

+ 1 - 2
src/util.h

@@ -444,8 +444,7 @@ std::string toString(const std::shared_ptr<BinaryStream>& binaryStream);
 void* allocateAlignedMemory(size_t alignment, size_t size);
 #endif // HAVE_POSIX_MEMALIGN
 
-std::pair<std::string, uint16_t>
-getNumericNameInfo(const struct sockaddr* sockaddr, socklen_t len);
+Endpoint getNumericNameInfo(const struct sockaddr* sockaddr, socklen_t len);
 
 std::string htmlEscape(const std::string& src);
 

+ 1 - 3
test/FtpConnectionTest.cc

@@ -55,9 +55,7 @@ public:
     listenSocket->bind(0);
     listenSocket->beginListen();
     listenSocket->setBlockingMode();
-    std::pair<std::string, uint16_t> addrinfo;
-    listenSocket->getAddrInfo(addrinfo);
-    listenPort_ = addrinfo.second;
+    listenPort_ = listenSocket->getAddrInfo().port;
 
     req_.reset(new Request());
     req_->setUri("ftp://localhost/dir%20sp/hello%20world.img");

+ 2 - 3
test/HttpServerTest.cc

@@ -22,11 +22,10 @@ namespace {
 std::unique_ptr<HttpServer> performHttpRequest(SocketCore& server,
                                                std::string request)
 {
-  std::pair<std::string, uint16_t> addr;
-  server.getAddrInfo(addr);
+  auto endpoint = server.getAddrInfo();
 
   SocketCore client;
-  client.establishConnection("localhost", addr.second);
+  client.establishConnection("localhost", endpoint.port);
   while (!client.isWritable(0)) {
   }
 

+ 2 - 2
test/LpdMessageDispatcherTest.cc

@@ -60,11 +60,11 @@ void LpdMessageDispatcherTest::testSendMessage()
 
   unsigned char buf[200];
 
-  std::pair<std::string, uint16_t> peer;
+  Endpoint remoteEndpoint;
   ssize_t nbytes;
   int trycnt;
   for (trycnt = 0; trycnt < 5; ++trycnt) {
-    nbytes = recvsock->readDataFrom(buf, sizeof(buf), peer);
+    nbytes = recvsock->readDataFrom(buf, sizeof(buf), remoteEndpoint);
     if (nbytes == 0) {
       util::sleep(1);
     }

+ 2 - 3
test/MSEHandshakeTest.cc

@@ -57,9 +57,8 @@ createSocketPair()
   receiverServerSock.beginListen();
   receiverServerSock.setBlockingMode();
 
-  std::pair<std::string, uint16_t> receiverAddrInfo;
-  receiverServerSock.getAddrInfo(receiverAddrInfo);
-  initiatorSock->establishConnection("localhost", receiverAddrInfo.second);
+  auto endpoint = receiverServerSock.getAddrInfo();
+  initiatorSock->establishConnection("localhost", endpoint.port);
   initiatorSock->setBlockingMode();
 
   std::shared_ptr<SocketCore> receiverSock(

+ 12 - 8
test/SocketCoreTest.cc

@@ -43,26 +43,30 @@ void SocketCoreTest::testWriteAndReadDatagram()
     SocketCore c(SOCK_DGRAM);
     c.bind(0);
 
-    std::pair<std::string, uint16_t> svaddr;
-    s.getAddrInfo(svaddr);
+    auto remoteEndpoint = s.getAddrInfo();
 
     std::string message1 = "hello world.";
-    c.writeData(message1.c_str(), message1.size(), "localhost", svaddr.second);
+    c.writeData(message1.c_str(), message1.size(), "localhost",
+                remoteEndpoint.port);
     std::string message2 = "chocolate coated pie";
-    c.writeData(message2.c_str(), message2.size(), "localhost", svaddr.second);
+    c.writeData(message2.c_str(), message2.size(), "localhost",
+                remoteEndpoint.port);
 
     char readbuffer[100];
-    std::pair<std::string, uint16_t> peer;
+
     {
-      ssize_t rlength = s.readDataFrom(readbuffer, sizeof(readbuffer), peer);
+      ssize_t rlength =
+          s.readDataFrom(readbuffer, sizeof(readbuffer), remoteEndpoint);
       // commented out because ip address may vary
-      // CPPUNIT_ASSERT_EQUAL(std::std::string("127.0.0.1"), peer.first);
+      // CPPUNIT_ASSERT_EQUAL(std::std::string("127.0.0.1"),
+      //                      remoteEndpoint.addr);
       CPPUNIT_ASSERT_EQUAL((ssize_t)message1.size(), rlength);
       readbuffer[rlength] = '\0';
       CPPUNIT_ASSERT_EQUAL(message1, std::string(readbuffer));
     }
     {
-      ssize_t rlength = s.readDataFrom(readbuffer, sizeof(readbuffer), peer);
+      ssize_t rlength =
+          s.readDataFrom(readbuffer, sizeof(readbuffer), remoteEndpoint);
       CPPUNIT_ASSERT_EQUAL((ssize_t)message2.size(), rlength);
       readbuffer[rlength] = '\0';
       CPPUNIT_ASSERT_EQUAL(message2, std::string(readbuffer));