1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #include "td/telegram/net/NetQueryDispatcher.h"
8 
9 #include "td/telegram/ConfigShared.h"
10 #include "td/telegram/Global.h"
11 #include "td/telegram/net/AuthDataShared.h"
12 #include "td/telegram/net/DcAuthManager.h"
13 #include "td/telegram/net/NetQuery.h"
14 #include "td/telegram/net/NetQueryDelayer.h"
15 #include "td/telegram/net/PublicRsaKeyShared.h"
16 #include "td/telegram/net/PublicRsaKeyWatchdog.h"
17 #include "td/telegram/net/SessionMultiProxy.h"
18 #include "td/telegram/Td.h"
19 #include "td/telegram/TdDb.h"
20 #include "td/telegram/telegram_api.h"
21 
22 #include "td/utils/common.h"
23 #include "td/utils/format.h"
24 #include "td/utils/logging.h"
25 #include "td/utils/misc.h"
26 #include "td/utils/port/thread.h"
27 #include "td/utils/Slice.h"
28 #include "td/utils/SliceBuilder.h"
29 
30 namespace td {
31 
complete_net_query(NetQueryPtr net_query)32 void NetQueryDispatcher::complete_net_query(NetQueryPtr net_query) {
33   auto callback = net_query->move_callback();
34   if (callback.empty()) {
35     net_query->debug("sent to td (no callback)");
36     send_closure_later(G()->td(), &Td::on_result, std::move(net_query));
37   } else {
38     net_query->debug("sent to callback", true);
39     send_closure_later(std::move(callback), &NetQueryCallback::on_result, std::move(net_query));
40   }
41 }
42 
dispatch(NetQueryPtr net_query)43 void NetQueryDispatcher::dispatch(NetQueryPtr net_query) {
44   // net_query->debug("dispatch");
45   if (stop_flag_.load(std::memory_order_relaxed)) {
46     net_query->set_error(Global::request_aborted_error());
47     return complete_net_query(std::move(net_query));
48   }
49   if (G()->shared_config().get_option_boolean("test_flood_wait")) {
50     net_query->set_error(Status::Error(429, "Too Many Requests: retry after 10"));
51     return complete_net_query(std::move(net_query));
52   }
53   if (net_query->tl_constructor() == telegram_api::account_getPassword::ID && false) {
54     net_query->set_error(Status::Error(429, "Too Many Requests: retry after 10"));
55     return complete_net_query(std::move(net_query));
56   }
57 
58   if (net_query->is_ready()) {
59     if (net_query->is_error()) {
60       auto code = net_query->error().code();
61       if (code == 303) {
62         try_fix_migrate(net_query);
63       } else if (code == NetQuery::Resend) {
64         net_query->resend();
65       } else if (code < 0 || code == 500 || code == 420) {
66         net_query->debug("sent to NetQueryDelayer");
67         return send_closure(delayer_, &NetQueryDelayer::delay, std::move(net_query));
68       }
69     }
70   }
71 
72   if (!net_query->is_ready()) {
73     if (net_query->dispatch_ttl_ == 0) {
74       net_query->set_error(Status::Error("DispatchTtlError"));
75     }
76   }
77 
78   auto dest_dc_id = net_query->dc_id();
79   if (dest_dc_id.is_main()) {
80     dest_dc_id = DcId::internal(main_dc_id_.load(std::memory_order_relaxed));
81   }
82   if (!net_query->is_ready() && wait_dc_init(dest_dc_id, true).is_error()) {
83     net_query->set_error(Status::Error(PSLICE() << "No such dc " << dest_dc_id));
84   }
85 
86   if (net_query->is_ready()) {
87     return complete_net_query(std::move(net_query));
88   }
89 
90   if (net_query->dispatch_ttl_ > 0) {
91     net_query->dispatch_ttl_--;
92   }
93 
94   auto dc_pos = static_cast<size_t>(dest_dc_id.get_raw_id() - 1);
95   CHECK(dc_pos < dcs_.size());
96   switch (net_query->type()) {
97     case NetQuery::Type::Common:
98       net_query->debug(PSTRING() << "sent to main session multi proxy " << dest_dc_id);
99       send_closure_later(dcs_[dc_pos].main_session_, &SessionMultiProxy::send, std::move(net_query));
100       break;
101     case NetQuery::Type::Upload:
102       net_query->debug(PSTRING() << "sent to upload session multi proxy " << dest_dc_id);
103       send_closure_later(dcs_[dc_pos].upload_session_, &SessionMultiProxy::send, std::move(net_query));
104       break;
105     case NetQuery::Type::Download:
106       net_query->debug(PSTRING() << "sent to download session multi proxy " << dest_dc_id);
107       send_closure_later(dcs_[dc_pos].download_session_, &SessionMultiProxy::send, std::move(net_query));
108       break;
109     case NetQuery::Type::DownloadSmall:
110       net_query->debug(PSTRING() << "sent to download small session multi proxy " << dest_dc_id);
111       send_closure_later(dcs_[dc_pos].download_small_session_, &SessionMultiProxy::send, std::move(net_query));
112       break;
113   }
114 }
115 
wait_dc_init(DcId dc_id,bool force)116 Status NetQueryDispatcher::wait_dc_init(DcId dc_id, bool force) {
117   // TODO: optimize
118   if (!dc_id.is_exact()) {
119     return Status::Error("Not exact DC");
120   }
121   auto pos = static_cast<size_t>(dc_id.get_raw_id() - 1);
122   if (pos >= dcs_.size()) {
123     return Status::Error("Too big DC ID");
124   }
125   auto &dc = dcs_[pos];
126 
127   bool should_init = false;
128   if (!dc.is_valid_) {
129     if (!force) {
130       return Status::Error("Invalid DC");
131     }
132     bool expected = false;
133     should_init =
134         dc.is_valid_.compare_exchange_strong(expected, true, std::memory_order_seq_cst, std::memory_order_seq_cst);
135   }
136 
137   if (should_init) {
138     std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
139     if (stop_flag_.load(std::memory_order_relaxed) || need_destroy_auth_key_) {
140       return Status::Error("Closing");
141     }
142     // init dc
143     dc.id_ = dc_id;
144     decltype(common_public_rsa_key_) public_rsa_key;
145     bool is_cdn = false;
146     bool need_destroy_key = false;
147     if (dc_id.is_internal()) {
148       public_rsa_key = common_public_rsa_key_;
149     } else {
150       public_rsa_key = std::make_shared<PublicRsaKeyShared>(dc_id, G()->is_test_dc());
151       send_closure_later(public_rsa_key_watchdog_, &PublicRsaKeyWatchdog::add_public_rsa_key, public_rsa_key);
152       is_cdn = true;
153     }
154     auto auth_data = AuthDataShared::create(dc_id, std::move(public_rsa_key), td_guard_);
155     int32 session_count = get_session_count();
156     bool use_pfs = get_use_pfs();
157 
158     int32 slow_net_scheduler_id = G()->get_slow_net_scheduler_id();
159 
160     auto raw_dc_id = dc_id.get_raw_id();
161     int32 upload_session_count = raw_dc_id != 2 && raw_dc_id != 4 ? 8 : 4;
162     int32 download_session_count = 2;
163     int32 download_small_session_count = 2;
164     dc.main_session_ = create_actor<SessionMultiProxy>(PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":main",
165                                                        session_count, auth_data, raw_dc_id == main_dc_id_, use_pfs,
166                                                        false, false, is_cdn, need_destroy_key);
167     dc.upload_session_ = create_actor_on_scheduler<SessionMultiProxy>(
168         PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":upload", slow_net_scheduler_id, upload_session_count,
169         auth_data, false, use_pfs, false, true, is_cdn, need_destroy_key);
170     dc.download_session_ = create_actor_on_scheduler<SessionMultiProxy>(
171         PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download", slow_net_scheduler_id, download_session_count,
172         auth_data, false, use_pfs, true, true, is_cdn, need_destroy_key);
173     dc.download_small_session_ = create_actor_on_scheduler<SessionMultiProxy>(
174         PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download_small", slow_net_scheduler_id,
175         download_small_session_count, auth_data, false, use_pfs, true, true, is_cdn, need_destroy_key);
176     dc.is_inited_ = true;
177     if (dc_id.is_internal()) {
178       send_closure_later(dc_auth_manager_, &DcAuthManager::add_dc, std::move(auth_data));
179     }
180   } else {
181     while (!dc.is_inited_) {
182       if (stop_flag_.load(std::memory_order_relaxed)) {
183         return Status::Error("Closing");
184       }
185 #if !TD_THREAD_UNSUPPORTED
186       td::this_thread::yield();
187 #endif
188     }
189   }
190   return Status::OK();
191 }
192 
dispatch_with_callback(NetQueryPtr net_query,ActorShared<NetQueryCallback> callback)193 void NetQueryDispatcher::dispatch_with_callback(NetQueryPtr net_query, ActorShared<NetQueryCallback> callback) {
194   net_query->set_callback(std::move(callback));
195   dispatch(std::move(net_query));
196 }
197 
stop()198 void NetQueryDispatcher::stop() {
199   std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
200   td_guard_.reset();
201   stop_flag_ = true;
202   delayer_.hangup();
203   for (const auto &dc : dcs_) {
204     dc.main_session_.hangup();
205     dc.upload_session_.hangup();
206     dc.download_session_.hangup();
207     dc.download_small_session_.hangup();
208   }
209   public_rsa_key_watchdog_.reset();
210   dc_auth_manager_.reset();
211 }
212 
update_session_count()213 void NetQueryDispatcher::update_session_count() {
214   std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
215   int32 session_count = get_session_count();
216   bool use_pfs = get_use_pfs();
217   for (size_t i = 1; i < MAX_DC_COUNT; i++) {
218     if (is_dc_inited(narrow_cast<int32>(i))) {
219       send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_options, session_count, use_pfs);
220       send_closure_later(dcs_[i - 1].upload_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
221       send_closure_later(dcs_[i - 1].download_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
222       send_closure_later(dcs_[i - 1].download_small_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
223     }
224   }
225 }
destroy_auth_keys(Promise<> promise)226 void NetQueryDispatcher::destroy_auth_keys(Promise<> promise) {
227   std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
228   LOG(INFO) << "Destroy auth keys";
229   need_destroy_auth_key_ = true;
230   for (size_t i = 1; i < MAX_DC_COUNT; i++) {
231     if (is_dc_inited(narrow_cast<int32>(i)) && dcs_[i - 1].id_.is_internal()) {
232       send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_destroy_auth_key,
233                          need_destroy_auth_key_);
234     }
235   }
236   send_closure_later(dc_auth_manager_, &DcAuthManager::destroy, std::move(promise));
237 }
238 
update_use_pfs()239 void NetQueryDispatcher::update_use_pfs() {
240   std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
241   bool use_pfs = get_use_pfs();
242   for (size_t i = 1; i < MAX_DC_COUNT; i++) {
243     if (is_dc_inited(narrow_cast<int32>(i))) {
244       send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
245       send_closure_later(dcs_[i - 1].upload_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
246       send_closure_later(dcs_[i - 1].download_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
247       send_closure_later(dcs_[i - 1].download_small_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
248     }
249   }
250 }
251 
update_mtproto_header()252 void NetQueryDispatcher::update_mtproto_header() {
253   std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
254   for (size_t i = 1; i < MAX_DC_COUNT; i++) {
255     if (is_dc_inited(narrow_cast<int32>(i))) {
256       send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_mtproto_header);
257       send_closure_later(dcs_[i - 1].upload_session_, &SessionMultiProxy::update_mtproto_header);
258       send_closure_later(dcs_[i - 1].download_session_, &SessionMultiProxy::update_mtproto_header);
259       send_closure_later(dcs_[i - 1].download_small_session_, &SessionMultiProxy::update_mtproto_header);
260     }
261   }
262 }
263 
update_valid_dc(DcId dc_id)264 void NetQueryDispatcher::update_valid_dc(DcId dc_id) {
265   wait_dc_init(dc_id, true).ignore();
266 }
267 
is_dc_inited(int32 raw_dc_id)268 bool NetQueryDispatcher::is_dc_inited(int32 raw_dc_id) {
269   return dcs_[raw_dc_id - 1].is_valid_.load(std::memory_order_relaxed);
270 }
get_session_count()271 int32 NetQueryDispatcher::get_session_count() {
272   return max(narrow_cast<int32>(G()->shared_config().get_option_integer("session_count")), 1);
273 }
274 
get_use_pfs()275 bool NetQueryDispatcher::get_use_pfs() {
276   return G()->shared_config().get_option_boolean("use_pfs") || get_session_count() > 1;
277 }
278 
NetQueryDispatcher(const std::function<ActorShared<> ()> & create_reference)279 NetQueryDispatcher::NetQueryDispatcher(const std::function<ActorShared<>()> &create_reference) {
280   auto s_main_dc_id = G()->td_db()->get_binlog_pmc()->get("main_dc_id");
281   if (!s_main_dc_id.empty()) {
282     main_dc_id_ = to_integer<int32>(s_main_dc_id);
283   }
284   LOG(INFO) << tag("main_dc_id", main_dc_id_.load(std::memory_order_relaxed));
285   delayer_ = create_actor<NetQueryDelayer>("NetQueryDelayer", create_reference());
286   dc_auth_manager_ = create_actor<DcAuthManager>("DcAuthManager", create_reference());
287   common_public_rsa_key_ = std::make_shared<PublicRsaKeyShared>(DcId::empty(), G()->is_test_dc());
288   public_rsa_key_watchdog_ = create_actor<PublicRsaKeyWatchdog>("PublicRsaKeyWatchdog", create_reference());
289 
290   td_guard_ = create_shared_lambda_guard([actor = create_reference()] {});
291 }
292 
293 NetQueryDispatcher::NetQueryDispatcher() = default;
294 NetQueryDispatcher::~NetQueryDispatcher() = default;
295 
try_fix_migrate(NetQueryPtr & net_query)296 void NetQueryDispatcher::try_fix_migrate(NetQueryPtr &net_query) {
297   auto error_message = net_query->error().message();
298   static constexpr CSlice prefixes[] = {"PHONE_MIGRATE_", "NETWORK_MIGRATE_", "USER_MIGRATE_"};
299   for (auto &prefix : prefixes) {
300     if (error_message.substr(0, prefix.size()) == prefix) {
301       auto new_main_dc_id = to_integer<int32>(error_message.substr(prefix.size()));
302       set_main_dc_id(new_main_dc_id);
303 
304       if (!net_query->dc_id().is_main()) {
305         LOG(ERROR) << "Receive " << error_message << " for query to non-main DC" << net_query->dc_id();
306         net_query->resend(DcId::internal(new_main_dc_id));
307       } else {
308         net_query->resend();
309       }
310       break;
311     }
312   }
313 }
314 
set_main_dc_id(int32 new_main_dc_id)315 void NetQueryDispatcher::set_main_dc_id(int32 new_main_dc_id) {
316   if (!DcId::is_valid(new_main_dc_id)) {
317     LOG(ERROR) << "Receive wrong DC " << new_main_dc_id;
318     return;
319   }
320   if (new_main_dc_id == main_dc_id_.load(std::memory_order_relaxed)) {
321     return;
322   }
323 
324   // Very rare event. Mutex is ok.
325   std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
326   if (new_main_dc_id == main_dc_id_) {
327     return;
328   }
329 
330   LOG(INFO) << "Update main DcId from " << main_dc_id_.load(std::memory_order_relaxed) << " to " << new_main_dc_id;
331   if (is_dc_inited(main_dc_id_.load(std::memory_order_relaxed))) {
332     send_closure_later(dcs_[main_dc_id_ - 1].main_session_, &SessionMultiProxy::update_main_flag, false);
333   }
334   main_dc_id_ = new_main_dc_id;
335   if (is_dc_inited(main_dc_id_.load(std::memory_order_relaxed))) {
336     send_closure_later(dcs_[main_dc_id_ - 1].main_session_, &SessionMultiProxy::update_main_flag, true);
337   }
338   send_closure_later(dc_auth_manager_, &DcAuthManager::update_main_dc,
339                      DcId::internal(main_dc_id_.load(std::memory_order_relaxed)));
340   G()->td_db()->get_binlog_pmc()->set("main_dc_id", to_string(main_dc_id_.load(std::memory_order_relaxed)));
341 }
342 
343 }  // namespace td
344