Procházet zdrojové kódy

Use std::unique_ptr for StreamFilter instead of std::shared_ptr

Tatsuhiro Tsujikawa před 12 roky
rodič
revize
57f1902ee1

+ 7 - 6
src/ChunkedDecodingStreamFilter.cc

@@ -65,12 +65,13 @@ enum {
 } // namespace
 
 ChunkedDecodingStreamFilter::ChunkedDecodingStreamFilter
-(const std::shared_ptr<StreamFilter>& delegate):
-  StreamFilter(delegate),
-  state_(PREV_CHUNK_SIZE),
-  chunkSize_(0),
-  chunkRemaining_(0),
-  bytesProcessed_(0) {}
+(std::unique_ptr<StreamFilter> delegate)
+  : StreamFilter{std::move(delegate)},
+    state_{PREV_CHUNK_SIZE},
+    chunkSize_{0},
+    chunkRemaining_{0},
+    bytesProcessed_{0}
+{}
 
 ChunkedDecodingStreamFilter::~ChunkedDecodingStreamFilter() {}
 

+ 1 - 1
src/ChunkedDecodingStreamFilter.h

@@ -47,7 +47,7 @@ private:
   size_t bytesProcessed_;
 public:
   ChunkedDecodingStreamFilter
-  (const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>());
+  (std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{});
 
   virtual ~ChunkedDecodingStreamFilter();
 

+ 5 - 6
src/DownloadCommand.cc

@@ -108,9 +108,8 @@ DownloadCommand::DownloadCommand
   peerStat_->downloadStart();
   getSegmentMan()->registerPeerStat(peerStat_);
 
-  WrDiskCache* wrDiskCache = getPieceStorage()->getWrDiskCache();
-  streamFilter_.reset(new SinkStreamFilter(wrDiskCache,
-                                           pieceHashValidationEnabled_));
+  streamFilter_ = make_unique<SinkStreamFilter>
+    (getPieceStorage()->getWrDiskCache(), pieceHashValidationEnabled_);
   streamFilter_->init();
   sinkFilterOnly_ = true;
   checkSocketRecvBuffer();
@@ -410,13 +409,13 @@ void DownloadCommand::completeSegment(cuid_t cuid,
 }
 
 void DownloadCommand::installStreamFilter
-(const std::shared_ptr<StreamFilter>& streamFilter)
+(std::unique_ptr<StreamFilter> streamFilter)
 {
   if(!streamFilter) {
     return;
   }
-  streamFilter->installDelegate(streamFilter_);
-  streamFilter_ = streamFilter;
+  streamFilter->installDelegate(std::move(streamFilter_));
+  streamFilter_ = std::move(streamFilter);
   const std::string& name = streamFilter_->getName();
   sinkFilterOnly_ = util::endsWith(name, SinkStreamFilter::NAME);
 }

+ 3 - 3
src/DownloadCommand.h

@@ -69,7 +69,7 @@ private:
 
   void completeSegment(cuid_t cuid, const std::shared_ptr<Segment>& segment);
 
-  std::shared_ptr<StreamFilter> streamFilter_;
+  std::unique_ptr<StreamFilter> streamFilter_;
 
   bool sinkFilterOnly_;
 protected:
@@ -89,12 +89,12 @@ public:
                   const std::shared_ptr<SocketRecvBuffer>& socketRecvBuffer);
   virtual ~DownloadCommand();
 
-  const std::shared_ptr<StreamFilter>& getStreamFilter() const
+  const std::unique_ptr<StreamFilter>& getStreamFilter() const
   {
     return streamFilter_;
   }
 
