Browse Source

Authorize RPC multicalls only once.

Cache the auth status afterwards and just assume the token still matches
(within the same request, of course).
Nils Maier 11 years ago
parent
commit
855dfa0e2f

+ 10 - 3
src/HttpServerBodyCommand.cc

@@ -243,6 +243,7 @@ bool HttpServerBodyCommand::execute()
         case RPC_TYPE_JSONP: {
         case RPC_TYPE_JSONP: {
           std::string callback;
           std::string callback;
           std::unique_ptr<ValueBase> json;
           std::unique_ptr<ValueBase> json;
+          auto preauthorized = rpc::RpcRequest::MUST_AUTHORIZE;
           ssize_t error = 0;
           ssize_t error = 0;
           if(httpServer_->getRequestType() == RPC_TYPE_JSONP) {
           if(httpServer_->getRequestType() == RPC_TYPE_JSONP) {
             json::JsonGetParam param = json::decodeGetParams(query);
             json::JsonGetParam param = json::decodeGetParams(query);
@@ -273,7 +274,8 @@ bool HttpServerBodyCommand::execute()
           }
           }
           Dict* jsondict = downcast<Dict>(json);
           Dict* jsondict = downcast<Dict>(json);
           if(jsondict) {
           if(jsondict) {
-            rpc::RpcResponse res = rpc::processJsonRpcRequest(jsondict, e_);
+            rpc::RpcResponse res =
+              rpc::processJsonRpcRequest(jsondict, e_, preauthorized);
             sendJsonRpcResponse(res, callback);
             sendJsonRpcResponse(res, callback);
           } else {
           } else {
             List* jsonlist = downcast<List>(json);
             List* jsonlist = downcast<List>(json);
@@ -283,8 +285,13 @@ bool HttpServerBodyCommand::execute()
               for(List::ValueType::const_iterator i = jsonlist->begin(),
               for(List::ValueType::const_iterator i = jsonlist->begin(),
                     eoi = jsonlist->end(); i != eoi; ++i) {
                     eoi = jsonlist->end(); i != eoi; ++i) {
                 Dict* jsondict = downcast<Dict>(*i);
                 Dict* jsondict = downcast<Dict>(*i);
-                if(jsondict) {
-                  results.push_back(rpc::processJsonRpcRequest(jsondict, e_));
+                if (jsondict) {
+                  auto resp =
+                      rpc::processJsonRpcRequest(jsondict, e_, preauthorized);
+                  if (resp.code == 0) {
+                    preauthorized = rpc::RpcRequest::PREAUTHORIZED;
+                  }
+                  results.push_back(std::move(resp));
                 }
                 }
               }
               }
               sendJsonRpcBatchResponse(results, callback);
               sendJsonRpcBatchResponse(results, callback);

+ 3 - 1
src/RpcMethod.cc

@@ -85,9 +85,11 @@ void RpcMethod::authorize(RpcRequest& req, DownloadEngine* e)
       }
       }
     }
     }
   }
   }
-  if (!e || !e->validateToken(token)) {
+  if (!e || (req.authorization != RpcRequest::PREAUTHORIZED &&
+        !e->validateToken(token))) {
     throw DL_ABORT_EX("Unauthorized");
     throw DL_ABORT_EX("Unauthorized");
   }
   }
+  req.authorization = RpcRequest::PREAUTHORIZED;
 }
 }
 
 
 RpcResponse RpcMethod::execute(RpcRequest req, DownloadEngine* e)
 RpcResponse RpcMethod::execute(RpcRequest req, DownloadEngine* e)

+ 10 - 2
src/RpcMethodImpl.cc

