1 /*-------------------------------------------------------------------------
2 *
3 * aggregate_utils.c
4 *
5 * Implementation of UDFs distributing execution of aggregates across workers.
6 *
7 * When an aggregate has a combinefunc, we use worker_partial_agg to skip
8 * calling finalfunc on workers, instead passing state to coordinator where
9 * it uses combinefunc in coord_combine_agg & applying finalfunc only at end.
10 *
11 * Copyright Citus Data, Inc.
12 *
13 *-------------------------------------------------------------------------
14 */
15
16
17 #include "postgres.h"
18
19 #include "access/htup_details.h"
20 #include "catalog/pg_aggregate.h"
21 #include "catalog/pg_proc.h"
22 #include "catalog/pg_type.h"
23 #include "distributed/version_compat.h"
24 #include "nodes/nodeFuncs.h"
25 #include "utils/acl.h"
26 #include "utils/builtins.h"
27 #include "utils/datum.h"
28 #include "utils/lsyscache.h"
29 #include "utils/syscache.h"
30 #include "utils/typcache.h"
31 #include "fmgr.h"
32 #include "miscadmin.h"
33 #include "pg_config_manual.h"
34
35 PG_FUNCTION_INFO_V1(worker_partial_agg_sfunc);
36 PG_FUNCTION_INFO_V1(worker_partial_agg_ffunc);
37 PG_FUNCTION_INFO_V1(coord_combine_agg_sfunc);
38 PG_FUNCTION_INFO_V1(coord_combine_agg_ffunc);
39
40 /*
41 * Holds information describing the structure of aggregation arguments
42 * and helps to efficiently handle both a single argument and multiple
43 * arguments wrapped in a tuple/record. It exploits the fact that
44 * aggregation argument types do not change between subsequent
45 * calls to SFUNC.
46 */
47 typedef struct AggregationArgumentContext
48 {
49 /* immutable fields */
50 int argumentCount;
51 bool isTuple;
52 TupleDesc tupleDesc;
53
54 /* mutable fields */
55 HeapTuple tuple;
56 Datum *values;
57 bool *nulls;
58 } AggregationArgumentContext;
59
60 /*
61 * internal type for support aggregates to pass transition state alongside
62 * aggregation bookkeeping
63 */
64 typedef struct StypeBox
65 {
66 Datum value;
67 Oid agg;
68 Oid transtype;
69 int16_t transtypeLen;
70 bool transtypeByVal;
71 bool valueNull;
72 bool valueInit;
73 AggregationArgumentContext *aggregationArgumentContext;
74 } StypeBox;
75
76 static HeapTuple GetAggregateForm(Oid oid, Form_pg_aggregate *form);
77 static HeapTuple GetProcForm(Oid oid, Form_pg_proc *form);
78 static HeapTuple GetTypeForm(Oid oid, Form_pg_type *form);
79 static void * pallocInAggContext(FunctionCallInfo fcinfo, size_t size);
80 static void aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid);
81 static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
82 static void InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box,
83 HeapTuple aggTuple, Oid transtype,
84 AggregationArgumentContext *aggregationArgumentContext);
85 static StypeBox * TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo);
86 static AggregationArgumentContext * CreateAggregationArgumentContext(FunctionCallInfo
87 fcinfo,
88 int argumentIndex);
89 static void ExtractAggregationValues(FunctionCallInfo fcinfo, int argumentIndex,
90 AggregationArgumentContext
91 *aggregationArgumentContext);
92 static void HandleTransition(StypeBox *box, FunctionCallInfo fcinfo,
93 FunctionCallInfo innerFcinfo);
94 static void HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value);
95 static bool TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box);
96 static bool TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc,
97 StypeBox *box);
98
99 /*
100 * GetAggregateForm loads corresponding tuple & Form_pg_aggregate for oid
101 */
102 static HeapTuple
GetAggregateForm(Oid oid,Form_pg_aggregate * form)103 GetAggregateForm(Oid oid, Form_pg_aggregate *form)
104 {
105 HeapTuple tuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(oid));
106 if (!HeapTupleIsValid(tuple))
107 {
108 elog(ERROR, "citus cache lookup failed for aggregate %u", oid);
109 }
110 *form = (Form_pg_aggregate) GETSTRUCT(tuple);
111 return tuple;
112 }
113
114
115 /*
116 * GetProcForm loads corresponding tuple & Form_pg_proc for oid
117 */
118 static HeapTuple
GetProcForm(Oid oid,Form_pg_proc * form)119 GetProcForm(Oid oid, Form_pg_proc *form)
120 {
121 HeapTuple tuple = SearchSysCache1(PROCOID, ObjectIdGetDatum(oid));
122 if (!HeapTupleIsValid(tuple))
123 {
124 elog(ERROR, "citus cache lookup failed for function %u", oid);
125 }
126 *form = (Form_pg_proc) GETSTRUCT(tuple);
127 return tuple;
128 }
129
130
131 /*
132 * GetTypeForm loads corresponding tuple & Form_pg_type for oid
133 */
134 static HeapTuple
GetTypeForm(Oid oid,Form_pg_type * form)135 GetTypeForm(Oid oid, Form_pg_type *form)
136 {
137 HeapTuple tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(oid));
138 if (!HeapTupleIsValid(tuple))
139 {
140 elog(ERROR, "citus cache lookup failed for type %u", oid);
141 }
142 *form = (Form_pg_type) GETSTRUCT(tuple);
143 return tuple;
144 }
145
146
147 /*
148 * pallocInAggContext calls palloc in fcinfo's aggregate context
149 */
150 static void *
pallocInAggContext(FunctionCallInfo fcinfo,size_t size)151 pallocInAggContext(FunctionCallInfo fcinfo, size_t size)
152 {
153 MemoryContext aggregateContext;
154 if (!AggCheckCallContext(fcinfo, &aggregateContext))
155 {
156 elog(ERROR, "Aggregate function called without an aggregate context");
157 }
158 return MemoryContextAlloc(aggregateContext, size);
159 }
160
161
162 /*
163 * aclcheckAggregate verifies that the given user has ACL_EXECUTE to the given proc
164 */
165 static void
aclcheckAggregate(ObjectType objectType,Oid userOid,Oid funcOid)166 aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid)
167 {
168 AclResult aclresult;
169 if (funcOid != InvalidOid)
170 {
171 aclresult = pg_proc_aclcheck(funcOid, userOid, ACL_EXECUTE);
172 if (aclresult != ACLCHECK_OK)
173 {
174 aclcheck_error(aclresult, objectType, get_func_name(funcOid));
175 }
176 }
177 }
178
179
180 /* Copied from nodeAgg.c */
181 static Datum
GetAggInitVal(Datum textInitVal,Oid transtype)182 GetAggInitVal(Datum textInitVal, Oid transtype)
183 {
184 /* *INDENT-OFF* */
185 Oid typinput,
186 typioparam;
187 char *strInitVal;
188 Datum initVal;
189
190 getTypeInputInfo(transtype, &typinput, &typioparam);
191 strInitVal = TextDatumGetCString(textInitVal);
192 initVal = OidInputFunctionCall(typinput, strInitVal,
193 typioparam, -1);
194 pfree(strInitVal);
195 return initVal;
196 /* *INDENT-ON* */
197 }
198
199
200 /*
201 * InitializeStypeBox fills in the rest of an StypeBox's fields besides agg,
202 * handling both permission checking & setting up the initial transition state.
203 */
204 static void
InitializeStypeBox(FunctionCallInfo fcinfo,StypeBox * box,HeapTuple aggTuple,Oid transtype,AggregationArgumentContext * aggregationArgumentContext)205 InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, Oid
206 transtype, AggregationArgumentContext *aggregationArgumentContext)
207 {
208 Form_pg_aggregate aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
209 Oid userId = GetUserId();
210
211 /* First we make ACL_EXECUTE checks as would be done in nodeAgg.c */
212 aclcheckAggregate(OBJECT_AGGREGATE, userId, aggform->aggfnoid);
213 aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggfinalfn);
214 aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggtransfn);
215 aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggdeserialfn);
216 aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggserialfn);
217 aclcheckAggregate(OBJECT_FUNCTION, userId, aggform->aggcombinefn);
218
219 Datum textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple,
220 Anum_pg_aggregate_agginitval,
221 &box->valueNull);
222 box->transtype = transtype;
223 box->valueInit = !box->valueNull;
224 box->aggregationArgumentContext = aggregationArgumentContext;
225 if (box->valueNull)
226 {
227 box->value = (Datum) 0;
228 }
229 else
230 {
231 MemoryContext aggregateContext;
232 if (!AggCheckCallContext(fcinfo, &aggregateContext))
233 {
234 elog(ERROR, "InitializeStypeBox called from non aggregate context");
235 }
236 MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext);
237
238 box->value = GetAggInitVal(textInitVal, transtype);
239
240 MemoryContextSwitchTo(oldContext);
241 }
242 }
243
244
245 /*
246 * TryCreateStypeBoxFromFcinfoAggref attempts to initialize an StypeBox through
247 * introspection of the fcinfo's Aggref from AggGetAggref. This is required
248 * when we receive no intermediate rows.
249 *
250 * Returns NULL if the Aggref isn't our expected shape.
251 */
252 static StypeBox *
TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo)253 TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo)
254 {
255 Aggref *aggref = AggGetAggref(fcinfo);
256 if (aggref == NULL || aggref->args == NIL)
257 {
258 return NULL;
259 }
260
261 TargetEntry *aggArg = linitial(aggref->args);
262 if (!IsA(aggArg->expr, Const))
263 {
264 return NULL;
265 }
266
267 Const *aggConst = (Const *) aggArg->expr;
268 if (aggConst->consttype != OIDOID && aggConst->consttype != REGPROCEDUREOID)
269 {
270 return NULL;
271 }
272
273 Form_pg_aggregate aggform;
274 StypeBox *box = pallocInAggContext(fcinfo, sizeof(StypeBox));
275 box->agg = DatumGetObjectId(aggConst->constvalue);
276 HeapTuple aggTuple = GetAggregateForm(box->agg, &aggform);
277 InitializeStypeBox(fcinfo, box, aggTuple, aggform->aggtranstype, NULL);
278 ReleaseSysCache(aggTuple);
279
280 return box;
281 }
282
283
284 /*
285 * CreateAggregationArgumentContext creates an AggregationArgumentContext tailored
286 * to handling the aggregation of input arguments identical to type at
287 * 'argumentIndex' in 'fcinfo'.
288 */
289 static AggregationArgumentContext *
CreateAggregationArgumentContext(FunctionCallInfo fcinfo,int argumentIndex)290 CreateAggregationArgumentContext(FunctionCallInfo fcinfo, int argumentIndex)
291 {
292 AggregationArgumentContext *aggregationArgumentContext =
293 pallocInAggContext(fcinfo, sizeof(AggregationArgumentContext));
294
295 /* check if input comes combined into tuple/record */
296 if (RECORDOID == get_fn_expr_argtype(fcinfo->flinfo, argumentIndex))
297 {
298 /* initialize context to handle aggregation argument combined into tuple */
299 if (fcGetArgNull(fcinfo, argumentIndex))
300 {
301 ereport(ERROR, (errmsg("worker_partial_agg_sfunc: null record input"),
302 errhint("Elements of record may be null")));
303 }
304
305 /* retrieve tuple header */
306 HeapTupleHeader tupleHeader = PG_GETARG_HEAPTUPLEHEADER(argumentIndex);
307
308 /* extract type info from the tuple */
309 TupleDesc tupleDesc =
310 lookup_rowtype_tupdesc(HeapTupleHeaderGetTypeId(tupleHeader),
311 HeapTupleHeaderGetTypMod(tupleHeader));
312
313 /* create a copy we can keep */
314 TupleDesc tupleDescCopy = pallocInAggContext(fcinfo, TupleDescSize(tupleDesc));
315 TupleDescCopy(tupleDescCopy, tupleDesc);
316 ReleaseTupleDesc(tupleDesc);
317
318 /* build a HeapTuple control structure */
319 HeapTuple tuple = pallocInAggContext(fcinfo, sizeof(HeapTupleData));
320 ItemPointerSetInvalid(&(tuple->t_self));
321 tuple->t_tableOid = InvalidOid;
322
323 /* initialize context to handle multiple aggregation arguments */
324 aggregationArgumentContext->argumentCount = tupleDescCopy->natts;
325
326 aggregationArgumentContext->values =
327 pallocInAggContext(fcinfo, tupleDescCopy->natts * sizeof(Datum));
328
329 aggregationArgumentContext->nulls =
330 pallocInAggContext(fcinfo, tupleDescCopy->natts * sizeof(bool));
331
332 aggregationArgumentContext->isTuple = true;
333 aggregationArgumentContext->tupleDesc = tupleDescCopy;
334 aggregationArgumentContext->tuple = tuple;
335 }
336 else
337 {
338 /* initialize context to handle single aggregation argument */
339 aggregationArgumentContext->argumentCount = 1;
340 aggregationArgumentContext->values = pallocInAggContext(fcinfo, sizeof(Datum));
341 aggregationArgumentContext->nulls = pallocInAggContext(fcinfo, sizeof(bool));
342 aggregationArgumentContext->isTuple = false;
343 aggregationArgumentContext->tupleDesc = NULL;
344 aggregationArgumentContext->tuple = NULL;
345 }
346
347 return aggregationArgumentContext;
348 }
349
350
351 /*
352 * ExtractAggregationValues extracts aggregation argument values and stores them in
353 * the mutable fields of AggregationArgumentContext.
354 */
355 static void
ExtractAggregationValues(FunctionCallInfo fcinfo,int argumentIndex,AggregationArgumentContext * aggregationArgumentContext)356 ExtractAggregationValues(FunctionCallInfo fcinfo, int argumentIndex,
357 AggregationArgumentContext *aggregationArgumentContext)
358 {
359 if (aggregationArgumentContext->isTuple)
360 {
361 if (fcGetArgNull(fcinfo, argumentIndex))
362 {
363 /* handle null record input */
364 for (int i = 0; i < aggregationArgumentContext->argumentCount; i++)
365 {
366 aggregationArgumentContext->values[i] = 0;
367 aggregationArgumentContext->nulls[i] = true;
368 }
369 }
370 else
371 {
372 /* handle tuple/record input */
373 HeapTupleHeader tupleHeader =
374 DatumGetHeapTupleHeader(fcGetArgValue(fcinfo, argumentIndex));
375
376 if (HeapTupleHeaderGetNatts(tupleHeader) !=
377 aggregationArgumentContext->argumentCount ||
378 HeapTupleHeaderGetTypeId(tupleHeader) !=
379 aggregationArgumentContext->tupleDesc->tdtypeid ||
380 HeapTupleHeaderGetTypMod(tupleHeader) !=
381 aggregationArgumentContext->tupleDesc->tdtypmod)
382 {
383 ereport(ERROR, (errmsg("worker_partial_agg_sfunc received "
384 "incompatible record")));
385 }
386
387 aggregationArgumentContext->tuple->t_len =
388 HeapTupleHeaderGetDatumLength(tupleHeader);
389
390 aggregationArgumentContext->tuple->t_data = tupleHeader;
391
392 /* break down the tuple into fields */
393 heap_deform_tuple(
394 aggregationArgumentContext->tuple,
395 aggregationArgumentContext->tupleDesc,
396 aggregationArgumentContext->values,
397 aggregationArgumentContext->nulls);
398 }
399 }
400 else
401 {
402 /* extract single argument value */
403 aggregationArgumentContext->values[0] = fcGetArgValue(fcinfo, argumentIndex);
404 aggregationArgumentContext->nulls[0] = fcGetArgNull(fcinfo, argumentIndex);
405 }
406 }
407
408
409 /*
410 * HandleTransition copies logic used in nodeAgg's advance_transition_function
411 * for handling result of transition function.
412 */
413 static void
HandleTransition(StypeBox * box,FunctionCallInfo fcinfo,FunctionCallInfo innerFcinfo)414 HandleTransition(StypeBox *box, FunctionCallInfo fcinfo, FunctionCallInfo innerFcinfo)
415 {
416 Datum newVal = FunctionCallInvoke(innerFcinfo);
417 bool newValIsNull = innerFcinfo->isnull;
418
419 if (!box->transtypeByVal &&
420 DatumGetPointer(newVal) != DatumGetPointer(box->value))
421 {
422 if (!newValIsNull)
423 {
424 MemoryContext aggregateContext;
425
426 if (!AggCheckCallContext(fcinfo, &aggregateContext))
427 {
428 elog(ERROR,
429 "HandleTransition called from non aggregate context");
430 }
431
432 MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext);
433 if (!(DatumIsReadWriteExpandedObject(newVal,
434 false, box->transtypeLen) &&
435 MemoryContextGetParent(DatumGetEOHP(newVal)->eoh_context) ==
436 CurrentMemoryContext))
437 {
438 newVal = datumCopy(newVal, box->transtypeByVal, box->transtypeLen);
439 }
440 MemoryContextSwitchTo(oldContext);
441 }
442
443 if (!box->valueNull)
444 {
445 if (DatumIsReadWriteExpandedObject(box->value,
446 false, box->transtypeLen))
447 {
448 DeleteExpandedObject(box->value);
449 }
450 else
451 {
452 pfree(DatumGetPointer(box->value));
453 }
454 }
455 }
456
457 box->value = newVal;
458 box->valueNull = newValIsNull;
459 }
460
461
462 /*
463 * HandleStrictUninit handles initialization of state for when
464 * transition function is strict & state has not yet been initialized.
465 */
466 static void
HandleStrictUninit(StypeBox * box,FunctionCallInfo fcinfo,Datum value)467 HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value)
468 {
469 MemoryContext aggregateContext;
470
471 if (!AggCheckCallContext(fcinfo, &aggregateContext))
472 {
473 elog(ERROR, "HandleStrictUninit called from non aggregate context");
474 }
475
476 MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext);
477 box->value = datumCopy(value, box->transtypeByVal, box->transtypeLen);
478 MemoryContextSwitchTo(oldContext);
479
480 box->valueNull = false;
481 box->valueInit = true;
482 }
483
484
485 /*
486 * worker_partial_agg_sfunc advances transition state,
487 * essentially implementing the following pseudocode:
488 *
489 * (box, agg, ...) -> box
490 * box.agg = agg;
491 * box.value = agg.sfunc(box.value, ...);
492 * return box
493 */
494 Datum
worker_partial_agg_sfunc(PG_FUNCTION_ARGS)495 worker_partial_agg_sfunc(PG_FUNCTION_ARGS)
496 {
497 StypeBox *box = NULL;
498 Form_pg_aggregate aggform;
499 LOCAL_FCINFO(innerFcinfo, FUNC_MAX_ARGS);
500 FmgrInfo info;
501 int argumentIndex = 0;
502 bool initialCall = PG_ARGISNULL(0);
503
504 if (initialCall)
505 {
506 if (PG_ARGISNULL(1))
507 {
508 ereport(ERROR, (errmsg("worker_partial_agg_sfunc received invalid null "
509 "input for second argument")));
510 }
511 box = pallocInAggContext(fcinfo, sizeof(StypeBox));
512 box->agg = PG_GETARG_OID(1);
513 box->aggregationArgumentContext = CreateAggregationArgumentContext(fcinfo, 2);
514
515 if (!TypecheckWorkerPartialAggArgType(fcinfo, box))
516 {
517 ereport(ERROR, (errmsg("worker_partial_agg_sfunc could not confirm type "
518 "correctness")));
519 }
520 }
521 else
522 {
523 box = (StypeBox *) PG_GETARG_POINTER(0);
524 Assert(box->agg == PG_GETARG_OID(1));
525 }
526
527 HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
528 Oid aggsfunc = aggform->aggtransfn;
529
530 if (initialCall)
531 {
532 InitializeStypeBox(fcinfo, box, aggtuple, aggform->aggtranstype,
533 box->aggregationArgumentContext);
534 }
535 ReleaseSysCache(aggtuple);
536 if (initialCall)
537 {
538 get_typlenbyval(box->transtype,
539 &box->transtypeLen,
540 &box->transtypeByVal);
541 }
542
543 /*
544 * Get aggregation values, which may be either wrapped in a
545 * tuple (multi-argument case) or a singular, unwrapped value.
546 */
547 ExtractAggregationValues(fcinfo, 2, box->aggregationArgumentContext);
548
549 fmgr_info(aggsfunc, &info);
550 if (info.fn_strict)
551 {
552 for (argumentIndex = 0;
553 argumentIndex < box->aggregationArgumentContext->argumentCount;
554 argumentIndex++)
555 {
556 if (box->aggregationArgumentContext->nulls[argumentIndex])
557 {
558 PG_RETURN_POINTER(box);
559 }
560 }
561
562 if (!box->valueInit)
563 {
564 /* For 'strict' transition functions, if the initial state value is null
565 * then the first argument value of the first row with all-nonnull input
566 * values replaces the state value.
567 */
568 Datum stateValue = box->aggregationArgumentContext->values[0];
569 HandleStrictUninit(box, fcinfo, stateValue);
570
571 PG_RETURN_POINTER(box);
572 }
573
574 if (box->valueNull)
575 {
576 PG_RETURN_POINTER(box);
577 }
578 }
579
580 /* if aggregate function has N parameters, corresponding SFUNC has N+1 */
581 InitFunctionCallInfoData(*innerFcinfo, &info,
582 box->aggregationArgumentContext->argumentCount + 1,
583 fcinfo->fncollation,
584 fcinfo->context, fcinfo->resultinfo);
585 fcSetArgExt(innerFcinfo, 0, box->value, box->valueNull);
586
587
588 for (argumentIndex = 0;
589 argumentIndex < box->aggregationArgumentContext->argumentCount;
590 argumentIndex++)
591 {
592 fcSetArgExt(innerFcinfo, argumentIndex + 1,
593 box->aggregationArgumentContext->values[argumentIndex],
594 box->aggregationArgumentContext->nulls[argumentIndex]);
595 }
596
597 HandleTransition(box, fcinfo, innerFcinfo);
598
599 PG_RETURN_POINTER(box);
600 }
601
602
603 /*
604 * worker_partial_agg_ffunc serializes transition state,
605 * essentially implementing the following pseudocode:
606 *
607 * (box) -> text
608 * return box.agg.stype.output(box.value)
609 */
610 Datum
worker_partial_agg_ffunc(PG_FUNCTION_ARGS)611 worker_partial_agg_ffunc(PG_FUNCTION_ARGS)
612 {
613 LOCAL_FCINFO(innerFcinfo, 1);
614 FmgrInfo info;
615 StypeBox *box = (StypeBox *) (PG_ARGISNULL(0) ? NULL : PG_GETARG_POINTER(0));
616 Form_pg_aggregate aggform;
617 Oid typoutput = InvalidOid;
618 bool typIsVarlena = false;
619
620 if (box == NULL)
621 {
622 box = TryCreateStypeBoxFromFcinfoAggref(fcinfo);
623 }
624
625 if (box == NULL || box->valueNull)
626 {
627 PG_RETURN_NULL();
628 }
629
630 HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
631
632 if (aggform->aggcombinefn == InvalidOid)
633 {
634 ereport(ERROR, (errmsg(
635 "worker_partial_agg_ffunc expects an aggregate with COMBINEFUNC")));
636 }
637
638 if (aggform->aggtranstype == INTERNALOID)
639 {
640 ereport(ERROR,
641 (errmsg(
642 "worker_partial_agg_ffunc does not support aggregates with INTERNAL transition state")));
643 }
644
645 Oid transtype = aggform->aggtranstype;
646 ReleaseSysCache(aggtuple);
647
648 getTypeOutputInfo(transtype, &typoutput, &typIsVarlena);
649
650 fmgr_info(typoutput, &info);
651
652 InitFunctionCallInfoData(*innerFcinfo, &info, 1, fcinfo->fncollation,
653 fcinfo->context, fcinfo->resultinfo);
654 fcSetArgExt(innerFcinfo, 0, box->value, box->valueNull);
655
656 Datum result = FunctionCallInvoke(innerFcinfo);
657
658 if (innerFcinfo->isnull)
659 {
660 PG_RETURN_NULL();
661 }
662 PG_RETURN_DATUM(result);
663 }
664
665
666 /*
667 * coord_combine_agg_sfunc deserializes transition state from worker
668 * & advances transition state using combinefunc,
669 * essentially implementing the following pseudocode:
670 *
671 * (box, agg, text) -> box
672 * box.agg = agg
673 * box.value = agg.combine(box.value, agg.stype.input(text))
674 * return box
675 */
676 Datum
coord_combine_agg_sfunc(PG_FUNCTION_ARGS)677 coord_combine_agg_sfunc(PG_FUNCTION_ARGS)
678 {
679 LOCAL_FCINFO(innerFcinfo, 3);
680 FmgrInfo info;
681 Form_pg_aggregate aggform;
682 Form_pg_type transtypeform;
683 Datum value;
684 StypeBox *box = NULL;
685
686 if (PG_ARGISNULL(0))
687 {
688 box = pallocInAggContext(fcinfo, sizeof(StypeBox));
689 box->agg = PG_GETARG_OID(1);
690 }
691 else
692 {
693 box = (StypeBox *) PG_GETARG_POINTER(0);
694 Assert(box->agg == PG_GETARG_OID(1));
695 }
696
697 HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
698
699 if (aggform->aggcombinefn == InvalidOid)
700 {
701 ereport(ERROR, (errmsg(
702 "coord_combine_agg_sfunc expects an aggregate with COMBINEFUNC")));
703 }
704
705 if (aggform->aggtranstype == INTERNALOID)
706 {
707 ereport(ERROR,
708 (errmsg(
709 "coord_combine_agg_sfunc does not support aggregates with INTERNAL transition state")));
710 }
711
712 Oid combine = aggform->aggcombinefn;
713
714 if (PG_ARGISNULL(0))
715 {
716 InitializeStypeBox(fcinfo, box, aggtuple, aggform->aggtranstype, NULL);
717 }
718
719 ReleaseSysCache(aggtuple);
720
721 if (PG_ARGISNULL(0))
722 {
723 get_typlenbyval(box->transtype,
724 &box->transtypeLen,
725 &box->transtypeByVal);
726 }
727
728 bool valueNull = PG_ARGISNULL(2);
729 HeapTuple transtypetuple = GetTypeForm(box->transtype, &transtypeform);
730 Oid ioparam = getTypeIOParam(transtypetuple);
731 Oid deserial = transtypeform->typinput;
732 ReleaseSysCache(transtypetuple);
733
734 fmgr_info(deserial, &info);
735 if (valueNull && info.fn_strict)
736 {
737 value = (Datum) 0;
738 }
739 else
740 {
741 InitFunctionCallInfoData(*innerFcinfo, &info, 3, fcinfo->fncollation,
742 fcinfo->context, fcinfo->resultinfo);
743 fcSetArgExt(innerFcinfo, 0, PG_GETARG_DATUM(2), valueNull);
744 fcSetArg(innerFcinfo, 1, ObjectIdGetDatum(ioparam));
745 fcSetArg(innerFcinfo, 2, Int32GetDatum(-1)); /* typmod */
746
747 value = FunctionCallInvoke(innerFcinfo);
748 valueNull = innerFcinfo->isnull;
749 }
750
751 fmgr_info(combine, &info);
752
753 if (info.fn_strict)
754 {
755 if (valueNull)
756 {
757 PG_RETURN_POINTER(box);
758 }
759
760 if (!box->valueInit)
761 {
762 HandleStrictUninit(box, fcinfo, value);
763 PG_RETURN_POINTER(box);
764 }
765
766 if (box->valueNull)
767 {
768 PG_RETURN_POINTER(box);
769 }
770 }
771
772 InitFunctionCallInfoData(*innerFcinfo, &info, 2, fcinfo->fncollation,
773 fcinfo->context, fcinfo->resultinfo);
774 fcSetArgExt(innerFcinfo, 0, box->value, box->valueNull);
775 fcSetArgExt(innerFcinfo, 1, value, valueNull);
776
777 HandleTransition(box, fcinfo, innerFcinfo);
778
779 PG_RETURN_POINTER(box);
780 }
781
782
783 /*
784 * coord_combine_agg_ffunc applies finalfunc of aggregate to state,
785 * essentially implementing the following pseudocode:
786 *
787 * (box, ...) -> fval
788 * return box.agg.ffunc(box.value)
789 */
790 Datum
coord_combine_agg_ffunc(PG_FUNCTION_ARGS)791 coord_combine_agg_ffunc(PG_FUNCTION_ARGS)
792 {
793 StypeBox *box = (StypeBox *) (PG_ARGISNULL(0) ? NULL : PG_GETARG_POINTER(0));
794 LOCAL_FCINFO(innerFcinfo, FUNC_MAX_ARGS);
795 FmgrInfo info;
796 int innerNargs = 0;
797 Form_pg_aggregate aggform;
798 Form_pg_proc ffuncform;
799
800 if (box == NULL)
801 {
802 box = TryCreateStypeBoxFromFcinfoAggref(fcinfo);
803
804 if (box == NULL)
805 {
806 PG_RETURN_NULL();
807 }
808 }
809
810 HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
811 Oid ffunc = aggform->aggfinalfn;
812 bool fextra = aggform->aggfinalextra;
813 ReleaseSysCache(aggtuple);
814
815 if (!TypecheckCoordCombineAggReturnType(fcinfo, ffunc, box))
816 {
817 ereport(ERROR, (errmsg(
818 "coord_combine_agg_ffunc could not confirm type correctness")));
819 }
820
821 if (ffunc == InvalidOid)
822 {
823 if (box->valueNull)
824 {
825 PG_RETURN_NULL();
826 }
827 PG_RETURN_DATUM(box->value);
828 }
829
830 HeapTuple ffunctuple = GetProcForm(ffunc, &ffuncform);
831 bool finalStrict = ffuncform->proisstrict;
832 ReleaseSysCache(ffunctuple);
833
834 if (finalStrict && box->valueNull)
835 {
836 PG_RETURN_NULL();
837 }
838
839 if (fextra)
840 {
841 innerNargs = fcinfo->nargs;
842 }
843 else
844 {
845 innerNargs = 1;
846 }
847 fmgr_info(ffunc, &info);
848 InitFunctionCallInfoData(*innerFcinfo, &info, innerNargs, fcinfo->fncollation,
849 fcinfo->context, fcinfo->resultinfo);
850 fcSetArgExt(innerFcinfo, 0, box->value, box->valueNull);
851 for (int argumentIndex = 1; argumentIndex < innerNargs; argumentIndex++)
852 {
853 fcSetArgNull(innerFcinfo, argumentIndex);
854 }
855
856 Datum result = FunctionCallInvoke(innerFcinfo);
857 fcinfo->isnull = innerFcinfo->isnull;
858 return result;
859 }
860
861
862 /*
863 * TypecheckWorkerPartialAggArgType returns whether the arguments being passed to
864 * worker_partial_agg match the arguments expected by the aggregate being distributed.
865 */
866 static bool
TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo,StypeBox * box)867 TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box)
868 {
869 Aggref *aggref = AggGetAggref(fcinfo);
870 if (aggref == NULL)
871 {
872 return false;
873 }
874
875 Assert(list_length(aggref->args) == 2);
876 TargetEntry *aggarg = list_nth(aggref->args, 1);
877
878 bool argtypesNull;
879 HeapTuple proctuple = SearchSysCache1(PROCOID, ObjectIdGetDatum(box->agg));
880 if (!HeapTupleIsValid(proctuple))
881 {
882 return false;
883 }
884
885 Datum argtypes = SysCacheGetAttr(PROCOID, proctuple,
886 Anum_pg_proc_proargtypes,
887 &argtypesNull);
888 Assert(!argtypesNull);
889 ReleaseSysCache(proctuple);
890
891 if (ARR_NDIM(DatumGetArrayTypeP(argtypes)) != 1)
892 {
893 elog(ERROR, "worker_partial_agg_sfunc cannot type check aggregates "
894 "taking multi-dimensional arguments");
895 }
896
897 int aggregateArgCount = ARR_DIMS(DatumGetArrayTypeP(argtypes))[0];
898
899 /* we expect aggregate function to have at least a single parameter */
900 if (box->aggregationArgumentContext->argumentCount != aggregateArgCount)
901 {
902 return false;
903 }
904
905 int aggregateArgIndex = 0;
906 Datum argType;
907
908 if (box->aggregationArgumentContext->isTuple)
909 {
910 /* check if record element types match aggregate input parameters */
911 for (aggregateArgIndex = 0; aggregateArgIndex < aggregateArgCount;
912 aggregateArgIndex++)
913 {
914 argType = array_get_element(argtypes, 1, &aggregateArgIndex, -1, sizeof(Oid),
915 true, 'i', &argtypesNull);
916 Assert(!argtypesNull);
917 TupleDesc tupleDesc = box->aggregationArgumentContext->tupleDesc;
918 if (argType != tupleDesc->attrs[aggregateArgIndex].atttypid)
919 {
920 return false;
921 }
922 }
923
924 return true;
925 }
926 else
927 {
928 argType = array_get_element(argtypes, 1, &aggregateArgIndex, -1, sizeof(Oid),
929 true, 'i', &argtypesNull);
930 Assert(!argtypesNull);
931
932 return exprType((Node *) aggarg->expr) == DatumGetObjectId(argType);
933 }
934 }
935
936
937 /*
938 * TypecheckCoordCombineAggReturnType returns whether the return type of the aggregate
939 * being distributed by coord_combine_agg matches the null constant used to inform postgres
940 * what the aggregate's expected return type is.
941 */
942 static bool
TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo,Oid ffunc,StypeBox * box)943 TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc, StypeBox *box)
944 {
945 Aggref *aggref = AggGetAggref(fcinfo);
946 if (aggref == NULL)
947 {
948 return false;
949 }
950
951 Oid finalType = ffunc == InvalidOid ?
952 box->transtype : get_func_rettype(ffunc);
953
954 Assert(list_length(aggref->args) == 3);
955 TargetEntry *nulltag = list_nth(aggref->args, 2);
956
957 return nulltag != NULL && IsA(nulltag->expr, Const) &&
958 ((Const *) nulltag->expr)->consttype == finalType;
959 }
960