1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #pragma once
8
9 #include "td/telegram/net/DcId.h"
10 #include "td/telegram/net/NetQueryCounter.h"
11 #include "td/telegram/net/NetQueryStats.h"
12
13 #include "td/actor/actor.h"
14 #include "td/actor/PromiseFuture.h"
15 #include "td/actor/SignalSlot.h"
16
17 #include "td/utils/buffer.h"
18 #include "td/utils/common.h"
19 #include "td/utils/format.h"
20 #include "td/utils/logging.h"
21 #include "td/utils/ObjectPool.h"
22 #include "td/utils/Slice.h"
23 #include "td/utils/Status.h"
24 #include "td/utils/StringBuilder.h"
25 #include "td/utils/Time.h"
26 #include "td/utils/tl_parsers.h"
27 #include "td/utils/TsList.h"
28
29 #include <atomic>
30 #include <utility>
31
32 namespace td {
33
34 extern int VERBOSITY_NAME(net_query);
35
36 class NetQuery;
37 using NetQueryPtr = ObjectPool<NetQuery>::OwnerPtr;
38 using NetQueryRef = ObjectPool<NetQuery>::WeakPtr;
39
40 class NetQueryCallback : public Actor {
41 public:
42 virtual void on_result(NetQueryPtr query);
43 virtual void on_result_resendable(NetQueryPtr query, Promise<NetQueryPtr> promise);
44 };
45
46 class NetQuery final : public TsListNode<NetQueryDebug> {
47 public:
48 NetQuery() = default;
49
50 enum class State : int8 { Empty, Query, OK, Error };
51 enum class Type : int8 { Common, Upload, Download, DownloadSmall };
52 enum class AuthFlag : int8 { Off, On };
53 enum class GzipFlag : int8 { Off, On };
54 enum Error : int32 { Resend = 202, Canceled = 203, ResendInvokeAfter = 204 };
55
id()56 uint64 id() const {
57 return id_;
58 }
59
dc_id()60 DcId dc_id() const {
61 return dc_id_;
62 }
63
type()64 Type type() const {
65 return type_;
66 }
67
gzip_flag()68 GzipFlag gzip_flag() const {
69 return gzip_flag_;
70 }
71
auth_flag()72 AuthFlag auth_flag() const {
73 return auth_flag_;
74 }
75
tl_constructor()76 int32 tl_constructor() const {
77 return tl_constructor_;
78 }
79
resend(DcId new_dc_id)80 void resend(DcId new_dc_id) {
81 VLOG(net_query) << "Resend" << *this;
82 {
83 auto guard = lock();
84 get_data_unsafe().resend_count_++;
85 }
86 dc_id_ = new_dc_id;
87 status_ = Status::OK();
88 state_ = State::Query;
89 }
90
resend()91 void resend() {
92 resend(dc_id_);
93 }
94
query()95 BufferSlice &query() {
96 return query_;
97 }
98
ok()99 BufferSlice &ok() {
100 CHECK(state_ == State::OK);
101 return answer_;
102 }
103
ok()104 const BufferSlice &ok() const {
105 CHECK(state_ == State::OK);
106 return answer_;
107 }
108
error()109 Status &error() {
110 CHECK(state_ == State::Error);
111 return status_;
112 }
113
error()114 const Status &error() const {
115 CHECK(state_ == State::Error);
116 return status_;
117 }
118
move_as_ok()119 BufferSlice move_as_ok() {
120 auto ok = std::move(answer_);
121 clear();
122 return ok;
123 }
move_as_error()124 Status move_as_error() TD_WARN_UNUSED_RESULT {
125 auto status = std::move(status_);
126 clear();
127 return status;
128 }
129
set_ok(BufferSlice slice)130 void set_ok(BufferSlice slice) {
131 VLOG(net_query) << "Got answer " << *this;
132 CHECK(state_ == State::Query);
133 answer_ = std::move(slice);
134 state_ = State::OK;
135 }
136
137 void on_net_write(size_t size);
138 void on_net_read(size_t size);
139
140 void set_error(Status status, string source = string());
141
set_error_resend()142 void set_error_resend() {
143 set_error_impl(Status::Error<Error::Resend>());
144 }
145
set_error_canceled()146 void set_error_canceled() {
147 set_error_impl(Status::Error<Error::Canceled>());
148 }
149
set_error_resend_invoke_after()150 void set_error_resend_invoke_after() {
151 set_error_impl(Status::Error<Error::ResendInvokeAfter>());
152 }
153
update_is_ready()154 bool update_is_ready() {
155 if (state_ == State::Query) {
156 if (cancellation_token_.load(std::memory_order_relaxed) == 0 || cancel_slot_.was_signal()) {
157 set_error_canceled();
158 return true;
159 }
160 return false;
161 }
162 return true;
163 }
164
is_ready()165 bool is_ready() const {
166 return state_ != State::Query;
167 }
168
is_error()169 bool is_error() const {
170 return state_ == State::Error;
171 }
172
is_ok()173 bool is_ok() const {
174 return state_ == State::OK;
175 }
176
ok_tl_constructor()177 int32 ok_tl_constructor() const {
178 return tl_magic(answer_);
179 }
180
ignore()181 void ignore() const {
182 status_.ignore();
183 }
184
session_id()185 uint64 session_id() const {
186 return session_id_.load(std::memory_order_relaxed);
187 }
set_session_id(uint64 session_id)188 void set_session_id(uint64 session_id) {
189 session_id_.store(session_id, std::memory_order_relaxed);
190 }
191
message_id()192 uint64 message_id() const {
193 return message_id_;
194 }
set_message_id(uint64 message_id)195 void set_message_id(uint64 message_id) {
196 message_id_ = message_id;
197 }
198
invoke_after()199 NetQueryRef invoke_after() const {
200 return invoke_after_;
201 }
set_invoke_after(NetQueryRef ref)202 void set_invoke_after(NetQueryRef ref) {
203 invoke_after_ = ref;
204 }
set_session_rand(uint32 session_rand)205 void set_session_rand(uint32 session_rand) {
206 session_rand_ = session_rand;
207 }
session_rand()208 uint32 session_rand() const {
209 return session_rand_;
210 }
211
cancel(int32 cancellation_token)212 void cancel(int32 cancellation_token) {
213 cancellation_token_.compare_exchange_strong(cancellation_token, 0, std::memory_order_relaxed);
214 }
set_cancellation_token(int32 cancellation_token)215 void set_cancellation_token(int32 cancellation_token) {
216 cancellation_token_.store(cancellation_token, std::memory_order_relaxed);
217 }
218
clear()219 void clear() {
220 if (!is_ready()) {
221 auto guard = lock();
222 LOG(ERROR) << "Destroy not ready query " << *this << " " << tag("state", get_data_unsafe().state_);
223 }
224 // TODO: CHECK if net_query is lost here
225 cancel_slot_.close();
226 *this = NetQuery();
227 }
empty()228 bool empty() const {
229 return state_ == State::Empty || !nq_counter_ || may_be_lost_;
230 }
231
stop_track()232 void stop_track() {
233 nq_counter_ = NetQueryCounter();
234 remove();
235 }
236
debug_send_failed()237 void debug_send_failed() {
238 auto guard = lock();
239 get_data_unsafe().send_failed_count_++;
240 }
241
242 void debug(string state, bool may_be_lost = false) {
243 may_be_lost_ = may_be_lost;
244 VLOG(net_query) << *this << " " << tag("state", state);
245 {
246 auto guard = lock();
247 auto &data = get_data_unsafe();
248 data.state_ = std::move(state);
249 data.state_timestamp_ = Time::now();
250 data.state_change_count_++;
251 }
252 }
253
set_callback(ActorShared<NetQueryCallback> callback)254 void set_callback(ActorShared<NetQueryCallback> callback) {
255 callback_ = std::move(callback);
256 }
257
move_callback()258 ActorShared<NetQueryCallback> move_callback() {
259 return std::move(callback_);
260 }
261
start_migrate(int32 sched_id)262 void start_migrate(int32 sched_id) {
263 using ::td::start_migrate;
264 start_migrate(cancel_slot_, sched_id);
265 }
finish_migrate()266 void finish_migrate() {
267 using ::td::finish_migrate;
268 finish_migrate(cancel_slot_);
269 }
270
priority()271 int8 priority() const {
272 return priority_;
273 }
set_priority(int8 priority)274 void set_priority(int8 priority) {
275 priority_ = priority;
276 }
277
278 private:
279 State state_ = State::Empty;
280 Type type_ = Type::Common;
281 AuthFlag auth_flag_ = AuthFlag::Off;
282 GzipFlag gzip_flag_ = GzipFlag::Off;
283 DcId dc_id_;
284
285 NetQueryCounter nq_counter_;
286 Status status_;
287 uint64 id_ = 0;
288 BufferSlice query_;
289 BufferSlice answer_;
290 int32 tl_constructor_ = 0;
291
292 NetQueryRef invoke_after_;
293 uint32 session_rand_ = 0;
294
295 bool may_be_lost_ = false;
296 int8 priority_{0};
297
298 template <class T>
299 struct movable_atomic final : public std::atomic<T> {
300 movable_atomic() = default;
movable_atomicfinal301 movable_atomic(T &&x) : std::atomic<T>(std::forward<T>(x)) {
302 }
movable_atomicfinal303 movable_atomic(movable_atomic &&other) noexcept {
304 this->store(other.load(std::memory_order_relaxed), std::memory_order_relaxed);
305 }
306 movable_atomic &operator=(movable_atomic &&other) noexcept {
307 this->store(other.load(std::memory_order_relaxed), std::memory_order_relaxed);
308 return *this;
309 }
310 movable_atomic(const movable_atomic &other) = delete;
311 movable_atomic &operator=(const movable_atomic &other) = delete;
312 ~movable_atomic() = default;
313 };
314
315 movable_atomic<uint64> session_id_{0};
316 uint64 message_id_{0};
317
318 movable_atomic<int32> cancellation_token_{-1}; // == 0 if query is canceled
319 ActorShared<NetQueryCallback> callback_;
320
321 void set_error_impl(Status status, string source = string()) {
322 VLOG(net_query) << "Got error " << *this << " " << status;
323 status_ = std::move(status);
324 state_ = State::Error;
325 source_ = std::move(source);
326 }
327
328 static int64 get_my_id();
329
330 static int32 tl_magic(const BufferSlice &buffer_slice);
331
332 public:
333 double next_timeout_ = 1; // for NetQueryDelayer
334 double total_timeout_ = 0; // for NetQueryDelayer/SequenceDispatcher
335 double total_timeout_limit_ = 60; // for NetQueryDelayer/SequenceDispatcher and to be set by caller
336 double last_timeout_ = 0; // for NetQueryDelayer/SequenceDispatcher
337 string source_; // for NetQueryDelayer/SequenceDispatcher
338 bool need_resend_on_503_ = true; // for NetQueryDispatcher and to be set by caller
339 int32 dispatch_ttl_ = -1; // for NetQueryDispatcher and to be set by caller
340 Slot cancel_slot_; // for Session and to be set by caller
341 Promise<> quick_ack_promise_; // for Session and to be set by caller
342 int32 file_type_ = -1; // to be set by caller
343
NetQuery(State state,uint64 id,BufferSlice && query,BufferSlice && answer,DcId dc_id,Type type,AuthFlag auth_flag,GzipFlag gzip_flag,int32 tl_constructor,double total_timeout_limit,NetQueryStats * stats)344 NetQuery(State state, uint64 id, BufferSlice &&query, BufferSlice &&answer, DcId dc_id, Type type, AuthFlag auth_flag,
345 GzipFlag gzip_flag, int32 tl_constructor, double total_timeout_limit, NetQueryStats *stats)
346 : state_(state)
347 , type_(type)
348 , auth_flag_(auth_flag)
349 , gzip_flag_(gzip_flag)
350 , dc_id_(dc_id)
351 , status_()
352 , id_(id)
353 , query_(std::move(query))
354 , answer_(std::move(answer))
355 , tl_constructor_(tl_constructor)
356 , total_timeout_limit_(total_timeout_limit) {
357 CHECK(id_ != 0);
358 auto &data = get_data_unsafe();
359 data.my_id_ = get_my_id();
360 data.start_timestamp_ = data.state_timestamp_ = Time::now();
361 LOG(INFO) << *this;
362 if (stats) {
363 nq_counter_ = stats->register_query(this);
364 }
365 }
366 };
367
368 inline StringBuilder &operator<<(StringBuilder &stream, const NetQuery &net_query) {
369 stream << "[Query:";
370 stream << tag("id", net_query.id());
371 stream << tag("tl", format::as_hex(net_query.tl_constructor()));
372 if (!net_query.is_ready()) {
373 stream << tag("state", "Query");
374 } else if (net_query.is_error()) {
375 stream << tag("state", "Error");
376 stream << net_query.error();
377 } else if (net_query.is_ok()) {
378 stream << tag("state", "Result");
379 stream << tag("tl", format::as_hex(net_query.ok_tl_constructor()));
380 }
381 stream << "]";
382 return stream;
383 }
384
385 inline StringBuilder &operator<<(StringBuilder &stream, const NetQueryPtr &net_query_ptr) {
386 return stream << *net_query_ptr;
387 }
388
389 void dump_pending_network_queries();
390
cancel_query(NetQueryRef & ref)391 inline void cancel_query(NetQueryRef &ref) {
392 if (ref.empty()) {
393 return;
394 }
395 ref->cancel(ref.generation());
396 }
397
398 template <class T>
fetch_result(const BufferSlice & message)399 Result<typename T::ReturnType> fetch_result(const BufferSlice &message) {
400 TlBufferParser parser(&message);
401 auto result = T::fetch_result(parser);
402 parser.fetch_end();
403
404 const char *error = parser.get_error();
405 if (error != nullptr) {
406 LOG(ERROR) << "Can't parse: " << format::as_hex_dump<4>(message.as_slice());
407 return Status::Error(500, Slice(error));
408 }
409
410 return std::move(result);
411 }
412
413 template <class T>
fetch_result(NetQueryPtr query)414 Result<typename T::ReturnType> fetch_result(NetQueryPtr query) {
415 CHECK(!query.empty());
416 if (query->is_error()) {
417 return query->move_as_error();
418 }
419 auto buffer = query->move_as_ok();
420 return fetch_result<T>(buffer);
421 }
422
423 template <class T>
fetch_result(Result<NetQueryPtr> r_query)424 Result<typename T::ReturnType> fetch_result(Result<NetQueryPtr> r_query) {
425 TRY_RESULT(query, std::move(r_query));
426 return fetch_result<T>(std::move(query));
427 }
428
on_result(NetQueryPtr query)429 inline void NetQueryCallback::on_result(NetQueryPtr query) {
430 on_result_resendable(std::move(query), Auto());
431 }
432
on_result_resendable(NetQueryPtr query,Promise<NetQueryPtr> promise)433 inline void NetQueryCallback::on_result_resendable(NetQueryPtr query, Promise<NetQueryPtr> promise) {
434 on_result(std::move(query));
435 }
436
start_migrate(NetQueryPtr & net_query,int32 sched_id)437 inline void start_migrate(NetQueryPtr &net_query, int32 sched_id) {
438 net_query->start_migrate(sched_id);
439 }
440
finish_migrate(NetQueryPtr & net_query)441 inline void finish_migrate(NetQueryPtr &net_query) {
442 net_query->finish_migrate();
443 }
444
445 } // namespace td
446