Explorar o código

Authorize RPC multicalls only once.

Cache the auth status afterwards and just assume the token still matches
(within the same request, of course).
Nils Maier %!s(int64=11) %!d(string=hai) anos
pai
achega
855dfa0e2f

+ 10 - 3
src/HttpServerBodyCommand.cc

@@ -243,6 +243,7 @@ bool HttpServerBodyCommand::execute()
         case RPC_TYPE_JSONP: {
           std::string callback;
           std::unique_ptr<ValueBase> json;
+          auto preauthorized = rpc::RpcRequest::MUST_AUTHORIZE;
           ssize_t error = 0;
           if(httpServer_->getRequestType() == RPC_TYPE_JSONP) {
             json::JsonGetParam param = json::decodeGetParams(query);
@@ -273,7 +274,8 @@ bool HttpServerBodyCommand::execute()
           }
           Dict* jsondict = downcast<Dict>(json);
           if(jsondict) {
-            rpc::RpcResponse res = rpc::processJsonRpcRequest(jsondict, e_);
+            rpc::RpcResponse res =
+              rpc::processJsonRpcRequest(jsondict, e_, preauthorized);
             sendJsonRpcResponse(res, callback);
           } else {
             List* jsonlist = downcast<List>(json);
@@ -283,8 +285,13 @@ bool HttpServerBodyCommand::execute()
               for(List::ValueType::const_iterator i = jsonlist->begin(),
                     eoi = jsonlist->end(); i != eoi; ++i) {
                 Dict* jsondict = downcast<Dict>(*i);
-                if(jsondict) {
-                  results.push_back(rpc::processJsonRpcRequest(jsondict, e_));
+                if (jsondict) {
+                  auto resp =
+                      rpc::processJsonRpcRequest(jsondict, e_, preauthorized);
+                  if (resp.code == 0) {
+                    preauthorized = rpc::RpcRequest::PREAUTHORIZED;
+                  }
+                  results.push_back(std::move(resp));
                 }
               }
               sendJsonRpcBatchResponse(results, callback);

+ 3 - 1
src/RpcMethod.cc

@@ -85,9 +85,11 @@ void RpcMethod::authorize(RpcRequest& req, DownloadEngine* e)
       }
     }
   }
-  if (!e || !e->validateToken(token)) {
+  if (!e || (req.authorization != RpcRequest::PREAUTHORIZED &&
+        !e->validateToken(token))) {
     throw DL_ABORT_EX("Unauthorized");
   }
+  req.authorization = RpcRequest::PREAUTHORIZED;
 }
 
 RpcResponse RpcMethod::execute(RpcRequest req, DownloadEngine* e)

+ 10 - 2
src/RpcMethodImpl.cc