-  void installStreamFilter(const std::shared_ptr<StreamFilter>& streamFilter);
+  void installStreamFilter(std::unique_ptr<StreamFilter> streamFilter);
 
   void setStartupIdleTime(time_t startupIdleTime)
   {

+ 6 - 2
src/GZipDecodingStreamFilter.cc

@@ -44,8 +44,12 @@ namespace aria2 {
 const std::string GZipDecodingStreamFilter::NAME("GZipDecodingStreamFilter");
 
 GZipDecodingStreamFilter::GZipDecodingStreamFilter
-(const std::shared_ptr<StreamFilter>& delegate):
-  StreamFilter(delegate), strm_(0), finished_(false), bytesProcessed_(0) {}
+(std::unique_ptr<StreamFilter> delegate)
+  : StreamFilter{std::move(delegate)},
+    strm_{nullptr},
+    finished_{false},
+    bytesProcessed_{0}
+{}
 
 GZipDecodingStreamFilter::~GZipDecodingStreamFilter()
 {

+ 1 - 1
src/GZipDecodingStreamFilter.h

@@ -52,7 +52,7 @@ private:
   static const size_t OUTBUF_LENGTH = 16*1024;
 public:
   GZipDecodingStreamFilter
-  (const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>());
+  (std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{});
 
   virtual ~GZipDecodingStreamFilter();
 

+ 8 - 8
src/HttpResponse.cc

@@ -189,17 +189,17 @@ const std::string& HttpResponse::getTransferEncoding() const
   return httpHeader_->find(HttpHeader::TRANSFER_ENCODING);
 }
 
-std::shared_ptr<StreamFilter> HttpResponse::getTransferEncodingStreamFilter() const
+std::unique_ptr<StreamFilter>
+HttpResponse::getTransferEncodingStreamFilter() const
 {
-  std::shared_ptr<StreamFilter> filter;
   // TODO Transfer-Encoding header field can contains multiple tokens. We should
   // parse the field and retrieve each token.
   if(isTransferEncodingSpecified()) {
     if(util::strieq(getTransferEncoding(), "chunked")) {
-      filter.reset(new ChunkedDecodingStreamFilter());
+      return make_unique<ChunkedDecodingStreamFilter>();
     }
   }
-  return filter;
+  return std::unique_ptr<StreamFilter>{};
 }
 
 bool HttpResponse::isContentEncodingSpecified() const
@@ -212,16 +212,16 @@ const std::string& HttpResponse::getContentEncoding() const
   return httpHeader_->find(HttpHeader::CONTENT_ENCODING);
 }
 
-std::shared_ptr<StreamFilter> HttpResponse::getContentEncodingStreamFilter() const
+std::unique_ptr<StreamFilter>
+HttpResponse::getContentEncodingStreamFilter() const
 {
-  std::shared_ptr<StreamFilter> filter;
 #ifdef HAVE_ZLIB
   if(util::strieq(getContentEncoding(), "gzip") ||
      util::strieq(getContentEncoding(), "deflate")) {
-    filter.reset(new GZipDecodingStreamFilter());
+    return make_unique<GZipDecodingStreamFilter>();
   }
 #endif // HAVE_ZLIB
-  return filter;
+  return std::unique_ptr<StreamFilter>{};
 }
 
 int64_t HttpResponse::getContentLength() const

+ 2 - 2
src/HttpResponse.h

@@ -86,13 +86,13 @@ public:
 
   const std::string& getTransferEncoding() const;
 
-  std::shared_ptr<StreamFilter> getTransferEncodingStreamFilter() const;
+  std::unique_ptr<StreamFilter> getTransferEncodingStreamFilter() const;
 
   bool isContentEncodingSpecified() const;
 
   const std::string& getContentEncoding() const;
 
-  std::shared_ptr<StreamFilter> getContentEncodingStreamFilter() const;
+  std::unique_ptr<StreamFilter> getContentEncodingStreamFilter() const;
 
   int64_t getContentLength() const;
 

+ 27 - 31
src/HttpResponseCommand.cc

@@ -84,51 +84,44 @@
 namespace aria2 {
 
 namespace {
-std::shared_ptr<StreamFilter> getTransferEncodingStreamFilter
+std::unique_ptr<StreamFilter> getTransferEncodingStreamFilter
 (HttpResponse* httpResponse,
- const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>())
+ std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{})
 {
-  std::shared_ptr<StreamFilter> filter;
   if(httpResponse->isTransferEncodingSpecified()) {
-    filter = httpResponse->getTransferEncodingStreamFilter();
+    auto filter = httpResponse->getTransferEncodingStreamFilter();
     if(!filter) {
       throw DL_ABORT_EX
         (fmt(EX_TRANSFER_ENCODING_NOT_SUPPORTED,
              httpResponse->getTransferEncoding().c_str()));
     }
     filter->init();
-    filter->installDelegate(delegate);
+    filter->installDelegate(std::move(delegate));
+    return filter;
   }
-  if(!filter) {
-    filter = delegate;
-  }
-  return filter;
+  return delegate;
 }
 } // namespace
 
 namespace {
-std::shared_ptr<StreamFilter> getContentEncodingStreamFilter
+std::unique_ptr<StreamFilter> getContentEncodingStreamFilter
 (HttpResponse* httpResponse,
- const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>())
+ std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{})
 {
-  std::shared_ptr<StreamFilter> filter;
   if(httpResponse->isContentEncodingSpecified()) {
-    filter = httpResponse->getContentEncodingStreamFilter();
+    auto filter = httpResponse->getContentEncodingStreamFilter();
     if(!filter) {
       A2_LOG_INFO
         (fmt("Content-Encoding %s is specified, but the current implementation"
              "doesn't support it. The decoding process is skipped and the"
              "downloaded content will be still encoded.",
              httpResponse->getContentEncoding().c_str()));
-    } else {
-      filter->init();
-      filter->installDelegate(delegate);
     }
+    filter->init();
+    filter->installDelegate(std::move(delegate));
+    return filter;
   }
-  if(!filter) {
-    filter = delegate;
-  }
-  return filter;
+  return delegate;
 }
 } // namespace
 
