Explorar o código

Made WebSocket handshake more strict.
Refactored HttpServer as well.

Tatsuhiro Tsujikawa %!s(int64=13) %!d(string=hai) anos
pai
achega
06b6bef860
Modificáronse 4 ficheiros con 123 adicións e 48 borrados
  1. 69 21
      src/HttpServer.cc
  2. 8 4
      src/HttpServer.h
  3. 4 4
      src/HttpServerBodyCommand.cc
  4. 42 19
      src/HttpServerCommand.cc

+ 69 - 21
src/HttpServer.cc

@@ -69,6 +69,58 @@ HttpServer::HttpServer
 
 HttpServer::~HttpServer() {}
 
+namespace {
+const char* getStatusString(int status)
+{
+  switch(status) {
+  case 100: return "100 Continue";
+  case 101: return "101 Switching Protocols";
+  case 200: return "200 OK";
+  case 201: return "201 Created";
+  case 202: return "202 Accepted";
+  case 203: return "203 Non-Authoritative Information";
+  case 204: return "204 No Content";
+  case 205: return "205 Reset Content";
+  case 206: return "206 Partial Content";
+  case 300: return "300 Multiple Choices";
+  case 301: return "301 Moved Permanently";
+  case 302: return "302 Found";
+  case 303: return "303 See Other";
+  case 304: return "304 Not Modified";
+  case 305: return "305 Use Proxy";
+    // case 306: return "306 (Unused)";
+  case 307: return "307 Temporary Redirect";
+  case 400: return "400 Bad Request";
+  case 401: return "401 Unauthorized";
+  case 402: return "402 Payment Required";
+  case 403: return "403 Forbidden";
+  case 404: return "404 Not Found";
+  case 405: return "405 Method Not Allowed";
+  case 406: return "406 Not Acceptable";
+  case 407: return "407 Proxy Authentication Required";
+  case 408: return "408 Request Timeout";
+  case 409: return "409 Conflict";
+  case 410: return "410 Gone";
+  case 411: return "411 Length Required";
+  case 412: return "412 Precondition Failed";
+  case 413: return "413 Request Entity Too Large";
+  case 414: return "414 Request-URI Too Long";
+  case 415: return "415 Unsupported Media Type";
+  case 416: return "416 Requested Range Not Satisfiable";
+  case 417: return "417 Expectation Failed";
+    // RFC 2817 defines 426 status code.
+  case 426: return "426 Upgrade Required";
+  case 500: return "500 Internal Server Error";
+  case 501: return "501 Not Implemented";
+  case 502: return "502 Bad Gateway";
+  case 503: return "503 Service Unavailable";
+  case 504: return "504 Gateway Timeout";
+  case 505: return "505 HTTP Version Not Supported";
+  default: return "";
+  }
+}
+} // namespace
+
 SharedHandle<HttpHeader> HttpServer::receiveRequest()
 {
   if(socketRecvBuffer_->bufferEmpty()) {
@@ -167,45 +219,41 @@ const std::string& HttpServer::getRequestPath() const
 
 void HttpServer::feedResponse(std::string& text, const std::string& contentType)
 {
-  feedResponse("200 OK", "", text, contentType);
+  feedResponse(200, "", text, contentType);
 }
 
-void HttpServer::feedResponse(const std::string& status,
+void HttpServer::feedResponse(int status,
                               const std::string& headers,
-                              std::string& text,
+                              const std::string& text,
                               const std::string& contentType)
 {
   std::string httpDate = Time().toHTTPDate();
   std::string header= fmt("HTTP/1.1 %s\r\n"
                           "Date: %s\r\n"
-                          "Content-Type: %s\r\n"
                           "Content-Length: %lu\r\n"
                           "Expires: %s\r\n"
-                          "Cache-Control: no-cache\r\n"
-                          "%s%s",
-                          status.c_str(),
+                          "Cache-Control: no-cache\r\n",
+                          getStatusString(status),
                           httpDate.c_str(),
-                          contentType.c_str(),
                           static_cast<unsigned long>(text.size()),
-                          httpDate.c_str(),
-                          supportsGZip() ?
-                          "Content-Encoding: gzip\r\n" : "",
-                          !supportsPersistentConnection() ?
-                          "Connection: close\r\n" : "");
+                          httpDate.c_str());
+  if(!contentType.empty()) {
+    header += "Content-Type: ";
+    header += contentType;
+    header += "\r\n";
+  }
   if(!allowOrigin_.empty()) {
     header += "Access-Control-Allow-Origin: ";
     header += allowOrigin_;
     header += "\r\n";
   }
-  if(!headers.empty()) {
-    header += headers;
-    if(headers.size() < 2 ||
-       (headers[headers.size()-2] != '\r' &&
-        headers[headers.size()-1] != '\n')) {
-      header += "\r\n";
-    }
+  if(supportsGZip()) {
+    header += "Content-Encoding: gzip\r\n";
   }
-
+  if(!supportsPersistentConnection()) {
+    header += "Connection: close\r\n";
+  }
+  header += headers;
   header += "\r\n";
   A2_LOG_DEBUG(fmt("HTTP Server sends response:\n%s", header.c_str()));
   socketBuffer_.pushStr(header);

+ 8 - 4
src/HttpServer.h

@@ -86,10 +86,14 @@ public:
 
   void feedResponse(std::string& text, const std::string& contentType);
 
-  void feedResponse(const std::string& status,
-                    const std::string& headers,
-                    std::string& text,
-                    const std::string& contentType);
+  // Feeds HTTP response with the status code |status| (e.g.,
+  // 200). The |headers| is zero or more lines of HTTP header field
+  // and each line must end with "\r\n". The |text| is the response
+  // body. The |contentType" is the content-type of the response body.
+  void feedResponse(int status,
+                    const std::string& headers = "",
+                    const std::string& text = "",
+                    const std::string& contentType = "");
 
   // Feeds "101 Switching Protocols" response. The |protocol| will
   // appear in Upgrade header field. The |headers| is zero or more

+ 4 - 4
src/HttpServerBodyCommand.cc

@@ -104,16 +104,16 @@ void HttpServerBodyCommand::sendJsonRpcResponse
                               getJsonRpcContentType(!callback.empty()));
   } else {
     httpServer_->disableKeepAlive();
-    std::string httpCode;
+    int httpCode;
     switch(res.code) {
     case -32600:
-      httpCode = "400 Bad Request";
+      httpCode = 400;
       break;
     case -32601:
-      httpCode = "404 Not Found";
+      httpCode = 404;
       break;
     default:
-      httpCode = "500 Internal Server Error";
+      httpCode = 500;
     };
     httpServer_->feedResponse(httpCode, A2STR::NIL,
                               responseData,

+ 42 - 19
src/HttpServerCommand.cc

@@ -108,6 +108,7 @@ void HttpServerCommand::checkSocketRecvBuffer()
   }
 }
 
+namespace {
 // Creates server's WebSocket accept key which will be sent in
 // Sec-WebSocket-Accept header field. The |clientKey| is the value
 // found in Sec-WebSocket-Key header field in the request.
@@ -120,6 +121,23 @@ std::string createWebSocketServerKey(const std::string& clientKey)
                          src.c_str(), src.size());
   return base64::encode(&digest[0], &digest[sizeof(digest)]);
 }
+} // namespace
+
+namespace {
+int websocketHandshake(const SharedHandle<HttpHeader>& header)
+{
+  if(header->getMethod() != "GET" ||
+     header->find("sec-websocket-key").empty()) {
+    return 400;
+  } else if(header->find("sec-websocket-version") != "13") {
+    return 426;
+  } else if(header->getRequestPath() != "/jsonrpc") {
+    return 404;
+  } else {
+    return 101;
+  }
+}
+} // namespace
 
 bool HttpServerCommand::execute()
 {
@@ -140,10 +158,8 @@ bool HttpServerCommand::execute()
       }
       if(!httpServer_->authenticate()) {
         httpServer_->disableKeepAlive();
-        std::string text;
-        httpServer_->feedResponse("401 Unauthorized",
-                                  "WWW-Authenticate: Basic realm=\"aria2\"",
-                                  text,"text/html");
+        httpServer_->feedResponse
+          (401, "WWW-Authenticate: Basic realm=\"aria2\"\r\n");
         Command* command =
           new HttpServerResponseCommand(getCuid(), httpServer_, e_, socket_);
         e_->addCommand(command);
@@ -152,21 +168,28 @@ bool HttpServerCommand::execute()
       }
       const std::string& upgradeHd = header->find("upgrade");
       const std::string& connectionHd = header->find("connection");
