Ver Fonte

RPC over SSL/TLS transport

To enable RPC over SSL/TLS, specify server certificate and private key
using --rpc-certificate and --rpc-private-key options and enable
--rpc-secure option.  After the encryption is enabled, use https and
wss scheme to access RPC server.
Tatsuhiro Tsujikawa há 13 anos atrás
pai
commit
90515dfa50

+ 41 - 8
src/AbstractHttpServerResponseCommand.cc

@@ -55,15 +55,44 @@ AbstractHttpServerResponseCommand::AbstractHttpServerResponseCommand
  : Command(cuid),
    e_(e),
    socket_(socket),
-   httpServer_(httpServer)
+   httpServer_(httpServer),
+   readCheck_(false),
+   writeCheck_(true)
 {
-  setStatus(Command::STATUS_ONESHOT_REALTIME); 
+  setStatus(Command::STATUS_ONESHOT_REALTIME);
   e_->addSocketForWriteCheck(socket_, this);
 }
 
 AbstractHttpServerResponseCommand::~AbstractHttpServerResponseCommand()
 {
-  e_->deleteSocketForWriteCheck(socket_, this);
+  if(readCheck_) {
+    e_->deleteSocketForReadCheck(socket_, this);
+  }
+  if(writeCheck_) {
+    e_->deleteSocketForWriteCheck(socket_, this);
+  }
+}
+
+void AbstractHttpServerResponseCommand::updateReadWriteCheck()
+{
+  if(httpServer_->wantRead()) {
+    if(!readCheck_) {
+      readCheck_ = true;
+      e_->addSocketForReadCheck(socket_, this);
+    }
+  } else if(readCheck_) {
+    readCheck_ = false;
+    e_->deleteSocketForReadCheck(socket_, this);
+  }
+  if(httpServer_->wantWrite()) {
+    if(!writeCheck_) {
+      writeCheck_ = true;
+      e_->addSocketForWriteCheck(socket_, this);
+    }
+  } else if(writeCheck_) {
+    writeCheck_ = false;
+    e_->deleteSocketForWriteCheck(socket_, this);
+  }
 }
 
 bool AbstractHttpServerResponseCommand::execute()
@@ -72,26 +101,30 @@ bool AbstractHttpServerResponseCommand::execute()
     return true;
   }
   try {
-    httpServer_->sendResponse();
+    ssize_t len = httpServer_->sendResponse();
+    if(len > 0) {
+      timeoutTimer_ = global::wallclock();
+    }
   } catch(RecoverableException& e) {
     A2_LOG_INFO_EX
-      (fmt("CUID#%" PRId64 " - Error occurred while transmitting response body.",
+      (fmt("CUID#%"PRId64" - Error occurred while transmitting response body.",
            getCuid()),
        e);
     return true;
   }
   if(httpServer_->sendBufferIsEmpty()) {
-    A2_LOG_INFO(fmt("CUID#%" PRId64 " - HttpServer: all response transmitted.",
+    A2_LOG_INFO(fmt("CUID#%"PRId64" - HttpServer: all response transmitted.",
                     getCuid()));
     afterSend(httpServer_, e_);
     return true;
   } else {
-    if(timeoutTimer_.difference(global::wallclock()) >= 10) {
-      A2_LOG_INFO(fmt("CUID#%" PRId64 " - HttpServer: Timeout while trasmitting"
+    if(timeoutTimer_.difference(global::wallclock()) >= 30) {
+      A2_LOG_INFO(fmt("CUID#%"PRId64" - HttpServer: Timeout while trasmitting"
                       " response.",
                       getCuid()));
       return true;
     } else {
+      updateReadWriteCheck();
       e_->addCommand(this);
       return false;
     }

+ 6 - 2
src/AbstractHttpServerResponseCommand.h

@@ -51,6 +51,10 @@ private:
   SharedHandle<SocketCore> socket_;
   SharedHandle<HttpServer> httpServer_;
   Timer timeoutTimer_;
+  bool readCheck_;
+  bool writeCheck_;
+
+  void updateReadWriteCheck();
 protected:
   DownloadEngine* getDownloadEngine()
   {
@@ -66,10 +70,10 @@ public:
                                     const SharedHandle<SocketCore>& socket);
 
   virtual ~AbstractHttpServerResponseCommand();
-  
+
   virtual bool execute();
 };
 
-} // namespace aria2 
+} // namespace aria2
 
 #endif // D_ABSTRACT_HTTP_SERVER_RESPONSE_COMMAND_H

+ 6 - 1
src/DownloadEngineFactory.cc

@@ -74,6 +74,7 @@
 #include "DlAbortEx.h"
 #include "FileAllocationEntry.h"
 #include "HttpListenCommand.h"
+#include "LogFactory.h"
 
 namespace aria2 {
 
@@ -170,11 +171,15 @@ DownloadEngineFactory::newDownloadEngine
   }
   if(op->getAsBool(PREF_ENABLE_RPC)) {
     bool ok = false;
+    bool secure = op->getAsBool(PREF_RPC_SECURE);
+    if(secure) {
+      A2_LOG_NOTICE("RPC transport will be encrypted.");
+    }
     static int families[] = { AF_INET, AF_INET6 };
     size_t familiesLength = op->getAsBool(PREF_DISABLE_IPV6)?1:2;
     for(size_t i = 0; i < familiesLength; ++i) {
       HttpListenCommand* httpListenCommand =
-        new HttpListenCommand(e->newCUID(), e.get(), families[i]);
+        new HttpListenCommand(e->newCUID(), e.get(), families[i], secure);
       if(httpListenCommand->bindPort(op->getAsInt(PREF_RPC_LISTEN_PORT))){
         e->addCommand(httpListenCommand);
         ok = true;

+ 5 - 3
src/HttpListenCommand.cc

@@ -50,10 +50,12 @@
 
 namespace aria2 {
 
-HttpListenCommand::HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family)
+HttpListenCommand::HttpListenCommand(cuid_t cuid, DownloadEngine* e,
+                                     int family, bool secure)
   : Command(cuid),
     e_(e),
-    family_(family)
+    family_(family),
+    secure_(secure)
 {}
 
 HttpListenCommand::~HttpListenCommand()
@@ -80,7 +82,7 @@ bool HttpListenCommand::execute()
                       peerInfo.first.c_str(), peerInfo.second));
 
       HttpServerCommand* c =
-        new HttpServerCommand(e_->newCUID(), e_, socket);
+        new HttpServerCommand(e_->newCUID(), e_, socket, secure_);
       e_->setNoWait(true);
       e_->addCommand(c);
     }

+ 4 - 3
src/HttpListenCommand.h

@@ -48,16 +48,17 @@ private:
   DownloadEngine* e_;
   int family_;
   SharedHandle<SocketCore> serverSocket_;
+  bool secure_;
 public:
-  HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family);
+  HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family, bool secure);
 
   virtual ~HttpListenCommand();
-  
+
   virtual bool execute();
 
   bool bindPort(uint16_t port);
 };
 
-} // namespace aria2 
+} // namespace aria2
 
 #endif // D_HTTP_LISTEN_COMMAND_H

