1 /*-------------------------------------------------------------------------
2  *
3  * aggregate_utils.c
4  *
5  * Implementation of UDFs distributing execution of aggregates across workers.
6  *
7  * When an aggregate has a combinefunc, we use worker_partial_agg to skip
8  * calling finalfunc on workers, instead passing state to coordinator where
9  * it uses combinefunc in coord_combine_agg & applying finalfunc only at end.
10  *
11  * Copyright Citus Data, Inc.
12  *
13  *-------------------------------------------------------------------------
14  */
15 
16 
17 #include "postgres.h"
18 
19 #include "access/htup_details.h"
20 #include "catalog/pg_aggregate.h"
21 #include "catalog/pg_proc.h"
22 #include "catalog/pg_type.h"
23 #include "distributed/version_compat.h"
24 #include "nodes/nodeFuncs.h"
25 #include "utils/acl.h"
26 #include "utils/builtins.h"
27 #include "utils/datum.h"
28 #include "utils/lsyscache.h"
29 #include "utils/syscache.h"
30 #include "utils/typcache.h"
31 #include "fmgr.h"
32 #include "miscadmin.h"
33 #include "pg_config_manual.h"
34 
35 PG_FUNCTION_INFO_V1(worker_partial_agg_sfunc);
36 PG_FUNCTION_INFO_V1(worker_partial_agg_ffunc);
37 PG_FUNCTION_INFO_V1(coord_combine_agg_sfunc);
38 PG_FUNCTION_INFO_V1(coord_combine_agg_ffunc);
39 
40 /*
41  * Holds information describing the structure of aggregation arguments
42  * and helps to efficiently handle both a single argument and multiple
43  * arguments wrapped in a tuple/record. It exploits the fact that
44  * aggregation argument types do not change between subsequent
45  * calls to SFUNC.
46  */
47 typedef struct AggregationArgumentContext
48 {
49 	/* immutable fields */
50 	int argumentCount;
51 	bool isTuple;
52 	TupleDesc tupleDesc;
53 
54 	/* mutable fields */
55 	HeapTuple tuple;
56 	Datum *values;
57 	bool *nulls;
58 } AggregationArgumentContext;
59 
60 /*
61  * internal type for support aggregates to pass transition state alongside
62  * aggregation bookkeeping
63  */
64 typedef struct StypeBox
65 {
66 	Datum value;
67 	Oid agg;
68 	Oid transtype;
69 	int16_t transtypeLen;
70 	bool transtypeByVal;
71 	bool valueNull;
72 	bool valueInit;
73 	AggregationArgumentContext *aggregationArgumentContext;
74 } StypeBox;
75 
76 static HeapTuple GetAggregateForm(Oid oid, Form_pg_aggregate *form);
77 static HeapTuple GetProcForm(Oid oid, Form_pg_proc *form);
78 static HeapTuple GetTypeForm(Oid oid, Form_pg_type *form);
79 static void * pallocInAggContext(FunctionCallInfo fcinfo, size_t size);
80 static void aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid);
81 static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
82 static void InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box,
83 							   HeapTuple aggTuple, Oid transtype,
84 							   AggregationArgumentContext *aggregationArgumentContext);
85 static StypeBox * TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo);
86 static AggregationArgumentContext * CreateAggregationArgumentContext(FunctionCallInfo
87 																	 fcinfo,
88 																	 int argumentIndex);
89 static void ExtractAggregationValues(FunctionCallInfo fcinfo, int argumentIndex,
90 									 AggregationArgumentContext
91 									 *aggregationArgumentContext);
92 static void HandleTransition(StypeBox *box, FunctionCallInfo fcinfo,
93 							 FunctionCallInfo innerFcinfo);
94 static void HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value);
95 static bool TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box);
96 static bool TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc,
97 											   StypeBox *box);
98 
99 /*
100  * GetAggregateForm loads corresponding tuple & Form_pg_aggregate for oid
101  */
102 static HeapTuple
GetAggregateForm(Oid oid,Form_pg_aggregate * form)103 GetAggregateForm(Oid oid, Form_pg_aggregate *form)
104 {
105 	HeapTuple tuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(oid));
106 	if (!HeapTupleIsValid(tuple))
107 	{
108 		elog(ERROR, "citus cache lookup failed for aggregate %u", oid);
109 	}
110 	*form = (Form_pg_aggregate) GETSTRUCT(tuple);
111 	return tuple;
112 }
113 
114 
115 /*
116  * GetProcForm loads corresponding tuple & Form_pg_proc for oid
117  */
118 static HeapTuple
GetProcForm(Oid oid,Form_pg_proc * form)119 GetProcForm(Oid oid, Form_pg_proc *form)
120 {
121 	HeapTuple tuple = SearchSysCache1(PROCOID, ObjectIdGetDatum(oid));
122 	if (!HeapTupleIsValid(tuple))
123 	{
124 		elog(ERROR, "citus cache lookup failed for function %u", oid);
125 	}
126 	*form = (Form_pg_proc) GETSTRUCT(tuple);
127 	return tuple;
128 }
129 
130 
131 /*
132  * GetTypeForm loads corresponding tuple & Form_pg_type for oid
133  */
134 static HeapTuple
GetTypeForm(Oid oid,Form_pg_type * form)135 GetTypeForm(Oid oid, Form_pg_type *form)
136 {
137 	HeapTuple tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(oid));
138 	if (!HeapTupleIsValid(tuple))
139 	{
140 		elog(ERROR, "citus cache lookup failed for type %u", oid);
141 	}
142 	*form = (Form_pg_type) GETSTRUCT(tuple);
143 	return tuple;
144 }
145 
146 
147 /*
148  * pallocInAggContext calls palloc in fcinfo's aggregate context
149  */
150 static void *
pallocInAggContext(FunctionCallInfo fcinfo,size_t size)151 pallocInAggContext(FunctionCallInfo fcinfo, size_t size)
152 {
153 	MemoryContext aggregateContext;
154 	if (!AggCheckCallContext(fcinfo, &aggregateContext))
155 	{
156 		elog(ERROR, "Aggregate function called without an aggregate context");
157 	}
158 	return MemoryContextAlloc(aggregateContext, size);
159 }
160 
161 
162 /*
163  * aclcheckAggregate verifies that the given user has ACL_EXECUTE to the given proc
164  */
165 static void
aclcheckAggregate(ObjectType objectType,Oid userOid,Oid funcOid)166 aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid)
167 {
168 	AclResult aclresult;
169 	if (funcOid != InvalidOid)
170 	{
171 		aclresult = pg_proc_aclcheck(funcOid, userOid, ACL_EXECUTE);
172 		if (aclresult != ACLCHECK_OK)
173 		{
174 			aclcheck_error(aclresult, objectType, get_func_name(funcOid));
175 		}
176 	}
177 }
178 
179 
180 /* Copied from nodeAgg.c */
181 static Datum
GetAggInitVal(Datum textInitVal,Oid transtype)182 GetAggInitVal(Datum textInitVal, Oid transtype)
183 {
184 	/* *INDENT-OFF* */
185 	Oid			typinput,
186 				typioparam;
187 	char	   *strInitVal;
188 	Datum		initVal;
189 
190 	getTypeInputInfo(transtype, &typinput, &typioparam);
191 	strInitVal = TextDatumGetCString(textInitVal);
192 	initVal = OidInputFunctionCall(typinput, strInitVal,
193 								   typioparam, -1);
194 	pfree(strInitVal);
195 	return initVal;
196 	/* *INDENT-ON* */
197 }
198 
199 
200 /*
201  * InitializeStypeBox fills in the rest of an StypeBox's fields besides agg,
202  * handling both permission checking & setting up the initial transition state.
203  */
204 static void
InitializeStypeBox(FunctionCallInfo fcinfo,StypeBox * box,HeapTuple aggTuple,Oid transtype,AggregationArgumentContext * aggregationArgumentContext)205 InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, Oid
206 				   transtype, AggregationArgumentContext *aggregationArgumentContext)
207 {
208 	Form_pg_aggregate aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
209 	Oid userId = GetUserId();
210 
211 	/* First we make ACL_EXECUTE checks as would be done in nodeAgg.c */
212 	aclcheckAggregate(OBJECT_AGGREGATE, userId, aggform->aggfnoid);
213 	aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggfinalfn);
214 	aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggtransfn);
215 	aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggdeserialfn);
216 	aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggserialfn);
217 	aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggcombinefn);
218 
219 	Datum textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple,
220 										Anum_pg_aggregate_agginitval,
221 										&box->valueNull);
222 	box->transtype = transtype;
223 	box->valueInit = !box->valueNull;
224 	box->aggregationArgumentContext = aggregationArgumentContext;
225 	if (box->valueNull)
226 	{
227 		box->value = (Datum) 0;
228 	}
229 	else
230 	{
231 		MemoryContext aggregateContext;
232 		if (!AggCheckCallContext(fcinfo, &aggregateContext))
233 		{
234 			elog(ERROR, "InitializeStypeBox called from non aggregate context");
235 		}
236 		MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext);
237 
238 		box->value = GetAggInitVal(textInitVal, transtype);
239 
240 		MemoryContextSwitchTo(oldContext);
241 	}
242 }
243 
244 
245 /*
246  * TryCreateStypeBoxFromFcinfoAggref attempts to initialize an StypeBox through
247  * introspection of the fcinfo's Aggref from AggGetAggref. This is required
248  * when we receive no intermediate rows.
249  *
250  * Returns NULL if the Aggref isn't our expected shape.
251  */
252 static StypeBox *
TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo)253 TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo)
254 {
255 	Aggref *aggref = AggGetAggref(fcinfo);
256 	if (aggref == NULL || aggref->args == NIL)
257 	{
258 		return NULL;
259 	}
260 
261 	TargetEntry *aggArg = linitial(aggref->args);
262 	if (!IsA(aggArg->expr, Const))
263 	{
264 		return NULL;
265 	}
266 
267 	Const *aggConst = (Const *) aggArg->expr;
268 	if (aggConst->consttype != OIDOID && aggConst->consttype != REGPROCEDUREOID)
269 	{
270 		return NULL;
271 	}
272 
273 	Form_pg_aggregate aggform;
274 	StypeBox *box = pallocInAggContext(fcinfo, sizeof(StypeBox));
275 	box->agg = DatumGetObjectId(aggConst->constvalue);
276 	HeapTuple aggTuple = GetAggregateForm(box->agg, &aggform);
277 	InitializeStypeBox(fcinfo, box, aggTuple, aggform->aggtranstype, NULL);
278 	ReleaseSysCache(aggTuple);
279 
280 	return box;
281 }
282 
283 
284 /*
285  * CreateAggregationArgumentContext creates an AggregationArgumentContext tailored
286  * to handling the aggregation of input arguments identical to type at
287  * 'argumentIndex' in 'fcinfo'.
288  */
289 static AggregationArgumentContext *
CreateAggregationArgumentContext(FunctionCallInfo fcinfo,int argumentIndex)290 CreateAggregationArgumentContext(FunctionCallInfo fcinfo, int argumentIndex)
291 {
292 	AggregationArgumentContext *aggregationArgumentContext =
293 		pallocInAggContext(fcinfo, sizeof(AggregationArgumentContext));
294 
295 	/* check if input comes combined into tuple/record */
296 	if (RECORDOID == get_fn_expr_argtype(fcinfo->flinfo, argumentIndex))
297 	{
298 		/* initialize context to handle aggregation argument combined into tuple */
299 		if (fcGetArgNull(fcinfo, argumentIndex))
300 		{
301 			ereport(ERROR, (errmsg("worker_partial_agg_sfunc: null record input"),
302 							errhint("Elements of record may be null")));
303 		}
304 
305 		/* retrieve tuple header */
306 		HeapTupleHeader tupleHeader = PG_GETARG_HEAPTUPLEHEADER(argumentIndex);
307 
308 		/* extract type info from the tuple */
309 		TupleDesc tupleDesc =
310 			lookup_rowtype_tupdesc(HeapTupleHeaderGetTypeId(tupleHeader),
311 								   HeapTupleHeaderGetTypMod(tupleHeader));
312 
313 		/* create a copy we can keep */
314 		TupleDesc tupleDescCopy = pallocInAggContext(fcinfo, TupleDescSize(tupleDesc));
315 		TupleDescCopy(tupleDescCopy, tupleDesc);
316 		ReleaseTupleDesc(tupleDesc);
317 
318 		/* build a HeapTuple control structure */
319 		HeapTuple tuple = pallocInAggContext(fcinfo, sizeof(HeapTupleData));
320 		ItemPointerSetInvalid(&(tuple->t_self));
321 		tuple->t_tableOid = InvalidOid;
322 
323 		/* initialize context to handle multiple aggregation arguments */
324 		aggregationArgumentContext->argumentCount = tupleDescCopy->natts;
325 
326 		aggregationArgumentContext->values =
327 			pallocInAggContext(fcinfo, tupleDescCopy->natts * sizeof(Datum));
328 
329 		aggregationArgumentContext->nulls =
330 			pallocInAggContext(fcinfo, tupleDescCopy->natts * sizeof(bool));
331 
332 		aggregationArgumentContext->isTuple = true;
333 		aggregationArgumentContext->tupleDesc = tupleDescCopy;
334 		aggregationArgumentContext->tuple = tuple;
335 	}
336 	else
337 	{
338 		/* initialize context to handle single aggregation argument */
339 		aggregationArgumentContext->argumentCount = 1;
340 		aggregationArgumentContext->values = pallocInAggContext(fcinfo, sizeof(Datum));
341 		aggregationArgumentContext->nulls = pallocInAggContext(fcinfo, sizeof(bool));
342 		aggregationArgumentContext->isTuple = false;
343 		aggregationArgumentContext->tupleDesc = NULL;
344 		aggregationArgumentContext->tuple = NULL;
345 	}
346 
347 	return aggregationArgumentContext;
348 }
349 
350 
351 /*
352  * ExtractAggregationValues extracts aggregation argument values and stores them in
353  * the mutable fields of AggregationArgumentContext.
354  */
355 static void
ExtractAggregationValues(FunctionCallInfo fcinfo,int argumentIndex,AggregationArgumentContext * aggregationArgumentContext)356 ExtractAggregationValues(FunctionCallInfo fcinfo, int argumentIndex,
357 						 AggregationArgumentContext *aggregationArgumentContext)
358 {
359 	if (aggregationArgumentContext->isTuple)
360 	{
361 		if (fcGetArgNull(fcinfo, argumentIndex))
362 		{
363 			/* handle null record input */
364 			for (int i = 0; i < aggregationArgumentContext->argumentCount; i++)
365 			{
366 				aggregationArgumentContext->values[i] = 0;
367 				aggregationArgumentContext->nulls[i] = true;
368 			}
369 		}
370 		else
371 		{
372 			/* handle tuple/record input */
373 			HeapTupleHeader tupleHeader =
374 				DatumGetHeapTupleHeader(fcGetArgValue(fcinfo, argumentIndex));
375 
376 			if (HeapTupleHeaderGetNatts(tupleHeader) !=
377 				aggregationArgumentContext->argumentCount ||
378 				HeapTupleHeaderGetTypeId(tupleHeader) !=
379 				aggregationArgumentContext->tupleDesc->tdtypeid ||
380 				HeapTupleHeaderGetTypMod(tupleHeader) !=
381 				aggregationArgumentContext->tupleDesc->tdtypmod)
382 			{
383 				ereport(ERROR, (errmsg("worker_partial_agg_sfunc received "
384 									   "incompatible record")));
385 			}
386 
387 			aggregationArgumentContext->tuple->t_len =
388 				HeapTupleHeaderGetDatumLength(tupleHeader);
389 
390 			aggregationArgumentContext->tuple->t_data = tupleHeader;
391 
392 			/* break down the tuple into fields */
393 			heap_deform_tuple(
394 				aggregationArgumentContext->tuple,
395 				aggregationArgumentContext->tupleDesc,
396 				aggregationArgumentContext->values,
397 				aggregationArgumentContext->nulls);
398 		}
399 	}
400 	else
401 	{
402 		/* extract single argument value */
403 		aggregationArgumentContext->values[0] = fcGetArgValue(fcinfo, argumentIndex);
404 		aggregationArgumentContext->nulls[0] = fcGetArgNull(fcinfo, argumentIndex);
405 	}
406 }
407 
408 
409 /*
410  * HandleTransition copies logic used in nodeAgg's advance_transition_function
411  * for handling result of transition function.
412  */
413 static void
HandleTransition(StypeBox * box,FunctionCallInfo fcinfo,FunctionCallInfo innerFcinfo)414 HandleTransition(StypeBox *box, FunctionCallInfo fcinfo, FunctionCallInfo innerFcinfo)
415 {
416 	Datum newVal = FunctionCallInvoke(innerFcinfo);
417 	bool newValIsNull = innerFcinfo->isnull;
418 
419 	if (!box->transtypeByVal &&
420 		DatumGetPointer(newVal) != DatumGetPointer(box->value))
421 	{
422 		if (!newValIsNull)
423 		{
424 			MemoryContext aggregateContext;
425 
426 			if (!AggCheckCallContext(fcinfo, &aggregateContext))
427 			{
428 				elog(ERROR,
429 					 "HandleTransition called from non aggregate context");
430 			}
431 
432 			MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext);
433 			if (!(DatumIsReadWriteExpandedObject(newVal,
434 												 false, box->transtypeLen) &&
435 				  MemoryContextGetParent(DatumGetEOHP(newVal)->eoh_context) ==
436 				  CurrentMemoryContext))
437 			{
438 				newVal = datumCopy(newVal, box->transtypeByVal, box->transtypeLen);
439 			}
440 			MemoryContextSwitchTo(oldContext);
441 		}
442 
443 		if (!box->valueNull)
444 		{
445 			if (DatumIsReadWriteExpandedObject(box->value,
446 											   false, box->transtypeLen))
447 			{
448 				DeleteExpandedObject(box->value);
449 			}
450 			else
451 			{
452 				pfree(DatumGetPointer(box->value));
453 			}
454 		}
455 	}
456 
457 	box->value = newVal;
458 	box->valueNull = newValIsNull;
459 }
460 
461 
462 /*
463  * HandleStrictUninit handles initialization of state for when
464  * transition function is strict & state has not yet been initialized.
465  */
466 static void
HandleStrictUninit(StypeBox * box,FunctionCallInfo fcinfo,Datum value)467 HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value)
468 {
469 	MemoryContext aggregateContext;
470 
471 	if (!AggCheckCallContext(fcinfo, &aggregateContext))
472 	{
473 		elog(ERROR, "HandleStrictUninit called from non aggregate context");
474 	}
475 
476 	MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext);
477 	box->value = datumCopy(value, box->transtypeByVal, box->transtypeLen);
478 	MemoryContextSwitchTo(oldContext);
479 
480 	box->valueNull = false;
481 	box->valueInit = true;
482 }
483 
484 
485 /*
486  * worker_partial_agg_sfunc advances transition state,
487  * essentially implementing the following pseudocode:
488  *
489  * (box, agg, ...) -> box
490  * box.agg = agg;
491  * box.value = agg.sfunc(box.value, ...);
492  * return box
493  */
494 Datum
worker_partial_agg_sfunc(PG_FUNCTION_ARGS)495 worker_partial_agg_sfunc(PG_FUNCTION_ARGS)
496 {
497 	StypeBox *box = NULL;
498 	Form_pg_aggregate aggform;
499 	LOCAL_FCINFO(innerFcinfo, FUNC_MAX_ARGS);
500 	FmgrInfo info;
501 	int argumentIndex = 0;
502 	bool initialCall = PG_ARGISNULL(0);
503 
504 	if (initialCall)
505 	{
506 		if (PG_ARGISNULL(1))
507 		{
508 			ereport(ERROR, (errmsg("worker_partial_agg_sfunc received invalid null "
509 								   "input for second argument")));
510 		}
511 		box = pallocInAggContext(fcinfo, sizeof(StypeBox));
512 		box->agg = PG_GETARG_OID(1);
513 		box->aggregationArgumentContext = CreateAggregationArgumentContext(fcinfo, 2);
514 
515 		if (!TypecheckWorkerPartialAggArgType(fcinfo, box))
516 		{
517 			ereport(ERROR, (errmsg("worker_partial_agg_sfunc could not confirm type "
518 								   "correctness")));
519 		}
520 	}
521 	else
522 	{
523 		box = (StypeBox *) PG_GETARG_POINTER(0);
524 		Assert(box->agg == PG_GETARG_OID(1));
525 	}
526 
527 	HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
528 	Oid aggsfunc = aggform->aggtransfn;
529 
530 	if (initialCall)
531 	{
532 		InitializeStypeBox(fcinfo, box, aggtuple, aggform->aggtranstype,
533 						   box->aggregationArgumentContext);
534 	}
535 	ReleaseSysCache(aggtuple);
536 	if (initialCall)
537 	{
538 		get_typlenbyval(box->transtype,
539 						&box->transtypeLen,
540 						&box->transtypeByVal);
541 	}
542 
543 	/*
544 	 * Get aggregation values, which may be either wrapped in a
545 	 * tuple (multi-argument case) or a singular, unwrapped value.
546 	 */
547 	ExtractAggregationValues(fcinfo, 2, box->aggregationArgumentContext);
548 
549 	fmgr_info(aggsfunc, &info);
550 	if (info.fn_strict)
551 	{
552 		for (argumentIndex = 0;
553 			 argumentIndex < box->aggregationArgumentContext->argumentCount;
554 			 argumentIndex++)
555 		{
556 			if (box->aggregationArgumentContext->nulls[argumentIndex])
557 			{
558 				PG_RETURN_POINTER(box);
559 			}
560 		}
561 
562 		if (!box->valueInit)
563 		{
564 			/* For 'strict' transition functions, if the initial state value is null
565 			 * then the first argument value of the first row with all-nonnull input
566 			 * values replaces the state value.
567 			 */
568 			Datum stateValue = box->aggregationArgumentContext->values[0];
569 			HandleStrictUninit(box, fcinfo, stateValue);
570 
571 			PG_RETURN_POINTER(box);
572 		}
573 
574 		if (box->valueNull)
575 		{
576 			PG_RETURN_POINTER(box);
577 		}
578 	}
579 
580 	/* if aggregate function has N parameters, corresponding SFUNC has N+1 */
581 	InitFunctionCallInfoData(*innerFcinfo, &info,
582 							 box->aggregationArgumentContext->argumentCount + 1,
583 							 fcinfo->fncollation,
584 							 fcinfo->context, fcinfo->resultinfo);
585 	fcSetArgExt(innerFcinfo, 0, box->value, box->valueNull);
586 
587 
588 	for (argumentIndex = 0;
589 		 argumentIndex < box->aggregationArgumentContext->argumentCount;
590 		 argumentIndex++)
591 	{
592 		fcSetArgExt(innerFcinfo, argumentIndex + 1,
593 					box->aggregationArgumentContext->values[argumentIndex],
594 					box->aggregationArgumentContext->nulls[argumentIndex]);
595 	}
596 
597 	HandleTransition(box, fcinfo, innerFcinfo);
598 
599 	PG_RETURN_POINTER(box);
600 }
601 
602 
603 /*
604  * worker_partial_agg_ffunc serializes transition state,
605  * essentially implementing the following pseudocode:
606  *
607  * (box) -> text
608  * return box.agg.stype.output(box.value)
609  */
610 Datum
worker_partial_agg_ffunc(PG_FUNCTION_ARGS)611 worker_partial_agg_ffunc(PG_FUNCTION_ARGS)
612 {
613 	LOCAL_FCINFO(innerFcinfo, 1);
614 	FmgrInfo info;
615 	StypeBox *box = (StypeBox *) (PG_ARGISNULL(0) ? NULL : PG_GETARG_POINTER(0));
616 	Form_pg_aggregate aggform;
617 	Oid typoutput = InvalidOid;
618 	bool typIsVarlena = false;
619 
620 	if (box == NULL)
621 	{
622 		box = TryCreateStypeBoxFromFcinfoAggref(fcinfo);
623 	}
624 
625 	if (box == NULL || box->valueNull)
626 	{
627 		PG_RETURN_NULL();
628 	}
629 
630 	HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
631 
632 	if (aggform->aggcombinefn == InvalidOid)
633 	{
634 		ereport(ERROR, (errmsg(
635 							"worker_partial_agg_ffunc expects an aggregate with COMBINEFUNC")));
636 	}
637 
638 	if (aggform->aggtranstype == INTERNALOID)
639 	{
640 		ereport(ERROR,
641 				(errmsg(
642 					 "worker_partial_agg_ffunc does not support aggregates with INTERNAL transition state")));
643 	}
644 
645 	Oid transtype = aggform->aggtranstype;
646 	ReleaseSysCache(aggtuple);
647 
648 	getTypeOutputInfo(transtype, &typoutput, &typIsVarlena);
649 
650 	fmgr_info(typoutput, &info);
651 
652 	InitFunctionCallInfoData(*innerFcinfo, &info, 1, fcinfo->fncollation,
653 							 fcinfo->context, fcinfo->resultinfo);
654 	fcSetArgExt(innerFcinfo, 0, box->value, box->valueNull);
655 
656 	Datum result = FunctionCallInvoke(innerFcinfo);
657 
658 	if (innerFcinfo->isnull)
659 	{
660 		PG_RETURN_NULL();
661 	}
662 	PG_RETURN_DATUM(result);
663 }
664 
665 
666 /*
667  * coord_combine_agg_sfunc deserializes transition state from worker
668  * & advances transition state using combinefunc,
669  * essentially implementing the following pseudocode:
670  *
671  * (box, agg, text) -> box
672  * box.agg = agg
673  * box.value = agg.combine(box.value, agg.stype.input(text))
674  * return box
675  */
676 Datum
coord_combine_agg_sfunc(PG_FUNCTION_ARGS)677 coord_combine_agg_sfunc(PG_FUNCTION_ARGS)
678 {
679 	LOCAL_FCINFO(innerFcinfo, 3);
680 	FmgrInfo info;
681 	Form_pg_aggregate aggform;
682 	Form_pg_type transtypeform;
683 	Datum value;
684 	StypeBox *box = NULL;
685 
686 	if (PG_ARGISNULL(0))
687 	{
688 		box = pallocInAggContext(fcinfo, sizeof(StypeBox));
689 		box->agg = PG_GETARG_OID(1);
690 	}
691 	else
692 	{
693 		box = (StypeBox *) PG_GETARG_POINTER(0);
694 		Assert(box->agg == PG_GETARG_OID(1));
695 	}
696 
697 	HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
698 
699 	if (aggform->aggcombinefn == InvalidOid)
700 	{
701 		ereport(ERROR, (errmsg(
702 							"coord_combine_agg_sfunc expects an aggregate with COMBINEFUNC")));
703 	}
704 
705 	if (aggform->aggtranstype == INTERNALOID)
706 	{
707 		ereport(ERROR,
708 				(errmsg(
709 					 "coord_combine_agg_sfunc does not support aggregates with INTERNAL transition state")));
710 	}
711 
712 	Oid combine = aggform->aggcombinefn;
713 
714 	if (PG_ARGISNULL(0))
715 	{
716 		InitializeStypeBox(fcinfo, box, aggtuple, aggform->aggtranstype, NULL);
717 	}
718 
719 	ReleaseSysCache(aggtuple);
720 
721 	if (PG_ARGISNULL(0))
722 	{
723 		get_typlenbyval(box->transtype,
724 						&box->transtypeLen,
725 						&box->transtypeByVal);
726 	}
727 
728 	bool valueNull = PG_ARGISNULL(2);
729 	HeapTuple transtypetuple = GetTypeForm(box->transtype, &transtypeform);
730 	Oid ioparam = getTypeIOParam(transtypetuple);
731 	Oid deserial = transtypeform->typinput;
732 	ReleaseSysCache(transtypetuple);
733 
734 	fmgr_info(deserial, &info);
735 	if (valueNull && info.fn_strict)
736 	{
737 		value = (Datum) 0;
738 	}
739 	else
740 	{
741 		InitFunctionCallInfoData(*innerFcinfo, &info, 3, fcinfo->fncollation,
742 								 fcinfo->context, fcinfo->resultinfo);
743 		fcSetArgExt(innerFcinfo, 0, PG_GETARG_DATUM(2), valueNull);
744 		fcSetArg(innerFcinfo, 1, ObjectIdGetDatum(ioparam));
745 		fcSetArg(innerFcinfo, 2, Int32GetDatum(-1)); /* typmod */
746 
747 		value = FunctionCallInvoke(innerFcinfo);
748 		valueNull = innerFcinfo->isnull;
749 	}
750 
751 	fmgr_info(combine, &info);
752 
753 	if (info.fn_strict)
754 	{
755 		if (valueNull)
756 		{
757 			PG_RETURN_POINTER(box);
758 		}
759 
760 		if (!box->valueInit)
761 		{
762 			HandleStrictUninit(box, fcinfo, value);
763 			PG_RETURN_POINTER(box);
764 		}
765 
766 		if (box->valueNull)
767 		{
768 			PG_RETURN_POINTER(box);
769 		}
770 	}
771 
772 	InitFunctionCallInfoData(*innerFcinfo, &info, 2, fcinfo->fncollation,
773 							 fcinfo->context, fcinfo->resultinfo);
774 	fcSetArgExt(innerFcinfo, 0, box->value, box->valueNull);
775 	fcSetArgExt(innerFcinfo, 1, value, valueNull);
776 
777 	HandleTransition(box, fcinfo, innerFcinfo);
778 
779 	PG_RETURN_POINTER(box);
780 }
781 
782 
783 /*
784  * coord_combine_agg_ffunc applies finalfunc of aggregate to state,
785  * essentially implementing the following pseudocode:
786  *
787  * (box, ...) -> fval
788  * return box.agg.ffunc(box.value)
789  */
790 Datum
coord_combine_agg_ffunc(PG_FUNCTION_ARGS)791 coord_combine_agg_ffunc(PG_FUNCTION_ARGS)
792 {
793 	StypeBox *box = (StypeBox *) (PG_ARGISNULL(0) ? NULL : PG_GETARG_POINTER(0));
794 	LOCAL_FCINFO(innerFcinfo, FUNC_MAX_ARGS);
795 	FmgrInfo info;
796 	int innerNargs = 0;
797 	Form_pg_aggregate aggform;
798 	Form_pg_proc ffuncform;
799 
800 	if (box == NULL)
801 	{
802 		box = TryCreateStypeBoxFromFcinfoAggref(fcinfo);
803 
804 		if (box == NULL)
805 		{
806 			PG_RETURN_NULL();
807 		}
808 	}
809 
810 	HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
811 	Oid ffunc = aggform->aggfinalfn;
812 	bool fextra = aggform->aggfinalextra;
813 	ReleaseSysCache(aggtuple);
814 
815 	if (!TypecheckCoordCombineAggReturnType(fcinfo, ffunc, box))
816 	{
817 		ereport(ERROR, (errmsg(
818 							"coord_combine_agg_ffunc could not confirm type correctness")));
819 	}
820 
821 	if (ffunc == InvalidOid)
822 	{
823 		if (box->valueNull)
824 		{
825 			PG_RETURN_NULL();
826 		}
827 		PG_RETURN_DATUM(box->value);
828 	}
829 
830 	HeapTuple ffunctuple = GetProcForm(ffunc, &ffuncform);
831 	bool finalStrict = ffuncform->proisstrict;
832 	ReleaseSysCache(ffunctuple);
833 
834 	if (finalStrict && box->valueNull)
835 	{
836 		PG_RETURN_NULL();
837 	}
838 
839 	if (fextra)
840 	{
841 		innerNargs = fcinfo->nargs;
842 	}
843 	else
844 	{
845 		innerNargs = 1;
846 	}
847 	fmgr_info(ffunc, &info);
848 	InitFunctionCallInfoData(*innerFcinfo, &info, innerNargs, fcinfo->fncollation,
849 							 fcinfo->context, fcinfo->resultinfo);
850 	fcSetArgExt(innerFcinfo, 0, box->value, box->valueNull);
851 	for (int argumentIndex = 1; argumentIndex < innerNargs; argumentIndex++)
852 	{
853 		fcSetArgNull(innerFcinfo, argumentIndex);
854 	}
855 
856 	Datum result = FunctionCallInvoke(innerFcinfo);
857 	fcinfo->isnull = innerFcinfo->isnull;
858 	return result;
859 }
860 
861 
862 /*
863  * TypecheckWorkerPartialAggArgType returns whether the arguments being passed to
864  * worker_partial_agg match the arguments expected by the aggregate being distributed.
865  */
866 static bool
TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo,StypeBox * box)867 TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box)
868 {
869 	Aggref *aggref = AggGetAggref(fcinfo);
870 	if (aggref == NULL)
871 	{
872 		return false;
873 	}
874 
875 	Assert(list_length(aggref->args) == 2);
876 	TargetEntry *aggarg = list_nth(aggref->args, 1);
877 
878 	bool argtypesNull;
879 	HeapTuple proctuple = SearchSysCache1(PROCOID, ObjectIdGetDatum(box->agg));
880 	if (!HeapTupleIsValid(proctuple))
881 	{
882 		return false;
883 	}
884 
885 	Datum argtypes = SysCacheGetAttr(PROCOID, proctuple,
886 									 Anum_pg_proc_proargtypes,
887 									 &argtypesNull);
888 	Assert(!argtypesNull);
889 	ReleaseSysCache(proctuple);
890 
891 	if (ARR_NDIM(DatumGetArrayTypeP(argtypes)) != 1)
892 	{
893 		elog(ERROR, "worker_partial_agg_sfunc cannot type check aggregates "
894 					"taking multi-dimensional arguments");
895 	}
896 
897 	int aggregateArgCount = ARR_DIMS(DatumGetArrayTypeP(argtypes))[0];
898 
899 	/* we expect aggregate function to have at least a single parameter */
900 	if (box->aggregationArgumentContext->argumentCount != aggregateArgCount)
901 	{
902 		return false;
903 	}
904 
905 	int aggregateArgIndex = 0;
906 	Datum argType;
907 
908 	if (box->aggregationArgumentContext->isTuple)
909 	{
910 		/* check if record element types match aggregate input parameters */
911 		for (aggregateArgIndex = 0; aggregateArgIndex < aggregateArgCount;
912 			 aggregateArgIndex++)
913 		{
914 			argType = array_get_element(argtypes, 1, &aggregateArgIndex, -1, sizeof(Oid),
915 										true, 'i', &argtypesNull);
916 			Assert(!argtypesNull);
917 			TupleDesc tupleDesc = box->aggregationArgumentContext->tupleDesc;
918 			if (argType != tupleDesc->attrs[aggregateArgIndex].atttypid)
919 			{
920 				return false;
921 			}
922 		}
923 
924 		return true;
925 	}
926 	else
927 	{
928 		argType = array_get_element(argtypes, 1, &aggregateArgIndex, -1, sizeof(Oid),
929 									true, 'i', &argtypesNull);
930 		Assert(!argtypesNull);
931 
932 		return exprType((Node *) aggarg->expr) == DatumGetObjectId(argType);
933 	}
934 }
935 
936 
937 /*
938  * TypecheckCoordCombineAggReturnType returns whether the return type of the aggregate
939  * being distributed by coord_combine_agg matches the null constant used to inform postgres
940  * what the aggregate's expected return type is.
941  */
942 static bool
TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo,Oid ffunc,StypeBox * box)943 TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc, StypeBox *box)
944 {
945 	Aggref *aggref = AggGetAggref(fcinfo);
946 	if (aggref == NULL)
947 	{
948 		return false;
949 	}
950 
951 	Oid finalType = ffunc == InvalidOid ?
952 					box->transtype : get_func_rettype(ffunc);
953 
954 	Assert(list_length(aggref->args) == 3);
955 	TargetEntry *nulltag = list_nth(aggref->args, 2);
956 
957 	return nulltag != NULL && IsA(nulltag->expr, Const) &&
958 		   ((Const *) nulltag->expr)->consttype == finalType;
959 }
960