Explorar o código

WinTLS: Fix hang because of buffered received data

WinTLSSession buffers received decrypted data into its own buffer.  If
read is requested, it copies the data from its buffer.  But if
requested buffer size is less than decrypted buffer, some of the data
is left in the buffer.  Previously, we had no facility to check the
existence of this pending data.  If this data is the last requested
data from remote server, we may end up waiting for read event even if
we have already data in our buffer, which may cause hang.  This commit
fixes this issue by introducing function to return the buffered length
in TLSSession.  SocketCore also provides the same function, which
delegates to TLSSession object.
Tatsuhiro Tsujikawa %!s(int64=9) %!d(string=hai) anos
pai
achega
cf2fa33fe0

+ 6 - 1
src/AbstractCommand.cc

@@ -148,6 +148,10 @@ bool AbstractCommand::shouldProcess() const
     if (socketRecvBuffer_ && !socketRecvBuffer_->bufferEmpty()) {
       return true;
     }
+
+    if (socket_ && socket_->getRecvBufferedLength()) {
+      return true;
+    }
   }
 
   if (checkSocketIsWritable_ && writeEventEnabled()) {
@@ -916,7 +920,8 @@ const std::shared_ptr<PieceStorage>& AbstractCommand::getPieceStorage() const
 
 void AbstractCommand::checkSocketRecvBuffer()
 {
-  if (socketRecvBuffer_->bufferEmpty()) {
+  if (socketRecvBuffer_->bufferEmpty() &&
+      socket_->getRecvBufferedLength() == 0) {
     return;
   }
 

+ 2 - 0
src/AppleTLSSession.h

@@ -105,6 +105,8 @@ public:
   // Returns last error string
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
 
+  virtual size_t getRecvBufferedLength() CXX11_OVERRIDE { return 0; }
+
 private:
   static OSStatus SocketWrite(SSLConnectionRef conn, const void* data,
                               size_t* len)

+ 3 - 1
src/HttpServerBodyCommand.cc

@@ -79,7 +79,8 @@ HttpServerBodyCommand::HttpServerBodyCommand(
   // To handle Content-Length == 0 case
   setStatus(Command::STATUS_ONESHOT_REALTIME);
   e_->addSocketForReadCheck(socket_, this);
-  if (!httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
+  if (!httpServer_->getSocketRecvBuffer()->bufferEmpty() ||
+      socket_->getRecvBufferedLength()) {
     e_->setNoWait(true);
   }
 }
@@ -178,6 +179,7 @@ bool HttpServerBodyCommand::execute()
   }
   try {
     if (socket_->isReadable(0) || (writeCheck_ && socket_->isWritable(0)) ||
+        socket_->getRecvBufferedLength() ||
         !httpServer_->getSocketRecvBuffer()->bufferEmpty() ||
         httpServer_->getContentLength() == 0) {
       timeoutTimer_ = global::wallclock();

+ 7 - 3
src/HttpServerCommand.cc

@@ -107,10 +107,13 @@ HttpServerCommand::~HttpServerCommand()
 
 void HttpServerCommand::checkSocketRecvBuffer()
 {
-  if (!httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
-    setStatus(Command::STATUS_ONESHOT_REALTIME);
-    e_->setNoWait(true);
+  if (httpServer_->getSocketRecvBuffer()->bufferEmpty() &&
+      socket_->getRecvBufferedLength() == 0) {
+    return;
   }
+
+  setStatus(Command::STATUS_ONESHOT_REALTIME);
+  e_->setNoWait(true);
 }
 
 #ifdef ENABLE_WEBSOCKET
@@ -172,6 +175,7 @@ bool HttpServerCommand::execute()
   }
   try {
     if (socket_->isReadable(0) || (writeCheck_ && socket_->isWritable(0)) ||
+        socket_->getRecvBufferedLength() ||
         !httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
       timeoutTimer_ = global::wallclock();
 

+ 1 - 0
src/LibgnutlsTLSSession.h

@@ -59,6 +59,7 @@ public:
                          std::string& handshakeErr) CXX11_OVERRIDE;
   virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
+  virtual size_t getRecvBufferedLength() CXX11_OVERRIDE { return 0; }
 
 private:
   gnutls_session_t sslSession_;

+ 1 - 0
src/LibsslTLSSession.h

@@ -59,6 +59,7 @@ public:
                          std::string& handshakeErr) CXX11_OVERRIDE;
   virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
+  virtual size_t getRecvBufferedLength() CXX11_OVERRIDE { return 0; }
 
 private:
   int handshake(TLSVersion& version);

+ 9 - 0
src/SocketCore.cc

@@ -1324,6 +1324,15 @@ void SocketCore::setSocketRecvBufferSize(int size)
 
 int SocketCore::getSocketRecvBufferSize() { return socketRecvBufferSize_; }
 
+size_t SocketCore::getRecvBufferedLength() const
+{
+  if (!tlsSession_) {
+    return 0;
+  }
+
+  return tlsSession_->getRecvBufferedLength();
+}
+
 std::vector<SockAddr> SocketCore::getInterfaceAddress(const std::string& iface,
                                                       int family, int aiFlags)
 {

+ 5 - 0
src/SocketCore.h

@@ -337,6 +337,11 @@ public:
    */
   bool wantWrite() const;
 
+  // Returns buffered data which are already received.  This data was
+  // already read from socket, and ready to read without reading
+  // socket.
+  size_t getRecvBufferedLength() const;
+
 #ifdef ENABLE_SSL
   static void
   setClientTLSContext(const std::shared_ptr<TLSContext>& tlsContext);

+ 4 - 0
src/TLSSession.h

@@ -107,6 +107,10 @@ public:
   // Returns last error string
   virtual std::string getLastErrorString() = 0;
 
+  // Returns buffered length, which can be read immediately without
+  // contacting network.
+  virtual size_t getRecvBufferedLength() = 0;
+
 protected:
   TLSSession() {}
 

+ 2 - 0
src/WinTLSSession.cc

@@ -832,4 +832,6 @@ std::string WinTLSSession::getLastErrorString()
   return ss.str();
 }
 
+size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
+
 } // namespace aria2

+ 2 - 0
src/WinTLSSession.h

@@ -169,6 +169,8 @@ public:
   // Returns last error string
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
 
+  virtual size_t getRecvBufferedLength() CXX11_OVERRIDE;
+
 private:
   std::string hostname_;
   sock_t sockfd_;