/* */ #ifndef D_DHT_ABSTRACT_NODE_LOOKUP_TASK_H #define D_DHT_ABSTRACT_NODE_LOOKUP_TASK_H #include "DHTAbstractTask.h" #include #include #include #include #include "DHTConstants.h" #include "DHTNodeLookupEntry.h" #include "DHTRoutingTable.h" #include "DHTMessageDispatcher.h" #include "DHTMessageFactory.h" #include "DHTMessage.h" #include "DHTNode.h" #include "DHTBucket.h" #include "LogFactory.h" #include "Logger.h" #include "util.h" #include "DHTIDCloser.h" #include "a2functional.h" #include "fmt.h" namespace aria2 { class DHTNode; class DHTMessage; template class DHTAbstractNodeLookupTask : public DHTAbstractTask { private: unsigned char targetID_[DHT_ID_LENGTH]; std::deque> entries_; size_t inFlightMessage_; template void toEntries(Container& entries, const std::vector>& nodes) const { for (auto& node : nodes) { entries.push_back(make_unique(node)); } } void sendMessage() { for (auto i = std::begin(entries_), eoi = std::end(entries_); i != eoi && inFlightMessage_ < ALPHA; ++i) { if ((*i)->used == false) { ++inFlightMessage_; (*i)->used = true; getMessageDispatcher()->addMessageToQueue(createMessage((*i)->node), createCallback()); } } } void sendMessageAndCheckFinish() { if (needsAdditionalOutgoingMessage()) { sendMessage(); } if (inFlightMessage_ == 0) { A2_LOG_DEBUG(fmt("Finished node_lookup for node ID %s", util::toHex(targetID_, DHT_ID_LENGTH).c_str())); onFinish(); updateBucket(); setFinished(true); } else { A2_LOG_DEBUG(fmt("%lu in flight message for node ID %s", static_cast(inFlightMessage_), util::toHex(targetID_, DHT_ID_LENGTH).c_str())); } } void updateBucket() {} protected: const unsigned char* getTargetID() const { return targetID_; } const std::deque>& getEntries() const { return entries_; } virtual void getNodesFromMessage(std::vector>& nodes, const ResponseMessage* message) = 0; virtual void onReceivedInternal(const ResponseMessage* message) {} virtual bool needsAdditionalOutgoingMessage() { return true; } virtual void onFinish() {} virtual std::unique_ptr createMessage(const std::shared_ptr& remoteNode) = 0; virtual std::unique_ptr createCallback() = 0; public: DHTAbstractNodeLookupTask(const unsigned char* targetID) : inFlightMessage_(0) { memcpy(targetID_, targetID, DHT_ID_LENGTH); } static const size_t ALPHA = 3; virtual void startup() CXX11_OVERRIDE { std::vector> nodes; getRoutingTable()->getClosestKNodes(nodes, targetID_); entries_.clear(); toEntries(entries_, nodes); if (entries_.empty()) { setFinished(true); } else { // TODO use RTT here inFlightMessage_ = 0; sendMessage(); if (inFlightMessage_ == 0) { A2_LOG_DEBUG("No message was sent in this lookup stage. Finished."); setFinished(true); } } } void onReceived(const ResponseMessage* message) { --inFlightMessage_; // Replace old Node ID with new Node ID. for (auto& entry : entries_) { if (entry->node->getIPAddress() == message->getRemoteNode()->getIPAddress() && entry->node->getPort() == message->getRemoteNode()->getPort()) { entry->node = message->getRemoteNode(); } } onReceivedInternal(message); std::vector> nodes; getNodesFromMessage(nodes, message); std::vector> newEntries; toEntries(newEntries, nodes); size_t count = 0; for (auto& ne : newEntries) { if (memcmp(getLocalNode()->getID(), ne->node->getID(), DHT_ID_LENGTH) != 0) { A2_LOG_DEBUG(fmt("Received nodes: id=%s, ip=%s", util::toHex(ne->node->getID(), DHT_ID_LENGTH).c_str(), ne->node->getIPAddress().c_str())); entries_.push_front(std::move(ne)); ++count; } } A2_LOG_DEBUG(fmt("%lu node lookup entries added.", static_cast(count))); std::stable_sort(std::begin(entries_), std::end(entries_), DHTIDCloser(targetID_)); entries_.erase( std::unique(std::begin(entries_), std::end(entries_), DerefEqualTo>{}), std::end(entries_)); A2_LOG_DEBUG(fmt("%lu node lookup entries are unique.", static_cast(entries_.size()))); if (entries_.size() > DHTBucket::K) { entries_.erase(std::begin(entries_) + DHTBucket::K, std::end(entries_)); } sendMessageAndCheckFinish(); } void onTimeout(const std::shared_ptr& node) { A2_LOG_DEBUG(fmt("node lookup message timeout for node ID=%s", util::toHex(node->getID(), DHT_ID_LENGTH).c_str())); --inFlightMessage_; for (auto i = std::begin(entries_), eoi = std::end(entries_); i != eoi; ++i) { if (*(*i)->node == *node) { entries_.erase(i); break; } } sendMessageAndCheckFinish(); } }; } // namespace aria2 #endif // D_DHT_ABSTRACT_NODE_LOOKUP_TASK_H