1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #pragma once
8 
9 #include "td/telegram/net/DcId.h"
10 #include "td/telegram/net/NetQueryCounter.h"
11 #include "td/telegram/net/NetQueryStats.h"
12 
13 #include "td/actor/actor.h"
14 #include "td/actor/PromiseFuture.h"
15 #include "td/actor/SignalSlot.h"
16 
17 #include "td/utils/buffer.h"
18 #include "td/utils/common.h"
19 #include "td/utils/format.h"
20 #include "td/utils/logging.h"
21 #include "td/utils/ObjectPool.h"
22 #include "td/utils/Slice.h"
23 #include "td/utils/Status.h"
24 #include "td/utils/StringBuilder.h"
25 #include "td/utils/Time.h"
26 #include "td/utils/tl_parsers.h"
27 #include "td/utils/TsList.h"
28 
29 #include <atomic>
30 #include <utility>
31 
32 namespace td {
33 
34 extern int VERBOSITY_NAME(net_query);
35 
36 class NetQuery;
37 using NetQueryPtr = ObjectPool<NetQuery>::OwnerPtr;
38 using NetQueryRef = ObjectPool<NetQuery>::WeakPtr;
39 
40 class NetQueryCallback : public Actor {
41  public:
42   virtual void on_result(NetQueryPtr query);
43   virtual void on_result_resendable(NetQueryPtr query, Promise<NetQueryPtr> promise);
44 };
45 
46 class NetQuery final : public TsListNode<NetQueryDebug> {
47  public:
48   NetQuery() = default;
49 
50   enum class State : int8 { Empty, Query, OK, Error };
51   enum class Type : int8 { Common, Upload, Download, DownloadSmall };
52   enum class AuthFlag : int8 { Off, On };
53   enum class GzipFlag : int8 { Off, On };
54   enum Error : int32 { Resend = 202, Canceled = 203, ResendInvokeAfter = 204 };
55 
id()56   uint64 id() const {
57     return id_;
58   }
59 
dc_id()60   DcId dc_id() const {
61     return dc_id_;
62   }
63 
type()64   Type type() const {
65     return type_;
66   }
67 
gzip_flag()68   GzipFlag gzip_flag() const {
69     return gzip_flag_;
70   }
71 
auth_flag()72   AuthFlag auth_flag() const {
73     return auth_flag_;
74   }
75 
tl_constructor()76   int32 tl_constructor() const {
77     return tl_constructor_;
78   }
79 
resend(DcId new_dc_id)80   void resend(DcId new_dc_id) {
81     VLOG(net_query) << "Resend" << *this;
82     {
83       auto guard = lock();
84       get_data_unsafe().resend_count_++;
85     }
86     dc_id_ = new_dc_id;
87     status_ = Status::OK();
88     state_ = State::Query;
89   }
90 
resend()91   void resend() {
92     resend(dc_id_);
93   }
94 
query()95   BufferSlice &query() {
96     return query_;
97   }
98 
ok()99   BufferSlice &ok() {
100     CHECK(state_ == State::OK);
101     return answer_;
102   }
103 
ok()104   const BufferSlice &ok() const {
105     CHECK(state_ == State::OK);
106     return answer_;
107   }
108 
error()109   Status &error() {
110     CHECK(state_ == State::Error);
111     return status_;
112   }
113 
error()114   const Status &error() const {
115     CHECK(state_ == State::Error);
116     return status_;
117   }
118 
move_as_ok()119   BufferSlice move_as_ok() {
120     auto ok = std::move(answer_);
121     clear();
122     return ok;
123   }
move_as_error()124   Status move_as_error() TD_WARN_UNUSED_RESULT {
125     auto status = std::move(status_);
126     clear();
127     return status;
128   }
129 
set_ok(BufferSlice slice)130   void set_ok(BufferSlice slice) {
131     VLOG(net_query) << "Got answer " << *this;
132     CHECK(state_ == State::Query);
133     answer_ = std::move(slice);
134     state_ = State::OK;
135   }
136 
137   void on_net_write(size_t size);
138   void on_net_read(size_t size);
139 
140   void set_error(Status status, string source = string());
141 
set_error_resend()142   void set_error_resend() {
143     set_error_impl(Status::Error<Error::Resend>());
144   }
145 
set_error_canceled()146   void set_error_canceled() {
147     set_error_impl(Status::Error<Error::Canceled>());
148   }
149 
set_error_resend_invoke_after()150   void set_error_resend_invoke_after() {
151     set_error_impl(Status::Error<Error::ResendInvokeAfter>());
152   }
153 
update_is_ready()154   bool update_is_ready() {
155     if (state_ == State::Query) {
156       if (cancellation_token_.load(std::memory_order_relaxed) == 0 || cancel_slot_.was_signal()) {
157         set_error_canceled();
158         return true;
159       }
160       return false;
161     }
162     return true;
163   }
164 
is_ready()165   bool is_ready() const {
166     return state_ != State::Query;
167   }
168 
is_error()169   bool is_error() const {
170     return state_ == State::Error;
171   }
172 
is_ok()173   bool is_ok() const {
174     return state_ == State::OK;
175   }
176 
ok_tl_constructor()177   int32 ok_tl_constructor() const {
178     return tl_magic(answer_);
179   }
180 
ignore()181   void ignore() const {
182     status_.ignore();
183   }
184 
session_id()185   uint64 session_id() const {
186     return session_id_.load(std::memory_order_relaxed);
187   }
set_session_id(uint64 session_id)188   void set_session_id(uint64 session_id) {
189     session_id_.store(session_id, std::memory_order_relaxed);
190   }
191 
message_id()192   uint64 message_id() const {
193     return message_id_;
194   }
set_message_id(uint64 message_id)195   void set_message_id(uint64 message_id) {
196     message_id_ = message_id;
197   }
198 
invoke_after()199   NetQueryRef invoke_after() const {
200     return invoke_after_;
201   }
set_invoke_after(NetQueryRef ref)202   void set_invoke_after(NetQueryRef ref) {
203     invoke_after_ = ref;
204   }
set_session_rand(uint32 session_rand)205   void set_session_rand(uint32 session_rand) {
206     session_rand_ = session_rand;
207   }
session_rand()208   uint32 session_rand() const {
209     return session_rand_;
210   }
211 
cancel(int32 cancellation_token)212   void cancel(int32 cancellation_token) {
213     cancellation_token_.compare_exchange_strong(cancellation_token, 0, std::memory_order_relaxed);
214   }
set_cancellation_token(int32 cancellation_token)215   void set_cancellation_token(int32 cancellation_token) {
216     cancellation_token_.store(cancellation_token, std::memory_order_relaxed);
217   }
218 
clear()219   void clear() {
220     if (!is_ready()) {
221       auto guard = lock();
222       LOG(ERROR) << "Destroy not ready query " << *this << " " << tag("state", get_data_unsafe().state_);
223     }
224     // TODO: CHECK if net_query is lost here
225     cancel_slot_.close();
226     *this = NetQuery();
227   }
empty()228   bool empty() const {
229     return state_ == State::Empty || !nq_counter_ || may_be_lost_;
230   }
231 
stop_track()232   void stop_track() {
233     nq_counter_ = NetQueryCounter();
234     remove();
235   }
236 
debug_send_failed()237   void debug_send_failed() {
238     auto guard = lock();
239     get_data_unsafe().send_failed_count_++;
240   }
241 
242   void debug(string state, bool may_be_lost = false) {
243     may_be_lost_ = may_be_lost;
244     VLOG(net_query) << *this << " " << tag("state", state);
245     {
246       auto guard = lock();
247       auto &data = get_data_unsafe();
248       data.state_ = std::move(state);
249       data.state_timestamp_ = Time::now();
250       data.state_change_count_++;
251     }
252   }
253 
set_callback(ActorShared<NetQueryCallback> callback)254   void set_callback(ActorShared<NetQueryCallback> callback) {
255     callback_ = std::move(callback);
256   }
257 
move_callback()258   ActorShared<NetQueryCallback> move_callback() {
259     return std::move(callback_);
260   }
261 
start_migrate(int32 sched_id)262   void start_migrate(int32 sched_id) {
263     using ::td::start_migrate;
264     start_migrate(cancel_slot_, sched_id);
265   }
finish_migrate()266   void finish_migrate() {
267     using ::td::finish_migrate;
268     finish_migrate(cancel_slot_);
269   }
270 
priority()271   int8 priority() const {
272     return priority_;
273   }
set_priority(int8 priority)274   void set_priority(int8 priority) {
275     priority_ = priority;
276   }
277 
278  private:
279   State state_ = State::Empty;
280   Type type_ = Type::Common;
281   AuthFlag auth_flag_ = AuthFlag::Off;
282   GzipFlag gzip_flag_ = GzipFlag::Off;
283   DcId dc_id_;
284 
285   NetQueryCounter nq_counter_;
286   Status status_;
287   uint64 id_ = 0;
288   BufferSlice query_;
289   BufferSlice answer_;
290   int32 tl_constructor_ = 0;
291 
292   NetQueryRef invoke_after_;
293   uint32 session_rand_ = 0;
294 
295   bool may_be_lost_ = false;
296   int8 priority_{0};
297 
298   template <class T>
299   struct movable_atomic final : public std::atomic<T> {
300     movable_atomic() = default;
movable_atomicfinal301     movable_atomic(T &&x) : std::atomic<T>(std::forward<T>(x)) {
302     }
movable_atomicfinal303     movable_atomic(movable_atomic &&other) noexcept {
304       this->store(other.load(std::memory_order_relaxed), std::memory_order_relaxed);
305     }
306     movable_atomic &operator=(movable_atomic &&other) noexcept {
307       this->store(other.load(std::memory_order_relaxed), std::memory_order_relaxed);
308       return *this;
309     }
310     movable_atomic(const movable_atomic &other) = delete;
311     movable_atomic &operator=(const movable_atomic &other) = delete;
312     ~movable_atomic() = default;
313   };
314 
315   movable_atomic<uint64> session_id_{0};
316   uint64 message_id_{0};
317 
318   movable_atomic<int32> cancellation_token_{-1};  // == 0 if query is canceled
319   ActorShared<NetQueryCallback> callback_;
320 
321   void set_error_impl(Status status, string source = string()) {
322     VLOG(net_query) << "Got error " << *this << " " << status;
323     status_ = std::move(status);
324     state_ = State::Error;
325     source_ = std::move(source);
326   }
327 
328   static int64 get_my_id();
329 
330   static int32 tl_magic(const BufferSlice &buffer_slice);
331 
332  public:
333   double next_timeout_ = 1;          // for NetQueryDelayer
334   double total_timeout_ = 0;         // for NetQueryDelayer/SequenceDispatcher
335   double total_timeout_limit_ = 60;  // for NetQueryDelayer/SequenceDispatcher and to be set by caller
336   double last_timeout_ = 0;          // for NetQueryDelayer/SequenceDispatcher
337   string source_;                    // for NetQueryDelayer/SequenceDispatcher
338   bool need_resend_on_503_ = true;   // for NetQueryDispatcher and to be set by caller
339   int32 dispatch_ttl_ = -1;          // for NetQueryDispatcher and to be set by caller
340   Slot cancel_slot_;                 // for Session and to be set by caller
341   Promise<> quick_ack_promise_;      // for Session and to be set by caller
342   int32 file_type_ = -1;             // to be set by caller
343 
NetQuery(State state,uint64 id,BufferSlice && query,BufferSlice && answer,DcId dc_id,Type type,AuthFlag auth_flag,GzipFlag gzip_flag,int32 tl_constructor,double total_timeout_limit,NetQueryStats * stats)344   NetQuery(State state, uint64 id, BufferSlice &&query, BufferSlice &&answer, DcId dc_id, Type type, AuthFlag auth_flag,
345            GzipFlag gzip_flag, int32 tl_constructor, double total_timeout_limit, NetQueryStats *stats)
346       : state_(state)
347       , type_(type)
348       , auth_flag_(auth_flag)
349       , gzip_flag_(gzip_flag)
350       , dc_id_(dc_id)
351       , status_()
352       , id_(id)
353       , query_(std::move(query))
354       , answer_(std::move(answer))
355       , tl_constructor_(tl_constructor)
356       , total_timeout_limit_(total_timeout_limit) {
357     CHECK(id_ != 0);
358     auto &data = get_data_unsafe();
359     data.my_id_ = get_my_id();
360     data.start_timestamp_ = data.state_timestamp_ = Time::now();
361     LOG(INFO) << *this;
362     if (stats) {
363       nq_counter_ = stats->register_query(this);
364     }
365   }
366 };
367 
368 inline StringBuilder &operator<<(StringBuilder &stream, const NetQuery &net_query) {
369   stream << "[Query:";
370   stream << tag("id", net_query.id());
371   stream << tag("tl", format::as_hex(net_query.tl_constructor()));
372   if (!net_query.is_ready()) {
373     stream << tag("state", "Query");
374   } else if (net_query.is_error()) {
375     stream << tag("state", "Error");
376     stream << net_query.error();
377   } else if (net_query.is_ok()) {
378     stream << tag("state", "Result");
379     stream << tag("tl", format::as_hex(net_query.ok_tl_constructor()));
380   }
381   stream << "]";
382   return stream;
383 }
384 
385 inline StringBuilder &operator<<(StringBuilder &stream, const NetQueryPtr &net_query_ptr) {
386   return stream << *net_query_ptr;
387 }
388 
389 void dump_pending_network_queries();
390 
cancel_query(NetQueryRef & ref)391 inline void cancel_query(NetQueryRef &ref) {
392   if (ref.empty()) {
393     return;
394   }
395   ref->cancel(ref.generation());
396 }
397 
398 template <class T>
fetch_result(const BufferSlice & message)399 Result<typename T::ReturnType> fetch_result(const BufferSlice &message) {
400   TlBufferParser parser(&message);
401   auto result = T::fetch_result(parser);
402   parser.fetch_end();
403 
404   const char *error = parser.get_error();
405   if (error != nullptr) {
406     LOG(ERROR) << "Can't parse: " << format::as_hex_dump<4>(message.as_slice());
407     return Status::Error(500, Slice(error));
408   }
409 
410   return std::move(result);
411 }
412 
413 template <class T>
fetch_result(NetQueryPtr query)414 Result<typename T::ReturnType> fetch_result(NetQueryPtr query) {
415   CHECK(!query.empty());
416   if (query->is_error()) {
417     return query->move_as_error();
418   }
419   auto buffer = query->move_as_ok();
420   return fetch_result<T>(buffer);
421 }
422 
423 template <class T>
fetch_result(Result<NetQueryPtr> r_query)424 Result<typename T::ReturnType> fetch_result(Result<NetQueryPtr> r_query) {
425   TRY_RESULT(query, std::move(r_query));
426   return fetch_result<T>(std::move(query));
427 }
428 
on_result(NetQueryPtr query)429 inline void NetQueryCallback::on_result(NetQueryPtr query) {
430   on_result_resendable(std::move(query), Auto());
431 }
432 
on_result_resendable(NetQueryPtr query,Promise<NetQueryPtr> promise)433 inline void NetQueryCallback::on_result_resendable(NetQueryPtr query, Promise<NetQueryPtr> promise) {
434   on_result(std::move(query));
435 }
436 
start_migrate(NetQueryPtr & net_query,int32 sched_id)437 inline void start_migrate(NetQueryPtr &net_query, int32 sched_id) {
438   net_query->start_migrate(sched_id);
439 }
440 
finish_migrate(NetQueryPtr & net_query)441 inline void finish_migrate(NetQueryPtr &net_query) {
442   net_query->finish_migrate();
443 }
444 
445 }  // namespace td
446