@@ -311,11 +304,13 @@ bool HttpResponseCommand::executeInternal()
         (httpResponse.get(),
          getContentEncodingStreamFilter(httpResponse.get()));
       getDownloadEngine()->addCommand
-        (createHttpDownloadCommand(std::move(httpResponse), teFilter));
+        (createHttpDownloadCommand(std::move(httpResponse),
+                                   std::move(teFilter)));
     } else {
       auto teFilter = getTransferEncodingStreamFilter(httpResponse.get());
       getDownloadEngine()->addCommand
-        (createHttpDownloadCommand(std::move(httpResponse), teFilter));
+        (createHttpDownloadCommand(std::move(httpResponse),
+                                   std::move(teFilter)));
     }
     return true;
   }
@@ -375,7 +370,8 @@ bool HttpResponseCommand::handleDefaultEncoding
      !getRequest()->isPipeliningEnabled()) {
     auto teFilter = getTransferEncodingStreamFilter(httpResponse.get());
     checkEntry->pushNextCommand
-      (createHttpDownloadCommand(std::move(httpResponse), teFilter));
+      (createHttpDownloadCommand(std::move(httpResponse),
+                                 std::move(teFilter)));
   } else {
     getSegmentMan()->cancelSegment(getCuid());
     getFileEntry()->poolRequest(getRequest());
@@ -477,7 +473,8 @@ bool HttpResponseCommand::handleOtherEncoding
   getSegmentMan()->getSegmentWithIndex(getCuid(), 0);
 
   getDownloadEngine()->addCommand
-    (createHttpDownloadCommand(std::move(httpResponse), streamFilter));
+    (createHttpDownloadCommand(std::move(httpResponse),
+                               std::move(streamFilter)));
   return true;
 }
 
@@ -492,7 +489,7 @@ bool HttpResponseCommand::skipResponseBody
     (getCuid(), getRequest(), getFileEntry(), getRequestGroup(),
      httpConnection_, std::move(httpResponse),
      getDownloadEngine(), getSocket());
-  command->installStreamFilter(filter);
+  command->installStreamFilter(std::move(filter));
 
   // If request method is HEAD or the response body is zero-length,
   // set command's status to real time so that avoid read check blocking
