Ver Fonte

Use SegList<int> instead of IntSequence in DownloadContext::setFileFilter()

Tatsuhiro Tsujikawa há 14 anos atrás
pai
commit
5749647ae5

+ 24 - 15
src/DownloadContext.cc

@@ -128,22 +128,31 @@ void DownloadContext::setFilePathWithIndex
   }
 }
 
-void DownloadContext::setFileFilter(IntSequence seq)
+void DownloadContext::setFileFilter(SegList<int>& sgl)
 {
-  std::vector<int32_t> fileIndexes = seq.flush();
-  std::sort(fileIndexes.begin(), fileIndexes.end());
-  fileIndexes.erase(std::unique(fileIndexes.begin(), fileIndexes.end()),
-                    fileIndexes.end());
-
-  bool selectAll = fileIndexes.empty() || fileEntries_.size() == 1;
-    
-  int32_t index = 1;
-  for(std::vector<SharedHandle<FileEntry> >::const_iterator i =
-        fileEntries_.begin(), eoi = fileEntries_.end();
-      i != eoi; ++i, ++index) {
-    (*i)->setRequested
-      (selectAll ||
-       std::binary_search(fileIndexes.begin(), fileIndexes.end(), index));
+  sgl.normalize();
+  if(!sgl.hasNext() || fileEntries_.size() == 1) {
+    std::for_each(fileEntries_.begin(), fileEntries_.end(),
+                  std::bind2nd(mem_fun_sh(&FileEntry::setRequested), true));
+    return;
+  }
+  assert(sgl.peek() >= 1);
+  size_t i = 0;
+  while(i < fileEntries_.size() && sgl.hasNext()) {
+    size_t idx = sgl.peek()-1;
+    if(i == idx) {
+      fileEntries_[i]->setRequested(true);
+      ++i;
+      sgl.next();
+    } else if(i < idx) {
+      fileEntries_[i]->setRequested(false);
+      ++i;
+    } else {
+      sgl.next();
+    }
+  }
+  for(; i < fileEntries_.size(); ++i) {
+    fileEntries_[i]->setRequested(false);
   }
 }
 

+ 2 - 2
src/DownloadContext.h

@@ -45,7 +45,7 @@
 #include "TimerA2.h"
 #include "A2STR.h"
 #include "ValueBase.h"
-#include "IntSequence.h"
+#include "SegList.h"
 
 namespace aria2 {
 
@@ -179,7 +179,7 @@ public:
     ownerRequestGroup_ = owner;
   }
 
-  void setFileFilter(IntSequence seq);
+  void setFileFilter(SegList<int>& sgl);
 
   // Sets file path for specified index. index starts from 1. The
   // index is the same used in setFileFilter(). path is not escaped by

+ 20 - 0
src/SegList.h

@@ -49,6 +49,13 @@ public:
     : index_(0), val_(std::numeric_limits<T>::min())
   {}
 
+  void clear()
+  {
+    seg_.clear();
+    index_ = 0;
+    val_ = std::numeric_limits<T>::min();
+  }
+
   // Transforms list of segments so that they are sorted ascending
   // order of starting value and intersecting and touching segments
   // are all merged into one. This function resets current position.
@@ -107,6 +114,19 @@ public:
     }
     return res;
   }
+
+  // Returns next value. Current position is not advanced.  If
+  // this fuction is called when hasNext() returns false, returns 0.
+  T peek() const
+  {
+    T res;
+    if(index_ < seg_.size()) {
+      res = val_;
+    } else {
+      res = 0;
+    }
+    return res;
+  }
 private:
   std::vector<std::pair<T, T> > seg_;
   size_t index_;

+ 4 - 1
src/download_helper.cc

@@ -62,6 +62,7 @@
 #include "ByteArrayDiskWriterFactory.h"
 #include "MetadataInfo.h"
 #include "OptionParser.h"
+#include "SegList.h"
 #ifdef ENABLE_BITTORRENT
 # include "bittorrent_helper.h"
 # include "BtConstants.h"
