xref: /reactos/dll/win32/msi/streams.c (revision fb5d5ecd)
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2007 James Hawkins
5  * Copyright 2015 Hans Leidekker for CodeWeavers
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
20  */
21 
22 #include <stdarg.h>
23 
24 #define COBJMACROS
25 
26 #include "windef.h"
27 #include "winbase.h"
28 #include "winerror.h"
29 #include "msi.h"
30 #include "msiquery.h"
31 #include "objbase.h"
32 #include "msipriv.h"
33 #include "query.h"
34 
35 #include "wine/debug.h"
36 #include "wine/unicode.h"
37 
38 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
39 
40 #define NUM_STREAMS_COLS    2
41 
42 typedef struct tagMSISTREAMSVIEW
43 {
44     MSIVIEW view;
45     MSIDATABASE *db;
46     UINT num_cols;
47 } MSISTREAMSVIEW;
48 
49 static BOOL streams_resize_table( MSIDATABASE *db, UINT size )
50 {
51     if (!db->num_streams_allocated)
52     {
53         if (!(db->streams = msi_alloc_zero( size * sizeof(MSISTREAM) ))) return FALSE;
54         db->num_streams_allocated = size;
55         return TRUE;
56     }
57     while (size >= db->num_streams_allocated)
58     {
59         MSISTREAM *tmp;
60         UINT new_size = db->num_streams_allocated * 2;
61         if (!(tmp = msi_realloc_zero( db->streams, new_size * sizeof(MSISTREAM) ))) return FALSE;
62         db->streams = tmp;
63         db->num_streams_allocated = new_size;
64     }
65     return TRUE;
66 }
67 
68 static UINT STREAMS_fetch_int(struct tagMSIVIEW *view, UINT row, UINT col, UINT *val)
69 {
70     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
71 
72     TRACE("(%p, %d, %d, %p)\n", view, row, col, val);
73 
74     if (col != 1)
75         return ERROR_INVALID_PARAMETER;
76 
77     if (row >= sv->db->num_streams)
78         return ERROR_NO_MORE_ITEMS;
79 
80     *val = sv->db->streams[row].str_index;
81 
82     return ERROR_SUCCESS;
83 }
84 
85 static UINT STREAMS_fetch_stream(struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm)
86 {
87     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
88     LARGE_INTEGER pos;
89     HRESULT hr;
90 
91     TRACE("(%p, %d, %d, %p)\n", view, row, col, stm);
92 
93     if (row >= sv->db->num_streams)
94         return ERROR_FUNCTION_FAILED;
95 
96     pos.QuadPart = 0;
97     hr = IStream_Seek( sv->db->streams[row].stream, pos, STREAM_SEEK_SET, NULL );
98     if (FAILED( hr ))
99         return ERROR_FUNCTION_FAILED;
100 
101     *stm = sv->db->streams[row].stream;
102     IStream_AddRef( *stm );
103 
104     return ERROR_SUCCESS;
105 }
106 
107 static UINT STREAMS_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec )
108 {
109     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
110 
111     TRACE("%p %d %p\n", sv, row, rec);
112 
113     return msi_view_get_row( sv->db, view, row, rec );
114 }
115 
116 static UINT STREAMS_set_row(struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask)
117 {
118     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
119 
120     TRACE("(%p, %d, %p, %08x)\n", view, row, rec, mask);
121 
122     if (row > sv->db->num_streams || mask >= (1 << sv->num_cols))
123         return ERROR_INVALID_PARAMETER;
124 
125     if (mask & 1)
126     {
127         const WCHAR *name = MSI_RecordGetString( rec, 1 );
128 
129         if (!name) return ERROR_INVALID_PARAMETER;
130         sv->db->streams[row].str_index = msi_add_string( sv->db->strings, name, -1, StringNonPersistent );
131     }
132     if (mask & 2)
133     {
134         IStream *old, *new;
135         HRESULT hr;
136         UINT r;
137 
138         r = MSI_RecordGetIStream( rec, 2, &new );
139         if (r != ERROR_SUCCESS)
140             return r;
141 
142         old = sv->db->streams[row].stream;
143         hr = IStream_QueryInterface( new, &IID_IStream, (void **)&sv->db->streams[row].stream );
144         if (FAILED( hr ))
145         {
146             IStream_Release( new );
147             return ERROR_FUNCTION_FAILED;
148         }
149         if (old) IStream_Release( old );
150     }
151 
152     return ERROR_SUCCESS;
153 }
154 
155 static UINT streams_find_row( MSISTREAMSVIEW *sv, MSIRECORD *rec, UINT *row )
156 {
157     const WCHAR *str;
158     UINT r, i, id, val;
159 
160     str = MSI_RecordGetString( rec, 1 );
161     r = msi_string2id( sv->db->strings, str, -1, &id );
162     if (r != ERROR_SUCCESS)
163         return r;
164 
165     for (i = 0; i < sv->db->num_streams; i++)
166     {
167         STREAMS_fetch_int( &sv->view, i, 1, &val );
168 
169         if (val == id)
170         {
171             if (row) *row = i;
172             return ERROR_SUCCESS;
173         }
174     }
175 
176     return ERROR_FUNCTION_FAILED;
177 }
178 
179 static UINT STREAMS_insert_row(struct tagMSIVIEW *view, MSIRECORD *rec, UINT row, BOOL temporary)
180 {
181     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
182     UINT i, r, num_rows = sv->db->num_streams + 1;
183 
184     TRACE("(%p, %p, %d, %d)\n", view, rec, row, temporary);
185 
186     r = streams_find_row( sv, rec, NULL );
187     if (r == ERROR_SUCCESS)
188         return ERROR_FUNCTION_FAILED;
189 
190     if (!streams_resize_table( sv->db, num_rows ))
191         return ERROR_FUNCTION_FAILED;
192 
193     if (row == -1)
194         row = num_rows - 1;
195 
196     /* shift the rows to make room for the new row */
197     for (i = num_rows - 1; i > row; i--)
198     {
199         sv->db->streams[i] = sv->db->streams[i - 1];
200     }
201 
202     r = STREAMS_set_row( view, row, rec, (1 << sv->num_cols) - 1 );
203     if (r == ERROR_SUCCESS)
204         sv->db->num_streams = num_rows;
205 
206     return r;
207 }
208 
209 static UINT STREAMS_delete_row(struct tagMSIVIEW *view, UINT row)
210 {
211     MSIDATABASE *db = ((MSISTREAMSVIEW *)view)->db;
212     UINT i, num_rows = db->num_streams - 1;
213     const WCHAR *name;
214     WCHAR *encname;
215     HRESULT hr;
216 
217     TRACE("(%p %d)!\n", view, row);
218 
219     name = msi_string_lookup( db->strings, db->streams[row].str_index, NULL );
220     if (!(encname = encode_streamname( FALSE, name ))) return ERROR_OUTOFMEMORY;
221     hr = IStorage_DestroyElement( db->storage, encname );
222     msi_free( encname );
223     if (FAILED( hr ))
224         return ERROR_FUNCTION_FAILED;
225     hr = IStream_Release( db->streams[row].stream );
226     if (FAILED( hr ))
227         return ERROR_FUNCTION_FAILED;
228 
229     for (i = row; i < num_rows; i++)
230         db->streams[i] = db->streams[i + 1];
231     db->num_streams = num_rows;
232 
233     return ERROR_SUCCESS;
234 }
235 
236 static UINT STREAMS_execute(struct tagMSIVIEW *view, MSIRECORD *record)
237 {
238     TRACE("(%p, %p)\n", view, record);
239     return ERROR_SUCCESS;
240 }
241 
242 static UINT STREAMS_close(struct tagMSIVIEW *view)
243 {
244     TRACE("(%p)\n", view);
245     return ERROR_SUCCESS;
246 }
247 
248 static UINT STREAMS_get_dimensions(struct tagMSIVIEW *view, UINT *rows, UINT *cols)
249 {
250     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
251 
252     TRACE("(%p, %p, %p)\n", view, rows, cols);
253 
254     if (cols) *cols = sv->num_cols;
255     if (rows) *rows = sv->db->num_streams;
256 
257     return ERROR_SUCCESS;
258 }
259 
260 static UINT STREAMS_get_column_info( struct tagMSIVIEW *view, UINT n, LPCWSTR *name,
261                                      UINT *type, BOOL *temporary, LPCWSTR *table_name )
262 {
263     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
264 
265     TRACE("(%p, %d, %p, %p, %p, %p)\n", view, n, name, type, temporary, table_name);
266 
267     if (!n || n > sv->num_cols)
268         return ERROR_INVALID_PARAMETER;
269 
270     switch (n)
271     {
272     case 1:
273         if (name) *name = szName;
274         if (type) *type = MSITYPE_STRING | MSITYPE_VALID | MAX_STREAM_NAME_LEN;
275         break;
276 
277     case 2:
278         if (name) *name = szData;
279         if (type) *type = MSITYPE_STRING | MSITYPE_VALID | MSITYPE_NULLABLE;
280         break;
281     }
282     if (table_name) *table_name = szStreams;
283     if (temporary) *temporary = FALSE;
284     return ERROR_SUCCESS;
285 }
286 
287 static UINT streams_modify_update(struct tagMSIVIEW *view, MSIRECORD *rec)
288 {
289     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
290     UINT r, row;
291 
292     r = streams_find_row(sv, rec, &row);
293     if (r != ERROR_SUCCESS)
294         return ERROR_FUNCTION_FAILED;
295 
296     return STREAMS_set_row( view, row, rec, (1 << sv->num_cols) - 1 );
297 }
298 
299 static UINT streams_modify_assign(struct tagMSIVIEW *view, MSIRECORD *rec)
300 {
301     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
302     UINT r;
303 
304     r = streams_find_row( sv, rec, NULL );
305     if (r == ERROR_SUCCESS)
306         return streams_modify_update(view, rec);
307 
308     return STREAMS_insert_row(view, rec, -1, FALSE);
309 }
310 
311 static UINT STREAMS_modify(struct tagMSIVIEW *view, MSIMODIFY eModifyMode, MSIRECORD *rec, UINT row)
312 {
313     UINT r;
314 
315     TRACE("%p %d %p\n", view, eModifyMode, rec);
316 
317     switch (eModifyMode)
318     {
319     case MSIMODIFY_ASSIGN:
320         r = streams_modify_assign(view, rec);
321         break;
322 
323     case MSIMODIFY_INSERT:
324         r = STREAMS_insert_row(view, rec, -1, FALSE);
325         break;
326 
327     case MSIMODIFY_UPDATE:
328         r = streams_modify_update(view, rec);
329         break;
330 
331     case MSIMODIFY_DELETE:
332         r = STREAMS_delete_row(view, row - 1);
333         break;
334 
335     case MSIMODIFY_VALIDATE_NEW:
336     case MSIMODIFY_INSERT_TEMPORARY:
337     case MSIMODIFY_REFRESH:
338     case MSIMODIFY_REPLACE:
339     case MSIMODIFY_MERGE:
340     case MSIMODIFY_VALIDATE:
341     case MSIMODIFY_VALIDATE_FIELD:
342     case MSIMODIFY_VALIDATE_DELETE:
343         FIXME("%p %d %p - mode not implemented\n", view, eModifyMode, rec );
344         r = ERROR_CALL_NOT_IMPLEMENTED;
345         break;
346 
347     default:
348         r = ERROR_INVALID_DATA;
349     }
350 
351     return r;
352 }
353 
354 static UINT STREAMS_delete(struct tagMSIVIEW *view)
355 {
356     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
357 
358     TRACE("(%p)\n", view);
359 
360     msi_free(sv);
361     return ERROR_SUCCESS;
362 }
363 
364 static UINT STREAMS_find_matching_rows(struct tagMSIVIEW *view, UINT col,
365                                        UINT val, UINT *row, MSIITERHANDLE *handle)
366 {
367     MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view;
368     UINT index = PtrToUlong(*handle);
369 
370     TRACE("(%p, %d, %d, %p, %p)\n", view, col, val, row, handle);
371 
372     if (!col || col > sv->num_cols)
373         return ERROR_INVALID_PARAMETER;
374 
375     while (index < sv->db->num_streams)
376     {
377         if (sv->db->streams[index].str_index == val)
378         {
379             *row = index;
380             break;
381         }
382         index++;
383     }
384 
385     *handle = UlongToPtr(++index);
386 
387     if (index > sv->db->num_streams)
388         return ERROR_NO_MORE_ITEMS;
389 
390     return ERROR_SUCCESS;
391 }
392 
393 static const MSIVIEWOPS streams_ops =
394 {
395     STREAMS_fetch_int,
396     STREAMS_fetch_stream,
397     STREAMS_get_row,
398     STREAMS_set_row,
399     STREAMS_insert_row,
400     STREAMS_delete_row,
401     STREAMS_execute,
402     STREAMS_close,
403     STREAMS_get_dimensions,
404     STREAMS_get_column_info,
405     STREAMS_modify,
406     STREAMS_delete,
407     STREAMS_find_matching_rows,
408     NULL,
409     NULL,
410     NULL,
411     NULL,
412     NULL,
413     NULL,
414 };
415 
416 static HRESULT open_stream( MSIDATABASE *db, const WCHAR *name, IStream **stream )
417 {
418     HRESULT hr;
419 
420     hr = IStorage_OpenStream( db->storage, name, NULL, STGM_READ|STGM_SHARE_EXCLUSIVE, 0, stream );
421     if (FAILED( hr ))
422     {
423         MSITRANSFORM *transform;
424 
425         LIST_FOR_EACH_ENTRY( transform, &db->transforms, MSITRANSFORM, entry )
426         {
427             hr = IStorage_OpenStream( transform->stg, name, NULL, STGM_READ|STGM_SHARE_EXCLUSIVE, 0, stream );
428             if (SUCCEEDED( hr ))
429                 break;
430         }
431     }
432     return hr;
433 }
434 
435 static MSISTREAM *find_stream( MSIDATABASE *db, const WCHAR *name )
436 {
437     UINT r, id, i;
438 
439     r = msi_string2id( db->strings, name, -1, &id );
440     if (r != ERROR_SUCCESS)
441         return NULL;
442 
443     for (i = 0; i < db->num_streams; i++)
444     {
445         if (db->streams[i].str_index == id) return &db->streams[i];
446     }
447     return NULL;
448 }
449 
450 static UINT append_stream( MSIDATABASE *db, const WCHAR *name, IStream *stream )
451 {
452     UINT i = db->num_streams;
453 
454     if (!streams_resize_table( db, db->num_streams + 1 ))
455         return ERROR_OUTOFMEMORY;
456 
457     db->streams[i].str_index = msi_add_string( db->strings, name, -1, StringNonPersistent );
458     db->streams[i].stream = stream;
459     db->num_streams++;
460 
461     TRACE("added %s\n", debugstr_w( name ));
462     return ERROR_SUCCESS;
463 }
464 
465 static UINT load_streams( MSIDATABASE *db )
466 {
467     WCHAR decoded[MAX_STREAM_NAME_LEN + 1];
468     IEnumSTATSTG *stgenum;
469     STATSTG stat;
470     HRESULT hr;
471     UINT count, r = ERROR_SUCCESS;
472     IStream *stream;
473 
474     hr = IStorage_EnumElements( db->storage, 0, NULL, 0, &stgenum );
475     if (FAILED( hr ))
476         return ERROR_FUNCTION_FAILED;
477 
478     for (;;)
479     {
480         count = 0;
481         hr = IEnumSTATSTG_Next( stgenum, 1, &stat, &count );
482         if (FAILED( hr ) || !count)
483             break;
484 
485         /* table streams are not in the _Streams table */
486         if (stat.type != STGTY_STREAM || *stat.pwcsName == 0x4840)
487         {
488             CoTaskMemFree( stat.pwcsName );
489             continue;
490         }
491         decode_streamname( stat.pwcsName, decoded );
492         if (find_stream( db, decoded ))
493         {
494             CoTaskMemFree( stat.pwcsName );
495             continue;
496         }
497         TRACE("found new stream %s\n", debugstr_w( decoded ));
498 
499         hr = open_stream( db, stat.pwcsName, &stream );
500         CoTaskMemFree( stat.pwcsName );
501         if (FAILED( hr ))
502         {
503             ERR("unable to open stream %08x\n", hr);
504             r = ERROR_FUNCTION_FAILED;
505             break;
506         }
507 
508         r = append_stream( db, decoded, stream );
509         if (r != ERROR_SUCCESS)
510             break;
511     }
512 
513     TRACE("loaded %u streams\n", db->num_streams);
514     IEnumSTATSTG_Release( stgenum );
515     return r;
516 }
517 
518 UINT msi_get_stream( MSIDATABASE *db, const WCHAR *name, IStream **ret )
519 {
520     MSISTREAM *stream;
521     WCHAR *encname;
522     HRESULT hr;
523     UINT r;
524 
525     if ((stream = find_stream( db, name )))
526     {
527         LARGE_INTEGER pos;
528 
529         pos.QuadPart = 0;
530         hr = IStream_Seek( stream->stream, pos, STREAM_SEEK_SET, NULL );
531         if (FAILED( hr ))
532             return ERROR_FUNCTION_FAILED;
533 
534         *ret = stream->stream;
535         IStream_AddRef( *ret );
536         return ERROR_SUCCESS;
537     }
538 
539     if (!(encname = encode_streamname( FALSE, name )))
540         return ERROR_OUTOFMEMORY;
541 
542     hr = open_stream( db, encname, ret );
543     msi_free( encname );
544     if (FAILED( hr ))
545         return ERROR_FUNCTION_FAILED;
546 
547     r = append_stream( db, name, *ret );
548     if (r != ERROR_SUCCESS)
549     {
550         IStream_Release( *ret );
551         return r;
552     }
553 
554     IStream_AddRef( *ret );
555     return ERROR_SUCCESS;
556 }
557 
558 UINT STREAMS_CreateView(MSIDATABASE *db, MSIVIEW **view)
559 {
560     MSISTREAMSVIEW *sv;
561     UINT r;
562 
563     TRACE("(%p, %p)\n", db, view);
564 
565     r = load_streams( db );
566     if (r != ERROR_SUCCESS)
567         return r;
568 
569     if (!(sv = msi_alloc_zero( sizeof(MSISTREAMSVIEW) )))
570         return ERROR_OUTOFMEMORY;
571 
572     sv->view.ops = &streams_ops;
573     sv->num_cols = NUM_STREAMS_COLS;
574     sv->db = db;
575 
576     *view = (MSIVIEW *)sv;
577 
578     return ERROR_SUCCESS;
579 }
580 
581 static HRESULT write_stream( IStream *dst, IStream *src )
582 {
583     HRESULT hr;
584     char buf[4096];
585     STATSTG stat;
586     LARGE_INTEGER pos;
587     UINT count, size;
588 
589     hr = IStream_Stat( src, &stat, STATFLAG_NONAME );
590     if (FAILED( hr )) return hr;
591 
592     hr = IStream_SetSize( dst, stat.cbSize );
593     if (FAILED( hr )) return hr;
594 
595     pos.QuadPart = 0;
596     hr = IStream_Seek( dst, pos, STREAM_SEEK_SET, NULL );
597     if (FAILED( hr )) return hr;
598 
599     for (;;)
600     {
601         size = min( sizeof(buf), stat.cbSize.QuadPart );
602         hr = IStream_Read( src, buf, size, &count );
603         if (FAILED( hr ) || count != size)
604         {
605             WARN("failed to read stream: %08x\n", hr);
606             return E_INVALIDARG;
607         }
608         stat.cbSize.QuadPart -= count;
609         if (count)
610         {
611             size = count;
612             hr = IStream_Write( dst, buf, size, &count );
613             if (FAILED( hr ) || count != size)
614             {
615                 WARN("failed to write stream: %08x\n", hr);
616                 return E_INVALIDARG;
617             }
618         }
619         if (!stat.cbSize.QuadPart) break;
620     }
621 
622     return S_OK;
623 }
624 
625 UINT msi_commit_streams( MSIDATABASE *db )
626 {
627     UINT i;
628     const WCHAR *name;
629     WCHAR *encname;
630     IStream *stream;
631     HRESULT hr;
632 
633     TRACE("got %u streams\n", db->num_streams);
634 
635     for (i = 0; i < db->num_streams; i++)
636     {
637         name = msi_string_lookup( db->strings, db->streams[i].str_index, NULL );
638         if (!(encname = encode_streamname( FALSE, name ))) return ERROR_OUTOFMEMORY;
639 
640         hr = IStorage_CreateStream( db->storage, encname, STGM_WRITE|STGM_SHARE_EXCLUSIVE, 0, 0, &stream );
641         if (SUCCEEDED( hr ))
642         {
643             hr = write_stream( stream, db->streams[i].stream );
644             if (FAILED( hr ))
645             {
646                 ERR("failed to write stream %s (hr = %08x)\n", debugstr_w(encname), hr);
647                 msi_free( encname );
648                 IStream_Release( stream );
649                 return ERROR_FUNCTION_FAILED;
650             }
651             hr = IStream_Commit( stream, 0 );
652             IStream_Release( stream );
653             if (FAILED( hr ))
654             {
655                 ERR("failed to commit stream %s (hr = %08x)\n", debugstr_w(encname), hr);
656                 msi_free( encname );
657                 return ERROR_FUNCTION_FAILED;
658             }
659         }
660         else if (hr != STG_E_FILEALREADYEXISTS)
661         {
662             ERR("failed to create stream %s (hr = %08x)\n", debugstr_w(encname), hr);
663             msi_free( encname );
664             return ERROR_FUNCTION_FAILED;
665         }
666         msi_free( encname );
667     }
668 
669     return ERROR_SUCCESS;
670 }
671