1 /*
2 This file is part of Telegram Desktop,
3 the official desktop application for the Telegram messaging service.
4 
5 For license and copyright information please follow this link:
6 https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
7 */
8 #pragma once
9 
10 #include "data/data_file_origin.h"
11 #include "base/timer.h"
12 #include "base/weak_ptr.h"
13 
14 class ApiWrap;
15 
16 namespace MTP {
17 class Error;
18 } // namespace MTP
19 
20 namespace Storage {
21 
22 // Different part sizes are not supported for now :(
23 // Because we start downloading with some part size
24 // and then we get a CDN-redirect where we support only
25 // fixed part size download for hash checking.
26 constexpr auto kDownloadPartSize = 128 * 1024;
27 
28 class DownloadMtprotoTask;
29 
30 class DownloadManagerMtproto final : public base::has_weak_ptr {
31 public:
32 	using Task = DownloadMtprotoTask;
33 
34 	explicit DownloadManagerMtproto(not_null<ApiWrap*> api);
35 	~DownloadManagerMtproto();
36 
api()37 	[[nodiscard]] ApiWrap &api() const {
38 		return *_api;
39 	}
40 
41 	void enqueue(not_null<Task*> task, int priority);
42 	void remove(not_null<Task*> task);
43 
notifyTaskFinished()44 	void notifyTaskFinished() {
45 		_taskFinished.fire({});
46 	}
taskFinished()47 	[[nodiscard]] rpl::producer<> taskFinished() const {
48 		return _taskFinished.events();
49 	}
50 
51 	int changeRequestedAmount(MTP::DcId dcId, int index, int delta);
52 	void requestSucceeded(
53 		MTP::DcId dcId,
54 		int index,
55 		int amountAtRequestStart,
56 		crl::time timeAtRequestStart);
57 	void checkSendNextAfterSuccess(MTP::DcId dcId);
58 	[[nodiscard]] int chooseSessionIndex(MTP::DcId dcId) const;
59 
60 private:
61 	class Queue final {
62 	public:
63 		void enqueue(not_null<Task*> task, int priority);
64 		void remove(not_null<Task*> task);
65 		void resetGeneration();
66 		[[nodiscard]] bool empty() const;
67 		[[nodiscard]] Task *nextTask(bool onlyHighestPriority) const;
68 		void removeSession(int index);
69 
70 	private:
71 		struct Enqueued {
72 			not_null<Task*> task;
73 			int priority = 0;
74 		};
75 		std::vector<Enqueued> _tasks;
76 
77 	};
78 	struct DcSessionBalanceData {
79 		DcSessionBalanceData();
80 
81 		int requested = 0;
82 		int successes = 0; // Since last timeout in this dc in any session.
83 		int maxWaitedAmount = 0;
84 	};
85 	struct DcBalanceData {
86 		DcBalanceData();
87 
88 		std::vector<DcSessionBalanceData> sessions;
89 		crl::time lastSessionRemove = 0;
90 		int sessionRemoveIndex = 0;
91 		int sessionRemoveTimes = 0;
92 		int timeouts = 0; // Since all sessions had successes >= required.
93 		int totalRequested = 0;
94 	};
95 
96 	void checkSendNext();
97 	void checkSendNext(MTP::DcId dcId, Queue &queue);
98 	bool trySendNextPart(MTP::DcId dcId, Queue &queue);
99 
100 	void killSessionsSchedule(MTP::DcId dcId);
101 	void killSessionsCancel(MTP::DcId dcId);
102 	void killSessions();
103 	void killSessions(MTP::DcId dcId);
104 
105 	void resetGeneration();
106 	void sessionTimedOut(MTP::DcId dcId, int index);
107 	void removeSession(MTP::DcId dcId);
108 
109 	const not_null<ApiWrap*> _api;
110 
111 	rpl::event_stream<> _taskFinished;
112 
113 	base::flat_map<MTP::DcId, DcBalanceData> _balanceData;
114 	base::Timer _resetGenerationTimer;
115 
116 	base::flat_map<MTP::DcId, crl::time> _killSessionsWhen;
117 	base::Timer _killSessionsTimer;
118 
119 	base::flat_map<MTP::DcId, Queue> _queues;
120 	rpl::lifetime _lifetime;
121 
122 };
123 
124 class DownloadMtprotoTask : public base::has_weak_ptr {
125 public:
126 	struct Location {
127 		std::variant<
128 			StorageFileLocation,
129 			WebFileLocation,
130 			GeoPointLocation> data;
131 	};
132 
133 	DownloadMtprotoTask(
134 		not_null<DownloadManagerMtproto*> owner,
135 		const StorageFileLocation &location,
136 		Data::FileOrigin origin);
137 	DownloadMtprotoTask(
138 		not_null<DownloadManagerMtproto*> owner,
139 		MTP::DcId dcId,
140 		const Location &location);
141 	virtual ~DownloadMtprotoTask();
142 
143 	[[nodiscard]] MTP::DcId dcId() const;
144 	[[nodiscard]] Data::FileOrigin fileOrigin() const;
145 	[[nodiscard]] uint64 objectId() const;
146 	[[nodiscard]] const Location &location() const;
147 
148 	[[nodiscard]] virtual bool readyToRequest() const = 0;
149 	void loadPart(int sessionIndex);
150 	void removeSession(int sessionIndex);
151 
152 	void refreshFileReferenceFrom(
153 		const Data::UpdatedFileReferences &updates,
154 		int requestId,
155 		const QByteArray &current);
156 
157 protected:
158 	[[nodiscard]] bool haveSentRequests() const;
159 	[[nodiscard]] bool haveSentRequestForOffset(int offset) const;
160 	void cancelAllRequests();
161 	void cancelRequestForOffset(int offset);
162 
163 	void addToQueue(int priority = 0);
164 	void removeFromQueue();
165 
api()166 	[[nodiscard]] ApiWrap &api() const {
167 		return _owner->api();
168 	}
169 
170 private:
171 	struct RequestData {
172 		int offset = 0;
173 		mutable int sessionIndex = 0;
174 		int requestedInSession = 0;
175 		crl::time sent = 0;
176 
177 		inline bool operator<(const RequestData &other) const {
178 			return offset < other.offset;
179 		}
180 	};
181 	struct CdnFileHash {
CdnFileHashCdnFileHash182 		CdnFileHash(int limit, QByteArray hash) : limit(limit), hash(hash) {
183 		}
184 		int limit = 0;
185 		QByteArray hash;
186 	};
187 	enum class CheckCdnHashResult {
188 		NoHash,
189 		Invalid,
190 		Good,
191 	};
192 	enum class FinishRequestReason {
193 		Success,
194 		Redirect,
195 		Cancel,
196 	};
197 
198 	// Called only if readyToRequest() == true.
199 	[[nodiscard]] virtual int takeNextRequestOffset() = 0;
200 	virtual bool feedPart(int offset, const QByteArray &bytes) = 0;
201 	virtual bool setWebFileSizeHook(int size);
202 	virtual void cancelOnFail() = 0;
203 
204 	void cancelRequest(mtpRequestId requestId);
205 	void makeRequest(const RequestData &requestData);
206 	void normalPartLoaded(
207 		const MTPupload_File &result,
208 		mtpRequestId requestId);
209 	void webPartLoaded(
210 		const MTPupload_WebFile &result,
211 		mtpRequestId requestId);
212 	void cdnPartLoaded(
213 		const MTPupload_CdnFile &result,
214 		mtpRequestId requestId);
215 	void reuploadDone(
216 		const MTPVector<MTPFileHash> &result,
217 		mtpRequestId requestId);
218 	void requestMoreCdnFileHashes();
219 	void getCdnFileHashesDone(
220 		const MTPVector<MTPFileHash> &result,
221 		mtpRequestId requestId);
222 
223 	void partLoaded(int offset, const QByteArray &bytes);
224 
225 	bool partFailed(const MTP::Error &error, mtpRequestId requestId);
226 	bool normalPartFailed(
227 		QByteArray fileReference,
228 		const MTP::Error &error,
229 		mtpRequestId requestId);
230 	bool cdnPartFailed(const MTP::Error &error, mtpRequestId requestId);
231 
232 	[[nodiscard]] mtpRequestId sendRequest(const RequestData &requestData);
233 	void placeSentRequest(
234 		mtpRequestId requestId,
235 		const RequestData &requestData);
236 	[[nodiscard]] RequestData finishSentRequest(
237 		mtpRequestId requestId,
238 		FinishRequestReason reason);
239 	void switchToCDN(
240 		const RequestData &requestData,
241 		const MTPDupload_fileCdnRedirect &redirect);
242 	void addCdnHashes(const QVector<MTPFileHash> &hashes);
243 	void changeCDNParams(
244 		const RequestData &requestData,
245 		MTP::DcId dcId,
246 		const QByteArray &token,
247 		const QByteArray &encryptionKey,
248 		const QByteArray &encryptionIV,
249 		const QVector<MTPFileHash> &hashes);
250 
251 	[[nodiscard]] CheckCdnHashResult checkCdnFileHash(
252 		int offset,
253 		bytes::const_span buffer);
254 
255 	const not_null<DownloadManagerMtproto*> _owner;
256 	const MTP::DcId _dcId = 0;
257 
258 	// _location can be changed with an updated file_reference.
259 	Location _location;
260 	const Data::FileOrigin _origin;
261 
262 	base::flat_map<mtpRequestId, RequestData> _sentRequests;
263 	base::flat_map<int, mtpRequestId> _requestByOffset;
264 
265 	MTP::DcId _cdnDcId = 0;
266 	QByteArray _cdnToken;
267 	QByteArray _cdnEncryptionKey;
268 	QByteArray _cdnEncryptionIV;
269 	base::flat_map<int, CdnFileHash> _cdnFileHashes;
270 	base::flat_map<RequestData, QByteArray> _cdnUncheckedParts;
271 	mtpRequestId _cdnHashesRequestId = 0;
272 
273 };
274 
275 } // namespace Storage
276