Quellcode durchsuchen

Authorize RPC multicalls only once.

Cache the auth status afterwards and just assume the token still matches
(within the same request, of course).
Nils Maier vor 11 Jahren
Ursprung
Commit
855dfa0e2f

+ 10 - 3
src/HttpServerBodyCommand.cc

@@ -243,6 +243,7 @@ bool HttpServerBodyCommand::execute()
         case RPC_TYPE_JSONP: {
           std::string callback;
           std::unique_ptr<ValueBase> json;
+          auto preauthorized = rpc::RpcRequest::MUST_AUTHORIZE;
           ssize_t error = 0;
           if(httpServer_->getRequestType() == RPC_TYPE_JSONP) {
             json::JsonGetParam param = json::decodeGetParams(query);
@@ -273,7 +274,8 @@ bool HttpServerBodyCommand::execute()
           }
           Dict* jsondict = downcast<Dict>(json);
           if(jsondict) {
-            rpc::RpcResponse res = rpc::processJsonRpcRequest(jsondict, e_);
+            rpc::RpcResponse res =
+              rpc::processJsonRpcRequest(jsondict, e_, preauthorized);
             sendJsonRpcResponse(res, callback);
           } else {
             List* jsonlist = downcast<List>(json);
@@ -283,8 +285,13 @@ bool HttpServerBodyCommand::execute()
               for(List::ValueType::const_iterator i = jsonlist->begin(),
                     eoi = jsonlist->end(); i != eoi; ++i) {
                 Dict* jsondict = downcast<Dict>(*i);
-                if(jsondict) {
-                  results.push_back(rpc::processJsonRpcRequest(jsondict, e_));
+                if (jsondict) {
+                  auto resp =
+                      rpc::processJsonRpcRequest(jsondict, e_, preauthorized);
+                  if (resp.code == 0) {
+                    preauthorized = rpc::RpcRequest::PREAUTHORIZED;
+                  }
+                  results.push_back(std::move(resp));
                 }
               }
               sendJsonRpcBatchResponse(results, callback);

+ 3 - 1
src/RpcMethod.cc

@@ -85,9 +85,11 @@ void RpcMethod::authorize(RpcRequest& req, DownloadEngine* e)
       }
     }
   }
-  if (!e || !e->validateToken(token)) {
+  if (!e || (req.authorization != RpcRequest::PREAUTHORIZED &&
+        !e->validateToken(token))) {
     throw DL_ABORT_EX("Unauthorized");
   }
+  req.authorization = RpcRequest::PREAUTHORIZED;
 }
 
 RpcResponse RpcMethod::execute(RpcRequest req, DownloadEngine* e)

+ 10 - 2
src/RpcMethodImpl.cc

