1 /*-------------------------------------------------------------------------
2  *
3  * parse_agg.c
4  *	  handle aggregates and window functions in parser
5  *
6  * Portions Copyright (c) 1996-2018, PostgreSQL Global Development Group
7  * Portions Copyright (c) 1994, Regents of the University of California
8  *
9  *
10  * IDENTIFICATION
11  *	  src/backend/parser/parse_agg.c
12  *
13  *-------------------------------------------------------------------------
14  */
15 #include "postgres.h"
16 
17 #include "catalog/pg_aggregate.h"
18 #include "catalog/pg_constraint.h"
19 #include "catalog/pg_type.h"
20 #include "nodes/makefuncs.h"
21 #include "nodes/nodeFuncs.h"
22 #include "optimizer/tlist.h"
23 #include "optimizer/var.h"
24 #include "parser/parse_agg.h"
25 #include "parser/parse_clause.h"
26 #include "parser/parse_coerce.h"
27 #include "parser/parse_expr.h"
28 #include "parser/parsetree.h"
29 #include "rewrite/rewriteManip.h"
30 #include "utils/builtins.h"
31 #include "utils/lsyscache.h"
32 
33 
34 typedef struct
35 {
36 	ParseState *pstate;
37 	int			min_varlevel;
38 	int			min_agglevel;
39 	int			sublevels_up;
40 } check_agg_arguments_context;
41 
42 typedef struct
43 {
44 	ParseState *pstate;
45 	Query	   *qry;
46 	PlannerInfo *root;
47 	List	   *groupClauses;
48 	List	   *groupClauseCommonVars;
49 	bool		have_non_var_grouping;
50 	List	  **func_grouped_rels;
51 	int			sublevels_up;
52 	bool		in_agg_direct_args;
53 } check_ungrouped_columns_context;
54 
55 static int check_agg_arguments(ParseState *pstate,
56 					List *directargs,
57 					List *args,
58 					Expr *filter);
59 static bool check_agg_arguments_walker(Node *node,
60 						   check_agg_arguments_context *context);
61 static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
62 						List *groupClauses, List *groupClauseVars,
63 						bool have_non_var_grouping,
64 						List **func_grouped_rels);
65 static bool check_ungrouped_columns_walker(Node *node,
66 							   check_ungrouped_columns_context *context);
67 static void finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
68 						List *groupClauses, PlannerInfo *root,
69 						bool have_non_var_grouping);
70 static bool finalize_grouping_exprs_walker(Node *node,
71 							   check_ungrouped_columns_context *context);
72 static void check_agglevels_and_constraints(ParseState *pstate, Node *expr);
73 static List *expand_groupingset_node(GroupingSet *gs);
74 static Node *make_agg_arg(Oid argtype, Oid argcollation);
75 
76 
77 /*
78  * transformAggregateCall -
79  *		Finish initial transformation of an aggregate call
80  *
81  * parse_func.c has recognized the function as an aggregate, and has set up
82  * all the fields of the Aggref except aggargtypes, aggdirectargs, args,
83  * aggorder, aggdistinct and agglevelsup.  The passed-in args list has been
84  * through standard expression transformation and type coercion to match the
85  * agg's declared arg types, while the passed-in aggorder list hasn't been
86  * transformed at all.
87  *
88  * Here we separate the args list into direct and aggregated args, storing the
89  * former in agg->aggdirectargs and the latter in agg->args.  The regular
90  * args, but not the direct args, are converted into a targetlist by inserting
91  * TargetEntry nodes.  We then transform the aggorder and agg_distinct
92  * specifications to produce lists of SortGroupClause nodes for agg->aggorder
93  * and agg->aggdistinct.  (For a regular aggregate, this might result in
94  * adding resjunk expressions to the targetlist; but for ordered-set
95  * aggregates the aggorder list will always be one-to-one with the aggregated
96  * args.)
97  *
98  * We must also determine which query level the aggregate actually belongs to,
99  * set agglevelsup accordingly, and mark p_hasAggs true in the corresponding
100  * pstate level.
101  */
102 void
transformAggregateCall(ParseState * pstate,Aggref * agg,List * args,List * aggorder,bool agg_distinct)103 transformAggregateCall(ParseState *pstate, Aggref *agg,
104 					   List *args, List *aggorder, bool agg_distinct)
105 {
106 	List	   *argtypes = NIL;
107 	List	   *tlist = NIL;
108 	List	   *torder = NIL;
109 	List	   *tdistinct = NIL;
110 	AttrNumber	attno = 1;
111 	int			save_next_resno;
112 	ListCell   *lc;
113 
114 	/*
115 	 * Before separating the args into direct and aggregated args, make a list
116 	 * of their data type OIDs for use later.
117 	 */
118 	foreach(lc, args)
119 	{
120 		Expr	   *arg = (Expr *) lfirst(lc);
121 
122 		argtypes = lappend_oid(argtypes, exprType((Node *) arg));
123 	}
124 	agg->aggargtypes = argtypes;
125 
126 	if (AGGKIND_IS_ORDERED_SET(agg->aggkind))
127 	{
128 		/*
129 		 * For an ordered-set agg, the args list includes direct args and
130 		 * aggregated args; we must split them apart.
131 		 */
132 		int			numDirectArgs = list_length(args) - list_length(aggorder);
133 		List	   *aargs;
134 		ListCell   *lc2;
135 
136 		Assert(numDirectArgs >= 0);
137 
138 		aargs = list_copy_tail(args, numDirectArgs);
139 		agg->aggdirectargs = list_truncate(args, numDirectArgs);
140 
141 		/*
142 		 * Build a tlist from the aggregated args, and make a sortlist entry
143 		 * for each one.  Note that the expressions in the SortBy nodes are
144 		 * ignored (they are the raw versions of the transformed args); we are
145 		 * just looking at the sort information in the SortBy nodes.
146 		 */
147 		forboth(lc, aargs, lc2, aggorder)
148 		{
149 			Expr	   *arg = (Expr *) lfirst(lc);
150 			SortBy	   *sortby = (SortBy *) lfirst(lc2);
151 			TargetEntry *tle;
152 
153 			/* We don't bother to assign column names to the entries */
154 			tle = makeTargetEntry(arg, attno++, NULL, false);
155 			tlist = lappend(tlist, tle);
156 
157 			torder = addTargetToSortList(pstate, tle,
158 										 torder, tlist, sortby);
159 		}
160 
161 		/* Never any DISTINCT in an ordered-set agg */
162 		Assert(!agg_distinct);
163 	}
164 	else
165 	{
166 		/* Regular aggregate, so it has no direct args */
167 		agg->aggdirectargs = NIL;
168 
169 		/*
170 		 * Transform the plain list of Exprs into a targetlist.
171 		 */
172 		foreach(lc, args)
173 		{
174 			Expr	   *arg = (Expr *) lfirst(lc);
175 			TargetEntry *tle;
176 
177 			/* We don't bother to assign column names to the entries */
178 			tle = makeTargetEntry(arg, attno++, NULL, false);
179 			tlist = lappend(tlist, tle);
180 		}
181 
182 		/*
183 		 * If we have an ORDER BY, transform it.  This will add columns to the
184 		 * tlist if they appear in ORDER BY but weren't already in the arg
185 		 * list.  They will be marked resjunk = true so we can tell them apart
186 		 * from regular aggregate arguments later.
187 		 *
188 		 * We need to mess with p_next_resno since it will be used to number
189 		 * any new targetlist entries.
190 		 */
191 		save_next_resno = pstate->p_next_resno;
192 		pstate->p_next_resno = attno;
193 
194 		torder = transformSortClause(pstate,
195 									 aggorder,
196 									 &tlist,
197 									 EXPR_KIND_ORDER_BY,
198 									 true /* force SQL99 rules */ );
199 
200 		/*
201 		 * If we have DISTINCT, transform that to produce a distinctList.
202 		 */
203 		if (agg_distinct)
204 		{
205 			tdistinct = transformDistinctClause(pstate, &tlist, torder, true);
206 
207 			/*
208 			 * Remove this check if executor support for hashed distinct for
209 			 * aggregates is ever added.
210 			 */
211 			foreach(lc, tdistinct)
212 			{
213 				SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
214 
215 				if (!OidIsValid(sortcl->sortop))
216 				{
217 					Node	   *expr = get_sortgroupclause_expr(sortcl, tlist);
218 
219 					ereport(ERROR,
220 							(errcode(ERRCODE_UNDEFINED_FUNCTION),
221 							 errmsg("could not identify an ordering operator for type %s",
222 									format_type_be(exprType(expr))),
223 							 errdetail("Aggregates with DISTINCT must be able to sort their inputs."),
224 							 parser_errposition(pstate, exprLocation(expr))));
225 				}
226 			}
227 		}
228 
229 		pstate->p_next_resno = save_next_resno;
230 	}
231 
232 	/* Update the Aggref with the transformation results */
233 	agg->args = tlist;
234 	agg->aggorder = torder;
235 	agg->aggdistinct = tdistinct;
236 
237 	check_agglevels_and_constraints(pstate, (Node *) agg);
238 }
239 
240 /*
241  * transformGroupingFunc
242  *		Transform a GROUPING expression
243  *
244  * GROUPING() behaves very like an aggregate.  Processing of levels and nesting
245  * is done as for aggregates.  We set p_hasAggs for these expressions too.
246  */
247 Node *
transformGroupingFunc(ParseState * pstate,GroupingFunc * p)248 transformGroupingFunc(ParseState *pstate, GroupingFunc *p)
249 {
250 	ListCell   *lc;
251 	List	   *args = p->args;
252 	List	   *result_list = NIL;
253 	GroupingFunc *result = makeNode(GroupingFunc);
254 
255 	if (list_length(args) > 31)
256 		ereport(ERROR,
257 				(errcode(ERRCODE_TOO_MANY_ARGUMENTS),
258 				 errmsg("GROUPING must have fewer than 32 arguments"),
259 				 parser_errposition(pstate, p->location)));
260 
261 	foreach(lc, args)
262 	{
263 		Node	   *current_result;
264 
265 		current_result = transformExpr(pstate, (Node *) lfirst(lc), pstate->p_expr_kind);
266 
267 		/* acceptability of expressions is checked later */
268 
269 		result_list = lappend(result_list, current_result);
270 	}
271 
272 	result->args = result_list;
273 	result->location = p->location;
274 
275 	check_agglevels_and_constraints(pstate, (Node *) result);
276 
277 	return (Node *) result;
278 }
279 
280 /*
281  * Aggregate functions and grouping operations (which are combined in the spec
282  * as <set function specification>) are very similar with regard to level and
283  * nesting restrictions (though we allow a lot more things than the spec does).
284  * Centralise those restrictions here.
285  */
286 static void
check_agglevels_and_constraints(ParseState * pstate,Node * expr)287 check_agglevels_and_constraints(ParseState *pstate, Node *expr)
288 {
289 	List	   *directargs = NIL;
290 	List	   *args = NIL;
291 	Expr	   *filter = NULL;
292 	int			min_varlevel;
293 	int			location = -1;
294 	Index	   *p_levelsup;
295 	const char *err;
296 	bool		errkind;
297 	bool		isAgg = IsA(expr, Aggref);
298 
299 	if (isAgg)
300 	{
301 		Aggref	   *agg = (Aggref *) expr;
302 
303 		directargs = agg->aggdirectargs;
304 		args = agg->args;
305 		filter = agg->aggfilter;
306 		location = agg->location;
307 		p_levelsup = &agg->agglevelsup;
308 	}
309 	else
310 	{
311 		GroupingFunc *grp = (GroupingFunc *) expr;
312 
313 		args = grp->args;
314 		location = grp->location;
315 		p_levelsup = &grp->agglevelsup;
316 	}
317 
318 	/*
319 	 * Check the arguments to compute the aggregate's level and detect
320 	 * improper nesting.
321 	 */
322 	min_varlevel = check_agg_arguments(pstate,
323 									   directargs,
324 									   args,
325 									   filter);
326 
327 	*p_levelsup = min_varlevel;
328 
329 	/* Mark the correct pstate level as having aggregates */
330 	while (min_varlevel-- > 0)
331 		pstate = pstate->parentParseState;
332 	pstate->p_hasAggs = true;
333 
334 	/*
335 	 * Check to see if the aggregate function is in an invalid place within
336 	 * its aggregation query.
337 	 *
338 	 * For brevity we support two schemes for reporting an error here: set
339 	 * "err" to a custom message, or set "errkind" true if the error context
340 	 * is sufficiently identified by what ParseExprKindName will return, *and*
341 	 * what it will return is just a SQL keyword.  (Otherwise, use a custom
342 	 * message to avoid creating translation problems.)
343 	 */
344 	err = NULL;
345 	errkind = false;
346 	switch (pstate->p_expr_kind)
347 	{
348 		case EXPR_KIND_NONE:
349 			Assert(false);		/* can't happen */
350 			break;
351 		case EXPR_KIND_OTHER:
352 
353 			/*
354 			 * Accept aggregate/grouping here; caller must throw error if
355 			 * wanted
356 			 */
357 			break;
358 		case EXPR_KIND_JOIN_ON:
359 		case EXPR_KIND_JOIN_USING:
360 			if (isAgg)
361 				err = _("aggregate functions are not allowed in JOIN conditions");
362 			else
363 				err = _("grouping operations are not allowed in JOIN conditions");
364 
365 			break;
366 		case EXPR_KIND_FROM_SUBSELECT:
367 			/* Should only be possible in a LATERAL subquery */
368 			Assert(pstate->p_lateral_active);
369 
370 			/*
371 			 * Aggregate/grouping scope rules make it worth being explicit
372 			 * here
373 			 */
374 			if (isAgg)
375 				err = _("aggregate functions are not allowed in FROM clause of their own query level");
376 			else
377 				err = _("grouping operations are not allowed in FROM clause of their own query level");
378 
379 			break;
380 		case EXPR_KIND_FROM_FUNCTION:
381 			if (isAgg)
382 				err = _("aggregate functions are not allowed in functions in FROM");
383 			else
384 				err = _("grouping operations are not allowed in functions in FROM");
385 
386 			break;
387 		case EXPR_KIND_WHERE:
388 			errkind = true;
389 			break;
390 		case EXPR_KIND_POLICY:
391 			if (isAgg)
392 				err = _("aggregate functions are not allowed in policy expressions");
393 			else
394 				err = _("grouping operations are not allowed in policy expressions");
395 
396 			break;
397 		case EXPR_KIND_HAVING:
398 			/* okay */
399 			break;
400 		case EXPR_KIND_FILTER:
401 			errkind = true;
402 			break;
403 		case EXPR_KIND_WINDOW_PARTITION:
404 			/* okay */
405 			break;
406 		case EXPR_KIND_WINDOW_ORDER:
407 			/* okay */
408 			break;
409 		case EXPR_KIND_WINDOW_FRAME_RANGE:
410 			if (isAgg)
411 				err = _("aggregate functions are not allowed in window RANGE");
412 			else
413 				err = _("grouping operations are not allowed in window RANGE");
414 
415 			break;
416 		case EXPR_KIND_WINDOW_FRAME_ROWS:
417 			if (isAgg)
418 				err = _("aggregate functions are not allowed in window ROWS");
419 			else
420 				err = _("grouping operations are not allowed in window ROWS");
421 
422 			break;
423 		case EXPR_KIND_WINDOW_FRAME_GROUPS:
424 			if (isAgg)
425 				err = _("aggregate functions are not allowed in window GROUPS");
426 			else
427 				err = _("grouping operations are not allowed in window GROUPS");
428 
429 			break;
430 		case EXPR_KIND_SELECT_TARGET:
431 			/* okay */
432 			break;
433 		case EXPR_KIND_INSERT_TARGET:
434 		case EXPR_KIND_UPDATE_SOURCE:
435 		case EXPR_KIND_UPDATE_TARGET:
436 			errkind = true;
437 			break;
438 		case EXPR_KIND_GROUP_BY:
439 			errkind = true;
440 			break;
441 		case EXPR_KIND_ORDER_BY:
442 			/* okay */
443 			break;
444 		case EXPR_KIND_DISTINCT_ON:
445 			/* okay */
446 			break;
447 		case EXPR_KIND_LIMIT:
448 		case EXPR_KIND_OFFSET:
449 			errkind = true;
450 			break;
451 		case EXPR_KIND_RETURNING:
452 			errkind = true;
453 			break;
454 		case EXPR_KIND_VALUES:
455 		case EXPR_KIND_VALUES_SINGLE:
456 			errkind = true;
457 			break;
458 		case EXPR_KIND_CHECK_CONSTRAINT:
459 		case EXPR_KIND_DOMAIN_CHECK:
460 			if (isAgg)
461 				err = _("aggregate functions are not allowed in check constraints");
462 			else
463 				err = _("grouping operations are not allowed in check constraints");
464 
465 			break;
466 		case EXPR_KIND_COLUMN_DEFAULT:
467 		case EXPR_KIND_FUNCTION_DEFAULT:
468 
469 			if (isAgg)
470 				err = _("aggregate functions are not allowed in DEFAULT expressions");
471 			else
472 				err = _("grouping operations are not allowed in DEFAULT expressions");
473 
474 			break;
475 		case EXPR_KIND_INDEX_EXPRESSION:
476 			if (isAgg)
477 				err = _("aggregate functions are not allowed in index expressions");
478 			else
479 				err = _("grouping operations are not allowed in index expressions");
480 
481 			break;
482 		case EXPR_KIND_INDEX_PREDICATE:
483 			if (isAgg)
484 				err = _("aggregate functions are not allowed in index predicates");
485 			else
486 				err = _("grouping operations are not allowed in index predicates");
487 
488 			break;
489 		case EXPR_KIND_ALTER_COL_TRANSFORM:
490 			if (isAgg)
491 				err = _("aggregate functions are not allowed in transform expressions");
492 			else
493 				err = _("grouping operations are not allowed in transform expressions");
494 
495 			break;
496 		case EXPR_KIND_EXECUTE_PARAMETER:
497 			if (isAgg)
498 				err = _("aggregate functions are not allowed in EXECUTE parameters");
499 			else
500 				err = _("grouping operations are not allowed in EXECUTE parameters");
501 
502 			break;
503 		case EXPR_KIND_TRIGGER_WHEN:
504 			if (isAgg)
505 				err = _("aggregate functions are not allowed in trigger WHEN conditions");
506 			else
507 				err = _("grouping operations are not allowed in trigger WHEN conditions");
508 
509 			break;
510 		case EXPR_KIND_PARTITION_EXPRESSION:
511 			if (isAgg)
512 				err = _("aggregate functions are not allowed in partition key expressions");
513 			else
514 				err = _("grouping operations are not allowed in partition key expressions");
515 
516 			break;
517 
518 		case EXPR_KIND_CALL_ARGUMENT:
519 			if (isAgg)
520 				err = _("aggregate functions are not allowed in CALL arguments");
521 			else
522 				err = _("grouping operations are not allowed in CALL arguments");
523 
524 			break;
525 
526 			/*
527 			 * There is intentionally no default: case here, so that the
528 			 * compiler will warn if we add a new ParseExprKind without
529 			 * extending this switch.  If we do see an unrecognized value at
530 			 * runtime, the behavior will be the same as for EXPR_KIND_OTHER,
531 			 * which is sane anyway.
532 			 */
533 	}
534 
535 	if (err)
536 		ereport(ERROR,
537 				(errcode(ERRCODE_GROUPING_ERROR),
538 				 errmsg_internal("%s", err),
539 				 parser_errposition(pstate, location)));
540 
541 	if (errkind)
542 	{
543 		if (isAgg)
544 			/* translator: %s is name of a SQL construct, eg GROUP BY */
545 			err = _("aggregate functions are not allowed in %s");
546 		else
547 			/* translator: %s is name of a SQL construct, eg GROUP BY */
548 			err = _("grouping operations are not allowed in %s");
549 
550 		ereport(ERROR,
551 				(errcode(ERRCODE_GROUPING_ERROR),
552 				 errmsg_internal(err,
553 								 ParseExprKindName(pstate->p_expr_kind)),
554 				 parser_errposition(pstate, location)));
555 	}
556 }
557 
558 /*
559  * check_agg_arguments
560  *	  Scan the arguments of an aggregate function to determine the
561  *	  aggregate's semantic level (zero is the current select's level,
562  *	  one is its parent, etc).
563  *
564  * The aggregate's level is the same as the level of the lowest-level variable
565  * or aggregate in its aggregated arguments (including any ORDER BY columns)
566  * or filter expression; or if it contains no variables at all, we presume it
567  * to be local.
568  *
569  * Vars/Aggs in direct arguments are *not* counted towards determining the
570  * agg's level, as those arguments aren't evaluated per-row but only
571  * per-group, and so in some sense aren't really agg arguments.  However,
572  * this can mean that we decide an agg is upper-level even when its direct
573  * args contain lower-level Vars/Aggs, and that case has to be disallowed.
574  * (This is a little strange, but the SQL standard seems pretty definite that
575  * direct args are not to be considered when setting the agg's level.)
576  *
577  * We also take this opportunity to detect any aggregates or window functions
578  * nested within the arguments.  We can throw error immediately if we find
579  * a window function.  Aggregates are a bit trickier because it's only an
580  * error if the inner aggregate is of the same semantic level as the outer,
581  * which we can't know until we finish scanning the arguments.
582  */
583 static int
check_agg_arguments(ParseState * pstate,List * directargs,List * args,Expr * filter)584 check_agg_arguments(ParseState *pstate,
585 					List *directargs,
586 					List *args,
587 					Expr *filter)
588 {
589 	int			agglevel;
590 	check_agg_arguments_context context;
591 
592 	context.pstate = pstate;
593 	context.min_varlevel = -1;	/* signifies nothing found yet */
594 	context.min_agglevel = -1;
595 	context.sublevels_up = 0;
596 
597 	(void) check_agg_arguments_walker((Node *) args, &context);
598 	(void) check_agg_arguments_walker((Node *) filter, &context);
599 
600 	/*
601 	 * If we found no vars nor aggs at all, it's a level-zero aggregate;
602 	 * otherwise, its level is the minimum of vars or aggs.
603 	 */
604 	if (context.min_varlevel < 0)
605 	{
606 		if (context.min_agglevel < 0)
607 			agglevel = 0;
608 		else
609 			agglevel = context.min_agglevel;
610 	}
611 	else if (context.min_agglevel < 0)
612 		agglevel = context.min_varlevel;
613 	else
614 		agglevel = Min(context.min_varlevel, context.min_agglevel);
615 
616 	/*
617 	 * If there's a nested aggregate of the same semantic level, complain.
618 	 */
619 	if (agglevel == context.min_agglevel)
620 	{
621 		int			aggloc;
622 
623 		aggloc = locate_agg_of_level((Node *) args, agglevel);
624 		if (aggloc < 0)
625 			aggloc = locate_agg_of_level((Node *) filter, agglevel);
626 		ereport(ERROR,
627 				(errcode(ERRCODE_GROUPING_ERROR),
628 				 errmsg("aggregate function calls cannot be nested"),
629 				 parser_errposition(pstate, aggloc)));
630 	}
631 
632 	/*
633 	 * Now check for vars/aggs in the direct arguments, and throw error if
634 	 * needed.  Note that we allow a Var of the agg's semantic level, but not
635 	 * an Agg of that level.  In principle such Aggs could probably be
636 	 * supported, but it would create an ordering dependency among the
637 	 * aggregates at execution time.  Since the case appears neither to be
638 	 * required by spec nor particularly useful, we just treat it as a
639 	 * nested-aggregate situation.
640 	 */
641 	if (directargs)
642 	{
643 		context.min_varlevel = -1;
644 		context.min_agglevel = -1;
645 		(void) check_agg_arguments_walker((Node *) directargs, &context);
646 		if (context.min_varlevel >= 0 && context.min_varlevel < agglevel)
647 			ereport(ERROR,
648 					(errcode(ERRCODE_GROUPING_ERROR),
649 					 errmsg("outer-level aggregate cannot contain a lower-level variable in its direct arguments"),
650 					 parser_errposition(pstate,
651 										locate_var_of_level((Node *) directargs,
652 															context.min_varlevel))));
653 		if (context.min_agglevel >= 0 && context.min_agglevel <= agglevel)
654 			ereport(ERROR,
655 					(errcode(ERRCODE_GROUPING_ERROR),
656 					 errmsg("aggregate function calls cannot be nested"),
657 					 parser_errposition(pstate,
658 										locate_agg_of_level((Node *) directargs,
659 															context.min_agglevel))));
660 	}
661 	return agglevel;
662 }
663 
664 static bool
check_agg_arguments_walker(Node * node,check_agg_arguments_context * context)665 check_agg_arguments_walker(Node *node,
666 						   check_agg_arguments_context *context)
667 {
668 	if (node == NULL)
669 		return false;
670 	if (IsA(node, Var))
671 	{
672 		int			varlevelsup = ((Var *) node)->varlevelsup;
673 
674 		/* convert levelsup to frame of reference of original query */
675 		varlevelsup -= context->sublevels_up;
676 		/* ignore local vars of subqueries */
677 		if (varlevelsup >= 0)
678 		{
679 			if (context->min_varlevel < 0 ||
680 				context->min_varlevel > varlevelsup)
681 				context->min_varlevel = varlevelsup;
682 		}
683 		return false;
684 	}
685 	if (IsA(node, Aggref))
686 	{
687 		int			agglevelsup = ((Aggref *) node)->agglevelsup;
688 
689 		/* convert levelsup to frame of reference of original query */
690 		agglevelsup -= context->sublevels_up;
691 		/* ignore local aggs of subqueries */
692 		if (agglevelsup >= 0)
693 		{
694 			if (context->min_agglevel < 0 ||
695 				context->min_agglevel > agglevelsup)
696 				context->min_agglevel = agglevelsup;
697 		}
698 		/* no need to examine args of the inner aggregate */
699 		return false;
700 	}
701 	if (IsA(node, GroupingFunc))
702 	{
703 		int			agglevelsup = ((GroupingFunc *) node)->agglevelsup;
704 
705 		/* convert levelsup to frame of reference of original query */
706 		agglevelsup -= context->sublevels_up;
707 		/* ignore local aggs of subqueries */
708 		if (agglevelsup >= 0)
709 		{
710 			if (context->min_agglevel < 0 ||
711 				context->min_agglevel > agglevelsup)
712 				context->min_agglevel = agglevelsup;
713 		}
714 		/* Continue and descend into subtree */
715 	}
716 
717 	/*
718 	 * SRFs and window functions can be rejected immediately, unless we are
719 	 * within a sub-select within the aggregate's arguments; in that case
720 	 * they're OK.
721 	 */
722 	if (context->sublevels_up == 0)
723 	{
724 		if ((IsA(node, FuncExpr) &&((FuncExpr *) node)->funcretset) ||
725 			(IsA(node, OpExpr) &&((OpExpr *) node)->opretset))
726 			ereport(ERROR,
727 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
728 					 errmsg("aggregate function calls cannot contain set-returning function calls"),
729 					 errhint("You might be able to move the set-returning function into a LATERAL FROM item."),
730 					 parser_errposition(context->pstate, exprLocation(node))));
731 		if (IsA(node, WindowFunc))
732 			ereport(ERROR,
733 					(errcode(ERRCODE_GROUPING_ERROR),
734 					 errmsg("aggregate function calls cannot contain window function calls"),
735 					 parser_errposition(context->pstate,
736 										((WindowFunc *) node)->location)));
737 	}
738 	if (IsA(node, Query))
739 	{
740 		/* Recurse into subselects */
741 		bool		result;
742 
743 		context->sublevels_up++;
744 		result = query_tree_walker((Query *) node,
745 								   check_agg_arguments_walker,
746 								   (void *) context,
747 								   0);
748 		context->sublevels_up--;
749 		return result;
750 	}
751 
752 	return expression_tree_walker(node,
753 								  check_agg_arguments_walker,
754 								  (void *) context);
755 }
756 
757 /*
758  * transformWindowFuncCall -
759  *		Finish initial transformation of a window function call
760  *
761  * parse_func.c has recognized the function as a window function, and has set
762  * up all the fields of the WindowFunc except winref.  Here we must (1) add
763  * the WindowDef to the pstate (if not a duplicate of one already present) and
764  * set winref to link to it; and (2) mark p_hasWindowFuncs true in the pstate.
765  * Unlike aggregates, only the most closely nested pstate level need be
766  * considered --- there are no "outer window functions" per SQL spec.
767  */
768 void
transformWindowFuncCall(ParseState * pstate,WindowFunc * wfunc,WindowDef * windef)769 transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc,
770 						WindowDef *windef)
771 {
772 	const char *err;
773 	bool		errkind;
774 
775 	/*
776 	 * A window function call can't contain another one (but aggs are OK). XXX
777 	 * is this required by spec, or just an unimplemented feature?
778 	 *
779 	 * Note: we don't need to check the filter expression here, because the
780 	 * context checks done below and in transformAggregateCall would have
781 	 * already rejected any window funcs or aggs within the filter.
782 	 */
783 	if (pstate->p_hasWindowFuncs &&
784 		contain_windowfuncs((Node *) wfunc->args))
785 		ereport(ERROR,
786 				(errcode(ERRCODE_WINDOWING_ERROR),
787 				 errmsg("window function calls cannot be nested"),
788 				 parser_errposition(pstate,
789 									locate_windowfunc((Node *) wfunc->args))));
790 
791 	/*
792 	 * Check to see if the window function is in an invalid place within the
793 	 * query.
794 	 *
795 	 * For brevity we support two schemes for reporting an error here: set
796 	 * "err" to a custom message, or set "errkind" true if the error context
797 	 * is sufficiently identified by what ParseExprKindName will return, *and*
798 	 * what it will return is just a SQL keyword.  (Otherwise, use a custom
799 	 * message to avoid creating translation problems.)
800 	 */
801 	err = NULL;
802 	errkind = false;
803 	switch (pstate->p_expr_kind)
804 	{
805 		case EXPR_KIND_NONE:
806 			Assert(false);		/* can't happen */
807 			break;
808 		case EXPR_KIND_OTHER:
809 			/* Accept window func here; caller must throw error if wanted */
810 			break;
811 		case EXPR_KIND_JOIN_ON:
812 		case EXPR_KIND_JOIN_USING:
813 			err = _("window functions are not allowed in JOIN conditions");
814 			break;
815 		case EXPR_KIND_FROM_SUBSELECT:
816 			/* can't get here, but just in case, throw an error */
817 			errkind = true;
818 			break;
819 		case EXPR_KIND_FROM_FUNCTION:
820 			err = _("window functions are not allowed in functions in FROM");
821 			break;
822 		case EXPR_KIND_WHERE:
823 			errkind = true;
824 			break;
825 		case EXPR_KIND_POLICY:
826 			err = _("window functions are not allowed in policy expressions");
827 			break;
828 		case EXPR_KIND_HAVING:
829 			errkind = true;
830 			break;
831 		case EXPR_KIND_FILTER:
832 			errkind = true;
833 			break;
834 		case EXPR_KIND_WINDOW_PARTITION:
835 		case EXPR_KIND_WINDOW_ORDER:
836 		case EXPR_KIND_WINDOW_FRAME_RANGE:
837 		case EXPR_KIND_WINDOW_FRAME_ROWS:
838 		case EXPR_KIND_WINDOW_FRAME_GROUPS:
839 			err = _("window functions are not allowed in window definitions");
840 			break;
841 		case EXPR_KIND_SELECT_TARGET:
842 			/* okay */
843 			break;
844 		case EXPR_KIND_INSERT_TARGET:
845 		case EXPR_KIND_UPDATE_SOURCE:
846 		case EXPR_KIND_UPDATE_TARGET:
847 			errkind = true;
848 			break;
849 		case EXPR_KIND_GROUP_BY:
850 			errkind = true;
851 			break;
852 		case EXPR_KIND_ORDER_BY:
853 			/* okay */
854 			break;
855 		case EXPR_KIND_DISTINCT_ON:
856 			/* okay */
857 			break;
858 		case EXPR_KIND_LIMIT:
859 		case EXPR_KIND_OFFSET:
860 			errkind = true;
861 			break;
862 		case EXPR_KIND_RETURNING:
863 			errkind = true;
864 			break;
865 		case EXPR_KIND_VALUES:
866 		case EXPR_KIND_VALUES_SINGLE:
867 			errkind = true;
868 			break;
869 		case EXPR_KIND_CHECK_CONSTRAINT:
870 		case EXPR_KIND_DOMAIN_CHECK:
871 			err = _("window functions are not allowed in check constraints");
872 			break;
873 		case EXPR_KIND_COLUMN_DEFAULT:
874 		case EXPR_KIND_FUNCTION_DEFAULT:
875 			err = _("window functions are not allowed in DEFAULT expressions");
876 			break;
877 		case EXPR_KIND_INDEX_EXPRESSION:
878 			err = _("window functions are not allowed in index expressions");
879 			break;
880 		case EXPR_KIND_INDEX_PREDICATE:
881 			err = _("window functions are not allowed in index predicates");
882 			break;
883 		case EXPR_KIND_ALTER_COL_TRANSFORM:
884 			err = _("window functions are not allowed in transform expressions");
885 			break;
886 		case EXPR_KIND_EXECUTE_PARAMETER:
887 			err = _("window functions are not allowed in EXECUTE parameters");
888 			break;
889 		case EXPR_KIND_TRIGGER_WHEN:
890 			err = _("window functions are not allowed in trigger WHEN conditions");
891 			break;
892 		case EXPR_KIND_PARTITION_EXPRESSION:
893 			err = _("window functions are not allowed in partition key expressions");
894 			break;
895 		case EXPR_KIND_CALL_ARGUMENT:
896 			err = _("window functions are not allowed in CALL arguments");
897 			break;
898 
899 			/*
900 			 * There is intentionally no default: case here, so that the
901 			 * compiler will warn if we add a new ParseExprKind without
902 			 * extending this switch.  If we do see an unrecognized value at
903 			 * runtime, the behavior will be the same as for EXPR_KIND_OTHER,
904 			 * which is sane anyway.
905 			 */
906 	}
907 	if (err)
908 		ereport(ERROR,
909 				(errcode(ERRCODE_WINDOWING_ERROR),
910 				 errmsg_internal("%s", err),
911 				 parser_errposition(pstate, wfunc->location)));
912 	if (errkind)
913 		ereport(ERROR,
914 				(errcode(ERRCODE_WINDOWING_ERROR),
915 		/* translator: %s is name of a SQL construct, eg GROUP BY */
916 				 errmsg("window functions are not allowed in %s",
917 						ParseExprKindName(pstate->p_expr_kind)),
918 				 parser_errposition(pstate, wfunc->location)));
919 
920 	/*
921 	 * If the OVER clause just specifies a window name, find that WINDOW
922 	 * clause (which had better be present).  Otherwise, try to match all the
923 	 * properties of the OVER clause, and make a new entry in the p_windowdefs
924 	 * list if no luck.
925 	 */
926 	if (windef->name)
927 	{
928 		Index		winref = 0;
929 		ListCell   *lc;
930 
931 		Assert(windef->refname == NULL &&
932 			   windef->partitionClause == NIL &&
933 			   windef->orderClause == NIL &&
934 			   windef->frameOptions == FRAMEOPTION_DEFAULTS);
935 
936 		foreach(lc, pstate->p_windowdefs)
937 		{
938 			WindowDef  *refwin = (WindowDef *) lfirst(lc);
939 
940 			winref++;
941 			if (refwin->name && strcmp(refwin->name, windef->name) == 0)
942 			{
943 				wfunc->winref = winref;
944 				break;
945 			}
946 		}
947 		if (lc == NULL)			/* didn't find it? */
948 			ereport(ERROR,
949 					(errcode(ERRCODE_UNDEFINED_OBJECT),
950 					 errmsg("window \"%s\" does not exist", windef->name),
951 					 parser_errposition(pstate, windef->location)));
952 	}
953 	else
954 	{
955 		Index		winref = 0;
956 		ListCell   *lc;
957 
958 		foreach(lc, pstate->p_windowdefs)
959 		{
960 			WindowDef  *refwin = (WindowDef *) lfirst(lc);
961 
962 			winref++;
963 			if (refwin->refname && windef->refname &&
964 				strcmp(refwin->refname, windef->refname) == 0)
965 				 /* matched on refname */ ;
966 			else if (!refwin->refname && !windef->refname)
967 				 /* matched, no refname */ ;
968 			else
969 				continue;
970 			if (equal(refwin->partitionClause, windef->partitionClause) &&
971 				equal(refwin->orderClause, windef->orderClause) &&
972 				refwin->frameOptions == windef->frameOptions &&
973 				equal(refwin->startOffset, windef->startOffset) &&
974 				equal(refwin->endOffset, windef->endOffset))
975 			{
976 				/* found a duplicate window specification */
977 				wfunc->winref = winref;
978 				break;
979 			}
980 		}
981 		if (lc == NULL)			/* didn't find it? */
982 		{
983 			pstate->p_windowdefs = lappend(pstate->p_windowdefs, windef);
984 			wfunc->winref = list_length(pstate->p_windowdefs);
985 		}
986 	}
987 
988 	pstate->p_hasWindowFuncs = true;
989 }
990 
991 /*
992  * parseCheckAggregates
993  *	Check for aggregates where they shouldn't be and improper grouping.
994  *	This function should be called after the target list and qualifications
995  *	are finalized.
996  *
997  *	Misplaced aggregates are now mostly detected in transformAggregateCall,
998  *	but it seems more robust to check for aggregates in recursive queries
999  *	only after everything is finalized.  In any case it's hard to detect
1000  *	improper grouping on-the-fly, so we have to make another pass over the
1001  *	query for that.
1002  */
1003 void
parseCheckAggregates(ParseState * pstate,Query * qry)1004 parseCheckAggregates(ParseState *pstate, Query *qry)
1005 {
1006 	List	   *gset_common = NIL;
1007 	List	   *groupClauses = NIL;
1008 	List	   *groupClauseCommonVars = NIL;
1009 	bool		have_non_var_grouping;
1010 	List	   *func_grouped_rels = NIL;
1011 	ListCell   *l;
1012 	bool		hasJoinRTEs;
1013 	bool		hasSelfRefRTEs;
1014 	PlannerInfo *root = NULL;
1015 	Node	   *clause;
1016 
1017 	/* This should only be called if we found aggregates or grouping */
1018 	Assert(pstate->p_hasAggs || qry->groupClause || qry->havingQual || qry->groupingSets);
1019 
1020 	/*
1021 	 * If we have grouping sets, expand them and find the intersection of all
1022 	 * sets.
1023 	 */
1024 	if (qry->groupingSets)
1025 	{
1026 		/*
1027 		 * The limit of 4096 is arbitrary and exists simply to avoid resource
1028 		 * issues from pathological constructs.
1029 		 */
1030 		List	   *gsets = expand_grouping_sets(qry->groupingSets, 4096);
1031 
1032 		if (!gsets)
1033 			ereport(ERROR,
1034 					(errcode(ERRCODE_STATEMENT_TOO_COMPLEX),
1035 					 errmsg("too many grouping sets present (maximum 4096)"),
1036 					 parser_errposition(pstate,
1037 										qry->groupClause
1038 										? exprLocation((Node *) qry->groupClause)
1039 										: exprLocation((Node *) qry->groupingSets))));
1040 
1041 		/*
1042 		 * The intersection will often be empty, so help things along by
1043 		 * seeding the intersect with the smallest set.
1044 		 */
1045 		gset_common = linitial(gsets);
1046 
1047 		if (gset_common)
1048 		{
1049 			for_each_cell(l, lnext(list_head(gsets)))
1050 			{
1051 				gset_common = list_intersection_int(gset_common, lfirst(l));
1052 				if (!gset_common)
1053 					break;
1054 			}
1055 		}
1056 
1057 		/*
1058 		 * If there was only one grouping set in the expansion, AND if the
1059 		 * groupClause is non-empty (meaning that the grouping set is not
1060 		 * empty either), then we can ditch the grouping set and pretend we
1061 		 * just had a normal GROUP BY.
1062 		 */
1063 		if (list_length(gsets) == 1 && qry->groupClause)
1064 			qry->groupingSets = NIL;
1065 	}
1066 
1067 	/*
1068 	 * Scan the range table to see if there are JOIN or self-reference CTE
1069 	 * entries.  We'll need this info below.
1070 	 */
1071 	hasJoinRTEs = hasSelfRefRTEs = false;
1072 	foreach(l, pstate->p_rtable)
1073 	{
1074 		RangeTblEntry *rte = (RangeTblEntry *) lfirst(l);
1075 
1076 		if (rte->rtekind == RTE_JOIN)
1077 			hasJoinRTEs = true;
1078 		else if (rte->rtekind == RTE_CTE && rte->self_reference)
1079 			hasSelfRefRTEs = true;
1080 	}
1081 
1082 	/*
1083 	 * Build a list of the acceptable GROUP BY expressions for use by
1084 	 * check_ungrouped_columns().
1085 	 *
1086 	 * We get the TLE, not just the expr, because GROUPING wants to know the
1087 	 * sortgroupref.
1088 	 */
1089 	foreach(l, qry->groupClause)
1090 	{
1091 		SortGroupClause *grpcl = (SortGroupClause *) lfirst(l);
1092 		TargetEntry *expr;
1093 
1094 		expr = get_sortgroupclause_tle(grpcl, qry->targetList);
1095 		if (expr == NULL)
1096 			continue;			/* probably cannot happen */
1097 
1098 		groupClauses = lcons(expr, groupClauses);
1099 	}
1100 
1101 	/*
1102 	 * If there are join alias vars involved, we have to flatten them to the
1103 	 * underlying vars, so that aliased and unaliased vars will be correctly
1104 	 * taken as equal.  We can skip the expense of doing this if no rangetable
1105 	 * entries are RTE_JOIN kind. We use the planner's flatten_join_alias_vars
1106 	 * routine to do the flattening; it wants a PlannerInfo root node, which
1107 	 * fortunately can be mostly dummy.
1108 	 */
1109 	if (hasJoinRTEs)
1110 	{
1111 		root = makeNode(PlannerInfo);
1112 		root->parse = qry;
1113 		root->planner_cxt = CurrentMemoryContext;
1114 		root->hasJoinRTEs = true;
1115 
1116 		groupClauses = (List *) flatten_join_alias_vars(root,
1117 														(Node *) groupClauses);
1118 	}
1119 
1120 	/*
1121 	 * Detect whether any of the grouping expressions aren't simple Vars; if
1122 	 * they're all Vars then we don't have to work so hard in the recursive
1123 	 * scans.  (Note we have to flatten aliases before this.)
1124 	 *
1125 	 * Track Vars that are included in all grouping sets separately in
1126 	 * groupClauseCommonVars, since these are the only ones we can use to
1127 	 * check for functional dependencies.
1128 	 */
1129 	have_non_var_grouping = false;
1130 	foreach(l, groupClauses)
1131 	{
1132 		TargetEntry *tle = lfirst(l);
1133 
1134 		if (!IsA(tle->expr, Var))
1135 		{
1136 			have_non_var_grouping = true;
1137 		}
1138 		else if (!qry->groupingSets ||
1139 				 list_member_int(gset_common, tle->ressortgroupref))
1140 		{
1141 			groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr);
1142 		}
1143 	}
1144 
1145 	/*
1146 	 * Check the targetlist and HAVING clause for ungrouped variables.
1147 	 *
1148 	 * Note: because we check resjunk tlist elements as well as regular ones,
1149 	 * this will also find ungrouped variables that came from ORDER BY and
1150 	 * WINDOW clauses.  For that matter, it's also going to examine the
1151 	 * grouping expressions themselves --- but they'll all pass the test ...
1152 	 *
1153 	 * We also finalize GROUPING expressions, but for that we need to traverse
1154 	 * the original (unflattened) clause in order to modify nodes.
1155 	 */
1156 	clause = (Node *) qry->targetList;
1157 	finalize_grouping_exprs(clause, pstate, qry,
1158 							groupClauses, root,
1159 							have_non_var_grouping);
1160 	if (hasJoinRTEs)
1161 		clause = flatten_join_alias_vars(root, clause);
1162 	check_ungrouped_columns(clause, pstate, qry,
1163 							groupClauses, groupClauseCommonVars,
1164 							have_non_var_grouping,
1165 							&func_grouped_rels);
1166 
1167 	clause = (Node *) qry->havingQual;
1168 	finalize_grouping_exprs(clause, pstate, qry,
1169 							groupClauses, root,
1170 							have_non_var_grouping);
1171 	if (hasJoinRTEs)
1172 		clause = flatten_join_alias_vars(root, clause);
1173 	check_ungrouped_columns(clause, pstate, qry,
1174 							groupClauses, groupClauseCommonVars,
1175 							have_non_var_grouping,
1176 							&func_grouped_rels);
1177 
1178 	/*
1179 	 * Per spec, aggregates can't appear in a recursive term.
1180 	 */
1181 	if (pstate->p_hasAggs && hasSelfRefRTEs)
1182 		ereport(ERROR,
1183 				(errcode(ERRCODE_INVALID_RECURSION),
1184 				 errmsg("aggregate functions are not allowed in a recursive query's recursive term"),
1185 				 parser_errposition(pstate,
1186 									locate_agg_of_level((Node *) qry, 0))));
1187 }
1188 
1189 /*
1190  * check_ungrouped_columns -
1191  *	  Scan the given expression tree for ungrouped variables (variables
1192  *	  that are not listed in the groupClauses list and are not within
1193  *	  the arguments of aggregate functions).  Emit a suitable error message
1194  *	  if any are found.
1195  *
1196  * NOTE: we assume that the given clause has been transformed suitably for
1197  * parser output.  This means we can use expression_tree_walker.
1198  *
1199  * NOTE: we recognize grouping expressions in the main query, but only
1200  * grouping Vars in subqueries.  For example, this will be rejected,
1201  * although it could be allowed:
1202  *		SELECT
1203  *			(SELECT x FROM bar where y = (foo.a + foo.b))
1204  *		FROM foo
1205  *		GROUP BY a + b;
1206  * The difficulty is the need to account for different sublevels_up.
1207  * This appears to require a whole custom version of equal(), which is
1208  * way more pain than the feature seems worth.
1209  */
1210 static void
check_ungrouped_columns(Node * node,ParseState * pstate,Query * qry,List * groupClauses,List * groupClauseCommonVars,bool have_non_var_grouping,List ** func_grouped_rels)1211 check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
1212 						List *groupClauses, List *groupClauseCommonVars,
1213 						bool have_non_var_grouping,
1214 						List **func_grouped_rels)
1215 {
1216 	check_ungrouped_columns_context context;
1217 
1218 	context.pstate = pstate;
1219 	context.qry = qry;
1220 	context.root = NULL;
1221 	context.groupClauses = groupClauses;
1222 	context.groupClauseCommonVars = groupClauseCommonVars;
1223 	context.have_non_var_grouping = have_non_var_grouping;
1224 	context.func_grouped_rels = func_grouped_rels;
1225 	context.sublevels_up = 0;
1226 	context.in_agg_direct_args = false;
1227 	check_ungrouped_columns_walker(node, &context);
1228 }
1229 
1230 static bool
check_ungrouped_columns_walker(Node * node,check_ungrouped_columns_context * context)1231 check_ungrouped_columns_walker(Node *node,
1232 							   check_ungrouped_columns_context *context)
1233 {
1234 	ListCell   *gl;
1235 
1236 	if (node == NULL)
1237 		return false;
1238 	if (IsA(node, Const) ||
1239 		IsA(node, Param))
1240 		return false;			/* constants are always acceptable */
1241 
1242 	if (IsA(node, Aggref))
1243 	{
1244 		Aggref	   *agg = (Aggref *) node;
1245 
1246 		if ((int) agg->agglevelsup == context->sublevels_up)
1247 		{
1248 			/*
1249 			 * If we find an aggregate call of the original level, do not
1250 			 * recurse into its normal arguments, ORDER BY arguments, or
1251 			 * filter; ungrouped vars there are not an error.  But we should
1252 			 * check direct arguments as though they weren't in an aggregate.
1253 			 * We set a special flag in the context to help produce a useful
1254 			 * error message for ungrouped vars in direct arguments.
1255 			 */
1256 			bool		result;
1257 
1258 			Assert(!context->in_agg_direct_args);
1259 			context->in_agg_direct_args = true;
1260 			result = check_ungrouped_columns_walker((Node *) agg->aggdirectargs,
1261 													context);
1262 			context->in_agg_direct_args = false;
1263 			return result;
1264 		}
1265 
1266 		/*
1267 		 * We can skip recursing into aggregates of higher levels altogether,
1268 		 * since they could not possibly contain Vars of concern to us (see
1269 		 * transformAggregateCall).  We do need to look at aggregates of lower
1270 		 * levels, however.
1271 		 */
1272 		if ((int) agg->agglevelsup > context->sublevels_up)
1273 			return false;
1274 	}
1275 
1276 	if (IsA(node, GroupingFunc))
1277 	{
1278 		GroupingFunc *grp = (GroupingFunc *) node;
1279 
1280 		/* handled GroupingFunc separately, no need to recheck at this level */
1281 
1282 		if ((int) grp->agglevelsup >= context->sublevels_up)
1283 			return false;
1284 	}
1285 
1286 	/*
1287 	 * If we have any GROUP BY items that are not simple Vars, check to see if
1288 	 * subexpression as a whole matches any GROUP BY item. We need to do this
1289 	 * at every recursion level so that we recognize GROUPed-BY expressions
1290 	 * before reaching variables within them. But this only works at the outer
1291 	 * query level, as noted above.
1292 	 */
1293 	if (context->have_non_var_grouping && context->sublevels_up == 0)
1294 	{
1295 		foreach(gl, context->groupClauses)
1296 		{
1297 			TargetEntry *tle = lfirst(gl);
1298 
1299 			if (equal(node, tle->expr))
1300 				return false;	/* acceptable, do not descend more */
1301 		}
1302 	}
1303 
1304 	/*
1305 	 * If we have an ungrouped Var of the original query level, we have a
1306 	 * failure.  Vars below the original query level are not a problem, and
1307 	 * neither are Vars from above it.  (If such Vars are ungrouped as far as
1308 	 * their own query level is concerned, that's someone else's problem...)
1309 	 */
1310 	if (IsA(node, Var))
1311 	{
1312 		Var		   *var = (Var *) node;
1313 		RangeTblEntry *rte;
1314 		char	   *attname;
1315 
1316 		if (var->varlevelsup != context->sublevels_up)
1317 			return false;		/* it's not local to my query, ignore */
1318 
1319 		/*
1320 		 * Check for a match, if we didn't do it above.
1321 		 */
1322 		if (!context->have_non_var_grouping || context->sublevels_up != 0)
1323 		{
1324 			foreach(gl, context->groupClauses)
1325 			{
1326 				Var		   *gvar = (Var *) ((TargetEntry *) lfirst(gl))->expr;
1327 
1328 				if (IsA(gvar, Var) &&
1329 					gvar->varno == var->varno &&
1330 					gvar->varattno == var->varattno &&
1331 					gvar->varlevelsup == 0)
1332 					return false;	/* acceptable, we're okay */
1333 			}
1334 		}
1335 
1336 		/*
1337 		 * Check whether the Var is known functionally dependent on the GROUP
1338 		 * BY columns.  If so, we can allow the Var to be used, because the
1339 		 * grouping is really a no-op for this table.  However, this deduction
1340 		 * depends on one or more constraints of the table, so we have to add
1341 		 * those constraints to the query's constraintDeps list, because it's
1342 		 * not semantically valid anymore if the constraint(s) get dropped.
1343 		 * (Therefore, this check must be the last-ditch effort before raising
1344 		 * error: we don't want to add dependencies unnecessarily.)
1345 		 *
1346 		 * Because this is a pretty expensive check, and will have the same
1347 		 * outcome for all columns of a table, we remember which RTEs we've
1348 		 * already proven functional dependency for in the func_grouped_rels
1349 		 * list.  This test also prevents us from adding duplicate entries to
1350 		 * the constraintDeps list.
1351 		 */
1352 		if (list_member_int(*context->func_grouped_rels, var->varno))
1353 			return false;		/* previously proven acceptable */
1354 
1355 		Assert(var->varno > 0 &&
1356 			   (int) var->varno <= list_length(context->pstate->p_rtable));
1357 		rte = rt_fetch(var->varno, context->pstate->p_rtable);
1358 		if (rte->rtekind == RTE_RELATION)
1359 		{
1360 			if (check_functional_grouping(rte->relid,
1361 										  var->varno,
1362 										  0,
1363 										  context->groupClauseCommonVars,
1364 										  &context->qry->constraintDeps))
1365 			{
1366 				*context->func_grouped_rels =
1367 					lappend_int(*context->func_grouped_rels, var->varno);
1368 				return false;	/* acceptable */
1369 			}
1370 		}
1371 
1372 		/* Found an ungrouped local variable; generate error message */
1373 		attname = get_rte_attribute_name(rte, var->varattno);
1374 		if (context->sublevels_up == 0)
1375 			ereport(ERROR,
1376 					(errcode(ERRCODE_GROUPING_ERROR),
1377 					 errmsg("column \"%s.%s\" must appear in the GROUP BY clause or be used in an aggregate function",
1378 							rte->eref->aliasname, attname),
1379 					 context->in_agg_direct_args ?
1380 					 errdetail("Direct arguments of an ordered-set aggregate must use only grouped columns.") : 0,
1381 					 parser_errposition(context->pstate, var->location)));
1382 		else
1383 			ereport(ERROR,
1384 					(errcode(ERRCODE_GROUPING_ERROR),
1385 					 errmsg("subquery uses ungrouped column \"%s.%s\" from outer query",
1386 							rte->eref->aliasname, attname),
1387 					 parser_errposition(context->pstate, var->location)));
1388 	}
1389 
1390 	if (IsA(node, Query))
1391 	{
1392 		/* Recurse into subselects */
1393 		bool		result;
1394 
1395 		context->sublevels_up++;
1396 		result = query_tree_walker((Query *) node,
1397 								   check_ungrouped_columns_walker,
1398 								   (void *) context,
1399 								   0);
1400 		context->sublevels_up--;
1401 		return result;
1402 	}
1403 	return expression_tree_walker(node, check_ungrouped_columns_walker,
1404 								  (void *) context);
1405 }
1406 
1407 /*
1408  * finalize_grouping_exprs -
1409  *	  Scan the given expression tree for GROUPING() and related calls,
1410  *	  and validate and process their arguments.
1411  *
1412  * This is split out from check_ungrouped_columns above because it needs
1413  * to modify the nodes (which it does in-place, not via a mutator) while
1414  * check_ungrouped_columns may see only a copy of the original thanks to
1415  * flattening of join alias vars. So here, we flatten each individual
1416  * GROUPING argument as we see it before comparing it.
1417  */
1418 static void
finalize_grouping_exprs(Node * node,ParseState * pstate,Query * qry,List * groupClauses,PlannerInfo * root,bool have_non_var_grouping)1419 finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
1420 						List *groupClauses, PlannerInfo *root,
1421 						bool have_non_var_grouping)
1422 {
1423 	check_ungrouped_columns_context context;
1424 
1425 	context.pstate = pstate;
1426 	context.qry = qry;
1427 	context.root = root;
1428 	context.groupClauses = groupClauses;
1429 	context.groupClauseCommonVars = NIL;
1430 	context.have_non_var_grouping = have_non_var_grouping;
1431 	context.func_grouped_rels = NULL;
1432 	context.sublevels_up = 0;
1433 	context.in_agg_direct_args = false;
1434 	finalize_grouping_exprs_walker(node, &context);
1435 }
1436 
1437 static bool
finalize_grouping_exprs_walker(Node * node,check_ungrouped_columns_context * context)1438 finalize_grouping_exprs_walker(Node *node,
1439 							   check_ungrouped_columns_context *context)
1440 {
1441 	ListCell   *gl;
1442 
1443 	if (node == NULL)
1444 		return false;
1445 	if (IsA(node, Const) ||
1446 		IsA(node, Param))
1447 		return false;			/* constants are always acceptable */
1448 
1449 	if (IsA(node, Aggref))
1450 	{
1451 		Aggref	   *agg = (Aggref *) node;
1452 
1453 		if ((int) agg->agglevelsup == context->sublevels_up)
1454 		{
1455 			/*
1456 			 * If we find an aggregate call of the original level, do not
1457 			 * recurse into its normal arguments, ORDER BY arguments, or
1458 			 * filter; GROUPING exprs of this level are not allowed there. But
1459 			 * check direct arguments as though they weren't in an aggregate.
1460 			 */
1461 			bool		result;
1462 
1463 			Assert(!context->in_agg_direct_args);
1464 			context->in_agg_direct_args = true;
1465 			result = finalize_grouping_exprs_walker((Node *) agg->aggdirectargs,
1466 													context);
1467 			context->in_agg_direct_args = false;
1468 			return result;
1469 		}
1470 
1471 		/*
1472 		 * We can skip recursing into aggregates of higher levels altogether,
1473 		 * since they could not possibly contain exprs of concern to us (see
1474 		 * transformAggregateCall).  We do need to look at aggregates of lower
1475 		 * levels, however.
1476 		 */
1477 		if ((int) agg->agglevelsup > context->sublevels_up)
1478 			return false;
1479 	}
1480 
1481 	if (IsA(node, GroupingFunc))
1482 	{
1483 		GroupingFunc *grp = (GroupingFunc *) node;
1484 
1485 		/*
1486 		 * We only need to check GroupingFunc nodes at the exact level to
1487 		 * which they belong, since they cannot mix levels in arguments.
1488 		 */
1489 
1490 		if ((int) grp->agglevelsup == context->sublevels_up)
1491 		{
1492 			ListCell   *lc;
1493 			List	   *ref_list = NIL;
1494 
1495 			foreach(lc, grp->args)
1496 			{
1497 				Node	   *expr = lfirst(lc);
1498 				Index		ref = 0;
1499 
1500 				if (context->root)
1501 					expr = flatten_join_alias_vars(context->root, expr);
1502 
1503 				/*
1504 				 * Each expression must match a grouping entry at the current
1505 				 * query level. Unlike the general expression case, we don't
1506 				 * allow functional dependencies or outer references.
1507 				 */
1508 
1509 				if (IsA(expr, Var))
1510 				{
1511 					Var		   *var = (Var *) expr;
1512 
1513 					if (var->varlevelsup == context->sublevels_up)
1514 					{
1515 						foreach(gl, context->groupClauses)
1516 						{
1517 							TargetEntry *tle = lfirst(gl);
1518 							Var		   *gvar = (Var *) tle->expr;
1519 
1520 							if (IsA(gvar, Var) &&
1521 								gvar->varno == var->varno &&
1522 								gvar->varattno == var->varattno &&
1523 								gvar->varlevelsup == 0)
1524 							{
1525 								ref = tle->ressortgroupref;
1526 								break;
1527 							}
1528 						}
1529 					}
1530 				}
1531 				else if (context->have_non_var_grouping &&
1532 						 context->sublevels_up == 0)
1533 				{
1534 					foreach(gl, context->groupClauses)
1535 					{
1536 						TargetEntry *tle = lfirst(gl);
1537 
1538 						if (equal(expr, tle->expr))
1539 						{
1540 							ref = tle->ressortgroupref;
1541 							break;
1542 						}
1543 					}
1544 				}
1545 
1546 				if (ref == 0)
1547 					ereport(ERROR,
1548 							(errcode(ERRCODE_GROUPING_ERROR),
1549 							 errmsg("arguments to GROUPING must be grouping expressions of the associated query level"),
1550 							 parser_errposition(context->pstate,
1551 												exprLocation(expr))));
1552 
1553 				ref_list = lappend_int(ref_list, ref);
1554 			}
1555 
1556 			grp->refs = ref_list;
1557 		}
1558 
1559 		if ((int) grp->agglevelsup > context->sublevels_up)
1560 			return false;
1561 	}
1562 
1563 	if (IsA(node, Query))
1564 	{
1565 		/* Recurse into subselects */
1566 		bool		result;
1567 
1568 		context->sublevels_up++;
1569 		result = query_tree_walker((Query *) node,
1570 								   finalize_grouping_exprs_walker,
1571 								   (void *) context,
1572 								   0);
1573 		context->sublevels_up--;
1574 		return result;
1575 	}
1576 	return expression_tree_walker(node, finalize_grouping_exprs_walker,
1577 								  (void *) context);
1578 }
1579 
1580 
1581 /*
1582  * Given a GroupingSet node, expand it and return a list of lists.
1583  *
1584  * For EMPTY nodes, return a list of one empty list.
1585  *
1586  * For SIMPLE nodes, return a list of one list, which is the node content.
1587  *
1588  * For CUBE and ROLLUP nodes, return a list of the expansions.
1589  *
1590  * For SET nodes, recursively expand contained CUBE and ROLLUP.
1591  */
1592 static List *
expand_groupingset_node(GroupingSet * gs)1593 expand_groupingset_node(GroupingSet *gs)
1594 {
1595 	List	   *result = NIL;
1596 
1597 	switch (gs->kind)
1598 	{
1599 		case GROUPING_SET_EMPTY:
1600 			result = list_make1(NIL);
1601 			break;
1602 
1603 		case GROUPING_SET_SIMPLE:
1604 			result = list_make1(gs->content);
1605 			break;
1606 
1607 		case GROUPING_SET_ROLLUP:
1608 			{
1609 				List	   *rollup_val = gs->content;
1610 				ListCell   *lc;
1611 				int			curgroup_size = list_length(gs->content);
1612 
1613 				while (curgroup_size > 0)
1614 				{
1615 					List	   *current_result = NIL;
1616 					int			i = curgroup_size;
1617 
1618 					foreach(lc, rollup_val)
1619 					{
1620 						GroupingSet *gs_current = (GroupingSet *) lfirst(lc);
1621 
1622 						Assert(gs_current->kind == GROUPING_SET_SIMPLE);
1623 
1624 						current_result
1625 							= list_concat(current_result,
1626 										  list_copy(gs_current->content));
1627 
1628 						/* If we are done with making the current group, break */
1629 						if (--i == 0)
1630 							break;
1631 					}
1632 
1633 					result = lappend(result, current_result);
1634 					--curgroup_size;
1635 				}
1636 
1637 				result = lappend(result, NIL);
1638 			}
1639 			break;
1640 
1641 		case GROUPING_SET_CUBE:
1642 			{
1643 				List	   *cube_list = gs->content;
1644 				int			number_bits = list_length(cube_list);
1645 				uint32		num_sets;
1646 				uint32		i;
1647 
1648 				/* parser should cap this much lower */
1649 				Assert(number_bits < 31);
1650 
1651 				num_sets = (1U << number_bits);
1652 
1653 				for (i = 0; i < num_sets; i++)
1654 				{
1655 					List	   *current_result = NIL;
1656 					ListCell   *lc;
1657 					uint32		mask = 1U;
1658 
1659 					foreach(lc, cube_list)
1660 					{
1661 						GroupingSet *gs_current = (GroupingSet *) lfirst(lc);
1662 
1663 						Assert(gs_current->kind == GROUPING_SET_SIMPLE);
1664 
1665 						if (mask & i)
1666 						{
1667 							current_result
1668 								= list_concat(current_result,
1669 											  list_copy(gs_current->content));
1670 						}
1671 
1672 						mask <<= 1;
1673 					}
1674 
1675 					result = lappend(result, current_result);
1676 				}
1677 			}
1678 			break;
1679 
1680 		case GROUPING_SET_SETS:
1681 			{
1682 				ListCell   *lc;
1683 
1684 				foreach(lc, gs->content)
1685 				{
1686 					List	   *current_result = expand_groupingset_node(lfirst(lc));
1687 
1688 					result = list_concat(result, current_result);
1689 				}
1690 			}
1691 			break;
1692 	}
1693 
1694 	return result;
1695 }
1696 
1697 static int
cmp_list_len_asc(const void * a,const void * b)1698 cmp_list_len_asc(const void *a, const void *b)
1699 {
1700 	int			la = list_length(*(List *const *) a);
1701 	int			lb = list_length(*(List *const *) b);
1702 
1703 	return (la > lb) ? 1 : (la == lb) ? 0 : -1;
1704 }
1705 
1706 /*
1707  * Expand a groupingSets clause to a flat list of grouping sets.
1708  * The returned list is sorted by length, shortest sets first.
1709  *
1710  * This is mainly for the planner, but we use it here too to do
1711  * some consistency checks.
1712  */
1713 List *
expand_grouping_sets(List * groupingSets,int limit)1714 expand_grouping_sets(List *groupingSets, int limit)
1715 {
1716 	List	   *expanded_groups = NIL;
1717 	List	   *result = NIL;
1718 	double		numsets = 1;
1719 	ListCell   *lc;
1720 
1721 	if (groupingSets == NIL)
1722 		return NIL;
1723 
1724 	foreach(lc, groupingSets)
1725 	{
1726 		List	   *current_result = NIL;
1727 		GroupingSet *gs = lfirst(lc);
1728 
1729 		current_result = expand_groupingset_node(gs);
1730 
1731 		Assert(current_result != NIL);
1732 
1733 		numsets *= list_length(current_result);
1734 
1735 		if (limit >= 0 && numsets > limit)
1736 			return NIL;
1737 
1738 		expanded_groups = lappend(expanded_groups, current_result);
1739 	}
1740 
1741 	/*
1742 	 * Do cartesian product between sublists of expanded_groups. While at it,
1743 	 * remove any duplicate elements from individual grouping sets (we must
1744 	 * NOT change the number of sets though)
1745 	 */
1746 
1747 	foreach(lc, (List *) linitial(expanded_groups))
1748 	{
1749 		result = lappend(result, list_union_int(NIL, (List *) lfirst(lc)));
1750 	}
1751 
1752 	for_each_cell(lc, lnext(list_head(expanded_groups)))
1753 	{
1754 		List	   *p = lfirst(lc);
1755 		List	   *new_result = NIL;
1756 		ListCell   *lc2;
1757 
1758 		foreach(lc2, result)
1759 		{
1760 			List	   *q = lfirst(lc2);
1761 			ListCell   *lc3;
1762 
1763 			foreach(lc3, p)
1764 			{
1765 				new_result = lappend(new_result,
1766 									 list_union_int(q, (List *) lfirst(lc3)));
1767 			}
1768 		}
1769 		result = new_result;
1770 	}
1771 
1772 	if (list_length(result) > 1)
1773 	{
1774 		int			result_len = list_length(result);
1775 		List	  **buf = palloc(sizeof(List *) * result_len);
1776 		List	  **ptr = buf;
1777 
1778 		foreach(lc, result)
1779 		{
1780 			*ptr++ = lfirst(lc);
1781 		}
1782 
1783 		qsort(buf, result_len, sizeof(List *), cmp_list_len_asc);
1784 
1785 		result = NIL;
1786 		ptr = buf;
1787 
1788 		while (result_len-- > 0)
1789 			result = lappend(result, *ptr++);
1790 
1791 		pfree(buf);
1792 	}
1793 
1794 	return result;
1795 }
1796 
1797 /*
1798  * get_aggregate_argtypes
1799  *	Identify the specific datatypes passed to an aggregate call.
1800  *
1801  * Given an Aggref, extract the actual datatypes of the input arguments.
1802  * The input datatypes are reported in a way that matches up with the
1803  * aggregate's declaration, ie, any ORDER BY columns attached to a plain
1804  * aggregate are ignored, but we report both direct and aggregated args of
1805  * an ordered-set aggregate.
1806  *
1807  * Datatypes are returned into inputTypes[], which must reference an array
1808  * of length FUNC_MAX_ARGS.
1809  *
1810  * The function result is the number of actual arguments.
1811  */
1812 int
get_aggregate_argtypes(Aggref * aggref,Oid * inputTypes)1813 get_aggregate_argtypes(Aggref *aggref, Oid *inputTypes)
1814 {
1815 	int			numArguments = 0;
1816 	ListCell   *lc;
1817 
1818 	Assert(list_length(aggref->aggargtypes) <= FUNC_MAX_ARGS);
1819 
1820 	foreach(lc, aggref->aggargtypes)
1821 	{
1822 		inputTypes[numArguments++] = lfirst_oid(lc);
1823 	}
1824 
1825 	return numArguments;
1826 }
1827 
1828 /*
1829  * resolve_aggregate_transtype
1830  *	Identify the transition state value's datatype for an aggregate call.
1831  *
1832  * This function resolves a polymorphic aggregate's state datatype.
1833  * It must be passed the aggtranstype from the aggregate's catalog entry,
1834  * as well as the actual argument types extracted by get_aggregate_argtypes.
1835  * (We could fetch pg_aggregate.aggtranstype internally, but all existing
1836  * callers already have the value at hand, so we make them pass it.)
1837  */
1838 Oid
resolve_aggregate_transtype(Oid aggfuncid,Oid aggtranstype,Oid * inputTypes,int numArguments)1839 resolve_aggregate_transtype(Oid aggfuncid,
1840 							Oid aggtranstype,
1841 							Oid *inputTypes,
1842 							int numArguments)
1843 {
1844 	/* resolve actual type of transition state, if polymorphic */
1845 	if (IsPolymorphicType(aggtranstype))
1846 	{
1847 		/* have to fetch the agg's declared input types... */
1848 		Oid		   *declaredArgTypes;
1849 		int			agg_nargs;
1850 
1851 		(void) get_func_signature(aggfuncid, &declaredArgTypes, &agg_nargs);
1852 
1853 		/*
1854 		 * VARIADIC ANY aggs could have more actual than declared args, but
1855 		 * such extra args can't affect polymorphic type resolution.
1856 		 */
1857 		Assert(agg_nargs <= numArguments);
1858 
1859 		aggtranstype = enforce_generic_type_consistency(inputTypes,
1860 														declaredArgTypes,
1861 														agg_nargs,
1862 														aggtranstype,
1863 														false);
1864 		pfree(declaredArgTypes);
1865 	}
1866 	return aggtranstype;
1867 }
1868 
1869 /*
1870  * Create an expression tree for the transition function of an aggregate.
1871  * This is needed so that polymorphic functions can be used within an
1872  * aggregate --- without the expression tree, such functions would not know
1873  * the datatypes they are supposed to use.  (The trees will never actually
1874  * be executed, however, so we can skimp a bit on correctness.)
1875  *
1876  * agg_input_types and agg_state_type identifies the input types of the
1877  * aggregate.  These should be resolved to actual types (ie, none should
1878  * ever be ANYELEMENT etc).
1879  * agg_input_collation is the aggregate function's input collation.
1880  *
1881  * For an ordered-set aggregate, remember that agg_input_types describes
1882  * the direct arguments followed by the aggregated arguments.
1883  *
1884  * transfn_oid and invtransfn_oid identify the funcs to be called; the
1885  * latter may be InvalidOid, however if invtransfn_oid is set then
1886  * transfn_oid must also be set.
1887  *
1888  * Pointers to the constructed trees are returned into *transfnexpr,
1889  * *invtransfnexpr. If there is no invtransfn, the respective pointer is set
1890  * to NULL.  Since use of the invtransfn is optional, NULL may be passed for
1891  * invtransfnexpr.
1892  */
1893 void
build_aggregate_transfn_expr(Oid * agg_input_types,int agg_num_inputs,int agg_num_direct_inputs,bool agg_variadic,Oid agg_state_type,Oid agg_input_collation,Oid transfn_oid,Oid invtransfn_oid,Expr ** transfnexpr,Expr ** invtransfnexpr)1894 build_aggregate_transfn_expr(Oid *agg_input_types,
1895 							 int agg_num_inputs,
1896 							 int agg_num_direct_inputs,
1897 							 bool agg_variadic,
1898 							 Oid agg_state_type,
1899 							 Oid agg_input_collation,
1900 							 Oid transfn_oid,
1901 							 Oid invtransfn_oid,
1902 							 Expr **transfnexpr,
1903 							 Expr **invtransfnexpr)
1904 {
1905 	List	   *args;
1906 	FuncExpr   *fexpr;
1907 	int			i;
1908 
1909 	/*
1910 	 * Build arg list to use in the transfn FuncExpr node.
1911 	 */
1912 	args = list_make1(make_agg_arg(agg_state_type, agg_input_collation));
1913 
1914 	for (i = agg_num_direct_inputs; i < agg_num_inputs; i++)
1915 	{
1916 		args = lappend(args,
1917 					   make_agg_arg(agg_input_types[i], agg_input_collation));
1918 	}
1919 
1920 	fexpr = makeFuncExpr(transfn_oid,
1921 						 agg_state_type,
1922 						 args,
1923 						 InvalidOid,
1924 						 agg_input_collation,
1925 						 COERCE_EXPLICIT_CALL);
1926 	fexpr->funcvariadic = agg_variadic;
1927 	*transfnexpr = (Expr *) fexpr;
1928 
1929 	/*
1930 	 * Build invtransfn expression if requested, with same args as transfn
1931 	 */
1932 	if (invtransfnexpr != NULL)
1933 	{
1934 		if (OidIsValid(invtransfn_oid))
1935 		{
1936 			fexpr = makeFuncExpr(invtransfn_oid,
1937 								 agg_state_type,
1938 								 args,
1939 								 InvalidOid,
1940 								 agg_input_collation,
1941 								 COERCE_EXPLICIT_CALL);
1942 			fexpr->funcvariadic = agg_variadic;
1943 			*invtransfnexpr = (Expr *) fexpr;
1944 		}
1945 		else
1946 			*invtransfnexpr = NULL;
1947 	}
1948 }
1949 
1950 /*
1951  * Like build_aggregate_transfn_expr, but creates an expression tree for the
1952  * combine function of an aggregate, rather than the transition function.
1953  */
1954 void
build_aggregate_combinefn_expr(Oid agg_state_type,Oid agg_input_collation,Oid combinefn_oid,Expr ** combinefnexpr)1955 build_aggregate_combinefn_expr(Oid agg_state_type,
1956 							   Oid agg_input_collation,
1957 							   Oid combinefn_oid,
1958 							   Expr **combinefnexpr)
1959 {
1960 	Node	   *argp;
1961 	List	   *args;
1962 	FuncExpr   *fexpr;
1963 
1964 	/* combinefn takes two arguments of the aggregate state type */
1965 	argp = make_agg_arg(agg_state_type, agg_input_collation);
1966 
1967 	args = list_make2(argp, argp);
1968 
1969 	fexpr = makeFuncExpr(combinefn_oid,
1970 						 agg_state_type,
1971 						 args,
1972 						 InvalidOid,
1973 						 agg_input_collation,
1974 						 COERCE_EXPLICIT_CALL);
1975 	/* combinefn is currently never treated as variadic */
1976 	*combinefnexpr = (Expr *) fexpr;
1977 }
1978 
1979 /*
1980  * Like build_aggregate_transfn_expr, but creates an expression tree for the
1981  * serialization function of an aggregate.
1982  */
1983 void
build_aggregate_serialfn_expr(Oid serialfn_oid,Expr ** serialfnexpr)1984 build_aggregate_serialfn_expr(Oid serialfn_oid,
1985 							  Expr **serialfnexpr)
1986 {
1987 	List	   *args;
1988 	FuncExpr   *fexpr;
1989 
1990 	/* serialfn always takes INTERNAL and returns BYTEA */
1991 	args = list_make1(make_agg_arg(INTERNALOID, InvalidOid));
1992 
1993 	fexpr = makeFuncExpr(serialfn_oid,
1994 						 BYTEAOID,
1995 						 args,
1996 						 InvalidOid,
1997 						 InvalidOid,
1998 						 COERCE_EXPLICIT_CALL);
1999 	*serialfnexpr = (Expr *) fexpr;
2000 }
2001 
2002 /*
2003  * Like build_aggregate_transfn_expr, but creates an expression tree for the
2004  * deserialization function of an aggregate.
2005  */
2006 void
build_aggregate_deserialfn_expr(Oid deserialfn_oid,Expr ** deserialfnexpr)2007 build_aggregate_deserialfn_expr(Oid deserialfn_oid,
2008 								Expr **deserialfnexpr)
2009 {
2010 	List	   *args;
2011 	FuncExpr   *fexpr;
2012 
2013 	/* deserialfn always takes BYTEA, INTERNAL and returns INTERNAL */
2014 	args = list_make2(make_agg_arg(BYTEAOID, InvalidOid),
2015 					  make_agg_arg(INTERNALOID, InvalidOid));
2016 
2017 	fexpr = makeFuncExpr(deserialfn_oid,
2018 						 INTERNALOID,
2019 						 args,
2020 						 InvalidOid,
2021 						 InvalidOid,
2022 						 COERCE_EXPLICIT_CALL);
2023 	*deserialfnexpr = (Expr *) fexpr;
2024 }
2025 
2026 /*
2027  * Like build_aggregate_transfn_expr, but creates an expression tree for the
2028  * final function of an aggregate, rather than the transition function.
2029  */
2030 void
build_aggregate_finalfn_expr(Oid * agg_input_types,int num_finalfn_inputs,Oid agg_state_type,Oid agg_result_type,Oid agg_input_collation,Oid finalfn_oid,Expr ** finalfnexpr)2031 build_aggregate_finalfn_expr(Oid *agg_input_types,
2032 							 int num_finalfn_inputs,
2033 							 Oid agg_state_type,
2034 							 Oid agg_result_type,
2035 							 Oid agg_input_collation,
2036 							 Oid finalfn_oid,
2037 							 Expr **finalfnexpr)
2038 {
2039 	List	   *args;
2040 	int			i;
2041 
2042 	/*
2043 	 * Build expr tree for final function
2044 	 */
2045 	args = list_make1(make_agg_arg(agg_state_type, agg_input_collation));
2046 
2047 	/* finalfn may take additional args, which match agg's input types */
2048 	for (i = 0; i < num_finalfn_inputs - 1; i++)
2049 	{
2050 		args = lappend(args,
2051 					   make_agg_arg(agg_input_types[i], agg_input_collation));
2052 	}
2053 
2054 	*finalfnexpr = (Expr *) makeFuncExpr(finalfn_oid,
2055 										 agg_result_type,
2056 										 args,
2057 										 InvalidOid,
2058 										 agg_input_collation,
2059 										 COERCE_EXPLICIT_CALL);
2060 	/* finalfn is currently never treated as variadic */
2061 }
2062 
2063 /*
2064  * Convenience function to build dummy argument expressions for aggregates.
2065  *
2066  * We really only care that an aggregate support function can discover its
2067  * actual argument types at runtime using get_fn_expr_argtype(), so it's okay
2068  * to use Param nodes that don't correspond to any real Param.
2069  */
2070 static Node *
make_agg_arg(Oid argtype,Oid argcollation)2071 make_agg_arg(Oid argtype, Oid argcollation)
2072 {
2073 	Param	   *argp = makeNode(Param);
2074 
2075 	argp->paramkind = PARAM_EXEC;
2076 	argp->paramid = -1;
2077 	argp->paramtype = argtype;
2078 	argp->paramtypmod = -1;
2079 	argp->paramcollid = argcollation;
2080 	argp->location = -1;
2081 	return (Node *) argp;
2082 }
2083