Browse Source

WinTLS: Fix hang because of buffered received data

WinTLSSession buffers received decrypted data into its own buffer.  If
read is requested, it copies the data from its buffer.  But if
requested buffer size is less than decrypted buffer, some of the data
is left in the buffer.  Previously, we had no facility to check the
existence of this pending data.  If this data is the last requested
data from remote server, we may end up waiting for read event even if
we have already data in our buffer, which may cause hang.  This commit
fixes this issue by introducing function to return the buffered length
in TLSSession.  SocketCore also provides the same function, which
delegates to TLSSession object.
Tatsuhiro Tsujikawa 9 years ago
parent
commit
cf2fa33fe0

+ 6 - 1
src/AbstractCommand.cc

@@ -148,6 +148,10 @@ bool AbstractCommand::shouldProcess() const
     if (socketRecvBuffer_ && !socketRecvBuffer_->bufferEmpty()) {
     if (socketRecvBuffer_ && !socketRecvBuffer_->bufferEmpty()) {
       return true;
       return true;
     }
     }
+
+    if (socket_ && socket_->getRecvBufferedLength()) {
+      return true;
+    }
   }
   }
 
 
   if (checkSocketIsWritable_ && writeEventEnabled()) {
   if (checkSocketIsWritable_ && writeEventEnabled()) {
@@ -916,7 +920,8 @@ const std::shared_ptr<PieceStorage>& AbstractCommand::getPieceStorage() const
 
 
 void AbstractCommand::checkSocketRecvBuffer()
 void AbstractCommand::checkSocketRecvBuffer()
 {
 {
-  if (socketRecvBuffer_->bufferEmpty()) {
+  if (socketRecvBuffer_->bufferEmpty() &&
+      socket_->getRecvBufferedLength() == 0) {
     return;
     return;
   }
   }
 
 

+ 2 - 0
src/AppleTLSSession.h

@@ -105,6 +105,8 @@ public:
   // Returns last error string
   // Returns last error string
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
 
 
+  virtual size_t getRecvBufferedLength() CXX11_OVERRIDE { return 0; }
+
 private:
 private:
   static OSStatus SocketWrite(SSLConnectionRef conn, const void* data,
   static OSStatus SocketWrite(SSLConnectionRef conn, const void* data,
                               size_t* len)
                               size_t* len)

+ 3 - 1
src/HttpServerBodyCommand.cc

@@ -79,7 +79,8 @@ HttpServerBodyCommand::HttpServerBodyCommand(
   // To handle Content-Length == 0 case
   // To handle Content-Length == 0 case
   setStatus(Command::STATUS_ONESHOT_REALTIME);
   setStatus(Command::STATUS_ONESHOT_REALTIME);
   e_->addSocketForReadCheck(socket_, this);
   e_->addSocketForReadCheck(socket_, this);
-  if (!httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
+  if (!httpServer_->getSocketRecvBuffer()->bufferEmpty() ||
+      socket_->getRecvBufferedLength()) {
     e_->setNoWait(true);
     e_->setNoWait(true);
   }
   }
 }
 }
@@ -178,6 +179,7 @@ bool HttpServerBodyCommand::execute()
   }
   }
   try {
   try {
     if (socket_->isReadable(0) || (writeCheck_ && socket_->isWritable(0)) ||
     if (socket_->isReadable(0) || (writeCheck_ && socket_->isWritable(0)) ||
+        socket_->getRecvBufferedLength() ||
         !httpServer_->getSocketRecvBuffer()->bufferEmpty() ||
         !httpServer_->getSocketRecvBuffer()->bufferEmpty() ||
         httpServer_->getContentLength() == 0) {
         httpServer_->getContentLength() == 0) {
       timeoutTimer_ = global::wallclock();
       timeoutTimer_ = global::wallclock();

+ 7 - 3
src/HttpServerCommand.cc

@@ -107,10 +107,13 @@ HttpServerCommand::~HttpServerCommand()
 
 
 void HttpServerCommand::checkSocketRecvBuffer()
 void HttpServerCommand::checkSocketRecvBuffer()
 {
 {
-  if (!httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
-    setStatus(Command::STATUS_ONESHOT_REALTIME);
-    e_->setNoWait(true);
+  if (httpServer_->getSocketRecvBuffer()->bufferEmpty() &&
+      socket_->getRecvBufferedLength() == 0) {
+    return;
   }
   }
+
+  setStatus(Command::STATUS_ONESHOT_REALTIME);
+  e_->setNoWait(true);
 }
 }
 
 
 #ifdef ENABLE_WEBSOCKET
 #ifdef ENABLE_WEBSOCKET
@@ -172,6 +175,7 @@ bool HttpServerCommand::execute()
   }
   }
   try {
   try {
     if (socket_->isReadable(0) || (writeCheck_ && socket_->isWritable(0)) ||
     if (socket_->isReadable(0) || (writeCheck_ && socket_->isWritable(0)) ||
+        socket_->getRecvBufferedLength() ||
         !httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
         !httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
       timeoutTimer_ = global::wallclock();
       timeoutTimer_ = global::wallclock();
 
 

+ 1 - 0
src/LibgnutlsTLSSession.h

@@ -59,6 +59,7 @@ public:
                          std::string& handshakeErr) CXX11_OVERRIDE;
                          std::string& handshakeErr) CXX11_OVERRIDE;
   virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
   virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
+  virtual size_t getRecvBufferedLength() CXX11_OVERRIDE { return 0; }
 
 
 private:
 private:
   gnutls_session_t sslSession_;
   gnutls_session_t sslSession_;

+ 1 - 0
src/LibsslTLSSession.h

@@ -59,6 +59,7 @@ public:
                          std::string& handshakeErr) CXX11_OVERRIDE;
                          std::string& handshakeErr) CXX11_OVERRIDE;
   virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
   virtual int tlsAccept(TLSVersion& version) CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
+  virtual size_t getRecvBufferedLength() CXX11_OVERRIDE { return 0; }
 
 
 private:
 private:
   int handshake(TLSVersion& version);
   int handshake(TLSVersion& version);

+ 9 - 0
src/SocketCore.cc

@@ -1324,6 +1324,15 @@ void SocketCore::setSocketRecvBufferSize(int size)
 
 
 int SocketCore::getSocketRecvBufferSize() { return socketRecvBufferSize_; }
 int SocketCore::getSocketRecvBufferSize() { return socketRecvBufferSize_; }
 
 
+size_t SocketCore::getRecvBufferedLength() const
+{
+  if (!tlsSession_) {
+    return 0;
+  }
+
+  return tlsSession_->getRecvBufferedLength();
+}
+
 std::vector<SockAddr> SocketCore::getInterfaceAddress(const std::string& iface,
 std::vector<SockAddr> SocketCore::getInterfaceAddress(const std::string& iface,
                                                       int family, int aiFlags)
                                                       int family, int aiFlags)
 {
 {

+ 5 - 0
src/SocketCore.h

@@ -337,6 +337,11 @@ public:
    */
    */
   bool wantWrite() const;
   bool wantWrite() const;
 
 
+  // Returns buffered data which are already received.  This data was
+  // already read from socket, and ready to read without reading
+  // socket.
+  size_t getRecvBufferedLength() const;
+
 #ifdef ENABLE_SSL
 #ifdef ENABLE_SSL
   static void
   static void
   setClientTLSContext(const std::shared_ptr<TLSContext>& tlsContext);
   setClientTLSContext(const std::shared_ptr<TLSContext>& tlsContext);

+ 4 - 0
src/TLSSession.h

@@ -107,6 +107,10 @@ public:
   // Returns last error string
   // Returns last error string
   virtual std::string getLastErrorString() = 0;
   virtual std::string getLastErrorString() = 0;
 
 
+  // Returns buffered length, which can be read immediately without
+  // contacting network.
+  virtual size_t getRecvBufferedLength() = 0;
+
 protected:
 protected:
   TLSSession() {}
   TLSSession() {}
 
 

+ 2 - 0
src/WinTLSSession.cc

@@ -832,4 +832,6 @@ std::string WinTLSSession::getLastErrorString()
   return ss.str();
   return ss.str();
 }
 }
 
 
+size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
+
 } // namespace aria2
 } // namespace aria2

+ 2 - 0
src/WinTLSSession.h

@@ -169,6 +169,8 @@ public:
   // Returns last error string
   // Returns last error string
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
   virtual std::string getLastErrorString() CXX11_OVERRIDE;
 
 
+  virtual size_t getRecvBufferedLength() CXX11_OVERRIDE;
+
 private:
 private:
   std::string hostname_;
   std::string hostname_;
   sock_t sockfd_;
   sock_t sockfd_;