1 /*
2 This file is part of Telegram Desktop,
3 the official desktop application for the Telegram messaging service.
4 
5 For license and copyright information please follow this link:
6 https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
7 */
8 #pragma once
9 
10 #include "base/variant.h"
11 #include "mtproto/mtproto_response.h"
12 #include "mtproto/mtp_instance.h"
13 #include "mtproto/facade.h"
14 
15 namespace MTP {
16 
17 class Sender {
18 	class RequestBuilder {
19 	public:
20 		RequestBuilder(const RequestBuilder &other) = delete;
21 		RequestBuilder &operator=(const RequestBuilder &other) = delete;
22 		RequestBuilder &operator=(RequestBuilder &&other) = delete;
23 
24 	protected:
25 		enum class FailSkipPolicy {
26 			Simple,
27 			HandleFlood,
28 			HandleAll,
29 		};
30 		using FailPlainHandler = Fn<void()>;
31 		using FailErrorHandler = Fn<void(const Error&)>;
32 		using FailRequestIdHandler = Fn<void(const Error&, mtpRequestId)>;
33 		using FailFullHandler = Fn<void(const Error&, const Response&)>;
34 
35 		template <typename ...Args>
36 		static constexpr bool IsCallable
37 			= rpl::details::is_callable_plain_v<Args...>;
38 
39 		template <typename Result, typename Handler>
MakeDoneHandler(not_null<Sender * > sender,Handler && handler)40 		[[nodiscard]] DoneHandler MakeDoneHandler(
41 				not_null<Sender*> sender,
42 				Handler &&handler) {
43 			return [sender, handler = std::forward<Handler>(handler)](
44 					const Response &response) mutable {
45 				auto onstack = std::move(handler);
46 				sender->senderRequestHandled(response.requestId);
47 
48 				auto result = Result();
49 				auto from = response.reply.constData();
50 				if (!result.read(from, from + response.reply.size())) {
51 					return false;
52 				} else if (!onstack) {
53 					return true;
54 				} else if constexpr (IsCallable<
55 						Handler,
56 						const Result&,
57 						const Response&>) {
58 					onstack(result, response);
59 				} else if constexpr (IsCallable<
60 						Handler,
61 						const Result&,
62 						mtpRequestId>) {
63 					onstack(result, response.requestId);
64 				} else if constexpr (IsCallable<
65 						Handler,
66 						const Result&>) {
67 					onstack(result);
68 				} else if constexpr (IsCallable<Handler>) {
69 					onstack();
70 				} else {
71 					static_assert(false_t(Handler{}), "Bad done handler.");
72 				}
73 				return true;
74 			};
75 		}
76 
77 		template <typename Handler>
MakeFailHandler(not_null<Sender * > sender,Handler && handler,FailSkipPolicy skipPolicy)78 		[[nodiscard]] FailHandler MakeFailHandler(
79 				not_null<Sender*> sender,
80 				Handler &&handler,
81 				FailSkipPolicy skipPolicy) {
82 			return [
83 				sender,
84 				handler = std::forward<Handler>(handler),
85 				skipPolicy
86 			](const Error &error, const Response &response) {
87 				if (skipPolicy == FailSkipPolicy::Simple) {
88 					if (IsDefaultHandledError(error)) {
89 						return false;
90 					}
91 				} else if (skipPolicy == FailSkipPolicy::HandleFlood) {
92 					if (IsDefaultHandledError(error) && !IsFloodError(error)) {
93 						return false;
94 					}
95 				}
96 
97 				auto onstack = handler;
98 				sender->senderRequestHandled(response.requestId);
99 
100 				if (!onstack) {
101 					return true;
102 				} else if constexpr (IsCallable<
103 						Handler,
104 						const Error&,
105 						const Response&>) {
106 					onstack(error, response);
107 				} else if constexpr (IsCallable<
108 						Handler,
109 						const Error&,
110 						mtpRequestId>) {
111 					onstack(error, response.requestId);
112 				} else if constexpr (IsCallable<
113 						Handler,
114 						const Error&>) {
115 					onstack(error);
116 				} else if constexpr (IsCallable<Handler>) {
117 					onstack();
118 				} else {
119 					static_assert(false_t(Handler{}), "Bad fail handler.");
120 				}
121 				return true;
122 			};
123 		}
124 
RequestBuilder(not_null<Sender * > sender)125 		explicit RequestBuilder(not_null<Sender*> sender) noexcept
126 		: _sender(sender) {
127 		}
128 		RequestBuilder(RequestBuilder &&other) = default;
129 
setToDC(ShiftedDcId dcId)130 		void setToDC(ShiftedDcId dcId) noexcept {
131 			_dcId = dcId;
132 		}
setCanWait(crl::time ms)133 		void setCanWait(crl::time ms) noexcept {
134 			_canWait = ms;
135 		}
setDoneHandler(DoneHandler && handler)136 		void setDoneHandler(DoneHandler &&handler) noexcept {
137 			_done = std::move(handler);
138 		}
139 		template <typename Handler>
setFailHandler(Handler && handler)140 		void setFailHandler(Handler &&handler) noexcept {
141 			_fail = std::forward<Handler>(handler);
142 		}
setFailSkipPolicy(FailSkipPolicy policy)143 		void setFailSkipPolicy(FailSkipPolicy policy) noexcept {
144 			_failSkipPolicy = policy;
145 		}
setAfter(mtpRequestId requestId)146 		void setAfter(mtpRequestId requestId) noexcept {
147 			_afterRequestId = requestId;
148 		}
149 
takeDcId()150 		ShiftedDcId takeDcId() const noexcept {
151 			return _dcId;
152 		}
takeCanWait()153 		crl::time takeCanWait() const noexcept {
154 			return _canWait;
155 		}
takeOnDone()156 		DoneHandler takeOnDone() noexcept {
157 			return std::move(_done);
158 		}
takeOnFail()159 		FailHandler takeOnFail() {
160 			return v::match(_fail, [&](auto &value) {
161 				return MakeFailHandler(
162 					_sender,
163 					std::move(value),
164 					_failSkipPolicy);
165 			});
166 		}
takeAfter()167 		mtpRequestId takeAfter() const noexcept {
168 			return _afterRequestId;
169 		}
170 
sender()171 		not_null<Sender*> sender() const noexcept {
172 			return _sender;
173 		}
registerRequest(mtpRequestId requestId)174 		void registerRequest(mtpRequestId requestId) {
175 			_sender->senderRequestRegister(requestId);
176 		}
177 
178 	private:
179 		not_null<Sender*> _sender;
180 		ShiftedDcId _dcId = 0;
181 		crl::time _canWait = 0;
182 		DoneHandler _done;
183 		std::variant<
184 			FailPlainHandler,
185 			FailErrorHandler,
186 			FailRequestIdHandler,
187 			FailFullHandler> _fail;
188 		FailSkipPolicy _failSkipPolicy = FailSkipPolicy::Simple;
189 		mtpRequestId _afterRequestId = 0;
190 
191 	};
192 
193 public:
Sender(not_null<Instance * > instance)194 	explicit Sender(not_null<Instance*> instance) noexcept
195 	: _instance(instance) {
196 	}
197 
instance()198 	[[nodiscard]] Instance &instance() const {
199 		return *_instance;
200 	}
201 
202 	template <typename Request>
203 	class SpecificRequestBuilder : public RequestBuilder {
204 	private:
205 		friend class Sender;
SpecificRequestBuilder(not_null<Sender * > sender,Request && request)206 		SpecificRequestBuilder(not_null<Sender*> sender, Request &&request) noexcept
207 		: RequestBuilder(sender)
208 		, _request(std::move(request)) {
209 		}
210 		SpecificRequestBuilder(SpecificRequestBuilder &&other) = default;
211 
212 	public:
toDC(ShiftedDcId dcId)213 		[[nodiscard]] SpecificRequestBuilder &toDC(ShiftedDcId dcId) noexcept {
214 			setToDC(dcId);
215 			return *this;
216 		}
afterDelay(crl::time ms)217 		[[nodiscard]] SpecificRequestBuilder &afterDelay(crl::time ms) noexcept {
218 			setCanWait(ms);
219 			return *this;
220 		}
221 
222 		using Result = typename Request::ResponseType;
done(FnMut<void (const Result & result,mtpRequestId requestId)> callback)223 		[[nodiscard]] SpecificRequestBuilder &done(
224 			FnMut<void(
225 				const Result &result,
226 				mtpRequestId requestId)> callback) {
227 			setDoneHandler(
228 				MakeDoneHandler<Result>(sender(), std::move(callback)));
229 			return *this;
230 		}
done(FnMut<void (const Result & result,const Response & response)> callback)231 		[[nodiscard]] SpecificRequestBuilder &done(
232 			FnMut<void(
233 				const Result &result,
234 				const Response &response)> callback) {
235 			setDoneHandler(
236 				MakeDoneHandler<Result>(sender(), std::move(callback)));
237 			return *this;
238 		}
done(FnMut<void ()> callback)239 		[[nodiscard]] SpecificRequestBuilder &done(
240 				FnMut<void()> callback) {
241 			setDoneHandler(
242 				MakeDoneHandler<Result>(sender(), std::move(callback)));
243 			return *this;
244 		}
done(FnMut<void (const typename Request::ResponseType & result)> callback)245 		[[nodiscard]] SpecificRequestBuilder &done(
246 			FnMut<void(
247 				const typename Request::ResponseType &result)> callback) {
248 			setDoneHandler(
249 				MakeDoneHandler<Result>(sender(), std::move(callback)));
250 			return *this;
251 		}
252 
fail(Fn<void (const Error & error,mtpRequestId requestId)> callback)253 		[[nodiscard]] SpecificRequestBuilder &fail(
254 			Fn<void(
255 				const Error &error,
256 				mtpRequestId requestId)> callback) noexcept {
257 			setFailHandler(std::move(callback));
258 			return *this;
259 		}
fail(Fn<void (const Error & error,const Response & response)> callback)260 		[[nodiscard]] SpecificRequestBuilder &fail(
261 			Fn<void(
262 				const Error &error,
263 				const Response &response)> callback) noexcept {
264 			setFailHandler(std::move(callback));
265 			return *this;
266 		}
fail(Fn<void ()> callback)267 		[[nodiscard]] SpecificRequestBuilder &fail(
268 				Fn<void()> callback) noexcept {
269 			setFailHandler(std::move(callback));
270 			return *this;
271 		}
fail(Fn<void (const Error & error)> callback)272 		[[nodiscard]] SpecificRequestBuilder &fail(
273 				Fn<void(const Error &error)> callback) noexcept {
274 			setFailHandler(std::move(callback));
275 			return *this;
276 		}
277 
handleFloodErrors()278 		[[nodiscard]] SpecificRequestBuilder &handleFloodErrors() noexcept {
279 			setFailSkipPolicy(FailSkipPolicy::HandleFlood);
280 			return *this;
281 		}
handleAllErrors()282 		[[nodiscard]] SpecificRequestBuilder &handleAllErrors() noexcept {
283 			setFailSkipPolicy(FailSkipPolicy::HandleAll);
284 			return *this;
285 		}
afterRequest(mtpRequestId requestId)286 		[[nodiscard]] SpecificRequestBuilder &afterRequest(mtpRequestId requestId) noexcept {
287 			setAfter(requestId);
288 			return *this;
289 		}
290 
send()291 		mtpRequestId send() {
292 			const auto id = sender()->_instance->send(
293 				_request,
294 				takeOnDone(),
295 				takeOnFail(),
296 				takeDcId(),
297 				takeCanWait(),
298 				takeAfter());
299 			registerRequest(id);
300 			return id;
301 		}
302 
303 	private:
304 		Request _request;
305 
306 	};
307 
308 	class SentRequestWrap {
309 	private:
310 		friend class Sender;
SentRequestWrap(not_null<Sender * > sender,mtpRequestId requestId)311 		SentRequestWrap(not_null<Sender*> sender, mtpRequestId requestId) : _sender(sender), _requestId(requestId) {
312 		}
313 
314 	public:
cancel()315 		void cancel() {
316 			if (_requestId) {
317 				_sender->senderRequestCancel(_requestId);
318 			}
319 		}
320 
321 	private:
322 		not_null<Sender*> _sender;
323 		mtpRequestId _requestId = 0;
324 
325 	};
326 
327 	template <
328 		typename Request,
329 		typename = std::enable_if_t<!std::is_reference_v<Request>>,
330 		typename = typename Request::Unboxed>
331 	[[nodiscard]] SpecificRequestBuilder<Request> request(Request &&request) noexcept;
332 
333 	[[nodiscard]] SentRequestWrap request(mtpRequestId requestId) noexcept;
334 
requestCanceller()335 	[[nodiscard]] auto requestCanceller() noexcept {
336 		return [this](mtpRequestId requestId) {
337 			request(requestId).cancel();
338 		};
339 	}
340 
requestSendDelayed()341 	void requestSendDelayed() {
342 		_instance->sendAnything();
343 	}
requestCancellingDiscard()344 	void requestCancellingDiscard() {
345 		for (auto &request : base::take(_requests)) {
346 			request.handled();
347 		}
348 	}
349 
350 private:
351 	class RequestWrap {
352 	public:
RequestWrap(not_null<Instance * > instance,mtpRequestId requestId)353 		RequestWrap(
354 			not_null<Instance*> instance,
355 			mtpRequestId requestId) noexcept
356 		: _instance(instance)
357 		, _id(requestId) {
358 		}
359 
360 		RequestWrap(const RequestWrap &other) = delete;
361 		RequestWrap &operator=(const RequestWrap &other) = delete;
RequestWrap(RequestWrap && other)362 		RequestWrap(RequestWrap &&other)
363 		: _instance(other._instance)
364 		, _id(base::take(other._id)) {
365 		}
366 		RequestWrap &operator=(RequestWrap &&other) {
367 			Expects(_instance == other._instance);
368 
369 			if (_id != other._id) {
370 				cancelRequest();
371 				_id = base::take(other._id);
372 			}
373 			return *this;
374 		}
375 
id()376 		mtpRequestId id() const noexcept {
377 			return _id;
378 		}
handled()379 		void handled() const noexcept {
380 			_id = 0;
381 		}
382 
~RequestWrap()383 		~RequestWrap() {
384 			cancelRequest();
385 		}
386 
387 	private:
cancelRequest()388 		void cancelRequest() {
389 			if (_id) {
390 				_instance->cancel(_id);
391 			}
392 		}
393 		const not_null<Instance*> _instance;
394 		mutable mtpRequestId _id = 0;
395 
396 	};
397 
398 	struct RequestWrapComparator {
399 		using is_transparent = std::true_type;
400 
401 		struct helper {
402 			mtpRequestId requestId = 0;
403 
404 			helper() = default;
405 			helper(const helper &other) = default;
helperRequestWrapComparator::helper406 			helper(mtpRequestId requestId) noexcept : requestId(requestId) {
407 			}
helperRequestWrapComparator::helper408 			helper(const RequestWrap &request) noexcept : requestId(request.id()) {
409 			}
410 			bool operator<(helper other) const {
411 				return requestId < other.requestId;
412 			}
413 		};
operatorRequestWrapComparator414 		bool operator()(const helper &&lhs, const helper &&rhs) const {
415 			return lhs < rhs;
416 		}
417 
418 	};
419 
420 	template <typename Request>
421 	friend class SpecificRequestBuilder;
422 	friend class RequestBuilder;
423 	friend class RequestWrap;
424 	friend class SentRequestWrap;
425 
senderRequestRegister(mtpRequestId requestId)426 	void senderRequestRegister(mtpRequestId requestId) {
427 		_requests.emplace(_instance, requestId);
428 	}
senderRequestHandled(mtpRequestId requestId)429 	void senderRequestHandled(mtpRequestId requestId) {
430 		auto it = _requests.find(requestId);
431 		if (it != _requests.cend()) {
432 			it->handled();
433 			_requests.erase(it);
434 		}
435 	}
senderRequestCancel(mtpRequestId requestId)436 	void senderRequestCancel(mtpRequestId requestId) {
437 		auto it = _requests.find(requestId);
438 		if (it != _requests.cend()) {
439 			_requests.erase(it);
440 		}
441 	}
442 
443 	const not_null<Instance*> _instance;
444 	base::flat_set<RequestWrap, RequestWrapComparator> _requests;
445 
446 };
447 
448 template <typename Request, typename, typename>
request(Request && request)449 Sender::SpecificRequestBuilder<Request> Sender::request(Request &&request) noexcept {
450 	return SpecificRequestBuilder<Request>(this, std::move(request));
451 }
452 
request(mtpRequestId requestId)453 inline Sender::SentRequestWrap Sender::request(mtpRequestId requestId) noexcept {
454 	return SentRequestWrap(this, requestId);
455 }
456 
457 } // namespace MTP
458