/* */ #include "DHTBucketTree.h" #include #include #include "DHTBucket.h" #include "DHTNode.h" #include "a2functional.h" namespace aria2 { DHTBucketTreeNode::DHTBucketTreeNode(std::unique_ptr left, std::unique_ptr right) : parent_(nullptr), left_(std::move(left)), right_(std::move(right)) { resetRelation(); } DHTBucketTreeNode::DHTBucketTreeNode(std::shared_ptr bucket) : parent_(nullptr), bucket_(std::move(bucket)) { memcpy(minId_, bucket_->getMinID(), DHT_ID_LENGTH); memcpy(maxId_, bucket_->getMaxID(), DHT_ID_LENGTH); } DHTBucketTreeNode::~DHTBucketTreeNode() = default; void DHTBucketTreeNode::resetRelation() { left_->setParent(this); right_->setParent(this); memcpy(minId_, left_->getMinId(), DHT_ID_LENGTH); memcpy(maxId_, right_->getMaxId(), DHT_ID_LENGTH); } DHTBucketTreeNode* DHTBucketTreeNode::dig(const unsigned char* key) { if (leaf()) { return nullptr; } if (left_->isInRange(key)) { return left_.get(); } else { return right_.get(); } } bool DHTBucketTreeNode::isInRange(const unsigned char* key) const { return !std::lexicographical_compare(&key[0], &key[DHT_ID_LENGTH], &minId_[0], &minId_[DHT_ID_LENGTH]) && !std::lexicographical_compare(&maxId_[0], &maxId_[DHT_ID_LENGTH], &key[0], &key[DHT_ID_LENGTH]); } void DHTBucketTreeNode::split() { left_ = make_unique(bucket_->split()); right_ = make_unique(bucket_); bucket_.reset(); resetRelation(); } namespace dht { DHTBucketTreeNode* findTreeNodeFor(DHTBucketTreeNode* root, const unsigned char* key) { if (root->leaf()) { return root; } else { return findTreeNodeFor(root->dig(key), key); } } std::shared_ptr findBucketFor(DHTBucketTreeNode* root, const unsigned char* key) { DHTBucketTreeNode* leaf = findTreeNodeFor(root, key); return leaf->getBucket(); } namespace { void collectNodes(std::vector>& nodes, const std::shared_ptr& bucket) { std::vector> goodNodes; bucket->getGoodNodes(goodNodes); nodes.insert(nodes.end(), goodNodes.begin(), goodNodes.end()); } } // namespace namespace { void collectDownwardLeftFirst(std::vector>& nodes, DHTBucketTreeNode* tnode) { if (tnode->leaf()) { collectNodes(nodes, tnode->getBucket()); } else { collectDownwardLeftFirst(nodes, tnode->getLeft()); if (nodes.size() < DHTBucket::K) { collectDownwardLeftFirst(nodes, tnode->getRight()); } } } } // namespace namespace { void collectDownwardRightFirst(std::vector>& nodes, DHTBucketTreeNode* tnode) { if (tnode->leaf()) { collectNodes(nodes, tnode->getBucket()); } else { collectDownwardRightFirst(nodes, tnode->getRight()); if (nodes.size() < DHTBucket::K) { collectDownwardRightFirst(nodes, tnode->getLeft()); } } } } // namespace namespace { void collectUpward(std::vector>& nodes, DHTBucketTreeNode* from) { while (1) { DHTBucketTreeNode* parent = from->getParent(); if (!parent) { break; } if (parent->getLeft() == from) { collectNodes(nodes, parent->getRight()->getBucket()); } else { collectNodes(nodes, parent->getLeft()->getBucket()); } from = parent; if (DHTBucket::K <= nodes.size()) { break; } } } } // namespace void findClosestKNodes(std::vector>& nodes, DHTBucketTreeNode* root, const unsigned char* key) { size_t nodesSize = nodes.size(); if (DHTBucket::K <= nodesSize) { return; } DHTBucketTreeNode* leaf = findTreeNodeFor(root, key); if (leaf == root) { collectNodes(nodes, leaf->getBucket()); } else { DHTBucketTreeNode* parent = leaf->getParent(); if (parent->getLeft() == leaf) { collectDownwardLeftFirst(nodes, parent); } else { collectDownwardRightFirst(nodes, parent); } if (nodes.size() < DHTBucket::K) { collectUpward(nodes, parent); } } if (DHTBucket::K < nodes.size()) { nodes.erase(nodes.begin() + DHTBucket::K, nodes.end()); } } void enumerateBucket(std::vector>& buckets, DHTBucketTreeNode* root) { if (root->leaf()) { buckets.push_back(root->getBucket()); } else { enumerateBucket(buckets, root->getLeft()); enumerateBucket(buckets, root->getRight()); } } } // namespace dht } // namespace aria2