ソースを参照

2008-10-05 Tatsuhiro Tsujikawa <tujikawa at rednoah dot com>

	Made socket for dht connections non-blocking
	* src/DHTAbstractMessage.cc
	* src/DHTAbstractMessage.h
	* src/DHTConnection.h
	* src/DHTConnectionImpl.cc
	* src/DHTConnectionImpl.h
	* src/DHTMessage.h
	* src/DHTMessageDispatcherImpl.cc
	* src/DHTMessageDispatcherImpl.h
	* src/DHTUnknownMessage.cc
	* src/DHTUnknownMessage.h
	* src/SocketCore.cc
	* src/SocketCore.h
	* test/MockDHTMessage.h
Tatsuhiro Tsujikawa 17 年 前
コミット
389f770008

+ 17 - 0
ChangeLog

@@ -1,3 +1,20 @@
+2008-10-05  Tatsuhiro Tsujikawa  <tujikawa at rednoah dot com>
+
+	Made socket for dht connections non-blocking
+	* src/DHTAbstractMessage.cc
+	* src/DHTAbstractMessage.h
+	* src/DHTConnection.h
+	* src/DHTConnectionImpl.cc
+	* src/DHTConnectionImpl.h
+	* src/DHTMessage.h
+	* src/DHTMessageDispatcherImpl.cc
+	* src/DHTMessageDispatcherImpl.h
+	* src/DHTUnknownMessage.cc
+	* src/DHTUnknownMessage.h
+	* src/SocketCore.cc
+	* src/SocketCore.h
+	* test/MockDHTMessage.h
+
 2008-10-05  Tatsuhiro Tsujikawa  <tujikawa at rednoah dot com>
 
 	Changed the type of offset to int.

+ 11 - 5
src/DHTAbstractMessage.cc

@@ -33,6 +33,9 @@
  */
 /* copyright --> */
 #include "DHTAbstractMessage.h"
+
+#include <cassert>
+
 #include "DHTNode.h"
 #include "BencodeVisitor.h"
 #include "DHTConnection.h"
@@ -64,13 +67,16 @@ std::string DHTAbstractMessage::getBencodedMessage()
   return v.getBencodedData();
 }
 
-void DHTAbstractMessage::send()
+bool DHTAbstractMessage::send()
 {
   std::string message = getBencodedMessage();
-  _connection->sendMessage(reinterpret_cast<const unsigned char*>(message.c_str()),
-			   message.size(),
-			   _remoteNode->getIPAddress(),
-			   _remoteNode->getPort());
+  ssize_t r = _connection->sendMessage
+    (reinterpret_cast<const unsigned char*>(message.c_str()),
+     message.size(),
+     _remoteNode->getIPAddress(),
+     _remoteNode->getPort());
+  assert(r >= 0);
+  return r == static_cast<ssize_t>(message.size());
 }
 
 void DHTAbstractMessage::setConnection(const WeakHandle<DHTConnection>& connection)

+ 1 - 1
src/DHTAbstractMessage.h

@@ -62,7 +62,7 @@ public:
 
   virtual ~DHTAbstractMessage();
 
-  virtual void send();
+  virtual bool send();
 
   virtual std::string getType() const = 0;
 

+ 4 - 2
src/DHTConnection.h

@@ -44,9 +44,11 @@ class DHTConnection {
 public:
   virtual ~DHTConnection() {}
 
-  virtual ssize_t receiveMessage(unsigned char* data, size_t len, std::string& host, uint16_t& port) = 0;
+  virtual ssize_t receiveMessage(unsigned char* data, size_t len,
+				 std::string& host, uint16_t& port) = 0;
 
-  virtual void sendMessage(const unsigned char* data, size_t len, const std::string& host, uint16_t port) = 0;
+  virtual ssize_t sendMessage(const unsigned char* data, size_t len,
+			      const std::string& host, uint16_t port) = 0;
 };
 
 } // namespace aria2

+ 14 - 9
src/DHTConnectionImpl.cc

@@ -33,12 +33,14 @@
  */
 /* copyright --> */
 #include "DHTConnectionImpl.h"
+
+#include <utility>
+
 #include "LogFactory.h"
 #include "Logger.h"
 #include "RecoverableException.h"
 #include "Util.h"
 #include "Socket.h"