@@ -183,7 +184,9 @@ createBtRequestGroup(const std::string& torrentFilePath,
   if(adjustAnnounceUri) {
     bittorrent::adjustAnnounceUri(bittorrent::getTorrentAttrs(dctx), option);
   }
-  dctx->setFileFilter(util::parseIntRange(option->get(PREF_SELECT_FILE)));
+  SegList<int> sgl;
+  util::parseIntSegments(sgl, option->get(PREF_SELECT_FILE));
+  dctx->setFileFilter(sgl);
   std::istringstream indexOutIn(option->get(PREF_INDEX_OUT));
   std::map<size_t, std::string> indexPathMap =
     util::createIndexPathMap(indexOutIn);

+ 32 - 0
src/util.cc

@@ -774,6 +774,38 @@ IntSequence parseIntRange(const std::string& src)
   return values;
 }
 
+void parseIntSegments(SegList<int>& sgl, const std::string& src)
+{
+  for(std::string::const_iterator i = src.begin(), eoi = src.end(); i != eoi;) {
+    std::string::const_iterator j = i;
+    while(j != eoi && *j != ',') {
+      ++j;
+    }
+    if(j == i) {
+      ++i;
+      continue;
+    }
+    std::string::const_iterator p = i;
+    while(p != j && *p != '-') {
+      ++p;
+    }
+    if(p == j) {
+      int a = parseInt(std::string(i, j));
+      sgl.add(a, a+1);
+    } else if(p == i || p+1 == j) {
+      throw DL_ABORT_EX(fmt(MSG_INCOMPLETE_RANGE, std::string(i, j).c_str()));
+    } else {
+      int a = parseInt(std::string(i, p));
+      int b = parseInt(std::string(p+1, j));
+      sgl.add(a, b+1);
+    }
+    if(j == eoi) {
+      break;
+    }
+    i = j+1;
+  }
+}
+
 namespace {
 void computeHeadPieces
 (std::vector<size_t>& indexes,

+ 3 - 0
src/util.h

@@ -55,6 +55,7 @@
 #include "a2time.h"
 #include "a2netcompat.h"
 #include "a2functional.h"
+#include "SegList.h"
 
 namespace aria2 {
 
@@ -220,6 +221,8 @@ uint64_t parseULLInt(const std::string& s, int base = 10);
 
 IntSequence parseIntRange(const std::string& src);
 
+void parseIntSegments(SegList<int>& sgl, const std::string& src);
+
 // Parses string which specifies the range of piece index for higher
 // priority and appends those indexes into result.  The input string
 // src can contain 2 keywords "head" and "tail".  To include both

+ 18 - 8
test/BittorrentHelperTest.cc

@@ -645,16 +645,20 @@ void BittorrentHelperTest::testSetFileFilter_single()
   load(A2_TEST_DIR"/single.torrent", dctx, option_);
 
   CPPUNIT_ASSERT(dctx->getFirstFileEntry()->isRequested());
-
-  dctx->setFileFilter(util::parseIntRange(""));
+  SegList<int> sgl;
+  dctx->setFileFilter(sgl);
   CPPUNIT_ASSERT(dctx->getFirstFileEntry()->isRequested());
 
-  dctx->setFileFilter(util::parseIntRange("1"));
+  sgl.clear();
+  sgl.add(1, 2);
+  dctx->setFileFilter(sgl);
   CPPUNIT_ASSERT(dctx->getFirstFileEntry()->isRequested());
 
   // For single file torrent, file is always selected whatever range
   // is passed.
-  dctx->setFileFilter(util::parseIntRange("2"));
+  sgl.clear();
+  sgl.add(2, 3);
+  dctx->setFileFilter(sgl);
   CPPUNIT_ASSERT(dctx->getFirstFileEntry()->isRequested());
 }
 
@@ -666,19 +670,25 @@ void BittorrentHelperTest::testSetFileFilter_multi()
   CPPUNIT_ASSERT(dctx->getFileEntries()[0]->isRequested());
   CPPUNIT_ASSERT(dctx->getFileEntries()[1]->isRequested());
 
-  dctx->setFileFilter(util::parseIntRange(""));
+  SegList<int> sgl;
+  dctx->setFileFilter(sgl);
   CPPUNIT_ASSERT(dctx->getFileEntries()[0]->isRequested());
   CPPUNIT_ASSERT(dctx->getFileEntries()[1]->isRequested());
 
-  dctx->setFileFilter(util::parseIntRange("2"));
+  sgl.add(2, 3);
+  dctx->setFileFilter(sgl);
   CPPUNIT_ASSERT(!dctx->getFileEntries()[0]->isRequested());
   CPPUNIT_ASSERT(dctx->getFileEntries()[1]->isRequested());
 
-  dctx->setFileFilter(util::parseIntRange("3"));
+  sgl.clear();
+  sgl.add(3, 4);
+  dctx->setFileFilter(sgl);
   CPPUNIT_ASSERT(!dctx->getFileEntries()[0]->isRequested());
   CPPUNIT_ASSERT(!dctx->getFileEntries()[1]->isRequested());
 
-  dctx->setFileFilter(util::parseIntRange("1,2"));
+  sgl.clear();
+  util::parseIntSegments(sgl, "1,2");
+  dctx->setFileFilter(sgl);
   CPPUNIT_ASSERT(dctx->getFileEntries()[0]->isRequested());
   CPPUNIT_ASSERT(dctx->getFileEntries()[1]->isRequested());
 }

+ 26 - 0
test/DownloadContextTest.cc

@@ -14,12 +14,14 @@ class DownloadContextTest:public CppUnit::TestFixture {
   CPPUNIT_TEST(testGetPieceHash);
   CPPUNIT_TEST(testGetNumPieces);
   CPPUNIT_TEST(testGetBasePath);
+  CPPUNIT_TEST(testSetFileFilter);
   CPPUNIT_TEST_SUITE_END();
 public:
   void testFindFileEntryByOffset();
   void testGetPieceHash();
   void testGetNumPieces();
   void testGetBasePath();
+  void testSetFileFilter();
 };
 
 
@@ -73,4 +75,28 @@ void DownloadContextTest::testGetBasePath()
   CPPUNIT_ASSERT_EQUAL(std::string("aria2.tar.bz2"), ctx.getBasePath());
 }
 