@@ -1361,6 +1361,7 @@ std::unique_ptr<ValueBase> SystemMulticallRpcMethod::process
 {
 {
   const List* methodSpecs = checkRequiredParam<List>(req, 0);
   const List* methodSpecs = checkRequiredParam<List>(req, 0);
   auto list = List::g();
   auto list = List::g();
+  auto auth = RpcRequest::MUST_AUTHORIZE;
   for(auto & methodSpec : *methodSpecs) {
   for(auto & methodSpec : *methodSpecs) {
     Dict* methodDict = downcast<Dict>(methodSpec);
     Dict* methodDict = downcast<Dict>(methodSpec);
     if(!methodDict) {
     if(!methodDict) {
@@ -1388,12 +1389,19 @@ std::unique_ptr<ValueBase> SystemMulticallRpcMethod::process
     } else {
     } else {
       paramsList = List::g();
       paramsList = List::g();
     }
     }
-    RpcResponse res = getMethod(methodName->s())->execute
-      ({methodName->s(), std::move(paramsList), nullptr, req.jsonRpc}, e);
+    RpcRequest r = {
+      methodName->s(),
+      std::move(paramsList),
+      nullptr,
+      auth,
+      req.jsonRpc
+    };
+    RpcResponse res = getMethod(methodName->s())->execute(std::move(r), e);
     if(res.code == 0) {
     if(res.code == 0) {
       auto l = List::g();
       auto l = List::g();
       l->append(std::move(res.param));
       l->append(std::move(res.param));
       list->append(std::move(l));
       list->append(std::move(l));
+      auth = RpcRequest::PREAUTHORIZED;
     } else {
     } else {
       list->append(std::move(res.param));
       list->append(std::move(res.param));
     }
     }

+ 4 - 3
src/RpcRequest.cc

@@ -39,21 +39,22 @@ namespace aria2 {
 namespace rpc {
 namespace rpc {
 
 
 RpcRequest::RpcRequest()
 RpcRequest::RpcRequest()
-  : jsonRpc{false}
+  : authorization{RpcRequest::MUST_AUTHORIZE}, jsonRpc{false}
 {}
 {}
 
 
 RpcRequest::RpcRequest(std::string methodName,
 RpcRequest::RpcRequest(std::string methodName,
                        std::unique_ptr<List> params)
                        std::unique_ptr<List> params)
   : methodName{std::move(methodName)}, params{std::move(params)},
   : methodName{std::move(methodName)}, params{std::move(params)},
-    jsonRpc{false}
+    authorization{RpcRequest::MUST_AUTHORIZE}, jsonRpc{false}
 {}
 {}
 
 
 RpcRequest::RpcRequest(std::string methodName,
 RpcRequest::RpcRequest(std::string methodName,
                        std::unique_ptr<List> params,
                        std::unique_ptr<List> params,
                        std::unique_ptr<ValueBase> id,
                        std::unique_ptr<ValueBase> id,
+                       RpcRequest::authorization_t authorization,
                        bool jsonRpc)
                        bool jsonRpc)
   : methodName{std::move(methodName)}, params{std::move(params)},
   : methodName{std::move(methodName)}, params{std::move(params)},
-    id{std::move(id)}, jsonRpc{jsonRpc}
+    id{std::move(id)}, authorization{authorization}, jsonRpc{jsonRpc}
 {}
 {}
 
 
 } // namespace rpc
 } // namespace rpc

+ 7 - 0
src/RpcRequest.h

@@ -46,9 +46,15 @@ namespace aria2 {
 namespace rpc {
 namespace rpc {
 
 
 struct RpcRequest {
 struct RpcRequest {
+  enum authorization_t {
+    MUST_AUTHORIZE,
+    PREAUTHORIZED
+  };
+
   std::string methodName;
   std::string methodName;
   std::unique_ptr<List> params;
   std::unique_ptr<List> params;
   std::unique_ptr<ValueBase> id;
   std::unique_ptr<ValueBase> id;
+  authorization_t authorization;
   bool jsonRpc;
   bool jsonRpc;
 
 
   RpcRequest();
   RpcRequest();
@@ -59,6 +65,7 @@ struct RpcRequest {
   RpcRequest(std::string methodName,
   RpcRequest(std::string methodName,
              std::unique_ptr<List> params,
              std::unique_ptr<List> params,
              std::unique_ptr<ValueBase> id,
              std::unique_ptr<ValueBase> id,
+             authorization_t authorization,
              bool jsonRpc = false);
              bool jsonRpc = false);
 };
 };
 
 

+ 10 - 5
src/WebSocketSession.cc

@@ -161,6 +161,7 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
     // TODO Only process text frame
     // TODO Only process text frame
     ssize_t error = 0;
     ssize_t error = 0;
     auto json = wsSession->parseFinal(nullptr, 0, error);
     auto json = wsSession->parseFinal(nullptr, 0, error);
+    auto preauthorized = RpcRequest::MUST_AUTHORIZE;
     if(error < 0) {
     if(error < 0) {
       A2_LOG_INFO("Failed to parse JSON-RPC request");
       A2_LOG_INFO("Failed to parse JSON-RPC request");
       RpcResponse res
       RpcResponse res
@@ -169,9 +170,10 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
       return;
       return;
     }
     }
     Dict* jsondict = downcast<Dict>(json);
     Dict* jsondict = downcast<Dict>(json);
+    auto e = wsSession->getDownloadEngine();
     if(jsondict) {
     if(jsondict) {
-      RpcResponse res = processJsonRpcRequest(jsondict,
-                                              wsSession->getDownloadEngine());
+      RpcResponse res =
+        processJsonRpcRequest(jsondict, e, preauthorized);
       addResponse(wsSession, res);
       addResponse(wsSession, res);
     } else {
     } else {
       List* jsonlist = downcast<List>(json);
       List* jsonlist = downcast<List>(json);
@@ -181,9 +183,12 @@ void onMsgRecvCallback(wslay_event_context_ptr wsctx,
         for(List::ValueType::const_iterator i = jsonlist->begin(),
         for(List::ValueType::const_iterator i = jsonlist->begin(),
               eoi = jsonlist->end(); i != eoi; ++i) {
               eoi = jsonlist->end(); i != eoi; ++i) {
           Dict* jsondict = downcast<Dict>(*i);
           Dict* jsondict = downcast<Dict>(*i);
-          if(jsondict) {
-            results.push_back(processJsonRpcRequest
-                              (jsondict, wsSession->getDownloadEngine()));
+          if (jsondict) {
+            auto resp = processJsonRpcRequest(jsondict, e, preauthorized);
+            if (resp.code == 0) {
+              preauthorized = RpcRequest::PREAUTHORIZED;
+            }
+            results.push_back(std::move(resp));
           }
           }
         }
         }
         addResponse(wsSession, results);
         addResponse(wsSession, results);

+ 5 - 3
src/rpc_helper.cc

@@ -76,7 +76,8 @@ RpcResponse createJsonRpcErrorResponse(int code,
   return rpc::RpcResponse{code, std::move(params), std::move(id)};
   return rpc::RpcResponse{code, std::move(params), std::move(id)};
 }
 }
 
 
-RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e)
+RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e,
+                                  RpcRequest::authorization_t authorization)
 {
 {
   auto id = jsondict->popValue("id");
   auto id = jsondict->popValue("id");
   if(!id) {
   if(!id) {
@@ -99,8 +100,9 @@ RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e)
                                       std::move(id));
                                       std::move(id));
   }
   }
   A2_LOG_INFO(fmt("Executing RPC method %s", methodName->s().c_str()));
   A2_LOG_INFO(fmt("Executing RPC method %s", methodName->s().c_str()));
-  return getMethod(methodName->s())->execute
-    ({methodName->s(), std::move(params), std::move(id), true}, e);
+  RpcRequest req =
+    {methodName->s(), std::move(params), std::move(id), authorization, true};
+  return getMethod(methodName->s())->execute(std::move(req), e);
 }
 }
 
 
 } // namespace rpc
 } // namespace rpc

+ 4 - 2
src/rpc_helper.h

@@ -41,6 +41,8 @@
 #include <string>
 #include <string>
 #include <memory>
 #include <memory>
 
 
+#include "RpcRequest.h"
+
 namespace aria2 {
 namespace aria2 {
 
 
 class ValueBase;
 class ValueBase;
@@ -49,7 +51,6 @@ class DownloadEngine;
 
 
 namespace rpc {
 namespace rpc {
 
 
-struct RpcRequest;
 struct RpcResponse;
 struct RpcResponse;
 
 
 #ifdef ENABLE_XML_RPC
 #ifdef ENABLE_XML_RPC
@@ -63,7 +64,8 @@ RpcResponse createJsonRpcErrorResponse(int code,
                                        std::unique_ptr<ValueBase> id);
                                        std::unique_ptr<ValueBase> id);
 
 
 // Processes JSON-RPC request |jsondict| and returns the result.
 // Processes JSON-RPC request |jsondict| and returns the result.
-RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e);
+RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e,
+                                  RpcRequest::authorization_t authorization);
 
 
 } // namespace rpc
 } // namespace rpc
 
 

+ 8 - 0
test/RpcMethodTest.cc

@@ -200,6 +200,14 @@ void RpcMethodTest::testAuthorize()
     auto res = m.execute(std::move(req), e_.get());
     auto res = m.execute(std::move(req), e_.get());
     CPPUNIT_ASSERT_EQUAL(1, res.code);
     CPPUNIT_ASSERT_EQUAL(1, res.code);
   }
   }
+  // secret token set and bad token: prefixed parameter is given, but preauthorized
+  {
+    auto req = createReq(GetVersionRpcMethod::getMethodName());
+    req.authorization = RpcRequest::PREAUTHORIZED;
+    req.params->append("token:foo2");
+    auto res = m.execute(std::move(req), e_.get());
+    CPPUNIT_ASSERT_EQUAL(0, res.code);
+  }
 }
 }
 
 
 void RpcMethodTest::testAddUri()
 void RpcMethodTest::testAddUri()