浏览代码

Parse WebSocket RPC request on the fly without buffering

Tatsuhiro Tsujikawa 13 年之前
父节点
当前提交
f56743b083
共有 2 个文件被更改,包括 83 次插入6 次删除
  1. 60 6
      src/WebSocketSession.cc
  2. 23 0
      src/WebSocketSession.h

+ 60 - 6
src/WebSocketSession.cc

@@ -45,6 +45,8 @@
 #include "rpc_helper.h"
 #include "RpcResponse.h"
 #include "json.h"
+#include "prefs.h"
+#include "Option.h"
 
 namespace aria2 {
 
@@ -122,6 +124,32 @@ void addResponse(WebSocketSession* wsSession,
 }
 } // namespace
 
+namespace {
+void onFrameRecvStartCallback
+(wslay_event_context_ptr wsctx,
+ const struct wslay_event_on_frame_recv_start_arg* arg,
+ void* userData)
+{
+  WebSocketSession* wsSession = reinterpret_cast<WebSocketSession*>(userData);
+  wsSession->setIgnorePayload(wslay_is_ctrl_frame(arg->opcode));
+}
+} // namespace
+
+namespace {
+void onFrameRecvChunkCallback
+(wslay_event_context_ptr wsctx,
+ const struct wslay_event_on_frame_recv_chunk_arg* arg,
+ void* userData)
+{
+  WebSocketSession* wsSession = reinterpret_cast<WebSocketSession*>(userData);
+  if(!wsSession->getIgnorePayload()) {
+    // The return value is ignored here. It will be evaluated in
+    // onMsgRecvCallback.
+    wsSession->parseUpdate(arg->data, arg->data_length);
+  }
+}
+} // namespace
+
 namespace {
 void onMsgRecvCallback(wslay_event_context_ptr wsctx,
                        const struct wslay_event_on_msg_recv_arg* arg,
@@ -130,11 +158,10 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
   WebSocketSession* wsSession = reinterpret_cast<WebSocketSession*>(userData);
   if(!wslay_is_ctrl_frame(arg->opcode)) {
     // TODO Only process text frame
-    SharedHandle<ValueBase> json;
-    try {
-      json = json::decode(arg->msg, arg->msg_length);
-    } catch(RecoverableException& e) {
-      A2_LOG_INFO_EX("Failed to parse JSON-RPC request", e);
+    ssize_t error = 0;
+    SharedHandle<ValueBase> json = wsSession->parseFinal(0, 0, error);
+    if(error < 0) {
+      A2_LOG_INFO("Failed to parse JSON-RPC request");
       RpcResponse res
         (createJsonRpcErrorResponse(-32700, "Parse error.", Null::g()));
       addResponse(wsSession, res);
@@ -177,15 +204,21 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
 WebSocketSession::WebSocketSession(const SharedHandle<SocketCore>& socket,
                                    DownloadEngine* e)
   : socket_(socket),
-    e_(e)
+    e_(e),
+    ignorePayload_(false),
+    receivedLength_(0)
 {
   wslay_event_callbacks callbacks;
   memset(&callbacks, 0, sizeof(wslay_event_callbacks));
   callbacks.recv_callback = recvCallback;
   callbacks.send_callback = sendCallback;
   callbacks.on_msg_recv_callback = onMsgRecvCallback;
+  callbacks.on_frame_recv_start_callback = onFrameRecvStartCallback;
+  callbacks.on_frame_recv_chunk_callback = onFrameRecvChunkCallback;
+
   int r = wslay_event_context_server_init(&wsctx_, &callbacks, this);
   assert(r == 0);
+  wslay_event_config_set_no_buffering(wsctx_, 1);
 }
     
 WebSocketSession::~WebSocketSession()
@@ -246,6 +279,27 @@ bool WebSocketSession::closeSent()
   return wslay_event_get_close_sent(wsctx_);
 }
 
+ssize_t WebSocketSession::parseUpdate(const uint8_t* data, size_t len)
+{
+  // Cap the number of bytes to feed the parser
+  size_t maxlen = e_->getOption()->getAsInt(PREF_RPC_MAX_REQUEST_SIZE);
+  if(receivedLength_ + len <= maxlen) {
+    receivedLength_ += len;
+  } else {
+    len = 0;
+  }
+  return parser_.parseUpdate(reinterpret_cast<const char*>(data), len);
+}
+
+SharedHandle<ValueBase> WebSocketSession::parseFinal
+(const uint8_t* data, size_t len, ssize_t& error)
+{
+  SharedHandle<ValueBase> res =
+    parser_.parseFinal(reinterpret_cast<const char*>(data), len, error);
+  receivedLength_ = 0;
+  return res;
+}
+
 } // namespace rpc
 
 } // namespace aria2

+ 23 - 0
src/WebSocketSession.h

@@ -40,6 +40,7 @@
 #include <wslay/wslay.h>
 
 #include "SharedHandle.h"
+#include "ValueBaseJsonParser.h"
 
 namespace aria2 {
 
@@ -77,6 +78,15 @@ public:
   bool closeReceived();
   // Returns true if the close frame is sent.
   bool closeSent();
+  // Parses parital request body. This function returns the number of
+  // bytes processed if it succeeds, or negative error code.
+  ssize_t parseUpdate(const uint8_t* data, size_t len);
+  // Parses final part of request body and returns result.  The
+  // |error| will be the number of bytes processed if this function
+  // succeeds, or negative error code. Whether success or failure,
+  // this function resets parser state and receivedLength_.
+  SharedHandle<ValueBase> parseFinal(const uint8_t* data, size_t len,
+                                     ssize_t& error);
 
   const SharedHandle<SocketCore>& getSocket() const
   {
@@ -97,10 +107,23 @@ public:
   {
     command_ = command;
   }
+
+  bool getIgnorePayload() const
+  {
+    return ignorePayload_;
+  }
+
+  void setIgnorePayload(bool flag)
+  {
+    ignorePayload_ = flag;
+  }
 private:
   SharedHandle<SocketCore> socket_;
   DownloadEngine* e_;
   wslay_event_context_ptr wsctx_;
+  bool ignorePayload_;
+  int32_t receivedLength_;
+  json::ValueBaseJsonParser parser_;
   WebSocketInteractionCommand* command_;
 };