@@ -1361,6 +1361,7 @@ std::unique_ptr<ValueBase> SystemMulticallRpcMethod::process
 {
   const List* methodSpecs = checkRequiredParam<List>(req, 0);
   auto list = List::g();
+  auto auth = RpcRequest::MUST_AUTHORIZE;
   for(auto & methodSpec : *methodSpecs) {
     Dict* methodDict = downcast<Dict>(methodSpec);
     if(!methodDict) {
@@ -1388,12 +1389,19 @@ std::unique_ptr<ValueBase> SystemMulticallRpcMethod::process
     } else {
       paramsList = List::g();
     }
-    RpcResponse res = getMethod(methodName->s())->execute
-      ({methodName->s(), std::move(paramsList), nullptr, req.jsonRpc}, e);
+    RpcRequest r = {
+      methodName->s(),
+      std::move(paramsList),
+      nullptr,
+      auth,
+      req.jsonRpc
+    };
+    RpcResponse res = getMethod(methodName->s())->execute(std::move(r), e);
     if(res.code == 0) {
       auto l = List::g();
       l->append(std::move(res.param));
       list->append(std::move(l));
+      auth = RpcRequest::PREAUTHORIZED;
     } else {
       list->append(std::move(res.param));
     }

+ 4 - 3
src/RpcRequest.cc

@@ -39,21 +39,22 @@ namespace aria2 {
 namespace rpc {
 
 RpcRequest::RpcRequest()
-  : jsonRpc{false}
+  : authorization{RpcRequest::MUST_AUTHORIZE}, jsonRpc{false}
 {}
 
 RpcRequest::RpcRequest(std::string methodName,
                        std::unique_ptr<List> params)
   : methodName{std::move(methodName)}, params{std::move(params)},
-    jsonRpc{false}
+    authorization{RpcRequest::MUST_AUTHORIZE}, jsonRpc{false}
 {}
 
 RpcRequest::RpcRequest(std::string methodName,
                        std::unique_ptr<List> params,
                        std::unique_ptr<ValueBase> id,
+                       RpcRequest::authorization_t authorization,
                        bool jsonRpc)
   : methodName{std::move(methodName)}, params{std::move(params)},
-    id{std::move(id)}, jsonRpc{jsonRpc}
+    id{std::move(id)}, authorization{authorization}, jsonRpc{jsonRpc}
 {}
 
 } // namespace rpc

+ 7 - 0
src/RpcRequest.h

@@ -46,9 +46,15 @@ namespace aria2 {
 namespace rpc {
 
 struct RpcRequest {
+  enum authorization_t {
+    MUST_AUTHORIZE,
+    PREAUTHORIZED
+  };
+
   std::string methodName;
   std::unique_ptr<List> params;
   std::unique_ptr<ValueBase> id;
+  authorization_t authorization;
   bool jsonRpc;
 
   RpcRequest();
@@ -59,6 +65,7 @@ struct RpcRequest {
   RpcRequest(std::string methodName,
              std::unique_ptr<List> params,
              std::unique_ptr<ValueBase> id,
+             authorization_t authorization,
              bool jsonRpc = false);
 };
 

+ 10 - 5
src/WebSocketSession.cc

@@ -161,6 +161,7 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
     // TODO Only process text frame
     ssize_t error = 0;
     auto json = wsSession->parseFinal(nullptr, 0, error);
+    auto preauthorized = RpcRequest::MUST_AUTHORIZE;
     if(error < 0) {
       A2_LOG_INFO("Failed to parse JSON-RPC request");
       RpcResponse res
@@ -169,9 +170,10 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
       return;
     }
     Dict* jsondict = downcast<Dict>(json);
+    auto e = wsSession->getDownloadEngine();
     if(jsondict) {
-      RpcResponse res = processJsonRpcRequest(jsondict,
-                                              wsSession->getDownloadEngine());
+      RpcResponse res =
+        processJsonRpcRequest(jsondict, e, preauthorized);
       addResponse(wsSession, res);
     } else {
       List* jsonlist = downcast<List>(json);
@@ -181,9 +183,12 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
         for(List::ValueType::const_iterator i = jsonlist->begin(),
               eoi = jsonlist->end(); i != eoi; ++i) {
           Dict* jsondict = downcast<Dict>(*i);
-          if(jsondict) {
-            results.push_back(processJsonRpcRequest
-                              (jsondict, wsSession->getDownloadEngine()));
+          if (jsondict) {
+            auto resp = processJsonRpcRequest(jsondict, e, preauthorized);
+            if (resp.code == 0) {
+              preauthorized = RpcRequest::PREAUTHORIZED;
+            }
+            results.push_back(std::move(resp));
           }
         }
         addResponse(wsSession, results);

+ 5 - 3
src/rpc_helper.cc

@@ -76,7 +76,8 @@ RpcResponse createJsonRpcErrorResponse(int code,
   return rpc::RpcResponse{code, std::move(params), std::move(id)};
 }
 
-RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e)
+RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e,
+                                  RpcRequest::authorization_t authorization)
 {
   auto id = jsondict->popValue("id");
   if(!id) {
@@ -99,8 +100,9 @@ RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e)
                                       std::move(id));
   }
   A2_LOG_INFO(fmt("Executing RPC method %s", methodName->s().c_str()));
-  return getMethod(methodName->s())->execute
-    ({methodName->s(), std::move(params), std::move(id), true}, e);
+  RpcRequest req =
+    {methodName->s(), std::move(params), std::move(id), authorization, true};
+  return getMethod(methodName->s())->execute(std::move(req), e);
 }
 
 } // namespace rpc

+ 4 - 2
src/rpc_helper.h

@@ -41,6 +41,8 @@
 #include <string>
 #include <memory>
 
+#include "RpcRequest.h"
+
 namespace aria2 {
 
 class ValueBase;
@@ -49,7 +51,6 @@ class DownloadEngine;
 
 namespace rpc {
 
-struct RpcRequest;
 struct RpcResponse;
 
 #ifdef ENABLE_XML_RPC
@@ -63,7 +64,8 @@ RpcResponse createJsonRpcErrorResponse(int code,
                                        std::unique_ptr<ValueBase> id);
 
 // Processes JSON-RPC request |jsondict| and returns the result.
-RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e);
+RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e,
+                                  RpcRequest::authorization_t authorization);
 
 } // namespace rpc
 

+ 8 - 0
test/RpcMethodTest.cc

@@ -200,6 +200,14 @@ void RpcMethodTest::testAuthorize()
     auto res = m.execute(std::move(req), e_.get());
     CPPUNIT_ASSERT_EQUAL(1, res.code);
   }
+  // secret token set and bad token: prefixed parameter is given, but preauthorized
+  {
+    auto req = createReq(GetVersionRpcMethod::getMethodName());
+    req.authorization = RpcRequest::PREAUTHORIZED;
+    req.params->append("token:foo2");
+    auto res = m.execute(std::move(req), e_.get());
+    CPPUNIT_ASSERT_EQUAL(0, res.code);
+  }
 }
 
 void RpcMethodTest::testAddUri()