1 /* <!-- copyright */
2 /*
3 * aria2 - The high speed download utility
4 *
5 * Copyright (C) 2006 Tatsuhiro Tsujikawa
6 *
7 * This program is free software; you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation; either version 2 of the License, or
10 * (at your option) any later version.
11 *
12 * This program is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with this program; if not, write to the Free Software
19 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20 *
21 * In addition, as a special exception, the copyright holders give
22 * permission to link the code of portions of this program with the
23 * OpenSSL library under certain conditions as described in each
24 * individual source file, and distribute linked combinations
25 * including the two.
26 * You must obey the GNU General Public License in all respects
27 * for all of the code used other than OpenSSL. If you modify
28 * file(s) with this exception, you may extend this exception to your
29 * version of the file(s), but you are not obligated to do so. If you
30 * do not wish to do so, delete this exception statement from your
31 * version. If you delete this exception statement from all source
32 * files in the program, then also delete it here.
33 */
34 /* copyright --> */
35 #include "DHTMessageReceiver.h"
36
37 #include <cstring>
38 #include <utility>
39
40 #include "DHTMessageTracker.h"
41 #include "DHTMessage.h"
42 #include "DHTQueryMessage.h"
43 #include "DHTResponseMessage.h"
44 #include "DHTUnknownMessage.h"
45 #include "DHTMessageFactory.h"
46 #include "DHTRoutingTable.h"
47 #include "DHTNode.h"
48 #include "DHTMessageCallback.h"
49 #include "DlAbortEx.h"
50 #include "LogFactory.h"
51 #include "Logger.h"
52 #include "util.h"
53 #include "bencode2.h"
54 #include "fmt.h"
55
56 namespace aria2 {
57
DHTMessageReceiver(const std::shared_ptr<DHTMessageTracker> & tracker)58 DHTMessageReceiver::DHTMessageReceiver(
59 const std::shared_ptr<DHTMessageTracker>& tracker)
60 : tracker_{tracker}, factory_{nullptr}, routingTable_{nullptr}
61 {
62 }
63
64 std::unique_ptr<DHTMessage>
receiveMessage(const std::string & remoteAddr,uint16_t remotePort,unsigned char * data,size_t length)65 DHTMessageReceiver::receiveMessage(const std::string& remoteAddr,
66 uint16_t remotePort, unsigned char* data,
67 size_t length)
68 {
69 try {
70 bool isReply = false;
71 auto decoded = bencode2::decode(data, length);
72 const Dict* dict = downcast<Dict>(decoded);
73 if (dict) {
74 const String* y = downcast<String>(dict->get(DHTMessage::Y));
75 if (y) {
76 if (y->s() == DHTResponseMessage::R || y->s() == DHTUnknownMessage::E) {
77 isReply = true;
78 }
79 }
80 else {
81 A2_LOG_INFO(fmt("Malformed DHT message. Missing 'y' key. From:%s:%u",
82 remoteAddr.c_str(), remotePort));
83 return handleUnknownMessage(data, length, remoteAddr, remotePort);
84 }
85 }
86 else {
87 A2_LOG_INFO(fmt("Malformed DHT message. This is not a bencoded directory."
88 " From:%s:%u",
89 remoteAddr.c_str(), remotePort));
90 return handleUnknownMessage(data, length, remoteAddr, remotePort);
91 }
92 if (isReply) {
93 auto p = tracker_->messageArrived(dict, remoteAddr, remotePort);
94 if (!p.first) {
95 // timeout or malicious? message
96 return handleUnknownMessage(data, length, remoteAddr, remotePort);
97 }
98 onMessageReceived(p.first.get());
99 if (p.second) {
100 p.second->onReceived(p.first.get());
101 }
102 return std::move(p.first);
103 }
104 else {
105 auto message = factory_->createQueryMessage(dict, remoteAddr, remotePort);
106 if (*message->getLocalNode() == *message->getRemoteNode()) {
107 // drop message from localnode
108 A2_LOG_INFO("Received DHT message from localnode.");
109 return handleUnknownMessage(data, length, remoteAddr, remotePort);
110 }
111 onMessageReceived(message.get());
112 return std::move(message);
113 }
114 }
115 catch (RecoverableException& e) {
116 A2_LOG_INFO_EX("Exception thrown while receiving DHT message.", e);
117 return handleUnknownMessage(data, length, remoteAddr, remotePort);
118 }
119 }
120
onMessageReceived(DHTMessage * message)121 void DHTMessageReceiver::onMessageReceived(DHTMessage* message)
122 {
123 A2_LOG_INFO(fmt("Message received: %s", message->toString().c_str()));
124 message->validate();
125 message->doReceivedAction();
126 message->getRemoteNode()->markGood();
127 message->getRemoteNode()->updateLastContact();
128 routingTable_->addGoodNode(message->getRemoteNode());
129 }
130
handleTimeout()131 void DHTMessageReceiver::handleTimeout() { tracker_->handleTimeout(); }
132
handleUnknownMessage(const unsigned char * data,size_t length,const std::string & remoteAddr,uint16_t remotePort)133 std::unique_ptr<DHTUnknownMessage> DHTMessageReceiver::handleUnknownMessage(
134 const unsigned char* data, size_t length, const std::string& remoteAddr,
135 uint16_t remotePort)
136 {
137 auto m = factory_->createUnknownMessage(data, length, remoteAddr, remotePort);
138 A2_LOG_INFO(fmt("Message received: %s", m->toString().c_str()));
139 return m;
140 }
141
setMessageFactory(DHTMessageFactory * factory)142 void DHTMessageReceiver::setMessageFactory(DHTMessageFactory* factory)
143 {
144 factory_ = factory;
145 }
146
setRoutingTable(DHTRoutingTable * routingTable)147 void DHTMessageReceiver::setRoutingTable(DHTRoutingTable* routingTable)
148 {
149 routingTable_ = routingTable;
150 }
151
152 } // namespace aria2
153