|
@@ -59,7 +59,7 @@
|
|
|
#else
|
|
|
# define CLOSE(X) while(close(X) == -1 && errno == EINTR)
|
|
|
#endif // __MINGW32__
|
|
|
-
|
|
|
+#include "LogFactory.h"
|
|
|
namespace aria2 {
|
|
|
|
|
|
SocketCore::SocketCore(int sockType):_sockType(sockType), sockfd(-1) {
|
|
@@ -80,7 +80,11 @@ void SocketCore::init()
|
|
|
#endif // HAVE_EPOLL
|
|
|
|
|
|
blocking = true;
|
|
|
- secure = false;
|
|
|
+ secure = 0;
|
|
|
+
|
|
|
+ _wantRead = false;
|
|
|
+ _wantWrite = false;
|
|
|
+
|
|
|
#ifdef HAVE_LIBSSL
|
|
|
// for SSL
|
|
|
sslCtx = NULL;
|
|
@@ -440,42 +444,74 @@ bool SocketCore::isReadable(time_t timeout)
|
|
|
|
|
|
}
|
|
|
|
|
|
+#ifdef HAVE_LIBSSL
|
|
|
+int SocketCore::sslHandleEAGAIN(int ret)
|
|
|
+{
|
|
|
+ int error = SSL_get_error(ssl, ret);
|
|
|
+ if(error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) {
|
|
|
+ ret = 0;
|
|
|
+ if(error == SSL_ERROR_WANT_READ) {
|
|
|
+ _wantRead = true;
|
|
|
+ } else {
|
|
|
+ _wantWrite = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+#endif // HAVE_LIBSSL
|
|
|
+
|
|
|
+#ifdef HAVE_LIBGNUTLS
|
|
|
+void SocketCore::gnutlsRecordCheckDirection()
|
|
|
+{
|
|
|
+ int direction = gnutls_record_get_direction(sslSession);
|
|
|
+ if(direction == 0) {
|
|
|
+ _wantRead = true;
|
|
|
+ } else { // if(direction == 1) {
|
|
|
+ _wantWrite = true;
|
|
|
+ }
|
|
|
+}
|
|
|
+#endif // HAVE_LIBGNUTLS
|
|
|
+
|
|
|
ssize_t SocketCore::writeData(const char* data, size_t len)
|
|
|
{
|
|
|
ssize_t ret = 0;
|
|
|
+ _wantRead = false;
|
|
|
+ _wantWrite = false;
|
|
|
|
|
|
if(!secure) {
|
|
|
while((ret = send(sockfd, data, len, 0)) == -1 && errno == EINTR);
|
|
|
- if(ret == -1 && errno == EAGAIN) {
|
|
|
- ret = 0;
|
|
|
- }
|
|
|
if(ret == -1) {
|
|
|
- throw DlRetryEx(StringFormat(EX_SOCKET_SEND, errorMsg()).str());
|
|
|
+ if(errno == EAGAIN) {
|
|
|
+ _wantWrite = true;
|
|
|
+ ret = 0;
|
|
|
+ } else {
|
|
|
+ throw DlRetryEx(StringFormat(EX_SOCKET_SEND, errorMsg()).str());
|
|
|
+ }
|
|
|
}
|
|
|
} else {
|
|
|
#ifdef HAVE_LIBSSL
|
|
|
- // for SSL
|
|
|
- // TODO handling len == 0 case required
|
|
|
ret = SSL_write(ssl, data, len);
|
|
|
+ if(ret == 0) {
|
|
|
+ throw DlRetryEx
|
|
|
+ (StringFormat
|
|
|
+ (EX_SOCKET_SEND, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
|
|
|
+ }
|
|
|
if(ret < 0) {
|
|
|
- switch(SSL_get_error(ssl, ret)) {
|
|
|
- case SSL_ERROR_WANT_READ:
|
|
|
- case SSL_ERROR_WANT_WRITE:
|
|
|
- ret = 0;
|
|
|
- }
|
|
|
+ ret = sslHandleEAGAIN(ret);
|
|
|
}
|
|
|
- if(ret <= 0) {
|
|
|
- throw DlRetryEx(StringFormat(EX_SOCKET_SEND,
|
|
|
- ERR_error_string(ERR_get_error(), 0)).str());
|
|
|
+ if(ret < 0) {
|
|
|
+ throw DlRetryEx
|
|
|
+ (StringFormat
|
|
|
+ (EX_SOCKET_SEND, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
|
|
|
}
|
|
|
#endif // HAVE_LIBSSL
|
|
|
#ifdef HAVE_LIBGNUTLS
|
|
|
while((ret = gnutls_record_send(sslSession, data, len)) ==
|
|
|
GNUTLS_E_INTERRUPTED);
|
|
|
if(ret == GNUTLS_E_AGAIN) {
|
|
|
+ gnutlsRecordCheckDirection();
|
|
|
ret = 0;
|
|
|
- }
|
|
|
- if(ret < 0) {
|
|
|
+ } else if(ret < 0) {
|
|
|
throw DlRetryEx(StringFormat(EX_SOCKET_SEND, gnutls_strerror(ret)).str());
|
|
|
}
|
|
|
#endif // HAVE_LIBGNUTLS
|
|
@@ -487,24 +523,45 @@ ssize_t SocketCore::writeData(const char* data, size_t len)
|
|
|
void SocketCore::readData(char* data, size_t& len)
|
|
|
{
|
|
|
ssize_t ret = 0;
|
|
|
+ _wantRead = false;
|
|
|
+ _wantWrite = false;
|
|
|
|
|
|
if(!secure) {
|
|
|
while((ret = recv(sockfd, data, len, 0)) == -1 && errno == EINTR);
|
|
|
+
|
|
|
if(ret == -1) {
|
|
|
- throw DlRetryEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str());
|
|
|
+ if(errno == EAGAIN) {
|
|
|
+ _wantRead = true;
|
|
|
+ ret = 0;
|
|
|
+ } else {
|
|
|
+ throw DlRetryEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str());
|
|
|
+ }
|
|
|
}
|
|
|
} else {
|
|
|
#ifdef HAVE_LIBSSL
|
|
|
// for SSL
|
|
|
// TODO handling len == 0 case required
|
|
|
- if ((ret = SSL_read(ssl, data, len)) <= 0) {
|
|
|
+ ret = SSL_read(ssl, data, len);
|
|
|
+ if(ret == 0) {
|
|
|
throw DlRetryEx
|
|
|
- (StringFormat(EX_SOCKET_RECV,
|
|
|
- ERR_error_string(ERR_get_error(), 0)).str());
|
|
|
+ (StringFormat
|
|
|
+ (EX_SOCKET_RECV, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
|
|
|
+ }
|
|
|
+ if(ret < 0) {
|
|
|
+ ret = sslHandleEAGAIN(ret);
|
|
|
+ }
|
|
|
+ if(ret < 0) {
|
|
|
+ throw DlRetryEx
|
|
|
+ (StringFormat
|
|
|
+ (EX_SOCKET_RECV, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
|
|
|
}
|
|
|
#endif // HAVE_LIBSSL
|
|
|
#ifdef HAVE_LIBGNUTLS
|
|
|
- if ((ret = gnutlsRecv(data, len)) < 0) {
|
|
|
+ ret = gnutlsRecv(data, len);
|
|
|
+ if(ret == GNUTLS_E_AGAIN) {
|
|
|
+ gnutlsRecordCheckDirection();
|
|
|
+ ret = 0;
|
|
|
+ } else if(ret < 0) {
|
|
|
throw DlRetryEx
|
|
|
(StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
|
|
|
}
|
|
@@ -517,24 +574,45 @@ void SocketCore::readData(char* data, size_t& len)
|
|
|
void SocketCore::peekData(char* data, size_t& len)
|
|
|
{
|
|
|
ssize_t ret = 0;
|
|
|
+ _wantRead = false;
|
|
|
+ _wantWrite = false;
|
|
|
|
|
|
if(!secure) {
|
|
|
while((ret = recv(sockfd, data, len, MSG_PEEK)) == -1 && errno == EINTR);
|
|
|
if(ret == -1) {
|
|
|
- throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, errorMsg()).str());
|
|
|
+ if(errno == EAGAIN) {
|
|
|
+ _wantRead = true;
|
|
|
+ ret = 0;
|
|
|
+ } else {
|
|
|
+ throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, errorMsg()).str());
|
|
|
+ }
|
|
|
}
|
|
|
} else {
|
|
|
#ifdef HAVE_LIBSSL
|
|
|
// for SSL
|
|
|
// TODO handling len == 0 case required
|
|
|
- if ((ret = SSL_peek(ssl, data, len)) < 0) {
|
|
|
+ ret = SSL_peek(ssl, data, len);
|
|
|
+ LogFactory::getInstance()->debug("len = %d", ret);
|
|
|
+ if(ret == 0) {
|
|
|
throw DlRetryEx
|
|
|
(StringFormat(EX_SOCKET_PEEK,
|
|
|
- ERR_error_string(ERR_get_error(), 0)).str());
|
|
|
+ ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
|
|
|
+ }
|
|
|
+ if(ret < 0) {
|
|
|
+ ret = sslHandleEAGAIN(ret);
|
|
|
+ }
|
|
|
+ if(ret < 0) {
|
|
|
+ throw DlRetryEx
|
|
|
+ (StringFormat(EX_SOCKET_PEEK,
|
|
|
+ ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
|
|
|
}
|
|
|
#endif // HAVE_LIBSSL
|
|
|
#ifdef HAVE_LIBGNUTLS
|
|
|
- if ((ret = gnutlsPeek(data, len)) < 0) {
|
|
|
+ ret = gnutlsPeek(data, len);
|
|
|
+ if(ret == GNUTLS_E_AGAIN) {
|
|
|
+ gnutlsRecordCheckDirection();
|
|
|
+ ret = 0;
|
|
|
+ } else if(ret < 0) {
|
|
|
throw DlRetryEx(StringFormat(EX_SOCKET_PEEK,
|
|
|
gnutls_strerror(ret)).str());
|
|
|
}
|
|
@@ -577,13 +655,27 @@ void SocketCore::addPeekData(char* data, size_t len)
|
|
|
peekBufLength += len;
|
|
|
}
|
|
|
|
|
|
+static ssize_t GNUTLS_RECORD_RECV_NO_INTERRUPT
|
|
|
+(gnutls_session_t sslSession, char* data, size_t len)
|
|
|
+{
|
|
|
+ int ret;
|
|
|
+ while((ret = gnutls_record_recv(sslSession, data, len)) ==
|
|
|
+ GNUTLS_E_INTERRUPTED);
|
|
|
+ if(ret < 0 && ret != GNUTLS_E_AGAIN) {
|
|
|
+ throw DlRetryEx
|
|
|
+ (StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
|
|
|
+ }
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+
|
|
|
ssize_t SocketCore::gnutlsRecv(char* data, size_t len)
|
|
|
{
|
|
|
size_t plen = shiftPeekData(data, len);
|
|
|
if(plen < len) {
|
|
|
- ssize_t ret = gnutls_record_recv(sslSession, data+plen, len-plen);
|
|
|
- if(ret < 0) {
|
|
|
- throw DlRetryEx(StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
|
|
|
+ ssize_t ret = GNUTLS_RECORD_RECV_NO_INTERRUPT
|
|
|
+ (sslSession, data+plen, len-plen);
|
|
|
+ if(ret == GNUTLS_E_AGAIN) {
|
|
|
+ return GNUTLS_E_AGAIN;
|
|
|
}
|
|
|
return plen+ret;
|
|
|
} else {
|
|
@@ -598,9 +690,10 @@ ssize_t SocketCore::gnutlsPeek(char* data, size_t len)
|
|
|
return len;
|
|
|
} else {
|
|
|
memcpy(data, peekBuf, peekBufLength);
|
|
|
- ssize_t ret = gnutls_record_recv(sslSession, data+peekBufLength, len-peekBufLength);
|
|
|
- if(ret < 0) {
|
|
|
- throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, gnutls_strerror(ret)).str());
|
|
|
+ ssize_t ret = GNUTLS_RECORD_RECV_NO_INTERRUPT
|
|
|
+ (sslSession, data+peekBufLength, len-peekBufLength);
|
|
|
+ if(ret == GNUTLS_E_AGAIN) {
|
|
|
+ return GNUTLS_E_AGAIN;
|
|
|
}
|
|
|
addPeekData(data+peekBufLength, ret);
|
|
|
return peekBufLength;
|
|
@@ -608,11 +701,11 @@ ssize_t SocketCore::gnutlsPeek(char* data, size_t len)
|
|
|
}
|
|
|
#endif // HAVE_LIBGNUTLS
|
|
|
|
|
|
-void SocketCore::initiateSecureConnection()
|
|
|
+void SocketCore::prepareSecureConnection()
|
|
|
{
|
|
|
+ if(!secure) {
|
|
|
#ifdef HAVE_LIBSSL
|
|
|
// for SSL
|
|
|
- if(!secure) {
|
|
|
sslCtx = SSL_CTX_new(SSLv23_client_method());
|
|
|
if(sslCtx == NULL) {
|
|
|
throw DlAbortEx
|
|
@@ -631,7 +724,31 @@ void SocketCore::initiateSecureConnection()
|
|
|
(StringFormat(EX_SSL_INIT_FAILURE,
|
|
|
ERR_error_string(ERR_get_error(), 0)).str());
|
|
|
}
|
|
|
- // TODO handling return value == 0 case required
|
|
|
+#endif // HAVE_LIBSSL
|
|
|
+#ifdef HAVE_LIBGNUTLS
|
|
|
+ const int cert_type_priority[3] = { GNUTLS_CRT_X509,
|
|
|
+ GNUTLS_CRT_OPENPGP, 0
|
|
|
+ };
|
|
|
+ // while we do not support X509 certificate, most web servers require
|
|
|
+ // X509 stuff.
|
|
|
+ gnutls_certificate_allocate_credentials (&sslXcred);
|
|
|
+ gnutls_init(&sslSession, GNUTLS_CLIENT);
|
|
|
+ gnutls_set_default_priority(sslSession);
|
|
|
+ gnutls_kx_set_priority(sslSession, cert_type_priority);
|
|
|
+ // put the x509 credentials to the current session
|
|
|
+ gnutls_credentials_set(sslSession, GNUTLS_CRD_CERTIFICATE, sslXcred);
|
|
|
+ gnutls_transport_set_ptr(sslSession, (gnutls_transport_ptr_t)sockfd);
|
|
|
+#endif // HAVE_LIBGNUTLS
|
|
|
+ secure = 1;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+bool SocketCore::initiateSecureConnection()
|
|
|
+{
|
|
|
+ if(secure == 1) {
|
|
|
+ _wantRead = false;
|
|
|
+ _wantWrite = false;
|
|
|
+#ifdef HAVE_LIBSSL
|
|
|
int e = SSL_connect(ssl);
|
|
|
|
|
|
if (e <= 0) {
|
|
@@ -641,7 +758,11 @@ void SocketCore::initiateSecureConnection()
|
|
|
break;
|
|
|
|
|
|
case SSL_ERROR_WANT_READ:
|
|
|
+ _wantRead = true;
|
|
|
+ return false;
|
|
|
case SSL_ERROR_WANT_WRITE:
|
|
|
+ _wantWrite = true;
|
|
|
+ return false;
|
|
|
case SSL_ERROR_WANT_X509_LOOKUP:
|
|
|
case SSL_ERROR_ZERO_RETURN:
|
|
|
if (blocking) {
|
|
@@ -661,32 +782,24 @@ void SocketCore::initiateSecureConnection()
|
|
|
(StringFormat(EX_SSL_UNKNOWN_ERROR, ssl_error).str());
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
#endif // HAVE_LIBSSL
|
|
|
#ifdef HAVE_LIBGNUTLS
|
|
|
- if(!secure) {
|
|
|
- const int cert_type_priority[3] = { GNUTLS_CRT_X509,
|
|
|
- GNUTLS_CRT_OPENPGP, 0
|
|
|
- };
|
|
|
- // while we do not support X509 certificate, most web servers require
|
|
|
- // X509 stuff.
|
|
|
- gnutls_certificate_allocate_credentials (&sslXcred);
|
|
|
- gnutls_init(&sslSession, GNUTLS_CLIENT);
|
|
|
- gnutls_set_default_priority(sslSession);
|
|
|
- gnutls_kx_set_priority(sslSession, cert_type_priority);
|
|
|
- // put the x509 credentials to the current session
|
|
|
- gnutls_credentials_set(sslSession, GNUTLS_CRD_CERTIFICATE, sslXcred);
|
|
|
- gnutls_transport_set_ptr(sslSession, (gnutls_transport_ptr_t)sockfd);
|
|
|
int ret = gnutls_handshake(sslSession);
|
|
|
- if(ret < 0) {
|
|
|
+ if(ret == GNUTLS_E_AGAIN) {
|
|
|
+ gnutlsRecordCheckDirection();
|
|
|
+ return false;
|
|
|
+ } else if(ret < 0) {
|
|
|
throw DlAbortEx
|
|
|
(StringFormat(EX_SSL_INIT_FAILURE, gnutls_strerror(ret)).str());
|
|
|
+ } else {
|
|
|
+ peekBuf = new char[peekBufMax];
|
|
|
}
|
|
|
- peekBuf = new char[peekBufMax];
|
|
|
- }
|
|
|
#endif // HAVE_LIBGNUTLS
|
|
|
-
|
|
|
- secure = true;
|
|
|
+ secure = 2;
|
|
|
+ return true;
|
|
|
+ } else {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/* static */ int SocketCore::error()
|
|
@@ -783,4 +896,14 @@ std::string SocketCore::getSocketError() const
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+bool SocketCore::wantRead() const
|
|
|
+{
|
|
|
+ return _wantRead;
|
|
|
+}
|
|
|
+
|
|
|
+bool SocketCore::wantWrite() const
|
|
|
+{
|
|
|
+ return _wantWrite;
|
|
|
+}
|
|
|
+
|
|
|
} // namespace aria2
|