@@ -1361,6 +1361,7 @@ std::unique_ptr<ValueBase> SystemMulticallRpcMethod::process
 {
   const List* methodSpecs = checkRequiredParam<List>(req, 0);
   auto list = List::g();
+  auto auth = RpcRequest::MUST_AUTHORIZE;
   for(auto & methodSpec : *methodSpecs) {
     Dict* methodDict = downcast<Dict>(methodSpec);
     if(!methodDict) {
@@ -1388,12 +1389,19 @@ std::unique_ptr<ValueBase> SystemMulticallRpcMethod::process
     } else {
       paramsList = List::g();
     }
-    RpcResponse res = getMethod(methodName->s())->execute
-      ({methodName->s(), std::move(paramsList), nullptr, req.jsonRpc}, e);
+    RpcRequest r = {
+      methodName->s(),
+      std::move(paramsList),
+      nullptr,
+      auth,
+      req.jsonRpc
+    };
+    RpcResponse res = getMethod(methodName->s())->execute(std::move(r), e);
     if(res.code == 0) {
       auto l = List::g();
       l->append(std::move(res.param));
       list->append(std::move(l));
+      auth = RpcRequest::PREAUTHORIZED;
     } else {
       list->append(std::move(res.param));
     }

+ 4 - 3
src/RpcRequest.cc

@@ -39,21 +39,22 @@ namespace aria2 {
 namespace rpc {
 
 RpcRequest::RpcRequest()
-  : jsonRpc{false}
+  : authorization{RpcRequest::MUST_AUTHORIZE}, jsonRpc{false}
 {}
 
 RpcRequest::RpcRequest(std::string methodName,
                        std::unique_ptr<List> params)
   : methodName{std::move(methodName)}, params{std::move(params)},
-    jsonRpc{false}
+    authorization{RpcRequest::MUST_AUTHORIZE}, jsonRpc{false}
 {}
 
 RpcRequest::RpcRequest(std::string methodName,
                        std::unique_ptr<List> params,
                        std::unique_ptr<ValueBase> id,
+                       RpcRequest::authorization_t authorization,
                        bool jsonRpc)
   : methodName{std::move(methodName)}, params{std::move(params)},
-    id{std::move(id)}, jsonRpc{jsonRpc}
+    id{std::move(id)}, authorization{authorization}, jsonRpc{jsonRpc}
 {}
 
 } // namespace rpc

+ 7 - 0
src/RpcRequest.h

@@ -46,9 +46,15 @@ namespace aria2 {
 namespace rpc {
 
 struct RpcRequest {
+  enum authorization_t {
+    MUST_AUTHORIZE,
+    PREAUTHORIZED
+  };
+
   std::string methodName;
   std::unique_ptr<List> params;
   std::unique_ptr<ValueBase> id;
+  authorization_t authorization;
   bool jsonRpc;
 
   RpcRequest();
@@ -59,6 +65,7 @@ struct RpcRequest {
   RpcRequest(std::string methodName,
              std::unique_ptr<List> params,
              std::unique_ptr<ValueBase> id,
+             authorization_t authorization,
              bool jsonRpc = false);
 };
 

+ 10 - 5
src/WebSocketSession.cc

@@ -161,6 +161,7 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
     // TODO Only process text frame
     ssize_t error = 0;
     auto json = wsSession->parseFinal(nullptr, 0, error);
+    auto preauthorized = RpcRequest::MUST_AUTHORIZE;
     if(error < 0) {
       A2_LOG_INFO("Failed to parse JSON-RPC request");
       RpcResponse res
@@ -169,9 +170,10 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
       return;
     }
     Dict* jsondict = downcast<Dict>(json);
+    auto e = wsSession->getDownloadEngine();
     if(jsondict) {
-      RpcResponse res = processJsonRpcRequest(jsondict,
-                                              wsSession->getDownloadEngine());
+      RpcResponse res =
+        processJsonRpcRequest(jsondict, e, preauthorized);
       addResponse(wsSession, res);
     } else {
       List* jsonlist = downcast<List>(json);
@@ -181,9 +183,12 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
         for(List::ValueType::const_iterator i = jsonlist->begin(),
               eoi = jsonlist->end(); i != eoi; ++i) {
           Dict* jsondict = downcast<Dict>(*i);
-          if(jsondict) {
-            results.push_back(processJsonRpcRequest
-                              (jsondict, wsSession->getDownloadEngine()));
+          if (jsondict) {
+            auto resp = processJsonRpcRequest(jsondict, e, preauthorized);
+            if (resp.code == 0) {
+              preauthorized = RpcRequest::PREAUTHORIZED;
+            }
+            results.push_back(std::move(resp));
           }
         }
         addResponse(wsSession, results);

+ 5 - 3
src/rpc_helper.cc

@@ -76,7 +76,8 @@ RpcResponse createJsonRpcErrorResponse(int code,
   return rpc::RpcResponse{code, std::move(params), std::move(id)};
 }
 
-RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e)
+RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e,
+                                  RpcRequest::authorization_t authorization)
 {
   auto id = jsondict->popValue("id");
   if(!id) {
@@ -99,8 +100,9 @@ RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e)
                                       std::move(id));
   }
   A2_LOG_INFO(fmt("Executing RPC method %s", methodName->s().c_str()));
-  return getMethod(methodName->s())->execute
-    ({methodName->s(), std::move(params), std::move(id), true}, e);
+  RpcRequest req =
+    {methodName->s(), std::move(params), std::move(id), authorization, true};
+  return getMethod(methodName->s())->execute(std::move(req), e);
 }
 
 } // namespace rpc

+ 4 - 2
src/rpc_helper.h

@@ -41,6 +41,8 @@
 #include <string>
 #include <memory>
 
+#include "RpcRequest.h"
+
 namespace aria2 {
 
 class ValueBase;
@@ -49,7 +51,6 @@ class DownloadEngine;
 
 namespace rpc {
 
-struct RpcRequest;
 struct RpcResponse;
 
 #ifdef ENABLE_XML_RPC
@@ -63,7 +64,8 @@ RpcResponse createJsonRpcErrorResponse(int code,
                                        std::unique_ptr<ValueBase> id);
 
 // Processes JSON-RPC request |jsondict| and returns the result.
-RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e);
+RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e,
+                                  RpcRequest::authorization_t authorization);
 
 } // namespace rpc
 

+ 8 - 0
test/RpcMethodTest.cc

@@ -200,6 +200,14 @@ void RpcMethodTest::testAuthorize()
     auto res = m.execute(std::move(req), e_.get());
     CPPUNIT_ASSERT_EQUAL(1, res.code);
   }
+  // secret token set and bad token: prefixed parameter is given, but preauthorized
+  {
+    auto req = createReq(GetVersionRpcMethod::getMethodName());
+    req.authorization = RpcRequest::PREAUTHORIZED;
+    req.params->append("token:foo2");
+    auto res = m.execute(std::move(req), e_.get());
+    CPPUNIT_ASSERT_EQUAL(0, res.code);
+  }
 }
 
 void RpcMethodTest::testAddUri()