Explorar el Código

2008-09-27 Tatsuhiro Tsujikawa <tujikawa at rednoah dot com>

	Fixed the bug that HTTPS download fails.
	* src/AbstractCommand.cc
	* src/AbstractCommand.h
	* src/DownloadCommand.cc
	* src/FtpConnection.cc
	* src/HttpConnection.cc
	* src/HttpRequestCommand.cc
	* src/HttpResponseCommand.cc
	* src/HttpSkipResponseCommand.cc
	* src/MSEHandshake.cc
	* src/PeerConnection.cc
	* src/SocketCore.cc
	* src/SocketCore.h

	Fixed the bug that aria2 doesn't download whole content body and 
cannot
	reuse connection if chunked transfer encoding and gzip content 
encoding
	are set.
	* src/DownloadCommand.cc
	* src/HttpSkipResponseCommand.cc
Tatsuhiro Tsujikawa hace 17 años
padre
commit
e9e215dc1f

+ 22 - 0
ChangeLog

@@ -1,3 +1,25 @@
+2008-09-27  Tatsuhiro Tsujikawa  <tujikawa at rednoah dot com>
+
+	Fixed the bug that HTTPS download fails.
+	* src/AbstractCommand.cc
+	* src/AbstractCommand.h
+	* src/DownloadCommand.cc
+	* src/FtpConnection.cc
+	* src/HttpConnection.cc
+	* src/HttpRequestCommand.cc
+	* src/HttpResponseCommand.cc
+	* src/HttpSkipResponseCommand.cc
+	* src/MSEHandshake.cc
+	* src/PeerConnection.cc
+	* src/SocketCore.cc
+	* src/SocketCore.h
+
+	Fixed the bug that aria2 doesn't download whole content body and cannot
+	reuse connection if chunked transfer encoding and gzip content encoding
+	are set.
+	* src/DownloadCommand.cc
+	* src/HttpSkipResponseCommand.cc
+	
 2008-09-27  Tatsuhiro Tsujikawa  <tujikawa at rednoah dot com>
 
 	Updated man page.

+ 20 - 0
src/AbstractCommand.cc

@@ -245,6 +245,16 @@ void AbstractCommand::setReadCheckSocket(const SocketHandle& socket) {
   }
 }
 