+void DownloadContextTest::testSetFileFilter()
+{
+  DownloadContext ctx;
+  std::vector<SharedHandle<FileEntry> > files;
+  for(int i = 0; i < 10; ++i) {
+    files.push_back(SharedHandle<FileEntry>(new FileEntry("file", 1, i)));
+  }
+  ctx.setFileEntries(files.begin(), files.end());
+  SegList<int> sgl;
+  util::parseIntSegments(sgl, "2-4,6-8");
+  ctx.setFileFilter(sgl);
+  const std::vector<SharedHandle<FileEntry> >& res = ctx.getFileEntries();
+  CPPUNIT_ASSERT(!res[0]->isRequested());
+  CPPUNIT_ASSERT(res[1]->isRequested());
+  CPPUNIT_ASSERT(res[2]->isRequested());
+  CPPUNIT_ASSERT(res[3]->isRequested());
+  CPPUNIT_ASSERT(!res[4]->isRequested());
+  CPPUNIT_ASSERT(res[5]->isRequested());
+  CPPUNIT_ASSERT(res[6]->isRequested());
+  CPPUNIT_ASSERT(res[7]->isRequested());
+  CPPUNIT_ASSERT(!res[8]->isRequested());
+  CPPUNIT_ASSERT(!res[9]->isRequested());
+}
+
 } // namespace aria2

+ 30 - 0
test/SegListTest.cc

@@ -8,10 +8,14 @@ class SegListTest:public CppUnit::TestFixture {
 
   CPPUNIT_TEST_SUITE(SegListTest);
   CPPUNIT_TEST(testNext);
+  CPPUNIT_TEST(testPeek);
+  CPPUNIT_TEST(testClear);
   CPPUNIT_TEST(testNormalize);
   CPPUNIT_TEST_SUITE_END();
 public:
   void testNext();
+  void testPeek();
+  void testClear();
   void testNormalize();
 };
 
@@ -40,6 +44,32 @@ void SegListTest::testNext()
   CPPUNIT_ASSERT_EQUAL(0, sgl.next());
 }
 
