1 /* <!-- copyright */
2 /*
3  * aria2 - The high speed download utility
4  *
5  * Copyright (C) 2006 Tatsuhiro Tsujikawa
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  *
21  * In addition, as a special exception, the copyright holders give
22  * permission to link the code of portions of this program with the
23  * OpenSSL library under certain conditions as described in each
24  * individual source file, and distribute linked combinations
25  * including the two.
26  * You must obey the GNU General Public License in all respects
27  * for all of the code used other than OpenSSL.  If you modify
28  * file(s) with this exception, you may extend this exception to your
29  * version of the file(s), but you are not obligated to do so.  If you
30  * do not wish to do so, delete this exception statement from your
31  * version.  If you delete this exception statement from all source
32  * files in the program, then also delete it here.
33  */
34 /* copyright --> */
35 #include "DHTMessageReceiver.h"
36 
37 #include <cstring>
38 #include <utility>
39 
40 #include "DHTMessageTracker.h"
41 #include "DHTMessage.h"
42 #include "DHTQueryMessage.h"
43 #include "DHTResponseMessage.h"
44 #include "DHTUnknownMessage.h"
45 #include "DHTMessageFactory.h"
46 #include "DHTRoutingTable.h"
47 #include "DHTNode.h"
48 #include "DHTMessageCallback.h"
49 #include "DlAbortEx.h"
50 #include "LogFactory.h"
51 #include "Logger.h"
52 #include "util.h"
53 #include "bencode2.h"
54 #include "fmt.h"
55 
56 namespace aria2 {
57 
DHTMessageReceiver(const std::shared_ptr<DHTMessageTracker> & tracker)58 DHTMessageReceiver::DHTMessageReceiver(
59     const std::shared_ptr<DHTMessageTracker>& tracker)
60     : tracker_{tracker}, factory_{nullptr}, routingTable_{nullptr}
61 {
62 }
63 
64 std::unique_ptr<DHTMessage>
receiveMessage(const std::string & remoteAddr,uint16_t remotePort,unsigned char * data,size_t length)65 DHTMessageReceiver::receiveMessage(const std::string& remoteAddr,
66                                    uint16_t remotePort, unsigned char* data,
67                                    size_t length)
68 {
69   try {
70     bool isReply = false;
71     auto decoded = bencode2::decode(data, length);
72     const Dict* dict = downcast<Dict>(decoded);
73     if (dict) {
74       const String* y = downcast<String>(dict->get(DHTMessage::Y));
75       if (y) {
76         if (y->s() == DHTResponseMessage::R || y->s() == DHTUnknownMessage::E) {
77           isReply = true;
78         }
79       }
80       else {
81         A2_LOG_INFO(fmt("Malformed DHT message. Missing 'y' key. From:%s:%u",
82                         remoteAddr.c_str(), remotePort));
83         return handleUnknownMessage(data, length, remoteAddr, remotePort);
84       }
85     }
86     else {
87       A2_LOG_INFO(fmt("Malformed DHT message. This is not a bencoded directory."
88                       " From:%s:%u",
89                       remoteAddr.c_str(), remotePort));
90       return handleUnknownMessage(data, length, remoteAddr, remotePort);
91     }
92     if (isReply) {
93       auto p = tracker_->messageArrived(dict, remoteAddr, remotePort);
94       if (!p.first) {
95         // timeout or malicious? message
96         return handleUnknownMessage(data, length, remoteAddr, remotePort);
97       }
98       onMessageReceived(p.first.get());
99       if (p.second) {
100         p.second->onReceived(p.first.get());
101       }
102       return std::move(p.first);
103     }
104     else {
105       auto message = factory_->createQueryMessage(dict, remoteAddr, remotePort);
106       if (*message->getLocalNode() == *message->getRemoteNode()) {
107         // drop message from localnode
108         A2_LOG_INFO("Received DHT message from localnode.");
109         return handleUnknownMessage(data, length, remoteAddr, remotePort);
110       }
111       onMessageReceived(message.get());
112       return std::move(message);
113     }
114   }
115   catch (RecoverableException& e) {
116     A2_LOG_INFO_EX("Exception thrown while receiving DHT message.", e);
117     return handleUnknownMessage(data, length, remoteAddr, remotePort);
118   }
119 }
120 
onMessageReceived(DHTMessage * message)121 void DHTMessageReceiver::onMessageReceived(DHTMessage* message)
122 {
123   A2_LOG_INFO(fmt("Message received: %s", message->toString().c_str()));
124   message->validate();
125   message->doReceivedAction();
126   message->getRemoteNode()->markGood();
127   message->getRemoteNode()->updateLastContact();
128   routingTable_->addGoodNode(message->getRemoteNode());
129 }
130 
handleTimeout()131 void DHTMessageReceiver::handleTimeout() { tracker_->handleTimeout(); }
132 
handleUnknownMessage(const unsigned char * data,size_t length,const std::string & remoteAddr,uint16_t remotePort)133 std::unique_ptr<DHTUnknownMessage> DHTMessageReceiver::handleUnknownMessage(
134     const unsigned char* data, size_t length, const std::string& remoteAddr,
135     uint16_t remotePort)
136 {
137   auto m = factory_->createUnknownMessage(data, length, remoteAddr, remotePort);
138   A2_LOG_INFO(fmt("Message received: %s", m->toString().c_str()));
139   return m;
140 }
141 
setMessageFactory(DHTMessageFactory * factory)142 void DHTMessageReceiver::setMessageFactory(DHTMessageFactory* factory)
143 {
144   factory_ = factory;
145 }
146 
setRoutingTable(DHTRoutingTable * routingTable)147 void DHTMessageReceiver::setRoutingTable(DHTRoutingTable* routingTable)
148 {
149   routingTable_ = routingTable;
150 }
151 
152 } // namespace aria2
153