xref: /reactos/dll/win32/urlmon/protocol.c (revision 02e84521)
1 /*
2  * Copyright 2007 Misha Koshelev
3  * Copyright 2009 Jacek Caban for CodeWeavers
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Lesser General Public
7  * License as published by the Free Software Foundation; either
8  * version 2.1 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
18  */
19 
20 #include "urlmon_main.h"
21 
22 #include "wine/debug.h"
23 
24 WINE_DEFAULT_DEBUG_CHANNEL(urlmon);
25 
26 static inline HRESULT report_progress(Protocol *protocol, ULONG status_code, LPCWSTR status_text)
27 {
28     return IInternetProtocolSink_ReportProgress(protocol->protocol_sink, status_code, status_text);
29 }
30 
31 static inline HRESULT report_result(Protocol *protocol, HRESULT hres)
32 {
33     if (!(protocol->flags & FLAG_RESULT_REPORTED) && protocol->protocol_sink) {
34         protocol->flags |= FLAG_RESULT_REPORTED;
35         IInternetProtocolSink_ReportResult(protocol->protocol_sink, hres, 0, NULL);
36     }
37 
38     return hres;
39 }
40 
41 static void report_data(Protocol *protocol)
42 {
43     DWORD bscf;
44 
45     if((protocol->flags & FLAG_LAST_DATA_REPORTED) || !protocol->protocol_sink)
46         return;
47 
48     if(protocol->flags & FLAG_FIRST_DATA_REPORTED) {
49         bscf = BSCF_INTERMEDIATEDATANOTIFICATION;
50     }else {
51         protocol->flags |= FLAG_FIRST_DATA_REPORTED;
52         bscf = BSCF_FIRSTDATANOTIFICATION;
53     }
54 
55     if(protocol->flags & FLAG_ALL_DATA_READ && !(protocol->flags & FLAG_LAST_DATA_REPORTED)) {
56         protocol->flags |= FLAG_LAST_DATA_REPORTED;
57         bscf |= BSCF_LASTDATANOTIFICATION;
58     }
59 
60     IInternetProtocolSink_ReportData(protocol->protocol_sink, bscf,
61             protocol->current_position+protocol->available_bytes,
62             protocol->content_length);
63 }
64 
65 static void all_data_read(Protocol *protocol)
66 {
67     protocol->flags |= FLAG_ALL_DATA_READ;
68 
69     report_data(protocol);
70     report_result(protocol, S_OK);
71 }
72 
73 static HRESULT start_downloading(Protocol *protocol)
74 {
75     HRESULT hres;
76 
77     hres = protocol->vtbl->start_downloading(protocol);
78     if(FAILED(hres)) {
79         if(hres == INET_E_REDIRECT_FAILED)
80             return S_OK;
81         protocol_close_connection(protocol);
82         report_result(protocol, hres);
83         return hres;
84     }
85 
86     if(protocol->bindf & BINDF_NEEDFILE) {
87         WCHAR cache_file[MAX_PATH];
88         DWORD buflen = sizeof(cache_file);
89 
90         if(InternetQueryOptionW(protocol->request, INTERNET_OPTION_DATAFILE_NAME, cache_file, &buflen)) {
91             report_progress(protocol, BINDSTATUS_CACHEFILENAMEAVAILABLE, cache_file);
92         }else {
93             FIXME("Could not get cache file\n");
94         }
95     }
96 
97     protocol->flags |= FLAG_FIRST_CONTINUE_COMPLETE;
98     return S_OK;
99 }
100 
101 HRESULT protocol_syncbinding(Protocol *protocol)
102 {
103     BOOL res;
104     HRESULT hres;
105 
106     protocol->flags |= FLAG_SYNC_READ;
107 
108     hres = start_downloading(protocol);
109     if(FAILED(hres))
110         return hres;
111 
112     res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
113     if(res)
114         protocol->available_bytes = protocol->query_available;
115     else
116         WARN("InternetQueryDataAvailable failed: %u\n", GetLastError());
117 
118     protocol->flags |= FLAG_FIRST_DATA_REPORTED|FLAG_LAST_DATA_REPORTED;
119     IInternetProtocolSink_ReportData(protocol->protocol_sink, BSCF_LASTDATANOTIFICATION|BSCF_DATAFULLYAVAILABLE,
120             protocol->available_bytes, protocol->content_length);
121     return S_OK;
122 }
123 
124 static void request_complete(Protocol *protocol, INTERNET_ASYNC_RESULT *ar)
125 {
126     PROTOCOLDATA data;
127 
128     TRACE("(%p)->(%p)\n", protocol, ar);
129 
130     /* PROTOCOLDATA same as native */
131     memset(&data, 0, sizeof(data));
132     data.dwState = 0xf1000000;
133 
134     if(ar->dwResult) {
135         protocol->flags |= FLAG_REQUEST_COMPLETE;
136 
137         if(!protocol->request) {
138             TRACE("setting request handle %p\n", (HINTERNET)ar->dwResult);
139             protocol->request = (HINTERNET)ar->dwResult;
140         }
141 
142         if(protocol->flags & FLAG_FIRST_CONTINUE_COMPLETE)
143             data.pData = UlongToPtr(BINDSTATUS_ENDDOWNLOADCOMPONENTS);
144         else
145             data.pData = UlongToPtr(BINDSTATUS_DOWNLOADINGDATA);
146 
147     }else {
148         protocol->flags |= FLAG_ERROR;
149         data.pData = UlongToPtr(ar->dwError);
150     }
151 
152     if (protocol->bindf & BINDF_FROMURLMON)
153         IInternetProtocolSink_Switch(protocol->protocol_sink, &data);
154     else
155         protocol_continue(protocol, &data);
156 }
157 
158 static void WINAPI internet_status_callback(HINTERNET internet, DWORD_PTR context,
159         DWORD internet_status, LPVOID status_info, DWORD status_info_len)
160 {
161     Protocol *protocol = (Protocol*)context;
162 
163     switch(internet_status) {
164     case INTERNET_STATUS_RESOLVING_NAME:
165         TRACE("%p INTERNET_STATUS_RESOLVING_NAME\n", protocol);
166         report_progress(protocol, BINDSTATUS_FINDINGRESOURCE, (LPWSTR)status_info);
167         break;
168 
169     case INTERNET_STATUS_CONNECTING_TO_SERVER: {
170         WCHAR *info;
171 
172         TRACE("%p INTERNET_STATUS_CONNECTING_TO_SERVER %s\n", protocol, (const char*)status_info);
173 
174         info = heap_strdupAtoW(status_info);
175         if(!info)
176             return;
177 
178         report_progress(protocol, BINDSTATUS_CONNECTING, info);
179         heap_free(info);
180         break;
181     }
182 
183     case INTERNET_STATUS_SENDING_REQUEST:
184         TRACE("%p INTERNET_STATUS_SENDING_REQUEST\n", protocol);
185         report_progress(protocol, BINDSTATUS_SENDINGREQUEST, (LPWSTR)status_info);
186         break;
187 
188     case INTERNET_STATUS_REDIRECT:
189         TRACE("%p INTERNET_STATUS_REDIRECT\n", protocol);
190         report_progress(protocol, BINDSTATUS_REDIRECTING, (LPWSTR)status_info);
191         break;
192 
193     case INTERNET_STATUS_REQUEST_COMPLETE:
194         request_complete(protocol, status_info);
195         break;
196 
197     case INTERNET_STATUS_HANDLE_CREATED:
198         TRACE("%p INTERNET_STATUS_HANDLE_CREATED\n", protocol);
199         IInternetProtocol_AddRef(protocol->protocol);
200         break;
201 
202     case INTERNET_STATUS_HANDLE_CLOSING:
203         TRACE("%p INTERNET_STATUS_HANDLE_CLOSING\n", protocol);
204 
205         if(*(HINTERNET *)status_info == protocol->request) {
206             protocol->request = NULL;
207             if(protocol->protocol_sink) {
208                 IInternetProtocolSink_Release(protocol->protocol_sink);
209                 protocol->protocol_sink = NULL;
210             }
211 
212             if(protocol->bind_info.cbSize) {
213                 ReleaseBindInfo(&protocol->bind_info);
214                 memset(&protocol->bind_info, 0, sizeof(protocol->bind_info));
215             }
216         }else if(*(HINTERNET *)status_info == protocol->connection) {
217             protocol->connection = NULL;
218         }
219 
220         IInternetProtocol_Release(protocol->protocol);
221         break;
222 
223     default:
224         WARN("Unhandled Internet status callback %d\n", internet_status);
225     }
226 }
227 
228 static HRESULT write_post_stream(Protocol *protocol)
229 {
230     BYTE buf[0x20000];
231     DWORD written;
232     ULONG size;
233     BOOL res;
234     HRESULT hres;
235 
236     protocol->flags &= ~FLAG_REQUEST_COMPLETE;
237 
238     while(1) {
239         size = 0;
240         hres = IStream_Read(protocol->post_stream, buf, sizeof(buf), &size);
241         if(FAILED(hres) || !size)
242             break;
243         res = InternetWriteFile(protocol->request, buf, size, &written);
244         if(!res) {
245             FIXME("InternetWriteFile failed: %u\n", GetLastError());
246             hres = E_FAIL;
247             break;
248         }
249     }
250 
251     if(SUCCEEDED(hres)) {
252         IStream_Release(protocol->post_stream);
253         protocol->post_stream = NULL;
254 
255         hres = protocol->vtbl->end_request(protocol);
256     }
257 
258     if(FAILED(hres))
259         return report_result(protocol, hres);
260 
261     return S_OK;
262 }
263 
264 static HINTERNET create_internet_session(IInternetBindInfo *bind_info)
265 {
266     LPWSTR global_user_agent = NULL;
267     LPOLESTR user_agent = NULL;
268     ULONG size = 0;
269     HINTERNET ret;
270     HRESULT hres;
271 
272     hres = IInternetBindInfo_GetBindString(bind_info, BINDSTRING_USER_AGENT, &user_agent, 1, &size);
273     if(hres != S_OK || !size)
274         global_user_agent = get_useragent();
275 
276     ret = InternetOpenW(user_agent ? user_agent : global_user_agent, 0, NULL, NULL, INTERNET_FLAG_ASYNC);
277     heap_free(global_user_agent);
278     CoTaskMemFree(user_agent);
279     if(!ret) {
280         WARN("InternetOpen failed: %d\n", GetLastError());
281         return NULL;
282     }
283 
284     InternetSetStatusCallbackW(ret, internet_status_callback);
285     return ret;
286 }
287 
288 static HINTERNET internet_session;
289 
290 HINTERNET get_internet_session(IInternetBindInfo *bind_info)
291 {
292     HINTERNET new_session;
293 
294     if(internet_session)
295         return internet_session;
296 
297     if(!bind_info)
298         return NULL;
299 
300     new_session = create_internet_session(bind_info);
301     if(new_session && InterlockedCompareExchangePointer((void**)&internet_session, new_session, NULL))
302         InternetCloseHandle(new_session);
303 
304     return internet_session;
305 }
306 
307 void update_user_agent(WCHAR *user_agent)
308 {
309     if(internet_session)
310         InternetSetOptionW(internet_session, INTERNET_OPTION_USER_AGENT, user_agent, strlenW(user_agent));
311 }
312 
313 HRESULT protocol_start(Protocol *protocol, IInternetProtocol *prot, IUri *uri,
314         IInternetProtocolSink *protocol_sink, IInternetBindInfo *bind_info)
315 {
316     DWORD request_flags;
317     HRESULT hres;
318 
319     protocol->protocol = prot;
320 
321     IInternetProtocolSink_AddRef(protocol_sink);
322     protocol->protocol_sink = protocol_sink;
323 
324     memset(&protocol->bind_info, 0, sizeof(protocol->bind_info));
325     protocol->bind_info.cbSize = sizeof(BINDINFO);
326     hres = IInternetBindInfo_GetBindInfo(bind_info, &protocol->bindf, &protocol->bind_info);
327     if(hres != S_OK) {
328         WARN("GetBindInfo failed: %08x\n", hres);
329         return report_result(protocol, hres);
330     }
331 
332     if(!(protocol->bindf & BINDF_FROMURLMON))
333         report_progress(protocol, BINDSTATUS_DIRECTBIND, NULL);
334 
335     if(!get_internet_session(bind_info))
336         return report_result(protocol, INET_E_NO_SESSION);
337 
338     request_flags = INTERNET_FLAG_KEEP_CONNECTION;
339     if(protocol->bindf & BINDF_NOWRITECACHE)
340         request_flags |= INTERNET_FLAG_NO_CACHE_WRITE;
341     if(protocol->bindf & BINDF_NEEDFILE)
342         request_flags |= INTERNET_FLAG_NEED_FILE;
343     if(protocol->bind_info.dwOptions & BINDINFO_OPTIONS_DISABLEAUTOREDIRECTS)
344         request_flags |= INTERNET_FLAG_NO_AUTO_REDIRECT;
345 
346     hres = protocol->vtbl->open_request(protocol, uri, request_flags, internet_session, bind_info);
347     if(FAILED(hres)) {
348         protocol_close_connection(protocol);
349         return report_result(protocol, hres);
350     }
351 
352     return S_OK;
353 }
354 
355 HRESULT protocol_continue(Protocol *protocol, PROTOCOLDATA *data)
356 {
357     BOOL is_start;
358     HRESULT hres;
359 
360     is_start = !data || data->pData == UlongToPtr(BINDSTATUS_DOWNLOADINGDATA);
361 
362     if(!protocol->request) {
363         WARN("Expected request to be non-NULL\n");
364         return S_OK;
365     }
366 
367     if(!protocol->protocol_sink) {
368         WARN("Expected IInternetProtocolSink pointer to be non-NULL\n");
369         return S_OK;
370     }
371 
372     if(protocol->flags & FLAG_ERROR) {
373         protocol->flags &= ~FLAG_ERROR;
374         protocol->vtbl->on_error(protocol, PtrToUlong(data->pData));
375         return S_OK;
376     }
377 
378     if(protocol->post_stream)
379         return write_post_stream(protocol);
380 
381     if(is_start) {
382         hres = start_downloading(protocol);
383         if(FAILED(hres))
384             return S_OK;
385     }
386 
387     if(!data || data->pData >= UlongToPtr(BINDSTATUS_DOWNLOADINGDATA)) {
388         if(!protocol->available_bytes) {
389             if(protocol->query_available) {
390                 protocol->available_bytes = protocol->query_available;
391             }else {
392                 BOOL res;
393 
394                 /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
395                  * read, so clear the flag _before_ calling so it does not incorrectly get cleared
396                  * after the status callback is called */
397                 protocol->flags &= ~FLAG_REQUEST_COMPLETE;
398                 res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
399                 if(res) {
400                     TRACE("available %u bytes\n", protocol->query_available);
401                     if(!protocol->query_available) {
402                         all_data_read(protocol);
403                         return S_OK;
404                     }
405                     protocol->available_bytes = protocol->query_available;
406                 }else if(GetLastError() != ERROR_IO_PENDING) {
407                     protocol->flags |= FLAG_REQUEST_COMPLETE;
408                     WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
409                     report_result(protocol, INET_E_DATA_NOT_AVAILABLE);
410                     return S_OK;
411                 }
412             }
413 
414             protocol->flags |= FLAG_REQUEST_COMPLETE;
415         }
416 
417         report_data(protocol);
418     }
419 
420     return S_OK;
421 }
422 
423 HRESULT protocol_read(Protocol *protocol, void *buf, ULONG size, ULONG *read_ret)
424 {
425     ULONG read = 0;
426     BOOL res;
427     HRESULT hres = S_FALSE;
428 
429     if(protocol->flags & FLAG_ALL_DATA_READ) {
430         *read_ret = 0;
431         return S_FALSE;
432     }
433 
434     if(!(protocol->flags & FLAG_SYNC_READ) && (!(protocol->flags & FLAG_REQUEST_COMPLETE) || !protocol->available_bytes)) {
435         *read_ret = 0;
436         return E_PENDING;
437     }
438 
439     while(read < size && protocol->available_bytes) {
440         ULONG len;
441 
442         res = InternetReadFile(protocol->request, ((BYTE *)buf)+read,
443                 protocol->available_bytes > size-read ? size-read : protocol->available_bytes, &len);
444         if(!res) {
445             WARN("InternetReadFile failed: %d\n", GetLastError());
446             hres = INET_E_DOWNLOAD_FAILURE;
447             report_result(protocol, hres);
448             break;
449         }
450 
451         if(!len) {
452             all_data_read(protocol);
453             break;
454         }
455 
456         read += len;
457         protocol->current_position += len;
458         protocol->available_bytes -= len;
459 
460         TRACE("current_position %d, available_bytes %d\n", protocol->current_position, protocol->available_bytes);
461 
462         if(!protocol->available_bytes) {
463             /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
464              * read, so clear the flag _before_ calling so it does not incorrectly get cleared
465              * after the status callback is called */
466             protocol->flags &= ~FLAG_REQUEST_COMPLETE;
467             res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
468             if(!res) {
469                 if (GetLastError() == ERROR_IO_PENDING) {
470                     hres = E_PENDING;
471                 }else {
472                     WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
473                     hres = INET_E_DATA_NOT_AVAILABLE;
474                     report_result(protocol, hres);
475                 }
476                 break;
477             }
478 
479             if(!protocol->query_available) {
480                 all_data_read(protocol);
481                 break;
482             }
483 
484             protocol->available_bytes = protocol->query_available;
485         }
486     }
487 
488     *read_ret = read;
489 
490     if (hres != E_PENDING)
491         protocol->flags |= FLAG_REQUEST_COMPLETE;
492     if(FAILED(hres))
493         return hres;
494 
495     return read ? S_OK : S_FALSE;
496 }
497 
498 HRESULT protocol_lock_request(Protocol *protocol)
499 {
500     if (!InternetLockRequestFile(protocol->request, &protocol->lock))
501         WARN("InternetLockRequest failed: %d\n", GetLastError());
502 
503     return S_OK;
504 }
505 
506 HRESULT protocol_unlock_request(Protocol *protocol)
507 {
508     if(!protocol->lock)
509         return S_OK;
510 
511     if(!InternetUnlockRequestFile(protocol->lock))
512         WARN("InternetUnlockRequest failed: %d\n", GetLastError());
513     protocol->lock = 0;
514 
515     return S_OK;
516 }
517 
518 HRESULT protocol_abort(Protocol *protocol, HRESULT reason)
519 {
520     if(!protocol->protocol_sink)
521         return S_OK;
522 
523     /* NOTE: IE10 returns S_OK here */
524     if(protocol->flags & FLAG_RESULT_REPORTED)
525         return INET_E_RESULT_DISPATCHED;
526 
527     report_result(protocol, reason);
528     return S_OK;
529 }
530 
531 void protocol_close_connection(Protocol *protocol)
532 {
533     protocol->vtbl->close_connection(protocol);
534 
535     if(protocol->request)
536         InternetCloseHandle(protocol->request);
537 
538     if(protocol->connection)
539         InternetCloseHandle(protocol->connection);
540 
541     if(protocol->post_stream) {
542         IStream_Release(protocol->post_stream);
543         protocol->post_stream = NULL;
544     }
545 
546     protocol->flags = 0;
547 }
548