1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #include "td/telegram/net/NetQueryDispatcher.h"
8
9 #include "td/telegram/ConfigShared.h"
10 #include "td/telegram/Global.h"
11 #include "td/telegram/net/AuthDataShared.h"
12 #include "td/telegram/net/DcAuthManager.h"
13 #include "td/telegram/net/NetQuery.h"
14 #include "td/telegram/net/NetQueryDelayer.h"
15 #include "td/telegram/net/PublicRsaKeyShared.h"
16 #include "td/telegram/net/PublicRsaKeyWatchdog.h"
17 #include "td/telegram/net/SessionMultiProxy.h"
18 #include "td/telegram/Td.h"
19 #include "td/telegram/TdDb.h"
20 #include "td/telegram/telegram_api.h"
21
22 #include "td/utils/common.h"
23 #include "td/utils/format.h"
24 #include "td/utils/logging.h"
25 #include "td/utils/misc.h"
26 #include "td/utils/port/thread.h"
27 #include "td/utils/Slice.h"
28 #include "td/utils/SliceBuilder.h"
29
30 namespace td {
31
complete_net_query(NetQueryPtr net_query)32 void NetQueryDispatcher::complete_net_query(NetQueryPtr net_query) {
33 auto callback = net_query->move_callback();
34 if (callback.empty()) {
35 net_query->debug("sent to td (no callback)");
36 send_closure_later(G()->td(), &Td::on_result, std::move(net_query));
37 } else {
38 net_query->debug("sent to callback", true);
39 send_closure_later(std::move(callback), &NetQueryCallback::on_result, std::move(net_query));
40 }
41 }
42
dispatch(NetQueryPtr net_query)43 void NetQueryDispatcher::dispatch(NetQueryPtr net_query) {
44 // net_query->debug("dispatch");
45 if (stop_flag_.load(std::memory_order_relaxed)) {
46 net_query->set_error(Global::request_aborted_error());
47 return complete_net_query(std::move(net_query));
48 }
49 if (G()->shared_config().get_option_boolean("test_flood_wait")) {
50 net_query->set_error(Status::Error(429, "Too Many Requests: retry after 10"));
51 return complete_net_query(std::move(net_query));
52 }
53 if (net_query->tl_constructor() == telegram_api::account_getPassword::ID && false) {
54 net_query->set_error(Status::Error(429, "Too Many Requests: retry after 10"));
55 return complete_net_query(std::move(net_query));
56 }
57
58 if (net_query->is_ready()) {
59 if (net_query->is_error()) {
60 auto code = net_query->error().code();
61 if (code == 303) {
62 try_fix_migrate(net_query);
63 } else if (code == NetQuery::Resend) {
64 net_query->resend();
65 } else if (code < 0 || code == 500 || code == 420) {
66 net_query->debug("sent to NetQueryDelayer");
67 return send_closure(delayer_, &NetQueryDelayer::delay, std::move(net_query));
68 }
69 }
70 }
71
72 if (!net_query->is_ready()) {
73 if (net_query->dispatch_ttl_ == 0) {
74 net_query->set_error(Status::Error("DispatchTtlError"));
75 }
76 }
77
78 auto dest_dc_id = net_query->dc_id();
79 if (dest_dc_id.is_main()) {
80 dest_dc_id = DcId::internal(main_dc_id_.load(std::memory_order_relaxed));
81 }
82 if (!net_query->is_ready() && wait_dc_init(dest_dc_id, true).is_error()) {
83 net_query->set_error(Status::Error(PSLICE() << "No such dc " << dest_dc_id));
84 }
85
86 if (net_query->is_ready()) {
87 return complete_net_query(std::move(net_query));
88 }
89
90 if (net_query->dispatch_ttl_ > 0) {
91 net_query->dispatch_ttl_--;
92 }
93
94 auto dc_pos = static_cast<size_t>(dest_dc_id.get_raw_id() - 1);
95 CHECK(dc_pos < dcs_.size());
96 switch (net_query->type()) {
97 case NetQuery::Type::Common:
98 net_query->debug(PSTRING() << "sent to main session multi proxy " << dest_dc_id);
99 send_closure_later(dcs_[dc_pos].main_session_, &SessionMultiProxy::send, std::move(net_query));
100 break;
101 case NetQuery::Type::Upload:
102 net_query->debug(PSTRING() << "sent to upload session multi proxy " << dest_dc_id);
103 send_closure_later(dcs_[dc_pos].upload_session_, &SessionMultiProxy::send, std::move(net_query));
104 break;
105 case NetQuery::Type::Download:
106 net_query->debug(PSTRING() << "sent to download session multi proxy " << dest_dc_id);
107 send_closure_later(dcs_[dc_pos].download_session_, &SessionMultiProxy::send, std::move(net_query));
108 break;
109 case NetQuery::Type::DownloadSmall:
110 net_query->debug(PSTRING() << "sent to download small session multi proxy " << dest_dc_id);
111 send_closure_later(dcs_[dc_pos].download_small_session_, &SessionMultiProxy::send, std::move(net_query));
112 break;
113 }
114 }
115
wait_dc_init(DcId dc_id,bool force)116 Status NetQueryDispatcher::wait_dc_init(DcId dc_id, bool force) {
117 // TODO: optimize
118 if (!dc_id.is_exact()) {
119 return Status::Error("Not exact DC");
120 }
121 auto pos = static_cast<size_t>(dc_id.get_raw_id() - 1);
122 if (pos >= dcs_.size()) {
123 return Status::Error("Too big DC ID");
124 }
125 auto &dc = dcs_[pos];
126
127 bool should_init = false;
128 if (!dc.is_valid_) {
129 if (!force) {
130 return Status::Error("Invalid DC");
131 }
132 bool expected = false;
133 should_init =
134 dc.is_valid_.compare_exchange_strong(expected, true, std::memory_order_seq_cst, std::memory_order_seq_cst);
135 }
136
137 if (should_init) {
138 std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
139 if (stop_flag_.load(std::memory_order_relaxed) || need_destroy_auth_key_) {
140 return Status::Error("Closing");
141 }
142 // init dc
143 dc.id_ = dc_id;
144 decltype(common_public_rsa_key_) public_rsa_key;
145 bool is_cdn = false;
146 bool need_destroy_key = false;
147 if (dc_id.is_internal()) {
148 public_rsa_key = common_public_rsa_key_;
149 } else {
150 public_rsa_key = std::make_shared<PublicRsaKeyShared>(dc_id, G()->is_test_dc());
151 send_closure_later(public_rsa_key_watchdog_, &PublicRsaKeyWatchdog::add_public_rsa_key, public_rsa_key);
152 is_cdn = true;
153 }
154 auto auth_data = AuthDataShared::create(dc_id, std::move(public_rsa_key), td_guard_);
155 int32 session_count = get_session_count();
156 bool use_pfs = get_use_pfs();
157
158 int32 slow_net_scheduler_id = G()->get_slow_net_scheduler_id();
159
160 auto raw_dc_id = dc_id.get_raw_id();
161 int32 upload_session_count = raw_dc_id != 2 && raw_dc_id != 4 ? 8 : 4;
162 int32 download_session_count = 2;
163 int32 download_small_session_count = 2;
164 dc.main_session_ = create_actor<SessionMultiProxy>(PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":main",
165 session_count, auth_data, raw_dc_id == main_dc_id_, use_pfs,
166 false, false, is_cdn, need_destroy_key);
167 dc.upload_session_ = create_actor_on_scheduler<SessionMultiProxy>(
168 PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":upload", slow_net_scheduler_id, upload_session_count,
169 auth_data, false, use_pfs, false, true, is_cdn, need_destroy_key);
170 dc.download_session_ = create_actor_on_scheduler<SessionMultiProxy>(
171 PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download", slow_net_scheduler_id, download_session_count,
172 auth_data, false, use_pfs, true, true, is_cdn, need_destroy_key);
173 dc.download_small_session_ = create_actor_on_scheduler<SessionMultiProxy>(
174 PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download_small", slow_net_scheduler_id,
175 download_small_session_count, auth_data, false, use_pfs, true, true, is_cdn, need_destroy_key);
176 dc.is_inited_ = true;
177 if (dc_id.is_internal()) {
178 send_closure_later(dc_auth_manager_, &DcAuthManager::add_dc, std::move(auth_data));
179 }
180 } else {
181 while (!dc.is_inited_) {
182 if (stop_flag_.load(std::memory_order_relaxed)) {
183 return Status::Error("Closing");
184 }
185 #if !TD_THREAD_UNSUPPORTED
186 td::this_thread::yield();
187 #endif
188 }
189 }
190 return Status::OK();
191 }
192
dispatch_with_callback(NetQueryPtr net_query,ActorShared<NetQueryCallback> callback)193 void NetQueryDispatcher::dispatch_with_callback(NetQueryPtr net_query, ActorShared<NetQueryCallback> callback) {
194 net_query->set_callback(std::move(callback));
195 dispatch(std::move(net_query));
196 }
197
stop()198 void NetQueryDispatcher::stop() {
199 std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
200 td_guard_.reset();
201 stop_flag_ = true;
202 delayer_.hangup();
203 for (const auto &dc : dcs_) {
204 dc.main_session_.hangup();
205 dc.upload_session_.hangup();
206 dc.download_session_.hangup();
207 dc.download_small_session_.hangup();
208 }
209 public_rsa_key_watchdog_.reset();
210 dc_auth_manager_.reset();
211 }
212
update_session_count()213 void NetQueryDispatcher::update_session_count() {
214 std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
215 int32 session_count = get_session_count();
216 bool use_pfs = get_use_pfs();
217 for (size_t i = 1; i < MAX_DC_COUNT; i++) {
218 if (is_dc_inited(narrow_cast<int32>(i))) {
219 send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_options, session_count, use_pfs);
220 send_closure_later(dcs_[i - 1].upload_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
221 send_closure_later(dcs_[i - 1].download_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
222 send_closure_later(dcs_[i - 1].download_small_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
223 }
224 }
225 }
destroy_auth_keys(Promise<> promise)226 void NetQueryDispatcher::destroy_auth_keys(Promise<> promise) {
227 std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
228 LOG(INFO) << "Destroy auth keys";
229 need_destroy_auth_key_ = true;
230 for (size_t i = 1; i < MAX_DC_COUNT; i++) {
231 if (is_dc_inited(narrow_cast<int32>(i)) && dcs_[i - 1].id_.is_internal()) {
232 send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_destroy_auth_key,
233 need_destroy_auth_key_);
234 }
235 }
236 send_closure_later(dc_auth_manager_, &DcAuthManager::destroy, std::move(promise));
237 }
238
update_use_pfs()239 void NetQueryDispatcher::update_use_pfs() {
240 std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
241 bool use_pfs = get_use_pfs();
242 for (size_t i = 1; i < MAX_DC_COUNT; i++) {
243 if (is_dc_inited(narrow_cast<int32>(i))) {
244 send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
245 send_closure_later(dcs_[i - 1].upload_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
246 send_closure_later(dcs_[i - 1].download_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
247 send_closure_later(dcs_[i - 1].download_small_session_, &SessionMultiProxy::update_use_pfs, use_pfs);
248 }
249 }
250 }
251
update_mtproto_header()252 void NetQueryDispatcher::update_mtproto_header() {
253 std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
254 for (size_t i = 1; i < MAX_DC_COUNT; i++) {
255 if (is_dc_inited(narrow_cast<int32>(i))) {
256 send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_mtproto_header);
257 send_closure_later(dcs_[i - 1].upload_session_, &SessionMultiProxy::update_mtproto_header);
258 send_closure_later(dcs_[i - 1].download_session_, &SessionMultiProxy::update_mtproto_header);
259 send_closure_later(dcs_[i - 1].download_small_session_, &SessionMultiProxy::update_mtproto_header);
260 }
261 }
262 }
263
update_valid_dc(DcId dc_id)264 void NetQueryDispatcher::update_valid_dc(DcId dc_id) {
265 wait_dc_init(dc_id, true).ignore();
266 }
267
is_dc_inited(int32 raw_dc_id)268 bool NetQueryDispatcher::is_dc_inited(int32 raw_dc_id) {
269 return dcs_[raw_dc_id - 1].is_valid_.load(std::memory_order_relaxed);
270 }
get_session_count()271 int32 NetQueryDispatcher::get_session_count() {
272 return max(narrow_cast<int32>(G()->shared_config().get_option_integer("session_count")), 1);
273 }
274
get_use_pfs()275 bool NetQueryDispatcher::get_use_pfs() {
276 return G()->shared_config().get_option_boolean("use_pfs") || get_session_count() > 1;
277 }
278
NetQueryDispatcher(const std::function<ActorShared<> ()> & create_reference)279 NetQueryDispatcher::NetQueryDispatcher(const std::function<ActorShared<>()> &create_reference) {
280 auto s_main_dc_id = G()->td_db()->get_binlog_pmc()->get("main_dc_id");
281 if (!s_main_dc_id.empty()) {
282 main_dc_id_ = to_integer<int32>(s_main_dc_id);
283 }
284 LOG(INFO) << tag("main_dc_id", main_dc_id_.load(std::memory_order_relaxed));
285 delayer_ = create_actor<NetQueryDelayer>("NetQueryDelayer", create_reference());
286 dc_auth_manager_ = create_actor<DcAuthManager>("DcAuthManager", create_reference());
287 common_public_rsa_key_ = std::make_shared<PublicRsaKeyShared>(DcId::empty(), G()->is_test_dc());
288 public_rsa_key_watchdog_ = create_actor<PublicRsaKeyWatchdog>("PublicRsaKeyWatchdog", create_reference());
289
290 td_guard_ = create_shared_lambda_guard([actor = create_reference()] {});
291 }
292
293 NetQueryDispatcher::NetQueryDispatcher() = default;
294 NetQueryDispatcher::~NetQueryDispatcher() = default;
295
try_fix_migrate(NetQueryPtr & net_query)296 void NetQueryDispatcher::try_fix_migrate(NetQueryPtr &net_query) {
297 auto error_message = net_query->error().message();
298 static constexpr CSlice prefixes[] = {"PHONE_MIGRATE_", "NETWORK_MIGRATE_", "USER_MIGRATE_"};
299 for (auto &prefix : prefixes) {
300 if (error_message.substr(0, prefix.size()) == prefix) {
301 auto new_main_dc_id = to_integer<int32>(error_message.substr(prefix.size()));
302 set_main_dc_id(new_main_dc_id);
303
304 if (!net_query->dc_id().is_main()) {
305 LOG(ERROR) << "Receive " << error_message << " for query to non-main DC" << net_query->dc_id();
306 net_query->resend(DcId::internal(new_main_dc_id));
307 } else {
308 net_query->resend();
309 }
310 break;
311 }
312 }
313 }
314
set_main_dc_id(int32 new_main_dc_id)315 void NetQueryDispatcher::set_main_dc_id(int32 new_main_dc_id) {
316 if (!DcId::is_valid(new_main_dc_id)) {
317 LOG(ERROR) << "Receive wrong DC " << new_main_dc_id;
318 return;
319 }
320 if (new_main_dc_id == main_dc_id_.load(std::memory_order_relaxed)) {
321 return;
322 }
323
324 // Very rare event. Mutex is ok.
325 std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
326 if (new_main_dc_id == main_dc_id_) {
327 return;
328 }
329
330 LOG(INFO) << "Update main DcId from " << main_dc_id_.load(std::memory_order_relaxed) << " to " << new_main_dc_id;
331 if (is_dc_inited(main_dc_id_.load(std::memory_order_relaxed))) {
332 send_closure_later(dcs_[main_dc_id_ - 1].main_session_, &SessionMultiProxy::update_main_flag, false);
333 }
334 main_dc_id_ = new_main_dc_id;
335 if (is_dc_inited(main_dc_id_.load(std::memory_order_relaxed))) {
336 send_closure_later(dcs_[main_dc_id_ - 1].main_session_, &SessionMultiProxy::update_main_flag, true);
337 }
338 send_closure_later(dc_auth_manager_, &DcAuthManager::update_main_dc,
339 DcId::internal(main_dc_id_.load(std::memory_order_relaxed)));
340 G()->td_db()->get_binlog_pmc()->set("main_dc_id", to_string(main_dc_id_.load(std::memory_order_relaxed)));
341 }
342
343 } // namespace td
344