1 /*-------------------------------------------------------------------------
2  *
3  * parse_agg.c
4  *	  handle aggregates and window functions in parser
5  *
6  * Portions Copyright (c) 1996-2017, PostgreSQL Global Development Group
7  * Portions Copyright (c) 1994, Regents of the University of California
8  *
9  *
10  * IDENTIFICATION
11  *	  src/backend/parser/parse_agg.c
12  *
13  *-------------------------------------------------------------------------
14  */
15 #include "postgres.h"
16 
17 #include "catalog/pg_aggregate.h"
18 #include "catalog/pg_constraint_fn.h"
19 #include "catalog/pg_type.h"
20 #include "nodes/makefuncs.h"
21 #include "nodes/nodeFuncs.h"
22 #include "optimizer/tlist.h"
23 #include "optimizer/var.h"
24 #include "parser/parse_agg.h"
25 #include "parser/parse_clause.h"
26 #include "parser/parse_coerce.h"
27 #include "parser/parse_expr.h"
28 #include "parser/parsetree.h"
29 #include "rewrite/rewriteManip.h"
30 #include "utils/builtins.h"
31 #include "utils/lsyscache.h"
32 
33 
34 typedef struct
35 {
36 	ParseState *pstate;
37 	int			min_varlevel;
38 	int			min_agglevel;
39 	int			sublevels_up;
40 } check_agg_arguments_context;
41 
42 typedef struct
43 {
44 	ParseState *pstate;
45 	Query	   *qry;
46 	PlannerInfo *root;
47 	List	   *groupClauses;
48 	List	   *groupClauseCommonVars;
49 	bool		have_non_var_grouping;
50 	List	  **func_grouped_rels;
51 	int			sublevels_up;
52 	bool		in_agg_direct_args;
53 } check_ungrouped_columns_context;
54 
55 static int check_agg_arguments(ParseState *pstate,
56 					List *directargs,
57 					List *args,
58 					Expr *filter);
59 static bool check_agg_arguments_walker(Node *node,
60 						   check_agg_arguments_context *context);
61 static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
62 						List *groupClauses, List *groupClauseVars,
63 						bool have_non_var_grouping,
64 						List **func_grouped_rels);
65 static bool check_ungrouped_columns_walker(Node *node,
66 							   check_ungrouped_columns_context *context);
67 static void finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
68 						List *groupClauses, PlannerInfo *root,
69 						bool have_non_var_grouping);
70 static bool finalize_grouping_exprs_walker(Node *node,
71 							   check_ungrouped_columns_context *context);
72 static void check_agglevels_and_constraints(ParseState *pstate, Node *expr);
73 static List *expand_groupingset_node(GroupingSet *gs);
74 static Node *make_agg_arg(Oid argtype, Oid argcollation);
75 
76 
77 /*
78  * transformAggregateCall -
79  *		Finish initial transformation of an aggregate call
80  *
81  * parse_func.c has recognized the function as an aggregate, and has set up
82  * all the fields of the Aggref except aggargtypes, aggdirectargs, args,
83  * aggorder, aggdistinct and agglevelsup.  The passed-in args list has been
84  * through standard expression transformation and type coercion to match the
85  * agg's declared arg types, while the passed-in aggorder list hasn't been
86  * transformed at all.
87  *
88  * Here we separate the args list into direct and aggregated args, storing the
89  * former in agg->aggdirectargs and the latter in agg->args.  The regular
90  * args, but not the direct args, are converted into a targetlist by inserting
91  * TargetEntry nodes.  We then transform the aggorder and agg_distinct
92  * specifications to produce lists of SortGroupClause nodes for agg->aggorder
93  * and agg->aggdistinct.  (For a regular aggregate, this might result in
94  * adding resjunk expressions to the targetlist; but for ordered-set
95  * aggregates the aggorder list will always be one-to-one with the aggregated
96  * args.)
97  *
98  * We must also determine which query level the aggregate actually belongs to,
99  * set agglevelsup accordingly, and mark p_hasAggs true in the corresponding
100  * pstate level.
101  */
102 void
transformAggregateCall(ParseState * pstate,Aggref * agg,List * args,List * aggorder,bool agg_distinct)103 transformAggregateCall(ParseState *pstate, Aggref *agg,
104 					   List *args, List *aggorder, bool agg_distinct)
105 {
106 	List	   *argtypes = NIL;
107 	List	   *tlist = NIL;
108 	List	   *torder = NIL;
109 	List	   *tdistinct = NIL;
110 	AttrNumber	attno = 1;
111 	int			save_next_resno;
112 	ListCell   *lc;
113 
114 	/*
115 	 * Before separating the args into direct and aggregated args, make a list
116 	 * of their data type OIDs for use later.
117 	 */
118 	foreach(lc, args)
119 	{
120 		Expr	   *arg = (Expr *) lfirst(lc);
121 
122 		argtypes = lappend_oid(argtypes, exprType((Node *) arg));
123 	}
124 	agg->aggargtypes = argtypes;
125 
126 	if (AGGKIND_IS_ORDERED_SET(agg->aggkind))
127 	{
128 		/*
129 		 * For an ordered-set agg, the args list includes direct args and
130 		 * aggregated args; we must split them apart.
131 		 */
132 		int			numDirectArgs = list_length(args) - list_length(aggorder);
133 		List	   *aargs;
134 		ListCell   *lc2;
135 
136 		Assert(numDirectArgs >= 0);
137 
138 		aargs = list_copy_tail(args, numDirectArgs);
139 		agg->aggdirectargs = list_truncate(args, numDirectArgs);
140 
141 		/*
142 		 * Build a tlist from the aggregated args, and make a sortlist entry
143 		 * for each one.  Note that the expressions in the SortBy nodes are
144 		 * ignored (they are the raw versions of the transformed args); we are
145 		 * just looking at the sort information in the SortBy nodes.
146 		 */
147 		forboth(lc, aargs, lc2, aggorder)
148 		{
149 			Expr	   *arg = (Expr *) lfirst(lc);
150 			SortBy	   *sortby = (SortBy *) lfirst(lc2);
151 			TargetEntry *tle;
152 
153 			/* We don't bother to assign column names to the entries */
154 			tle = makeTargetEntry(arg, attno++, NULL, false);
155 			tlist = lappend(tlist, tle);
156 
157 			torder = addTargetToSortList(pstate, tle,
158 										 torder, tlist, sortby);
159 		}
160 
161 		/* Never any DISTINCT in an ordered-set agg */
162 		Assert(!agg_distinct);
163 	}
164 	else
165 	{
166 		/* Regular aggregate, so it has no direct args */
167 		agg->aggdirectargs = NIL;
168 
169 		/*
170 		 * Transform the plain list of Exprs into a targetlist.
171 		 */
172 		foreach(lc, args)
173 		{
174 			Expr	   *arg = (Expr *) lfirst(lc);
175 			TargetEntry *tle;
176 
177 			/* We don't bother to assign column names to the entries */
178 			tle = makeTargetEntry(arg, attno++, NULL, false);
179 			tlist = lappend(tlist, tle);
180 		}
181 
182 		/*
183 		 * If we have an ORDER BY, transform it.  This will add columns to the
184 		 * tlist if they appear in ORDER BY but weren't already in the arg
185 		 * list.  They will be marked resjunk = true so we can tell them apart
186 		 * from regular aggregate arguments later.
187 		 *
188 		 * We need to mess with p_next_resno since it will be used to number
189 		 * any new targetlist entries.
190 		 */
191 		save_next_resno = pstate->p_next_resno;
192 		pstate->p_next_resno = attno;
193 
194 		torder = transformSortClause(pstate,
195 									 aggorder,
196 									 &tlist,
197 									 EXPR_KIND_ORDER_BY,
198 									 true /* force SQL99 rules */ );
199 
200 		/*
201 		 * If we have DISTINCT, transform that to produce a distinctList.
202 		 */
203 		if (agg_distinct)
204 		{
205 			tdistinct = transformDistinctClause(pstate, &tlist, torder, true);
206 
207 			/*
208 			 * Remove this check if executor support for hashed distinct for
209 			 * aggregates is ever added.
210 			 */
211 			foreach(lc, tdistinct)
212 			{
213 				SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
214 
215 				if (!OidIsValid(sortcl->sortop))
216 				{
217 					Node	   *expr = get_sortgroupclause_expr(sortcl, tlist);
218 
219 					ereport(ERROR,
220 							(errcode(ERRCODE_UNDEFINED_FUNCTION),
221 							 errmsg("could not identify an ordering operator for type %s",
222 									format_type_be(exprType(expr))),
223 							 errdetail("Aggregates with DISTINCT must be able to sort their inputs."),
224 							 parser_errposition(pstate, exprLocation(expr))));
225 				}
226 			}
227 		}
228 
229 		pstate->p_next_resno = save_next_resno;
230 	}
231 
232 	/* Update the Aggref with the transformation results */
233 	agg->args = tlist;
234 	agg->aggorder = torder;
235 	agg->aggdistinct = tdistinct;
236 
237 	check_agglevels_and_constraints(pstate, (Node *) agg);
238 }
239 
240 /*
241  * transformGroupingFunc
242  *		Transform a GROUPING expression
243  *
244  * GROUPING() behaves very like an aggregate.  Processing of levels and nesting
245  * is done as for aggregates.  We set p_hasAggs for these expressions too.
246  */
247 Node *
transformGroupingFunc(ParseState * pstate,GroupingFunc * p)248 transformGroupingFunc(ParseState *pstate, GroupingFunc *p)
249 {
250 	ListCell   *lc;
251 	List	   *args = p->args;
252 	List	   *result_list = NIL;
253 	GroupingFunc *result = makeNode(GroupingFunc);
254 
255 	if (list_length(args) > 31)
256 		ereport(ERROR,
257 				(errcode(ERRCODE_TOO_MANY_ARGUMENTS),
258 				 errmsg("GROUPING must have fewer than 32 arguments"),
259 				 parser_errposition(pstate, p->location)));
260 
261 	foreach(lc, args)
262 	{
263 		Node	   *current_result;
264 
265 		current_result = transformExpr(pstate, (Node *) lfirst(lc), pstate->p_expr_kind);
266 
267 		/* acceptability of expressions is checked later */
268 
269 		result_list = lappend(result_list, current_result);
270 	}
271 
272 	result->args = result_list;
273 	result->location = p->location;
274 
275 	check_agglevels_and_constraints(pstate, (Node *) result);
276 
277 	return (Node *) result;
278 }
279 
280 /*
281  * Aggregate functions and grouping operations (which are combined in the spec
282  * as <set function specification>) are very similar with regard to level and
283  * nesting restrictions (though we allow a lot more things than the spec does).
284  * Centralise those restrictions here.
285  */
286 static void
check_agglevels_and_constraints(ParseState * pstate,Node * expr)287 check_agglevels_and_constraints(ParseState *pstate, Node *expr)
288 {
289 	List	   *directargs = NIL;
290 	List	   *args = NIL;
291 	Expr	   *filter = NULL;
292 	int			min_varlevel;
293 	int			location = -1;
294 	Index	   *p_levelsup;
295 	const char *err;
296 	bool		errkind;
297 	bool		isAgg = IsA(expr, Aggref);
298 
299 	if (isAgg)
300 	{
301 		Aggref	   *agg = (Aggref *) expr;
302 
303 		directargs = agg->aggdirectargs;
304 		args = agg->args;
305 		filter = agg->aggfilter;
306 		location = agg->location;
307 		p_levelsup = &agg->agglevelsup;
308 	}
309 	else
310 	{
311 		GroupingFunc *grp = (GroupingFunc *) expr;
312 
313 		args = grp->args;
314 		location = grp->location;
315 		p_levelsup = &grp->agglevelsup;
316 	}
317 
318 	/*
319 	 * Check the arguments to compute the aggregate's level and detect
320 	 * improper nesting.
321 	 */
322 	min_varlevel = check_agg_arguments(pstate,
323 									   directargs,
324 									   args,
325 									   filter);
326 
327 	*p_levelsup = min_varlevel;
328 
329 	/* Mark the correct pstate level as having aggregates */
330 	while (min_varlevel-- > 0)
331 		pstate = pstate->parentParseState;
332 	pstate->p_hasAggs = true;
333 
334 	/*
335 	 * Check to see if the aggregate function is in an invalid place within
336 	 * its aggregation query.
337 	 *
338 	 * For brevity we support two schemes for reporting an error here: set
339 	 * "err" to a custom message, or set "errkind" true if the error context
340 	 * is sufficiently identified by what ParseExprKindName will return, *and*
341 	 * what it will return is just a SQL keyword.  (Otherwise, use a custom
342 	 * message to avoid creating translation problems.)
343 	 */
344 	err = NULL;
345 	errkind = false;
346 	switch (pstate->p_expr_kind)
347 	{
348 		case EXPR_KIND_NONE:
349 			Assert(false);		/* can't happen */
350 			break;
351 		case EXPR_KIND_OTHER:
352 
353 			/*
354 			 * Accept aggregate/grouping here; caller must throw error if
355 			 * wanted
356 			 */
357 			break;
358 		case EXPR_KIND_JOIN_ON:
359 		case EXPR_KIND_JOIN_USING:
360 			if (isAgg)
361 				err = _("aggregate functions are not allowed in JOIN conditions");
362 			else
363 				err = _("grouping operations are not allowed in JOIN conditions");
364 
365 			break;
366 		case EXPR_KIND_FROM_SUBSELECT:
367 			/* Should only be possible in a LATERAL subquery */
368 			Assert(pstate->p_lateral_active);
369 
370 			/*
371 			 * Aggregate/grouping scope rules make it worth being explicit
372 			 * here
373 			 */
374 			if (isAgg)
375 				err = _("aggregate functions are not allowed in FROM clause of their own query level");
376 			else
377 				err = _("grouping operations are not allowed in FROM clause of their own query level");
378 
379 			break;
380 		case EXPR_KIND_FROM_FUNCTION:
381 			if (isAgg)
382 				err = _("aggregate functions are not allowed in functions in FROM");
383 			else
384 				err = _("grouping operations are not allowed in functions in FROM");
385 
386 			break;
387 		case EXPR_KIND_WHERE:
388 			errkind = true;
389 			break;
390 		case EXPR_KIND_POLICY:
391 			if (isAgg)
392 				err = _("aggregate functions are not allowed in policy expressions");
393 			else
394 				err = _("grouping operations are not allowed in policy expressions");
395 
396 			break;
397 		case EXPR_KIND_HAVING:
398 			/* okay */
399 			break;
400 		case EXPR_KIND_FILTER:
401 			errkind = true;
402 			break;
403 		case EXPR_KIND_WINDOW_PARTITION:
404 			/* okay */
405 			break;
406 		case EXPR_KIND_WINDOW_ORDER:
407 			/* okay */
408 			break;
409 		case EXPR_KIND_WINDOW_FRAME_RANGE:
410 			if (isAgg)
411 				err = _("aggregate functions are not allowed in window RANGE");
412 			else
413 				err = _("grouping operations are not allowed in window RANGE");
414 
415 			break;
416 		case EXPR_KIND_WINDOW_FRAME_ROWS:
417 			if (isAgg)
418 				err = _("aggregate functions are not allowed in window ROWS");
419 			else
420 				err = _("grouping operations are not allowed in window ROWS");
421 
422 			break;
423 		case EXPR_KIND_SELECT_TARGET:
424 			/* okay */
425 			break;
426 		case EXPR_KIND_INSERT_TARGET:
427 		case EXPR_KIND_UPDATE_SOURCE:
428 		case EXPR_KIND_UPDATE_TARGET:
429 			errkind = true;
430 			break;
431 		case EXPR_KIND_GROUP_BY:
432 			errkind = true;
433 			break;
434 		case EXPR_KIND_ORDER_BY:
435 			/* okay */
436 			break;
437 		case EXPR_KIND_DISTINCT_ON:
438 			/* okay */
439 			break;
440 		case EXPR_KIND_LIMIT:
441 		case EXPR_KIND_OFFSET:
442 			errkind = true;
443 			break;
444 		case EXPR_KIND_RETURNING:
445 			errkind = true;
446 			break;
447 		case EXPR_KIND_VALUES:
448 		case EXPR_KIND_VALUES_SINGLE:
449 			errkind = true;
450 			break;
451 		case EXPR_KIND_CHECK_CONSTRAINT:
452 		case EXPR_KIND_DOMAIN_CHECK:
453 			if (isAgg)
454 				err = _("aggregate functions are not allowed in check constraints");
455 			else
456 				err = _("grouping operations are not allowed in check constraints");
457 
458 			break;
459 		case EXPR_KIND_COLUMN_DEFAULT:
460 		case EXPR_KIND_FUNCTION_DEFAULT:
461 
462 			if (isAgg)
463 				err = _("aggregate functions are not allowed in DEFAULT expressions");
464 			else
465 				err = _("grouping operations are not allowed in DEFAULT expressions");
466 
467 			break;
468 		case EXPR_KIND_INDEX_EXPRESSION:
469 			if (isAgg)
470 				err = _("aggregate functions are not allowed in index expressions");
471 			else
472 				err = _("grouping operations are not allowed in index expressions");
473 
474 			break;
475 		case EXPR_KIND_INDEX_PREDICATE:
476 			if (isAgg)
477 				err = _("aggregate functions are not allowed in index predicates");
478 			else
479 				err = _("grouping operations are not allowed in index predicates");
480 
481 			break;
482 		case EXPR_KIND_ALTER_COL_TRANSFORM:
483 			if (isAgg)
484 				err = _("aggregate functions are not allowed in transform expressions");
485 			else
486 				err = _("grouping operations are not allowed in transform expressions");
487 
488 			break;
489 		case EXPR_KIND_EXECUTE_PARAMETER:
490 			if (isAgg)
491 				err = _("aggregate functions are not allowed in EXECUTE parameters");
492 			else
493 				err = _("grouping operations are not allowed in EXECUTE parameters");
494 
495 			break;
496 		case EXPR_KIND_TRIGGER_WHEN:
497 			if (isAgg)
498 				err = _("aggregate functions are not allowed in trigger WHEN conditions");
499 			else
500 				err = _("grouping operations are not allowed in trigger WHEN conditions");
501 
502 			break;
503 		case EXPR_KIND_PARTITION_EXPRESSION:
504 			if (isAgg)
505 				err = _("aggregate functions are not allowed in partition key expression");
506 			else
507 				err = _("grouping operations are not allowed in partition key expression");
508 
509 			break;
510 
511 			/*
512 			 * There is intentionally no default: case here, so that the
513 			 * compiler will warn if we add a new ParseExprKind without
514 			 * extending this switch.  If we do see an unrecognized value at
515 			 * runtime, the behavior will be the same as for EXPR_KIND_OTHER,
516 			 * which is sane anyway.
517 			 */
518 	}
519 
520 	if (err)
521 		ereport(ERROR,
522 				(errcode(ERRCODE_GROUPING_ERROR),
523 				 errmsg_internal("%s", err),
524 				 parser_errposition(pstate, location)));
525 
526 	if (errkind)
527 	{
528 		if (isAgg)
529 			/* translator: %s is name of a SQL construct, eg GROUP BY */
530 			err = _("aggregate functions are not allowed in %s");
531 		else
532 			/* translator: %s is name of a SQL construct, eg GROUP BY */
533 			err = _("grouping operations are not allowed in %s");
534 
535 		ereport(ERROR,
536 				(errcode(ERRCODE_GROUPING_ERROR),
537 				 errmsg_internal(err,
538 								 ParseExprKindName(pstate->p_expr_kind)),
539 				 parser_errposition(pstate, location)));
540 	}
541 }
542 
543 /*
544  * check_agg_arguments
545  *	  Scan the arguments of an aggregate function to determine the
546  *	  aggregate's semantic level (zero is the current select's level,
547  *	  one is its parent, etc).
548  *
549  * The aggregate's level is the same as the level of the lowest-level variable
550  * or aggregate in its aggregated arguments (including any ORDER BY columns)
551  * or filter expression; or if it contains no variables at all, we presume it
552  * to be local.
553  *
554  * Vars/Aggs in direct arguments are *not* counted towards determining the
555  * agg's level, as those arguments aren't evaluated per-row but only
556  * per-group, and so in some sense aren't really agg arguments.  However,
557  * this can mean that we decide an agg is upper-level even when its direct
558  * args contain lower-level Vars/Aggs, and that case has to be disallowed.
559  * (This is a little strange, but the SQL standard seems pretty definite that
560  * direct args are not to be considered when setting the agg's level.)
561  *
562  * We also take this opportunity to detect any aggregates or window functions
563  * nested within the arguments.  We can throw error immediately if we find
564  * a window function.  Aggregates are a bit trickier because it's only an
565  * error if the inner aggregate is of the same semantic level as the outer,
566  * which we can't know until we finish scanning the arguments.
567  */
568 static int
check_agg_arguments(ParseState * pstate,List * directargs,List * args,Expr * filter)569 check_agg_arguments(ParseState *pstate,
570 					List *directargs,
571 					List *args,
572 					Expr *filter)
573 {
574 	int			agglevel;
575 	check_agg_arguments_context context;
576 
577 	context.pstate = pstate;
578 	context.min_varlevel = -1;	/* signifies nothing found yet */
579 	context.min_agglevel = -1;
580 	context.sublevels_up = 0;
581 
582 	(void) check_agg_arguments_walker((Node *) args, &context);
583 	(void) check_agg_arguments_walker((Node *) filter, &context);
584 
585 	/*
586 	 * If we found no vars nor aggs at all, it's a level-zero aggregate;
587 	 * otherwise, its level is the minimum of vars or aggs.
588 	 */
589 	if (context.min_varlevel < 0)
590 	{
591 		if (context.min_agglevel < 0)
592 			agglevel = 0;
593 		else
594 			agglevel = context.min_agglevel;
595 	}
596 	else if (context.min_agglevel < 0)
597 		agglevel = context.min_varlevel;
598 	else
599 		agglevel = Min(context.min_varlevel, context.min_agglevel);
600 
601 	/*
602 	 * If there's a nested aggregate of the same semantic level, complain.
603 	 */
604 	if (agglevel == context.min_agglevel)
605 	{
606 		int			aggloc;
607 
608 		aggloc = locate_agg_of_level((Node *) args, agglevel);
609 		if (aggloc < 0)
610 			aggloc = locate_agg_of_level((Node *) filter, agglevel);
611 		ereport(ERROR,
612 				(errcode(ERRCODE_GROUPING_ERROR),
613 				 errmsg("aggregate function calls cannot be nested"),
614 				 parser_errposition(pstate, aggloc)));
615 	}
616 
617 	/*
618 	 * Now check for vars/aggs in the direct arguments, and throw error if
619 	 * needed.  Note that we allow a Var of the agg's semantic level, but not
620 	 * an Agg of that level.  In principle such Aggs could probably be
621 	 * supported, but it would create an ordering dependency among the
622 	 * aggregates at execution time.  Since the case appears neither to be
623 	 * required by spec nor particularly useful, we just treat it as a
624 	 * nested-aggregate situation.
625 	 */
626 	if (directargs)
627 	{
628 		context.min_varlevel = -1;
629 		context.min_agglevel = -1;
630 		(void) check_agg_arguments_walker((Node *) directargs, &context);
631 		if (context.min_varlevel >= 0 && context.min_varlevel < agglevel)
632 			ereport(ERROR,
633 					(errcode(ERRCODE_GROUPING_ERROR),
634 					 errmsg("outer-level aggregate cannot contain a lower-level variable in its direct arguments"),
635 					 parser_errposition(pstate,
636 										locate_var_of_level((Node *) directargs,
637 															context.min_varlevel))));
638 		if (context.min_agglevel >= 0 && context.min_agglevel <= agglevel)
639 			ereport(ERROR,
640 					(errcode(ERRCODE_GROUPING_ERROR),
641 					 errmsg("aggregate function calls cannot be nested"),
642 					 parser_errposition(pstate,
643 										locate_agg_of_level((Node *) directargs,
644 															context.min_agglevel))));
645 	}
646 	return agglevel;
647 }
648 
649 static bool
check_agg_arguments_walker(Node * node,check_agg_arguments_context * context)650 check_agg_arguments_walker(Node *node,
651 						   check_agg_arguments_context *context)
652 {
653 	if (node == NULL)
654 		return false;
655 	if (IsA(node, Var))
656 	{
657 		int			varlevelsup = ((Var *) node)->varlevelsup;
658 
659 		/* convert levelsup to frame of reference of original query */
660 		varlevelsup -= context->sublevels_up;
661 		/* ignore local vars of subqueries */
662 		if (varlevelsup >= 0)
663 		{
664 			if (context->min_varlevel < 0 ||
665 				context->min_varlevel > varlevelsup)
666 				context->min_varlevel = varlevelsup;
667 		}
668 		return false;
669 	}
670 	if (IsA(node, Aggref))
671 	{
672 		int			agglevelsup = ((Aggref *) node)->agglevelsup;
673 
674 		/* convert levelsup to frame of reference of original query */
675 		agglevelsup -= context->sublevels_up;
676 		/* ignore local aggs of subqueries */
677 		if (agglevelsup >= 0)
678 		{
679 			if (context->min_agglevel < 0 ||
680 				context->min_agglevel > agglevelsup)
681 				context->min_agglevel = agglevelsup;
682 		}
683 		/* no need to examine args of the inner aggregate */
684 		return false;
685 	}
686 	if (IsA(node, GroupingFunc))
687 	{
688 		int			agglevelsup = ((GroupingFunc *) node)->agglevelsup;
689 
690 		/* convert levelsup to frame of reference of original query */
691 		agglevelsup -= context->sublevels_up;
692 		/* ignore local aggs of subqueries */
693 		if (agglevelsup >= 0)
694 		{
695 			if (context->min_agglevel < 0 ||
696 				context->min_agglevel > agglevelsup)
697 				context->min_agglevel = agglevelsup;
698 		}
699 		/* Continue and descend into subtree */
700 	}
701 
702 	/*
703 	 * SRFs and window functions can be rejected immediately, unless we are
704 	 * within a sub-select within the aggregate's arguments; in that case
705 	 * they're OK.
706 	 */
707 	if (context->sublevels_up == 0)
708 	{
709 		if ((IsA(node, FuncExpr) &&((FuncExpr *) node)->funcretset) ||
710 			(IsA(node, OpExpr) &&((OpExpr *) node)->opretset))
711 			ereport(ERROR,
712 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
713 					 errmsg("aggregate function calls cannot contain set-returning function calls"),
714 					 errhint("You might be able to move the set-returning function into a LATERAL FROM item."),
715 					 parser_errposition(context->pstate, exprLocation(node))));
716 		if (IsA(node, WindowFunc))
717 			ereport(ERROR,
718 					(errcode(ERRCODE_GROUPING_ERROR),
719 					 errmsg("aggregate function calls cannot contain window function calls"),
720 					 parser_errposition(context->pstate,
721 										((WindowFunc *) node)->location)));
722 	}
723 	if (IsA(node, Query))
724 	{
725 		/* Recurse into subselects */
726 		bool		result;
727 
728 		context->sublevels_up++;
729 		result = query_tree_walker((Query *) node,
730 								   check_agg_arguments_walker,
731 								   (void *) context,
732 								   0);
733 		context->sublevels_up--;
734 		return result;
735 	}
736 
737 	return expression_tree_walker(node,
738 								  check_agg_arguments_walker,
739 								  (void *) context);
740 }
741 
742 /*
743  * transformWindowFuncCall -
744  *		Finish initial transformation of a window function call
745  *
746  * parse_func.c has recognized the function as a window function, and has set
747  * up all the fields of the WindowFunc except winref.  Here we must (1) add
748  * the WindowDef to the pstate (if not a duplicate of one already present) and
749  * set winref to link to it; and (2) mark p_hasWindowFuncs true in the pstate.
750  * Unlike aggregates, only the most closely nested pstate level need be
751  * considered --- there are no "outer window functions" per SQL spec.
752  */
753 void
transformWindowFuncCall(ParseState * pstate,WindowFunc * wfunc,WindowDef * windef)754 transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc,
755 						WindowDef *windef)
756 {
757 	const char *err;
758 	bool		errkind;
759 
760 	/*
761 	 * A window function call can't contain another one (but aggs are OK). XXX
762 	 * is this required by spec, or just an unimplemented feature?
763 	 *
764 	 * Note: we don't need to check the filter expression here, because the
765 	 * context checks done below and in transformAggregateCall would have
766 	 * already rejected any window funcs or aggs within the filter.
767 	 */
768 	if (pstate->p_hasWindowFuncs &&
769 		contain_windowfuncs((Node *) wfunc->args))
770 		ereport(ERROR,
771 				(errcode(ERRCODE_WINDOWING_ERROR),
772 				 errmsg("window function calls cannot be nested"),
773 				 parser_errposition(pstate,
774 									locate_windowfunc((Node *) wfunc->args))));
775 
776 	/*
777 	 * Check to see if the window function is in an invalid place within the
778 	 * query.
779 	 *
780 	 * For brevity we support two schemes for reporting an error here: set
781 	 * "err" to a custom message, or set "errkind" true if the error context
782 	 * is sufficiently identified by what ParseExprKindName will return, *and*
783 	 * what it will return is just a SQL keyword.  (Otherwise, use a custom
784 	 * message to avoid creating translation problems.)
785 	 */
786 	err = NULL;
787 	errkind = false;
788 	switch (pstate->p_expr_kind)
789 	{
790 		case EXPR_KIND_NONE:
791 			Assert(false);		/* can't happen */
792 			break;
793 		case EXPR_KIND_OTHER:
794 			/* Accept window func here; caller must throw error if wanted */
795 			break;
796 		case EXPR_KIND_JOIN_ON:
797 		case EXPR_KIND_JOIN_USING:
798 			err = _("window functions are not allowed in JOIN conditions");
799 			break;
800 		case EXPR_KIND_FROM_SUBSELECT:
801 			/* can't get here, but just in case, throw an error */
802 			errkind = true;
803 			break;
804 		case EXPR_KIND_FROM_FUNCTION:
805 			err = _("window functions are not allowed in functions in FROM");
806 			break;
807 		case EXPR_KIND_WHERE:
808 			errkind = true;
809 			break;
810 		case EXPR_KIND_POLICY:
811 			err = _("window functions are not allowed in policy expressions");
812 			break;
813 		case EXPR_KIND_HAVING:
814 			errkind = true;
815 			break;
816 		case EXPR_KIND_FILTER:
817 			errkind = true;
818 			break;
819 		case EXPR_KIND_WINDOW_PARTITION:
820 		case EXPR_KIND_WINDOW_ORDER:
821 		case EXPR_KIND_WINDOW_FRAME_RANGE:
822 		case EXPR_KIND_WINDOW_FRAME_ROWS:
823 			err = _("window functions are not allowed in window definitions");
824 			break;
825 		case EXPR_KIND_SELECT_TARGET:
826 			/* okay */
827 			break;
828 		case EXPR_KIND_INSERT_TARGET:
829 		case EXPR_KIND_UPDATE_SOURCE:
830 		case EXPR_KIND_UPDATE_TARGET:
831 			errkind = true;
832 			break;
833 		case EXPR_KIND_GROUP_BY:
834 			errkind = true;
835 			break;
836 		case EXPR_KIND_ORDER_BY:
837 			/* okay */
838 			break;
839 		case EXPR_KIND_DISTINCT_ON:
840 			/* okay */
841 			break;
842 		case EXPR_KIND_LIMIT:
843 		case EXPR_KIND_OFFSET:
844 			errkind = true;
845 			break;
846 		case EXPR_KIND_RETURNING:
847 			errkind = true;
848 			break;
849 		case EXPR_KIND_VALUES:
850 		case EXPR_KIND_VALUES_SINGLE:
851 			errkind = true;
852 			break;
853 		case EXPR_KIND_CHECK_CONSTRAINT:
854 		case EXPR_KIND_DOMAIN_CHECK:
855 			err = _("window functions are not allowed in check constraints");
856 			break;
857 		case EXPR_KIND_COLUMN_DEFAULT:
858 		case EXPR_KIND_FUNCTION_DEFAULT:
859 			err = _("window functions are not allowed in DEFAULT expressions");
860 			break;
861 		case EXPR_KIND_INDEX_EXPRESSION:
862 			err = _("window functions are not allowed in index expressions");
863 			break;
864 		case EXPR_KIND_INDEX_PREDICATE:
865 			err = _("window functions are not allowed in index predicates");
866 			break;
867 		case EXPR_KIND_ALTER_COL_TRANSFORM:
868 			err = _("window functions are not allowed in transform expressions");
869 			break;
870 		case EXPR_KIND_EXECUTE_PARAMETER:
871 			err = _("window functions are not allowed in EXECUTE parameters");
872 			break;
873 		case EXPR_KIND_TRIGGER_WHEN:
874 			err = _("window functions are not allowed in trigger WHEN conditions");
875 			break;
876 		case EXPR_KIND_PARTITION_EXPRESSION:
877 			err = _("window functions are not allowed in partition key expression");
878 			break;
879 
880 			/*
881 			 * There is intentionally no default: case here, so that the
882 			 * compiler will warn if we add a new ParseExprKind without
883 			 * extending this switch.  If we do see an unrecognized value at
884 			 * runtime, the behavior will be the same as for EXPR_KIND_OTHER,
885 			 * which is sane anyway.
886 			 */
887 	}
888 	if (err)
889 		ereport(ERROR,
890 				(errcode(ERRCODE_WINDOWING_ERROR),
891 				 errmsg_internal("%s", err),
892 				 parser_errposition(pstate, wfunc->location)));
893 	if (errkind)
894 		ereport(ERROR,
895 				(errcode(ERRCODE_WINDOWING_ERROR),
896 		/* translator: %s is name of a SQL construct, eg GROUP BY */
897 				 errmsg("window functions are not allowed in %s",
898 						ParseExprKindName(pstate->p_expr_kind)),
899 				 parser_errposition(pstate, wfunc->location)));
900 
901 	/*
902 	 * If the OVER clause just specifies a window name, find that WINDOW
903 	 * clause (which had better be present).  Otherwise, try to match all the
904 	 * properties of the OVER clause, and make a new entry in the p_windowdefs
905 	 * list if no luck.
906 	 */
907 	if (windef->name)
908 	{
909 		Index		winref = 0;
910 		ListCell   *lc;
911 
912 		Assert(windef->refname == NULL &&
913 			   windef->partitionClause == NIL &&
914 			   windef->orderClause == NIL &&
915 			   windef->frameOptions == FRAMEOPTION_DEFAULTS);
916 
917 		foreach(lc, pstate->p_windowdefs)
918 		{
919 			WindowDef  *refwin = (WindowDef *) lfirst(lc);
920 
921 			winref++;
922 			if (refwin->name && strcmp(refwin->name, windef->name) == 0)
923 			{
924 				wfunc->winref = winref;
925 				break;
926 			}
927 		}
928 		if (lc == NULL)			/* didn't find it? */
929 			ereport(ERROR,
930 					(errcode(ERRCODE_UNDEFINED_OBJECT),
931 					 errmsg("window \"%s\" does not exist", windef->name),
932 					 parser_errposition(pstate, windef->location)));
933 	}
934 	else
935 	{
936 		Index		winref = 0;
937 		ListCell   *lc;
938 
939 		foreach(lc, pstate->p_windowdefs)
940 		{
941 			WindowDef  *refwin = (WindowDef *) lfirst(lc);
942 
943 			winref++;
944 			if (refwin->refname && windef->refname &&
945 				strcmp(refwin->refname, windef->refname) == 0)
946 				 /* matched on refname */ ;
947 			else if (!refwin->refname && !windef->refname)
948 				 /* matched, no refname */ ;
949 			else
950 				continue;
951 			if (equal(refwin->partitionClause, windef->partitionClause) &&
952 				equal(refwin->orderClause, windef->orderClause) &&
953 				refwin->frameOptions == windef->frameOptions &&
954 				equal(refwin->startOffset, windef->startOffset) &&
955 				equal(refwin->endOffset, windef->endOffset))
956 			{
957 				/* found a duplicate window specification */
958 				wfunc->winref = winref;
959 				break;
960 			}
961 		}
962 		if (lc == NULL)			/* didn't find it? */
963 		{
964 			pstate->p_windowdefs = lappend(pstate->p_windowdefs, windef);
965 			wfunc->winref = list_length(pstate->p_windowdefs);
966 		}
967 	}
968 
969 	pstate->p_hasWindowFuncs = true;
970 }
971 
972 /*
973  * parseCheckAggregates
974  *	Check for aggregates where they shouldn't be and improper grouping.
975  *	This function should be called after the target list and qualifications
976  *	are finalized.
977  *
978  *	Misplaced aggregates are now mostly detected in transformAggregateCall,
979  *	but it seems more robust to check for aggregates in recursive queries
980  *	only after everything is finalized.  In any case it's hard to detect
981  *	improper grouping on-the-fly, so we have to make another pass over the
982  *	query for that.
983  */
984 void
parseCheckAggregates(ParseState * pstate,Query * qry)985 parseCheckAggregates(ParseState *pstate, Query *qry)
986 {
987 	List	   *gset_common = NIL;
988 	List	   *groupClauses = NIL;
989 	List	   *groupClauseCommonVars = NIL;
990 	bool		have_non_var_grouping;
991 	List	   *func_grouped_rels = NIL;
992 	ListCell   *l;
993 	bool		hasJoinRTEs;
994 	bool		hasSelfRefRTEs;
995 	PlannerInfo *root = NULL;
996 	Node	   *clause;
997 
998 	/* This should only be called if we found aggregates or grouping */
999 	Assert(pstate->p_hasAggs || qry->groupClause || qry->havingQual || qry->groupingSets);
1000 
1001 	/*
1002 	 * If we have grouping sets, expand them and find the intersection of all
1003 	 * sets.
1004 	 */
1005 	if (qry->groupingSets)
1006 	{
1007 		/*
1008 		 * The limit of 4096 is arbitrary and exists simply to avoid resource
1009 		 * issues from pathological constructs.
1010 		 */
1011 		List	   *gsets = expand_grouping_sets(qry->groupingSets, 4096);
1012 
1013 		if (!gsets)
1014 			ereport(ERROR,
1015 					(errcode(ERRCODE_STATEMENT_TOO_COMPLEX),
1016 					 errmsg("too many grouping sets present (maximum 4096)"),
1017 					 parser_errposition(pstate,
1018 										qry->groupClause
1019 										? exprLocation((Node *) qry->groupClause)
1020 										: exprLocation((Node *) qry->groupingSets))));
1021 
1022 		/*
1023 		 * The intersection will often be empty, so help things along by
1024 		 * seeding the intersect with the smallest set.
1025 		 */
1026 		gset_common = linitial(gsets);
1027 
1028 		if (gset_common)
1029 		{
1030 			for_each_cell(l, lnext(list_head(gsets)))
1031 			{
1032 				gset_common = list_intersection_int(gset_common, lfirst(l));
1033 				if (!gset_common)
1034 					break;
1035 			}
1036 		}
1037 
1038 		/*
1039 		 * If there was only one grouping set in the expansion, AND if the
1040 		 * groupClause is non-empty (meaning that the grouping set is not
1041 		 * empty either), then we can ditch the grouping set and pretend we
1042 		 * just had a normal GROUP BY.
1043 		 */
1044 		if (list_length(gsets) == 1 && qry->groupClause)
1045 			qry->groupingSets = NIL;
1046 	}
1047 
1048 	/*
1049 	 * Scan the range table to see if there are JOIN or self-reference CTE
1050 	 * entries.  We'll need this info below.
1051 	 */
1052 	hasJoinRTEs = hasSelfRefRTEs = false;
1053 	foreach(l, pstate->p_rtable)
1054 	{
1055 		RangeTblEntry *rte = (RangeTblEntry *) lfirst(l);
1056 
1057 		if (rte->rtekind == RTE_JOIN)
1058 			hasJoinRTEs = true;
1059 		else if (rte->rtekind == RTE_CTE && rte->self_reference)
1060 			hasSelfRefRTEs = true;
1061 	}
1062 
1063 	/*
1064 	 * Build a list of the acceptable GROUP BY expressions for use by
1065 	 * check_ungrouped_columns().
1066 	 *
1067 	 * We get the TLE, not just the expr, because GROUPING wants to know the
1068 	 * sortgroupref.
1069 	 */
1070 	foreach(l, qry->groupClause)
1071 	{
1072 		SortGroupClause *grpcl = (SortGroupClause *) lfirst(l);
1073 		TargetEntry *expr;
1074 
1075 		expr = get_sortgroupclause_tle(grpcl, qry->targetList);
1076 		if (expr == NULL)
1077 			continue;			/* probably cannot happen */
1078 
1079 		groupClauses = lcons(expr, groupClauses);
1080 	}
1081 
1082 	/*
1083 	 * If there are join alias vars involved, we have to flatten them to the
1084 	 * underlying vars, so that aliased and unaliased vars will be correctly
1085 	 * taken as equal.  We can skip the expense of doing this if no rangetable
1086 	 * entries are RTE_JOIN kind. We use the planner's flatten_join_alias_vars
1087 	 * routine to do the flattening; it wants a PlannerInfo root node, which
1088 	 * fortunately can be mostly dummy.
1089 	 */
1090 	if (hasJoinRTEs)
1091 	{
1092 		root = makeNode(PlannerInfo);
1093 		root->parse = qry;
1094 		root->planner_cxt = CurrentMemoryContext;
1095 		root->hasJoinRTEs = true;
1096 
1097 		groupClauses = (List *) flatten_join_alias_vars(root,
1098 														(Node *) groupClauses);
1099 	}
1100 
1101 	/*
1102 	 * Detect whether any of the grouping expressions aren't simple Vars; if
1103 	 * they're all Vars then we don't have to work so hard in the recursive
1104 	 * scans.  (Note we have to flatten aliases before this.)
1105 	 *
1106 	 * Track Vars that are included in all grouping sets separately in
1107 	 * groupClauseCommonVars, since these are the only ones we can use to
1108 	 * check for functional dependencies.
1109 	 */
1110 	have_non_var_grouping = false;
1111 	foreach(l, groupClauses)
1112 	{
1113 		TargetEntry *tle = lfirst(l);
1114 
1115 		if (!IsA(tle->expr, Var))
1116 		{
1117 			have_non_var_grouping = true;
1118 		}
1119 		else if (!qry->groupingSets ||
1120 				 list_member_int(gset_common, tle->ressortgroupref))
1121 		{
1122 			groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr);
1123 		}
1124 	}
1125 
1126 	/*
1127 	 * Check the targetlist and HAVING clause for ungrouped variables.
1128 	 *
1129 	 * Note: because we check resjunk tlist elements as well as regular ones,
1130 	 * this will also find ungrouped variables that came from ORDER BY and
1131 	 * WINDOW clauses.  For that matter, it's also going to examine the
1132 	 * grouping expressions themselves --- but they'll all pass the test ...
1133 	 *
1134 	 * We also finalize GROUPING expressions, but for that we need to traverse
1135 	 * the original (unflattened) clause in order to modify nodes.
1136 	 */
1137 	clause = (Node *) qry->targetList;
1138 	finalize_grouping_exprs(clause, pstate, qry,
1139 							groupClauses, root,
1140 							have_non_var_grouping);
1141 	if (hasJoinRTEs)
1142 		clause = flatten_join_alias_vars(root, clause);
1143 	check_ungrouped_columns(clause, pstate, qry,
1144 							groupClauses, groupClauseCommonVars,
1145 							have_non_var_grouping,
1146 							&func_grouped_rels);
1147 
1148 	clause = (Node *) qry->havingQual;
1149 	finalize_grouping_exprs(clause, pstate, qry,
1150 							groupClauses, root,
1151 							have_non_var_grouping);
1152 	if (hasJoinRTEs)
1153 		clause = flatten_join_alias_vars(root, clause);
1154 	check_ungrouped_columns(clause, pstate, qry,
1155 							groupClauses, groupClauseCommonVars,
1156 							have_non_var_grouping,
1157 							&func_grouped_rels);
1158 
1159 	/*
1160 	 * Per spec, aggregates can't appear in a recursive term.
1161 	 */
1162 	if (pstate->p_hasAggs && hasSelfRefRTEs)
1163 		ereport(ERROR,
1164 				(errcode(ERRCODE_INVALID_RECURSION),
1165 				 errmsg("aggregate functions are not allowed in a recursive query's recursive term"),
1166 				 parser_errposition(pstate,
1167 									locate_agg_of_level((Node *) qry, 0))));
1168 }
1169 
1170 /*
1171  * check_ungrouped_columns -
1172  *	  Scan the given expression tree for ungrouped variables (variables
1173  *	  that are not listed in the groupClauses list and are not within
1174  *	  the arguments of aggregate functions).  Emit a suitable error message
1175  *	  if any are found.
1176  *
1177  * NOTE: we assume that the given clause has been transformed suitably for
1178  * parser output.  This means we can use expression_tree_walker.
1179  *
1180  * NOTE: we recognize grouping expressions in the main query, but only
1181  * grouping Vars in subqueries.  For example, this will be rejected,
1182  * although it could be allowed:
1183  *		SELECT
1184  *			(SELECT x FROM bar where y = (foo.a + foo.b))
1185  *		FROM foo
1186  *		GROUP BY a + b;
1187  * The difficulty is the need to account for different sublevels_up.
1188  * This appears to require a whole custom version of equal(), which is
1189  * way more pain than the feature seems worth.
1190  */
1191 static void
check_ungrouped_columns(Node * node,ParseState * pstate,Query * qry,List * groupClauses,List * groupClauseCommonVars,bool have_non_var_grouping,List ** func_grouped_rels)1192 check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
1193 						List *groupClauses, List *groupClauseCommonVars,
1194 						bool have_non_var_grouping,
1195 						List **func_grouped_rels)
1196 {
1197 	check_ungrouped_columns_context context;
1198 
1199 	context.pstate = pstate;
1200 	context.qry = qry;
1201 	context.root = NULL;
1202 	context.groupClauses = groupClauses;
1203 	context.groupClauseCommonVars = groupClauseCommonVars;
1204 	context.have_non_var_grouping = have_non_var_grouping;
1205 	context.func_grouped_rels = func_grouped_rels;
1206 	context.sublevels_up = 0;
1207 	context.in_agg_direct_args = false;
1208 	check_ungrouped_columns_walker(node, &context);
1209 }
1210 
1211 static bool
check_ungrouped_columns_walker(Node * node,check_ungrouped_columns_context * context)1212 check_ungrouped_columns_walker(Node *node,
1213 							   check_ungrouped_columns_context *context)
1214 {
1215 	ListCell   *gl;
1216 
1217 	if (node == NULL)
1218 		return false;
1219 	if (IsA(node, Const) ||
1220 		IsA(node, Param))
1221 		return false;			/* constants are always acceptable */
1222 
1223 	if (IsA(node, Aggref))
1224 	{
1225 		Aggref	   *agg = (Aggref *) node;
1226 
1227 		if ((int) agg->agglevelsup == context->sublevels_up)
1228 		{
1229 			/*
1230 			 * If we find an aggregate call of the original level, do not
1231 			 * recurse into its normal arguments, ORDER BY arguments, or
1232 			 * filter; ungrouped vars there are not an error.  But we should
1233 			 * check direct arguments as though they weren't in an aggregate.
1234 			 * We set a special flag in the context to help produce a useful
1235 			 * error message for ungrouped vars in direct arguments.
1236 			 */
1237 			bool		result;
1238 
1239 			Assert(!context->in_agg_direct_args);
1240 			context->in_agg_direct_args = true;
1241 			result = check_ungrouped_columns_walker((Node *) agg->aggdirectargs,
1242 													context);
1243 			context->in_agg_direct_args = false;
1244 			return result;
1245 		}
1246 
1247 		/*
1248 		 * We can skip recursing into aggregates of higher levels altogether,
1249 		 * since they could not possibly contain Vars of concern to us (see
1250 		 * transformAggregateCall).  We do need to look at aggregates of lower
1251 		 * levels, however.
1252 		 */
1253 		if ((int) agg->agglevelsup > context->sublevels_up)
1254 			return false;
1255 	}
1256 
1257 	if (IsA(node, GroupingFunc))
1258 	{
1259 		GroupingFunc *grp = (GroupingFunc *) node;
1260 
1261 		/* handled GroupingFunc separately, no need to recheck at this level */
1262 
1263 		if ((int) grp->agglevelsup >= context->sublevels_up)
1264 			return false;
1265 	}
1266 
1267 	/*
1268 	 * If we have any GROUP BY items that are not simple Vars, check to see if
1269 	 * subexpression as a whole matches any GROUP BY item. We need to do this
1270 	 * at every recursion level so that we recognize GROUPed-BY expressions
1271 	 * before reaching variables within them. But this only works at the outer
1272 	 * query level, as noted above.
1273 	 */
1274 	if (context->have_non_var_grouping && context->sublevels_up == 0)
1275 	{
1276 		foreach(gl, context->groupClauses)
1277 		{
1278 			TargetEntry *tle = lfirst(gl);
1279 
1280 			if (equal(node, tle->expr))
1281 				return false;	/* acceptable, do not descend more */
1282 		}
1283 	}
1284 
1285 	/*
1286 	 * If we have an ungrouped Var of the original query level, we have a
1287 	 * failure.  Vars below the original query level are not a problem, and
1288 	 * neither are Vars from above it.  (If such Vars are ungrouped as far as
1289 	 * their own query level is concerned, that's someone else's problem...)
1290 	 */
1291 	if (IsA(node, Var))
1292 	{
1293 		Var		   *var = (Var *) node;
1294 		RangeTblEntry *rte;
1295 		char	   *attname;
1296 
1297 		if (var->varlevelsup != context->sublevels_up)
1298 			return false;		/* it's not local to my query, ignore */
1299 
1300 		/*
1301 		 * Check for a match, if we didn't do it above.
1302 		 */
1303 		if (!context->have_non_var_grouping || context->sublevels_up != 0)
1304 		{
1305 			foreach(gl, context->groupClauses)
1306 			{
1307 				Var		   *gvar = (Var *) ((TargetEntry *) lfirst(gl))->expr;
1308 
1309 				if (IsA(gvar, Var) &&
1310 					gvar->varno == var->varno &&
1311 					gvar->varattno == var->varattno &&
1312 					gvar->varlevelsup == 0)
1313 					return false;	/* acceptable, we're okay */
1314 			}
1315 		}
1316 
1317 		/*
1318 		 * Check whether the Var is known functionally dependent on the GROUP
1319 		 * BY columns.  If so, we can allow the Var to be used, because the
1320 		 * grouping is really a no-op for this table.  However, this deduction
1321 		 * depends on one or more constraints of the table, so we have to add
1322 		 * those constraints to the query's constraintDeps list, because it's
1323 		 * not semantically valid anymore if the constraint(s) get dropped.
1324 		 * (Therefore, this check must be the last-ditch effort before raising
1325 		 * error: we don't want to add dependencies unnecessarily.)
1326 		 *
1327 		 * Because this is a pretty expensive check, and will have the same
1328 		 * outcome for all columns of a table, we remember which RTEs we've
1329 		 * already proven functional dependency for in the func_grouped_rels
1330 		 * list.  This test also prevents us from adding duplicate entries to
1331 		 * the constraintDeps list.
1332 		 */
1333 		if (list_member_int(*context->func_grouped_rels, var->varno))
1334 			return false;		/* previously proven acceptable */
1335 
1336 		Assert(var->varno > 0 &&
1337 			   (int) var->varno <= list_length(context->pstate->p_rtable));
1338 		rte = rt_fetch(var->varno, context->pstate->p_rtable);
1339 		if (rte->rtekind == RTE_RELATION)
1340 		{
1341 			if (check_functional_grouping(rte->relid,
1342 										  var->varno,
1343 										  0,
1344 										  context->groupClauseCommonVars,
1345 										  &context->qry->constraintDeps))
1346 			{
1347 				*context->func_grouped_rels =
1348 					lappend_int(*context->func_grouped_rels, var->varno);
1349 				return false;	/* acceptable */
1350 			}
1351 		}
1352 
1353 		/* Found an ungrouped local variable; generate error message */
1354 		attname = get_rte_attribute_name(rte, var->varattno);
1355 		if (context->sublevels_up == 0)
1356 			ereport(ERROR,
1357 					(errcode(ERRCODE_GROUPING_ERROR),
1358 					 errmsg("column \"%s.%s\" must appear in the GROUP BY clause or be used in an aggregate function",
1359 							rte->eref->aliasname, attname),
1360 					 context->in_agg_direct_args ?
1361 					 errdetail("Direct arguments of an ordered-set aggregate must use only grouped columns.") : 0,
1362 					 parser_errposition(context->pstate, var->location)));
1363 		else
1364 			ereport(ERROR,
1365 					(errcode(ERRCODE_GROUPING_ERROR),
1366 					 errmsg("subquery uses ungrouped column \"%s.%s\" from outer query",
1367 							rte->eref->aliasname, attname),
1368 					 parser_errposition(context->pstate, var->location)));
1369 	}
1370 
1371 	if (IsA(node, Query))
1372 	{
1373 		/* Recurse into subselects */
1374 		bool		result;
1375 
1376 		context->sublevels_up++;
1377 		result = query_tree_walker((Query *) node,
1378 								   check_ungrouped_columns_walker,
1379 								   (void *) context,
1380 								   0);
1381 		context->sublevels_up--;
1382 		return result;
1383 	}
1384 	return expression_tree_walker(node, check_ungrouped_columns_walker,
1385 								  (void *) context);
1386 }
1387 
1388 /*
1389  * finalize_grouping_exprs -
1390  *	  Scan the given expression tree for GROUPING() and related calls,
1391  *	  and validate and process their arguments.
1392  *
1393  * This is split out from check_ungrouped_columns above because it needs
1394  * to modify the nodes (which it does in-place, not via a mutator) while
1395  * check_ungrouped_columns may see only a copy of the original thanks to
1396  * flattening of join alias vars. So here, we flatten each individual
1397  * GROUPING argument as we see it before comparing it.
1398  */
1399 static void
finalize_grouping_exprs(Node * node,ParseState * pstate,Query * qry,List * groupClauses,PlannerInfo * root,bool have_non_var_grouping)1400 finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
1401 						List *groupClauses, PlannerInfo *root,
1402 						bool have_non_var_grouping)
1403 {
1404 	check_ungrouped_columns_context context;
1405 
1406 	context.pstate = pstate;
1407 	context.qry = qry;
1408 	context.root = root;
1409 	context.groupClauses = groupClauses;
1410 	context.groupClauseCommonVars = NIL;
1411 	context.have_non_var_grouping = have_non_var_grouping;
1412 	context.func_grouped_rels = NULL;
1413 	context.sublevels_up = 0;
1414 	context.in_agg_direct_args = false;
1415 	finalize_grouping_exprs_walker(node, &context);
1416 }
1417 
1418 static bool
finalize_grouping_exprs_walker(Node * node,check_ungrouped_columns_context * context)1419 finalize_grouping_exprs_walker(Node *node,
1420 							   check_ungrouped_columns_context *context)
1421 {
1422 	ListCell   *gl;
1423 
1424 	if (node == NULL)
1425 		return false;
1426 	if (IsA(node, Const) ||
1427 		IsA(node, Param))
1428 		return false;			/* constants are always acceptable */
1429 
1430 	if (IsA(node, Aggref))
1431 	{
1432 		Aggref	   *agg = (Aggref *) node;
1433 
1434 		if ((int) agg->agglevelsup == context->sublevels_up)
1435 		{
1436 			/*
1437 			 * If we find an aggregate call of the original level, do not
1438 			 * recurse into its normal arguments, ORDER BY arguments, or
1439 			 * filter; GROUPING exprs of this level are not allowed there. But
1440 			 * check direct arguments as though they weren't in an aggregate.
1441 			 */
1442 			bool		result;
1443 
1444 			Assert(!context->in_agg_direct_args);
1445 			context->in_agg_direct_args = true;
1446 			result = finalize_grouping_exprs_walker((Node *) agg->aggdirectargs,
1447 													context);
1448 			context->in_agg_direct_args = false;
1449 			return result;
1450 		}
1451 
1452 		/*
1453 		 * We can skip recursing into aggregates of higher levels altogether,
1454 		 * since they could not possibly contain exprs of concern to us (see
1455 		 * transformAggregateCall).  We do need to look at aggregates of lower
1456 		 * levels, however.
1457 		 */
1458 		if ((int) agg->agglevelsup > context->sublevels_up)
1459 			return false;
1460 	}
1461 
1462 	if (IsA(node, GroupingFunc))
1463 	{
1464 		GroupingFunc *grp = (GroupingFunc *) node;
1465 
1466 		/*
1467 		 * We only need to check GroupingFunc nodes at the exact level to
1468 		 * which they belong, since they cannot mix levels in arguments.
1469 		 */
1470 
1471 		if ((int) grp->agglevelsup == context->sublevels_up)
1472 		{
1473 			ListCell   *lc;
1474 			List	   *ref_list = NIL;
1475 
1476 			foreach(lc, grp->args)
1477 			{
1478 				Node	   *expr = lfirst(lc);
1479 				Index		ref = 0;
1480 
1481 				if (context->root)
1482 					expr = flatten_join_alias_vars(context->root, expr);
1483 
1484 				/*
1485 				 * Each expression must match a grouping entry at the current
1486 				 * query level. Unlike the general expression case, we don't
1487 				 * allow functional dependencies or outer references.
1488 				 */
1489 
1490 				if (IsA(expr, Var))
1491 				{
1492 					Var		   *var = (Var *) expr;
1493 
1494 					if (var->varlevelsup == context->sublevels_up)
1495 					{
1496 						foreach(gl, context->groupClauses)
1497 						{
1498 							TargetEntry *tle = lfirst(gl);
1499 							Var		   *gvar = (Var *) tle->expr;
1500 
1501 							if (IsA(gvar, Var) &&
1502 								gvar->varno == var->varno &&
1503 								gvar->varattno == var->varattno &&
1504 								gvar->varlevelsup == 0)
1505 							{
1506 								ref = tle->ressortgroupref;
1507 								break;
1508 							}
1509 						}
1510 					}
1511 				}
1512 				else if (context->have_non_var_grouping &&
1513 						 context->sublevels_up == 0)
1514 				{
1515 					foreach(gl, context->groupClauses)
1516 					{
1517 						TargetEntry *tle = lfirst(gl);
1518 
1519 						if (equal(expr, tle->expr))
1520 						{
1521 							ref = tle->ressortgroupref;
1522 							break;
1523 						}
1524 					}
1525 				}
1526 
1527 				if (ref == 0)
1528 					ereport(ERROR,
1529 							(errcode(ERRCODE_GROUPING_ERROR),
1530 							 errmsg("arguments to GROUPING must be grouping expressions of the associated query level"),
1531 							 parser_errposition(context->pstate,
1532 												exprLocation(expr))));
1533 
1534 				ref_list = lappend_int(ref_list, ref);
1535 			}
1536 
1537 			grp->refs = ref_list;
1538 		}
1539 
1540 		if ((int) grp->agglevelsup > context->sublevels_up)
1541 			return false;
1542 	}
1543 
1544 	if (IsA(node, Query))
1545 	{
1546 		/* Recurse into subselects */
1547 		bool		result;
1548 
1549 		context->sublevels_up++;
1550 		result = query_tree_walker((Query *) node,
1551 								   finalize_grouping_exprs_walker,
1552 								   (void *) context,
1553 								   0);
1554 		context->sublevels_up--;
1555 		return result;
1556 	}
1557 	return expression_tree_walker(node, finalize_grouping_exprs_walker,
1558 								  (void *) context);
1559 }
1560 
1561 
1562 /*
1563  * Given a GroupingSet node, expand it and return a list of lists.
1564  *
1565  * For EMPTY nodes, return a list of one empty list.
1566  *
1567  * For SIMPLE nodes, return a list of one list, which is the node content.
1568  *
1569  * For CUBE and ROLLUP nodes, return a list of the expansions.
1570  *
1571  * For SET nodes, recursively expand contained CUBE and ROLLUP.
1572  */
1573 static List *
expand_groupingset_node(GroupingSet * gs)1574 expand_groupingset_node(GroupingSet *gs)
1575 {
1576 	List	   *result = NIL;
1577 
1578 	switch (gs->kind)
1579 	{
1580 		case GROUPING_SET_EMPTY:
1581 			result = list_make1(NIL);
1582 			break;
1583 
1584 		case GROUPING_SET_SIMPLE:
1585 			result = list_make1(gs->content);
1586 			break;
1587 
1588 		case GROUPING_SET_ROLLUP:
1589 			{
1590 				List	   *rollup_val = gs->content;
1591 				ListCell   *lc;
1592 				int			curgroup_size = list_length(gs->content);
1593 
1594 				while (curgroup_size > 0)
1595 				{
1596 					List	   *current_result = NIL;
1597 					int			i = curgroup_size;
1598 
1599 					foreach(lc, rollup_val)
1600 					{
1601 						GroupingSet *gs_current = (GroupingSet *) lfirst(lc);
1602 
1603 						Assert(gs_current->kind == GROUPING_SET_SIMPLE);
1604 
1605 						current_result
1606 							= list_concat(current_result,
1607 										  list_copy(gs_current->content));
1608 
1609 						/* If we are done with making the current group, break */
1610 						if (--i == 0)
1611 							break;
1612 					}
1613 
1614 					result = lappend(result, current_result);
1615 					--curgroup_size;
1616 				}
1617 
1618 				result = lappend(result, NIL);
1619 			}
1620 			break;
1621 
1622 		case GROUPING_SET_CUBE:
1623 			{
1624 				List	   *cube_list = gs->content;
1625 				int			number_bits = list_length(cube_list);
1626 				uint32		num_sets;
1627 				uint32		i;
1628 
1629 				/* parser should cap this much lower */
1630 				Assert(number_bits < 31);
1631 
1632 				num_sets = (1U << number_bits);
1633 
1634 				for (i = 0; i < num_sets; i++)
1635 				{
1636 					List	   *current_result = NIL;
1637 					ListCell   *lc;
1638 					uint32		mask = 1U;
1639 
1640 					foreach(lc, cube_list)
1641 					{
1642 						GroupingSet *gs_current = (GroupingSet *) lfirst(lc);
1643 
1644 						Assert(gs_current->kind == GROUPING_SET_SIMPLE);
1645 
1646 						if (mask & i)
1647 						{
1648 							current_result
1649 								= list_concat(current_result,
1650 											  list_copy(gs_current->content));
1651 						}
1652 
1653 						mask <<= 1;
1654 					}
1655 
1656 					result = lappend(result, current_result);
1657 				}
1658 			}
1659 			break;
1660 
1661 		case GROUPING_SET_SETS:
1662 			{
1663 				ListCell   *lc;
1664 
1665 				foreach(lc, gs->content)
1666 				{
1667 					List	   *current_result = expand_groupingset_node(lfirst(lc));
1668 
1669 					result = list_concat(result, current_result);
1670 				}
1671 			}
1672 			break;
1673 	}
1674 
1675 	return result;
1676 }
1677 
1678 static int
cmp_list_len_asc(const void * a,const void * b)1679 cmp_list_len_asc(const void *a, const void *b)
1680 {
1681 	int			la = list_length(*(List *const *) a);
1682 	int			lb = list_length(*(List *const *) b);
1683 
1684 	return (la > lb) ? 1 : (la == lb) ? 0 : -1;
1685 }
1686 
1687 /*
1688  * Expand a groupingSets clause to a flat list of grouping sets.
1689  * The returned list is sorted by length, shortest sets first.
1690  *
1691  * This is mainly for the planner, but we use it here too to do
1692  * some consistency checks.
1693  */
1694 List *
expand_grouping_sets(List * groupingSets,int limit)1695 expand_grouping_sets(List *groupingSets, int limit)
1696 {
1697 	List	   *expanded_groups = NIL;
1698 	List	   *result = NIL;
1699 	double		numsets = 1;
1700 	ListCell   *lc;
1701 
1702 	if (groupingSets == NIL)
1703 		return NIL;
1704 
1705 	foreach(lc, groupingSets)
1706 	{
1707 		List	   *current_result = NIL;
1708 		GroupingSet *gs = lfirst(lc);
1709 
1710 		current_result = expand_groupingset_node(gs);
1711 
1712 		Assert(current_result != NIL);
1713 
1714 		numsets *= list_length(current_result);
1715 
1716 		if (limit >= 0 && numsets > limit)
1717 			return NIL;
1718 
1719 		expanded_groups = lappend(expanded_groups, current_result);
1720 	}
1721 
1722 	/*
1723 	 * Do cartesian product between sublists of expanded_groups. While at it,
1724 	 * remove any duplicate elements from individual grouping sets (we must
1725 	 * NOT change the number of sets though)
1726 	 */
1727 
1728 	foreach(lc, (List *) linitial(expanded_groups))
1729 	{
1730 		result = lappend(result, list_union_int(NIL, (List *) lfirst(lc)));
1731 	}
1732 
1733 	for_each_cell(lc, lnext(list_head(expanded_groups)))
1734 	{
1735 		List	   *p = lfirst(lc);
1736 		List	   *new_result = NIL;
1737 		ListCell   *lc2;
1738 
1739 		foreach(lc2, result)
1740 		{
1741 			List	   *q = lfirst(lc2);
1742 			ListCell   *lc3;
1743 
1744 			foreach(lc3, p)
1745 			{
1746 				new_result = lappend(new_result,
1747 									 list_union_int(q, (List *) lfirst(lc3)));
1748 			}
1749 		}
1750 		result = new_result;
1751 	}
1752 
1753 	if (list_length(result) > 1)
1754 	{
1755 		int			result_len = list_length(result);
1756 		List	  **buf = palloc(sizeof(List *) * result_len);
1757 		List	  **ptr = buf;
1758 
1759 		foreach(lc, result)
1760 		{
1761 			*ptr++ = lfirst(lc);
1762 		}
1763 
1764 		qsort(buf, result_len, sizeof(List *), cmp_list_len_asc);
1765 
1766 		result = NIL;
1767 		ptr = buf;
1768 
1769 		while (result_len-- > 0)
1770 			result = lappend(result, *ptr++);
1771 
1772 		pfree(buf);
1773 	}
1774 
1775 	return result;
1776 }
1777 
1778 /*
1779  * get_aggregate_argtypes
1780  *	Identify the specific datatypes passed to an aggregate call.
1781  *
1782  * Given an Aggref, extract the actual datatypes of the input arguments.
1783  * The input datatypes are reported in a way that matches up with the
1784  * aggregate's declaration, ie, any ORDER BY columns attached to a plain
1785  * aggregate are ignored, but we report both direct and aggregated args of
1786  * an ordered-set aggregate.
1787  *
1788  * Datatypes are returned into inputTypes[], which must reference an array
1789  * of length FUNC_MAX_ARGS.
1790  *
1791  * The function result is the number of actual arguments.
1792  */
1793 int
get_aggregate_argtypes(Aggref * aggref,Oid * inputTypes)1794 get_aggregate_argtypes(Aggref *aggref, Oid *inputTypes)
1795 {
1796 	int			numArguments = 0;
1797 	ListCell   *lc;
1798 
1799 	Assert(list_length(aggref->aggargtypes) <= FUNC_MAX_ARGS);
1800 
1801 	foreach(lc, aggref->aggargtypes)
1802 	{
1803 		inputTypes[numArguments++] = lfirst_oid(lc);
1804 	}
1805 
1806 	return numArguments;
1807 }
1808 
1809 /*
1810  * resolve_aggregate_transtype
1811  *	Identify the transition state value's datatype for an aggregate call.
1812  *
1813  * This function resolves a polymorphic aggregate's state datatype.
1814  * It must be passed the aggtranstype from the aggregate's catalog entry,
1815  * as well as the actual argument types extracted by get_aggregate_argtypes.
1816  * (We could fetch pg_aggregate.aggtranstype internally, but all existing
1817  * callers already have the value at hand, so we make them pass it.)
1818  */
1819 Oid
resolve_aggregate_transtype(Oid aggfuncid,Oid aggtranstype,Oid * inputTypes,int numArguments)1820 resolve_aggregate_transtype(Oid aggfuncid,
1821 							Oid aggtranstype,
1822 							Oid *inputTypes,
1823 							int numArguments)
1824 {
1825 	/* resolve actual type of transition state, if polymorphic */
1826 	if (IsPolymorphicType(aggtranstype))
1827 	{
1828 		/* have to fetch the agg's declared input types... */
1829 		Oid		   *declaredArgTypes;
1830 		int			agg_nargs;
1831 
1832 		(void) get_func_signature(aggfuncid, &declaredArgTypes, &agg_nargs);
1833 
1834 		/*
1835 		 * VARIADIC ANY aggs could have more actual than declared args, but
1836 		 * such extra args can't affect polymorphic type resolution.
1837 		 */
1838 		Assert(agg_nargs <= numArguments);
1839 
1840 		aggtranstype = enforce_generic_type_consistency(inputTypes,
1841 														declaredArgTypes,
1842 														agg_nargs,
1843 														aggtranstype,
1844 														false);
1845 		pfree(declaredArgTypes);
1846 	}
1847 	return aggtranstype;
1848 }
1849 
1850 /*
1851  * Create an expression tree for the transition function of an aggregate.
1852  * This is needed so that polymorphic functions can be used within an
1853  * aggregate --- without the expression tree, such functions would not know
1854  * the datatypes they are supposed to use.  (The trees will never actually
1855  * be executed, however, so we can skimp a bit on correctness.)
1856  *
1857  * agg_input_types and agg_state_type identifies the input types of the
1858  * aggregate.  These should be resolved to actual types (ie, none should
1859  * ever be ANYELEMENT etc).
1860  * agg_input_collation is the aggregate function's input collation.
1861  *
1862  * For an ordered-set aggregate, remember that agg_input_types describes
1863  * the direct arguments followed by the aggregated arguments.
1864  *
1865  * transfn_oid and invtransfn_oid identify the funcs to be called; the
1866  * latter may be InvalidOid, however if invtransfn_oid is set then
1867  * transfn_oid must also be set.
1868  *
1869  * Pointers to the constructed trees are returned into *transfnexpr,
1870  * *invtransfnexpr. If there is no invtransfn, the respective pointer is set
1871  * to NULL.  Since use of the invtransfn is optional, NULL may be passed for
1872  * invtransfnexpr.
1873  */
1874 void
build_aggregate_transfn_expr(Oid * agg_input_types,int agg_num_inputs,int agg_num_direct_inputs,bool agg_variadic,Oid agg_state_type,Oid agg_input_collation,Oid transfn_oid,Oid invtransfn_oid,Expr ** transfnexpr,Expr ** invtransfnexpr)1875 build_aggregate_transfn_expr(Oid *agg_input_types,
1876 							 int agg_num_inputs,
1877 							 int agg_num_direct_inputs,
1878 							 bool agg_variadic,
1879 							 Oid agg_state_type,
1880 							 Oid agg_input_collation,
1881 							 Oid transfn_oid,
1882 							 Oid invtransfn_oid,
1883 							 Expr **transfnexpr,
1884 							 Expr **invtransfnexpr)
1885 {
1886 	List	   *args;
1887 	FuncExpr   *fexpr;
1888 	int			i;
1889 
1890 	/*
1891 	 * Build arg list to use in the transfn FuncExpr node.
1892 	 */
1893 	args = list_make1(make_agg_arg(agg_state_type, agg_input_collation));
1894 
1895 	for (i = agg_num_direct_inputs; i < agg_num_inputs; i++)
1896 	{
1897 		args = lappend(args,
1898 					   make_agg_arg(agg_input_types[i], agg_input_collation));
1899 	}
1900 
1901 	fexpr = makeFuncExpr(transfn_oid,
1902 						 agg_state_type,
1903 						 args,
1904 						 InvalidOid,
1905 						 agg_input_collation,
1906 						 COERCE_EXPLICIT_CALL);
1907 	fexpr->funcvariadic = agg_variadic;
1908 	*transfnexpr = (Expr *) fexpr;
1909 
1910 	/*
1911 	 * Build invtransfn expression if requested, with same args as transfn
1912 	 */
1913 	if (invtransfnexpr != NULL)
1914 	{
1915 		if (OidIsValid(invtransfn_oid))
1916 		{
1917 			fexpr = makeFuncExpr(invtransfn_oid,
1918 								 agg_state_type,
1919 								 args,
1920 								 InvalidOid,
1921 								 agg_input_collation,
1922 								 COERCE_EXPLICIT_CALL);
1923 			fexpr->funcvariadic = agg_variadic;
1924 			*invtransfnexpr = (Expr *) fexpr;
1925 		}
1926 		else
1927 			*invtransfnexpr = NULL;
1928 	}
1929 }
1930 
1931 /*
1932  * Like build_aggregate_transfn_expr, but creates an expression tree for the
1933  * combine function of an aggregate, rather than the transition function.
1934  */
1935 void
build_aggregate_combinefn_expr(Oid agg_state_type,Oid agg_input_collation,Oid combinefn_oid,Expr ** combinefnexpr)1936 build_aggregate_combinefn_expr(Oid agg_state_type,
1937 							   Oid agg_input_collation,
1938 							   Oid combinefn_oid,
1939 							   Expr **combinefnexpr)
1940 {
1941 	Node	   *argp;
1942 	List	   *args;
1943 	FuncExpr   *fexpr;
1944 
1945 	/* combinefn takes two arguments of the aggregate state type */
1946 	argp = make_agg_arg(agg_state_type, agg_input_collation);
1947 
1948 	args = list_make2(argp, argp);
1949 
1950 	fexpr = makeFuncExpr(combinefn_oid,
1951 						 agg_state_type,
1952 						 args,
1953 						 InvalidOid,
1954 						 agg_input_collation,
1955 						 COERCE_EXPLICIT_CALL);
1956 	/* combinefn is currently never treated as variadic */
1957 	*combinefnexpr = (Expr *) fexpr;
1958 }
1959 
1960 /*
1961  * Like build_aggregate_transfn_expr, but creates an expression tree for the
1962  * serialization function of an aggregate.
1963  */
1964 void
build_aggregate_serialfn_expr(Oid serialfn_oid,Expr ** serialfnexpr)1965 build_aggregate_serialfn_expr(Oid serialfn_oid,
1966 							  Expr **serialfnexpr)
1967 {
1968 	List	   *args;
1969 	FuncExpr   *fexpr;
1970 
1971 	/* serialfn always takes INTERNAL and returns BYTEA */
1972 	args = list_make1(make_agg_arg(INTERNALOID, InvalidOid));
1973 
1974 	fexpr = makeFuncExpr(serialfn_oid,
1975 						 BYTEAOID,
1976 						 args,
1977 						 InvalidOid,
1978 						 InvalidOid,
1979 						 COERCE_EXPLICIT_CALL);
1980 	*serialfnexpr = (Expr *) fexpr;
1981 }
1982 
1983 /*
1984  * Like build_aggregate_transfn_expr, but creates an expression tree for the
1985  * deserialization function of an aggregate.
1986  */
1987 void
build_aggregate_deserialfn_expr(Oid deserialfn_oid,Expr ** deserialfnexpr)1988 build_aggregate_deserialfn_expr(Oid deserialfn_oid,
1989 								Expr **deserialfnexpr)
1990 {
1991 	List	   *args;
1992 	FuncExpr   *fexpr;
1993 
1994 	/* deserialfn always takes BYTEA, INTERNAL and returns INTERNAL */
1995 	args = list_make2(make_agg_arg(BYTEAOID, InvalidOid),
1996 					  make_agg_arg(INTERNALOID, InvalidOid));
1997 
1998 	fexpr = makeFuncExpr(deserialfn_oid,
1999 						 INTERNALOID,
2000 						 args,
2001 						 InvalidOid,
2002 						 InvalidOid,
2003 						 COERCE_EXPLICIT_CALL);
2004 	*deserialfnexpr = (Expr *) fexpr;
2005 }
2006 
2007 /*
2008  * Like build_aggregate_transfn_expr, but creates an expression tree for the
2009  * final function of an aggregate, rather than the transition function.
2010  */
2011 void
build_aggregate_finalfn_expr(Oid * agg_input_types,int num_finalfn_inputs,Oid agg_state_type,Oid agg_result_type,Oid agg_input_collation,Oid finalfn_oid,Expr ** finalfnexpr)2012 build_aggregate_finalfn_expr(Oid *agg_input_types,
2013 							 int num_finalfn_inputs,
2014 							 Oid agg_state_type,
2015 							 Oid agg_result_type,
2016 							 Oid agg_input_collation,
2017 							 Oid finalfn_oid,
2018 							 Expr **finalfnexpr)
2019 {
2020 	List	   *args;
2021 	int			i;
2022 
2023 	/*
2024 	 * Build expr tree for final function
2025 	 */
2026 	args = list_make1(make_agg_arg(agg_state_type, agg_input_collation));
2027 
2028 	/* finalfn may take additional args, which match agg's input types */
2029 	for (i = 0; i < num_finalfn_inputs - 1; i++)
2030 	{
2031 		args = lappend(args,
2032 					   make_agg_arg(agg_input_types[i], agg_input_collation));
2033 	}
2034 
2035 	*finalfnexpr = (Expr *) makeFuncExpr(finalfn_oid,
2036 										 agg_result_type,
2037 										 args,
2038 										 InvalidOid,
2039 										 agg_input_collation,
2040 										 COERCE_EXPLICIT_CALL);
2041 	/* finalfn is currently never treated as variadic */
2042 }
2043 
2044 /*
2045  * Convenience function to build dummy argument expressions for aggregates.
2046  *
2047  * We really only care that an aggregate support function can discover its
2048  * actual argument types at runtime using get_fn_expr_argtype(), so it's okay
2049  * to use Param nodes that don't correspond to any real Param.
2050  */
2051 static Node *
make_agg_arg(Oid argtype,Oid argcollation)2052 make_agg_arg(Oid argtype, Oid argcollation)
2053 {
2054 	Param	   *argp = makeNode(Param);
2055 
2056 	argp->paramkind = PARAM_EXEC;
2057 	argp->paramid = -1;
2058 	argp->paramtype = argtype;
2059 	argp->paramtypmod = -1;
2060 	argp->paramcollid = argcollation;
2061 	argp->location = -1;
2062 	return (Node *) argp;
2063 }
2064