+ 1 - 2
src/HttpRequestCommand.cc

@@ -123,8 +123,7 @@ createHttpRequest(const SharedHandle<Request>& req,
 bool HttpRequestCommand::executeInternal() {
   //socket->setBlockingMode();
   if(getRequest()->getProtocol() == "https") {
-    getSocket()->prepareSecureConnection();
-    if(!getSocket()->initiateSecureConnection(getRequest()->getHost())) {
+    if(!getSocket()->tlsConnect(getRequest()->getHost())) {
       setReadCheckSocketIf(getSocket(), getSocket()->wantRead());
       setWriteCheckSocketIf(getSocket(), getSocket()->wantWrite());
       getDownloadEngine()->addCommand(this);

+ 20 - 7
src/HttpServer.cc

@@ -148,13 +148,16 @@ SharedHandle<HttpHeader> HttpServer::receiveRequest()
     if(setupResponseRecv() < 0) {
       A2_LOG_INFO("Request path is invaild. Ignore the request body.");
     }
-    if(!util::parseLLIntNoThrow(lastContentLength_,
-                                lastRequestHeader_->
-                                find(HttpHeader::CONTENT_LENGTH)) ||
-       lastContentLength_ < 0) {
-      throw DL_ABORT_EX(fmt("Invalid Content-Length=%s",
-                            lastRequestHeader_->
-                            find(HttpHeader::CONTENT_LENGTH).c_str()));
+    const std::string& contentLengthHdr = lastRequestHeader_->
+      find(HttpHeader::CONTENT_LENGTH);
+    if(!contentLengthHdr.empty()) {
+      if(!util::parseLLIntNoThrow(lastContentLength_, contentLengthHdr) ||
+         lastContentLength_ < 0) {
+        throw DL_ABORT_EX(fmt("Invalid Content-Length=%s",
+                              contentLengthHdr.c_str()));
+      }
+    } else {
+      lastContentLength_ = 0;
     }
     headerProcessor_->clear();
 
@@ -386,4 +389,14 @@ bool HttpServer::supportsPersistentConnection() const
     lastRequestHeader_ && lastRequestHeader_->isKeepAlive();
 }
 
+bool HttpServer::wantRead() const
+{
+  return socket_->wantRead();
+}
+
+bool HttpServer::wantWrite() const
+{
+  return socket_->wantWrite();
+}
+
 } // namespace aria2

+ 14 - 0
src/HttpServer.h

@@ -82,6 +82,7 @@ private:
   std::string password_;
   bool acceptsGZip_;
   std::string allowOrigin_;
+  bool secure_;
 public:
   HttpServer(const SharedHandle<SocketCore>& socket, DownloadEngine* e);
 
@@ -178,6 +179,19 @@ public:
   {
     return lastRequestHeader_;
   }
+
+  void setSecure(bool f)
+  {
+    secure_ = f;
+  }
+
+  bool getSecure() const
+  {
+    return secure_;
+  }
+
+  bool wantRead() const;
+  bool wantWrite() const;
 };
 
 } // namespace aria2

+ 21 - 2
src/HttpServerBodyCommand.cc

@@ -74,7 +74,8 @@ HttpServerBodyCommand::HttpServerBodyCommand
   : Command(cuid),
     e_(e),
     socket_(socket),
-    httpServer_(httpServer)
+    httpServer_(httpServer),
+    writeCheck_(false)
 {
   // To handle Content-Length == 0 case
   setStatus(Command::STATUS_ONESHOT_REALTIME);
@@ -87,6 +88,9 @@ HttpServerBodyCommand::HttpServerBodyCommand
 HttpServerBodyCommand::~HttpServerBodyCommand()
 {
   e_->deleteSocketForReadCheck(socket_, this);
+  if(writeCheck_) {
+    e_->deleteSocketForWriteCheck(socket_, this);
+  }
 }
 
 namespace {
@@ -144,6 +148,19 @@ void HttpServerBodyCommand::addHttpServerResponseCommand()
   e_->setNoWait(true);
 }
 
+void HttpServerBodyCommand::updateWriteCheck()
+{
+  if(httpServer_->wantWrite()) {
+    if(!writeCheck_) {
+      writeCheck_ = true;
+      e_->addSocketForWriteCheck(socket_, this);
+    }
+  } else if(writeCheck_) {
+    writeCheck_ = false;
+    e_->deleteSocketForWriteCheck(socket_, this);
+  }
+}
+
 bool HttpServerBodyCommand::execute()
 {
   if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) {
@@ -151,6 +168,7 @@ bool HttpServerBodyCommand::execute()
   }
   try {
     if(socket_->isReadable(0) ||
+       (writeCheck_ && socket_->isWritable(0)) ||
        !httpServer_->getSocketRecvBuffer()->bufferEmpty() ||
        httpServer_->getContentLength() == 0) {
       timeoutTimer_ = global::wallclock();
@@ -290,9 +308,10 @@ bool HttpServerBodyCommand::execute()
           return true;
         }
       } else {
+        updateWriteCheck();
         e_->addCommand(this);
         return false;
-      } 
+      }
     } else {
       if(timeoutTimer_.difference(global::wallclock()) >= 30) {
         A2_LOG_INFO("HTTP request body timeout.");

+ 5 - 2
src/HttpServerBodyCommand.h

@@ -53,6 +53,8 @@ private:
   SharedHandle<SocketCore> socket_;
   SharedHandle<HttpServer> httpServer_;
   Timer timeoutTimer_;
+  bool writeCheck_;
+
   void sendJsonRpcErrorResponse
   (const std::string& httpStatus,
    int code,
@@ -66,6 +68,7 @@ private:
   (const std::vector<rpc::RpcResponse>& results,
    const std::string& callback);
   void addHttpServerResponseCommand();
+  void updateWriteCheck();
 public:
   HttpServerBodyCommand(cuid_t cuid,
                         const SharedHandle<HttpServer>& httpServer,
@@ -73,10 +76,10 @@ public:
                         const SharedHandle<SocketCore>& socket);
 
   virtual ~HttpServerBodyCommand();
-  
+
   virtual bool execute();
 };
 
-} // namespace aria2 
+} // namespace aria2
 
 #endif // D_HTTP_SERVER_BODY_COMMAND_H

+ 36 - 5
src/HttpServerCommand.cc

@@ -64,14 +64,17 @@ namespace aria2 {
 HttpServerCommand::HttpServerCommand
 (cuid_t cuid,
  DownloadEngine* e,
- const SharedHandle<SocketCore>& socket)
+ const SharedHandle<SocketCore>& socket,
+ bool secure)
   : Command(cuid),
     e_(e),
     socket_(socket),
-    httpServer_(new HttpServer(socket, e))
+    httpServer_(new HttpServer(socket, e)),
+    writeCheck_(false)
 {
   setStatus(Command::STATUS_ONESHOT_REALTIME);
   e_->addSocketForReadCheck(socket_, this);
+  httpServer_->setSecure(secure);
   httpServer_->setUsernamePassword(e_->getOption()->get(PREF_RPC_USER),
                                    e_->getOption()->get(PREF_RPC_PASSWD));
   if(e_->getOption()->getAsBool(PREF_RPC_ALLOW_ORIGIN_ALL)) {
@@ -93,7 +96,8 @@ HttpServerCommand::HttpServerCommand
   : Command(cuid),
     e_(e),
     socket_(socket),
-    httpServer_(httpServer)
+    httpServer_(httpServer),
+    writeCheck_(false)
 {
   e_->addSocketForReadCheck(socket_, this);
   checkSocketRecvBuffer();
@@ -102,6 +106,9 @@ HttpServerCommand::HttpServerCommand
 HttpServerCommand::~HttpServerCommand()
 {
   e_->deleteSocketForReadCheck(socket_, this);
+  if(writeCheck_) {
+    e_->deleteSocketForWriteCheck(socket_, this);
+  }
 }
 
 void HttpServerCommand::checkSocketRecvBuffer()
@@ -147,6 +154,19 @@ int websocketHandshake(const SharedHandle<HttpHeader>& header)
 
 #endif // ENABLE_WEBSOCKET
 
+void HttpServerCommand::updateWriteCheck()
+{
+  if(httpServer_->wantWrite()) {
+    if(!writeCheck_) {
+      writeCheck_ = true;
+      e_->addSocketForWriteCheck(socket_, this);
+    }
+  } else if(writeCheck_) {
+    writeCheck_ = false;
+    e_->deleteSocketForWriteCheck(socket_, this);
+  }
+}
+
 bool HttpServerCommand::execute()
 {
   if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) {
@@ -154,13 +174,24 @@ bool HttpServerCommand::execute()
   }
   try {
     if(socket_->isReadable(0) ||
+       (writeCheck_ && socket_->isWritable(0)) ||
        !httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
       timeoutTimer_ = global::wallclock();
-      SharedHandle<HttpHeader> header;
 
-      header = httpServer_->receiveRequest();
+      if(httpServer_->getSecure()) {
+        // tlsAccept() just returns true if handshake has already
+        // finished.
+        if(!socket_->tlsAccept()) {
+          updateWriteCheck();
+          e_->addCommand(this);
+          return false;
+        }
+      }
 
+      SharedHandle<HttpHeader> header;
+      header = httpServer_->receiveRequest();
       if(!header) {
+        updateWriteCheck();
         e_->addCommand(this);
         return false;
       }

+ 6 - 3
src/HttpServerCommand.h

@@ -51,11 +51,14 @@ private:
   SharedHandle<SocketCore> socket_;
   SharedHandle<HttpServer> httpServer_;
   Timer timeoutTimer_;
+  bool writeCheck_;
 
   void checkSocketRecvBuffer();
+  void updateWriteCheck();
 public:
   HttpServerCommand(cuid_t cuid, DownloadEngine* e,
-                    const SharedHandle<SocketCore>& socket);
+                    const SharedHandle<SocketCore>& socket,
+                    bool secure);
 
   HttpServerCommand(cuid_t cuid,
                     const SharedHandle<HttpServer>& httpServer,
@@ -63,10 +66,10 @@ public:
                     const SharedHandle<SocketCore>& socket);
 
   virtual ~HttpServerCommand();
-  
+
   virtual bool execute();
 };
 
-} // namespace aria2 
+} // namespace aria2
 
 #endif // D_HTTP_SERVER_COMMAND_H

+ 8 - 6
src/LibgnutlsTLSContext.cc

@@ -45,8 +45,9 @@
 
 namespace aria2 {
 
-TLSContext::TLSContext()
+TLSContext::TLSContext(TLSSessionSide side)
   : certCred_(0),
+    side_(side),
     peerVerificationEnabled_(false)
 {
   int r = gnutls_certificate_allocate_credentials(&certCred_);
@@ -79,19 +80,20 @@ bool TLSContext::bad() const
   return !good_;
 }
 
-bool TLSContext::addClientKeyFile(const std::string& certfile,
-                                  const std::string& keyfile)
+bool TLSContext::addCredentialFile(const std::string& certfile,
+                                   const std::string& keyfile)
 {
   int ret = gnutls_certificate_set_x509_key_file(certCred_,
                                                  certfile.c_str(),
                                                  keyfile.c_str(),
                                                  GNUTLS_X509_FMT_PEM);
   if(ret == GNUTLS_E_SUCCESS) {
-    A2_LOG_INFO(fmt("Client Key File(cert=%s, key=%s) were successfully added.",
-                    certfile.c_str(), keyfile.c_str()));
+    A2_LOG_INFO(fmt
+                ("Credential files(cert=%s, key=%s) were successfully added.",
+                 certfile.c_str(), keyfile.c_str()));
     return true;
   } else {
-    A2_LOG_ERROR(fmt("Failed to load client certificate from %s and"
+    A2_LOG_ERROR(fmt("Failed to load certificate from %s and"
                      " private key from %s. Cause: %s",
                      certfile.c_str(), keyfile.c_str(),
                      gnutls_strerror(ret)));

+ 11 - 3
src/LibgnutlsTLSContext.h

@@ -41,6 +41,7 @@
 
 #include <gnutls/gnutls.h>
 
+#include "TLSContext.h"
 #include "DlAbortEx.h"
 
 namespace aria2 {
@@ -49,17 +50,19 @@ class TLSContext {
 private:
   gnutls_certificate_credentials_t certCred_;
 
+  TLSSessionSide side_;
+
   bool good_;
 
   bool peerVerificationEnabled_;
 public:
-  TLSContext();
+  TLSContext(TLSSessionSide side);
 
   ~TLSContext();
 
   // private key `keyfile' must be decrypted.
-  bool addClientKeyFile(const std::string& certfile,
-                        const std::string& keyfile);
+  bool addCredentialFile(const std::string& certfile,
+                         const std::string& keyfile);
 
   bool addSystemTrustedCACerts();
 
@@ -72,6 +75,11 @@ public:
 
   gnutls_certificate_credentials_t getCertCred() const;
 
+  TLSSessionSide getSide() const
+  {
+    return side_;
+  }
+
   void enablePeerVerification();
 
   void disablePeerVerification();

+ 10 - 9
src/LibsslTLSContext.cc

@@ -43,11 +43,12 @@
 
 namespace aria2 {
 
-TLSContext::TLSContext()
+TLSContext::TLSContext(TLSSessionSide side)
   : sslCtx_(0),
+    side_(side),
     peerVerificationEnabled_(false)
 {
-  sslCtx_ = SSL_CTX_new(SSLv23_client_method());
+  sslCtx_ = SSL_CTX_new(SSLv23_method());
   if(sslCtx_) {
     good_ = true;
   } else {
@@ -55,15 +56,15 @@ TLSContext::TLSContext()
     A2_LOG_ERROR(fmt("SSL_CTX_new() failed. Cause: %s",
                      ERR_error_string(ERR_get_error(), 0)));
   }
-  /* Disable SSLv2 and enable all workarounds for buggy servers */
+  // Disable SSLv2 and enable all workarounds for buggy servers
   SSL_CTX_set_options(sslCtx_, SSL_OP_ALL|SSL_OP_NO_SSLv2|
                       SSL_OP_NO_COMPRESSION);
   SSL_CTX_set_mode(sslCtx_, SSL_MODE_AUTO_RETRY);
+  SSL_CTX_set_mode(sslCtx_, SSL_MODE_ENABLE_PARTIAL_WRITE);
   #ifdef SSL_MODE_RELEASE_BUFFERS
   /* keep memory usage low */
   SSL_CTX_set_mode(sslCtx_, SSL_MODE_RELEASE_BUFFERS);
   #endif
-  
 }
 
 TLSContext::~TLSContext()
@@ -81,23 +82,23 @@ bool TLSContext::bad() const
   return !good_;
 }
 
-bool TLSContext::addClientKeyFile(const std::string& certfile,
-                                  const std::string& keyfile)
+bool TLSContext::addCredentialFile(const std::string& certfile,
+                                   const std::string& keyfile)
 {
   if(SSL_CTX_use_PrivateKey_file(sslCtx_, keyfile.c_str(),
                                  SSL_FILETYPE_PEM) != 1) {
-    A2_LOG_ERROR(fmt("Failed to load client private key from %s. Cause: %s",
+    A2_LOG_ERROR(fmt("Failed to load private key from %s. Cause: %s",
                      keyfile.c_str(),
                      ERR_error_string(ERR_get_error(), 0)));
     return false;
   }
   if(SSL_CTX_use_certificate_chain_file(sslCtx_, certfile.c_str()) != 1) {
-    A2_LOG_ERROR(fmt("Failed to load client certificate from %s. Cause: %s",
+    A2_LOG_ERROR(fmt("Failed to load certificate from %s. Cause: %s",
                      certfile.c_str(),
                      ERR_error_string(ERR_get_error(), 0)));
     return false;
   }
-  A2_LOG_INFO(fmt("Client Key File(cert=%s, key=%s) were successfully added.",
+  A2_LOG_INFO(fmt("Credential files(cert=%s, key=%s) were successfully added.",
                   certfile.c_str(),
                   keyfile.c_str()));
   return true;

+ 12 - 4
src/LibsslTLSContext.h

@@ -41,6 +41,7 @@
 
 # include <openssl/ssl.h>
 
+#include "TLSContext.h"
 #include "DlAbortEx.h"
 
 namespace aria2 {
@@ -49,17 +50,19 @@ class TLSContext {
 private:
   SSL_CTX* sslCtx_;
 
+  TLSSessionSide side_;
+
   bool good_;
 
   bool peerVerificationEnabled_;
 public:
-  TLSContext();
+  TLSContext(TLSSessionSide side);
 
   ~TLSContext();
 
   // private key `keyfile' must be decrypted.
-  bool addClientKeyFile(const std::string& certfile,
-                        const std::string& keyfile);
+  bool addCredentialFile(const std::string& certfile,
+                         const std::string& keyfile);
 
   bool addSystemTrustedCACerts();
 
@@ -74,7 +77,12 @@ public:
   {
     return sslCtx_;
   }
-  
+
+  TLSSessionSide getSide() const
+  {
+    return side_;
+  }
+
   void enablePeerVerification();
 
   void disablePeerVerification();

+ 28 - 9
src/MultiUrlRequestInfo.cc

@@ -137,6 +137,24 @@ error_code::Value MultiUrlRequestInfo::execute()
     Notifier notifier(wsSessionMan);
     SingletonHolder<Notifier>::instance(&notifier);
 
+#ifdef ENABLE_SSL
+    if(option_->getAsBool(PREF_ENABLE_RPC) &&
+       option_->getAsBool(PREF_RPC_SECURE)) {
+      if(!option_->blank(PREF_RPC_CERTIFICATE) &&
+         !option_->blank(PREF_RPC_PRIVATE_KEY)) {
+        // We set server TLS context to the SocketCore before creating
+        // DownloadEngine instance.
+        SharedHandle<TLSContext> svTlsContext(new TLSContext(TLS_SERVER));
+        svTlsContext->addCredentialFile(option_->get(PREF_RPC_CERTIFICATE),
+                                        option_->get(PREF_RPC_PRIVATE_KEY));
+        SocketCore::setServerTLSContext(svTlsContext);
+      } else {
+        throw DL_ABORT_EX("Specify --rpc-certificate and --rpc-private-key "
+                          "options in order to use secure RPC.");
+      }
+    }
+#endif // ENABLE_SSL
+
     SharedHandle<DownloadEngine> e =
       DownloadEngineFactory().newDownloadEngine(option_.get(), requestGroups_);
 
@@ -173,26 +191,27 @@ error_code::Value MultiUrlRequestInfo::execute()
     e->setAuthConfigFactory(authConfigFactory);
 
 #ifdef ENABLE_SSL
-    SharedHandle<TLSContext> tlsContext(new TLSContext());
+    SharedHandle<TLSContext> clTlsContext(new TLSContext(TLS_CLIENT));
     if(!option_->blank(PREF_CERTIFICATE) &&
        !option_->blank(PREF_PRIVATE_KEY)) {
-      tlsContext->addClientKeyFile(option_->get(PREF_CERTIFICATE),
-                                   option_->get(PREF_PRIVATE_KEY));
+      clTlsContext->addCredentialFile(option_->get(PREF_CERTIFICATE),
+                                      option_->get(PREF_PRIVATE_KEY));
     }
 
     if(!option_->blank(PREF_CA_CERTIFICATE)) {
-      if(!tlsContext->addTrustedCACertFile(option_->get(PREF_CA_CERTIFICATE))) {
+      if(!clTlsContext->addTrustedCACertFile
+         (option_->get(PREF_CA_CERTIFICATE))) {
         A2_LOG_INFO(MSG_WARN_NO_CA_CERT);
       }
     } else if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) {
-      if(!tlsContext->addSystemTrustedCACerts()) {
+      if(!clTlsContext->addSystemTrustedCACerts()) {
         A2_LOG_INFO(MSG_WARN_NO_CA_CERT);
       }
     }
     if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) {
-      tlsContext->enablePeerVerification();
+      clTlsContext->enablePeerVerification();
     }
-    SocketCore::setTLSContext(tlsContext);
+    SocketCore::setClientTLSContext(clTlsContext);
 #endif
 #ifdef HAVE_ARES_ADDR_NODE
     ares_addr_node* asyncDNSServers =
@@ -219,9 +238,9 @@ error_code::Value MultiUrlRequestInfo::execute()
 #endif // SIGHUP
     util::setGlobalSignalHandler(SIGINT, handler, 0);
     util::setGlobalSignalHandler(SIGTERM, handler, 0);
-    
+
     e->run();
-    
+
     if(!option_->blank(PREF_SAVE_COOKIES)) {
       e->getCookieStorage()->saveNsFormat(option_->get(PREF_SAVE_COOKIES));
     }

+ 27 - 0
src/OptionHandlerFactory.cc

@@ -747,6 +747,15 @@ std::vector<OptionHandler*> OptionHandlerFactory::createOptionHandlers()
     op->addTag(TAG_RPC);
     handlers.push_back(op);
   }
+  {
+    OptionHandler* op(new LocalFilePathOptionHandler
+                      (PREF_RPC_CERTIFICATE,
+                       TEXT_RPC_CERTIFICATE,
+                       NO_DEFAULT_VALUE,
+                       false));
+    op->addTag(TAG_RPC);
+    handlers.push_back(op);
+  }
   {
     OptionHandler* op(new BooleanOptionHandler
                       (PREF_RPC_LISTEN_ALL,
@@ -774,6 +783,24 @@ std::vector<OptionHandler*> OptionHandlerFactory::createOptionHandlers()
     op->addTag(TAG_RPC);
     handlers.push_back(op);
   }
+  {
+    OptionHandler* op(new LocalFilePathOptionHandler
+                      (PREF_RPC_PRIVATE_KEY,
+                       TEXT_RPC_PRIVATE_KEY,
+                       NO_DEFAULT_VALUE,
+                       false));
+    op->addTag(TAG_RPC);
+    handlers.push_back(op);
+  }
+  {
+    OptionHandler* op(new BooleanOptionHandler
+                      (PREF_RPC_SECURE,
+                       TEXT_RPC_SECURE,
+                       A2_V_FALSE,
+                       OptionHandler::OPT_ARG));
+    op->addTag(TAG_RPC);
+    handlers.push_back(op);
+  }
   {
     OptionHandler* op(new DefaultOptionHandler
                       (PREF_RPC_USER,

+ 90 - 58
src/SocketCore.cc

@@ -125,8 +125,6 @@ namespace {
 enum TlsState {
   // TLS object is not initialized.
   A2_TLS_NONE = 0,
-  // TLS object is initialized. Ready for handshake.
-  A2_TLS_INITIALIZED = 1,
   // TLS object is now handshaking.
   A2_TLS_HANDSHAKING = 2,
   // TLS object is now connected.
@@ -140,11 +138,19 @@ std::vector<std::pair<sockaddr_union, socklen_t> >
 SocketCore::bindAddrs_;
 
 #ifdef ENABLE_SSL
-SharedHandle<TLSContext> SocketCore::tlsContext_;
+SharedHandle<TLSContext> SocketCore::clTlsContext_;
+SharedHandle<TLSContext> SocketCore::svTlsContext_;
 
-void SocketCore::setTLSContext(const SharedHandle<TLSContext>& tlsContext)
+void SocketCore::setClientTLSContext
+(const SharedHandle<TLSContext>& tlsContext)
 {
-  tlsContext_ = tlsContext;
+  clTlsContext_ = tlsContext;
+}
+
+void SocketCore::setServerTLSContext
+(const SharedHandle<TLSContext>& tlsContext)
+{
+  svTlsContext_ = tlsContext;
 }
 #endif // ENABLE_SSL
 
@@ -818,12 +824,24 @@ void SocketCore::readData(char* data, size_t& len)
   len = ret;
 }
 
-void SocketCore::prepareSecureConnection()
+bool SocketCore::tlsAccept()
 {
-  if(!secure_) {
+  return tlsHandshake(svTlsContext_.get(), A2STR::NIL);
+}
+
+bool SocketCore::tlsConnect(const std::string& hostname)
+{
+  return tlsHandshake(clTlsContext_.get(), hostname);
+}
+
+bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
+{
+  wantRead_ = false;
+  wantWrite_ = false;
 #ifdef HAVE_OPENSSL
-    // for SSL
-    ssl = SSL_new(tlsContext_->getSSLCtx());
+  switch(secure_) {
+  case A2_TLS_NONE:
+    ssl = SSL_new(tlsctx->getSSLCtx());
     if(!ssl) {
       throw DL_ABORT_EX
         (fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0)));
@@ -832,48 +850,25 @@ void SocketCore::prepareSecureConnection()
       throw DL_ABORT_EX
         (fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0)));
     }
-#endif // HAVE_OPENSSL
-#ifdef HAVE_LIBGNUTLS
-    int r;
-    gnutls_init(&sslSession_, GNUTLS_CLIENT);
-    // It seems err is not error message, but the argument string
-    // which causes syntax error.
-    const char* err;
-    // Disables TLS1.1 here because there are servers that don't
-    // understand TLS1.1.
-    r = gnutls_priority_set_direct(sslSession_, "NORMAL:!VERS-TLS1.1", &err);
-    if(r != GNUTLS_E_SUCCESS) {
-      throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(r)));
-    }
-    // put the x509 credentials to the current session
-    gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE,
-                           tlsContext_->getCertCred());
-    gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_);
-#endif // HAVE_LIBGNUTLS
-    secure_ = A2_TLS_INITIALIZED;
-  }
-}
-
-bool SocketCore::initiateSecureConnection(const std::string& hostname)
-{
-  wantRead_ = false;
-  wantWrite_ = false;
-#ifdef HAVE_OPENSSL
-  switch(secure_) {
-  case A2_TLS_INITIALIZED:
-    secure_ = A2_TLS_HANDSHAKING;
+    // Fall through
 #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
-    if(!util::isNumericHost(hostname)) {
+    if(tlsctx->getSide() == TLS_CLIENT && !util::isNumericHost(hostname)) {
       // TLS extensions: SNI.  There is not documentation about the
       // return code for this function (actually this is macro
       // wrapping SSL_ctrl at the time of this writing).
       SSL_set_tlsext_host_name(ssl, hostname.c_str());
     }
 #endif // SSL_CTRL_SET_TLSEXT_HOSTNAME
+    secure_ = A2_TLS_HANDSHAKING;
     // Fall through
   case A2_TLS_HANDSHAKING: {
     ERR_clear_error();
-    int e = SSL_connect(ssl);
+    int e;
+    if(tlsctx->getSide() == TLS_CLIENT) {
+      e = SSL_connect(ssl);
+    } else {
+      e = SSL_accept(ssl);
+    }
 
     if (e <= 0) {
       int ssl_error = SSL_get_error(ssl, e);
@@ -893,9 +888,21 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
         }
         break;
 
-      case SSL_ERROR_SYSCALL:
-        throw DL_ABORT_EX(EX_SSL_IO_ERROR);
-
+      case SSL_ERROR_SYSCALL: {
+        int sslErr = ERR_get_error();
+        if(sslErr == 0) {
+          if(e == 0) {
+            throw DL_ABORT_EX("Got EOF in SSL handshake");
+          } else if(e == -1) {
+            throw DL_ABORT_EX(fmt("SSL I/O error: %s", strerror(errno)));
+          } else {
+            throw DL_ABORT_EX(EX_SSL_IO_ERROR);
+          }
+        } else {
+          throw DL_ABORT_EX(fmt("SSL I/O error: %s",
+                                ERR_error_string(sslErr, 0)));
+        }
+      }
       case SSL_ERROR_SSL:
         throw DL_ABORT_EX(EX_SSL_PROTOCOL_ERROR);
 
@@ -903,7 +910,8 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
         throw DL_ABORT_EX(fmt(EX_SSL_UNKNOWN_ERROR, ssl_error));
       }
     }
-    if(tlsContext_->peerVerificationEnabled()) {
+    if(tlsctx->getSide() == TLS_CLIENT &&
+       tlsctx->peerVerificationEnabled()) {
       // verify peer
       X509* peerCert = SSL_get_peer_certificate(ssl);
       if(!peerCert) {
@@ -984,20 +992,44 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
 #endif // HAVE_OPENSSL
 #ifdef HAVE_LIBGNUTLS
   switch(secure_) {
-  case A2_TLS_INITIALIZED:
-    secure_ = A2_TLS_HANDSHAKING;
-    // Check hostname is not numeric and it includes ".". Setting
-    // "localhost" will produce TLS alert.
-    if(!util::isNumericHost(hostname) &&
-       hostname.find(".") != std::string::npos) {
-      // TLS extensions: SNI
-      int ret = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS,
-                                       hostname.c_str(), hostname.size());
-      if(ret < 0) {
-        A2_LOG_WARN(fmt("Setting hostname in SNI extension failed. Cause: %s",
-                        gnutls_strerror(ret)));
+  case A2_TLS_NONE:
+    int r;
+    gnutls_init(&sslSession_,
+                tlsctx->getSide() == TLS_CLIENT ?
+                GNUTLS_CLIENT : GNUTLS_SERVER);
+    // It seems err is not error message, but the argument string
+    // which causes syntax error.
+    const char* err;
+    // For client side, disables TLS1.1 here because there are servers
+    // that don't understand TLS1.1.  TODO Is this still necessary?
+    r = gnutls_priority_set_direct(sslSession_,
+                                   tlsctx->getSide() == TLS_CLIENT ?
+                                   "NORMAL:-VERS-TLS1.1" :
+                                   "NORMAL",
+                                   &err);
+    if(r != GNUTLS_E_SUCCESS) {
+      throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(r)));
+    }
+    // put the x509 credentials to the current session
+    gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE,
+                           tlsctx->getCertCred());
+    gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_);
+    if(tlsctx->getSide() == TLS_CLIENT) {
+      // Check hostname is not numeric and it includes ".". Setting
+      // "localhost" will produce TLS alert.
+      if(!util::isNumericHost(hostname) &&
+         hostname.find(".") != std::string::npos) {
+        // TLS extensions: SNI
+        int ret = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS,
+                                         hostname.c_str(), hostname.size());
+        if(ret < 0) {
+          A2_LOG_WARN(fmt
+                      ("Setting hostname in SNI extension failed. Cause: %s",
+                       gnutls_strerror(ret)));
+        }
       }
     }
+    secure_ = A2_TLS_HANDSHAKING;
     // Fall through
   case A2_TLS_HANDSHAKING: {
     int ret = gnutls_handshake(sslSession_);
@@ -1008,7 +1040,7 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
       throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(ret)));
     }
 
-    if(tlsContext_->peerVerificationEnabled()) {
+    if(tlsctx->getSide() == TLS_CLIENT && tlsctx->peerVerificationEnabled()) {
       // verify peer
       unsigned int status;
       ret = gnutls_certificate_verify_peers2(sslSession_, &status);

+ 25 - 12
src/SocketCore.h

@@ -85,7 +85,10 @@ private:
   bool wantWrite_;
 
 #if ENABLE_SSL
-  static SharedHandle<TLSContext> tlsContext_;
+  // TLS context for client side
+  static SharedHandle<TLSContext> clTlsContext_;
+  // TLS context for server side
+  static SharedHandle<TLSContext> svTlsContext_;
 #endif // ENABLE_SSL
 
 #ifdef HAVE_OPENSSL
@@ -106,6 +109,15 @@ private:
 
   void setSockOpt(int level, int optname, void* optval, socklen_t optlen);
 
+  /**
+   * Makes this socket secure.
+   * If the system has not OpenSSL, then this method do nothing.
+   * connection must be established  before calling this method.
+   *
+   * If you are going to verify peer's certificate, hostname must be supplied.
+   */
+  bool tlsHandshake(TLSContext* tlsctx, const std::string& hostname);
+
   SocketCore(sock_t sockfd, int sockType);
 public:
   SocketCore(int sockType = SOCK_STREAM);
@@ -124,7 +136,7 @@ public:
   void joinMulticastGroup
   (const std::string& multicastAddr, uint16_t multicastPort,
    const std::string& localAddr);
-  
+
   // Enables TCP_NODELAY socket option if f == true.
   void setTcpNodelay(bool f);
 
@@ -293,16 +305,16 @@ public:
     return readDataFrom(reinterpret_cast<char*>(data), len, sender);
   }
 
-  /**
-   * Makes this socket secure.
-   * If the system has not OpenSSL, then this method do nothing.
-   * connection must be established  before calling this method.
-   *
-   * If you are going to verify peer's certificate, hostname must be supplied.
-   */
-  bool initiateSecureConnection(const std::string& hostname="");
+  // Performs TLS server side handshake. If handshake is completed,
+  // returns true. If handshake has not been done yet, returns false.
+  bool tlsAccept();
 
-  void prepareSecureConnection();
+  // Performs TLS client side handshake. If handshake is completed,
+  // returns true. If handshake has not been done yet, returns false.
+  //
+  // If you are going to verify peer's certificate, hostname must be
+  // supplied.
+  bool tlsConnect(const std::string& hostname);
 
   bool operator==(const SocketCore& s) {
     return sockfd_ == s.sockfd_;
@@ -332,7 +344,8 @@ public:
   bool wantWrite() const;
 
 #ifdef ENABLE_SSL
-  static void setTLSContext(const SharedHandle<TLSContext>& tlsContext);
+  static void setClientTLSContext(const SharedHandle<TLSContext>& tlsContext);
+  static void setServerTLSContext(const SharedHandle<TLSContext>& tlsContext);
 #endif // ENABLE_SSL
 
   static void setProtocolFamily(int protocolFamily)

+ 9 - 0
src/TLSContext.h

@@ -37,6 +37,15 @@
 
 #include "common.h"
 
+namespace aria2 {
+
+enum TLSSessionSide {
+  TLS_CLIENT,
+  TLS_SERVER
+};
+
+} // namespace aria2
+
 #ifdef HAVE_OPENSSL
 # include "LibsslTLSContext.h"
 #elif HAVE_LIBGNUTLS

+ 4 - 3
src/WebSocketInteractionCommand.cc

@@ -73,7 +73,7 @@ WebSocketInteractionCommand::~WebSocketInteractionCommand()
 
 void WebSocketInteractionCommand::updateWriteCheck()
 {
-  if(wsSession_->wantWrite()) {
+  if(socket_->wantWrite() || wsSession_->wantWrite()) {
     if(!writeCheck_) {
       writeCheck_ = true;
       e_->addSocketForWriteCheck(socket_, this);
@@ -91,9 +91,10 @@ bool WebSocketInteractionCommand::execute()
   }
   if(wsSession_->onReadEvent() == -1 || wsSession_->onWriteEvent() == -1) {
     if(wsSession_->closeSent() || wsSession_->closeReceived()) {
-      A2_LOG_INFO(fmt("CUID#%" PRId64 " - WebSocket session terminated.", getCuid()));
+      A2_LOG_INFO(fmt("CUID#%"PRId64" - WebSocket session terminated.",
+                      getCuid()));
     } else {
-      A2_LOG_INFO(fmt("CUID#%" PRId64 " - WebSocket session terminated"
+      A2_LOG_INFO(fmt("CUID#%"PRId64" - WebSocket session terminated"
                       " (Possibly due to EOF).", getCuid()));
     }
     return true;

+ 6 - 0
src/prefs.cc

@@ -270,6 +270,12 @@ const Pref* PREF_RPC_MAX_REQUEST_SIZE = makePref("rpc-max-request-size");
 const Pref* PREF_RPC_LISTEN_ALL = makePref("rpc-listen-all");
 // value: true | false
 const Pref* PREF_RPC_ALLOW_ORIGIN_ALL = makePref("rpc-allow-origin-all");
+// value: string that your file system recognizes as a file name.
+const Pref* PREF_RPC_CERTIFICATE = makePref("rpc-certificate");
+// value: string that your file system recognizes as a file name.
+const Pref* PREF_RPC_PRIVATE_KEY = makePref("rpc-private-key");
+// value: true | false
+const Pref* PREF_RPC_SECURE = makePref("rpc-secure");
 // value: true | false
 const Pref* PREF_DRY_RUN = makePref("dry-run");
 // value: true | false

+ 6 - 0
src/prefs.h

@@ -213,6 +213,12 @@ extern const Pref* PREF_RPC_MAX_REQUEST_SIZE;
 extern const Pref* PREF_RPC_LISTEN_ALL;
 // value: true | false
 extern const Pref* PREF_RPC_ALLOW_ORIGIN_ALL;
+// value: string that your file system recognizes as a file name.
+extern const Pref* PREF_RPC_CERTIFICATE;
+// value: string that your file system recognizes as a file name.
+extern const Pref* PREF_RPC_PRIVATE_KEY;
+// value: true | false
+extern const Pref* PREF_RPC_SECURE;
 // value: true | false
 extern const Pref* PREF_DRY_RUN;
 // value: true | false

+ 18 - 0
src/usage_text.h

@@ -880,3 +880,21 @@
     "                              your disk.")
 #define TEXT_ENABLE_MMAP                        \
   _(" --enable-mmap[=true|false]   Map files into memory.")
+#define TEXT_RPC_CERTIFICATE                                            \
+  _(" --rpc-certificate=FILE       Use the certificate in FILE for RPC server.\n" \
+    "                              The certificate must be in PEM format.\n" \
+    "                              Use --rpc-private-key option to specify the\n" \
+    "                              private key. Use --rpc-secure option to enable\n" \
+    "                              encryption.")
+#define TEXT_RPC_PRIVATE_KEY                                            \
+  _(" --rpc-private-key=FILE       Use the private key in FILE for RPC server.\n" \
+    "                              The private key must be decrypted and in PEM\n" \
+    "                              format. Use --rpc-secure option to enable\n" \
+    "                              encryption. See also --rpc-certificate option.")
+#define TEXT_RPC_SECURE                         \
+  _(" --rpc-secure[=true|false]    RPC transport will be encrypted by SSL/TLS.\n" \
+    "                              The RPC clients must use https scheme to access\n" \
+    "                              the server. For WebSocket client, use wss\n" \
+    "                              scheme. Use --rpc-certificate and\n" \
+    "                              --rpc-private-key options to specify the\n" \
+    "                              server certificate and private key.")