1 //
2 // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 #pragma once
8
9 #include "td/mtproto/AuthKey.h"
10
11 #include "td/utils/common.h"
12 #include "td/utils/Slice.h"
13 #include "td/utils/Status.h"
14
15 #include <array>
16
17 namespace td {
18 namespace mtproto {
19
20 struct ServerSalt {
21 int64 salt;
22 double valid_since;
23 double valid_until;
24 };
25
26 template <class StorerT>
store(const ServerSalt & salt,StorerT & storer)27 void store(const ServerSalt &salt, StorerT &storer) {
28 storer.template store_binary<int64>(salt.salt);
29 storer.template store_binary<double>(salt.valid_since);
30 storer.template store_binary<double>(salt.valid_until);
31 }
32
33 template <class ParserT>
parse(ServerSalt & salt,ParserT & parser)34 void parse(ServerSalt &salt, ParserT &parser) {
35 salt.salt = parser.fetch_long();
36 salt.valid_since = parser.fetch_double();
37 salt.valid_until = parser.fetch_double();
38 }
39
40 Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id);
41
42 template <size_t max_size>
43 class MessageIdDuplicateChecker {
44 public:
check(int64 message_id)45 Status check(int64 message_id) {
46 return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id);
47 }
48
49 private:
50 std::array<int64, 2 * max_size> saved_message_ids_;
51 size_t end_pos_ = 0;
52 };
53
54 class AuthData {
55 public:
56 AuthData();
57 AuthData(const AuthData &) = default;
58 AuthData &operator=(const AuthData &) = delete;
59 AuthData(AuthData &&) = delete;
60 AuthData &operator=(AuthData &&) = delete;
61 ~AuthData() = default;
62
63 bool is_ready(double now);
64
set_main_auth_key(AuthKey auth_key)65 void set_main_auth_key(AuthKey auth_key) {
66 main_auth_key_ = std::move(auth_key);
67 }
break_main_auth_key()68 void break_main_auth_key() {
69 main_auth_key_.break_key();
70 }
get_main_auth_key()71 const AuthKey &get_main_auth_key() const {
72 // CHECK(has_main_auth_key());
73 return main_auth_key_;
74 }
has_main_auth_key()75 bool has_main_auth_key() const {
76 return !main_auth_key_.empty();
77 }
need_main_auth_key()78 bool need_main_auth_key() const {
79 return !has_main_auth_key();
80 }
81
set_tmp_auth_key(AuthKey auth_key)82 void set_tmp_auth_key(AuthKey auth_key) {
83 CHECK(!auth_key.empty());
84 tmp_auth_key_ = std::move(auth_key);
85 }
get_tmp_auth_key()86 const AuthKey &get_tmp_auth_key() const {
87 // CHECK(has_tmp_auth_key());
88 return tmp_auth_key_;
89 }
was_tmp_auth_key()90 bool was_tmp_auth_key() const {
91 return use_pfs() && !tmp_auth_key_.empty();
92 }
need_tmp_auth_key(double now)93 bool need_tmp_auth_key(double now) const {
94 if (!use_pfs()) {
95 return false;
96 }
97 if (tmp_auth_key_.empty()) {
98 return true;
99 }
100 if (now > tmp_auth_key_.expires_at() - 60 * 60 * 2 /*2 hours*/) {
101 return true;
102 }
103 if (!has_tmp_auth_key(now)) {
104 return true;
105 }
106 return false;
107 }
drop_main_auth_key()108 void drop_main_auth_key() {
109 main_auth_key_ = AuthKey();
110 }
drop_tmp_auth_key()111 void drop_tmp_auth_key() {
112 tmp_auth_key_ = AuthKey();
113 }
has_tmp_auth_key(double now)114 bool has_tmp_auth_key(double now) const {
115 if (!use_pfs()) {
116 return false;
117 }
118 if (tmp_auth_key_.empty()) {
119 return false;
120 }
121 if (now > tmp_auth_key_.expires_at() - 60 * 60 /*1 hour*/) {
122 return false;
123 }
124 return true;
125 }
126
get_auth_key()127 const AuthKey &get_auth_key() const {
128 if (use_pfs()) {
129 return get_tmp_auth_key();
130 }
131 return get_main_auth_key();
132 }
has_auth_key(double now)133 bool has_auth_key(double now) const {
134 if (use_pfs()) {
135 return has_tmp_auth_key(now);
136 }
137 return has_main_auth_key();
138 }
139
get_auth_flag()140 bool get_auth_flag() const {
141 return main_auth_key_.auth_flag();
142 }
set_auth_flag(bool auth_flag)143 void set_auth_flag(bool auth_flag) {
144 main_auth_key_.set_auth_flag(auth_flag);
145 if (!auth_flag) {
146 drop_tmp_auth_key();
147 }
148 }
149
get_bind_flag()150 bool get_bind_flag() const {
151 return !use_pfs() || tmp_auth_key_.auth_flag();
152 }
on_bind()153 void on_bind() {
154 CHECK(use_pfs());
155 tmp_auth_key_.set_auth_flag(true);
156 }
157
get_header()158 Slice get_header() const {
159 if (use_pfs()) {
160 return tmp_auth_key_.need_header() ? Slice(header_) : Slice();
161 } else {
162 return main_auth_key_.need_header() ? Slice(header_) : Slice();
163 }
164 }
set_header(std::string header)165 void set_header(std::string header) {
166 header_ = std::move(header);
167 }
on_api_response()168 void on_api_response() {
169 if (use_pfs()) {
170 if (tmp_auth_key_.auth_flag()) {
171 tmp_auth_key_.set_need_header(false);
172 }
173 } else {
174 if (main_auth_key_.auth_flag()) {
175 main_auth_key_.set_need_header(false);
176 }
177 }
178 }
179
set_session_id(uint64 session_id)180 void set_session_id(uint64 session_id) {
181 session_id_ = session_id;
182 }
get_session_id()183 uint64 get_session_id() const {
184 CHECK(session_id_ != 0);
185 return session_id_;
186 }
187
get_server_time(double now)188 double get_server_time(double now) const {
189 return server_time_difference_ + now;
190 }
191
get_server_time_difference()192 double get_server_time_difference() const {
193 return server_time_difference_;
194 }
195
196 // diff == msg_id / 2^32 - now == old_server_now - now <= server_now - now
197 // server_time_difference >= max{diff}
198 bool update_server_time_difference(double diff);
199
set_server_time_difference(double diff)200 void set_server_time_difference(double diff) {
201 server_time_difference_was_updated_ = false;
202 server_time_difference_ = diff;
203 }
204
get_server_salt(double now)205 uint64 get_server_salt(double now) {
206 update_salt(now);
207 return server_salt_.salt;
208 }
209
set_server_salt(uint64 salt,double now)210 void set_server_salt(uint64 salt, double now) {
211 server_salt_.salt = salt;
212 double server_time = get_server_time(now);
213 server_salt_.valid_since = server_time;
214 server_salt_.valid_until = server_time + 60 * 10;
215 future_salts_.clear();
216 }
217
is_server_salt_valid(double now)218 bool is_server_salt_valid(double now) const {
219 return server_salt_.valid_until > get_server_time(now) + 60;
220 }
221
has_salt(double now)222 bool has_salt(double now) {
223 update_salt(now);
224 return is_server_salt_valid(now);
225 }
226
need_future_salts(double now)227 bool need_future_salts(double now) {
228 update_salt(now);
229 return future_salts_.empty() || !is_server_salt_valid(now);
230 }
231
232 void set_future_salts(const std::vector<ServerSalt> &salts, double now);
233
234 std::vector<ServerSalt> get_future_salts() const;
235
236 int64 next_message_id(double now);
237
238 bool is_valid_outbound_msg_id(int64 id, double now) const;
239
240 bool is_valid_inbound_msg_id(int64 id, double now) const;
241
242 Status check_packet(int64 session_id, int64 message_id, double now, bool &time_difference_was_updated);
243
check_update(int64 message_id)244 Status check_update(int64 message_id) {
245 return updates_duplicate_checker_.check(message_id);
246 }
247
recheck_update(int64 message_id)248 Status recheck_update(int64 message_id) {
249 return updates_duplicate_rechecker_.check(message_id);
250 }
251
next_seq_no(bool is_content_related)252 int32 next_seq_no(bool is_content_related) {
253 int32 res = seq_no_;
254 if (is_content_related) {
255 res |= 1;
256 seq_no_ += 2;
257 }
258 return res;
259 }
260
clear_seq_no()261 void clear_seq_no() {
262 seq_no_ = 0;
263 }
264
set_use_pfs(bool use_pfs)265 void set_use_pfs(bool use_pfs) {
266 use_pfs_ = use_pfs;
267 }
use_pfs()268 bool use_pfs() const {
269 return use_pfs_;
270 }
271
272 private:
273 bool use_pfs_ = true;
274 AuthKey main_auth_key_;
275 AuthKey tmp_auth_key_;
276 bool server_time_difference_was_updated_ = false;
277 double server_time_difference_ = 0;
278 ServerSalt server_salt_;
279 int64 last_message_id_ = 0;
280 int32 seq_no_ = 0;
281 std::string header_;
282 uint64 session_id_ = 0;
283
284 std::vector<ServerSalt> future_salts_;
285
286 MessageIdDuplicateChecker<1000> duplicate_checker_;
287 MessageIdDuplicateChecker<1000> updates_duplicate_checker_;
288 MessageIdDuplicateChecker<100> updates_duplicate_rechecker_;
289
290 void update_salt(double now);
291 };
292
293 } // namespace mtproto
294 } // namespace td
295