1 /*
2 This file is part of Telegram Desktop,
3 the official desktop application for the Telegram messaging service.
4
5 For license and copyright information please follow this link:
6 https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
7 */
8 #pragma once
9
10 #include "base/variant.h"
11 #include "mtproto/mtproto_response.h"
12 #include "mtproto/mtp_instance.h"
13 #include "mtproto/facade.h"
14
15 namespace MTP {
16
17 class Sender {
18 class RequestBuilder {
19 public:
20 RequestBuilder(const RequestBuilder &other) = delete;
21 RequestBuilder &operator=(const RequestBuilder &other) = delete;
22 RequestBuilder &operator=(RequestBuilder &&other) = delete;
23
24 protected:
25 enum class FailSkipPolicy {
26 Simple,
27 HandleFlood,
28 HandleAll,
29 };
30 using FailPlainHandler = Fn<void()>;
31 using FailErrorHandler = Fn<void(const Error&)>;
32 using FailRequestIdHandler = Fn<void(const Error&, mtpRequestId)>;
33 using FailFullHandler = Fn<void(const Error&, const Response&)>;
34
35 template <typename ...Args>
36 static constexpr bool IsCallable
37 = rpl::details::is_callable_plain_v<Args...>;
38
39 template <typename Result, typename Handler>
MakeDoneHandler(not_null<Sender * > sender,Handler && handler)40 [[nodiscard]] DoneHandler MakeDoneHandler(
41 not_null<Sender*> sender,
42 Handler &&handler) {
43 return [sender, handler = std::forward<Handler>(handler)](
44 const Response &response) mutable {
45 auto onstack = std::move(handler);
46 sender->senderRequestHandled(response.requestId);
47
48 auto result = Result();
49 auto from = response.reply.constData();
50 if (!result.read(from, from + response.reply.size())) {
51 return false;
52 } else if (!onstack) {
53 return true;
54 } else if constexpr (IsCallable<
55 Handler,
56 const Result&,
57 const Response&>) {
58 onstack(result, response);
59 } else if constexpr (IsCallable<
60 Handler,
61 const Result&,
62 mtpRequestId>) {
63 onstack(result, response.requestId);
64 } else if constexpr (IsCallable<
65 Handler,
66 const Result&>) {
67 onstack(result);
68 } else if constexpr (IsCallable<Handler>) {
69 onstack();
70 } else {
71 static_assert(false_t(Handler{}), "Bad done handler.");
72 }
73 return true;
74 };
75 }
76
77 template <typename Handler>
MakeFailHandler(not_null<Sender * > sender,Handler && handler,FailSkipPolicy skipPolicy)78 [[nodiscard]] FailHandler MakeFailHandler(
79 not_null<Sender*> sender,
80 Handler &&handler,
81 FailSkipPolicy skipPolicy) {
82 return [
83 sender,
84 handler = std::forward<Handler>(handler),
85 skipPolicy
86 ](const Error &error, const Response &response) {
87 if (skipPolicy == FailSkipPolicy::Simple) {
88 if (IsDefaultHandledError(error)) {
89 return false;
90 }
91 } else if (skipPolicy == FailSkipPolicy::HandleFlood) {
92 if (IsDefaultHandledError(error) && !IsFloodError(error)) {
93 return false;
94 }
95 }
96
97 auto onstack = handler;
98 sender->senderRequestHandled(response.requestId);
99
100 if (!onstack) {
101 return true;
102 } else if constexpr (IsCallable<
103 Handler,
104 const Error&,
105 const Response&>) {
106 onstack(error, response);
107 } else if constexpr (IsCallable<
108 Handler,
109 const Error&,
110 mtpRequestId>) {
111 onstack(error, response.requestId);
112 } else if constexpr (IsCallable<
113 Handler,
114 const Error&>) {
115 onstack(error);
116 } else if constexpr (IsCallable<Handler>) {
117 onstack();
118 } else {
119 static_assert(false_t(Handler{}), "Bad fail handler.");
120 }
121 return true;
122 };
123 }
124
RequestBuilder(not_null<Sender * > sender)125 explicit RequestBuilder(not_null<Sender*> sender) noexcept
126 : _sender(sender) {
127 }
128 RequestBuilder(RequestBuilder &&other) = default;
129
setToDC(ShiftedDcId dcId)130 void setToDC(ShiftedDcId dcId) noexcept {
131 _dcId = dcId;
132 }
setCanWait(crl::time ms)133 void setCanWait(crl::time ms) noexcept {
134 _canWait = ms;
135 }
setDoneHandler(DoneHandler && handler)136 void setDoneHandler(DoneHandler &&handler) noexcept {
137 _done = std::move(handler);
138 }
139 template <typename Handler>
setFailHandler(Handler && handler)140 void setFailHandler(Handler &&handler) noexcept {
141 _fail = std::forward<Handler>(handler);
142 }
setFailSkipPolicy(FailSkipPolicy policy)143 void setFailSkipPolicy(FailSkipPolicy policy) noexcept {
144 _failSkipPolicy = policy;
145 }
setAfter(mtpRequestId requestId)146 void setAfter(mtpRequestId requestId) noexcept {
147 _afterRequestId = requestId;
148 }
149
takeDcId()150 ShiftedDcId takeDcId() const noexcept {
151 return _dcId;
152 }
takeCanWait()153 crl::time takeCanWait() const noexcept {
154 return _canWait;
155 }
takeOnDone()156 DoneHandler takeOnDone() noexcept {
157 return std::move(_done);
158 }
takeOnFail()159 FailHandler takeOnFail() {
160 return v::match(_fail, [&](auto &value) {
161 return MakeFailHandler(
162 _sender,
163 std::move(value),
164 _failSkipPolicy);
165 });
166 }
takeAfter()167 mtpRequestId takeAfter() const noexcept {
168 return _afterRequestId;
169 }
170
sender()171 not_null<Sender*> sender() const noexcept {
172 return _sender;
173 }
registerRequest(mtpRequestId requestId)174 void registerRequest(mtpRequestId requestId) {
175 _sender->senderRequestRegister(requestId);
176 }
177
178 private:
179 not_null<Sender*> _sender;
180 ShiftedDcId _dcId = 0;
181 crl::time _canWait = 0;
182 DoneHandler _done;
183 std::variant<
184 FailPlainHandler,
185 FailErrorHandler,
186 FailRequestIdHandler,
187 FailFullHandler> _fail;
188 FailSkipPolicy _failSkipPolicy = FailSkipPolicy::Simple;
189 mtpRequestId _afterRequestId = 0;
190
191 };
192
193 public:
Sender(not_null<Instance * > instance)194 explicit Sender(not_null<Instance*> instance) noexcept
195 : _instance(instance) {
196 }
197
instance()198 [[nodiscard]] Instance &instance() const {
199 return *_instance;
200 }
201
202 template <typename Request>
203 class SpecificRequestBuilder : public RequestBuilder {
204 private:
205 friend class Sender;
SpecificRequestBuilder(not_null<Sender * > sender,Request && request)206 SpecificRequestBuilder(not_null<Sender*> sender, Request &&request) noexcept
207 : RequestBuilder(sender)
208 , _request(std::move(request)) {
209 }
210 SpecificRequestBuilder(SpecificRequestBuilder &&other) = default;
211
212 public:
toDC(ShiftedDcId dcId)213 [[nodiscard]] SpecificRequestBuilder &toDC(ShiftedDcId dcId) noexcept {
214 setToDC(dcId);
215 return *this;
216 }
afterDelay(crl::time ms)217 [[nodiscard]] SpecificRequestBuilder &afterDelay(crl::time ms) noexcept {
218 setCanWait(ms);
219 return *this;
220 }
221
222 using Result = typename Request::ResponseType;
done(FnMut<void (const Result & result,mtpRequestId requestId)> callback)223 [[nodiscard]] SpecificRequestBuilder &done(
224 FnMut<void(
225 const Result &result,
226 mtpRequestId requestId)> callback) {
227 setDoneHandler(
228 MakeDoneHandler<Result>(sender(), std::move(callback)));
229 return *this;
230 }
done(FnMut<void (const Result & result,const Response & response)> callback)231 [[nodiscard]] SpecificRequestBuilder &done(
232 FnMut<void(
233 const Result &result,
234 const Response &response)> callback) {
235 setDoneHandler(
236 MakeDoneHandler<Result>(sender(), std::move(callback)));
237 return *this;
238 }
done(FnMut<void ()> callback)239 [[nodiscard]] SpecificRequestBuilder &done(
240 FnMut<void()> callback) {
241 setDoneHandler(
242 MakeDoneHandler<Result>(sender(), std::move(callback)));
243 return *this;
244 }
done(FnMut<void (const typename Request::ResponseType & result)> callback)245 [[nodiscard]] SpecificRequestBuilder &done(
246 FnMut<void(
247 const typename Request::ResponseType &result)> callback) {
248 setDoneHandler(
249 MakeDoneHandler<Result>(sender(), std::move(callback)));
250 return *this;
251 }
252
fail(Fn<void (const Error & error,mtpRequestId requestId)> callback)253 [[nodiscard]] SpecificRequestBuilder &fail(
254 Fn<void(
255 const Error &error,
256 mtpRequestId requestId)> callback) noexcept {
257 setFailHandler(std::move(callback));
258 return *this;
259 }
fail(Fn<void (const Error & error,const Response & response)> callback)260 [[nodiscard]] SpecificRequestBuilder &fail(
261 Fn<void(
262 const Error &error,
263 const Response &response)> callback) noexcept {
264 setFailHandler(std::move(callback));
265 return *this;
266 }
fail(Fn<void ()> callback)267 [[nodiscard]] SpecificRequestBuilder &fail(
268 Fn<void()> callback) noexcept {
269 setFailHandler(std::move(callback));
270 return *this;
271 }
fail(Fn<void (const Error & error)> callback)272 [[nodiscard]] SpecificRequestBuilder &fail(
273 Fn<void(const Error &error)> callback) noexcept {
274 setFailHandler(std::move(callback));
275 return *this;
276 }
277
handleFloodErrors()278 [[nodiscard]] SpecificRequestBuilder &handleFloodErrors() noexcept {
279 setFailSkipPolicy(FailSkipPolicy::HandleFlood);
280 return *this;
281 }
handleAllErrors()282 [[nodiscard]] SpecificRequestBuilder &handleAllErrors() noexcept {
283 setFailSkipPolicy(FailSkipPolicy::HandleAll);
284 return *this;
285 }
afterRequest(mtpRequestId requestId)286 [[nodiscard]] SpecificRequestBuilder &afterRequest(mtpRequestId requestId) noexcept {
287 setAfter(requestId);
288 return *this;
289 }
290
send()291 mtpRequestId send() {
292 const auto id = sender()->_instance->send(
293 _request,
294 takeOnDone(),
295 takeOnFail(),
296 takeDcId(),
297 takeCanWait(),
298 takeAfter());
299 registerRequest(id);
300 return id;
301 }
302
303 private:
304 Request _request;
305
306 };
307
308 class SentRequestWrap {
309 private:
310 friend class Sender;
SentRequestWrap(not_null<Sender * > sender,mtpRequestId requestId)311 SentRequestWrap(not_null<Sender*> sender, mtpRequestId requestId) : _sender(sender), _requestId(requestId) {
312 }
313
314 public:
cancel()315 void cancel() {
316 if (_requestId) {
317 _sender->senderRequestCancel(_requestId);
318 }
319 }
320
321 private:
322 not_null<Sender*> _sender;
323 mtpRequestId _requestId = 0;
324
325 };
326
327 template <
328 typename Request,
329 typename = std::enable_if_t<!std::is_reference_v<Request>>,
330 typename = typename Request::Unboxed>
331 [[nodiscard]] SpecificRequestBuilder<Request> request(Request &&request) noexcept;
332
333 [[nodiscard]] SentRequestWrap request(mtpRequestId requestId) noexcept;
334
requestCanceller()335 [[nodiscard]] auto requestCanceller() noexcept {
336 return [this](mtpRequestId requestId) {
337 request(requestId).cancel();
338 };
339 }
340
requestSendDelayed()341 void requestSendDelayed() {
342 _instance->sendAnything();
343 }
requestCancellingDiscard()344 void requestCancellingDiscard() {
345 for (auto &request : base::take(_requests)) {
346 request.handled();
347 }
348 }
349
350 private:
351 class RequestWrap {
352 public:
RequestWrap(not_null<Instance * > instance,mtpRequestId requestId)353 RequestWrap(
354 not_null<Instance*> instance,
355 mtpRequestId requestId) noexcept
356 : _instance(instance)
357 , _id(requestId) {
358 }
359
360 RequestWrap(const RequestWrap &other) = delete;
361 RequestWrap &operator=(const RequestWrap &other) = delete;
RequestWrap(RequestWrap && other)362 RequestWrap(RequestWrap &&other)
363 : _instance(other._instance)
364 , _id(base::take(other._id)) {
365 }
366 RequestWrap &operator=(RequestWrap &&other) {
367 Expects(_instance == other._instance);
368
369 if (_id != other._id) {
370 cancelRequest();
371 _id = base::take(other._id);
372 }
373 return *this;
374 }
375
id()376 mtpRequestId id() const noexcept {
377 return _id;
378 }
handled()379 void handled() const noexcept {
380 _id = 0;
381 }
382
~RequestWrap()383 ~RequestWrap() {
384 cancelRequest();
385 }
386
387 private:
cancelRequest()388 void cancelRequest() {
389 if (_id) {
390 _instance->cancel(_id);
391 }
392 }
393 const not_null<Instance*> _instance;
394 mutable mtpRequestId _id = 0;
395
396 };
397
398 struct RequestWrapComparator {
399 using is_transparent = std::true_type;
400
401 struct helper {
402 mtpRequestId requestId = 0;
403
404 helper() = default;
405 helper(const helper &other) = default;
helperRequestWrapComparator::helper406 helper(mtpRequestId requestId) noexcept : requestId(requestId) {
407 }
helperRequestWrapComparator::helper408 helper(const RequestWrap &request) noexcept : requestId(request.id()) {
409 }
410 bool operator<(helper other) const {
411 return requestId < other.requestId;
412 }
413 };
operatorRequestWrapComparator414 bool operator()(const helper &&lhs, const helper &&rhs) const {
415 return lhs < rhs;
416 }
417
418 };
419
420 template <typename Request>
421 friend class SpecificRequestBuilder;
422 friend class RequestBuilder;
423 friend class RequestWrap;
424 friend class SentRequestWrap;
425
senderRequestRegister(mtpRequestId requestId)426 void senderRequestRegister(mtpRequestId requestId) {
427 _requests.emplace(_instance, requestId);
428 }
senderRequestHandled(mtpRequestId requestId)429 void senderRequestHandled(mtpRequestId requestId) {
430 auto it = _requests.find(requestId);
431 if (it != _requests.cend()) {
432 it->handled();
433 _requests.erase(it);
434 }
435 }
senderRequestCancel(mtpRequestId requestId)436 void senderRequestCancel(mtpRequestId requestId) {
437 auto it = _requests.find(requestId);
438 if (it != _requests.cend()) {
439 _requests.erase(it);
440 }
441 }
442
443 const not_null<Instance*> _instance;
444 base::flat_set<RequestWrap, RequestWrapComparator> _requests;
445
446 };
447
448 template <typename Request, typename, typename>
request(Request && request)449 Sender::SpecificRequestBuilder<Request> Sender::request(Request &&request) noexcept {
450 return SpecificRequestBuilder<Request>(this, std::move(request));
451 }
452
request(mtpRequestId requestId)453 inline Sender::SentRequestWrap Sender::request(mtpRequestId requestId) noexcept {
454 return SentRequestWrap(this, requestId);
455 }
456
457 } // namespace MTP
458