+void AbstractCommand::setReadCheckSocketIf
+(const SharedHandle<SocketCore>& socket, bool pred)
+{
+  if(pred) {
+    setReadCheckSocket(socket);
+  } else {
+    disableReadCheckSocket();
+  }
+}
+
 void AbstractCommand::disableWriteCheckSocket() {
   if(checkSocketIsWritable) {
     e->deleteSocketForWriteCheck(writeCheckTarget, this);
@@ -271,6 +281,16 @@ void AbstractCommand::setWriteCheckSocket(const SocketHandle& socket) {
   }
 }
 
+void AbstractCommand::setWriteCheckSocketIf
+(const SharedHandle<SocketCore>& socket, bool pred)
+{
+  if(pred) {
+    setWriteCheckSocket(socket);
+  } else {
+    disableWriteCheckSocket();
+  }
+}
+
 static bool isProxyGETRequest(const std::string& protocol, const Option* option)
 {
   return

+ 11 - 0
src/AbstractCommand.h

@@ -84,6 +84,17 @@ protected:
   void disableReadCheckSocket();
   void disableWriteCheckSocket();
 
+  /**
+   * If pred == true, calls setReadCheckSocket(socket). Otherwise, calls
+   * disableReadCheckSocket().
+   */
+  void setReadCheckSocketIf(const SharedHandle<SocketCore>& socket, bool pred);
+  /**
+   * If pred == true, calls setWriteCheckSocket(socket). Otherwise, calls
+   * disableWriteCheckSocket().
+   */
+  void setWriteCheckSocketIf(const SharedHandle<SocketCore>& socket, bool pred);
+
   void setTimeout(time_t timeout) { this->timeout = timeout; }
 
   void prepareForNextAction(Command* nextCommand = 0);

+ 29 - 14
src/DownloadCommand.cc

@@ -159,24 +159,38 @@ bool DownloadCommand::executeInternal() {
   
   peerStat->updateDownloadLength(bufSize);
 
-  if(_requestGroup->getTotalLength() != 0 && bufSize == 0) {
+  if(_requestGroup->getTotalLength() != 0 && bufSize == 0 &&
+     !socket->wantRead() && !socket->wantWrite()) {
     throw DlRetryEx(EX_GOT_EOF);
   }
-  if((!_transferEncodingDecoder.isNull() &&
-      _transferEncodingDecoder->finished())
-     || (_transferEncodingDecoder.isNull() && segment->complete())
-     || (!_contentEncodingDecoder.isNull() &&
-	 _contentEncodingDecoder->finished())
-     || bufSize == 0) {
-    logger->info(MSG_SEGMENT_DOWNLOAD_COMPLETED, cuid);
-
-    if(!_contentEncodingDecoder.isNull() &&
-       !_contentEncodingDecoder->finished()) {
-      logger->warn("CUID#%d - Transfer was completed, but inflate operation"
-		   " have not finished. Maybe the file is broken in the server"
-		   " side.", cuid);
+  bool segmentComplete = false;
+  if(_transferEncodingDecoder.isNull() && _contentEncodingDecoder.isNull()) {
+    if(segment->complete()) {
+      segmentComplete = true;
+    } else if(segment->getLength() == 0 && bufSize == 0 &&
+	      !socket->wantRead() && !socket->wantWrite()) {
+      segmentComplete = true;
+    }
+  } else if(!_transferEncodingDecoder.isNull() &&
+	    !_contentEncodingDecoder.isNull()) {
+    if(_transferEncodingDecoder->finished() &&
+       _contentEncodingDecoder->finished()) {
+      segmentComplete = true;
+    }
+  } else if(!_transferEncodingDecoder.isNull() &&
+	    _contentEncodingDecoder.isNull()) {
+    if(_transferEncodingDecoder->finished()) {
+      segmentComplete = true;
     }
+  } else if(_transferEncodingDecoder.isNull() &&
+	    !_contentEncodingDecoder.isNull()) {
+    if(_contentEncodingDecoder->finished()) {
+      segmentComplete = true;
+    }
+  }
 
+  if(segmentComplete) {
+    logger->info(MSG_SEGMENT_DOWNLOAD_COMPLETED, cuid);
 #ifdef ENABLE_MESSAGE_DIGEST
 
     {
@@ -211,6 +225,7 @@ bool DownloadCommand::executeInternal() {
     return prepareForNextSegment();
   } else {
     checkLowestDownloadSpeed();
+    setWriteCheckSocketIf(socket, socket->wantWrite());
     e->commands.push_back(this);
     return false;
   }

+ 3 - 0
src/FtpConnection.cc

@@ -285,6 +285,9 @@ bool FtpConnection::bulkReceiveResponse(std::pair<unsigned int, std::string>& re
     size_t size = sizeof(buf);
     socket->readData(buf, size);
     if(size == 0) {
+      if(socket->wantRead() || socket->wantWrite()) {
+	return false;
+      }
       throw DlRetryEx(EX_GOT_EOF);
     }
     if(strbuf.size()+size > MAX_RECV_BUFFER) {

+ 5 - 1
src/HttpConnection.cc

@@ -126,7 +126,11 @@ HttpResponseHandle HttpConnection::receiveResponse()
   size_t size = sizeof(buf);
   socket->peekData(buf, size);
   if(size == 0) {
-    throw DlRetryEx(EX_INVALID_RESPONSE);
+    if(socket->wantRead() || socket->wantWrite()) {
+      return SharedHandle<HttpResponse>();
+    } else {
+      throw DlRetryEx(EX_INVALID_RESPONSE);
+    }
   }
   proc->update(buf, size);
   if(!proc->eoh()) {

+ 11 - 4
src/HttpRequestCommand.cc

@@ -99,11 +99,17 @@ createHttpRequest(const SharedHandle<Request>& req,
 
 bool HttpRequestCommand::executeInternal() {
   //socket->setBlockingMode();
+  if(req->getProtocol() == Request::PROTO_HTTPS) {
+    socket->prepareSecureConnection();
+    if(!socket->initiateSecureConnection()) {
+      setReadCheckSocketIf(socket, socket->wantRead());
+      setWriteCheckSocketIf(socket, socket->wantWrite());
+      e->commands.push_back(this);
+      return false;
+    }
+  }
   if(_httpConnection->sendBufferIsEmpty()) {
     checkIfConnectionEstablished(socket);
-    if(req->getProtocol() == Request::PROTO_HTTPS) {
-      socket->initiateSecureConnection();
-    }
 
     if(_segments.empty()) {
       HttpRequestHandle httpRequest
@@ -134,7 +140,8 @@ bool HttpRequestCommand::executeInternal() {
     e->commands.push_back(command);
     return true;
   } else {
-    setWriteCheckSocket(socket);
+    setReadCheckSocketIf(socket, socket->wantRead());
+    setWriteCheckSocketIf(socket, socket->wantWrite());
     e->commands.push_back(this);
     return false;
   }

+ 3 - 0
src/HttpResponseCommand.cc

@@ -83,6 +83,9 @@ bool HttpResponseCommand::executeInternal()
   HttpResponseHandle httpResponse = httpConnection->receiveResponse();
   if(httpResponse.isNull()) {
     // The server has not responded to our request yet.
+    // For socket->wantRead() == true, setReadCheckSocket(socket) is already
+    // done in the constructor.
+    setWriteCheckSocketIf(socket, socket->wantWrite());
     e->commands.push_back(this);
     return false;
   }

+ 16 - 10
src/HttpSkipResponseCommand.cc

@@ -96,7 +96,8 @@ bool HttpSkipResponseCommand::executeInternal()
       // The return value is safely ignored here.
       _transferEncodingDecoder->decode(buf, bufSize);
     }
-    if(_totalLength != 0 && bufSize == 0) {
+    if(_totalLength != 0 && bufSize == 0 &&
+       !socket->wantRead() && !socket->wantWrite()) {
       throw DlRetryEx(EX_GOT_EOF);
     }
   } catch(RecoverableException& e) {
@@ -104,15 +105,19 @@ bool HttpSkipResponseCommand::executeInternal()
     return processResponse();
   }
 
-  if(bufSize == 0) {
-    // Since this method is called by DownloadEngine only when the socket is
-    // readable, bufSize == 0 means server shutdown the connection.
-    // So socket cannot be reused in this case.
-    return prepareForRetry(0);
-  } else if((!_transferEncodingDecoder.isNull() &&
-	     _transferEncodingDecoder->finished())
-	    || (_transferEncodingDecoder.isNull() &&
-		_totalLength == _receivedBytes)) {
+  bool finished = false;
+  if(_transferEncodingDecoder.isNull()) {
+    if(bufSize == 0) {
+      if(!socket->wantRead() && !socket->wantWrite()) {
+	return processResponse();
+      }
+    } else {
+      finished = (_totalLength == _receivedBytes);
+    }
+  } else {
+    finished = _transferEncodingDecoder->finished();
+  }
+  if(finished) {
     if(!e->option->getAsBool(PREF_HTTP_PROXY_ENABLED) &&
        req->supportsPersistentConnection()) {
       std::pair<std::string, uint16_t> peerInfo;
@@ -121,6 +126,7 @@ bool HttpSkipResponseCommand::executeInternal()
     }
     return processResponse();
   } else {
+    setWriteCheckSocketIf(socket, socket->wantWrite());
     e->commands.push_back(this);
     return false;
   }

+ 8 - 2
src/MSEHandshake.cc

@@ -93,7 +93,7 @@ MSEHandshake::HANDSHAKE_TYPE MSEHandshake::identifyHandshakeType()
   }
   size_t r = 20-_rbufLength;
   _socket->readData(_rbuf+_rbufLength, r);
-  if(r == 0) {
+  if(r == 0 && !_socket->wantRead() && !_socket->wantWrite()) {
     throw DlAbortEx(EX_EOF_FROM_PEER);
   }
   _rbufLength += r;
@@ -301,6 +301,9 @@ bool MSEHandshake::findInitiatorVCMarker()
   }
   _socket->peekData(_rbuf+_rbufLength, r);
   if(r == 0) {
+    if(_socket->wantRead() || _socket->wantWrite()) {
+      return false;
+    }
     throw DlAbortEx(EX_EOF_FROM_PEER);
   }
   // find vc
@@ -388,6 +391,9 @@ bool MSEHandshake::findReceiverHashMarker()
   }
   _socket->peekData(_rbuf+_rbufLength, r);
   if(r == 0) {
+    if(_socket->wantRead() || _socket->wantWrite()) {
+      return false;
+    }
     throw DlAbortEx(EX_EOF_FROM_PEER);
   }
   // find hash('req1', S), S is _secret.
@@ -575,7 +581,7 @@ size_t MSEHandshake::receiveNBytes(size_t bytes)
       return 0;
     }
     _socket->readData(_rbuf+_rbufLength, r);
-    if(r == 0) {
+    if(r == 0 && !_socket->wantRead() && !_socket->wantWrite()) {
       throw DlAbortEx(EX_EOF_FROM_PEER);
     }
     _rbufLength += r;

+ 9 - 0
src/PeerConnection.cc

@@ -86,6 +86,9 @@ bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
     size_t temp = remaining;
     readData(lenbuf+lenbufLength, remaining, _encryptionEnabled);
     if(remaining == 0) {
+      if(socket->wantRead() || socket->wantWrite()) {
+	return false;
+      }
       // we got EOF
       logger->debug("CUID#%d - In PeerConnection::receiveMessage(), remain=%zu",
 		    cuid, temp);
@@ -111,6 +114,9 @@ bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
   if(remaining > 0) {
     readData(resbuf+resbufLength, remaining, _encryptionEnabled);
     if(remaining == 0) {
+      if(socket->wantRead() || socket->wantWrite()) {
+	return false;
+      }
       // we got EOF
       logger->debug("CUID#%d - In PeerConnection::receiveMessage(), payloadlen=%zu, remaining=%zu",
 		    cuid, currentPayloadLength, temp);
@@ -154,6 +160,9 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
       size_t temp = remaining;
       readData(resbuf+resbufLength, remaining, _encryptionEnabled);
       if(remaining == 0) {
+	if(socket->wantRead() || socket->wantWrite()) {
+	  return false;
+	}
 	// we got EOF
 	logger->debug("CUID#%d - In PeerConnection::receiveHandshake(), remain=%zu",
 		      cuid, temp);

+ 178 - 55
src/SocketCore.cc

@@ -59,7 +59,7 @@
 #else
 # define CLOSE(X) while(close(X) == -1 && errno == EINTR)
 #endif // __MINGW32__
-
+#include "LogFactory.h"
 namespace aria2 {
 
 SocketCore::SocketCore(int sockType):_sockType(sockType), sockfd(-1)  {
@@ -80,7 +80,11 @@ void SocketCore::init()
 #endif // HAVE_EPOLL
 
   blocking = true;
-  secure = false;
+  secure = 0;
+
+  _wantRead = false;
+  _wantWrite = false;
+
 #ifdef HAVE_LIBSSL
   // for SSL
   sslCtx = NULL;
@@ -440,42 +444,74 @@ bool SocketCore::isReadable(time_t timeout)
 
 }
 
+#ifdef HAVE_LIBSSL
+int SocketCore::sslHandleEAGAIN(int ret)
+{
+  int error = SSL_get_error(ssl, ret);
+  if(error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) {
+    ret = 0;
+    if(error == SSL_ERROR_WANT_READ) {
+      _wantRead = true;
+    } else {
+      _wantWrite = true;
+    }
+  }
+  return ret;
+}
+#endif // HAVE_LIBSSL
+
+#ifdef HAVE_LIBGNUTLS
+void SocketCore::gnutlsRecordCheckDirection()
+{
+  int direction = gnutls_record_get_direction(sslSession);
+  if(direction == 0) {
+    _wantRead = true;
+  } else { // if(direction == 1) {
+    _wantWrite = true;
+  }
+}
+#endif // HAVE_LIBGNUTLS
+
 ssize_t SocketCore::writeData(const char* data, size_t len)
 {
   ssize_t ret = 0;
+  _wantRead = false;
+  _wantWrite = false;
 
   if(!secure) {
     while((ret = send(sockfd, data, len, 0)) == -1 && errno == EINTR);
-    if(ret == -1 && errno == EAGAIN) {
-      ret = 0;
-    }
     if(ret == -1) {
-      throw DlRetryEx(StringFormat(EX_SOCKET_SEND, errorMsg()).str());
+      if(errno == EAGAIN) {
+	_wantWrite = true;
+	ret = 0;
+      } else {
+	throw DlRetryEx(StringFormat(EX_SOCKET_SEND, errorMsg()).str());
+      }
     }
   } else {
 #ifdef HAVE_LIBSSL
-     // for SSL
-     // TODO handling len == 0 case required
     ret = SSL_write(ssl, data, len);
+    if(ret == 0) {
+      throw DlRetryEx
+	(StringFormat
+	 (EX_SOCKET_SEND, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
+    }
     if(ret < 0) {
-      switch(SSL_get_error(ssl, ret)) {
-      case SSL_ERROR_WANT_READ:
-      case SSL_ERROR_WANT_WRITE:
-	ret = 0;
-      }
+      ret = sslHandleEAGAIN(ret);
     }
-    if(ret <= 0) {
-      throw DlRetryEx(StringFormat(EX_SOCKET_SEND,
-				   ERR_error_string(ERR_get_error(), 0)).str());
+    if(ret < 0) {
+      throw DlRetryEx
+	(StringFormat
+	 (EX_SOCKET_SEND, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
     }
 #endif // HAVE_LIBSSL
 #ifdef HAVE_LIBGNUTLS
     while((ret = gnutls_record_send(sslSession, data, len)) ==
 	  GNUTLS_E_INTERRUPTED);
     if(ret == GNUTLS_E_AGAIN) {
+      gnutlsRecordCheckDirection();
       ret = 0;
-    }
-    if(ret < 0) {
+    } else if(ret < 0) {
       throw DlRetryEx(StringFormat(EX_SOCKET_SEND, gnutls_strerror(ret)).str());
     }
 #endif // HAVE_LIBGNUTLS
@@ -487,24 +523,45 @@ ssize_t SocketCore::writeData(const char* data, size_t len)
 void SocketCore::readData(char* data, size_t& len)
 {
   ssize_t ret = 0;
+  _wantRead = false;
+  _wantWrite = false;
 
   if(!secure) {    
     while((ret = recv(sockfd, data, len, 0)) == -1 && errno == EINTR);
+    
     if(ret == -1) {
-      throw DlRetryEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str());
+      if(errno == EAGAIN) {
+	_wantRead = true;
+	ret = 0;
+      } else {
+	throw DlRetryEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str());
+      }
     }
   } else {
 #ifdef HAVE_LIBSSL
      // for SSL
      // TODO handling len == 0 case required
-    if ((ret = SSL_read(ssl, data, len)) <= 0) {
+    ret = SSL_read(ssl, data, len);
+    if(ret == 0) {
       throw DlRetryEx
-	(StringFormat(EX_SOCKET_RECV,
-		      ERR_error_string(ERR_get_error(), 0)).str());
+	(StringFormat
+	 (EX_SOCKET_RECV, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
+    }
+    if(ret < 0) {
+      ret = sslHandleEAGAIN(ret);
+    }
+    if(ret < 0) {
+      throw DlRetryEx
+	(StringFormat
+	 (EX_SOCKET_RECV, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
     }
 #endif // HAVE_LIBSSL
 #ifdef HAVE_LIBGNUTLS
-    if ((ret = gnutlsRecv(data, len)) < 0) {
+    ret = gnutlsRecv(data, len);
+    if(ret == GNUTLS_E_AGAIN) {
+      gnutlsRecordCheckDirection();
+      ret = 0;
+    } else if(ret < 0) {
       throw DlRetryEx
 	(StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
     }
@@ -517,24 +574,45 @@ void SocketCore::readData(char* data, size_t& len)
 void SocketCore::peekData(char* data, size_t& len)
 {
   ssize_t ret = 0;
+  _wantRead = false;
+  _wantWrite = false;
 
   if(!secure) {
     while((ret = recv(sockfd, data, len, MSG_PEEK)) == -1 && errno == EINTR);
     if(ret == -1) {
-      throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, errorMsg()).str());
+      if(errno == EAGAIN) {
+	_wantRead = true;
+	ret = 0;
+      } else {
+	throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, errorMsg()).str());
+      }
     }
   } else {
 #ifdef HAVE_LIBSSL
      // for SSL
      // TODO handling len == 0 case required
-    if ((ret = SSL_peek(ssl, data, len)) < 0) {
+    ret = SSL_peek(ssl, data, len);
+    LogFactory::getInstance()->debug("len = %d", ret);
+    if(ret == 0) {
       throw DlRetryEx
 	(StringFormat(EX_SOCKET_PEEK,
-		      ERR_error_string(ERR_get_error(), 0)).str());
+		      ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
+    }
+    if(ret < 0) {
+      ret = sslHandleEAGAIN(ret);
+    }
+    if(ret < 0) {
+      throw DlRetryEx
+	(StringFormat(EX_SOCKET_PEEK,
+		      ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
     }
 #endif // HAVE_LIBSSL
 #ifdef HAVE_LIBGNUTLS
-    if ((ret = gnutlsPeek(data, len)) < 0) {
+    ret = gnutlsPeek(data, len);
+    if(ret == GNUTLS_E_AGAIN) {
+      gnutlsRecordCheckDirection();
+      ret = 0;
+    } else if(ret < 0) {
       throw DlRetryEx(StringFormat(EX_SOCKET_PEEK,
 				   gnutls_strerror(ret)).str());
     }
@@ -577,13 +655,27 @@ void SocketCore::addPeekData(char* data, size_t len)
   peekBufLength += len;
 }
 
+static ssize_t GNUTLS_RECORD_RECV_NO_INTERRUPT
+(gnutls_session_t sslSession, char* data, size_t len)
+{
+  int ret;
+  while((ret = gnutls_record_recv(sslSession, data, len)) ==
+	GNUTLS_E_INTERRUPTED);
+  if(ret < 0 && ret != GNUTLS_E_AGAIN) {
+    throw DlRetryEx
+      (StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
+  }
+  return ret;
+}
+
 ssize_t SocketCore::gnutlsRecv(char* data, size_t len)
 {
   size_t plen = shiftPeekData(data, len);
   if(plen < len) {
-    ssize_t ret = gnutls_record_recv(sslSession, data+plen, len-plen);
-    if(ret < 0) {
-      throw DlRetryEx(StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
+    ssize_t ret = GNUTLS_RECORD_RECV_NO_INTERRUPT
+      (sslSession, data+plen, len-plen);
+    if(ret == GNUTLS_E_AGAIN) {
+      return GNUTLS_E_AGAIN;
     }
     return plen+ret;
   } else {
@@ -598,9 +690,10 @@ ssize_t SocketCore::gnutlsPeek(char* data, size_t len)
     return len;
   } else {
     memcpy(data, peekBuf, peekBufLength);
-    ssize_t ret = gnutls_record_recv(sslSession, data+peekBufLength, len-peekBufLength);
-    if(ret < 0) {
-      throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, gnutls_strerror(ret)).str());
+    ssize_t ret = GNUTLS_RECORD_RECV_NO_INTERRUPT
+      (sslSession, data+peekBufLength, len-peekBufLength);
+    if(ret == GNUTLS_E_AGAIN) {
+      return GNUTLS_E_AGAIN;
     }
     addPeekData(data+peekBufLength, ret);
     return peekBufLength;
@@ -608,11 +701,11 @@ ssize_t SocketCore::gnutlsPeek(char* data, size_t len)
 }
 #endif // HAVE_LIBGNUTLS
 
-void SocketCore::initiateSecureConnection()
+void SocketCore::prepareSecureConnection()
 {
+  if(!secure) {
 #ifdef HAVE_LIBSSL
   // for SSL
-  if(!secure) {
     sslCtx = SSL_CTX_new(SSLv23_client_method());
     if(sslCtx == NULL) {
       throw DlAbortEx
@@ -631,7 +724,31 @@ void SocketCore::initiateSecureConnection()
 	(StringFormat(EX_SSL_INIT_FAILURE,
 		      ERR_error_string(ERR_get_error(), 0)).str());
     }
-     // TODO handling return value == 0 case required
+#endif // HAVE_LIBSSL
+#ifdef HAVE_LIBGNUTLS
+    const int cert_type_priority[3] = { GNUTLS_CRT_X509,
+					GNUTLS_CRT_OPENPGP, 0
+    };
+    // while we do not support X509 certificate, most web servers require
+    // X509 stuff.
+    gnutls_certificate_allocate_credentials (&sslXcred);
+    gnutls_init(&sslSession, GNUTLS_CLIENT);
+    gnutls_set_default_priority(sslSession);
+    gnutls_kx_set_priority(sslSession, cert_type_priority);
+    // put the x509 credentials to the current session
+    gnutls_credentials_set(sslSession, GNUTLS_CRD_CERTIFICATE, sslXcred);
+    gnutls_transport_set_ptr(sslSession, (gnutls_transport_ptr_t)sockfd);
+#endif // HAVE_LIBGNUTLS
+    secure = 1;
+  }
+}
+
+bool SocketCore::initiateSecureConnection()
+{
+  if(secure == 1) {
+    _wantRead = false;
+    _wantWrite = false;
+#ifdef HAVE_LIBSSL
     int e = SSL_connect(ssl);
 
     if (e <= 0) {
@@ -641,7 +758,11 @@ void SocketCore::initiateSecureConnection()
           break;
 
         case SSL_ERROR_WANT_READ:
+	  _wantRead = true;
+	  return false;
         case SSL_ERROR_WANT_WRITE:
+	  _wantWrite = true;
+	  return false;
         case SSL_ERROR_WANT_X509_LOOKUP:
         case SSL_ERROR_ZERO_RETURN:
           if (blocking) {
@@ -661,32 +782,24 @@ void SocketCore::initiateSecureConnection()
 	    (StringFormat(EX_SSL_UNKNOWN_ERROR, ssl_error).str());
       }
     }
-  }
 #endif // HAVE_LIBSSL
 #ifdef HAVE_LIBGNUTLS
-  if(!secure) {
-    const int cert_type_priority[3] = { GNUTLS_CRT_X509,
-					GNUTLS_CRT_OPENPGP, 0
-    };
-    // while we do not support X509 certificate, most web servers require
-    // X509 stuff.
-    gnutls_certificate_allocate_credentials (&sslXcred);
-    gnutls_init(&sslSession, GNUTLS_CLIENT);
-    gnutls_set_default_priority(sslSession);
-    gnutls_kx_set_priority(sslSession, cert_type_priority);
-    // put the x509 credentials to the current session
-    gnutls_credentials_set(sslSession, GNUTLS_CRD_CERTIFICATE, sslXcred);
-    gnutls_transport_set_ptr(sslSession, (gnutls_transport_ptr_t)sockfd);
     int ret = gnutls_handshake(sslSession);
-    if(ret < 0) {
+    if(ret == GNUTLS_E_AGAIN) {
+      gnutlsRecordCheckDirection();
+      return false;
+    } else if(ret < 0) {
       throw DlAbortEx
 	(StringFormat(EX_SSL_INIT_FAILURE, gnutls_strerror(ret)).str());
+    } else {
+      peekBuf = new char[peekBufMax];
     }
-    peekBuf = new char[peekBufMax];
-  }
 #endif // HAVE_LIBGNUTLS
-
-  secure = true;
+    secure = 2;
+    return true;
+  } else {
+    return true;
+  }
 }
 
 /* static */ int SocketCore::error()
@@ -783,4 +896,14 @@ std::string SocketCore::getSocketError() const
   }
 }
 
+bool SocketCore::wantRead() const
+{
+  return _wantRead;
+}
+
+bool SocketCore::wantWrite() const
+{
+  return _wantWrite;
+}
+
 } // namespace aria2

+ 39 - 7
src/SocketCore.h

@@ -77,11 +77,17 @@ private:
 #endif // HAVE_EPOLL
 
   bool blocking;
-  bool secure;
+  int secure;
+
+  bool _wantRead;
+  bool _wantWrite;
+
 #ifdef HAVE_LIBSSL
   // for SSL
   SSL_CTX* sslCtx;
   SSL* ssl;
+
+  int sslHandleEAGAIN(int ret);
 #endif // HAVE_LIBSSL
 #ifdef HAVE_LIBGNUTLS
   gnutls_session_t sslSession;
@@ -94,6 +100,8 @@ private:
   void addPeekData(char* data, size_t len);
   ssize_t gnutlsRecv(char* data, size_t len);
   ssize_t gnutlsPeek(char* data, size_t len);
+
+  void gnutlsRecordCheckDirection();
 #endif // HAVE_LIBGNUTLS
 
   void init();
@@ -105,6 +113,7 @@ private:
 #endif // HAVE_EPOLL
 
   SocketCore(sock_t sockfd, int sockType);
+
   static int error();
   static const char *errorMsg();
   static const char *errorMsg(const int err);
@@ -189,10 +198,14 @@ public:
   bool isReadable(time_t timeout);
 
   /**
-   * Writes characters into this socket. data is a pointer pointing the first
+   * Writes data into this socket. data is a pointer pointing the first
    * byte of the data and len is the length of data.
-   * This method internally calls isWritable(). The parmeter timeout is used
-   * for this method call.
+   * If the underlying socket is in blocking mode, this method may block until
+   * all data is sent.
+   * If the underlying socket is in non-blocking mode, this method may return
+   * even if all data is sent. The size of written data is returned. If
+   * underlying socket gets EAGAIN, _wantRead or _wantWrite is set accordingly.
+   * This method sets _wantRead and _wantWrite to false before do anything else.
    * @param data data to write
    * @param len length of data
    */
@@ -220,8 +233,12 @@ public:
    * byte of the data, which must be allocated before this method is called.
    * len is the size of the allocated memory. When this method returns
    * successfully, len is replaced by the size of the read data.
-   * This method internally calls isReadable(). The parameter timeout is used
-   * for this method call.
+   * If the underlying socket is in blocking mode, this method may block until
+   * at least 1byte is received.
+   * If the underlying socket is in non-blocking mode, this method may return
+   * even if no single byte is received. If the underlying socket gets EAGAIN,
+   * _wantRead or _wantWrite is set accordingly.
+   * This method sets _wantRead and _wantWrite to false before do anything else.
    * @param data holder to store data.
    * @param len the maximum size data can store. This method assigns
    * the number of bytes read to len.
@@ -265,7 +282,9 @@ public:
    * If the system has not OpenSSL, then this method do nothing.
    * connection must be established  before calling this method.
    */
-  void initiateSecureConnection();
+  bool initiateSecureConnection();
+
+  void prepareSecureConnection();
 
   bool operator==(const SocketCore& s) {
     return sockfd == s.sockfd;
@@ -280,6 +299,19 @@ public:
   }
 
   std::string getSocketError() const;
+
+  /**
+   * Returns true if the underlying socket gets EAGAIN in the previous
+   * readData() or writeData() and the socket needs more incoming data to
+   * continue the operation.
+   */
+  bool wantRead() const;
+
+  /**
+   * Returns true if the underlying socket gets EAGAIN in the previous
+   * readData() or writeData() and the socket needs to write more data.
+   */
+  bool wantWrite() const;
 };
 
 } // namespace aria2