xref: /reactos/dll/win32/msi/where.c (revision f4be6dc3)
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2002 Mike McCormack for CodeWeavers
5  * Copyright 2011 Bernhard Loos
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
20  */
21 
22 #include <stdarg.h>
23 #include <assert.h>
24 
25 #include "windef.h"
26 #include "winbase.h"
27 #include "winerror.h"
28 #include "wine/debug.h"
29 #include "msi.h"
30 #include "msiquery.h"
31 #include "objbase.h"
32 #include "objidl.h"
33 #include "msipriv.h"
34 #include "winnls.h"
35 
36 #include "query.h"
37 
38 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
39 
40 /* below is the query interface to a table */
41 struct row_entry
42 {
43     struct tagMSIWHEREVIEW *wv; /* used during sorting */
44     UINT values[1];
45 };
46 
47 struct join_table
48 {
49     struct join_table *next;
50     MSIVIEW *view;
51     UINT col_count;
52     UINT row_count;
53     UINT table_index;
54 };
55 
56 typedef struct tagMSIORDERINFO
57 {
58     UINT col_count;
59     UINT error;
60     union ext_column columns[1];
61 } MSIORDERINFO;
62 
63 typedef struct tagMSIWHEREVIEW
64 {
65     MSIVIEW        view;
66     MSIDATABASE   *db;
67     struct join_table *tables;
68     UINT           row_count;
69     UINT           col_count;
70     UINT           table_count;
71     struct row_entry **reorder;
72     UINT           reorder_size; /* number of entries available in reorder */
73     struct expr   *cond;
74     UINT           rec_index;
75     MSIORDERINFO  *order_info;
76 } MSIWHEREVIEW;
77 
78 static UINT WHERE_evaluate( MSIWHEREVIEW *wv, const UINT rows[],
79                             struct expr *cond, INT *val, MSIRECORD *record );
80 
81 #define INITIAL_REORDER_SIZE 16
82 
83 #define INVALID_ROW_INDEX (-1)
84 
free_reorder(MSIWHEREVIEW * wv)85 static void free_reorder(MSIWHEREVIEW *wv)
86 {
87     UINT i;
88 
89     if (!wv->reorder)
90         return;
91 
92     for (i = 0; i < wv->row_count; i++)
93         free(wv->reorder[i]);
94 
95     free(wv->reorder);
96     wv->reorder = NULL;
97     wv->reorder_size = 0;
98     wv->row_count = 0;
99 }
100 
init_reorder(MSIWHEREVIEW * wv)101 static UINT init_reorder(MSIWHEREVIEW *wv)
102 {
103     struct row_entry **new = calloc(INITIAL_REORDER_SIZE, sizeof(*new));
104     if (!new)
105         return ERROR_OUTOFMEMORY;
106 
107     free_reorder(wv);
108 
109     wv->reorder = new;
110     wv->reorder_size = INITIAL_REORDER_SIZE;
111 
112     return ERROR_SUCCESS;
113 }
114 
find_row(MSIWHEREVIEW * wv,UINT row,UINT * (values[]))115 static inline UINT find_row(MSIWHEREVIEW *wv, UINT row, UINT *(values[]))
116 {
117     if (row >= wv->row_count)
118         return ERROR_NO_MORE_ITEMS;
119 
120     *values = wv->reorder[row]->values;
121 
122     return ERROR_SUCCESS;
123 }
124 
add_row(MSIWHEREVIEW * wv,UINT vals[])125 static UINT add_row(MSIWHEREVIEW *wv, UINT vals[])
126 {
127     struct row_entry *new;
128 
129     if (wv->reorder_size <= wv->row_count)
130     {
131         struct row_entry **new_reorder;
132         UINT newsize = wv->reorder_size * 2;
133 
134         new_reorder = realloc(wv->reorder, newsize * sizeof(*new_reorder));
135         if (!new_reorder)
136             return ERROR_OUTOFMEMORY;
137         memset(new_reorder + wv->reorder_size, 0, (newsize - wv->reorder_size) * sizeof(*new_reorder));
138 
139         wv->reorder = new_reorder;
140         wv->reorder_size = newsize;
141     }
142 
143     new = malloc(offsetof(struct row_entry, values[wv->table_count]));
144 
145     if (!new)
146         return ERROR_OUTOFMEMORY;
147 
148     wv->reorder[wv->row_count++] = new;
149 
150     memcpy(new->values, vals, wv->table_count * sizeof(UINT));
151     new->wv = wv;
152 
153     return ERROR_SUCCESS;
154 }
155 
find_table(MSIWHEREVIEW * wv,UINT col,UINT * table_col)156 static struct join_table *find_table(MSIWHEREVIEW *wv, UINT col, UINT *table_col)
157 {
158     struct join_table *table = wv->tables;
159 
160     if(col == 0 || col > wv->col_count)
161          return NULL;
162 
163     while (col > table->col_count)
164     {
165         col -= table->col_count;
166         table = table->next;
167         assert(table);
168     }
169 
170     *table_col = col;
171     return table;
172 }
173 
parse_column(MSIWHEREVIEW * wv,union ext_column * column,UINT * column_type)174 static UINT parse_column(MSIWHEREVIEW *wv, union ext_column *column,
175                          UINT *column_type)
176 {
177     struct join_table *table = wv->tables;
178     UINT i, r;
179 
180     do
181     {
182         LPCWSTR table_name;
183 
184         if (column->unparsed.table)
185         {
186             r = table->view->ops->get_column_info(table->view, 1, NULL, NULL,
187                                                   NULL, &table_name);
188             if (r != ERROR_SUCCESS)
189                 return r;
190             if (wcscmp(table_name, column->unparsed.table) != 0)
191                 continue;
192         }
193 
194         for(i = 1; i <= table->col_count; i++)
195         {
196             LPCWSTR col_name;
197 
198             r = table->view->ops->get_column_info(table->view, i, &col_name, column_type,
199                                                   NULL, NULL);
200             if(r != ERROR_SUCCESS )
201                 return r;
202 
203             if(wcscmp(col_name, column->unparsed.column))
204                 continue;
205             column->parsed.column = i;
206             column->parsed.table = table;
207             return ERROR_SUCCESS;
208         }
209     }
210     while ((table = table->next));
211 
212     WARN("Couldn't find column %s.%s\n", debugstr_w( column->unparsed.table ), debugstr_w( column->unparsed.column ) );
213     return ERROR_BAD_QUERY_SYNTAX;
214 }
215 
WHERE_fetch_int(struct tagMSIVIEW * view,UINT row,UINT col,UINT * val)216 static UINT WHERE_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
217 {
218     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
219     struct join_table *table;
220     UINT *rows;
221     UINT r;
222 
223     TRACE("%p %d %d %p\n", wv, row, col, val );
224 
225     if( !wv->tables )
226         return ERROR_FUNCTION_FAILED;
227 
228     r = find_row(wv, row, &rows);
229     if (r != ERROR_SUCCESS)
230         return r;
231 
232     table = find_table(wv, col, &col);
233     if (!table)
234         return ERROR_FUNCTION_FAILED;
235 
236     return table->view->ops->fetch_int(table->view, rows[table->table_index], col, val);
237 }
238 
WHERE_fetch_stream(struct tagMSIVIEW * view,UINT row,UINT col,IStream ** stm)239 static UINT WHERE_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm )
240 {
241     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
242     struct join_table *table;
243     UINT *rows;
244     UINT r;
245 
246     TRACE("%p %d %d %p\n", wv, row, col, stm );
247 
248     if( !wv->tables )
249         return ERROR_FUNCTION_FAILED;
250 
251     r = find_row(wv, row, &rows);
252     if (r != ERROR_SUCCESS)
253         return r;
254 
255     table = find_table(wv, col, &col);
256     if (!table)
257         return ERROR_FUNCTION_FAILED;
258 
259     return table->view->ops->fetch_stream( table->view, rows[table->table_index], col, stm );
260 }
261 
WHERE_set_int(struct tagMSIVIEW * view,UINT row,UINT col,int val)262 static UINT WHERE_set_int(struct tagMSIVIEW *view, UINT row, UINT col, int val)
263 {
264     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
265     struct join_table *table;
266     UINT *rows;
267     UINT r;
268 
269     TRACE("view %p, row %u, col %u, val %d.\n", wv, row, col, val );
270 
271     r = find_row(wv, row, &rows);
272     if (r != ERROR_SUCCESS)
273         return r;
274 
275     table = find_table(wv, col, &col);
276     if (!table)
277         return ERROR_FUNCTION_FAILED;
278 
279     return table->view->ops->set_int(table->view, rows[table->table_index], col, val);
280 }
281 
WHERE_set_string(struct tagMSIVIEW * view,UINT row,UINT col,const WCHAR * val,int len)282 static UINT WHERE_set_string(struct tagMSIVIEW *view, UINT row, UINT col, const WCHAR *val, int len)
283 {
284     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
285     struct join_table *table;
286     UINT *rows;
287     UINT r;
288 
289     TRACE("view %p, row %u, col %u, val %s.\n", wv, row, col, debugstr_wn(val, len));
290 
291     r = find_row(wv, row, &rows);
292     if (r != ERROR_SUCCESS)
293         return r;
294 
295     table = find_table(wv, col, &col);
296     if (!table)
297         return ERROR_FUNCTION_FAILED;
298 
299     return table->view->ops->set_string(table->view, rows[table->table_index], col, val, len);
300 }
301 
WHERE_set_stream(MSIVIEW * view,UINT row,UINT col,IStream * stream)302 static UINT WHERE_set_stream(MSIVIEW *view, UINT row, UINT col, IStream *stream)
303 {
304     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
305     struct join_table *table;
306     UINT *rows;
307     UINT r;
308 
309     TRACE("view %p, row %u, col %u, stream %p.\n", wv, row, col, stream);
310 
311     r = find_row(wv, row, &rows);
312     if (r != ERROR_SUCCESS)
313         return r;
314 
315     table = find_table(wv, col, &col);
316     if (!table)
317         return ERROR_FUNCTION_FAILED;
318 
319     return table->view->ops->set_stream(table->view, rows[table->table_index], col, stream);
320 }
321 
WHERE_set_row(struct tagMSIVIEW * view,UINT row,MSIRECORD * rec,UINT mask)322 static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask )
323 {
324     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
325     UINT i, r, offset = 0;
326     struct join_table *table = wv->tables;
327     UINT *rows;
328     UINT mask_copy = mask;
329 
330     TRACE("%p %d %p %08x\n", wv, row, rec, mask );
331 
332     if( !wv->tables )
333          return ERROR_FUNCTION_FAILED;
334 
335     r = find_row(wv, row, &rows);
336     if (r != ERROR_SUCCESS)
337         return r;
338 
339     if (mask >= 1 << wv->col_count)
340         return ERROR_INVALID_PARAMETER;
341 
342     do
343     {
344         for (i = 0; i < table->col_count; i++) {
345             UINT type;
346 
347             if (!(mask_copy & (1 << i)))
348                 continue;
349             r = table->view->ops->get_column_info(table->view, i + 1, NULL,
350                                             &type, NULL, NULL );
351             if (r != ERROR_SUCCESS)
352                 return r;
353             if (type & MSITYPE_KEY)
354                 return ERROR_FUNCTION_FAILED;
355         }
356         mask_copy >>= table->col_count;
357     }
358     while (mask_copy && (table = table->next));
359 
360     table = wv->tables;
361 
362     do
363     {
364         const UINT col_count = table->col_count;
365         UINT i;
366         MSIRECORD *reduced;
367         UINT reduced_mask = (mask >> offset) & ((1 << col_count) - 1);
368 
369         if (!reduced_mask)
370         {
371             offset += col_count;
372             continue;
373         }
374 
375         reduced = MSI_CreateRecord(col_count);
376         if (!reduced)
377             return ERROR_FUNCTION_FAILED;
378 
379         for (i = 1; i <= col_count; i++)
380         {
381             r = MSI_RecordCopyField(rec, i + offset, reduced, i);
382             if (r != ERROR_SUCCESS)
383                 break;
384         }
385 
386         offset += col_count;
387 
388         if (r == ERROR_SUCCESS)
389             r = table->view->ops->set_row(table->view, rows[table->table_index], reduced, reduced_mask);
390 
391         msiobj_release(&reduced->hdr);
392     }
393     while ((table = table->next));
394     return r;
395 }
396 
WHERE_delete_row(struct tagMSIVIEW * view,UINT row)397 static UINT WHERE_delete_row(struct tagMSIVIEW *view, UINT row)
398 {
399     MSIWHEREVIEW *wv = (MSIWHEREVIEW *)view;
400     UINT r;
401     UINT *rows;
402 
403     TRACE("(%p %d)\n", view, row);
404 
405     if (!wv->tables)
406         return ERROR_FUNCTION_FAILED;
407 
408     r = find_row(wv, row, &rows);
409     if ( r != ERROR_SUCCESS )
410         return r;
411 
412     if (wv->table_count > 1)
413         return ERROR_CALL_NOT_IMPLEMENTED;
414 
415     return wv->tables->view->ops->delete_row(wv->tables->view, rows[0]);
416 }
417 
INT_evaluate_binary(MSIWHEREVIEW * wv,const UINT rows[],const struct complex_expr * expr,INT * val,MSIRECORD * record)418 static INT INT_evaluate_binary( MSIWHEREVIEW *wv, const UINT rows[],
419                                 const struct complex_expr *expr, INT *val, MSIRECORD *record )
420 {
421     UINT rl, rr;
422     INT lval, rval;
423 
424     rl = WHERE_evaluate(wv, rows, expr->left, &lval, record);
425     if (rl != ERROR_SUCCESS && rl != ERROR_CONTINUE)
426         return rl;
427     rr = WHERE_evaluate(wv, rows, expr->right, &rval, record);
428     if (rr != ERROR_SUCCESS && rr != ERROR_CONTINUE)
429         return rr;
430 
431     if (rl == ERROR_CONTINUE || rr == ERROR_CONTINUE)
432     {
433         if (rl == rr)
434         {
435             *val = TRUE;
436             return ERROR_CONTINUE;
437         }
438 
439         if (expr->op == OP_AND)
440         {
441             if ((rl == ERROR_CONTINUE && !rval) || (rr == ERROR_CONTINUE && !lval))
442             {
443                 *val = FALSE;
444                 return ERROR_SUCCESS;
445             }
446         }
447         else if (expr->op == OP_OR)
448         {
449             if ((rl == ERROR_CONTINUE && rval) || (rr == ERROR_CONTINUE && lval))
450             {
451                 *val = TRUE;
452                 return ERROR_SUCCESS;
453             }
454         }
455 
456         *val = TRUE;
457         return ERROR_CONTINUE;
458     }
459 
460     switch( expr->op )
461     {
462     case OP_EQ:
463         *val = ( lval == rval );
464         break;
465     case OP_AND:
466         *val = ( lval && rval );
467         break;
468     case OP_OR:
469         *val = ( lval || rval );
470         break;
471     case OP_GT:
472         *val = ( lval > rval );
473         break;
474     case OP_LT:
475         *val = ( lval < rval );
476         break;
477     case OP_LE:
478         *val = ( lval <= rval );
479         break;
480     case OP_GE:
481         *val = ( lval >= rval );
482         break;
483     case OP_NE:
484         *val = ( lval != rval );
485         break;
486     default:
487         ERR("Unknown operator %d\n", expr->op );
488         return ERROR_FUNCTION_FAILED;
489     }
490 
491     return ERROR_SUCCESS;
492 }
493 
expr_fetch_value(const union ext_column * expr,const UINT rows[],UINT * val)494 static inline UINT expr_fetch_value(const union ext_column *expr, const UINT rows[], UINT *val)
495 {
496     struct join_table *table = expr->parsed.table;
497 
498     if( rows[table->table_index] == INVALID_ROW_INDEX )
499     {
500         *val = 1;
501         return ERROR_CONTINUE;
502     }
503     return table->view->ops->fetch_int(table->view, rows[table->table_index],
504                                         expr->parsed.column, val);
505 }
506 
507 
INT_evaluate_unary(MSIWHEREVIEW * wv,const UINT rows[],const struct complex_expr * expr,INT * val,MSIRECORD * record)508 static UINT INT_evaluate_unary( MSIWHEREVIEW *wv, const UINT rows[],
509                                 const struct complex_expr *expr, INT *val, MSIRECORD *record )
510 {
511     UINT r;
512     UINT lval;
513 
514     r = expr_fetch_value(&expr->left->u.column, rows, &lval);
515     if(r != ERROR_SUCCESS)
516         return r;
517 
518     switch( expr->op )
519     {
520     case OP_ISNULL:
521         *val = !lval;
522         break;
523     case OP_NOTNULL:
524         *val = lval;
525         break;
526     default:
527         ERR("Unknown operator %d\n", expr->op );
528         return ERROR_FUNCTION_FAILED;
529     }
530     return ERROR_SUCCESS;
531 }
532 
STRING_evaluate(MSIWHEREVIEW * wv,const UINT rows[],const struct expr * expr,const MSIRECORD * record,const WCHAR ** str)533 static UINT STRING_evaluate( MSIWHEREVIEW *wv, const UINT rows[],
534                                      const struct expr *expr,
535                                      const MSIRECORD *record,
536                                      const WCHAR **str )
537 {
538     UINT val = 0, r = ERROR_SUCCESS;
539 
540     switch( expr->type )
541     {
542     case EXPR_COL_NUMBER_STRING:
543         r = expr_fetch_value(&expr->u.column, rows, &val);
544         if (r == ERROR_SUCCESS)
545             *str =  msi_string_lookup(wv->db->strings, val, NULL);
546         else
547             *str = NULL;
548         break;
549 
550     case EXPR_SVAL:
551         *str = expr->u.sval;
552         break;
553 
554     case EXPR_WILDCARD:
555         *str = MSI_RecordGetString(record, ++wv->rec_index);
556         break;
557 
558     default:
559         ERR("Invalid expression type\n");
560         r = ERROR_FUNCTION_FAILED;
561         *str = NULL;
562         break;
563     }
564     return r;
565 }
566 
STRCMP_Evaluate(MSIWHEREVIEW * wv,const UINT rows[],const struct complex_expr * expr,INT * val,const MSIRECORD * record)567 static UINT STRCMP_Evaluate( MSIWHEREVIEW *wv, const UINT rows[], const struct complex_expr *expr,
568                              INT *val, const MSIRECORD *record )
569 {
570     int sr;
571     const WCHAR *l_str, *r_str;
572     UINT r;
573 
574     *val = TRUE;
575     r = STRING_evaluate(wv, rows, expr->left, record, &l_str);
576     if (r == ERROR_CONTINUE)
577         return r;
578     r = STRING_evaluate(wv, rows, expr->right, record, &r_str);
579     if (r == ERROR_CONTINUE)
580         return r;
581 
582     if( l_str == r_str ||
583         ((!l_str || !*l_str) && (!r_str || !*r_str)) )
584         sr = 0;
585     else if( l_str && ! r_str )
586         sr = 1;
587     else if( r_str && ! l_str )
588         sr = -1;
589     else
590         sr = wcscmp( l_str, r_str );
591 
592     *val = ( expr->op == OP_EQ && ( sr == 0 ) ) ||
593            ( expr->op == OP_NE && ( sr != 0 ) );
594 
595     return ERROR_SUCCESS;
596 }
597 
WHERE_evaluate(MSIWHEREVIEW * wv,const UINT rows[],struct expr * cond,INT * val,MSIRECORD * record)598 static UINT WHERE_evaluate( MSIWHEREVIEW *wv, const UINT rows[],
599                             struct expr *cond, INT *val, MSIRECORD *record )
600 {
601     UINT r, tval;
602 
603     if( !cond )
604     {
605         *val = TRUE;
606         return ERROR_SUCCESS;
607     }
608 
609     switch( cond->type )
610     {
611     case EXPR_COL_NUMBER:
612         r = expr_fetch_value(&cond->u.column, rows, &tval);
613         if( r != ERROR_SUCCESS )
614             return r;
615         *val = tval - 0x8000;
616         return ERROR_SUCCESS;
617 
618     case EXPR_COL_NUMBER32:
619         r = expr_fetch_value(&cond->u.column, rows, &tval);
620         if( r != ERROR_SUCCESS )
621             return r;
622         *val = tval - 0x80000000;
623         return r;
624 
625     case EXPR_UVAL:
626         *val = cond->u.uval;
627         return ERROR_SUCCESS;
628 
629     case EXPR_COMPLEX:
630         return INT_evaluate_binary(wv, rows, &cond->u.expr, val, record);
631 
632     case EXPR_UNARY:
633         return INT_evaluate_unary( wv, rows, &cond->u.expr, val, record );
634 
635     case EXPR_STRCMP:
636         return STRCMP_Evaluate( wv, rows, &cond->u.expr, val, record );
637 
638     case EXPR_WILDCARD:
639         *val = MSI_RecordGetInteger( record, ++wv->rec_index );
640         return ERROR_SUCCESS;
641 
642     default:
643         ERR("Invalid expression type\n");
644         return ERROR_FUNCTION_FAILED;
645     }
646 
647     return ERROR_SUCCESS;
648 }
649 
check_condition(MSIWHEREVIEW * wv,MSIRECORD * record,struct join_table ** tables,UINT table_rows[])650 static UINT check_condition( MSIWHEREVIEW *wv, MSIRECORD *record, struct join_table **tables,
651                              UINT table_rows[] )
652 {
653     UINT r = ERROR_FUNCTION_FAILED;
654     INT val;
655 
656     for (table_rows[(*tables)->table_index] = 0;
657          table_rows[(*tables)->table_index] < (*tables)->row_count;
658          table_rows[(*tables)->table_index]++)
659     {
660         val = 0;
661         wv->rec_index = 0;
662         r = WHERE_evaluate( wv, table_rows, wv->cond, &val, record );
663         if (r != ERROR_SUCCESS && r != ERROR_CONTINUE)
664             break;
665         if (val)
666         {
667             if (*(tables + 1))
668             {
669                 r = check_condition(wv, record, tables + 1, table_rows);
670                 if (r != ERROR_SUCCESS)
671                     break;
672             }
673             else
674             {
675                 if (r != ERROR_SUCCESS)
676                     break;
677                 add_row (wv, table_rows);
678             }
679         }
680     }
681     table_rows[(*tables)->table_index] = INVALID_ROW_INDEX;
682     return r;
683 }
684 
compare_entry(const void * left,const void * right)685 static int __cdecl compare_entry( const void *left, const void *right )
686 {
687     const struct row_entry *le = *(const struct row_entry **)left;
688     const struct row_entry *re = *(const struct row_entry **)right;
689     const MSIWHEREVIEW *wv = le->wv;
690     MSIORDERINFO *order = wv->order_info;
691     UINT i, j, r, l_val, r_val;
692 
693     assert(le->wv == re->wv);
694 
695     if (order)
696     {
697         for (i = 0; i < order->col_count; i++)
698         {
699             const union ext_column *column = &order->columns[i];
700 
701             r = column->parsed.table->view->ops->fetch_int(column->parsed.table->view,
702                           le->values[column->parsed.table->table_index],
703                           column->parsed.column, &l_val);
704             if (r != ERROR_SUCCESS)
705             {
706                 order->error = r;
707                 return 0;
708             }
709 
710             r = column->parsed.table->view->ops->fetch_int(column->parsed.table->view,
711                           re->values[column->parsed.table->table_index],
712                           column->parsed.column, &r_val);
713             if (r != ERROR_SUCCESS)
714             {
715                 order->error = r;
716                 return 0;
717             }
718 
719             if (l_val != r_val)
720                 return l_val < r_val ? -1 : 1;
721         }
722     }
723 
724     for (j = 0; j < wv->table_count; j++)
725     {
726         if (le->values[j] != re->values[j])
727             return le->values[j] < re->values[j] ? -1 : 1;
728     }
729     return 0;
730 }
731 
add_to_array(struct join_table ** array,struct join_table * elem)732 static void add_to_array( struct join_table **array, struct join_table *elem )
733 {
734     while (*array && *array != elem)
735         array++;
736     if (!*array)
737         *array = elem;
738 }
739 
in_array(struct join_table ** array,struct join_table * elem)740 static BOOL in_array( struct join_table **array, struct join_table *elem )
741 {
742     while (*array && *array != elem)
743         array++;
744     return *array != NULL;
745 }
746 
747 #define CONST_EXPR 1 /* comparison to a constant value */
748 #define JOIN_TO_CONST_EXPR 0x10000 /* comparison to a table involved with
749                                       a CONST_EXPR comaprison */
750 
reorder_check(const struct expr * expr,struct join_table ** ordered_tables,BOOL process_joins,struct join_table ** lastused)751 static UINT reorder_check( const struct expr *expr, struct join_table **ordered_tables,
752                            BOOL process_joins, struct join_table **lastused )
753 {
754     UINT res = 0;
755 
756     switch (expr->type)
757     {
758         case EXPR_WILDCARD:
759         case EXPR_SVAL:
760         case EXPR_UVAL:
761             return 0;
762         case EXPR_COL_NUMBER:
763         case EXPR_COL_NUMBER32:
764         case EXPR_COL_NUMBER_STRING:
765             if (in_array(ordered_tables, expr->u.column.parsed.table))
766                 return JOIN_TO_CONST_EXPR;
767             *lastused = expr->u.column.parsed.table;
768             return CONST_EXPR;
769         case EXPR_STRCMP:
770         case EXPR_COMPLEX:
771             res = reorder_check(expr->u.expr.right, ordered_tables, process_joins, lastused);
772             /* fall through */
773         case EXPR_UNARY:
774             res += reorder_check(expr->u.expr.left, ordered_tables, process_joins, lastused);
775             if (res == 0)
776                 return 0;
777             if (res == CONST_EXPR)
778                 add_to_array(ordered_tables, *lastused);
779             if (process_joins && res == JOIN_TO_CONST_EXPR + CONST_EXPR)
780                 add_to_array(ordered_tables, *lastused);
781             return res;
782         default:
783             ERR("Unknown expr type: %i\n", expr->type);
784             assert(0);
785             return 0x1000000;
786     }
787 }
788 
789 /* reorders the tablelist in a way to evaluate the condition as fast as possible */
ordertables(MSIWHEREVIEW * wv)790 static struct join_table **ordertables( MSIWHEREVIEW *wv )
791 {
792     struct join_table *table, **tables;
793 
794     tables = calloc(wv->table_count + 1, sizeof(*tables));
795 
796     if (wv->cond)
797     {
798         table = NULL;
799         reorder_check(wv->cond, tables, FALSE, &table);
800         table = NULL;
801         reorder_check(wv->cond, tables, TRUE, &table);
802     }
803 
804     table = wv->tables;
805     while (table)
806     {
807         add_to_array(tables, table);
808         table = table->next;
809     }
810     return tables;
811 }
812 
WHERE_execute(struct tagMSIVIEW * view,MSIRECORD * record)813 static UINT WHERE_execute( struct tagMSIVIEW *view, MSIRECORD *record )
814 {
815     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
816     UINT r;
817     struct join_table *table = wv->tables;
818     UINT *rows;
819     struct join_table **ordered_tables;
820     UINT i = 0;
821 
822     TRACE("%p %p\n", wv, record);
823 
824     if( !table )
825          return ERROR_FUNCTION_FAILED;
826 
827     r = init_reorder(wv);
828     if (r != ERROR_SUCCESS)
829         return r;
830 
831     do
832     {
833         table->view->ops->execute(table->view, NULL);
834 
835         r = table->view->ops->get_dimensions(table->view, &table->row_count, NULL);
836         if (r != ERROR_SUCCESS)
837         {
838             ERR("failed to get table dimensions\n");
839             return r;
840         }
841 
842         /* each table must have at least one row */
843         if (table->row_count == 0)
844             return ERROR_SUCCESS;
845     }
846     while ((table = table->next));
847 
848     ordered_tables = ordertables( wv );
849 
850     rows = malloc(wv->table_count * sizeof(*rows));
851     for (i = 0; i < wv->table_count; i++)
852         rows[i] = INVALID_ROW_INDEX;
853 
854     r =  check_condition(wv, record, ordered_tables, rows);
855 
856     if (wv->order_info)
857         wv->order_info->error = ERROR_SUCCESS;
858 
859     qsort(wv->reorder, wv->row_count, sizeof(struct row_entry *), compare_entry);
860 
861     if (wv->order_info)
862         r = wv->order_info->error;
863 
864     free(rows);
865     free(ordered_tables);
866     return r;
867 }
868 
WHERE_close(struct tagMSIVIEW * view)869 static UINT WHERE_close( struct tagMSIVIEW *view )
870 {
871     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
872     struct join_table *table = wv->tables;
873 
874     TRACE("%p\n", wv );
875 
876     if (!table)
877         return ERROR_FUNCTION_FAILED;
878 
879     do
880         table->view->ops->close(table->view);
881     while ((table = table->next));
882 
883     return ERROR_SUCCESS;
884 }
885 
WHERE_get_dimensions(struct tagMSIVIEW * view,UINT * rows,UINT * cols)886 static UINT WHERE_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols )
887 {
888     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
889 
890     TRACE("%p %p %p\n", wv, rows, cols );
891 
892     if(!wv->tables)
893          return ERROR_FUNCTION_FAILED;
894 
895     if (rows)
896     {
897         if (!wv->reorder)
898             return ERROR_FUNCTION_FAILED;
899         *rows = wv->row_count;
900     }
901 
902     if (cols)
903         *cols = wv->col_count;
904 
905     return ERROR_SUCCESS;
906 }
907 
WHERE_get_column_info(struct tagMSIVIEW * view,UINT n,LPCWSTR * name,UINT * type,BOOL * temporary,LPCWSTR * table_name)908 static UINT WHERE_get_column_info( struct tagMSIVIEW *view, UINT n, LPCWSTR *name,
909                                    UINT *type, BOOL *temporary, LPCWSTR *table_name )
910 {
911     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
912     struct join_table *table;
913 
914     TRACE("%p %d %p %p %p %p\n", wv, n, name, type, temporary, table_name );
915 
916     if(!wv->tables)
917          return ERROR_FUNCTION_FAILED;
918 
919     table = find_table(wv, n, &n);
920     if (!table)
921         return ERROR_FUNCTION_FAILED;
922 
923     return table->view->ops->get_column_info(table->view, n, name,
924                                             type, temporary, table_name);
925 }
926 
join_find_row(MSIWHEREVIEW * wv,MSIRECORD * rec,UINT * row)927 static UINT join_find_row( MSIWHEREVIEW *wv, MSIRECORD *rec, UINT *row )
928 {
929     LPCWSTR str;
930     UINT r, i, id, data;
931 
932     str = MSI_RecordGetString( rec, 1 );
933     r = msi_string2id( wv->db->strings, str, -1, &id );
934     if (r != ERROR_SUCCESS)
935         return r;
936 
937     for (i = 0; i < wv->row_count; i++)
938     {
939         WHERE_fetch_int( &wv->view, i, 1, &data );
940 
941         if (data == id)
942         {
943             *row = i;
944             return ERROR_SUCCESS;
945         }
946     }
947 
948     return ERROR_FUNCTION_FAILED;
949 }
950 
join_modify_update(struct tagMSIVIEW * view,MSIRECORD * rec)951 static UINT join_modify_update( struct tagMSIVIEW *view, MSIRECORD *rec )
952 {
953     MSIWHEREVIEW *wv = (MSIWHEREVIEW *)view;
954     UINT r, row, i, mask = 0;
955     MSIRECORD *current;
956 
957 
958     r = join_find_row( wv, rec, &row );
959     if (r != ERROR_SUCCESS)
960         return r;
961 
962     r = msi_view_get_row( wv->db, view, row, &current );
963     if (r != ERROR_SUCCESS)
964         return r;
965 
966     assert(MSI_RecordGetFieldCount(rec) == MSI_RecordGetFieldCount(current));
967 
968     for (i = MSI_RecordGetFieldCount(rec); i > 0; i--)
969     {
970         if (!MSI_RecordsAreFieldsEqual(rec, current, i))
971             mask |= 1 << (i - 1);
972     }
973      msiobj_release(&current->hdr);
974 
975     return WHERE_set_row( view, row, rec, mask );
976 }
977 
WHERE_modify(struct tagMSIVIEW * view,MSIMODIFY eModifyMode,MSIRECORD * rec,UINT row)978 static UINT WHERE_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,
979                           MSIRECORD *rec, UINT row )
980 {
981     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
982     struct join_table *table = wv->tables;
983     UINT r;
984 
985     TRACE("%p %d %p\n", wv, eModifyMode, rec);
986 
987     if (!table)
988         return ERROR_FUNCTION_FAILED;
989 
990     if (!table->next)
991     {
992         UINT *rows;
993 
994         if (find_row(wv, row, &rows) == ERROR_SUCCESS)
995             row = rows[0];
996         else
997             row = -1;
998 
999         return table->view->ops->modify(table->view, eModifyMode, rec, row);
1000     }
1001 
1002     switch (eModifyMode)
1003     {
1004     case MSIMODIFY_UPDATE:
1005         return join_modify_update( view, rec );
1006 
1007     case MSIMODIFY_ASSIGN:
1008     case MSIMODIFY_DELETE:
1009     case MSIMODIFY_INSERT:
1010     case MSIMODIFY_INSERT_TEMPORARY:
1011     case MSIMODIFY_MERGE:
1012     case MSIMODIFY_REPLACE:
1013     case MSIMODIFY_SEEK:
1014     case MSIMODIFY_VALIDATE:
1015     case MSIMODIFY_VALIDATE_DELETE:
1016     case MSIMODIFY_VALIDATE_FIELD:
1017     case MSIMODIFY_VALIDATE_NEW:
1018         r = ERROR_FUNCTION_FAILED;
1019         break;
1020 
1021     case MSIMODIFY_REFRESH:
1022         r = ERROR_CALL_NOT_IMPLEMENTED;
1023         break;
1024 
1025     default:
1026         WARN("%p %d %p %u - unknown mode\n", view, eModifyMode, rec, row );
1027         r = ERROR_INVALID_PARAMETER;
1028         break;
1029     }
1030 
1031     return r;
1032 }
1033 
WHERE_delete(struct tagMSIVIEW * view)1034 static UINT WHERE_delete( struct tagMSIVIEW *view )
1035 {
1036     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
1037     struct join_table *table = wv->tables;
1038 
1039     TRACE("%p\n", wv );
1040 
1041     while(table)
1042     {
1043         struct join_table *next;
1044 
1045         table->view->ops->delete(table->view);
1046         table->view = NULL;
1047         next = table->next;
1048         free(table);
1049         table = next;
1050     }
1051     wv->tables = NULL;
1052     wv->table_count = 0;
1053 
1054     free_reorder(wv);
1055 
1056     free(wv->order_info);
1057     wv->order_info = NULL;
1058 
1059     msiobj_release( &wv->db->hdr );
1060     free(wv);
1061 
1062     return ERROR_SUCCESS;
1063 }
1064 
WHERE_sort(struct tagMSIVIEW * view,column_info * columns)1065 static UINT WHERE_sort(struct tagMSIVIEW *view, column_info *columns)
1066 {
1067     MSIWHEREVIEW *wv = (MSIWHEREVIEW *)view;
1068     struct join_table *table = wv->tables;
1069     column_info *column = columns;
1070     MSIORDERINFO *orderinfo;
1071     UINT r, count = 0;
1072     UINT i;
1073 
1074     TRACE("%p %p\n", view, columns);
1075 
1076     if (!table)
1077         return ERROR_FUNCTION_FAILED;
1078 
1079     while (column)
1080     {
1081         count++;
1082         column = column->next;
1083     }
1084 
1085     if (count == 0)
1086         return ERROR_SUCCESS;
1087 
1088     orderinfo = malloc(offsetof(MSIORDERINFO, columns[count]));
1089     if (!orderinfo)
1090         return ERROR_OUTOFMEMORY;
1091 
1092     orderinfo->col_count = count;
1093 
1094     column = columns;
1095 
1096     for (i = 0; i < count; i++)
1097     {
1098         orderinfo->columns[i].unparsed.column = column->column;
1099         orderinfo->columns[i].unparsed.table = column->table;
1100 
1101         r = parse_column(wv, &orderinfo->columns[i], NULL);
1102         if (r != ERROR_SUCCESS)
1103             goto error;
1104     }
1105 
1106     wv->order_info = orderinfo;
1107 
1108     return ERROR_SUCCESS;
1109 error:
1110     free(orderinfo);
1111     return r;
1112 }
1113 
1114 static const MSIVIEWOPS where_ops =
1115 {
1116     WHERE_fetch_int,
1117     WHERE_fetch_stream,
1118     WHERE_set_int,
1119     WHERE_set_string,
1120     WHERE_set_stream,
1121     WHERE_set_row,
1122     NULL,
1123     WHERE_delete_row,
1124     WHERE_execute,
1125     WHERE_close,
1126     WHERE_get_dimensions,
1127     WHERE_get_column_info,
1128     WHERE_modify,
1129     WHERE_delete,
1130     NULL,
1131     NULL,
1132     NULL,
1133     WHERE_sort,
1134     NULL,
1135 };
1136 
WHERE_VerifyCondition(MSIWHEREVIEW * wv,struct expr * cond,UINT * valid)1137 static UINT WHERE_VerifyCondition( MSIWHEREVIEW *wv, struct expr *cond,
1138                                    UINT *valid )
1139 {
1140     UINT r;
1141 
1142     switch( cond->type )
1143     {
1144     case EXPR_COLUMN:
1145     {
1146         UINT type;
1147 
1148         *valid = FALSE;
1149 
1150         r = parse_column(wv, &cond->u.column, &type);
1151         if (r != ERROR_SUCCESS)
1152             break;
1153 
1154         if (type&MSITYPE_STRING)
1155             cond->type = EXPR_COL_NUMBER_STRING;
1156         else if ((type&0xff) == 4)
1157             cond->type = EXPR_COL_NUMBER32;
1158         else
1159             cond->type = EXPR_COL_NUMBER;
1160 
1161         *valid = TRUE;
1162         break;
1163     }
1164     case EXPR_COMPLEX:
1165         r = WHERE_VerifyCondition( wv, cond->u.expr.left, valid );
1166         if( r != ERROR_SUCCESS )
1167             return r;
1168         if( !*valid )
1169             return ERROR_SUCCESS;
1170         r = WHERE_VerifyCondition( wv, cond->u.expr.right, valid );
1171         if( r != ERROR_SUCCESS )
1172             return r;
1173 
1174         /* check the type of the comparison */
1175         if( ( cond->u.expr.left->type == EXPR_SVAL ) ||
1176             ( cond->u.expr.left->type == EXPR_COL_NUMBER_STRING ) ||
1177             ( cond->u.expr.right->type == EXPR_SVAL ) ||
1178             ( cond->u.expr.right->type == EXPR_COL_NUMBER_STRING ) )
1179         {
1180             switch( cond->u.expr.op )
1181             {
1182             case OP_EQ:
1183             case OP_NE:
1184                 break;
1185             default:
1186                 *valid = FALSE;
1187                 return ERROR_INVALID_PARAMETER;
1188             }
1189 
1190             /* FIXME: check we're comparing a string to a column */
1191 
1192             cond->type = EXPR_STRCMP;
1193         }
1194 
1195         break;
1196     case EXPR_UNARY:
1197         if ( cond->u.expr.left->type != EXPR_COLUMN )
1198         {
1199             *valid = FALSE;
1200             return ERROR_INVALID_PARAMETER;
1201         }
1202         r = WHERE_VerifyCondition( wv, cond->u.expr.left, valid );
1203         if( r != ERROR_SUCCESS )
1204             return r;
1205         break;
1206     case EXPR_IVAL:
1207         *valid = 1;
1208         cond->type = EXPR_UVAL;
1209         cond->u.uval = cond->u.ival;
1210         break;
1211     case EXPR_WILDCARD:
1212         *valid = 1;
1213         break;
1214     case EXPR_SVAL:
1215         *valid = 1;
1216         break;
1217     default:
1218         ERR("Invalid expression type\n");
1219         *valid = 0;
1220         break;
1221     }
1222 
1223     return ERROR_SUCCESS;
1224 }
1225 
WHERE_CreateView(MSIDATABASE * db,MSIVIEW ** view,LPWSTR tables,struct expr * cond)1226 UINT WHERE_CreateView( MSIDATABASE *db, MSIVIEW **view, LPWSTR tables,
1227                        struct expr *cond )
1228 {
1229     MSIWHEREVIEW *wv = NULL;
1230     UINT r, valid = 0;
1231     WCHAR *ptr;
1232 
1233     TRACE("(%s)\n", debugstr_w(tables) );
1234 
1235     wv = calloc(1, sizeof *wv);
1236     if( !wv )
1237         return ERROR_FUNCTION_FAILED;
1238 
1239     /* fill the structure */
1240     wv->view.ops = &where_ops;
1241     msiobj_addref( &db->hdr );
1242     wv->db = db;
1243     wv->cond = cond;
1244 
1245     while (*tables)
1246     {
1247         struct join_table *table;
1248 
1249         if ((ptr = wcschr(tables, ' ')))
1250             *ptr = '\0';
1251 
1252         table = malloc(sizeof(*table));
1253         if (!table)
1254         {
1255             r = ERROR_OUTOFMEMORY;
1256             goto end;
1257         }
1258 
1259         r = TABLE_CreateView(db, tables, &table->view);
1260         if (r != ERROR_SUCCESS)
1261         {
1262             WARN("can't create table: %s\n", debugstr_w(tables));
1263             free(table);
1264             r = ERROR_BAD_QUERY_SYNTAX;
1265             goto end;
1266         }
1267 
1268         r = table->view->ops->get_dimensions(table->view, NULL,
1269                                              &table->col_count);
1270         if (r != ERROR_SUCCESS)
1271         {
1272             ERR("can't get table dimensions\n");
1273             table->view->ops->delete(table->view);
1274             free(table);
1275             goto end;
1276         }
1277 
1278         wv->col_count += table->col_count;
1279         table->table_index = wv->table_count++;
1280 
1281         table->next = wv->tables;
1282         wv->tables = table;
1283 
1284         if (!ptr)
1285             break;
1286 
1287         tables = ptr + 1;
1288     }
1289 
1290     if( cond )
1291     {
1292         r = WHERE_VerifyCondition( wv, cond, &valid );
1293         if( r != ERROR_SUCCESS )
1294             goto end;
1295         if( !valid ) {
1296             r = ERROR_FUNCTION_FAILED;
1297             goto end;
1298         }
1299     }
1300 
1301     *view = (MSIVIEW*) wv;
1302 
1303     return ERROR_SUCCESS;
1304 end:
1305     WHERE_delete(&wv->view);
1306 
1307     return r;
1308 }
1309