1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #pragma once
8 
9 #include "td/mtproto/AuthKey.h"
10 
11 #include "td/utils/common.h"
12 #include "td/utils/Slice.h"
13 #include "td/utils/Status.h"
14 
15 #include <array>
16 
17 namespace td {
18 namespace mtproto {
19 
20 struct ServerSalt {
21   int64 salt;
22   double valid_since;
23   double valid_until;
24 };
25 
26 template <class StorerT>
store(const ServerSalt & salt,StorerT & storer)27 void store(const ServerSalt &salt, StorerT &storer) {
28   storer.template store_binary<int64>(salt.salt);
29   storer.template store_binary<double>(salt.valid_since);
30   storer.template store_binary<double>(salt.valid_until);
31 }
32 
33 template <class ParserT>
parse(ServerSalt & salt,ParserT & parser)34 void parse(ServerSalt &salt, ParserT &parser) {
35   salt.salt = parser.fetch_long();
36   salt.valid_since = parser.fetch_double();
37   salt.valid_until = parser.fetch_double();
38 }
39 
40 Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id);
41 
42 template <size_t max_size>
43 class MessageIdDuplicateChecker {
44  public:
check(int64 message_id)45   Status check(int64 message_id) {
46     return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id);
47   }
48 
49  private:
50   std::array<int64, 2 * max_size> saved_message_ids_;
51   size_t end_pos_ = 0;
52 };
53 
54 class AuthData {
55  public:
56   AuthData();
57   AuthData(const AuthData &) = default;
58   AuthData &operator=(const AuthData &) = delete;
59   AuthData(AuthData &&) = delete;
60   AuthData &operator=(AuthData &&) = delete;
61   ~AuthData() = default;
62 
63   bool is_ready(double now);
64 
set_main_auth_key(AuthKey auth_key)65   void set_main_auth_key(AuthKey auth_key) {
66     main_auth_key_ = std::move(auth_key);
67   }
break_main_auth_key()68   void break_main_auth_key() {
69     main_auth_key_.break_key();
70   }
get_main_auth_key()71   const AuthKey &get_main_auth_key() const {
72     // CHECK(has_main_auth_key());
73     return main_auth_key_;
74   }
has_main_auth_key()75   bool has_main_auth_key() const {
76     return !main_auth_key_.empty();
77   }
need_main_auth_key()78   bool need_main_auth_key() const {
79     return !has_main_auth_key();
80   }
81 
set_tmp_auth_key(AuthKey auth_key)82   void set_tmp_auth_key(AuthKey auth_key) {
83     CHECK(!auth_key.empty());
84     tmp_auth_key_ = std::move(auth_key);
85   }
get_tmp_auth_key()86   const AuthKey &get_tmp_auth_key() const {
87     // CHECK(has_tmp_auth_key());
88     return tmp_auth_key_;
89   }
was_tmp_auth_key()90   bool was_tmp_auth_key() const {
91     return use_pfs() && !tmp_auth_key_.empty();
92   }
need_tmp_auth_key(double now)93   bool need_tmp_auth_key(double now) const {
94     if (!use_pfs()) {
95       return false;
96     }
97     if (tmp_auth_key_.empty()) {
98       return true;
99     }
100     if (now > tmp_auth_key_.expires_at() - 60 * 60 * 2 /*2 hours*/) {
101       return true;
102     }
103     if (!has_tmp_auth_key(now)) {
104       return true;
105     }
106     return false;
107   }
drop_main_auth_key()108   void drop_main_auth_key() {
109     main_auth_key_ = AuthKey();
110   }
drop_tmp_auth_key()111   void drop_tmp_auth_key() {
112     tmp_auth_key_ = AuthKey();
113   }
has_tmp_auth_key(double now)114   bool has_tmp_auth_key(double now) const {
115     if (!use_pfs()) {
116       return false;
117     }
118     if (tmp_auth_key_.empty()) {
119       return false;
120     }
121     if (now > tmp_auth_key_.expires_at() - 60 * 60 /*1 hour*/) {
122       return false;
123     }
124     return true;
125   }
126 
get_auth_key()127   const AuthKey &get_auth_key() const {
128     if (use_pfs()) {
129       return get_tmp_auth_key();
130     }
131     return get_main_auth_key();
132   }
has_auth_key(double now)133   bool has_auth_key(double now) const {
134     if (use_pfs()) {
135       return has_tmp_auth_key(now);
136     }
137     return has_main_auth_key();
138   }
139 
get_auth_flag()140   bool get_auth_flag() const {
141     return main_auth_key_.auth_flag();
142   }
set_auth_flag(bool auth_flag)143   void set_auth_flag(bool auth_flag) {
144     main_auth_key_.set_auth_flag(auth_flag);
145     if (!auth_flag) {
146       drop_tmp_auth_key();
147     }
148   }
149 
get_bind_flag()150   bool get_bind_flag() const {
151     return !use_pfs() || tmp_auth_key_.auth_flag();
152   }
on_bind()153   void on_bind() {
154     CHECK(use_pfs());
155     tmp_auth_key_.set_auth_flag(true);
156   }
157 
get_header()158   Slice get_header() const {
159     if (use_pfs()) {
160       return tmp_auth_key_.need_header() ? Slice(header_) : Slice();
161     } else {
162       return main_auth_key_.need_header() ? Slice(header_) : Slice();
163     }
164   }
set_header(std::string header)165   void set_header(std::string header) {
166     header_ = std::move(header);
167   }
on_api_response()168   void on_api_response() {
169     if (use_pfs()) {
170       if (tmp_auth_key_.auth_flag()) {
171         tmp_auth_key_.set_need_header(false);
172       }
173     } else {
174       if (main_auth_key_.auth_flag()) {
175         main_auth_key_.set_need_header(false);
176       }
177     }
178   }
179 
set_session_id(uint64 session_id)180   void set_session_id(uint64 session_id) {
181     session_id_ = session_id;
182   }
get_session_id()183   uint64 get_session_id() const {
184     CHECK(session_id_ != 0);
185     return session_id_;
186   }
187 
get_server_time(double now)188   double get_server_time(double now) const {
189     return server_time_difference_ + now;
190   }
191 
get_server_time_difference()192   double get_server_time_difference() const {
193     return server_time_difference_;
194   }
195 
196   // diff == msg_id / 2^32 - now == old_server_now - now <= server_now - now
197   // server_time_difference >= max{diff}
198   bool update_server_time_difference(double diff);
199 
set_server_time_difference(double diff)200   void set_server_time_difference(double diff) {
201     server_time_difference_was_updated_ = false;
202     server_time_difference_ = diff;
203   }
204 
get_server_salt(double now)205   uint64 get_server_salt(double now) {
206     update_salt(now);
207     return server_salt_.salt;
208   }
209 
set_server_salt(uint64 salt,double now)210   void set_server_salt(uint64 salt, double now) {
211     server_salt_.salt = salt;
212     double server_time = get_server_time(now);
213     server_salt_.valid_since = server_time;
214     server_salt_.valid_until = server_time + 60 * 10;
215     future_salts_.clear();
216   }
217 
is_server_salt_valid(double now)218   bool is_server_salt_valid(double now) const {
219     return server_salt_.valid_until > get_server_time(now) + 60;
220   }
221 
has_salt(double now)222   bool has_salt(double now) {
223     update_salt(now);
224     return is_server_salt_valid(now);
225   }
226 
need_future_salts(double now)227   bool need_future_salts(double now) {
228     update_salt(now);
229     return future_salts_.empty() || !is_server_salt_valid(now);
230   }
231 
232   void set_future_salts(const std::vector<ServerSalt> &salts, double now);
233 
234   std::vector<ServerSalt> get_future_salts() const;
235 
236   int64 next_message_id(double now);
237 
238   bool is_valid_outbound_msg_id(int64 id, double now) const;
239 
240   bool is_valid_inbound_msg_id(int64 id, double now) const;
241 
242   Status check_packet(int64 session_id, int64 message_id, double now, bool &time_difference_was_updated);
243 
check_update(int64 message_id)244   Status check_update(int64 message_id) {
245     return updates_duplicate_checker_.check(message_id);
246   }
247 
recheck_update(int64 message_id)248   Status recheck_update(int64 message_id) {
249     return updates_duplicate_rechecker_.check(message_id);
250   }
251 
next_seq_no(bool is_content_related)252   int32 next_seq_no(bool is_content_related) {
253     int32 res = seq_no_;
254     if (is_content_related) {
255       res |= 1;
256       seq_no_ += 2;
257     }
258     return res;
259   }
260 
clear_seq_no()261   void clear_seq_no() {
262     seq_no_ = 0;
263   }
264 
set_use_pfs(bool use_pfs)265   void set_use_pfs(bool use_pfs) {
266     use_pfs_ = use_pfs;
267   }
use_pfs()268   bool use_pfs() const {
269     return use_pfs_;
270   }
271 
272  private:
273   bool use_pfs_ = true;
274   AuthKey main_auth_key_;
275   AuthKey tmp_auth_key_;
276   bool server_time_difference_was_updated_ = false;
277   double server_time_difference_ = 0;
278   ServerSalt server_salt_;
279   int64 last_message_id_ = 0;
280   int32 seq_no_ = 0;
281   std::string header_;
282   uint64 session_id_ = 0;
283 
284   std::vector<ServerSalt> future_salts_;
285 
286   MessageIdDuplicateChecker<1000> duplicate_checker_;
287   MessageIdDuplicateChecker<1000> updates_duplicate_checker_;
288   MessageIdDuplicateChecker<100> updates_duplicate_rechecker_;
289 
290   void update_salt(double now);
291 };
292 
293 }  // namespace mtproto
294 }  // namespace td
295