1 /*-------------------------------------------------------------------------
2  *
3  * parse_agg.c
4  *	  handle aggregates and window functions in parser
5  *
6  * Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group
7  * Portions Copyright (c) 1994, Regents of the University of California
8  *
9  *
10  * IDENTIFICATION
11  *	  src/backend/parser/parse_agg.c
12  *
13  *-------------------------------------------------------------------------
14  */
15 #include "postgres.h"
16 
17 #include "catalog/pg_aggregate.h"
18 #include "catalog/pg_constraint.h"
19 #include "catalog/pg_type.h"
20 #include "nodes/makefuncs.h"
21 #include "nodes/nodeFuncs.h"
22 #include "optimizer/optimizer.h"
23 #include "parser/parse_agg.h"
24 #include "parser/parse_clause.h"
25 #include "parser/parse_coerce.h"
26 #include "parser/parse_expr.h"
27 #include "parser/parsetree.h"
28 #include "rewrite/rewriteManip.h"
29 #include "utils/builtins.h"
30 #include "utils/lsyscache.h"
31 
32 
33 typedef struct
34 {
35 	ParseState *pstate;
36 	int			min_varlevel;
37 	int			min_agglevel;
38 	int			sublevels_up;
39 } check_agg_arguments_context;
40 
41 typedef struct
42 {
43 	ParseState *pstate;
44 	Query	   *qry;
45 	bool		hasJoinRTEs;
46 	List	   *groupClauses;
47 	List	   *groupClauseCommonVars;
48 	bool		have_non_var_grouping;
49 	List	  **func_grouped_rels;
50 	int			sublevels_up;
51 	bool		in_agg_direct_args;
52 } check_ungrouped_columns_context;
53 
54 static int	check_agg_arguments(ParseState *pstate,
55 								List *directargs,
56 								List *args,
57 								Expr *filter);
58 static bool check_agg_arguments_walker(Node *node,
59 									   check_agg_arguments_context *context);
60 static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
61 									List *groupClauses, List *groupClauseVars,
62 									bool have_non_var_grouping,
63 									List **func_grouped_rels);
64 static bool check_ungrouped_columns_walker(Node *node,
65 										   check_ungrouped_columns_context *context);
66 static void finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
67 									List *groupClauses, bool hasJoinRTEs,
68 									bool have_non_var_grouping);
69 static bool finalize_grouping_exprs_walker(Node *node,
70 										   check_ungrouped_columns_context *context);
71 static void check_agglevels_and_constraints(ParseState *pstate, Node *expr);
72 static List *expand_groupingset_node(GroupingSet *gs);
73 static Node *make_agg_arg(Oid argtype, Oid argcollation);
74 
75 
76 /*
77  * transformAggregateCall -
78  *		Finish initial transformation of an aggregate call
79  *
80  * parse_func.c has recognized the function as an aggregate, and has set up
81  * all the fields of the Aggref except aggargtypes, aggdirectargs, args,
82  * aggorder, aggdistinct and agglevelsup.  The passed-in args list has been
83  * through standard expression transformation and type coercion to match the
84  * agg's declared arg types, while the passed-in aggorder list hasn't been
85  * transformed at all.
86  *
87  * Here we separate the args list into direct and aggregated args, storing the
88  * former in agg->aggdirectargs and the latter in agg->args.  The regular
89  * args, but not the direct args, are converted into a targetlist by inserting
90  * TargetEntry nodes.  We then transform the aggorder and agg_distinct
91  * specifications to produce lists of SortGroupClause nodes for agg->aggorder
92  * and agg->aggdistinct.  (For a regular aggregate, this might result in
93  * adding resjunk expressions to the targetlist; but for ordered-set
94  * aggregates the aggorder list will always be one-to-one with the aggregated
95  * args.)
96  *
97  * We must also determine which query level the aggregate actually belongs to,
98  * set agglevelsup accordingly, and mark p_hasAggs true in the corresponding
99  * pstate level.
100  */
101 void
transformAggregateCall(ParseState * pstate,Aggref * agg,List * args,List * aggorder,bool agg_distinct)102 transformAggregateCall(ParseState *pstate, Aggref *agg,
103 					   List *args, List *aggorder, bool agg_distinct)
104 {
105 	List	   *argtypes = NIL;
106 	List	   *tlist = NIL;
107 	List	   *torder = NIL;
108 	List	   *tdistinct = NIL;
109 	AttrNumber	attno = 1;
110 	int			save_next_resno;
111 	ListCell   *lc;
112 
113 	/*
114 	 * Before separating the args into direct and aggregated args, make a list
115 	 * of their data type OIDs for use later.
116 	 */
117 	foreach(lc, args)
118 	{
119 		Expr	   *arg = (Expr *) lfirst(lc);
120 
121 		argtypes = lappend_oid(argtypes, exprType((Node *) arg));
122 	}
123 	agg->aggargtypes = argtypes;
124 
125 	if (AGGKIND_IS_ORDERED_SET(agg->aggkind))
126 	{
127 		/*
128 		 * For an ordered-set agg, the args list includes direct args and
129 		 * aggregated args; we must split them apart.
130 		 */
131 		int			numDirectArgs = list_length(args) - list_length(aggorder);
132 		List	   *aargs;
133 		ListCell   *lc2;
134 
135 		Assert(numDirectArgs >= 0);
136 
137 		aargs = list_copy_tail(args, numDirectArgs);
138 		agg->aggdirectargs = list_truncate(args, numDirectArgs);
139 
140 		/*
141 		 * Build a tlist from the aggregated args, and make a sortlist entry
142 		 * for each one.  Note that the expressions in the SortBy nodes are
143 		 * ignored (they are the raw versions of the transformed args); we are
144 		 * just looking at the sort information in the SortBy nodes.
145 		 */
146 		forboth(lc, aargs, lc2, aggorder)
147 		{
148 			Expr	   *arg = (Expr *) lfirst(lc);
149 			SortBy	   *sortby = (SortBy *) lfirst(lc2);
150 			TargetEntry *tle;
151 
152 			/* We don't bother to assign column names to the entries */
153 			tle = makeTargetEntry(arg, attno++, NULL, false);
154 			tlist = lappend(tlist, tle);
155 
156 			torder = addTargetToSortList(pstate, tle,
157 										 torder, tlist, sortby);
158 		}
159 
160 		/* Never any DISTINCT in an ordered-set agg */
161 		Assert(!agg_distinct);
162 	}
163 	else
164 	{
165 		/* Regular aggregate, so it has no direct args */
166 		agg->aggdirectargs = NIL;
167 
168 		/*
169 		 * Transform the plain list of Exprs into a targetlist.
170 		 */
171 		foreach(lc, args)
172 		{
173 			Expr	   *arg = (Expr *) lfirst(lc);
174 			TargetEntry *tle;
175 
176 			/* We don't bother to assign column names to the entries */
177 			tle = makeTargetEntry(arg, attno++, NULL, false);
178 			tlist = lappend(tlist, tle);
179 		}
180 
181 		/*
182 		 * If we have an ORDER BY, transform it.  This will add columns to the
183 		 * tlist if they appear in ORDER BY but weren't already in the arg
184 		 * list.  They will be marked resjunk = true so we can tell them apart
185 		 * from regular aggregate arguments later.
186 		 *
187 		 * We need to mess with p_next_resno since it will be used to number
188 		 * any new targetlist entries.
189 		 */
190 		save_next_resno = pstate->p_next_resno;
191 		pstate->p_next_resno = attno;
192 
193 		torder = transformSortClause(pstate,
194 									 aggorder,
195 									 &tlist,
196 									 EXPR_KIND_ORDER_BY,
197 									 true /* force SQL99 rules */ );
198 
199 		/*
200 		 * If we have DISTINCT, transform that to produce a distinctList.
201 		 */
202 		if (agg_distinct)
203 		{
204 			tdistinct = transformDistinctClause(pstate, &tlist, torder, true);
205 
206 			/*
207 			 * Remove this check if executor support for hashed distinct for
208 			 * aggregates is ever added.
209 			 */
210 			foreach(lc, tdistinct)
211 			{
212 				SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
213 
214 				if (!OidIsValid(sortcl->sortop))
215 				{
216 					Node	   *expr = get_sortgroupclause_expr(sortcl, tlist);
217 
218 					ereport(ERROR,
219 							(errcode(ERRCODE_UNDEFINED_FUNCTION),
220 							 errmsg("could not identify an ordering operator for type %s",
221 									format_type_be(exprType(expr))),
222 							 errdetail("Aggregates with DISTINCT must be able to sort their inputs."),
223 							 parser_errposition(pstate, exprLocation(expr))));
224 				}
225 			}
226 		}
227 
228 		pstate->p_next_resno = save_next_resno;
229 	}
230 
231 	/* Update the Aggref with the transformation results */
232 	agg->args = tlist;
233 	agg->aggorder = torder;
234 	agg->aggdistinct = tdistinct;
235 
236 	check_agglevels_and_constraints(pstate, (Node *) agg);
237 }
238 
239 /*
240  * transformGroupingFunc
241  *		Transform a GROUPING expression
242  *
243  * GROUPING() behaves very like an aggregate.  Processing of levels and nesting
244  * is done as for aggregates.  We set p_hasAggs for these expressions too.
245  */
246 Node *
transformGroupingFunc(ParseState * pstate,GroupingFunc * p)247 transformGroupingFunc(ParseState *pstate, GroupingFunc *p)
248 {
249 	ListCell   *lc;
250 	List	   *args = p->args;
251 	List	   *result_list = NIL;
252 	GroupingFunc *result = makeNode(GroupingFunc);
253 
254 	if (list_length(args) > 31)
255 		ereport(ERROR,
256 				(errcode(ERRCODE_TOO_MANY_ARGUMENTS),
257 				 errmsg("GROUPING must have fewer than 32 arguments"),
258 				 parser_errposition(pstate, p->location)));
259 
260 	foreach(lc, args)
261 	{
262 		Node	   *current_result;
263 
264 		current_result = transformExpr(pstate, (Node *) lfirst(lc), pstate->p_expr_kind);
265 
266 		/* acceptability of expressions is checked later */
267 
268 		result_list = lappend(result_list, current_result);
269 	}
270 
271 	result->args = result_list;
272 	result->location = p->location;
273 
274 	check_agglevels_and_constraints(pstate, (Node *) result);
275 
276 	return (Node *) result;
277 }
278 
279 /*
280  * Aggregate functions and grouping operations (which are combined in the spec
281  * as <set function specification>) are very similar with regard to level and
282  * nesting restrictions (though we allow a lot more things than the spec does).
283  * Centralise those restrictions here.
284  */
285 static void
check_agglevels_and_constraints(ParseState * pstate,Node * expr)286 check_agglevels_and_constraints(ParseState *pstate, Node *expr)
287 {
288 	List	   *directargs = NIL;
289 	List	   *args = NIL;
290 	Expr	   *filter = NULL;
291 	int			min_varlevel;
292 	int			location = -1;
293 	Index	   *p_levelsup;
294 	const char *err;
295 	bool		errkind;
296 	bool		isAgg = IsA(expr, Aggref);
297 
298 	if (isAgg)
299 	{
300 		Aggref	   *agg = (Aggref *) expr;
301 
302 		directargs = agg->aggdirectargs;
303 		args = agg->args;
304 		filter = agg->aggfilter;
305 		location = agg->location;
306 		p_levelsup = &agg->agglevelsup;
307 	}
308 	else
309 	{
310 		GroupingFunc *grp = (GroupingFunc *) expr;
311 
312 		args = grp->args;
313 		location = grp->location;
314 		p_levelsup = &grp->agglevelsup;
315 	}
316 
317 	/*
318 	 * Check the arguments to compute the aggregate's level and detect
319 	 * improper nesting.
320 	 */
321 	min_varlevel = check_agg_arguments(pstate,
322 									   directargs,
323 									   args,
324 									   filter);
325 
326 	*p_levelsup = min_varlevel;
327 
328 	/* Mark the correct pstate level as having aggregates */
329 	while (min_varlevel-- > 0)
330 		pstate = pstate->parentParseState;
331 	pstate->p_hasAggs = true;
332 
333 	/*
334 	 * Check to see if the aggregate function is in an invalid place within
335 	 * its aggregation query.
336 	 *
337 	 * For brevity we support two schemes for reporting an error here: set
338 	 * "err" to a custom message, or set "errkind" true if the error context
339 	 * is sufficiently identified by what ParseExprKindName will return, *and*
340 	 * what it will return is just a SQL keyword.  (Otherwise, use a custom
341 	 * message to avoid creating translation problems.)
342 	 */
343 	err = NULL;
344 	errkind = false;
345 	switch (pstate->p_expr_kind)
346 	{
347 		case EXPR_KIND_NONE:
348 			Assert(false);		/* can't happen */
349 			break;
350 		case EXPR_KIND_OTHER:
351 
352 			/*
353 			 * Accept aggregate/grouping here; caller must throw error if
354 			 * wanted
355 			 */
356 			break;
357 		case EXPR_KIND_JOIN_ON:
358 		case EXPR_KIND_JOIN_USING:
359 			if (isAgg)
360 				err = _("aggregate functions are not allowed in JOIN conditions");
361 			else
362 				err = _("grouping operations are not allowed in JOIN conditions");
363 
364 			break;
365 		case EXPR_KIND_FROM_SUBSELECT:
366 			/* Should only be possible in a LATERAL subquery */
367 			Assert(pstate->p_lateral_active);
368 
369 			/*
370 			 * Aggregate/grouping scope rules make it worth being explicit
371 			 * here
372 			 */
373 			if (isAgg)
374 				err = _("aggregate functions are not allowed in FROM clause of their own query level");
375 			else
376 				err = _("grouping operations are not allowed in FROM clause of their own query level");
377 
378 			break;
379 		case EXPR_KIND_FROM_FUNCTION:
380 			if (isAgg)
381 				err = _("aggregate functions are not allowed in functions in FROM");
382 			else
383 				err = _("grouping operations are not allowed in functions in FROM");
384 
385 			break;
386 		case EXPR_KIND_WHERE:
387 			errkind = true;
388 			break;
389 		case EXPR_KIND_POLICY:
390 			if (isAgg)
391 				err = _("aggregate functions are not allowed in policy expressions");
392 			else
393 				err = _("grouping operations are not allowed in policy expressions");
394 
395 			break;
396 		case EXPR_KIND_HAVING:
397 			/* okay */
398 			break;
399 		case EXPR_KIND_FILTER:
400 			errkind = true;
401 			break;
402 		case EXPR_KIND_WINDOW_PARTITION:
403 			/* okay */
404 			break;
405 		case EXPR_KIND_WINDOW_ORDER:
406 			/* okay */
407 			break;
408 		case EXPR_KIND_WINDOW_FRAME_RANGE:
409 			if (isAgg)
410 				err = _("aggregate functions are not allowed in window RANGE");
411 			else
412 				err = _("grouping operations are not allowed in window RANGE");
413 
414 			break;
415 		case EXPR_KIND_WINDOW_FRAME_ROWS:
416 			if (isAgg)
417 				err = _("aggregate functions are not allowed in window ROWS");
418 			else
419 				err = _("grouping operations are not allowed in window ROWS");
420 
421 			break;
422 		case EXPR_KIND_WINDOW_FRAME_GROUPS:
423 			if (isAgg)
424 				err = _("aggregate functions are not allowed in window GROUPS");
425 			else
426 				err = _("grouping operations are not allowed in window GROUPS");
427 
428 			break;
429 		case EXPR_KIND_SELECT_TARGET:
430 			/* okay */
431 			break;
432 		case EXPR_KIND_INSERT_TARGET:
433 		case EXPR_KIND_UPDATE_SOURCE:
434 		case EXPR_KIND_UPDATE_TARGET:
435 			errkind = true;
436 			break;
437 		case EXPR_KIND_GROUP_BY:
438 			errkind = true;
439 			break;
440 		case EXPR_KIND_ORDER_BY:
441 			/* okay */
442 			break;
443 		case EXPR_KIND_DISTINCT_ON:
444 			/* okay */
445 			break;
446 		case EXPR_KIND_LIMIT:
447 		case EXPR_KIND_OFFSET:
448 			errkind = true;
449 			break;
450 		case EXPR_KIND_RETURNING:
451 			errkind = true;
452 			break;
453 		case EXPR_KIND_VALUES:
454 		case EXPR_KIND_VALUES_SINGLE:
455 			errkind = true;
456 			break;
457 		case EXPR_KIND_CHECK_CONSTRAINT:
458 		case EXPR_KIND_DOMAIN_CHECK:
459 			if (isAgg)
460 				err = _("aggregate functions are not allowed in check constraints");
461 			else
462 				err = _("grouping operations are not allowed in check constraints");
463 
464 			break;
465 		case EXPR_KIND_COLUMN_DEFAULT:
466 		case EXPR_KIND_FUNCTION_DEFAULT:
467 
468 			if (isAgg)
469 				err = _("aggregate functions are not allowed in DEFAULT expressions");
470 			else
471 				err = _("grouping operations are not allowed in DEFAULT expressions");
472 
473 			break;
474 		case EXPR_KIND_INDEX_EXPRESSION:
475 			if (isAgg)
476 				err = _("aggregate functions are not allowed in index expressions");
477 			else
478 				err = _("grouping operations are not allowed in index expressions");
479 
480 			break;
481 		case EXPR_KIND_INDEX_PREDICATE:
482 			if (isAgg)
483 				err = _("aggregate functions are not allowed in index predicates");
484 			else
485 				err = _("grouping operations are not allowed in index predicates");
486 
487 			break;
488 		case EXPR_KIND_ALTER_COL_TRANSFORM:
489 			if (isAgg)
490 				err = _("aggregate functions are not allowed in transform expressions");
491 			else
492 				err = _("grouping operations are not allowed in transform expressions");
493 
494 			break;
495 		case EXPR_KIND_EXECUTE_PARAMETER:
496 			if (isAgg)
497 				err = _("aggregate functions are not allowed in EXECUTE parameters");
498 			else
499 				err = _("grouping operations are not allowed in EXECUTE parameters");
500 
501 			break;
502 		case EXPR_KIND_TRIGGER_WHEN:
503 			if (isAgg)
504 				err = _("aggregate functions are not allowed in trigger WHEN conditions");
505 			else
506 				err = _("grouping operations are not allowed in trigger WHEN conditions");
507 
508 			break;
509 		case EXPR_KIND_PARTITION_BOUND:
510 			if (isAgg)
511 				err = _("aggregate functions are not allowed in partition bound");
512 			else
513 				err = _("grouping operations are not allowed in partition bound");
514 
515 			break;
516 		case EXPR_KIND_PARTITION_EXPRESSION:
517 			if (isAgg)
518 				err = _("aggregate functions are not allowed in partition key expressions");
519 			else
520 				err = _("grouping operations are not allowed in partition key expressions");
521 
522 			break;
523 		case EXPR_KIND_GENERATED_COLUMN:
524 
525 			if (isAgg)
526 				err = _("aggregate functions are not allowed in column generation expressions");
527 			else
528 				err = _("grouping operations are not allowed in column generation expressions");
529 
530 			break;
531 
532 		case EXPR_KIND_CALL_ARGUMENT:
533 			if (isAgg)
534 				err = _("aggregate functions are not allowed in CALL arguments");
535 			else
536 				err = _("grouping operations are not allowed in CALL arguments");
537 
538 			break;
539 
540 		case EXPR_KIND_COPY_WHERE:
541 			if (isAgg)
542 				err = _("aggregate functions are not allowed in COPY FROM WHERE conditions");
543 			else
544 				err = _("grouping operations are not allowed in COPY FROM WHERE conditions");
545 
546 			break;
547 
548 			/*
549 			 * There is intentionally no default: case here, so that the
550 			 * compiler will warn if we add a new ParseExprKind without
551 			 * extending this switch.  If we do see an unrecognized value at
552 			 * runtime, the behavior will be the same as for EXPR_KIND_OTHER,
553 			 * which is sane anyway.
554 			 */
555 	}
556 
557 	if (err)
558 		ereport(ERROR,
559 				(errcode(ERRCODE_GROUPING_ERROR),
560 				 errmsg_internal("%s", err),
561 				 parser_errposition(pstate, location)));
562 
563 	if (errkind)
564 	{
565 		if (isAgg)
566 			/* translator: %s is name of a SQL construct, eg GROUP BY */
567 			err = _("aggregate functions are not allowed in %s");
568 		else
569 			/* translator: %s is name of a SQL construct, eg GROUP BY */
570 			err = _("grouping operations are not allowed in %s");
571 
572 		ereport(ERROR,
573 				(errcode(ERRCODE_GROUPING_ERROR),
574 				 errmsg_internal(err,
575 								 ParseExprKindName(pstate->p_expr_kind)),
576 				 parser_errposition(pstate, location)));
577 	}
578 }
579 
580 /*
581  * check_agg_arguments
582  *	  Scan the arguments of an aggregate function to determine the
583  *	  aggregate's semantic level (zero is the current select's level,
584  *	  one is its parent, etc).
585  *
586  * The aggregate's level is the same as the level of the lowest-level variable
587  * or aggregate in its aggregated arguments (including any ORDER BY columns)
588  * or filter expression; or if it contains no variables at all, we presume it
589  * to be local.
590  *
591  * Vars/Aggs in direct arguments are *not* counted towards determining the
592  * agg's level, as those arguments aren't evaluated per-row but only
593  * per-group, and so in some sense aren't really agg arguments.  However,
594  * this can mean that we decide an agg is upper-level even when its direct
595  * args contain lower-level Vars/Aggs, and that case has to be disallowed.
596  * (This is a little strange, but the SQL standard seems pretty definite that
597  * direct args are not to be considered when setting the agg's level.)
598  *
599  * We also take this opportunity to detect any aggregates or window functions
600  * nested within the arguments.  We can throw error immediately if we find
601  * a window function.  Aggregates are a bit trickier because it's only an
602  * error if the inner aggregate is of the same semantic level as the outer,
603  * which we can't know until we finish scanning the arguments.
604  */
605 static int
check_agg_arguments(ParseState * pstate,List * directargs,List * args,Expr * filter)606 check_agg_arguments(ParseState *pstate,
607 					List *directargs,
608 					List *args,
609 					Expr *filter)
610 {
611 	int			agglevel;
612 	check_agg_arguments_context context;
613 
614 	context.pstate = pstate;
615 	context.min_varlevel = -1;	/* signifies nothing found yet */
616 	context.min_agglevel = -1;
617 	context.sublevels_up = 0;
618 
619 	(void) check_agg_arguments_walker((Node *) args, &context);
620 	(void) check_agg_arguments_walker((Node *) filter, &context);
621 
622 	/*
623 	 * If we found no vars nor aggs at all, it's a level-zero aggregate;
624 	 * otherwise, its level is the minimum of vars or aggs.
625 	 */
626 	if (context.min_varlevel < 0)
627 	{
628 		if (context.min_agglevel < 0)
629 			agglevel = 0;
630 		else
631 			agglevel = context.min_agglevel;
632 	}
633 	else if (context.min_agglevel < 0)
634 		agglevel = context.min_varlevel;
635 	else
636 		agglevel = Min(context.min_varlevel, context.min_agglevel);
637 
638 	/*
639 	 * If there's a nested aggregate of the same semantic level, complain.
640 	 */
641 	if (agglevel == context.min_agglevel)
642 	{
643 		int			aggloc;
644 
645 		aggloc = locate_agg_of_level((Node *) args, agglevel);
646 		if (aggloc < 0)
647 			aggloc = locate_agg_of_level((Node *) filter, agglevel);
648 		ereport(ERROR,
649 				(errcode(ERRCODE_GROUPING_ERROR),
650 				 errmsg("aggregate function calls cannot be nested"),
651 				 parser_errposition(pstate, aggloc)));
652 	}
653 
654 	/*
655 	 * Now check for vars/aggs in the direct arguments, and throw error if
656 	 * needed.  Note that we allow a Var of the agg's semantic level, but not
657 	 * an Agg of that level.  In principle such Aggs could probably be
658 	 * supported, but it would create an ordering dependency among the
659 	 * aggregates at execution time.  Since the case appears neither to be
660 	 * required by spec nor particularly useful, we just treat it as a
661 	 * nested-aggregate situation.
662 	 */
663 	if (directargs)
664 	{
665 		context.min_varlevel = -1;
666 		context.min_agglevel = -1;
667 		(void) check_agg_arguments_walker((Node *) directargs, &context);
668 		if (context.min_varlevel >= 0 && context.min_varlevel < agglevel)
669 			ereport(ERROR,
670 					(errcode(ERRCODE_GROUPING_ERROR),
671 					 errmsg("outer-level aggregate cannot contain a lower-level variable in its direct arguments"),
672 					 parser_errposition(pstate,
673 										locate_var_of_level((Node *) directargs,
674 															context.min_varlevel))));
675 		if (context.min_agglevel >= 0 && context.min_agglevel <= agglevel)
676 			ereport(ERROR,
677 					(errcode(ERRCODE_GROUPING_ERROR),
678 					 errmsg("aggregate function calls cannot be nested"),
679 					 parser_errposition(pstate,
680 										locate_agg_of_level((Node *) directargs,
681 															context.min_agglevel))));
682 	}
683 	return agglevel;
684 }
685 
686 static bool
check_agg_arguments_walker(Node * node,check_agg_arguments_context * context)687 check_agg_arguments_walker(Node *node,
688 						   check_agg_arguments_context *context)
689 {
690 	if (node == NULL)
691 		return false;
692 	if (IsA(node, Var))
693 	{
694 		int			varlevelsup = ((Var *) node)->varlevelsup;
695 
696 		/* convert levelsup to frame of reference of original query */
697 		varlevelsup -= context->sublevels_up;
698 		/* ignore local vars of subqueries */
699 		if (varlevelsup >= 0)
700 		{
701 			if (context->min_varlevel < 0 ||
702 				context->min_varlevel > varlevelsup)
703 				context->min_varlevel = varlevelsup;
704 		}
705 		return false;
706 	}
707 	if (IsA(node, Aggref))
708 	{
709 		int			agglevelsup = ((Aggref *) node)->agglevelsup;
710 
711 		/* convert levelsup to frame of reference of original query */
712 		agglevelsup -= context->sublevels_up;
713 		/* ignore local aggs of subqueries */
714 		if (agglevelsup >= 0)
715 		{
716 			if (context->min_agglevel < 0 ||
717 				context->min_agglevel > agglevelsup)
718 				context->min_agglevel = agglevelsup;
719 		}
720 		/* no need to examine args of the inner aggregate */
721 		return false;
722 	}
723 	if (IsA(node, GroupingFunc))
724 	{
725 		int			agglevelsup = ((GroupingFunc *) node)->agglevelsup;
726 
727 		/* convert levelsup to frame of reference of original query */
728 		agglevelsup -= context->sublevels_up;
729 		/* ignore local aggs of subqueries */
730 		if (agglevelsup >= 0)
731 		{
732 			if (context->min_agglevel < 0 ||
733 				context->min_agglevel > agglevelsup)
734 				context->min_agglevel = agglevelsup;
735 		}
736 		/* Continue and descend into subtree */
737 	}
738 
739 	/*
740 	 * SRFs and window functions can be rejected immediately, unless we are
741 	 * within a sub-select within the aggregate's arguments; in that case
742 	 * they're OK.
743 	 */
744 	if (context->sublevels_up == 0)
745 	{
746 		if ((IsA(node, FuncExpr) &&((FuncExpr *) node)->funcretset) ||
747 			(IsA(node, OpExpr) &&((OpExpr *) node)->opretset))
748 			ereport(ERROR,
749 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
750 					 errmsg("aggregate function calls cannot contain set-returning function calls"),
751 					 errhint("You might be able to move the set-returning function into a LATERAL FROM item."),
752 					 parser_errposition(context->pstate, exprLocation(node))));
753 		if (IsA(node, WindowFunc))
754 			ereport(ERROR,
755 					(errcode(ERRCODE_GROUPING_ERROR),
756 					 errmsg("aggregate function calls cannot contain window function calls"),
757 					 parser_errposition(context->pstate,
758 										((WindowFunc *) node)->location)));
759 	}
760 	if (IsA(node, Query))
761 	{
762 		/* Recurse into subselects */
763 		bool		result;
764 
765 		context->sublevels_up++;
766 		result = query_tree_walker((Query *) node,
767 								   check_agg_arguments_walker,
768 								   (void *) context,
769 								   0);
770 		context->sublevels_up--;
771 		return result;
772 	}
773 
774 	return expression_tree_walker(node,
775 								  check_agg_arguments_walker,
776 								  (void *) context);
777 }
778 
779 /*
780  * transformWindowFuncCall -
781  *		Finish initial transformation of a window function call
782  *
783  * parse_func.c has recognized the function as a window function, and has set
784  * up all the fields of the WindowFunc except winref.  Here we must (1) add
785  * the WindowDef to the pstate (if not a duplicate of one already present) and
786  * set winref to link to it; and (2) mark p_hasWindowFuncs true in the pstate.
787  * Unlike aggregates, only the most closely nested pstate level need be
788  * considered --- there are no "outer window functions" per SQL spec.
789  */
790 void
transformWindowFuncCall(ParseState * pstate,WindowFunc * wfunc,WindowDef * windef)791 transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc,
792 						WindowDef *windef)
793 {
794 	const char *err;
795 	bool		errkind;
796 
797 	/*
798 	 * A window function call can't contain another one (but aggs are OK). XXX
799 	 * is this required by spec, or just an unimplemented feature?
800 	 *
801 	 * Note: we don't need to check the filter expression here, because the
802 	 * context checks done below and in transformAggregateCall would have
803 	 * already rejected any window funcs or aggs within the filter.
804 	 */
805 	if (pstate->p_hasWindowFuncs &&
806 		contain_windowfuncs((Node *) wfunc->args))
807 		ereport(ERROR,
808 				(errcode(ERRCODE_WINDOWING_ERROR),
809 				 errmsg("window function calls cannot be nested"),
810 				 parser_errposition(pstate,
811 									locate_windowfunc((Node *) wfunc->args))));
812 
813 	/*
814 	 * Check to see if the window function is in an invalid place within the
815 	 * query.
816 	 *
817 	 * For brevity we support two schemes for reporting an error here: set
818 	 * "err" to a custom message, or set "errkind" true if the error context
819 	 * is sufficiently identified by what ParseExprKindName will return, *and*
820 	 * what it will return is just a SQL keyword.  (Otherwise, use a custom
821 	 * message to avoid creating translation problems.)
822 	 */
823 	err = NULL;
824 	errkind = false;
825 	switch (pstate->p_expr_kind)
826 	{
827 		case EXPR_KIND_NONE:
828 			Assert(false);		/* can't happen */
829 			break;
830 		case EXPR_KIND_OTHER:
831 			/* Accept window func here; caller must throw error if wanted */
832 			break;
833 		case EXPR_KIND_JOIN_ON:
834 		case EXPR_KIND_JOIN_USING:
835 			err = _("window functions are not allowed in JOIN conditions");
836 			break;
837 		case EXPR_KIND_FROM_SUBSELECT:
838 			/* can't get here, but just in case, throw an error */
839 			errkind = true;
840 			break;
841 		case EXPR_KIND_FROM_FUNCTION:
842 			err = _("window functions are not allowed in functions in FROM");
843 			break;
844 		case EXPR_KIND_WHERE:
845 			errkind = true;
846 			break;
847 		case EXPR_KIND_POLICY:
848 			err = _("window functions are not allowed in policy expressions");
849 			break;
850 		case EXPR_KIND_HAVING:
851 			errkind = true;
852 			break;
853 		case EXPR_KIND_FILTER:
854 			errkind = true;
855 			break;
856 		case EXPR_KIND_WINDOW_PARTITION:
857 		case EXPR_KIND_WINDOW_ORDER:
858 		case EXPR_KIND_WINDOW_FRAME_RANGE:
859 		case EXPR_KIND_WINDOW_FRAME_ROWS:
860 		case EXPR_KIND_WINDOW_FRAME_GROUPS:
861 			err = _("window functions are not allowed in window definitions");
862 			break;
863 		case EXPR_KIND_SELECT_TARGET:
864 			/* okay */
865 			break;
866 		case EXPR_KIND_INSERT_TARGET:
867 		case EXPR_KIND_UPDATE_SOURCE:
868 		case EXPR_KIND_UPDATE_TARGET:
869 			errkind = true;
870 			break;
871 		case EXPR_KIND_GROUP_BY:
872 			errkind = true;
873 			break;
874 		case EXPR_KIND_ORDER_BY:
875 			/* okay */
876 			break;
877 		case EXPR_KIND_DISTINCT_ON:
878 			/* okay */
879 			break;
880 		case EXPR_KIND_LIMIT:
881 		case EXPR_KIND_OFFSET:
882 			errkind = true;
883 			break;
884 		case EXPR_KIND_RETURNING:
885 			errkind = true;
886 			break;
887 		case EXPR_KIND_VALUES:
888 		case EXPR_KIND_VALUES_SINGLE:
889 			errkind = true;
890 			break;
891 		case EXPR_KIND_CHECK_CONSTRAINT:
892 		case EXPR_KIND_DOMAIN_CHECK:
893 			err = _("window functions are not allowed in check constraints");
894 			break;
895 		case EXPR_KIND_COLUMN_DEFAULT:
896 		case EXPR_KIND_FUNCTION_DEFAULT:
897 			err = _("window functions are not allowed in DEFAULT expressions");
898 			break;
899 		case EXPR_KIND_INDEX_EXPRESSION:
900 			err = _("window functions are not allowed in index expressions");
901 			break;
902 		case EXPR_KIND_INDEX_PREDICATE:
903 			err = _("window functions are not allowed in index predicates");
904 			break;
905 		case EXPR_KIND_ALTER_COL_TRANSFORM:
906 			err = _("window functions are not allowed in transform expressions");
907 			break;
908 		case EXPR_KIND_EXECUTE_PARAMETER:
909 			err = _("window functions are not allowed in EXECUTE parameters");
910 			break;
911 		case EXPR_KIND_TRIGGER_WHEN:
912 			err = _("window functions are not allowed in trigger WHEN conditions");
913 			break;
914 		case EXPR_KIND_PARTITION_BOUND:
915 			err = _("window functions are not allowed in partition bound");
916 			break;
917 		case EXPR_KIND_PARTITION_EXPRESSION:
918 			err = _("window functions are not allowed in partition key expressions");
919 			break;
920 		case EXPR_KIND_CALL_ARGUMENT:
921 			err = _("window functions are not allowed in CALL arguments");
922 			break;
923 		case EXPR_KIND_COPY_WHERE:
924 			err = _("window functions are not allowed in COPY FROM WHERE conditions");
925 			break;
926 		case EXPR_KIND_GENERATED_COLUMN:
927 			err = _("window functions are not allowed in column generation expressions");
928 			break;
929 
930 			/*
931 			 * There is intentionally no default: case here, so that the
932 			 * compiler will warn if we add a new ParseExprKind without
933 			 * extending this switch.  If we do see an unrecognized value at
934 			 * runtime, the behavior will be the same as for EXPR_KIND_OTHER,
935 			 * which is sane anyway.
936 			 */
937 	}
938 	if (err)
939 		ereport(ERROR,
940 				(errcode(ERRCODE_WINDOWING_ERROR),
941 				 errmsg_internal("%s", err),
942 				 parser_errposition(pstate, wfunc->location)));
943 	if (errkind)
944 		ereport(ERROR,
945 				(errcode(ERRCODE_WINDOWING_ERROR),
946 		/* translator: %s is name of a SQL construct, eg GROUP BY */
947 				 errmsg("window functions are not allowed in %s",
948 						ParseExprKindName(pstate->p_expr_kind)),
949 				 parser_errposition(pstate, wfunc->location)));
950 
951 	/*
952 	 * If the OVER clause just specifies a window name, find that WINDOW
953 	 * clause (which had better be present).  Otherwise, try to match all the
954 	 * properties of the OVER clause, and make a new entry in the p_windowdefs
955 	 * list if no luck.
956 	 */
957 	if (windef->name)
958 	{
959 		Index		winref = 0;
960 		ListCell   *lc;
961 
962 		Assert(windef->refname == NULL &&
963 			   windef->partitionClause == NIL &&
964 			   windef->orderClause == NIL &&
965 			   windef->frameOptions == FRAMEOPTION_DEFAULTS);
966 
967 		foreach(lc, pstate->p_windowdefs)
968 		{
969 			WindowDef  *refwin = (WindowDef *) lfirst(lc);
970 
971 			winref++;
972 			if (refwin->name && strcmp(refwin->name, windef->name) == 0)
973 			{
974 				wfunc->winref = winref;
975 				break;
976 			}
977 		}
978 		if (lc == NULL)			/* didn't find it? */
979 			ereport(ERROR,
980 					(errcode(ERRCODE_UNDEFINED_OBJECT),
981 					 errmsg("window \"%s\" does not exist", windef->name),
982 					 parser_errposition(pstate, windef->location)));
983 	}
984 	else
985 	{
986 		Index		winref = 0;
987 		ListCell   *lc;
988 
989 		foreach(lc, pstate->p_windowdefs)
990 		{
991 			WindowDef  *refwin = (WindowDef *) lfirst(lc);
992 
993 			winref++;
994 			if (refwin->refname && windef->refname &&
995 				strcmp(refwin->refname, windef->refname) == 0)
996 				 /* matched on refname */ ;
997 			else if (!refwin->refname && !windef->refname)
998 				 /* matched, no refname */ ;
999 			else
1000 				continue;
1001 			if (equal(refwin->partitionClause, windef->partitionClause) &&
1002 				equal(refwin->orderClause, windef->orderClause) &&
1003 				refwin->frameOptions == windef->frameOptions &&
1004 				equal(refwin->startOffset, windef->startOffset) &&
1005 				equal(refwin->endOffset, windef->endOffset))
1006 			{
1007 				/* found a duplicate window specification */
1008 				wfunc->winref = winref;
1009 				break;
1010 			}
1011 		}
1012 		if (lc == NULL)			/* didn't find it? */
1013 		{
1014 			pstate->p_windowdefs = lappend(pstate->p_windowdefs, windef);
1015 			wfunc->winref = list_length(pstate->p_windowdefs);
1016 		}
1017 	}
1018 
1019 	pstate->p_hasWindowFuncs = true;
1020 }
1021 
1022 /*
1023  * parseCheckAggregates
1024  *	Check for aggregates where they shouldn't be and improper grouping.
1025  *	This function should be called after the target list and qualifications
1026  *	are finalized.
1027  *
1028  *	Misplaced aggregates are now mostly detected in transformAggregateCall,
1029  *	but it seems more robust to check for aggregates in recursive queries
1030  *	only after everything is finalized.  In any case it's hard to detect
1031  *	improper grouping on-the-fly, so we have to make another pass over the
1032  *	query for that.
1033  */
1034 void
parseCheckAggregates(ParseState * pstate,Query * qry)1035 parseCheckAggregates(ParseState *pstate, Query *qry)
1036 {
1037 	List	   *gset_common = NIL;
1038 	List	   *groupClauses = NIL;
1039 	List	   *groupClauseCommonVars = NIL;
1040 	bool		have_non_var_grouping;
1041 	List	   *func_grouped_rels = NIL;
1042 	ListCell   *l;
1043 	bool		hasJoinRTEs;
1044 	bool		hasSelfRefRTEs;
1045 	Node	   *clause;
1046 
1047 	/* This should only be called if we found aggregates or grouping */
1048 	Assert(pstate->p_hasAggs || qry->groupClause || qry->havingQual || qry->groupingSets);
1049 
1050 	/*
1051 	 * If we have grouping sets, expand them and find the intersection of all
1052 	 * sets.
1053 	 */
1054 	if (qry->groupingSets)
1055 	{
1056 		/*
1057 		 * The limit of 4096 is arbitrary and exists simply to avoid resource
1058 		 * issues from pathological constructs.
1059 		 */
1060 		List	   *gsets = expand_grouping_sets(qry->groupingSets, 4096);
1061 
1062 		if (!gsets)
1063 			ereport(ERROR,
1064 					(errcode(ERRCODE_STATEMENT_TOO_COMPLEX),
1065 					 errmsg("too many grouping sets present (maximum 4096)"),
1066 					 parser_errposition(pstate,
1067 										qry->groupClause
1068 										? exprLocation((Node *) qry->groupClause)
1069 										: exprLocation((Node *) qry->groupingSets))));
1070 
1071 		/*
1072 		 * The intersection will often be empty, so help things along by
1073 		 * seeding the intersect with the smallest set.
1074 		 */
1075 		gset_common = linitial(gsets);
1076 
1077 		if (gset_common)
1078 		{
1079 			for_each_cell(l, lnext(list_head(gsets)))
1080 			{
1081 				gset_common = list_intersection_int(gset_common, lfirst(l));
1082 				if (!gset_common)
1083 					break;
1084 			}
1085 		}
1086 
1087 		/*
1088 		 * If there was only one grouping set in the expansion, AND if the
1089 		 * groupClause is non-empty (meaning that the grouping set is not
1090 		 * empty either), then we can ditch the grouping set and pretend we
1091 		 * just had a normal GROUP BY.
1092 		 */
1093 		if (list_length(gsets) == 1 && qry->groupClause)
1094 			qry->groupingSets = NIL;
1095 	}
1096 
1097 	/*
1098 	 * Scan the range table to see if there are JOIN or self-reference CTE
1099 	 * entries.  We'll need this info below.
1100 	 */
1101 	hasJoinRTEs = hasSelfRefRTEs = false;
1102 	foreach(l, pstate->p_rtable)
1103 	{
1104 		RangeTblEntry *rte = (RangeTblEntry *) lfirst(l);
1105 
1106 		if (rte->rtekind == RTE_JOIN)
1107 			hasJoinRTEs = true;
1108 		else if (rte->rtekind == RTE_CTE && rte->self_reference)
1109 			hasSelfRefRTEs = true;
1110 	}
1111 
1112 	/*
1113 	 * Build a list of the acceptable GROUP BY expressions for use by
1114 	 * check_ungrouped_columns().
1115 	 *
1116 	 * We get the TLE, not just the expr, because GROUPING wants to know the
1117 	 * sortgroupref.
1118 	 */
1119 	foreach(l, qry->groupClause)
1120 	{
1121 		SortGroupClause *grpcl = (SortGroupClause *) lfirst(l);
1122 		TargetEntry *expr;
1123 
1124 		expr = get_sortgroupclause_tle(grpcl, qry->targetList);
1125 		if (expr == NULL)
1126 			continue;			/* probably cannot happen */
1127 
1128 		groupClauses = lcons(expr, groupClauses);
1129 	}
1130 
1131 	/*
1132 	 * If there are join alias vars involved, we have to flatten them to the
1133 	 * underlying vars, so that aliased and unaliased vars will be correctly
1134 	 * taken as equal.  We can skip the expense of doing this if no rangetable
1135 	 * entries are RTE_JOIN kind.
1136 	 */
1137 	if (hasJoinRTEs)
1138 		groupClauses = (List *) flatten_join_alias_vars(qry,
1139 														(Node *) groupClauses);
1140 
1141 	/*
1142 	 * Detect whether any of the grouping expressions aren't simple Vars; if
1143 	 * they're all Vars then we don't have to work so hard in the recursive
1144 	 * scans.  (Note we have to flatten aliases before this.)
1145 	 *
1146 	 * Track Vars that are included in all grouping sets separately in
1147 	 * groupClauseCommonVars, since these are the only ones we can use to
1148 	 * check for functional dependencies.
1149 	 */
1150 	have_non_var_grouping = false;
1151 	foreach(l, groupClauses)
1152 	{
1153 		TargetEntry *tle = lfirst(l);
1154 
1155 		if (!IsA(tle->expr, Var))
1156 		{
1157 			have_non_var_grouping = true;
1158 		}
1159 		else if (!qry->groupingSets ||
1160 				 list_member_int(gset_common, tle->ressortgroupref))
1161 		{
1162 			groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr);
1163 		}
1164 	}
1165 
1166 	/*
1167 	 * Check the targetlist and HAVING clause for ungrouped variables.
1168 	 *
1169 	 * Note: because we check resjunk tlist elements as well as regular ones,
1170 	 * this will also find ungrouped variables that came from ORDER BY and
1171 	 * WINDOW clauses.  For that matter, it's also going to examine the
1172 	 * grouping expressions themselves --- but they'll all pass the test ...
1173 	 *
1174 	 * We also finalize GROUPING expressions, but for that we need to traverse
1175 	 * the original (unflattened) clause in order to modify nodes.
1176 	 */
1177 	clause = (Node *) qry->targetList;
1178 	finalize_grouping_exprs(clause, pstate, qry,
1179 							groupClauses, hasJoinRTEs,
1180 							have_non_var_grouping);
1181 	if (hasJoinRTEs)
1182 		clause = flatten_join_alias_vars(qry, clause);
1183 	check_ungrouped_columns(clause, pstate, qry,
1184 							groupClauses, groupClauseCommonVars,
1185 							have_non_var_grouping,
1186 							&func_grouped_rels);
1187 
1188 	clause = (Node *) qry->havingQual;
1189 	finalize_grouping_exprs(clause, pstate, qry,
1190 							groupClauses, hasJoinRTEs,
1191 							have_non_var_grouping);
1192 	if (hasJoinRTEs)
1193 		clause = flatten_join_alias_vars(qry, clause);
1194 	check_ungrouped_columns(clause, pstate, qry,
1195 							groupClauses, groupClauseCommonVars,
1196 							have_non_var_grouping,
1197 							&func_grouped_rels);
1198 
1199 	/*
1200 	 * Per spec, aggregates can't appear in a recursive term.
1201 	 */
1202 	if (pstate->p_hasAggs && hasSelfRefRTEs)
1203 		ereport(ERROR,
1204 				(errcode(ERRCODE_INVALID_RECURSION),
1205 				 errmsg("aggregate functions are not allowed in a recursive query's recursive term"),
1206 				 parser_errposition(pstate,
1207 									locate_agg_of_level((Node *) qry, 0))));
1208 }
1209 
1210 /*
1211  * check_ungrouped_columns -
1212  *	  Scan the given expression tree for ungrouped variables (variables
1213  *	  that are not listed in the groupClauses list and are not within
1214  *	  the arguments of aggregate functions).  Emit a suitable error message
1215  *	  if any are found.
1216  *
1217  * NOTE: we assume that the given clause has been transformed suitably for
1218  * parser output.  This means we can use expression_tree_walker.
1219  *
1220  * NOTE: we recognize grouping expressions in the main query, but only
1221  * grouping Vars in subqueries.  For example, this will be rejected,
1222  * although it could be allowed:
1223  *		SELECT
1224  *			(SELECT x FROM bar where y = (foo.a + foo.b))
1225  *		FROM foo
1226  *		GROUP BY a + b;
1227  * The difficulty is the need to account for different sublevels_up.
1228  * This appears to require a whole custom version of equal(), which is
1229  * way more pain than the feature seems worth.
1230  */
1231 static void
check_ungrouped_columns(Node * node,ParseState * pstate,Query * qry,List * groupClauses,List * groupClauseCommonVars,bool have_non_var_grouping,List ** func_grouped_rels)1232 check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
1233 						List *groupClauses, List *groupClauseCommonVars,
1234 						bool have_non_var_grouping,
1235 						List **func_grouped_rels)
1236 {
1237 	check_ungrouped_columns_context context;
1238 
1239 	context.pstate = pstate;
1240 	context.qry = qry;
1241 	context.hasJoinRTEs = false;	/* assume caller flattened join Vars */
1242 	context.groupClauses = groupClauses;
1243 	context.groupClauseCommonVars = groupClauseCommonVars;
1244 	context.have_non_var_grouping = have_non_var_grouping;
1245 	context.func_grouped_rels = func_grouped_rels;
1246 	context.sublevels_up = 0;
1247 	context.in_agg_direct_args = false;
1248 	check_ungrouped_columns_walker(node, &context);
1249 }
1250 
1251 static bool
check_ungrouped_columns_walker(Node * node,check_ungrouped_columns_context * context)1252 check_ungrouped_columns_walker(Node *node,
1253 							   check_ungrouped_columns_context *context)
1254 {
1255 	ListCell   *gl;
1256 
1257 	if (node == NULL)
1258 		return false;
1259 	if (IsA(node, Const) ||
1260 		IsA(node, Param))
1261 		return false;			/* constants are always acceptable */
1262 
1263 	if (IsA(node, Aggref))
1264 	{
1265 		Aggref	   *agg = (Aggref *) node;
1266 
1267 		if ((int) agg->agglevelsup == context->sublevels_up)
1268 		{
1269 			/*
1270 			 * If we find an aggregate call of the original level, do not
1271 			 * recurse into its normal arguments, ORDER BY arguments, or
1272 			 * filter; ungrouped vars there are not an error.  But we should
1273 			 * check direct arguments as though they weren't in an aggregate.
1274 			 * We set a special flag in the context to help produce a useful
1275 			 * error message for ungrouped vars in direct arguments.
1276 			 */
1277 			bool		result;
1278 
1279 			Assert(!context->in_agg_direct_args);
1280 			context->in_agg_direct_args = true;
1281 			result = check_ungrouped_columns_walker((Node *) agg->aggdirectargs,
1282 													context);
1283 			context->in_agg_direct_args = false;
1284 			return result;
1285 		}
1286 
1287 		/*
1288 		 * We can skip recursing into aggregates of higher levels altogether,
1289 		 * since they could not possibly contain Vars of concern to us (see
1290 		 * transformAggregateCall).  We do need to look at aggregates of lower
1291 		 * levels, however.
1292 		 */
1293 		if ((int) agg->agglevelsup > context->sublevels_up)
1294 			return false;
1295 	}
1296 
1297 	if (IsA(node, GroupingFunc))
1298 	{
1299 		GroupingFunc *grp = (GroupingFunc *) node;
1300 
1301 		/* handled GroupingFunc separately, no need to recheck at this level */
1302 
1303 		if ((int) grp->agglevelsup >= context->sublevels_up)
1304 			return false;
1305 	}
1306 
1307 	/*
1308 	 * If we have any GROUP BY items that are not simple Vars, check to see if
1309 	 * subexpression as a whole matches any GROUP BY item. We need to do this
1310 	 * at every recursion level so that we recognize GROUPed-BY expressions
1311 	 * before reaching variables within them. But this only works at the outer
1312 	 * query level, as noted above.
1313 	 */
1314 	if (context->have_non_var_grouping && context->sublevels_up == 0)
1315 	{
1316 		foreach(gl, context->groupClauses)
1317 		{
1318 			TargetEntry *tle = lfirst(gl);
1319 
1320 			if (equal(node, tle->expr))
1321 				return false;	/* acceptable, do not descend more */
1322 		}
1323 	}
1324 
1325 	/*
1326 	 * If we have an ungrouped Var of the original query level, we have a
1327 	 * failure.  Vars below the original query level are not a problem, and
1328 	 * neither are Vars from above it.  (If such Vars are ungrouped as far as
1329 	 * their own query level is concerned, that's someone else's problem...)
1330 	 */
1331 	if (IsA(node, Var))
1332 	{
1333 		Var		   *var = (Var *) node;
1334 		RangeTblEntry *rte;
1335 		char	   *attname;
1336 
1337 		if (var->varlevelsup != context->sublevels_up)
1338 			return false;		/* it's not local to my query, ignore */
1339 
1340 		/*
1341 		 * Check for a match, if we didn't do it above.
1342 		 */
1343 		if (!context->have_non_var_grouping || context->sublevels_up != 0)
1344 		{
1345 			foreach(gl, context->groupClauses)
1346 			{
1347 				Var		   *gvar = (Var *) ((TargetEntry *) lfirst(gl))->expr;
1348 
1349 				if (IsA(gvar, Var) &&
1350 					gvar->varno == var->varno &&
1351 					gvar->varattno == var->varattno &&
1352 					gvar->varlevelsup == 0)
1353 					return false;	/* acceptable, we're okay */
1354 			}
1355 		}
1356 
1357 		/*
1358 		 * Check whether the Var is known functionally dependent on the GROUP
1359 		 * BY columns.  If so, we can allow the Var to be used, because the
1360 		 * grouping is really a no-op for this table.  However, this deduction
1361 		 * depends on one or more constraints of the table, so we have to add
1362 		 * those constraints to the query's constraintDeps list, because it's
1363 		 * not semantically valid anymore if the constraint(s) get dropped.
1364 		 * (Therefore, this check must be the last-ditch effort before raising
1365 		 * error: we don't want to add dependencies unnecessarily.)
1366 		 *
1367 		 * Because this is a pretty expensive check, and will have the same
1368 		 * outcome for all columns of a table, we remember which RTEs we've
1369 		 * already proven functional dependency for in the func_grouped_rels
1370 		 * list.  This test also prevents us from adding duplicate entries to
1371 		 * the constraintDeps list.
1372 		 */
1373 		if (list_member_int(*context->func_grouped_rels, var->varno))
1374 			return false;		/* previously proven acceptable */
1375 
1376 		Assert(var->varno > 0 &&
1377 			   (int) var->varno <= list_length(context->pstate->p_rtable));
1378 		rte = rt_fetch(var->varno, context->pstate->p_rtable);
1379 		if (rte->rtekind == RTE_RELATION)
1380 		{
1381 			if (check_functional_grouping(rte->relid,
1382 										  var->varno,
1383 										  0,
1384 										  context->groupClauseCommonVars,
1385 										  &context->qry->constraintDeps))
1386 			{
1387 				*context->func_grouped_rels =
1388 					lappend_int(*context->func_grouped_rels, var->varno);
1389 				return false;	/* acceptable */
1390 			}
1391 		}
1392 
1393 		/* Found an ungrouped local variable; generate error message */
1394 		attname = get_rte_attribute_name(rte, var->varattno);
1395 		if (context->sublevels_up == 0)
1396 			ereport(ERROR,
1397 					(errcode(ERRCODE_GROUPING_ERROR),
1398 					 errmsg("column \"%s.%s\" must appear in the GROUP BY clause or be used in an aggregate function",
1399 							rte->eref->aliasname, attname),
1400 					 context->in_agg_direct_args ?
1401 					 errdetail("Direct arguments of an ordered-set aggregate must use only grouped columns.") : 0,
1402 					 parser_errposition(context->pstate, var->location)));
1403 		else
1404 			ereport(ERROR,
1405 					(errcode(ERRCODE_GROUPING_ERROR),
1406 					 errmsg("subquery uses ungrouped column \"%s.%s\" from outer query",
1407 							rte->eref->aliasname, attname),
1408 					 parser_errposition(context->pstate, var->location)));
1409 	}
1410 
1411 	if (IsA(node, Query))
1412 	{
1413 		/* Recurse into subselects */
1414 		bool		result;
1415 
1416 		context->sublevels_up++;
1417 		result = query_tree_walker((Query *) node,
1418 								   check_ungrouped_columns_walker,
1419 								   (void *) context,
1420 								   0);
1421 		context->sublevels_up--;
1422 		return result;
1423 	}
1424 	return expression_tree_walker(node, check_ungrouped_columns_walker,
1425 								  (void *) context);
1426 }
1427 
1428 /*
1429  * finalize_grouping_exprs -
1430  *	  Scan the given expression tree for GROUPING() and related calls,
1431  *	  and validate and process their arguments.
1432  *
1433  * This is split out from check_ungrouped_columns above because it needs
1434  * to modify the nodes (which it does in-place, not via a mutator) while
1435  * check_ungrouped_columns may see only a copy of the original thanks to
1436  * flattening of join alias vars. So here, we flatten each individual
1437  * GROUPING argument as we see it before comparing it.
1438  */
1439 static void
finalize_grouping_exprs(Node * node,ParseState * pstate,Query * qry,List * groupClauses,bool hasJoinRTEs,bool have_non_var_grouping)1440 finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
1441 						List *groupClauses, bool hasJoinRTEs,
1442 						bool have_non_var_grouping)
1443 {
1444 	check_ungrouped_columns_context context;
1445 
1446 	context.pstate = pstate;
1447 	context.qry = qry;
1448 	context.hasJoinRTEs = hasJoinRTEs;
1449 	context.groupClauses = groupClauses;
1450 	context.groupClauseCommonVars = NIL;
1451 	context.have_non_var_grouping = have_non_var_grouping;
1452 	context.func_grouped_rels = NULL;
1453 	context.sublevels_up = 0;
1454 	context.in_agg_direct_args = false;
1455 	finalize_grouping_exprs_walker(node, &context);
1456 }
1457 
1458 static bool
finalize_grouping_exprs_walker(Node * node,check_ungrouped_columns_context * context)1459 finalize_grouping_exprs_walker(Node *node,
1460 							   check_ungrouped_columns_context *context)
1461 {
1462 	ListCell   *gl;
1463 
1464 	if (node == NULL)
1465 		return false;
1466 	if (IsA(node, Const) ||
1467 		IsA(node, Param))
1468 		return false;			/* constants are always acceptable */
1469 
1470 	if (IsA(node, Aggref))
1471 	{
1472 		Aggref	   *agg = (Aggref *) node;
1473 
1474 		if ((int) agg->agglevelsup == context->sublevels_up)
1475 		{
1476 			/*
1477 			 * If we find an aggregate call of the original level, do not
1478 			 * recurse into its normal arguments, ORDER BY arguments, or
1479 			 * filter; GROUPING exprs of this level are not allowed there. But
1480 			 * check direct arguments as though they weren't in an aggregate.
1481 			 */
1482 			bool		result;
1483 
1484 			Assert(!context->in_agg_direct_args);
1485 			context->in_agg_direct_args = true;
1486 			result = finalize_grouping_exprs_walker((Node *) agg->aggdirectargs,
1487 													context);
1488 			context->in_agg_direct_args = false;
1489 			return result;
1490 		}
1491 
1492 		/*
1493 		 * We can skip recursing into aggregates of higher levels altogether,
1494 		 * since they could not possibly contain exprs of concern to us (see
1495 		 * transformAggregateCall).  We do need to look at aggregates of lower
1496 		 * levels, however.
1497 		 */
1498 		if ((int) agg->agglevelsup > context->sublevels_up)
1499 			return false;
1500 	}
1501 
1502 	if (IsA(node, GroupingFunc))
1503 	{
1504 		GroupingFunc *grp = (GroupingFunc *) node;
1505 
1506 		/*
1507 		 * We only need to check GroupingFunc nodes at the exact level to
1508 		 * which they belong, since they cannot mix levels in arguments.
1509 		 */
1510 
1511 		if ((int) grp->agglevelsup == context->sublevels_up)
1512 		{
1513 			ListCell   *lc;
1514 			List	   *ref_list = NIL;
1515 
1516 			foreach(lc, grp->args)
1517 			{
1518 				Node	   *expr = lfirst(lc);
1519 				Index		ref = 0;
1520 
1521 				if (context->hasJoinRTEs)
1522 					expr = flatten_join_alias_vars(context->qry, expr);
1523 
1524 				/*
1525 				 * Each expression must match a grouping entry at the current
1526 				 * query level. Unlike the general expression case, we don't
1527 				 * allow functional dependencies or outer references.
1528 				 */
1529 
1530 				if (IsA(expr, Var))
1531 				{
1532 					Var		   *var = (Var *) expr;
1533 
1534 					if (var->varlevelsup == context->sublevels_up)
1535 					{
1536 						foreach(gl, context->groupClauses)
1537 						{
1538 							TargetEntry *tle = lfirst(gl);
1539 							Var		   *gvar = (Var *) tle->expr;
1540 
1541 							if (IsA(gvar, Var) &&
1542 								gvar->varno == var->varno &&
1543 								gvar->varattno == var->varattno &&
1544 								gvar->varlevelsup == 0)
1545 							{
1546 								ref = tle->ressortgroupref;
1547 								break;
1548 							}
1549 						}
1550 					}
1551 				}
1552 				else if (context->have_non_var_grouping &&
1553 						 context->sublevels_up == 0)
1554 				{
1555 					foreach(gl, context->groupClauses)
1556 					{
1557 						TargetEntry *tle = lfirst(gl);
1558 
1559 						if (equal(expr, tle->expr))
1560 						{
1561 							ref = tle->ressortgroupref;
1562 							break;
1563 						}
1564 					}
1565 				}
1566 
1567 				if (ref == 0)
1568 					ereport(ERROR,
1569 							(errcode(ERRCODE_GROUPING_ERROR),
1570 							 errmsg("arguments to GROUPING must be grouping expressions of the associated query level"),
1571 							 parser_errposition(context->pstate,
1572 												exprLocation(expr))));
1573 
1574 				ref_list = lappend_int(ref_list, ref);
1575 			}
1576 
1577 			grp->refs = ref_list;
1578 		}
1579 
1580 		if ((int) grp->agglevelsup > context->sublevels_up)
1581 			return false;
1582 	}
1583 
1584 	if (IsA(node, Query))
1585 	{
1586 		/* Recurse into subselects */
1587 		bool		result;
1588 
1589 		context->sublevels_up++;
1590 		result = query_tree_walker((Query *) node,
1591 								   finalize_grouping_exprs_walker,
1592 								   (void *) context,
1593 								   0);
1594 		context->sublevels_up--;
1595 		return result;
1596 	}
1597 	return expression_tree_walker(node, finalize_grouping_exprs_walker,
1598 								  (void *) context);
1599 }
1600 
1601 
1602 /*
1603  * Given a GroupingSet node, expand it and return a list of lists.
1604  *
1605  * For EMPTY nodes, return a list of one empty list.
1606  *
1607  * For SIMPLE nodes, return a list of one list, which is the node content.
1608  *
1609  * For CUBE and ROLLUP nodes, return a list of the expansions.
1610  *
1611  * For SET nodes, recursively expand contained CUBE and ROLLUP.
1612  */
1613 static List *
expand_groupingset_node(GroupingSet * gs)1614 expand_groupingset_node(GroupingSet *gs)
1615 {
1616 	List	   *result = NIL;
1617 
1618 	switch (gs->kind)
1619 	{
1620 		case GROUPING_SET_EMPTY:
1621 			result = list_make1(NIL);
1622 			break;
1623 
1624 		case GROUPING_SET_SIMPLE:
1625 			result = list_make1(gs->content);
1626 			break;
1627 
1628 		case GROUPING_SET_ROLLUP:
1629 			{
1630 				List	   *rollup_val = gs->content;
1631 				ListCell   *lc;
1632 				int			curgroup_size = list_length(gs->content);
1633 
1634 				while (curgroup_size > 0)
1635 				{
1636 					List	   *current_result = NIL;
1637 					int			i = curgroup_size;
1638 
1639 					foreach(lc, rollup_val)
1640 					{
1641 						GroupingSet *gs_current = (GroupingSet *) lfirst(lc);
1642 
1643 						Assert(gs_current->kind == GROUPING_SET_SIMPLE);
1644 
1645 						current_result
1646 							= list_concat(current_result,
1647 										  list_copy(gs_current->content));
1648 
1649 						/* If we are done with making the current group, break */
1650 						if (--i == 0)
1651 							break;
1652 					}
1653 
1654 					result = lappend(result, current_result);
1655 					--curgroup_size;
1656 				}
1657 
1658 				result = lappend(result, NIL);
1659 			}
1660 			break;
1661 
1662 		case GROUPING_SET_CUBE:
1663 			{
1664 				List	   *cube_list = gs->content;
1665 				int			number_bits = list_length(cube_list);
1666 				uint32		num_sets;
1667 				uint32		i;
1668 
1669 				/* parser should cap this much lower */
1670 				Assert(number_bits < 31);
1671 
1672 				num_sets = (1U << number_bits);
1673 
1674 				for (i = 0; i < num_sets; i++)
1675 				{
1676 					List	   *current_result = NIL;
1677 					ListCell   *lc;
1678 					uint32		mask = 1U;
1679 
1680 					foreach(lc, cube_list)
1681 					{
1682 						GroupingSet *gs_current = (GroupingSet *) lfirst(lc);
1683 
1684 						Assert(gs_current->kind == GROUPING_SET_SIMPLE);
1685 
1686 						if (mask & i)
1687 						{
1688 							current_result
1689 								= list_concat(current_result,
1690 											  list_copy(gs_current->content));
1691 						}
1692 
1693 						mask <<= 1;
1694 					}
1695 
1696 					result = lappend(result, current_result);
1697 				}
1698 			}
1699 			break;
1700 
1701 		case GROUPING_SET_SETS:
1702 			{
1703 				ListCell   *lc;
1704 
1705 				foreach(lc, gs->content)
1706 				{
1707 					List	   *current_result = expand_groupingset_node(lfirst(lc));
1708 
1709 					result = list_concat(result, current_result);
1710 				}
1711 			}
1712 			break;
1713 	}
1714 
1715 	return result;
1716 }
1717 
1718 static int
cmp_list_len_asc(const void * a,const void * b)1719 cmp_list_len_asc(const void *a, const void *b)
1720 {
1721 	int			la = list_length(*(List *const *) a);
1722 	int			lb = list_length(*(List *const *) b);
1723 
1724 	return (la > lb) ? 1 : (la == lb) ? 0 : -1;
1725 }
1726 
1727 /*
1728  * Expand a groupingSets clause to a flat list of grouping sets.
1729  * The returned list is sorted by length, shortest sets first.
1730  *
1731  * This is mainly for the planner, but we use it here too to do
1732  * some consistency checks.
1733  */
1734 List *
expand_grouping_sets(List * groupingSets,int limit)1735 expand_grouping_sets(List *groupingSets, int limit)
1736 {
1737 	List	   *expanded_groups = NIL;
1738 	List	   *result = NIL;
1739 	double		numsets = 1;
1740 	ListCell   *lc;
1741 
1742 	if (groupingSets == NIL)
1743 		return NIL;
1744 
1745 	foreach(lc, groupingSets)
1746 	{
1747 		List	   *current_result = NIL;
1748 		GroupingSet *gs = lfirst(lc);
1749 
1750 		current_result = expand_groupingset_node(gs);
1751 
1752 		Assert(current_result != NIL);
1753 
1754 		numsets *= list_length(current_result);
1755 
1756 		if (limit >= 0 && numsets > limit)
1757 			return NIL;
1758 
1759 		expanded_groups = lappend(expanded_groups, current_result);
1760 	}
1761 
1762 	/*
1763 	 * Do cartesian product between sublists of expanded_groups. While at it,
1764 	 * remove any duplicate elements from individual grouping sets (we must
1765 	 * NOT change the number of sets though)
1766 	 */
1767 
1768 	foreach(lc, (List *) linitial(expanded_groups))
1769 	{
1770 		result = lappend(result, list_union_int(NIL, (List *) lfirst(lc)));
1771 	}
1772 
1773 	for_each_cell(lc, lnext(list_head(expanded_groups)))
1774 	{
1775 		List	   *p = lfirst(lc);
1776 		List	   *new_result = NIL;
1777 		ListCell   *lc2;
1778 
1779 		foreach(lc2, result)
1780 		{
1781 			List	   *q = lfirst(lc2);
1782 			ListCell   *lc3;
1783 
1784 			foreach(lc3, p)
1785 			{
1786 				new_result = lappend(new_result,
1787 									 list_union_int(q, (List *) lfirst(lc3)));
1788 			}
1789 		}
1790 		result = new_result;
1791 	}
1792 
1793 	if (list_length(result) > 1)
1794 	{
1795 		int			result_len = list_length(result);
1796 		List	  **buf = palloc(sizeof(List *) * result_len);
1797 		List	  **ptr = buf;
1798 
1799 		foreach(lc, result)
1800 		{
1801 			*ptr++ = lfirst(lc);
1802 		}
1803 
1804 		qsort(buf, result_len, sizeof(List *), cmp_list_len_asc);
1805 
1806 		result = NIL;
1807 		ptr = buf;
1808 
1809 		while (result_len-- > 0)
1810 			result = lappend(result, *ptr++);
1811 
1812 		pfree(buf);
1813 	}
1814 
1815 	return result;
1816 }
1817 
1818 /*
1819  * get_aggregate_argtypes
1820  *	Identify the specific datatypes passed to an aggregate call.
1821  *
1822  * Given an Aggref, extract the actual datatypes of the input arguments.
1823  * The input datatypes are reported in a way that matches up with the
1824  * aggregate's declaration, ie, any ORDER BY columns attached to a plain
1825  * aggregate are ignored, but we report both direct and aggregated args of
1826  * an ordered-set aggregate.
1827  *
1828  * Datatypes are returned into inputTypes[], which must reference an array
1829  * of length FUNC_MAX_ARGS.
1830  *
1831  * The function result is the number of actual arguments.
1832  */
1833 int
get_aggregate_argtypes(Aggref * aggref,Oid * inputTypes)1834 get_aggregate_argtypes(Aggref *aggref, Oid *inputTypes)
1835 {
1836 	int			numArguments = 0;
1837 	ListCell   *lc;
1838 
1839 	Assert(list_length(aggref->aggargtypes) <= FUNC_MAX_ARGS);
1840 
1841 	foreach(lc, aggref->aggargtypes)
1842 	{
1843 		inputTypes[numArguments++] = lfirst_oid(lc);
1844 	}
1845 
1846 	return numArguments;
1847 }
1848 
1849 /*
1850  * resolve_aggregate_transtype
1851  *	Identify the transition state value's datatype for an aggregate call.
1852  *
1853  * This function resolves a polymorphic aggregate's state datatype.
1854  * It must be passed the aggtranstype from the aggregate's catalog entry,
1855  * as well as the actual argument types extracted by get_aggregate_argtypes.
1856  * (We could fetch pg_aggregate.aggtranstype internally, but all existing
1857  * callers already have the value at hand, so we make them pass it.)
1858  */
1859 Oid
resolve_aggregate_transtype(Oid aggfuncid,Oid aggtranstype,Oid * inputTypes,int numArguments)1860 resolve_aggregate_transtype(Oid aggfuncid,
1861 							Oid aggtranstype,
1862 							Oid *inputTypes,
1863 							int numArguments)
1864 {
1865 	/* resolve actual type of transition state, if polymorphic */
1866 	if (IsPolymorphicType(aggtranstype))
1867 	{
1868 		/* have to fetch the agg's declared input types... */
1869 		Oid		   *declaredArgTypes;
1870 		int			agg_nargs;
1871 
1872 		(void) get_func_signature(aggfuncid, &declaredArgTypes, &agg_nargs);
1873 
1874 		/*
1875 		 * VARIADIC ANY aggs could have more actual than declared args, but
1876 		 * such extra args can't affect polymorphic type resolution.
1877 		 */
1878 		Assert(agg_nargs <= numArguments);
1879 
1880 		aggtranstype = enforce_generic_type_consistency(inputTypes,
1881 														declaredArgTypes,
1882 														agg_nargs,
1883 														aggtranstype,
1884 														false);
1885 		pfree(declaredArgTypes);
1886 	}
1887 	return aggtranstype;
1888 }
1889 
1890 /*
1891  * Create an expression tree for the transition function of an aggregate.
1892  * This is needed so that polymorphic functions can be used within an
1893  * aggregate --- without the expression tree, such functions would not know
1894  * the datatypes they are supposed to use.  (The trees will never actually
1895  * be executed, however, so we can skimp a bit on correctness.)
1896  *
1897  * agg_input_types and agg_state_type identifies the input types of the
1898  * aggregate.  These should be resolved to actual types (ie, none should
1899  * ever be ANYELEMENT etc).
1900  * agg_input_collation is the aggregate function's input collation.
1901  *
1902  * For an ordered-set aggregate, remember that agg_input_types describes
1903  * the direct arguments followed by the aggregated arguments.
1904  *
1905  * transfn_oid and invtransfn_oid identify the funcs to be called; the
1906  * latter may be InvalidOid, however if invtransfn_oid is set then
1907  * transfn_oid must also be set.
1908  *
1909  * Pointers to the constructed trees are returned into *transfnexpr,
1910  * *invtransfnexpr. If there is no invtransfn, the respective pointer is set
1911  * to NULL.  Since use of the invtransfn is optional, NULL may be passed for
1912  * invtransfnexpr.
1913  */
1914 void
build_aggregate_transfn_expr(Oid * agg_input_types,int agg_num_inputs,int agg_num_direct_inputs,bool agg_variadic,Oid agg_state_type,Oid agg_input_collation,Oid transfn_oid,Oid invtransfn_oid,Expr ** transfnexpr,Expr ** invtransfnexpr)1915 build_aggregate_transfn_expr(Oid *agg_input_types,
1916 							 int agg_num_inputs,
1917 							 int agg_num_direct_inputs,
1918 							 bool agg_variadic,
1919 							 Oid agg_state_type,
1920 							 Oid agg_input_collation,
1921 							 Oid transfn_oid,
1922 							 Oid invtransfn_oid,
1923 							 Expr **transfnexpr,
1924 							 Expr **invtransfnexpr)
1925 {
1926 	List	   *args;
1927 	FuncExpr   *fexpr;
1928 	int			i;
1929 
1930 	/*
1931 	 * Build arg list to use in the transfn FuncExpr node.
1932 	 */
1933 	args = list_make1(make_agg_arg(agg_state_type, agg_input_collation));
1934 
1935 	for (i = agg_num_direct_inputs; i < agg_num_inputs; i++)
1936 	{
1937 		args = lappend(args,
1938 					   make_agg_arg(agg_input_types[i], agg_input_collation));
1939 	}
1940 
1941 	fexpr = makeFuncExpr(transfn_oid,
1942 						 agg_state_type,
1943 						 args,
1944 						 InvalidOid,
1945 						 agg_input_collation,
1946 						 COERCE_EXPLICIT_CALL);
1947 	fexpr->funcvariadic = agg_variadic;
1948 	*transfnexpr = (Expr *) fexpr;
1949 
1950 	/*
1951 	 * Build invtransfn expression if requested, with same args as transfn
1952 	 */
1953 	if (invtransfnexpr != NULL)
1954 	{
1955 		if (OidIsValid(invtransfn_oid))
1956 		{
1957 			fexpr = makeFuncExpr(invtransfn_oid,
1958 								 agg_state_type,
1959 								 args,
1960 								 InvalidOid,
1961 								 agg_input_collation,
1962 								 COERCE_EXPLICIT_CALL);
1963 			fexpr->funcvariadic = agg_variadic;
1964 			*invtransfnexpr = (Expr *) fexpr;
1965 		}
1966 		else
1967 			*invtransfnexpr = NULL;
1968 	}
1969 }
1970 
1971 /*
1972  * Like build_aggregate_transfn_expr, but creates an expression tree for the
1973  * combine function of an aggregate, rather than the transition function.
1974  */
1975 void
build_aggregate_combinefn_expr(Oid agg_state_type,Oid agg_input_collation,Oid combinefn_oid,Expr ** combinefnexpr)1976 build_aggregate_combinefn_expr(Oid agg_state_type,
1977 							   Oid agg_input_collation,
1978 							   Oid combinefn_oid,
1979 							   Expr **combinefnexpr)
1980 {
1981 	Node	   *argp;
1982 	List	   *args;
1983 	FuncExpr   *fexpr;
1984 
1985 	/* combinefn takes two arguments of the aggregate state type */
1986 	argp = make_agg_arg(agg_state_type, agg_input_collation);
1987 
1988 	args = list_make2(argp, argp);
1989 
1990 	fexpr = makeFuncExpr(combinefn_oid,
1991 						 agg_state_type,
1992 						 args,
1993 						 InvalidOid,
1994 						 agg_input_collation,
1995 						 COERCE_EXPLICIT_CALL);
1996 	/* combinefn is currently never treated as variadic */
1997 	*combinefnexpr = (Expr *) fexpr;
1998 }
1999 
2000 /*
2001  * Like build_aggregate_transfn_expr, but creates an expression tree for the
2002  * serialization function of an aggregate.
2003  */
2004 void
build_aggregate_serialfn_expr(Oid serialfn_oid,Expr ** serialfnexpr)2005 build_aggregate_serialfn_expr(Oid serialfn_oid,
2006 							  Expr **serialfnexpr)
2007 {
2008 	List	   *args;
2009 	FuncExpr   *fexpr;
2010 
2011 	/* serialfn always takes INTERNAL and returns BYTEA */
2012 	args = list_make1(make_agg_arg(INTERNALOID, InvalidOid));
2013 
2014 	fexpr = makeFuncExpr(serialfn_oid,
2015 						 BYTEAOID,
2016 						 args,
2017 						 InvalidOid,
2018 						 InvalidOid,
2019 						 COERCE_EXPLICIT_CALL);
2020 	*serialfnexpr = (Expr *) fexpr;
2021 }
2022 
2023 /*
2024  * Like build_aggregate_transfn_expr, but creates an expression tree for the
2025  * deserialization function of an aggregate.
2026  */
2027 void
build_aggregate_deserialfn_expr(Oid deserialfn_oid,Expr ** deserialfnexpr)2028 build_aggregate_deserialfn_expr(Oid deserialfn_oid,
2029 								Expr **deserialfnexpr)
2030 {
2031 	List	   *args;
2032 	FuncExpr   *fexpr;
2033 
2034 	/* deserialfn always takes BYTEA, INTERNAL and returns INTERNAL */
2035 	args = list_make2(make_agg_arg(BYTEAOID, InvalidOid),
2036 					  make_agg_arg(INTERNALOID, InvalidOid));
2037 
2038 	fexpr = makeFuncExpr(deserialfn_oid,
2039 						 INTERNALOID,
2040 						 args,
2041 						 InvalidOid,
2042 						 InvalidOid,
2043 						 COERCE_EXPLICIT_CALL);
2044 	*deserialfnexpr = (Expr *) fexpr;
2045 }
2046 
2047 /*
2048  * Like build_aggregate_transfn_expr, but creates an expression tree for the
2049  * final function of an aggregate, rather than the transition function.
2050  */
2051 void
build_aggregate_finalfn_expr(Oid * agg_input_types,int num_finalfn_inputs,Oid agg_state_type,Oid agg_result_type,Oid agg_input_collation,Oid finalfn_oid,Expr ** finalfnexpr)2052 build_aggregate_finalfn_expr(Oid *agg_input_types,
2053 							 int num_finalfn_inputs,
2054 							 Oid agg_state_type,
2055 							 Oid agg_result_type,
2056 							 Oid agg_input_collation,
2057 							 Oid finalfn_oid,
2058 							 Expr **finalfnexpr)
2059 {
2060 	List	   *args;
2061 	int			i;
2062 
2063 	/*
2064 	 * Build expr tree for final function
2065 	 */
2066 	args = list_make1(make_agg_arg(agg_state_type, agg_input_collation));
2067 
2068 	/* finalfn may take additional args, which match agg's input types */
2069 	for (i = 0; i < num_finalfn_inputs - 1; i++)
2070 	{
2071 		args = lappend(args,
2072 					   make_agg_arg(agg_input_types[i], agg_input_collation));
2073 	}
2074 
2075 	*finalfnexpr = (Expr *) makeFuncExpr(finalfn_oid,
2076 										 agg_result_type,
2077 										 args,
2078 										 InvalidOid,
2079 										 agg_input_collation,
2080 										 COERCE_EXPLICIT_CALL);
2081 	/* finalfn is currently never treated as variadic */
2082 }
2083 
2084 /*
2085  * Convenience function to build dummy argument expressions for aggregates.
2086  *
2087  * We really only care that an aggregate support function can discover its
2088  * actual argument types at runtime using get_fn_expr_argtype(), so it's okay
2089  * to use Param nodes that don't correspond to any real Param.
2090  */
2091 static Node *
make_agg_arg(Oid argtype,Oid argcollation)2092 make_agg_arg(Oid argtype, Oid argcollation)
2093 {
2094 	Param	   *argp = makeNode(Param);
2095 
2096 	argp->paramkind = PARAM_EXEC;
2097 	argp->paramid = -1;
2098 	argp->paramtype = argtype;
2099 	argp->paramtypmod = -1;
2100 	argp->paramcollid = argcollation;
2101 	argp->location = -1;
2102 	return (Node *) argp;
2103 }
2104