-      if(httpServer_->getRequestPath() == "/jsonrpc" &&
-         httpServer_->getMethod() == "GET" &&
-         util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") &&
-         util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade") &&
-         header->find("sec-websocket-version") == "13" &&
-         header->defined("sec-websocket-key")) {
-        std::string serverKey =
-          createWebSocketServerKey(header->find("sec-websocket-key"));
-        httpServer_->feedUpgradeResponse("websocket",
-                                         fmt("Sec-WebSocket-Accept: %s\r\n",
-                                             serverKey.c_str()));
-        httpServer_->getSocket()->setTcpNodelay(true);
-        Command* command =
-          new rpc::WebSocketResponseCommand(getCuid(), httpServer_, e_,
-                                            socket_);
+      if(util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") &&
+         util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade")) {
+        int status = websocketHandshake(header);
+        Command* command;
+        if(status == 101) {
+          std::string serverKey =
+            createWebSocketServerKey(header->find("sec-websocket-key"));
+          httpServer_->feedUpgradeResponse("websocket",
+                                           fmt("Sec-WebSocket-Accept: %s\r\n",
+                                               serverKey.c_str()));
+          httpServer_->getSocket()->setTcpNodelay(true);
+          command = new rpc::WebSocketResponseCommand(getCuid(), httpServer_,
+                                                      e_, socket_);
+        } else {
+          if(status == 426) {
+            httpServer_->feedResponse(426, "Sec-WebSocket-Version: 13\r\n");
+          } else {
+            httpServer_->feedResponse(status);
+          }
+          command = new HttpServerResponseCommand(getCuid(), httpServer_, e_,
+                                                  socket_);
+        }
         e_->addCommand(command);
         e_->setNoWait(true);
         return true;