1 /*
2 This file is part of Telegram Desktop,
3 the official desktop application for the Telegram messaging service.
4 
5 For license and copyright information please follow this link:
6 https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
7 */
8 #include "storage/download_manager_mtproto.h"
9 
10 #include "mtproto/facade.h"
11 #include "mtproto/mtproto_auth_key.h"
12 #include "mtproto/mtproto_response.h"
13 #include "main/main_session.h"
14 #include "apiwrap.h"
15 #include "base/openssl_help.h"
16 
17 namespace Storage {
18 namespace {
19 
20 constexpr auto kKillSessionTimeout = 15 * crl::time(1000);
21 constexpr auto kStartWaitedInSession = 4 * kDownloadPartSize;
22 constexpr auto kMaxWaitedInSession = 16 * kDownloadPartSize;
23 constexpr auto kStartSessionsCount = 1;
24 constexpr auto kMaxSessionsCount = 8;
25 constexpr auto kMaxTrackedSessionRemoves = 64;
26 constexpr auto kRetryAddSessionTimeout = 8 * crl::time(1000);
27 constexpr auto kRetryAddSessionSuccesses = 3;
28 constexpr auto kMaxTrackedSuccesses = kRetryAddSessionSuccesses
29 	* kMaxTrackedSessionRemoves;
30 constexpr auto kRemoveSessionAfterTimeouts = 4;
31 constexpr auto kResetDownloadPrioritiesTimeout = crl::time(200);
32 constexpr auto kBadRequestDurationThreshold = 8 * crl::time(1000);
33 
34 // Each (session remove by timeouts) we wait for time:
35 // kRetryAddSessionTimeout * max(removesCount, kMaxTrackedSessionRemoves)
36 // and for successes in all remaining sessions:
37 // kRetryAddSessionSuccesses * max(removesCount, kMaxTrackedSessionRemoves)
38 
39 } // namespace
40 
enqueue(not_null<Task * > task,int priority)41 void DownloadManagerMtproto::Queue::enqueue(
42 		not_null<Task*> task,
43 		int priority) {
44 	const auto position = ranges::find_if(_tasks, [&](const Enqueued &task) {
45 		return task.priority <= priority;
46 	}) - begin(_tasks);
47 	const auto now = ranges::find(_tasks, task, &Enqueued::task);
48 	const auto i = [&] {
49 		if (now != end(_tasks)) {
50 			(now->priority = priority);
51 			return now;
52 		}
53 		_tasks.push_back({ task, priority });
54 		return end(_tasks) - 1;
55 	}();
56 	const auto j = begin(_tasks) + position;
57 	if (j < i) {
58 		std::rotate(j, i, i + 1);
59 	} else if (j > i + 1) {
60 		std::rotate(i, i + 1, j);
61 	}
62 }
63 
remove(not_null<Task * > task)64 void DownloadManagerMtproto::Queue::remove(not_null<Task*> task) {
65 	_tasks.erase(ranges::remove(_tasks, task, &Enqueued::task), end(_tasks));
66 }
67 
resetGeneration()68 void DownloadManagerMtproto::Queue::resetGeneration() {
69 	const auto from = ranges::find(_tasks, 0, &Enqueued::priority);
70 	for (auto &task : ranges::make_subrange(from, end(_tasks))) {
71 		if (task.priority) {
72 			Assert(task.priority == -1);
73 			break;
74 		}
75 		task.priority = -1;
76 	}
77 }
78 
empty() const79 bool DownloadManagerMtproto::Queue::empty() const {
80 	return _tasks.empty();
81 }
82 
nextTask(bool onlyHighestPriority) const83 auto DownloadManagerMtproto::Queue::nextTask(bool onlyHighestPriority) const
84 -> Task* {
85 	if (_tasks.empty()) {
86 		return nullptr;
87 	}
88 	const auto highestPriority = _tasks.front().priority;
89 	const auto notHighestPriority = [&](const Enqueued &enqueued) {
90 		return (enqueued.priority != highestPriority);
91 	};
92 	const auto till = (onlyHighestPriority && highestPriority > 0)
93 		? ranges::find_if(_tasks, notHighestPriority)
94 		: end(_tasks);
95 	const auto readyToRequest = [&](const Enqueued &enqueued) {
96 		return enqueued.task->readyToRequest();
97 	};
98 	const auto first = ranges::find_if(
99 		ranges::make_subrange(begin(_tasks), till),
100 		readyToRequest);
101 	return (first != till) ? first->task.get() : nullptr;
102 }
103 
removeSession(int index)104 void DownloadManagerMtproto::Queue::removeSession(int index) {
105 	for (const auto &enqueued : _tasks) {
106 		enqueued.task->removeSession(index);
107 	}
108 }
109 
DcSessionBalanceData()110 DownloadManagerMtproto::DcSessionBalanceData::DcSessionBalanceData()
111 : maxWaitedAmount(kStartWaitedInSession) {
112 }
113 
DcBalanceData()114 DownloadManagerMtproto::DcBalanceData::DcBalanceData()
115 : sessions(kStartSessionsCount) {
116 }
117 
DownloadManagerMtproto(not_null<ApiWrap * > api)118 DownloadManagerMtproto::DownloadManagerMtproto(not_null<ApiWrap*> api)
119 : _api(api)
120 , _resetGenerationTimer([=] { resetGeneration(); })
__anon188a26980702null121 , _killSessionsTimer([=] { killSessions(); }) {
122 	_api->instance().restartsByTimeout(
__anon188a26980802(MTP::ShiftedDcId shiftedDcId) 123 	) | rpl::filter([](MTP::ShiftedDcId shiftedDcId) {
124 		return MTP::isDownloadDcId(shiftedDcId);
125 	}) | rpl::start_with_next([=](MTP::ShiftedDcId shiftedDcId) {
126 		sessionTimedOut(
127 			MTP::BareDcId(shiftedDcId),
128 			MTP::GetDcIdShift(shiftedDcId));
129 	}, _lifetime);
130 }
131 
~DownloadManagerMtproto()132 DownloadManagerMtproto::~DownloadManagerMtproto() {
133 	killSessions();
134 }
135 
enqueue(not_null<Task * > task,int priority)136 void DownloadManagerMtproto::enqueue(not_null<Task*> task, int priority) {
137 	const auto dcId = task->dcId();
138 	auto &queue = _queues[dcId];
139 	queue.enqueue(task, priority);
140 	if (!_resetGenerationTimer.isActive()) {
141 		_resetGenerationTimer.callOnce(kResetDownloadPrioritiesTimeout);
142 	}
143 	checkSendNext(dcId, queue);
144 }
145 
remove(not_null<Task * > task)146 void DownloadManagerMtproto::remove(not_null<Task*> task) {
147 	const auto dcId = task->dcId();
148 	auto &queue = _queues[dcId];
149 	queue.remove(task);
150 	checkSendNext(dcId, queue);
151 }
152 
resetGeneration()153 void DownloadManagerMtproto::resetGeneration() {
154 	_resetGenerationTimer.cancel();
155 	for (auto &[dcId, queue] : _queues) {
156 		queue.resetGeneration();
157 	}
158 }
159 
checkSendNext()160 void DownloadManagerMtproto::checkSendNext() {
161 	for (auto &[dcId, queue] : _queues) {
162 		if (queue.empty()) {
163 			continue;
164 		}
165 		checkSendNext(dcId, queue);
166 	}
167 }
168 
checkSendNext(MTP::DcId dcId,Queue & queue)169 void DownloadManagerMtproto::checkSendNext(MTP::DcId dcId, Queue &queue) {
170 	while (trySendNextPart(dcId, queue)) {
171 	}
172 }
173 
checkSendNextAfterSuccess(MTP::DcId dcId)174 void DownloadManagerMtproto::checkSendNextAfterSuccess(MTP::DcId dcId) {
175 	checkSendNext(dcId, _queues[dcId]);
176 }
177 
trySendNextPart(MTP::DcId dcId,Queue & queue)178 bool DownloadManagerMtproto::trySendNextPart(MTP::DcId dcId, Queue &queue) {
179 	auto &balanceData = _balanceData[dcId];
180 	const auto &sessions = balanceData.sessions;
181 	const auto bestIndex = [&] {
182 		const auto proj = [](const DcSessionBalanceData &data) {
183 			return (data.requested < data.maxWaitedAmount)
184 				? data.requested
185 				: kMaxWaitedInSession;
186 		};
187 		const auto j = ranges::min_element(sessions, ranges::less(), proj);
188 		return (j->requested + kDownloadPartSize <= j->maxWaitedAmount)
189 			? (j - begin(sessions))
190 			: -1;
191 	}();
192 	if (bestIndex < 0) {
193 		return false;
194 	}
195 	const auto onlyHighestPriority = (balanceData.totalRequested > 0);
196 	if (const auto task = queue.nextTask(onlyHighestPriority)) {
197 		task->loadPart(bestIndex);
198 		return true;
199 	}
200 	return false;
201 }
202 
changeRequestedAmount(MTP::DcId dcId,int index,int delta)203 int DownloadManagerMtproto::changeRequestedAmount(
204 		MTP::DcId dcId,
205 		int index,
206 		int delta) {
207 	const auto i = _balanceData.find(dcId);
208 	Assert(i != _balanceData.end());
209 	Assert(index < i->second.sessions.size());
210 	const auto result = (i->second.sessions[index].requested += delta);
211 	i->second.totalRequested += delta;
212 	const auto findNonEmptySession = [](const DcBalanceData &data) {
213 		using namespace rpl::mappers;
214 		return ranges::find_if(
215 			data.sessions,
216 			_1 > 0,
217 			&DcSessionBalanceData::requested);
218 	};
219 	if (delta > 0) {
220 		killSessionsCancel(dcId);
221 	} else if (findNonEmptySession(i->second) == end(i->second.sessions)) {
222 		killSessionsSchedule(dcId);
223 	}
224 	return result;
225 }
226 
requestSucceeded(MTP::DcId dcId,int index,int amountAtRequestStart,crl::time timeAtRequestStart)227 void DownloadManagerMtproto::requestSucceeded(
228 		MTP::DcId dcId,
229 		int index,
230 		int amountAtRequestStart,
231 		crl::time timeAtRequestStart) {
232 	using namespace rpl::mappers;
233 
234 	const auto i = _balanceData.find(dcId);
235 	Assert(i != end(_balanceData));
236 	auto &dc = i->second;
237 	Assert(index < dc.sessions.size());
238 	auto &data = dc.sessions[index];
239 	const auto overloaded = (timeAtRequestStart <= dc.lastSessionRemove)
240 		|| (amountAtRequestStart > data.maxWaitedAmount);
241 	const auto parts = amountAtRequestStart / kDownloadPartSize;
242 	const auto duration = (crl::now() - timeAtRequestStart);
243 	DEBUG_LOG(("Download (%1,%2) request done, duration: %3, parts: %4%5"
244 		).arg(dcId
245 		).arg(index
246 		).arg(duration
247 		).arg(parts
248 		).arg(overloaded ? " (overloaded)" : ""));
249 	if (overloaded) {
250 		return;
251 	}
252 
253 	if (duration >= kBadRequestDurationThreshold) {
254 		DEBUG_LOG(("Duration too large, signaling time out."));
255 		crl::on_main(this, [=] {
256 			sessionTimedOut(dcId, index);
257 		});
258 		return;
259 	}
260 	if (amountAtRequestStart == data.maxWaitedAmount
261 		&& data.maxWaitedAmount < kMaxWaitedInSession) {
262 		data.maxWaitedAmount = std::min(
263 			data.maxWaitedAmount + kDownloadPartSize,
264 			kMaxWaitedInSession);
265 		DEBUG_LOG(("Download (%1,%2) increased max waited amount %3."
266 			).arg(dcId
267 			).arg(index
268 			).arg(data.maxWaitedAmount));
269 	}
270 	data.successes = std::min(data.successes + 1, kMaxTrackedSuccesses);
271 	const auto notEnough = ranges::any_of(
272 		dc.sessions,
273 		_1 < (dc.sessionRemoveTimes + 1) * kRetryAddSessionSuccesses,
274 		&DcSessionBalanceData::successes);
275 	if (notEnough) {
276 		return;
277 	}
278 	for (auto &session : dc.sessions) {
279 		session.successes = 0;
280 	}
281 	if (dc.timeouts > 0) {
282 		--dc.timeouts;
283 		return;
284 	} else if (dc.sessions.size() == kMaxSessionsCount) {
285 		return;
286 	}
287 	const auto now = crl::now();
288 	const auto delay = (dc.sessionRemoveTimes + 1) * kRetryAddSessionTimeout;
289 	if (dc.lastSessionRemove && now < dc.lastSessionRemove + delay) {
290 		return;
291 	}
292 	dc.sessions.emplace_back();
293 	DEBUG_LOG(("Download (%1,%2) adding, now sessions: %3"
294 		).arg(dcId
295 		).arg(dc.sessions.size() - 1
296 		).arg(dc.sessions.size()));
297 }
298 
chooseSessionIndex(MTP::DcId dcId) const299 int DownloadManagerMtproto::chooseSessionIndex(MTP::DcId dcId) const {
300 	const auto i = _balanceData.find(dcId);
301 	Assert(i != end(_balanceData));
302 	const auto &sessions = i->second.sessions;
303 	const auto j = ranges::min_element(
304 		sessions,
305 		ranges::less(),
306 		&DcSessionBalanceData::requested);
307 	return (j - begin(sessions));
308 }
309 
sessionTimedOut(MTP::DcId dcId,int index)310 void DownloadManagerMtproto::sessionTimedOut(MTP::DcId dcId, int index) {
311 	const auto i = _balanceData.find(dcId);
312 	if (i == end(_balanceData)) {
313 		return;
314 	}
315 	auto &dc = i->second;
316 	if (index >= dc.sessions.size()) {
317 		return;
318 	}
319 	DEBUG_LOG(("Download (%1,%2) session timed-out.").arg(dcId).arg(index));
320 	for (auto &session : dc.sessions) {
321 		session.successes = 0;
322 	}
323 	if (dc.sessions.size() == kStartSessionsCount
324 		|| ++dc.timeouts < kRemoveSessionAfterTimeouts) {
325 		return;
326 	}
327 	dc.timeouts = 0;
328 	removeSession(dcId);
329 }
330 
removeSession(MTP::DcId dcId)331 void DownloadManagerMtproto::removeSession(MTP::DcId dcId) {
332 	auto &dc = _balanceData[dcId];
333 	Assert(dc.sessions.size() > kStartSessionsCount);
334 	const auto index = int(dc.sessions.size() - 1);
335 	DEBUG_LOG(("Download (%1,%2) removing, now sessions: %3"
336 		).arg(dcId
337 		).arg(index
338 		).arg(index));
339 	auto &queue = _queues[dcId];
340 	if (dc.sessionRemoveIndex == index) {
341 		dc.sessionRemoveTimes = std::min(
342 			dc.sessionRemoveTimes + 1,
343 			kMaxTrackedSessionRemoves);
344 	} else {
345 		dc.sessionRemoveIndex = index;
346 		dc.sessionRemoveTimes = 1;
347 	}
348 	auto &session = dc.sessions.back();
349 
350 	// Make sure we don't send anything to that session while redirecting.
351 	session.requested += kMaxWaitedInSession * kMaxSessionsCount;
352 	queue.removeSession(index);
353 	Assert(session.requested == kMaxWaitedInSession * kMaxSessionsCount);
354 
355 	dc.sessions.pop_back();
356 	api().instance().killSession(MTP::downloadDcId(dcId, index));
357 
358 	dc.lastSessionRemove = crl::now();
359 }
360 
killSessionsSchedule(MTP::DcId dcId)361 void DownloadManagerMtproto::killSessionsSchedule(MTP::DcId dcId) {
362 	if (!_killSessionsWhen.contains(dcId)) {
363 		_killSessionsWhen.emplace(dcId, crl::now() + kKillSessionTimeout);
364 	}
365 	if (!_killSessionsTimer.isActive()) {
366 		_killSessionsTimer.callOnce(kKillSessionTimeout + 5);
367 	}
368 }
369 
killSessionsCancel(MTP::DcId dcId)370 void DownloadManagerMtproto::killSessionsCancel(MTP::DcId dcId) {
371 	_killSessionsWhen.erase(dcId);
372 	if (_killSessionsWhen.empty()) {
373 		_killSessionsTimer.cancel();
374 	}
375 }
376 
killSessions()377 void DownloadManagerMtproto::killSessions() {
378 	const auto now = crl::now();
379 	auto left = kKillSessionTimeout;
380 	for (auto i = begin(_killSessionsWhen); i != end(_killSessionsWhen); ) {
381 		if (i->second <= now) {
382 			killSessions(i->first);
383 			i = _killSessionsWhen.erase(i);
384 		} else {
385 			if (i->second - now < left) {
386 				left = i->second - now;
387 			}
388 			++i;
389 		}
390 	}
391 	if (!_killSessionsWhen.empty()) {
392 		_killSessionsTimer.callOnce(left);
393 	}
394 }
395 
killSessions(MTP::DcId dcId)396 void DownloadManagerMtproto::killSessions(MTP::DcId dcId) {
397 	const auto i = _balanceData.find(dcId);
398 	if (i != end(_balanceData)) {
399 		auto &dc = i->second;
400 		Assert(dc.totalRequested == 0);
401 		auto sessions = base::take(dc.sessions);
402 		dc = DcBalanceData();
403 		for (auto j = 0; j != int(sessions.size()); ++j) {
404 			Assert(sessions[j].requested == 0);
405 			sessions[j] = DcSessionBalanceData();
406 			api().instance().stopSession(MTP::downloadDcId(dcId, j));
407 		}
408 		dc.sessions = base::take(sessions);
409 	}
410 }
411 
DownloadMtprotoTask(not_null<DownloadManagerMtproto * > owner,const StorageFileLocation & location,Data::FileOrigin origin)412 DownloadMtprotoTask::DownloadMtprotoTask(
413 	not_null<DownloadManagerMtproto*> owner,
414 	const StorageFileLocation &location,
415 	Data::FileOrigin origin)
416 : _owner(owner)
417 , _dcId(location.dcId())
418 , _location({ location })
419 , _origin(origin) {
420 }
421 
DownloadMtprotoTask(not_null<DownloadManagerMtproto * > owner,MTP::DcId dcId,const Location & location)422 DownloadMtprotoTask::DownloadMtprotoTask(
423 	not_null<DownloadManagerMtproto*> owner,
424 	MTP::DcId dcId,
425 	const Location &location)
426 : _owner(owner)
427 , _dcId(dcId)
428 , _location(location) {
429 }
430 
~DownloadMtprotoTask()431 DownloadMtprotoTask::~DownloadMtprotoTask() {
432 	cancelAllRequests();
433 	_owner->remove(this);
434 }
435 
dcId() const436 MTP::DcId DownloadMtprotoTask::dcId() const {
437 	return _dcId;
438 }
439 
fileOrigin() const440 Data::FileOrigin DownloadMtprotoTask::fileOrigin() const {
441 	return _origin;
442 }
443 
objectId() const444 uint64 DownloadMtprotoTask::objectId() const {
445 	if (const auto v = std::get_if<StorageFileLocation>(&_location.data)) {
446 		return v->objectId();
447 	}
448 	return 0;
449 }
450 
location() const451 const DownloadMtprotoTask::Location &DownloadMtprotoTask::location() const {
452 	return _location;
453 }
454 
refreshFileReferenceFrom(const Data::UpdatedFileReferences & updates,int requestId,const QByteArray & current)455 void DownloadMtprotoTask::refreshFileReferenceFrom(
456 		const Data::UpdatedFileReferences &updates,
457 		int requestId,
458 		const QByteArray &current) {
459 	if (const auto v = std::get_if<StorageFileLocation>(&_location.data)) {
460 		v->refreshFileReference(updates);
461 		if (v->fileReference() == current) {
462 			cancelOnFail();
463 			return;
464 		}
465 	} else {
466 		cancelOnFail();
467 		return;
468 	}
469 	if (_sentRequests.contains(requestId)) {
470 		makeRequest(finishSentRequest(
471 			requestId,
472 			FinishRequestReason::Redirect));
473 	}
474 }
475 
loadPart(int sessionIndex)476 void DownloadMtprotoTask::loadPart(int sessionIndex) {
477 	makeRequest({ takeNextRequestOffset(), sessionIndex });
478 }
479 
removeSession(int sessionIndex)480 void DownloadMtprotoTask::removeSession(int sessionIndex) {
481 	struct Redirect {
482 		mtpRequestId requestId = 0;
483 		int offset = 0;
484 	};
485 	auto redirect = std::vector<Redirect>();
486 	for (const auto &[requestId, requestData] : _sentRequests) {
487 		if (requestData.sessionIndex == sessionIndex) {
488 			redirect.reserve(_sentRequests.size());
489 			redirect.push_back({ requestId, requestData.offset });
490 		}
491 	}
492 	for (auto &[requestData, bytes] : _cdnUncheckedParts) {
493 		if (requestData.sessionIndex == sessionIndex) {
494 			const auto newIndex = _owner->chooseSessionIndex(dcId());
495 			Assert(newIndex < sessionIndex);
496 			requestData.sessionIndex = newIndex;
497 		}
498 	}
499 	for (const auto &[requestId, offset] : redirect) {
500 		const auto needMakeRequest = (requestId != _cdnHashesRequestId);
501 		cancelRequest(requestId);
502 		if (needMakeRequest) {
503 			const auto newIndex = _owner->chooseSessionIndex(dcId());
504 			Assert(newIndex < sessionIndex);
505 			makeRequest({ offset, newIndex });
506 		}
507 	}
508 }
509 
sendRequest(const RequestData & requestData)510 mtpRequestId DownloadMtprotoTask::sendRequest(
511 		const RequestData &requestData) {
512 	const auto offset = requestData.offset;
513 	const auto limit = Storage::kDownloadPartSize;
514 	const auto shiftedDcId = MTP::downloadDcId(
515 		_cdnDcId ? _cdnDcId : dcId(),
516 		requestData.sessionIndex);
517 	if (_cdnDcId) {
518 		return api().request(MTPupload_GetCdnFile(
519 			MTP_bytes(_cdnToken),
520 			MTP_int(offset),
521 			MTP_int(limit)
522 		)).done([=](const MTPupload_CdnFile &result, mtpRequestId id) {
523 			cdnPartLoaded(result, id);
524 		}).fail([=](const MTP::Error &error, mtpRequestId id) {
525 			cdnPartFailed(error, id);
526 		}).toDC(shiftedDcId).send();
527 	}
528 	return v::match(_location.data, [&](const WebFileLocation &location) {
529 		return api().request(MTPupload_GetWebFile(
530 			MTP_inputWebFileLocation(
531 				MTP_bytes(location.url()),
532 				MTP_long(location.accessHash())),
533 			MTP_int(offset),
534 			MTP_int(limit)
535 		)).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
536 			webPartLoaded(result, id);
537 		}).fail([=](const MTP::Error &error, mtpRequestId id) {
538 			partFailed(error, id);
539 		}).toDC(shiftedDcId).send();
540 	}, [&](const GeoPointLocation &location) {
541 		return api().request(MTPupload_GetWebFile(
542 			MTP_inputWebFileGeoPointLocation(
543 				MTP_inputGeoPoint(
544 					MTP_flags(0),
545 					MTP_double(location.lat),
546 					MTP_double(location.lon),
547 					MTP_int(0)), // accuracy_radius
548 				MTP_long(location.access),
549 				MTP_int(location.width),
550 				MTP_int(location.height),
551 				MTP_int(location.zoom),
552 				MTP_int(location.scale)),
553 			MTP_int(offset),
554 			MTP_int(limit)
555 		)).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
556 			webPartLoaded(result, id);
557 		}).fail([=](const MTP::Error &error, mtpRequestId id) {
558 			partFailed(error, id);
559 		}).toDC(shiftedDcId).send();
560 	}, [&](const StorageFileLocation &location) {
561 		const auto reference = location.fileReference();
562 		return api().request(MTPupload_GetFile(
563 			MTP_flags(MTPupload_GetFile::Flag::f_cdn_supported),
564 			location.tl(api().session().userId()),
565 			MTP_int(offset),
566 			MTP_int(limit)
567 		)).done([=](const MTPupload_File &result, mtpRequestId id) {
568 			normalPartLoaded(result, id);
569 		}).fail([=](const MTP::Error &error, mtpRequestId id) {
570 			normalPartFailed(reference, error, id);
571 		}).toDC(shiftedDcId).send();
572 	});
573 }
574 
setWebFileSizeHook(int size)575 bool DownloadMtprotoTask::setWebFileSizeHook(int size) {
576 	return true;
577 }
578 
makeRequest(const RequestData & requestData)579 void DownloadMtprotoTask::makeRequest(const RequestData &requestData) {
580 	placeSentRequest(sendRequest(requestData), requestData);
581 }
582 
requestMoreCdnFileHashes()583 void DownloadMtprotoTask::requestMoreCdnFileHashes() {
584 	if (_cdnHashesRequestId || _cdnUncheckedParts.empty()) {
585 		return;
586 	}
587 
588 	const auto requestData = _cdnUncheckedParts.cbegin()->first;
589 	const auto shiftedDcId = MTP::downloadDcId(
590 		dcId(),
591 		requestData.sessionIndex);
592 	_cdnHashesRequestId = api().request(MTPupload_GetCdnFileHashes(
593 		MTP_bytes(_cdnToken),
594 		MTP_int(requestData.offset)
595 	)).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
596 		getCdnFileHashesDone(result, id);
597 	}).fail([=](const MTP::Error &error, mtpRequestId id) {
598 		cdnPartFailed(error, id);
599 	}).toDC(shiftedDcId).send();
600 	placeSentRequest(_cdnHashesRequestId, requestData);
601 }
602 
normalPartLoaded(const MTPupload_File & result,mtpRequestId requestId)603 void DownloadMtprotoTask::normalPartLoaded(
604 		const MTPupload_File &result,
605 		mtpRequestId requestId) {
606 	const auto requestData = finishSentRequest(
607 		requestId,
608 		FinishRequestReason::Success);
609 	const auto owner = _owner;
610 	const auto dcId = this->dcId();
611 	result.match([&](const MTPDupload_fileCdnRedirect &data) {
612 		switchToCDN(requestData, data);
613 	}, [&](const MTPDupload_file &data) {
614 		partLoaded(requestData.offset, data.vbytes().v);
615 	});
616 
617 	// 'this' may be deleted at this point.
618 	owner->checkSendNextAfterSuccess(dcId);
619 }
620 
webPartLoaded(const MTPupload_WebFile & result,mtpRequestId requestId)621 void DownloadMtprotoTask::webPartLoaded(
622 		const MTPupload_WebFile &result,
623 		mtpRequestId requestId) {
624 	const auto requestData = finishSentRequest(
625 		requestId,
626 		FinishRequestReason::Success);
627 	const auto owner = _owner;
628 	const auto dcId = this->dcId();
629 	result.match([&](const MTPDupload_webFile &data) {
630 		if (setWebFileSizeHook(data.vsize().v)) {
631 			partLoaded(requestData.offset, data.vbytes().v);
632 		}
633 	});
634 
635 	// 'this' may be deleted at this point.
636 	owner->checkSendNextAfterSuccess(dcId);
637 }
638 
cdnPartLoaded(const MTPupload_CdnFile & result,mtpRequestId requestId)639 void DownloadMtprotoTask::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId) {
640 	result.match([&](const MTPDupload_cdnFileReuploadNeeded &data) {
641 		const auto requestData = finishSentRequest(
642 			requestId,
643 			FinishRequestReason::Redirect);
644 		const auto shiftedDcId = MTP::downloadDcId(
645 			dcId(),
646 			requestData.sessionIndex);
647 		const auto requestId = api().request(MTPupload_ReuploadCdnFile(
648 			MTP_bytes(_cdnToken),
649 			data.vrequest_token()
650 		)).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
651 			reuploadDone(result, id);
652 		}).fail([=](const MTP::Error &error, mtpRequestId id) {
653 			cdnPartFailed(error, id);
654 		}).toDC(shiftedDcId).send();
655 		placeSentRequest(requestId, requestData);
656 	}, [&](const MTPDupload_cdnFile &data) {
657 		const auto requestData = finishSentRequest(
658 			requestId,
659 			FinishRequestReason::Success);
660 		const auto owner = _owner;
661 		const auto dcId = this->dcId();
662 		const auto guard = gsl::finally([=] {
663 			// 'this' may be deleted at this point.
664 			owner->checkSendNextAfterSuccess(dcId);
665 		});
666 
667 		auto key = bytes::make_span(_cdnEncryptionKey);
668 		auto iv = bytes::make_span(_cdnEncryptionIV);
669 		Expects(key.size() == MTP::CTRState::KeySize);
670 		Expects(iv.size() == MTP::CTRState::IvecSize);
671 
672 		auto state = MTP::CTRState();
673 		auto ivec = bytes::make_span(state.ivec);
674 		std::copy(iv.begin(), iv.end(), ivec.begin());
675 
676 		auto counterOffset = static_cast<uint32>(requestData.offset) >> 4;
677 		state.ivec[15] = static_cast<uchar>(counterOffset & 0xFF);
678 		state.ivec[14] = static_cast<uchar>((counterOffset >> 8) & 0xFF);
679 		state.ivec[13] = static_cast<uchar>((counterOffset >> 16) & 0xFF);
680 		state.ivec[12] = static_cast<uchar>((counterOffset >> 24) & 0xFF);
681 
682 		auto decryptInPlace = data.vbytes().v;
683 		auto buffer = bytes::make_detached_span(decryptInPlace);
684 		MTP::aesCtrEncrypt(buffer, key.data(), &state);
685 
686 		switch (checkCdnFileHash(requestData.offset, buffer)) {
687 		case CheckCdnHashResult::NoHash: {
688 			_cdnUncheckedParts.emplace(requestData, decryptInPlace);
689 			requestMoreCdnFileHashes();
690 		} return;
691 
692 		case CheckCdnHashResult::Invalid: {
693 			LOG(("API Error: Wrong cdnFileHash for offset %1."
694 				).arg(requestData.offset));
695 			cancelOnFail();
696 		} return;
697 
698 		case CheckCdnHashResult::Good: {
699 			partLoaded(requestData.offset, decryptInPlace);
700 		} return;
701 		}
702 		Unexpected("Result of checkCdnFileHash()");
703 	});
704 }
705 
checkCdnFileHash(int offset,bytes::const_span buffer)706 DownloadMtprotoTask::CheckCdnHashResult DownloadMtprotoTask::checkCdnFileHash(
707 		int offset,
708 		bytes::const_span buffer) {
709 	const auto cdnFileHashIt = _cdnFileHashes.find(offset);
710 	if (cdnFileHashIt == _cdnFileHashes.cend()) {
711 		return CheckCdnHashResult::NoHash;
712 	}
713 	const auto realHash = openssl::Sha256(buffer);
714 	const auto receivedHash = bytes::make_span(cdnFileHashIt->second.hash);
715 	if (bytes::compare(realHash, receivedHash)) {
716 		return CheckCdnHashResult::Invalid;
717 	}
718 	return CheckCdnHashResult::Good;
719 }
720 
reuploadDone(const MTPVector<MTPFileHash> & result,mtpRequestId requestId)721 void DownloadMtprotoTask::reuploadDone(
722 		const MTPVector<MTPFileHash> &result,
723 		mtpRequestId requestId) {
724 	const auto requestData = finishSentRequest(
725 		requestId,
726 		FinishRequestReason::Redirect);
727 	addCdnHashes(result.v);
728 	makeRequest(requestData);
729 }
730 
getCdnFileHashesDone(const MTPVector<MTPFileHash> & result,mtpRequestId requestId)731 void DownloadMtprotoTask::getCdnFileHashesDone(
732 		const MTPVector<MTPFileHash> &result,
733 		mtpRequestId requestId) {
734 	Expects(_cdnHashesRequestId == requestId);
735 
736 	const auto requestData = finishSentRequest(
737 		requestId,
738 		FinishRequestReason::Redirect);
739 	addCdnHashes(result.v);
740 	auto someMoreChecked = false;
741 	for (auto i = _cdnUncheckedParts.begin(); i != _cdnUncheckedParts.cend();) {
742 		const auto uncheckedData = i->first;
743 		const auto uncheckedBytes = bytes::make_span(i->second);
744 
745 		switch (checkCdnFileHash(uncheckedData.offset, uncheckedBytes)) {
746 		case CheckCdnHashResult::NoHash: {
747 			++i;
748 		} break;
749 
750 		case CheckCdnHashResult::Invalid: {
751 			LOG(("API Error: Wrong cdnFileHash for offset %1."
752 				).arg(uncheckedData.offset));
753 			cancelOnFail();
754 			return;
755 		} break;
756 
757 		case CheckCdnHashResult::Good: {
758 			someMoreChecked = true;
759 			const auto goodOffset = uncheckedData.offset;
760 			const auto goodBytes = std::move(i->second);
761 			const auto weak = base::make_weak(this);
762 			i = _cdnUncheckedParts.erase(i);
763 			if (!feedPart(goodOffset, goodBytes) || !weak) {
764 				return;
765 			}
766 		} break;
767 
768 		default: Unexpected("Result of checkCdnFileHash()");
769 		}
770 	}
771 	if (!someMoreChecked) {
772 		LOG(("API Error: "
773 			"Could not find cdnFileHash for offset %1 "
774 			"after getCdnFileHashes request."
775 			).arg(requestData.offset));
776 		cancelOnFail();
777 		return;
778 	}
779 	requestMoreCdnFileHashes();
780 }
781 
placeSentRequest(mtpRequestId requestId,const RequestData & requestData)782 void DownloadMtprotoTask::placeSentRequest(
783 		mtpRequestId requestId,
784 		const RequestData &requestData) {
785 	const auto amount = _owner->changeRequestedAmount(
786 		dcId(),
787 		requestData.sessionIndex,
788 		Storage::kDownloadPartSize);
789 	const auto [i, ok1] = _sentRequests.emplace(requestId, requestData);
790 	const auto [j, ok2] = _requestByOffset.emplace(
791 		requestData.offset,
792 		requestId);
793 
794 	i->second.requestedInSession = amount;
795 	i->second.sent = crl::now();
796 
797 	Ensures(ok1 && ok2);
798 }
799 
finishSentRequest(mtpRequestId requestId,FinishRequestReason reason)800 auto DownloadMtprotoTask::finishSentRequest(
801 	mtpRequestId requestId,
802 	FinishRequestReason reason)
803 -> RequestData {
804 	auto it = _sentRequests.find(requestId);
805 	Assert(it != _sentRequests.cend());
806 
807 	if (_cdnHashesRequestId == requestId) {
808 		_cdnHashesRequestId = 0;
809 	}
810 	const auto result = it->second;
811 	_owner->changeRequestedAmount(
812 		dcId(),
813 		result.sessionIndex,
814 		-Storage::kDownloadPartSize);
815 	_sentRequests.erase(it);
816 	const auto ok = _requestByOffset.remove(result.offset);
817 
818 	if (reason == FinishRequestReason::Success) {
819 		_owner->requestSucceeded(
820 			dcId(),
821 			result.sessionIndex,
822 			result.requestedInSession,
823 			result.sent);
824 	}
825 
826 	Ensures(ok);
827 	return result;
828 }
829 
haveSentRequests() const830 bool DownloadMtprotoTask::haveSentRequests() const {
831 	return !_sentRequests.empty() || !_cdnUncheckedParts.empty();
832 }
833 
haveSentRequestForOffset(int offset) const834 bool DownloadMtprotoTask::haveSentRequestForOffset(int offset) const {
835 	return _requestByOffset.contains(offset)
836 		|| _cdnUncheckedParts.contains({ offset, 0 });
837 }
838 
cancelAllRequests()839 void DownloadMtprotoTask::cancelAllRequests() {
840 	while (!_sentRequests.empty()) {
841 		cancelRequest(_sentRequests.begin()->first);
842 	}
843 	_cdnUncheckedParts.clear();
844 }
845 
cancelRequestForOffset(int offset)846 void DownloadMtprotoTask::cancelRequestForOffset(int offset) {
847 	const auto i = _requestByOffset.find(offset);
848 	if (i != end(_requestByOffset)) {
849 		cancelRequest(i->second);
850 	}
851 	_cdnUncheckedParts.remove({ offset, 0 });
852 }
853 
cancelRequest(mtpRequestId requestId)854 void DownloadMtprotoTask::cancelRequest(mtpRequestId requestId) {
855 	const auto hashes = (_cdnHashesRequestId == requestId);
856 	api().request(requestId).cancel();
857 	[[maybe_unused]] const auto data = finishSentRequest(
858 		requestId,
859 		FinishRequestReason::Cancel);
860 	if (hashes && !_cdnUncheckedParts.empty()) {
861 		crl::on_main(this, [=] {
862 			requestMoreCdnFileHashes();
863 		});
864 	}
865 }
866 
addToQueue(int priority)867 void DownloadMtprotoTask::addToQueue(int priority) {
868 	_owner->enqueue(this, priority);
869 }
870 
removeFromQueue()871 void DownloadMtprotoTask::removeFromQueue() {
872 	_owner->remove(this);
873 }
874 
partLoaded(int offset,const QByteArray & bytes)875 void DownloadMtprotoTask::partLoaded(
876 		int offset,
877 		const QByteArray &bytes) {
878 	feedPart(offset, bytes);
879 }
880 
normalPartFailed(QByteArray fileReference,const MTP::Error & error,mtpRequestId requestId)881 bool DownloadMtprotoTask::normalPartFailed(
882 		QByteArray fileReference,
883 		const MTP::Error &error,
884 		mtpRequestId requestId) {
885 	if (MTP::IsDefaultHandledError(error)) {
886 		return false;
887 	}
888 	if (error.code() == 400
889 		&& error.type().startsWith(qstr("FILE_REFERENCE_"))) {
890 		api().refreshFileReference(
891 			_origin,
892 			this,
893 			requestId,
894 			fileReference);
895 		return true;
896 	}
897 	return partFailed(error, requestId);
898 }
899 
partFailed(const MTP::Error & error,mtpRequestId requestId)900 bool DownloadMtprotoTask::partFailed(
901 		const MTP::Error &error,
902 		mtpRequestId requestId) {
903 	if (MTP::IsDefaultHandledError(error)) {
904 		return false;
905 	}
906 	cancelOnFail();
907 	return true;
908 }
909 
cdnPartFailed(const MTP::Error & error,mtpRequestId requestId)910 bool DownloadMtprotoTask::cdnPartFailed(
911 		const MTP::Error &error,
912 		mtpRequestId requestId) {
913 	if (MTP::IsDefaultHandledError(error)) {
914 		return false;
915 	}
916 
917 	if (error.type() == qstr("FILE_TOKEN_INVALID")
918 		|| error.type() == qstr("REQUEST_TOKEN_INVALID")) {
919 		const auto requestData = finishSentRequest(
920 			requestId,
921 			FinishRequestReason::Redirect);
922 		changeCDNParams(
923 			requestData,
924 			0,
925 			QByteArray(),
926 			QByteArray(),
927 			QByteArray(),
928 			QVector<MTPFileHash>());
929 		return true;
930 	}
931 	return partFailed(error, requestId);
932 }
933 
switchToCDN(const RequestData & requestData,const MTPDupload_fileCdnRedirect & redirect)934 void DownloadMtprotoTask::switchToCDN(
935 		const RequestData &requestData,
936 		const MTPDupload_fileCdnRedirect &redirect) {
937 	changeCDNParams(
938 		requestData,
939 		redirect.vdc_id().v,
940 		redirect.vfile_token().v,
941 		redirect.vencryption_key().v,
942 		redirect.vencryption_iv().v,
943 		redirect.vfile_hashes().v);
944 }
945 
addCdnHashes(const QVector<MTPFileHash> & hashes)946 void DownloadMtprotoTask::addCdnHashes(
947 		const QVector<MTPFileHash> &hashes) {
948 	for (const auto &hash : hashes) {
949 		hash.match([&](const MTPDfileHash &data) {
950 			_cdnFileHashes.emplace(
951 				data.voffset().v,
952 				CdnFileHash{ data.vlimit().v, data.vhash().v });
953 		});
954 	}
955 }
956 
changeCDNParams(const RequestData & requestData,MTP::DcId dcId,const QByteArray & token,const QByteArray & encryptionKey,const QByteArray & encryptionIV,const QVector<MTPFileHash> & hashes)957 void DownloadMtprotoTask::changeCDNParams(
958 		const RequestData &requestData,
959 		MTP::DcId dcId,
960 		const QByteArray &token,
961 		const QByteArray &encryptionKey,
962 		const QByteArray &encryptionIV,
963 		const QVector<MTPFileHash> &hashes) {
964 	if (dcId != 0
965 		&& (encryptionKey.size() != MTP::CTRState::KeySize
966 			|| encryptionIV.size() != MTP::CTRState::IvecSize)) {
967 		LOG(("Message Error: Wrong key (%1) / iv (%2) size in CDN params"
968 			).arg(encryptionKey.size()
969 			).arg(encryptionIV.size()));
970 		cancelOnFail();
971 		return;
972 	}
973 
974 	auto resendAllRequests = (_cdnDcId != dcId
975 		|| _cdnToken != token
976 		|| _cdnEncryptionKey != encryptionKey
977 		|| _cdnEncryptionIV != encryptionIV);
978 	_cdnDcId = dcId;
979 	_cdnToken = token;
980 	_cdnEncryptionKey = encryptionKey;
981 	_cdnEncryptionIV = encryptionIV;
982 	addCdnHashes(hashes);
983 
984 	if (resendAllRequests && !_sentRequests.empty()) {
985 		auto resendRequests = std::vector<RequestData>();
986 		resendRequests.reserve(_sentRequests.size());
987 		while (!_sentRequests.empty()) {
988 			const auto requestId = _sentRequests.begin()->first;
989 			api().request(requestId).cancel();
990 			resendRequests.push_back(finishSentRequest(
991 				requestId,
992 				FinishRequestReason::Redirect));
993 		}
994 		for (const auto &requestData : resendRequests) {
995 			makeRequest(requestData);
996 		}
997 	}
998 	makeRequest(requestData);
999 }
1000 
1001 } // namespace Storage
1002