1 /*
2 This file is part of Telegram Desktop,
3 the official desktop application for the Telegram messaging service.
4
5 For license and copyright information please follow this link:
6 https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
7 */
8 #include "storage/download_manager_mtproto.h"
9
10 #include "mtproto/facade.h"
11 #include "mtproto/mtproto_auth_key.h"
12 #include "mtproto/mtproto_response.h"
13 #include "main/main_session.h"
14 #include "apiwrap.h"
15 #include "base/openssl_help.h"
16
17 namespace Storage {
18 namespace {
19
20 constexpr auto kKillSessionTimeout = 15 * crl::time(1000);
21 constexpr auto kStartWaitedInSession = 4 * kDownloadPartSize;
22 constexpr auto kMaxWaitedInSession = 16 * kDownloadPartSize;
23 constexpr auto kStartSessionsCount = 1;
24 constexpr auto kMaxSessionsCount = 8;
25 constexpr auto kMaxTrackedSessionRemoves = 64;
26 constexpr auto kRetryAddSessionTimeout = 8 * crl::time(1000);
27 constexpr auto kRetryAddSessionSuccesses = 3;
28 constexpr auto kMaxTrackedSuccesses = kRetryAddSessionSuccesses
29 * kMaxTrackedSessionRemoves;
30 constexpr auto kRemoveSessionAfterTimeouts = 4;
31 constexpr auto kResetDownloadPrioritiesTimeout = crl::time(200);
32 constexpr auto kBadRequestDurationThreshold = 8 * crl::time(1000);
33
34 // Each (session remove by timeouts) we wait for time:
35 // kRetryAddSessionTimeout * max(removesCount, kMaxTrackedSessionRemoves)
36 // and for successes in all remaining sessions:
37 // kRetryAddSessionSuccesses * max(removesCount, kMaxTrackedSessionRemoves)
38
39 } // namespace
40
enqueue(not_null<Task * > task,int priority)41 void DownloadManagerMtproto::Queue::enqueue(
42 not_null<Task*> task,
43 int priority) {
44 const auto position = ranges::find_if(_tasks, [&](const Enqueued &task) {
45 return task.priority <= priority;
46 }) - begin(_tasks);
47 const auto now = ranges::find(_tasks, task, &Enqueued::task);
48 const auto i = [&] {
49 if (now != end(_tasks)) {
50 (now->priority = priority);
51 return now;
52 }
53 _tasks.push_back({ task, priority });
54 return end(_tasks) - 1;
55 }();
56 const auto j = begin(_tasks) + position;
57 if (j < i) {
58 std::rotate(j, i, i + 1);
59 } else if (j > i + 1) {
60 std::rotate(i, i + 1, j);
61 }
62 }
63
remove(not_null<Task * > task)64 void DownloadManagerMtproto::Queue::remove(not_null<Task*> task) {
65 _tasks.erase(ranges::remove(_tasks, task, &Enqueued::task), end(_tasks));
66 }
67
resetGeneration()68 void DownloadManagerMtproto::Queue::resetGeneration() {
69 const auto from = ranges::find(_tasks, 0, &Enqueued::priority);
70 for (auto &task : ranges::make_subrange(from, end(_tasks))) {
71 if (task.priority) {
72 Assert(task.priority == -1);
73 break;
74 }
75 task.priority = -1;
76 }
77 }
78
empty() const79 bool DownloadManagerMtproto::Queue::empty() const {
80 return _tasks.empty();
81 }
82
nextTask(bool onlyHighestPriority) const83 auto DownloadManagerMtproto::Queue::nextTask(bool onlyHighestPriority) const
84 -> Task* {
85 if (_tasks.empty()) {
86 return nullptr;
87 }
88 const auto highestPriority = _tasks.front().priority;
89 const auto notHighestPriority = [&](const Enqueued &enqueued) {
90 return (enqueued.priority != highestPriority);
91 };
92 const auto till = (onlyHighestPriority && highestPriority > 0)
93 ? ranges::find_if(_tasks, notHighestPriority)
94 : end(_tasks);
95 const auto readyToRequest = [&](const Enqueued &enqueued) {
96 return enqueued.task->readyToRequest();
97 };
98 const auto first = ranges::find_if(
99 ranges::make_subrange(begin(_tasks), till),
100 readyToRequest);
101 return (first != till) ? first->task.get() : nullptr;
102 }
103
removeSession(int index)104 void DownloadManagerMtproto::Queue::removeSession(int index) {
105 for (const auto &enqueued : _tasks) {
106 enqueued.task->removeSession(index);
107 }
108 }
109
DcSessionBalanceData()110 DownloadManagerMtproto::DcSessionBalanceData::DcSessionBalanceData()
111 : maxWaitedAmount(kStartWaitedInSession) {
112 }
113
DcBalanceData()114 DownloadManagerMtproto::DcBalanceData::DcBalanceData()
115 : sessions(kStartSessionsCount) {
116 }
117
DownloadManagerMtproto(not_null<ApiWrap * > api)118 DownloadManagerMtproto::DownloadManagerMtproto(not_null<ApiWrap*> api)
119 : _api(api)
120 , _resetGenerationTimer([=] { resetGeneration(); })
__anon188a26980702null121 , _killSessionsTimer([=] { killSessions(); }) {
122 _api->instance().restartsByTimeout(
__anon188a26980802(MTP::ShiftedDcId shiftedDcId) 123 ) | rpl::filter([](MTP::ShiftedDcId shiftedDcId) {
124 return MTP::isDownloadDcId(shiftedDcId);
125 }) | rpl::start_with_next([=](MTP::ShiftedDcId shiftedDcId) {
126 sessionTimedOut(
127 MTP::BareDcId(shiftedDcId),
128 MTP::GetDcIdShift(shiftedDcId));
129 }, _lifetime);
130 }
131
~DownloadManagerMtproto()132 DownloadManagerMtproto::~DownloadManagerMtproto() {
133 killSessions();
134 }
135
enqueue(not_null<Task * > task,int priority)136 void DownloadManagerMtproto::enqueue(not_null<Task*> task, int priority) {
137 const auto dcId = task->dcId();
138 auto &queue = _queues[dcId];
139 queue.enqueue(task, priority);
140 if (!_resetGenerationTimer.isActive()) {
141 _resetGenerationTimer.callOnce(kResetDownloadPrioritiesTimeout);
142 }
143 checkSendNext(dcId, queue);
144 }
145
remove(not_null<Task * > task)146 void DownloadManagerMtproto::remove(not_null<Task*> task) {
147 const auto dcId = task->dcId();
148 auto &queue = _queues[dcId];
149 queue.remove(task);
150 checkSendNext(dcId, queue);
151 }
152
resetGeneration()153 void DownloadManagerMtproto::resetGeneration() {
154 _resetGenerationTimer.cancel();
155 for (auto &[dcId, queue] : _queues) {
156 queue.resetGeneration();
157 }
158 }
159
checkSendNext()160 void DownloadManagerMtproto::checkSendNext() {
161 for (auto &[dcId, queue] : _queues) {
162 if (queue.empty()) {
163 continue;
164 }
165 checkSendNext(dcId, queue);
166 }
167 }
168
checkSendNext(MTP::DcId dcId,Queue & queue)169 void DownloadManagerMtproto::checkSendNext(MTP::DcId dcId, Queue &queue) {
170 while (trySendNextPart(dcId, queue)) {
171 }
172 }
173
checkSendNextAfterSuccess(MTP::DcId dcId)174 void DownloadManagerMtproto::checkSendNextAfterSuccess(MTP::DcId dcId) {
175 checkSendNext(dcId, _queues[dcId]);
176 }
177
trySendNextPart(MTP::DcId dcId,Queue & queue)178 bool DownloadManagerMtproto::trySendNextPart(MTP::DcId dcId, Queue &queue) {
179 auto &balanceData = _balanceData[dcId];
180 const auto &sessions = balanceData.sessions;
181 const auto bestIndex = [&] {
182 const auto proj = [](const DcSessionBalanceData &data) {
183 return (data.requested < data.maxWaitedAmount)
184 ? data.requested
185 : kMaxWaitedInSession;
186 };
187 const auto j = ranges::min_element(sessions, ranges::less(), proj);
188 return (j->requested + kDownloadPartSize <= j->maxWaitedAmount)
189 ? (j - begin(sessions))
190 : -1;
191 }();
192 if (bestIndex < 0) {
193 return false;
194 }
195 const auto onlyHighestPriority = (balanceData.totalRequested > 0);
196 if (const auto task = queue.nextTask(onlyHighestPriority)) {
197 task->loadPart(bestIndex);
198 return true;
199 }
200 return false;
201 }
202
changeRequestedAmount(MTP::DcId dcId,int index,int delta)203 int DownloadManagerMtproto::changeRequestedAmount(
204 MTP::DcId dcId,
205 int index,
206 int delta) {
207 const auto i = _balanceData.find(dcId);
208 Assert(i != _balanceData.end());
209 Assert(index < i->second.sessions.size());
210 const auto result = (i->second.sessions[index].requested += delta);
211 i->second.totalRequested += delta;
212 const auto findNonEmptySession = [](const DcBalanceData &data) {
213 using namespace rpl::mappers;
214 return ranges::find_if(
215 data.sessions,
216 _1 > 0,
217 &DcSessionBalanceData::requested);
218 };
219 if (delta > 0) {
220 killSessionsCancel(dcId);
221 } else if (findNonEmptySession(i->second) == end(i->second.sessions)) {
222 killSessionsSchedule(dcId);
223 }
224 return result;
225 }
226
requestSucceeded(MTP::DcId dcId,int index,int amountAtRequestStart,crl::time timeAtRequestStart)227 void DownloadManagerMtproto::requestSucceeded(
228 MTP::DcId dcId,
229 int index,
230 int amountAtRequestStart,
231 crl::time timeAtRequestStart) {
232 using namespace rpl::mappers;
233
234 const auto i = _balanceData.find(dcId);
235 Assert(i != end(_balanceData));
236 auto &dc = i->second;
237 Assert(index < dc.sessions.size());
238 auto &data = dc.sessions[index];
239 const auto overloaded = (timeAtRequestStart <= dc.lastSessionRemove)
240 || (amountAtRequestStart > data.maxWaitedAmount);
241 const auto parts = amountAtRequestStart / kDownloadPartSize;
242 const auto duration = (crl::now() - timeAtRequestStart);
243 DEBUG_LOG(("Download (%1,%2) request done, duration: %3, parts: %4%5"
244 ).arg(dcId
245 ).arg(index
246 ).arg(duration
247 ).arg(parts
248 ).arg(overloaded ? " (overloaded)" : ""));
249 if (overloaded) {
250 return;
251 }
252
253 if (duration >= kBadRequestDurationThreshold) {
254 DEBUG_LOG(("Duration too large, signaling time out."));
255 crl::on_main(this, [=] {
256 sessionTimedOut(dcId, index);
257 });
258 return;
259 }
260 if (amountAtRequestStart == data.maxWaitedAmount
261 && data.maxWaitedAmount < kMaxWaitedInSession) {
262 data.maxWaitedAmount = std::min(
263 data.maxWaitedAmount + kDownloadPartSize,
264 kMaxWaitedInSession);
265 DEBUG_LOG(("Download (%1,%2) increased max waited amount %3."
266 ).arg(dcId
267 ).arg(index
268 ).arg(data.maxWaitedAmount));
269 }
270 data.successes = std::min(data.successes + 1, kMaxTrackedSuccesses);
271 const auto notEnough = ranges::any_of(
272 dc.sessions,
273 _1 < (dc.sessionRemoveTimes + 1) * kRetryAddSessionSuccesses,
274 &DcSessionBalanceData::successes);
275 if (notEnough) {
276 return;
277 }
278 for (auto &session : dc.sessions) {
279 session.successes = 0;
280 }
281 if (dc.timeouts > 0) {
282 --dc.timeouts;
283 return;
284 } else if (dc.sessions.size() == kMaxSessionsCount) {
285 return;
286 }
287 const auto now = crl::now();
288 const auto delay = (dc.sessionRemoveTimes + 1) * kRetryAddSessionTimeout;
289 if (dc.lastSessionRemove && now < dc.lastSessionRemove + delay) {
290 return;
291 }
292 dc.sessions.emplace_back();
293 DEBUG_LOG(("Download (%1,%2) adding, now sessions: %3"
294 ).arg(dcId
295 ).arg(dc.sessions.size() - 1
296 ).arg(dc.sessions.size()));
297 }
298
chooseSessionIndex(MTP::DcId dcId) const299 int DownloadManagerMtproto::chooseSessionIndex(MTP::DcId dcId) const {
300 const auto i = _balanceData.find(dcId);
301 Assert(i != end(_balanceData));
302 const auto &sessions = i->second.sessions;
303 const auto j = ranges::min_element(
304 sessions,
305 ranges::less(),
306 &DcSessionBalanceData::requested);
307 return (j - begin(sessions));
308 }
309
sessionTimedOut(MTP::DcId dcId,int index)310 void DownloadManagerMtproto::sessionTimedOut(MTP::DcId dcId, int index) {
311 const auto i = _balanceData.find(dcId);
312 if (i == end(_balanceData)) {
313 return;
314 }
315 auto &dc = i->second;
316 if (index >= dc.sessions.size()) {
317 return;
318 }
319 DEBUG_LOG(("Download (%1,%2) session timed-out.").arg(dcId).arg(index));
320 for (auto &session : dc.sessions) {
321 session.successes = 0;
322 }
323 if (dc.sessions.size() == kStartSessionsCount
324 || ++dc.timeouts < kRemoveSessionAfterTimeouts) {
325 return;
326 }
327 dc.timeouts = 0;
328 removeSession(dcId);
329 }
330
removeSession(MTP::DcId dcId)331 void DownloadManagerMtproto::removeSession(MTP::DcId dcId) {
332 auto &dc = _balanceData[dcId];
333 Assert(dc.sessions.size() > kStartSessionsCount);
334 const auto index = int(dc.sessions.size() - 1);
335 DEBUG_LOG(("Download (%1,%2) removing, now sessions: %3"
336 ).arg(dcId
337 ).arg(index
338 ).arg(index));
339 auto &queue = _queues[dcId];
340 if (dc.sessionRemoveIndex == index) {
341 dc.sessionRemoveTimes = std::min(
342 dc.sessionRemoveTimes + 1,
343 kMaxTrackedSessionRemoves);
344 } else {
345 dc.sessionRemoveIndex = index;
346 dc.sessionRemoveTimes = 1;
347 }
348 auto &session = dc.sessions.back();
349
350 // Make sure we don't send anything to that session while redirecting.
351 session.requested += kMaxWaitedInSession * kMaxSessionsCount;
352 queue.removeSession(index);
353 Assert(session.requested == kMaxWaitedInSession * kMaxSessionsCount);
354
355 dc.sessions.pop_back();
356 api().instance().killSession(MTP::downloadDcId(dcId, index));
357
358 dc.lastSessionRemove = crl::now();
359 }
360
killSessionsSchedule(MTP::DcId dcId)361 void DownloadManagerMtproto::killSessionsSchedule(MTP::DcId dcId) {
362 if (!_killSessionsWhen.contains(dcId)) {
363 _killSessionsWhen.emplace(dcId, crl::now() + kKillSessionTimeout);
364 }
365 if (!_killSessionsTimer.isActive()) {
366 _killSessionsTimer.callOnce(kKillSessionTimeout + 5);
367 }
368 }
369
killSessionsCancel(MTP::DcId dcId)370 void DownloadManagerMtproto::killSessionsCancel(MTP::DcId dcId) {
371 _killSessionsWhen.erase(dcId);
372 if (_killSessionsWhen.empty()) {
373 _killSessionsTimer.cancel();
374 }
375 }
376
killSessions()377 void DownloadManagerMtproto::killSessions() {
378 const auto now = crl::now();
379 auto left = kKillSessionTimeout;
380 for (auto i = begin(_killSessionsWhen); i != end(_killSessionsWhen); ) {
381 if (i->second <= now) {
382 killSessions(i->first);
383 i = _killSessionsWhen.erase(i);
384 } else {
385 if (i->second - now < left) {
386 left = i->second - now;
387 }
388 ++i;
389 }
390 }
391 if (!_killSessionsWhen.empty()) {
392 _killSessionsTimer.callOnce(left);
393 }
394 }
395
killSessions(MTP::DcId dcId)396 void DownloadManagerMtproto::killSessions(MTP::DcId dcId) {
397 const auto i = _balanceData.find(dcId);
398 if (i != end(_balanceData)) {
399 auto &dc = i->second;
400 Assert(dc.totalRequested == 0);
401 auto sessions = base::take(dc.sessions);
402 dc = DcBalanceData();
403 for (auto j = 0; j != int(sessions.size()); ++j) {
404 Assert(sessions[j].requested == 0);
405 sessions[j] = DcSessionBalanceData();
406 api().instance().stopSession(MTP::downloadDcId(dcId, j));
407 }
408 dc.sessions = base::take(sessions);
409 }
410 }
411
DownloadMtprotoTask(not_null<DownloadManagerMtproto * > owner,const StorageFileLocation & location,Data::FileOrigin origin)412 DownloadMtprotoTask::DownloadMtprotoTask(
413 not_null<DownloadManagerMtproto*> owner,
414 const StorageFileLocation &location,
415 Data::FileOrigin origin)
416 : _owner(owner)
417 , _dcId(location.dcId())
418 , _location({ location })
419 , _origin(origin) {
420 }
421
DownloadMtprotoTask(not_null<DownloadManagerMtproto * > owner,MTP::DcId dcId,const Location & location)422 DownloadMtprotoTask::DownloadMtprotoTask(
423 not_null<DownloadManagerMtproto*> owner,
424 MTP::DcId dcId,
425 const Location &location)
426 : _owner(owner)
427 , _dcId(dcId)
428 , _location(location) {
429 }
430
~DownloadMtprotoTask()431 DownloadMtprotoTask::~DownloadMtprotoTask() {
432 cancelAllRequests();
433 _owner->remove(this);
434 }
435
dcId() const436 MTP::DcId DownloadMtprotoTask::dcId() const {
437 return _dcId;
438 }
439
fileOrigin() const440 Data::FileOrigin DownloadMtprotoTask::fileOrigin() const {
441 return _origin;
442 }
443
objectId() const444 uint64 DownloadMtprotoTask::objectId() const {
445 if (const auto v = std::get_if<StorageFileLocation>(&_location.data)) {
446 return v->objectId();
447 }
448 return 0;
449 }
450
location() const451 const DownloadMtprotoTask::Location &DownloadMtprotoTask::location() const {
452 return _location;
453 }
454
refreshFileReferenceFrom(const Data::UpdatedFileReferences & updates,int requestId,const QByteArray & current)455 void DownloadMtprotoTask::refreshFileReferenceFrom(
456 const Data::UpdatedFileReferences &updates,
457 int requestId,
458 const QByteArray ¤t) {
459 if (const auto v = std::get_if<StorageFileLocation>(&_location.data)) {
460 v->refreshFileReference(updates);
461 if (v->fileReference() == current) {
462 cancelOnFail();
463 return;
464 }
465 } else {
466 cancelOnFail();
467 return;
468 }
469 if (_sentRequests.contains(requestId)) {
470 makeRequest(finishSentRequest(
471 requestId,
472 FinishRequestReason::Redirect));
473 }
474 }
475
loadPart(int sessionIndex)476 void DownloadMtprotoTask::loadPart(int sessionIndex) {
477 makeRequest({ takeNextRequestOffset(), sessionIndex });
478 }
479
removeSession(int sessionIndex)480 void DownloadMtprotoTask::removeSession(int sessionIndex) {
481 struct Redirect {
482 mtpRequestId requestId = 0;
483 int offset = 0;
484 };
485 auto redirect = std::vector<Redirect>();
486 for (const auto &[requestId, requestData] : _sentRequests) {
487 if (requestData.sessionIndex == sessionIndex) {
488 redirect.reserve(_sentRequests.size());
489 redirect.push_back({ requestId, requestData.offset });
490 }
491 }
492 for (auto &[requestData, bytes] : _cdnUncheckedParts) {
493 if (requestData.sessionIndex == sessionIndex) {
494 const auto newIndex = _owner->chooseSessionIndex(dcId());
495 Assert(newIndex < sessionIndex);
496 requestData.sessionIndex = newIndex;
497 }
498 }
499 for (const auto &[requestId, offset] : redirect) {
500 const auto needMakeRequest = (requestId != _cdnHashesRequestId);
501 cancelRequest(requestId);
502 if (needMakeRequest) {
503 const auto newIndex = _owner->chooseSessionIndex(dcId());
504 Assert(newIndex < sessionIndex);
505 makeRequest({ offset, newIndex });
506 }
507 }
508 }
509
sendRequest(const RequestData & requestData)510 mtpRequestId DownloadMtprotoTask::sendRequest(
511 const RequestData &requestData) {
512 const auto offset = requestData.offset;
513 const auto limit = Storage::kDownloadPartSize;
514 const auto shiftedDcId = MTP::downloadDcId(
515 _cdnDcId ? _cdnDcId : dcId(),
516 requestData.sessionIndex);
517 if (_cdnDcId) {
518 return api().request(MTPupload_GetCdnFile(
519 MTP_bytes(_cdnToken),
520 MTP_int(offset),
521 MTP_int(limit)
522 )).done([=](const MTPupload_CdnFile &result, mtpRequestId id) {
523 cdnPartLoaded(result, id);
524 }).fail([=](const MTP::Error &error, mtpRequestId id) {
525 cdnPartFailed(error, id);
526 }).toDC(shiftedDcId).send();
527 }
528 return v::match(_location.data, [&](const WebFileLocation &location) {
529 return api().request(MTPupload_GetWebFile(
530 MTP_inputWebFileLocation(
531 MTP_bytes(location.url()),
532 MTP_long(location.accessHash())),
533 MTP_int(offset),
534 MTP_int(limit)
535 )).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
536 webPartLoaded(result, id);
537 }).fail([=](const MTP::Error &error, mtpRequestId id) {
538 partFailed(error, id);
539 }).toDC(shiftedDcId).send();
540 }, [&](const GeoPointLocation &location) {
541 return api().request(MTPupload_GetWebFile(
542 MTP_inputWebFileGeoPointLocation(
543 MTP_inputGeoPoint(
544 MTP_flags(0),
545 MTP_double(location.lat),
546 MTP_double(location.lon),
547 MTP_int(0)), // accuracy_radius
548 MTP_long(location.access),
549 MTP_int(location.width),
550 MTP_int(location.height),
551 MTP_int(location.zoom),
552 MTP_int(location.scale)),
553 MTP_int(offset),
554 MTP_int(limit)
555 )).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
556 webPartLoaded(result, id);
557 }).fail([=](const MTP::Error &error, mtpRequestId id) {
558 partFailed(error, id);
559 }).toDC(shiftedDcId).send();
560 }, [&](const StorageFileLocation &location) {
561 const auto reference = location.fileReference();
562 return api().request(MTPupload_GetFile(
563 MTP_flags(MTPupload_GetFile::Flag::f_cdn_supported),
564 location.tl(api().session().userId()),
565 MTP_int(offset),
566 MTP_int(limit)
567 )).done([=](const MTPupload_File &result, mtpRequestId id) {
568 normalPartLoaded(result, id);
569 }).fail([=](const MTP::Error &error, mtpRequestId id) {
570 normalPartFailed(reference, error, id);
571 }).toDC(shiftedDcId).send();
572 });
573 }
574
setWebFileSizeHook(int size)575 bool DownloadMtprotoTask::setWebFileSizeHook(int size) {
576 return true;
577 }
578
makeRequest(const RequestData & requestData)579 void DownloadMtprotoTask::makeRequest(const RequestData &requestData) {
580 placeSentRequest(sendRequest(requestData), requestData);
581 }
582
requestMoreCdnFileHashes()583 void DownloadMtprotoTask::requestMoreCdnFileHashes() {
584 if (_cdnHashesRequestId || _cdnUncheckedParts.empty()) {
585 return;
586 }
587
588 const auto requestData = _cdnUncheckedParts.cbegin()->first;
589 const auto shiftedDcId = MTP::downloadDcId(
590 dcId(),
591 requestData.sessionIndex);
592 _cdnHashesRequestId = api().request(MTPupload_GetCdnFileHashes(
593 MTP_bytes(_cdnToken),
594 MTP_int(requestData.offset)
595 )).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
596 getCdnFileHashesDone(result, id);
597 }).fail([=](const MTP::Error &error, mtpRequestId id) {
598 cdnPartFailed(error, id);
599 }).toDC(shiftedDcId).send();
600 placeSentRequest(_cdnHashesRequestId, requestData);
601 }
602
normalPartLoaded(const MTPupload_File & result,mtpRequestId requestId)603 void DownloadMtprotoTask::normalPartLoaded(
604 const MTPupload_File &result,
605 mtpRequestId requestId) {
606 const auto requestData = finishSentRequest(
607 requestId,
608 FinishRequestReason::Success);
609 const auto owner = _owner;
610 const auto dcId = this->dcId();
611 result.match([&](const MTPDupload_fileCdnRedirect &data) {
612 switchToCDN(requestData, data);
613 }, [&](const MTPDupload_file &data) {
614 partLoaded(requestData.offset, data.vbytes().v);
615 });
616
617 // 'this' may be deleted at this point.
618 owner->checkSendNextAfterSuccess(dcId);
619 }
620
webPartLoaded(const MTPupload_WebFile & result,mtpRequestId requestId)621 void DownloadMtprotoTask::webPartLoaded(
622 const MTPupload_WebFile &result,
623 mtpRequestId requestId) {
624 const auto requestData = finishSentRequest(
625 requestId,
626 FinishRequestReason::Success);
627 const auto owner = _owner;
628 const auto dcId = this->dcId();
629 result.match([&](const MTPDupload_webFile &data) {
630 if (setWebFileSizeHook(data.vsize().v)) {
631 partLoaded(requestData.offset, data.vbytes().v);
632 }
633 });
634
635 // 'this' may be deleted at this point.
636 owner->checkSendNextAfterSuccess(dcId);
637 }
638
cdnPartLoaded(const MTPupload_CdnFile & result,mtpRequestId requestId)639 void DownloadMtprotoTask::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId) {
640 result.match([&](const MTPDupload_cdnFileReuploadNeeded &data) {
641 const auto requestData = finishSentRequest(
642 requestId,
643 FinishRequestReason::Redirect);
644 const auto shiftedDcId = MTP::downloadDcId(
645 dcId(),
646 requestData.sessionIndex);
647 const auto requestId = api().request(MTPupload_ReuploadCdnFile(
648 MTP_bytes(_cdnToken),
649 data.vrequest_token()
650 )).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
651 reuploadDone(result, id);
652 }).fail([=](const MTP::Error &error, mtpRequestId id) {
653 cdnPartFailed(error, id);
654 }).toDC(shiftedDcId).send();
655 placeSentRequest(requestId, requestData);
656 }, [&](const MTPDupload_cdnFile &data) {
657 const auto requestData = finishSentRequest(
658 requestId,
659 FinishRequestReason::Success);
660 const auto owner = _owner;
661 const auto dcId = this->dcId();
662 const auto guard = gsl::finally([=] {
663 // 'this' may be deleted at this point.
664 owner->checkSendNextAfterSuccess(dcId);
665 });
666
667 auto key = bytes::make_span(_cdnEncryptionKey);
668 auto iv = bytes::make_span(_cdnEncryptionIV);
669 Expects(key.size() == MTP::CTRState::KeySize);
670 Expects(iv.size() == MTP::CTRState::IvecSize);
671
672 auto state = MTP::CTRState();
673 auto ivec = bytes::make_span(state.ivec);
674 std::copy(iv.begin(), iv.end(), ivec.begin());
675
676 auto counterOffset = static_cast<uint32>(requestData.offset) >> 4;
677 state.ivec[15] = static_cast<uchar>(counterOffset & 0xFF);
678 state.ivec[14] = static_cast<uchar>((counterOffset >> 8) & 0xFF);
679 state.ivec[13] = static_cast<uchar>((counterOffset >> 16) & 0xFF);
680 state.ivec[12] = static_cast<uchar>((counterOffset >> 24) & 0xFF);
681
682 auto decryptInPlace = data.vbytes().v;
683 auto buffer = bytes::make_detached_span(decryptInPlace);
684 MTP::aesCtrEncrypt(buffer, key.data(), &state);
685
686 switch (checkCdnFileHash(requestData.offset, buffer)) {
687 case CheckCdnHashResult::NoHash: {
688 _cdnUncheckedParts.emplace(requestData, decryptInPlace);
689 requestMoreCdnFileHashes();
690 } return;
691
692 case CheckCdnHashResult::Invalid: {
693 LOG(("API Error: Wrong cdnFileHash for offset %1."
694 ).arg(requestData.offset));
695 cancelOnFail();
696 } return;
697
698 case CheckCdnHashResult::Good: {
699 partLoaded(requestData.offset, decryptInPlace);
700 } return;
701 }
702 Unexpected("Result of checkCdnFileHash()");
703 });
704 }
705
checkCdnFileHash(int offset,bytes::const_span buffer)706 DownloadMtprotoTask::CheckCdnHashResult DownloadMtprotoTask::checkCdnFileHash(
707 int offset,
708 bytes::const_span buffer) {
709 const auto cdnFileHashIt = _cdnFileHashes.find(offset);
710 if (cdnFileHashIt == _cdnFileHashes.cend()) {
711 return CheckCdnHashResult::NoHash;
712 }
713 const auto realHash = openssl::Sha256(buffer);
714 const auto receivedHash = bytes::make_span(cdnFileHashIt->second.hash);
715 if (bytes::compare(realHash, receivedHash)) {
716 return CheckCdnHashResult::Invalid;
717 }
718 return CheckCdnHashResult::Good;
719 }
720
reuploadDone(const MTPVector<MTPFileHash> & result,mtpRequestId requestId)721 void DownloadMtprotoTask::reuploadDone(
722 const MTPVector<MTPFileHash> &result,
723 mtpRequestId requestId) {
724 const auto requestData = finishSentRequest(
725 requestId,
726 FinishRequestReason::Redirect);
727 addCdnHashes(result.v);
728 makeRequest(requestData);
729 }
730
getCdnFileHashesDone(const MTPVector<MTPFileHash> & result,mtpRequestId requestId)731 void DownloadMtprotoTask::getCdnFileHashesDone(
732 const MTPVector<MTPFileHash> &result,
733 mtpRequestId requestId) {
734 Expects(_cdnHashesRequestId == requestId);
735
736 const auto requestData = finishSentRequest(
737 requestId,
738 FinishRequestReason::Redirect);
739 addCdnHashes(result.v);
740 auto someMoreChecked = false;
741 for (auto i = _cdnUncheckedParts.begin(); i != _cdnUncheckedParts.cend();) {
742 const auto uncheckedData = i->first;
743 const auto uncheckedBytes = bytes::make_span(i->second);
744
745 switch (checkCdnFileHash(uncheckedData.offset, uncheckedBytes)) {
746 case CheckCdnHashResult::NoHash: {
747 ++i;
748 } break;
749
750 case CheckCdnHashResult::Invalid: {
751 LOG(("API Error: Wrong cdnFileHash for offset %1."
752 ).arg(uncheckedData.offset));
753 cancelOnFail();
754 return;
755 } break;
756
757 case CheckCdnHashResult::Good: {
758 someMoreChecked = true;
759 const auto goodOffset = uncheckedData.offset;
760 const auto goodBytes = std::move(i->second);
761 const auto weak = base::make_weak(this);
762 i = _cdnUncheckedParts.erase(i);
763 if (!feedPart(goodOffset, goodBytes) || !weak) {
764 return;
765 }
766 } break;
767
768 default: Unexpected("Result of checkCdnFileHash()");
769 }
770 }
771 if (!someMoreChecked) {
772 LOG(("API Error: "
773 "Could not find cdnFileHash for offset %1 "
774 "after getCdnFileHashes request."
775 ).arg(requestData.offset));
776 cancelOnFail();
777 return;
778 }
779 requestMoreCdnFileHashes();
780 }
781
placeSentRequest(mtpRequestId requestId,const RequestData & requestData)782 void DownloadMtprotoTask::placeSentRequest(
783 mtpRequestId requestId,
784 const RequestData &requestData) {
785 const auto amount = _owner->changeRequestedAmount(
786 dcId(),
787 requestData.sessionIndex,
788 Storage::kDownloadPartSize);
789 const auto [i, ok1] = _sentRequests.emplace(requestId, requestData);
790 const auto [j, ok2] = _requestByOffset.emplace(
791 requestData.offset,
792 requestId);
793
794 i->second.requestedInSession = amount;
795 i->second.sent = crl::now();
796
797 Ensures(ok1 && ok2);
798 }
799
finishSentRequest(mtpRequestId requestId,FinishRequestReason reason)800 auto DownloadMtprotoTask::finishSentRequest(
801 mtpRequestId requestId,
802 FinishRequestReason reason)
803 -> RequestData {
804 auto it = _sentRequests.find(requestId);
805 Assert(it != _sentRequests.cend());
806
807 if (_cdnHashesRequestId == requestId) {
808 _cdnHashesRequestId = 0;
809 }
810 const auto result = it->second;
811 _owner->changeRequestedAmount(
812 dcId(),
813 result.sessionIndex,
814 -Storage::kDownloadPartSize);
815 _sentRequests.erase(it);
816 const auto ok = _requestByOffset.remove(result.offset);
817
818 if (reason == FinishRequestReason::Success) {
819 _owner->requestSucceeded(
820 dcId(),
821 result.sessionIndex,
822 result.requestedInSession,
823 result.sent);
824 }
825
826 Ensures(ok);
827 return result;
828 }
829
haveSentRequests() const830 bool DownloadMtprotoTask::haveSentRequests() const {
831 return !_sentRequests.empty() || !_cdnUncheckedParts.empty();
832 }
833
haveSentRequestForOffset(int offset) const834 bool DownloadMtprotoTask::haveSentRequestForOffset(int offset) const {
835 return _requestByOffset.contains(offset)
836 || _cdnUncheckedParts.contains({ offset, 0 });
837 }
838
cancelAllRequests()839 void DownloadMtprotoTask::cancelAllRequests() {
840 while (!_sentRequests.empty()) {
841 cancelRequest(_sentRequests.begin()->first);
842 }
843 _cdnUncheckedParts.clear();
844 }
845
cancelRequestForOffset(int offset)846 void DownloadMtprotoTask::cancelRequestForOffset(int offset) {
847 const auto i = _requestByOffset.find(offset);
848 if (i != end(_requestByOffset)) {
849 cancelRequest(i->second);
850 }
851 _cdnUncheckedParts.remove({ offset, 0 });
852 }
853
cancelRequest(mtpRequestId requestId)854 void DownloadMtprotoTask::cancelRequest(mtpRequestId requestId) {
855 const auto hashes = (_cdnHashesRequestId == requestId);
856 api().request(requestId).cancel();
857 [[maybe_unused]] const auto data = finishSentRequest(
858 requestId,
859 FinishRequestReason::Cancel);
860 if (hashes && !_cdnUncheckedParts.empty()) {
861 crl::on_main(this, [=] {
862 requestMoreCdnFileHashes();
863 });
864 }
865 }
866
addToQueue(int priority)867 void DownloadMtprotoTask::addToQueue(int priority) {
868 _owner->enqueue(this, priority);
869 }
870
removeFromQueue()871 void DownloadMtprotoTask::removeFromQueue() {
872 _owner->remove(this);
873 }
874
partLoaded(int offset,const QByteArray & bytes)875 void DownloadMtprotoTask::partLoaded(
876 int offset,
877 const QByteArray &bytes) {
878 feedPart(offset, bytes);
879 }
880
normalPartFailed(QByteArray fileReference,const MTP::Error & error,mtpRequestId requestId)881 bool DownloadMtprotoTask::normalPartFailed(
882 QByteArray fileReference,
883 const MTP::Error &error,
884 mtpRequestId requestId) {
885 if (MTP::IsDefaultHandledError(error)) {
886 return false;
887 }
888 if (error.code() == 400
889 && error.type().startsWith(qstr("FILE_REFERENCE_"))) {
890 api().refreshFileReference(
891 _origin,
892 this,
893 requestId,
894 fileReference);
895 return true;
896 }
897 return partFailed(error, requestId);
898 }
899
partFailed(const MTP::Error & error,mtpRequestId requestId)900 bool DownloadMtprotoTask::partFailed(
901 const MTP::Error &error,
902 mtpRequestId requestId) {
903 if (MTP::IsDefaultHandledError(error)) {
904 return false;
905 }
906 cancelOnFail();
907 return true;
908 }
909
cdnPartFailed(const MTP::Error & error,mtpRequestId requestId)910 bool DownloadMtprotoTask::cdnPartFailed(
911 const MTP::Error &error,
912 mtpRequestId requestId) {
913 if (MTP::IsDefaultHandledError(error)) {
914 return false;
915 }
916
917 if (error.type() == qstr("FILE_TOKEN_INVALID")
918 || error.type() == qstr("REQUEST_TOKEN_INVALID")) {
919 const auto requestData = finishSentRequest(
920 requestId,
921 FinishRequestReason::Redirect);
922 changeCDNParams(
923 requestData,
924 0,
925 QByteArray(),
926 QByteArray(),
927 QByteArray(),
928 QVector<MTPFileHash>());
929 return true;
930 }
931 return partFailed(error, requestId);
932 }
933
switchToCDN(const RequestData & requestData,const MTPDupload_fileCdnRedirect & redirect)934 void DownloadMtprotoTask::switchToCDN(
935 const RequestData &requestData,
936 const MTPDupload_fileCdnRedirect &redirect) {
937 changeCDNParams(
938 requestData,
939 redirect.vdc_id().v,
940 redirect.vfile_token().v,
941 redirect.vencryption_key().v,
942 redirect.vencryption_iv().v,
943 redirect.vfile_hashes().v);
944 }
945
addCdnHashes(const QVector<MTPFileHash> & hashes)946 void DownloadMtprotoTask::addCdnHashes(
947 const QVector<MTPFileHash> &hashes) {
948 for (const auto &hash : hashes) {
949 hash.match([&](const MTPDfileHash &data) {
950 _cdnFileHashes.emplace(
951 data.voffset().v,
952 CdnFileHash{ data.vlimit().v, data.vhash().v });
953 });
954 }
955 }
956
changeCDNParams(const RequestData & requestData,MTP::DcId dcId,const QByteArray & token,const QByteArray & encryptionKey,const QByteArray & encryptionIV,const QVector<MTPFileHash> & hashes)957 void DownloadMtprotoTask::changeCDNParams(
958 const RequestData &requestData,
959 MTP::DcId dcId,
960 const QByteArray &token,
961 const QByteArray &encryptionKey,
962 const QByteArray &encryptionIV,
963 const QVector<MTPFileHash> &hashes) {
964 if (dcId != 0
965 && (encryptionKey.size() != MTP::CTRState::KeySize
966 || encryptionIV.size() != MTP::CTRState::IvecSize)) {
967 LOG(("Message Error: Wrong key (%1) / iv (%2) size in CDN params"
968 ).arg(encryptionKey.size()
969 ).arg(encryptionIV.size()));
970 cancelOnFail();
971 return;
972 }
973
974 auto resendAllRequests = (_cdnDcId != dcId
975 || _cdnToken != token
976 || _cdnEncryptionKey != encryptionKey
977 || _cdnEncryptionIV != encryptionIV);
978 _cdnDcId = dcId;
979 _cdnToken = token;
980 _cdnEncryptionKey = encryptionKey;
981 _cdnEncryptionIV = encryptionIV;
982 addCdnHashes(hashes);
983
984 if (resendAllRequests && !_sentRequests.empty()) {
985 auto resendRequests = std::vector<RequestData>();
986 resendRequests.reserve(_sentRequests.size());
987 while (!_sentRequests.empty()) {
988 const auto requestId = _sentRequests.begin()->first;
989 api().request(requestId).cancel();
990 resendRequests.push_back(finishSentRequest(
991 requestId,
992 FinishRequestReason::Redirect));
993 }
994 for (const auto &requestData : resendRequests) {
995 makeRequest(requestData);
996 }
997 }
998 makeRequest(requestData);
999 }
1000
1001 } // namespace Storage
1002