@@ -510,11 +507,10 @@ bool HttpResponseCommand::skipResponseBody
 }
 
 namespace {
-bool decideFileAllocation
-(const std::shared_ptr<StreamFilter>& filter)
+bool decideFileAllocation(StreamFilter* filter)
 {
 #ifdef HAVE_ZLIB
-  for(std::shared_ptr<StreamFilter> f = filter; f; f = f->getDelegate()){
+  for(StreamFilter* f = filter; f; f = f->getDelegate().get()){
     // Since the compressed file's length are returned in the response header
     // and the decompressed file size is unknown at this point, disable file
     // allocation here.
@@ -530,7 +526,7 @@ bool decideFileAllocation
 std::unique_ptr<HttpDownloadCommand>
 HttpResponseCommand::createHttpDownloadCommand
 (std::unique_ptr<HttpResponse> httpResponse,
- const std::shared_ptr<StreamFilter>& filter)
+ std::unique_ptr<StreamFilter> filter)
 {
 
   auto command = make_unique<HttpDownloadCommand>
@@ -541,11 +537,11 @@ HttpResponseCommand::createHttpDownloadCommand
   command->setStartupIdleTime(getOption()->getAsInt(PREF_STARTUP_IDLE_TIME));
   command->setLowestDownloadSpeedLimit
     (getOption()->getAsInt(PREF_LOWEST_SPEED_LIMIT));
-  command->installStreamFilter(filter);
   if(getRequestGroup()->isFileAllocationEnabled() &&
-     !decideFileAllocation(filter)) {
+     !decideFileAllocation(filter.get())) {
     getRequestGroup()->setFileAllocationEnabled(false);
   }
+  command->installStreamFilter(std::move(filter));
   getRequestGroup()->getURISelector()->tuneDownloadCommand
     (getFileEntry()->getRemainingUris(), command.get());
 

+ 1 - 1
src/HttpResponseCommand.h

@@ -70,7 +70,7 @@ private:
   std::unique_ptr<HttpDownloadCommand>
   createHttpDownloadCommand
   (std::unique_ptr<HttpResponse> httpResponse,
-   const std::shared_ptr<StreamFilter>& streamFilter);
+   std::unique_ptr<StreamFilter> streamFilter);
 
   void updateLastModifiedTime(const Time& lastModified);
 

+ 3 - 3
src/HttpSkipResponseCommand.cc

@@ -87,13 +87,13 @@ HttpSkipResponseCommand::HttpSkipResponseCommand
 HttpSkipResponseCommand::~HttpSkipResponseCommand() {}
 
 void HttpSkipResponseCommand::installStreamFilter
-(const std::shared_ptr<StreamFilter>& streamFilter)
+(std::unique_ptr<StreamFilter> streamFilter)
 {
   if(!streamFilter) {
     return;
   }
-  streamFilter->installDelegate(streamFilter_);
-  streamFilter_ = streamFilter;
+  streamFilter->installDelegate(std::move(streamFilter_));
+  streamFilter_ = std::move(streamFilter);
   const std::string& name = streamFilter_->getName();
   sinkFilterOnly_ = util::endsWith(name, SinkStreamFilter::NAME);
 }

+ 2 - 2
src/HttpSkipResponseCommand.h

@@ -49,7 +49,7 @@ private:
 
   std::unique_ptr<HttpResponse> httpResponse_;
 
-  std::shared_ptr<StreamFilter> streamFilter_;
+  std::unique_ptr<StreamFilter> streamFilter_;
 
   bool sinkFilterOnly_;
 
@@ -75,7 +75,7 @@ public:
 
   virtual ~HttpSkipResponseCommand();
 
-  void installStreamFilter(const std::shared_ptr<StreamFilter>& streamFilter);
+  void installStreamFilter(std::unique_ptr<StreamFilter> streamFilter);
 
   void disableSocketCheck();
 };

+ 6 - 6
src/StreamFilter.cc

@@ -36,19 +36,19 @@
 
 namespace aria2 {
 
-StreamFilter::StreamFilter
-(const std::shared_ptr<StreamFilter>& delegate):
-  delegate_(delegate) {}
+StreamFilter::StreamFilter(std::unique_ptr<StreamFilter> delegate)
+  : delegate_(std::move(delegate))
+{}
 
 StreamFilter::~StreamFilter() {}
 
-bool StreamFilter::installDelegate(const std::shared_ptr<StreamFilter>& filter)
+bool StreamFilter::installDelegate(std::unique_ptr<StreamFilter> filter)
 {
   if(!delegate_) {
-    delegate_ = filter;
+    delegate_ = std::move(filter);
     return true;
   } else {
-    return delegate_->installDelegate(filter);
+    return delegate_->installDelegate(std::move(filter));
   }
 }
 

+ 4 - 4
src/StreamFilter.h

@@ -48,10 +48,10 @@ class Segment;
 // Interface for basic decoding functionality.
 class StreamFilter {
 private:
-  std::shared_ptr<StreamFilter> delegate_;
+  std::unique_ptr<StreamFilter> delegate_;
 public:
   StreamFilter
-  (const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>());
+  (std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{});
 
   virtual ~StreamFilter();
 
@@ -75,9 +75,9 @@ public:
   // tranfrom() invocation.
   virtual size_t getBytesProcessed() const = 0;
 
-  virtual bool installDelegate(const std::shared_ptr<StreamFilter>& filter);
+  virtual bool installDelegate(std::unique_ptr<StreamFilter> filter);
 
-  std::shared_ptr<StreamFilter> getDelegate() const
+  const std::unique_ptr<StreamFilter>& getDelegate() const
   {
     return delegate_;
   }

+ 7 - 7
test/ChunkedDecodingStreamFilterTest.cc

@@ -9,6 +9,7 @@
 #include "ByteArrayDiskWriter.h"
 #include "SinkStreamFilter.h"
 #include "MockSegment.h"
+#include "a2functional.h"
 
 namespace aria2 {
 
@@ -24,8 +25,7 @@ class ChunkedDecodingStreamFilterTest:public CppUnit::TestFixture {
   CPPUNIT_TEST(testGetName);
   CPPUNIT_TEST_SUITE_END();
 
-  std::shared_ptr<ChunkedDecodingStreamFilter> filter_;
-  std::shared_ptr<SinkStreamFilter> sinkFilter_;
+  std::unique_ptr<ChunkedDecodingStreamFilter> filter_;
   std::shared_ptr<ByteArrayDiskWriter> writer_;
   std::shared_ptr<Segment> segment_;
 
@@ -36,12 +36,12 @@ class ChunkedDecodingStreamFilterTest:public CppUnit::TestFixture {
 public:
   void setUp()
   {
-    writer_.reset(new ByteArrayDiskWriter());
-    sinkFilter_.reset(new SinkStreamFilter());
-    filter_.reset(new ChunkedDecodingStreamFilter(sinkFilter_));
-    sinkFilter_->init();
+    writer_ = std::make_shared<ByteArrayDiskWriter>();
+    auto sinkFilter = make_unique<SinkStreamFilter>();
+    sinkFilter->init();
+    filter_ = make_unique<ChunkedDecodingStreamFilter>(std::move(sinkFilter));
     filter_->init();
-    segment_.reset(new MockSegment());
+    segment_ = std::make_shared<MockSegment>();
   }
 
   void testTransform();

+ 8 - 9
test/GZipDecodingStreamFilterTest.cc

@@ -30,30 +30,29 @@ class GZipDecodingStreamFilterTest:public CppUnit::TestFixture {
   public:
     MockSegment2():positionToWrite_(0) {}
 
-    virtual void updateWrittenLength(int32_t bytes)
+    virtual void updateWrittenLength(int32_t bytes) override
     {
       positionToWrite_ += bytes;
     }
 
-    virtual int64_t getPositionToWrite() const
+    virtual int64_t getPositionToWrite() const override
     {
       return positionToWrite_;
     }
   };
 
-  std::shared_ptr<GZipDecodingStreamFilter> filter_;
-  std::shared_ptr<SinkStreamFilter> sinkFilter_;
+  std::unique_ptr<GZipDecodingStreamFilter> filter_;
   std::shared_ptr<ByteArrayDiskWriter> writer_;
   std::shared_ptr<MockSegment2> segment_;
 public:
   void setUp()
   {
-    writer_.reset(new ByteArrayDiskWriter());
-    sinkFilter_.reset(new SinkStreamFilter());
-    filter_.reset(new GZipDecodingStreamFilter(sinkFilter_));
-    sinkFilter_->init();
+    writer_ = std::make_shared<ByteArrayDiskWriter>();
+    auto sinkFilter = make_unique<SinkStreamFilter>();
+    sinkFilter->init();
+    filter_ = make_unique<GZipDecodingStreamFilter>(std::move(sinkFilter));
     filter_->init();
-    segment_.reset(new MockSegment2());
+    segment_ = std::make_shared<MockSegment2>();
   }
 
   void testTransform();