-#include <utility>
 
 namespace aria2 {
 
@@ -66,6 +68,7 @@ bool DHTConnectionImpl::bind(uint16_t& port)
 {
   try {
     _socket->bind(port);
+    _socket->setNonBlockingMode();
     std::pair<std::string, uint16_t> svaddr;
     _socket->getAddrInfo(svaddr);
     port = svaddr.second;
@@ -77,22 +80,24 @@ bool DHTConnectionImpl::bind(uint16_t& port)
   return false;
 }
 
-ssize_t DHTConnectionImpl::receiveMessage(unsigned char* data, size_t len, std::string& host, uint16_t& port)
+ssize_t DHTConnectionImpl::receiveMessage(unsigned char* data, size_t len,
+					  std::string& host, uint16_t& port)
 {
-  if(_socket->isReadable(0)) {
-    std::pair<std::string, uint16_t> remoteHost;
-    ssize_t length = _socket->readDataFrom(data, len, remoteHost);
+  std::pair<std::string, uint16_t> remoteHost;
+  ssize_t length = _socket->readDataFrom(data, len, remoteHost);
+  if(length == 0) {
+    return length;
+  } else {
     host = remoteHost.first;
     port = remoteHost.second;
     return length;
-  } else {
-    return -1;
   }
 }
 
-void DHTConnectionImpl::sendMessage(const unsigned char* data, size_t len, const std::string& host, uint16_t port)
+ssize_t DHTConnectionImpl::sendMessage(const unsigned char* data, size_t len,
+				       const std::string& host, uint16_t port)
 {
-  _socket->writeData(data, len, host, port);
+  return _socket->writeData(data, len, host, port);
 }
 
 SharedHandle<SocketCore> DHTConnectionImpl::getSocket() const

+ 4 - 2
src/DHTConnectionImpl.h

@@ -71,9 +71,11 @@ public:
    */
   bool bind(uint16_t& port);
 
-  virtual ssize_t receiveMessage(unsigned char* data, size_t len, std::string& host, uint16_t& port);
+  virtual ssize_t receiveMessage(unsigned char* data, size_t len,
+				 std::string& host, uint16_t& port);
 
-  virtual void sendMessage(const unsigned char* data, size_t len, const std::string& host, uint16_t port);
+  virtual ssize_t sendMessage(const unsigned char* data, size_t len,
+			      const std::string& host, uint16_t port);
 
   SharedHandle<SocketCore> getSocket() const;
 };

+ 4 - 2
src/DHTMessage.h

@@ -36,9 +36,11 @@
 #define _D_DHT_MESSAGE_H_
 
 #include "common.h"
+
+#include <string>
+
 #include "SharedHandle.h"
 #include "A2STR.h"
-#include <string>
 
 namespace aria2 {
 
@@ -71,7 +73,7 @@ public:
 
   virtual void doReceivedAction() = 0;
 
-  virtual void send() = 0;
+  virtual bool send() = 0;
 
   virtual bool isReply() const = 0;
 

+ 18 - 8
src/DHTMessageDispatcherImpl.cc

@@ -67,28 +67,38 @@ DHTMessageDispatcherImpl::addMessageToQueue(const SharedHandle<DHTMessage>& mess
   addMessageToQueue(message, DHT_MESSAGE_TIMEOUT, callback);
 }
 
-void
+bool
 DHTMessageDispatcherImpl::sendMessage(const SharedHandle<DHTMessageEntry>& entry)
 {
   try {
-    entry->_message->send();
-    if(!entry->_message->isReply()) {
-      _tracker->addMessage(entry->_message, entry->_timeout, entry->_callback);
+    if(entry->_message->send()) {
+      if(!entry->_message->isReply()) {
+	_tracker->addMessage(entry->_message, entry->_timeout, entry->_callback);
+      }
+      _logger->info("Message sent: %s", entry->_message->toString().c_str());
+    } else {
+      return false;
     }
-    _logger->info("Message sent: %s", entry->_message->toString().c_str());
   } catch(RecoverableException& e) {
     _logger->error("Failed to send message: %s", e, entry->_message->toString().c_str());
   }
+  return true;
 }
 
 void DHTMessageDispatcherImpl::sendMessages()
 {
   // TODO I can't use bind1st and mem_fun here because bind1st cannot bind a
   // function which takes a reference as an argument..
-  for(std::deque<SharedHandle<DHTMessageEntry> >::iterator itr = _messageQueue.begin(); itr != _messageQueue.end(); ++itr) {
-    sendMessage(*itr);
+  std::deque<SharedHandle<DHTMessageEntry> >::iterator itr =
+    _messageQueue.begin();
+  for(; itr != _messageQueue.end(); ++itr) {
+    if(!sendMessage(*itr)) {
+      break;
+    }
   }
-  _messageQueue.clear();
+  _messageQueue.erase(_messageQueue.begin(), itr);
+  _logger->debug("%lu dht messages remaining in the queue.",
+		 static_cast<unsigned long>(_messageQueue.size()));
 }
 
 size_t DHTMessageDispatcherImpl::countMessageInQueue() const

+ 1 - 1
src/DHTMessageDispatcherImpl.h

@@ -52,7 +52,7 @@ private:
 
   Logger* _logger;
 
-  void sendMessage(const SharedHandle<DHTMessageEntry>& msg);
+  bool sendMessage(const SharedHandle<DHTMessageEntry>& msg);
 public:
   DHTMessageDispatcherImpl(const SharedHandle<DHTMessageTracker>& tracker);
 

+ 5 - 2
src/DHTUnknownMessage.cc

@@ -33,9 +33,12 @@
  */
 /* copyright --> */
 #include "DHTUnknownMessage.h"
+
+#include <cstring>
+#include <cstdlib>
+
 #include "DHTNode.h"
 #include "Util.h"
-#include <cstring>
 
 namespace aria2 {
 
@@ -66,7 +69,7 @@ DHTUnknownMessage::~DHTUnknownMessage()
 
 void DHTUnknownMessage::doReceivedAction() {}
 
-void DHTUnknownMessage::send() {}
+bool DHTUnknownMessage::send() { return true; }
 
 bool DHTUnknownMessage::isReply() const
 {

+ 1 - 1
src/DHTUnknownMessage.h

@@ -57,7 +57,7 @@ public:
   virtual void doReceivedAction();
 
   // do nothing; we don't use this message as outgoing message.
-  virtual void send();
+  virtual bool send();
 
   // always return false
   virtual bool isReply() const;

+ 20 - 3
src/SocketCore.cc

@@ -839,8 +839,11 @@ bool SocketCore::initiateSecureConnection()
 #endif // __MINGW32__
 }
 
-void SocketCore::writeData(const char* data, size_t len, const std::string& host, uint16_t port)
+ssize_t SocketCore::writeData(const char* data, size_t len,
+			      const std::string& host, uint16_t port)
 {
+  _wantRead = false;
+  _wantWrite = false;
 
   struct addrinfo hints;
   struct addrinfo* res;
@@ -861,17 +864,25 @@ void SocketCore::writeData(const char* data, size_t len, const std::string& host
     if(r == static_cast<ssize_t>(len)) {
       break;
     }
+    if(r == -1 && errno == EAGAIN) {
+      _wantWrite = true;
+      r = 0;
+      break;
+    }
   }
   freeaddrinfo(res);
   if(r == -1) {
     throw DlAbortEx(StringFormat(EX_SOCKET_SEND, errorMsg()).str());
   }
+  return r;
 }
 
 ssize_t SocketCore::readDataFrom(char* data, size_t len,
 				 std::pair<std::string /* numerichost */,
 				 uint16_t /* port */>& sender)
 {
+  _wantRead = false;
+  _wantWrite = false;
   struct sockaddr_storage sockaddr;
   socklen_t sockaddrlen = sizeof(struct sockaddr_storage);
   struct sockaddr* addrp = reinterpret_cast<struct sockaddr*>(&sockaddr);
@@ -879,9 +890,15 @@ ssize_t SocketCore::readDataFrom(char* data, size_t len,
   while((r = recvfrom(sockfd, data, len, 0, addrp, &sockaddrlen)) == -1 &&
 	EINTR == errno);
   if(r == -1) {
-    throw DlAbortEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str());
+    if(errno == EAGAIN) {
+      _wantRead = true;
+      r = 0;
+    } else {
+      throw DlRetryEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str());
+    }
+  } else {
+    sender = Util::getNumericNameInfo(addrp, sockaddrlen);
   }
-  sender = Util::getNumericNameInfo(addrp, sockaddrlen);
 
   return r;
 }

+ 6 - 4
src/SocketCore.h

@@ -222,12 +222,14 @@ public:
     return writeData(reinterpret_cast<const char*>(data), len);
   }
 
-  void writeData(const char* data, size_t len, const std::string& host, uint16_t port);
+  ssize_t writeData(const char* data, size_t len,
+		    const std::string& host, uint16_t port);
 
-  void writeData(const unsigned char* data, size_t len, const std::string& host,
-		 uint16_t port)
+  ssize_t writeData(const unsigned char* data, size_t len,
+		    const std::string& host,
+		    uint16_t port)
   {
-    writeData(reinterpret_cast<const char*>(data), len, host, port);
+    return writeData(reinterpret_cast<const char*>(data), len, host, port);
   }
 
   /**

+ 4 - 2
test/MockDHTMessage.h

@@ -2,9 +2,11 @@
 #define _D_MOCK_DHT_MESSAGE_H_
 
 #include "DHTMessage.h"
+
+#include <deque>
+
 #include "DHTNode.h"
 #include "Peer.h"
-#include <deque>
 
 namespace aria2 {
 
@@ -30,7 +32,7 @@ public:
 
   virtual void doReceivedAction() {}
 
-  virtual void send() {}
+  virtual bool send() { return true; }
 
   virtual bool isReply() const { return _isReply; }