1 #include "pg_query.h"
2 #include "pg_query_internal.h"
3
4 #include "parser/parser.h"
5 #include "parser/scanner.h"
6 #include "parser/scansup.h"
7 #include "mb/pg_wchar.h"
8 #include "nodes/nodeFuncs.h"
9
10 /*
11 * Struct for tracking locations/lengths of constants during normalization
12 */
13 typedef struct pgssLocationLen
14 {
15 int location; /* start offset in query text */
16 int length; /* length in bytes, or -1 to ignore */
17 } pgssLocationLen;
18
19 /*
20 * Working state for constant tree walker
21 */
22 typedef struct pgssConstLocations
23 {
24 /* Array of locations of constants that should be removed */
25 pgssLocationLen *clocations;
26
27 /* Allocated length of clocations array */
28 int clocations_buf_size;
29
30 /* Current number of valid entries in clocations array */
31 int clocations_count;
32
33 /* highest Param id we've seen, in order to start normalization correctly */
34 int highest_extern_param_id;
35 } pgssConstLocations;
36
37 /*
38 * comp_location: comparator for qsorting pgssLocationLen structs by location
39 */
40 static int
comp_location(const void * a,const void * b)41 comp_location(const void *a, const void *b)
42 {
43 int l = ((const pgssLocationLen *) a)->location;
44 int r = ((const pgssLocationLen *) b)->location;
45
46 if (l < r)
47 return -1;
48 else if (l > r)
49 return +1;
50 else
51 return 0;
52 }
53
54 /*
55 * Given a valid SQL string and an array of constant-location records,
56 * fill in the textual lengths of those constants.
57 *
58 * The constants may use any allowed constant syntax, such as float literals,
59 * bit-strings, single-quoted strings and dollar-quoted strings. This is
60 * accomplished by using the public API for the core scanner.
61 *
62 * It is the caller's job to ensure that the string is a valid SQL statement
63 * with constants at the indicated locations. Since in practice the string
64 * has already been parsed, and the locations that the caller provides will
65 * have originated from within the authoritative parser, this should not be
66 * a problem.
67 *
68 * Duplicate constant pointers are possible, and will have their lengths
69 * marked as '-1', so that they are later ignored. (Actually, we assume the
70 * lengths were initialized as -1 to start with, and don't change them here.)
71 *
72 * N.B. There is an assumption that a '-' character at a Const location begins
73 * a negative numeric constant. This precludes there ever being another
74 * reason for a constant to start with a '-'.
75 */
76 static void
fill_in_constant_lengths(pgssConstLocations * jstate,const char * query)77 fill_in_constant_lengths(pgssConstLocations *jstate, const char *query)
78 {
79 pgssLocationLen *locs;
80 core_yyscan_t yyscanner;
81 core_yy_extra_type yyextra;
82 core_YYSTYPE yylval;
83 YYLTYPE yylloc;
84 int last_loc = -1;
85 int i;
86
87 /*
88 * Sort the records by location so that we can process them in order while
89 * scanning the query text.
90 */
91 if (jstate->clocations_count > 1)
92 qsort(jstate->clocations, jstate->clocations_count,
93 sizeof(pgssLocationLen), comp_location);
94 locs = jstate->clocations;
95
96 /* initialize the flex scanner --- should match raw_parser() */
97 yyscanner = scanner_init(query,
98 &yyextra,
99 ScanKeywords,
100 NumScanKeywords);
101
102 /* Search for each constant, in sequence */
103 for (i = 0; i < jstate->clocations_count; i++)
104 {
105 int loc = locs[i].location;
106 int tok;
107
108 Assert(loc >= 0);
109
110 if (loc <= last_loc)
111 continue; /* Duplicate constant, ignore */
112
113 /* Lex tokens until we find the desired constant */
114 for (;;)
115 {
116 tok = core_yylex(&yylval, &yylloc, yyscanner);
117
118 /* We should not hit end-of-string, but if we do, behave sanely */
119 if (tok == 0)
120 break; /* out of inner for-loop */
121
122 /*
123 * We should find the token position exactly, but if we somehow
124 * run past it, work with that.
125 */
126 if (yylloc >= loc)
127 {
128 if (query[loc] == '-')
129 {
130 /*
131 * It's a negative value - this is the one and only case
132 * where we replace more than a single token.
133 *
134 * Do not compensate for the core system's special-case
135 * adjustment of location to that of the leading '-'
136 * operator in the event of a negative constant. It is
137 * also useful for our purposes to start from the minus
138 * symbol. In this way, queries like "select * from foo
139 * where bar = 1" and "select * from foo where bar = -2"
140 * will have identical normalized query strings.
141 */
142 tok = core_yylex(&yylval, &yylloc, yyscanner);
143 if (tok == 0)
144 break; /* out of inner for-loop */
145 }
146
147 /*
148 * We now rely on the assumption that flex has placed a zero
149 * byte after the text of the current token in scanbuf.
150 */
151 locs[i].length = (int) strlen(yyextra.scanbuf + loc);
152
153 /* Quoted string with Unicode escapes
154 *
155 * The lexer consumes trailing whitespace in order to find UESCAPE, but if there
156 * is no UESCAPE it has still consumed it - don't include it in constant length.
157 */
158 if (locs[i].length > 4 && /* U&'' */
159 (yyextra.scanbuf[loc] == 'u' || yyextra.scanbuf[loc] == 'U') &&
160 yyextra.scanbuf[loc + 1] == '&' && yyextra.scanbuf[loc + 2] == '\'')
161 {
162 int j = locs[i].length - 1; /* Skip the \0 */
163 for (; j >= 0 && scanner_isspace(yyextra.scanbuf[loc + j]); j--) {}
164 locs[i].length = j + 1; /* Count the \0 */
165 }
166
167 break; /* out of inner for-loop */
168 }
169 }
170
171 /* If we hit end-of-string, give up, leaving remaining lengths -1 */
172 if (tok == 0)
173 break;
174
175 last_loc = loc;
176 }
177
178 scanner_finish(yyscanner);
179 }
180
181 /*
182 * Generate a normalized version of the query string that will be used to
183 * represent all similar queries.
184 *
185 * Note that the normalized representation may well vary depending on
186 * just which "equivalent" query is used to create the hashtable entry.
187 * We assume this is OK.
188 *
189 * *query_len_p contains the input string length, and is updated with
190 * the result string length (which cannot be longer) on exit.
191 *
192 * Returns a palloc'd string.
193 */
194 static char *
generate_normalized_query(pgssConstLocations * jstate,const char * query,int query_loc,int * query_len_p,int encoding)195 generate_normalized_query(pgssConstLocations *jstate, const char *query,
196 int query_loc, int *query_len_p, int encoding)
197 {
198 char *norm_query;
199 int query_len = *query_len_p;
200 int i,
201 norm_query_buflen, /* Space allowed for norm_query */
202 len_to_wrt, /* Length (in bytes) to write */
203 quer_loc = 0, /* Source query byte location */
204 n_quer_loc = 0, /* Normalized query byte location */
205 last_off = 0, /* Offset from start for previous tok */
206 last_tok_len = 0; /* Length (in bytes) of that tok */
207
208 /*
209 * Get constants' lengths (core system only gives us locations). Note
210 * this also ensures the items are sorted by location.
211 */
212 fill_in_constant_lengths(jstate, query);
213
214 /*
215 * Allow for $n symbols to be longer than the constants they replace.
216 * Constants must take at least one byte in text form, while a $n symbol
217 * certainly isn't more than 11 bytes, even if n reaches INT_MAX. We
218 * could refine that limit based on the max value of n for the current
219 * query, but it hardly seems worth any extra effort to do so.
220 */
221 norm_query_buflen = query_len + jstate->clocations_count * 10;
222
223 /* Allocate result buffer */
224 norm_query = palloc(norm_query_buflen + 1);
225
226 for (i = 0; i < jstate->clocations_count; i++)
227 {
228 int off, /* Offset from start for cur tok */
229 tok_len; /* Length (in bytes) of that tok */
230
231 off = jstate->clocations[i].location;
232 /* Adjust recorded location if we're dealing with partial string */
233 off -= query_loc;
234
235 tok_len = jstate->clocations[i].length;
236
237 if (tok_len < 0)
238 continue; /* ignore any duplicates */
239
240 /* Copy next chunk (what precedes the next constant) */
241 len_to_wrt = off - last_off;
242 len_to_wrt -= last_tok_len;
243
244 Assert(len_to_wrt >= 0);
245 memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt);
246 n_quer_loc += len_to_wrt;
247
248 /* And insert a param symbol in place of the constant token */
249 n_quer_loc += sprintf(norm_query + n_quer_loc, "$%d",
250 i + 1 + jstate->highest_extern_param_id);
251
252 quer_loc = off + tok_len;
253 last_off = off;
254 last_tok_len = tok_len;
255 }
256
257 /*
258 * We've copied up until the last ignorable constant. Copy over the
259 * remaining bytes of the original query string.
260 */
261 len_to_wrt = query_len - quer_loc;
262
263 Assert(len_to_wrt >= 0);
264 memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt);
265 n_quer_loc += len_to_wrt;
266
267 Assert(n_quer_loc <= norm_query_buflen);
268 norm_query[n_quer_loc] = '\0';
269
270 *query_len_p = n_quer_loc;
271 return norm_query;
272 }
273
RecordConstLocation(pgssConstLocations * jstate,int location)274 static void RecordConstLocation(pgssConstLocations *jstate, int location)
275 {
276 /* -1 indicates unknown or undefined location */
277 if (location >= 0)
278 {
279 /* enlarge array if needed */
280 if (jstate->clocations_count >= jstate->clocations_buf_size)
281 {
282 jstate->clocations_buf_size *= 2;
283 jstate->clocations = (pgssLocationLen *)
284 repalloc(jstate->clocations,
285 jstate->clocations_buf_size *
286 sizeof(pgssLocationLen));
287 }
288 jstate->clocations[jstate->clocations_count].location = location;
289 /* initialize lengths to -1 to simplify fill_in_constant_lengths */
290 jstate->clocations[jstate->clocations_count].length = -1;
291 jstate->clocations_count++;
292 }
293 }
294
const_record_walker(Node * node,pgssConstLocations * jstate)295 static bool const_record_walker(Node *node, pgssConstLocations *jstate)
296 {
297 bool result;
298
299 if (node == NULL) return false;
300
301 if (IsA(node, A_Const))
302 {
303 RecordConstLocation(jstate, castNode(A_Const, node)->location);
304 }
305 else if (IsA(node, ParamRef))
306 {
307 /* Track the highest ParamRef number */
308 if (((ParamRef *) node)->number > jstate->highest_extern_param_id)
309 jstate->highest_extern_param_id = castNode(ParamRef, node)->number;
310 }
311 else if (IsA(node, DefElem))
312 {
313 return const_record_walker((Node *) ((DefElem *) node)->arg, jstate);
314 }
315 else if (IsA(node, RawStmt))
316 {
317 return const_record_walker((Node *) ((RawStmt *) node)->stmt, jstate);
318 }
319 else if (IsA(node, VariableSetStmt))
320 {
321 return const_record_walker((Node *) ((VariableSetStmt *) node)->args, jstate);
322 }
323 else if (IsA(node, CopyStmt))
324 {
325 return const_record_walker((Node *) ((CopyStmt *) node)->query, jstate);
326 }
327 else if (IsA(node, ExplainStmt))
328 {
329 return const_record_walker((Node *) ((ExplainStmt *) node)->query, jstate);
330 }
331 else if (IsA(node, AlterRoleStmt))
332 {
333 return const_record_walker((Node *) ((AlterRoleStmt *) node)->options, jstate);
334 }
335 else if (IsA(node, DeclareCursorStmt))
336 {
337 return const_record_walker((Node *) ((DeclareCursorStmt *) node)->query, jstate);
338 }
339
340 PG_TRY();
341 {
342 result = raw_expression_tree_walker(node, const_record_walker, (void*) jstate);
343 }
344 PG_CATCH();
345 {
346 FlushErrorState();
347 result = false;
348 }
349 PG_END_TRY();
350
351 return result;
352 }
353
pg_query_normalize(const char * input)354 PgQueryNormalizeResult pg_query_normalize(const char* input)
355 {
356 MemoryContext ctx = NULL;
357 PgQueryNormalizeResult result = {0};
358
359 ctx = pg_query_enter_memory_context("pg_query_normalize");
360
361 PG_TRY();
362 {
363 List *tree;
364 pgssConstLocations jstate;
365 int query_len;
366
367 /* Parse query */
368 tree = raw_parser(input);
369
370 /* Set up workspace for constant recording */
371 jstate.clocations_buf_size = 32;
372 jstate.clocations = (pgssLocationLen *)
373 palloc(jstate.clocations_buf_size * sizeof(pgssLocationLen));
374 jstate.clocations_count = 0;
375 jstate.highest_extern_param_id = 0;
376
377 /* Walk tree and record const locations */
378 const_record_walker((Node *) tree, &jstate);
379
380 /* Normalize query */
381 query_len = (int) strlen(input);
382 result.normalized_query = strdup(generate_normalized_query(&jstate, input, 0, &query_len, PG_UTF8));
383 }
384 PG_CATCH();
385 {
386 ErrorData* error_data;
387 PgQueryError* error;
388
389 MemoryContextSwitchTo(ctx);
390 error_data = CopyErrorData();
391
392 error = malloc(sizeof(PgQueryError));
393 error->message = strdup(error_data->message);
394 error->filename = strdup(error_data->filename);
395 error->lineno = error_data->lineno;
396 error->cursorpos = error_data->cursorpos;
397
398 result.error = error;
399 FlushErrorState();
400 }
401 PG_END_TRY();
402
403 pg_query_exit_memory_context(ctx);
404
405 return result;
406 }
407
pg_query_free_normalize_result(PgQueryNormalizeResult result)408 void pg_query_free_normalize_result(PgQueryNormalizeResult result)
409 {
410 if (result.error) {
411 free(result.error->message);
412 free(result.error->filename);
413 free(result.error);
414 }
415
416 free(result.normalized_query);
417 }
418