1 #include "pg_query.h"
2 #include "pg_query_internal.h"
3 
4 #include "parser/parser.h"
5 #include "parser/scanner.h"
6 #include "parser/scansup.h"
7 #include "mb/pg_wchar.h"
8 #include "nodes/nodeFuncs.h"
9 
10 /*
11  * Struct for tracking locations/lengths of constants during normalization
12  */
13 typedef struct pgssLocationLen
14 {
15 	int			location;		/* start offset in query text */
16 	int			length;			/* length in bytes, or -1 to ignore */
17 } pgssLocationLen;
18 
19 /*
20  * Working state for constant tree walker
21  */
22 typedef struct pgssConstLocations
23 {
24 	/* Array of locations of constants that should be removed */
25 	pgssLocationLen *clocations;
26 
27 	/* Allocated length of clocations array */
28 	int			clocations_buf_size;
29 
30 	/* Current number of valid entries in clocations array */
31 	int			clocations_count;
32 
33 	/* highest Param id we've seen, in order to start normalization correctly */
34 	int			highest_extern_param_id;
35 } pgssConstLocations;
36 
37 /*
38  * comp_location: comparator for qsorting pgssLocationLen structs by location
39  */
40 static int
comp_location(const void * a,const void * b)41 comp_location(const void *a, const void *b)
42 {
43 	int			l = ((const pgssLocationLen *) a)->location;
44 	int			r = ((const pgssLocationLen *) b)->location;
45 
46 	if (l < r)
47 		return -1;
48 	else if (l > r)
49 		return +1;
50 	else
51 		return 0;
52 }
53 
54 /*
55  * Given a valid SQL string and an array of constant-location records,
56  * fill in the textual lengths of those constants.
57  *
58  * The constants may use any allowed constant syntax, such as float literals,
59  * bit-strings, single-quoted strings and dollar-quoted strings.  This is
60  * accomplished by using the public API for the core scanner.
61  *
62  * It is the caller's job to ensure that the string is a valid SQL statement
63  * with constants at the indicated locations.  Since in practice the string
64  * has already been parsed, and the locations that the caller provides will
65  * have originated from within the authoritative parser, this should not be
66  * a problem.
67  *
68  * Duplicate constant pointers are possible, and will have their lengths
69  * marked as '-1', so that they are later ignored.  (Actually, we assume the
70  * lengths were initialized as -1 to start with, and don't change them here.)
71  *
72  * N.B. There is an assumption that a '-' character at a Const location begins
73  * a negative numeric constant.  This precludes there ever being another
74  * reason for a constant to start with a '-'.
75  */
76 static void
fill_in_constant_lengths(pgssConstLocations * jstate,const char * query)77 fill_in_constant_lengths(pgssConstLocations *jstate, const char *query)
78 {
79 	pgssLocationLen *locs;
80 	core_yyscan_t yyscanner;
81 	core_yy_extra_type yyextra;
82 	core_YYSTYPE yylval;
83 	YYLTYPE		yylloc;
84 	int			last_loc = -1;
85 	int			i;
86 
87 	/*
88 	 * Sort the records by location so that we can process them in order while
89 	 * scanning the query text.
90 	 */
91 	if (jstate->clocations_count > 1)
92 		qsort(jstate->clocations, jstate->clocations_count,
93 			  sizeof(pgssLocationLen), comp_location);
94 	locs = jstate->clocations;
95 
96 	/* initialize the flex scanner --- should match raw_parser() */
97 	yyscanner = scanner_init(query,
98 							 &yyextra,
99 							 ScanKeywords,
100 							 NumScanKeywords);
101 
102 	/* Search for each constant, in sequence */
103 	for (i = 0; i < jstate->clocations_count; i++)
104 	{
105 		int			loc = locs[i].location;
106 		int			tok;
107 
108 		Assert(loc >= 0);
109 
110 		if (loc <= last_loc)
111 			continue;			/* Duplicate constant, ignore */
112 
113 		/* Lex tokens until we find the desired constant */
114 		for (;;)
115 		{
116 			tok = core_yylex(&yylval, &yylloc, yyscanner);
117 
118 			/* We should not hit end-of-string, but if we do, behave sanely */
119 			if (tok == 0)
120 				break;			/* out of inner for-loop */
121 
122 			/*
123 			 * We should find the token position exactly, but if we somehow
124 			 * run past it, work with that.
125 			 */
126 			if (yylloc >= loc)
127 			{
128 				if (query[loc] == '-')
129 				{
130 					/*
131 					 * It's a negative value - this is the one and only case
132 					 * where we replace more than a single token.
133 					 *
134 					 * Do not compensate for the core system's special-case
135 					 * adjustment of location to that of the leading '-'
136 					 * operator in the event of a negative constant.  It is
137 					 * also useful for our purposes to start from the minus
138 					 * symbol.  In this way, queries like "select * from foo
139 					 * where bar = 1" and "select * from foo where bar = -2"
140 					 * will have identical normalized query strings.
141 					 */
142 					tok = core_yylex(&yylval, &yylloc, yyscanner);
143 					if (tok == 0)
144 						break;	/* out of inner for-loop */
145 				}
146 
147 				/*
148 				 * We now rely on the assumption that flex has placed a zero
149 				 * byte after the text of the current token in scanbuf.
150 				 */
151 				locs[i].length = (int) strlen(yyextra.scanbuf + loc);
152 
153 				/* Quoted string with Unicode escapes
154 				 *
155 				 * The lexer consumes trailing whitespace in order to find UESCAPE, but if there
156 				 * is no UESCAPE it has still consumed it - don't include it in constant length.
157 				 */
158 				if (locs[i].length > 4 && /* U&'' */
159 					(yyextra.scanbuf[loc] == 'u' || yyextra.scanbuf[loc] == 'U') &&
160 					 yyextra.scanbuf[loc + 1] == '&' && yyextra.scanbuf[loc + 2] == '\'')
161 				{
162 					int j = locs[i].length - 1; /* Skip the \0 */
163 					for (; j >= 0 && scanner_isspace(yyextra.scanbuf[loc + j]); j--) {}
164 					locs[i].length = j + 1; /* Count the \0 */
165 				}
166 
167 				break;			/* out of inner for-loop */
168 			}
169 		}
170 
171 		/* If we hit end-of-string, give up, leaving remaining lengths -1 */
172 		if (tok == 0)
173 			break;
174 
175 		last_loc = loc;
176 	}
177 
178 	scanner_finish(yyscanner);
179 }
180 
181 /*
182  * Generate a normalized version of the query string that will be used to
183  * represent all similar queries.
184  *
185  * Note that the normalized representation may well vary depending on
186  * just which "equivalent" query is used to create the hashtable entry.
187  * We assume this is OK.
188  *
189  * *query_len_p contains the input string length, and is updated with
190  * the result string length (which cannot be longer) on exit.
191  *
192  * Returns a palloc'd string.
193  */
194 static char *
generate_normalized_query(pgssConstLocations * jstate,const char * query,int query_loc,int * query_len_p,int encoding)195 generate_normalized_query(pgssConstLocations *jstate, const char *query,
196 						  int query_loc, int *query_len_p, int encoding)
197 {
198 	char	   *norm_query;
199 	int			query_len = *query_len_p;
200 	int			i,
201 				norm_query_buflen,		/* Space allowed for norm_query */
202 				len_to_wrt,		/* Length (in bytes) to write */
203 				quer_loc = 0,	/* Source query byte location */
204 				n_quer_loc = 0, /* Normalized query byte location */
205 				last_off = 0,	/* Offset from start for previous tok */
206 				last_tok_len = 0;		/* Length (in bytes) of that tok */
207 
208 	/*
209 	 * Get constants' lengths (core system only gives us locations).  Note
210 	 * this also ensures the items are sorted by location.
211 	 */
212 	fill_in_constant_lengths(jstate, query);
213 
214 	/*
215 	 * Allow for $n symbols to be longer than the constants they replace.
216 	 * Constants must take at least one byte in text form, while a $n symbol
217 	 * certainly isn't more than 11 bytes, even if n reaches INT_MAX.  We
218 	 * could refine that limit based on the max value of n for the current
219 	 * query, but it hardly seems worth any extra effort to do so.
220 	 */
221 	norm_query_buflen = query_len + jstate->clocations_count * 10;
222 
223 	/* Allocate result buffer */
224 	norm_query = palloc(norm_query_buflen + 1);
225 
226 	for (i = 0; i < jstate->clocations_count; i++)
227 	{
228 		int			off,		/* Offset from start for cur tok */
229 					tok_len;	/* Length (in bytes) of that tok */
230 
231 		off = jstate->clocations[i].location;
232 		/* Adjust recorded location if we're dealing with partial string */
233 		off -= query_loc;
234 
235 		tok_len = jstate->clocations[i].length;
236 
237 		if (tok_len < 0)
238 			continue;			/* ignore any duplicates */
239 
240 		/* Copy next chunk (what precedes the next constant) */
241 		len_to_wrt = off - last_off;
242 		len_to_wrt -= last_tok_len;
243 
244 		Assert(len_to_wrt >= 0);
245 		memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt);
246 		n_quer_loc += len_to_wrt;
247 
248 		/* And insert a param symbol in place of the constant token */
249 		n_quer_loc += sprintf(norm_query + n_quer_loc, "$%d",
250 							  i + 1 + jstate->highest_extern_param_id);
251 
252 		quer_loc = off + tok_len;
253 		last_off = off;
254 		last_tok_len = tok_len;
255 	}
256 
257 	/*
258 	 * We've copied up until the last ignorable constant.  Copy over the
259 	 * remaining bytes of the original query string.
260 	 */
261 	len_to_wrt = query_len - quer_loc;
262 
263 	Assert(len_to_wrt >= 0);
264 	memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt);
265 	n_quer_loc += len_to_wrt;
266 
267 	Assert(n_quer_loc <= norm_query_buflen);
268 	norm_query[n_quer_loc] = '\0';
269 
270 	*query_len_p = n_quer_loc;
271 	return norm_query;
272 }
273 
RecordConstLocation(pgssConstLocations * jstate,int location)274 static void RecordConstLocation(pgssConstLocations *jstate, int location)
275 {
276 	/* -1 indicates unknown or undefined location */
277 	if (location >= 0)
278 	{
279 		/* enlarge array if needed */
280 		if (jstate->clocations_count >= jstate->clocations_buf_size)
281 		{
282 			jstate->clocations_buf_size *= 2;
283 			jstate->clocations = (pgssLocationLen *)
284 				repalloc(jstate->clocations,
285 						 jstate->clocations_buf_size *
286 						 sizeof(pgssLocationLen));
287 		}
288 		jstate->clocations[jstate->clocations_count].location = location;
289 		/* initialize lengths to -1 to simplify fill_in_constant_lengths */
290 		jstate->clocations[jstate->clocations_count].length = -1;
291 		jstate->clocations_count++;
292 	}
293 }
294 
const_record_walker(Node * node,pgssConstLocations * jstate)295 static bool const_record_walker(Node *node, pgssConstLocations *jstate)
296 {
297 	bool result;
298 
299 	if (node == NULL) return false;
300 
301 	if (IsA(node, A_Const))
302 	{
303 		RecordConstLocation(jstate, castNode(A_Const, node)->location);
304 	}
305 	else if (IsA(node, ParamRef))
306 	{
307 		/* Track the highest ParamRef number */
308 		if (((ParamRef *) node)->number > jstate->highest_extern_param_id)
309 			jstate->highest_extern_param_id = castNode(ParamRef, node)->number;
310 	}
311 	else if (IsA(node, DefElem))
312 	{
313 		return const_record_walker((Node *) ((DefElem *) node)->arg, jstate);
314 	}
315 	else if (IsA(node, RawStmt))
316 	{
317 		return const_record_walker((Node *) ((RawStmt *) node)->stmt, jstate);
318 	}
319 	else if (IsA(node, VariableSetStmt))
320 	{
321 		return const_record_walker((Node *) ((VariableSetStmt *) node)->args, jstate);
322 	}
323 	else if (IsA(node, CopyStmt))
324 	{
325 		return const_record_walker((Node *) ((CopyStmt *) node)->query, jstate);
326 	}
327 	else if (IsA(node, ExplainStmt))
328 	{
329 		return const_record_walker((Node *) ((ExplainStmt *) node)->query, jstate);
330 	}
331 	else if (IsA(node, AlterRoleStmt))
332 	{
333 		return const_record_walker((Node *) ((AlterRoleStmt *) node)->options, jstate);
334 	}
335 	else if (IsA(node, DeclareCursorStmt))
336 	{
337 		return const_record_walker((Node *) ((DeclareCursorStmt *) node)->query, jstate);
338 	}
339 
340 	PG_TRY();
341 	{
342 		result = raw_expression_tree_walker(node, const_record_walker, (void*) jstate);
343 	}
344 	PG_CATCH();
345 	{
346 		FlushErrorState();
347 		result = false;
348 	}
349 	PG_END_TRY();
350 
351 	return result;
352 }
353 
pg_query_normalize(const char * input)354 PgQueryNormalizeResult pg_query_normalize(const char* input)
355 {
356 	MemoryContext ctx = NULL;
357 	PgQueryNormalizeResult result = {0};
358 
359 	ctx = pg_query_enter_memory_context("pg_query_normalize");
360 
361 	PG_TRY();
362 	{
363 		List *tree;
364 		pgssConstLocations jstate;
365 		int query_len;
366 
367 		/* Parse query */
368 		tree = raw_parser(input);
369 
370 		/* Set up workspace for constant recording */
371 		jstate.clocations_buf_size = 32;
372 		jstate.clocations = (pgssLocationLen *)
373 			palloc(jstate.clocations_buf_size * sizeof(pgssLocationLen));
374 		jstate.clocations_count = 0;
375 		jstate.highest_extern_param_id = 0;
376 
377 		/* Walk tree and record const locations */
378 		const_record_walker((Node *) tree, &jstate);
379 
380 		/* Normalize query */
381 		query_len = (int) strlen(input);
382 		result.normalized_query = strdup(generate_normalized_query(&jstate, input, 0, &query_len, PG_UTF8));
383 	}
384 	PG_CATCH();
385 	{
386 		ErrorData* error_data;
387 		PgQueryError* error;
388 
389 		MemoryContextSwitchTo(ctx);
390 		error_data = CopyErrorData();
391 
392 		error = malloc(sizeof(PgQueryError));
393 		error->message   = strdup(error_data->message);
394 		error->filename  = strdup(error_data->filename);
395 		error->lineno    = error_data->lineno;
396 		error->cursorpos = error_data->cursorpos;
397 
398 		result.error = error;
399 		FlushErrorState();
400 	}
401 	PG_END_TRY();
402 
403 	pg_query_exit_memory_context(ctx);
404 
405 	return result;
406 }
407 
pg_query_free_normalize_result(PgQueryNormalizeResult result)408 void pg_query_free_normalize_result(PgQueryNormalizeResult result)
409 {
410   if (result.error) {
411     free(result.error->message);
412     free(result.error->filename);
413     free(result.error);
414   }
415 
416   free(result.normalized_query);
417 }
418