+void SegListTest::testPeek()
+{
+  SegList<int> sgl;
+  sgl.add(1, 3);
+  sgl.add(4, 5);
+  CPPUNIT_ASSERT_EQUAL(1, sgl.peek());
+  CPPUNIT_ASSERT_EQUAL(1, sgl.peek());
+  CPPUNIT_ASSERT_EQUAL(1, sgl.next());
+  CPPUNIT_ASSERT_EQUAL(2, sgl.peek());
+  CPPUNIT_ASSERT_EQUAL(2, sgl.next());
+  CPPUNIT_ASSERT_EQUAL(4, sgl.peek());
+  CPPUNIT_ASSERT_EQUAL(4, sgl.next());
+  CPPUNIT_ASSERT(!sgl.hasNext());
+}
+
+void SegListTest::testClear()
+{
+  SegList<int> sgl;
+  sgl.add(1, 3);
+  CPPUNIT_ASSERT_EQUAL(1, sgl.next());
+  sgl.clear();
+  CPPUNIT_ASSERT(!sgl.hasNext());
+  sgl.add(2, 3);
+  CPPUNIT_ASSERT_EQUAL(2, sgl.next());
+}
+
 void SegListTest::testNormalize()
 {
   SegList<int> sgl;

+ 75 - 0
test/UtilTest.cc

@@ -47,6 +47,8 @@ class UtilTest:public CppUnit::TestFixture {
   CPPUNIT_TEST(testConvertBitfield);
   CPPUNIT_TEST(testParseIntRange);
   CPPUNIT_TEST(testParseIntRange_invalidRange);
+  CPPUNIT_TEST(testParseIntSegments);
+  CPPUNIT_TEST(testParseIntSegments_invalidRange);
   CPPUNIT_TEST(testParseInt);
   CPPUNIT_TEST(testParseUInt);
   CPPUNIT_TEST(testParseLLInt);
@@ -107,6 +109,8 @@ public:
   void testConvertBitfield();
   void testParseIntRange();
   void testParseIntRange_invalidRange();
+  void testParseIntSegments();
+  void testParseIntSegments_invalidRange();
   void testParseInt();
   void testParseUInt();
   void testParseLLInt();
@@ -742,6 +746,77 @@ void UtilTest::testParseIntRange_invalidRange()
   }
 }
 
+void UtilTest::testParseIntSegments()
+{
+  SegList<int> sgl;
+  util::parseIntSegments(sgl, "1,3-8,10");
+
+  CPPUNIT_ASSERT(sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(1, sgl.next());
+  CPPUNIT_ASSERT(sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(3, sgl.next());
+  CPPUNIT_ASSERT(sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(4, sgl.next());
+  CPPUNIT_ASSERT(sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(5, sgl.next());
+  CPPUNIT_ASSERT(sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(6, sgl.next());
+  CPPUNIT_ASSERT(sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(7, sgl.next());
+  CPPUNIT_ASSERT(sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(8, sgl.next());
+  CPPUNIT_ASSERT(sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(10, sgl.next());
+  CPPUNIT_ASSERT(!sgl.hasNext());
+  CPPUNIT_ASSERT_EQUAL(0, sgl.next());
+
+  sgl.clear();
+  util::parseIntSegments(sgl, ",,,1,,,3,,,");
+  CPPUNIT_ASSERT_EQUAL(1, sgl.next());
+  CPPUNIT_ASSERT_EQUAL(3, sgl.next());
+  CPPUNIT_ASSERT(!sgl.hasNext());
+}
+
+void UtilTest::testParseIntSegments_invalidRange()
+{
+  try {
+    SegList<int> sgl;
+    util::parseIntSegments(sgl, "-1");
+    CPPUNIT_FAIL("exception must be thrown.");
+  } catch(Exception& e) {
+  }
+  try {
+    SegList<int> sgl;
+    util::parseIntSegments(sgl, "1-");
+    CPPUNIT_FAIL("exception must be thrown.");
+  } catch(Exception& e) {
+  }
+  try {
+    SegList<int> sgl;
+    util::parseIntSegments(sgl, "2147483648");
+    CPPUNIT_FAIL("exception must be thrown.");
+  } catch(Exception& e) {
+  }
+  try {
+    SegList<int> sgl;
+    util::parseIntSegments(sgl, "2147483647-2147483648");
+    CPPUNIT_FAIL("exception must be thrown.");
+  } catch(Exception& e) {
+  }
+  try {
+    SegList<int> sgl;
+    util::parseIntSegments(sgl, "1-2x");
+    CPPUNIT_FAIL("exception must be thrown.");
+  } catch(Exception& e) {
+  }
+  try {
+    SegList<int> sgl;
+    util::parseIntSegments(sgl, "3x-4");
+    CPPUNIT_FAIL("exception must be thrown.");
+  } catch(Exception& e) {
+  }
+}
+
 void UtilTest::testParseInt()
 {
   CPPUNIT_ASSERT_EQUAL(-1, util::parseInt(" -1 "));