1 /*-------------------------------------------------------------------------
2  *
3  * numeric.c
4  *	  An exact numeric data type for the Postgres database system
5  *
6  * Original coding 1998, Jan Wieck.  Heavily revised 2003, Tom Lane.
7  *
8  * Many of the algorithmic ideas are borrowed from David M. Smith's "FM"
9  * multiple-precision math library, most recently published as Algorithm
10  * 786: Multiple-Precision Complex Arithmetic and Functions, ACM
11  * Transactions on Mathematical Software, Vol. 24, No. 4, December 1998,
12  * pages 359-367.
13  *
14  * Copyright (c) 1998-2020, PostgreSQL Global Development Group
15  *
16  * IDENTIFICATION
17  *	  src/backend/utils/adt/numeric.c
18  *
19  *-------------------------------------------------------------------------
20  */
21 
22 #include "postgres.h"
23 
24 #include <ctype.h>
25 #include <float.h>
26 #include <limits.h>
27 #include <math.h>
28 
29 #include "catalog/pg_type.h"
30 #include "common/hashfn.h"
31 #include "common/int.h"
32 #include "funcapi.h"
33 #include "lib/hyperloglog.h"
34 #include "libpq/pqformat.h"
35 #include "miscadmin.h"
36 #include "nodes/nodeFuncs.h"
37 #include "nodes/supportnodes.h"
38 #include "utils/array.h"
39 #include "utils/builtins.h"
40 #include "utils/float.h"
41 #include "utils/guc.h"
42 #include "utils/int8.h"
43 #include "utils/numeric.h"
44 #include "utils/sortsupport.h"
45 
46 /* ----------
47  * Uncomment the following to enable compilation of dump_numeric()
48  * and dump_var() and to get a dump of any result produced by make_result().
49  * ----------
50 #define NUMERIC_DEBUG
51  */
52 
53 
54 /* ----------
55  * Local data types
56  *
57  * Numeric values are represented in a base-NBASE floating point format.
58  * Each "digit" ranges from 0 to NBASE-1.  The type NumericDigit is signed
59  * and wide enough to store a digit.  We assume that NBASE*NBASE can fit in
60  * an int.  Although the purely calculational routines could handle any even
61  * NBASE that's less than sqrt(INT_MAX), in practice we are only interested
62  * in NBASE a power of ten, so that I/O conversions and decimal rounding
63  * are easy.  Also, it's actually more efficient if NBASE is rather less than
64  * sqrt(INT_MAX), so that there is "headroom" for mul_var and div_var_fast to
65  * postpone processing carries.
66  *
67  * Values of NBASE other than 10000 are considered of historical interest only
68  * and are no longer supported in any sense; no mechanism exists for the client
69  * to discover the base, so every client supporting binary mode expects the
70  * base-10000 format.  If you plan to change this, also note the numeric
71  * abbreviation code, which assumes NBASE=10000.
72  * ----------
73  */
74 
75 #if 0
76 #define NBASE		10
77 #define HALF_NBASE	5
78 #define DEC_DIGITS	1			/* decimal digits per NBASE digit */
79 #define MUL_GUARD_DIGITS	4	/* these are measured in NBASE digits */
80 #define DIV_GUARD_DIGITS	8
81 
82 typedef signed char NumericDigit;
83 #endif
84 
85 #if 0
86 #define NBASE		100
87 #define HALF_NBASE	50
88 #define DEC_DIGITS	2			/* decimal digits per NBASE digit */
89 #define MUL_GUARD_DIGITS	3	/* these are measured in NBASE digits */
90 #define DIV_GUARD_DIGITS	6
91 
92 typedef signed char NumericDigit;
93 #endif
94 
95 #if 1
96 #define NBASE		10000
97 #define HALF_NBASE	5000
98 #define DEC_DIGITS	4			/* decimal digits per NBASE digit */
99 #define MUL_GUARD_DIGITS	2	/* these are measured in NBASE digits */
100 #define DIV_GUARD_DIGITS	4
101 
102 typedef int16 NumericDigit;
103 #endif
104 
105 /*
106  * The Numeric type as stored on disk.
107  *
108  * If the high bits of the first word of a NumericChoice (n_header, or
109  * n_short.n_header, or n_long.n_sign_dscale) are NUMERIC_SHORT, then the
110  * numeric follows the NumericShort format; if they are NUMERIC_POS or
111  * NUMERIC_NEG, it follows the NumericLong format.  If they are NUMERIC_NAN,
112  * it is a NaN.  We currently always store a NaN using just two bytes (i.e.
113  * only n_header), but previous releases used only the NumericLong format,
114  * so we might find 4-byte NaNs on disk if a database has been migrated using
115  * pg_upgrade.  In either case, when the high bits indicate a NaN, the
116  * remaining bits are never examined.  Currently, we always initialize these
117  * to zero, but it might be possible to use them for some other purpose in
118  * the future.
119  *
120  * In the NumericShort format, the remaining 14 bits of the header word
121  * (n_short.n_header) are allocated as follows: 1 for sign (positive or
122  * negative), 6 for dynamic scale, and 7 for weight.  In practice, most
123  * commonly-encountered values can be represented this way.
124  *
125  * In the NumericLong format, the remaining 14 bits of the header word
126  * (n_long.n_sign_dscale) represent the display scale; and the weight is
127  * stored separately in n_weight.
128  *
129  * NOTE: by convention, values in the packed form have been stripped of
130  * all leading and trailing zero digits (where a "digit" is of base NBASE).
131  * In particular, if the value is zero, there will be no digits at all!
132  * The weight is arbitrary in that case, but we normally set it to zero.
133  */
134 
135 struct NumericShort
136 {
137 	uint16		n_header;		/* Sign + display scale + weight */
138 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
139 };
140 
141 struct NumericLong
142 {
143 	uint16		n_sign_dscale;	/* Sign + display scale */
144 	int16		n_weight;		/* Weight of 1st digit	*/
145 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
146 };
147 
148 union NumericChoice
149 {
150 	uint16		n_header;		/* Header word */
151 	struct NumericLong n_long;	/* Long form (4-byte header) */
152 	struct NumericShort n_short;	/* Short form (2-byte header) */
153 };
154 
155 struct NumericData
156 {
157 	int32		vl_len_;		/* varlena header (do not touch directly!) */
158 	union NumericChoice choice; /* choice of format */
159 };
160 
161 
162 /*
163  * Interpretation of high bits.
164  */
165 
166 #define NUMERIC_SIGN_MASK	0xC000
167 #define NUMERIC_POS			0x0000
168 #define NUMERIC_NEG			0x4000
169 #define NUMERIC_SHORT		0x8000
170 #define NUMERIC_NAN			0xC000
171 
172 #define NUMERIC_FLAGBITS(n) ((n)->choice.n_header & NUMERIC_SIGN_MASK)
173 #define NUMERIC_IS_NAN(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_NAN)
174 #define NUMERIC_IS_SHORT(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_SHORT)
175 
176 #define NUMERIC_HDRSZ	(VARHDRSZ + sizeof(uint16) + sizeof(int16))
177 #define NUMERIC_HDRSZ_SHORT (VARHDRSZ + sizeof(uint16))
178 
179 /*
180  * If the flag bits are NUMERIC_SHORT or NUMERIC_NAN, we want the short header;
181  * otherwise, we want the long one.  Instead of testing against each value, we
182  * can just look at the high bit, for a slight efficiency gain.
183  */
184 #define NUMERIC_HEADER_IS_SHORT(n)	(((n)->choice.n_header & 0x8000) != 0)
185 #define NUMERIC_HEADER_SIZE(n) \
186 	(VARHDRSZ + sizeof(uint16) + \
187 	 (NUMERIC_HEADER_IS_SHORT(n) ? 0 : sizeof(int16)))
188 
189 /*
190  * Short format definitions.
191  */
192 
193 #define NUMERIC_SHORT_SIGN_MASK			0x2000
194 #define NUMERIC_SHORT_DSCALE_MASK		0x1F80
195 #define NUMERIC_SHORT_DSCALE_SHIFT		7
196 #define NUMERIC_SHORT_DSCALE_MAX		\
197 	(NUMERIC_SHORT_DSCALE_MASK >> NUMERIC_SHORT_DSCALE_SHIFT)
198 #define NUMERIC_SHORT_WEIGHT_SIGN_MASK	0x0040
199 #define NUMERIC_SHORT_WEIGHT_MASK		0x003F
200 #define NUMERIC_SHORT_WEIGHT_MAX		NUMERIC_SHORT_WEIGHT_MASK
201 #define NUMERIC_SHORT_WEIGHT_MIN		(-(NUMERIC_SHORT_WEIGHT_MASK+1))
202 
203 /*
204  * Extract sign, display scale, weight.
205  */
206 
207 #define NUMERIC_DSCALE_MASK			0x3FFF
208 #define NUMERIC_DSCALE_MAX			NUMERIC_DSCALE_MASK
209 
210 #define NUMERIC_SIGN(n) \
211 	(NUMERIC_IS_SHORT(n) ? \
212 		(((n)->choice.n_short.n_header & NUMERIC_SHORT_SIGN_MASK) ? \
213 		NUMERIC_NEG : NUMERIC_POS) : NUMERIC_FLAGBITS(n))
214 #define NUMERIC_DSCALE(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
215 	((n)->choice.n_short.n_header & NUMERIC_SHORT_DSCALE_MASK) \
216 		>> NUMERIC_SHORT_DSCALE_SHIFT \
217 	: ((n)->choice.n_long.n_sign_dscale & NUMERIC_DSCALE_MASK))
218 #define NUMERIC_WEIGHT(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
219 	(((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_SIGN_MASK ? \
220 		~NUMERIC_SHORT_WEIGHT_MASK : 0) \
221 	 | ((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_MASK)) \
222 	: ((n)->choice.n_long.n_weight))
223 
224 /* ----------
225  * NumericVar is the format we use for arithmetic.  The digit-array part
226  * is the same as the NumericData storage format, but the header is more
227  * complex.
228  *
229  * The value represented by a NumericVar is determined by the sign, weight,
230  * ndigits, and digits[] array.
231  *
232  * Note: the first digit of a NumericVar's value is assumed to be multiplied
233  * by NBASE ** weight.  Another way to say it is that there are weight+1
234  * digits before the decimal point.  It is possible to have weight < 0.
235  *
236  * buf points at the physical start of the palloc'd digit buffer for the
237  * NumericVar.  digits points at the first digit in actual use (the one
238  * with the specified weight).  We normally leave an unused digit or two
239  * (preset to zeroes) between buf and digits, so that there is room to store
240  * a carry out of the top digit without reallocating space.  We just need to
241  * decrement digits (and increment weight) to make room for the carry digit.
242  * (There is no such extra space in a numeric value stored in the database,
243  * only in a NumericVar in memory.)
244  *
245  * If buf is NULL then the digit buffer isn't actually palloc'd and should
246  * not be freed --- see the constants below for an example.
247  *
248  * dscale, or display scale, is the nominal precision expressed as number
249  * of digits after the decimal point (it must always be >= 0 at present).
250  * dscale may be more than the number of physically stored fractional digits,
251  * implying that we have suppressed storage of significant trailing zeroes.
252  * It should never be less than the number of stored digits, since that would
253  * imply hiding digits that are present.  NOTE that dscale is always expressed
254  * in *decimal* digits, and so it may correspond to a fractional number of
255  * base-NBASE digits --- divide by DEC_DIGITS to convert to NBASE digits.
256  *
257  * rscale, or result scale, is the target precision for a computation.
258  * Like dscale it is expressed as number of *decimal* digits after the decimal
259  * point, and is always >= 0 at present.
260  * Note that rscale is not stored in variables --- it's figured on-the-fly
261  * from the dscales of the inputs.
262  *
263  * While we consistently use "weight" to refer to the base-NBASE weight of
264  * a numeric value, it is convenient in some scale-related calculations to
265  * make use of the base-10 weight (ie, the approximate log10 of the value).
266  * To avoid confusion, such a decimal-units weight is called a "dweight".
267  *
268  * NB: All the variable-level functions are written in a style that makes it
269  * possible to give one and the same variable as argument and destination.
270  * This is feasible because the digit buffer is separate from the variable.
271  * ----------
272  */
273 typedef struct NumericVar
274 {
275 	int			ndigits;		/* # of digits in digits[] - can be 0! */
276 	int			weight;			/* weight of first digit */
277 	int			sign;			/* NUMERIC_POS, NUMERIC_NEG, or NUMERIC_NAN */
278 	int			dscale;			/* display scale */
279 	NumericDigit *buf;			/* start of palloc'd space for digits[] */
280 	NumericDigit *digits;		/* base-NBASE digits */
281 } NumericVar;
282 
283 
284 /* ----------
285  * Data for generate_series
286  * ----------
287  */
288 typedef struct
289 {
290 	NumericVar	current;
291 	NumericVar	stop;
292 	NumericVar	step;
293 } generate_series_numeric_fctx;
294 
295 
296 /* ----------
297  * Sort support.
298  * ----------
299  */
300 typedef struct
301 {
302 	void	   *buf;			/* buffer for short varlenas */
303 	int64		input_count;	/* number of non-null values seen */
304 	bool		estimating;		/* true if estimating cardinality */
305 
306 	hyperLogLogState abbr_card; /* cardinality estimator */
307 } NumericSortSupport;
308 
309 
310 /* ----------
311  * Fast sum accumulator.
312  *
313  * NumericSumAccum is used to implement SUM(), and other standard aggregates
314  * that track the sum of input values.  It uses 32-bit integers to store the
315  * digits, instead of the normal 16-bit integers (with NBASE=10000).  This
316  * way, we can safely accumulate up to NBASE - 1 values without propagating
317  * carry, before risking overflow of any of the digits.  'num_uncarried'
318  * tracks how many values have been accumulated without propagating carry.
319  *
320  * Positive and negative values are accumulated separately, in 'pos_digits'
321  * and 'neg_digits'.  This is simpler and faster than deciding whether to add
322  * or subtract from the current value, for each new value (see sub_var() for
323  * the logic we avoid by doing this).  Both buffers are of same size, and
324  * have the same weight and scale.  In accum_sum_final(), the positive and
325  * negative sums are added together to produce the final result.
326  *
327  * When a new value has a larger ndigits or weight than the accumulator
328  * currently does, the accumulator is enlarged to accommodate the new value.
329  * We normally have one zero digit reserved for carry propagation, and that
330  * is indicated by the 'have_carry_space' flag.  When accum_sum_carry() uses
331  * up the reserved digit, it clears the 'have_carry_space' flag.  The next
332  * call to accum_sum_add() will enlarge the buffer, to make room for the
333  * extra digit, and set the flag again.
334  *
335  * To initialize a new accumulator, simply reset all fields to zeros.
336  *
337  * The accumulator does not handle NaNs.
338  * ----------
339  */
340 typedef struct NumericSumAccum
341 {
342 	int			ndigits;
343 	int			weight;
344 	int			dscale;
345 	int			num_uncarried;
346 	bool		have_carry_space;
347 	int32	   *pos_digits;
348 	int32	   *neg_digits;
349 } NumericSumAccum;
350 
351 
352 /*
353  * We define our own macros for packing and unpacking abbreviated-key
354  * representations for numeric values in order to avoid depending on
355  * USE_FLOAT8_BYVAL.  The type of abbreviation we use is based only on
356  * the size of a datum, not the argument-passing convention for float8.
357  */
358 #define NUMERIC_ABBREV_BITS (SIZEOF_DATUM * BITS_PER_BYTE)
359 #if SIZEOF_DATUM == 8
360 #define NumericAbbrevGetDatum(X) ((Datum) (X))
361 #define DatumGetNumericAbbrev(X) ((int64) (X))
362 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT64_MIN)
363 #else
364 #define NumericAbbrevGetDatum(X) ((Datum) (X))
365 #define DatumGetNumericAbbrev(X) ((int32) (X))
366 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT32_MIN)
367 #endif
368 
369 
370 /* ----------
371  * Some preinitialized constants
372  * ----------
373  */
374 static const NumericDigit const_zero_data[1] = {0};
375 static const NumericVar const_zero =
376 {0, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_zero_data};
377 
378 static const NumericDigit const_one_data[1] = {1};
379 static const NumericVar const_one =
380 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_one_data};
381 
382 static const NumericDigit const_two_data[1] = {2};
383 static const NumericVar const_two =
384 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_two_data};
385 
386 #if DEC_DIGITS == 4
387 static const NumericDigit const_zero_point_nine_data[1] = {9000};
388 #elif DEC_DIGITS == 2
389 static const NumericDigit const_zero_point_nine_data[1] = {90};
390 #elif DEC_DIGITS == 1
391 static const NumericDigit const_zero_point_nine_data[1] = {9};
392 #endif
393 static const NumericVar const_zero_point_nine =
394 {1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_nine_data};
395 
396 #if DEC_DIGITS == 4
397 static const NumericDigit const_one_point_one_data[2] = {1, 1000};
398 #elif DEC_DIGITS == 2
399 static const NumericDigit const_one_point_one_data[2] = {1, 10};
400 #elif DEC_DIGITS == 1
401 static const NumericDigit const_one_point_one_data[2] = {1, 1};
402 #endif
403 static const NumericVar const_one_point_one =
404 {2, 0, NUMERIC_POS, 1, NULL, (NumericDigit *) const_one_point_one_data};
405 
406 static const NumericVar const_nan =
407 {0, 0, NUMERIC_NAN, 0, NULL, NULL};
408 
409 #if DEC_DIGITS == 4
410 static const int round_powers[4] = {0, 1000, 100, 10};
411 #endif
412 
413 
414 /* ----------
415  * Local functions
416  * ----------
417  */
418 
419 #ifdef NUMERIC_DEBUG
420 static void dump_numeric(const char *str, Numeric num);
421 static void dump_var(const char *str, NumericVar *var);
422 #else
423 #define dump_numeric(s,n)
424 #define dump_var(s,v)
425 #endif
426 
427 #define digitbuf_alloc(ndigits)  \
428 	((NumericDigit *) palloc((ndigits) * sizeof(NumericDigit)))
429 #define digitbuf_free(buf)	\
430 	do { \
431 		 if ((buf) != NULL) \
432 			 pfree(buf); \
433 	} while (0)
434 
435 #define init_var(v)		MemSetAligned(v, 0, sizeof(NumericVar))
436 
437 #define NUMERIC_DIGITS(num) (NUMERIC_HEADER_IS_SHORT(num) ? \
438 	(num)->choice.n_short.n_data : (num)->choice.n_long.n_data)
439 #define NUMERIC_NDIGITS(num) \
440 	((VARSIZE(num) - NUMERIC_HEADER_SIZE(num)) / sizeof(NumericDigit))
441 #define NUMERIC_CAN_BE_SHORT(scale,weight) \
442 	((scale) <= NUMERIC_SHORT_DSCALE_MAX && \
443 	(weight) <= NUMERIC_SHORT_WEIGHT_MAX && \
444 	(weight) >= NUMERIC_SHORT_WEIGHT_MIN)
445 
446 static void alloc_var(NumericVar *var, int ndigits);
447 static void free_var(NumericVar *var);
448 static void zero_var(NumericVar *var);
449 
450 static const char *set_var_from_str(const char *str, const char *cp,
451 									NumericVar *dest);
452 static void set_var_from_num(Numeric value, NumericVar *dest);
453 static void init_var_from_num(Numeric num, NumericVar *dest);
454 static void set_var_from_var(const NumericVar *value, NumericVar *dest);
455 static char *get_str_from_var(const NumericVar *var);
456 static char *get_str_from_var_sci(const NumericVar *var, int rscale);
457 
458 static Numeric make_result(const NumericVar *var);
459 static Numeric make_result_opt_error(const NumericVar *var, bool *error);
460 
461 static void apply_typmod(NumericVar *var, int32 typmod);
462 
463 static bool numericvar_to_int32(const NumericVar *var, int32 *result);
464 static bool numericvar_to_int64(const NumericVar *var, int64 *result);
465 static void int64_to_numericvar(int64 val, NumericVar *var);
466 #ifdef HAVE_INT128
467 static bool numericvar_to_int128(const NumericVar *var, int128 *result);
468 static void int128_to_numericvar(int128 val, NumericVar *var);
469 #endif
470 static double numeric_to_double_no_overflow(Numeric num);
471 static double numericvar_to_double_no_overflow(const NumericVar *var);
472 
473 static Datum numeric_abbrev_convert(Datum original_datum, SortSupport ssup);
474 static bool numeric_abbrev_abort(int memtupcount, SortSupport ssup);
475 static int	numeric_fast_cmp(Datum x, Datum y, SortSupport ssup);
476 static int	numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup);
477 
478 static Datum numeric_abbrev_convert_var(const NumericVar *var,
479 										NumericSortSupport *nss);
480 
481 static int	cmp_numerics(Numeric num1, Numeric num2);
482 static int	cmp_var(const NumericVar *var1, const NumericVar *var2);
483 static int	cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
484 						   int var1weight, int var1sign,
485 						   const NumericDigit *var2digits, int var2ndigits,
486 						   int var2weight, int var2sign);
487 static void add_var(const NumericVar *var1, const NumericVar *var2,
488 					NumericVar *result);
489 static void sub_var(const NumericVar *var1, const NumericVar *var2,
490 					NumericVar *result);
491 static void mul_var(const NumericVar *var1, const NumericVar *var2,
492 					NumericVar *result,
493 					int rscale);
494 static void div_var(const NumericVar *var1, const NumericVar *var2,
495 					NumericVar *result,
496 					int rscale, bool round);
497 static void div_var_fast(const NumericVar *var1, const NumericVar *var2,
498 						 NumericVar *result, int rscale, bool round);
499 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
500 static void mod_var(const NumericVar *var1, const NumericVar *var2,
501 					NumericVar *result);
502 static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
503 						NumericVar *quot, NumericVar *rem);
504 static void ceil_var(const NumericVar *var, NumericVar *result);
505 static void floor_var(const NumericVar *var, NumericVar *result);
506 
507 static void gcd_var(const NumericVar *var1, const NumericVar *var2,
508 					NumericVar *result);
509 static void sqrt_var(const NumericVar *arg, NumericVar *result, int rscale);
510 static void exp_var(const NumericVar *arg, NumericVar *result, int rscale);
511 static int	estimate_ln_dweight(const NumericVar *var);
512 static void ln_var(const NumericVar *arg, NumericVar *result, int rscale);
513 static void log_var(const NumericVar *base, const NumericVar *num,
514 					NumericVar *result);
515 static void power_var(const NumericVar *base, const NumericVar *exp,
516 					  NumericVar *result);
517 static void power_var_int(const NumericVar *base, int exp, NumericVar *result,
518 						  int rscale);
519 static void power_ten_int(int exp, NumericVar *result);
520 
521 static int	cmp_abs(const NumericVar *var1, const NumericVar *var2);
522 static int	cmp_abs_common(const NumericDigit *var1digits, int var1ndigits,
523 						   int var1weight,
524 						   const NumericDigit *var2digits, int var2ndigits,
525 						   int var2weight);
526 static void add_abs(const NumericVar *var1, const NumericVar *var2,
527 					NumericVar *result);
528 static void sub_abs(const NumericVar *var1, const NumericVar *var2,
529 					NumericVar *result);
530 static void round_var(NumericVar *var, int rscale);
531 static void trunc_var(NumericVar *var, int rscale);
532 static void strip_var(NumericVar *var);
533 static void compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
534 						   const NumericVar *count_var, NumericVar *result_var);
535 
536 static void accum_sum_add(NumericSumAccum *accum, const NumericVar *var1);
537 static void accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val);
538 static void accum_sum_carry(NumericSumAccum *accum);
539 static void accum_sum_reset(NumericSumAccum *accum);
540 static void accum_sum_final(NumericSumAccum *accum, NumericVar *result);
541 static void accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src);
542 static void accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2);
543 
544 
545 /* ----------------------------------------------------------------------
546  *
547  * Input-, output- and rounding-functions
548  *
549  * ----------------------------------------------------------------------
550  */
551 
552 
553 /*
554  * numeric_in() -
555  *
556  *	Input function for numeric data type
557  */
558 Datum
numeric_in(PG_FUNCTION_ARGS)559 numeric_in(PG_FUNCTION_ARGS)
560 {
561 	char	   *str = PG_GETARG_CSTRING(0);
562 
563 #ifdef NOT_USED
564 	Oid			typelem = PG_GETARG_OID(1);
565 #endif
566 	int32		typmod = PG_GETARG_INT32(2);
567 	Numeric		res;
568 	const char *cp;
569 
570 	/* Skip leading spaces */
571 	cp = str;
572 	while (*cp)
573 	{
574 		if (!isspace((unsigned char) *cp))
575 			break;
576 		cp++;
577 	}
578 
579 	/*
580 	 * Check for NaN
581 	 */
582 	if (pg_strncasecmp(cp, "NaN", 3) == 0)
583 	{
584 		res = make_result(&const_nan);
585 
586 		/* Should be nothing left but spaces */
587 		cp += 3;
588 		while (*cp)
589 		{
590 			if (!isspace((unsigned char) *cp))
591 				ereport(ERROR,
592 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
593 						 errmsg("invalid input syntax for type %s: \"%s\"",
594 								"numeric", str)));
595 			cp++;
596 		}
597 	}
598 	else
599 	{
600 		/*
601 		 * Use set_var_from_str() to parse a normal numeric value
602 		 */
603 		NumericVar	value;
604 
605 		init_var(&value);
606 
607 		cp = set_var_from_str(str, cp, &value);
608 
609 		/*
610 		 * We duplicate a few lines of code here because we would like to
611 		 * throw any trailing-junk syntax error before any semantic error
612 		 * resulting from apply_typmod.  We can't easily fold the two cases
613 		 * together because we mustn't apply apply_typmod to a NaN.
614 		 */
615 		while (*cp)
616 		{
617 			if (!isspace((unsigned char) *cp))
618 				ereport(ERROR,
619 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
620 						 errmsg("invalid input syntax for type %s: \"%s\"",
621 								"numeric", str)));
622 			cp++;
623 		}
624 
625 		apply_typmod(&value, typmod);
626 
627 		res = make_result(&value);
628 		free_var(&value);
629 	}
630 
631 	PG_RETURN_NUMERIC(res);
632 }
633 
634 
635 /*
636  * numeric_out() -
637  *
638  *	Output function for numeric data type
639  */
640 Datum
numeric_out(PG_FUNCTION_ARGS)641 numeric_out(PG_FUNCTION_ARGS)
642 {
643 	Numeric		num = PG_GETARG_NUMERIC(0);
644 	NumericVar	x;
645 	char	   *str;
646 
647 	/*
648 	 * Handle NaN
649 	 */
650 	if (NUMERIC_IS_NAN(num))
651 		PG_RETURN_CSTRING(pstrdup("NaN"));
652 
653 	/*
654 	 * Get the number in the variable format.
655 	 */
656 	init_var_from_num(num, &x);
657 
658 	str = get_str_from_var(&x);
659 
660 	PG_RETURN_CSTRING(str);
661 }
662 
663 /*
664  * numeric_is_nan() -
665  *
666  *	Is Numeric value a NaN?
667  */
668 bool
numeric_is_nan(Numeric num)669 numeric_is_nan(Numeric num)
670 {
671 	return NUMERIC_IS_NAN(num);
672 }
673 
674 /*
675  * numeric_maximum_size() -
676  *
677  *	Maximum size of a numeric with given typmod, or -1 if unlimited/unknown.
678  */
679 int32
numeric_maximum_size(int32 typmod)680 numeric_maximum_size(int32 typmod)
681 {
682 	int			precision;
683 	int			numeric_digits;
684 
685 	if (typmod < (int32) (VARHDRSZ))
686 		return -1;
687 
688 	/* precision (ie, max # of digits) is in upper bits of typmod */
689 	precision = ((typmod - VARHDRSZ) >> 16) & 0xffff;
690 
691 	/*
692 	 * This formula computes the maximum number of NumericDigits we could need
693 	 * in order to store the specified number of decimal digits. Because the
694 	 * weight is stored as a number of NumericDigits rather than a number of
695 	 * decimal digits, it's possible that the first NumericDigit will contain
696 	 * only a single decimal digit.  Thus, the first two decimal digits can
697 	 * require two NumericDigits to store, but it isn't until we reach
698 	 * DEC_DIGITS + 2 decimal digits that we potentially need a third
699 	 * NumericDigit.
700 	 */
701 	numeric_digits = (precision + 2 * (DEC_DIGITS - 1)) / DEC_DIGITS;
702 
703 	/*
704 	 * In most cases, the size of a numeric will be smaller than the value
705 	 * computed below, because the varlena header will typically get toasted
706 	 * down to a single byte before being stored on disk, and it may also be
707 	 * possible to use a short numeric header.  But our job here is to compute
708 	 * the worst case.
709 	 */
710 	return NUMERIC_HDRSZ + (numeric_digits * sizeof(NumericDigit));
711 }
712 
713 /*
714  * numeric_out_sci() -
715  *
716  *	Output function for numeric data type in scientific notation.
717  */
718 char *
numeric_out_sci(Numeric num,int scale)719 numeric_out_sci(Numeric num, int scale)
720 {
721 	NumericVar	x;
722 	char	   *str;
723 
724 	/*
725 	 * Handle NaN
726 	 */
727 	if (NUMERIC_IS_NAN(num))
728 		return pstrdup("NaN");
729 
730 	init_var_from_num(num, &x);
731 
732 	str = get_str_from_var_sci(&x, scale);
733 
734 	return str;
735 }
736 
737 /*
738  * numeric_normalize() -
739  *
740  *	Output function for numeric data type, suppressing insignificant trailing
741  *	zeroes and then any trailing decimal point.  The intent of this is to
742  *	produce strings that are equal if and only if the input numeric values
743  *	compare equal.
744  */
745 char *
numeric_normalize(Numeric num)746 numeric_normalize(Numeric num)
747 {
748 	NumericVar	x;
749 	char	   *str;
750 	int			last;
751 
752 	/*
753 	 * Handle NaN
754 	 */
755 	if (NUMERIC_IS_NAN(num))
756 		return pstrdup("NaN");
757 
758 	init_var_from_num(num, &x);
759 
760 	str = get_str_from_var(&x);
761 
762 	/* If there's no decimal point, there's certainly nothing to remove. */
763 	if (strchr(str, '.') != NULL)
764 	{
765 		/*
766 		 * Back up over trailing fractional zeroes.  Since there is a decimal
767 		 * point, this loop will terminate safely.
768 		 */
769 		last = strlen(str) - 1;
770 		while (str[last] == '0')
771 			last--;
772 
773 		/* We want to get rid of the decimal point too, if it's now last. */
774 		if (str[last] == '.')
775 			last--;
776 
777 		/* Delete whatever we backed up over. */
778 		str[last + 1] = '\0';
779 	}
780 
781 	return str;
782 }
783 
784 /*
785  *		numeric_recv			- converts external binary format to numeric
786  *
787  * External format is a sequence of int16's:
788  * ndigits, weight, sign, dscale, NumericDigits.
789  */
790 Datum
numeric_recv(PG_FUNCTION_ARGS)791 numeric_recv(PG_FUNCTION_ARGS)
792 {
793 	StringInfo	buf = (StringInfo) PG_GETARG_POINTER(0);
794 
795 #ifdef NOT_USED
796 	Oid			typelem = PG_GETARG_OID(1);
797 #endif
798 	int32		typmod = PG_GETARG_INT32(2);
799 	NumericVar	value;
800 	Numeric		res;
801 	int			len,
802 				i;
803 
804 	init_var(&value);
805 
806 	len = (uint16) pq_getmsgint(buf, sizeof(uint16));
807 
808 	alloc_var(&value, len);
809 
810 	value.weight = (int16) pq_getmsgint(buf, sizeof(int16));
811 	/* we allow any int16 for weight --- OK? */
812 
813 	value.sign = (uint16) pq_getmsgint(buf, sizeof(uint16));
814 	if (!(value.sign == NUMERIC_POS ||
815 		  value.sign == NUMERIC_NEG ||
816 		  value.sign == NUMERIC_NAN))
817 		ereport(ERROR,
818 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
819 				 errmsg("invalid sign in external \"numeric\" value")));
820 
821 	value.dscale = (uint16) pq_getmsgint(buf, sizeof(uint16));
822 	if ((value.dscale & NUMERIC_DSCALE_MASK) != value.dscale)
823 		ereport(ERROR,
824 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
825 				 errmsg("invalid scale in external \"numeric\" value")));
826 
827 	for (i = 0; i < len; i++)
828 	{
829 		NumericDigit d = pq_getmsgint(buf, sizeof(NumericDigit));
830 
831 		if (d < 0 || d >= NBASE)
832 			ereport(ERROR,
833 					(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
834 					 errmsg("invalid digit in external \"numeric\" value")));
835 		value.digits[i] = d;
836 	}
837 
838 	/*
839 	 * If the given dscale would hide any digits, truncate those digits away.
840 	 * We could alternatively throw an error, but that would take a bunch of
841 	 * extra code (about as much as trunc_var involves), and it might cause
842 	 * client compatibility issues.
843 	 */
844 	trunc_var(&value, value.dscale);
845 
846 	apply_typmod(&value, typmod);
847 
848 	res = make_result(&value);
849 	free_var(&value);
850 
851 	PG_RETURN_NUMERIC(res);
852 }
853 
854 /*
855  *		numeric_send			- converts numeric to binary format
856  */
857 Datum
numeric_send(PG_FUNCTION_ARGS)858 numeric_send(PG_FUNCTION_ARGS)
859 {
860 	Numeric		num = PG_GETARG_NUMERIC(0);
861 	NumericVar	x;
862 	StringInfoData buf;
863 	int			i;
864 
865 	init_var_from_num(num, &x);
866 
867 	pq_begintypsend(&buf);
868 
869 	pq_sendint16(&buf, x.ndigits);
870 	pq_sendint16(&buf, x.weight);
871 	pq_sendint16(&buf, x.sign);
872 	pq_sendint16(&buf, x.dscale);
873 	for (i = 0; i < x.ndigits; i++)
874 		pq_sendint16(&buf, x.digits[i]);
875 
876 	PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
877 }
878 
879 
880 /*
881  * numeric_support()
882  *
883  * Planner support function for the numeric() length coercion function.
884  *
885  * Flatten calls that solely represent increases in allowable precision.
886  * Scale changes mutate every datum, so they are unoptimizable.  Some values,
887  * e.g. 1E-1001, can only fit into an unconstrained numeric, so a change from
888  * an unconstrained numeric to any constrained numeric is also unoptimizable.
889  */
890 Datum
numeric_support(PG_FUNCTION_ARGS)891 numeric_support(PG_FUNCTION_ARGS)
892 {
893 	Node	   *rawreq = (Node *) PG_GETARG_POINTER(0);
894 	Node	   *ret = NULL;
895 
896 	if (IsA(rawreq, SupportRequestSimplify))
897 	{
898 		SupportRequestSimplify *req = (SupportRequestSimplify *) rawreq;
899 		FuncExpr   *expr = req->fcall;
900 		Node	   *typmod;
901 
902 		Assert(list_length(expr->args) >= 2);
903 
904 		typmod = (Node *) lsecond(expr->args);
905 
906 		if (IsA(typmod, Const) && !((Const *) typmod)->constisnull)
907 		{
908 			Node	   *source = (Node *) linitial(expr->args);
909 			int32		old_typmod = exprTypmod(source);
910 			int32		new_typmod = DatumGetInt32(((Const *) typmod)->constvalue);
911 			int32		old_scale = (old_typmod - VARHDRSZ) & 0xffff;
912 			int32		new_scale = (new_typmod - VARHDRSZ) & 0xffff;
913 			int32		old_precision = (old_typmod - VARHDRSZ) >> 16 & 0xffff;
914 			int32		new_precision = (new_typmod - VARHDRSZ) >> 16 & 0xffff;
915 
916 			/*
917 			 * If new_typmod < VARHDRSZ, the destination is unconstrained;
918 			 * that's always OK.  If old_typmod >= VARHDRSZ, the source is
919 			 * constrained, and we're OK if the scale is unchanged and the
920 			 * precision is not decreasing.  See further notes in function
921 			 * header comment.
922 			 */
923 			if (new_typmod < (int32) VARHDRSZ ||
924 				(old_typmod >= (int32) VARHDRSZ &&
925 				 new_scale == old_scale && new_precision >= old_precision))
926 				ret = relabel_to_typmod(source, new_typmod);
927 		}
928 	}
929 
930 	PG_RETURN_POINTER(ret);
931 }
932 
933 /*
934  * numeric() -
935  *
936  *	This is a special function called by the Postgres database system
937  *	before a value is stored in a tuple's attribute. The precision and
938  *	scale of the attribute have to be applied on the value.
939  */
940 Datum
numeric(PG_FUNCTION_ARGS)941 numeric		(PG_FUNCTION_ARGS)
942 {
943 	Numeric		num = PG_GETARG_NUMERIC(0);
944 	int32		typmod = PG_GETARG_INT32(1);
945 	Numeric		new;
946 	int32		tmp_typmod;
947 	int			precision;
948 	int			scale;
949 	int			ddigits;
950 	int			maxdigits;
951 	NumericVar	var;
952 
953 	/*
954 	 * Handle NaN
955 	 */
956 	if (NUMERIC_IS_NAN(num))
957 		PG_RETURN_NUMERIC(make_result(&const_nan));
958 
959 	/*
960 	 * If the value isn't a valid type modifier, simply return a copy of the
961 	 * input value
962 	 */
963 	if (typmod < (int32) (VARHDRSZ))
964 	{
965 		new = (Numeric) palloc(VARSIZE(num));
966 		memcpy(new, num, VARSIZE(num));
967 		PG_RETURN_NUMERIC(new);
968 	}
969 
970 	/*
971 	 * Get the precision and scale out of the typmod value
972 	 */
973 	tmp_typmod = typmod - VARHDRSZ;
974 	precision = (tmp_typmod >> 16) & 0xffff;
975 	scale = tmp_typmod & 0xffff;
976 	maxdigits = precision - scale;
977 
978 	/*
979 	 * If the number is certainly in bounds and due to the target scale no
980 	 * rounding could be necessary, just make a copy of the input and modify
981 	 * its scale fields, unless the larger scale forces us to abandon the
982 	 * short representation.  (Note we assume the existing dscale is
983 	 * honest...)
984 	 */
985 	ddigits = (NUMERIC_WEIGHT(num) + 1) * DEC_DIGITS;
986 	if (ddigits <= maxdigits && scale >= NUMERIC_DSCALE(num)
987 		&& (NUMERIC_CAN_BE_SHORT(scale, NUMERIC_WEIGHT(num))
988 			|| !NUMERIC_IS_SHORT(num)))
989 	{
990 		new = (Numeric) palloc(VARSIZE(num));
991 		memcpy(new, num, VARSIZE(num));
992 		if (NUMERIC_IS_SHORT(num))
993 			new->choice.n_short.n_header =
994 				(num->choice.n_short.n_header & ~NUMERIC_SHORT_DSCALE_MASK)
995 				| (scale << NUMERIC_SHORT_DSCALE_SHIFT);
996 		else
997 			new->choice.n_long.n_sign_dscale = NUMERIC_SIGN(new) |
998 				((uint16) scale & NUMERIC_DSCALE_MASK);
999 		PG_RETURN_NUMERIC(new);
1000 	}
1001 
1002 	/*
1003 	 * We really need to fiddle with things - unpack the number into a
1004 	 * variable and let apply_typmod() do it.
1005 	 */
1006 	init_var(&var);
1007 
1008 	set_var_from_num(num, &var);
1009 	apply_typmod(&var, typmod);
1010 	new = make_result(&var);
1011 
1012 	free_var(&var);
1013 
1014 	PG_RETURN_NUMERIC(new);
1015 }
1016 
1017 Datum
numerictypmodin(PG_FUNCTION_ARGS)1018 numerictypmodin(PG_FUNCTION_ARGS)
1019 {
1020 	ArrayType  *ta = PG_GETARG_ARRAYTYPE_P(0);
1021 	int32	   *tl;
1022 	int			n;
1023 	int32		typmod;
1024 
1025 	tl = ArrayGetIntegerTypmods(ta, &n);
1026 
1027 	if (n == 2)
1028 	{
1029 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1030 			ereport(ERROR,
1031 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1032 					 errmsg("NUMERIC precision %d must be between 1 and %d",
1033 							tl[0], NUMERIC_MAX_PRECISION)));
1034 		if (tl[1] < 0 || tl[1] > tl[0])
1035 			ereport(ERROR,
1036 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1037 					 errmsg("NUMERIC scale %d must be between 0 and precision %d",
1038 							tl[1], tl[0])));
1039 		typmod = ((tl[0] << 16) | tl[1]) + VARHDRSZ;
1040 	}
1041 	else if (n == 1)
1042 	{
1043 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1044 			ereport(ERROR,
1045 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1046 					 errmsg("NUMERIC precision %d must be between 1 and %d",
1047 							tl[0], NUMERIC_MAX_PRECISION)));
1048 		/* scale defaults to zero */
1049 		typmod = (tl[0] << 16) + VARHDRSZ;
1050 	}
1051 	else
1052 	{
1053 		ereport(ERROR,
1054 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1055 				 errmsg("invalid NUMERIC type modifier")));
1056 		typmod = 0;				/* keep compiler quiet */
1057 	}
1058 
1059 	PG_RETURN_INT32(typmod);
1060 }
1061 
1062 Datum
numerictypmodout(PG_FUNCTION_ARGS)1063 numerictypmodout(PG_FUNCTION_ARGS)
1064 {
1065 	int32		typmod = PG_GETARG_INT32(0);
1066 	char	   *res = (char *) palloc(64);
1067 
1068 	if (typmod >= 0)
1069 		snprintf(res, 64, "(%d,%d)",
1070 				 ((typmod - VARHDRSZ) >> 16) & 0xffff,
1071 				 (typmod - VARHDRSZ) & 0xffff);
1072 	else
1073 		*res = '\0';
1074 
1075 	PG_RETURN_CSTRING(res);
1076 }
1077 
1078 
1079 /* ----------------------------------------------------------------------
1080  *
1081  * Sign manipulation, rounding and the like
1082  *
1083  * ----------------------------------------------------------------------
1084  */
1085 
1086 Datum
numeric_abs(PG_FUNCTION_ARGS)1087 numeric_abs(PG_FUNCTION_ARGS)
1088 {
1089 	Numeric		num = PG_GETARG_NUMERIC(0);
1090 	Numeric		res;
1091 
1092 	/*
1093 	 * Handle NaN
1094 	 */
1095 	if (NUMERIC_IS_NAN(num))
1096 		PG_RETURN_NUMERIC(make_result(&const_nan));
1097 
1098 	/*
1099 	 * Do it the easy way directly on the packed format
1100 	 */
1101 	res = (Numeric) palloc(VARSIZE(num));
1102 	memcpy(res, num, VARSIZE(num));
1103 
1104 	if (NUMERIC_IS_SHORT(num))
1105 		res->choice.n_short.n_header =
1106 			num->choice.n_short.n_header & ~NUMERIC_SHORT_SIGN_MASK;
1107 	else
1108 		res->choice.n_long.n_sign_dscale = NUMERIC_POS | NUMERIC_DSCALE(num);
1109 
1110 	PG_RETURN_NUMERIC(res);
1111 }
1112 
1113 
1114 Datum
numeric_uminus(PG_FUNCTION_ARGS)1115 numeric_uminus(PG_FUNCTION_ARGS)
1116 {
1117 	Numeric		num = PG_GETARG_NUMERIC(0);
1118 	Numeric		res;
1119 
1120 	/*
1121 	 * Handle NaN
1122 	 */
1123 	if (NUMERIC_IS_NAN(num))
1124 		PG_RETURN_NUMERIC(make_result(&const_nan));
1125 
1126 	/*
1127 	 * Do it the easy way directly on the packed format
1128 	 */
1129 	res = (Numeric) palloc(VARSIZE(num));
1130 	memcpy(res, num, VARSIZE(num));
1131 
1132 	/*
1133 	 * The packed format is known to be totally zero digit trimmed always. So
1134 	 * we can identify a ZERO by the fact that there are no digits at all.  Do
1135 	 * nothing to a zero.
1136 	 */
1137 	if (NUMERIC_NDIGITS(num) != 0)
1138 	{
1139 		/* Else, flip the sign */
1140 		if (NUMERIC_IS_SHORT(num))
1141 			res->choice.n_short.n_header =
1142 				num->choice.n_short.n_header ^ NUMERIC_SHORT_SIGN_MASK;
1143 		else if (NUMERIC_SIGN(num) == NUMERIC_POS)
1144 			res->choice.n_long.n_sign_dscale =
1145 				NUMERIC_NEG | NUMERIC_DSCALE(num);
1146 		else
1147 			res->choice.n_long.n_sign_dscale =
1148 				NUMERIC_POS | NUMERIC_DSCALE(num);
1149 	}
1150 
1151 	PG_RETURN_NUMERIC(res);
1152 }
1153 
1154 
1155 Datum
numeric_uplus(PG_FUNCTION_ARGS)1156 numeric_uplus(PG_FUNCTION_ARGS)
1157 {
1158 	Numeric		num = PG_GETARG_NUMERIC(0);
1159 	Numeric		res;
1160 
1161 	res = (Numeric) palloc(VARSIZE(num));
1162 	memcpy(res, num, VARSIZE(num));
1163 
1164 	PG_RETURN_NUMERIC(res);
1165 }
1166 
1167 /*
1168  * numeric_sign() -
1169  *
1170  * returns -1 if the argument is less than 0, 0 if the argument is equal
1171  * to 0, and 1 if the argument is greater than zero.
1172  */
1173 Datum
numeric_sign(PG_FUNCTION_ARGS)1174 numeric_sign(PG_FUNCTION_ARGS)
1175 {
1176 	Numeric		num = PG_GETARG_NUMERIC(0);
1177 	Numeric		res;
1178 	NumericVar	result;
1179 
1180 	/*
1181 	 * Handle NaN
1182 	 */
1183 	if (NUMERIC_IS_NAN(num))
1184 		PG_RETURN_NUMERIC(make_result(&const_nan));
1185 
1186 	init_var(&result);
1187 
1188 	/*
1189 	 * The packed format is known to be totally zero digit trimmed always. So
1190 	 * we can identify a ZERO by the fact that there are no digits at all.
1191 	 */
1192 	if (NUMERIC_NDIGITS(num) == 0)
1193 		set_var_from_var(&const_zero, &result);
1194 	else
1195 	{
1196 		/*
1197 		 * And if there are some, we return a copy of ONE with the sign of our
1198 		 * argument
1199 		 */
1200 		set_var_from_var(&const_one, &result);
1201 		result.sign = NUMERIC_SIGN(num);
1202 	}
1203 
1204 	res = make_result(&result);
1205 	free_var(&result);
1206 
1207 	PG_RETURN_NUMERIC(res);
1208 }
1209 
1210 
1211 /*
1212  * numeric_round() -
1213  *
1214  *	Round a value to have 'scale' digits after the decimal point.
1215  *	We allow negative 'scale', implying rounding before the decimal
1216  *	point --- Oracle interprets rounding that way.
1217  */
1218 Datum
numeric_round(PG_FUNCTION_ARGS)1219 numeric_round(PG_FUNCTION_ARGS)
1220 {
1221 	Numeric		num = PG_GETARG_NUMERIC(0);
1222 	int32		scale = PG_GETARG_INT32(1);
1223 	Numeric		res;
1224 	NumericVar	arg;
1225 
1226 	/*
1227 	 * Handle NaN
1228 	 */
1229 	if (NUMERIC_IS_NAN(num))
1230 		PG_RETURN_NUMERIC(make_result(&const_nan));
1231 
1232 	/*
1233 	 * Limit the scale value to avoid possible overflow in calculations
1234 	 */
1235 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1236 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1237 
1238 	/*
1239 	 * Unpack the argument and round it at the proper digit position
1240 	 */
1241 	init_var(&arg);
1242 	set_var_from_num(num, &arg);
1243 
1244 	round_var(&arg, scale);
1245 
1246 	/* We don't allow negative output dscale */
1247 	if (scale < 0)
1248 		arg.dscale = 0;
1249 
1250 	/*
1251 	 * Return the rounded result
1252 	 */
1253 	res = make_result(&arg);
1254 
1255 	free_var(&arg);
1256 	PG_RETURN_NUMERIC(res);
1257 }
1258 
1259 
1260 /*
1261  * numeric_trunc() -
1262  *
1263  *	Truncate a value to have 'scale' digits after the decimal point.
1264  *	We allow negative 'scale', implying a truncation before the decimal
1265  *	point --- Oracle interprets truncation that way.
1266  */
1267 Datum
numeric_trunc(PG_FUNCTION_ARGS)1268 numeric_trunc(PG_FUNCTION_ARGS)
1269 {
1270 	Numeric		num = PG_GETARG_NUMERIC(0);
1271 	int32		scale = PG_GETARG_INT32(1);
1272 	Numeric		res;
1273 	NumericVar	arg;
1274 
1275 	/*
1276 	 * Handle NaN
1277 	 */
1278 	if (NUMERIC_IS_NAN(num))
1279 		PG_RETURN_NUMERIC(make_result(&const_nan));
1280 
1281 	/*
1282 	 * Limit the scale value to avoid possible overflow in calculations
1283 	 */
1284 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1285 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1286 
1287 	/*
1288 	 * Unpack the argument and truncate it at the proper digit position
1289 	 */
1290 	init_var(&arg);
1291 	set_var_from_num(num, &arg);
1292 
1293 	trunc_var(&arg, scale);
1294 
1295 	/* We don't allow negative output dscale */
1296 	if (scale < 0)
1297 		arg.dscale = 0;
1298 
1299 	/*
1300 	 * Return the truncated result
1301 	 */
1302 	res = make_result(&arg);
1303 
1304 	free_var(&arg);
1305 	PG_RETURN_NUMERIC(res);
1306 }
1307 
1308 
1309 /*
1310  * numeric_ceil() -
1311  *
1312  *	Return the smallest integer greater than or equal to the argument
1313  */
1314 Datum
numeric_ceil(PG_FUNCTION_ARGS)1315 numeric_ceil(PG_FUNCTION_ARGS)
1316 {
1317 	Numeric		num = PG_GETARG_NUMERIC(0);
1318 	Numeric		res;
1319 	NumericVar	result;
1320 
1321 	if (NUMERIC_IS_NAN(num))
1322 		PG_RETURN_NUMERIC(make_result(&const_nan));
1323 
1324 	init_var_from_num(num, &result);
1325 	ceil_var(&result, &result);
1326 
1327 	res = make_result(&result);
1328 	free_var(&result);
1329 
1330 	PG_RETURN_NUMERIC(res);
1331 }
1332 
1333 
1334 /*
1335  * numeric_floor() -
1336  *
1337  *	Return the largest integer equal to or less than the argument
1338  */
1339 Datum
numeric_floor(PG_FUNCTION_ARGS)1340 numeric_floor(PG_FUNCTION_ARGS)
1341 {
1342 	Numeric		num = PG_GETARG_NUMERIC(0);
1343 	Numeric		res;
1344 	NumericVar	result;
1345 
1346 	if (NUMERIC_IS_NAN(num))
1347 		PG_RETURN_NUMERIC(make_result(&const_nan));
1348 
1349 	init_var_from_num(num, &result);
1350 	floor_var(&result, &result);
1351 
1352 	res = make_result(&result);
1353 	free_var(&result);
1354 
1355 	PG_RETURN_NUMERIC(res);
1356 }
1357 
1358 
1359 /*
1360  * generate_series_numeric() -
1361  *
1362  *	Generate series of numeric.
1363  */
1364 Datum
generate_series_numeric(PG_FUNCTION_ARGS)1365 generate_series_numeric(PG_FUNCTION_ARGS)
1366 {
1367 	return generate_series_step_numeric(fcinfo);
1368 }
1369 
1370 Datum
generate_series_step_numeric(PG_FUNCTION_ARGS)1371 generate_series_step_numeric(PG_FUNCTION_ARGS)
1372 {
1373 	generate_series_numeric_fctx *fctx;
1374 	FuncCallContext *funcctx;
1375 	MemoryContext oldcontext;
1376 
1377 	if (SRF_IS_FIRSTCALL())
1378 	{
1379 		Numeric		start_num = PG_GETARG_NUMERIC(0);
1380 		Numeric		stop_num = PG_GETARG_NUMERIC(1);
1381 		NumericVar	steploc = const_one;
1382 
1383 		/* handle NaN in start and stop values */
1384 		if (NUMERIC_IS_NAN(start_num))
1385 			ereport(ERROR,
1386 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1387 					 errmsg("start value cannot be NaN")));
1388 
1389 		if (NUMERIC_IS_NAN(stop_num))
1390 			ereport(ERROR,
1391 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1392 					 errmsg("stop value cannot be NaN")));
1393 
1394 		/* see if we were given an explicit step size */
1395 		if (PG_NARGS() == 3)
1396 		{
1397 			Numeric		step_num = PG_GETARG_NUMERIC(2);
1398 
1399 			if (NUMERIC_IS_NAN(step_num))
1400 				ereport(ERROR,
1401 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1402 						 errmsg("step size cannot be NaN")));
1403 
1404 			init_var_from_num(step_num, &steploc);
1405 
1406 			if (cmp_var(&steploc, &const_zero) == 0)
1407 				ereport(ERROR,
1408 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1409 						 errmsg("step size cannot equal zero")));
1410 		}
1411 
1412 		/* create a function context for cross-call persistence */
1413 		funcctx = SRF_FIRSTCALL_INIT();
1414 
1415 		/*
1416 		 * Switch to memory context appropriate for multiple function calls.
1417 		 */
1418 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1419 
1420 		/* allocate memory for user context */
1421 		fctx = (generate_series_numeric_fctx *)
1422 			palloc(sizeof(generate_series_numeric_fctx));
1423 
1424 		/*
1425 		 * Use fctx to keep state from call to call. Seed current with the
1426 		 * original start value. We must copy the start_num and stop_num
1427 		 * values rather than pointing to them, since we may have detoasted
1428 		 * them in the per-call context.
1429 		 */
1430 		init_var(&fctx->current);
1431 		init_var(&fctx->stop);
1432 		init_var(&fctx->step);
1433 
1434 		set_var_from_num(start_num, &fctx->current);
1435 		set_var_from_num(stop_num, &fctx->stop);
1436 		set_var_from_var(&steploc, &fctx->step);
1437 
1438 		funcctx->user_fctx = fctx;
1439 		MemoryContextSwitchTo(oldcontext);
1440 	}
1441 
1442 	/* stuff done on every call of the function */
1443 	funcctx = SRF_PERCALL_SETUP();
1444 
1445 	/*
1446 	 * Get the saved state and use current state as the result of this
1447 	 * iteration.
1448 	 */
1449 	fctx = funcctx->user_fctx;
1450 
1451 	if ((fctx->step.sign == NUMERIC_POS &&
1452 		 cmp_var(&fctx->current, &fctx->stop) <= 0) ||
1453 		(fctx->step.sign == NUMERIC_NEG &&
1454 		 cmp_var(&fctx->current, &fctx->stop) >= 0))
1455 	{
1456 		Numeric		result = make_result(&fctx->current);
1457 
1458 		/* switch to memory context appropriate for iteration calculation */
1459 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1460 
1461 		/* increment current in preparation for next iteration */
1462 		add_var(&fctx->current, &fctx->step, &fctx->current);
1463 		MemoryContextSwitchTo(oldcontext);
1464 
1465 		/* do when there is more left to send */
1466 		SRF_RETURN_NEXT(funcctx, NumericGetDatum(result));
1467 	}
1468 	else
1469 		/* do when there is no more left */
1470 		SRF_RETURN_DONE(funcctx);
1471 }
1472 
1473 
1474 /*
1475  * Implements the numeric version of the width_bucket() function
1476  * defined by SQL2003. See also width_bucket_float8().
1477  *
1478  * 'bound1' and 'bound2' are the lower and upper bounds of the
1479  * histogram's range, respectively. 'count' is the number of buckets
1480  * in the histogram. width_bucket() returns an integer indicating the
1481  * bucket number that 'operand' belongs to in an equiwidth histogram
1482  * with the specified characteristics. An operand smaller than the
1483  * lower bound is assigned to bucket 0. An operand greater than the
1484  * upper bound is assigned to an additional bucket (with number
1485  * count+1). We don't allow "NaN" for any of the numeric arguments.
1486  */
1487 Datum
width_bucket_numeric(PG_FUNCTION_ARGS)1488 width_bucket_numeric(PG_FUNCTION_ARGS)
1489 {
1490 	Numeric		operand = PG_GETARG_NUMERIC(0);
1491 	Numeric		bound1 = PG_GETARG_NUMERIC(1);
1492 	Numeric		bound2 = PG_GETARG_NUMERIC(2);
1493 	int32		count = PG_GETARG_INT32(3);
1494 	NumericVar	count_var;
1495 	NumericVar	result_var;
1496 	int32		result;
1497 
1498 	if (count <= 0)
1499 		ereport(ERROR,
1500 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1501 				 errmsg("count must be greater than zero")));
1502 
1503 	if (NUMERIC_IS_NAN(operand) ||
1504 		NUMERIC_IS_NAN(bound1) ||
1505 		NUMERIC_IS_NAN(bound2))
1506 		ereport(ERROR,
1507 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1508 				 errmsg("operand, lower bound, and upper bound cannot be NaN")));
1509 
1510 	init_var(&result_var);
1511 	init_var(&count_var);
1512 
1513 	/* Convert 'count' to a numeric, for ease of use later */
1514 	int64_to_numericvar((int64) count, &count_var);
1515 
1516 	switch (cmp_numerics(bound1, bound2))
1517 	{
1518 		case 0:
1519 			ereport(ERROR,
1520 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1521 					 errmsg("lower bound cannot equal upper bound")));
1522 			break;
1523 
1524 			/* bound1 < bound2 */
1525 		case -1:
1526 			if (cmp_numerics(operand, bound1) < 0)
1527 				set_var_from_var(&const_zero, &result_var);
1528 			else if (cmp_numerics(operand, bound2) >= 0)
1529 				add_var(&count_var, &const_one, &result_var);
1530 			else
1531 				compute_bucket(operand, bound1, bound2,
1532 							   &count_var, &result_var);
1533 			break;
1534 
1535 			/* bound1 > bound2 */
1536 		case 1:
1537 			if (cmp_numerics(operand, bound1) > 0)
1538 				set_var_from_var(&const_zero, &result_var);
1539 			else if (cmp_numerics(operand, bound2) <= 0)
1540 				add_var(&count_var, &const_one, &result_var);
1541 			else
1542 				compute_bucket(operand, bound1, bound2,
1543 							   &count_var, &result_var);
1544 			break;
1545 	}
1546 
1547 	/* if result exceeds the range of a legal int4, we ereport here */
1548 	if (!numericvar_to_int32(&result_var, &result))
1549 		ereport(ERROR,
1550 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
1551 				 errmsg("integer out of range")));
1552 
1553 	free_var(&count_var);
1554 	free_var(&result_var);
1555 
1556 	PG_RETURN_INT32(result);
1557 }
1558 
1559 /*
1560  * If 'operand' is not outside the bucket range, determine the correct
1561  * bucket for it to go. The calculations performed by this function
1562  * are derived directly from the SQL2003 spec.
1563  */
1564 static void
compute_bucket(Numeric operand,Numeric bound1,Numeric bound2,const NumericVar * count_var,NumericVar * result_var)1565 compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
1566 			   const NumericVar *count_var, NumericVar *result_var)
1567 {
1568 	NumericVar	bound1_var;
1569 	NumericVar	bound2_var;
1570 	NumericVar	operand_var;
1571 
1572 	init_var_from_num(bound1, &bound1_var);
1573 	init_var_from_num(bound2, &bound2_var);
1574 	init_var_from_num(operand, &operand_var);
1575 
1576 	if (cmp_var(&bound1_var, &bound2_var) < 0)
1577 	{
1578 		sub_var(&operand_var, &bound1_var, &operand_var);
1579 		sub_var(&bound2_var, &bound1_var, &bound2_var);
1580 		div_var(&operand_var, &bound2_var, result_var,
1581 				select_div_scale(&operand_var, &bound2_var), true);
1582 	}
1583 	else
1584 	{
1585 		sub_var(&bound1_var, &operand_var, &operand_var);
1586 		sub_var(&bound1_var, &bound2_var, &bound1_var);
1587 		div_var(&operand_var, &bound1_var, result_var,
1588 				select_div_scale(&operand_var, &bound1_var), true);
1589 	}
1590 
1591 	mul_var(result_var, count_var, result_var,
1592 			result_var->dscale + count_var->dscale);
1593 	add_var(result_var, &const_one, result_var);
1594 	floor_var(result_var, result_var);
1595 
1596 	free_var(&bound1_var);
1597 	free_var(&bound2_var);
1598 	free_var(&operand_var);
1599 }
1600 
1601 /* ----------------------------------------------------------------------
1602  *
1603  * Comparison functions
1604  *
1605  * Note: btree indexes need these routines not to leak memory; therefore,
1606  * be careful to free working copies of toasted datums.  Most places don't
1607  * need to be so careful.
1608  *
1609  * Sort support:
1610  *
1611  * We implement the sortsupport strategy routine in order to get the benefit of
1612  * abbreviation. The ordinary numeric comparison can be quite slow as a result
1613  * of palloc/pfree cycles (due to detoasting packed values for alignment);
1614  * while this could be worked on itself, the abbreviation strategy gives more
1615  * speedup in many common cases.
1616  *
1617  * Two different representations are used for the abbreviated form, one in
1618  * int32 and one in int64, whichever fits into a by-value Datum.  In both cases
1619  * the representation is negated relative to the original value, because we use
1620  * the largest negative value for NaN, which sorts higher than other values. We
1621  * convert the absolute value of the numeric to a 31-bit or 63-bit positive
1622  * value, and then negate it if the original number was positive.
1623  *
1624  * We abort the abbreviation process if the abbreviation cardinality is below
1625  * 0.01% of the row count (1 per 10k non-null rows).  The actual break-even
1626  * point is somewhat below that, perhaps 1 per 30k (at 1 per 100k there's a
1627  * very small penalty), but we don't want to build up too many abbreviated
1628  * values before first testing for abort, so we take the slightly pessimistic
1629  * number.  We make no attempt to estimate the cardinality of the real values,
1630  * since it plays no part in the cost model here (if the abbreviation is equal,
1631  * the cost of comparing equal and unequal underlying values is comparable).
1632  * We discontinue even checking for abort (saving us the hashing overhead) if
1633  * the estimated cardinality gets to 100k; that would be enough to support many
1634  * billions of rows while doing no worse than breaking even.
1635  *
1636  * ----------------------------------------------------------------------
1637  */
1638 
1639 /*
1640  * Sort support strategy routine.
1641  */
1642 Datum
numeric_sortsupport(PG_FUNCTION_ARGS)1643 numeric_sortsupport(PG_FUNCTION_ARGS)
1644 {
1645 	SortSupport ssup = (SortSupport) PG_GETARG_POINTER(0);
1646 
1647 	ssup->comparator = numeric_fast_cmp;
1648 
1649 	if (ssup->abbreviate)
1650 	{
1651 		NumericSortSupport *nss;
1652 		MemoryContext oldcontext = MemoryContextSwitchTo(ssup->ssup_cxt);
1653 
1654 		nss = palloc(sizeof(NumericSortSupport));
1655 
1656 		/*
1657 		 * palloc a buffer for handling unaligned packed values in addition to
1658 		 * the support struct
1659 		 */
1660 		nss->buf = palloc(VARATT_SHORT_MAX + VARHDRSZ + 1);
1661 
1662 		nss->input_count = 0;
1663 		nss->estimating = true;
1664 		initHyperLogLog(&nss->abbr_card, 10);
1665 
1666 		ssup->ssup_extra = nss;
1667 
1668 		ssup->abbrev_full_comparator = ssup->comparator;
1669 		ssup->comparator = numeric_cmp_abbrev;
1670 		ssup->abbrev_converter = numeric_abbrev_convert;
1671 		ssup->abbrev_abort = numeric_abbrev_abort;
1672 
1673 		MemoryContextSwitchTo(oldcontext);
1674 	}
1675 
1676 	PG_RETURN_VOID();
1677 }
1678 
1679 /*
1680  * Abbreviate a numeric datum, handling NaNs and detoasting
1681  * (must not leak memory!)
1682  */
1683 static Datum
numeric_abbrev_convert(Datum original_datum,SortSupport ssup)1684 numeric_abbrev_convert(Datum original_datum, SortSupport ssup)
1685 {
1686 	NumericSortSupport *nss = ssup->ssup_extra;
1687 	void	   *original_varatt = PG_DETOAST_DATUM_PACKED(original_datum);
1688 	Numeric		value;
1689 	Datum		result;
1690 
1691 	nss->input_count += 1;
1692 
1693 	/*
1694 	 * This is to handle packed datums without needing a palloc/pfree cycle;
1695 	 * we keep and reuse a buffer large enough to handle any short datum.
1696 	 */
1697 	if (VARATT_IS_SHORT(original_varatt))
1698 	{
1699 		void	   *buf = nss->buf;
1700 		Size		sz = VARSIZE_SHORT(original_varatt) - VARHDRSZ_SHORT;
1701 
1702 		Assert(sz <= VARATT_SHORT_MAX - VARHDRSZ_SHORT);
1703 
1704 		SET_VARSIZE(buf, VARHDRSZ + sz);
1705 		memcpy(VARDATA(buf), VARDATA_SHORT(original_varatt), sz);
1706 
1707 		value = (Numeric) buf;
1708 	}
1709 	else
1710 		value = (Numeric) original_varatt;
1711 
1712 	if (NUMERIC_IS_NAN(value))
1713 	{
1714 		result = NUMERIC_ABBREV_NAN;
1715 	}
1716 	else
1717 	{
1718 		NumericVar	var;
1719 
1720 		init_var_from_num(value, &var);
1721 
1722 		result = numeric_abbrev_convert_var(&var, nss);
1723 	}
1724 
1725 	/* should happen only for external/compressed toasts */
1726 	if ((Pointer) original_varatt != DatumGetPointer(original_datum))
1727 		pfree(original_varatt);
1728 
1729 	return result;
1730 }
1731 
1732 /*
1733  * Consider whether to abort abbreviation.
1734  *
1735  * We pay no attention to the cardinality of the non-abbreviated data. There is
1736  * no reason to do so: unlike text, we have no fast check for equal values, so
1737  * we pay the full overhead whenever the abbreviations are equal regardless of
1738  * whether the underlying values are also equal.
1739  */
1740 static bool
numeric_abbrev_abort(int memtupcount,SortSupport ssup)1741 numeric_abbrev_abort(int memtupcount, SortSupport ssup)
1742 {
1743 	NumericSortSupport *nss = ssup->ssup_extra;
1744 	double		abbr_card;
1745 
1746 	if (memtupcount < 10000 || nss->input_count < 10000 || !nss->estimating)
1747 		return false;
1748 
1749 	abbr_card = estimateHyperLogLog(&nss->abbr_card);
1750 
1751 	/*
1752 	 * If we have >100k distinct values, then even if we were sorting many
1753 	 * billion rows we'd likely still break even, and the penalty of undoing
1754 	 * that many rows of abbrevs would probably not be worth it. Stop even
1755 	 * counting at that point.
1756 	 */
1757 	if (abbr_card > 100000.0)
1758 	{
1759 #ifdef TRACE_SORT
1760 		if (trace_sort)
1761 			elog(LOG,
1762 				 "numeric_abbrev: estimation ends at cardinality %f"
1763 				 " after " INT64_FORMAT " values (%d rows)",
1764 				 abbr_card, nss->input_count, memtupcount);
1765 #endif
1766 		nss->estimating = false;
1767 		return false;
1768 	}
1769 
1770 	/*
1771 	 * Target minimum cardinality is 1 per ~10k of non-null inputs.  (The
1772 	 * break even point is somewhere between one per 100k rows, where
1773 	 * abbreviation has a very slight penalty, and 1 per 10k where it wins by
1774 	 * a measurable percentage.)  We use the relatively pessimistic 10k
1775 	 * threshold, and add a 0.5 row fudge factor, because it allows us to
1776 	 * abort earlier on genuinely pathological data where we've had exactly
1777 	 * one abbreviated value in the first 10k (non-null) rows.
1778 	 */
1779 	if (abbr_card < nss->input_count / 10000.0 + 0.5)
1780 	{
1781 #ifdef TRACE_SORT
1782 		if (trace_sort)
1783 			elog(LOG,
1784 				 "numeric_abbrev: aborting abbreviation at cardinality %f"
1785 				 " below threshold %f after " INT64_FORMAT " values (%d rows)",
1786 				 abbr_card, nss->input_count / 10000.0 + 0.5,
1787 				 nss->input_count, memtupcount);
1788 #endif
1789 		return true;
1790 	}
1791 
1792 #ifdef TRACE_SORT
1793 	if (trace_sort)
1794 		elog(LOG,
1795 			 "numeric_abbrev: cardinality %f"
1796 			 " after " INT64_FORMAT " values (%d rows)",
1797 			 abbr_card, nss->input_count, memtupcount);
1798 #endif
1799 
1800 	return false;
1801 }
1802 
1803 /*
1804  * Non-fmgr interface to the comparison routine to allow sortsupport to elide
1805  * the fmgr call.  The saving here is small given how slow numeric comparisons
1806  * are, but it is a required part of the sort support API when abbreviations
1807  * are performed.
1808  *
1809  * Two palloc/pfree cycles could be saved here by using persistent buffers for
1810  * aligning short-varlena inputs, but this has not so far been considered to
1811  * be worth the effort.
1812  */
1813 static int
numeric_fast_cmp(Datum x,Datum y,SortSupport ssup)1814 numeric_fast_cmp(Datum x, Datum y, SortSupport ssup)
1815 {
1816 	Numeric		nx = DatumGetNumeric(x);
1817 	Numeric		ny = DatumGetNumeric(y);
1818 	int			result;
1819 
1820 	result = cmp_numerics(nx, ny);
1821 
1822 	if ((Pointer) nx != DatumGetPointer(x))
1823 		pfree(nx);
1824 	if ((Pointer) ny != DatumGetPointer(y))
1825 		pfree(ny);
1826 
1827 	return result;
1828 }
1829 
1830 /*
1831  * Compare abbreviations of values. (Abbreviations may be equal where the true
1832  * values differ, but if the abbreviations differ, they must reflect the
1833  * ordering of the true values.)
1834  */
1835 static int
numeric_cmp_abbrev(Datum x,Datum y,SortSupport ssup)1836 numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup)
1837 {
1838 	/*
1839 	 * NOTE WELL: this is intentionally backwards, because the abbreviation is
1840 	 * negated relative to the original value, to handle NaN.
1841 	 */
1842 	if (DatumGetNumericAbbrev(x) < DatumGetNumericAbbrev(y))
1843 		return 1;
1844 	if (DatumGetNumericAbbrev(x) > DatumGetNumericAbbrev(y))
1845 		return -1;
1846 	return 0;
1847 }
1848 
1849 /*
1850  * Abbreviate a NumericVar according to the available bit size.
1851  *
1852  * The 31-bit value is constructed as:
1853  *
1854  *	0 + 7bits digit weight + 24 bits digit value
1855  *
1856  * where the digit weight is in single decimal digits, not digit words, and
1857  * stored in excess-44 representation[1]. The 24-bit digit value is the 7 most
1858  * significant decimal digits of the value converted to binary. Values whose
1859  * weights would fall outside the representable range are rounded off to zero
1860  * (which is also used to represent actual zeros) or to 0x7FFFFFFF (which
1861  * otherwise cannot occur). Abbreviation therefore fails to gain any advantage
1862  * where values are outside the range 10^-44 to 10^83, which is not considered
1863  * to be a serious limitation, or when values are of the same magnitude and
1864  * equal in the first 7 decimal digits, which is considered to be an
1865  * unavoidable limitation given the available bits. (Stealing three more bits
1866  * to compare another digit would narrow the range of representable weights by
1867  * a factor of 8, which starts to look like a real limiting factor.)
1868  *
1869  * (The value 44 for the excess is essentially arbitrary)
1870  *
1871  * The 63-bit value is constructed as:
1872  *
1873  *	0 + 7bits weight + 4 x 14-bit packed digit words
1874  *
1875  * The weight in this case is again stored in excess-44, but this time it is
1876  * the original weight in digit words (i.e. powers of 10000). The first four
1877  * digit words of the value (if present; trailing zeros are assumed as needed)
1878  * are packed into 14 bits each to form the rest of the value. Again,
1879  * out-of-range values are rounded off to 0 or 0x7FFFFFFFFFFFFFFF. The
1880  * representable range in this case is 10^-176 to 10^332, which is considered
1881  * to be good enough for all practical purposes, and comparison of 4 words
1882  * means that at least 13 decimal digits are compared, which is considered to
1883  * be a reasonable compromise between effectiveness and efficiency in computing
1884  * the abbreviation.
1885  *
1886  * (The value 44 for the excess is even more arbitrary here, it was chosen just
1887  * to match the value used in the 31-bit case)
1888  *
1889  * [1] - Excess-k representation means that the value is offset by adding 'k'
1890  * and then treated as unsigned, so the smallest representable value is stored
1891  * with all bits zero. This allows simple comparisons to work on the composite
1892  * value.
1893  */
1894 
1895 #if NUMERIC_ABBREV_BITS == 64
1896 
1897 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)1898 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
1899 {
1900 	int			ndigits = var->ndigits;
1901 	int			weight = var->weight;
1902 	int64		result;
1903 
1904 	if (ndigits == 0 || weight < -44)
1905 	{
1906 		result = 0;
1907 	}
1908 	else if (weight > 83)
1909 	{
1910 		result = PG_INT64_MAX;
1911 	}
1912 	else
1913 	{
1914 		result = ((int64) (weight + 44) << 56);
1915 
1916 		switch (ndigits)
1917 		{
1918 			default:
1919 				result |= ((int64) var->digits[3]);
1920 				/* FALLTHROUGH */
1921 			case 3:
1922 				result |= ((int64) var->digits[2]) << 14;
1923 				/* FALLTHROUGH */
1924 			case 2:
1925 				result |= ((int64) var->digits[1]) << 28;
1926 				/* FALLTHROUGH */
1927 			case 1:
1928 				result |= ((int64) var->digits[0]) << 42;
1929 				break;
1930 		}
1931 	}
1932 
1933 	/* the abbrev is negated relative to the original */
1934 	if (var->sign == NUMERIC_POS)
1935 		result = -result;
1936 
1937 	if (nss->estimating)
1938 	{
1939 		uint32		tmp = ((uint32) result
1940 						   ^ (uint32) ((uint64) result >> 32));
1941 
1942 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
1943 	}
1944 
1945 	return NumericAbbrevGetDatum(result);
1946 }
1947 
1948 #endif							/* NUMERIC_ABBREV_BITS == 64 */
1949 
1950 #if NUMERIC_ABBREV_BITS == 32
1951 
1952 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)1953 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
1954 {
1955 	int			ndigits = var->ndigits;
1956 	int			weight = var->weight;
1957 	int32		result;
1958 
1959 	if (ndigits == 0 || weight < -11)
1960 	{
1961 		result = 0;
1962 	}
1963 	else if (weight > 20)
1964 	{
1965 		result = PG_INT32_MAX;
1966 	}
1967 	else
1968 	{
1969 		NumericDigit nxt1 = (ndigits > 1) ? var->digits[1] : 0;
1970 
1971 		weight = (weight + 11) * 4;
1972 
1973 		result = var->digits[0];
1974 
1975 		/*
1976 		 * "result" now has 1 to 4 nonzero decimal digits. We pack in more
1977 		 * digits to make 7 in total (largest we can fit in 24 bits)
1978 		 */
1979 
1980 		if (result > 999)
1981 		{
1982 			/* already have 4 digits, add 3 more */
1983 			result = (result * 1000) + (nxt1 / 10);
1984 			weight += 3;
1985 		}
1986 		else if (result > 99)
1987 		{
1988 			/* already have 3 digits, add 4 more */
1989 			result = (result * 10000) + nxt1;
1990 			weight += 2;
1991 		}
1992 		else if (result > 9)
1993 		{
1994 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
1995 
1996 			/* already have 2 digits, add 5 more */
1997 			result = (result * 100000) + (nxt1 * 10) + (nxt2 / 1000);
1998 			weight += 1;
1999 		}
2000 		else
2001 		{
2002 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
2003 
2004 			/* already have 1 digit, add 6 more */
2005 			result = (result * 1000000) + (nxt1 * 100) + (nxt2 / 100);
2006 		}
2007 
2008 		result = result | (weight << 24);
2009 	}
2010 
2011 	/* the abbrev is negated relative to the original */
2012 	if (var->sign == NUMERIC_POS)
2013 		result = -result;
2014 
2015 	if (nss->estimating)
2016 	{
2017 		uint32		tmp = (uint32) result;
2018 
2019 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
2020 	}
2021 
2022 	return NumericAbbrevGetDatum(result);
2023 }
2024 
2025 #endif							/* NUMERIC_ABBREV_BITS == 32 */
2026 
2027 /*
2028  * Ordinary (non-sortsupport) comparisons follow.
2029  */
2030 
2031 Datum
numeric_cmp(PG_FUNCTION_ARGS)2032 numeric_cmp(PG_FUNCTION_ARGS)
2033 {
2034 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2035 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2036 	int			result;
2037 
2038 	result = cmp_numerics(num1, num2);
2039 
2040 	PG_FREE_IF_COPY(num1, 0);
2041 	PG_FREE_IF_COPY(num2, 1);
2042 
2043 	PG_RETURN_INT32(result);
2044 }
2045 
2046 
2047 Datum
numeric_eq(PG_FUNCTION_ARGS)2048 numeric_eq(PG_FUNCTION_ARGS)
2049 {
2050 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2051 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2052 	bool		result;
2053 
2054 	result = cmp_numerics(num1, num2) == 0;
2055 
2056 	PG_FREE_IF_COPY(num1, 0);
2057 	PG_FREE_IF_COPY(num2, 1);
2058 
2059 	PG_RETURN_BOOL(result);
2060 }
2061 
2062 Datum
numeric_ne(PG_FUNCTION_ARGS)2063 numeric_ne(PG_FUNCTION_ARGS)
2064 {
2065 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2066 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2067 	bool		result;
2068 
2069 	result = cmp_numerics(num1, num2) != 0;
2070 
2071 	PG_FREE_IF_COPY(num1, 0);
2072 	PG_FREE_IF_COPY(num2, 1);
2073 
2074 	PG_RETURN_BOOL(result);
2075 }
2076 
2077 Datum
numeric_gt(PG_FUNCTION_ARGS)2078 numeric_gt(PG_FUNCTION_ARGS)
2079 {
2080 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2081 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2082 	bool		result;
2083 
2084 	result = cmp_numerics(num1, num2) > 0;
2085 
2086 	PG_FREE_IF_COPY(num1, 0);
2087 	PG_FREE_IF_COPY(num2, 1);
2088 
2089 	PG_RETURN_BOOL(result);
2090 }
2091 
2092 Datum
numeric_ge(PG_FUNCTION_ARGS)2093 numeric_ge(PG_FUNCTION_ARGS)
2094 {
2095 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2096 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2097 	bool		result;
2098 
2099 	result = cmp_numerics(num1, num2) >= 0;
2100 
2101 	PG_FREE_IF_COPY(num1, 0);
2102 	PG_FREE_IF_COPY(num2, 1);
2103 
2104 	PG_RETURN_BOOL(result);
2105 }
2106 
2107 Datum
numeric_lt(PG_FUNCTION_ARGS)2108 numeric_lt(PG_FUNCTION_ARGS)
2109 {
2110 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2111 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2112 	bool		result;
2113 
2114 	result = cmp_numerics(num1, num2) < 0;
2115 
2116 	PG_FREE_IF_COPY(num1, 0);
2117 	PG_FREE_IF_COPY(num2, 1);
2118 
2119 	PG_RETURN_BOOL(result);
2120 }
2121 
2122 Datum
numeric_le(PG_FUNCTION_ARGS)2123 numeric_le(PG_FUNCTION_ARGS)
2124 {
2125 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2126 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2127 	bool		result;
2128 
2129 	result = cmp_numerics(num1, num2) <= 0;
2130 
2131 	PG_FREE_IF_COPY(num1, 0);
2132 	PG_FREE_IF_COPY(num2, 1);
2133 
2134 	PG_RETURN_BOOL(result);
2135 }
2136 
2137 static int
cmp_numerics(Numeric num1,Numeric num2)2138 cmp_numerics(Numeric num1, Numeric num2)
2139 {
2140 	int			result;
2141 
2142 	/*
2143 	 * We consider all NANs to be equal and larger than any non-NAN. This is
2144 	 * somewhat arbitrary; the important thing is to have a consistent sort
2145 	 * order.
2146 	 */
2147 	if (NUMERIC_IS_NAN(num1))
2148 	{
2149 		if (NUMERIC_IS_NAN(num2))
2150 			result = 0;			/* NAN = NAN */
2151 		else
2152 			result = 1;			/* NAN > non-NAN */
2153 	}
2154 	else if (NUMERIC_IS_NAN(num2))
2155 	{
2156 		result = -1;			/* non-NAN < NAN */
2157 	}
2158 	else
2159 	{
2160 		result = cmp_var_common(NUMERIC_DIGITS(num1), NUMERIC_NDIGITS(num1),
2161 								NUMERIC_WEIGHT(num1), NUMERIC_SIGN(num1),
2162 								NUMERIC_DIGITS(num2), NUMERIC_NDIGITS(num2),
2163 								NUMERIC_WEIGHT(num2), NUMERIC_SIGN(num2));
2164 	}
2165 
2166 	return result;
2167 }
2168 
2169 /*
2170  * in_range support function for numeric.
2171  */
2172 Datum
in_range_numeric_numeric(PG_FUNCTION_ARGS)2173 in_range_numeric_numeric(PG_FUNCTION_ARGS)
2174 {
2175 	Numeric		val = PG_GETARG_NUMERIC(0);
2176 	Numeric		base = PG_GETARG_NUMERIC(1);
2177 	Numeric		offset = PG_GETARG_NUMERIC(2);
2178 	bool		sub = PG_GETARG_BOOL(3);
2179 	bool		less = PG_GETARG_BOOL(4);
2180 	bool		result;
2181 
2182 	/*
2183 	 * Reject negative or NaN offset.  Negative is per spec, and NaN is
2184 	 * because appropriate semantics for that seem non-obvious.
2185 	 */
2186 	if (NUMERIC_IS_NAN(offset) || NUMERIC_SIGN(offset) == NUMERIC_NEG)
2187 		ereport(ERROR,
2188 				(errcode(ERRCODE_INVALID_PRECEDING_OR_FOLLOWING_SIZE),
2189 				 errmsg("invalid preceding or following size in window function")));
2190 
2191 	/*
2192 	 * Deal with cases where val and/or base is NaN, following the rule that
2193 	 * NaN sorts after non-NaN (cf cmp_numerics).  The offset cannot affect
2194 	 * the conclusion.
2195 	 */
2196 	if (NUMERIC_IS_NAN(val))
2197 	{
2198 		if (NUMERIC_IS_NAN(base))
2199 			result = true;		/* NAN = NAN */
2200 		else
2201 			result = !less;		/* NAN > non-NAN */
2202 	}
2203 	else if (NUMERIC_IS_NAN(base))
2204 	{
2205 		result = less;			/* non-NAN < NAN */
2206 	}
2207 	else
2208 	{
2209 		/*
2210 		 * Otherwise go ahead and compute base +/- offset.  While it's
2211 		 * possible for this to overflow the numeric format, it's unlikely
2212 		 * enough that we don't take measures to prevent it.
2213 		 */
2214 		NumericVar	valv;
2215 		NumericVar	basev;
2216 		NumericVar	offsetv;
2217 		NumericVar	sum;
2218 
2219 		init_var_from_num(val, &valv);
2220 		init_var_from_num(base, &basev);
2221 		init_var_from_num(offset, &offsetv);
2222 		init_var(&sum);
2223 
2224 		if (sub)
2225 			sub_var(&basev, &offsetv, &sum);
2226 		else
2227 			add_var(&basev, &offsetv, &sum);
2228 
2229 		if (less)
2230 			result = (cmp_var(&valv, &sum) <= 0);
2231 		else
2232 			result = (cmp_var(&valv, &sum) >= 0);
2233 
2234 		free_var(&sum);
2235 	}
2236 
2237 	PG_FREE_IF_COPY(val, 0);
2238 	PG_FREE_IF_COPY(base, 1);
2239 	PG_FREE_IF_COPY(offset, 2);
2240 
2241 	PG_RETURN_BOOL(result);
2242 }
2243 
2244 Datum
hash_numeric(PG_FUNCTION_ARGS)2245 hash_numeric(PG_FUNCTION_ARGS)
2246 {
2247 	Numeric		key = PG_GETARG_NUMERIC(0);
2248 	Datum		digit_hash;
2249 	Datum		result;
2250 	int			weight;
2251 	int			start_offset;
2252 	int			end_offset;
2253 	int			i;
2254 	int			hash_len;
2255 	NumericDigit *digits;
2256 
2257 	/* If it's NaN, don't try to hash the rest of the fields */
2258 	if (NUMERIC_IS_NAN(key))
2259 		PG_RETURN_UINT32(0);
2260 
2261 	weight = NUMERIC_WEIGHT(key);
2262 	start_offset = 0;
2263 	end_offset = 0;
2264 
2265 	/*
2266 	 * Omit any leading or trailing zeros from the input to the hash. The
2267 	 * numeric implementation *should* guarantee that leading and trailing
2268 	 * zeros are suppressed, but we're paranoid. Note that we measure the
2269 	 * starting and ending offsets in units of NumericDigits, not bytes.
2270 	 */
2271 	digits = NUMERIC_DIGITS(key);
2272 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2273 	{
2274 		if (digits[i] != (NumericDigit) 0)
2275 			break;
2276 
2277 		start_offset++;
2278 
2279 		/*
2280 		 * The weight is effectively the # of digits before the decimal point,
2281 		 * so decrement it for each leading zero we skip.
2282 		 */
2283 		weight--;
2284 	}
2285 
2286 	/*
2287 	 * If there are no non-zero digits, then the value of the number is zero,
2288 	 * regardless of any other fields.
2289 	 */
2290 	if (NUMERIC_NDIGITS(key) == start_offset)
2291 		PG_RETURN_UINT32(-1);
2292 
2293 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2294 	{
2295 		if (digits[i] != (NumericDigit) 0)
2296 			break;
2297 
2298 		end_offset++;
2299 	}
2300 
2301 	/* If we get here, there should be at least one non-zero digit */
2302 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2303 
2304 	/*
2305 	 * Note that we don't hash on the Numeric's scale, since two numerics can
2306 	 * compare equal but have different scales. We also don't hash on the
2307 	 * sign, although we could: since a sign difference implies inequality,
2308 	 * this shouldn't affect correctness.
2309 	 */
2310 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2311 	digit_hash = hash_any((unsigned char *) (NUMERIC_DIGITS(key) + start_offset),
2312 						  hash_len * sizeof(NumericDigit));
2313 
2314 	/* Mix in the weight, via XOR */
2315 	result = digit_hash ^ weight;
2316 
2317 	PG_RETURN_DATUM(result);
2318 }
2319 
2320 /*
2321  * Returns 64-bit value by hashing a value to a 64-bit value, with a seed.
2322  * Otherwise, similar to hash_numeric.
2323  */
2324 Datum
hash_numeric_extended(PG_FUNCTION_ARGS)2325 hash_numeric_extended(PG_FUNCTION_ARGS)
2326 {
2327 	Numeric		key = PG_GETARG_NUMERIC(0);
2328 	uint64		seed = PG_GETARG_INT64(1);
2329 	Datum		digit_hash;
2330 	Datum		result;
2331 	int			weight;
2332 	int			start_offset;
2333 	int			end_offset;
2334 	int			i;
2335 	int			hash_len;
2336 	NumericDigit *digits;
2337 
2338 	if (NUMERIC_IS_NAN(key))
2339 		PG_RETURN_UINT64(seed);
2340 
2341 	weight = NUMERIC_WEIGHT(key);
2342 	start_offset = 0;
2343 	end_offset = 0;
2344 
2345 	digits = NUMERIC_DIGITS(key);
2346 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2347 	{
2348 		if (digits[i] != (NumericDigit) 0)
2349 			break;
2350 
2351 		start_offset++;
2352 
2353 		weight--;
2354 	}
2355 
2356 	if (NUMERIC_NDIGITS(key) == start_offset)
2357 		PG_RETURN_UINT64(seed - 1);
2358 
2359 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2360 	{
2361 		if (digits[i] != (NumericDigit) 0)
2362 			break;
2363 
2364 		end_offset++;
2365 	}
2366 
2367 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2368 
2369 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2370 	digit_hash = hash_any_extended((unsigned char *) (NUMERIC_DIGITS(key)
2371 													  + start_offset),
2372 								   hash_len * sizeof(NumericDigit),
2373 								   seed);
2374 
2375 	result = UInt64GetDatum(DatumGetUInt64(digit_hash) ^ weight);
2376 
2377 	PG_RETURN_DATUM(result);
2378 }
2379 
2380 
2381 /* ----------------------------------------------------------------------
2382  *
2383  * Basic arithmetic functions
2384  *
2385  * ----------------------------------------------------------------------
2386  */
2387 
2388 
2389 /*
2390  * numeric_add() -
2391  *
2392  *	Add two numerics
2393  */
2394 Datum
numeric_add(PG_FUNCTION_ARGS)2395 numeric_add(PG_FUNCTION_ARGS)
2396 {
2397 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2398 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2399 	Numeric		res;
2400 
2401 	res = numeric_add_opt_error(num1, num2, NULL);
2402 
2403 	PG_RETURN_NUMERIC(res);
2404 }
2405 
2406 /*
2407  * numeric_add_opt_error() -
2408  *
2409  *	Internal version of numeric_add().  If "*have_error" flag is provided,
2410  *	on error it's set to true, NULL returned.  This is helpful when caller
2411  *	need to handle errors by itself.
2412  */
2413 Numeric
numeric_add_opt_error(Numeric num1,Numeric num2,bool * have_error)2414 numeric_add_opt_error(Numeric num1, Numeric num2, bool *have_error)
2415 {
2416 	NumericVar	arg1;
2417 	NumericVar	arg2;
2418 	NumericVar	result;
2419 	Numeric		res;
2420 
2421 	/*
2422 	 * Handle NaN
2423 	 */
2424 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2425 		return make_result(&const_nan);
2426 
2427 	/*
2428 	 * Unpack the values, let add_var() compute the result and return it.
2429 	 */
2430 	init_var_from_num(num1, &arg1);
2431 	init_var_from_num(num2, &arg2);
2432 
2433 	init_var(&result);
2434 	add_var(&arg1, &arg2, &result);
2435 
2436 	res = make_result_opt_error(&result, have_error);
2437 
2438 	free_var(&result);
2439 
2440 	return res;
2441 }
2442 
2443 
2444 /*
2445  * numeric_sub() -
2446  *
2447  *	Subtract one numeric from another
2448  */
2449 Datum
numeric_sub(PG_FUNCTION_ARGS)2450 numeric_sub(PG_FUNCTION_ARGS)
2451 {
2452 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2453 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2454 	Numeric		res;
2455 
2456 	res = numeric_sub_opt_error(num1, num2, NULL);
2457 
2458 	PG_RETURN_NUMERIC(res);
2459 }
2460 
2461 
2462 /*
2463  * numeric_sub_opt_error() -
2464  *
2465  *	Internal version of numeric_sub().  If "*have_error" flag is provided,
2466  *	on error it's set to true, NULL returned.  This is helpful when caller
2467  *	need to handle errors by itself.
2468  */
2469 Numeric
numeric_sub_opt_error(Numeric num1,Numeric num2,bool * have_error)2470 numeric_sub_opt_error(Numeric num1, Numeric num2, bool *have_error)
2471 {
2472 	NumericVar	arg1;
2473 	NumericVar	arg2;
2474 	NumericVar	result;
2475 	Numeric		res;
2476 
2477 	/*
2478 	 * Handle NaN
2479 	 */
2480 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2481 		return make_result(&const_nan);
2482 
2483 	/*
2484 	 * Unpack the values, let sub_var() compute the result and return it.
2485 	 */
2486 	init_var_from_num(num1, &arg1);
2487 	init_var_from_num(num2, &arg2);
2488 
2489 	init_var(&result);
2490 	sub_var(&arg1, &arg2, &result);
2491 
2492 	res = make_result_opt_error(&result, have_error);
2493 
2494 	free_var(&result);
2495 
2496 	return res;
2497 }
2498 
2499 
2500 /*
2501  * numeric_mul() -
2502  *
2503  *	Calculate the product of two numerics
2504  */
2505 Datum
numeric_mul(PG_FUNCTION_ARGS)2506 numeric_mul(PG_FUNCTION_ARGS)
2507 {
2508 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2509 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2510 	Numeric		res;
2511 
2512 	res = numeric_mul_opt_error(num1, num2, NULL);
2513 
2514 	PG_RETURN_NUMERIC(res);
2515 }
2516 
2517 
2518 /*
2519  * numeric_mul_opt_error() -
2520  *
2521  *	Internal version of numeric_mul().  If "*have_error" flag is provided,
2522  *	on error it's set to true, NULL returned.  This is helpful when caller
2523  *	need to handle errors by itself.
2524  */
2525 Numeric
numeric_mul_opt_error(Numeric num1,Numeric num2,bool * have_error)2526 numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error)
2527 {
2528 	NumericVar	arg1;
2529 	NumericVar	arg2;
2530 	NumericVar	result;
2531 	Numeric		res;
2532 
2533 	/*
2534 	 * Handle NaN
2535 	 */
2536 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2537 		return make_result(&const_nan);
2538 
2539 	/*
2540 	 * Unpack the values, let mul_var() compute the result and return it.
2541 	 * Unlike add_var() and sub_var(), mul_var() will round its result. In the
2542 	 * case of numeric_mul(), which is invoked for the * operator on numerics,
2543 	 * we request exact representation for the product (rscale = sum(dscale of
2544 	 * arg1, dscale of arg2)).  If the exact result has more digits after the
2545 	 * decimal point than can be stored in a numeric, we round it.  Rounding
2546 	 * after computing the exact result ensures that the final result is
2547 	 * correctly rounded (rounding in mul_var() using a truncated product
2548 	 * would not guarantee this).
2549 	 */
2550 	init_var_from_num(num1, &arg1);
2551 	init_var_from_num(num2, &arg2);
2552 
2553 	init_var(&result);
2554 	mul_var(&arg1, &arg2, &result, arg1.dscale + arg2.dscale);
2555 
2556 	if (result.dscale > NUMERIC_DSCALE_MAX)
2557 		round_var(&result, NUMERIC_DSCALE_MAX);
2558 
2559 	res = make_result_opt_error(&result, have_error);
2560 
2561 	free_var(&result);
2562 
2563 	return res;
2564 }
2565 
2566 
2567 /*
2568  * numeric_div() -
2569  *
2570  *	Divide one numeric into another
2571  */
2572 Datum
numeric_div(PG_FUNCTION_ARGS)2573 numeric_div(PG_FUNCTION_ARGS)
2574 {
2575 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2576 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2577 	Numeric		res;
2578 
2579 	res = numeric_div_opt_error(num1, num2, NULL);
2580 
2581 	PG_RETURN_NUMERIC(res);
2582 }
2583 
2584 
2585 /*
2586  * numeric_div_opt_error() -
2587  *
2588  *	Internal version of numeric_div().  If "*have_error" flag is provided,
2589  *	on error it's set to true, NULL returned.  This is helpful when caller
2590  *	need to handle errors by itself.
2591  */
2592 Numeric
numeric_div_opt_error(Numeric num1,Numeric num2,bool * have_error)2593 numeric_div_opt_error(Numeric num1, Numeric num2, bool *have_error)
2594 {
2595 	NumericVar	arg1;
2596 	NumericVar	arg2;
2597 	NumericVar	result;
2598 	Numeric		res;
2599 	int			rscale;
2600 
2601 	if (have_error)
2602 		*have_error = false;
2603 
2604 	/*
2605 	 * Handle NaN
2606 	 */
2607 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2608 		return make_result(&const_nan);
2609 
2610 	/*
2611 	 * Unpack the arguments
2612 	 */
2613 	init_var_from_num(num1, &arg1);
2614 	init_var_from_num(num2, &arg2);
2615 
2616 	init_var(&result);
2617 
2618 	/*
2619 	 * Select scale for division result
2620 	 */
2621 	rscale = select_div_scale(&arg1, &arg2);
2622 
2623 	/*
2624 	 * If "have_error" is provided, check for division by zero here
2625 	 */
2626 	if (have_error && (arg2.ndigits == 0 || arg2.digits[0] == 0))
2627 	{
2628 		*have_error = true;
2629 		return NULL;
2630 	}
2631 
2632 	/*
2633 	 * Do the divide and return the result
2634 	 */
2635 	div_var(&arg1, &arg2, &result, rscale, true);
2636 
2637 	res = make_result_opt_error(&result, have_error);
2638 
2639 	free_var(&result);
2640 
2641 	return res;
2642 }
2643 
2644 
2645 /*
2646  * numeric_div_trunc() -
2647  *
2648  *	Divide one numeric into another, truncating the result to an integer
2649  */
2650 Datum
numeric_div_trunc(PG_FUNCTION_ARGS)2651 numeric_div_trunc(PG_FUNCTION_ARGS)
2652 {
2653 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2654 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2655 	NumericVar	arg1;
2656 	NumericVar	arg2;
2657 	NumericVar	result;
2658 	Numeric		res;
2659 
2660 	/*
2661 	 * Handle NaN
2662 	 */
2663 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2664 		PG_RETURN_NUMERIC(make_result(&const_nan));
2665 
2666 	/*
2667 	 * Unpack the arguments
2668 	 */
2669 	init_var_from_num(num1, &arg1);
2670 	init_var_from_num(num2, &arg2);
2671 
2672 	init_var(&result);
2673 
2674 	/*
2675 	 * Do the divide and return the result
2676 	 */
2677 	div_var(&arg1, &arg2, &result, 0, false);
2678 
2679 	res = make_result(&result);
2680 
2681 	free_var(&result);
2682 
2683 	PG_RETURN_NUMERIC(res);
2684 }
2685 
2686 
2687 /*
2688  * numeric_mod() -
2689  *
2690  *	Calculate the modulo of two numerics
2691  */
2692 Datum
numeric_mod(PG_FUNCTION_ARGS)2693 numeric_mod(PG_FUNCTION_ARGS)
2694 {
2695 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2696 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2697 	Numeric		res;
2698 
2699 	res = numeric_mod_opt_error(num1, num2, NULL);
2700 
2701 	PG_RETURN_NUMERIC(res);
2702 }
2703 
2704 
2705 /*
2706  * numeric_mod_opt_error() -
2707  *
2708  *	Internal version of numeric_mod().  If "*have_error" flag is provided,
2709  *	on error it's set to true, NULL returned.  This is helpful when caller
2710  *	need to handle errors by itself.
2711  */
2712 Numeric
numeric_mod_opt_error(Numeric num1,Numeric num2,bool * have_error)2713 numeric_mod_opt_error(Numeric num1, Numeric num2, bool *have_error)
2714 {
2715 	Numeric		res;
2716 	NumericVar	arg1;
2717 	NumericVar	arg2;
2718 	NumericVar	result;
2719 
2720 	if (have_error)
2721 		*have_error = false;
2722 
2723 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2724 		return make_result(&const_nan);
2725 
2726 	init_var_from_num(num1, &arg1);
2727 	init_var_from_num(num2, &arg2);
2728 
2729 	init_var(&result);
2730 
2731 	/*
2732 	 * If "have_error" is provided, check for division by zero here
2733 	 */
2734 	if (have_error && (arg2.ndigits == 0 || arg2.digits[0] == 0))
2735 	{
2736 		*have_error = true;
2737 		return NULL;
2738 	}
2739 
2740 	mod_var(&arg1, &arg2, &result);
2741 
2742 	res = make_result_opt_error(&result, NULL);
2743 
2744 	free_var(&result);
2745 
2746 	return res;
2747 }
2748 
2749 
2750 /*
2751  * numeric_inc() -
2752  *
2753  *	Increment a number by one
2754  */
2755 Datum
numeric_inc(PG_FUNCTION_ARGS)2756 numeric_inc(PG_FUNCTION_ARGS)
2757 {
2758 	Numeric		num = PG_GETARG_NUMERIC(0);
2759 	NumericVar	arg;
2760 	Numeric		res;
2761 
2762 	/*
2763 	 * Handle NaN
2764 	 */
2765 	if (NUMERIC_IS_NAN(num))
2766 		PG_RETURN_NUMERIC(make_result(&const_nan));
2767 
2768 	/*
2769 	 * Compute the result and return it
2770 	 */
2771 	init_var_from_num(num, &arg);
2772 
2773 	add_var(&arg, &const_one, &arg);
2774 
2775 	res = make_result(&arg);
2776 
2777 	free_var(&arg);
2778 
2779 	PG_RETURN_NUMERIC(res);
2780 }
2781 
2782 
2783 /*
2784  * numeric_smaller() -
2785  *
2786  *	Return the smaller of two numbers
2787  */
2788 Datum
numeric_smaller(PG_FUNCTION_ARGS)2789 numeric_smaller(PG_FUNCTION_ARGS)
2790 {
2791 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2792 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2793 
2794 	/*
2795 	 * Use cmp_numerics so that this will agree with the comparison operators,
2796 	 * particularly as regards comparisons involving NaN.
2797 	 */
2798 	if (cmp_numerics(num1, num2) < 0)
2799 		PG_RETURN_NUMERIC(num1);
2800 	else
2801 		PG_RETURN_NUMERIC(num2);
2802 }
2803 
2804 
2805 /*
2806  * numeric_larger() -
2807  *
2808  *	Return the larger of two numbers
2809  */
2810 Datum
numeric_larger(PG_FUNCTION_ARGS)2811 numeric_larger(PG_FUNCTION_ARGS)
2812 {
2813 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2814 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2815 
2816 	/*
2817 	 * Use cmp_numerics so that this will agree with the comparison operators,
2818 	 * particularly as regards comparisons involving NaN.
2819 	 */
2820 	if (cmp_numerics(num1, num2) > 0)
2821 		PG_RETURN_NUMERIC(num1);
2822 	else
2823 		PG_RETURN_NUMERIC(num2);
2824 }
2825 
2826 
2827 /* ----------------------------------------------------------------------
2828  *
2829  * Advanced math functions
2830  *
2831  * ----------------------------------------------------------------------
2832  */
2833 
2834 /*
2835  * numeric_gcd() -
2836  *
2837  *	Calculate the greatest common divisor of two numerics
2838  */
2839 Datum
numeric_gcd(PG_FUNCTION_ARGS)2840 numeric_gcd(PG_FUNCTION_ARGS)
2841 {
2842 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2843 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2844 	NumericVar	arg1;
2845 	NumericVar	arg2;
2846 	NumericVar	result;
2847 	Numeric		res;
2848 
2849 	/*
2850 	 * Handle NaN
2851 	 */
2852 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2853 		PG_RETURN_NUMERIC(make_result(&const_nan));
2854 
2855 	/*
2856 	 * Unpack the arguments
2857 	 */
2858 	init_var_from_num(num1, &arg1);
2859 	init_var_from_num(num2, &arg2);
2860 
2861 	init_var(&result);
2862 
2863 	/*
2864 	 * Find the GCD and return the result
2865 	 */
2866 	gcd_var(&arg1, &arg2, &result);
2867 
2868 	res = make_result(&result);
2869 
2870 	free_var(&result);
2871 
2872 	PG_RETURN_NUMERIC(res);
2873 }
2874 
2875 
2876 /*
2877  * numeric_lcm() -
2878  *
2879  *	Calculate the least common multiple of two numerics
2880  */
2881 Datum
numeric_lcm(PG_FUNCTION_ARGS)2882 numeric_lcm(PG_FUNCTION_ARGS)
2883 {
2884 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2885 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2886 	NumericVar	arg1;
2887 	NumericVar	arg2;
2888 	NumericVar	result;
2889 	Numeric		res;
2890 
2891 	/*
2892 	 * Handle NaN
2893 	 */
2894 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2895 		PG_RETURN_NUMERIC(make_result(&const_nan));
2896 
2897 	/*
2898 	 * Unpack the arguments
2899 	 */
2900 	init_var_from_num(num1, &arg1);
2901 	init_var_from_num(num2, &arg2);
2902 
2903 	init_var(&result);
2904 
2905 	/*
2906 	 * Compute the result using lcm(x, y) = abs(x / gcd(x, y) * y), returning
2907 	 * zero if either input is zero.
2908 	 *
2909 	 * Note that the division is guaranteed to be exact, returning an integer
2910 	 * result, so the LCM is an integral multiple of both x and y.  A display
2911 	 * scale of Min(x.dscale, y.dscale) would be sufficient to represent it,
2912 	 * but as with other numeric functions, we choose to return a result whose
2913 	 * display scale is no smaller than either input.
2914 	 */
2915 	if (arg1.ndigits == 0 || arg2.ndigits == 0)
2916 		set_var_from_var(&const_zero, &result);
2917 	else
2918 	{
2919 		gcd_var(&arg1, &arg2, &result);
2920 		div_var(&arg1, &result, &result, 0, false);
2921 		mul_var(&arg2, &result, &result, arg2.dscale);
2922 		result.sign = NUMERIC_POS;
2923 	}
2924 
2925 	result.dscale = Max(arg1.dscale, arg2.dscale);
2926 
2927 	res = make_result(&result);
2928 
2929 	free_var(&result);
2930 
2931 	PG_RETURN_NUMERIC(res);
2932 }
2933 
2934 
2935 /*
2936  * numeric_fac()
2937  *
2938  * Compute factorial
2939  */
2940 Datum
numeric_fac(PG_FUNCTION_ARGS)2941 numeric_fac(PG_FUNCTION_ARGS)
2942 {
2943 	int64		num = PG_GETARG_INT64(0);
2944 	Numeric		res;
2945 	NumericVar	fact;
2946 	NumericVar	result;
2947 
2948 	if (num <= 1)
2949 	{
2950 		res = make_result(&const_one);
2951 		PG_RETURN_NUMERIC(res);
2952 	}
2953 	/* Fail immediately if the result would overflow */
2954 	if (num > 32177)
2955 		ereport(ERROR,
2956 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2957 				 errmsg("value overflows numeric format")));
2958 
2959 	init_var(&fact);
2960 	init_var(&result);
2961 
2962 	int64_to_numericvar(num, &result);
2963 
2964 	for (num = num - 1; num > 1; num--)
2965 	{
2966 		/* this loop can take awhile, so allow it to be interrupted */
2967 		CHECK_FOR_INTERRUPTS();
2968 
2969 		int64_to_numericvar(num, &fact);
2970 
2971 		mul_var(&result, &fact, &result, 0);
2972 	}
2973 
2974 	res = make_result(&result);
2975 
2976 	free_var(&fact);
2977 	free_var(&result);
2978 
2979 	PG_RETURN_NUMERIC(res);
2980 }
2981 
2982 
2983 /*
2984  * numeric_sqrt() -
2985  *
2986  *	Compute the square root of a numeric.
2987  */
2988 Datum
numeric_sqrt(PG_FUNCTION_ARGS)2989 numeric_sqrt(PG_FUNCTION_ARGS)
2990 {
2991 	Numeric		num = PG_GETARG_NUMERIC(0);
2992 	Numeric		res;
2993 	NumericVar	arg;
2994 	NumericVar	result;
2995 	int			sweight;
2996 	int			rscale;
2997 
2998 	/*
2999 	 * Handle NaN
3000 	 */
3001 	if (NUMERIC_IS_NAN(num))
3002 		PG_RETURN_NUMERIC(make_result(&const_nan));
3003 
3004 	/*
3005 	 * Unpack the argument and determine the result scale.  We choose a scale
3006 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
3007 	 * case not less than the input's dscale.
3008 	 */
3009 	init_var_from_num(num, &arg);
3010 
3011 	init_var(&result);
3012 
3013 	/* Assume the input was normalized, so arg.weight is accurate */
3014 	sweight = (arg.weight + 1) * DEC_DIGITS / 2 - 1;
3015 
3016 	rscale = NUMERIC_MIN_SIG_DIGITS - sweight;
3017 	rscale = Max(rscale, arg.dscale);
3018 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3019 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3020 
3021 	/*
3022 	 * Let sqrt_var() do the calculation and return the result.
3023 	 */
3024 	sqrt_var(&arg, &result, rscale);
3025 
3026 	res = make_result(&result);
3027 
3028 	free_var(&result);
3029 
3030 	PG_RETURN_NUMERIC(res);
3031 }
3032 
3033 
3034 /*
3035  * numeric_exp() -
3036  *
3037  *	Raise e to the power of x
3038  */
3039 Datum
numeric_exp(PG_FUNCTION_ARGS)3040 numeric_exp(PG_FUNCTION_ARGS)
3041 {
3042 	Numeric		num = PG_GETARG_NUMERIC(0);
3043 	Numeric		res;
3044 	NumericVar	arg;
3045 	NumericVar	result;
3046 	int			rscale;
3047 	double		val;
3048 
3049 	/*
3050 	 * Handle NaN
3051 	 */
3052 	if (NUMERIC_IS_NAN(num))
3053 		PG_RETURN_NUMERIC(make_result(&const_nan));
3054 
3055 	/*
3056 	 * Unpack the argument and determine the result scale.  We choose a scale
3057 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
3058 	 * case not less than the input's dscale.
3059 	 */
3060 	init_var_from_num(num, &arg);
3061 
3062 	init_var(&result);
3063 
3064 	/* convert input to float8, ignoring overflow */
3065 	val = numericvar_to_double_no_overflow(&arg);
3066 
3067 	/*
3068 	 * log10(result) = num * log10(e), so this is approximately the decimal
3069 	 * weight of the result:
3070 	 */
3071 	val *= 0.434294481903252;
3072 
3073 	/* limit to something that won't cause integer overflow */
3074 	val = Max(val, -NUMERIC_MAX_RESULT_SCALE);
3075 	val = Min(val, NUMERIC_MAX_RESULT_SCALE);
3076 
3077 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
3078 	rscale = Max(rscale, arg.dscale);
3079 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3080 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3081 
3082 	/*
3083 	 * Let exp_var() do the calculation and return the result.
3084 	 */
3085 	exp_var(&arg, &result, rscale);
3086 
3087 	res = make_result(&result);
3088 
3089 	free_var(&result);
3090 
3091 	PG_RETURN_NUMERIC(res);
3092 }
3093 
3094 
3095 /*
3096  * numeric_ln() -
3097  *
3098  *	Compute the natural logarithm of x
3099  */
3100 Datum
numeric_ln(PG_FUNCTION_ARGS)3101 numeric_ln(PG_FUNCTION_ARGS)
3102 {
3103 	Numeric		num = PG_GETARG_NUMERIC(0);
3104 	Numeric		res;
3105 	NumericVar	arg;
3106 	NumericVar	result;
3107 	int			ln_dweight;
3108 	int			rscale;
3109 
3110 	/*
3111 	 * Handle NaN
3112 	 */
3113 	if (NUMERIC_IS_NAN(num))
3114 		PG_RETURN_NUMERIC(make_result(&const_nan));
3115 
3116 	init_var_from_num(num, &arg);
3117 	init_var(&result);
3118 
3119 	/* Estimated dweight of logarithm */
3120 	ln_dweight = estimate_ln_dweight(&arg);
3121 
3122 	rscale = NUMERIC_MIN_SIG_DIGITS - ln_dweight;
3123 	rscale = Max(rscale, arg.dscale);
3124 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3125 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3126 
3127 	ln_var(&arg, &result, rscale);
3128 
3129 	res = make_result(&result);
3130 
3131 	free_var(&result);
3132 
3133 	PG_RETURN_NUMERIC(res);
3134 }
3135 
3136 
3137 /*
3138  * numeric_log() -
3139  *
3140  *	Compute the logarithm of x in a given base
3141  */
3142 Datum
numeric_log(PG_FUNCTION_ARGS)3143 numeric_log(PG_FUNCTION_ARGS)
3144 {
3145 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3146 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3147 	Numeric		res;
3148 	NumericVar	arg1;
3149 	NumericVar	arg2;
3150 	NumericVar	result;
3151 
3152 	/*
3153 	 * Handle NaN
3154 	 */
3155 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
3156 		PG_RETURN_NUMERIC(make_result(&const_nan));
3157 
3158 	/*
3159 	 * Initialize things
3160 	 */
3161 	init_var_from_num(num1, &arg1);
3162 	init_var_from_num(num2, &arg2);
3163 	init_var(&result);
3164 
3165 	/*
3166 	 * Call log_var() to compute and return the result; note it handles scale
3167 	 * selection itself.
3168 	 */
3169 	log_var(&arg1, &arg2, &result);
3170 
3171 	res = make_result(&result);
3172 
3173 	free_var(&result);
3174 
3175 	PG_RETURN_NUMERIC(res);
3176 }
3177 
3178 
3179 /*
3180  * numeric_power() -
3181  *
3182  *	Raise b to the power of x
3183  */
3184 Datum
numeric_power(PG_FUNCTION_ARGS)3185 numeric_power(PG_FUNCTION_ARGS)
3186 {
3187 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3188 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3189 	Numeric		res;
3190 	NumericVar	arg1;
3191 	NumericVar	arg2;
3192 	NumericVar	arg2_trunc;
3193 	NumericVar	result;
3194 
3195 	/*
3196 	 * Handle NaN cases.  We follow the POSIX spec for pow(3), which says that
3197 	 * NaN ^ 0 = 1, and 1 ^ NaN = 1, while all other cases with NaN inputs
3198 	 * yield NaN (with no error).
3199 	 */
3200 	if (NUMERIC_IS_NAN(num1))
3201 	{
3202 		if (!NUMERIC_IS_NAN(num2))
3203 		{
3204 			init_var_from_num(num2, &arg2);
3205 			if (cmp_var(&arg2, &const_zero) == 0)
3206 				PG_RETURN_NUMERIC(make_result(&const_one));
3207 		}
3208 		PG_RETURN_NUMERIC(make_result(&const_nan));
3209 	}
3210 	if (NUMERIC_IS_NAN(num2))
3211 	{
3212 		init_var_from_num(num1, &arg1);
3213 		if (cmp_var(&arg1, &const_one) == 0)
3214 			PG_RETURN_NUMERIC(make_result(&const_one));
3215 		PG_RETURN_NUMERIC(make_result(&const_nan));
3216 	}
3217 
3218 	/*
3219 	 * Initialize things
3220 	 */
3221 	init_var(&arg2_trunc);
3222 	init_var(&result);
3223 	init_var_from_num(num1, &arg1);
3224 	init_var_from_num(num2, &arg2);
3225 
3226 	set_var_from_var(&arg2, &arg2_trunc);
3227 	trunc_var(&arg2_trunc, 0);
3228 
3229 	/*
3230 	 * The SQL spec requires that we emit a particular SQLSTATE error code for
3231 	 * certain error conditions.  Specifically, we don't return a
3232 	 * divide-by-zero error code for 0 ^ -1.  Raising a negative number to a
3233 	 * non-integer power must produce the same error code, but that case is
3234 	 * handled in power_var().
3235 	 */
3236 	if (cmp_var(&arg1, &const_zero) == 0 &&
3237 		cmp_var(&arg2, &const_zero) < 0)
3238 		ereport(ERROR,
3239 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
3240 				 errmsg("zero raised to a negative power is undefined")));
3241 
3242 	/*
3243 	 * Call power_var() to compute and return the result; note it handles
3244 	 * scale selection itself.
3245 	 */
3246 	power_var(&arg1, &arg2, &result);
3247 
3248 	res = make_result(&result);
3249 
3250 	free_var(&result);
3251 	free_var(&arg2_trunc);
3252 
3253 	PG_RETURN_NUMERIC(res);
3254 }
3255 
3256 /*
3257  * numeric_scale() -
3258  *
3259  *	Returns the scale, i.e. the count of decimal digits in the fractional part
3260  */
3261 Datum
numeric_scale(PG_FUNCTION_ARGS)3262 numeric_scale(PG_FUNCTION_ARGS)
3263 {
3264 	Numeric		num = PG_GETARG_NUMERIC(0);
3265 
3266 	if (NUMERIC_IS_NAN(num))
3267 		PG_RETURN_NULL();
3268 
3269 	PG_RETURN_INT32(NUMERIC_DSCALE(num));
3270 }
3271 
3272 /*
3273  * Calculate minimum scale for value.
3274  */
3275 static int
get_min_scale(NumericVar * var)3276 get_min_scale(NumericVar *var)
3277 {
3278 	int			min_scale;
3279 	int			last_digit_pos;
3280 
3281 	/*
3282 	 * Ordinarily, the input value will be "stripped" so that the last
3283 	 * NumericDigit is nonzero.  But we don't want to get into an infinite
3284 	 * loop if it isn't, so explicitly find the last nonzero digit.
3285 	 */
3286 	last_digit_pos = var->ndigits - 1;
3287 	while (last_digit_pos >= 0 &&
3288 		   var->digits[last_digit_pos] == 0)
3289 		last_digit_pos--;
3290 
3291 	if (last_digit_pos >= 0)
3292 	{
3293 		/* compute min_scale assuming that last ndigit has no zeroes */
3294 		min_scale = (last_digit_pos - var->weight) * DEC_DIGITS;
3295 
3296 		/*
3297 		 * We could get a negative result if there are no digits after the
3298 		 * decimal point.  In this case the min_scale must be zero.
3299 		 */
3300 		if (min_scale > 0)
3301 		{
3302 			/*
3303 			 * Reduce min_scale if trailing digit(s) in last NumericDigit are
3304 			 * zero.
3305 			 */
3306 			NumericDigit last_digit = var->digits[last_digit_pos];
3307 
3308 			while (last_digit % 10 == 0)
3309 			{
3310 				min_scale--;
3311 				last_digit /= 10;
3312 			}
3313 		}
3314 		else
3315 			min_scale = 0;
3316 	}
3317 	else
3318 		min_scale = 0;			/* result if input is zero */
3319 
3320 	return min_scale;
3321 }
3322 
3323 /*
3324  * Returns minimum scale required to represent supplied value without loss.
3325  */
3326 Datum
numeric_min_scale(PG_FUNCTION_ARGS)3327 numeric_min_scale(PG_FUNCTION_ARGS)
3328 {
3329 	Numeric		num = PG_GETARG_NUMERIC(0);
3330 	NumericVar	arg;
3331 	int			min_scale;
3332 
3333 	if (NUMERIC_IS_NAN(num))
3334 		PG_RETURN_NULL();
3335 
3336 	init_var_from_num(num, &arg);
3337 	min_scale = get_min_scale(&arg);
3338 	free_var(&arg);
3339 
3340 	PG_RETURN_INT32(min_scale);
3341 }
3342 
3343 /*
3344  * Reduce scale of numeric value to represent supplied value without loss.
3345  */
3346 Datum
numeric_trim_scale(PG_FUNCTION_ARGS)3347 numeric_trim_scale(PG_FUNCTION_ARGS)
3348 {
3349 	Numeric		num = PG_GETARG_NUMERIC(0);
3350 	Numeric		res;
3351 	NumericVar	result;
3352 
3353 	if (NUMERIC_IS_NAN(num))
3354 		PG_RETURN_NUMERIC(make_result(&const_nan));
3355 
3356 	init_var_from_num(num, &result);
3357 	result.dscale = get_min_scale(&result);
3358 	res = make_result(&result);
3359 	free_var(&result);
3360 
3361 	PG_RETURN_NUMERIC(res);
3362 }
3363 
3364 
3365 /* ----------------------------------------------------------------------
3366  *
3367  * Type conversion functions
3368  *
3369  * ----------------------------------------------------------------------
3370  */
3371 
3372 
3373 Datum
int4_numeric(PG_FUNCTION_ARGS)3374 int4_numeric(PG_FUNCTION_ARGS)
3375 {
3376 	int32		val = PG_GETARG_INT32(0);
3377 	Numeric		res;
3378 	NumericVar	result;
3379 
3380 	init_var(&result);
3381 
3382 	int64_to_numericvar((int64) val, &result);
3383 
3384 	res = make_result(&result);
3385 
3386 	free_var(&result);
3387 
3388 	PG_RETURN_NUMERIC(res);
3389 }
3390 
3391 int32
numeric_int4_opt_error(Numeric num,bool * have_error)3392 numeric_int4_opt_error(Numeric num, bool *have_error)
3393 {
3394 	NumericVar	x;
3395 	int32		result;
3396 
3397 	if (have_error)
3398 		*have_error = false;
3399 
3400 	/* XXX would it be better to return NULL? */
3401 	if (NUMERIC_IS_NAN(num))
3402 	{
3403 		if (have_error)
3404 		{
3405 			*have_error = true;
3406 			return 0;
3407 		}
3408 		else
3409 		{
3410 			ereport(ERROR,
3411 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3412 					 errmsg("cannot convert NaN to integer")));
3413 		}
3414 	}
3415 
3416 	/* Convert to variable format, then convert to int4 */
3417 	init_var_from_num(num, &x);
3418 
3419 	if (!numericvar_to_int32(&x, &result))
3420 	{
3421 		if (have_error)
3422 		{
3423 			*have_error = true;
3424 			return 0;
3425 		}
3426 		else
3427 		{
3428 			ereport(ERROR,
3429 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3430 					 errmsg("integer out of range")));
3431 		}
3432 	}
3433 
3434 	return result;
3435 }
3436 
3437 Datum
numeric_int4(PG_FUNCTION_ARGS)3438 numeric_int4(PG_FUNCTION_ARGS)
3439 {
3440 	Numeric		num = PG_GETARG_NUMERIC(0);
3441 
3442 	PG_RETURN_INT32(numeric_int4_opt_error(num, NULL));
3443 }
3444 
3445 /*
3446  * Given a NumericVar, convert it to an int32. If the NumericVar
3447  * exceeds the range of an int32, false is returned, otherwise true is returned.
3448  * The input NumericVar is *not* free'd.
3449  */
3450 static bool
numericvar_to_int32(const NumericVar * var,int32 * result)3451 numericvar_to_int32(const NumericVar *var, int32 *result)
3452 {
3453 	int64		val;
3454 
3455 	if (!numericvar_to_int64(var, &val))
3456 		return false;
3457 
3458 	if (unlikely(val < PG_INT32_MIN) || unlikely(val > PG_INT32_MAX))
3459 		return false;
3460 
3461 	/* Down-convert to int4 */
3462 	*result = (int32) val;
3463 
3464 	return true;
3465 }
3466 
3467 Datum
int8_numeric(PG_FUNCTION_ARGS)3468 int8_numeric(PG_FUNCTION_ARGS)
3469 {
3470 	int64		val = PG_GETARG_INT64(0);
3471 	Numeric		res;
3472 	NumericVar	result;
3473 
3474 	init_var(&result);
3475 
3476 	int64_to_numericvar(val, &result);
3477 
3478 	res = make_result(&result);
3479 
3480 	free_var(&result);
3481 
3482 	PG_RETURN_NUMERIC(res);
3483 }
3484 
3485 
3486 Datum
numeric_int8(PG_FUNCTION_ARGS)3487 numeric_int8(PG_FUNCTION_ARGS)
3488 {
3489 	Numeric		num = PG_GETARG_NUMERIC(0);
3490 	NumericVar	x;
3491 	int64		result;
3492 
3493 	/* XXX would it be better to return NULL? */
3494 	if (NUMERIC_IS_NAN(num))
3495 		ereport(ERROR,
3496 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3497 				 errmsg("cannot convert NaN to bigint")));
3498 
3499 	/* Convert to variable format and thence to int8 */
3500 	init_var_from_num(num, &x);
3501 
3502 	if (!numericvar_to_int64(&x, &result))
3503 		ereport(ERROR,
3504 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3505 				 errmsg("bigint out of range")));
3506 
3507 	PG_RETURN_INT64(result);
3508 }
3509 
3510 
3511 Datum
int2_numeric(PG_FUNCTION_ARGS)3512 int2_numeric(PG_FUNCTION_ARGS)
3513 {
3514 	int16		val = PG_GETARG_INT16(0);
3515 	Numeric		res;
3516 	NumericVar	result;
3517 
3518 	init_var(&result);
3519 
3520 	int64_to_numericvar((int64) val, &result);
3521 
3522 	res = make_result(&result);
3523 
3524 	free_var(&result);
3525 
3526 	PG_RETURN_NUMERIC(res);
3527 }
3528 
3529 
3530 Datum
numeric_int2(PG_FUNCTION_ARGS)3531 numeric_int2(PG_FUNCTION_ARGS)
3532 {
3533 	Numeric		num = PG_GETARG_NUMERIC(0);
3534 	NumericVar	x;
3535 	int64		val;
3536 	int16		result;
3537 
3538 	/* XXX would it be better to return NULL? */
3539 	if (NUMERIC_IS_NAN(num))
3540 		ereport(ERROR,
3541 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3542 				 errmsg("cannot convert NaN to smallint")));
3543 
3544 	/* Convert to variable format and thence to int8 */
3545 	init_var_from_num(num, &x);
3546 
3547 	if (!numericvar_to_int64(&x, &val))
3548 		ereport(ERROR,
3549 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3550 				 errmsg("smallint out of range")));
3551 
3552 	if (unlikely(val < PG_INT16_MIN) || unlikely(val > PG_INT16_MAX))
3553 		ereport(ERROR,
3554 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3555 				 errmsg("smallint out of range")));
3556 
3557 	/* Down-convert to int2 */
3558 	result = (int16) val;
3559 
3560 	PG_RETURN_INT16(result);
3561 }
3562 
3563 
3564 Datum
float8_numeric(PG_FUNCTION_ARGS)3565 float8_numeric(PG_FUNCTION_ARGS)
3566 {
3567 	float8		val = PG_GETARG_FLOAT8(0);
3568 	Numeric		res;
3569 	NumericVar	result;
3570 	char		buf[DBL_DIG + 100];
3571 
3572 	if (isnan(val))
3573 		PG_RETURN_NUMERIC(make_result(&const_nan));
3574 
3575 	if (isinf(val))
3576 		ereport(ERROR,
3577 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3578 				 errmsg("cannot convert infinity to numeric")));
3579 
3580 	snprintf(buf, sizeof(buf), "%.*g", DBL_DIG, val);
3581 
3582 	init_var(&result);
3583 
3584 	/* Assume we need not worry about leading/trailing spaces */
3585 	(void) set_var_from_str(buf, buf, &result);
3586 
3587 	res = make_result(&result);
3588 
3589 	free_var(&result);
3590 
3591 	PG_RETURN_NUMERIC(res);
3592 }
3593 
3594 
3595 Datum
numeric_float8(PG_FUNCTION_ARGS)3596 numeric_float8(PG_FUNCTION_ARGS)
3597 {
3598 	Numeric		num = PG_GETARG_NUMERIC(0);
3599 	char	   *tmp;
3600 	Datum		result;
3601 
3602 	if (NUMERIC_IS_NAN(num))
3603 		PG_RETURN_FLOAT8(get_float8_nan());
3604 
3605 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3606 											  NumericGetDatum(num)));
3607 
3608 	result = DirectFunctionCall1(float8in, CStringGetDatum(tmp));
3609 
3610 	pfree(tmp);
3611 
3612 	PG_RETURN_DATUM(result);
3613 }
3614 
3615 
3616 /*
3617  * Convert numeric to float8; if out of range, return +/- HUGE_VAL
3618  *
3619  * (internal helper function, not directly callable from SQL)
3620  */
3621 Datum
numeric_float8_no_overflow(PG_FUNCTION_ARGS)3622 numeric_float8_no_overflow(PG_FUNCTION_ARGS)
3623 {
3624 	Numeric		num = PG_GETARG_NUMERIC(0);
3625 	double		val;
3626 
3627 	if (NUMERIC_IS_NAN(num))
3628 		PG_RETURN_FLOAT8(get_float8_nan());
3629 
3630 	val = numeric_to_double_no_overflow(num);
3631 
3632 	PG_RETURN_FLOAT8(val);
3633 }
3634 
3635 Datum
float4_numeric(PG_FUNCTION_ARGS)3636 float4_numeric(PG_FUNCTION_ARGS)
3637 {
3638 	float4		val = PG_GETARG_FLOAT4(0);
3639 	Numeric		res;
3640 	NumericVar	result;
3641 	char		buf[FLT_DIG + 100];
3642 
3643 	if (isnan(val))
3644 		PG_RETURN_NUMERIC(make_result(&const_nan));
3645 
3646 	if (isinf(val))
3647 		ereport(ERROR,
3648 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3649 				 errmsg("cannot convert infinity to numeric")));
3650 
3651 	snprintf(buf, sizeof(buf), "%.*g", FLT_DIG, val);
3652 
3653 	init_var(&result);
3654 
3655 	/* Assume we need not worry about leading/trailing spaces */
3656 	(void) set_var_from_str(buf, buf, &result);
3657 
3658 	res = make_result(&result);
3659 
3660 	free_var(&result);
3661 
3662 	PG_RETURN_NUMERIC(res);
3663 }
3664 
3665 
3666 Datum
numeric_float4(PG_FUNCTION_ARGS)3667 numeric_float4(PG_FUNCTION_ARGS)
3668 {
3669 	Numeric		num = PG_GETARG_NUMERIC(0);
3670 	char	   *tmp;
3671 	Datum		result;
3672 
3673 	if (NUMERIC_IS_NAN(num))
3674 		PG_RETURN_FLOAT4(get_float4_nan());
3675 
3676 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3677 											  NumericGetDatum(num)));
3678 
3679 	result = DirectFunctionCall1(float4in, CStringGetDatum(tmp));
3680 
3681 	pfree(tmp);
3682 
3683 	PG_RETURN_DATUM(result);
3684 }
3685 
3686 
3687 /* ----------------------------------------------------------------------
3688  *
3689  * Aggregate functions
3690  *
3691  * The transition datatype for all these aggregates is declared as INTERNAL.
3692  * Actually, it's a pointer to a NumericAggState allocated in the aggregate
3693  * context.  The digit buffers for the NumericVars will be there too.
3694  *
3695  * On platforms which support 128-bit integers some aggregates instead use a
3696  * 128-bit integer based transition datatype to speed up calculations.
3697  *
3698  * ----------------------------------------------------------------------
3699  */
3700 
3701 typedef struct NumericAggState
3702 {
3703 	bool		calcSumX2;		/* if true, calculate sumX2 */
3704 	MemoryContext agg_context;	/* context we're calculating in */
3705 	int64		N;				/* count of processed numbers */
3706 	NumericSumAccum sumX;		/* sum of processed numbers */
3707 	NumericSumAccum sumX2;		/* sum of squares of processed numbers */
3708 	int			maxScale;		/* maximum scale seen so far */
3709 	int64		maxScaleCount;	/* number of values seen with maximum scale */
3710 	int64		NaNcount;		/* count of NaN values (not included in N!) */
3711 } NumericAggState;
3712 
3713 /*
3714  * Prepare state data for a numeric aggregate function that needs to compute
3715  * sum, count and optionally sum of squares of the input.
3716  */
3717 static NumericAggState *
makeNumericAggState(FunctionCallInfo fcinfo,bool calcSumX2)3718 makeNumericAggState(FunctionCallInfo fcinfo, bool calcSumX2)
3719 {
3720 	NumericAggState *state;
3721 	MemoryContext agg_context;
3722 	MemoryContext old_context;
3723 
3724 	if (!AggCheckCallContext(fcinfo, &agg_context))
3725 		elog(ERROR, "aggregate function called in non-aggregate context");
3726 
3727 	old_context = MemoryContextSwitchTo(agg_context);
3728 
3729 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3730 	state->calcSumX2 = calcSumX2;
3731 	state->agg_context = agg_context;
3732 
3733 	MemoryContextSwitchTo(old_context);
3734 
3735 	return state;
3736 }
3737 
3738 /*
3739  * Like makeNumericAggState(), but allocate the state in the current memory
3740  * context.
3741  */
3742 static NumericAggState *
makeNumericAggStateCurrentContext(bool calcSumX2)3743 makeNumericAggStateCurrentContext(bool calcSumX2)
3744 {
3745 	NumericAggState *state;
3746 
3747 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3748 	state->calcSumX2 = calcSumX2;
3749 	state->agg_context = CurrentMemoryContext;
3750 
3751 	return state;
3752 }
3753 
3754 /*
3755  * Accumulate a new input value for numeric aggregate functions.
3756  */
3757 static void
do_numeric_accum(NumericAggState * state,Numeric newval)3758 do_numeric_accum(NumericAggState *state, Numeric newval)
3759 {
3760 	NumericVar	X;
3761 	NumericVar	X2;
3762 	MemoryContext old_context;
3763 
3764 	/* Count NaN inputs separately from all else */
3765 	if (NUMERIC_IS_NAN(newval))
3766 	{
3767 		state->NaNcount++;
3768 		return;
3769 	}
3770 
3771 	/* load processed number in short-lived context */
3772 	init_var_from_num(newval, &X);
3773 
3774 	/*
3775 	 * Track the highest input dscale that we've seen, to support inverse
3776 	 * transitions (see do_numeric_discard).
3777 	 */
3778 	if (X.dscale > state->maxScale)
3779 	{
3780 		state->maxScale = X.dscale;
3781 		state->maxScaleCount = 1;
3782 	}
3783 	else if (X.dscale == state->maxScale)
3784 		state->maxScaleCount++;
3785 
3786 	/* if we need X^2, calculate that in short-lived context */
3787 	if (state->calcSumX2)
3788 	{
3789 		init_var(&X2);
3790 		mul_var(&X, &X, &X2, X.dscale * 2);
3791 	}
3792 
3793 	/* The rest of this needs to work in the aggregate context */
3794 	old_context = MemoryContextSwitchTo(state->agg_context);
3795 
3796 	state->N++;
3797 
3798 	/* Accumulate sums */
3799 	accum_sum_add(&(state->sumX), &X);
3800 
3801 	if (state->calcSumX2)
3802 		accum_sum_add(&(state->sumX2), &X2);
3803 
3804 	MemoryContextSwitchTo(old_context);
3805 }
3806 
3807 /*
3808  * Attempt to remove an input value from the aggregated state.
3809  *
3810  * If the value cannot be removed then the function will return false; the
3811  * possible reasons for failing are described below.
3812  *
3813  * If we aggregate the values 1.01 and 2 then the result will be 3.01.
3814  * If we are then asked to un-aggregate the 1.01 then we must fail as we
3815  * won't be able to tell what the new aggregated value's dscale should be.
3816  * We don't want to return 2.00 (dscale = 2), since the sum's dscale would
3817  * have been zero if we'd really aggregated only 2.
3818  *
3819  * Note: alternatively, we could count the number of inputs with each possible
3820  * dscale (up to some sane limit).  Not yet clear if it's worth the trouble.
3821  */
3822 static bool
do_numeric_discard(NumericAggState * state,Numeric newval)3823 do_numeric_discard(NumericAggState *state, Numeric newval)
3824 {
3825 	NumericVar	X;
3826 	NumericVar	X2;
3827 	MemoryContext old_context;
3828 
3829 	/* Count NaN inputs separately from all else */
3830 	if (NUMERIC_IS_NAN(newval))
3831 	{
3832 		state->NaNcount--;
3833 		return true;
3834 	}
3835 
3836 	/* load processed number in short-lived context */
3837 	init_var_from_num(newval, &X);
3838 
3839 	/*
3840 	 * state->sumX's dscale is the maximum dscale of any of the inputs.
3841 	 * Removing the last input with that dscale would require us to recompute
3842 	 * the maximum dscale of the *remaining* inputs, which we cannot do unless
3843 	 * no more non-NaN inputs remain at all.  So we report a failure instead,
3844 	 * and force the aggregation to be redone from scratch.
3845 	 */
3846 	if (X.dscale == state->maxScale)
3847 	{
3848 		if (state->maxScaleCount > 1 || state->maxScale == 0)
3849 		{
3850 			/*
3851 			 * Some remaining inputs have same dscale, or dscale hasn't gotten
3852 			 * above zero anyway
3853 			 */
3854 			state->maxScaleCount--;
3855 		}
3856 		else if (state->N == 1)
3857 		{
3858 			/* No remaining non-NaN inputs at all, so reset maxScale */
3859 			state->maxScale = 0;
3860 			state->maxScaleCount = 0;
3861 		}
3862 		else
3863 		{
3864 			/* Correct new maxScale is uncertain, must fail */
3865 			return false;
3866 		}
3867 	}
3868 
3869 	/* if we need X^2, calculate that in short-lived context */
3870 	if (state->calcSumX2)
3871 	{
3872 		init_var(&X2);
3873 		mul_var(&X, &X, &X2, X.dscale * 2);
3874 	}
3875 
3876 	/* The rest of this needs to work in the aggregate context */
3877 	old_context = MemoryContextSwitchTo(state->agg_context);
3878 
3879 	if (state->N-- > 1)
3880 	{
3881 		/* Negate X, to subtract it from the sum */
3882 		X.sign = (X.sign == NUMERIC_POS ? NUMERIC_NEG : NUMERIC_POS);
3883 		accum_sum_add(&(state->sumX), &X);
3884 
3885 		if (state->calcSumX2)
3886 		{
3887 			/* Negate X^2. X^2 is always positive */
3888 			X2.sign = NUMERIC_NEG;
3889 			accum_sum_add(&(state->sumX2), &X2);
3890 		}
3891 	}
3892 	else
3893 	{
3894 		/* Zero the sums */
3895 		Assert(state->N == 0);
3896 
3897 		accum_sum_reset(&state->sumX);
3898 		if (state->calcSumX2)
3899 			accum_sum_reset(&state->sumX2);
3900 	}
3901 
3902 	MemoryContextSwitchTo(old_context);
3903 
3904 	return true;
3905 }
3906 
3907 /*
3908  * Generic transition function for numeric aggregates that require sumX2.
3909  */
3910 Datum
numeric_accum(PG_FUNCTION_ARGS)3911 numeric_accum(PG_FUNCTION_ARGS)
3912 {
3913 	NumericAggState *state;
3914 
3915 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3916 
3917 	/* Create the state data on the first call */
3918 	if (state == NULL)
3919 		state = makeNumericAggState(fcinfo, true);
3920 
3921 	if (!PG_ARGISNULL(1))
3922 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
3923 
3924 	PG_RETURN_POINTER(state);
3925 }
3926 
3927 /*
3928  * Generic combine function for numeric aggregates which require sumX2
3929  */
3930 Datum
numeric_combine(PG_FUNCTION_ARGS)3931 numeric_combine(PG_FUNCTION_ARGS)
3932 {
3933 	NumericAggState *state1;
3934 	NumericAggState *state2;
3935 	MemoryContext agg_context;
3936 	MemoryContext old_context;
3937 
3938 	if (!AggCheckCallContext(fcinfo, &agg_context))
3939 		elog(ERROR, "aggregate function called in non-aggregate context");
3940 
3941 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3942 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
3943 
3944 	if (state2 == NULL)
3945 		PG_RETURN_POINTER(state1);
3946 
3947 	/* manually copy all fields from state2 to state1 */
3948 	if (state1 == NULL)
3949 	{
3950 		old_context = MemoryContextSwitchTo(agg_context);
3951 
3952 		state1 = makeNumericAggStateCurrentContext(true);
3953 		state1->N = state2->N;
3954 		state1->NaNcount = state2->NaNcount;
3955 		state1->maxScale = state2->maxScale;
3956 		state1->maxScaleCount = state2->maxScaleCount;
3957 
3958 		accum_sum_copy(&state1->sumX, &state2->sumX);
3959 		accum_sum_copy(&state1->sumX2, &state2->sumX2);
3960 
3961 		MemoryContextSwitchTo(old_context);
3962 
3963 		PG_RETURN_POINTER(state1);
3964 	}
3965 
3966 	state1->N += state2->N;
3967 	state1->NaNcount += state2->NaNcount;
3968 
3969 	if (state2->N > 0)
3970 	{
3971 		/*
3972 		 * These are currently only needed for moving aggregates, but let's do
3973 		 * the right thing anyway...
3974 		 */
3975 		if (state2->maxScale > state1->maxScale)
3976 		{
3977 			state1->maxScale = state2->maxScale;
3978 			state1->maxScaleCount = state2->maxScaleCount;
3979 		}
3980 		else if (state2->maxScale == state1->maxScale)
3981 			state1->maxScaleCount += state2->maxScaleCount;
3982 
3983 		/* The rest of this needs to work in the aggregate context */
3984 		old_context = MemoryContextSwitchTo(agg_context);
3985 
3986 		/* Accumulate sums */
3987 		accum_sum_combine(&state1->sumX, &state2->sumX);
3988 		accum_sum_combine(&state1->sumX2, &state2->sumX2);
3989 
3990 		MemoryContextSwitchTo(old_context);
3991 	}
3992 	PG_RETURN_POINTER(state1);
3993 }
3994 
3995 /*
3996  * Generic transition function for numeric aggregates that don't require sumX2.
3997  */
3998 Datum
numeric_avg_accum(PG_FUNCTION_ARGS)3999 numeric_avg_accum(PG_FUNCTION_ARGS)
4000 {
4001 	NumericAggState *state;
4002 
4003 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4004 
4005 	/* Create the state data on the first call */
4006 	if (state == NULL)
4007 		state = makeNumericAggState(fcinfo, false);
4008 
4009 	if (!PG_ARGISNULL(1))
4010 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
4011 
4012 	PG_RETURN_POINTER(state);
4013 }
4014 
4015 /*
4016  * Combine function for numeric aggregates which don't require sumX2
4017  */
4018 Datum
numeric_avg_combine(PG_FUNCTION_ARGS)4019 numeric_avg_combine(PG_FUNCTION_ARGS)
4020 {
4021 	NumericAggState *state1;
4022 	NumericAggState *state2;
4023 	MemoryContext agg_context;
4024 	MemoryContext old_context;
4025 
4026 	if (!AggCheckCallContext(fcinfo, &agg_context))
4027 		elog(ERROR, "aggregate function called in non-aggregate context");
4028 
4029 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4030 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
4031 
4032 	if (state2 == NULL)
4033 		PG_RETURN_POINTER(state1);
4034 
4035 	/* manually copy all fields from state2 to state1 */
4036 	if (state1 == NULL)
4037 	{
4038 		old_context = MemoryContextSwitchTo(agg_context);
4039 
4040 		state1 = makeNumericAggStateCurrentContext(false);
4041 		state1->N = state2->N;
4042 		state1->NaNcount = state2->NaNcount;
4043 		state1->maxScale = state2->maxScale;
4044 		state1->maxScaleCount = state2->maxScaleCount;
4045 
4046 		accum_sum_copy(&state1->sumX, &state2->sumX);
4047 
4048 		MemoryContextSwitchTo(old_context);
4049 
4050 		PG_RETURN_POINTER(state1);
4051 	}
4052 
4053 	state1->N += state2->N;
4054 	state1->NaNcount += state2->NaNcount;
4055 
4056 	if (state2->N > 0)
4057 	{
4058 		/*
4059 		 * These are currently only needed for moving aggregates, but let's do
4060 		 * the right thing anyway...
4061 		 */
4062 		if (state2->maxScale > state1->maxScale)
4063 		{
4064 			state1->maxScale = state2->maxScale;
4065 			state1->maxScaleCount = state2->maxScaleCount;
4066 		}
4067 		else if (state2->maxScale == state1->maxScale)
4068 			state1->maxScaleCount += state2->maxScaleCount;
4069 
4070 		/* The rest of this needs to work in the aggregate context */
4071 		old_context = MemoryContextSwitchTo(agg_context);
4072 
4073 		/* Accumulate sums */
4074 		accum_sum_combine(&state1->sumX, &state2->sumX);
4075 
4076 		MemoryContextSwitchTo(old_context);
4077 	}
4078 	PG_RETURN_POINTER(state1);
4079 }
4080 
4081 /*
4082  * numeric_avg_serialize
4083  *		Serialize NumericAggState for numeric aggregates that don't require
4084  *		sumX2.
4085  */
4086 Datum
numeric_avg_serialize(PG_FUNCTION_ARGS)4087 numeric_avg_serialize(PG_FUNCTION_ARGS)
4088 {
4089 	NumericAggState *state;
4090 	StringInfoData buf;
4091 	Datum		temp;
4092 	bytea	   *sumX;
4093 	bytea	   *result;
4094 	NumericVar	tmp_var;
4095 
4096 	/* Ensure we disallow calling when not in aggregate context */
4097 	if (!AggCheckCallContext(fcinfo, NULL))
4098 		elog(ERROR, "aggregate function called in non-aggregate context");
4099 
4100 	state = (NumericAggState *) PG_GETARG_POINTER(0);
4101 
4102 	/*
4103 	 * This is a little wasteful since make_result converts the NumericVar
4104 	 * into a Numeric and numeric_send converts it back again. Is it worth
4105 	 * splitting the tasks in numeric_send into separate functions to stop
4106 	 * this? Doing so would also remove the fmgr call overhead.
4107 	 */
4108 	init_var(&tmp_var);
4109 	accum_sum_final(&state->sumX, &tmp_var);
4110 
4111 	temp = DirectFunctionCall1(numeric_send,
4112 							   NumericGetDatum(make_result(&tmp_var)));
4113 	sumX = DatumGetByteaPP(temp);
4114 	free_var(&tmp_var);
4115 
4116 	pq_begintypsend(&buf);
4117 
4118 	/* N */
4119 	pq_sendint64(&buf, state->N);
4120 
4121 	/* sumX */
4122 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4123 
4124 	/* maxScale */
4125 	pq_sendint32(&buf, state->maxScale);
4126 
4127 	/* maxScaleCount */
4128 	pq_sendint64(&buf, state->maxScaleCount);
4129 
4130 	/* NaNcount */
4131 	pq_sendint64(&buf, state->NaNcount);
4132 
4133 	result = pq_endtypsend(&buf);
4134 
4135 	PG_RETURN_BYTEA_P(result);
4136 }
4137 
4138 /*
4139  * numeric_avg_deserialize
4140  *		Deserialize bytea into NumericAggState for numeric aggregates that
4141  *		don't require sumX2.
4142  */
4143 Datum
numeric_avg_deserialize(PG_FUNCTION_ARGS)4144 numeric_avg_deserialize(PG_FUNCTION_ARGS)
4145 {
4146 	bytea	   *sstate;
4147 	NumericAggState *result;
4148 	Datum		temp;
4149 	NumericVar	tmp_var;
4150 	StringInfoData buf;
4151 
4152 	if (!AggCheckCallContext(fcinfo, NULL))
4153 		elog(ERROR, "aggregate function called in non-aggregate context");
4154 
4155 	sstate = PG_GETARG_BYTEA_PP(0);
4156 
4157 	/*
4158 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4159 	 * standard recv-function infrastructure.
4160 	 */
4161 	initStringInfo(&buf);
4162 	appendBinaryStringInfo(&buf,
4163 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4164 
4165 	result = makeNumericAggStateCurrentContext(false);
4166 
4167 	/* N */
4168 	result->N = pq_getmsgint64(&buf);
4169 
4170 	/* sumX */
4171 	temp = DirectFunctionCall3(numeric_recv,
4172 							   PointerGetDatum(&buf),
4173 							   ObjectIdGetDatum(InvalidOid),
4174 							   Int32GetDatum(-1));
4175 	init_var_from_num(DatumGetNumeric(temp), &tmp_var);
4176 	accum_sum_add(&(result->sumX), &tmp_var);
4177 
4178 	/* maxScale */
4179 	result->maxScale = pq_getmsgint(&buf, 4);
4180 
4181 	/* maxScaleCount */
4182 	result->maxScaleCount = pq_getmsgint64(&buf);
4183 
4184 	/* NaNcount */
4185 	result->NaNcount = pq_getmsgint64(&buf);
4186 
4187 	pq_getmsgend(&buf);
4188 	pfree(buf.data);
4189 
4190 	PG_RETURN_POINTER(result);
4191 }
4192 
4193 /*
4194  * numeric_serialize
4195  *		Serialization function for NumericAggState for numeric aggregates that
4196  *		require sumX2.
4197  */
4198 Datum
numeric_serialize(PG_FUNCTION_ARGS)4199 numeric_serialize(PG_FUNCTION_ARGS)
4200 {
4201 	NumericAggState *state;
4202 	StringInfoData buf;
4203 	Datum		temp;
4204 	bytea	   *sumX;
4205 	NumericVar	tmp_var;
4206 	bytea	   *sumX2;
4207 	bytea	   *result;
4208 
4209 	/* Ensure we disallow calling when not in aggregate context */
4210 	if (!AggCheckCallContext(fcinfo, NULL))
4211 		elog(ERROR, "aggregate function called in non-aggregate context");
4212 
4213 	state = (NumericAggState *) PG_GETARG_POINTER(0);
4214 
4215 	/*
4216 	 * This is a little wasteful since make_result converts the NumericVar
4217 	 * into a Numeric and numeric_send converts it back again. Is it worth
4218 	 * splitting the tasks in numeric_send into separate functions to stop
4219 	 * this? Doing so would also remove the fmgr call overhead.
4220 	 */
4221 	init_var(&tmp_var);
4222 
4223 	accum_sum_final(&state->sumX, &tmp_var);
4224 	temp = DirectFunctionCall1(numeric_send,
4225 							   NumericGetDatum(make_result(&tmp_var)));
4226 	sumX = DatumGetByteaPP(temp);
4227 
4228 	accum_sum_final(&state->sumX2, &tmp_var);
4229 	temp = DirectFunctionCall1(numeric_send,
4230 							   NumericGetDatum(make_result(&tmp_var)));
4231 	sumX2 = DatumGetByteaPP(temp);
4232 
4233 	free_var(&tmp_var);
4234 
4235 	pq_begintypsend(&buf);
4236 
4237 	/* N */
4238 	pq_sendint64(&buf, state->N);
4239 
4240 	/* sumX */
4241 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4242 
4243 	/* sumX2 */
4244 	pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
4245 
4246 	/* maxScale */
4247 	pq_sendint32(&buf, state->maxScale);
4248 
4249 	/* maxScaleCount */
4250 	pq_sendint64(&buf, state->maxScaleCount);
4251 
4252 	/* NaNcount */
4253 	pq_sendint64(&buf, state->NaNcount);
4254 
4255 	result = pq_endtypsend(&buf);
4256 
4257 	PG_RETURN_BYTEA_P(result);
4258 }
4259 
4260 /*
4261  * numeric_deserialize
4262  *		Deserialization function for NumericAggState for numeric aggregates that
4263  *		require sumX2.
4264  */
4265 Datum
numeric_deserialize(PG_FUNCTION_ARGS)4266 numeric_deserialize(PG_FUNCTION_ARGS)
4267 {
4268 	bytea	   *sstate;
4269 	NumericAggState *result;
4270 	Datum		temp;
4271 	NumericVar	sumX_var;
4272 	NumericVar	sumX2_var;
4273 	StringInfoData buf;
4274 
4275 	if (!AggCheckCallContext(fcinfo, NULL))
4276 		elog(ERROR, "aggregate function called in non-aggregate context");
4277 
4278 	sstate = PG_GETARG_BYTEA_PP(0);
4279 
4280 	/*
4281 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4282 	 * standard recv-function infrastructure.
4283 	 */
4284 	initStringInfo(&buf);
4285 	appendBinaryStringInfo(&buf,
4286 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4287 
4288 	result = makeNumericAggStateCurrentContext(false);
4289 
4290 	/* N */
4291 	result->N = pq_getmsgint64(&buf);
4292 
4293 	/* sumX */
4294 	temp = DirectFunctionCall3(numeric_recv,
4295 							   PointerGetDatum(&buf),
4296 							   ObjectIdGetDatum(InvalidOid),
4297 							   Int32GetDatum(-1));
4298 	init_var_from_num(DatumGetNumeric(temp), &sumX_var);
4299 	accum_sum_add(&(result->sumX), &sumX_var);
4300 
4301 	/* sumX2 */
4302 	temp = DirectFunctionCall3(numeric_recv,
4303 							   PointerGetDatum(&buf),
4304 							   ObjectIdGetDatum(InvalidOid),
4305 							   Int32GetDatum(-1));
4306 	init_var_from_num(DatumGetNumeric(temp), &sumX2_var);
4307 	accum_sum_add(&(result->sumX2), &sumX2_var);
4308 
4309 	/* maxScale */
4310 	result->maxScale = pq_getmsgint(&buf, 4);
4311 
4312 	/* maxScaleCount */
4313 	result->maxScaleCount = pq_getmsgint64(&buf);
4314 
4315 	/* NaNcount */
4316 	result->NaNcount = pq_getmsgint64(&buf);
4317 
4318 	pq_getmsgend(&buf);
4319 	pfree(buf.data);
4320 
4321 	PG_RETURN_POINTER(result);
4322 }
4323 
4324 /*
4325  * Generic inverse transition function for numeric aggregates
4326  * (with or without requirement for X^2).
4327  */
4328 Datum
numeric_accum_inv(PG_FUNCTION_ARGS)4329 numeric_accum_inv(PG_FUNCTION_ARGS)
4330 {
4331 	NumericAggState *state;
4332 
4333 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4334 
4335 	/* Should not get here with no state */
4336 	if (state == NULL)
4337 		elog(ERROR, "numeric_accum_inv called with NULL state");
4338 
4339 	if (!PG_ARGISNULL(1))
4340 	{
4341 		/* If we fail to perform the inverse transition, return NULL */
4342 		if (!do_numeric_discard(state, PG_GETARG_NUMERIC(1)))
4343 			PG_RETURN_NULL();
4344 	}
4345 
4346 	PG_RETURN_POINTER(state);
4347 }
4348 
4349 
4350 /*
4351  * Integer data types in general use Numeric accumulators to share code
4352  * and avoid risk of overflow.
4353  *
4354  * However for performance reasons optimized special-purpose accumulator
4355  * routines are used when possible.
4356  *
4357  * On platforms with 128-bit integer support, the 128-bit routines will be
4358  * used when sum(X) or sum(X*X) fit into 128-bit.
4359  *
4360  * For 16 and 32 bit inputs, the N and sum(X) fit into 64-bit so the 64-bit
4361  * accumulators will be used for SUM and AVG of these data types.
4362  */
4363 
4364 #ifdef HAVE_INT128
4365 typedef struct Int128AggState
4366 {
4367 	bool		calcSumX2;		/* if true, calculate sumX2 */
4368 	int64		N;				/* count of processed numbers */
4369 	int128		sumX;			/* sum of processed numbers */
4370 	int128		sumX2;			/* sum of squares of processed numbers */
4371 } Int128AggState;
4372 
4373 /*
4374  * Prepare state data for a 128-bit aggregate function that needs to compute
4375  * sum, count and optionally sum of squares of the input.
4376  */
4377 static Int128AggState *
makeInt128AggState(FunctionCallInfo fcinfo,bool calcSumX2)4378 makeInt128AggState(FunctionCallInfo fcinfo, bool calcSumX2)
4379 {
4380 	Int128AggState *state;
4381 	MemoryContext agg_context;
4382 	MemoryContext old_context;
4383 
4384 	if (!AggCheckCallContext(fcinfo, &agg_context))
4385 		elog(ERROR, "aggregate function called in non-aggregate context");
4386 
4387 	old_context = MemoryContextSwitchTo(agg_context);
4388 
4389 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
4390 	state->calcSumX2 = calcSumX2;
4391 
4392 	MemoryContextSwitchTo(old_context);
4393 
4394 	return state;
4395 }
4396 
4397 /*
4398  * Like makeInt128AggState(), but allocate the state in the current memory
4399  * context.
4400  */
4401 static Int128AggState *
makeInt128AggStateCurrentContext(bool calcSumX2)4402 makeInt128AggStateCurrentContext(bool calcSumX2)
4403 {
4404 	Int128AggState *state;
4405 
4406 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
4407 	state->calcSumX2 = calcSumX2;
4408 
4409 	return state;
4410 }
4411 
4412 /*
4413  * Accumulate a new input value for 128-bit aggregate functions.
4414  */
4415 static void
do_int128_accum(Int128AggState * state,int128 newval)4416 do_int128_accum(Int128AggState *state, int128 newval)
4417 {
4418 	if (state->calcSumX2)
4419 		state->sumX2 += newval * newval;
4420 
4421 	state->sumX += newval;
4422 	state->N++;
4423 }
4424 
4425 /*
4426  * Remove an input value from the aggregated state.
4427  */
4428 static void
do_int128_discard(Int128AggState * state,int128 newval)4429 do_int128_discard(Int128AggState *state, int128 newval)
4430 {
4431 	if (state->calcSumX2)
4432 		state->sumX2 -= newval * newval;
4433 
4434 	state->sumX -= newval;
4435 	state->N--;
4436 }
4437 
4438 typedef Int128AggState PolyNumAggState;
4439 #define makePolyNumAggState makeInt128AggState
4440 #define makePolyNumAggStateCurrentContext makeInt128AggStateCurrentContext
4441 #else
4442 typedef NumericAggState PolyNumAggState;
4443 #define makePolyNumAggState makeNumericAggState
4444 #define makePolyNumAggStateCurrentContext makeNumericAggStateCurrentContext
4445 #endif
4446 
4447 Datum
int2_accum(PG_FUNCTION_ARGS)4448 int2_accum(PG_FUNCTION_ARGS)
4449 {
4450 	PolyNumAggState *state;
4451 
4452 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4453 
4454 	/* Create the state data on the first call */
4455 	if (state == NULL)
4456 		state = makePolyNumAggState(fcinfo, true);
4457 
4458 	if (!PG_ARGISNULL(1))
4459 	{
4460 #ifdef HAVE_INT128
4461 		do_int128_accum(state, (int128) PG_GETARG_INT16(1));
4462 #else
4463 		Numeric		newval;
4464 
4465 		newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4466 													 PG_GETARG_DATUM(1)));
4467 		do_numeric_accum(state, newval);
4468 #endif
4469 	}
4470 
4471 	PG_RETURN_POINTER(state);
4472 }
4473 
4474 Datum
int4_accum(PG_FUNCTION_ARGS)4475 int4_accum(PG_FUNCTION_ARGS)
4476 {
4477 	PolyNumAggState *state;
4478 
4479 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4480 
4481 	/* Create the state data on the first call */
4482 	if (state == NULL)
4483 		state = makePolyNumAggState(fcinfo, true);
4484 
4485 	if (!PG_ARGISNULL(1))
4486 	{
4487 #ifdef HAVE_INT128
4488 		do_int128_accum(state, (int128) PG_GETARG_INT32(1));
4489 #else
4490 		Numeric		newval;
4491 
4492 		newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4493 													 PG_GETARG_DATUM(1)));
4494 		do_numeric_accum(state, newval);
4495 #endif
4496 	}
4497 
4498 	PG_RETURN_POINTER(state);
4499 }
4500 
4501 Datum
int8_accum(PG_FUNCTION_ARGS)4502 int8_accum(PG_FUNCTION_ARGS)
4503 {
4504 	NumericAggState *state;
4505 
4506 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4507 
4508 	/* Create the state data on the first call */
4509 	if (state == NULL)
4510 		state = makeNumericAggState(fcinfo, true);
4511 
4512 	if (!PG_ARGISNULL(1))
4513 	{
4514 		Numeric		newval;
4515 
4516 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4517 													 PG_GETARG_DATUM(1)));
4518 		do_numeric_accum(state, newval);
4519 	}
4520 
4521 	PG_RETURN_POINTER(state);
4522 }
4523 
4524 /*
4525  * Combine function for numeric aggregates which require sumX2
4526  */
4527 Datum
numeric_poly_combine(PG_FUNCTION_ARGS)4528 numeric_poly_combine(PG_FUNCTION_ARGS)
4529 {
4530 	PolyNumAggState *state1;
4531 	PolyNumAggState *state2;
4532 	MemoryContext agg_context;
4533 	MemoryContext old_context;
4534 
4535 	if (!AggCheckCallContext(fcinfo, &agg_context))
4536 		elog(ERROR, "aggregate function called in non-aggregate context");
4537 
4538 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4539 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4540 
4541 	if (state2 == NULL)
4542 		PG_RETURN_POINTER(state1);
4543 
4544 	/* manually copy all fields from state2 to state1 */
4545 	if (state1 == NULL)
4546 	{
4547 		old_context = MemoryContextSwitchTo(agg_context);
4548 
4549 		state1 = makePolyNumAggState(fcinfo, true);
4550 		state1->N = state2->N;
4551 
4552 #ifdef HAVE_INT128
4553 		state1->sumX = state2->sumX;
4554 		state1->sumX2 = state2->sumX2;
4555 #else
4556 		accum_sum_copy(&state1->sumX, &state2->sumX);
4557 		accum_sum_copy(&state1->sumX2, &state2->sumX2);
4558 #endif
4559 
4560 		MemoryContextSwitchTo(old_context);
4561 
4562 		PG_RETURN_POINTER(state1);
4563 	}
4564 
4565 	if (state2->N > 0)
4566 	{
4567 		state1->N += state2->N;
4568 
4569 #ifdef HAVE_INT128
4570 		state1->sumX += state2->sumX;
4571 		state1->sumX2 += state2->sumX2;
4572 #else
4573 		/* The rest of this needs to work in the aggregate context */
4574 		old_context = MemoryContextSwitchTo(agg_context);
4575 
4576 		/* Accumulate sums */
4577 		accum_sum_combine(&state1->sumX, &state2->sumX);
4578 		accum_sum_combine(&state1->sumX2, &state2->sumX2);
4579 
4580 		MemoryContextSwitchTo(old_context);
4581 #endif
4582 
4583 	}
4584 	PG_RETURN_POINTER(state1);
4585 }
4586 
4587 /*
4588  * numeric_poly_serialize
4589  *		Serialize PolyNumAggState into bytea for aggregate functions which
4590  *		require sumX2.
4591  */
4592 Datum
numeric_poly_serialize(PG_FUNCTION_ARGS)4593 numeric_poly_serialize(PG_FUNCTION_ARGS)
4594 {
4595 	PolyNumAggState *state;
4596 	StringInfoData buf;
4597 	bytea	   *sumX;
4598 	bytea	   *sumX2;
4599 	bytea	   *result;
4600 
4601 	/* Ensure we disallow calling when not in aggregate context */
4602 	if (!AggCheckCallContext(fcinfo, NULL))
4603 		elog(ERROR, "aggregate function called in non-aggregate context");
4604 
4605 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4606 
4607 	/*
4608 	 * If the platform supports int128 then sumX and sumX2 will be a 128 bit
4609 	 * integer type. Here we'll convert that into a numeric type so that the
4610 	 * combine state is in the same format for both int128 enabled machines
4611 	 * and machines which don't support that type. The logic here is that one
4612 	 * day we might like to send these over to another server for further
4613 	 * processing and we want a standard format to work with.
4614 	 */
4615 	{
4616 		Datum		temp;
4617 		NumericVar	num;
4618 
4619 		init_var(&num);
4620 
4621 #ifdef HAVE_INT128
4622 		int128_to_numericvar(state->sumX, &num);
4623 #else
4624 		accum_sum_final(&state->sumX, &num);
4625 #endif
4626 		temp = DirectFunctionCall1(numeric_send,
4627 								   NumericGetDatum(make_result(&num)));
4628 		sumX = DatumGetByteaPP(temp);
4629 
4630 #ifdef HAVE_INT128
4631 		int128_to_numericvar(state->sumX2, &num);
4632 #else
4633 		accum_sum_final(&state->sumX2, &num);
4634 #endif
4635 		temp = DirectFunctionCall1(numeric_send,
4636 								   NumericGetDatum(make_result(&num)));
4637 		sumX2 = DatumGetByteaPP(temp);
4638 
4639 		free_var(&num);
4640 	}
4641 
4642 	pq_begintypsend(&buf);
4643 
4644 	/* N */
4645 	pq_sendint64(&buf, state->N);
4646 
4647 	/* sumX */
4648 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4649 
4650 	/* sumX2 */
4651 	pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
4652 
4653 	result = pq_endtypsend(&buf);
4654 
4655 	PG_RETURN_BYTEA_P(result);
4656 }
4657 
4658 /*
4659  * numeric_poly_deserialize
4660  *		Deserialize PolyNumAggState from bytea for aggregate functions which
4661  *		require sumX2.
4662  */
4663 Datum
numeric_poly_deserialize(PG_FUNCTION_ARGS)4664 numeric_poly_deserialize(PG_FUNCTION_ARGS)
4665 {
4666 	bytea	   *sstate;
4667 	PolyNumAggState *result;
4668 	Datum		sumX;
4669 	NumericVar	sumX_var;
4670 	Datum		sumX2;
4671 	NumericVar	sumX2_var;
4672 	StringInfoData buf;
4673 
4674 	if (!AggCheckCallContext(fcinfo, NULL))
4675 		elog(ERROR, "aggregate function called in non-aggregate context");
4676 
4677 	sstate = PG_GETARG_BYTEA_PP(0);
4678 
4679 	/*
4680 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4681 	 * standard recv-function infrastructure.
4682 	 */
4683 	initStringInfo(&buf);
4684 	appendBinaryStringInfo(&buf,
4685 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4686 
4687 	result = makePolyNumAggStateCurrentContext(false);
4688 
4689 	/* N */
4690 	result->N = pq_getmsgint64(&buf);
4691 
4692 	/* sumX */
4693 	sumX = DirectFunctionCall3(numeric_recv,
4694 							   PointerGetDatum(&buf),
4695 							   ObjectIdGetDatum(InvalidOid),
4696 							   Int32GetDatum(-1));
4697 
4698 	/* sumX2 */
4699 	sumX2 = DirectFunctionCall3(numeric_recv,
4700 								PointerGetDatum(&buf),
4701 								ObjectIdGetDatum(InvalidOid),
4702 								Int32GetDatum(-1));
4703 
4704 	init_var_from_num(DatumGetNumeric(sumX), &sumX_var);
4705 #ifdef HAVE_INT128
4706 	numericvar_to_int128(&sumX_var, &result->sumX);
4707 #else
4708 	accum_sum_add(&result->sumX, &sumX_var);
4709 #endif
4710 
4711 	init_var_from_num(DatumGetNumeric(sumX2), &sumX2_var);
4712 #ifdef HAVE_INT128
4713 	numericvar_to_int128(&sumX2_var, &result->sumX2);
4714 #else
4715 	accum_sum_add(&result->sumX2, &sumX2_var);
4716 #endif
4717 
4718 	pq_getmsgend(&buf);
4719 	pfree(buf.data);
4720 
4721 	PG_RETURN_POINTER(result);
4722 }
4723 
4724 /*
4725  * Transition function for int8 input when we don't need sumX2.
4726  */
4727 Datum
int8_avg_accum(PG_FUNCTION_ARGS)4728 int8_avg_accum(PG_FUNCTION_ARGS)
4729 {
4730 	PolyNumAggState *state;
4731 
4732 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4733 
4734 	/* Create the state data on the first call */
4735 	if (state == NULL)
4736 		state = makePolyNumAggState(fcinfo, false);
4737 
4738 	if (!PG_ARGISNULL(1))
4739 	{
4740 #ifdef HAVE_INT128
4741 		do_int128_accum(state, (int128) PG_GETARG_INT64(1));
4742 #else
4743 		Numeric		newval;
4744 
4745 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4746 													 PG_GETARG_DATUM(1)));
4747 		do_numeric_accum(state, newval);
4748 #endif
4749 	}
4750 
4751 	PG_RETURN_POINTER(state);
4752 }
4753 
4754 /*
4755  * Combine function for PolyNumAggState for aggregates which don't require
4756  * sumX2
4757  */
4758 Datum
int8_avg_combine(PG_FUNCTION_ARGS)4759 int8_avg_combine(PG_FUNCTION_ARGS)
4760 {
4761 	PolyNumAggState *state1;
4762 	PolyNumAggState *state2;
4763 	MemoryContext agg_context;
4764 	MemoryContext old_context;
4765 
4766 	if (!AggCheckCallContext(fcinfo, &agg_context))
4767 		elog(ERROR, "aggregate function called in non-aggregate context");
4768 
4769 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4770 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4771 
4772 	if (state2 == NULL)
4773 		PG_RETURN_POINTER(state1);
4774 
4775 	/* manually copy all fields from state2 to state1 */
4776 	if (state1 == NULL)
4777 	{
4778 		old_context = MemoryContextSwitchTo(agg_context);
4779 
4780 		state1 = makePolyNumAggState(fcinfo, false);
4781 		state1->N = state2->N;
4782 
4783 #ifdef HAVE_INT128
4784 		state1->sumX = state2->sumX;
4785 #else
4786 		accum_sum_copy(&state1->sumX, &state2->sumX);
4787 #endif
4788 		MemoryContextSwitchTo(old_context);
4789 
4790 		PG_RETURN_POINTER(state1);
4791 	}
4792 
4793 	if (state2->N > 0)
4794 	{
4795 		state1->N += state2->N;
4796 
4797 #ifdef HAVE_INT128
4798 		state1->sumX += state2->sumX;
4799 #else
4800 		/* The rest of this needs to work in the aggregate context */
4801 		old_context = MemoryContextSwitchTo(agg_context);
4802 
4803 		/* Accumulate sums */
4804 		accum_sum_combine(&state1->sumX, &state2->sumX);
4805 
4806 		MemoryContextSwitchTo(old_context);
4807 #endif
4808 
4809 	}
4810 	PG_RETURN_POINTER(state1);
4811 }
4812 
4813 /*
4814  * int8_avg_serialize
4815  *		Serialize PolyNumAggState into bytea using the standard
4816  *		recv-function infrastructure.
4817  */
4818 Datum
int8_avg_serialize(PG_FUNCTION_ARGS)4819 int8_avg_serialize(PG_FUNCTION_ARGS)
4820 {
4821 	PolyNumAggState *state;
4822 	StringInfoData buf;
4823 	bytea	   *sumX;
4824 	bytea	   *result;
4825 
4826 	/* Ensure we disallow calling when not in aggregate context */
4827 	if (!AggCheckCallContext(fcinfo, NULL))
4828 		elog(ERROR, "aggregate function called in non-aggregate context");
4829 
4830 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4831 
4832 	/*
4833 	 * If the platform supports int128 then sumX will be a 128 integer type.
4834 	 * Here we'll convert that into a numeric type so that the combine state
4835 	 * is in the same format for both int128 enabled machines and machines
4836 	 * which don't support that type. The logic here is that one day we might
4837 	 * like to send these over to another server for further processing and we
4838 	 * want a standard format to work with.
4839 	 */
4840 	{
4841 		Datum		temp;
4842 		NumericVar	num;
4843 
4844 		init_var(&num);
4845 
4846 #ifdef HAVE_INT128
4847 		int128_to_numericvar(state->sumX, &num);
4848 #else
4849 		accum_sum_final(&state->sumX, &num);
4850 #endif
4851 		temp = DirectFunctionCall1(numeric_send,
4852 								   NumericGetDatum(make_result(&num)));
4853 		sumX = DatumGetByteaPP(temp);
4854 
4855 		free_var(&num);
4856 	}
4857 
4858 	pq_begintypsend(&buf);
4859 
4860 	/* N */
4861 	pq_sendint64(&buf, state->N);
4862 
4863 	/* sumX */
4864 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4865 
4866 	result = pq_endtypsend(&buf);
4867 
4868 	PG_RETURN_BYTEA_P(result);
4869 }
4870 
4871 /*
4872  * int8_avg_deserialize
4873  *		Deserialize bytea back into PolyNumAggState.
4874  */
4875 Datum
int8_avg_deserialize(PG_FUNCTION_ARGS)4876 int8_avg_deserialize(PG_FUNCTION_ARGS)
4877 {
4878 	bytea	   *sstate;
4879 	PolyNumAggState *result;
4880 	StringInfoData buf;
4881 	Datum		temp;
4882 	NumericVar	num;
4883 
4884 	if (!AggCheckCallContext(fcinfo, NULL))
4885 		elog(ERROR, "aggregate function called in non-aggregate context");
4886 
4887 	sstate = PG_GETARG_BYTEA_PP(0);
4888 
4889 	/*
4890 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4891 	 * standard recv-function infrastructure.
4892 	 */
4893 	initStringInfo(&buf);
4894 	appendBinaryStringInfo(&buf,
4895 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4896 
4897 	result = makePolyNumAggStateCurrentContext(false);
4898 
4899 	/* N */
4900 	result->N = pq_getmsgint64(&buf);
4901 
4902 	/* sumX */
4903 	temp = DirectFunctionCall3(numeric_recv,
4904 							   PointerGetDatum(&buf),
4905 							   ObjectIdGetDatum(InvalidOid),
4906 							   Int32GetDatum(-1));
4907 	init_var_from_num(DatumGetNumeric(temp), &num);
4908 #ifdef HAVE_INT128
4909 	numericvar_to_int128(&num, &result->sumX);
4910 #else
4911 	accum_sum_add(&result->sumX, &num);
4912 #endif
4913 
4914 	pq_getmsgend(&buf);
4915 	pfree(buf.data);
4916 
4917 	PG_RETURN_POINTER(result);
4918 }
4919 
4920 /*
4921  * Inverse transition functions to go with the above.
4922  */
4923 
4924 Datum
int2_accum_inv(PG_FUNCTION_ARGS)4925 int2_accum_inv(PG_FUNCTION_ARGS)
4926 {
4927 	PolyNumAggState *state;
4928 
4929 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4930 
4931 	/* Should not get here with no state */
4932 	if (state == NULL)
4933 		elog(ERROR, "int2_accum_inv called with NULL state");
4934 
4935 	if (!PG_ARGISNULL(1))
4936 	{
4937 #ifdef HAVE_INT128
4938 		do_int128_discard(state, (int128) PG_GETARG_INT16(1));
4939 #else
4940 		Numeric		newval;
4941 
4942 		newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4943 													 PG_GETARG_DATUM(1)));
4944 
4945 		/* Should never fail, all inputs have dscale 0 */
4946 		if (!do_numeric_discard(state, newval))
4947 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4948 #endif
4949 	}
4950 
4951 	PG_RETURN_POINTER(state);
4952 }
4953 
4954 Datum
int4_accum_inv(PG_FUNCTION_ARGS)4955 int4_accum_inv(PG_FUNCTION_ARGS)
4956 {
4957 	PolyNumAggState *state;
4958 
4959 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4960 
4961 	/* Should not get here with no state */
4962 	if (state == NULL)
4963 		elog(ERROR, "int4_accum_inv called with NULL state");
4964 
4965 	if (!PG_ARGISNULL(1))
4966 	{
4967 #ifdef HAVE_INT128
4968 		do_int128_discard(state, (int128) PG_GETARG_INT32(1));
4969 #else
4970 		Numeric		newval;
4971 
4972 		newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4973 													 PG_GETARG_DATUM(1)));
4974 
4975 		/* Should never fail, all inputs have dscale 0 */
4976 		if (!do_numeric_discard(state, newval))
4977 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4978 #endif
4979 	}
4980 
4981 	PG_RETURN_POINTER(state);
4982 }
4983 
4984 Datum
int8_accum_inv(PG_FUNCTION_ARGS)4985 int8_accum_inv(PG_FUNCTION_ARGS)
4986 {
4987 	NumericAggState *state;
4988 
4989 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4990 
4991 	/* Should not get here with no state */
4992 	if (state == NULL)
4993 		elog(ERROR, "int8_accum_inv called with NULL state");
4994 
4995 	if (!PG_ARGISNULL(1))
4996 	{
4997 		Numeric		newval;
4998 
4999 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
5000 													 PG_GETARG_DATUM(1)));
5001 
5002 		/* Should never fail, all inputs have dscale 0 */
5003 		if (!do_numeric_discard(state, newval))
5004 			elog(ERROR, "do_numeric_discard failed unexpectedly");
5005 	}
5006 
5007 	PG_RETURN_POINTER(state);
5008 }
5009 
5010 Datum
int8_avg_accum_inv(PG_FUNCTION_ARGS)5011 int8_avg_accum_inv(PG_FUNCTION_ARGS)
5012 {
5013 	PolyNumAggState *state;
5014 
5015 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5016 
5017 	/* Should not get here with no state */
5018 	if (state == NULL)
5019 		elog(ERROR, "int8_avg_accum_inv called with NULL state");
5020 
5021 	if (!PG_ARGISNULL(1))
5022 	{
5023 #ifdef HAVE_INT128
5024 		do_int128_discard(state, (int128) PG_GETARG_INT64(1));
5025 #else
5026 		Numeric		newval;
5027 
5028 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
5029 													 PG_GETARG_DATUM(1)));
5030 
5031 		/* Should never fail, all inputs have dscale 0 */
5032 		if (!do_numeric_discard(state, newval))
5033 			elog(ERROR, "do_numeric_discard failed unexpectedly");
5034 #endif
5035 	}
5036 
5037 	PG_RETURN_POINTER(state);
5038 }
5039 
5040 Datum
numeric_poly_sum(PG_FUNCTION_ARGS)5041 numeric_poly_sum(PG_FUNCTION_ARGS)
5042 {
5043 #ifdef HAVE_INT128
5044 	PolyNumAggState *state;
5045 	Numeric		res;
5046 	NumericVar	result;
5047 
5048 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5049 
5050 	/* If there were no non-null inputs, return NULL */
5051 	if (state == NULL || state->N == 0)
5052 		PG_RETURN_NULL();
5053 
5054 	init_var(&result);
5055 
5056 	int128_to_numericvar(state->sumX, &result);
5057 
5058 	res = make_result(&result);
5059 
5060 	free_var(&result);
5061 
5062 	PG_RETURN_NUMERIC(res);
5063 #else
5064 	return numeric_sum(fcinfo);
5065 #endif
5066 }
5067 
5068 Datum
numeric_poly_avg(PG_FUNCTION_ARGS)5069 numeric_poly_avg(PG_FUNCTION_ARGS)
5070 {
5071 #ifdef HAVE_INT128
5072 	PolyNumAggState *state;
5073 	NumericVar	result;
5074 	Datum		countd,
5075 				sumd;
5076 
5077 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5078 
5079 	/* If there were no non-null inputs, return NULL */
5080 	if (state == NULL || state->N == 0)
5081 		PG_RETURN_NULL();
5082 
5083 	init_var(&result);
5084 
5085 	int128_to_numericvar(state->sumX, &result);
5086 
5087 	countd = DirectFunctionCall1(int8_numeric,
5088 								 Int64GetDatumFast(state->N));
5089 	sumd = NumericGetDatum(make_result(&result));
5090 
5091 	free_var(&result);
5092 
5093 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
5094 #else
5095 	return numeric_avg(fcinfo);
5096 #endif
5097 }
5098 
5099 Datum
numeric_avg(PG_FUNCTION_ARGS)5100 numeric_avg(PG_FUNCTION_ARGS)
5101 {
5102 	NumericAggState *state;
5103 	Datum		N_datum;
5104 	Datum		sumX_datum;
5105 	NumericVar	sumX_var;
5106 
5107 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5108 
5109 	/* If there were no non-null inputs, return NULL */
5110 	if (state == NULL || (state->N + state->NaNcount) == 0)
5111 		PG_RETURN_NULL();
5112 
5113 	if (state->NaNcount > 0)	/* there was at least one NaN input */
5114 		PG_RETURN_NUMERIC(make_result(&const_nan));
5115 
5116 	N_datum = DirectFunctionCall1(int8_numeric, Int64GetDatum(state->N));
5117 
5118 	init_var(&sumX_var);
5119 	accum_sum_final(&state->sumX, &sumX_var);
5120 	sumX_datum = NumericGetDatum(make_result(&sumX_var));
5121 	free_var(&sumX_var);
5122 
5123 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumX_datum, N_datum));
5124 }
5125 
5126 Datum
numeric_sum(PG_FUNCTION_ARGS)5127 numeric_sum(PG_FUNCTION_ARGS)
5128 {
5129 	NumericAggState *state;
5130 	NumericVar	sumX_var;
5131 	Numeric		result;
5132 
5133 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5134 
5135 	/* If there were no non-null inputs, return NULL */
5136 	if (state == NULL || (state->N + state->NaNcount) == 0)
5137 		PG_RETURN_NULL();
5138 
5139 	if (state->NaNcount > 0)	/* there was at least one NaN input */
5140 		PG_RETURN_NUMERIC(make_result(&const_nan));
5141 
5142 	init_var(&sumX_var);
5143 	accum_sum_final(&state->sumX, &sumX_var);
5144 	result = make_result(&sumX_var);
5145 	free_var(&sumX_var);
5146 
5147 	PG_RETURN_NUMERIC(result);
5148 }
5149 
5150 /*
5151  * Workhorse routine for the standard deviance and variance
5152  * aggregates. 'state' is aggregate's transition state.
5153  * 'variance' specifies whether we should calculate the
5154  * variance or the standard deviation. 'sample' indicates whether the
5155  * caller is interested in the sample or the population
5156  * variance/stddev.
5157  *
5158  * If appropriate variance statistic is undefined for the input,
5159  * *is_null is set to true and NULL is returned.
5160  */
5161 static Numeric
numeric_stddev_internal(NumericAggState * state,bool variance,bool sample,bool * is_null)5162 numeric_stddev_internal(NumericAggState *state,
5163 						bool variance, bool sample,
5164 						bool *is_null)
5165 {
5166 	Numeric		res;
5167 	NumericVar	vN,
5168 				vsumX,
5169 				vsumX2,
5170 				vNminus1;
5171 	const NumericVar *comp;
5172 	int			rscale;
5173 
5174 	/* Deal with empty input and NaN-input cases */
5175 	if (state == NULL || (state->N + state->NaNcount) == 0)
5176 	{
5177 		*is_null = true;
5178 		return NULL;
5179 	}
5180 
5181 	*is_null = false;
5182 
5183 	if (state->NaNcount > 0)
5184 		return make_result(&const_nan);
5185 
5186 	init_var(&vN);
5187 	init_var(&vsumX);
5188 	init_var(&vsumX2);
5189 
5190 	int64_to_numericvar(state->N, &vN);
5191 	accum_sum_final(&(state->sumX), &vsumX);
5192 	accum_sum_final(&(state->sumX2), &vsumX2);
5193 
5194 	/*
5195 	 * Sample stddev and variance are undefined when N <= 1; population stddev
5196 	 * is undefined when N == 0. Return NULL in either case.
5197 	 */
5198 	if (sample)
5199 		comp = &const_one;
5200 	else
5201 		comp = &const_zero;
5202 
5203 	if (cmp_var(&vN, comp) <= 0)
5204 	{
5205 		*is_null = true;
5206 		return NULL;
5207 	}
5208 
5209 	init_var(&vNminus1);
5210 	sub_var(&vN, &const_one, &vNminus1);
5211 
5212 	/* compute rscale for mul_var calls */
5213 	rscale = vsumX.dscale * 2;
5214 
5215 	mul_var(&vsumX, &vsumX, &vsumX, rscale);	/* vsumX = sumX * sumX */
5216 	mul_var(&vN, &vsumX2, &vsumX2, rscale); /* vsumX2 = N * sumX2 */
5217 	sub_var(&vsumX2, &vsumX, &vsumX2);	/* N * sumX2 - sumX * sumX */
5218 
5219 	if (cmp_var(&vsumX2, &const_zero) <= 0)
5220 	{
5221 		/* Watch out for roundoff error producing a negative numerator */
5222 		res = make_result(&const_zero);
5223 	}
5224 	else
5225 	{
5226 		if (sample)
5227 			mul_var(&vN, &vNminus1, &vNminus1, 0);	/* N * (N - 1) */
5228 		else
5229 			mul_var(&vN, &vN, &vNminus1, 0);	/* N * N */
5230 		rscale = select_div_scale(&vsumX2, &vNminus1);
5231 		div_var(&vsumX2, &vNminus1, &vsumX, rscale, true);	/* variance */
5232 		if (!variance)
5233 			sqrt_var(&vsumX, &vsumX, rscale);	/* stddev */
5234 
5235 		res = make_result(&vsumX);
5236 	}
5237 
5238 	free_var(&vNminus1);
5239 	free_var(&vsumX);
5240 	free_var(&vsumX2);
5241 
5242 	return res;
5243 }
5244 
5245 Datum
numeric_var_samp(PG_FUNCTION_ARGS)5246 numeric_var_samp(PG_FUNCTION_ARGS)
5247 {
5248 	NumericAggState *state;
5249 	Numeric		res;
5250 	bool		is_null;
5251 
5252 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5253 
5254 	res = numeric_stddev_internal(state, true, true, &is_null);
5255 
5256 	if (is_null)
5257 		PG_RETURN_NULL();
5258 	else
5259 		PG_RETURN_NUMERIC(res);
5260 }
5261 
5262 Datum
numeric_stddev_samp(PG_FUNCTION_ARGS)5263 numeric_stddev_samp(PG_FUNCTION_ARGS)
5264 {
5265 	NumericAggState *state;
5266 	Numeric		res;
5267 	bool		is_null;
5268 
5269 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5270 
5271 	res = numeric_stddev_internal(state, false, true, &is_null);
5272 
5273 	if (is_null)
5274 		PG_RETURN_NULL();
5275 	else
5276 		PG_RETURN_NUMERIC(res);
5277 }
5278 
5279 Datum
numeric_var_pop(PG_FUNCTION_ARGS)5280 numeric_var_pop(PG_FUNCTION_ARGS)
5281 {
5282 	NumericAggState *state;
5283 	Numeric		res;
5284 	bool		is_null;
5285 
5286 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5287 
5288 	res = numeric_stddev_internal(state, true, false, &is_null);
5289 
5290 	if (is_null)
5291 		PG_RETURN_NULL();
5292 	else
5293 		PG_RETURN_NUMERIC(res);
5294 }
5295 
5296 Datum
numeric_stddev_pop(PG_FUNCTION_ARGS)5297 numeric_stddev_pop(PG_FUNCTION_ARGS)
5298 {
5299 	NumericAggState *state;
5300 	Numeric		res;
5301 	bool		is_null;
5302 
5303 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5304 
5305 	res = numeric_stddev_internal(state, false, false, &is_null);
5306 
5307 	if (is_null)
5308 		PG_RETURN_NULL();
5309 	else
5310 		PG_RETURN_NUMERIC(res);
5311 }
5312 
5313 #ifdef HAVE_INT128
5314 static Numeric
numeric_poly_stddev_internal(Int128AggState * state,bool variance,bool sample,bool * is_null)5315 numeric_poly_stddev_internal(Int128AggState *state,
5316 							 bool variance, bool sample,
5317 							 bool *is_null)
5318 {
5319 	NumericAggState numstate;
5320 	Numeric		res;
5321 
5322 	/* Initialize an empty agg state */
5323 	memset(&numstate, 0, sizeof(NumericAggState));
5324 
5325 	if (state)
5326 	{
5327 		NumericVar	tmp_var;
5328 
5329 		numstate.N = state->N;
5330 
5331 		init_var(&tmp_var);
5332 
5333 		int128_to_numericvar(state->sumX, &tmp_var);
5334 		accum_sum_add(&numstate.sumX, &tmp_var);
5335 
5336 		int128_to_numericvar(state->sumX2, &tmp_var);
5337 		accum_sum_add(&numstate.sumX2, &tmp_var);
5338 
5339 		free_var(&tmp_var);
5340 	}
5341 
5342 	res = numeric_stddev_internal(&numstate, variance, sample, is_null);
5343 
5344 	if (numstate.sumX.ndigits > 0)
5345 	{
5346 		pfree(numstate.sumX.pos_digits);
5347 		pfree(numstate.sumX.neg_digits);
5348 	}
5349 	if (numstate.sumX2.ndigits > 0)
5350 	{
5351 		pfree(numstate.sumX2.pos_digits);
5352 		pfree(numstate.sumX2.neg_digits);
5353 	}
5354 
5355 	return res;
5356 }
5357 #endif
5358 
5359 Datum
numeric_poly_var_samp(PG_FUNCTION_ARGS)5360 numeric_poly_var_samp(PG_FUNCTION_ARGS)
5361 {
5362 #ifdef HAVE_INT128
5363 	PolyNumAggState *state;
5364 	Numeric		res;
5365 	bool		is_null;
5366 
5367 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5368 
5369 	res = numeric_poly_stddev_internal(state, true, true, &is_null);
5370 
5371 	if (is_null)
5372 		PG_RETURN_NULL();
5373 	else
5374 		PG_RETURN_NUMERIC(res);
5375 #else
5376 	return numeric_var_samp(fcinfo);
5377 #endif
5378 }
5379 
5380 Datum
numeric_poly_stddev_samp(PG_FUNCTION_ARGS)5381 numeric_poly_stddev_samp(PG_FUNCTION_ARGS)
5382 {
5383 #ifdef HAVE_INT128
5384 	PolyNumAggState *state;
5385 	Numeric		res;
5386 	bool		is_null;
5387 
5388 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5389 
5390 	res = numeric_poly_stddev_internal(state, false, true, &is_null);
5391 
5392 	if (is_null)
5393 		PG_RETURN_NULL();
5394 	else
5395 		PG_RETURN_NUMERIC(res);
5396 #else
5397 	return numeric_stddev_samp(fcinfo);
5398 #endif
5399 }
5400 
5401 Datum
numeric_poly_var_pop(PG_FUNCTION_ARGS)5402 numeric_poly_var_pop(PG_FUNCTION_ARGS)
5403 {
5404 #ifdef HAVE_INT128
5405 	PolyNumAggState *state;
5406 	Numeric		res;
5407 	bool		is_null;
5408 
5409 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5410 
5411 	res = numeric_poly_stddev_internal(state, true, false, &is_null);
5412 
5413 	if (is_null)
5414 		PG_RETURN_NULL();
5415 	else
5416 		PG_RETURN_NUMERIC(res);
5417 #else
5418 	return numeric_var_pop(fcinfo);
5419 #endif
5420 }
5421 
5422 Datum
numeric_poly_stddev_pop(PG_FUNCTION_ARGS)5423 numeric_poly_stddev_pop(PG_FUNCTION_ARGS)
5424 {
5425 #ifdef HAVE_INT128
5426 	PolyNumAggState *state;
5427 	Numeric		res;
5428 	bool		is_null;
5429 
5430 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5431 
5432 	res = numeric_poly_stddev_internal(state, false, false, &is_null);
5433 
5434 	if (is_null)
5435 		PG_RETURN_NULL();
5436 	else
5437 		PG_RETURN_NUMERIC(res);
5438 #else
5439 	return numeric_stddev_pop(fcinfo);
5440 #endif
5441 }
5442 
5443 /*
5444  * SUM transition functions for integer datatypes.
5445  *
5446  * To avoid overflow, we use accumulators wider than the input datatype.
5447  * A Numeric accumulator is needed for int8 input; for int4 and int2
5448  * inputs, we use int8 accumulators which should be sufficient for practical
5449  * purposes.  (The latter two therefore don't really belong in this file,
5450  * but we keep them here anyway.)
5451  *
5452  * Because SQL defines the SUM() of no values to be NULL, not zero,
5453  * the initial condition of the transition data value needs to be NULL. This
5454  * means we can't rely on ExecAgg to automatically insert the first non-null
5455  * data value into the transition data: it doesn't know how to do the type
5456  * conversion.  The upshot is that these routines have to be marked non-strict
5457  * and handle substitution of the first non-null input themselves.
5458  *
5459  * Note: these functions are used only in plain aggregation mode.
5460  * In moving-aggregate mode, we use intX_avg_accum and intX_avg_accum_inv.
5461  */
5462 
5463 Datum
int2_sum(PG_FUNCTION_ARGS)5464 int2_sum(PG_FUNCTION_ARGS)
5465 {
5466 	int64		newval;
5467 
5468 	if (PG_ARGISNULL(0))
5469 	{
5470 		/* No non-null input seen so far... */
5471 		if (PG_ARGISNULL(1))
5472 			PG_RETURN_NULL();	/* still no non-null */
5473 		/* This is the first non-null input. */
5474 		newval = (int64) PG_GETARG_INT16(1);
5475 		PG_RETURN_INT64(newval);
5476 	}
5477 
5478 	/*
5479 	 * If we're invoked as an aggregate, we can cheat and modify our first
5480 	 * parameter in-place to avoid palloc overhead. If not, we need to return
5481 	 * the new value of the transition variable. (If int8 is pass-by-value,
5482 	 * then of course this is useless as well as incorrect, so just ifdef it
5483 	 * out.)
5484 	 */
5485 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
5486 	if (AggCheckCallContext(fcinfo, NULL))
5487 	{
5488 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
5489 
5490 		/* Leave the running sum unchanged in the new input is null */
5491 		if (!PG_ARGISNULL(1))
5492 			*oldsum = *oldsum + (int64) PG_GETARG_INT16(1);
5493 
5494 		PG_RETURN_POINTER(oldsum);
5495 	}
5496 	else
5497 #endif
5498 	{
5499 		int64		oldsum = PG_GETARG_INT64(0);
5500 
5501 		/* Leave sum unchanged if new input is null. */
5502 		if (PG_ARGISNULL(1))
5503 			PG_RETURN_INT64(oldsum);
5504 
5505 		/* OK to do the addition. */
5506 		newval = oldsum + (int64) PG_GETARG_INT16(1);
5507 
5508 		PG_RETURN_INT64(newval);
5509 	}
5510 }
5511 
5512 Datum
int4_sum(PG_FUNCTION_ARGS)5513 int4_sum(PG_FUNCTION_ARGS)
5514 {
5515 	int64		newval;
5516 
5517 	if (PG_ARGISNULL(0))
5518 	{
5519 		/* No non-null input seen so far... */
5520 		if (PG_ARGISNULL(1))
5521 			PG_RETURN_NULL();	/* still no non-null */
5522 		/* This is the first non-null input. */
5523 		newval = (int64) PG_GETARG_INT32(1);
5524 		PG_RETURN_INT64(newval);
5525 	}
5526 
5527 	/*
5528 	 * If we're invoked as an aggregate, we can cheat and modify our first
5529 	 * parameter in-place to avoid palloc overhead. If not, we need to return
5530 	 * the new value of the transition variable. (If int8 is pass-by-value,
5531 	 * then of course this is useless as well as incorrect, so just ifdef it
5532 	 * out.)
5533 	 */
5534 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
5535 	if (AggCheckCallContext(fcinfo, NULL))
5536 	{
5537 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
5538 
5539 		/* Leave the running sum unchanged in the new input is null */
5540 		if (!PG_ARGISNULL(1))
5541 			*oldsum = *oldsum + (int64) PG_GETARG_INT32(1);
5542 
5543 		PG_RETURN_POINTER(oldsum);
5544 	}
5545 	else
5546 #endif
5547 	{
5548 		int64		oldsum = PG_GETARG_INT64(0);
5549 
5550 		/* Leave sum unchanged if new input is null. */
5551 		if (PG_ARGISNULL(1))
5552 			PG_RETURN_INT64(oldsum);
5553 
5554 		/* OK to do the addition. */
5555 		newval = oldsum + (int64) PG_GETARG_INT32(1);
5556 
5557 		PG_RETURN_INT64(newval);
5558 	}
5559 }
5560 
5561 /*
5562  * Note: this function is obsolete, it's no longer used for SUM(int8).
5563  */
5564 Datum
int8_sum(PG_FUNCTION_ARGS)5565 int8_sum(PG_FUNCTION_ARGS)
5566 {
5567 	Numeric		oldsum;
5568 	Datum		newval;
5569 
5570 	if (PG_ARGISNULL(0))
5571 	{
5572 		/* No non-null input seen so far... */
5573 		if (PG_ARGISNULL(1))
5574 			PG_RETURN_NULL();	/* still no non-null */
5575 		/* This is the first non-null input. */
5576 		newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5577 		PG_RETURN_DATUM(newval);
5578 	}
5579 
5580 	/*
5581 	 * Note that we cannot special-case the aggregate case here, as we do for
5582 	 * int2_sum and int4_sum: numeric is of variable size, so we cannot modify
5583 	 * our first parameter in-place.
5584 	 */
5585 
5586 	oldsum = PG_GETARG_NUMERIC(0);
5587 
5588 	/* Leave sum unchanged if new input is null. */
5589 	if (PG_ARGISNULL(1))
5590 		PG_RETURN_NUMERIC(oldsum);
5591 
5592 	/* OK to do the addition. */
5593 	newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5594 
5595 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_add,
5596 										NumericGetDatum(oldsum), newval));
5597 }
5598 
5599 
5600 /*
5601  * Routines for avg(int2) and avg(int4).  The transition datatype
5602  * is a two-element int8 array, holding count and sum.
5603  *
5604  * These functions are also used for sum(int2) and sum(int4) when
5605  * operating in moving-aggregate mode, since for correct inverse transitions
5606  * we need to count the inputs.
5607  */
5608 
5609 typedef struct Int8TransTypeData
5610 {
5611 	int64		count;
5612 	int64		sum;
5613 } Int8TransTypeData;
5614 
5615 Datum
int2_avg_accum(PG_FUNCTION_ARGS)5616 int2_avg_accum(PG_FUNCTION_ARGS)
5617 {
5618 	ArrayType  *transarray;
5619 	int16		newval = PG_GETARG_INT16(1);
5620 	Int8TransTypeData *transdata;
5621 
5622 	/*
5623 	 * If we're invoked as an aggregate, we can cheat and modify our first
5624 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5625 	 * a copy of it before scribbling on it.
5626 	 */
5627 	if (AggCheckCallContext(fcinfo, NULL))
5628 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5629 	else
5630 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5631 
5632 	if (ARR_HASNULL(transarray) ||
5633 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5634 		elog(ERROR, "expected 2-element int8 array");
5635 
5636 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5637 	transdata->count++;
5638 	transdata->sum += newval;
5639 
5640 	PG_RETURN_ARRAYTYPE_P(transarray);
5641 }
5642 
5643 Datum
int4_avg_accum(PG_FUNCTION_ARGS)5644 int4_avg_accum(PG_FUNCTION_ARGS)
5645 {
5646 	ArrayType  *transarray;
5647 	int32		newval = PG_GETARG_INT32(1);
5648 	Int8TransTypeData *transdata;
5649 
5650 	/*
5651 	 * If we're invoked as an aggregate, we can cheat and modify our first
5652 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5653 	 * a copy of it before scribbling on it.
5654 	 */
5655 	if (AggCheckCallContext(fcinfo, NULL))
5656 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5657 	else
5658 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5659 
5660 	if (ARR_HASNULL(transarray) ||
5661 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5662 		elog(ERROR, "expected 2-element int8 array");
5663 
5664 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5665 	transdata->count++;
5666 	transdata->sum += newval;
5667 
5668 	PG_RETURN_ARRAYTYPE_P(transarray);
5669 }
5670 
5671 Datum
int4_avg_combine(PG_FUNCTION_ARGS)5672 int4_avg_combine(PG_FUNCTION_ARGS)
5673 {
5674 	ArrayType  *transarray1;
5675 	ArrayType  *transarray2;
5676 	Int8TransTypeData *state1;
5677 	Int8TransTypeData *state2;
5678 
5679 	if (!AggCheckCallContext(fcinfo, NULL))
5680 		elog(ERROR, "aggregate function called in non-aggregate context");
5681 
5682 	transarray1 = PG_GETARG_ARRAYTYPE_P(0);
5683 	transarray2 = PG_GETARG_ARRAYTYPE_P(1);
5684 
5685 	if (ARR_HASNULL(transarray1) ||
5686 		ARR_SIZE(transarray1) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5687 		elog(ERROR, "expected 2-element int8 array");
5688 
5689 	if (ARR_HASNULL(transarray2) ||
5690 		ARR_SIZE(transarray2) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5691 		elog(ERROR, "expected 2-element int8 array");
5692 
5693 	state1 = (Int8TransTypeData *) ARR_DATA_PTR(transarray1);
5694 	state2 = (Int8TransTypeData *) ARR_DATA_PTR(transarray2);
5695 
5696 	state1->count += state2->count;
5697 	state1->sum += state2->sum;
5698 
5699 	PG_RETURN_ARRAYTYPE_P(transarray1);
5700 }
5701 
5702 Datum
int2_avg_accum_inv(PG_FUNCTION_ARGS)5703 int2_avg_accum_inv(PG_FUNCTION_ARGS)
5704 {
5705 	ArrayType  *transarray;
5706 	int16		newval = PG_GETARG_INT16(1);
5707 	Int8TransTypeData *transdata;
5708 
5709 	/*
5710 	 * If we're invoked as an aggregate, we can cheat and modify our first
5711 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5712 	 * a copy of it before scribbling on it.
5713 	 */
5714 	if (AggCheckCallContext(fcinfo, NULL))
5715 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5716 	else
5717 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5718 
5719 	if (ARR_HASNULL(transarray) ||
5720 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5721 		elog(ERROR, "expected 2-element int8 array");
5722 
5723 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5724 	transdata->count--;
5725 	transdata->sum -= newval;
5726 
5727 	PG_RETURN_ARRAYTYPE_P(transarray);
5728 }
5729 
5730 Datum
int4_avg_accum_inv(PG_FUNCTION_ARGS)5731 int4_avg_accum_inv(PG_FUNCTION_ARGS)
5732 {
5733 	ArrayType  *transarray;
5734 	int32		newval = PG_GETARG_INT32(1);
5735 	Int8TransTypeData *transdata;
5736 
5737 	/*
5738 	 * If we're invoked as an aggregate, we can cheat and modify our first
5739 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5740 	 * a copy of it before scribbling on it.
5741 	 */
5742 	if (AggCheckCallContext(fcinfo, NULL))
5743 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5744 	else
5745 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5746 
5747 	if (ARR_HASNULL(transarray) ||
5748 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5749 		elog(ERROR, "expected 2-element int8 array");
5750 
5751 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5752 	transdata->count--;
5753 	transdata->sum -= newval;
5754 
5755 	PG_RETURN_ARRAYTYPE_P(transarray);
5756 }
5757 
5758 Datum
int8_avg(PG_FUNCTION_ARGS)5759 int8_avg(PG_FUNCTION_ARGS)
5760 {
5761 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
5762 	Int8TransTypeData *transdata;
5763 	Datum		countd,
5764 				sumd;
5765 
5766 	if (ARR_HASNULL(transarray) ||
5767 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5768 		elog(ERROR, "expected 2-element int8 array");
5769 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5770 
5771 	/* SQL defines AVG of no values to be NULL */
5772 	if (transdata->count == 0)
5773 		PG_RETURN_NULL();
5774 
5775 	countd = DirectFunctionCall1(int8_numeric,
5776 								 Int64GetDatumFast(transdata->count));
5777 	sumd = DirectFunctionCall1(int8_numeric,
5778 							   Int64GetDatumFast(transdata->sum));
5779 
5780 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
5781 }
5782 
5783 /*
5784  * SUM(int2) and SUM(int4) both return int8, so we can use this
5785  * final function for both.
5786  */
5787 Datum
int2int4_sum(PG_FUNCTION_ARGS)5788 int2int4_sum(PG_FUNCTION_ARGS)
5789 {
5790 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
5791 	Int8TransTypeData *transdata;
5792 
5793 	if (ARR_HASNULL(transarray) ||
5794 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5795 		elog(ERROR, "expected 2-element int8 array");
5796 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5797 
5798 	/* SQL defines SUM of no values to be NULL */
5799 	if (transdata->count == 0)
5800 		PG_RETURN_NULL();
5801 
5802 	PG_RETURN_DATUM(Int64GetDatumFast(transdata->sum));
5803 }
5804 
5805 
5806 /* ----------------------------------------------------------------------
5807  *
5808  * Debug support
5809  *
5810  * ----------------------------------------------------------------------
5811  */
5812 
5813 #ifdef NUMERIC_DEBUG
5814 
5815 /*
5816  * dump_numeric() - Dump a value in the db storage format for debugging
5817  */
5818 static void
dump_numeric(const char * str,Numeric num)5819 dump_numeric(const char *str, Numeric num)
5820 {
5821 	NumericDigit *digits = NUMERIC_DIGITS(num);
5822 	int			ndigits;
5823 	int			i;
5824 
5825 	ndigits = NUMERIC_NDIGITS(num);
5826 
5827 	printf("%s: NUMERIC w=%d d=%d ", str,
5828 		   NUMERIC_WEIGHT(num), NUMERIC_DSCALE(num));
5829 	switch (NUMERIC_SIGN(num))
5830 	{
5831 		case NUMERIC_POS:
5832 			printf("POS");
5833 			break;
5834 		case NUMERIC_NEG:
5835 			printf("NEG");
5836 			break;
5837 		case NUMERIC_NAN:
5838 			printf("NaN");
5839 			break;
5840 		default:
5841 			printf("SIGN=0x%x", NUMERIC_SIGN(num));
5842 			break;
5843 	}
5844 
5845 	for (i = 0; i < ndigits; i++)
5846 		printf(" %0*d", DEC_DIGITS, digits[i]);
5847 	printf("\n");
5848 }
5849 
5850 
5851 /*
5852  * dump_var() - Dump a value in the variable format for debugging
5853  */
5854 static void
dump_var(const char * str,NumericVar * var)5855 dump_var(const char *str, NumericVar *var)
5856 {
5857 	int			i;
5858 
5859 	printf("%s: VAR w=%d d=%d ", str, var->weight, var->dscale);
5860 	switch (var->sign)
5861 	{
5862 		case NUMERIC_POS:
5863 			printf("POS");
5864 			break;
5865 		case NUMERIC_NEG:
5866 			printf("NEG");
5867 			break;
5868 		case NUMERIC_NAN:
5869 			printf("NaN");
5870 			break;
5871 		default:
5872 			printf("SIGN=0x%x", var->sign);
5873 			break;
5874 	}
5875 
5876 	for (i = 0; i < var->ndigits; i++)
5877 		printf(" %0*d", DEC_DIGITS, var->digits[i]);
5878 
5879 	printf("\n");
5880 }
5881 #endif							/* NUMERIC_DEBUG */
5882 
5883 
5884 /* ----------------------------------------------------------------------
5885  *
5886  * Local functions follow
5887  *
5888  * In general, these do not support NaNs --- callers must eliminate
5889  * the possibility of NaN first.  (make_result() is an exception.)
5890  *
5891  * ----------------------------------------------------------------------
5892  */
5893 
5894 
5895 /*
5896  * alloc_var() -
5897  *
5898  *	Allocate a digit buffer of ndigits digits (plus a spare digit for rounding)
5899  */
5900 static void
alloc_var(NumericVar * var,int ndigits)5901 alloc_var(NumericVar *var, int ndigits)
5902 {
5903 	digitbuf_free(var->buf);
5904 	var->buf = digitbuf_alloc(ndigits + 1);
5905 	var->buf[0] = 0;			/* spare digit for rounding */
5906 	var->digits = var->buf + 1;
5907 	var->ndigits = ndigits;
5908 }
5909 
5910 
5911 /*
5912  * free_var() -
5913  *
5914  *	Return the digit buffer of a variable to the free pool
5915  */
5916 static void
free_var(NumericVar * var)5917 free_var(NumericVar *var)
5918 {
5919 	digitbuf_free(var->buf);
5920 	var->buf = NULL;
5921 	var->digits = NULL;
5922 	var->sign = NUMERIC_NAN;
5923 }
5924 
5925 
5926 /*
5927  * zero_var() -
5928  *
5929  *	Set a variable to ZERO.
5930  *	Note: its dscale is not touched.
5931  */
5932 static void
zero_var(NumericVar * var)5933 zero_var(NumericVar *var)
5934 {
5935 	digitbuf_free(var->buf);
5936 	var->buf = NULL;
5937 	var->digits = NULL;
5938 	var->ndigits = 0;
5939 	var->weight = 0;			/* by convention; doesn't really matter */
5940 	var->sign = NUMERIC_POS;	/* anything but NAN... */
5941 }
5942 
5943 
5944 /*
5945  * set_var_from_str()
5946  *
5947  *	Parse a string and put the number into a variable
5948  *
5949  * This function does not handle leading or trailing spaces, and it doesn't
5950  * accept "NaN" either.  It returns the end+1 position so that caller can
5951  * check for trailing spaces/garbage if deemed necessary.
5952  *
5953  * cp is the place to actually start parsing; str is what to use in error
5954  * reports.  (Typically cp would be the same except advanced over spaces.)
5955  */
5956 static const char *
set_var_from_str(const char * str,const char * cp,NumericVar * dest)5957 set_var_from_str(const char *str, const char *cp, NumericVar *dest)
5958 {
5959 	bool		have_dp = false;
5960 	int			i;
5961 	unsigned char *decdigits;
5962 	int			sign = NUMERIC_POS;
5963 	int			dweight = -1;
5964 	int			ddigits;
5965 	int			dscale = 0;
5966 	int			weight;
5967 	int			ndigits;
5968 	int			offset;
5969 	NumericDigit *digits;
5970 
5971 	/*
5972 	 * We first parse the string to extract decimal digits and determine the
5973 	 * correct decimal weight.  Then convert to NBASE representation.
5974 	 */
5975 	switch (*cp)
5976 	{
5977 		case '+':
5978 			sign = NUMERIC_POS;
5979 			cp++;
5980 			break;
5981 
5982 		case '-':
5983 			sign = NUMERIC_NEG;
5984 			cp++;
5985 			break;
5986 	}
5987 
5988 	if (*cp == '.')
5989 	{
5990 		have_dp = true;
5991 		cp++;
5992 	}
5993 
5994 	if (!isdigit((unsigned char) *cp))
5995 		ereport(ERROR,
5996 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5997 				 errmsg("invalid input syntax for type %s: \"%s\"",
5998 						"numeric", str)));
5999 
6000 	decdigits = (unsigned char *) palloc(strlen(cp) + DEC_DIGITS * 2);
6001 
6002 	/* leading padding for digit alignment later */
6003 	memset(decdigits, 0, DEC_DIGITS);
6004 	i = DEC_DIGITS;
6005 
6006 	while (*cp)
6007 	{
6008 		if (isdigit((unsigned char) *cp))
6009 		{
6010 			decdigits[i++] = *cp++ - '0';
6011 			if (!have_dp)
6012 				dweight++;
6013 			else
6014 				dscale++;
6015 		}
6016 		else if (*cp == '.')
6017 		{
6018 			if (have_dp)
6019 				ereport(ERROR,
6020 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6021 						 errmsg("invalid input syntax for type %s: \"%s\"",
6022 								"numeric", str)));
6023 			have_dp = true;
6024 			cp++;
6025 		}
6026 		else
6027 			break;
6028 	}
6029 
6030 	ddigits = i - DEC_DIGITS;
6031 	/* trailing padding for digit alignment later */
6032 	memset(decdigits + i, 0, DEC_DIGITS - 1);
6033 
6034 	/* Handle exponent, if any */
6035 	if (*cp == 'e' || *cp == 'E')
6036 	{
6037 		long		exponent;
6038 		char	   *endptr;
6039 
6040 		cp++;
6041 		exponent = strtol(cp, &endptr, 10);
6042 		if (endptr == cp)
6043 			ereport(ERROR,
6044 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6045 					 errmsg("invalid input syntax for type %s: \"%s\"",
6046 							"numeric", str)));
6047 		cp = endptr;
6048 
6049 		/*
6050 		 * At this point, dweight and dscale can't be more than about
6051 		 * INT_MAX/2 due to the MaxAllocSize limit on string length, so
6052 		 * constraining the exponent similarly should be enough to prevent
6053 		 * integer overflow in this function.  If the value is too large to
6054 		 * fit in storage format, make_result() will complain about it later;
6055 		 * for consistency use the same ereport errcode/text as make_result().
6056 		 */
6057 		if (exponent >= INT_MAX / 2 || exponent <= -(INT_MAX / 2))
6058 			ereport(ERROR,
6059 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6060 					 errmsg("value overflows numeric format")));
6061 		dweight += (int) exponent;
6062 		dscale -= (int) exponent;
6063 		if (dscale < 0)
6064 			dscale = 0;
6065 	}
6066 
6067 	/*
6068 	 * Okay, convert pure-decimal representation to base NBASE.  First we need
6069 	 * to determine the converted weight and ndigits.  offset is the number of
6070 	 * decimal zeroes to insert before the first given digit to have a
6071 	 * correctly aligned first NBASE digit.
6072 	 */
6073 	if (dweight >= 0)
6074 		weight = (dweight + 1 + DEC_DIGITS - 1) / DEC_DIGITS - 1;
6075 	else
6076 		weight = -((-dweight - 1) / DEC_DIGITS + 1);
6077 	offset = (weight + 1) * DEC_DIGITS - (dweight + 1);
6078 	ndigits = (ddigits + offset + DEC_DIGITS - 1) / DEC_DIGITS;
6079 
6080 	alloc_var(dest, ndigits);
6081 	dest->sign = sign;
6082 	dest->weight = weight;
6083 	dest->dscale = dscale;
6084 
6085 	i = DEC_DIGITS - offset;
6086 	digits = dest->digits;
6087 
6088 	while (ndigits-- > 0)
6089 	{
6090 #if DEC_DIGITS == 4
6091 		*digits++ = ((decdigits[i] * 10 + decdigits[i + 1]) * 10 +
6092 					 decdigits[i + 2]) * 10 + decdigits[i + 3];
6093 #elif DEC_DIGITS == 2
6094 		*digits++ = decdigits[i] * 10 + decdigits[i + 1];
6095 #elif DEC_DIGITS == 1
6096 		*digits++ = decdigits[i];
6097 #else
6098 #error unsupported NBASE
6099 #endif
6100 		i += DEC_DIGITS;
6101 	}
6102 
6103 	pfree(decdigits);
6104 
6105 	/* Strip any leading/trailing zeroes, and normalize weight if zero */
6106 	strip_var(dest);
6107 
6108 	/* Return end+1 position for caller */
6109 	return cp;
6110 }
6111 
6112 
6113 /*
6114  * set_var_from_num() -
6115  *
6116  *	Convert the packed db format into a variable
6117  */
6118 static void
set_var_from_num(Numeric num,NumericVar * dest)6119 set_var_from_num(Numeric num, NumericVar *dest)
6120 {
6121 	int			ndigits;
6122 
6123 	ndigits = NUMERIC_NDIGITS(num);
6124 
6125 	alloc_var(dest, ndigits);
6126 
6127 	dest->weight = NUMERIC_WEIGHT(num);
6128 	dest->sign = NUMERIC_SIGN(num);
6129 	dest->dscale = NUMERIC_DSCALE(num);
6130 
6131 	memcpy(dest->digits, NUMERIC_DIGITS(num), ndigits * sizeof(NumericDigit));
6132 }
6133 
6134 
6135 /*
6136  * init_var_from_num() -
6137  *
6138  *	Initialize a variable from packed db format. The digits array is not
6139  *	copied, which saves some cycles when the resulting var is not modified.
6140  *	Also, there's no need to call free_var(), as long as you don't assign any
6141  *	other value to it (with set_var_* functions, or by using the var as the
6142  *	destination of a function like add_var())
6143  *
6144  *	CAUTION: Do not modify the digits buffer of a var initialized with this
6145  *	function, e.g by calling round_var() or trunc_var(), as the changes will
6146  *	propagate to the original Numeric! It's OK to use it as the destination
6147  *	argument of one of the calculational functions, though.
6148  */
6149 static void
init_var_from_num(Numeric num,NumericVar * dest)6150 init_var_from_num(Numeric num, NumericVar *dest)
6151 {
6152 	dest->ndigits = NUMERIC_NDIGITS(num);
6153 	dest->weight = NUMERIC_WEIGHT(num);
6154 	dest->sign = NUMERIC_SIGN(num);
6155 	dest->dscale = NUMERIC_DSCALE(num);
6156 	dest->digits = NUMERIC_DIGITS(num);
6157 	dest->buf = NULL;			/* digits array is not palloc'd */
6158 }
6159 
6160 
6161 /*
6162  * set_var_from_var() -
6163  *
6164  *	Copy one variable into another
6165  */
6166 static void
set_var_from_var(const NumericVar * value,NumericVar * dest)6167 set_var_from_var(const NumericVar *value, NumericVar *dest)
6168 {
6169 	NumericDigit *newbuf;
6170 
6171 	newbuf = digitbuf_alloc(value->ndigits + 1);
6172 	newbuf[0] = 0;				/* spare digit for rounding */
6173 	if (value->ndigits > 0)		/* else value->digits might be null */
6174 		memcpy(newbuf + 1, value->digits,
6175 			   value->ndigits * sizeof(NumericDigit));
6176 
6177 	digitbuf_free(dest->buf);
6178 
6179 	memmove(dest, value, sizeof(NumericVar));
6180 	dest->buf = newbuf;
6181 	dest->digits = newbuf + 1;
6182 }
6183 
6184 
6185 /*
6186  * get_str_from_var() -
6187  *
6188  *	Convert a var to text representation (guts of numeric_out).
6189  *	The var is displayed to the number of digits indicated by its dscale.
6190  *	Returns a palloc'd string.
6191  */
6192 static char *
get_str_from_var(const NumericVar * var)6193 get_str_from_var(const NumericVar *var)
6194 {
6195 	int			dscale;
6196 	char	   *str;
6197 	char	   *cp;
6198 	char	   *endcp;
6199 	int			i;
6200 	int			d;
6201 	NumericDigit dig;
6202 
6203 #if DEC_DIGITS > 1
6204 	NumericDigit d1;
6205 #endif
6206 
6207 	dscale = var->dscale;
6208 
6209 	/*
6210 	 * Allocate space for the result.
6211 	 *
6212 	 * i is set to the # of decimal digits before decimal point. dscale is the
6213 	 * # of decimal digits we will print after decimal point. We may generate
6214 	 * as many as DEC_DIGITS-1 excess digits at the end, and in addition we
6215 	 * need room for sign, decimal point, null terminator.
6216 	 */
6217 	i = (var->weight + 1) * DEC_DIGITS;
6218 	if (i <= 0)
6219 		i = 1;
6220 
6221 	str = palloc(i + dscale + DEC_DIGITS + 2);
6222 	cp = str;
6223 
6224 	/*
6225 	 * Output a dash for negative values
6226 	 */
6227 	if (var->sign == NUMERIC_NEG)
6228 		*cp++ = '-';
6229 
6230 	/*
6231 	 * Output all digits before the decimal point
6232 	 */
6233 	if (var->weight < 0)
6234 	{
6235 		d = var->weight + 1;
6236 		*cp++ = '0';
6237 	}
6238 	else
6239 	{
6240 		for (d = 0; d <= var->weight; d++)
6241 		{
6242 			dig = (d < var->ndigits) ? var->digits[d] : 0;
6243 			/* In the first digit, suppress extra leading decimal zeroes */
6244 #if DEC_DIGITS == 4
6245 			{
6246 				bool		putit = (d > 0);
6247 
6248 				d1 = dig / 1000;
6249 				dig -= d1 * 1000;
6250 				putit |= (d1 > 0);
6251 				if (putit)
6252 					*cp++ = d1 + '0';
6253 				d1 = dig / 100;
6254 				dig -= d1 * 100;
6255 				putit |= (d1 > 0);
6256 				if (putit)
6257 					*cp++ = d1 + '0';
6258 				d1 = dig / 10;
6259 				dig -= d1 * 10;
6260 				putit |= (d1 > 0);
6261 				if (putit)
6262 					*cp++ = d1 + '0';
6263 				*cp++ = dig + '0';
6264 			}
6265 #elif DEC_DIGITS == 2
6266 			d1 = dig / 10;
6267 			dig -= d1 * 10;
6268 			if (d1 > 0 || d > 0)
6269 				*cp++ = d1 + '0';
6270 			*cp++ = dig + '0';
6271 #elif DEC_DIGITS == 1
6272 			*cp++ = dig + '0';
6273 #else
6274 #error unsupported NBASE
6275 #endif
6276 		}
6277 	}
6278 
6279 	/*
6280 	 * If requested, output a decimal point and all the digits that follow it.
6281 	 * We initially put out a multiple of DEC_DIGITS digits, then truncate if
6282 	 * needed.
6283 	 */
6284 	if (dscale > 0)
6285 	{
6286 		*cp++ = '.';
6287 		endcp = cp + dscale;
6288 		for (i = 0; i < dscale; d++, i += DEC_DIGITS)
6289 		{
6290 			dig = (d >= 0 && d < var->ndigits) ? var->digits[d] : 0;
6291 #if DEC_DIGITS == 4
6292 			d1 = dig / 1000;
6293 			dig -= d1 * 1000;
6294 			*cp++ = d1 + '0';
6295 			d1 = dig / 100;
6296 			dig -= d1 * 100;
6297 			*cp++ = d1 + '0';
6298 			d1 = dig / 10;
6299 			dig -= d1 * 10;
6300 			*cp++ = d1 + '0';
6301 			*cp++ = dig + '0';
6302 #elif DEC_DIGITS == 2
6303 			d1 = dig / 10;
6304 			dig -= d1 * 10;
6305 			*cp++ = d1 + '0';
6306 			*cp++ = dig + '0';
6307 #elif DEC_DIGITS == 1
6308 			*cp++ = dig + '0';
6309 #else
6310 #error unsupported NBASE
6311 #endif
6312 		}
6313 		cp = endcp;
6314 	}
6315 
6316 	/*
6317 	 * terminate the string and return it
6318 	 */
6319 	*cp = '\0';
6320 	return str;
6321 }
6322 
6323 /*
6324  * get_str_from_var_sci() -
6325  *
6326  *	Convert a var to a normalised scientific notation text representation.
6327  *	This function does the heavy lifting for numeric_out_sci().
6328  *
6329  *	This notation has the general form a * 10^b, where a is known as the
6330  *	"significand" and b is known as the "exponent".
6331  *
6332  *	Because we can't do superscript in ASCII (and because we want to copy
6333  *	printf's behaviour) we display the exponent using E notation, with a
6334  *	minimum of two exponent digits.
6335  *
6336  *	For example, the value 1234 could be output as 1.2e+03.
6337  *
6338  *	We assume that the exponent can fit into an int32.
6339  *
6340  *	rscale is the number of decimal digits desired after the decimal point in
6341  *	the output, negative values will be treated as meaning zero.
6342  *
6343  *	Returns a palloc'd string.
6344  */
6345 static char *
get_str_from_var_sci(const NumericVar * var,int rscale)6346 get_str_from_var_sci(const NumericVar *var, int rscale)
6347 {
6348 	int32		exponent;
6349 	NumericVar	tmp_var;
6350 	size_t		len;
6351 	char	   *str;
6352 	char	   *sig_out;
6353 
6354 	if (rscale < 0)
6355 		rscale = 0;
6356 
6357 	/*
6358 	 * Determine the exponent of this number in normalised form.
6359 	 *
6360 	 * This is the exponent required to represent the number with only one
6361 	 * significant digit before the decimal place.
6362 	 */
6363 	if (var->ndigits > 0)
6364 	{
6365 		exponent = (var->weight + 1) * DEC_DIGITS;
6366 
6367 		/*
6368 		 * Compensate for leading decimal zeroes in the first numeric digit by
6369 		 * decrementing the exponent.
6370 		 */
6371 		exponent -= DEC_DIGITS - (int) log10(var->digits[0]);
6372 	}
6373 	else
6374 	{
6375 		/*
6376 		 * If var has no digits, then it must be zero.
6377 		 *
6378 		 * Zero doesn't technically have a meaningful exponent in normalised
6379 		 * notation, but we just display the exponent as zero for consistency
6380 		 * of output.
6381 		 */
6382 		exponent = 0;
6383 	}
6384 
6385 	/*
6386 	 * Divide var by 10^exponent to get the significand, rounding to rscale
6387 	 * decimal digits in the process.
6388 	 */
6389 	init_var(&tmp_var);
6390 
6391 	power_ten_int(exponent, &tmp_var);
6392 	div_var(var, &tmp_var, &tmp_var, rscale, true);
6393 	sig_out = get_str_from_var(&tmp_var);
6394 
6395 	free_var(&tmp_var);
6396 
6397 	/*
6398 	 * Allocate space for the result.
6399 	 *
6400 	 * In addition to the significand, we need room for the exponent
6401 	 * decoration ("e"), the sign of the exponent, up to 10 digits for the
6402 	 * exponent itself, and of course the null terminator.
6403 	 */
6404 	len = strlen(sig_out) + 13;
6405 	str = palloc(len);
6406 	snprintf(str, len, "%se%+03d", sig_out, exponent);
6407 
6408 	pfree(sig_out);
6409 
6410 	return str;
6411 }
6412 
6413 
6414 /*
6415  * make_result_opt_error() -
6416  *
6417  *	Create the packed db numeric format in palloc()'d memory from
6418  *	a variable.  If "*have_error" flag is provided, on error it's set to
6419  *	true, NULL returned.  This is helpful when caller need to handle errors
6420  *	by itself.
6421  */
6422 static Numeric
make_result_opt_error(const NumericVar * var,bool * have_error)6423 make_result_opt_error(const NumericVar *var, bool *have_error)
6424 {
6425 	Numeric		result;
6426 	NumericDigit *digits = var->digits;
6427 	int			weight = var->weight;
6428 	int			sign = var->sign;
6429 	int			n;
6430 	Size		len;
6431 
6432 	if (have_error)
6433 		*have_error = false;
6434 
6435 	if (sign == NUMERIC_NAN)
6436 	{
6437 		result = (Numeric) palloc(NUMERIC_HDRSZ_SHORT);
6438 
6439 		SET_VARSIZE(result, NUMERIC_HDRSZ_SHORT);
6440 		result->choice.n_header = NUMERIC_NAN;
6441 		/* the header word is all we need */
6442 
6443 		dump_numeric("make_result()", result);
6444 		return result;
6445 	}
6446 
6447 	n = var->ndigits;
6448 
6449 	/* truncate leading zeroes */
6450 	while (n > 0 && *digits == 0)
6451 	{
6452 		digits++;
6453 		weight--;
6454 		n--;
6455 	}
6456 	/* truncate trailing zeroes */
6457 	while (n > 0 && digits[n - 1] == 0)
6458 		n--;
6459 
6460 	/* If zero result, force to weight=0 and positive sign */
6461 	if (n == 0)
6462 	{
6463 		weight = 0;
6464 		sign = NUMERIC_POS;
6465 	}
6466 
6467 	/* Build the result */
6468 	if (NUMERIC_CAN_BE_SHORT(var->dscale, weight))
6469 	{
6470 		len = NUMERIC_HDRSZ_SHORT + n * sizeof(NumericDigit);
6471 		result = (Numeric) palloc(len);
6472 		SET_VARSIZE(result, len);
6473 		result->choice.n_short.n_header =
6474 			(sign == NUMERIC_NEG ? (NUMERIC_SHORT | NUMERIC_SHORT_SIGN_MASK)
6475 			 : NUMERIC_SHORT)
6476 			| (var->dscale << NUMERIC_SHORT_DSCALE_SHIFT)
6477 			| (weight < 0 ? NUMERIC_SHORT_WEIGHT_SIGN_MASK : 0)
6478 			| (weight & NUMERIC_SHORT_WEIGHT_MASK);
6479 	}
6480 	else
6481 	{
6482 		len = NUMERIC_HDRSZ + n * sizeof(NumericDigit);
6483 		result = (Numeric) palloc(len);
6484 		SET_VARSIZE(result, len);
6485 		result->choice.n_long.n_sign_dscale =
6486 			sign | (var->dscale & NUMERIC_DSCALE_MASK);
6487 		result->choice.n_long.n_weight = weight;
6488 	}
6489 
6490 	Assert(NUMERIC_NDIGITS(result) == n);
6491 	if (n > 0)
6492 		memcpy(NUMERIC_DIGITS(result), digits, n * sizeof(NumericDigit));
6493 
6494 	/* Check for overflow of int16 fields */
6495 	if (NUMERIC_WEIGHT(result) != weight ||
6496 		NUMERIC_DSCALE(result) != var->dscale)
6497 	{
6498 		if (have_error)
6499 		{
6500 			*have_error = true;
6501 			return NULL;
6502 		}
6503 		else
6504 		{
6505 			ereport(ERROR,
6506 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6507 					 errmsg("value overflows numeric format")));
6508 		}
6509 	}
6510 
6511 	dump_numeric("make_result()", result);
6512 	return result;
6513 }
6514 
6515 
6516 /*
6517  * make_result() -
6518  *
6519  *	An interface to make_result_opt_error() without "have_error" argument.
6520  */
6521 static Numeric
make_result(const NumericVar * var)6522 make_result(const NumericVar *var)
6523 {
6524 	return make_result_opt_error(var, NULL);
6525 }
6526 
6527 
6528 /*
6529  * apply_typmod() -
6530  *
6531  *	Do bounds checking and rounding according to the attributes
6532  *	typmod field.
6533  */
6534 static void
apply_typmod(NumericVar * var,int32 typmod)6535 apply_typmod(NumericVar *var, int32 typmod)
6536 {
6537 	int			precision;
6538 	int			scale;
6539 	int			maxdigits;
6540 	int			ddigits;
6541 	int			i;
6542 
6543 	/* Do nothing if we have a default typmod (-1) */
6544 	if (typmod < (int32) (VARHDRSZ))
6545 		return;
6546 
6547 	typmod -= VARHDRSZ;
6548 	precision = (typmod >> 16) & 0xffff;
6549 	scale = typmod & 0xffff;
6550 	maxdigits = precision - scale;
6551 
6552 	/* Round to target scale (and set var->dscale) */
6553 	round_var(var, scale);
6554 
6555 	/*
6556 	 * Check for overflow - note we can't do this before rounding, because
6557 	 * rounding could raise the weight.  Also note that the var's weight could
6558 	 * be inflated by leading zeroes, which will be stripped before storage
6559 	 * but perhaps might not have been yet. In any case, we must recognize a
6560 	 * true zero, whose weight doesn't mean anything.
6561 	 */
6562 	ddigits = (var->weight + 1) * DEC_DIGITS;
6563 	if (ddigits > maxdigits)
6564 	{
6565 		/* Determine true weight; and check for all-zero result */
6566 		for (i = 0; i < var->ndigits; i++)
6567 		{
6568 			NumericDigit dig = var->digits[i];
6569 
6570 			if (dig)
6571 			{
6572 				/* Adjust for any high-order decimal zero digits */
6573 #if DEC_DIGITS == 4
6574 				if (dig < 10)
6575 					ddigits -= 3;
6576 				else if (dig < 100)
6577 					ddigits -= 2;
6578 				else if (dig < 1000)
6579 					ddigits -= 1;
6580 #elif DEC_DIGITS == 2
6581 				if (dig < 10)
6582 					ddigits -= 1;
6583 #elif DEC_DIGITS == 1
6584 				/* no adjustment */
6585 #else
6586 #error unsupported NBASE
6587 #endif
6588 				if (ddigits > maxdigits)
6589 					ereport(ERROR,
6590 							(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6591 							 errmsg("numeric field overflow"),
6592 							 errdetail("A field with precision %d, scale %d must round to an absolute value less than %s%d.",
6593 									   precision, scale,
6594 					/* Display 10^0 as 1 */
6595 									   maxdigits ? "10^" : "",
6596 									   maxdigits ? maxdigits : 1
6597 									   )));
6598 				break;
6599 			}
6600 			ddigits -= DEC_DIGITS;
6601 		}
6602 	}
6603 }
6604 
6605 /*
6606  * Convert numeric to int8, rounding if needed.
6607  *
6608  * If overflow, return false (no error is raised).  Return true if okay.
6609  */
6610 static bool
numericvar_to_int64(const NumericVar * var,int64 * result)6611 numericvar_to_int64(const NumericVar *var, int64 *result)
6612 {
6613 	NumericDigit *digits;
6614 	int			ndigits;
6615 	int			weight;
6616 	int			i;
6617 	int64		val;
6618 	bool		neg;
6619 	NumericVar	rounded;
6620 
6621 	/* Round to nearest integer */
6622 	init_var(&rounded);
6623 	set_var_from_var(var, &rounded);
6624 	round_var(&rounded, 0);
6625 
6626 	/* Check for zero input */
6627 	strip_var(&rounded);
6628 	ndigits = rounded.ndigits;
6629 	if (ndigits == 0)
6630 	{
6631 		*result = 0;
6632 		free_var(&rounded);
6633 		return true;
6634 	}
6635 
6636 	/*
6637 	 * For input like 10000000000, we must treat stripped digits as real. So
6638 	 * the loop assumes there are weight+1 digits before the decimal point.
6639 	 */
6640 	weight = rounded.weight;
6641 	Assert(weight >= 0 && ndigits <= weight + 1);
6642 
6643 	/*
6644 	 * Construct the result. To avoid issues with converting a value
6645 	 * corresponding to INT64_MIN (which can't be represented as a positive 64
6646 	 * bit two's complement integer), accumulate value as a negative number.
6647 	 */
6648 	digits = rounded.digits;
6649 	neg = (rounded.sign == NUMERIC_NEG);
6650 	val = -digits[0];
6651 	for (i = 1; i <= weight; i++)
6652 	{
6653 		if (unlikely(pg_mul_s64_overflow(val, NBASE, &val)))
6654 		{
6655 			free_var(&rounded);
6656 			return false;
6657 		}
6658 
6659 		if (i < ndigits)
6660 		{
6661 			if (unlikely(pg_sub_s64_overflow(val, digits[i], &val)))
6662 			{
6663 				free_var(&rounded);
6664 				return false;
6665 			}
6666 		}
6667 	}
6668 
6669 	free_var(&rounded);
6670 
6671 	if (!neg)
6672 	{
6673 		if (unlikely(val == PG_INT64_MIN))
6674 			return false;
6675 		val = -val;
6676 	}
6677 	*result = val;
6678 
6679 	return true;
6680 }
6681 
6682 /*
6683  * Convert int8 value to numeric.
6684  */
6685 static void
int64_to_numericvar(int64 val,NumericVar * var)6686 int64_to_numericvar(int64 val, NumericVar *var)
6687 {
6688 	uint64		uval,
6689 				newuval;
6690 	NumericDigit *ptr;
6691 	int			ndigits;
6692 
6693 	/* int64 can require at most 19 decimal digits; add one for safety */
6694 	alloc_var(var, 20 / DEC_DIGITS);
6695 	if (val < 0)
6696 	{
6697 		var->sign = NUMERIC_NEG;
6698 		uval = -val;
6699 	}
6700 	else
6701 	{
6702 		var->sign = NUMERIC_POS;
6703 		uval = val;
6704 	}
6705 	var->dscale = 0;
6706 	if (val == 0)
6707 	{
6708 		var->ndigits = 0;
6709 		var->weight = 0;
6710 		return;
6711 	}
6712 	ptr = var->digits + var->ndigits;
6713 	ndigits = 0;
6714 	do
6715 	{
6716 		ptr--;
6717 		ndigits++;
6718 		newuval = uval / NBASE;
6719 		*ptr = uval - newuval * NBASE;
6720 		uval = newuval;
6721 	} while (uval);
6722 	var->digits = ptr;
6723 	var->ndigits = ndigits;
6724 	var->weight = ndigits - 1;
6725 }
6726 
6727 #ifdef HAVE_INT128
6728 /*
6729  * Convert numeric to int128, rounding if needed.
6730  *
6731  * If overflow, return false (no error is raised).  Return true if okay.
6732  */
6733 static bool
numericvar_to_int128(const NumericVar * var,int128 * result)6734 numericvar_to_int128(const NumericVar *var, int128 *result)
6735 {
6736 	NumericDigit *digits;
6737 	int			ndigits;
6738 	int			weight;
6739 	int			i;
6740 	int128		val,
6741 				oldval;
6742 	bool		neg;
6743 	NumericVar	rounded;
6744 
6745 	/* Round to nearest integer */
6746 	init_var(&rounded);
6747 	set_var_from_var(var, &rounded);
6748 	round_var(&rounded, 0);
6749 
6750 	/* Check for zero input */
6751 	strip_var(&rounded);
6752 	ndigits = rounded.ndigits;
6753 	if (ndigits == 0)
6754 	{
6755 		*result = 0;
6756 		free_var(&rounded);
6757 		return true;
6758 	}
6759 
6760 	/*
6761 	 * For input like 10000000000, we must treat stripped digits as real. So
6762 	 * the loop assumes there are weight+1 digits before the decimal point.
6763 	 */
6764 	weight = rounded.weight;
6765 	Assert(weight >= 0 && ndigits <= weight + 1);
6766 
6767 	/* Construct the result */
6768 	digits = rounded.digits;
6769 	neg = (rounded.sign == NUMERIC_NEG);
6770 	val = digits[0];
6771 	for (i = 1; i <= weight; i++)
6772 	{
6773 		oldval = val;
6774 		val *= NBASE;
6775 		if (i < ndigits)
6776 			val += digits[i];
6777 
6778 		/*
6779 		 * The overflow check is a bit tricky because we want to accept
6780 		 * INT128_MIN, which will overflow the positive accumulator.  We can
6781 		 * detect this case easily though because INT128_MIN is the only
6782 		 * nonzero value for which -val == val (on a two's complement machine,
6783 		 * anyway).
6784 		 */
6785 		if ((val / NBASE) != oldval)	/* possible overflow? */
6786 		{
6787 			if (!neg || (-val) != val || val == 0 || oldval < 0)
6788 			{
6789 				free_var(&rounded);
6790 				return false;
6791 			}
6792 		}
6793 	}
6794 
6795 	free_var(&rounded);
6796 
6797 	*result = neg ? -val : val;
6798 	return true;
6799 }
6800 
6801 /*
6802  * Convert 128 bit integer to numeric.
6803  */
6804 static void
int128_to_numericvar(int128 val,NumericVar * var)6805 int128_to_numericvar(int128 val, NumericVar *var)
6806 {
6807 	uint128		uval,
6808 				newuval;
6809 	NumericDigit *ptr;
6810 	int			ndigits;
6811 
6812 	/* int128 can require at most 39 decimal digits; add one for safety */
6813 	alloc_var(var, 40 / DEC_DIGITS);
6814 	if (val < 0)
6815 	{
6816 		var->sign = NUMERIC_NEG;
6817 		uval = -val;
6818 	}
6819 	else
6820 	{
6821 		var->sign = NUMERIC_POS;
6822 		uval = val;
6823 	}
6824 	var->dscale = 0;
6825 	if (val == 0)
6826 	{
6827 		var->ndigits = 0;
6828 		var->weight = 0;
6829 		return;
6830 	}
6831 	ptr = var->digits + var->ndigits;
6832 	ndigits = 0;
6833 	do
6834 	{
6835 		ptr--;
6836 		ndigits++;
6837 		newuval = uval / NBASE;
6838 		*ptr = uval - newuval * NBASE;
6839 		uval = newuval;
6840 	} while (uval);
6841 	var->digits = ptr;
6842 	var->ndigits = ndigits;
6843 	var->weight = ndigits - 1;
6844 }
6845 #endif
6846 
6847 /*
6848  * Convert numeric to float8; if out of range, return +/- HUGE_VAL
6849  */
6850 static double
numeric_to_double_no_overflow(Numeric num)6851 numeric_to_double_no_overflow(Numeric num)
6852 {
6853 	char	   *tmp;
6854 	double		val;
6855 	char	   *endptr;
6856 
6857 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
6858 											  NumericGetDatum(num)));
6859 
6860 	/* unlike float8in, we ignore ERANGE from strtod */
6861 	val = strtod(tmp, &endptr);
6862 	if (*endptr != '\0')
6863 	{
6864 		/* shouldn't happen ... */
6865 		ereport(ERROR,
6866 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6867 				 errmsg("invalid input syntax for type %s: \"%s\"",
6868 						"double precision", tmp)));
6869 	}
6870 
6871 	pfree(tmp);
6872 
6873 	return val;
6874 }
6875 
6876 /* As above, but work from a NumericVar */
6877 static double
numericvar_to_double_no_overflow(const NumericVar * var)6878 numericvar_to_double_no_overflow(const NumericVar *var)
6879 {
6880 	char	   *tmp;
6881 	double		val;
6882 	char	   *endptr;
6883 
6884 	tmp = get_str_from_var(var);
6885 
6886 	/* unlike float8in, we ignore ERANGE from strtod */
6887 	val = strtod(tmp, &endptr);
6888 	if (*endptr != '\0')
6889 	{
6890 		/* shouldn't happen ... */
6891 		ereport(ERROR,
6892 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6893 				 errmsg("invalid input syntax for type %s: \"%s\"",
6894 						"double precision", tmp)));
6895 	}
6896 
6897 	pfree(tmp);
6898 
6899 	return val;
6900 }
6901 
6902 
6903 /*
6904  * cmp_var() -
6905  *
6906  *	Compare two values on variable level.  We assume zeroes have been
6907  *	truncated to no digits.
6908  */
6909 static int
cmp_var(const NumericVar * var1,const NumericVar * var2)6910 cmp_var(const NumericVar *var1, const NumericVar *var2)
6911 {
6912 	return cmp_var_common(var1->digits, var1->ndigits,
6913 						  var1->weight, var1->sign,
6914 						  var2->digits, var2->ndigits,
6915 						  var2->weight, var2->sign);
6916 }
6917 
6918 /*
6919  * cmp_var_common() -
6920  *
6921  *	Main routine of cmp_var(). This function can be used by both
6922  *	NumericVar and Numeric.
6923  */
6924 static int
cmp_var_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,int var1sign,const NumericDigit * var2digits,int var2ndigits,int var2weight,int var2sign)6925 cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
6926 			   int var1weight, int var1sign,
6927 			   const NumericDigit *var2digits, int var2ndigits,
6928 			   int var2weight, int var2sign)
6929 {
6930 	if (var1ndigits == 0)
6931 	{
6932 		if (var2ndigits == 0)
6933 			return 0;
6934 		if (var2sign == NUMERIC_NEG)
6935 			return 1;
6936 		return -1;
6937 	}
6938 	if (var2ndigits == 0)
6939 	{
6940 		if (var1sign == NUMERIC_POS)
6941 			return 1;
6942 		return -1;
6943 	}
6944 
6945 	if (var1sign == NUMERIC_POS)
6946 	{
6947 		if (var2sign == NUMERIC_NEG)
6948 			return 1;
6949 		return cmp_abs_common(var1digits, var1ndigits, var1weight,
6950 							  var2digits, var2ndigits, var2weight);
6951 	}
6952 
6953 	if (var2sign == NUMERIC_POS)
6954 		return -1;
6955 
6956 	return cmp_abs_common(var2digits, var2ndigits, var2weight,
6957 						  var1digits, var1ndigits, var1weight);
6958 }
6959 
6960 
6961 /*
6962  * add_var() -
6963  *
6964  *	Full version of add functionality on variable level (handling signs).
6965  *	result might point to one of the operands too without danger.
6966  */
6967 static void
add_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)6968 add_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
6969 {
6970 	/*
6971 	 * Decide on the signs of the two variables what to do
6972 	 */
6973 	if (var1->sign == NUMERIC_POS)
6974 	{
6975 		if (var2->sign == NUMERIC_POS)
6976 		{
6977 			/*
6978 			 * Both are positive result = +(ABS(var1) + ABS(var2))
6979 			 */
6980 			add_abs(var1, var2, result);
6981 			result->sign = NUMERIC_POS;
6982 		}
6983 		else
6984 		{
6985 			/*
6986 			 * var1 is positive, var2 is negative Must compare absolute values
6987 			 */
6988 			switch (cmp_abs(var1, var2))
6989 			{
6990 				case 0:
6991 					/* ----------
6992 					 * ABS(var1) == ABS(var2)
6993 					 * result = ZERO
6994 					 * ----------
6995 					 */
6996 					zero_var(result);
6997 					result->dscale = Max(var1->dscale, var2->dscale);
6998 					break;
6999 
7000 				case 1:
7001 					/* ----------
7002 					 * ABS(var1) > ABS(var2)
7003 					 * result = +(ABS(var1) - ABS(var2))
7004 					 * ----------
7005 					 */
7006 					sub_abs(var1, var2, result);
7007 					result->sign = NUMERIC_POS;
7008 					break;
7009 
7010 				case -1:
7011 					/* ----------
7012 					 * ABS(var1) < ABS(var2)
7013 					 * result = -(ABS(var2) - ABS(var1))
7014 					 * ----------
7015 					 */
7016 					sub_abs(var2, var1, result);
7017 					result->sign = NUMERIC_NEG;
7018 					break;
7019 			}
7020 		}
7021 	}
7022 	else
7023 	{
7024 		if (var2->sign == NUMERIC_POS)
7025 		{
7026 			/* ----------
7027 			 * var1 is negative, var2 is positive
7028 			 * Must compare absolute values
7029 			 * ----------
7030 			 */
7031 			switch (cmp_abs(var1, var2))
7032 			{
7033 				case 0:
7034 					/* ----------
7035 					 * ABS(var1) == ABS(var2)
7036 					 * result = ZERO
7037 					 * ----------
7038 					 */
7039 					zero_var(result);
7040 					result->dscale = Max(var1->dscale, var2->dscale);
7041 					break;
7042 
7043 				case 1:
7044 					/* ----------
7045 					 * ABS(var1) > ABS(var2)
7046 					 * result = -(ABS(var1) - ABS(var2))
7047 					 * ----------
7048 					 */
7049 					sub_abs(var1, var2, result);
7050 					result->sign = NUMERIC_NEG;
7051 					break;
7052 
7053 				case -1:
7054 					/* ----------
7055 					 * ABS(var1) < ABS(var2)
7056 					 * result = +(ABS(var2) - ABS(var1))
7057 					 * ----------
7058 					 */
7059 					sub_abs(var2, var1, result);
7060 					result->sign = NUMERIC_POS;
7061 					break;
7062 			}
7063 		}
7064 		else
7065 		{
7066 			/* ----------
7067 			 * Both are negative
7068 			 * result = -(ABS(var1) + ABS(var2))
7069 			 * ----------
7070 			 */
7071 			add_abs(var1, var2, result);
7072 			result->sign = NUMERIC_NEG;
7073 		}
7074 	}
7075 }
7076 
7077 
7078 /*
7079  * sub_var() -
7080  *
7081  *	Full version of sub functionality on variable level (handling signs).
7082  *	result might point to one of the operands too without danger.
7083  */
7084 static void
sub_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)7085 sub_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
7086 {
7087 	/*
7088 	 * Decide on the signs of the two variables what to do
7089 	 */
7090 	if (var1->sign == NUMERIC_POS)
7091 	{
7092 		if (var2->sign == NUMERIC_NEG)
7093 		{
7094 			/* ----------
7095 			 * var1 is positive, var2 is negative
7096 			 * result = +(ABS(var1) + ABS(var2))
7097 			 * ----------
7098 			 */
7099 			add_abs(var1, var2, result);
7100 			result->sign = NUMERIC_POS;
7101 		}
7102 		else
7103 		{
7104 			/* ----------
7105 			 * Both are positive
7106 			 * Must compare absolute values
7107 			 * ----------
7108 			 */
7109 			switch (cmp_abs(var1, var2))
7110 			{
7111 				case 0:
7112 					/* ----------
7113 					 * ABS(var1) == ABS(var2)
7114 					 * result = ZERO
7115 					 * ----------
7116 					 */
7117 					zero_var(result);
7118 					result->dscale = Max(var1->dscale, var2->dscale);
7119 					break;
7120 
7121 				case 1:
7122 					/* ----------
7123 					 * ABS(var1) > ABS(var2)
7124 					 * result = +(ABS(var1) - ABS(var2))
7125 					 * ----------
7126 					 */
7127 					sub_abs(var1, var2, result);
7128 					result->sign = NUMERIC_POS;
7129 					break;
7130 
7131 				case -1:
7132 					/* ----------
7133 					 * ABS(var1) < ABS(var2)
7134 					 * result = -(ABS(var2) - ABS(var1))
7135 					 * ----------
7136 					 */
7137 					sub_abs(var2, var1, result);
7138 					result->sign = NUMERIC_NEG;
7139 					break;
7140 			}
7141 		}
7142 	}
7143 	else
7144 	{
7145 		if (var2->sign == NUMERIC_NEG)
7146 		{
7147 			/* ----------
7148 			 * Both are negative
7149 			 * Must compare absolute values
7150 			 * ----------
7151 			 */
7152 			switch (cmp_abs(var1, var2))
7153 			{
7154 				case 0:
7155 					/* ----------
7156 					 * ABS(var1) == ABS(var2)
7157 					 * result = ZERO
7158 					 * ----------
7159 					 */
7160 					zero_var(result);
7161 					result->dscale = Max(var1->dscale, var2->dscale);
7162 					break;
7163 
7164 				case 1:
7165 					/* ----------
7166 					 * ABS(var1) > ABS(var2)
7167 					 * result = -(ABS(var1) - ABS(var2))
7168 					 * ----------
7169 					 */
7170 					sub_abs(var1, var2, result);
7171 					result->sign = NUMERIC_NEG;
7172 					break;
7173 
7174 				case -1:
7175 					/* ----------
7176 					 * ABS(var1) < ABS(var2)
7177 					 * result = +(ABS(var2) - ABS(var1))
7178 					 * ----------
7179 					 */
7180 					sub_abs(var2, var1, result);
7181 					result->sign = NUMERIC_POS;
7182 					break;
7183 			}
7184 		}
7185 		else
7186 		{
7187 			/* ----------
7188 			 * var1 is negative, var2 is positive
7189 			 * result = -(ABS(var1) + ABS(var2))
7190 			 * ----------
7191 			 */
7192 			add_abs(var1, var2, result);
7193 			result->sign = NUMERIC_NEG;
7194 		}
7195 	}
7196 }
7197 
7198 
7199 /*
7200  * mul_var() -
7201  *
7202  *	Multiplication on variable level. Product of var1 * var2 is stored
7203  *	in result.  Result is rounded to no more than rscale fractional digits.
7204  */
7205 static void
mul_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale)7206 mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
7207 		int rscale)
7208 {
7209 	int			res_ndigits;
7210 	int			res_sign;
7211 	int			res_weight;
7212 	int			maxdigits;
7213 	int		   *dig;
7214 	int			carry;
7215 	int			maxdig;
7216 	int			newdig;
7217 	int			var1ndigits;
7218 	int			var2ndigits;
7219 	NumericDigit *var1digits;
7220 	NumericDigit *var2digits;
7221 	NumericDigit *res_digits;
7222 	int			i,
7223 				i1,
7224 				i2;
7225 
7226 	/*
7227 	 * Arrange for var1 to be the shorter of the two numbers.  This improves
7228 	 * performance because the inner multiplication loop is much simpler than
7229 	 * the outer loop, so it's better to have a smaller number of iterations
7230 	 * of the outer loop.  This also reduces the number of times that the
7231 	 * accumulator array needs to be normalized.
7232 	 */
7233 	if (var1->ndigits > var2->ndigits)
7234 	{
7235 		const NumericVar *tmp = var1;
7236 
7237 		var1 = var2;
7238 		var2 = tmp;
7239 	}
7240 
7241 	/* copy these values into local vars for speed in inner loop */
7242 	var1ndigits = var1->ndigits;
7243 	var2ndigits = var2->ndigits;
7244 	var1digits = var1->digits;
7245 	var2digits = var2->digits;
7246 
7247 	if (var1ndigits == 0 || var2ndigits == 0)
7248 	{
7249 		/* one or both inputs is zero; so is result */
7250 		zero_var(result);
7251 		result->dscale = rscale;
7252 		return;
7253 	}
7254 
7255 	/* Determine result sign and (maximum possible) weight */
7256 	if (var1->sign == var2->sign)
7257 		res_sign = NUMERIC_POS;
7258 	else
7259 		res_sign = NUMERIC_NEG;
7260 	res_weight = var1->weight + var2->weight + 2;
7261 
7262 	/*
7263 	 * Determine the number of result digits to compute.  If the exact result
7264 	 * would have more than rscale fractional digits, truncate the computation
7265 	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
7266 	 * would only contribute to the right of that.  (This will give the exact
7267 	 * rounded-to-rscale answer unless carries out of the ignored positions
7268 	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
7269 	 *
7270 	 * Note: an exact computation could not produce more than var1ndigits +
7271 	 * var2ndigits digits, but we allocate one extra output digit in case
7272 	 * rscale-driven rounding produces a carry out of the highest exact digit.
7273 	 */
7274 	res_ndigits = var1ndigits + var2ndigits + 1;
7275 	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
7276 		MUL_GUARD_DIGITS;
7277 	res_ndigits = Min(res_ndigits, maxdigits);
7278 
7279 	if (res_ndigits < 3)
7280 	{
7281 		/* All input digits will be ignored; so result is zero */
7282 		zero_var(result);
7283 		result->dscale = rscale;
7284 		return;
7285 	}
7286 
7287 	/*
7288 	 * We do the arithmetic in an array "dig[]" of signed int's.  Since
7289 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
7290 	 * to avoid normalizing carries immediately.
7291 	 *
7292 	 * maxdig tracks the maximum possible value of any dig[] entry; when this
7293 	 * threatens to exceed INT_MAX, we take the time to propagate carries.
7294 	 * Furthermore, we need to ensure that overflow doesn't occur during the
7295 	 * carry propagation passes either.  The carry values could be as much as
7296 	 * INT_MAX/NBASE, so really we must normalize when digits threaten to
7297 	 * exceed INT_MAX - INT_MAX/NBASE.
7298 	 *
7299 	 * To avoid overflow in maxdig itself, it actually represents the max
7300 	 * possible value divided by NBASE-1, ie, at the top of the loop it is
7301 	 * known that no dig[] entry exceeds maxdig * (NBASE-1).
7302 	 */
7303 	dig = (int *) palloc0(res_ndigits * sizeof(int));
7304 	maxdig = 0;
7305 
7306 	/*
7307 	 * The least significant digits of var1 should be ignored if they don't
7308 	 * contribute directly to the first res_ndigits digits of the result that
7309 	 * we are computing.
7310 	 *
7311 	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
7312 	 * i1+i2+2 of the accumulator array, so we need only consider digits of
7313 	 * var1 for which i1 <= res_ndigits - 3.
7314 	 */
7315 	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
7316 	{
7317 		int			var1digit = var1digits[i1];
7318 
7319 		if (var1digit == 0)
7320 			continue;
7321 
7322 		/* Time to normalize? */
7323 		maxdig += var1digit;
7324 		if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
7325 		{
7326 			/* Yes, do it */
7327 			carry = 0;
7328 			for (i = res_ndigits - 1; i >= 0; i--)
7329 			{
7330 				newdig = dig[i] + carry;
7331 				if (newdig >= NBASE)
7332 				{
7333 					carry = newdig / NBASE;
7334 					newdig -= carry * NBASE;
7335 				}
7336 				else
7337 					carry = 0;
7338 				dig[i] = newdig;
7339 			}
7340 			Assert(carry == 0);
7341 			/* Reset maxdig to indicate new worst-case */
7342 			maxdig = 1 + var1digit;
7343 		}
7344 
7345 		/*
7346 		 * Add the appropriate multiple of var2 into the accumulator.
7347 		 *
7348 		 * As above, digits of var2 can be ignored if they don't contribute,
7349 		 * so we only include digits for which i1+i2+2 <= res_ndigits - 1.
7350 		 */
7351 		for (i2 = Min(var2ndigits - 1, res_ndigits - i1 - 3), i = i1 + i2 + 2;
7352 			 i2 >= 0; i2--)
7353 			dig[i--] += var1digit * var2digits[i2];
7354 	}
7355 
7356 	/*
7357 	 * Now we do a final carry propagation pass to normalize the result, which
7358 	 * we combine with storing the result digits into the output. Note that
7359 	 * this is still done at full precision w/guard digits.
7360 	 */
7361 	alloc_var(result, res_ndigits);
7362 	res_digits = result->digits;
7363 	carry = 0;
7364 	for (i = res_ndigits - 1; i >= 0; i--)
7365 	{
7366 		newdig = dig[i] + carry;
7367 		if (newdig >= NBASE)
7368 		{
7369 			carry = newdig / NBASE;
7370 			newdig -= carry * NBASE;
7371 		}
7372 		else
7373 			carry = 0;
7374 		res_digits[i] = newdig;
7375 	}
7376 	Assert(carry == 0);
7377 
7378 	pfree(dig);
7379 
7380 	/*
7381 	 * Finally, round the result to the requested precision.
7382 	 */
7383 	result->weight = res_weight;
7384 	result->sign = res_sign;
7385 
7386 	/* Round to target rscale (and set result->dscale) */
7387 	round_var(result, rscale);
7388 
7389 	/* Strip leading and trailing zeroes */
7390 	strip_var(result);
7391 }
7392 
7393 
7394 /*
7395  * div_var() -
7396  *
7397  *	Division on variable level. Quotient of var1 / var2 is stored in result.
7398  *	The quotient is figured to exactly rscale fractional digits.
7399  *	If round is true, it is rounded at the rscale'th digit; if false, it
7400  *	is truncated (towards zero) at that digit.
7401  */
7402 static void
div_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)7403 div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
7404 		int rscale, bool round)
7405 {
7406 	int			div_ndigits;
7407 	int			res_ndigits;
7408 	int			res_sign;
7409 	int			res_weight;
7410 	int			carry;
7411 	int			borrow;
7412 	int			divisor1;
7413 	int			divisor2;
7414 	NumericDigit *dividend;
7415 	NumericDigit *divisor;
7416 	NumericDigit *res_digits;
7417 	int			i;
7418 	int			j;
7419 
7420 	/* copy these values into local vars for speed in inner loop */
7421 	int			var1ndigits = var1->ndigits;
7422 	int			var2ndigits = var2->ndigits;
7423 
7424 	/*
7425 	 * First of all division by zero check; we must not be handed an
7426 	 * unnormalized divisor.
7427 	 */
7428 	if (var2ndigits == 0 || var2->digits[0] == 0)
7429 		ereport(ERROR,
7430 				(errcode(ERRCODE_DIVISION_BY_ZERO),
7431 				 errmsg("division by zero")));
7432 
7433 	/*
7434 	 * Now result zero check
7435 	 */
7436 	if (var1ndigits == 0)
7437 	{
7438 		zero_var(result);
7439 		result->dscale = rscale;
7440 		return;
7441 	}
7442 
7443 	/*
7444 	 * Determine the result sign, weight and number of digits to calculate.
7445 	 * The weight figured here is correct if the emitted quotient has no
7446 	 * leading zero digits; otherwise strip_var() will fix things up.
7447 	 */
7448 	if (var1->sign == var2->sign)
7449 		res_sign = NUMERIC_POS;
7450 	else
7451 		res_sign = NUMERIC_NEG;
7452 	res_weight = var1->weight - var2->weight;
7453 	/* The number of accurate result digits we need to produce: */
7454 	res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7455 	/* ... but always at least 1 */
7456 	res_ndigits = Max(res_ndigits, 1);
7457 	/* If rounding needed, figure one more digit to ensure correct result */
7458 	if (round)
7459 		res_ndigits++;
7460 
7461 	/*
7462 	 * The working dividend normally requires res_ndigits + var2ndigits
7463 	 * digits, but make it at least var1ndigits so we can load all of var1
7464 	 * into it.  (There will be an additional digit dividend[0] in the
7465 	 * dividend space, but for consistency with Knuth's notation we don't
7466 	 * count that in div_ndigits.)
7467 	 */
7468 	div_ndigits = res_ndigits + var2ndigits;
7469 	div_ndigits = Max(div_ndigits, var1ndigits);
7470 
7471 	/*
7472 	 * We need a workspace with room for the working dividend (div_ndigits+1
7473 	 * digits) plus room for the possibly-normalized divisor (var2ndigits
7474 	 * digits).  It is convenient also to have a zero at divisor[0] with the
7475 	 * actual divisor data in divisor[1 .. var2ndigits].  Transferring the
7476 	 * digits into the workspace also allows us to realloc the result (which
7477 	 * might be the same as either input var) before we begin the main loop.
7478 	 * Note that we use palloc0 to ensure that divisor[0], dividend[0], and
7479 	 * any additional dividend positions beyond var1ndigits, start out 0.
7480 	 */
7481 	dividend = (NumericDigit *)
7482 		palloc0((div_ndigits + var2ndigits + 2) * sizeof(NumericDigit));
7483 	divisor = dividend + (div_ndigits + 1);
7484 	memcpy(dividend + 1, var1->digits, var1ndigits * sizeof(NumericDigit));
7485 	memcpy(divisor + 1, var2->digits, var2ndigits * sizeof(NumericDigit));
7486 
7487 	/*
7488 	 * Now we can realloc the result to hold the generated quotient digits.
7489 	 */
7490 	alloc_var(result, res_ndigits);
7491 	res_digits = result->digits;
7492 
7493 	if (var2ndigits == 1)
7494 	{
7495 		/*
7496 		 * If there's only a single divisor digit, we can use a fast path (cf.
7497 		 * Knuth section 4.3.1 exercise 16).
7498 		 */
7499 		divisor1 = divisor[1];
7500 		carry = 0;
7501 		for (i = 0; i < res_ndigits; i++)
7502 		{
7503 			carry = carry * NBASE + dividend[i + 1];
7504 			res_digits[i] = carry / divisor1;
7505 			carry = carry % divisor1;
7506 		}
7507 	}
7508 	else
7509 	{
7510 		/*
7511 		 * The full multiple-place algorithm is taken from Knuth volume 2,
7512 		 * Algorithm 4.3.1D.
7513 		 *
7514 		 * We need the first divisor digit to be >= NBASE/2.  If it isn't,
7515 		 * make it so by scaling up both the divisor and dividend by the
7516 		 * factor "d".  (The reason for allocating dividend[0] above is to
7517 		 * leave room for possible carry here.)
7518 		 */
7519 		if (divisor[1] < HALF_NBASE)
7520 		{
7521 			int			d = NBASE / (divisor[1] + 1);
7522 
7523 			carry = 0;
7524 			for (i = var2ndigits; i > 0; i--)
7525 			{
7526 				carry += divisor[i] * d;
7527 				divisor[i] = carry % NBASE;
7528 				carry = carry / NBASE;
7529 			}
7530 			Assert(carry == 0);
7531 			carry = 0;
7532 			/* at this point only var1ndigits of dividend can be nonzero */
7533 			for (i = var1ndigits; i >= 0; i--)
7534 			{
7535 				carry += dividend[i] * d;
7536 				dividend[i] = carry % NBASE;
7537 				carry = carry / NBASE;
7538 			}
7539 			Assert(carry == 0);
7540 			Assert(divisor[1] >= HALF_NBASE);
7541 		}
7542 		/* First 2 divisor digits are used repeatedly in main loop */
7543 		divisor1 = divisor[1];
7544 		divisor2 = divisor[2];
7545 
7546 		/*
7547 		 * Begin the main loop.  Each iteration of this loop produces the j'th
7548 		 * quotient digit by dividing dividend[j .. j + var2ndigits] by the
7549 		 * divisor; this is essentially the same as the common manual
7550 		 * procedure for long division.
7551 		 */
7552 		for (j = 0; j < res_ndigits; j++)
7553 		{
7554 			/* Estimate quotient digit from the first two dividend digits */
7555 			int			next2digits = dividend[j] * NBASE + dividend[j + 1];
7556 			int			qhat;
7557 
7558 			/*
7559 			 * If next2digits are 0, then quotient digit must be 0 and there's
7560 			 * no need to adjust the working dividend.  It's worth testing
7561 			 * here to fall out ASAP when processing trailing zeroes in a
7562 			 * dividend.
7563 			 */
7564 			if (next2digits == 0)
7565 			{
7566 				res_digits[j] = 0;
7567 				continue;
7568 			}
7569 
7570 			if (dividend[j] == divisor1)
7571 				qhat = NBASE - 1;
7572 			else
7573 				qhat = next2digits / divisor1;
7574 
7575 			/*
7576 			 * Adjust quotient digit if it's too large.  Knuth proves that
7577 			 * after this step, the quotient digit will be either correct or
7578 			 * just one too large.  (Note: it's OK to use dividend[j+2] here
7579 			 * because we know the divisor length is at least 2.)
7580 			 */
7581 			while (divisor2 * qhat >
7582 				   (next2digits - qhat * divisor1) * NBASE + dividend[j + 2])
7583 				qhat--;
7584 
7585 			/* As above, need do nothing more when quotient digit is 0 */
7586 			if (qhat > 0)
7587 			{
7588 				/*
7589 				 * Multiply the divisor by qhat, and subtract that from the
7590 				 * working dividend.  "carry" tracks the multiplication,
7591 				 * "borrow" the subtraction (could we fold these together?)
7592 				 */
7593 				carry = 0;
7594 				borrow = 0;
7595 				for (i = var2ndigits; i >= 0; i--)
7596 				{
7597 					carry += divisor[i] * qhat;
7598 					borrow -= carry % NBASE;
7599 					carry = carry / NBASE;
7600 					borrow += dividend[j + i];
7601 					if (borrow < 0)
7602 					{
7603 						dividend[j + i] = borrow + NBASE;
7604 						borrow = -1;
7605 					}
7606 					else
7607 					{
7608 						dividend[j + i] = borrow;
7609 						borrow = 0;
7610 					}
7611 				}
7612 				Assert(carry == 0);
7613 
7614 				/*
7615 				 * If we got a borrow out of the top dividend digit, then
7616 				 * indeed qhat was one too large.  Fix it, and add back the
7617 				 * divisor to correct the working dividend.  (Knuth proves
7618 				 * that this will occur only about 3/NBASE of the time; hence,
7619 				 * it's a good idea to test this code with small NBASE to be
7620 				 * sure this section gets exercised.)
7621 				 */
7622 				if (borrow)
7623 				{
7624 					qhat--;
7625 					carry = 0;
7626 					for (i = var2ndigits; i >= 0; i--)
7627 					{
7628 						carry += dividend[j + i] + divisor[i];
7629 						if (carry >= NBASE)
7630 						{
7631 							dividend[j + i] = carry - NBASE;
7632 							carry = 1;
7633 						}
7634 						else
7635 						{
7636 							dividend[j + i] = carry;
7637 							carry = 0;
7638 						}
7639 					}
7640 					/* A carry should occur here to cancel the borrow above */
7641 					Assert(carry == 1);
7642 				}
7643 			}
7644 
7645 			/* And we're done with this quotient digit */
7646 			res_digits[j] = qhat;
7647 		}
7648 	}
7649 
7650 	pfree(dividend);
7651 
7652 	/*
7653 	 * Finally, round or truncate the result to the requested precision.
7654 	 */
7655 	result->weight = res_weight;
7656 	result->sign = res_sign;
7657 
7658 	/* Round or truncate to target rscale (and set result->dscale) */
7659 	if (round)
7660 		round_var(result, rscale);
7661 	else
7662 		trunc_var(result, rscale);
7663 
7664 	/* Strip leading and trailing zeroes */
7665 	strip_var(result);
7666 }
7667 
7668 
7669 /*
7670  * div_var_fast() -
7671  *
7672  *	This has the same API as div_var, but is implemented using the division
7673  *	algorithm from the "FM" library, rather than Knuth's schoolbook-division
7674  *	approach.  This is significantly faster but can produce inaccurate
7675  *	results, because it sometimes has to propagate rounding to the left,
7676  *	and so we can never be entirely sure that we know the requested digits
7677  *	exactly.  We compute DIV_GUARD_DIGITS extra digits, but there is
7678  *	no certainty that that's enough.  We use this only in the transcendental
7679  *	function calculation routines, where everything is approximate anyway.
7680  *
7681  *	Although we provide a "round" argument for consistency with div_var,
7682  *	it is unwise to use this function with round=false.  In truncation mode
7683  *	it is possible to get a result with no significant digits, for example
7684  *	with rscale=0 we might compute 0.99999... and truncate that to 0 when
7685  *	the correct answer is 1.
7686  */
7687 static void
div_var_fast(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)7688 div_var_fast(const NumericVar *var1, const NumericVar *var2,
7689 			 NumericVar *result, int rscale, bool round)
7690 {
7691 	int			div_ndigits;
7692 	int			load_ndigits;
7693 	int			res_sign;
7694 	int			res_weight;
7695 	int		   *div;
7696 	int			qdigit;
7697 	int			carry;
7698 	int			maxdiv;
7699 	int			newdig;
7700 	NumericDigit *res_digits;
7701 	double		fdividend,
7702 				fdivisor,
7703 				fdivisorinverse,
7704 				fquotient;
7705 	int			qi;
7706 	int			i;
7707 
7708 	/* copy these values into local vars for speed in inner loop */
7709 	int			var1ndigits = var1->ndigits;
7710 	int			var2ndigits = var2->ndigits;
7711 	NumericDigit *var1digits = var1->digits;
7712 	NumericDigit *var2digits = var2->digits;
7713 
7714 	/*
7715 	 * First of all division by zero check; we must not be handed an
7716 	 * unnormalized divisor.
7717 	 */
7718 	if (var2ndigits == 0 || var2digits[0] == 0)
7719 		ereport(ERROR,
7720 				(errcode(ERRCODE_DIVISION_BY_ZERO),
7721 				 errmsg("division by zero")));
7722 
7723 	/*
7724 	 * Now result zero check
7725 	 */
7726 	if (var1ndigits == 0)
7727 	{
7728 		zero_var(result);
7729 		result->dscale = rscale;
7730 		return;
7731 	}
7732 
7733 	/*
7734 	 * Determine the result sign, weight and number of digits to calculate
7735 	 */
7736 	if (var1->sign == var2->sign)
7737 		res_sign = NUMERIC_POS;
7738 	else
7739 		res_sign = NUMERIC_NEG;
7740 	res_weight = var1->weight - var2->weight + 1;
7741 	/* The number of accurate result digits we need to produce: */
7742 	div_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7743 	/* Add guard digits for roundoff error */
7744 	div_ndigits += DIV_GUARD_DIGITS;
7745 	if (div_ndigits < DIV_GUARD_DIGITS)
7746 		div_ndigits = DIV_GUARD_DIGITS;
7747 
7748 	/*
7749 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
7750 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
7751 	 * to avoid normalizing carries immediately.
7752 	 *
7753 	 * We start with div[] containing one zero digit followed by the
7754 	 * dividend's digits (plus appended zeroes to reach the desired precision
7755 	 * including guard digits).  Each step of the main loop computes an
7756 	 * (approximate) quotient digit and stores it into div[], removing one
7757 	 * position of dividend space.  A final pass of carry propagation takes
7758 	 * care of any mistaken quotient digits.
7759 	 *
7760 	 * Note that div[] doesn't necessarily contain all of the digits from the
7761 	 * dividend --- the desired precision plus guard digits might be less than
7762 	 * the dividend's precision.  This happens, for example, in the square
7763 	 * root algorithm, where we typically divide a 2N-digit number by an
7764 	 * N-digit number, and only require a result with N digits of precision.
7765 	 */
7766 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
7767 	load_ndigits = Min(div_ndigits, var1ndigits);
7768 	for (i = 0; i < load_ndigits; i++)
7769 		div[i + 1] = var1digits[i];
7770 
7771 	/*
7772 	 * We estimate each quotient digit using floating-point arithmetic, taking
7773 	 * the first four digits of the (current) dividend and divisor.  This must
7774 	 * be float to avoid overflow.  The quotient digits will generally be off
7775 	 * by no more than one from the exact answer.
7776 	 */
7777 	fdivisor = (double) var2digits[0];
7778 	for (i = 1; i < 4; i++)
7779 	{
7780 		fdivisor *= NBASE;
7781 		if (i < var2ndigits)
7782 			fdivisor += (double) var2digits[i];
7783 	}
7784 	fdivisorinverse = 1.0 / fdivisor;
7785 
7786 	/*
7787 	 * maxdiv tracks the maximum possible absolute value of any div[] entry;
7788 	 * when this threatens to exceed INT_MAX, we take the time to propagate
7789 	 * carries.  Furthermore, we need to ensure that overflow doesn't occur
7790 	 * during the carry propagation passes either.  The carry values may have
7791 	 * an absolute value as high as INT_MAX/NBASE + 1, so really we must
7792 	 * normalize when digits threaten to exceed INT_MAX - INT_MAX/NBASE - 1.
7793 	 *
7794 	 * To avoid overflow in maxdiv itself, it represents the max absolute
7795 	 * value divided by NBASE-1, ie, at the top of the loop it is known that
7796 	 * no div[] entry has an absolute value exceeding maxdiv * (NBASE-1).
7797 	 *
7798 	 * Actually, though, that holds good only for div[] entries after div[qi];
7799 	 * the adjustment done at the bottom of the loop may cause div[qi + 1] to
7800 	 * exceed the maxdiv limit, so that div[qi] in the next iteration is
7801 	 * beyond the limit.  This does not cause problems, as explained below.
7802 	 */
7803 	maxdiv = 1;
7804 
7805 	/*
7806 	 * Outer loop computes next quotient digit, which will go into div[qi]
7807 	 */
7808 	for (qi = 0; qi < div_ndigits; qi++)
7809 	{
7810 		/* Approximate the current dividend value */
7811 		fdividend = (double) div[qi];
7812 		for (i = 1; i < 4; i++)
7813 		{
7814 			fdividend *= NBASE;
7815 			if (qi + i <= div_ndigits)
7816 				fdividend += (double) div[qi + i];
7817 		}
7818 		/* Compute the (approximate) quotient digit */
7819 		fquotient = fdividend * fdivisorinverse;
7820 		qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7821 			(((int) fquotient) - 1);	/* truncate towards -infinity */
7822 
7823 		if (qdigit != 0)
7824 		{
7825 			/* Do we need to normalize now? */
7826 			maxdiv += Abs(qdigit);
7827 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
7828 			{
7829 				/*
7830 				 * Yes, do it.  Note that if var2ndigits is much smaller than
7831 				 * div_ndigits, we can save a significant amount of effort
7832 				 * here by noting that we only need to normalise those div[]
7833 				 * entries touched where prior iterations subtracted multiples
7834 				 * of the divisor.
7835 				 */
7836 				carry = 0;
7837 				for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
7838 				{
7839 					newdig = div[i] + carry;
7840 					if (newdig < 0)
7841 					{
7842 						carry = -((-newdig - 1) / NBASE) - 1;
7843 						newdig -= carry * NBASE;
7844 					}
7845 					else if (newdig >= NBASE)
7846 					{
7847 						carry = newdig / NBASE;
7848 						newdig -= carry * NBASE;
7849 					}
7850 					else
7851 						carry = 0;
7852 					div[i] = newdig;
7853 				}
7854 				newdig = div[qi] + carry;
7855 				div[qi] = newdig;
7856 
7857 				/*
7858 				 * All the div[] digits except possibly div[qi] are now in the
7859 				 * range 0..NBASE-1.  We do not need to consider div[qi] in
7860 				 * the maxdiv value anymore, so we can reset maxdiv to 1.
7861 				 */
7862 				maxdiv = 1;
7863 
7864 				/*
7865 				 * Recompute the quotient digit since new info may have
7866 				 * propagated into the top four dividend digits
7867 				 */
7868 				fdividend = (double) div[qi];
7869 				for (i = 1; i < 4; i++)
7870 				{
7871 					fdividend *= NBASE;
7872 					if (qi + i <= div_ndigits)
7873 						fdividend += (double) div[qi + i];
7874 				}
7875 				/* Compute the (approximate) quotient digit */
7876 				fquotient = fdividend * fdivisorinverse;
7877 				qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7878 					(((int) fquotient) - 1);	/* truncate towards -infinity */
7879 				maxdiv += Abs(qdigit);
7880 			}
7881 
7882 			/*
7883 			 * Subtract off the appropriate multiple of the divisor.
7884 			 *
7885 			 * The digits beyond div[qi] cannot overflow, because we know they
7886 			 * will fall within the maxdiv limit.  As for div[qi] itself, note
7887 			 * that qdigit is approximately trunc(div[qi] / vardigits[0]),
7888 			 * which would make the new value simply div[qi] mod vardigits[0].
7889 			 * The lower-order terms in qdigit can change this result by not
7890 			 * more than about twice INT_MAX/NBASE, so overflow is impossible.
7891 			 */
7892 			if (qdigit != 0)
7893 			{
7894 				int			istop = Min(var2ndigits, div_ndigits - qi + 1);
7895 
7896 				for (i = 0; i < istop; i++)
7897 					div[qi + i] -= qdigit * var2digits[i];
7898 			}
7899 		}
7900 
7901 		/*
7902 		 * The dividend digit we are about to replace might still be nonzero.
7903 		 * Fold it into the next digit position.
7904 		 *
7905 		 * There is no risk of overflow here, although proving that requires
7906 		 * some care.  Much as with the argument for div[qi] not overflowing,
7907 		 * if we consider the first two terms in the numerator and denominator
7908 		 * of qdigit, we can see that the final value of div[qi + 1] will be
7909 		 * approximately a remainder mod (vardigits[0]*NBASE + vardigits[1]).
7910 		 * Accounting for the lower-order terms is a bit complicated but ends
7911 		 * up adding not much more than INT_MAX/NBASE to the possible range.
7912 		 * Thus, div[qi + 1] cannot overflow here, and in its role as div[qi]
7913 		 * in the next loop iteration, it can't be large enough to cause
7914 		 * overflow in the carry propagation step (if any), either.
7915 		 *
7916 		 * But having said that: div[qi] can be more than INT_MAX/NBASE, as
7917 		 * noted above, which means that the product div[qi] * NBASE *can*
7918 		 * overflow.  When that happens, adding it to div[qi + 1] will always
7919 		 * cause a canceling overflow so that the end result is correct.  We
7920 		 * could avoid the intermediate overflow by doing the multiplication
7921 		 * and addition in int64 arithmetic, but so far there appears no need.
7922 		 */
7923 		div[qi + 1] += div[qi] * NBASE;
7924 
7925 		div[qi] = qdigit;
7926 	}
7927 
7928 	/*
7929 	 * Approximate and store the last quotient digit (div[div_ndigits])
7930 	 */
7931 	fdividend = (double) div[qi];
7932 	for (i = 1; i < 4; i++)
7933 		fdividend *= NBASE;
7934 	fquotient = fdividend * fdivisorinverse;
7935 	qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7936 		(((int) fquotient) - 1);	/* truncate towards -infinity */
7937 	div[qi] = qdigit;
7938 
7939 	/*
7940 	 * Because the quotient digits might be off by one, some of them might be
7941 	 * -1 or NBASE at this point.  The represented value is correct in a
7942 	 * mathematical sense, but it doesn't look right.  We do a final carry
7943 	 * propagation pass to normalize the digits, which we combine with storing
7944 	 * the result digits into the output.  Note that this is still done at
7945 	 * full precision w/guard digits.
7946 	 */
7947 	alloc_var(result, div_ndigits + 1);
7948 	res_digits = result->digits;
7949 	carry = 0;
7950 	for (i = div_ndigits; i >= 0; i--)
7951 	{
7952 		newdig = div[i] + carry;
7953 		if (newdig < 0)
7954 		{
7955 			carry = -((-newdig - 1) / NBASE) - 1;
7956 			newdig -= carry * NBASE;
7957 		}
7958 		else if (newdig >= NBASE)
7959 		{
7960 			carry = newdig / NBASE;
7961 			newdig -= carry * NBASE;
7962 		}
7963 		else
7964 			carry = 0;
7965 		res_digits[i] = newdig;
7966 	}
7967 	Assert(carry == 0);
7968 
7969 	pfree(div);
7970 
7971 	/*
7972 	 * Finally, round the result to the requested precision.
7973 	 */
7974 	result->weight = res_weight;
7975 	result->sign = res_sign;
7976 
7977 	/* Round to target rscale (and set result->dscale) */
7978 	if (round)
7979 		round_var(result, rscale);
7980 	else
7981 		trunc_var(result, rscale);
7982 
7983 	/* Strip leading and trailing zeroes */
7984 	strip_var(result);
7985 }
7986 
7987 
7988 /*
7989  * Default scale selection for division
7990  *
7991  * Returns the appropriate result scale for the division result.
7992  */
7993 static int
select_div_scale(const NumericVar * var1,const NumericVar * var2)7994 select_div_scale(const NumericVar *var1, const NumericVar *var2)
7995 {
7996 	int			weight1,
7997 				weight2,
7998 				qweight,
7999 				i;
8000 	NumericDigit firstdigit1,
8001 				firstdigit2;
8002 	int			rscale;
8003 
8004 	/*
8005 	 * The result scale of a division isn't specified in any SQL standard. For
8006 	 * PostgreSQL we select a result scale that will give at least
8007 	 * NUMERIC_MIN_SIG_DIGITS significant digits, so that numeric gives a
8008 	 * result no less accurate than float8; but use a scale not less than
8009 	 * either input's display scale.
8010 	 */
8011 
8012 	/* Get the actual (normalized) weight and first digit of each input */
8013 
8014 	weight1 = 0;				/* values to use if var1 is zero */
8015 	firstdigit1 = 0;
8016 	for (i = 0; i < var1->ndigits; i++)
8017 	{
8018 		firstdigit1 = var1->digits[i];
8019 		if (firstdigit1 != 0)
8020 		{
8021 			weight1 = var1->weight - i;
8022 			break;
8023 		}
8024 	}
8025 
8026 	weight2 = 0;				/* values to use if var2 is zero */
8027 	firstdigit2 = 0;
8028 	for (i = 0; i < var2->ndigits; i++)
8029 	{
8030 		firstdigit2 = var2->digits[i];
8031 		if (firstdigit2 != 0)
8032 		{
8033 			weight2 = var2->weight - i;
8034 			break;
8035 		}
8036 	}
8037 
8038 	/*
8039 	 * Estimate weight of quotient.  If the two first digits are equal, we
8040 	 * can't be sure, but assume that var1 is less than var2.
8041 	 */
8042 	qweight = weight1 - weight2;
8043 	if (firstdigit1 <= firstdigit2)
8044 		qweight--;
8045 
8046 	/* Select result scale */
8047 	rscale = NUMERIC_MIN_SIG_DIGITS - qweight * DEC_DIGITS;
8048 	rscale = Max(rscale, var1->dscale);
8049 	rscale = Max(rscale, var2->dscale);
8050 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8051 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8052 
8053 	return rscale;
8054 }
8055 
8056 
8057 /*
8058  * mod_var() -
8059  *
8060  *	Calculate the modulo of two numerics at variable level
8061  */
8062 static void
mod_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8063 mod_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8064 {
8065 	NumericVar	tmp;
8066 
8067 	init_var(&tmp);
8068 
8069 	/* ---------
8070 	 * We do this using the equation
8071 	 *		mod(x,y) = x - trunc(x/y)*y
8072 	 * div_var can be persuaded to give us trunc(x/y) directly.
8073 	 * ----------
8074 	 */
8075 	div_var(var1, var2, &tmp, 0, false);
8076 
8077 	mul_var(var2, &tmp, &tmp, var2->dscale);
8078 
8079 	sub_var(var1, &tmp, result);
8080 
8081 	free_var(&tmp);
8082 }
8083 
8084 
8085 /*
8086  * div_mod_var() -
8087  *
8088  *	Calculate the truncated integer quotient and numeric remainder of two
8089  *	numeric variables.  The remainder is precise to var2's dscale.
8090  */
8091 static void
div_mod_var(const NumericVar * var1,const NumericVar * var2,NumericVar * quot,NumericVar * rem)8092 div_mod_var(const NumericVar *var1, const NumericVar *var2,
8093 			NumericVar *quot, NumericVar *rem)
8094 {
8095 	NumericVar	q;
8096 	NumericVar	r;
8097 
8098 	init_var(&q);
8099 	init_var(&r);
8100 
8101 	/*
8102 	 * Use div_var_fast() to get an initial estimate for the integer quotient.
8103 	 * This might be inaccurate (per the warning in div_var_fast's comments),
8104 	 * but we can correct it below.
8105 	 */
8106 	div_var_fast(var1, var2, &q, 0, false);
8107 
8108 	/* Compute initial estimate of remainder using the quotient estimate. */
8109 	mul_var(var2, &q, &r, var2->dscale);
8110 	sub_var(var1, &r, &r);
8111 
8112 	/*
8113 	 * Adjust the results if necessary --- the remainder should have the same
8114 	 * sign as var1, and its absolute value should be less than the absolute
8115 	 * value of var2.
8116 	 */
8117 	while (r.ndigits != 0 && r.sign != var1->sign)
8118 	{
8119 		/* The absolute value of the quotient is too large */
8120 		if (var1->sign == var2->sign)
8121 		{
8122 			sub_var(&q, &const_one, &q);
8123 			add_var(&r, var2, &r);
8124 		}
8125 		else
8126 		{
8127 			add_var(&q, &const_one, &q);
8128 			sub_var(&r, var2, &r);
8129 		}
8130 	}
8131 
8132 	while (cmp_abs(&r, var2) >= 0)
8133 	{
8134 		/* The absolute value of the quotient is too small */
8135 		if (var1->sign == var2->sign)
8136 		{
8137 			add_var(&q, &const_one, &q);
8138 			sub_var(&r, var2, &r);
8139 		}
8140 		else
8141 		{
8142 			sub_var(&q, &const_one, &q);
8143 			add_var(&r, var2, &r);
8144 		}
8145 	}
8146 
8147 	set_var_from_var(&q, quot);
8148 	set_var_from_var(&r, rem);
8149 
8150 	free_var(&q);
8151 	free_var(&r);
8152 }
8153 
8154 
8155 /*
8156  * ceil_var() -
8157  *
8158  *	Return the smallest integer greater than or equal to the argument
8159  *	on variable level
8160  */
8161 static void
ceil_var(const NumericVar * var,NumericVar * result)8162 ceil_var(const NumericVar *var, NumericVar *result)
8163 {
8164 	NumericVar	tmp;
8165 
8166 	init_var(&tmp);
8167 	set_var_from_var(var, &tmp);
8168 
8169 	trunc_var(&tmp, 0);
8170 
8171 	if (var->sign == NUMERIC_POS && cmp_var(var, &tmp) != 0)
8172 		add_var(&tmp, &const_one, &tmp);
8173 
8174 	set_var_from_var(&tmp, result);
8175 	free_var(&tmp);
8176 }
8177 
8178 
8179 /*
8180  * floor_var() -
8181  *
8182  *	Return the largest integer equal to or less than the argument
8183  *	on variable level
8184  */
8185 static void
floor_var(const NumericVar * var,NumericVar * result)8186 floor_var(const NumericVar *var, NumericVar *result)
8187 {
8188 	NumericVar	tmp;
8189 
8190 	init_var(&tmp);
8191 	set_var_from_var(var, &tmp);
8192 
8193 	trunc_var(&tmp, 0);
8194 
8195 	if (var->sign == NUMERIC_NEG && cmp_var(var, &tmp) != 0)
8196 		sub_var(&tmp, &const_one, &tmp);
8197 
8198 	set_var_from_var(&tmp, result);
8199 	free_var(&tmp);
8200 }
8201 
8202 
8203 /*
8204  * gcd_var() -
8205  *
8206  *	Calculate the greatest common divisor of two numerics at variable level
8207  */
8208 static void
gcd_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8209 gcd_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8210 {
8211 	int			res_dscale;
8212 	int			cmp;
8213 	NumericVar	tmp_arg;
8214 	NumericVar	mod;
8215 
8216 	res_dscale = Max(var1->dscale, var2->dscale);
8217 
8218 	/*
8219 	 * Arrange for var1 to be the number with the greater absolute value.
8220 	 *
8221 	 * This would happen automatically in the loop below, but avoids an
8222 	 * expensive modulo operation.
8223 	 */
8224 	cmp = cmp_abs(var1, var2);
8225 	if (cmp < 0)
8226 	{
8227 		const NumericVar *tmp = var1;
8228 
8229 		var1 = var2;
8230 		var2 = tmp;
8231 	}
8232 
8233 	/*
8234 	 * Also avoid the taking the modulo if the inputs have the same absolute
8235 	 * value, or if the smaller input is zero.
8236 	 */
8237 	if (cmp == 0 || var2->ndigits == 0)
8238 	{
8239 		set_var_from_var(var1, result);
8240 		result->sign = NUMERIC_POS;
8241 		result->dscale = res_dscale;
8242 		return;
8243 	}
8244 
8245 	init_var(&tmp_arg);
8246 	init_var(&mod);
8247 
8248 	/* Use the Euclidean algorithm to find the GCD */
8249 	set_var_from_var(var1, &tmp_arg);
8250 	set_var_from_var(var2, result);
8251 
8252 	for (;;)
8253 	{
8254 		/* this loop can take a while, so allow it to be interrupted */
8255 		CHECK_FOR_INTERRUPTS();
8256 
8257 		mod_var(&tmp_arg, result, &mod);
8258 		if (mod.ndigits == 0)
8259 			break;
8260 		set_var_from_var(result, &tmp_arg);
8261 		set_var_from_var(&mod, result);
8262 	}
8263 	result->sign = NUMERIC_POS;
8264 	result->dscale = res_dscale;
8265 
8266 	free_var(&tmp_arg);
8267 	free_var(&mod);
8268 }
8269 
8270 
8271 /*
8272  * sqrt_var() -
8273  *
8274  *	Compute the square root of x using the Karatsuba Square Root algorithm.
8275  *	NOTE: we allow rscale < 0 here, implying rounding before the decimal
8276  *	point.
8277  */
8278 static void
sqrt_var(const NumericVar * arg,NumericVar * result,int rscale)8279 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
8280 {
8281 	int			stat;
8282 	int			res_weight;
8283 	int			res_ndigits;
8284 	int			src_ndigits;
8285 	int			step;
8286 	int			ndigits[32];
8287 	int			blen;
8288 	int64		arg_int64;
8289 	int			src_idx;
8290 	int64		s_int64;
8291 	int64		r_int64;
8292 	NumericVar	s_var;
8293 	NumericVar	r_var;
8294 	NumericVar	a0_var;
8295 	NumericVar	a1_var;
8296 	NumericVar	q_var;
8297 	NumericVar	u_var;
8298 
8299 	stat = cmp_var(arg, &const_zero);
8300 	if (stat == 0)
8301 	{
8302 		zero_var(result);
8303 		result->dscale = rscale;
8304 		return;
8305 	}
8306 
8307 	/*
8308 	 * SQL2003 defines sqrt() in terms of power, so we need to emit the right
8309 	 * SQLSTATE error code if the operand is negative.
8310 	 */
8311 	if (stat < 0)
8312 		ereport(ERROR,
8313 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
8314 				 errmsg("cannot take square root of a negative number")));
8315 
8316 	init_var(&s_var);
8317 	init_var(&r_var);
8318 	init_var(&a0_var);
8319 	init_var(&a1_var);
8320 	init_var(&q_var);
8321 	init_var(&u_var);
8322 
8323 	/*
8324 	 * The result weight is half the input weight, rounded towards minus
8325 	 * infinity --- res_weight = floor(arg->weight / 2).
8326 	 */
8327 	if (arg->weight >= 0)
8328 		res_weight = arg->weight / 2;
8329 	else
8330 		res_weight = -((-arg->weight - 1) / 2 + 1);
8331 
8332 	/*
8333 	 * Number of NBASE digits to compute.  To ensure correct rounding, compute
8334 	 * at least 1 extra decimal digit.  We explicitly allow rscale to be
8335 	 * negative here, but must always compute at least 1 NBASE digit.  Thus
8336 	 * res_ndigits = res_weight + 1 + ceil((rscale + 1) / DEC_DIGITS) or 1.
8337 	 */
8338 	if (rscale + 1 >= 0)
8339 		res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS) / DEC_DIGITS;
8340 	else
8341 		res_ndigits = res_weight + 1 - (-rscale - 1) / DEC_DIGITS;
8342 	res_ndigits = Max(res_ndigits, 1);
8343 
8344 	/*
8345 	 * Number of source NBASE digits logically required to produce a result
8346 	 * with this precision --- every digit before the decimal point, plus 2
8347 	 * for each result digit after the decimal point (or minus 2 for each
8348 	 * result digit we round before the decimal point).
8349 	 */
8350 	src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
8351 	src_ndigits = Max(src_ndigits, 1);
8352 
8353 	/* ----------
8354 	 * From this point on, we treat the input and the result as integers and
8355 	 * compute the integer square root and remainder using the Karatsuba
8356 	 * Square Root algorithm, which may be written recursively as follows:
8357 	 *
8358 	 *	SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
8359 	 *		[ for some base b, and coefficients a0,a1,a2,a3 chosen so that
8360 	 *		  0 <= a0,a1,a2 < b and a3 >= b/4 ]
8361 	 *		Let (s,r) = SqrtRem(a3*b + a2)
8362 	 *		Let (q,u) = DivRem(r*b + a1, 2*s)
8363 	 *		Let s = s*b + q
8364 	 *		Let r = u*b + a0 - q^2
8365 	 *		If r < 0 Then
8366 	 *			Let r = r + s
8367 	 *			Let s = s - 1
8368 	 *			Let r = r + s
8369 	 *		Return (s,r)
8370 	 *
8371 	 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
8372 	 * RR-3805, November 1999.  At the time of writing this was available
8373 	 * on the net at <https://hal.inria.fr/inria-00072854>.
8374 	 *
8375 	 * The way to read the assumption "n = a3*b^3 + a2*b^2 + a1*b + a0" is
8376 	 * "choose a base b such that n requires at least four base-b digits to
8377 	 * express; then those digits are a3,a2,a1,a0, with a3 possibly larger
8378 	 * than b".  For optimal performance, b should have approximately a
8379 	 * quarter the number of digits in the input, so that the outer square
8380 	 * root computes roughly twice as many digits as the inner one.  For
8381 	 * simplicity, we choose b = NBASE^blen, an integer power of NBASE.
8382 	 *
8383 	 * We implement the algorithm iteratively rather than recursively, to
8384 	 * allow the working variables to be reused.  With this approach, each
8385 	 * digit of the input is read precisely once --- src_idx tracks the number
8386 	 * of input digits used so far.
8387 	 *
8388 	 * The array ndigits[] holds the number of NBASE digits of the input that
8389 	 * will have been used at the end of each iteration, which roughly doubles
8390 	 * each time.  Note that the array elements are stored in reverse order,
8391 	 * so if the final iteration requires src_ndigits = 37 input digits, the
8392 	 * array will contain [37,19,11,7,5,3], and we would start by computing
8393 	 * the square root of the 3 most significant NBASE digits.
8394 	 *
8395 	 * In each iteration, we choose blen to be the largest integer for which
8396 	 * the input number has a3 >= b/4, when written in the form above.  In
8397 	 * general, this means blen = src_ndigits / 4 (truncated), but if
8398 	 * src_ndigits is a multiple of 4, that might lead to the coefficient a3
8399 	 * being less than b/4 (if the first input digit is less than NBASE/4), in
8400 	 * which case we choose blen = src_ndigits / 4 - 1.  The number of digits
8401 	 * in the inner square root is then src_ndigits - 2*blen.  So, for
8402 	 * example, if we have src_ndigits = 26 initially, the array ndigits[]
8403 	 * will be either [26,14,8,4] or [26,14,8,6,4], depending on the size of
8404 	 * the first input digit.
8405 	 *
8406 	 * Additionally, we can put an upper bound on the number of steps required
8407 	 * as follows --- suppose that the number of source digits is an n-bit
8408 	 * number in the range [2^(n-1), 2^n-1], then blen will be in the range
8409 	 * [2^(n-3)-1, 2^(n-2)-1] and the number of digits in the inner square
8410 	 * root will be in the range [2^(n-2), 2^(n-1)+1].  In the next step, blen
8411 	 * will be in the range [2^(n-4)-1, 2^(n-3)] and the number of digits in
8412 	 * the next inner square root will be in the range [2^(n-3), 2^(n-2)+1].
8413 	 * This pattern repeats, and in the worst case the array ndigits[] will
8414 	 * contain [2^n-1, 2^(n-1)+1, 2^(n-2)+1, ... 9, 5, 3], and the computation
8415 	 * will require n steps.  Therefore, since all digit array sizes are
8416 	 * signed 32-bit integers, the number of steps required is guaranteed to
8417 	 * be less than 32.
8418 	 * ----------
8419 	 */
8420 	step = 0;
8421 	while ((ndigits[step] = src_ndigits) > 4)
8422 	{
8423 		/* Choose b so that a3 >= b/4, as described above */
8424 		blen = src_ndigits / 4;
8425 		if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
8426 			blen--;
8427 
8428 		/* Number of digits in the next step (inner square root) */
8429 		src_ndigits -= 2 * blen;
8430 		step++;
8431 	}
8432 
8433 	/*
8434 	 * First iteration (innermost square root and remainder):
8435 	 *
8436 	 * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
8437 	 * has at most 9 decimal digits, so estimate it using double precision
8438 	 * arithmetic, which will in fact almost certainly return the correct
8439 	 * result with no further correction required.
8440 	 */
8441 	arg_int64 = arg->digits[0];
8442 	for (src_idx = 1; src_idx < src_ndigits; src_idx++)
8443 	{
8444 		arg_int64 *= NBASE;
8445 		if (src_idx < arg->ndigits)
8446 			arg_int64 += arg->digits[src_idx];
8447 	}
8448 
8449 	s_int64 = (int64) sqrt((double) arg_int64);
8450 	r_int64 = arg_int64 - s_int64 * s_int64;
8451 
8452 	/*
8453 	 * Use Newton's method to correct the result, if necessary.
8454 	 *
8455 	 * This uses integer division with truncation to compute the truncated
8456 	 * integer square root by iterating using the formula x -> (x + n/x) / 2.
8457 	 * This is known to converge to isqrt(n), unless n+1 is a perfect square.
8458 	 * If n+1 is a perfect square, the sequence will oscillate between the two
8459 	 * values isqrt(n) and isqrt(n)+1, so we can be assured of convergence by
8460 	 * checking the remainder.
8461 	 */
8462 	while (r_int64 < 0 || r_int64 > 2 * s_int64)
8463 	{
8464 		s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
8465 		r_int64 = arg_int64 - s_int64 * s_int64;
8466 	}
8467 
8468 	/*
8469 	 * Iterations with src_ndigits <= 8:
8470 	 *
8471 	 * The next 1 or 2 iterations compute larger (outer) square roots with
8472 	 * src_ndigits <= 8, so the result still fits in an int64 (even though the
8473 	 * input no longer does) and we can continue to compute using int64
8474 	 * variables to avoid more expensive numeric computations.
8475 	 *
8476 	 * It is fairly easy to see that there is no risk of the intermediate
8477 	 * values below overflowing 64-bit integers.  In the worst case, the
8478 	 * previous iteration will have computed a 3-digit square root (of a
8479 	 * 6-digit input less than NBASE^6 / 4), so at the start of this
8480 	 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
8481 	 * less than 10^12.  In this case, blen will be 1, so numer will be less
8482 	 * than 10^17, and denom will be less than 10^12 (and hence u will also be
8483 	 * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
8484 	 * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
8485 	 * in 64-bit integers.
8486 	 */
8487 	step--;
8488 	while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
8489 	{
8490 		int			b;
8491 		int			a0;
8492 		int			a1;
8493 		int			i;
8494 		int64		numer;
8495 		int64		denom;
8496 		int64		q;
8497 		int64		u;
8498 
8499 		blen = (src_ndigits - src_idx) / 2;
8500 
8501 		/* Extract a1 and a0, and compute b */
8502 		a0 = 0;
8503 		a1 = 0;
8504 		b = 1;
8505 
8506 		for (i = 0; i < blen; i++, src_idx++)
8507 		{
8508 			b *= NBASE;
8509 			a1 *= NBASE;
8510 			if (src_idx < arg->ndigits)
8511 				a1 += arg->digits[src_idx];
8512 		}
8513 
8514 		for (i = 0; i < blen; i++, src_idx++)
8515 		{
8516 			a0 *= NBASE;
8517 			if (src_idx < arg->ndigits)
8518 				a0 += arg->digits[src_idx];
8519 		}
8520 
8521 		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
8522 		numer = r_int64 * b + a1;
8523 		denom = 2 * s_int64;
8524 		q = numer / denom;
8525 		u = numer - q * denom;
8526 
8527 		/* Compute s = s*b + q and r = u*b + a0 - q^2 */
8528 		s_int64 = s_int64 * b + q;
8529 		r_int64 = u * b + a0 - q * q;
8530 
8531 		if (r_int64 < 0)
8532 		{
8533 			/* s is too large by 1; set r += s, s--, r += s */
8534 			r_int64 += s_int64;
8535 			s_int64--;
8536 			r_int64 += s_int64;
8537 		}
8538 
8539 		Assert(src_idx == src_ndigits); /* All input digits consumed */
8540 		step--;
8541 	}
8542 
8543 	/*
8544 	 * On platforms with 128-bit integer support, we can further delay the
8545 	 * need to use numeric variables.
8546 	 */
8547 #ifdef HAVE_INT128
8548 	if (step >= 0)
8549 	{
8550 		int128		s_int128;
8551 		int128		r_int128;
8552 
8553 		s_int128 = s_int64;
8554 		r_int128 = r_int64;
8555 
8556 		/*
8557 		 * Iterations with src_ndigits <= 16:
8558 		 *
8559 		 * The result fits in an int128 (even though the input doesn't) so we
8560 		 * use int128 variables to avoid more expensive numeric computations.
8561 		 */
8562 		while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
8563 		{
8564 			int64		b;
8565 			int64		a0;
8566 			int64		a1;
8567 			int64		i;
8568 			int128		numer;
8569 			int128		denom;
8570 			int128		q;
8571 			int128		u;
8572 
8573 			blen = (src_ndigits - src_idx) / 2;
8574 
8575 			/* Extract a1 and a0, and compute b */
8576 			a0 = 0;
8577 			a1 = 0;
8578 			b = 1;
8579 
8580 			for (i = 0; i < blen; i++, src_idx++)
8581 			{
8582 				b *= NBASE;
8583 				a1 *= NBASE;
8584 				if (src_idx < arg->ndigits)
8585 					a1 += arg->digits[src_idx];
8586 			}
8587 
8588 			for (i = 0; i < blen; i++, src_idx++)
8589 			{
8590 				a0 *= NBASE;
8591 				if (src_idx < arg->ndigits)
8592 					a0 += arg->digits[src_idx];
8593 			}
8594 
8595 			/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
8596 			numer = r_int128 * b + a1;
8597 			denom = 2 * s_int128;
8598 			q = numer / denom;
8599 			u = numer - q * denom;
8600 
8601 			/* Compute s = s*b + q and r = u*b + a0 - q^2 */
8602 			s_int128 = s_int128 * b + q;
8603 			r_int128 = u * b + a0 - q * q;
8604 
8605 			if (r_int128 < 0)
8606 			{
8607 				/* s is too large by 1; set r += s, s--, r += s */
8608 				r_int128 += s_int128;
8609 				s_int128--;
8610 				r_int128 += s_int128;
8611 			}
8612 
8613 			Assert(src_idx == src_ndigits); /* All input digits consumed */
8614 			step--;
8615 		}
8616 
8617 		/*
8618 		 * All remaining iterations require numeric variables.  Convert the
8619 		 * integer values to NumericVar and continue.  Note that in the final
8620 		 * iteration we don't need the remainder, so we can save a few cycles
8621 		 * there by not fully computing it.
8622 		 */
8623 		int128_to_numericvar(s_int128, &s_var);
8624 		if (step >= 0)
8625 			int128_to_numericvar(r_int128, &r_var);
8626 	}
8627 	else
8628 	{
8629 		int64_to_numericvar(s_int64, &s_var);
8630 		/* step < 0, so we certainly don't need r */
8631 	}
8632 #else							/* !HAVE_INT128 */
8633 	int64_to_numericvar(s_int64, &s_var);
8634 	if (step >= 0)
8635 		int64_to_numericvar(r_int64, &r_var);
8636 #endif							/* HAVE_INT128 */
8637 
8638 	/*
8639 	 * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
8640 	 * use numeric variables.
8641 	 */
8642 	while (step >= 0)
8643 	{
8644 		int			tmp_len;
8645 
8646 		src_ndigits = ndigits[step];
8647 		blen = (src_ndigits - src_idx) / 2;
8648 
8649 		/* Extract a1 and a0 */
8650 		if (src_idx < arg->ndigits)
8651 		{
8652 			tmp_len = Min(blen, arg->ndigits - src_idx);
8653 			alloc_var(&a1_var, tmp_len);
8654 			memcpy(a1_var.digits, arg->digits + src_idx,
8655 				   tmp_len * sizeof(NumericDigit));
8656 			a1_var.weight = blen - 1;
8657 			a1_var.sign = NUMERIC_POS;
8658 			a1_var.dscale = 0;
8659 			strip_var(&a1_var);
8660 		}
8661 		else
8662 		{
8663 			zero_var(&a1_var);
8664 			a1_var.dscale = 0;
8665 		}
8666 		src_idx += blen;
8667 
8668 		if (src_idx < arg->ndigits)
8669 		{
8670 			tmp_len = Min(blen, arg->ndigits - src_idx);
8671 			alloc_var(&a0_var, tmp_len);
8672 			memcpy(a0_var.digits, arg->digits + src_idx,
8673 				   tmp_len * sizeof(NumericDigit));
8674 			a0_var.weight = blen - 1;
8675 			a0_var.sign = NUMERIC_POS;
8676 			a0_var.dscale = 0;
8677 			strip_var(&a0_var);
8678 		}
8679 		else
8680 		{
8681 			zero_var(&a0_var);
8682 			a0_var.dscale = 0;
8683 		}
8684 		src_idx += blen;
8685 
8686 		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
8687 		set_var_from_var(&r_var, &q_var);
8688 		q_var.weight += blen;
8689 		add_var(&q_var, &a1_var, &q_var);
8690 		add_var(&s_var, &s_var, &u_var);
8691 		div_mod_var(&q_var, &u_var, &q_var, &u_var);
8692 
8693 		/* Compute s = s*b + q */
8694 		s_var.weight += blen;
8695 		add_var(&s_var, &q_var, &s_var);
8696 
8697 		/*
8698 		 * Compute r = u*b + a0 - q^2.
8699 		 *
8700 		 * In the final iteration, we don't actually need r; we just need to
8701 		 * know whether it is negative, so that we know whether to adjust s.
8702 		 * So instead of the final subtraction we can just compare.
8703 		 */
8704 		u_var.weight += blen;
8705 		add_var(&u_var, &a0_var, &u_var);
8706 		mul_var(&q_var, &q_var, &q_var, 0);
8707 
8708 		if (step > 0)
8709 		{
8710 			/* Need r for later iterations */
8711 			sub_var(&u_var, &q_var, &r_var);
8712 			if (r_var.sign == NUMERIC_NEG)
8713 			{
8714 				/* s is too large by 1; set r += s, s--, r += s */
8715 				add_var(&r_var, &s_var, &r_var);
8716 				sub_var(&s_var, &const_one, &s_var);
8717 				add_var(&r_var, &s_var, &r_var);
8718 			}
8719 		}
8720 		else
8721 		{
8722 			/* Don't need r anymore, except to test if s is too large by 1 */
8723 			if (cmp_var(&u_var, &q_var) < 0)
8724 				sub_var(&s_var, &const_one, &s_var);
8725 		}
8726 
8727 		Assert(src_idx == src_ndigits); /* All input digits consumed */
8728 		step--;
8729 	}
8730 
8731 	/*
8732 	 * Construct the final result, rounding it to the requested precision.
8733 	 */
8734 	set_var_from_var(&s_var, result);
8735 	result->weight = res_weight;
8736 	result->sign = NUMERIC_POS;
8737 
8738 	/* Round to target rscale (and set result->dscale) */
8739 	round_var(result, rscale);
8740 
8741 	/* Strip leading and trailing zeroes */
8742 	strip_var(result);
8743 
8744 	free_var(&s_var);
8745 	free_var(&r_var);
8746 	free_var(&a0_var);
8747 	free_var(&a1_var);
8748 	free_var(&q_var);
8749 	free_var(&u_var);
8750 }
8751 
8752 
8753 /*
8754  * exp_var() -
8755  *
8756  *	Raise e to the power of x, computed to rscale fractional digits
8757  */
8758 static void
exp_var(const NumericVar * arg,NumericVar * result,int rscale)8759 exp_var(const NumericVar *arg, NumericVar *result, int rscale)
8760 {
8761 	NumericVar	x;
8762 	NumericVar	elem;
8763 	NumericVar	ni;
8764 	double		val;
8765 	int			dweight;
8766 	int			ndiv2;
8767 	int			sig_digits;
8768 	int			local_rscale;
8769 
8770 	init_var(&x);
8771 	init_var(&elem);
8772 	init_var(&ni);
8773 
8774 	set_var_from_var(arg, &x);
8775 
8776 	/*
8777 	 * Estimate the dweight of the result using floating point arithmetic, so
8778 	 * that we can choose an appropriate local rscale for the calculation.
8779 	 */
8780 	val = numericvar_to_double_no_overflow(&x);
8781 
8782 	/* Guard against overflow/underflow */
8783 	/* If you change this limit, see also power_var()'s limit */
8784 	if (Abs(val) >= NUMERIC_MAX_RESULT_SCALE * 3)
8785 	{
8786 		if (val > 0)
8787 			ereport(ERROR,
8788 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8789 					 errmsg("value overflows numeric format")));
8790 		zero_var(result);
8791 		result->dscale = rscale;
8792 		return;
8793 	}
8794 
8795 	/* decimal weight = log10(e^x) = x * log10(e) */
8796 	dweight = (int) (val * 0.434294481903252);
8797 
8798 	/*
8799 	 * Reduce x to the range -0.01 <= x <= 0.01 (approximately) by dividing by
8800 	 * 2^n, to improve the convergence rate of the Taylor series.
8801 	 */
8802 	if (Abs(val) > 0.01)
8803 	{
8804 		NumericVar	tmp;
8805 
8806 		init_var(&tmp);
8807 		set_var_from_var(&const_two, &tmp);
8808 
8809 		ndiv2 = 1;
8810 		val /= 2;
8811 
8812 		while (Abs(val) > 0.01)
8813 		{
8814 			ndiv2++;
8815 			val /= 2;
8816 			add_var(&tmp, &tmp, &tmp);
8817 		}
8818 
8819 		local_rscale = x.dscale + ndiv2;
8820 		div_var_fast(&x, &tmp, &x, local_rscale, true);
8821 
8822 		free_var(&tmp);
8823 	}
8824 	else
8825 		ndiv2 = 0;
8826 
8827 	/*
8828 	 * Set the scale for the Taylor series expansion.  The final result has
8829 	 * (dweight + rscale + 1) significant digits.  In addition, we have to
8830 	 * raise the Taylor series result to the power 2^ndiv2, which introduces
8831 	 * an error of up to around log10(2^ndiv2) digits, so work with this many
8832 	 * extra digits of precision (plus a few more for good measure).
8833 	 */
8834 	sig_digits = 1 + dweight + rscale + (int) (ndiv2 * 0.301029995663981);
8835 	sig_digits = Max(sig_digits, 0) + 8;
8836 
8837 	local_rscale = sig_digits - 1;
8838 
8839 	/*
8840 	 * Use the Taylor series
8841 	 *
8842 	 * exp(x) = 1 + x + x^2/2! + x^3/3! + ...
8843 	 *
8844 	 * Given the limited range of x, this should converge reasonably quickly.
8845 	 * We run the series until the terms fall below the local_rscale limit.
8846 	 */
8847 	add_var(&const_one, &x, result);
8848 
8849 	mul_var(&x, &x, &elem, local_rscale);
8850 	set_var_from_var(&const_two, &ni);
8851 	div_var_fast(&elem, &ni, &elem, local_rscale, true);
8852 
8853 	while (elem.ndigits != 0)
8854 	{
8855 		add_var(result, &elem, result);
8856 
8857 		mul_var(&elem, &x, &elem, local_rscale);
8858 		add_var(&ni, &const_one, &ni);
8859 		div_var_fast(&elem, &ni, &elem, local_rscale, true);
8860 	}
8861 
8862 	/*
8863 	 * Compensate for the argument range reduction.  Since the weight of the
8864 	 * result doubles with each multiplication, we can reduce the local rscale
8865 	 * as we proceed.
8866 	 */
8867 	while (ndiv2-- > 0)
8868 	{
8869 		local_rscale = sig_digits - result->weight * 2 * DEC_DIGITS;
8870 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8871 		mul_var(result, result, result, local_rscale);
8872 	}
8873 
8874 	/* Round to requested rscale */
8875 	round_var(result, rscale);
8876 
8877 	free_var(&x);
8878 	free_var(&elem);
8879 	free_var(&ni);
8880 }
8881 
8882 
8883 /*
8884  * Estimate the dweight of the most significant decimal digit of the natural
8885  * logarithm of a number.
8886  *
8887  * Essentially, we're approximating log10(abs(ln(var))).  This is used to
8888  * determine the appropriate rscale when computing natural logarithms.
8889  */
8890 static int
estimate_ln_dweight(const NumericVar * var)8891 estimate_ln_dweight(const NumericVar *var)
8892 {
8893 	int			ln_dweight;
8894 
8895 	if (cmp_var(var, &const_zero_point_nine) >= 0 &&
8896 		cmp_var(var, &const_one_point_one) <= 0)
8897 	{
8898 		/*
8899 		 * 0.9 <= var <= 1.1
8900 		 *
8901 		 * ln(var) has a negative weight (possibly very large).  To get a
8902 		 * reasonably accurate result, estimate it using ln(1+x) ~= x.
8903 		 */
8904 		NumericVar	x;
8905 
8906 		init_var(&x);
8907 		sub_var(var, &const_one, &x);
8908 
8909 		if (x.ndigits > 0)
8910 		{
8911 			/* Use weight of most significant decimal digit of x */
8912 			ln_dweight = x.weight * DEC_DIGITS + (int) log10(x.digits[0]);
8913 		}
8914 		else
8915 		{
8916 			/* x = 0.  Since ln(1) = 0 exactly, we don't need extra digits */
8917 			ln_dweight = 0;
8918 		}
8919 
8920 		free_var(&x);
8921 	}
8922 	else
8923 	{
8924 		/*
8925 		 * Estimate the logarithm using the first couple of digits from the
8926 		 * input number.  This will give an accurate result whenever the input
8927 		 * is not too close to 1.
8928 		 */
8929 		if (var->ndigits > 0)
8930 		{
8931 			int			digits;
8932 			int			dweight;
8933 			double		ln_var;
8934 
8935 			digits = var->digits[0];
8936 			dweight = var->weight * DEC_DIGITS;
8937 
8938 			if (var->ndigits > 1)
8939 			{
8940 				digits = digits * NBASE + var->digits[1];
8941 				dweight -= DEC_DIGITS;
8942 			}
8943 
8944 			/*----------
8945 			 * We have var ~= digits * 10^dweight
8946 			 * so ln(var) ~= ln(digits) + dweight * ln(10)
8947 			 *----------
8948 			 */
8949 			ln_var = log((double) digits) + dweight * 2.302585092994046;
8950 			ln_dweight = (int) log10(Abs(ln_var));
8951 		}
8952 		else
8953 		{
8954 			/* Caller should fail on ln(0), but for the moment return zero */
8955 			ln_dweight = 0;
8956 		}
8957 	}
8958 
8959 	return ln_dweight;
8960 }
8961 
8962 
8963 /*
8964  * ln_var() -
8965  *
8966  *	Compute the natural log of x
8967  */
8968 static void
ln_var(const NumericVar * arg,NumericVar * result,int rscale)8969 ln_var(const NumericVar *arg, NumericVar *result, int rscale)
8970 {
8971 	NumericVar	x;
8972 	NumericVar	xx;
8973 	NumericVar	ni;
8974 	NumericVar	elem;
8975 	NumericVar	fact;
8976 	int			nsqrt;
8977 	int			local_rscale;
8978 	int			cmp;
8979 
8980 	cmp = cmp_var(arg, &const_zero);
8981 	if (cmp == 0)
8982 		ereport(ERROR,
8983 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
8984 				 errmsg("cannot take logarithm of zero")));
8985 	else if (cmp < 0)
8986 		ereport(ERROR,
8987 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
8988 				 errmsg("cannot take logarithm of a negative number")));
8989 
8990 	init_var(&x);
8991 	init_var(&xx);
8992 	init_var(&ni);
8993 	init_var(&elem);
8994 	init_var(&fact);
8995 
8996 	set_var_from_var(arg, &x);
8997 	set_var_from_var(&const_two, &fact);
8998 
8999 	/*
9000 	 * Reduce input into range 0.9 < x < 1.1 with repeated sqrt() operations.
9001 	 *
9002 	 * The final logarithm will have up to around rscale+6 significant digits.
9003 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
9004 	 * rscale as we work so that we keep this many significant digits at each
9005 	 * step (plus a few more for good measure).
9006 	 *
9007 	 * Note that we allow local_rscale < 0 during this input reduction
9008 	 * process, which implies rounding before the decimal point.  sqrt_var()
9009 	 * explicitly supports this, and it significantly reduces the work
9010 	 * required to reduce very large inputs to the required range.  Once the
9011 	 * input reduction is complete, x.weight will be 0 and its display scale
9012 	 * will be non-negative again.
9013 	 */
9014 	nsqrt = 0;
9015 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
9016 	{
9017 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
9018 		sqrt_var(&x, &x, local_rscale);
9019 		mul_var(&fact, &const_two, &fact, 0);
9020 		nsqrt++;
9021 	}
9022 	while (cmp_var(&x, &const_one_point_one) >= 0)
9023 	{
9024 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
9025 		sqrt_var(&x, &x, local_rscale);
9026 		mul_var(&fact, &const_two, &fact, 0);
9027 		nsqrt++;
9028 	}
9029 
9030 	/*
9031 	 * We use the Taylor series for 0.5 * ln((1+z)/(1-z)),
9032 	 *
9033 	 * z + z^3/3 + z^5/5 + ...
9034 	 *
9035 	 * where z = (x-1)/(x+1) is in the range (approximately) -0.053 .. 0.048
9036 	 * due to the above range-reduction of x.
9037 	 *
9038 	 * The convergence of this is not as fast as one would like, but is
9039 	 * tolerable given that z is small.
9040 	 *
9041 	 * The Taylor series result will be multiplied by 2^(nsqrt+1), which has a
9042 	 * decimal weight of (nsqrt+1) * log10(2), so work with this many extra
9043 	 * digits of precision (plus a few more for good measure).
9044 	 */
9045 	local_rscale = rscale + (int) ((nsqrt + 1) * 0.301029995663981) + 8;
9046 
9047 	sub_var(&x, &const_one, result);
9048 	add_var(&x, &const_one, &elem);
9049 	div_var_fast(result, &elem, result, local_rscale, true);
9050 	set_var_from_var(result, &xx);
9051 	mul_var(result, result, &x, local_rscale);
9052 
9053 	set_var_from_var(&const_one, &ni);
9054 
9055 	for (;;)
9056 	{
9057 		add_var(&ni, &const_two, &ni);
9058 		mul_var(&xx, &x, &xx, local_rscale);
9059 		div_var_fast(&xx, &ni, &elem, local_rscale, true);
9060 
9061 		if (elem.ndigits == 0)
9062 			break;
9063 
9064 		add_var(result, &elem, result);
9065 
9066 		if (elem.weight < (result->weight - local_rscale * 2 / DEC_DIGITS))
9067 			break;
9068 	}
9069 
9070 	/* Compensate for argument range reduction, round to requested rscale */
9071 	mul_var(result, &fact, result, rscale);
9072 
9073 	free_var(&x);
9074 	free_var(&xx);
9075 	free_var(&ni);
9076 	free_var(&elem);
9077 	free_var(&fact);
9078 }
9079 
9080 
9081 /*
9082  * log_var() -
9083  *
9084  *	Compute the logarithm of num in a given base.
9085  *
9086  *	Note: this routine chooses dscale of the result.
9087  */
9088 static void
log_var(const NumericVar * base,const NumericVar * num,NumericVar * result)9089 log_var(const NumericVar *base, const NumericVar *num, NumericVar *result)
9090 {
9091 	NumericVar	ln_base;
9092 	NumericVar	ln_num;
9093 	int			ln_base_dweight;
9094 	int			ln_num_dweight;
9095 	int			result_dweight;
9096 	int			rscale;
9097 	int			ln_base_rscale;
9098 	int			ln_num_rscale;
9099 
9100 	init_var(&ln_base);
9101 	init_var(&ln_num);
9102 
9103 	/* Estimated dweights of ln(base), ln(num) and the final result */
9104 	ln_base_dweight = estimate_ln_dweight(base);
9105 	ln_num_dweight = estimate_ln_dweight(num);
9106 	result_dweight = ln_num_dweight - ln_base_dweight;
9107 
9108 	/*
9109 	 * Select the scale of the result so that it will have at least
9110 	 * NUMERIC_MIN_SIG_DIGITS significant digits and is not less than either
9111 	 * input's display scale.
9112 	 */
9113 	rscale = NUMERIC_MIN_SIG_DIGITS - result_dweight;
9114 	rscale = Max(rscale, base->dscale);
9115 	rscale = Max(rscale, num->dscale);
9116 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
9117 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
9118 
9119 	/*
9120 	 * Set the scales for ln(base) and ln(num) so that they each have more
9121 	 * significant digits than the final result.
9122 	 */
9123 	ln_base_rscale = rscale + result_dweight - ln_base_dweight + 8;
9124 	ln_base_rscale = Max(ln_base_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9125 
9126 	ln_num_rscale = rscale + result_dweight - ln_num_dweight + 8;
9127 	ln_num_rscale = Max(ln_num_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9128 
9129 	/* Form natural logarithms */
9130 	ln_var(base, &ln_base, ln_base_rscale);
9131 	ln_var(num, &ln_num, ln_num_rscale);
9132 
9133 	/* Divide and round to the required scale */
9134 	div_var_fast(&ln_num, &ln_base, result, rscale, true);
9135 
9136 	free_var(&ln_num);
9137 	free_var(&ln_base);
9138 }
9139 
9140 
9141 /*
9142  * power_var() -
9143  *
9144  *	Raise base to the power of exp
9145  *
9146  *	Note: this routine chooses dscale of the result.
9147  */
9148 static void
power_var(const NumericVar * base,const NumericVar * exp,NumericVar * result)9149 power_var(const NumericVar *base, const NumericVar *exp, NumericVar *result)
9150 {
9151 	int			res_sign;
9152 	NumericVar	abs_base;
9153 	NumericVar	ln_base;
9154 	NumericVar	ln_num;
9155 	int			ln_dweight;
9156 	int			rscale;
9157 	int			sig_digits;
9158 	int			local_rscale;
9159 	double		val;
9160 
9161 	/* If exp can be represented as an integer, use power_var_int */
9162 	if (exp->ndigits == 0 || exp->ndigits <= exp->weight + 1)
9163 	{
9164 		/* exact integer, but does it fit in int? */
9165 		int64		expval64;
9166 
9167 		if (numericvar_to_int64(exp, &expval64))
9168 		{
9169 			if (expval64 >= PG_INT32_MIN && expval64 <= PG_INT32_MAX)
9170 			{
9171 				/* Okay, select rscale */
9172 				rscale = NUMERIC_MIN_SIG_DIGITS;
9173 				rscale = Max(rscale, base->dscale);
9174 				rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
9175 				rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
9176 
9177 				power_var_int(base, (int) expval64, result, rscale);
9178 				return;
9179 			}
9180 		}
9181 	}
9182 
9183 	/*
9184 	 * This avoids log(0) for cases of 0 raised to a non-integer.  0 ^ 0 is
9185 	 * handled by power_var_int().
9186 	 */
9187 	if (cmp_var(base, &const_zero) == 0)
9188 	{
9189 		set_var_from_var(&const_zero, result);
9190 		result->dscale = NUMERIC_MIN_SIG_DIGITS;	/* no need to round */
9191 		return;
9192 	}
9193 
9194 	init_var(&abs_base);
9195 	init_var(&ln_base);
9196 	init_var(&ln_num);
9197 
9198 	/*
9199 	 * If base is negative, insist that exp be an integer.  The result is then
9200 	 * positive if exp is even and negative if exp is odd.
9201 	 */
9202 	if (base->sign == NUMERIC_NEG)
9203 	{
9204 		/*
9205 		 * Check that exp is an integer.  This error code is defined by the
9206 		 * SQL standard, and matches other errors in numeric_power().
9207 		 */
9208 		if (exp->ndigits > 0 && exp->ndigits > exp->weight + 1)
9209 			ereport(ERROR,
9210 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
9211 					 errmsg("a negative number raised to a non-integer power yields a complex result")));
9212 
9213 		/* Test if exp is odd or even */
9214 		if (exp->ndigits > 0 && exp->ndigits == exp->weight + 1 &&
9215 			(exp->digits[exp->ndigits - 1] & 1))
9216 			res_sign = NUMERIC_NEG;
9217 		else
9218 			res_sign = NUMERIC_POS;
9219 
9220 		/* Then work with abs(base) below */
9221 		set_var_from_var(base, &abs_base);
9222 		abs_base.sign = NUMERIC_POS;
9223 		base = &abs_base;
9224 	}
9225 	else
9226 		res_sign = NUMERIC_POS;
9227 
9228 	/*----------
9229 	 * Decide on the scale for the ln() calculation.  For this we need an
9230 	 * estimate of the weight of the result, which we obtain by doing an
9231 	 * initial low-precision calculation of exp * ln(base).
9232 	 *
9233 	 * We want result = e ^ (exp * ln(base))
9234 	 * so result dweight = log10(result) = exp * ln(base) * log10(e)
9235 	 *
9236 	 * We also perform a crude overflow test here so that we can exit early if
9237 	 * the full-precision result is sure to overflow, and to guard against
9238 	 * integer overflow when determining the scale for the real calculation.
9239 	 * exp_var() supports inputs up to NUMERIC_MAX_RESULT_SCALE * 3, so the
9240 	 * result will overflow if exp * ln(base) >= NUMERIC_MAX_RESULT_SCALE * 3.
9241 	 * Since the values here are only approximations, we apply a small fuzz
9242 	 * factor to this overflow test and let exp_var() determine the exact
9243 	 * overflow threshold so that it is consistent for all inputs.
9244 	 *----------
9245 	 */
9246 	ln_dweight = estimate_ln_dweight(base);
9247 
9248 	/*
9249 	 * Set the scale for the low-precision calculation, computing ln(base) to
9250 	 * around 8 significant digits.  Note that ln_dweight may be as small as
9251 	 * -SHRT_MAX, so the scale may exceed NUMERIC_MAX_DISPLAY_SCALE here.
9252 	 */
9253 	local_rscale = 8 - ln_dweight;
9254 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9255 
9256 	ln_var(base, &ln_base, local_rscale);
9257 
9258 	mul_var(&ln_base, exp, &ln_num, local_rscale);
9259 
9260 	val = numericvar_to_double_no_overflow(&ln_num);
9261 
9262 	/* initial overflow/underflow test with fuzz factor */
9263 	if (Abs(val) > NUMERIC_MAX_RESULT_SCALE * 3.01)
9264 	{
9265 		if (val > 0)
9266 			ereport(ERROR,
9267 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
9268 					 errmsg("value overflows numeric format")));
9269 		zero_var(result);
9270 		result->dscale = NUMERIC_MAX_DISPLAY_SCALE;
9271 		return;
9272 	}
9273 
9274 	val *= 0.434294481903252;	/* approximate decimal result weight */
9275 
9276 	/* choose the result scale */
9277 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
9278 	rscale = Max(rscale, base->dscale);
9279 	rscale = Max(rscale, exp->dscale);
9280 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
9281 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
9282 
9283 	/* significant digits required in the result */
9284 	sig_digits = rscale + (int) val;
9285 	sig_digits = Max(sig_digits, 0);
9286 
9287 	/* set the scale for the real exp * ln(base) calculation */
9288 	local_rscale = sig_digits - ln_dweight + 8;
9289 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9290 
9291 	/* and do the real calculation */
9292 
9293 	ln_var(base, &ln_base, local_rscale);
9294 
9295 	mul_var(&ln_base, exp, &ln_num, local_rscale);
9296 
9297 	exp_var(&ln_num, result, rscale);
9298 
9299 	if (res_sign == NUMERIC_NEG && result->ndigits > 0)
9300 		result->sign = NUMERIC_NEG;
9301 
9302 	free_var(&ln_num);
9303 	free_var(&ln_base);
9304 	free_var(&abs_base);
9305 }
9306 
9307 /*
9308  * power_var_int() -
9309  *
9310  *	Raise base to the power of exp, where exp is an integer.
9311  */
9312 static void
power_var_int(const NumericVar * base,int exp,NumericVar * result,int rscale)9313 power_var_int(const NumericVar *base, int exp, NumericVar *result, int rscale)
9314 {
9315 	double		f;
9316 	int			p;
9317 	int			i;
9318 	int			sig_digits;
9319 	unsigned int mask;
9320 	bool		neg;
9321 	NumericVar	base_prod;
9322 	int			local_rscale;
9323 
9324 	/* Handle some common special cases, as well as corner cases */
9325 	switch (exp)
9326 	{
9327 		case 0:
9328 
9329 			/*
9330 			 * While 0 ^ 0 can be either 1 or indeterminate (error), we treat
9331 			 * it as 1 because most programming languages do this. SQL:2003
9332 			 * also requires a return value of 1.
9333 			 * https://en.wikipedia.org/wiki/Exponentiation#Zero_to_the_zero_power
9334 			 */
9335 			set_var_from_var(&const_one, result);
9336 			result->dscale = rscale;	/* no need to round */
9337 			return;
9338 		case 1:
9339 			set_var_from_var(base, result);
9340 			round_var(result, rscale);
9341 			return;
9342 		case -1:
9343 			div_var(&const_one, base, result, rscale, true);
9344 			return;
9345 		case 2:
9346 			mul_var(base, base, result, rscale);
9347 			return;
9348 		default:
9349 			break;
9350 	}
9351 
9352 	/* Handle the special case where the base is zero */
9353 	if (base->ndigits == 0)
9354 	{
9355 		if (exp < 0)
9356 			ereport(ERROR,
9357 					(errcode(ERRCODE_DIVISION_BY_ZERO),
9358 					 errmsg("division by zero")));
9359 		zero_var(result);
9360 		result->dscale = rscale;
9361 		return;
9362 	}
9363 
9364 	/*
9365 	 * The general case repeatedly multiplies base according to the bit
9366 	 * pattern of exp.
9367 	 *
9368 	 * First we need to estimate the weight of the result so that we know how
9369 	 * many significant digits are needed.
9370 	 */
9371 	f = base->digits[0];
9372 	p = base->weight * DEC_DIGITS;
9373 
9374 	for (i = 1; i < base->ndigits && i * DEC_DIGITS < 16; i++)
9375 	{
9376 		f = f * NBASE + base->digits[i];
9377 		p -= DEC_DIGITS;
9378 	}
9379 
9380 	/*----------
9381 	 * We have base ~= f * 10^p
9382 	 * so log10(result) = log10(base^exp) ~= exp * (log10(f) + p)
9383 	 *----------
9384 	 */
9385 	f = exp * (log10(f) + p);
9386 
9387 	/*
9388 	 * Apply crude overflow/underflow tests so we can exit early if the result
9389 	 * certainly will overflow/underflow.
9390 	 */
9391 	if (f > 3 * SHRT_MAX * DEC_DIGITS)
9392 		ereport(ERROR,
9393 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
9394 				 errmsg("value overflows numeric format")));
9395 	if (f + 1 < -rscale || f + 1 < -NUMERIC_MAX_DISPLAY_SCALE)
9396 	{
9397 		zero_var(result);
9398 		result->dscale = rscale;
9399 		return;
9400 	}
9401 
9402 	/*
9403 	 * Approximate number of significant digits in the result.  Note that the
9404 	 * underflow test above means that this is necessarily >= 0.
9405 	 */
9406 	sig_digits = 1 + rscale + (int) f;
9407 
9408 	/*
9409 	 * The multiplications to produce the result may introduce an error of up
9410 	 * to around log10(abs(exp)) digits, so work with this many extra digits
9411 	 * of precision (plus a few more for good measure).
9412 	 */
9413 	sig_digits += (int) log(fabs((double) exp)) + 8;
9414 
9415 	/*
9416 	 * Now we can proceed with the multiplications.
9417 	 */
9418 	neg = (exp < 0);
9419 	mask = Abs(exp);
9420 
9421 	init_var(&base_prod);
9422 	set_var_from_var(base, &base_prod);
9423 
9424 	if (mask & 1)
9425 		set_var_from_var(base, result);
9426 	else
9427 		set_var_from_var(&const_one, result);
9428 
9429 	while ((mask >>= 1) > 0)
9430 	{
9431 		/*
9432 		 * Do the multiplications using rscales large enough to hold the
9433 		 * results to the required number of significant digits, but don't
9434 		 * waste time by exceeding the scales of the numbers themselves.
9435 		 */
9436 		local_rscale = sig_digits - 2 * base_prod.weight * DEC_DIGITS;
9437 		local_rscale = Min(local_rscale, 2 * base_prod.dscale);
9438 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9439 
9440 		mul_var(&base_prod, &base_prod, &base_prod, local_rscale);
9441 
9442 		if (mask & 1)
9443 		{
9444 			local_rscale = sig_digits -
9445 				(base_prod.weight + result->weight) * DEC_DIGITS;
9446 			local_rscale = Min(local_rscale,
9447 							   base_prod.dscale + result->dscale);
9448 			local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9449 
9450 			mul_var(&base_prod, result, result, local_rscale);
9451 		}
9452 
9453 		/*
9454 		 * When abs(base) > 1, the number of digits to the left of the decimal
9455 		 * point in base_prod doubles at each iteration, so if exp is large we
9456 		 * could easily spend large amounts of time and memory space doing the
9457 		 * multiplications.  But once the weight exceeds what will fit in
9458 		 * int16, the final result is guaranteed to overflow (or underflow, if
9459 		 * exp < 0), so we can give up before wasting too many cycles.
9460 		 */
9461 		if (base_prod.weight > SHRT_MAX || result->weight > SHRT_MAX)
9462 		{
9463 			/* overflow, unless neg, in which case result should be 0 */
9464 			if (!neg)
9465 				ereport(ERROR,
9466 						(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
9467 						 errmsg("value overflows numeric format")));
9468 			zero_var(result);
9469 			neg = false;
9470 			break;
9471 		}
9472 	}
9473 
9474 	free_var(&base_prod);
9475 
9476 	/* Compensate for input sign, and round to requested rscale */
9477 	if (neg)
9478 		div_var_fast(&const_one, result, result, rscale, true);
9479 	else
9480 		round_var(result, rscale);
9481 }
9482 
9483 /*
9484  * power_ten_int() -
9485  *
9486  *	Raise ten to the power of exp, where exp is an integer.  Note that unlike
9487  *	power_var_int(), this does no overflow/underflow checking or rounding.
9488  */
9489 static void
power_ten_int(int exp,NumericVar * result)9490 power_ten_int(int exp, NumericVar *result)
9491 {
9492 	/* Construct the result directly, starting from 10^0 = 1 */
9493 	set_var_from_var(&const_one, result);
9494 
9495 	/* Scale needed to represent the result exactly */
9496 	result->dscale = exp < 0 ? -exp : 0;
9497 
9498 	/* Base-NBASE weight of result and remaining exponent */
9499 	if (exp >= 0)
9500 		result->weight = exp / DEC_DIGITS;
9501 	else
9502 		result->weight = (exp + 1) / DEC_DIGITS - 1;
9503 
9504 	exp -= result->weight * DEC_DIGITS;
9505 
9506 	/* Final adjustment of the result's single NBASE digit */
9507 	while (exp-- > 0)
9508 		result->digits[0] *= 10;
9509 }
9510 
9511 
9512 /* ----------------------------------------------------------------------
9513  *
9514  * Following are the lowest level functions that operate unsigned
9515  * on the variable level
9516  *
9517  * ----------------------------------------------------------------------
9518  */
9519 
9520 
9521 /* ----------
9522  * cmp_abs() -
9523  *
9524  *	Compare the absolute values of var1 and var2
9525  *	Returns:	-1 for ABS(var1) < ABS(var2)
9526  *				0  for ABS(var1) == ABS(var2)
9527  *				1  for ABS(var1) > ABS(var2)
9528  * ----------
9529  */
9530 static int
cmp_abs(const NumericVar * var1,const NumericVar * var2)9531 cmp_abs(const NumericVar *var1, const NumericVar *var2)
9532 {
9533 	return cmp_abs_common(var1->digits, var1->ndigits, var1->weight,
9534 						  var2->digits, var2->ndigits, var2->weight);
9535 }
9536 
9537 /* ----------
9538  * cmp_abs_common() -
9539  *
9540  *	Main routine of cmp_abs(). This function can be used by both
9541  *	NumericVar and Numeric.
9542  * ----------
9543  */
9544 static int
cmp_abs_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,const NumericDigit * var2digits,int var2ndigits,int var2weight)9545 cmp_abs_common(const NumericDigit *var1digits, int var1ndigits, int var1weight,
9546 			   const NumericDigit *var2digits, int var2ndigits, int var2weight)
9547 {
9548 	int			i1 = 0;
9549 	int			i2 = 0;
9550 
9551 	/* Check any digits before the first common digit */
9552 
9553 	while (var1weight > var2weight && i1 < var1ndigits)
9554 	{
9555 		if (var1digits[i1++] != 0)
9556 			return 1;
9557 		var1weight--;
9558 	}
9559 	while (var2weight > var1weight && i2 < var2ndigits)
9560 	{
9561 		if (var2digits[i2++] != 0)
9562 			return -1;
9563 		var2weight--;
9564 	}
9565 
9566 	/* At this point, either w1 == w2 or we've run out of digits */
9567 
9568 	if (var1weight == var2weight)
9569 	{
9570 		while (i1 < var1ndigits && i2 < var2ndigits)
9571 		{
9572 			int			stat = var1digits[i1++] - var2digits[i2++];
9573 
9574 			if (stat)
9575 			{
9576 				if (stat > 0)
9577 					return 1;
9578 				return -1;
9579 			}
9580 		}
9581 	}
9582 
9583 	/*
9584 	 * At this point, we've run out of digits on one side or the other; so any
9585 	 * remaining nonzero digits imply that side is larger
9586 	 */
9587 	while (i1 < var1ndigits)
9588 	{
9589 		if (var1digits[i1++] != 0)
9590 			return 1;
9591 	}
9592 	while (i2 < var2ndigits)
9593 	{
9594 		if (var2digits[i2++] != 0)
9595 			return -1;
9596 	}
9597 
9598 	return 0;
9599 }
9600 
9601 
9602 /*
9603  * add_abs() -
9604  *
9605  *	Add the absolute values of two variables into result.
9606  *	result might point to one of the operands without danger.
9607  */
9608 static void
add_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)9609 add_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
9610 {
9611 	NumericDigit *res_buf;
9612 	NumericDigit *res_digits;
9613 	int			res_ndigits;
9614 	int			res_weight;
9615 	int			res_rscale,
9616 				rscale1,
9617 				rscale2;
9618 	int			res_dscale;
9619 	int			i,
9620 				i1,
9621 				i2;
9622 	int			carry = 0;
9623 
9624 	/* copy these values into local vars for speed in inner loop */
9625 	int			var1ndigits = var1->ndigits;
9626 	int			var2ndigits = var2->ndigits;
9627 	NumericDigit *var1digits = var1->digits;
9628 	NumericDigit *var2digits = var2->digits;
9629 
9630 	res_weight = Max(var1->weight, var2->weight) + 1;
9631 
9632 	res_dscale = Max(var1->dscale, var2->dscale);
9633 
9634 	/* Note: here we are figuring rscale in base-NBASE digits */
9635 	rscale1 = var1->ndigits - var1->weight - 1;
9636 	rscale2 = var2->ndigits - var2->weight - 1;
9637 	res_rscale = Max(rscale1, rscale2);
9638 
9639 	res_ndigits = res_rscale + res_weight + 1;
9640 	if (res_ndigits <= 0)
9641 		res_ndigits = 1;
9642 
9643 	res_buf = digitbuf_alloc(res_ndigits + 1);
9644 	res_buf[0] = 0;				/* spare digit for later rounding */
9645 	res_digits = res_buf + 1;
9646 
9647 	i1 = res_rscale + var1->weight + 1;
9648 	i2 = res_rscale + var2->weight + 1;
9649 	for (i = res_ndigits - 1; i >= 0; i--)
9650 	{
9651 		i1--;
9652 		i2--;
9653 		if (i1 >= 0 && i1 < var1ndigits)
9654 			carry += var1digits[i1];
9655 		if (i2 >= 0 && i2 < var2ndigits)
9656 			carry += var2digits[i2];
9657 
9658 		if (carry >= NBASE)
9659 		{
9660 			res_digits[i] = carry - NBASE;
9661 			carry = 1;
9662 		}
9663 		else
9664 		{
9665 			res_digits[i] = carry;
9666 			carry = 0;
9667 		}
9668 	}
9669 
9670 	Assert(carry == 0);			/* else we failed to allow for carry out */
9671 
9672 	digitbuf_free(result->buf);
9673 	result->ndigits = res_ndigits;
9674 	result->buf = res_buf;
9675 	result->digits = res_digits;
9676 	result->weight = res_weight;
9677 	result->dscale = res_dscale;
9678 
9679 	/* Remove leading/trailing zeroes */
9680 	strip_var(result);
9681 }
9682 
9683 
9684 /*
9685  * sub_abs()
9686  *
9687  *	Subtract the absolute value of var2 from the absolute value of var1
9688  *	and store in result. result might point to one of the operands
9689  *	without danger.
9690  *
9691  *	ABS(var1) MUST BE GREATER OR EQUAL ABS(var2) !!!
9692  */
9693 static void
sub_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)9694 sub_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
9695 {
9696 	NumericDigit *res_buf;
9697 	NumericDigit *res_digits;
9698 	int			res_ndigits;
9699 	int			res_weight;
9700 	int			res_rscale,
9701 				rscale1,
9702 				rscale2;
9703 	int			res_dscale;
9704 	int			i,
9705 				i1,
9706 				i2;
9707 	int			borrow = 0;
9708 
9709 	/* copy these values into local vars for speed in inner loop */
9710 	int			var1ndigits = var1->ndigits;
9711 	int			var2ndigits = var2->ndigits;
9712 	NumericDigit *var1digits = var1->digits;
9713 	NumericDigit *var2digits = var2->digits;
9714 
9715 	res_weight = var1->weight;
9716 
9717 	res_dscale = Max(var1->dscale, var2->dscale);
9718 
9719 	/* Note: here we are figuring rscale in base-NBASE digits */
9720 	rscale1 = var1->ndigits - var1->weight - 1;
9721 	rscale2 = var2->ndigits - var2->weight - 1;
9722 	res_rscale = Max(rscale1, rscale2);
9723 
9724 	res_ndigits = res_rscale + res_weight + 1;
9725 	if (res_ndigits <= 0)
9726 		res_ndigits = 1;
9727 
9728 	res_buf = digitbuf_alloc(res_ndigits + 1);
9729 	res_buf[0] = 0;				/* spare digit for later rounding */
9730 	res_digits = res_buf + 1;
9731 
9732 	i1 = res_rscale + var1->weight + 1;
9733 	i2 = res_rscale + var2->weight + 1;
9734 	for (i = res_ndigits - 1; i >= 0; i--)
9735 	{
9736 		i1--;
9737 		i2--;
9738 		if (i1 >= 0 && i1 < var1ndigits)
9739 			borrow += var1digits[i1];
9740 		if (i2 >= 0 && i2 < var2ndigits)
9741 			borrow -= var2digits[i2];
9742 
9743 		if (borrow < 0)
9744 		{
9745 			res_digits[i] = borrow + NBASE;
9746 			borrow = -1;
9747 		}
9748 		else
9749 		{
9750 			res_digits[i] = borrow;
9751 			borrow = 0;
9752 		}
9753 	}
9754 
9755 	Assert(borrow == 0);		/* else caller gave us var1 < var2 */
9756 
9757 	digitbuf_free(result->buf);
9758 	result->ndigits = res_ndigits;
9759 	result->buf = res_buf;
9760 	result->digits = res_digits;
9761 	result->weight = res_weight;
9762 	result->dscale = res_dscale;
9763 
9764 	/* Remove leading/trailing zeroes */
9765 	strip_var(result);
9766 }
9767 
9768 /*
9769  * round_var
9770  *
9771  * Round the value of a variable to no more than rscale decimal digits
9772  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
9773  * rounding before the decimal point.
9774  */
9775 static void
round_var(NumericVar * var,int rscale)9776 round_var(NumericVar *var, int rscale)
9777 {
9778 	NumericDigit *digits = var->digits;
9779 	int			di;
9780 	int			ndigits;
9781 	int			carry;
9782 
9783 	var->dscale = rscale;
9784 
9785 	/* decimal digits wanted */
9786 	di = (var->weight + 1) * DEC_DIGITS + rscale;
9787 
9788 	/*
9789 	 * If di = 0, the value loses all digits, but could round up to 1 if its
9790 	 * first extra digit is >= 5.  If di < 0 the result must be 0.
9791 	 */
9792 	if (di < 0)
9793 	{
9794 		var->ndigits = 0;
9795 		var->weight = 0;
9796 		var->sign = NUMERIC_POS;
9797 	}
9798 	else
9799 	{
9800 		/* NBASE digits wanted */
9801 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
9802 
9803 		/* 0, or number of decimal digits to keep in last NBASE digit */
9804 		di %= DEC_DIGITS;
9805 
9806 		if (ndigits < var->ndigits ||
9807 			(ndigits == var->ndigits && di > 0))
9808 		{
9809 			var->ndigits = ndigits;
9810 
9811 #if DEC_DIGITS == 1
9812 			/* di must be zero */
9813 			carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
9814 #else
9815 			if (di == 0)
9816 				carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
9817 			else
9818 			{
9819 				/* Must round within last NBASE digit */
9820 				int			extra,
9821 							pow10;
9822 
9823 #if DEC_DIGITS == 4
9824 				pow10 = round_powers[di];
9825 #elif DEC_DIGITS == 2
9826 				pow10 = 10;
9827 #else
9828 #error unsupported NBASE
9829 #endif
9830 				extra = digits[--ndigits] % pow10;
9831 				digits[ndigits] -= extra;
9832 				carry = 0;
9833 				if (extra >= pow10 / 2)
9834 				{
9835 					pow10 += digits[ndigits];
9836 					if (pow10 >= NBASE)
9837 					{
9838 						pow10 -= NBASE;
9839 						carry = 1;
9840 					}
9841 					digits[ndigits] = pow10;
9842 				}
9843 			}
9844 #endif
9845 
9846 			/* Propagate carry if needed */
9847 			while (carry)
9848 			{
9849 				carry += digits[--ndigits];
9850 				if (carry >= NBASE)
9851 				{
9852 					digits[ndigits] = carry - NBASE;
9853 					carry = 1;
9854 				}
9855 				else
9856 				{
9857 					digits[ndigits] = carry;
9858 					carry = 0;
9859 				}
9860 			}
9861 
9862 			if (ndigits < 0)
9863 			{
9864 				Assert(ndigits == -1);	/* better not have added > 1 digit */
9865 				Assert(var->digits > var->buf);
9866 				var->digits--;
9867 				var->ndigits++;
9868 				var->weight++;
9869 			}
9870 		}
9871 	}
9872 }
9873 
9874 /*
9875  * trunc_var
9876  *
9877  * Truncate (towards zero) the value of a variable at rscale decimal digits
9878  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
9879  * truncation before the decimal point.
9880  */
9881 static void
trunc_var(NumericVar * var,int rscale)9882 trunc_var(NumericVar *var, int rscale)
9883 {
9884 	int			di;
9885 	int			ndigits;
9886 
9887 	var->dscale = rscale;
9888 
9889 	/* decimal digits wanted */
9890 	di = (var->weight + 1) * DEC_DIGITS + rscale;
9891 
9892 	/*
9893 	 * If di <= 0, the value loses all digits.
9894 	 */
9895 	if (di <= 0)
9896 	{
9897 		var->ndigits = 0;
9898 		var->weight = 0;
9899 		var->sign = NUMERIC_POS;
9900 	}
9901 	else
9902 	{
9903 		/* NBASE digits wanted */
9904 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
9905 
9906 		if (ndigits <= var->ndigits)
9907 		{
9908 			var->ndigits = ndigits;
9909 
9910 #if DEC_DIGITS == 1
9911 			/* no within-digit stuff to worry about */
9912 #else
9913 			/* 0, or number of decimal digits to keep in last NBASE digit */
9914 			di %= DEC_DIGITS;
9915 
9916 			if (di > 0)
9917 			{
9918 				/* Must truncate within last NBASE digit */
9919 				NumericDigit *digits = var->digits;
9920 				int			extra,
9921 							pow10;
9922 
9923 #if DEC_DIGITS == 4
9924 				pow10 = round_powers[di];
9925 #elif DEC_DIGITS == 2
9926 				pow10 = 10;
9927 #else
9928 #error unsupported NBASE
9929 #endif
9930 				extra = digits[--ndigits] % pow10;
9931 				digits[ndigits] -= extra;
9932 			}
9933 #endif
9934 		}
9935 	}
9936 }
9937 
9938 /*
9939  * strip_var
9940  *
9941  * Strip any leading and trailing zeroes from a numeric variable
9942  */
9943 static void
strip_var(NumericVar * var)9944 strip_var(NumericVar *var)
9945 {
9946 	NumericDigit *digits = var->digits;
9947 	int			ndigits = var->ndigits;
9948 
9949 	/* Strip leading zeroes */
9950 	while (ndigits > 0 && *digits == 0)
9951 	{
9952 		digits++;
9953 		var->weight--;
9954 		ndigits--;
9955 	}
9956 
9957 	/* Strip trailing zeroes */
9958 	while (ndigits > 0 && digits[ndigits - 1] == 0)
9959 		ndigits--;
9960 
9961 	/* If it's zero, normalize the sign and weight */
9962 	if (ndigits == 0)
9963 	{
9964 		var->sign = NUMERIC_POS;
9965 		var->weight = 0;
9966 	}
9967 
9968 	var->digits = digits;
9969 	var->ndigits = ndigits;
9970 }
9971 
9972 
9973 /* ----------------------------------------------------------------------
9974  *
9975  * Fast sum accumulator functions
9976  *
9977  * ----------------------------------------------------------------------
9978  */
9979 
9980 /*
9981  * Reset the accumulator's value to zero.  The buffers to hold the digits
9982  * are not free'd.
9983  */
9984 static void
accum_sum_reset(NumericSumAccum * accum)9985 accum_sum_reset(NumericSumAccum *accum)
9986 {
9987 	int			i;
9988 
9989 	accum->dscale = 0;
9990 	for (i = 0; i < accum->ndigits; i++)
9991 	{
9992 		accum->pos_digits[i] = 0;
9993 		accum->neg_digits[i] = 0;
9994 	}
9995 }
9996 
9997 /*
9998  * Accumulate a new value.
9999  */
10000 static void
accum_sum_add(NumericSumAccum * accum,const NumericVar * val)10001 accum_sum_add(NumericSumAccum *accum, const NumericVar *val)
10002 {
10003 	int32	   *accum_digits;
10004 	int			i,
10005 				val_i;
10006 	int			val_ndigits;
10007 	NumericDigit *val_digits;
10008 
10009 	/*
10010 	 * If we have accumulated too many values since the last carry
10011 	 * propagation, do it now, to avoid overflowing.  (We could allow more
10012 	 * than NBASE - 1, if we reserved two extra digits, rather than one, for
10013 	 * carry propagation.  But even with NBASE - 1, this needs to be done so
10014 	 * seldom, that the performance difference is negligible.)
10015 	 */
10016 	if (accum->num_uncarried == NBASE - 1)
10017 		accum_sum_carry(accum);
10018 
10019 	/*
10020 	 * Adjust the weight or scale of the old value, so that it can accommodate
10021 	 * the new value.
10022 	 */
10023 	accum_sum_rescale(accum, val);
10024 
10025 	/* */
10026 	if (val->sign == NUMERIC_POS)
10027 		accum_digits = accum->pos_digits;
10028 	else
10029 		accum_digits = accum->neg_digits;
10030 
10031 	/* copy these values into local vars for speed in loop */
10032 	val_ndigits = val->ndigits;
10033 	val_digits = val->digits;
10034 
10035 	i = accum->weight - val->weight;
10036 	for (val_i = 0; val_i < val_ndigits; val_i++)
10037 	{
10038 		accum_digits[i] += (int32) val_digits[val_i];
10039 		i++;
10040 	}
10041 
10042 	accum->num_uncarried++;
10043 }
10044 
10045 /*
10046  * Propagate carries.
10047  */
10048 static void
accum_sum_carry(NumericSumAccum * accum)10049 accum_sum_carry(NumericSumAccum *accum)
10050 {
10051 	int			i;
10052 	int			ndigits;
10053 	int32	   *dig;
10054 	int32		carry;
10055 	int32		newdig = 0;
10056 
10057 	/*
10058 	 * If no new values have been added since last carry propagation, nothing
10059 	 * to do.
10060 	 */
10061 	if (accum->num_uncarried == 0)
10062 		return;
10063 
10064 	/*
10065 	 * We maintain that the weight of the accumulator is always one larger
10066 	 * than needed to hold the current value, before carrying, to make sure
10067 	 * there is enough space for the possible extra digit when carry is
10068 	 * propagated.  We cannot expand the buffer here, unless we require
10069 	 * callers of accum_sum_final() to switch to the right memory context.
10070 	 */
10071 	Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
10072 
10073 	ndigits = accum->ndigits;
10074 
10075 	/* Propagate carry in the positive sum */
10076 	dig = accum->pos_digits;
10077 	carry = 0;
10078 	for (i = ndigits - 1; i >= 0; i--)
10079 	{
10080 		newdig = dig[i] + carry;
10081 		if (newdig >= NBASE)
10082 		{
10083 			carry = newdig / NBASE;
10084 			newdig -= carry * NBASE;
10085 		}
10086 		else
10087 			carry = 0;
10088 		dig[i] = newdig;
10089 	}
10090 	/* Did we use up the digit reserved for carry propagation? */
10091 	if (newdig > 0)
10092 		accum->have_carry_space = false;
10093 
10094 	/* And the same for the negative sum */
10095 	dig = accum->neg_digits;
10096 	carry = 0;
10097 	for (i = ndigits - 1; i >= 0; i--)
10098 	{
10099 		newdig = dig[i] + carry;
10100 		if (newdig >= NBASE)
10101 		{
10102 			carry = newdig / NBASE;
10103 			newdig -= carry * NBASE;
10104 		}
10105 		else
10106 			carry = 0;
10107 		dig[i] = newdig;
10108 	}
10109 	if (newdig > 0)
10110 		accum->have_carry_space = false;
10111 
10112 	accum->num_uncarried = 0;
10113 }
10114 
10115 /*
10116  * Re-scale accumulator to accommodate new value.
10117  *
10118  * If the new value has more digits than the current digit buffers in the
10119  * accumulator, enlarge the buffers.
10120  */
10121 static void
accum_sum_rescale(NumericSumAccum * accum,const NumericVar * val)10122 accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val)
10123 {
10124 	int			old_weight = accum->weight;
10125 	int			old_ndigits = accum->ndigits;
10126 	int			accum_ndigits;
10127 	int			accum_weight;
10128 	int			accum_rscale;
10129 	int			val_rscale;
10130 
10131 	accum_weight = old_weight;
10132 	accum_ndigits = old_ndigits;
10133 
10134 	/*
10135 	 * Does the new value have a larger weight? If so, enlarge the buffers,
10136 	 * and shift the existing value to the new weight, by adding leading
10137 	 * zeros.
10138 	 *
10139 	 * We enforce that the accumulator always has a weight one larger than
10140 	 * needed for the inputs, so that we have space for an extra digit at the
10141 	 * final carry-propagation phase, if necessary.
10142 	 */
10143 	if (val->weight >= accum_weight)
10144 	{
10145 		accum_weight = val->weight + 1;
10146 		accum_ndigits = accum_ndigits + (accum_weight - old_weight);
10147 	}
10148 
10149 	/*
10150 	 * Even though the new value is small, we might've used up the space
10151 	 * reserved for the carry digit in the last call to accum_sum_carry().  If
10152 	 * so, enlarge to make room for another one.
10153 	 */
10154 	else if (!accum->have_carry_space)
10155 	{
10156 		accum_weight++;
10157 		accum_ndigits++;
10158 	}
10159 
10160 	/* Is the new value wider on the right side? */
10161 	accum_rscale = accum_ndigits - accum_weight - 1;
10162 	val_rscale = val->ndigits - val->weight - 1;
10163 	if (val_rscale > accum_rscale)
10164 		accum_ndigits = accum_ndigits + (val_rscale - accum_rscale);
10165 
10166 	if (accum_ndigits != old_ndigits ||
10167 		accum_weight != old_weight)
10168 	{
10169 		int32	   *new_pos_digits;
10170 		int32	   *new_neg_digits;
10171 		int			weightdiff;
10172 
10173 		weightdiff = accum_weight - old_weight;
10174 
10175 		new_pos_digits = palloc0(accum_ndigits * sizeof(int32));
10176 		new_neg_digits = palloc0(accum_ndigits * sizeof(int32));
10177 
10178 		if (accum->pos_digits)
10179 		{
10180 			memcpy(&new_pos_digits[weightdiff], accum->pos_digits,
10181 				   old_ndigits * sizeof(int32));
10182 			pfree(accum->pos_digits);
10183 
10184 			memcpy(&new_neg_digits[weightdiff], accum->neg_digits,
10185 				   old_ndigits * sizeof(int32));
10186 			pfree(accum->neg_digits);
10187 		}
10188 
10189 		accum->pos_digits = new_pos_digits;
10190 		accum->neg_digits = new_neg_digits;
10191 
10192 		accum->weight = accum_weight;
10193 		accum->ndigits = accum_ndigits;
10194 
10195 		Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
10196 		accum->have_carry_space = true;
10197 	}
10198 
10199 	if (val->dscale > accum->dscale)
10200 		accum->dscale = val->dscale;
10201 }
10202 
10203 /*
10204  * Return the current value of the accumulator.  This perform final carry
10205  * propagation, and adds together the positive and negative sums.
10206  *
10207  * Unlike all the other routines, the caller is not required to switch to
10208  * the memory context that holds the accumulator.
10209  */
10210 static void
accum_sum_final(NumericSumAccum * accum,NumericVar * result)10211 accum_sum_final(NumericSumAccum *accum, NumericVar *result)
10212 {
10213 	int			i;
10214 	NumericVar	pos_var;
10215 	NumericVar	neg_var;
10216 
10217 	if (accum->ndigits == 0)
10218 	{
10219 		set_var_from_var(&const_zero, result);
10220 		return;
10221 	}
10222 
10223 	/* Perform final carry */
10224 	accum_sum_carry(accum);
10225 
10226 	/* Create NumericVars representing the positive and negative sums */
10227 	init_var(&pos_var);
10228 	init_var(&neg_var);
10229 
10230 	pos_var.ndigits = neg_var.ndigits = accum->ndigits;
10231 	pos_var.weight = neg_var.weight = accum->weight;
10232 	pos_var.dscale = neg_var.dscale = accum->dscale;
10233 	pos_var.sign = NUMERIC_POS;
10234 	neg_var.sign = NUMERIC_NEG;
10235 
10236 	pos_var.buf = pos_var.digits = digitbuf_alloc(accum->ndigits);
10237 	neg_var.buf = neg_var.digits = digitbuf_alloc(accum->ndigits);
10238 
10239 	for (i = 0; i < accum->ndigits; i++)
10240 	{
10241 		Assert(accum->pos_digits[i] < NBASE);
10242 		pos_var.digits[i] = (int16) accum->pos_digits[i];
10243 
10244 		Assert(accum->neg_digits[i] < NBASE);
10245 		neg_var.digits[i] = (int16) accum->neg_digits[i];
10246 	}
10247 
10248 	/* And add them together */
10249 	add_var(&pos_var, &neg_var, result);
10250 
10251 	/* Remove leading/trailing zeroes */
10252 	strip_var(result);
10253 }
10254 
10255 /*
10256  * Copy an accumulator's state.
10257  *
10258  * 'dst' is assumed to be uninitialized beforehand.  No attempt is made at
10259  * freeing old values.
10260  */
10261 static void
accum_sum_copy(NumericSumAccum * dst,NumericSumAccum * src)10262 accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src)
10263 {
10264 	dst->pos_digits = palloc(src->ndigits * sizeof(int32));
10265 	dst->neg_digits = palloc(src->ndigits * sizeof(int32));
10266 
10267 	memcpy(dst->pos_digits, src->pos_digits, src->ndigits * sizeof(int32));
10268 	memcpy(dst->neg_digits, src->neg_digits, src->ndigits * sizeof(int32));
10269 	dst->num_uncarried = src->num_uncarried;
10270 	dst->ndigits = src->ndigits;
10271 	dst->weight = src->weight;
10272 	dst->dscale = src->dscale;
10273 }
10274 
10275 /*
10276  * Add the current value of 'accum2' into 'accum'.
10277  */
10278 static void
accum_sum_combine(NumericSumAccum * accum,NumericSumAccum * accum2)10279 accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2)
10280 {
10281 	NumericVar	tmp_var;
10282 
10283 	init_var(&tmp_var);
10284 
10285 	accum_sum_final(accum2, &tmp_var);
10286 	accum_sum_add(accum, &tmp_var);
10287 
10288 	free_var(&tmp_var);
10289 }
10290