1 /*-------------------------------------------------------------------------
2  *
3  * numeric.c
4  *	  An exact numeric data type for the Postgres database system
5  *
6  * Original coding 1998, Jan Wieck.  Heavily revised 2003, Tom Lane.
7  *
8  * Many of the algorithmic ideas are borrowed from David M. Smith's "FM"
9  * multiple-precision math library, most recently published as Algorithm
10  * 786: Multiple-Precision Complex Arithmetic and Functions, ACM
11  * Transactions on Mathematical Software, Vol. 24, No. 4, December 1998,
12  * pages 359-367.
13  *
14  * Copyright (c) 1998-2018, PostgreSQL Global Development Group
15  *
16  * IDENTIFICATION
17  *	  src/backend/utils/adt/numeric.c
18  *
19  *-------------------------------------------------------------------------
20  */
21 
22 #include "postgres.h"
23 
24 #include <ctype.h>
25 #include <float.h>
26 #include <limits.h>
27 #include <math.h>
28 
29 #include "access/hash.h"
30 #include "catalog/pg_type.h"
31 #include "common/int.h"
32 #include "funcapi.h"
33 #include "lib/hyperloglog.h"
34 #include "libpq/pqformat.h"
35 #include "miscadmin.h"
36 #include "nodes/nodeFuncs.h"
37 #include "utils/array.h"
38 #include "utils/builtins.h"
39 #include "utils/guc.h"
40 #include "utils/int8.h"
41 #include "utils/numeric.h"
42 #include "utils/sortsupport.h"
43 
44 /* ----------
45  * Uncomment the following to enable compilation of dump_numeric()
46  * and dump_var() and to get a dump of any result produced by make_result().
47  * ----------
48 #define NUMERIC_DEBUG
49  */
50 
51 
52 /* ----------
53  * Local data types
54  *
55  * Numeric values are represented in a base-NBASE floating point format.
56  * Each "digit" ranges from 0 to NBASE-1.  The type NumericDigit is signed
57  * and wide enough to store a digit.  We assume that NBASE*NBASE can fit in
58  * an int.  Although the purely calculational routines could handle any even
59  * NBASE that's less than sqrt(INT_MAX), in practice we are only interested
60  * in NBASE a power of ten, so that I/O conversions and decimal rounding
61  * are easy.  Also, it's actually more efficient if NBASE is rather less than
62  * sqrt(INT_MAX), so that there is "headroom" for mul_var and div_var_fast to
63  * postpone processing carries.
64  *
65  * Values of NBASE other than 10000 are considered of historical interest only
66  * and are no longer supported in any sense; no mechanism exists for the client
67  * to discover the base, so every client supporting binary mode expects the
68  * base-10000 format.  If you plan to change this, also note the numeric
69  * abbreviation code, which assumes NBASE=10000.
70  * ----------
71  */
72 
73 #if 0
74 #define NBASE		10
75 #define HALF_NBASE	5
76 #define DEC_DIGITS	1			/* decimal digits per NBASE digit */
77 #define MUL_GUARD_DIGITS	4	/* these are measured in NBASE digits */
78 #define DIV_GUARD_DIGITS	8
79 
80 typedef signed char NumericDigit;
81 #endif
82 
83 #if 0
84 #define NBASE		100
85 #define HALF_NBASE	50
86 #define DEC_DIGITS	2			/* decimal digits per NBASE digit */
87 #define MUL_GUARD_DIGITS	3	/* these are measured in NBASE digits */
88 #define DIV_GUARD_DIGITS	6
89 
90 typedef signed char NumericDigit;
91 #endif
92 
93 #if 1
94 #define NBASE		10000
95 #define HALF_NBASE	5000
96 #define DEC_DIGITS	4			/* decimal digits per NBASE digit */
97 #define MUL_GUARD_DIGITS	2	/* these are measured in NBASE digits */
98 #define DIV_GUARD_DIGITS	4
99 
100 typedef int16 NumericDigit;
101 #endif
102 
103 /*
104  * The Numeric type as stored on disk.
105  *
106  * If the high bits of the first word of a NumericChoice (n_header, or
107  * n_short.n_header, or n_long.n_sign_dscale) are NUMERIC_SHORT, then the
108  * numeric follows the NumericShort format; if they are NUMERIC_POS or
109  * NUMERIC_NEG, it follows the NumericLong format.  If they are NUMERIC_NAN,
110  * it is a NaN.  We currently always store a NaN using just two bytes (i.e.
111  * only n_header), but previous releases used only the NumericLong format,
112  * so we might find 4-byte NaNs on disk if a database has been migrated using
113  * pg_upgrade.  In either case, when the high bits indicate a NaN, the
114  * remaining bits are never examined.  Currently, we always initialize these
115  * to zero, but it might be possible to use them for some other purpose in
116  * the future.
117  *
118  * In the NumericShort format, the remaining 14 bits of the header word
119  * (n_short.n_header) are allocated as follows: 1 for sign (positive or
120  * negative), 6 for dynamic scale, and 7 for weight.  In practice, most
121  * commonly-encountered values can be represented this way.
122  *
123  * In the NumericLong format, the remaining 14 bits of the header word
124  * (n_long.n_sign_dscale) represent the display scale; and the weight is
125  * stored separately in n_weight.
126  *
127  * NOTE: by convention, values in the packed form have been stripped of
128  * all leading and trailing zero digits (where a "digit" is of base NBASE).
129  * In particular, if the value is zero, there will be no digits at all!
130  * The weight is arbitrary in that case, but we normally set it to zero.
131  */
132 
133 struct NumericShort
134 {
135 	uint16		n_header;		/* Sign + display scale + weight */
136 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
137 };
138 
139 struct NumericLong
140 {
141 	uint16		n_sign_dscale;	/* Sign + display scale */
142 	int16		n_weight;		/* Weight of 1st digit	*/
143 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
144 };
145 
146 union NumericChoice
147 {
148 	uint16		n_header;		/* Header word */
149 	struct NumericLong n_long;	/* Long form (4-byte header) */
150 	struct NumericShort n_short;	/* Short form (2-byte header) */
151 };
152 
153 struct NumericData
154 {
155 	int32		vl_len_;		/* varlena header (do not touch directly!) */
156 	union NumericChoice choice; /* choice of format */
157 };
158 
159 
160 /*
161  * Interpretation of high bits.
162  */
163 
164 #define NUMERIC_SIGN_MASK	0xC000
165 #define NUMERIC_POS			0x0000
166 #define NUMERIC_NEG			0x4000
167 #define NUMERIC_SHORT		0x8000
168 #define NUMERIC_NAN			0xC000
169 
170 #define NUMERIC_FLAGBITS(n) ((n)->choice.n_header & NUMERIC_SIGN_MASK)
171 #define NUMERIC_IS_NAN(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_NAN)
172 #define NUMERIC_IS_SHORT(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_SHORT)
173 
174 #define NUMERIC_HDRSZ	(VARHDRSZ + sizeof(uint16) + sizeof(int16))
175 #define NUMERIC_HDRSZ_SHORT (VARHDRSZ + sizeof(uint16))
176 
177 /*
178  * If the flag bits are NUMERIC_SHORT or NUMERIC_NAN, we want the short header;
179  * otherwise, we want the long one.  Instead of testing against each value, we
180  * can just look at the high bit, for a slight efficiency gain.
181  */
182 #define NUMERIC_HEADER_IS_SHORT(n)	(((n)->choice.n_header & 0x8000) != 0)
183 #define NUMERIC_HEADER_SIZE(n) \
184 	(VARHDRSZ + sizeof(uint16) + \
185 	 (NUMERIC_HEADER_IS_SHORT(n) ? 0 : sizeof(int16)))
186 
187 /*
188  * Short format definitions.
189  */
190 
191 #define NUMERIC_SHORT_SIGN_MASK			0x2000
192 #define NUMERIC_SHORT_DSCALE_MASK		0x1F80
193 #define NUMERIC_SHORT_DSCALE_SHIFT		7
194 #define NUMERIC_SHORT_DSCALE_MAX		\
195 	(NUMERIC_SHORT_DSCALE_MASK >> NUMERIC_SHORT_DSCALE_SHIFT)
196 #define NUMERIC_SHORT_WEIGHT_SIGN_MASK	0x0040
197 #define NUMERIC_SHORT_WEIGHT_MASK		0x003F
198 #define NUMERIC_SHORT_WEIGHT_MAX		NUMERIC_SHORT_WEIGHT_MASK
199 #define NUMERIC_SHORT_WEIGHT_MIN		(-(NUMERIC_SHORT_WEIGHT_MASK+1))
200 
201 /*
202  * Extract sign, display scale, weight.
203  */
204 
205 #define NUMERIC_DSCALE_MASK			0x3FFF
206 #define NUMERIC_DSCALE_MAX			NUMERIC_DSCALE_MASK
207 
208 #define NUMERIC_SIGN(n) \
209 	(NUMERIC_IS_SHORT(n) ? \
210 		(((n)->choice.n_short.n_header & NUMERIC_SHORT_SIGN_MASK) ? \
211 		NUMERIC_NEG : NUMERIC_POS) : NUMERIC_FLAGBITS(n))
212 #define NUMERIC_DSCALE(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
213 	((n)->choice.n_short.n_header & NUMERIC_SHORT_DSCALE_MASK) \
214 		>> NUMERIC_SHORT_DSCALE_SHIFT \
215 	: ((n)->choice.n_long.n_sign_dscale & NUMERIC_DSCALE_MASK))
216 #define NUMERIC_WEIGHT(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
217 	(((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_SIGN_MASK ? \
218 		~NUMERIC_SHORT_WEIGHT_MASK : 0) \
219 	 | ((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_MASK)) \
220 	: ((n)->choice.n_long.n_weight))
221 
222 /* ----------
223  * NumericVar is the format we use for arithmetic.  The digit-array part
224  * is the same as the NumericData storage format, but the header is more
225  * complex.
226  *
227  * The value represented by a NumericVar is determined by the sign, weight,
228  * ndigits, and digits[] array.
229  *
230  * Note: the first digit of a NumericVar's value is assumed to be multiplied
231  * by NBASE ** weight.  Another way to say it is that there are weight+1
232  * digits before the decimal point.  It is possible to have weight < 0.
233  *
234  * buf points at the physical start of the palloc'd digit buffer for the
235  * NumericVar.  digits points at the first digit in actual use (the one
236  * with the specified weight).  We normally leave an unused digit or two
237  * (preset to zeroes) between buf and digits, so that there is room to store
238  * a carry out of the top digit without reallocating space.  We just need to
239  * decrement digits (and increment weight) to make room for the carry digit.
240  * (There is no such extra space in a numeric value stored in the database,
241  * only in a NumericVar in memory.)
242  *
243  * If buf is NULL then the digit buffer isn't actually palloc'd and should
244  * not be freed --- see the constants below for an example.
245  *
246  * dscale, or display scale, is the nominal precision expressed as number
247  * of digits after the decimal point (it must always be >= 0 at present).
248  * dscale may be more than the number of physically stored fractional digits,
249  * implying that we have suppressed storage of significant trailing zeroes.
250  * It should never be less than the number of stored digits, since that would
251  * imply hiding digits that are present.  NOTE that dscale is always expressed
252  * in *decimal* digits, and so it may correspond to a fractional number of
253  * base-NBASE digits --- divide by DEC_DIGITS to convert to NBASE digits.
254  *
255  * rscale, or result scale, is the target precision for a computation.
256  * Like dscale it is expressed as number of *decimal* digits after the decimal
257  * point, and is always >= 0 at present.
258  * Note that rscale is not stored in variables --- it's figured on-the-fly
259  * from the dscales of the inputs.
260  *
261  * While we consistently use "weight" to refer to the base-NBASE weight of
262  * a numeric value, it is convenient in some scale-related calculations to
263  * make use of the base-10 weight (ie, the approximate log10 of the value).
264  * To avoid confusion, such a decimal-units weight is called a "dweight".
265  *
266  * NB: All the variable-level functions are written in a style that makes it
267  * possible to give one and the same variable as argument and destination.
268  * This is feasible because the digit buffer is separate from the variable.
269  * ----------
270  */
271 typedef struct NumericVar
272 {
273 	int			ndigits;		/* # of digits in digits[] - can be 0! */
274 	int			weight;			/* weight of first digit */
275 	int			sign;			/* NUMERIC_POS, NUMERIC_NEG, or NUMERIC_NAN */
276 	int			dscale;			/* display scale */
277 	NumericDigit *buf;			/* start of palloc'd space for digits[] */
278 	NumericDigit *digits;		/* base-NBASE digits */
279 } NumericVar;
280 
281 
282 /* ----------
283  * Data for generate_series
284  * ----------
285  */
286 typedef struct
287 {
288 	NumericVar	current;
289 	NumericVar	stop;
290 	NumericVar	step;
291 } generate_series_numeric_fctx;
292 
293 
294 /* ----------
295  * Sort support.
296  * ----------
297  */
298 typedef struct
299 {
300 	void	   *buf;			/* buffer for short varlenas */
301 	int64		input_count;	/* number of non-null values seen */
302 	bool		estimating;		/* true if estimating cardinality */
303 
304 	hyperLogLogState abbr_card; /* cardinality estimator */
305 } NumericSortSupport;
306 
307 
308 /* ----------
309  * Fast sum accumulator.
310  *
311  * NumericSumAccum is used to implement SUM(), and other standard aggregates
312  * that track the sum of input values.  It uses 32-bit integers to store the
313  * digits, instead of the normal 16-bit integers (with NBASE=10000).  This
314  * way, we can safely accumulate up to NBASE - 1 values without propagating
315  * carry, before risking overflow of any of the digits.  'num_uncarried'
316  * tracks how many values have been accumulated without propagating carry.
317  *
318  * Positive and negative values are accumulated separately, in 'pos_digits'
319  * and 'neg_digits'.  This is simpler and faster than deciding whether to add
320  * or subtract from the current value, for each new value (see sub_var() for
321  * the logic we avoid by doing this).  Both buffers are of same size, and
322  * have the same weight and scale.  In accum_sum_final(), the positive and
323  * negative sums are added together to produce the final result.
324  *
325  * When a new value has a larger ndigits or weight than the accumulator
326  * currently does, the accumulator is enlarged to accommodate the new value.
327  * We normally have one zero digit reserved for carry propagation, and that
328  * is indicated by the 'have_carry_space' flag.  When accum_sum_carry() uses
329  * up the reserved digit, it clears the 'have_carry_space' flag.  The next
330  * call to accum_sum_add() will enlarge the buffer, to make room for the
331  * extra digit, and set the flag again.
332  *
333  * To initialize a new accumulator, simply reset all fields to zeros.
334  *
335  * The accumulator does not handle NaNs.
336  * ----------
337  */
338 typedef struct NumericSumAccum
339 {
340 	int			ndigits;
341 	int			weight;
342 	int			dscale;
343 	int			num_uncarried;
344 	bool		have_carry_space;
345 	int32	   *pos_digits;
346 	int32	   *neg_digits;
347 } NumericSumAccum;
348 
349 
350 /*
351  * We define our own macros for packing and unpacking abbreviated-key
352  * representations for numeric values in order to avoid depending on
353  * USE_FLOAT8_BYVAL.  The type of abbreviation we use is based only on
354  * the size of a datum, not the argument-passing convention for float8.
355  */
356 #define NUMERIC_ABBREV_BITS (SIZEOF_DATUM * BITS_PER_BYTE)
357 #if SIZEOF_DATUM == 8
358 #define NumericAbbrevGetDatum(X) ((Datum) (X))
359 #define DatumGetNumericAbbrev(X) ((int64) (X))
360 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT64_MIN)
361 #else
362 #define NumericAbbrevGetDatum(X) ((Datum) (X))
363 #define DatumGetNumericAbbrev(X) ((int32) (X))
364 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT32_MIN)
365 #endif
366 
367 
368 /* ----------
369  * Some preinitialized constants
370  * ----------
371  */
372 static const NumericDigit const_zero_data[1] = {0};
373 static const NumericVar const_zero =
374 {0, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_zero_data};
375 
376 static const NumericDigit const_one_data[1] = {1};
377 static const NumericVar const_one =
378 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_one_data};
379 
380 static const NumericDigit const_two_data[1] = {2};
381 static const NumericVar const_two =
382 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_two_data};
383 
384 #if DEC_DIGITS == 4
385 static const NumericDigit const_zero_point_five_data[1] = {5000};
386 #elif DEC_DIGITS == 2
387 static const NumericDigit const_zero_point_five_data[1] = {50};
388 #elif DEC_DIGITS == 1
389 static const NumericDigit const_zero_point_five_data[1] = {5};
390 #endif
391 static const NumericVar const_zero_point_five =
392 {1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
393 
394 #if DEC_DIGITS == 4
395 static const NumericDigit const_zero_point_nine_data[1] = {9000};
396 #elif DEC_DIGITS == 2
397 static const NumericDigit const_zero_point_nine_data[1] = {90};
398 #elif DEC_DIGITS == 1
399 static const NumericDigit const_zero_point_nine_data[1] = {9};
400 #endif
401 static const NumericVar const_zero_point_nine =
402 {1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_nine_data};
403 
404 #if DEC_DIGITS == 4
405 static const NumericDigit const_one_point_one_data[2] = {1, 1000};
406 #elif DEC_DIGITS == 2
407 static const NumericDigit const_one_point_one_data[2] = {1, 10};
408 #elif DEC_DIGITS == 1
409 static const NumericDigit const_one_point_one_data[2] = {1, 1};
410 #endif
411 static const NumericVar const_one_point_one =
412 {2, 0, NUMERIC_POS, 1, NULL, (NumericDigit *) const_one_point_one_data};
413 
414 static const NumericVar const_nan =
415 {0, 0, NUMERIC_NAN, 0, NULL, NULL};
416 
417 #if DEC_DIGITS == 4
418 static const int round_powers[4] = {0, 1000, 100, 10};
419 #endif
420 
421 
422 /* ----------
423  * Local functions
424  * ----------
425  */
426 
427 #ifdef NUMERIC_DEBUG
428 static void dump_numeric(const char *str, Numeric num);
429 static void dump_var(const char *str, NumericVar *var);
430 #else
431 #define dump_numeric(s,n)
432 #define dump_var(s,v)
433 #endif
434 
435 #define digitbuf_alloc(ndigits)  \
436 	((NumericDigit *) palloc((ndigits) * sizeof(NumericDigit)))
437 #define digitbuf_free(buf)	\
438 	do { \
439 		 if ((buf) != NULL) \
440 			 pfree(buf); \
441 	} while (0)
442 
443 #define init_var(v)		MemSetAligned(v, 0, sizeof(NumericVar))
444 
445 #define NUMERIC_DIGITS(num) (NUMERIC_HEADER_IS_SHORT(num) ? \
446 	(num)->choice.n_short.n_data : (num)->choice.n_long.n_data)
447 #define NUMERIC_NDIGITS(num) \
448 	((VARSIZE(num) - NUMERIC_HEADER_SIZE(num)) / sizeof(NumericDigit))
449 #define NUMERIC_CAN_BE_SHORT(scale,weight) \
450 	((scale) <= NUMERIC_SHORT_DSCALE_MAX && \
451 	(weight) <= NUMERIC_SHORT_WEIGHT_MAX && \
452 	(weight) >= NUMERIC_SHORT_WEIGHT_MIN)
453 
454 static void alloc_var(NumericVar *var, int ndigits);
455 static void free_var(NumericVar *var);
456 static void zero_var(NumericVar *var);
457 
458 static const char *set_var_from_str(const char *str, const char *cp,
459 				 NumericVar *dest);
460 static void set_var_from_num(Numeric value, NumericVar *dest);
461 static void init_var_from_num(Numeric num, NumericVar *dest);
462 static void set_var_from_var(const NumericVar *value, NumericVar *dest);
463 static char *get_str_from_var(const NumericVar *var);
464 static char *get_str_from_var_sci(const NumericVar *var, int rscale);
465 
466 static Numeric make_result(const NumericVar *var);
467 
468 static void apply_typmod(NumericVar *var, int32 typmod);
469 
470 static int32 numericvar_to_int32(const NumericVar *var);
471 static bool numericvar_to_int64(const NumericVar *var, int64 *result);
472 static void int64_to_numericvar(int64 val, NumericVar *var);
473 #ifdef HAVE_INT128
474 static bool numericvar_to_int128(const NumericVar *var, int128 *result);
475 static void int128_to_numericvar(int128 val, NumericVar *var);
476 #endif
477 static double numeric_to_double_no_overflow(Numeric num);
478 static double numericvar_to_double_no_overflow(const NumericVar *var);
479 
480 static Datum numeric_abbrev_convert(Datum original_datum, SortSupport ssup);
481 static bool numeric_abbrev_abort(int memtupcount, SortSupport ssup);
482 static int	numeric_fast_cmp(Datum x, Datum y, SortSupport ssup);
483 static int	numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup);
484 
485 static Datum numeric_abbrev_convert_var(const NumericVar *var,
486 						   NumericSortSupport *nss);
487 
488 static int	cmp_numerics(Numeric num1, Numeric num2);
489 static int	cmp_var(const NumericVar *var1, const NumericVar *var2);
490 static int cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
491 			   int var1weight, int var1sign,
492 			   const NumericDigit *var2digits, int var2ndigits,
493 			   int var2weight, int var2sign);
494 static void add_var(const NumericVar *var1, const NumericVar *var2,
495 		NumericVar *result);
496 static void sub_var(const NumericVar *var1, const NumericVar *var2,
497 		NumericVar *result);
498 static void mul_var(const NumericVar *var1, const NumericVar *var2,
499 		NumericVar *result,
500 		int rscale);
501 static void div_var(const NumericVar *var1, const NumericVar *var2,
502 		NumericVar *result,
503 		int rscale, bool round);
504 static void div_var_fast(const NumericVar *var1, const NumericVar *var2,
505 			 NumericVar *result, int rscale, bool round);
506 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
507 static void mod_var(const NumericVar *var1, const NumericVar *var2,
508 		NumericVar *result);
509 static void ceil_var(const NumericVar *var, NumericVar *result);
510 static void floor_var(const NumericVar *var, NumericVar *result);
511 
512 static void sqrt_var(const NumericVar *arg, NumericVar *result, int rscale);
513 static void exp_var(const NumericVar *arg, NumericVar *result, int rscale);
514 static int	estimate_ln_dweight(const NumericVar *var);
515 static void ln_var(const NumericVar *arg, NumericVar *result, int rscale);
516 static void log_var(const NumericVar *base, const NumericVar *num,
517 		NumericVar *result);
518 static void power_var(const NumericVar *base, const NumericVar *exp,
519 		  NumericVar *result);
520 static void power_var_int(const NumericVar *base, int exp, NumericVar *result,
521 			  int rscale);
522 static void power_ten_int(int exp, NumericVar *result);
523 
524 static int	cmp_abs(const NumericVar *var1, const NumericVar *var2);
525 static int cmp_abs_common(const NumericDigit *var1digits, int var1ndigits,
526 			   int var1weight,
527 			   const NumericDigit *var2digits, int var2ndigits,
528 			   int var2weight);
529 static void add_abs(const NumericVar *var1, const NumericVar *var2,
530 		NumericVar *result);
531 static void sub_abs(const NumericVar *var1, const NumericVar *var2,
532 		NumericVar *result);
533 static void round_var(NumericVar *var, int rscale);
534 static void trunc_var(NumericVar *var, int rscale);
535 static void strip_var(NumericVar *var);
536 static void compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
537 			   const NumericVar *count_var, NumericVar *result_var);
538 
539 static void accum_sum_add(NumericSumAccum *accum, const NumericVar *var1);
540 static void accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val);
541 static void accum_sum_carry(NumericSumAccum *accum);
542 static void accum_sum_reset(NumericSumAccum *accum);
543 static void accum_sum_final(NumericSumAccum *accum, NumericVar *result);
544 static void accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src);
545 static void accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2);
546 
547 
548 /* ----------------------------------------------------------------------
549  *
550  * Input-, output- and rounding-functions
551  *
552  * ----------------------------------------------------------------------
553  */
554 
555 
556 /*
557  * numeric_in() -
558  *
559  *	Input function for numeric data type
560  */
561 Datum
numeric_in(PG_FUNCTION_ARGS)562 numeric_in(PG_FUNCTION_ARGS)
563 {
564 	char	   *str = PG_GETARG_CSTRING(0);
565 
566 #ifdef NOT_USED
567 	Oid			typelem = PG_GETARG_OID(1);
568 #endif
569 	int32		typmod = PG_GETARG_INT32(2);
570 	Numeric		res;
571 	const char *cp;
572 
573 	/* Skip leading spaces */
574 	cp = str;
575 	while (*cp)
576 	{
577 		if (!isspace((unsigned char) *cp))
578 			break;
579 		cp++;
580 	}
581 
582 	/*
583 	 * Check for NaN
584 	 */
585 	if (pg_strncasecmp(cp, "NaN", 3) == 0)
586 	{
587 		res = make_result(&const_nan);
588 
589 		/* Should be nothing left but spaces */
590 		cp += 3;
591 		while (*cp)
592 		{
593 			if (!isspace((unsigned char) *cp))
594 				ereport(ERROR,
595 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
596 						 errmsg("invalid input syntax for type %s: \"%s\"",
597 								"numeric", str)));
598 			cp++;
599 		}
600 	}
601 	else
602 	{
603 		/*
604 		 * Use set_var_from_str() to parse a normal numeric value
605 		 */
606 		NumericVar	value;
607 
608 		init_var(&value);
609 
610 		cp = set_var_from_str(str, cp, &value);
611 
612 		/*
613 		 * We duplicate a few lines of code here because we would like to
614 		 * throw any trailing-junk syntax error before any semantic error
615 		 * resulting from apply_typmod.  We can't easily fold the two cases
616 		 * together because we mustn't apply apply_typmod to a NaN.
617 		 */
618 		while (*cp)
619 		{
620 			if (!isspace((unsigned char) *cp))
621 				ereport(ERROR,
622 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
623 						 errmsg("invalid input syntax for type %s: \"%s\"",
624 								"numeric", str)));
625 			cp++;
626 		}
627 
628 		apply_typmod(&value, typmod);
629 
630 		res = make_result(&value);
631 		free_var(&value);
632 	}
633 
634 	PG_RETURN_NUMERIC(res);
635 }
636 
637 
638 /*
639  * numeric_out() -
640  *
641  *	Output function for numeric data type
642  */
643 Datum
numeric_out(PG_FUNCTION_ARGS)644 numeric_out(PG_FUNCTION_ARGS)
645 {
646 	Numeric		num = PG_GETARG_NUMERIC(0);
647 	NumericVar	x;
648 	char	   *str;
649 
650 	/*
651 	 * Handle NaN
652 	 */
653 	if (NUMERIC_IS_NAN(num))
654 		PG_RETURN_CSTRING(pstrdup("NaN"));
655 
656 	/*
657 	 * Get the number in the variable format.
658 	 */
659 	init_var_from_num(num, &x);
660 
661 	str = get_str_from_var(&x);
662 
663 	PG_RETURN_CSTRING(str);
664 }
665 
666 /*
667  * numeric_is_nan() -
668  *
669  *	Is Numeric value a NaN?
670  */
671 bool
numeric_is_nan(Numeric num)672 numeric_is_nan(Numeric num)
673 {
674 	return NUMERIC_IS_NAN(num);
675 }
676 
677 /*
678  * numeric_maximum_size() -
679  *
680  *	Maximum size of a numeric with given typmod, or -1 if unlimited/unknown.
681  */
682 int32
numeric_maximum_size(int32 typmod)683 numeric_maximum_size(int32 typmod)
684 {
685 	int			precision;
686 	int			numeric_digits;
687 
688 	if (typmod < (int32) (VARHDRSZ))
689 		return -1;
690 
691 	/* precision (ie, max # of digits) is in upper bits of typmod */
692 	precision = ((typmod - VARHDRSZ) >> 16) & 0xffff;
693 
694 	/*
695 	 * This formula computes the maximum number of NumericDigits we could need
696 	 * in order to store the specified number of decimal digits. Because the
697 	 * weight is stored as a number of NumericDigits rather than a number of
698 	 * decimal digits, it's possible that the first NumericDigit will contain
699 	 * only a single decimal digit.  Thus, the first two decimal digits can
700 	 * require two NumericDigits to store, but it isn't until we reach
701 	 * DEC_DIGITS + 2 decimal digits that we potentially need a third
702 	 * NumericDigit.
703 	 */
704 	numeric_digits = (precision + 2 * (DEC_DIGITS - 1)) / DEC_DIGITS;
705 
706 	/*
707 	 * In most cases, the size of a numeric will be smaller than the value
708 	 * computed below, because the varlena header will typically get toasted
709 	 * down to a single byte before being stored on disk, and it may also be
710 	 * possible to use a short numeric header.  But our job here is to compute
711 	 * the worst case.
712 	 */
713 	return NUMERIC_HDRSZ + (numeric_digits * sizeof(NumericDigit));
714 }
715 
716 /*
717  * numeric_out_sci() -
718  *
719  *	Output function for numeric data type in scientific notation.
720  */
721 char *
numeric_out_sci(Numeric num,int scale)722 numeric_out_sci(Numeric num, int scale)
723 {
724 	NumericVar	x;
725 	char	   *str;
726 
727 	/*
728 	 * Handle NaN
729 	 */
730 	if (NUMERIC_IS_NAN(num))
731 		return pstrdup("NaN");
732 
733 	init_var_from_num(num, &x);
734 
735 	str = get_str_from_var_sci(&x, scale);
736 
737 	return str;
738 }
739 
740 /*
741  * numeric_normalize() -
742  *
743  *	Output function for numeric data type, suppressing insignificant trailing
744  *	zeroes and then any trailing decimal point.  The intent of this is to
745  *	produce strings that are equal if and only if the input numeric values
746  *	compare equal.
747  */
748 char *
numeric_normalize(Numeric num)749 numeric_normalize(Numeric num)
750 {
751 	NumericVar	x;
752 	char	   *str;
753 	int			last;
754 
755 	/*
756 	 * Handle NaN
757 	 */
758 	if (NUMERIC_IS_NAN(num))
759 		return pstrdup("NaN");
760 
761 	init_var_from_num(num, &x);
762 
763 	str = get_str_from_var(&x);
764 
765 	/* If there's no decimal point, there's certainly nothing to remove. */
766 	if (strchr(str, '.') != NULL)
767 	{
768 		/*
769 		 * Back up over trailing fractional zeroes.  Since there is a decimal
770 		 * point, this loop will terminate safely.
771 		 */
772 		last = strlen(str) - 1;
773 		while (str[last] == '0')
774 			last--;
775 
776 		/* We want to get rid of the decimal point too, if it's now last. */
777 		if (str[last] == '.')
778 			last--;
779 
780 		/* Delete whatever we backed up over. */
781 		str[last + 1] = '\0';
782 	}
783 
784 	return str;
785 }
786 
787 /*
788  *		numeric_recv			- converts external binary format to numeric
789  *
790  * External format is a sequence of int16's:
791  * ndigits, weight, sign, dscale, NumericDigits.
792  */
793 Datum
numeric_recv(PG_FUNCTION_ARGS)794 numeric_recv(PG_FUNCTION_ARGS)
795 {
796 	StringInfo	buf = (StringInfo) PG_GETARG_POINTER(0);
797 
798 #ifdef NOT_USED
799 	Oid			typelem = PG_GETARG_OID(1);
800 #endif
801 	int32		typmod = PG_GETARG_INT32(2);
802 	NumericVar	value;
803 	Numeric		res;
804 	int			len,
805 				i;
806 
807 	init_var(&value);
808 
809 	len = (uint16) pq_getmsgint(buf, sizeof(uint16));
810 
811 	alloc_var(&value, len);
812 
813 	value.weight = (int16) pq_getmsgint(buf, sizeof(int16));
814 	/* we allow any int16 for weight --- OK? */
815 
816 	value.sign = (uint16) pq_getmsgint(buf, sizeof(uint16));
817 	if (!(value.sign == NUMERIC_POS ||
818 		  value.sign == NUMERIC_NEG ||
819 		  value.sign == NUMERIC_NAN))
820 		ereport(ERROR,
821 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
822 				 errmsg("invalid sign in external \"numeric\" value")));
823 
824 	value.dscale = (uint16) pq_getmsgint(buf, sizeof(uint16));
825 	if ((value.dscale & NUMERIC_DSCALE_MASK) != value.dscale)
826 		ereport(ERROR,
827 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
828 				 errmsg("invalid scale in external \"numeric\" value")));
829 
830 	for (i = 0; i < len; i++)
831 	{
832 		NumericDigit d = pq_getmsgint(buf, sizeof(NumericDigit));
833 
834 		if (d < 0 || d >= NBASE)
835 			ereport(ERROR,
836 					(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
837 					 errmsg("invalid digit in external \"numeric\" value")));
838 		value.digits[i] = d;
839 	}
840 
841 	/*
842 	 * If the given dscale would hide any digits, truncate those digits away.
843 	 * We could alternatively throw an error, but that would take a bunch of
844 	 * extra code (about as much as trunc_var involves), and it might cause
845 	 * client compatibility issues.
846 	 */
847 	trunc_var(&value, value.dscale);
848 
849 	apply_typmod(&value, typmod);
850 
851 	res = make_result(&value);
852 	free_var(&value);
853 
854 	PG_RETURN_NUMERIC(res);
855 }
856 
857 /*
858  *		numeric_send			- converts numeric to binary format
859  */
860 Datum
numeric_send(PG_FUNCTION_ARGS)861 numeric_send(PG_FUNCTION_ARGS)
862 {
863 	Numeric		num = PG_GETARG_NUMERIC(0);
864 	NumericVar	x;
865 	StringInfoData buf;
866 	int			i;
867 
868 	init_var_from_num(num, &x);
869 
870 	pq_begintypsend(&buf);
871 
872 	pq_sendint16(&buf, x.ndigits);
873 	pq_sendint16(&buf, x.weight);
874 	pq_sendint16(&buf, x.sign);
875 	pq_sendint16(&buf, x.dscale);
876 	for (i = 0; i < x.ndigits; i++)
877 		pq_sendint16(&buf, x.digits[i]);
878 
879 	PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
880 }
881 
882 
883 /*
884  * numeric_transform() -
885  *
886  * Flatten calls to numeric's length coercion function that solely represent
887  * increases in allowable precision.  Scale changes mutate every datum, so
888  * they are unoptimizable.  Some values, e.g. 1E-1001, can only fit into an
889  * unconstrained numeric, so a change from an unconstrained numeric to any
890  * constrained numeric is also unoptimizable.
891  */
892 Datum
numeric_transform(PG_FUNCTION_ARGS)893 numeric_transform(PG_FUNCTION_ARGS)
894 {
895 	FuncExpr   *expr = castNode(FuncExpr, PG_GETARG_POINTER(0));
896 	Node	   *ret = NULL;
897 	Node	   *typmod;
898 
899 	Assert(list_length(expr->args) >= 2);
900 
901 	typmod = (Node *) lsecond(expr->args);
902 
903 	if (IsA(typmod, Const) &&!((Const *) typmod)->constisnull)
904 	{
905 		Node	   *source = (Node *) linitial(expr->args);
906 		int32		old_typmod = exprTypmod(source);
907 		int32		new_typmod = DatumGetInt32(((Const *) typmod)->constvalue);
908 		int32		old_scale = (old_typmod - VARHDRSZ) & 0xffff;
909 		int32		new_scale = (new_typmod - VARHDRSZ) & 0xffff;
910 		int32		old_precision = (old_typmod - VARHDRSZ) >> 16 & 0xffff;
911 		int32		new_precision = (new_typmod - VARHDRSZ) >> 16 & 0xffff;
912 
913 		/*
914 		 * If new_typmod < VARHDRSZ, the destination is unconstrained; that's
915 		 * always OK.  If old_typmod >= VARHDRSZ, the source is constrained,
916 		 * and we're OK if the scale is unchanged and the precision is not
917 		 * decreasing.  See further notes in function header comment.
918 		 */
919 		if (new_typmod < (int32) VARHDRSZ ||
920 			(old_typmod >= (int32) VARHDRSZ &&
921 			 new_scale == old_scale && new_precision >= old_precision))
922 			ret = relabel_to_typmod(source, new_typmod);
923 	}
924 
925 	PG_RETURN_POINTER(ret);
926 }
927 
928 /*
929  * numeric() -
930  *
931  *	This is a special function called by the Postgres database system
932  *	before a value is stored in a tuple's attribute. The precision and
933  *	scale of the attribute have to be applied on the value.
934  */
935 Datum
numeric(PG_FUNCTION_ARGS)936 numeric		(PG_FUNCTION_ARGS)
937 {
938 	Numeric		num = PG_GETARG_NUMERIC(0);
939 	int32		typmod = PG_GETARG_INT32(1);
940 	Numeric		new;
941 	int32		tmp_typmod;
942 	int			precision;
943 	int			scale;
944 	int			ddigits;
945 	int			maxdigits;
946 	NumericVar	var;
947 
948 	/*
949 	 * Handle NaN
950 	 */
951 	if (NUMERIC_IS_NAN(num))
952 		PG_RETURN_NUMERIC(make_result(&const_nan));
953 
954 	/*
955 	 * If the value isn't a valid type modifier, simply return a copy of the
956 	 * input value
957 	 */
958 	if (typmod < (int32) (VARHDRSZ))
959 	{
960 		new = (Numeric) palloc(VARSIZE(num));
961 		memcpy(new, num, VARSIZE(num));
962 		PG_RETURN_NUMERIC(new);
963 	}
964 
965 	/*
966 	 * Get the precision and scale out of the typmod value
967 	 */
968 	tmp_typmod = typmod - VARHDRSZ;
969 	precision = (tmp_typmod >> 16) & 0xffff;
970 	scale = tmp_typmod & 0xffff;
971 	maxdigits = precision - scale;
972 
973 	/*
974 	 * If the number is certainly in bounds and due to the target scale no
975 	 * rounding could be necessary, just make a copy of the input and modify
976 	 * its scale fields, unless the larger scale forces us to abandon the
977 	 * short representation.  (Note we assume the existing dscale is
978 	 * honest...)
979 	 */
980 	ddigits = (NUMERIC_WEIGHT(num) + 1) * DEC_DIGITS;
981 	if (ddigits <= maxdigits && scale >= NUMERIC_DSCALE(num)
982 		&& (NUMERIC_CAN_BE_SHORT(scale, NUMERIC_WEIGHT(num))
983 			|| !NUMERIC_IS_SHORT(num)))
984 	{
985 		new = (Numeric) palloc(VARSIZE(num));
986 		memcpy(new, num, VARSIZE(num));
987 		if (NUMERIC_IS_SHORT(num))
988 			new->choice.n_short.n_header =
989 				(num->choice.n_short.n_header & ~NUMERIC_SHORT_DSCALE_MASK)
990 				| (scale << NUMERIC_SHORT_DSCALE_SHIFT);
991 		else
992 			new->choice.n_long.n_sign_dscale = NUMERIC_SIGN(new) |
993 				((uint16) scale & NUMERIC_DSCALE_MASK);
994 		PG_RETURN_NUMERIC(new);
995 	}
996 
997 	/*
998 	 * We really need to fiddle with things - unpack the number into a
999 	 * variable and let apply_typmod() do it.
1000 	 */
1001 	init_var(&var);
1002 
1003 	set_var_from_num(num, &var);
1004 	apply_typmod(&var, typmod);
1005 	new = make_result(&var);
1006 
1007 	free_var(&var);
1008 
1009 	PG_RETURN_NUMERIC(new);
1010 }
1011 
1012 Datum
numerictypmodin(PG_FUNCTION_ARGS)1013 numerictypmodin(PG_FUNCTION_ARGS)
1014 {
1015 	ArrayType  *ta = PG_GETARG_ARRAYTYPE_P(0);
1016 	int32	   *tl;
1017 	int			n;
1018 	int32		typmod;
1019 
1020 	tl = ArrayGetIntegerTypmods(ta, &n);
1021 
1022 	if (n == 2)
1023 	{
1024 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1025 			ereport(ERROR,
1026 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1027 					 errmsg("NUMERIC precision %d must be between 1 and %d",
1028 							tl[0], NUMERIC_MAX_PRECISION)));
1029 		if (tl[1] < 0 || tl[1] > tl[0])
1030 			ereport(ERROR,
1031 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1032 					 errmsg("NUMERIC scale %d must be between 0 and precision %d",
1033 							tl[1], tl[0])));
1034 		typmod = ((tl[0] << 16) | tl[1]) + VARHDRSZ;
1035 	}
1036 	else if (n == 1)
1037 	{
1038 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1039 			ereport(ERROR,
1040 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1041 					 errmsg("NUMERIC precision %d must be between 1 and %d",
1042 							tl[0], NUMERIC_MAX_PRECISION)));
1043 		/* scale defaults to zero */
1044 		typmod = (tl[0] << 16) + VARHDRSZ;
1045 	}
1046 	else
1047 	{
1048 		ereport(ERROR,
1049 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1050 				 errmsg("invalid NUMERIC type modifier")));
1051 		typmod = 0;				/* keep compiler quiet */
1052 	}
1053 
1054 	PG_RETURN_INT32(typmod);
1055 }
1056 
1057 Datum
numerictypmodout(PG_FUNCTION_ARGS)1058 numerictypmodout(PG_FUNCTION_ARGS)
1059 {
1060 	int32		typmod = PG_GETARG_INT32(0);
1061 	char	   *res = (char *) palloc(64);
1062 
1063 	if (typmod >= 0)
1064 		snprintf(res, 64, "(%d,%d)",
1065 				 ((typmod - VARHDRSZ) >> 16) & 0xffff,
1066 				 (typmod - VARHDRSZ) & 0xffff);
1067 	else
1068 		*res = '\0';
1069 
1070 	PG_RETURN_CSTRING(res);
1071 }
1072 
1073 
1074 /* ----------------------------------------------------------------------
1075  *
1076  * Sign manipulation, rounding and the like
1077  *
1078  * ----------------------------------------------------------------------
1079  */
1080 
1081 Datum
numeric_abs(PG_FUNCTION_ARGS)1082 numeric_abs(PG_FUNCTION_ARGS)
1083 {
1084 	Numeric		num = PG_GETARG_NUMERIC(0);
1085 	Numeric		res;
1086 
1087 	/*
1088 	 * Handle NaN
1089 	 */
1090 	if (NUMERIC_IS_NAN(num))
1091 		PG_RETURN_NUMERIC(make_result(&const_nan));
1092 
1093 	/*
1094 	 * Do it the easy way directly on the packed format
1095 	 */
1096 	res = (Numeric) palloc(VARSIZE(num));
1097 	memcpy(res, num, VARSIZE(num));
1098 
1099 	if (NUMERIC_IS_SHORT(num))
1100 		res->choice.n_short.n_header =
1101 			num->choice.n_short.n_header & ~NUMERIC_SHORT_SIGN_MASK;
1102 	else
1103 		res->choice.n_long.n_sign_dscale = NUMERIC_POS | NUMERIC_DSCALE(num);
1104 
1105 	PG_RETURN_NUMERIC(res);
1106 }
1107 
1108 
1109 Datum
numeric_uminus(PG_FUNCTION_ARGS)1110 numeric_uminus(PG_FUNCTION_ARGS)
1111 {
1112 	Numeric		num = PG_GETARG_NUMERIC(0);
1113 	Numeric		res;
1114 
1115 	/*
1116 	 * Handle NaN
1117 	 */
1118 	if (NUMERIC_IS_NAN(num))
1119 		PG_RETURN_NUMERIC(make_result(&const_nan));
1120 
1121 	/*
1122 	 * Do it the easy way directly on the packed format
1123 	 */
1124 	res = (Numeric) palloc(VARSIZE(num));
1125 	memcpy(res, num, VARSIZE(num));
1126 
1127 	/*
1128 	 * The packed format is known to be totally zero digit trimmed always. So
1129 	 * we can identify a ZERO by the fact that there are no digits at all.  Do
1130 	 * nothing to a zero.
1131 	 */
1132 	if (NUMERIC_NDIGITS(num) != 0)
1133 	{
1134 		/* Else, flip the sign */
1135 		if (NUMERIC_IS_SHORT(num))
1136 			res->choice.n_short.n_header =
1137 				num->choice.n_short.n_header ^ NUMERIC_SHORT_SIGN_MASK;
1138 		else if (NUMERIC_SIGN(num) == NUMERIC_POS)
1139 			res->choice.n_long.n_sign_dscale =
1140 				NUMERIC_NEG | NUMERIC_DSCALE(num);
1141 		else
1142 			res->choice.n_long.n_sign_dscale =
1143 				NUMERIC_POS | NUMERIC_DSCALE(num);
1144 	}
1145 
1146 	PG_RETURN_NUMERIC(res);
1147 }
1148 
1149 
1150 Datum
numeric_uplus(PG_FUNCTION_ARGS)1151 numeric_uplus(PG_FUNCTION_ARGS)
1152 {
1153 	Numeric		num = PG_GETARG_NUMERIC(0);
1154 	Numeric		res;
1155 
1156 	res = (Numeric) palloc(VARSIZE(num));
1157 	memcpy(res, num, VARSIZE(num));
1158 
1159 	PG_RETURN_NUMERIC(res);
1160 }
1161 
1162 /*
1163  * numeric_sign() -
1164  *
1165  * returns -1 if the argument is less than 0, 0 if the argument is equal
1166  * to 0, and 1 if the argument is greater than zero.
1167  */
1168 Datum
numeric_sign(PG_FUNCTION_ARGS)1169 numeric_sign(PG_FUNCTION_ARGS)
1170 {
1171 	Numeric		num = PG_GETARG_NUMERIC(0);
1172 	Numeric		res;
1173 	NumericVar	result;
1174 
1175 	/*
1176 	 * Handle NaN
1177 	 */
1178 	if (NUMERIC_IS_NAN(num))
1179 		PG_RETURN_NUMERIC(make_result(&const_nan));
1180 
1181 	init_var(&result);
1182 
1183 	/*
1184 	 * The packed format is known to be totally zero digit trimmed always. So
1185 	 * we can identify a ZERO by the fact that there are no digits at all.
1186 	 */
1187 	if (NUMERIC_NDIGITS(num) == 0)
1188 		set_var_from_var(&const_zero, &result);
1189 	else
1190 	{
1191 		/*
1192 		 * And if there are some, we return a copy of ONE with the sign of our
1193 		 * argument
1194 		 */
1195 		set_var_from_var(&const_one, &result);
1196 		result.sign = NUMERIC_SIGN(num);
1197 	}
1198 
1199 	res = make_result(&result);
1200 	free_var(&result);
1201 
1202 	PG_RETURN_NUMERIC(res);
1203 }
1204 
1205 
1206 /*
1207  * numeric_round() -
1208  *
1209  *	Round a value to have 'scale' digits after the decimal point.
1210  *	We allow negative 'scale', implying rounding before the decimal
1211  *	point --- Oracle interprets rounding that way.
1212  */
1213 Datum
numeric_round(PG_FUNCTION_ARGS)1214 numeric_round(PG_FUNCTION_ARGS)
1215 {
1216 	Numeric		num = PG_GETARG_NUMERIC(0);
1217 	int32		scale = PG_GETARG_INT32(1);
1218 	Numeric		res;
1219 	NumericVar	arg;
1220 
1221 	/*
1222 	 * Handle NaN
1223 	 */
1224 	if (NUMERIC_IS_NAN(num))
1225 		PG_RETURN_NUMERIC(make_result(&const_nan));
1226 
1227 	/*
1228 	 * Limit the scale value to avoid possible overflow in calculations
1229 	 */
1230 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1231 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1232 
1233 	/*
1234 	 * Unpack the argument and round it at the proper digit position
1235 	 */
1236 	init_var(&arg);
1237 	set_var_from_num(num, &arg);
1238 
1239 	round_var(&arg, scale);
1240 
1241 	/* We don't allow negative output dscale */
1242 	if (scale < 0)
1243 		arg.dscale = 0;
1244 
1245 	/*
1246 	 * Return the rounded result
1247 	 */
1248 	res = make_result(&arg);
1249 
1250 	free_var(&arg);
1251 	PG_RETURN_NUMERIC(res);
1252 }
1253 
1254 
1255 /*
1256  * numeric_trunc() -
1257  *
1258  *	Truncate a value to have 'scale' digits after the decimal point.
1259  *	We allow negative 'scale', implying a truncation before the decimal
1260  *	point --- Oracle interprets truncation that way.
1261  */
1262 Datum
numeric_trunc(PG_FUNCTION_ARGS)1263 numeric_trunc(PG_FUNCTION_ARGS)
1264 {
1265 	Numeric		num = PG_GETARG_NUMERIC(0);
1266 	int32		scale = PG_GETARG_INT32(1);
1267 	Numeric		res;
1268 	NumericVar	arg;
1269 
1270 	/*
1271 	 * Handle NaN
1272 	 */
1273 	if (NUMERIC_IS_NAN(num))
1274 		PG_RETURN_NUMERIC(make_result(&const_nan));
1275 
1276 	/*
1277 	 * Limit the scale value to avoid possible overflow in calculations
1278 	 */
1279 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1280 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1281 
1282 	/*
1283 	 * Unpack the argument and truncate it at the proper digit position
1284 	 */
1285 	init_var(&arg);
1286 	set_var_from_num(num, &arg);
1287 
1288 	trunc_var(&arg, scale);
1289 
1290 	/* We don't allow negative output dscale */
1291 	if (scale < 0)
1292 		arg.dscale = 0;
1293 
1294 	/*
1295 	 * Return the truncated result
1296 	 */
1297 	res = make_result(&arg);
1298 
1299 	free_var(&arg);
1300 	PG_RETURN_NUMERIC(res);
1301 }
1302 
1303 
1304 /*
1305  * numeric_ceil() -
1306  *
1307  *	Return the smallest integer greater than or equal to the argument
1308  */
1309 Datum
numeric_ceil(PG_FUNCTION_ARGS)1310 numeric_ceil(PG_FUNCTION_ARGS)
1311 {
1312 	Numeric		num = PG_GETARG_NUMERIC(0);
1313 	Numeric		res;
1314 	NumericVar	result;
1315 
1316 	if (NUMERIC_IS_NAN(num))
1317 		PG_RETURN_NUMERIC(make_result(&const_nan));
1318 
1319 	init_var_from_num(num, &result);
1320 	ceil_var(&result, &result);
1321 
1322 	res = make_result(&result);
1323 	free_var(&result);
1324 
1325 	PG_RETURN_NUMERIC(res);
1326 }
1327 
1328 
1329 /*
1330  * numeric_floor() -
1331  *
1332  *	Return the largest integer equal to or less than the argument
1333  */
1334 Datum
numeric_floor(PG_FUNCTION_ARGS)1335 numeric_floor(PG_FUNCTION_ARGS)
1336 {
1337 	Numeric		num = PG_GETARG_NUMERIC(0);
1338 	Numeric		res;
1339 	NumericVar	result;
1340 
1341 	if (NUMERIC_IS_NAN(num))
1342 		PG_RETURN_NUMERIC(make_result(&const_nan));
1343 
1344 	init_var_from_num(num, &result);
1345 	floor_var(&result, &result);
1346 
1347 	res = make_result(&result);
1348 	free_var(&result);
1349 
1350 	PG_RETURN_NUMERIC(res);
1351 }
1352 
1353 
1354 /*
1355  * generate_series_numeric() -
1356  *
1357  *	Generate series of numeric.
1358  */
1359 Datum
generate_series_numeric(PG_FUNCTION_ARGS)1360 generate_series_numeric(PG_FUNCTION_ARGS)
1361 {
1362 	return generate_series_step_numeric(fcinfo);
1363 }
1364 
1365 Datum
generate_series_step_numeric(PG_FUNCTION_ARGS)1366 generate_series_step_numeric(PG_FUNCTION_ARGS)
1367 {
1368 	generate_series_numeric_fctx *fctx;
1369 	FuncCallContext *funcctx;
1370 	MemoryContext oldcontext;
1371 
1372 	if (SRF_IS_FIRSTCALL())
1373 	{
1374 		Numeric		start_num = PG_GETARG_NUMERIC(0);
1375 		Numeric		stop_num = PG_GETARG_NUMERIC(1);
1376 		NumericVar	steploc = const_one;
1377 
1378 		/* handle NaN in start and stop values */
1379 		if (NUMERIC_IS_NAN(start_num))
1380 			ereport(ERROR,
1381 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1382 					 errmsg("start value cannot be NaN")));
1383 
1384 		if (NUMERIC_IS_NAN(stop_num))
1385 			ereport(ERROR,
1386 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1387 					 errmsg("stop value cannot be NaN")));
1388 
1389 		/* see if we were given an explicit step size */
1390 		if (PG_NARGS() == 3)
1391 		{
1392 			Numeric		step_num = PG_GETARG_NUMERIC(2);
1393 
1394 			if (NUMERIC_IS_NAN(step_num))
1395 				ereport(ERROR,
1396 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1397 						 errmsg("step size cannot be NaN")));
1398 
1399 			init_var_from_num(step_num, &steploc);
1400 
1401 			if (cmp_var(&steploc, &const_zero) == 0)
1402 				ereport(ERROR,
1403 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1404 						 errmsg("step size cannot equal zero")));
1405 		}
1406 
1407 		/* create a function context for cross-call persistence */
1408 		funcctx = SRF_FIRSTCALL_INIT();
1409 
1410 		/*
1411 		 * Switch to memory context appropriate for multiple function calls.
1412 		 */
1413 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1414 
1415 		/* allocate memory for user context */
1416 		fctx = (generate_series_numeric_fctx *)
1417 			palloc(sizeof(generate_series_numeric_fctx));
1418 
1419 		/*
1420 		 * Use fctx to keep state from call to call. Seed current with the
1421 		 * original start value. We must copy the start_num and stop_num
1422 		 * values rather than pointing to them, since we may have detoasted
1423 		 * them in the per-call context.
1424 		 */
1425 		init_var(&fctx->current);
1426 		init_var(&fctx->stop);
1427 		init_var(&fctx->step);
1428 
1429 		set_var_from_num(start_num, &fctx->current);
1430 		set_var_from_num(stop_num, &fctx->stop);
1431 		set_var_from_var(&steploc, &fctx->step);
1432 
1433 		funcctx->user_fctx = fctx;
1434 		MemoryContextSwitchTo(oldcontext);
1435 	}
1436 
1437 	/* stuff done on every call of the function */
1438 	funcctx = SRF_PERCALL_SETUP();
1439 
1440 	/*
1441 	 * Get the saved state and use current state as the result of this
1442 	 * iteration.
1443 	 */
1444 	fctx = funcctx->user_fctx;
1445 
1446 	if ((fctx->step.sign == NUMERIC_POS &&
1447 		 cmp_var(&fctx->current, &fctx->stop) <= 0) ||
1448 		(fctx->step.sign == NUMERIC_NEG &&
1449 		 cmp_var(&fctx->current, &fctx->stop) >= 0))
1450 	{
1451 		Numeric		result = make_result(&fctx->current);
1452 
1453 		/* switch to memory context appropriate for iteration calculation */
1454 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1455 
1456 		/* increment current in preparation for next iteration */
1457 		add_var(&fctx->current, &fctx->step, &fctx->current);
1458 		MemoryContextSwitchTo(oldcontext);
1459 
1460 		/* do when there is more left to send */
1461 		SRF_RETURN_NEXT(funcctx, NumericGetDatum(result));
1462 	}
1463 	else
1464 		/* do when there is no more left */
1465 		SRF_RETURN_DONE(funcctx);
1466 }
1467 
1468 
1469 /*
1470  * Implements the numeric version of the width_bucket() function
1471  * defined by SQL2003. See also width_bucket_float8().
1472  *
1473  * 'bound1' and 'bound2' are the lower and upper bounds of the
1474  * histogram's range, respectively. 'count' is the number of buckets
1475  * in the histogram. width_bucket() returns an integer indicating the
1476  * bucket number that 'operand' belongs to in an equiwidth histogram
1477  * with the specified characteristics. An operand smaller than the
1478  * lower bound is assigned to bucket 0. An operand greater than the
1479  * upper bound is assigned to an additional bucket (with number
1480  * count+1). We don't allow "NaN" for any of the numeric arguments.
1481  */
1482 Datum
width_bucket_numeric(PG_FUNCTION_ARGS)1483 width_bucket_numeric(PG_FUNCTION_ARGS)
1484 {
1485 	Numeric		operand = PG_GETARG_NUMERIC(0);
1486 	Numeric		bound1 = PG_GETARG_NUMERIC(1);
1487 	Numeric		bound2 = PG_GETARG_NUMERIC(2);
1488 	int32		count = PG_GETARG_INT32(3);
1489 	NumericVar	count_var;
1490 	NumericVar	result_var;
1491 	int32		result;
1492 
1493 	if (count <= 0)
1494 		ereport(ERROR,
1495 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1496 				 errmsg("count must be greater than zero")));
1497 
1498 	if (NUMERIC_IS_NAN(operand) ||
1499 		NUMERIC_IS_NAN(bound1) ||
1500 		NUMERIC_IS_NAN(bound2))
1501 		ereport(ERROR,
1502 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1503 				 errmsg("operand, lower bound, and upper bound cannot be NaN")));
1504 
1505 	init_var(&result_var);
1506 	init_var(&count_var);
1507 
1508 	/* Convert 'count' to a numeric, for ease of use later */
1509 	int64_to_numericvar((int64) count, &count_var);
1510 
1511 	switch (cmp_numerics(bound1, bound2))
1512 	{
1513 		case 0:
1514 			ereport(ERROR,
1515 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1516 					 errmsg("lower bound cannot equal upper bound")));
1517 			break;
1518 
1519 			/* bound1 < bound2 */
1520 		case -1:
1521 			if (cmp_numerics(operand, bound1) < 0)
1522 				set_var_from_var(&const_zero, &result_var);
1523 			else if (cmp_numerics(operand, bound2) >= 0)
1524 				add_var(&count_var, &const_one, &result_var);
1525 			else
1526 				compute_bucket(operand, bound1, bound2,
1527 							   &count_var, &result_var);
1528 			break;
1529 
1530 			/* bound1 > bound2 */
1531 		case 1:
1532 			if (cmp_numerics(operand, bound1) > 0)
1533 				set_var_from_var(&const_zero, &result_var);
1534 			else if (cmp_numerics(operand, bound2) <= 0)
1535 				add_var(&count_var, &const_one, &result_var);
1536 			else
1537 				compute_bucket(operand, bound1, bound2,
1538 							   &count_var, &result_var);
1539 			break;
1540 	}
1541 
1542 	/* if result exceeds the range of a legal int4, we ereport here */
1543 	result = numericvar_to_int32(&result_var);
1544 
1545 	free_var(&count_var);
1546 	free_var(&result_var);
1547 
1548 	PG_RETURN_INT32(result);
1549 }
1550 
1551 /*
1552  * If 'operand' is not outside the bucket range, determine the correct
1553  * bucket for it to go. The calculations performed by this function
1554  * are derived directly from the SQL2003 spec.
1555  */
1556 static void
compute_bucket(Numeric operand,Numeric bound1,Numeric bound2,const NumericVar * count_var,NumericVar * result_var)1557 compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
1558 			   const NumericVar *count_var, NumericVar *result_var)
1559 {
1560 	NumericVar	bound1_var;
1561 	NumericVar	bound2_var;
1562 	NumericVar	operand_var;
1563 
1564 	init_var_from_num(bound1, &bound1_var);
1565 	init_var_from_num(bound2, &bound2_var);
1566 	init_var_from_num(operand, &operand_var);
1567 
1568 	if (cmp_var(&bound1_var, &bound2_var) < 0)
1569 	{
1570 		sub_var(&operand_var, &bound1_var, &operand_var);
1571 		sub_var(&bound2_var, &bound1_var, &bound2_var);
1572 		div_var(&operand_var, &bound2_var, result_var,
1573 				select_div_scale(&operand_var, &bound2_var), true);
1574 	}
1575 	else
1576 	{
1577 		sub_var(&bound1_var, &operand_var, &operand_var);
1578 		sub_var(&bound1_var, &bound2_var, &bound1_var);
1579 		div_var(&operand_var, &bound1_var, result_var,
1580 				select_div_scale(&operand_var, &bound1_var), true);
1581 	}
1582 
1583 	mul_var(result_var, count_var, result_var,
1584 			result_var->dscale + count_var->dscale);
1585 	add_var(result_var, &const_one, result_var);
1586 	floor_var(result_var, result_var);
1587 
1588 	free_var(&bound1_var);
1589 	free_var(&bound2_var);
1590 	free_var(&operand_var);
1591 }
1592 
1593 /* ----------------------------------------------------------------------
1594  *
1595  * Comparison functions
1596  *
1597  * Note: btree indexes need these routines not to leak memory; therefore,
1598  * be careful to free working copies of toasted datums.  Most places don't
1599  * need to be so careful.
1600  *
1601  * Sort support:
1602  *
1603  * We implement the sortsupport strategy routine in order to get the benefit of
1604  * abbreviation. The ordinary numeric comparison can be quite slow as a result
1605  * of palloc/pfree cycles (due to detoasting packed values for alignment);
1606  * while this could be worked on itself, the abbreviation strategy gives more
1607  * speedup in many common cases.
1608  *
1609  * Two different representations are used for the abbreviated form, one in
1610  * int32 and one in int64, whichever fits into a by-value Datum.  In both cases
1611  * the representation is negated relative to the original value, because we use
1612  * the largest negative value for NaN, which sorts higher than other values. We
1613  * convert the absolute value of the numeric to a 31-bit or 63-bit positive
1614  * value, and then negate it if the original number was positive.
1615  *
1616  * We abort the abbreviation process if the abbreviation cardinality is below
1617  * 0.01% of the row count (1 per 10k non-null rows).  The actual break-even
1618  * point is somewhat below that, perhaps 1 per 30k (at 1 per 100k there's a
1619  * very small penalty), but we don't want to build up too many abbreviated
1620  * values before first testing for abort, so we take the slightly pessimistic
1621  * number.  We make no attempt to estimate the cardinality of the real values,
1622  * since it plays no part in the cost model here (if the abbreviation is equal,
1623  * the cost of comparing equal and unequal underlying values is comparable).
1624  * We discontinue even checking for abort (saving us the hashing overhead) if
1625  * the estimated cardinality gets to 100k; that would be enough to support many
1626  * billions of rows while doing no worse than breaking even.
1627  *
1628  * ----------------------------------------------------------------------
1629  */
1630 
1631 /*
1632  * Sort support strategy routine.
1633  */
1634 Datum
numeric_sortsupport(PG_FUNCTION_ARGS)1635 numeric_sortsupport(PG_FUNCTION_ARGS)
1636 {
1637 	SortSupport ssup = (SortSupport) PG_GETARG_POINTER(0);
1638 
1639 	ssup->comparator = numeric_fast_cmp;
1640 
1641 	if (ssup->abbreviate)
1642 	{
1643 		NumericSortSupport *nss;
1644 		MemoryContext oldcontext = MemoryContextSwitchTo(ssup->ssup_cxt);
1645 
1646 		nss = palloc(sizeof(NumericSortSupport));
1647 
1648 		/*
1649 		 * palloc a buffer for handling unaligned packed values in addition to
1650 		 * the support struct
1651 		 */
1652 		nss->buf = palloc(VARATT_SHORT_MAX + VARHDRSZ + 1);
1653 
1654 		nss->input_count = 0;
1655 		nss->estimating = true;
1656 		initHyperLogLog(&nss->abbr_card, 10);
1657 
1658 		ssup->ssup_extra = nss;
1659 
1660 		ssup->abbrev_full_comparator = ssup->comparator;
1661 		ssup->comparator = numeric_cmp_abbrev;
1662 		ssup->abbrev_converter = numeric_abbrev_convert;
1663 		ssup->abbrev_abort = numeric_abbrev_abort;
1664 
1665 		MemoryContextSwitchTo(oldcontext);
1666 	}
1667 
1668 	PG_RETURN_VOID();
1669 }
1670 
1671 /*
1672  * Abbreviate a numeric datum, handling NaNs and detoasting
1673  * (must not leak memory!)
1674  */
1675 static Datum
numeric_abbrev_convert(Datum original_datum,SortSupport ssup)1676 numeric_abbrev_convert(Datum original_datum, SortSupport ssup)
1677 {
1678 	NumericSortSupport *nss = ssup->ssup_extra;
1679 	void	   *original_varatt = PG_DETOAST_DATUM_PACKED(original_datum);
1680 	Numeric		value;
1681 	Datum		result;
1682 
1683 	nss->input_count += 1;
1684 
1685 	/*
1686 	 * This is to handle packed datums without needing a palloc/pfree cycle;
1687 	 * we keep and reuse a buffer large enough to handle any short datum.
1688 	 */
1689 	if (VARATT_IS_SHORT(original_varatt))
1690 	{
1691 		void	   *buf = nss->buf;
1692 		Size		sz = VARSIZE_SHORT(original_varatt) - VARHDRSZ_SHORT;
1693 
1694 		Assert(sz <= VARATT_SHORT_MAX - VARHDRSZ_SHORT);
1695 
1696 		SET_VARSIZE(buf, VARHDRSZ + sz);
1697 		memcpy(VARDATA(buf), VARDATA_SHORT(original_varatt), sz);
1698 
1699 		value = (Numeric) buf;
1700 	}
1701 	else
1702 		value = (Numeric) original_varatt;
1703 
1704 	if (NUMERIC_IS_NAN(value))
1705 	{
1706 		result = NUMERIC_ABBREV_NAN;
1707 	}
1708 	else
1709 	{
1710 		NumericVar	var;
1711 
1712 		init_var_from_num(value, &var);
1713 
1714 		result = numeric_abbrev_convert_var(&var, nss);
1715 	}
1716 
1717 	/* should happen only for external/compressed toasts */
1718 	if ((Pointer) original_varatt != DatumGetPointer(original_datum))
1719 		pfree(original_varatt);
1720 
1721 	return result;
1722 }
1723 
1724 /*
1725  * Consider whether to abort abbreviation.
1726  *
1727  * We pay no attention to the cardinality of the non-abbreviated data. There is
1728  * no reason to do so: unlike text, we have no fast check for equal values, so
1729  * we pay the full overhead whenever the abbreviations are equal regardless of
1730  * whether the underlying values are also equal.
1731  */
1732 static bool
numeric_abbrev_abort(int memtupcount,SortSupport ssup)1733 numeric_abbrev_abort(int memtupcount, SortSupport ssup)
1734 {
1735 	NumericSortSupport *nss = ssup->ssup_extra;
1736 	double		abbr_card;
1737 
1738 	if (memtupcount < 10000 || nss->input_count < 10000 || !nss->estimating)
1739 		return false;
1740 
1741 	abbr_card = estimateHyperLogLog(&nss->abbr_card);
1742 
1743 	/*
1744 	 * If we have >100k distinct values, then even if we were sorting many
1745 	 * billion rows we'd likely still break even, and the penalty of undoing
1746 	 * that many rows of abbrevs would probably not be worth it. Stop even
1747 	 * counting at that point.
1748 	 */
1749 	if (abbr_card > 100000.0)
1750 	{
1751 #ifdef TRACE_SORT
1752 		if (trace_sort)
1753 			elog(LOG,
1754 				 "numeric_abbrev: estimation ends at cardinality %f"
1755 				 " after " INT64_FORMAT " values (%d rows)",
1756 				 abbr_card, nss->input_count, memtupcount);
1757 #endif
1758 		nss->estimating = false;
1759 		return false;
1760 	}
1761 
1762 	/*
1763 	 * Target minimum cardinality is 1 per ~10k of non-null inputs.  (The
1764 	 * break even point is somewhere between one per 100k rows, where
1765 	 * abbreviation has a very slight penalty, and 1 per 10k where it wins by
1766 	 * a measurable percentage.)  We use the relatively pessimistic 10k
1767 	 * threshold, and add a 0.5 row fudge factor, because it allows us to
1768 	 * abort earlier on genuinely pathological data where we've had exactly
1769 	 * one abbreviated value in the first 10k (non-null) rows.
1770 	 */
1771 	if (abbr_card < nss->input_count / 10000.0 + 0.5)
1772 	{
1773 #ifdef TRACE_SORT
1774 		if (trace_sort)
1775 			elog(LOG,
1776 				 "numeric_abbrev: aborting abbreviation at cardinality %f"
1777 				 " below threshold %f after " INT64_FORMAT " values (%d rows)",
1778 				 abbr_card, nss->input_count / 10000.0 + 0.5,
1779 				 nss->input_count, memtupcount);
1780 #endif
1781 		return true;
1782 	}
1783 
1784 #ifdef TRACE_SORT
1785 	if (trace_sort)
1786 		elog(LOG,
1787 			 "numeric_abbrev: cardinality %f"
1788 			 " after " INT64_FORMAT " values (%d rows)",
1789 			 abbr_card, nss->input_count, memtupcount);
1790 #endif
1791 
1792 	return false;
1793 }
1794 
1795 /*
1796  * Non-fmgr interface to the comparison routine to allow sortsupport to elide
1797  * the fmgr call.  The saving here is small given how slow numeric comparisons
1798  * are, but it is a required part of the sort support API when abbreviations
1799  * are performed.
1800  *
1801  * Two palloc/pfree cycles could be saved here by using persistent buffers for
1802  * aligning short-varlena inputs, but this has not so far been considered to
1803  * be worth the effort.
1804  */
1805 static int
numeric_fast_cmp(Datum x,Datum y,SortSupport ssup)1806 numeric_fast_cmp(Datum x, Datum y, SortSupport ssup)
1807 {
1808 	Numeric		nx = DatumGetNumeric(x);
1809 	Numeric		ny = DatumGetNumeric(y);
1810 	int			result;
1811 
1812 	result = cmp_numerics(nx, ny);
1813 
1814 	if ((Pointer) nx != DatumGetPointer(x))
1815 		pfree(nx);
1816 	if ((Pointer) ny != DatumGetPointer(y))
1817 		pfree(ny);
1818 
1819 	return result;
1820 }
1821 
1822 /*
1823  * Compare abbreviations of values. (Abbreviations may be equal where the true
1824  * values differ, but if the abbreviations differ, they must reflect the
1825  * ordering of the true values.)
1826  */
1827 static int
numeric_cmp_abbrev(Datum x,Datum y,SortSupport ssup)1828 numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup)
1829 {
1830 	/*
1831 	 * NOTE WELL: this is intentionally backwards, because the abbreviation is
1832 	 * negated relative to the original value, to handle NaN.
1833 	 */
1834 	if (DatumGetNumericAbbrev(x) < DatumGetNumericAbbrev(y))
1835 		return 1;
1836 	if (DatumGetNumericAbbrev(x) > DatumGetNumericAbbrev(y))
1837 		return -1;
1838 	return 0;
1839 }
1840 
1841 /*
1842  * Abbreviate a NumericVar according to the available bit size.
1843  *
1844  * The 31-bit value is constructed as:
1845  *
1846  *	0 + 7bits digit weight + 24 bits digit value
1847  *
1848  * where the digit weight is in single decimal digits, not digit words, and
1849  * stored in excess-44 representation[1]. The 24-bit digit value is the 7 most
1850  * significant decimal digits of the value converted to binary. Values whose
1851  * weights would fall outside the representable range are rounded off to zero
1852  * (which is also used to represent actual zeros) or to 0x7FFFFFFF (which
1853  * otherwise cannot occur). Abbreviation therefore fails to gain any advantage
1854  * where values are outside the range 10^-44 to 10^83, which is not considered
1855  * to be a serious limitation, or when values are of the same magnitude and
1856  * equal in the first 7 decimal digits, which is considered to be an
1857  * unavoidable limitation given the available bits. (Stealing three more bits
1858  * to compare another digit would narrow the range of representable weights by
1859  * a factor of 8, which starts to look like a real limiting factor.)
1860  *
1861  * (The value 44 for the excess is essentially arbitrary)
1862  *
1863  * The 63-bit value is constructed as:
1864  *
1865  *	0 + 7bits weight + 4 x 14-bit packed digit words
1866  *
1867  * The weight in this case is again stored in excess-44, but this time it is
1868  * the original weight in digit words (i.e. powers of 10000). The first four
1869  * digit words of the value (if present; trailing zeros are assumed as needed)
1870  * are packed into 14 bits each to form the rest of the value. Again,
1871  * out-of-range values are rounded off to 0 or 0x7FFFFFFFFFFFFFFF. The
1872  * representable range in this case is 10^-176 to 10^332, which is considered
1873  * to be good enough for all practical purposes, and comparison of 4 words
1874  * means that at least 13 decimal digits are compared, which is considered to
1875  * be a reasonable compromise between effectiveness and efficiency in computing
1876  * the abbreviation.
1877  *
1878  * (The value 44 for the excess is even more arbitrary here, it was chosen just
1879  * to match the value used in the 31-bit case)
1880  *
1881  * [1] - Excess-k representation means that the value is offset by adding 'k'
1882  * and then treated as unsigned, so the smallest representable value is stored
1883  * with all bits zero. This allows simple comparisons to work on the composite
1884  * value.
1885  */
1886 
1887 #if NUMERIC_ABBREV_BITS == 64
1888 
1889 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)1890 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
1891 {
1892 	int			ndigits = var->ndigits;
1893 	int			weight = var->weight;
1894 	int64		result;
1895 
1896 	if (ndigits == 0 || weight < -44)
1897 	{
1898 		result = 0;
1899 	}
1900 	else if (weight > 83)
1901 	{
1902 		result = PG_INT64_MAX;
1903 	}
1904 	else
1905 	{
1906 		result = ((int64) (weight + 44) << 56);
1907 
1908 		switch (ndigits)
1909 		{
1910 			default:
1911 				result |= ((int64) var->digits[3]);
1912 				/* FALLTHROUGH */
1913 			case 3:
1914 				result |= ((int64) var->digits[2]) << 14;
1915 				/* FALLTHROUGH */
1916 			case 2:
1917 				result |= ((int64) var->digits[1]) << 28;
1918 				/* FALLTHROUGH */
1919 			case 1:
1920 				result |= ((int64) var->digits[0]) << 42;
1921 				break;
1922 		}
1923 	}
1924 
1925 	/* the abbrev is negated relative to the original */
1926 	if (var->sign == NUMERIC_POS)
1927 		result = -result;
1928 
1929 	if (nss->estimating)
1930 	{
1931 		uint32		tmp = ((uint32) result
1932 						   ^ (uint32) ((uint64) result >> 32));
1933 
1934 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
1935 	}
1936 
1937 	return NumericAbbrevGetDatum(result);
1938 }
1939 
1940 #endif							/* NUMERIC_ABBREV_BITS == 64 */
1941 
1942 #if NUMERIC_ABBREV_BITS == 32
1943 
1944 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)1945 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
1946 {
1947 	int			ndigits = var->ndigits;
1948 	int			weight = var->weight;
1949 	int32		result;
1950 
1951 	if (ndigits == 0 || weight < -11)
1952 	{
1953 		result = 0;
1954 	}
1955 	else if (weight > 20)
1956 	{
1957 		result = PG_INT32_MAX;
1958 	}
1959 	else
1960 	{
1961 		NumericDigit nxt1 = (ndigits > 1) ? var->digits[1] : 0;
1962 
1963 		weight = (weight + 11) * 4;
1964 
1965 		result = var->digits[0];
1966 
1967 		/*
1968 		 * "result" now has 1 to 4 nonzero decimal digits. We pack in more
1969 		 * digits to make 7 in total (largest we can fit in 24 bits)
1970 		 */
1971 
1972 		if (result > 999)
1973 		{
1974 			/* already have 4 digits, add 3 more */
1975 			result = (result * 1000) + (nxt1 / 10);
1976 			weight += 3;
1977 		}
1978 		else if (result > 99)
1979 		{
1980 			/* already have 3 digits, add 4 more */
1981 			result = (result * 10000) + nxt1;
1982 			weight += 2;
1983 		}
1984 		else if (result > 9)
1985 		{
1986 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
1987 
1988 			/* already have 2 digits, add 5 more */
1989 			result = (result * 100000) + (nxt1 * 10) + (nxt2 / 1000);
1990 			weight += 1;
1991 		}
1992 		else
1993 		{
1994 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
1995 
1996 			/* already have 1 digit, add 6 more */
1997 			result = (result * 1000000) + (nxt1 * 100) + (nxt2 / 100);
1998 		}
1999 
2000 		result = result | (weight << 24);
2001 	}
2002 
2003 	/* the abbrev is negated relative to the original */
2004 	if (var->sign == NUMERIC_POS)
2005 		result = -result;
2006 
2007 	if (nss->estimating)
2008 	{
2009 		uint32		tmp = (uint32) result;
2010 
2011 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
2012 	}
2013 
2014 	return NumericAbbrevGetDatum(result);
2015 }
2016 
2017 #endif							/* NUMERIC_ABBREV_BITS == 32 */
2018 
2019 /*
2020  * Ordinary (non-sortsupport) comparisons follow.
2021  */
2022 
2023 Datum
numeric_cmp(PG_FUNCTION_ARGS)2024 numeric_cmp(PG_FUNCTION_ARGS)
2025 {
2026 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2027 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2028 	int			result;
2029 
2030 	result = cmp_numerics(num1, num2);
2031 
2032 	PG_FREE_IF_COPY(num1, 0);
2033 	PG_FREE_IF_COPY(num2, 1);
2034 
2035 	PG_RETURN_INT32(result);
2036 }
2037 
2038 
2039 Datum
numeric_eq(PG_FUNCTION_ARGS)2040 numeric_eq(PG_FUNCTION_ARGS)
2041 {
2042 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2043 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2044 	bool		result;
2045 
2046 	result = cmp_numerics(num1, num2) == 0;
2047 
2048 	PG_FREE_IF_COPY(num1, 0);
2049 	PG_FREE_IF_COPY(num2, 1);
2050 
2051 	PG_RETURN_BOOL(result);
2052 }
2053 
2054 Datum
numeric_ne(PG_FUNCTION_ARGS)2055 numeric_ne(PG_FUNCTION_ARGS)
2056 {
2057 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2058 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2059 	bool		result;
2060 
2061 	result = cmp_numerics(num1, num2) != 0;
2062 
2063 	PG_FREE_IF_COPY(num1, 0);
2064 	PG_FREE_IF_COPY(num2, 1);
2065 
2066 	PG_RETURN_BOOL(result);
2067 }
2068 
2069 Datum
numeric_gt(PG_FUNCTION_ARGS)2070 numeric_gt(PG_FUNCTION_ARGS)
2071 {
2072 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2073 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2074 	bool		result;
2075 
2076 	result = cmp_numerics(num1, num2) > 0;
2077 
2078 	PG_FREE_IF_COPY(num1, 0);
2079 	PG_FREE_IF_COPY(num2, 1);
2080 
2081 	PG_RETURN_BOOL(result);
2082 }
2083 
2084 Datum
numeric_ge(PG_FUNCTION_ARGS)2085 numeric_ge(PG_FUNCTION_ARGS)
2086 {
2087 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2088 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2089 	bool		result;
2090 
2091 	result = cmp_numerics(num1, num2) >= 0;
2092 
2093 	PG_FREE_IF_COPY(num1, 0);
2094 	PG_FREE_IF_COPY(num2, 1);
2095 
2096 	PG_RETURN_BOOL(result);
2097 }
2098 
2099 Datum
numeric_lt(PG_FUNCTION_ARGS)2100 numeric_lt(PG_FUNCTION_ARGS)
2101 {
2102 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2103 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2104 	bool		result;
2105 
2106 	result = cmp_numerics(num1, num2) < 0;
2107 
2108 	PG_FREE_IF_COPY(num1, 0);
2109 	PG_FREE_IF_COPY(num2, 1);
2110 
2111 	PG_RETURN_BOOL(result);
2112 }
2113 
2114 Datum
numeric_le(PG_FUNCTION_ARGS)2115 numeric_le(PG_FUNCTION_ARGS)
2116 {
2117 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2118 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2119 	bool		result;
2120 
2121 	result = cmp_numerics(num1, num2) <= 0;
2122 
2123 	PG_FREE_IF_COPY(num1, 0);
2124 	PG_FREE_IF_COPY(num2, 1);
2125 
2126 	PG_RETURN_BOOL(result);
2127 }
2128 
2129 static int
cmp_numerics(Numeric num1,Numeric num2)2130 cmp_numerics(Numeric num1, Numeric num2)
2131 {
2132 	int			result;
2133 
2134 	/*
2135 	 * We consider all NANs to be equal and larger than any non-NAN. This is
2136 	 * somewhat arbitrary; the important thing is to have a consistent sort
2137 	 * order.
2138 	 */
2139 	if (NUMERIC_IS_NAN(num1))
2140 	{
2141 		if (NUMERIC_IS_NAN(num2))
2142 			result = 0;			/* NAN = NAN */
2143 		else
2144 			result = 1;			/* NAN > non-NAN */
2145 	}
2146 	else if (NUMERIC_IS_NAN(num2))
2147 	{
2148 		result = -1;			/* non-NAN < NAN */
2149 	}
2150 	else
2151 	{
2152 		result = cmp_var_common(NUMERIC_DIGITS(num1), NUMERIC_NDIGITS(num1),
2153 								NUMERIC_WEIGHT(num1), NUMERIC_SIGN(num1),
2154 								NUMERIC_DIGITS(num2), NUMERIC_NDIGITS(num2),
2155 								NUMERIC_WEIGHT(num2), NUMERIC_SIGN(num2));
2156 	}
2157 
2158 	return result;
2159 }
2160 
2161 /*
2162  * in_range support function for numeric.
2163  */
2164 Datum
in_range_numeric_numeric(PG_FUNCTION_ARGS)2165 in_range_numeric_numeric(PG_FUNCTION_ARGS)
2166 {
2167 	Numeric		val = PG_GETARG_NUMERIC(0);
2168 	Numeric		base = PG_GETARG_NUMERIC(1);
2169 	Numeric		offset = PG_GETARG_NUMERIC(2);
2170 	bool		sub = PG_GETARG_BOOL(3);
2171 	bool		less = PG_GETARG_BOOL(4);
2172 	bool		result;
2173 
2174 	/*
2175 	 * Reject negative or NaN offset.  Negative is per spec, and NaN is
2176 	 * because appropriate semantics for that seem non-obvious.
2177 	 */
2178 	if (NUMERIC_IS_NAN(offset) || NUMERIC_SIGN(offset) == NUMERIC_NEG)
2179 		ereport(ERROR,
2180 				(errcode(ERRCODE_INVALID_PRECEDING_OR_FOLLOWING_SIZE),
2181 				 errmsg("invalid preceding or following size in window function")));
2182 
2183 	/*
2184 	 * Deal with cases where val and/or base is NaN, following the rule that
2185 	 * NaN sorts after non-NaN (cf cmp_numerics).  The offset cannot affect
2186 	 * the conclusion.
2187 	 */
2188 	if (NUMERIC_IS_NAN(val))
2189 	{
2190 		if (NUMERIC_IS_NAN(base))
2191 			result = true;		/* NAN = NAN */
2192 		else
2193 			result = !less;		/* NAN > non-NAN */
2194 	}
2195 	else if (NUMERIC_IS_NAN(base))
2196 	{
2197 		result = less;			/* non-NAN < NAN */
2198 	}
2199 	else
2200 	{
2201 		/*
2202 		 * Otherwise go ahead and compute base +/- offset.  While it's
2203 		 * possible for this to overflow the numeric format, it's unlikely
2204 		 * enough that we don't take measures to prevent it.
2205 		 */
2206 		NumericVar	valv;
2207 		NumericVar	basev;
2208 		NumericVar	offsetv;
2209 		NumericVar	sum;
2210 
2211 		init_var_from_num(val, &valv);
2212 		init_var_from_num(base, &basev);
2213 		init_var_from_num(offset, &offsetv);
2214 		init_var(&sum);
2215 
2216 		if (sub)
2217 			sub_var(&basev, &offsetv, &sum);
2218 		else
2219 			add_var(&basev, &offsetv, &sum);
2220 
2221 		if (less)
2222 			result = (cmp_var(&valv, &sum) <= 0);
2223 		else
2224 			result = (cmp_var(&valv, &sum) >= 0);
2225 
2226 		free_var(&sum);
2227 	}
2228 
2229 	PG_FREE_IF_COPY(val, 0);
2230 	PG_FREE_IF_COPY(base, 1);
2231 	PG_FREE_IF_COPY(offset, 2);
2232 
2233 	PG_RETURN_BOOL(result);
2234 }
2235 
2236 Datum
hash_numeric(PG_FUNCTION_ARGS)2237 hash_numeric(PG_FUNCTION_ARGS)
2238 {
2239 	Numeric		key = PG_GETARG_NUMERIC(0);
2240 	Datum		digit_hash;
2241 	Datum		result;
2242 	int			weight;
2243 	int			start_offset;
2244 	int			end_offset;
2245 	int			i;
2246 	int			hash_len;
2247 	NumericDigit *digits;
2248 
2249 	/* If it's NaN, don't try to hash the rest of the fields */
2250 	if (NUMERIC_IS_NAN(key))
2251 		PG_RETURN_UINT32(0);
2252 
2253 	weight = NUMERIC_WEIGHT(key);
2254 	start_offset = 0;
2255 	end_offset = 0;
2256 
2257 	/*
2258 	 * Omit any leading or trailing zeros from the input to the hash. The
2259 	 * numeric implementation *should* guarantee that leading and trailing
2260 	 * zeros are suppressed, but we're paranoid. Note that we measure the
2261 	 * starting and ending offsets in units of NumericDigits, not bytes.
2262 	 */
2263 	digits = NUMERIC_DIGITS(key);
2264 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2265 	{
2266 		if (digits[i] != (NumericDigit) 0)
2267 			break;
2268 
2269 		start_offset++;
2270 
2271 		/*
2272 		 * The weight is effectively the # of digits before the decimal point,
2273 		 * so decrement it for each leading zero we skip.
2274 		 */
2275 		weight--;
2276 	}
2277 
2278 	/*
2279 	 * If there are no non-zero digits, then the value of the number is zero,
2280 	 * regardless of any other fields.
2281 	 */
2282 	if (NUMERIC_NDIGITS(key) == start_offset)
2283 		PG_RETURN_UINT32(-1);
2284 
2285 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2286 	{
2287 		if (digits[i] != (NumericDigit) 0)
2288 			break;
2289 
2290 		end_offset++;
2291 	}
2292 
2293 	/* If we get here, there should be at least one non-zero digit */
2294 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2295 
2296 	/*
2297 	 * Note that we don't hash on the Numeric's scale, since two numerics can
2298 	 * compare equal but have different scales. We also don't hash on the
2299 	 * sign, although we could: since a sign difference implies inequality,
2300 	 * this shouldn't affect correctness.
2301 	 */
2302 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2303 	digit_hash = hash_any((unsigned char *) (NUMERIC_DIGITS(key) + start_offset),
2304 						  hash_len * sizeof(NumericDigit));
2305 
2306 	/* Mix in the weight, via XOR */
2307 	result = digit_hash ^ weight;
2308 
2309 	PG_RETURN_DATUM(result);
2310 }
2311 
2312 /*
2313  * Returns 64-bit value by hashing a value to a 64-bit value, with a seed.
2314  * Otherwise, similar to hash_numeric.
2315  */
2316 Datum
hash_numeric_extended(PG_FUNCTION_ARGS)2317 hash_numeric_extended(PG_FUNCTION_ARGS)
2318 {
2319 	Numeric		key = PG_GETARG_NUMERIC(0);
2320 	uint64		seed = PG_GETARG_INT64(1);
2321 	Datum		digit_hash;
2322 	Datum		result;
2323 	int			weight;
2324 	int			start_offset;
2325 	int			end_offset;
2326 	int			i;
2327 	int			hash_len;
2328 	NumericDigit *digits;
2329 
2330 	if (NUMERIC_IS_NAN(key))
2331 		PG_RETURN_UINT64(seed);
2332 
2333 	weight = NUMERIC_WEIGHT(key);
2334 	start_offset = 0;
2335 	end_offset = 0;
2336 
2337 	digits = NUMERIC_DIGITS(key);
2338 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2339 	{
2340 		if (digits[i] != (NumericDigit) 0)
2341 			break;
2342 
2343 		start_offset++;
2344 
2345 		weight--;
2346 	}
2347 
2348 	if (NUMERIC_NDIGITS(key) == start_offset)
2349 		PG_RETURN_UINT64(seed - 1);
2350 
2351 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2352 	{
2353 		if (digits[i] != (NumericDigit) 0)
2354 			break;
2355 
2356 		end_offset++;
2357 	}
2358 
2359 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2360 
2361 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2362 	digit_hash = hash_any_extended((unsigned char *) (NUMERIC_DIGITS(key)
2363 													  + start_offset),
2364 								   hash_len * sizeof(NumericDigit),
2365 								   seed);
2366 
2367 	result = UInt64GetDatum(DatumGetUInt64(digit_hash) ^ weight);
2368 
2369 	PG_RETURN_DATUM(result);
2370 }
2371 
2372 
2373 /* ----------------------------------------------------------------------
2374  *
2375  * Basic arithmetic functions
2376  *
2377  * ----------------------------------------------------------------------
2378  */
2379 
2380 
2381 /*
2382  * numeric_add() -
2383  *
2384  *	Add two numerics
2385  */
2386 Datum
numeric_add(PG_FUNCTION_ARGS)2387 numeric_add(PG_FUNCTION_ARGS)
2388 {
2389 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2390 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2391 	NumericVar	arg1;
2392 	NumericVar	arg2;
2393 	NumericVar	result;
2394 	Numeric		res;
2395 
2396 	/*
2397 	 * Handle NaN
2398 	 */
2399 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2400 		PG_RETURN_NUMERIC(make_result(&const_nan));
2401 
2402 	/*
2403 	 * Unpack the values, let add_var() compute the result and return it.
2404 	 */
2405 	init_var_from_num(num1, &arg1);
2406 	init_var_from_num(num2, &arg2);
2407 
2408 	init_var(&result);
2409 	add_var(&arg1, &arg2, &result);
2410 
2411 	res = make_result(&result);
2412 
2413 	free_var(&result);
2414 
2415 	PG_RETURN_NUMERIC(res);
2416 }
2417 
2418 
2419 /*
2420  * numeric_sub() -
2421  *
2422  *	Subtract one numeric from another
2423  */
2424 Datum
numeric_sub(PG_FUNCTION_ARGS)2425 numeric_sub(PG_FUNCTION_ARGS)
2426 {
2427 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2428 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2429 	NumericVar	arg1;
2430 	NumericVar	arg2;
2431 	NumericVar	result;
2432 	Numeric		res;
2433 
2434 	/*
2435 	 * Handle NaN
2436 	 */
2437 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2438 		PG_RETURN_NUMERIC(make_result(&const_nan));
2439 
2440 	/*
2441 	 * Unpack the values, let sub_var() compute the result and return it.
2442 	 */
2443 	init_var_from_num(num1, &arg1);
2444 	init_var_from_num(num2, &arg2);
2445 
2446 	init_var(&result);
2447 	sub_var(&arg1, &arg2, &result);
2448 
2449 	res = make_result(&result);
2450 
2451 	free_var(&result);
2452 
2453 	PG_RETURN_NUMERIC(res);
2454 }
2455 
2456 
2457 /*
2458  * numeric_mul() -
2459  *
2460  *	Calculate the product of two numerics
2461  */
2462 Datum
numeric_mul(PG_FUNCTION_ARGS)2463 numeric_mul(PG_FUNCTION_ARGS)
2464 {
2465 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2466 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2467 	NumericVar	arg1;
2468 	NumericVar	arg2;
2469 	NumericVar	result;
2470 	Numeric		res;
2471 
2472 	/*
2473 	 * Handle NaN
2474 	 */
2475 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2476 		PG_RETURN_NUMERIC(make_result(&const_nan));
2477 
2478 	/*
2479 	 * Unpack the values, let mul_var() compute the result and return it.
2480 	 * Unlike add_var() and sub_var(), mul_var() will round its result. In the
2481 	 * case of numeric_mul(), which is invoked for the * operator on numerics,
2482 	 * we request exact representation for the product (rscale = sum(dscale of
2483 	 * arg1, dscale of arg2)).  If the exact result has more digits after the
2484 	 * decimal point than can be stored in a numeric, we round it.  Rounding
2485 	 * after computing the exact result ensures that the final result is
2486 	 * correctly rounded (rounding in mul_var() using a truncated product
2487 	 * would not guarantee this).
2488 	 */
2489 	init_var_from_num(num1, &arg1);
2490 	init_var_from_num(num2, &arg2);
2491 
2492 	init_var(&result);
2493 	mul_var(&arg1, &arg2, &result, arg1.dscale + arg2.dscale);
2494 
2495 	if (result.dscale > NUMERIC_DSCALE_MAX)
2496 		round_var(&result, NUMERIC_DSCALE_MAX);
2497 
2498 	res = make_result(&result);
2499 
2500 	free_var(&result);
2501 
2502 	PG_RETURN_NUMERIC(res);
2503 }
2504 
2505 
2506 /*
2507  * numeric_div() -
2508  *
2509  *	Divide one numeric into another
2510  */
2511 Datum
numeric_div(PG_FUNCTION_ARGS)2512 numeric_div(PG_FUNCTION_ARGS)
2513 {
2514 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2515 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2516 	NumericVar	arg1;
2517 	NumericVar	arg2;
2518 	NumericVar	result;
2519 	Numeric		res;
2520 	int			rscale;
2521 
2522 	/*
2523 	 * Handle NaN
2524 	 */
2525 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2526 		PG_RETURN_NUMERIC(make_result(&const_nan));
2527 
2528 	/*
2529 	 * Unpack the arguments
2530 	 */
2531 	init_var_from_num(num1, &arg1);
2532 	init_var_from_num(num2, &arg2);
2533 
2534 	init_var(&result);
2535 
2536 	/*
2537 	 * Select scale for division result
2538 	 */
2539 	rscale = select_div_scale(&arg1, &arg2);
2540 
2541 	/*
2542 	 * Do the divide and return the result
2543 	 */
2544 	div_var(&arg1, &arg2, &result, rscale, true);
2545 
2546 	res = make_result(&result);
2547 
2548 	free_var(&result);
2549 
2550 	PG_RETURN_NUMERIC(res);
2551 }
2552 
2553 
2554 /*
2555  * numeric_div_trunc() -
2556  *
2557  *	Divide one numeric into another, truncating the result to an integer
2558  */
2559 Datum
numeric_div_trunc(PG_FUNCTION_ARGS)2560 numeric_div_trunc(PG_FUNCTION_ARGS)
2561 {
2562 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2563 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2564 	NumericVar	arg1;
2565 	NumericVar	arg2;
2566 	NumericVar	result;
2567 	Numeric		res;
2568 
2569 	/*
2570 	 * Handle NaN
2571 	 */
2572 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2573 		PG_RETURN_NUMERIC(make_result(&const_nan));
2574 
2575 	/*
2576 	 * Unpack the arguments
2577 	 */
2578 	init_var_from_num(num1, &arg1);
2579 	init_var_from_num(num2, &arg2);
2580 
2581 	init_var(&result);
2582 
2583 	/*
2584 	 * Do the divide and return the result
2585 	 */
2586 	div_var(&arg1, &arg2, &result, 0, false);
2587 
2588 	res = make_result(&result);
2589 
2590 	free_var(&result);
2591 
2592 	PG_RETURN_NUMERIC(res);
2593 }
2594 
2595 
2596 /*
2597  * numeric_mod() -
2598  *
2599  *	Calculate the modulo of two numerics
2600  */
2601 Datum
numeric_mod(PG_FUNCTION_ARGS)2602 numeric_mod(PG_FUNCTION_ARGS)
2603 {
2604 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2605 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2606 	Numeric		res;
2607 	NumericVar	arg1;
2608 	NumericVar	arg2;
2609 	NumericVar	result;
2610 
2611 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2612 		PG_RETURN_NUMERIC(make_result(&const_nan));
2613 
2614 	init_var_from_num(num1, &arg1);
2615 	init_var_from_num(num2, &arg2);
2616 
2617 	init_var(&result);
2618 
2619 	mod_var(&arg1, &arg2, &result);
2620 
2621 	res = make_result(&result);
2622 
2623 	free_var(&result);
2624 
2625 	PG_RETURN_NUMERIC(res);
2626 }
2627 
2628 
2629 /*
2630  * numeric_inc() -
2631  *
2632  *	Increment a number by one
2633  */
2634 Datum
numeric_inc(PG_FUNCTION_ARGS)2635 numeric_inc(PG_FUNCTION_ARGS)
2636 {
2637 	Numeric		num = PG_GETARG_NUMERIC(0);
2638 	NumericVar	arg;
2639 	Numeric		res;
2640 
2641 	/*
2642 	 * Handle NaN
2643 	 */
2644 	if (NUMERIC_IS_NAN(num))
2645 		PG_RETURN_NUMERIC(make_result(&const_nan));
2646 
2647 	/*
2648 	 * Compute the result and return it
2649 	 */
2650 	init_var_from_num(num, &arg);
2651 
2652 	add_var(&arg, &const_one, &arg);
2653 
2654 	res = make_result(&arg);
2655 
2656 	free_var(&arg);
2657 
2658 	PG_RETURN_NUMERIC(res);
2659 }
2660 
2661 
2662 /*
2663  * numeric_smaller() -
2664  *
2665  *	Return the smaller of two numbers
2666  */
2667 Datum
numeric_smaller(PG_FUNCTION_ARGS)2668 numeric_smaller(PG_FUNCTION_ARGS)
2669 {
2670 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2671 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2672 
2673 	/*
2674 	 * Use cmp_numerics so that this will agree with the comparison operators,
2675 	 * particularly as regards comparisons involving NaN.
2676 	 */
2677 	if (cmp_numerics(num1, num2) < 0)
2678 		PG_RETURN_NUMERIC(num1);
2679 	else
2680 		PG_RETURN_NUMERIC(num2);
2681 }
2682 
2683 
2684 /*
2685  * numeric_larger() -
2686  *
2687  *	Return the larger of two numbers
2688  */
2689 Datum
numeric_larger(PG_FUNCTION_ARGS)2690 numeric_larger(PG_FUNCTION_ARGS)
2691 {
2692 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2693 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2694 
2695 	/*
2696 	 * Use cmp_numerics so that this will agree with the comparison operators,
2697 	 * particularly as regards comparisons involving NaN.
2698 	 */
2699 	if (cmp_numerics(num1, num2) > 0)
2700 		PG_RETURN_NUMERIC(num1);
2701 	else
2702 		PG_RETURN_NUMERIC(num2);
2703 }
2704 
2705 
2706 /* ----------------------------------------------------------------------
2707  *
2708  * Advanced math functions
2709  *
2710  * ----------------------------------------------------------------------
2711  */
2712 
2713 /*
2714  * numeric_fac()
2715  *
2716  * Compute factorial
2717  */
2718 Datum
numeric_fac(PG_FUNCTION_ARGS)2719 numeric_fac(PG_FUNCTION_ARGS)
2720 {
2721 	int64		num = PG_GETARG_INT64(0);
2722 	Numeric		res;
2723 	NumericVar	fact;
2724 	NumericVar	result;
2725 
2726 	if (num <= 1)
2727 	{
2728 		res = make_result(&const_one);
2729 		PG_RETURN_NUMERIC(res);
2730 	}
2731 	/* Fail immediately if the result would overflow */
2732 	if (num > 32177)
2733 		ereport(ERROR,
2734 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2735 				 errmsg("value overflows numeric format")));
2736 
2737 	init_var(&fact);
2738 	init_var(&result);
2739 
2740 	int64_to_numericvar(num, &result);
2741 
2742 	for (num = num - 1; num > 1; num--)
2743 	{
2744 		/* this loop can take awhile, so allow it to be interrupted */
2745 		CHECK_FOR_INTERRUPTS();
2746 
2747 		int64_to_numericvar(num, &fact);
2748 
2749 		mul_var(&result, &fact, &result, 0);
2750 	}
2751 
2752 	res = make_result(&result);
2753 
2754 	free_var(&fact);
2755 	free_var(&result);
2756 
2757 	PG_RETURN_NUMERIC(res);
2758 }
2759 
2760 
2761 /*
2762  * numeric_sqrt() -
2763  *
2764  *	Compute the square root of a numeric.
2765  */
2766 Datum
numeric_sqrt(PG_FUNCTION_ARGS)2767 numeric_sqrt(PG_FUNCTION_ARGS)
2768 {
2769 	Numeric		num = PG_GETARG_NUMERIC(0);
2770 	Numeric		res;
2771 	NumericVar	arg;
2772 	NumericVar	result;
2773 	int			sweight;
2774 	int			rscale;
2775 
2776 	/*
2777 	 * Handle NaN
2778 	 */
2779 	if (NUMERIC_IS_NAN(num))
2780 		PG_RETURN_NUMERIC(make_result(&const_nan));
2781 
2782 	/*
2783 	 * Unpack the argument and determine the result scale.  We choose a scale
2784 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
2785 	 * case not less than the input's dscale.
2786 	 */
2787 	init_var_from_num(num, &arg);
2788 
2789 	init_var(&result);
2790 
2791 	/* Assume the input was normalized, so arg.weight is accurate */
2792 	sweight = (arg.weight + 1) * DEC_DIGITS / 2 - 1;
2793 
2794 	rscale = NUMERIC_MIN_SIG_DIGITS - sweight;
2795 	rscale = Max(rscale, arg.dscale);
2796 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
2797 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
2798 
2799 	/*
2800 	 * Let sqrt_var() do the calculation and return the result.
2801 	 */
2802 	sqrt_var(&arg, &result, rscale);
2803 
2804 	res = make_result(&result);
2805 
2806 	free_var(&result);
2807 
2808 	PG_RETURN_NUMERIC(res);
2809 }
2810 
2811 
2812 /*
2813  * numeric_exp() -
2814  *
2815  *	Raise e to the power of x
2816  */
2817 Datum
numeric_exp(PG_FUNCTION_ARGS)2818 numeric_exp(PG_FUNCTION_ARGS)
2819 {
2820 	Numeric		num = PG_GETARG_NUMERIC(0);
2821 	Numeric		res;
2822 	NumericVar	arg;
2823 	NumericVar	result;
2824 	int			rscale;
2825 	double		val;
2826 
2827 	/*
2828 	 * Handle NaN
2829 	 */
2830 	if (NUMERIC_IS_NAN(num))
2831 		PG_RETURN_NUMERIC(make_result(&const_nan));
2832 
2833 	/*
2834 	 * Unpack the argument and determine the result scale.  We choose a scale
2835 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
2836 	 * case not less than the input's dscale.
2837 	 */
2838 	init_var_from_num(num, &arg);
2839 
2840 	init_var(&result);
2841 
2842 	/* convert input to float8, ignoring overflow */
2843 	val = numericvar_to_double_no_overflow(&arg);
2844 
2845 	/*
2846 	 * log10(result) = num * log10(e), so this is approximately the decimal
2847 	 * weight of the result:
2848 	 */
2849 	val *= 0.434294481903252;
2850 
2851 	/* limit to something that won't cause integer overflow */
2852 	val = Max(val, -NUMERIC_MAX_RESULT_SCALE);
2853 	val = Min(val, NUMERIC_MAX_RESULT_SCALE);
2854 
2855 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
2856 	rscale = Max(rscale, arg.dscale);
2857 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
2858 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
2859 
2860 	/*
2861 	 * Let exp_var() do the calculation and return the result.
2862 	 */
2863 	exp_var(&arg, &result, rscale);
2864 
2865 	res = make_result(&result);
2866 
2867 	free_var(&result);
2868 
2869 	PG_RETURN_NUMERIC(res);
2870 }
2871 
2872 
2873 /*
2874  * numeric_ln() -
2875  *
2876  *	Compute the natural logarithm of x
2877  */
2878 Datum
numeric_ln(PG_FUNCTION_ARGS)2879 numeric_ln(PG_FUNCTION_ARGS)
2880 {
2881 	Numeric		num = PG_GETARG_NUMERIC(0);
2882 	Numeric		res;
2883 	NumericVar	arg;
2884 	NumericVar	result;
2885 	int			ln_dweight;
2886 	int			rscale;
2887 
2888 	/*
2889 	 * Handle NaN
2890 	 */
2891 	if (NUMERIC_IS_NAN(num))
2892 		PG_RETURN_NUMERIC(make_result(&const_nan));
2893 
2894 	init_var_from_num(num, &arg);
2895 	init_var(&result);
2896 
2897 	/* Estimated dweight of logarithm */
2898 	ln_dweight = estimate_ln_dweight(&arg);
2899 
2900 	rscale = NUMERIC_MIN_SIG_DIGITS - ln_dweight;
2901 	rscale = Max(rscale, arg.dscale);
2902 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
2903 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
2904 
2905 	ln_var(&arg, &result, rscale);
2906 
2907 	res = make_result(&result);
2908 
2909 	free_var(&result);
2910 
2911 	PG_RETURN_NUMERIC(res);
2912 }
2913 
2914 
2915 /*
2916  * numeric_log() -
2917  *
2918  *	Compute the logarithm of x in a given base
2919  */
2920 Datum
numeric_log(PG_FUNCTION_ARGS)2921 numeric_log(PG_FUNCTION_ARGS)
2922 {
2923 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2924 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2925 	Numeric		res;
2926 	NumericVar	arg1;
2927 	NumericVar	arg2;
2928 	NumericVar	result;
2929 
2930 	/*
2931 	 * Handle NaN
2932 	 */
2933 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2934 		PG_RETURN_NUMERIC(make_result(&const_nan));
2935 
2936 	/*
2937 	 * Initialize things
2938 	 */
2939 	init_var_from_num(num1, &arg1);
2940 	init_var_from_num(num2, &arg2);
2941 	init_var(&result);
2942 
2943 	/*
2944 	 * Call log_var() to compute and return the result; note it handles scale
2945 	 * selection itself.
2946 	 */
2947 	log_var(&arg1, &arg2, &result);
2948 
2949 	res = make_result(&result);
2950 
2951 	free_var(&result);
2952 
2953 	PG_RETURN_NUMERIC(res);
2954 }
2955 
2956 
2957 /*
2958  * numeric_power() -
2959  *
2960  *	Raise b to the power of x
2961  */
2962 Datum
numeric_power(PG_FUNCTION_ARGS)2963 numeric_power(PG_FUNCTION_ARGS)
2964 {
2965 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2966 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2967 	Numeric		res;
2968 	NumericVar	arg1;
2969 	NumericVar	arg2;
2970 	NumericVar	arg2_trunc;
2971 	NumericVar	result;
2972 
2973 	/*
2974 	 * Handle NaN cases.  We follow the POSIX spec for pow(3), which says that
2975 	 * NaN ^ 0 = 1, and 1 ^ NaN = 1, while all other cases with NaN inputs
2976 	 * yield NaN (with no error).
2977 	 */
2978 	if (NUMERIC_IS_NAN(num1))
2979 	{
2980 		if (!NUMERIC_IS_NAN(num2))
2981 		{
2982 			init_var_from_num(num2, &arg2);
2983 			if (cmp_var(&arg2, &const_zero) == 0)
2984 				PG_RETURN_NUMERIC(make_result(&const_one));
2985 		}
2986 		PG_RETURN_NUMERIC(make_result(&const_nan));
2987 	}
2988 	if (NUMERIC_IS_NAN(num2))
2989 	{
2990 		init_var_from_num(num1, &arg1);
2991 		if (cmp_var(&arg1, &const_one) == 0)
2992 			PG_RETURN_NUMERIC(make_result(&const_one));
2993 		PG_RETURN_NUMERIC(make_result(&const_nan));
2994 	}
2995 
2996 	/*
2997 	 * Initialize things
2998 	 */
2999 	init_var(&arg2_trunc);
3000 	init_var(&result);
3001 	init_var_from_num(num1, &arg1);
3002 	init_var_from_num(num2, &arg2);
3003 
3004 	set_var_from_var(&arg2, &arg2_trunc);
3005 	trunc_var(&arg2_trunc, 0);
3006 
3007 	/*
3008 	 * The SQL spec requires that we emit a particular SQLSTATE error code for
3009 	 * certain error conditions.  Specifically, we don't return a
3010 	 * divide-by-zero error code for 0 ^ -1.  Raising a negative number to a
3011 	 * non-integer power must produce the same error code, but that case is
3012 	 * handled in power_var().
3013 	 */
3014 	if (cmp_var(&arg1, &const_zero) == 0 &&
3015 		cmp_var(&arg2, &const_zero) < 0)
3016 		ereport(ERROR,
3017 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
3018 				 errmsg("zero raised to a negative power is undefined")));
3019 
3020 	/*
3021 	 * Call power_var() to compute and return the result; note it handles
3022 	 * scale selection itself.
3023 	 */
3024 	power_var(&arg1, &arg2, &result);
3025 
3026 	res = make_result(&result);
3027 
3028 	free_var(&result);
3029 	free_var(&arg2_trunc);
3030 
3031 	PG_RETURN_NUMERIC(res);
3032 }
3033 
3034 /*
3035  * numeric_scale() -
3036  *
3037  *	Returns the scale, i.e. the count of decimal digits in the fractional part
3038  */
3039 Datum
numeric_scale(PG_FUNCTION_ARGS)3040 numeric_scale(PG_FUNCTION_ARGS)
3041 {
3042 	Numeric		num = PG_GETARG_NUMERIC(0);
3043 
3044 	if (NUMERIC_IS_NAN(num))
3045 		PG_RETURN_NULL();
3046 
3047 	PG_RETURN_INT32(NUMERIC_DSCALE(num));
3048 }
3049 
3050 
3051 
3052 /* ----------------------------------------------------------------------
3053  *
3054  * Type conversion functions
3055  *
3056  * ----------------------------------------------------------------------
3057  */
3058 
3059 
3060 Datum
int4_numeric(PG_FUNCTION_ARGS)3061 int4_numeric(PG_FUNCTION_ARGS)
3062 {
3063 	int32		val = PG_GETARG_INT32(0);
3064 	Numeric		res;
3065 	NumericVar	result;
3066 
3067 	init_var(&result);
3068 
3069 	int64_to_numericvar((int64) val, &result);
3070 
3071 	res = make_result(&result);
3072 
3073 	free_var(&result);
3074 
3075 	PG_RETURN_NUMERIC(res);
3076 }
3077 
3078 
3079 Datum
numeric_int4(PG_FUNCTION_ARGS)3080 numeric_int4(PG_FUNCTION_ARGS)
3081 {
3082 	Numeric		num = PG_GETARG_NUMERIC(0);
3083 	NumericVar	x;
3084 	int32		result;
3085 
3086 	/* XXX would it be better to return NULL? */
3087 	if (NUMERIC_IS_NAN(num))
3088 		ereport(ERROR,
3089 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3090 				 errmsg("cannot convert NaN to integer")));
3091 
3092 	/* Convert to variable format, then convert to int4 */
3093 	init_var_from_num(num, &x);
3094 	result = numericvar_to_int32(&x);
3095 	PG_RETURN_INT32(result);
3096 }
3097 
3098 /*
3099  * Given a NumericVar, convert it to an int32. If the NumericVar
3100  * exceeds the range of an int32, raise the appropriate error via
3101  * ereport(). The input NumericVar is *not* free'd.
3102  */
3103 static int32
numericvar_to_int32(const NumericVar * var)3104 numericvar_to_int32(const NumericVar *var)
3105 {
3106 	int64		val;
3107 
3108 	if (!numericvar_to_int64(var, &val))
3109 		ereport(ERROR,
3110 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3111 				 errmsg("integer out of range")));
3112 
3113 	if (unlikely(val < PG_INT32_MIN) || unlikely(val > PG_INT32_MAX))
3114 		ereport(ERROR,
3115 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3116 				 errmsg("integer out of range")));
3117 
3118 	/* Down-convert to int4 */
3119 	return (int32) val;
3120 }
3121 
3122 Datum
int8_numeric(PG_FUNCTION_ARGS)3123 int8_numeric(PG_FUNCTION_ARGS)
3124 {
3125 	int64		val = PG_GETARG_INT64(0);
3126 	Numeric		res;
3127 	NumericVar	result;
3128 
3129 	init_var(&result);
3130 
3131 	int64_to_numericvar(val, &result);
3132 
3133 	res = make_result(&result);
3134 
3135 	free_var(&result);
3136 
3137 	PG_RETURN_NUMERIC(res);
3138 }
3139 
3140 
3141 Datum
numeric_int8(PG_FUNCTION_ARGS)3142 numeric_int8(PG_FUNCTION_ARGS)
3143 {
3144 	Numeric		num = PG_GETARG_NUMERIC(0);
3145 	NumericVar	x;
3146 	int64		result;
3147 
3148 	/* XXX would it be better to return NULL? */
3149 	if (NUMERIC_IS_NAN(num))
3150 		ereport(ERROR,
3151 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3152 				 errmsg("cannot convert NaN to bigint")));
3153 
3154 	/* Convert to variable format and thence to int8 */
3155 	init_var_from_num(num, &x);
3156 
3157 	if (!numericvar_to_int64(&x, &result))
3158 		ereport(ERROR,
3159 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3160 				 errmsg("bigint out of range")));
3161 
3162 	PG_RETURN_INT64(result);
3163 }
3164 
3165 
3166 Datum
int2_numeric(PG_FUNCTION_ARGS)3167 int2_numeric(PG_FUNCTION_ARGS)
3168 {
3169 	int16		val = PG_GETARG_INT16(0);
3170 	Numeric		res;
3171 	NumericVar	result;
3172 
3173 	init_var(&result);
3174 
3175 	int64_to_numericvar((int64) val, &result);
3176 
3177 	res = make_result(&result);
3178 
3179 	free_var(&result);
3180 
3181 	PG_RETURN_NUMERIC(res);
3182 }
3183 
3184 
3185 Datum
numeric_int2(PG_FUNCTION_ARGS)3186 numeric_int2(PG_FUNCTION_ARGS)
3187 {
3188 	Numeric		num = PG_GETARG_NUMERIC(0);
3189 	NumericVar	x;
3190 	int64		val;
3191 	int16		result;
3192 
3193 	/* XXX would it be better to return NULL? */
3194 	if (NUMERIC_IS_NAN(num))
3195 		ereport(ERROR,
3196 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3197 				 errmsg("cannot convert NaN to smallint")));
3198 
3199 	/* Convert to variable format and thence to int8 */
3200 	init_var_from_num(num, &x);
3201 
3202 	if (!numericvar_to_int64(&x, &val))
3203 		ereport(ERROR,
3204 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3205 				 errmsg("smallint out of range")));
3206 
3207 	if (unlikely(val < PG_INT16_MIN) || unlikely(val > PG_INT16_MAX))
3208 		ereport(ERROR,
3209 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3210 				 errmsg("smallint out of range")));
3211 
3212 	/* Down-convert to int2 */
3213 	result = (int16) val;
3214 
3215 	PG_RETURN_INT16(result);
3216 }
3217 
3218 
3219 Datum
float8_numeric(PG_FUNCTION_ARGS)3220 float8_numeric(PG_FUNCTION_ARGS)
3221 {
3222 	float8		val = PG_GETARG_FLOAT8(0);
3223 	Numeric		res;
3224 	NumericVar	result;
3225 	char		buf[DBL_DIG + 100];
3226 
3227 	if (isnan(val))
3228 		PG_RETURN_NUMERIC(make_result(&const_nan));
3229 
3230 	if (isinf(val))
3231 		ereport(ERROR,
3232 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3233 				 errmsg("cannot convert infinity to numeric")));
3234 
3235 	snprintf(buf, sizeof(buf), "%.*g", DBL_DIG, val);
3236 
3237 	init_var(&result);
3238 
3239 	/* Assume we need not worry about leading/trailing spaces */
3240 	(void) set_var_from_str(buf, buf, &result);
3241 
3242 	res = make_result(&result);
3243 
3244 	free_var(&result);
3245 
3246 	PG_RETURN_NUMERIC(res);
3247 }
3248 
3249 
3250 Datum
numeric_float8(PG_FUNCTION_ARGS)3251 numeric_float8(PG_FUNCTION_ARGS)
3252 {
3253 	Numeric		num = PG_GETARG_NUMERIC(0);
3254 	char	   *tmp;
3255 	Datum		result;
3256 
3257 	if (NUMERIC_IS_NAN(num))
3258 		PG_RETURN_FLOAT8(get_float8_nan());
3259 
3260 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3261 											  NumericGetDatum(num)));
3262 
3263 	result = DirectFunctionCall1(float8in, CStringGetDatum(tmp));
3264 
3265 	pfree(tmp);
3266 
3267 	PG_RETURN_DATUM(result);
3268 }
3269 
3270 
3271 /*
3272  * Convert numeric to float8; if out of range, return +/- HUGE_VAL
3273  *
3274  * (internal helper function, not directly callable from SQL)
3275  */
3276 Datum
numeric_float8_no_overflow(PG_FUNCTION_ARGS)3277 numeric_float8_no_overflow(PG_FUNCTION_ARGS)
3278 {
3279 	Numeric		num = PG_GETARG_NUMERIC(0);
3280 	double		val;
3281 
3282 	if (NUMERIC_IS_NAN(num))
3283 		PG_RETURN_FLOAT8(get_float8_nan());
3284 
3285 	val = numeric_to_double_no_overflow(num);
3286 
3287 	PG_RETURN_FLOAT8(val);
3288 }
3289 
3290 Datum
float4_numeric(PG_FUNCTION_ARGS)3291 float4_numeric(PG_FUNCTION_ARGS)
3292 {
3293 	float4		val = PG_GETARG_FLOAT4(0);
3294 	Numeric		res;
3295 	NumericVar	result;
3296 	char		buf[FLT_DIG + 100];
3297 
3298 	if (isnan(val))
3299 		PG_RETURN_NUMERIC(make_result(&const_nan));
3300 
3301 	if (isinf(val))
3302 		ereport(ERROR,
3303 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3304 				 errmsg("cannot convert infinity to numeric")));
3305 
3306 	snprintf(buf, sizeof(buf), "%.*g", FLT_DIG, val);
3307 
3308 	init_var(&result);
3309 
3310 	/* Assume we need not worry about leading/trailing spaces */
3311 	(void) set_var_from_str(buf, buf, &result);
3312 
3313 	res = make_result(&result);
3314 
3315 	free_var(&result);
3316 
3317 	PG_RETURN_NUMERIC(res);
3318 }
3319 
3320 
3321 Datum
numeric_float4(PG_FUNCTION_ARGS)3322 numeric_float4(PG_FUNCTION_ARGS)
3323 {
3324 	Numeric		num = PG_GETARG_NUMERIC(0);
3325 	char	   *tmp;
3326 	Datum		result;
3327 
3328 	if (NUMERIC_IS_NAN(num))
3329 		PG_RETURN_FLOAT4(get_float4_nan());
3330 
3331 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3332 											  NumericGetDatum(num)));
3333 
3334 	result = DirectFunctionCall1(float4in, CStringGetDatum(tmp));
3335 
3336 	pfree(tmp);
3337 
3338 	PG_RETURN_DATUM(result);
3339 }
3340 
3341 
3342 /* ----------------------------------------------------------------------
3343  *
3344  * Aggregate functions
3345  *
3346  * The transition datatype for all these aggregates is declared as INTERNAL.
3347  * Actually, it's a pointer to a NumericAggState allocated in the aggregate
3348  * context.  The digit buffers for the NumericVars will be there too.
3349  *
3350  * On platforms which support 128-bit integers some aggregates instead use a
3351  * 128-bit integer based transition datatype to speed up calculations.
3352  *
3353  * ----------------------------------------------------------------------
3354  */
3355 
3356 typedef struct NumericAggState
3357 {
3358 	bool		calcSumX2;		/* if true, calculate sumX2 */
3359 	MemoryContext agg_context;	/* context we're calculating in */
3360 	int64		N;				/* count of processed numbers */
3361 	NumericSumAccum sumX;		/* sum of processed numbers */
3362 	NumericSumAccum sumX2;		/* sum of squares of processed numbers */
3363 	int			maxScale;		/* maximum scale seen so far */
3364 	int64		maxScaleCount;	/* number of values seen with maximum scale */
3365 	int64		NaNcount;		/* count of NaN values (not included in N!) */
3366 } NumericAggState;
3367 
3368 /*
3369  * Prepare state data for a numeric aggregate function that needs to compute
3370  * sum, count and optionally sum of squares of the input.
3371  */
3372 static NumericAggState *
makeNumericAggState(FunctionCallInfo fcinfo,bool calcSumX2)3373 makeNumericAggState(FunctionCallInfo fcinfo, bool calcSumX2)
3374 {
3375 	NumericAggState *state;
3376 	MemoryContext agg_context;
3377 	MemoryContext old_context;
3378 
3379 	if (!AggCheckCallContext(fcinfo, &agg_context))
3380 		elog(ERROR, "aggregate function called in non-aggregate context");
3381 
3382 	old_context = MemoryContextSwitchTo(agg_context);
3383 
3384 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3385 	state->calcSumX2 = calcSumX2;
3386 	state->agg_context = agg_context;
3387 
3388 	MemoryContextSwitchTo(old_context);
3389 
3390 	return state;
3391 }
3392 
3393 /*
3394  * Like makeNumericAggState(), but allocate the state in the current memory
3395  * context.
3396  */
3397 static NumericAggState *
makeNumericAggStateCurrentContext(bool calcSumX2)3398 makeNumericAggStateCurrentContext(bool calcSumX2)
3399 {
3400 	NumericAggState *state;
3401 
3402 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3403 	state->calcSumX2 = calcSumX2;
3404 	state->agg_context = CurrentMemoryContext;
3405 
3406 	return state;
3407 }
3408 
3409 /*
3410  * Accumulate a new input value for numeric aggregate functions.
3411  */
3412 static void
do_numeric_accum(NumericAggState * state,Numeric newval)3413 do_numeric_accum(NumericAggState *state, Numeric newval)
3414 {
3415 	NumericVar	X;
3416 	NumericVar	X2;
3417 	MemoryContext old_context;
3418 
3419 	/* Count NaN inputs separately from all else */
3420 	if (NUMERIC_IS_NAN(newval))
3421 	{
3422 		state->NaNcount++;
3423 		return;
3424 	}
3425 
3426 	/* load processed number in short-lived context */
3427 	init_var_from_num(newval, &X);
3428 
3429 	/*
3430 	 * Track the highest input dscale that we've seen, to support inverse
3431 	 * transitions (see do_numeric_discard).
3432 	 */
3433 	if (X.dscale > state->maxScale)
3434 	{
3435 		state->maxScale = X.dscale;
3436 		state->maxScaleCount = 1;
3437 	}
3438 	else if (X.dscale == state->maxScale)
3439 		state->maxScaleCount++;
3440 
3441 	/* if we need X^2, calculate that in short-lived context */
3442 	if (state->calcSumX2)
3443 	{
3444 		init_var(&X2);
3445 		mul_var(&X, &X, &X2, X.dscale * 2);
3446 	}
3447 
3448 	/* The rest of this needs to work in the aggregate context */
3449 	old_context = MemoryContextSwitchTo(state->agg_context);
3450 
3451 	state->N++;
3452 
3453 	/* Accumulate sums */
3454 	accum_sum_add(&(state->sumX), &X);
3455 
3456 	if (state->calcSumX2)
3457 		accum_sum_add(&(state->sumX2), &X2);
3458 
3459 	MemoryContextSwitchTo(old_context);
3460 }
3461 
3462 /*
3463  * Attempt to remove an input value from the aggregated state.
3464  *
3465  * If the value cannot be removed then the function will return false; the
3466  * possible reasons for failing are described below.
3467  *
3468  * If we aggregate the values 1.01 and 2 then the result will be 3.01.
3469  * If we are then asked to un-aggregate the 1.01 then we must fail as we
3470  * won't be able to tell what the new aggregated value's dscale should be.
3471  * We don't want to return 2.00 (dscale = 2), since the sum's dscale would
3472  * have been zero if we'd really aggregated only 2.
3473  *
3474  * Note: alternatively, we could count the number of inputs with each possible
3475  * dscale (up to some sane limit).  Not yet clear if it's worth the trouble.
3476  */
3477 static bool
do_numeric_discard(NumericAggState * state,Numeric newval)3478 do_numeric_discard(NumericAggState *state, Numeric newval)
3479 {
3480 	NumericVar	X;
3481 	NumericVar	X2;
3482 	MemoryContext old_context;
3483 
3484 	/* Count NaN inputs separately from all else */
3485 	if (NUMERIC_IS_NAN(newval))
3486 	{
3487 		state->NaNcount--;
3488 		return true;
3489 	}
3490 
3491 	/* load processed number in short-lived context */
3492 	init_var_from_num(newval, &X);
3493 
3494 	/*
3495 	 * state->sumX's dscale is the maximum dscale of any of the inputs.
3496 	 * Removing the last input with that dscale would require us to recompute
3497 	 * the maximum dscale of the *remaining* inputs, which we cannot do unless
3498 	 * no more non-NaN inputs remain at all.  So we report a failure instead,
3499 	 * and force the aggregation to be redone from scratch.
3500 	 */
3501 	if (X.dscale == state->maxScale)
3502 	{
3503 		if (state->maxScaleCount > 1 || state->maxScale == 0)
3504 		{
3505 			/*
3506 			 * Some remaining inputs have same dscale, or dscale hasn't gotten
3507 			 * above zero anyway
3508 			 */
3509 			state->maxScaleCount--;
3510 		}
3511 		else if (state->N == 1)
3512 		{
3513 			/* No remaining non-NaN inputs at all, so reset maxScale */
3514 			state->maxScale = 0;
3515 			state->maxScaleCount = 0;
3516 		}
3517 		else
3518 		{
3519 			/* Correct new maxScale is uncertain, must fail */
3520 			return false;
3521 		}
3522 	}
3523 
3524 	/* if we need X^2, calculate that in short-lived context */
3525 	if (state->calcSumX2)
3526 	{
3527 		init_var(&X2);
3528 		mul_var(&X, &X, &X2, X.dscale * 2);
3529 	}
3530 
3531 	/* The rest of this needs to work in the aggregate context */
3532 	old_context = MemoryContextSwitchTo(state->agg_context);
3533 
3534 	if (state->N-- > 1)
3535 	{
3536 		/* Negate X, to subtract it from the sum */
3537 		X.sign = (X.sign == NUMERIC_POS ? NUMERIC_NEG : NUMERIC_POS);
3538 		accum_sum_add(&(state->sumX), &X);
3539 
3540 		if (state->calcSumX2)
3541 		{
3542 			/* Negate X^2. X^2 is always positive */
3543 			X2.sign = NUMERIC_NEG;
3544 			accum_sum_add(&(state->sumX2), &X2);
3545 		}
3546 	}
3547 	else
3548 	{
3549 		/* Zero the sums */
3550 		Assert(state->N == 0);
3551 
3552 		accum_sum_reset(&state->sumX);
3553 		if (state->calcSumX2)
3554 			accum_sum_reset(&state->sumX2);
3555 	}
3556 
3557 	MemoryContextSwitchTo(old_context);
3558 
3559 	return true;
3560 }
3561 
3562 /*
3563  * Generic transition function for numeric aggregates that require sumX2.
3564  */
3565 Datum
numeric_accum(PG_FUNCTION_ARGS)3566 numeric_accum(PG_FUNCTION_ARGS)
3567 {
3568 	NumericAggState *state;
3569 
3570 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3571 
3572 	/* Create the state data on the first call */
3573 	if (state == NULL)
3574 		state = makeNumericAggState(fcinfo, true);
3575 
3576 	if (!PG_ARGISNULL(1))
3577 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
3578 
3579 	PG_RETURN_POINTER(state);
3580 }
3581 
3582 /*
3583  * Generic combine function for numeric aggregates which require sumX2
3584  */
3585 Datum
numeric_combine(PG_FUNCTION_ARGS)3586 numeric_combine(PG_FUNCTION_ARGS)
3587 {
3588 	NumericAggState *state1;
3589 	NumericAggState *state2;
3590 	MemoryContext agg_context;
3591 	MemoryContext old_context;
3592 
3593 	if (!AggCheckCallContext(fcinfo, &agg_context))
3594 		elog(ERROR, "aggregate function called in non-aggregate context");
3595 
3596 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3597 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
3598 
3599 	if (state2 == NULL)
3600 		PG_RETURN_POINTER(state1);
3601 
3602 	/* manually copy all fields from state2 to state1 */
3603 	if (state1 == NULL)
3604 	{
3605 		old_context = MemoryContextSwitchTo(agg_context);
3606 
3607 		state1 = makeNumericAggStateCurrentContext(true);
3608 		state1->N = state2->N;
3609 		state1->NaNcount = state2->NaNcount;
3610 		state1->maxScale = state2->maxScale;
3611 		state1->maxScaleCount = state2->maxScaleCount;
3612 
3613 		accum_sum_copy(&state1->sumX, &state2->sumX);
3614 		accum_sum_copy(&state1->sumX2, &state2->sumX2);
3615 
3616 		MemoryContextSwitchTo(old_context);
3617 
3618 		PG_RETURN_POINTER(state1);
3619 	}
3620 
3621 	state1->N += state2->N;
3622 	state1->NaNcount += state2->NaNcount;
3623 
3624 	if (state2->N > 0)
3625 	{
3626 		/*
3627 		 * These are currently only needed for moving aggregates, but let's do
3628 		 * the right thing anyway...
3629 		 */
3630 		if (state2->maxScale > state1->maxScale)
3631 		{
3632 			state1->maxScale = state2->maxScale;
3633 			state1->maxScaleCount = state2->maxScaleCount;
3634 		}
3635 		else if (state2->maxScale == state1->maxScale)
3636 			state1->maxScaleCount += state2->maxScaleCount;
3637 
3638 		/* The rest of this needs to work in the aggregate context */
3639 		old_context = MemoryContextSwitchTo(agg_context);
3640 
3641 		/* Accumulate sums */
3642 		accum_sum_combine(&state1->sumX, &state2->sumX);
3643 		accum_sum_combine(&state1->sumX2, &state2->sumX2);
3644 
3645 		MemoryContextSwitchTo(old_context);
3646 	}
3647 	PG_RETURN_POINTER(state1);
3648 }
3649 
3650 /*
3651  * Generic transition function for numeric aggregates that don't require sumX2.
3652  */
3653 Datum
numeric_avg_accum(PG_FUNCTION_ARGS)3654 numeric_avg_accum(PG_FUNCTION_ARGS)
3655 {
3656 	NumericAggState *state;
3657 
3658 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3659 
3660 	/* Create the state data on the first call */
3661 	if (state == NULL)
3662 		state = makeNumericAggState(fcinfo, false);
3663 
3664 	if (!PG_ARGISNULL(1))
3665 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
3666 
3667 	PG_RETURN_POINTER(state);
3668 }
3669 
3670 /*
3671  * Combine function for numeric aggregates which don't require sumX2
3672  */
3673 Datum
numeric_avg_combine(PG_FUNCTION_ARGS)3674 numeric_avg_combine(PG_FUNCTION_ARGS)
3675 {
3676 	NumericAggState *state1;
3677 	NumericAggState *state2;
3678 	MemoryContext agg_context;
3679 	MemoryContext old_context;
3680 
3681 	if (!AggCheckCallContext(fcinfo, &agg_context))
3682 		elog(ERROR, "aggregate function called in non-aggregate context");
3683 
3684 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3685 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
3686 
3687 	if (state2 == NULL)
3688 		PG_RETURN_POINTER(state1);
3689 
3690 	/* manually copy all fields from state2 to state1 */
3691 	if (state1 == NULL)
3692 	{
3693 		old_context = MemoryContextSwitchTo(agg_context);
3694 
3695 		state1 = makeNumericAggStateCurrentContext(false);
3696 		state1->N = state2->N;
3697 		state1->NaNcount = state2->NaNcount;
3698 		state1->maxScale = state2->maxScale;
3699 		state1->maxScaleCount = state2->maxScaleCount;
3700 
3701 		accum_sum_copy(&state1->sumX, &state2->sumX);
3702 
3703 		MemoryContextSwitchTo(old_context);
3704 
3705 		PG_RETURN_POINTER(state1);
3706 	}
3707 
3708 	state1->N += state2->N;
3709 	state1->NaNcount += state2->NaNcount;
3710 
3711 	if (state2->N > 0)
3712 	{
3713 		/*
3714 		 * These are currently only needed for moving aggregates, but let's do
3715 		 * the right thing anyway...
3716 		 */
3717 		if (state2->maxScale > state1->maxScale)
3718 		{
3719 			state1->maxScale = state2->maxScale;
3720 			state1->maxScaleCount = state2->maxScaleCount;
3721 		}
3722 		else if (state2->maxScale == state1->maxScale)
3723 			state1->maxScaleCount += state2->maxScaleCount;
3724 
3725 		/* The rest of this needs to work in the aggregate context */
3726 		old_context = MemoryContextSwitchTo(agg_context);
3727 
3728 		/* Accumulate sums */
3729 		accum_sum_combine(&state1->sumX, &state2->sumX);
3730 
3731 		MemoryContextSwitchTo(old_context);
3732 	}
3733 	PG_RETURN_POINTER(state1);
3734 }
3735 
3736 /*
3737  * numeric_avg_serialize
3738  *		Serialize NumericAggState for numeric aggregates that don't require
3739  *		sumX2.
3740  */
3741 Datum
numeric_avg_serialize(PG_FUNCTION_ARGS)3742 numeric_avg_serialize(PG_FUNCTION_ARGS)
3743 {
3744 	NumericAggState *state;
3745 	StringInfoData buf;
3746 	Datum		temp;
3747 	bytea	   *sumX;
3748 	bytea	   *result;
3749 	NumericVar	tmp_var;
3750 
3751 	/* Ensure we disallow calling when not in aggregate context */
3752 	if (!AggCheckCallContext(fcinfo, NULL))
3753 		elog(ERROR, "aggregate function called in non-aggregate context");
3754 
3755 	state = (NumericAggState *) PG_GETARG_POINTER(0);
3756 
3757 	/*
3758 	 * This is a little wasteful since make_result converts the NumericVar
3759 	 * into a Numeric and numeric_send converts it back again. Is it worth
3760 	 * splitting the tasks in numeric_send into separate functions to stop
3761 	 * this? Doing so would also remove the fmgr call overhead.
3762 	 */
3763 	init_var(&tmp_var);
3764 	accum_sum_final(&state->sumX, &tmp_var);
3765 
3766 	temp = DirectFunctionCall1(numeric_send,
3767 							   NumericGetDatum(make_result(&tmp_var)));
3768 	sumX = DatumGetByteaPP(temp);
3769 	free_var(&tmp_var);
3770 
3771 	pq_begintypsend(&buf);
3772 
3773 	/* N */
3774 	pq_sendint64(&buf, state->N);
3775 
3776 	/* sumX */
3777 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
3778 
3779 	/* maxScale */
3780 	pq_sendint32(&buf, state->maxScale);
3781 
3782 	/* maxScaleCount */
3783 	pq_sendint64(&buf, state->maxScaleCount);
3784 
3785 	/* NaNcount */
3786 	pq_sendint64(&buf, state->NaNcount);
3787 
3788 	result = pq_endtypsend(&buf);
3789 
3790 	PG_RETURN_BYTEA_P(result);
3791 }
3792 
3793 /*
3794  * numeric_avg_deserialize
3795  *		Deserialize bytea into NumericAggState for numeric aggregates that
3796  *		don't require sumX2.
3797  */
3798 Datum
numeric_avg_deserialize(PG_FUNCTION_ARGS)3799 numeric_avg_deserialize(PG_FUNCTION_ARGS)
3800 {
3801 	bytea	   *sstate;
3802 	NumericAggState *result;
3803 	Datum		temp;
3804 	NumericVar	tmp_var;
3805 	StringInfoData buf;
3806 
3807 	if (!AggCheckCallContext(fcinfo, NULL))
3808 		elog(ERROR, "aggregate function called in non-aggregate context");
3809 
3810 	sstate = PG_GETARG_BYTEA_PP(0);
3811 
3812 	/*
3813 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
3814 	 * standard recv-function infrastructure.
3815 	 */
3816 	initStringInfo(&buf);
3817 	appendBinaryStringInfo(&buf,
3818 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
3819 
3820 	result = makeNumericAggStateCurrentContext(false);
3821 
3822 	/* N */
3823 	result->N = pq_getmsgint64(&buf);
3824 
3825 	/* sumX */
3826 	temp = DirectFunctionCall3(numeric_recv,
3827 							   PointerGetDatum(&buf),
3828 							   ObjectIdGetDatum(InvalidOid),
3829 							   Int32GetDatum(-1));
3830 	init_var_from_num(DatumGetNumeric(temp), &tmp_var);
3831 	accum_sum_add(&(result->sumX), &tmp_var);
3832 
3833 	/* maxScale */
3834 	result->maxScale = pq_getmsgint(&buf, 4);
3835 
3836 	/* maxScaleCount */
3837 	result->maxScaleCount = pq_getmsgint64(&buf);
3838 
3839 	/* NaNcount */
3840 	result->NaNcount = pq_getmsgint64(&buf);
3841 
3842 	pq_getmsgend(&buf);
3843 	pfree(buf.data);
3844 
3845 	PG_RETURN_POINTER(result);
3846 }
3847 
3848 /*
3849  * numeric_serialize
3850  *		Serialization function for NumericAggState for numeric aggregates that
3851  *		require sumX2.
3852  */
3853 Datum
numeric_serialize(PG_FUNCTION_ARGS)3854 numeric_serialize(PG_FUNCTION_ARGS)
3855 {
3856 	NumericAggState *state;
3857 	StringInfoData buf;
3858 	Datum		temp;
3859 	bytea	   *sumX;
3860 	NumericVar	tmp_var;
3861 	bytea	   *sumX2;
3862 	bytea	   *result;
3863 
3864 	/* Ensure we disallow calling when not in aggregate context */
3865 	if (!AggCheckCallContext(fcinfo, NULL))
3866 		elog(ERROR, "aggregate function called in non-aggregate context");
3867 
3868 	state = (NumericAggState *) PG_GETARG_POINTER(0);
3869 
3870 	/*
3871 	 * This is a little wasteful since make_result converts the NumericVar
3872 	 * into a Numeric and numeric_send converts it back again. Is it worth
3873 	 * splitting the tasks in numeric_send into separate functions to stop
3874 	 * this? Doing so would also remove the fmgr call overhead.
3875 	 */
3876 	init_var(&tmp_var);
3877 
3878 	accum_sum_final(&state->sumX, &tmp_var);
3879 	temp = DirectFunctionCall1(numeric_send,
3880 							   NumericGetDatum(make_result(&tmp_var)));
3881 	sumX = DatumGetByteaPP(temp);
3882 
3883 	accum_sum_final(&state->sumX2, &tmp_var);
3884 	temp = DirectFunctionCall1(numeric_send,
3885 							   NumericGetDatum(make_result(&tmp_var)));
3886 	sumX2 = DatumGetByteaPP(temp);
3887 
3888 	free_var(&tmp_var);
3889 
3890 	pq_begintypsend(&buf);
3891 
3892 	/* N */
3893 	pq_sendint64(&buf, state->N);
3894 
3895 	/* sumX */
3896 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
3897 
3898 	/* sumX2 */
3899 	pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
3900 
3901 	/* maxScale */
3902 	pq_sendint32(&buf, state->maxScale);
3903 
3904 	/* maxScaleCount */
3905 	pq_sendint64(&buf, state->maxScaleCount);
3906 
3907 	/* NaNcount */
3908 	pq_sendint64(&buf, state->NaNcount);
3909 
3910 	result = pq_endtypsend(&buf);
3911 
3912 	PG_RETURN_BYTEA_P(result);
3913 }
3914 
3915 /*
3916  * numeric_deserialize
3917  *		Deserialization function for NumericAggState for numeric aggregates that
3918  *		require sumX2.
3919  */
3920 Datum
numeric_deserialize(PG_FUNCTION_ARGS)3921 numeric_deserialize(PG_FUNCTION_ARGS)
3922 {
3923 	bytea	   *sstate;
3924 	NumericAggState *result;
3925 	Datum		temp;
3926 	NumericVar	sumX_var;
3927 	NumericVar	sumX2_var;
3928 	StringInfoData buf;
3929 
3930 	if (!AggCheckCallContext(fcinfo, NULL))
3931 		elog(ERROR, "aggregate function called in non-aggregate context");
3932 
3933 	sstate = PG_GETARG_BYTEA_PP(0);
3934 
3935 	/*
3936 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
3937 	 * standard recv-function infrastructure.
3938 	 */
3939 	initStringInfo(&buf);
3940 	appendBinaryStringInfo(&buf,
3941 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
3942 
3943 	result = makeNumericAggStateCurrentContext(false);
3944 
3945 	/* N */
3946 	result->N = pq_getmsgint64(&buf);
3947 
3948 	/* sumX */
3949 	temp = DirectFunctionCall3(numeric_recv,
3950 							   PointerGetDatum(&buf),
3951 							   ObjectIdGetDatum(InvalidOid),
3952 							   Int32GetDatum(-1));
3953 	init_var_from_num(DatumGetNumeric(temp), &sumX_var);
3954 	accum_sum_add(&(result->sumX), &sumX_var);
3955 
3956 	/* sumX2 */
3957 	temp = DirectFunctionCall3(numeric_recv,
3958 							   PointerGetDatum(&buf),
3959 							   ObjectIdGetDatum(InvalidOid),
3960 							   Int32GetDatum(-1));
3961 	init_var_from_num(DatumGetNumeric(temp), &sumX2_var);
3962 	accum_sum_add(&(result->sumX2), &sumX2_var);
3963 
3964 	/* maxScale */
3965 	result->maxScale = pq_getmsgint(&buf, 4);
3966 
3967 	/* maxScaleCount */
3968 	result->maxScaleCount = pq_getmsgint64(&buf);
3969 
3970 	/* NaNcount */
3971 	result->NaNcount = pq_getmsgint64(&buf);
3972 
3973 	pq_getmsgend(&buf);
3974 	pfree(buf.data);
3975 
3976 	PG_RETURN_POINTER(result);
3977 }
3978 
3979 /*
3980  * Generic inverse transition function for numeric aggregates
3981  * (with or without requirement for X^2).
3982  */
3983 Datum
numeric_accum_inv(PG_FUNCTION_ARGS)3984 numeric_accum_inv(PG_FUNCTION_ARGS)
3985 {
3986 	NumericAggState *state;
3987 
3988 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3989 
3990 	/* Should not get here with no state */
3991 	if (state == NULL)
3992 		elog(ERROR, "numeric_accum_inv called with NULL state");
3993 
3994 	if (!PG_ARGISNULL(1))
3995 	{
3996 		/* If we fail to perform the inverse transition, return NULL */
3997 		if (!do_numeric_discard(state, PG_GETARG_NUMERIC(1)))
3998 			PG_RETURN_NULL();
3999 	}
4000 
4001 	PG_RETURN_POINTER(state);
4002 }
4003 
4004 
4005 /*
4006  * Integer data types in general use Numeric accumulators to share code
4007  * and avoid risk of overflow.
4008  *
4009  * However for performance reasons optimized special-purpose accumulator
4010  * routines are used when possible.
4011  *
4012  * On platforms with 128-bit integer support, the 128-bit routines will be
4013  * used when sum(X) or sum(X*X) fit into 128-bit.
4014  *
4015  * For 16 and 32 bit inputs, the N and sum(X) fit into 64-bit so the 64-bit
4016  * accumulators will be used for SUM and AVG of these data types.
4017  */
4018 
4019 #ifdef HAVE_INT128
4020 typedef struct Int128AggState
4021 {
4022 	bool		calcSumX2;		/* if true, calculate sumX2 */
4023 	int64		N;				/* count of processed numbers */
4024 	int128		sumX;			/* sum of processed numbers */
4025 	int128		sumX2;			/* sum of squares of processed numbers */
4026 } Int128AggState;
4027 
4028 /*
4029  * Prepare state data for a 128-bit aggregate function that needs to compute
4030  * sum, count and optionally sum of squares of the input.
4031  */
4032 static Int128AggState *
makeInt128AggState(FunctionCallInfo fcinfo,bool calcSumX2)4033 makeInt128AggState(FunctionCallInfo fcinfo, bool calcSumX2)
4034 {
4035 	Int128AggState *state;
4036 	MemoryContext agg_context;
4037 	MemoryContext old_context;
4038 
4039 	if (!AggCheckCallContext(fcinfo, &agg_context))
4040 		elog(ERROR, "aggregate function called in non-aggregate context");
4041 
4042 	old_context = MemoryContextSwitchTo(agg_context);
4043 
4044 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
4045 	state->calcSumX2 = calcSumX2;
4046 
4047 	MemoryContextSwitchTo(old_context);
4048 
4049 	return state;
4050 }
4051 
4052 /*
4053  * Like makeInt128AggState(), but allocate the state in the current memory
4054  * context.
4055  */
4056 static Int128AggState *
makeInt128AggStateCurrentContext(bool calcSumX2)4057 makeInt128AggStateCurrentContext(bool calcSumX2)
4058 {
4059 	Int128AggState *state;
4060 
4061 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
4062 	state->calcSumX2 = calcSumX2;
4063 
4064 	return state;
4065 }
4066 
4067 /*
4068  * Accumulate a new input value for 128-bit aggregate functions.
4069  */
4070 static void
do_int128_accum(Int128AggState * state,int128 newval)4071 do_int128_accum(Int128AggState *state, int128 newval)
4072 {
4073 	if (state->calcSumX2)
4074 		state->sumX2 += newval * newval;
4075 
4076 	state->sumX += newval;
4077 	state->N++;
4078 }
4079 
4080 /*
4081  * Remove an input value from the aggregated state.
4082  */
4083 static void
do_int128_discard(Int128AggState * state,int128 newval)4084 do_int128_discard(Int128AggState *state, int128 newval)
4085 {
4086 	if (state->calcSumX2)
4087 		state->sumX2 -= newval * newval;
4088 
4089 	state->sumX -= newval;
4090 	state->N--;
4091 }
4092 
4093 typedef Int128AggState PolyNumAggState;
4094 #define makePolyNumAggState makeInt128AggState
4095 #define makePolyNumAggStateCurrentContext makeInt128AggStateCurrentContext
4096 #else
4097 typedef NumericAggState PolyNumAggState;
4098 #define makePolyNumAggState makeNumericAggState
4099 #define makePolyNumAggStateCurrentContext makeNumericAggStateCurrentContext
4100 #endif
4101 
4102 Datum
int2_accum(PG_FUNCTION_ARGS)4103 int2_accum(PG_FUNCTION_ARGS)
4104 {
4105 	PolyNumAggState *state;
4106 
4107 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4108 
4109 	/* Create the state data on the first call */
4110 	if (state == NULL)
4111 		state = makePolyNumAggState(fcinfo, true);
4112 
4113 	if (!PG_ARGISNULL(1))
4114 	{
4115 #ifdef HAVE_INT128
4116 		do_int128_accum(state, (int128) PG_GETARG_INT16(1));
4117 #else
4118 		Numeric		newval;
4119 
4120 		newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4121 													 PG_GETARG_DATUM(1)));
4122 		do_numeric_accum(state, newval);
4123 #endif
4124 	}
4125 
4126 	PG_RETURN_POINTER(state);
4127 }
4128 
4129 Datum
int4_accum(PG_FUNCTION_ARGS)4130 int4_accum(PG_FUNCTION_ARGS)
4131 {
4132 	PolyNumAggState *state;
4133 
4134 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4135 
4136 	/* Create the state data on the first call */
4137 	if (state == NULL)
4138 		state = makePolyNumAggState(fcinfo, true);
4139 
4140 	if (!PG_ARGISNULL(1))
4141 	{
4142 #ifdef HAVE_INT128
4143 		do_int128_accum(state, (int128) PG_GETARG_INT32(1));
4144 #else
4145 		Numeric		newval;
4146 
4147 		newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4148 													 PG_GETARG_DATUM(1)));
4149 		do_numeric_accum(state, newval);
4150 #endif
4151 	}
4152 
4153 	PG_RETURN_POINTER(state);
4154 }
4155 
4156 Datum
int8_accum(PG_FUNCTION_ARGS)4157 int8_accum(PG_FUNCTION_ARGS)
4158 {
4159 	NumericAggState *state;
4160 
4161 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4162 
4163 	/* Create the state data on the first call */
4164 	if (state == NULL)
4165 		state = makeNumericAggState(fcinfo, true);
4166 
4167 	if (!PG_ARGISNULL(1))
4168 	{
4169 		Numeric		newval;
4170 
4171 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4172 													 PG_GETARG_DATUM(1)));
4173 		do_numeric_accum(state, newval);
4174 	}
4175 
4176 	PG_RETURN_POINTER(state);
4177 }
4178 
4179 /*
4180  * Combine function for numeric aggregates which require sumX2
4181  */
4182 Datum
numeric_poly_combine(PG_FUNCTION_ARGS)4183 numeric_poly_combine(PG_FUNCTION_ARGS)
4184 {
4185 	PolyNumAggState *state1;
4186 	PolyNumAggState *state2;
4187 	MemoryContext agg_context;
4188 	MemoryContext old_context;
4189 
4190 	if (!AggCheckCallContext(fcinfo, &agg_context))
4191 		elog(ERROR, "aggregate function called in non-aggregate context");
4192 
4193 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4194 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4195 
4196 	if (state2 == NULL)
4197 		PG_RETURN_POINTER(state1);
4198 
4199 	/* manually copy all fields from state2 to state1 */
4200 	if (state1 == NULL)
4201 	{
4202 		old_context = MemoryContextSwitchTo(agg_context);
4203 
4204 		state1 = makePolyNumAggState(fcinfo, true);
4205 		state1->N = state2->N;
4206 
4207 #ifdef HAVE_INT128
4208 		state1->sumX = state2->sumX;
4209 		state1->sumX2 = state2->sumX2;
4210 #else
4211 		accum_sum_copy(&state1->sumX, &state2->sumX);
4212 		accum_sum_copy(&state1->sumX2, &state2->sumX2);
4213 #endif
4214 
4215 		MemoryContextSwitchTo(old_context);
4216 
4217 		PG_RETURN_POINTER(state1);
4218 	}
4219 
4220 	if (state2->N > 0)
4221 	{
4222 		state1->N += state2->N;
4223 
4224 #ifdef HAVE_INT128
4225 		state1->sumX += state2->sumX;
4226 		state1->sumX2 += state2->sumX2;
4227 #else
4228 		/* The rest of this needs to work in the aggregate context */
4229 		old_context = MemoryContextSwitchTo(agg_context);
4230 
4231 		/* Accumulate sums */
4232 		accum_sum_combine(&state1->sumX, &state2->sumX);
4233 		accum_sum_combine(&state1->sumX2, &state2->sumX2);
4234 
4235 		MemoryContextSwitchTo(old_context);
4236 #endif
4237 
4238 	}
4239 	PG_RETURN_POINTER(state1);
4240 }
4241 
4242 /*
4243  * numeric_poly_serialize
4244  *		Serialize PolyNumAggState into bytea for aggregate functions which
4245  *		require sumX2.
4246  */
4247 Datum
numeric_poly_serialize(PG_FUNCTION_ARGS)4248 numeric_poly_serialize(PG_FUNCTION_ARGS)
4249 {
4250 	PolyNumAggState *state;
4251 	StringInfoData buf;
4252 	bytea	   *sumX;
4253 	bytea	   *sumX2;
4254 	bytea	   *result;
4255 
4256 	/* Ensure we disallow calling when not in aggregate context */
4257 	if (!AggCheckCallContext(fcinfo, NULL))
4258 		elog(ERROR, "aggregate function called in non-aggregate context");
4259 
4260 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4261 
4262 	/*
4263 	 * If the platform supports int128 then sumX and sumX2 will be a 128 bit
4264 	 * integer type. Here we'll convert that into a numeric type so that the
4265 	 * combine state is in the same format for both int128 enabled machines
4266 	 * and machines which don't support that type. The logic here is that one
4267 	 * day we might like to send these over to another server for further
4268 	 * processing and we want a standard format to work with.
4269 	 */
4270 	{
4271 		Datum		temp;
4272 		NumericVar	num;
4273 
4274 		init_var(&num);
4275 
4276 #ifdef HAVE_INT128
4277 		int128_to_numericvar(state->sumX, &num);
4278 #else
4279 		accum_sum_final(&state->sumX, &num);
4280 #endif
4281 		temp = DirectFunctionCall1(numeric_send,
4282 								   NumericGetDatum(make_result(&num)));
4283 		sumX = DatumGetByteaPP(temp);
4284 
4285 #ifdef HAVE_INT128
4286 		int128_to_numericvar(state->sumX2, &num);
4287 #else
4288 		accum_sum_final(&state->sumX2, &num);
4289 #endif
4290 		temp = DirectFunctionCall1(numeric_send,
4291 								   NumericGetDatum(make_result(&num)));
4292 		sumX2 = DatumGetByteaPP(temp);
4293 
4294 		free_var(&num);
4295 	}
4296 
4297 	pq_begintypsend(&buf);
4298 
4299 	/* N */
4300 	pq_sendint64(&buf, state->N);
4301 
4302 	/* sumX */
4303 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4304 
4305 	/* sumX2 */
4306 	pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
4307 
4308 	result = pq_endtypsend(&buf);
4309 
4310 	PG_RETURN_BYTEA_P(result);
4311 }
4312 
4313 /*
4314  * numeric_poly_deserialize
4315  *		Deserialize PolyNumAggState from bytea for aggregate functions which
4316  *		require sumX2.
4317  */
4318 Datum
numeric_poly_deserialize(PG_FUNCTION_ARGS)4319 numeric_poly_deserialize(PG_FUNCTION_ARGS)
4320 {
4321 	bytea	   *sstate;
4322 	PolyNumAggState *result;
4323 	Datum		sumX;
4324 	NumericVar	sumX_var;
4325 	Datum		sumX2;
4326 	NumericVar	sumX2_var;
4327 	StringInfoData buf;
4328 
4329 	if (!AggCheckCallContext(fcinfo, NULL))
4330 		elog(ERROR, "aggregate function called in non-aggregate context");
4331 
4332 	sstate = PG_GETARG_BYTEA_PP(0);
4333 
4334 	/*
4335 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4336 	 * standard recv-function infrastructure.
4337 	 */
4338 	initStringInfo(&buf);
4339 	appendBinaryStringInfo(&buf,
4340 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4341 
4342 	result = makePolyNumAggStateCurrentContext(false);
4343 
4344 	/* N */
4345 	result->N = pq_getmsgint64(&buf);
4346 
4347 	/* sumX */
4348 	sumX = DirectFunctionCall3(numeric_recv,
4349 							   PointerGetDatum(&buf),
4350 							   ObjectIdGetDatum(InvalidOid),
4351 							   Int32GetDatum(-1));
4352 
4353 	/* sumX2 */
4354 	sumX2 = DirectFunctionCall3(numeric_recv,
4355 								PointerGetDatum(&buf),
4356 								ObjectIdGetDatum(InvalidOid),
4357 								Int32GetDatum(-1));
4358 
4359 	init_var_from_num(DatumGetNumeric(sumX), &sumX_var);
4360 #ifdef HAVE_INT128
4361 	numericvar_to_int128(&sumX_var, &result->sumX);
4362 #else
4363 	accum_sum_add(&result->sumX, &sumX_var);
4364 #endif
4365 
4366 	init_var_from_num(DatumGetNumeric(sumX2), &sumX2_var);
4367 #ifdef HAVE_INT128
4368 	numericvar_to_int128(&sumX2_var, &result->sumX2);
4369 #else
4370 	accum_sum_add(&result->sumX2, &sumX2_var);
4371 #endif
4372 
4373 	pq_getmsgend(&buf);
4374 	pfree(buf.data);
4375 
4376 	PG_RETURN_POINTER(result);
4377 }
4378 
4379 /*
4380  * Transition function for int8 input when we don't need sumX2.
4381  */
4382 Datum
int8_avg_accum(PG_FUNCTION_ARGS)4383 int8_avg_accum(PG_FUNCTION_ARGS)
4384 {
4385 	PolyNumAggState *state;
4386 
4387 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4388 
4389 	/* Create the state data on the first call */
4390 	if (state == NULL)
4391 		state = makePolyNumAggState(fcinfo, false);
4392 
4393 	if (!PG_ARGISNULL(1))
4394 	{
4395 #ifdef HAVE_INT128
4396 		do_int128_accum(state, (int128) PG_GETARG_INT64(1));
4397 #else
4398 		Numeric		newval;
4399 
4400 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4401 													 PG_GETARG_DATUM(1)));
4402 		do_numeric_accum(state, newval);
4403 #endif
4404 	}
4405 
4406 	PG_RETURN_POINTER(state);
4407 }
4408 
4409 /*
4410  * Combine function for PolyNumAggState for aggregates which don't require
4411  * sumX2
4412  */
4413 Datum
int8_avg_combine(PG_FUNCTION_ARGS)4414 int8_avg_combine(PG_FUNCTION_ARGS)
4415 {
4416 	PolyNumAggState *state1;
4417 	PolyNumAggState *state2;
4418 	MemoryContext agg_context;
4419 	MemoryContext old_context;
4420 
4421 	if (!AggCheckCallContext(fcinfo, &agg_context))
4422 		elog(ERROR, "aggregate function called in non-aggregate context");
4423 
4424 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4425 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4426 
4427 	if (state2 == NULL)
4428 		PG_RETURN_POINTER(state1);
4429 
4430 	/* manually copy all fields from state2 to state1 */
4431 	if (state1 == NULL)
4432 	{
4433 		old_context = MemoryContextSwitchTo(agg_context);
4434 
4435 		state1 = makePolyNumAggState(fcinfo, false);
4436 		state1->N = state2->N;
4437 
4438 #ifdef HAVE_INT128
4439 		state1->sumX = state2->sumX;
4440 #else
4441 		accum_sum_copy(&state1->sumX, &state2->sumX);
4442 #endif
4443 		MemoryContextSwitchTo(old_context);
4444 
4445 		PG_RETURN_POINTER(state1);
4446 	}
4447 
4448 	if (state2->N > 0)
4449 	{
4450 		state1->N += state2->N;
4451 
4452 #ifdef HAVE_INT128
4453 		state1->sumX += state2->sumX;
4454 #else
4455 		/* The rest of this needs to work in the aggregate context */
4456 		old_context = MemoryContextSwitchTo(agg_context);
4457 
4458 		/* Accumulate sums */
4459 		accum_sum_combine(&state1->sumX, &state2->sumX);
4460 
4461 		MemoryContextSwitchTo(old_context);
4462 #endif
4463 
4464 	}
4465 	PG_RETURN_POINTER(state1);
4466 }
4467 
4468 /*
4469  * int8_avg_serialize
4470  *		Serialize PolyNumAggState into bytea using the standard
4471  *		recv-function infrastructure.
4472  */
4473 Datum
int8_avg_serialize(PG_FUNCTION_ARGS)4474 int8_avg_serialize(PG_FUNCTION_ARGS)
4475 {
4476 	PolyNumAggState *state;
4477 	StringInfoData buf;
4478 	bytea	   *sumX;
4479 	bytea	   *result;
4480 
4481 	/* Ensure we disallow calling when not in aggregate context */
4482 	if (!AggCheckCallContext(fcinfo, NULL))
4483 		elog(ERROR, "aggregate function called in non-aggregate context");
4484 
4485 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4486 
4487 	/*
4488 	 * If the platform supports int128 then sumX will be a 128 integer type.
4489 	 * Here we'll convert that into a numeric type so that the combine state
4490 	 * is in the same format for both int128 enabled machines and machines
4491 	 * which don't support that type. The logic here is that one day we might
4492 	 * like to send these over to another server for further processing and we
4493 	 * want a standard format to work with.
4494 	 */
4495 	{
4496 		Datum		temp;
4497 		NumericVar	num;
4498 
4499 		init_var(&num);
4500 
4501 #ifdef HAVE_INT128
4502 		int128_to_numericvar(state->sumX, &num);
4503 #else
4504 		accum_sum_final(&state->sumX, &num);
4505 #endif
4506 		temp = DirectFunctionCall1(numeric_send,
4507 								   NumericGetDatum(make_result(&num)));
4508 		sumX = DatumGetByteaPP(temp);
4509 
4510 		free_var(&num);
4511 	}
4512 
4513 	pq_begintypsend(&buf);
4514 
4515 	/* N */
4516 	pq_sendint64(&buf, state->N);
4517 
4518 	/* sumX */
4519 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4520 
4521 	result = pq_endtypsend(&buf);
4522 
4523 	PG_RETURN_BYTEA_P(result);
4524 }
4525 
4526 /*
4527  * int8_avg_deserialize
4528  *		Deserialize bytea back into PolyNumAggState.
4529  */
4530 Datum
int8_avg_deserialize(PG_FUNCTION_ARGS)4531 int8_avg_deserialize(PG_FUNCTION_ARGS)
4532 {
4533 	bytea	   *sstate;
4534 	PolyNumAggState *result;
4535 	StringInfoData buf;
4536 	Datum		temp;
4537 	NumericVar	num;
4538 
4539 	if (!AggCheckCallContext(fcinfo, NULL))
4540 		elog(ERROR, "aggregate function called in non-aggregate context");
4541 
4542 	sstate = PG_GETARG_BYTEA_PP(0);
4543 
4544 	/*
4545 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4546 	 * standard recv-function infrastructure.
4547 	 */
4548 	initStringInfo(&buf);
4549 	appendBinaryStringInfo(&buf,
4550 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4551 
4552 	result = makePolyNumAggStateCurrentContext(false);
4553 
4554 	/* N */
4555 	result->N = pq_getmsgint64(&buf);
4556 
4557 	/* sumX */
4558 	temp = DirectFunctionCall3(numeric_recv,
4559 							   PointerGetDatum(&buf),
4560 							   ObjectIdGetDatum(InvalidOid),
4561 							   Int32GetDatum(-1));
4562 	init_var_from_num(DatumGetNumeric(temp), &num);
4563 #ifdef HAVE_INT128
4564 	numericvar_to_int128(&num, &result->sumX);
4565 #else
4566 	accum_sum_add(&result->sumX, &num);
4567 #endif
4568 
4569 	pq_getmsgend(&buf);
4570 	pfree(buf.data);
4571 
4572 	PG_RETURN_POINTER(result);
4573 }
4574 
4575 /*
4576  * Inverse transition functions to go with the above.
4577  */
4578 
4579 Datum
int2_accum_inv(PG_FUNCTION_ARGS)4580 int2_accum_inv(PG_FUNCTION_ARGS)
4581 {
4582 	PolyNumAggState *state;
4583 
4584 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4585 
4586 	/* Should not get here with no state */
4587 	if (state == NULL)
4588 		elog(ERROR, "int2_accum_inv called with NULL state");
4589 
4590 	if (!PG_ARGISNULL(1))
4591 	{
4592 #ifdef HAVE_INT128
4593 		do_int128_discard(state, (int128) PG_GETARG_INT16(1));
4594 #else
4595 		Numeric		newval;
4596 
4597 		newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4598 													 PG_GETARG_DATUM(1)));
4599 
4600 		/* Should never fail, all inputs have dscale 0 */
4601 		if (!do_numeric_discard(state, newval))
4602 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4603 #endif
4604 	}
4605 
4606 	PG_RETURN_POINTER(state);
4607 }
4608 
4609 Datum
int4_accum_inv(PG_FUNCTION_ARGS)4610 int4_accum_inv(PG_FUNCTION_ARGS)
4611 {
4612 	PolyNumAggState *state;
4613 
4614 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4615 
4616 	/* Should not get here with no state */
4617 	if (state == NULL)
4618 		elog(ERROR, "int4_accum_inv called with NULL state");
4619 
4620 	if (!PG_ARGISNULL(1))
4621 	{
4622 #ifdef HAVE_INT128
4623 		do_int128_discard(state, (int128) PG_GETARG_INT32(1));
4624 #else
4625 		Numeric		newval;
4626 
4627 		newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4628 													 PG_GETARG_DATUM(1)));
4629 
4630 		/* Should never fail, all inputs have dscale 0 */
4631 		if (!do_numeric_discard(state, newval))
4632 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4633 #endif
4634 	}
4635 
4636 	PG_RETURN_POINTER(state);
4637 }
4638 
4639 Datum
int8_accum_inv(PG_FUNCTION_ARGS)4640 int8_accum_inv(PG_FUNCTION_ARGS)
4641 {
4642 	NumericAggState *state;
4643 
4644 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4645 
4646 	/* Should not get here with no state */
4647 	if (state == NULL)
4648 		elog(ERROR, "int8_accum_inv called with NULL state");
4649 
4650 	if (!PG_ARGISNULL(1))
4651 	{
4652 		Numeric		newval;
4653 
4654 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4655 													 PG_GETARG_DATUM(1)));
4656 
4657 		/* Should never fail, all inputs have dscale 0 */
4658 		if (!do_numeric_discard(state, newval))
4659 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4660 	}
4661 
4662 	PG_RETURN_POINTER(state);
4663 }
4664 
4665 Datum
int8_avg_accum_inv(PG_FUNCTION_ARGS)4666 int8_avg_accum_inv(PG_FUNCTION_ARGS)
4667 {
4668 	PolyNumAggState *state;
4669 
4670 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4671 
4672 	/* Should not get here with no state */
4673 	if (state == NULL)
4674 		elog(ERROR, "int8_avg_accum_inv called with NULL state");
4675 
4676 	if (!PG_ARGISNULL(1))
4677 	{
4678 #ifdef HAVE_INT128
4679 		do_int128_discard(state, (int128) PG_GETARG_INT64(1));
4680 #else
4681 		Numeric		newval;
4682 
4683 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4684 													 PG_GETARG_DATUM(1)));
4685 
4686 		/* Should never fail, all inputs have dscale 0 */
4687 		if (!do_numeric_discard(state, newval))
4688 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4689 #endif
4690 	}
4691 
4692 	PG_RETURN_POINTER(state);
4693 }
4694 
4695 Datum
numeric_poly_sum(PG_FUNCTION_ARGS)4696 numeric_poly_sum(PG_FUNCTION_ARGS)
4697 {
4698 #ifdef HAVE_INT128
4699 	PolyNumAggState *state;
4700 	Numeric		res;
4701 	NumericVar	result;
4702 
4703 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4704 
4705 	/* If there were no non-null inputs, return NULL */
4706 	if (state == NULL || state->N == 0)
4707 		PG_RETURN_NULL();
4708 
4709 	init_var(&result);
4710 
4711 	int128_to_numericvar(state->sumX, &result);
4712 
4713 	res = make_result(&result);
4714 
4715 	free_var(&result);
4716 
4717 	PG_RETURN_NUMERIC(res);
4718 #else
4719 	return numeric_sum(fcinfo);
4720 #endif
4721 }
4722 
4723 Datum
numeric_poly_avg(PG_FUNCTION_ARGS)4724 numeric_poly_avg(PG_FUNCTION_ARGS)
4725 {
4726 #ifdef HAVE_INT128
4727 	PolyNumAggState *state;
4728 	NumericVar	result;
4729 	Datum		countd,
4730 				sumd;
4731 
4732 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4733 
4734 	/* If there were no non-null inputs, return NULL */
4735 	if (state == NULL || state->N == 0)
4736 		PG_RETURN_NULL();
4737 
4738 	init_var(&result);
4739 
4740 	int128_to_numericvar(state->sumX, &result);
4741 
4742 	countd = DirectFunctionCall1(int8_numeric,
4743 								 Int64GetDatumFast(state->N));
4744 	sumd = NumericGetDatum(make_result(&result));
4745 
4746 	free_var(&result);
4747 
4748 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
4749 #else
4750 	return numeric_avg(fcinfo);
4751 #endif
4752 }
4753 
4754 Datum
numeric_avg(PG_FUNCTION_ARGS)4755 numeric_avg(PG_FUNCTION_ARGS)
4756 {
4757 	NumericAggState *state;
4758 	Datum		N_datum;
4759 	Datum		sumX_datum;
4760 	NumericVar	sumX_var;
4761 
4762 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4763 
4764 	/* If there were no non-null inputs, return NULL */
4765 	if (state == NULL || (state->N + state->NaNcount) == 0)
4766 		PG_RETURN_NULL();
4767 
4768 	if (state->NaNcount > 0)	/* there was at least one NaN input */
4769 		PG_RETURN_NUMERIC(make_result(&const_nan));
4770 
4771 	N_datum = DirectFunctionCall1(int8_numeric, Int64GetDatum(state->N));
4772 
4773 	init_var(&sumX_var);
4774 	accum_sum_final(&state->sumX, &sumX_var);
4775 	sumX_datum = NumericGetDatum(make_result(&sumX_var));
4776 	free_var(&sumX_var);
4777 
4778 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumX_datum, N_datum));
4779 }
4780 
4781 Datum
numeric_sum(PG_FUNCTION_ARGS)4782 numeric_sum(PG_FUNCTION_ARGS)
4783 {
4784 	NumericAggState *state;
4785 	NumericVar	sumX_var;
4786 	Numeric		result;
4787 
4788 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4789 
4790 	/* If there were no non-null inputs, return NULL */
4791 	if (state == NULL || (state->N + state->NaNcount) == 0)
4792 		PG_RETURN_NULL();
4793 
4794 	if (state->NaNcount > 0)	/* there was at least one NaN input */
4795 		PG_RETURN_NUMERIC(make_result(&const_nan));
4796 
4797 	init_var(&sumX_var);
4798 	accum_sum_final(&state->sumX, &sumX_var);
4799 	result = make_result(&sumX_var);
4800 	free_var(&sumX_var);
4801 
4802 	PG_RETURN_NUMERIC(result);
4803 }
4804 
4805 /*
4806  * Workhorse routine for the standard deviance and variance
4807  * aggregates. 'state' is aggregate's transition state.
4808  * 'variance' specifies whether we should calculate the
4809  * variance or the standard deviation. 'sample' indicates whether the
4810  * caller is interested in the sample or the population
4811  * variance/stddev.
4812  *
4813  * If appropriate variance statistic is undefined for the input,
4814  * *is_null is set to true and NULL is returned.
4815  */
4816 static Numeric
numeric_stddev_internal(NumericAggState * state,bool variance,bool sample,bool * is_null)4817 numeric_stddev_internal(NumericAggState *state,
4818 						bool variance, bool sample,
4819 						bool *is_null)
4820 {
4821 	Numeric		res;
4822 	NumericVar	vN,
4823 				vsumX,
4824 				vsumX2,
4825 				vNminus1;
4826 	const NumericVar *comp;
4827 	int			rscale;
4828 
4829 	/* Deal with empty input and NaN-input cases */
4830 	if (state == NULL || (state->N + state->NaNcount) == 0)
4831 	{
4832 		*is_null = true;
4833 		return NULL;
4834 	}
4835 
4836 	*is_null = false;
4837 
4838 	if (state->NaNcount > 0)
4839 		return make_result(&const_nan);
4840 
4841 	init_var(&vN);
4842 	init_var(&vsumX);
4843 	init_var(&vsumX2);
4844 
4845 	int64_to_numericvar(state->N, &vN);
4846 	accum_sum_final(&(state->sumX), &vsumX);
4847 	accum_sum_final(&(state->sumX2), &vsumX2);
4848 
4849 	/*
4850 	 * Sample stddev and variance are undefined when N <= 1; population stddev
4851 	 * is undefined when N == 0. Return NULL in either case.
4852 	 */
4853 	if (sample)
4854 		comp = &const_one;
4855 	else
4856 		comp = &const_zero;
4857 
4858 	if (cmp_var(&vN, comp) <= 0)
4859 	{
4860 		*is_null = true;
4861 		return NULL;
4862 	}
4863 
4864 	init_var(&vNminus1);
4865 	sub_var(&vN, &const_one, &vNminus1);
4866 
4867 	/* compute rscale for mul_var calls */
4868 	rscale = vsumX.dscale * 2;
4869 
4870 	mul_var(&vsumX, &vsumX, &vsumX, rscale);	/* vsumX = sumX * sumX */
4871 	mul_var(&vN, &vsumX2, &vsumX2, rscale); /* vsumX2 = N * sumX2 */
4872 	sub_var(&vsumX2, &vsumX, &vsumX2);	/* N * sumX2 - sumX * sumX */
4873 
4874 	if (cmp_var(&vsumX2, &const_zero) <= 0)
4875 	{
4876 		/* Watch out for roundoff error producing a negative numerator */
4877 		res = make_result(&const_zero);
4878 	}
4879 	else
4880 	{
4881 		if (sample)
4882 			mul_var(&vN, &vNminus1, &vNminus1, 0);	/* N * (N - 1) */
4883 		else
4884 			mul_var(&vN, &vN, &vNminus1, 0);	/* N * N */
4885 		rscale = select_div_scale(&vsumX2, &vNminus1);
4886 		div_var(&vsumX2, &vNminus1, &vsumX, rscale, true);	/* variance */
4887 		if (!variance)
4888 			sqrt_var(&vsumX, &vsumX, rscale);	/* stddev */
4889 
4890 		res = make_result(&vsumX);
4891 	}
4892 
4893 	free_var(&vNminus1);
4894 	free_var(&vsumX);
4895 	free_var(&vsumX2);
4896 
4897 	return res;
4898 }
4899 
4900 Datum
numeric_var_samp(PG_FUNCTION_ARGS)4901 numeric_var_samp(PG_FUNCTION_ARGS)
4902 {
4903 	NumericAggState *state;
4904 	Numeric		res;
4905 	bool		is_null;
4906 
4907 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4908 
4909 	res = numeric_stddev_internal(state, true, true, &is_null);
4910 
4911 	if (is_null)
4912 		PG_RETURN_NULL();
4913 	else
4914 		PG_RETURN_NUMERIC(res);
4915 }
4916 
4917 Datum
numeric_stddev_samp(PG_FUNCTION_ARGS)4918 numeric_stddev_samp(PG_FUNCTION_ARGS)
4919 {
4920 	NumericAggState *state;
4921 	Numeric		res;
4922 	bool		is_null;
4923 
4924 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4925 
4926 	res = numeric_stddev_internal(state, false, true, &is_null);
4927 
4928 	if (is_null)
4929 		PG_RETURN_NULL();
4930 	else
4931 		PG_RETURN_NUMERIC(res);
4932 }
4933 
4934 Datum
numeric_var_pop(PG_FUNCTION_ARGS)4935 numeric_var_pop(PG_FUNCTION_ARGS)
4936 {
4937 	NumericAggState *state;
4938 	Numeric		res;
4939 	bool		is_null;
4940 
4941 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4942 
4943 	res = numeric_stddev_internal(state, true, false, &is_null);
4944 
4945 	if (is_null)
4946 		PG_RETURN_NULL();
4947 	else
4948 		PG_RETURN_NUMERIC(res);
4949 }
4950 
4951 Datum
numeric_stddev_pop(PG_FUNCTION_ARGS)4952 numeric_stddev_pop(PG_FUNCTION_ARGS)
4953 {
4954 	NumericAggState *state;
4955 	Numeric		res;
4956 	bool		is_null;
4957 
4958 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4959 
4960 	res = numeric_stddev_internal(state, false, false, &is_null);
4961 
4962 	if (is_null)
4963 		PG_RETURN_NULL();
4964 	else
4965 		PG_RETURN_NUMERIC(res);
4966 }
4967 
4968 #ifdef HAVE_INT128
4969 static Numeric
numeric_poly_stddev_internal(Int128AggState * state,bool variance,bool sample,bool * is_null)4970 numeric_poly_stddev_internal(Int128AggState *state,
4971 							 bool variance, bool sample,
4972 							 bool *is_null)
4973 {
4974 	NumericAggState numstate;
4975 	Numeric		res;
4976 
4977 	/* Initialize an empty agg state */
4978 	memset(&numstate, 0, sizeof(NumericAggState));
4979 
4980 	if (state)
4981 	{
4982 		NumericVar	tmp_var;
4983 
4984 		numstate.N = state->N;
4985 
4986 		init_var(&tmp_var);
4987 
4988 		int128_to_numericvar(state->sumX, &tmp_var);
4989 		accum_sum_add(&numstate.sumX, &tmp_var);
4990 
4991 		int128_to_numericvar(state->sumX2, &tmp_var);
4992 		accum_sum_add(&numstate.sumX2, &tmp_var);
4993 
4994 		free_var(&tmp_var);
4995 	}
4996 
4997 	res = numeric_stddev_internal(&numstate, variance, sample, is_null);
4998 
4999 	if (numstate.sumX.ndigits > 0)
5000 	{
5001 		pfree(numstate.sumX.pos_digits);
5002 		pfree(numstate.sumX.neg_digits);
5003 	}
5004 	if (numstate.sumX2.ndigits > 0)
5005 	{
5006 		pfree(numstate.sumX2.pos_digits);
5007 		pfree(numstate.sumX2.neg_digits);
5008 	}
5009 
5010 	return res;
5011 }
5012 #endif
5013 
5014 Datum
numeric_poly_var_samp(PG_FUNCTION_ARGS)5015 numeric_poly_var_samp(PG_FUNCTION_ARGS)
5016 {
5017 #ifdef HAVE_INT128
5018 	PolyNumAggState *state;
5019 	Numeric		res;
5020 	bool		is_null;
5021 
5022 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5023 
5024 	res = numeric_poly_stddev_internal(state, true, true, &is_null);
5025 
5026 	if (is_null)
5027 		PG_RETURN_NULL();
5028 	else
5029 		PG_RETURN_NUMERIC(res);
5030 #else
5031 	return numeric_var_samp(fcinfo);
5032 #endif
5033 }
5034 
5035 Datum
numeric_poly_stddev_samp(PG_FUNCTION_ARGS)5036 numeric_poly_stddev_samp(PG_FUNCTION_ARGS)
5037 {
5038 #ifdef HAVE_INT128
5039 	PolyNumAggState *state;
5040 	Numeric		res;
5041 	bool		is_null;
5042 
5043 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5044 
5045 	res = numeric_poly_stddev_internal(state, false, true, &is_null);
5046 
5047 	if (is_null)
5048 		PG_RETURN_NULL();
5049 	else
5050 		PG_RETURN_NUMERIC(res);
5051 #else
5052 	return numeric_stddev_samp(fcinfo);
5053 #endif
5054 }
5055 
5056 Datum
numeric_poly_var_pop(PG_FUNCTION_ARGS)5057 numeric_poly_var_pop(PG_FUNCTION_ARGS)
5058 {
5059 #ifdef HAVE_INT128
5060 	PolyNumAggState *state;
5061 	Numeric		res;
5062 	bool		is_null;
5063 
5064 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5065 
5066 	res = numeric_poly_stddev_internal(state, true, false, &is_null);
5067 
5068 	if (is_null)
5069 		PG_RETURN_NULL();
5070 	else
5071 		PG_RETURN_NUMERIC(res);
5072 #else
5073 	return numeric_var_pop(fcinfo);
5074 #endif
5075 }
5076 
5077 Datum
numeric_poly_stddev_pop(PG_FUNCTION_ARGS)5078 numeric_poly_stddev_pop(PG_FUNCTION_ARGS)
5079 {
5080 #ifdef HAVE_INT128
5081 	PolyNumAggState *state;
5082 	Numeric		res;
5083 	bool		is_null;
5084 
5085 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5086 
5087 	res = numeric_poly_stddev_internal(state, false, false, &is_null);
5088 
5089 	if (is_null)
5090 		PG_RETURN_NULL();
5091 	else
5092 		PG_RETURN_NUMERIC(res);
5093 #else
5094 	return numeric_stddev_pop(fcinfo);
5095 #endif
5096 }
5097 
5098 /*
5099  * SUM transition functions for integer datatypes.
5100  *
5101  * To avoid overflow, we use accumulators wider than the input datatype.
5102  * A Numeric accumulator is needed for int8 input; for int4 and int2
5103  * inputs, we use int8 accumulators which should be sufficient for practical
5104  * purposes.  (The latter two therefore don't really belong in this file,
5105  * but we keep them here anyway.)
5106  *
5107  * Because SQL defines the SUM() of no values to be NULL, not zero,
5108  * the initial condition of the transition data value needs to be NULL. This
5109  * means we can't rely on ExecAgg to automatically insert the first non-null
5110  * data value into the transition data: it doesn't know how to do the type
5111  * conversion.  The upshot is that these routines have to be marked non-strict
5112  * and handle substitution of the first non-null input themselves.
5113  *
5114  * Note: these functions are used only in plain aggregation mode.
5115  * In moving-aggregate mode, we use intX_avg_accum and intX_avg_accum_inv.
5116  */
5117 
5118 Datum
int2_sum(PG_FUNCTION_ARGS)5119 int2_sum(PG_FUNCTION_ARGS)
5120 {
5121 	int64		newval;
5122 
5123 	if (PG_ARGISNULL(0))
5124 	{
5125 		/* No non-null input seen so far... */
5126 		if (PG_ARGISNULL(1))
5127 			PG_RETURN_NULL();	/* still no non-null */
5128 		/* This is the first non-null input. */
5129 		newval = (int64) PG_GETARG_INT16(1);
5130 		PG_RETURN_INT64(newval);
5131 	}
5132 
5133 	/*
5134 	 * If we're invoked as an aggregate, we can cheat and modify our first
5135 	 * parameter in-place to avoid palloc overhead. If not, we need to return
5136 	 * the new value of the transition variable. (If int8 is pass-by-value,
5137 	 * then of course this is useless as well as incorrect, so just ifdef it
5138 	 * out.)
5139 	 */
5140 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
5141 	if (AggCheckCallContext(fcinfo, NULL))
5142 	{
5143 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
5144 
5145 		/* Leave the running sum unchanged in the new input is null */
5146 		if (!PG_ARGISNULL(1))
5147 			*oldsum = *oldsum + (int64) PG_GETARG_INT16(1);
5148 
5149 		PG_RETURN_POINTER(oldsum);
5150 	}
5151 	else
5152 #endif
5153 	{
5154 		int64		oldsum = PG_GETARG_INT64(0);
5155 
5156 		/* Leave sum unchanged if new input is null. */
5157 		if (PG_ARGISNULL(1))
5158 			PG_RETURN_INT64(oldsum);
5159 
5160 		/* OK to do the addition. */
5161 		newval = oldsum + (int64) PG_GETARG_INT16(1);
5162 
5163 		PG_RETURN_INT64(newval);
5164 	}
5165 }
5166 
5167 Datum
int4_sum(PG_FUNCTION_ARGS)5168 int4_sum(PG_FUNCTION_ARGS)
5169 {
5170 	int64		newval;
5171 
5172 	if (PG_ARGISNULL(0))
5173 	{
5174 		/* No non-null input seen so far... */
5175 		if (PG_ARGISNULL(1))
5176 			PG_RETURN_NULL();	/* still no non-null */
5177 		/* This is the first non-null input. */
5178 		newval = (int64) PG_GETARG_INT32(1);
5179 		PG_RETURN_INT64(newval);
5180 	}
5181 
5182 	/*
5183 	 * If we're invoked as an aggregate, we can cheat and modify our first
5184 	 * parameter in-place to avoid palloc overhead. If not, we need to return
5185 	 * the new value of the transition variable. (If int8 is pass-by-value,
5186 	 * then of course this is useless as well as incorrect, so just ifdef it
5187 	 * out.)
5188 	 */
5189 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
5190 	if (AggCheckCallContext(fcinfo, NULL))
5191 	{
5192 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
5193 
5194 		/* Leave the running sum unchanged in the new input is null */
5195 		if (!PG_ARGISNULL(1))
5196 			*oldsum = *oldsum + (int64) PG_GETARG_INT32(1);
5197 
5198 		PG_RETURN_POINTER(oldsum);
5199 	}
5200 	else
5201 #endif
5202 	{
5203 		int64		oldsum = PG_GETARG_INT64(0);
5204 
5205 		/* Leave sum unchanged if new input is null. */
5206 		if (PG_ARGISNULL(1))
5207 			PG_RETURN_INT64(oldsum);
5208 
5209 		/* OK to do the addition. */
5210 		newval = oldsum + (int64) PG_GETARG_INT32(1);
5211 
5212 		PG_RETURN_INT64(newval);
5213 	}
5214 }
5215 
5216 /*
5217  * Note: this function is obsolete, it's no longer used for SUM(int8).
5218  */
5219 Datum
int8_sum(PG_FUNCTION_ARGS)5220 int8_sum(PG_FUNCTION_ARGS)
5221 {
5222 	Numeric		oldsum;
5223 	Datum		newval;
5224 
5225 	if (PG_ARGISNULL(0))
5226 	{
5227 		/* No non-null input seen so far... */
5228 		if (PG_ARGISNULL(1))
5229 			PG_RETURN_NULL();	/* still no non-null */
5230 		/* This is the first non-null input. */
5231 		newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5232 		PG_RETURN_DATUM(newval);
5233 	}
5234 
5235 	/*
5236 	 * Note that we cannot special-case the aggregate case here, as we do for
5237 	 * int2_sum and int4_sum: numeric is of variable size, so we cannot modify
5238 	 * our first parameter in-place.
5239 	 */
5240 
5241 	oldsum = PG_GETARG_NUMERIC(0);
5242 
5243 	/* Leave sum unchanged if new input is null. */
5244 	if (PG_ARGISNULL(1))
5245 		PG_RETURN_NUMERIC(oldsum);
5246 
5247 	/* OK to do the addition. */
5248 	newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5249 
5250 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_add,
5251 										NumericGetDatum(oldsum), newval));
5252 }
5253 
5254 
5255 /*
5256  * Routines for avg(int2) and avg(int4).  The transition datatype
5257  * is a two-element int8 array, holding count and sum.
5258  *
5259  * These functions are also used for sum(int2) and sum(int4) when
5260  * operating in moving-aggregate mode, since for correct inverse transitions
5261  * we need to count the inputs.
5262  */
5263 
5264 typedef struct Int8TransTypeData
5265 {
5266 	int64		count;
5267 	int64		sum;
5268 } Int8TransTypeData;
5269 
5270 Datum
int2_avg_accum(PG_FUNCTION_ARGS)5271 int2_avg_accum(PG_FUNCTION_ARGS)
5272 {
5273 	ArrayType  *transarray;
5274 	int16		newval = PG_GETARG_INT16(1);
5275 	Int8TransTypeData *transdata;
5276 
5277 	/*
5278 	 * If we're invoked as an aggregate, we can cheat and modify our first
5279 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5280 	 * a copy of it before scribbling on it.
5281 	 */
5282 	if (AggCheckCallContext(fcinfo, NULL))
5283 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5284 	else
5285 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5286 
5287 	if (ARR_HASNULL(transarray) ||
5288 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5289 		elog(ERROR, "expected 2-element int8 array");
5290 
5291 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5292 	transdata->count++;
5293 	transdata->sum += newval;
5294 
5295 	PG_RETURN_ARRAYTYPE_P(transarray);
5296 }
5297 
5298 Datum
int4_avg_accum(PG_FUNCTION_ARGS)5299 int4_avg_accum(PG_FUNCTION_ARGS)
5300 {
5301 	ArrayType  *transarray;
5302 	int32		newval = PG_GETARG_INT32(1);
5303 	Int8TransTypeData *transdata;
5304 
5305 	/*
5306 	 * If we're invoked as an aggregate, we can cheat and modify our first
5307 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5308 	 * a copy of it before scribbling on it.
5309 	 */
5310 	if (AggCheckCallContext(fcinfo, NULL))
5311 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5312 	else
5313 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5314 
5315 	if (ARR_HASNULL(transarray) ||
5316 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5317 		elog(ERROR, "expected 2-element int8 array");
5318 
5319 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5320 	transdata->count++;
5321 	transdata->sum += newval;
5322 
5323 	PG_RETURN_ARRAYTYPE_P(transarray);
5324 }
5325 
5326 Datum
int4_avg_combine(PG_FUNCTION_ARGS)5327 int4_avg_combine(PG_FUNCTION_ARGS)
5328 {
5329 	ArrayType  *transarray1;
5330 	ArrayType  *transarray2;
5331 	Int8TransTypeData *state1;
5332 	Int8TransTypeData *state2;
5333 
5334 	if (!AggCheckCallContext(fcinfo, NULL))
5335 		elog(ERROR, "aggregate function called in non-aggregate context");
5336 
5337 	transarray1 = PG_GETARG_ARRAYTYPE_P(0);
5338 	transarray2 = PG_GETARG_ARRAYTYPE_P(1);
5339 
5340 	if (ARR_HASNULL(transarray1) ||
5341 		ARR_SIZE(transarray1) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5342 		elog(ERROR, "expected 2-element int8 array");
5343 
5344 	if (ARR_HASNULL(transarray2) ||
5345 		ARR_SIZE(transarray2) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5346 		elog(ERROR, "expected 2-element int8 array");
5347 
5348 	state1 = (Int8TransTypeData *) ARR_DATA_PTR(transarray1);
5349 	state2 = (Int8TransTypeData *) ARR_DATA_PTR(transarray2);
5350 
5351 	state1->count += state2->count;
5352 	state1->sum += state2->sum;
5353 
5354 	PG_RETURN_ARRAYTYPE_P(transarray1);
5355 }
5356 
5357 Datum
int2_avg_accum_inv(PG_FUNCTION_ARGS)5358 int2_avg_accum_inv(PG_FUNCTION_ARGS)
5359 {
5360 	ArrayType  *transarray;
5361 	int16		newval = PG_GETARG_INT16(1);
5362 	Int8TransTypeData *transdata;
5363 
5364 	/*
5365 	 * If we're invoked as an aggregate, we can cheat and modify our first
5366 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5367 	 * a copy of it before scribbling on it.
5368 	 */
5369 	if (AggCheckCallContext(fcinfo, NULL))
5370 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5371 	else
5372 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5373 
5374 	if (ARR_HASNULL(transarray) ||
5375 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5376 		elog(ERROR, "expected 2-element int8 array");
5377 
5378 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5379 	transdata->count--;
5380 	transdata->sum -= newval;
5381 
5382 	PG_RETURN_ARRAYTYPE_P(transarray);
5383 }
5384 
5385 Datum
int4_avg_accum_inv(PG_FUNCTION_ARGS)5386 int4_avg_accum_inv(PG_FUNCTION_ARGS)
5387 {
5388 	ArrayType  *transarray;
5389 	int32		newval = PG_GETARG_INT32(1);
5390 	Int8TransTypeData *transdata;
5391 
5392 	/*
5393 	 * If we're invoked as an aggregate, we can cheat and modify our first
5394 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5395 	 * a copy of it before scribbling on it.
5396 	 */
5397 	if (AggCheckCallContext(fcinfo, NULL))
5398 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5399 	else
5400 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5401 
5402 	if (ARR_HASNULL(transarray) ||
5403 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5404 		elog(ERROR, "expected 2-element int8 array");
5405 
5406 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5407 	transdata->count--;
5408 	transdata->sum -= newval;
5409 
5410 	PG_RETURN_ARRAYTYPE_P(transarray);
5411 }
5412 
5413 Datum
int8_avg(PG_FUNCTION_ARGS)5414 int8_avg(PG_FUNCTION_ARGS)
5415 {
5416 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
5417 	Int8TransTypeData *transdata;
5418 	Datum		countd,
5419 				sumd;
5420 
5421 	if (ARR_HASNULL(transarray) ||
5422 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5423 		elog(ERROR, "expected 2-element int8 array");
5424 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5425 
5426 	/* SQL defines AVG of no values to be NULL */
5427 	if (transdata->count == 0)
5428 		PG_RETURN_NULL();
5429 
5430 	countd = DirectFunctionCall1(int8_numeric,
5431 								 Int64GetDatumFast(transdata->count));
5432 	sumd = DirectFunctionCall1(int8_numeric,
5433 							   Int64GetDatumFast(transdata->sum));
5434 
5435 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
5436 }
5437 
5438 /*
5439  * SUM(int2) and SUM(int4) both return int8, so we can use this
5440  * final function for both.
5441  */
5442 Datum
int2int4_sum(PG_FUNCTION_ARGS)5443 int2int4_sum(PG_FUNCTION_ARGS)
5444 {
5445 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
5446 	Int8TransTypeData *transdata;
5447 
5448 	if (ARR_HASNULL(transarray) ||
5449 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5450 		elog(ERROR, "expected 2-element int8 array");
5451 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5452 
5453 	/* SQL defines SUM of no values to be NULL */
5454 	if (transdata->count == 0)
5455 		PG_RETURN_NULL();
5456 
5457 	PG_RETURN_DATUM(Int64GetDatumFast(transdata->sum));
5458 }
5459 
5460 
5461 /* ----------------------------------------------------------------------
5462  *
5463  * Debug support
5464  *
5465  * ----------------------------------------------------------------------
5466  */
5467 
5468 #ifdef NUMERIC_DEBUG
5469 
5470 /*
5471  * dump_numeric() - Dump a value in the db storage format for debugging
5472  */
5473 static void
dump_numeric(const char * str,Numeric num)5474 dump_numeric(const char *str, Numeric num)
5475 {
5476 	NumericDigit *digits = NUMERIC_DIGITS(num);
5477 	int			ndigits;
5478 	int			i;
5479 
5480 	ndigits = NUMERIC_NDIGITS(num);
5481 
5482 	printf("%s: NUMERIC w=%d d=%d ", str,
5483 		   NUMERIC_WEIGHT(num), NUMERIC_DSCALE(num));
5484 	switch (NUMERIC_SIGN(num))
5485 	{
5486 		case NUMERIC_POS:
5487 			printf("POS");
5488 			break;
5489 		case NUMERIC_NEG:
5490 			printf("NEG");
5491 			break;
5492 		case NUMERIC_NAN:
5493 			printf("NaN");
5494 			break;
5495 		default:
5496 			printf("SIGN=0x%x", NUMERIC_SIGN(num));
5497 			break;
5498 	}
5499 
5500 	for (i = 0; i < ndigits; i++)
5501 		printf(" %0*d", DEC_DIGITS, digits[i]);
5502 	printf("\n");
5503 }
5504 
5505 
5506 /*
5507  * dump_var() - Dump a value in the variable format for debugging
5508  */
5509 static void
dump_var(const char * str,NumericVar * var)5510 dump_var(const char *str, NumericVar *var)
5511 {
5512 	int			i;
5513 
5514 	printf("%s: VAR w=%d d=%d ", str, var->weight, var->dscale);
5515 	switch (var->sign)
5516 	{
5517 		case NUMERIC_POS:
5518 			printf("POS");
5519 			break;
5520 		case NUMERIC_NEG:
5521 			printf("NEG");
5522 			break;
5523 		case NUMERIC_NAN:
5524 			printf("NaN");
5525 			break;
5526 		default:
5527 			printf("SIGN=0x%x", var->sign);
5528 			break;
5529 	}
5530 
5531 	for (i = 0; i < var->ndigits; i++)
5532 		printf(" %0*d", DEC_DIGITS, var->digits[i]);
5533 
5534 	printf("\n");
5535 }
5536 #endif							/* NUMERIC_DEBUG */
5537 
5538 
5539 /* ----------------------------------------------------------------------
5540  *
5541  * Local functions follow
5542  *
5543  * In general, these do not support NaNs --- callers must eliminate
5544  * the possibility of NaN first.  (make_result() is an exception.)
5545  *
5546  * ----------------------------------------------------------------------
5547  */
5548 
5549 
5550 /*
5551  * alloc_var() -
5552  *
5553  *	Allocate a digit buffer of ndigits digits (plus a spare digit for rounding)
5554  */
5555 static void
alloc_var(NumericVar * var,int ndigits)5556 alloc_var(NumericVar *var, int ndigits)
5557 {
5558 	digitbuf_free(var->buf);
5559 	var->buf = digitbuf_alloc(ndigits + 1);
5560 	var->buf[0] = 0;			/* spare digit for rounding */
5561 	var->digits = var->buf + 1;
5562 	var->ndigits = ndigits;
5563 }
5564 
5565 
5566 /*
5567  * free_var() -
5568  *
5569  *	Return the digit buffer of a variable to the free pool
5570  */
5571 static void
free_var(NumericVar * var)5572 free_var(NumericVar *var)
5573 {
5574 	digitbuf_free(var->buf);
5575 	var->buf = NULL;
5576 	var->digits = NULL;
5577 	var->sign = NUMERIC_NAN;
5578 }
5579 
5580 
5581 /*
5582  * zero_var() -
5583  *
5584  *	Set a variable to ZERO.
5585  *	Note: its dscale is not touched.
5586  */
5587 static void
zero_var(NumericVar * var)5588 zero_var(NumericVar *var)
5589 {
5590 	digitbuf_free(var->buf);
5591 	var->buf = NULL;
5592 	var->digits = NULL;
5593 	var->ndigits = 0;
5594 	var->weight = 0;			/* by convention; doesn't really matter */
5595 	var->sign = NUMERIC_POS;	/* anything but NAN... */
5596 }
5597 
5598 
5599 /*
5600  * set_var_from_str()
5601  *
5602  *	Parse a string and put the number into a variable
5603  *
5604  * This function does not handle leading or trailing spaces, and it doesn't
5605  * accept "NaN" either.  It returns the end+1 position so that caller can
5606  * check for trailing spaces/garbage if deemed necessary.
5607  *
5608  * cp is the place to actually start parsing; str is what to use in error
5609  * reports.  (Typically cp would be the same except advanced over spaces.)
5610  */
5611 static const char *
set_var_from_str(const char * str,const char * cp,NumericVar * dest)5612 set_var_from_str(const char *str, const char *cp, NumericVar *dest)
5613 {
5614 	bool		have_dp = false;
5615 	int			i;
5616 	unsigned char *decdigits;
5617 	int			sign = NUMERIC_POS;
5618 	int			dweight = -1;
5619 	int			ddigits;
5620 	int			dscale = 0;
5621 	int			weight;
5622 	int			ndigits;
5623 	int			offset;
5624 	NumericDigit *digits;
5625 
5626 	/*
5627 	 * We first parse the string to extract decimal digits and determine the
5628 	 * correct decimal weight.  Then convert to NBASE representation.
5629 	 */
5630 	switch (*cp)
5631 	{
5632 		case '+':
5633 			sign = NUMERIC_POS;
5634 			cp++;
5635 			break;
5636 
5637 		case '-':
5638 			sign = NUMERIC_NEG;
5639 			cp++;
5640 			break;
5641 	}
5642 
5643 	if (*cp == '.')
5644 	{
5645 		have_dp = true;
5646 		cp++;
5647 	}
5648 
5649 	if (!isdigit((unsigned char) *cp))
5650 		ereport(ERROR,
5651 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5652 				 errmsg("invalid input syntax for type %s: \"%s\"",
5653 						"numeric", str)));
5654 
5655 	decdigits = (unsigned char *) palloc(strlen(cp) + DEC_DIGITS * 2);
5656 
5657 	/* leading padding for digit alignment later */
5658 	memset(decdigits, 0, DEC_DIGITS);
5659 	i = DEC_DIGITS;
5660 
5661 	while (*cp)
5662 	{
5663 		if (isdigit((unsigned char) *cp))
5664 		{
5665 			decdigits[i++] = *cp++ - '0';
5666 			if (!have_dp)
5667 				dweight++;
5668 			else
5669 				dscale++;
5670 		}
5671 		else if (*cp == '.')
5672 		{
5673 			if (have_dp)
5674 				ereport(ERROR,
5675 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5676 						 errmsg("invalid input syntax for type %s: \"%s\"",
5677 								"numeric", str)));
5678 			have_dp = true;
5679 			cp++;
5680 		}
5681 		else
5682 			break;
5683 	}
5684 
5685 	ddigits = i - DEC_DIGITS;
5686 	/* trailing padding for digit alignment later */
5687 	memset(decdigits + i, 0, DEC_DIGITS - 1);
5688 
5689 	/* Handle exponent, if any */
5690 	if (*cp == 'e' || *cp == 'E')
5691 	{
5692 		long		exponent;
5693 		char	   *endptr;
5694 
5695 		cp++;
5696 		exponent = strtol(cp, &endptr, 10);
5697 		if (endptr == cp)
5698 			ereport(ERROR,
5699 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5700 					 errmsg("invalid input syntax for type %s: \"%s\"",
5701 							"numeric", str)));
5702 		cp = endptr;
5703 
5704 		/*
5705 		 * At this point, dweight and dscale can't be more than about
5706 		 * INT_MAX/2 due to the MaxAllocSize limit on string length, so
5707 		 * constraining the exponent similarly should be enough to prevent
5708 		 * integer overflow in this function.  If the value is too large to
5709 		 * fit in storage format, make_result() will complain about it later;
5710 		 * for consistency use the same ereport errcode/text as make_result().
5711 		 */
5712 		if (exponent >= INT_MAX / 2 || exponent <= -(INT_MAX / 2))
5713 			ereport(ERROR,
5714 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
5715 					 errmsg("value overflows numeric format")));
5716 		dweight += (int) exponent;
5717 		dscale -= (int) exponent;
5718 		if (dscale < 0)
5719 			dscale = 0;
5720 	}
5721 
5722 	/*
5723 	 * Okay, convert pure-decimal representation to base NBASE.  First we need
5724 	 * to determine the converted weight and ndigits.  offset is the number of
5725 	 * decimal zeroes to insert before the first given digit to have a
5726 	 * correctly aligned first NBASE digit.
5727 	 */
5728 	if (dweight >= 0)
5729 		weight = (dweight + 1 + DEC_DIGITS - 1) / DEC_DIGITS - 1;
5730 	else
5731 		weight = -((-dweight - 1) / DEC_DIGITS + 1);
5732 	offset = (weight + 1) * DEC_DIGITS - (dweight + 1);
5733 	ndigits = (ddigits + offset + DEC_DIGITS - 1) / DEC_DIGITS;
5734 
5735 	alloc_var(dest, ndigits);
5736 	dest->sign = sign;
5737 	dest->weight = weight;
5738 	dest->dscale = dscale;
5739 
5740 	i = DEC_DIGITS - offset;
5741 	digits = dest->digits;
5742 
5743 	while (ndigits-- > 0)
5744 	{
5745 #if DEC_DIGITS == 4
5746 		*digits++ = ((decdigits[i] * 10 + decdigits[i + 1]) * 10 +
5747 					 decdigits[i + 2]) * 10 + decdigits[i + 3];
5748 #elif DEC_DIGITS == 2
5749 		*digits++ = decdigits[i] * 10 + decdigits[i + 1];
5750 #elif DEC_DIGITS == 1
5751 		*digits++ = decdigits[i];
5752 #else
5753 #error unsupported NBASE
5754 #endif
5755 		i += DEC_DIGITS;
5756 	}
5757 
5758 	pfree(decdigits);
5759 
5760 	/* Strip any leading/trailing zeroes, and normalize weight if zero */
5761 	strip_var(dest);
5762 
5763 	/* Return end+1 position for caller */
5764 	return cp;
5765 }
5766 
5767 
5768 /*
5769  * set_var_from_num() -
5770  *
5771  *	Convert the packed db format into a variable
5772  */
5773 static void
set_var_from_num(Numeric num,NumericVar * dest)5774 set_var_from_num(Numeric num, NumericVar *dest)
5775 {
5776 	int			ndigits;
5777 
5778 	ndigits = NUMERIC_NDIGITS(num);
5779 
5780 	alloc_var(dest, ndigits);
5781 
5782 	dest->weight = NUMERIC_WEIGHT(num);
5783 	dest->sign = NUMERIC_SIGN(num);
5784 	dest->dscale = NUMERIC_DSCALE(num);
5785 
5786 	memcpy(dest->digits, NUMERIC_DIGITS(num), ndigits * sizeof(NumericDigit));
5787 }
5788 
5789 
5790 /*
5791  * init_var_from_num() -
5792  *
5793  *	Initialize a variable from packed db format. The digits array is not
5794  *	copied, which saves some cycles when the resulting var is not modified.
5795  *	Also, there's no need to call free_var(), as long as you don't assign any
5796  *	other value to it (with set_var_* functions, or by using the var as the
5797  *	destination of a function like add_var())
5798  *
5799  *	CAUTION: Do not modify the digits buffer of a var initialized with this
5800  *	function, e.g by calling round_var() or trunc_var(), as the changes will
5801  *	propagate to the original Numeric! It's OK to use it as the destination
5802  *	argument of one of the calculational functions, though.
5803  */
5804 static void
init_var_from_num(Numeric num,NumericVar * dest)5805 init_var_from_num(Numeric num, NumericVar *dest)
5806 {
5807 	dest->ndigits = NUMERIC_NDIGITS(num);
5808 	dest->weight = NUMERIC_WEIGHT(num);
5809 	dest->sign = NUMERIC_SIGN(num);
5810 	dest->dscale = NUMERIC_DSCALE(num);
5811 	dest->digits = NUMERIC_DIGITS(num);
5812 	dest->buf = NULL;			/* digits array is not palloc'd */
5813 }
5814 
5815 
5816 /*
5817  * set_var_from_var() -
5818  *
5819  *	Copy one variable into another
5820  */
5821 static void
set_var_from_var(const NumericVar * value,NumericVar * dest)5822 set_var_from_var(const NumericVar *value, NumericVar *dest)
5823 {
5824 	NumericDigit *newbuf;
5825 
5826 	newbuf = digitbuf_alloc(value->ndigits + 1);
5827 	newbuf[0] = 0;				/* spare digit for rounding */
5828 	if (value->ndigits > 0)		/* else value->digits might be null */
5829 		memcpy(newbuf + 1, value->digits,
5830 			   value->ndigits * sizeof(NumericDigit));
5831 
5832 	digitbuf_free(dest->buf);
5833 
5834 	memmove(dest, value, sizeof(NumericVar));
5835 	dest->buf = newbuf;
5836 	dest->digits = newbuf + 1;
5837 }
5838 
5839 
5840 /*
5841  * get_str_from_var() -
5842  *
5843  *	Convert a var to text representation (guts of numeric_out).
5844  *	The var is displayed to the number of digits indicated by its dscale.
5845  *	Returns a palloc'd string.
5846  */
5847 static char *
get_str_from_var(const NumericVar * var)5848 get_str_from_var(const NumericVar *var)
5849 {
5850 	int			dscale;
5851 	char	   *str;
5852 	char	   *cp;
5853 	char	   *endcp;
5854 	int			i;
5855 	int			d;
5856 	NumericDigit dig;
5857 
5858 #if DEC_DIGITS > 1
5859 	NumericDigit d1;
5860 #endif
5861 
5862 	dscale = var->dscale;
5863 
5864 	/*
5865 	 * Allocate space for the result.
5866 	 *
5867 	 * i is set to the # of decimal digits before decimal point. dscale is the
5868 	 * # of decimal digits we will print after decimal point. We may generate
5869 	 * as many as DEC_DIGITS-1 excess digits at the end, and in addition we
5870 	 * need room for sign, decimal point, null terminator.
5871 	 */
5872 	i = (var->weight + 1) * DEC_DIGITS;
5873 	if (i <= 0)
5874 		i = 1;
5875 
5876 	str = palloc(i + dscale + DEC_DIGITS + 2);
5877 	cp = str;
5878 
5879 	/*
5880 	 * Output a dash for negative values
5881 	 */
5882 	if (var->sign == NUMERIC_NEG)
5883 		*cp++ = '-';
5884 
5885 	/*
5886 	 * Output all digits before the decimal point
5887 	 */
5888 	if (var->weight < 0)
5889 	{
5890 		d = var->weight + 1;
5891 		*cp++ = '0';
5892 	}
5893 	else
5894 	{
5895 		for (d = 0; d <= var->weight; d++)
5896 		{
5897 			dig = (d < var->ndigits) ? var->digits[d] : 0;
5898 			/* In the first digit, suppress extra leading decimal zeroes */
5899 #if DEC_DIGITS == 4
5900 			{
5901 				bool		putit = (d > 0);
5902 
5903 				d1 = dig / 1000;
5904 				dig -= d1 * 1000;
5905 				putit |= (d1 > 0);
5906 				if (putit)
5907 					*cp++ = d1 + '0';
5908 				d1 = dig / 100;
5909 				dig -= d1 * 100;
5910 				putit |= (d1 > 0);
5911 				if (putit)
5912 					*cp++ = d1 + '0';
5913 				d1 = dig / 10;
5914 				dig -= d1 * 10;
5915 				putit |= (d1 > 0);
5916 				if (putit)
5917 					*cp++ = d1 + '0';
5918 				*cp++ = dig + '0';
5919 			}
5920 #elif DEC_DIGITS == 2
5921 			d1 = dig / 10;
5922 			dig -= d1 * 10;
5923 			if (d1 > 0 || d > 0)
5924 				*cp++ = d1 + '0';
5925 			*cp++ = dig + '0';
5926 #elif DEC_DIGITS == 1
5927 			*cp++ = dig + '0';
5928 #else
5929 #error unsupported NBASE
5930 #endif
5931 		}
5932 	}
5933 
5934 	/*
5935 	 * If requested, output a decimal point and all the digits that follow it.
5936 	 * We initially put out a multiple of DEC_DIGITS digits, then truncate if
5937 	 * needed.
5938 	 */
5939 	if (dscale > 0)
5940 	{
5941 		*cp++ = '.';
5942 		endcp = cp + dscale;
5943 		for (i = 0; i < dscale; d++, i += DEC_DIGITS)
5944 		{
5945 			dig = (d >= 0 && d < var->ndigits) ? var->digits[d] : 0;
5946 #if DEC_DIGITS == 4
5947 			d1 = dig / 1000;
5948 			dig -= d1 * 1000;
5949 			*cp++ = d1 + '0';
5950 			d1 = dig / 100;
5951 			dig -= d1 * 100;
5952 			*cp++ = d1 + '0';
5953 			d1 = dig / 10;
5954 			dig -= d1 * 10;
5955 			*cp++ = d1 + '0';
5956 			*cp++ = dig + '0';
5957 #elif DEC_DIGITS == 2
5958 			d1 = dig / 10;
5959 			dig -= d1 * 10;
5960 			*cp++ = d1 + '0';
5961 			*cp++ = dig + '0';
5962 #elif DEC_DIGITS == 1
5963 			*cp++ = dig + '0';
5964 #else
5965 #error unsupported NBASE
5966 #endif
5967 		}
5968 		cp = endcp;
5969 	}
5970 
5971 	/*
5972 	 * terminate the string and return it
5973 	 */
5974 	*cp = '\0';
5975 	return str;
5976 }
5977 
5978 /*
5979  * get_str_from_var_sci() -
5980  *
5981  *	Convert a var to a normalised scientific notation text representation.
5982  *	This function does the heavy lifting for numeric_out_sci().
5983  *
5984  *	This notation has the general form a * 10^b, where a is known as the
5985  *	"significand" and b is known as the "exponent".
5986  *
5987  *	Because we can't do superscript in ASCII (and because we want to copy
5988  *	printf's behaviour) we display the exponent using E notation, with a
5989  *	minimum of two exponent digits.
5990  *
5991  *	For example, the value 1234 could be output as 1.2e+03.
5992  *
5993  *	We assume that the exponent can fit into an int32.
5994  *
5995  *	rscale is the number of decimal digits desired after the decimal point in
5996  *	the output, negative values will be treated as meaning zero.
5997  *
5998  *	Returns a palloc'd string.
5999  */
6000 static char *
get_str_from_var_sci(const NumericVar * var,int rscale)6001 get_str_from_var_sci(const NumericVar *var, int rscale)
6002 {
6003 	int32		exponent;
6004 	NumericVar	tmp_var;
6005 	size_t		len;
6006 	char	   *str;
6007 	char	   *sig_out;
6008 
6009 	if (rscale < 0)
6010 		rscale = 0;
6011 
6012 	/*
6013 	 * Determine the exponent of this number in normalised form.
6014 	 *
6015 	 * This is the exponent required to represent the number with only one
6016 	 * significant digit before the decimal place.
6017 	 */
6018 	if (var->ndigits > 0)
6019 	{
6020 		exponent = (var->weight + 1) * DEC_DIGITS;
6021 
6022 		/*
6023 		 * Compensate for leading decimal zeroes in the first numeric digit by
6024 		 * decrementing the exponent.
6025 		 */
6026 		exponent -= DEC_DIGITS - (int) log10(var->digits[0]);
6027 	}
6028 	else
6029 	{
6030 		/*
6031 		 * If var has no digits, then it must be zero.
6032 		 *
6033 		 * Zero doesn't technically have a meaningful exponent in normalised
6034 		 * notation, but we just display the exponent as zero for consistency
6035 		 * of output.
6036 		 */
6037 		exponent = 0;
6038 	}
6039 
6040 	/*
6041 	 * Divide var by 10^exponent to get the significand, rounding to rscale
6042 	 * decimal digits in the process.
6043 	 */
6044 	init_var(&tmp_var);
6045 
6046 	power_ten_int(exponent, &tmp_var);
6047 	div_var(var, &tmp_var, &tmp_var, rscale, true);
6048 	sig_out = get_str_from_var(&tmp_var);
6049 
6050 	free_var(&tmp_var);
6051 
6052 	/*
6053 	 * Allocate space for the result.
6054 	 *
6055 	 * In addition to the significand, we need room for the exponent
6056 	 * decoration ("e"), the sign of the exponent, up to 10 digits for the
6057 	 * exponent itself, and of course the null terminator.
6058 	 */
6059 	len = strlen(sig_out) + 13;
6060 	str = palloc(len);
6061 	snprintf(str, len, "%se%+03d", sig_out, exponent);
6062 
6063 	pfree(sig_out);
6064 
6065 	return str;
6066 }
6067 
6068 
6069 /*
6070  * make_result() -
6071  *
6072  *	Create the packed db numeric format in palloc()'d memory from
6073  *	a variable.
6074  */
6075 static Numeric
make_result(const NumericVar * var)6076 make_result(const NumericVar *var)
6077 {
6078 	Numeric		result;
6079 	NumericDigit *digits = var->digits;
6080 	int			weight = var->weight;
6081 	int			sign = var->sign;
6082 	int			n;
6083 	Size		len;
6084 
6085 	if (sign == NUMERIC_NAN)
6086 	{
6087 		result = (Numeric) palloc(NUMERIC_HDRSZ_SHORT);
6088 
6089 		SET_VARSIZE(result, NUMERIC_HDRSZ_SHORT);
6090 		result->choice.n_header = NUMERIC_NAN;
6091 		/* the header word is all we need */
6092 
6093 		dump_numeric("make_result()", result);
6094 		return result;
6095 	}
6096 
6097 	n = var->ndigits;
6098 
6099 	/* truncate leading zeroes */
6100 	while (n > 0 && *digits == 0)
6101 	{
6102 		digits++;
6103 		weight--;
6104 		n--;
6105 	}
6106 	/* truncate trailing zeroes */
6107 	while (n > 0 && digits[n - 1] == 0)
6108 		n--;
6109 
6110 	/* If zero result, force to weight=0 and positive sign */
6111 	if (n == 0)
6112 	{
6113 		weight = 0;
6114 		sign = NUMERIC_POS;
6115 	}
6116 
6117 	/* Build the result */
6118 	if (NUMERIC_CAN_BE_SHORT(var->dscale, weight))
6119 	{
6120 		len = NUMERIC_HDRSZ_SHORT + n * sizeof(NumericDigit);
6121 		result = (Numeric) palloc(len);
6122 		SET_VARSIZE(result, len);
6123 		result->choice.n_short.n_header =
6124 			(sign == NUMERIC_NEG ? (NUMERIC_SHORT | NUMERIC_SHORT_SIGN_MASK)
6125 			 : NUMERIC_SHORT)
6126 			| (var->dscale << NUMERIC_SHORT_DSCALE_SHIFT)
6127 			| (weight < 0 ? NUMERIC_SHORT_WEIGHT_SIGN_MASK : 0)
6128 			| (weight & NUMERIC_SHORT_WEIGHT_MASK);
6129 	}
6130 	else
6131 	{
6132 		len = NUMERIC_HDRSZ + n * sizeof(NumericDigit);
6133 		result = (Numeric) palloc(len);
6134 		SET_VARSIZE(result, len);
6135 		result->choice.n_long.n_sign_dscale =
6136 			sign | (var->dscale & NUMERIC_DSCALE_MASK);
6137 		result->choice.n_long.n_weight = weight;
6138 	}
6139 
6140 	Assert(NUMERIC_NDIGITS(result) == n);
6141 	if (n > 0)
6142 		memcpy(NUMERIC_DIGITS(result), digits, n * sizeof(NumericDigit));
6143 
6144 	/* Check for overflow of int16 fields */
6145 	if (NUMERIC_WEIGHT(result) != weight ||
6146 		NUMERIC_DSCALE(result) != var->dscale)
6147 		ereport(ERROR,
6148 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6149 				 errmsg("value overflows numeric format")));
6150 
6151 	dump_numeric("make_result()", result);
6152 	return result;
6153 }
6154 
6155 
6156 /*
6157  * apply_typmod() -
6158  *
6159  *	Do bounds checking and rounding according to the attributes
6160  *	typmod field.
6161  */
6162 static void
apply_typmod(NumericVar * var,int32 typmod)6163 apply_typmod(NumericVar *var, int32 typmod)
6164 {
6165 	int			precision;
6166 	int			scale;
6167 	int			maxdigits;
6168 	int			ddigits;
6169 	int			i;
6170 
6171 	/* Do nothing if we have a default typmod (-1) */
6172 	if (typmod < (int32) (VARHDRSZ))
6173 		return;
6174 
6175 	typmod -= VARHDRSZ;
6176 	precision = (typmod >> 16) & 0xffff;
6177 	scale = typmod & 0xffff;
6178 	maxdigits = precision - scale;
6179 
6180 	/* Round to target scale (and set var->dscale) */
6181 	round_var(var, scale);
6182 
6183 	/*
6184 	 * Check for overflow - note we can't do this before rounding, because
6185 	 * rounding could raise the weight.  Also note that the var's weight could
6186 	 * be inflated by leading zeroes, which will be stripped before storage
6187 	 * but perhaps might not have been yet. In any case, we must recognize a
6188 	 * true zero, whose weight doesn't mean anything.
6189 	 */
6190 	ddigits = (var->weight + 1) * DEC_DIGITS;
6191 	if (ddigits > maxdigits)
6192 	{
6193 		/* Determine true weight; and check for all-zero result */
6194 		for (i = 0; i < var->ndigits; i++)
6195 		{
6196 			NumericDigit dig = var->digits[i];
6197 
6198 			if (dig)
6199 			{
6200 				/* Adjust for any high-order decimal zero digits */
6201 #if DEC_DIGITS == 4
6202 				if (dig < 10)
6203 					ddigits -= 3;
6204 				else if (dig < 100)
6205 					ddigits -= 2;
6206 				else if (dig < 1000)
6207 					ddigits -= 1;
6208 #elif DEC_DIGITS == 2
6209 				if (dig < 10)
6210 					ddigits -= 1;
6211 #elif DEC_DIGITS == 1
6212 				/* no adjustment */
6213 #else
6214 #error unsupported NBASE
6215 #endif
6216 				if (ddigits > maxdigits)
6217 					ereport(ERROR,
6218 							(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6219 							 errmsg("numeric field overflow"),
6220 							 errdetail("A field with precision %d, scale %d must round to an absolute value less than %s%d.",
6221 									   precision, scale,
6222 					/* Display 10^0 as 1 */
6223 									   maxdigits ? "10^" : "",
6224 									   maxdigits ? maxdigits : 1
6225 									   )));
6226 				break;
6227 			}
6228 			ddigits -= DEC_DIGITS;
6229 		}
6230 	}
6231 }
6232 
6233 /*
6234  * Convert numeric to int8, rounding if needed.
6235  *
6236  * If overflow, return false (no error is raised).  Return true if okay.
6237  */
6238 static bool
numericvar_to_int64(const NumericVar * var,int64 * result)6239 numericvar_to_int64(const NumericVar *var, int64 *result)
6240 {
6241 	NumericDigit *digits;
6242 	int			ndigits;
6243 	int			weight;
6244 	int			i;
6245 	int64		val;
6246 	bool		neg;
6247 	NumericVar	rounded;
6248 
6249 	/* Round to nearest integer */
6250 	init_var(&rounded);
6251 	set_var_from_var(var, &rounded);
6252 	round_var(&rounded, 0);
6253 
6254 	/* Check for zero input */
6255 	strip_var(&rounded);
6256 	ndigits = rounded.ndigits;
6257 	if (ndigits == 0)
6258 	{
6259 		*result = 0;
6260 		free_var(&rounded);
6261 		return true;
6262 	}
6263 
6264 	/*
6265 	 * For input like 10000000000, we must treat stripped digits as real. So
6266 	 * the loop assumes there are weight+1 digits before the decimal point.
6267 	 */
6268 	weight = rounded.weight;
6269 	Assert(weight >= 0 && ndigits <= weight + 1);
6270 
6271 	/*
6272 	 * Construct the result. To avoid issues with converting a value
6273 	 * corresponding to INT64_MIN (which can't be represented as a positive 64
6274 	 * bit two's complement integer), accumulate value as a negative number.
6275 	 */
6276 	digits = rounded.digits;
6277 	neg = (rounded.sign == NUMERIC_NEG);
6278 	val = -digits[0];
6279 	for (i = 1; i <= weight; i++)
6280 	{
6281 		if (unlikely(pg_mul_s64_overflow(val, NBASE, &val)))
6282 		{
6283 			free_var(&rounded);
6284 			return false;
6285 		}
6286 
6287 		if (i < ndigits)
6288 		{
6289 			if (unlikely(pg_sub_s64_overflow(val, digits[i], &val)))
6290 			{
6291 				free_var(&rounded);
6292 				return false;
6293 			}
6294 		}
6295 	}
6296 
6297 	free_var(&rounded);
6298 
6299 	if (!neg)
6300 	{
6301 		if (unlikely(val == PG_INT64_MIN))
6302 			return false;
6303 		val = -val;
6304 	}
6305 	*result = val;
6306 
6307 	return true;
6308 }
6309 
6310 /*
6311  * Convert int8 value to numeric.
6312  */
6313 static void
int64_to_numericvar(int64 val,NumericVar * var)6314 int64_to_numericvar(int64 val, NumericVar *var)
6315 {
6316 	uint64		uval,
6317 				newuval;
6318 	NumericDigit *ptr;
6319 	int			ndigits;
6320 
6321 	/* int64 can require at most 19 decimal digits; add one for safety */
6322 	alloc_var(var, 20 / DEC_DIGITS);
6323 	if (val < 0)
6324 	{
6325 		var->sign = NUMERIC_NEG;
6326 		uval = -val;
6327 	}
6328 	else
6329 	{
6330 		var->sign = NUMERIC_POS;
6331 		uval = val;
6332 	}
6333 	var->dscale = 0;
6334 	if (val == 0)
6335 	{
6336 		var->ndigits = 0;
6337 		var->weight = 0;
6338 		return;
6339 	}
6340 	ptr = var->digits + var->ndigits;
6341 	ndigits = 0;
6342 	do
6343 	{
6344 		ptr--;
6345 		ndigits++;
6346 		newuval = uval / NBASE;
6347 		*ptr = uval - newuval * NBASE;
6348 		uval = newuval;
6349 	} while (uval);
6350 	var->digits = ptr;
6351 	var->ndigits = ndigits;
6352 	var->weight = ndigits - 1;
6353 }
6354 
6355 #ifdef HAVE_INT128
6356 /*
6357  * Convert numeric to int128, rounding if needed.
6358  *
6359  * If overflow, return false (no error is raised).  Return true if okay.
6360  */
6361 static bool
numericvar_to_int128(const NumericVar * var,int128 * result)6362 numericvar_to_int128(const NumericVar *var, int128 *result)
6363 {
6364 	NumericDigit *digits;
6365 	int			ndigits;
6366 	int			weight;
6367 	int			i;
6368 	int128		val,
6369 				oldval;
6370 	bool		neg;
6371 	NumericVar	rounded;
6372 
6373 	/* Round to nearest integer */
6374 	init_var(&rounded);
6375 	set_var_from_var(var, &rounded);
6376 	round_var(&rounded, 0);
6377 
6378 	/* Check for zero input */
6379 	strip_var(&rounded);
6380 	ndigits = rounded.ndigits;
6381 	if (ndigits == 0)
6382 	{
6383 		*result = 0;
6384 		free_var(&rounded);
6385 		return true;
6386 	}
6387 
6388 	/*
6389 	 * For input like 10000000000, we must treat stripped digits as real. So
6390 	 * the loop assumes there are weight+1 digits before the decimal point.
6391 	 */
6392 	weight = rounded.weight;
6393 	Assert(weight >= 0 && ndigits <= weight + 1);
6394 
6395 	/* Construct the result */
6396 	digits = rounded.digits;
6397 	neg = (rounded.sign == NUMERIC_NEG);
6398 	val = digits[0];
6399 	for (i = 1; i <= weight; i++)
6400 	{
6401 		oldval = val;
6402 		val *= NBASE;
6403 		if (i < ndigits)
6404 			val += digits[i];
6405 
6406 		/*
6407 		 * The overflow check is a bit tricky because we want to accept
6408 		 * INT128_MIN, which will overflow the positive accumulator.  We can
6409 		 * detect this case easily though because INT128_MIN is the only
6410 		 * nonzero value for which -val == val (on a two's complement machine,
6411 		 * anyway).
6412 		 */
6413 		if ((val / NBASE) != oldval)	/* possible overflow? */
6414 		{
6415 			if (!neg || (-val) != val || val == 0 || oldval < 0)
6416 			{
6417 				free_var(&rounded);
6418 				return false;
6419 			}
6420 		}
6421 	}
6422 
6423 	free_var(&rounded);
6424 
6425 	*result = neg ? -val : val;
6426 	return true;
6427 }
6428 
6429 /*
6430  * Convert 128 bit integer to numeric.
6431  */
6432 static void
int128_to_numericvar(int128 val,NumericVar * var)6433 int128_to_numericvar(int128 val, NumericVar *var)
6434 {
6435 	uint128		uval,
6436 				newuval;
6437 	NumericDigit *ptr;
6438 	int			ndigits;
6439 
6440 	/* int128 can require at most 39 decimal digits; add one for safety */
6441 	alloc_var(var, 40 / DEC_DIGITS);
6442 	if (val < 0)
6443 	{
6444 		var->sign = NUMERIC_NEG;
6445 		uval = -val;
6446 	}
6447 	else
6448 	{
6449 		var->sign = NUMERIC_POS;
6450 		uval = val;
6451 	}
6452 	var->dscale = 0;
6453 	if (val == 0)
6454 	{
6455 		var->ndigits = 0;
6456 		var->weight = 0;
6457 		return;
6458 	}
6459 	ptr = var->digits + var->ndigits;
6460 	ndigits = 0;
6461 	do
6462 	{
6463 		ptr--;
6464 		ndigits++;
6465 		newuval = uval / NBASE;
6466 		*ptr = uval - newuval * NBASE;
6467 		uval = newuval;
6468 	} while (uval);
6469 	var->digits = ptr;
6470 	var->ndigits = ndigits;
6471 	var->weight = ndigits - 1;
6472 }
6473 #endif
6474 
6475 /*
6476  * Convert numeric to float8; if out of range, return +/- HUGE_VAL
6477  */
6478 static double
numeric_to_double_no_overflow(Numeric num)6479 numeric_to_double_no_overflow(Numeric num)
6480 {
6481 	char	   *tmp;
6482 	double		val;
6483 	char	   *endptr;
6484 
6485 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
6486 											  NumericGetDatum(num)));
6487 
6488 	/* unlike float8in, we ignore ERANGE from strtod */
6489 	val = strtod(tmp, &endptr);
6490 	if (*endptr != '\0')
6491 	{
6492 		/* shouldn't happen ... */
6493 		ereport(ERROR,
6494 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6495 				 errmsg("invalid input syntax for type %s: \"%s\"",
6496 						"double precision", tmp)));
6497 	}
6498 
6499 	pfree(tmp);
6500 
6501 	return val;
6502 }
6503 
6504 /* As above, but work from a NumericVar */
6505 static double
numericvar_to_double_no_overflow(const NumericVar * var)6506 numericvar_to_double_no_overflow(const NumericVar *var)
6507 {
6508 	char	   *tmp;
6509 	double		val;
6510 	char	   *endptr;
6511 
6512 	tmp = get_str_from_var(var);
6513 
6514 	/* unlike float8in, we ignore ERANGE from strtod */
6515 	val = strtod(tmp, &endptr);
6516 	if (*endptr != '\0')
6517 	{
6518 		/* shouldn't happen ... */
6519 		ereport(ERROR,
6520 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6521 				 errmsg("invalid input syntax for type %s: \"%s\"",
6522 						"double precision", tmp)));
6523 	}
6524 
6525 	pfree(tmp);
6526 
6527 	return val;
6528 }
6529 
6530 
6531 /*
6532  * cmp_var() -
6533  *
6534  *	Compare two values on variable level.  We assume zeroes have been
6535  *	truncated to no digits.
6536  */
6537 static int
cmp_var(const NumericVar * var1,const NumericVar * var2)6538 cmp_var(const NumericVar *var1, const NumericVar *var2)
6539 {
6540 	return cmp_var_common(var1->digits, var1->ndigits,
6541 						  var1->weight, var1->sign,
6542 						  var2->digits, var2->ndigits,
6543 						  var2->weight, var2->sign);
6544 }
6545 
6546 /*
6547  * cmp_var_common() -
6548  *
6549  *	Main routine of cmp_var(). This function can be used by both
6550  *	NumericVar and Numeric.
6551  */
6552 static int
cmp_var_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,int var1sign,const NumericDigit * var2digits,int var2ndigits,int var2weight,int var2sign)6553 cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
6554 			   int var1weight, int var1sign,
6555 			   const NumericDigit *var2digits, int var2ndigits,
6556 			   int var2weight, int var2sign)
6557 {
6558 	if (var1ndigits == 0)
6559 	{
6560 		if (var2ndigits == 0)
6561 			return 0;
6562 		if (var2sign == NUMERIC_NEG)
6563 			return 1;
6564 		return -1;
6565 	}
6566 	if (var2ndigits == 0)
6567 	{
6568 		if (var1sign == NUMERIC_POS)
6569 			return 1;
6570 		return -1;
6571 	}
6572 
6573 	if (var1sign == NUMERIC_POS)
6574 	{
6575 		if (var2sign == NUMERIC_NEG)
6576 			return 1;
6577 		return cmp_abs_common(var1digits, var1ndigits, var1weight,
6578 							  var2digits, var2ndigits, var2weight);
6579 	}
6580 
6581 	if (var2sign == NUMERIC_POS)
6582 		return -1;
6583 
6584 	return cmp_abs_common(var2digits, var2ndigits, var2weight,
6585 						  var1digits, var1ndigits, var1weight);
6586 }
6587 
6588 
6589 /*
6590  * add_var() -
6591  *
6592  *	Full version of add functionality on variable level (handling signs).
6593  *	result might point to one of the operands too without danger.
6594  */
6595 static void
add_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)6596 add_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
6597 {
6598 	/*
6599 	 * Decide on the signs of the two variables what to do
6600 	 */
6601 	if (var1->sign == NUMERIC_POS)
6602 	{
6603 		if (var2->sign == NUMERIC_POS)
6604 		{
6605 			/*
6606 			 * Both are positive result = +(ABS(var1) + ABS(var2))
6607 			 */
6608 			add_abs(var1, var2, result);
6609 			result->sign = NUMERIC_POS;
6610 		}
6611 		else
6612 		{
6613 			/*
6614 			 * var1 is positive, var2 is negative Must compare absolute values
6615 			 */
6616 			switch (cmp_abs(var1, var2))
6617 			{
6618 				case 0:
6619 					/* ----------
6620 					 * ABS(var1) == ABS(var2)
6621 					 * result = ZERO
6622 					 * ----------
6623 					 */
6624 					zero_var(result);
6625 					result->dscale = Max(var1->dscale, var2->dscale);
6626 					break;
6627 
6628 				case 1:
6629 					/* ----------
6630 					 * ABS(var1) > ABS(var2)
6631 					 * result = +(ABS(var1) - ABS(var2))
6632 					 * ----------
6633 					 */
6634 					sub_abs(var1, var2, result);
6635 					result->sign = NUMERIC_POS;
6636 					break;
6637 
6638 				case -1:
6639 					/* ----------
6640 					 * ABS(var1) < ABS(var2)
6641 					 * result = -(ABS(var2) - ABS(var1))
6642 					 * ----------
6643 					 */
6644 					sub_abs(var2, var1, result);
6645 					result->sign = NUMERIC_NEG;
6646 					break;
6647 			}
6648 		}
6649 	}
6650 	else
6651 	{
6652 		if (var2->sign == NUMERIC_POS)
6653 		{
6654 			/* ----------
6655 			 * var1 is negative, var2 is positive
6656 			 * Must compare absolute values
6657 			 * ----------
6658 			 */
6659 			switch (cmp_abs(var1, var2))
6660 			{
6661 				case 0:
6662 					/* ----------
6663 					 * ABS(var1) == ABS(var2)
6664 					 * result = ZERO
6665 					 * ----------
6666 					 */
6667 					zero_var(result);
6668 					result->dscale = Max(var1->dscale, var2->dscale);
6669 					break;
6670 
6671 				case 1:
6672 					/* ----------
6673 					 * ABS(var1) > ABS(var2)
6674 					 * result = -(ABS(var1) - ABS(var2))
6675 					 * ----------
6676 					 */
6677 					sub_abs(var1, var2, result);
6678 					result->sign = NUMERIC_NEG;
6679 					break;
6680 
6681 				case -1:
6682 					/* ----------
6683 					 * ABS(var1) < ABS(var2)
6684 					 * result = +(ABS(var2) - ABS(var1))
6685 					 * ----------
6686 					 */
6687 					sub_abs(var2, var1, result);
6688 					result->sign = NUMERIC_POS;
6689 					break;
6690 			}
6691 		}
6692 		else
6693 		{
6694 			/* ----------
6695 			 * Both are negative
6696 			 * result = -(ABS(var1) + ABS(var2))
6697 			 * ----------
6698 			 */
6699 			add_abs(var1, var2, result);
6700 			result->sign = NUMERIC_NEG;
6701 		}
6702 	}
6703 }
6704 
6705 
6706 /*
6707  * sub_var() -
6708  *
6709  *	Full version of sub functionality on variable level (handling signs).
6710  *	result might point to one of the operands too without danger.
6711  */
6712 static void
sub_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)6713 sub_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
6714 {
6715 	/*
6716 	 * Decide on the signs of the two variables what to do
6717 	 */
6718 	if (var1->sign == NUMERIC_POS)
6719 	{
6720 		if (var2->sign == NUMERIC_NEG)
6721 		{
6722 			/* ----------
6723 			 * var1 is positive, var2 is negative
6724 			 * result = +(ABS(var1) + ABS(var2))
6725 			 * ----------
6726 			 */
6727 			add_abs(var1, var2, result);
6728 			result->sign = NUMERIC_POS;
6729 		}
6730 		else
6731 		{
6732 			/* ----------
6733 			 * Both are positive
6734 			 * Must compare absolute values
6735 			 * ----------
6736 			 */
6737 			switch (cmp_abs(var1, var2))
6738 			{
6739 				case 0:
6740 					/* ----------
6741 					 * ABS(var1) == ABS(var2)
6742 					 * result = ZERO
6743 					 * ----------
6744 					 */
6745 					zero_var(result);
6746 					result->dscale = Max(var1->dscale, var2->dscale);
6747 					break;
6748 
6749 				case 1:
6750 					/* ----------
6751 					 * ABS(var1) > ABS(var2)
6752 					 * result = +(ABS(var1) - ABS(var2))
6753 					 * ----------
6754 					 */
6755 					sub_abs(var1, var2, result);
6756 					result->sign = NUMERIC_POS;
6757 					break;
6758 
6759 				case -1:
6760 					/* ----------
6761 					 * ABS(var1) < ABS(var2)
6762 					 * result = -(ABS(var2) - ABS(var1))
6763 					 * ----------
6764 					 */
6765 					sub_abs(var2, var1, result);
6766 					result->sign = NUMERIC_NEG;
6767 					break;
6768 			}
6769 		}
6770 	}
6771 	else
6772 	{
6773 		if (var2->sign == NUMERIC_NEG)
6774 		{
6775 			/* ----------
6776 			 * Both are negative
6777 			 * Must compare absolute values
6778 			 * ----------
6779 			 */
6780 			switch (cmp_abs(var1, var2))
6781 			{
6782 				case 0:
6783 					/* ----------
6784 					 * ABS(var1) == ABS(var2)
6785 					 * result = ZERO
6786 					 * ----------
6787 					 */
6788 					zero_var(result);
6789 					result->dscale = Max(var1->dscale, var2->dscale);
6790 					break;
6791 
6792 				case 1:
6793 					/* ----------
6794 					 * ABS(var1) > ABS(var2)
6795 					 * result = -(ABS(var1) - ABS(var2))
6796 					 * ----------
6797 					 */
6798 					sub_abs(var1, var2, result);
6799 					result->sign = NUMERIC_NEG;
6800 					break;
6801 
6802 				case -1:
6803 					/* ----------
6804 					 * ABS(var1) < ABS(var2)
6805 					 * result = +(ABS(var2) - ABS(var1))
6806 					 * ----------
6807 					 */
6808 					sub_abs(var2, var1, result);
6809 					result->sign = NUMERIC_POS;
6810 					break;
6811 			}
6812 		}
6813 		else
6814 		{
6815 			/* ----------
6816 			 * var1 is negative, var2 is positive
6817 			 * result = -(ABS(var1) + ABS(var2))
6818 			 * ----------
6819 			 */
6820 			add_abs(var1, var2, result);
6821 			result->sign = NUMERIC_NEG;
6822 		}
6823 	}
6824 }
6825 
6826 
6827 /*
6828  * mul_var() -
6829  *
6830  *	Multiplication on variable level. Product of var1 * var2 is stored
6831  *	in result.  Result is rounded to no more than rscale fractional digits.
6832  */
6833 static void
mul_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale)6834 mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
6835 		int rscale)
6836 {
6837 	int			res_ndigits;
6838 	int			res_sign;
6839 	int			res_weight;
6840 	int			maxdigits;
6841 	int		   *dig;
6842 	int			carry;
6843 	int			maxdig;
6844 	int			newdig;
6845 	int			var1ndigits;
6846 	int			var2ndigits;
6847 	NumericDigit *var1digits;
6848 	NumericDigit *var2digits;
6849 	NumericDigit *res_digits;
6850 	int			i,
6851 				i1,
6852 				i2;
6853 
6854 	/*
6855 	 * Arrange for var1 to be the shorter of the two numbers.  This improves
6856 	 * performance because the inner multiplication loop is much simpler than
6857 	 * the outer loop, so it's better to have a smaller number of iterations
6858 	 * of the outer loop.  This also reduces the number of times that the
6859 	 * accumulator array needs to be normalized.
6860 	 */
6861 	if (var1->ndigits > var2->ndigits)
6862 	{
6863 		const NumericVar *tmp = var1;
6864 
6865 		var1 = var2;
6866 		var2 = tmp;
6867 	}
6868 
6869 	/* copy these values into local vars for speed in inner loop */
6870 	var1ndigits = var1->ndigits;
6871 	var2ndigits = var2->ndigits;
6872 	var1digits = var1->digits;
6873 	var2digits = var2->digits;
6874 
6875 	if (var1ndigits == 0 || var2ndigits == 0)
6876 	{
6877 		/* one or both inputs is zero; so is result */
6878 		zero_var(result);
6879 		result->dscale = rscale;
6880 		return;
6881 	}
6882 
6883 	/* Determine result sign and (maximum possible) weight */
6884 	if (var1->sign == var2->sign)
6885 		res_sign = NUMERIC_POS;
6886 	else
6887 		res_sign = NUMERIC_NEG;
6888 	res_weight = var1->weight + var2->weight + 2;
6889 
6890 	/*
6891 	 * Determine the number of result digits to compute.  If the exact result
6892 	 * would have more than rscale fractional digits, truncate the computation
6893 	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
6894 	 * would only contribute to the right of that.  (This will give the exact
6895 	 * rounded-to-rscale answer unless carries out of the ignored positions
6896 	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
6897 	 *
6898 	 * Note: an exact computation could not produce more than var1ndigits +
6899 	 * var2ndigits digits, but we allocate one extra output digit in case
6900 	 * rscale-driven rounding produces a carry out of the highest exact digit.
6901 	 */
6902 	res_ndigits = var1ndigits + var2ndigits + 1;
6903 	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
6904 		MUL_GUARD_DIGITS;
6905 	res_ndigits = Min(res_ndigits, maxdigits);
6906 
6907 	if (res_ndigits < 3)
6908 	{
6909 		/* All input digits will be ignored; so result is zero */
6910 		zero_var(result);
6911 		result->dscale = rscale;
6912 		return;
6913 	}
6914 
6915 	/*
6916 	 * We do the arithmetic in an array "dig[]" of signed int's.  Since
6917 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
6918 	 * to avoid normalizing carries immediately.
6919 	 *
6920 	 * maxdig tracks the maximum possible value of any dig[] entry; when this
6921 	 * threatens to exceed INT_MAX, we take the time to propagate carries.
6922 	 * Furthermore, we need to ensure that overflow doesn't occur during the
6923 	 * carry propagation passes either.  The carry values could be as much as
6924 	 * INT_MAX/NBASE, so really we must normalize when digits threaten to
6925 	 * exceed INT_MAX - INT_MAX/NBASE.
6926 	 *
6927 	 * To avoid overflow in maxdig itself, it actually represents the max
6928 	 * possible value divided by NBASE-1, ie, at the top of the loop it is
6929 	 * known that no dig[] entry exceeds maxdig * (NBASE-1).
6930 	 */
6931 	dig = (int *) palloc0(res_ndigits * sizeof(int));
6932 	maxdig = 0;
6933 
6934 	/*
6935 	 * The least significant digits of var1 should be ignored if they don't
6936 	 * contribute directly to the first res_ndigits digits of the result that
6937 	 * we are computing.
6938 	 *
6939 	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
6940 	 * i1+i2+2 of the accumulator array, so we need only consider digits of
6941 	 * var1 for which i1 <= res_ndigits - 3.
6942 	 */
6943 	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
6944 	{
6945 		int			var1digit = var1digits[i1];
6946 
6947 		if (var1digit == 0)
6948 			continue;
6949 
6950 		/* Time to normalize? */
6951 		maxdig += var1digit;
6952 		if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
6953 		{
6954 			/* Yes, do it */
6955 			carry = 0;
6956 			for (i = res_ndigits - 1; i >= 0; i--)
6957 			{
6958 				newdig = dig[i] + carry;
6959 				if (newdig >= NBASE)
6960 				{
6961 					carry = newdig / NBASE;
6962 					newdig -= carry * NBASE;
6963 				}
6964 				else
6965 					carry = 0;
6966 				dig[i] = newdig;
6967 			}
6968 			Assert(carry == 0);
6969 			/* Reset maxdig to indicate new worst-case */
6970 			maxdig = 1 + var1digit;
6971 		}
6972 
6973 		/*
6974 		 * Add the appropriate multiple of var2 into the accumulator.
6975 		 *
6976 		 * As above, digits of var2 can be ignored if they don't contribute,
6977 		 * so we only include digits for which i1+i2+2 <= res_ndigits - 1.
6978 		 */
6979 		for (i2 = Min(var2ndigits - 1, res_ndigits - i1 - 3), i = i1 + i2 + 2;
6980 			 i2 >= 0; i2--)
6981 			dig[i--] += var1digit * var2digits[i2];
6982 	}
6983 
6984 	/*
6985 	 * Now we do a final carry propagation pass to normalize the result, which
6986 	 * we combine with storing the result digits into the output. Note that
6987 	 * this is still done at full precision w/guard digits.
6988 	 */
6989 	alloc_var(result, res_ndigits);
6990 	res_digits = result->digits;
6991 	carry = 0;
6992 	for (i = res_ndigits - 1; i >= 0; i--)
6993 	{
6994 		newdig = dig[i] + carry;
6995 		if (newdig >= NBASE)
6996 		{
6997 			carry = newdig / NBASE;
6998 			newdig -= carry * NBASE;
6999 		}
7000 		else
7001 			carry = 0;
7002 		res_digits[i] = newdig;
7003 	}
7004 	Assert(carry == 0);
7005 
7006 	pfree(dig);
7007 
7008 	/*
7009 	 * Finally, round the result to the requested precision.
7010 	 */
7011 	result->weight = res_weight;
7012 	result->sign = res_sign;
7013 
7014 	/* Round to target rscale (and set result->dscale) */
7015 	round_var(result, rscale);
7016 
7017 	/* Strip leading and trailing zeroes */
7018 	strip_var(result);
7019 }
7020 
7021 
7022 /*
7023  * div_var() -
7024  *
7025  *	Division on variable level. Quotient of var1 / var2 is stored in result.
7026  *	The quotient is figured to exactly rscale fractional digits.
7027  *	If round is true, it is rounded at the rscale'th digit; if false, it
7028  *	is truncated (towards zero) at that digit.
7029  */
7030 static void
div_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)7031 div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
7032 		int rscale, bool round)
7033 {
7034 	int			div_ndigits;
7035 	int			res_ndigits;
7036 	int			res_sign;
7037 	int			res_weight;
7038 	int			carry;
7039 	int			borrow;
7040 	int			divisor1;
7041 	int			divisor2;
7042 	NumericDigit *dividend;
7043 	NumericDigit *divisor;
7044 	NumericDigit *res_digits;
7045 	int			i;
7046 	int			j;
7047 
7048 	/* copy these values into local vars for speed in inner loop */
7049 	int			var1ndigits = var1->ndigits;
7050 	int			var2ndigits = var2->ndigits;
7051 
7052 	/*
7053 	 * First of all division by zero check; we must not be handed an
7054 	 * unnormalized divisor.
7055 	 */
7056 	if (var2ndigits == 0 || var2->digits[0] == 0)
7057 		ereport(ERROR,
7058 				(errcode(ERRCODE_DIVISION_BY_ZERO),
7059 				 errmsg("division by zero")));
7060 
7061 	/*
7062 	 * Now result zero check
7063 	 */
7064 	if (var1ndigits == 0)
7065 	{
7066 		zero_var(result);
7067 		result->dscale = rscale;
7068 		return;
7069 	}
7070 
7071 	/*
7072 	 * Determine the result sign, weight and number of digits to calculate.
7073 	 * The weight figured here is correct if the emitted quotient has no
7074 	 * leading zero digits; otherwise strip_var() will fix things up.
7075 	 */
7076 	if (var1->sign == var2->sign)
7077 		res_sign = NUMERIC_POS;
7078 	else
7079 		res_sign = NUMERIC_NEG;
7080 	res_weight = var1->weight - var2->weight;
7081 	/* The number of accurate result digits we need to produce: */
7082 	res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7083 	/* ... but always at least 1 */
7084 	res_ndigits = Max(res_ndigits, 1);
7085 	/* If rounding needed, figure one more digit to ensure correct result */
7086 	if (round)
7087 		res_ndigits++;
7088 
7089 	/*
7090 	 * The working dividend normally requires res_ndigits + var2ndigits
7091 	 * digits, but make it at least var1ndigits so we can load all of var1
7092 	 * into it.  (There will be an additional digit dividend[0] in the
7093 	 * dividend space, but for consistency with Knuth's notation we don't
7094 	 * count that in div_ndigits.)
7095 	 */
7096 	div_ndigits = res_ndigits + var2ndigits;
7097 	div_ndigits = Max(div_ndigits, var1ndigits);
7098 
7099 	/*
7100 	 * We need a workspace with room for the working dividend (div_ndigits+1
7101 	 * digits) plus room for the possibly-normalized divisor (var2ndigits
7102 	 * digits).  It is convenient also to have a zero at divisor[0] with the
7103 	 * actual divisor data in divisor[1 .. var2ndigits].  Transferring the
7104 	 * digits into the workspace also allows us to realloc the result (which
7105 	 * might be the same as either input var) before we begin the main loop.
7106 	 * Note that we use palloc0 to ensure that divisor[0], dividend[0], and
7107 	 * any additional dividend positions beyond var1ndigits, start out 0.
7108 	 */
7109 	dividend = (NumericDigit *)
7110 		palloc0((div_ndigits + var2ndigits + 2) * sizeof(NumericDigit));
7111 	divisor = dividend + (div_ndigits + 1);
7112 	memcpy(dividend + 1, var1->digits, var1ndigits * sizeof(NumericDigit));
7113 	memcpy(divisor + 1, var2->digits, var2ndigits * sizeof(NumericDigit));
7114 
7115 	/*
7116 	 * Now we can realloc the result to hold the generated quotient digits.
7117 	 */
7118 	alloc_var(result, res_ndigits);
7119 	res_digits = result->digits;
7120 
7121 	if (var2ndigits == 1)
7122 	{
7123 		/*
7124 		 * If there's only a single divisor digit, we can use a fast path (cf.
7125 		 * Knuth section 4.3.1 exercise 16).
7126 		 */
7127 		divisor1 = divisor[1];
7128 		carry = 0;
7129 		for (i = 0; i < res_ndigits; i++)
7130 		{
7131 			carry = carry * NBASE + dividend[i + 1];
7132 			res_digits[i] = carry / divisor1;
7133 			carry = carry % divisor1;
7134 		}
7135 	}
7136 	else
7137 	{
7138 		/*
7139 		 * The full multiple-place algorithm is taken from Knuth volume 2,
7140 		 * Algorithm 4.3.1D.
7141 		 *
7142 		 * We need the first divisor digit to be >= NBASE/2.  If it isn't,
7143 		 * make it so by scaling up both the divisor and dividend by the
7144 		 * factor "d".  (The reason for allocating dividend[0] above is to
7145 		 * leave room for possible carry here.)
7146 		 */
7147 		if (divisor[1] < HALF_NBASE)
7148 		{
7149 			int			d = NBASE / (divisor[1] + 1);
7150 
7151 			carry = 0;
7152 			for (i = var2ndigits; i > 0; i--)
7153 			{
7154 				carry += divisor[i] * d;
7155 				divisor[i] = carry % NBASE;
7156 				carry = carry / NBASE;
7157 			}
7158 			Assert(carry == 0);
7159 			carry = 0;
7160 			/* at this point only var1ndigits of dividend can be nonzero */
7161 			for (i = var1ndigits; i >= 0; i--)
7162 			{
7163 				carry += dividend[i] * d;
7164 				dividend[i] = carry % NBASE;
7165 				carry = carry / NBASE;
7166 			}
7167 			Assert(carry == 0);
7168 			Assert(divisor[1] >= HALF_NBASE);
7169 		}
7170 		/* First 2 divisor digits are used repeatedly in main loop */
7171 		divisor1 = divisor[1];
7172 		divisor2 = divisor[2];
7173 
7174 		/*
7175 		 * Begin the main loop.  Each iteration of this loop produces the j'th
7176 		 * quotient digit by dividing dividend[j .. j + var2ndigits] by the
7177 		 * divisor; this is essentially the same as the common manual
7178 		 * procedure for long division.
7179 		 */
7180 		for (j = 0; j < res_ndigits; j++)
7181 		{
7182 			/* Estimate quotient digit from the first two dividend digits */
7183 			int			next2digits = dividend[j] * NBASE + dividend[j + 1];
7184 			int			qhat;
7185 
7186 			/*
7187 			 * If next2digits are 0, then quotient digit must be 0 and there's
7188 			 * no need to adjust the working dividend.  It's worth testing
7189 			 * here to fall out ASAP when processing trailing zeroes in a
7190 			 * dividend.
7191 			 */
7192 			if (next2digits == 0)
7193 			{
7194 				res_digits[j] = 0;
7195 				continue;
7196 			}
7197 
7198 			if (dividend[j] == divisor1)
7199 				qhat = NBASE - 1;
7200 			else
7201 				qhat = next2digits / divisor1;
7202 
7203 			/*
7204 			 * Adjust quotient digit if it's too large.  Knuth proves that
7205 			 * after this step, the quotient digit will be either correct or
7206 			 * just one too large.  (Note: it's OK to use dividend[j+2] here
7207 			 * because we know the divisor length is at least 2.)
7208 			 */
7209 			while (divisor2 * qhat >
7210 				   (next2digits - qhat * divisor1) * NBASE + dividend[j + 2])
7211 				qhat--;
7212 
7213 			/* As above, need do nothing more when quotient digit is 0 */
7214 			if (qhat > 0)
7215 			{
7216 				/*
7217 				 * Multiply the divisor by qhat, and subtract that from the
7218 				 * working dividend.  "carry" tracks the multiplication,
7219 				 * "borrow" the subtraction (could we fold these together?)
7220 				 */
7221 				carry = 0;
7222 				borrow = 0;
7223 				for (i = var2ndigits; i >= 0; i--)
7224 				{
7225 					carry += divisor[i] * qhat;
7226 					borrow -= carry % NBASE;
7227 					carry = carry / NBASE;
7228 					borrow += dividend[j + i];
7229 					if (borrow < 0)
7230 					{
7231 						dividend[j + i] = borrow + NBASE;
7232 						borrow = -1;
7233 					}
7234 					else
7235 					{
7236 						dividend[j + i] = borrow;
7237 						borrow = 0;
7238 					}
7239 				}
7240 				Assert(carry == 0);
7241 
7242 				/*
7243 				 * If we got a borrow out of the top dividend digit, then
7244 				 * indeed qhat was one too large.  Fix it, and add back the
7245 				 * divisor to correct the working dividend.  (Knuth proves
7246 				 * that this will occur only about 3/NBASE of the time; hence,
7247 				 * it's a good idea to test this code with small NBASE to be
7248 				 * sure this section gets exercised.)
7249 				 */
7250 				if (borrow)
7251 				{
7252 					qhat--;
7253 					carry = 0;
7254 					for (i = var2ndigits; i >= 0; i--)
7255 					{
7256 						carry += dividend[j + i] + divisor[i];
7257 						if (carry >= NBASE)
7258 						{
7259 							dividend[j + i] = carry - NBASE;
7260 							carry = 1;
7261 						}
7262 						else
7263 						{
7264 							dividend[j + i] = carry;
7265 							carry = 0;
7266 						}
7267 					}
7268 					/* A carry should occur here to cancel the borrow above */
7269 					Assert(carry == 1);
7270 				}
7271 			}
7272 
7273 			/* And we're done with this quotient digit */
7274 			res_digits[j] = qhat;
7275 		}
7276 	}
7277 
7278 	pfree(dividend);
7279 
7280 	/*
7281 	 * Finally, round or truncate the result to the requested precision.
7282 	 */
7283 	result->weight = res_weight;
7284 	result->sign = res_sign;
7285 
7286 	/* Round or truncate to target rscale (and set result->dscale) */
7287 	if (round)
7288 		round_var(result, rscale);
7289 	else
7290 		trunc_var(result, rscale);
7291 
7292 	/* Strip leading and trailing zeroes */
7293 	strip_var(result);
7294 }
7295 
7296 
7297 /*
7298  * div_var_fast() -
7299  *
7300  *	This has the same API as div_var, but is implemented using the division
7301  *	algorithm from the "FM" library, rather than Knuth's schoolbook-division
7302  *	approach.  This is significantly faster but can produce inaccurate
7303  *	results, because it sometimes has to propagate rounding to the left,
7304  *	and so we can never be entirely sure that we know the requested digits
7305  *	exactly.  We compute DIV_GUARD_DIGITS extra digits, but there is
7306  *	no certainty that that's enough.  We use this only in the transcendental
7307  *	function calculation routines, where everything is approximate anyway.
7308  *
7309  *	Although we provide a "round" argument for consistency with div_var,
7310  *	it is unwise to use this function with round=false.  In truncation mode
7311  *	it is possible to get a result with no significant digits, for example
7312  *	with rscale=0 we might compute 0.99999... and truncate that to 0 when
7313  *	the correct answer is 1.
7314  */
7315 static void
div_var_fast(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)7316 div_var_fast(const NumericVar *var1, const NumericVar *var2,
7317 			 NumericVar *result, int rscale, bool round)
7318 {
7319 	int			div_ndigits;
7320 	int			res_sign;
7321 	int			res_weight;
7322 	int		   *div;
7323 	int			qdigit;
7324 	int			carry;
7325 	int			maxdiv;
7326 	int			newdig;
7327 	NumericDigit *res_digits;
7328 	double		fdividend,
7329 				fdivisor,
7330 				fdivisorinverse,
7331 				fquotient;
7332 	int			qi;
7333 	int			i;
7334 
7335 	/* copy these values into local vars for speed in inner loop */
7336 	int			var1ndigits = var1->ndigits;
7337 	int			var2ndigits = var2->ndigits;
7338 	NumericDigit *var1digits = var1->digits;
7339 	NumericDigit *var2digits = var2->digits;
7340 
7341 	/*
7342 	 * First of all division by zero check; we must not be handed an
7343 	 * unnormalized divisor.
7344 	 */
7345 	if (var2ndigits == 0 || var2digits[0] == 0)
7346 		ereport(ERROR,
7347 				(errcode(ERRCODE_DIVISION_BY_ZERO),
7348 				 errmsg("division by zero")));
7349 
7350 	/*
7351 	 * Now result zero check
7352 	 */
7353 	if (var1ndigits == 0)
7354 	{
7355 		zero_var(result);
7356 		result->dscale = rscale;
7357 		return;
7358 	}
7359 
7360 	/*
7361 	 * Determine the result sign, weight and number of digits to calculate
7362 	 */
7363 	if (var1->sign == var2->sign)
7364 		res_sign = NUMERIC_POS;
7365 	else
7366 		res_sign = NUMERIC_NEG;
7367 	res_weight = var1->weight - var2->weight + 1;
7368 	/* The number of accurate result digits we need to produce: */
7369 	div_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7370 	/* Add guard digits for roundoff error */
7371 	div_ndigits += DIV_GUARD_DIGITS;
7372 	if (div_ndigits < DIV_GUARD_DIGITS)
7373 		div_ndigits = DIV_GUARD_DIGITS;
7374 	/* Must be at least var1ndigits, too, to simplify data-loading loop */
7375 	if (div_ndigits < var1ndigits)
7376 		div_ndigits = var1ndigits;
7377 
7378 	/*
7379 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
7380 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
7381 	 * to avoid normalizing carries immediately.
7382 	 *
7383 	 * We start with div[] containing one zero digit followed by the
7384 	 * dividend's digits (plus appended zeroes to reach the desired precision
7385 	 * including guard digits).  Each step of the main loop computes an
7386 	 * (approximate) quotient digit and stores it into div[], removing one
7387 	 * position of dividend space.  A final pass of carry propagation takes
7388 	 * care of any mistaken quotient digits.
7389 	 */
7390 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
7391 	for (i = 0; i < var1ndigits; i++)
7392 		div[i + 1] = var1digits[i];
7393 
7394 	/*
7395 	 * We estimate each quotient digit using floating-point arithmetic, taking
7396 	 * the first four digits of the (current) dividend and divisor.  This must
7397 	 * be float to avoid overflow.  The quotient digits will generally be off
7398 	 * by no more than one from the exact answer.
7399 	 */
7400 	fdivisor = (double) var2digits[0];
7401 	for (i = 1; i < 4; i++)
7402 	{
7403 		fdivisor *= NBASE;
7404 		if (i < var2ndigits)
7405 			fdivisor += (double) var2digits[i];
7406 	}
7407 	fdivisorinverse = 1.0 / fdivisor;
7408 
7409 	/*
7410 	 * maxdiv tracks the maximum possible absolute value of any div[] entry;
7411 	 * when this threatens to exceed INT_MAX, we take the time to propagate
7412 	 * carries.  Furthermore, we need to ensure that overflow doesn't occur
7413 	 * during the carry propagation passes either.  The carry values may have
7414 	 * an absolute value as high as INT_MAX/NBASE + 1, so really we must
7415 	 * normalize when digits threaten to exceed INT_MAX - INT_MAX/NBASE - 1.
7416 	 *
7417 	 * To avoid overflow in maxdiv itself, it represents the max absolute
7418 	 * value divided by NBASE-1, ie, at the top of the loop it is known that
7419 	 * no div[] entry has an absolute value exceeding maxdiv * (NBASE-1).
7420 	 *
7421 	 * Actually, though, that holds good only for div[] entries after div[qi];
7422 	 * the adjustment done at the bottom of the loop may cause div[qi + 1] to
7423 	 * exceed the maxdiv limit, so that div[qi] in the next iteration is
7424 	 * beyond the limit.  This does not cause problems, as explained below.
7425 	 */
7426 	maxdiv = 1;
7427 
7428 	/*
7429 	 * Outer loop computes next quotient digit, which will go into div[qi]
7430 	 */
7431 	for (qi = 0; qi < div_ndigits; qi++)
7432 	{
7433 		/* Approximate the current dividend value */
7434 		fdividend = (double) div[qi];
7435 		for (i = 1; i < 4; i++)
7436 		{
7437 			fdividend *= NBASE;
7438 			if (qi + i <= div_ndigits)
7439 				fdividend += (double) div[qi + i];
7440 		}
7441 		/* Compute the (approximate) quotient digit */
7442 		fquotient = fdividend * fdivisorinverse;
7443 		qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7444 			(((int) fquotient) - 1);	/* truncate towards -infinity */
7445 
7446 		if (qdigit != 0)
7447 		{
7448 			/* Do we need to normalize now? */
7449 			maxdiv += Abs(qdigit);
7450 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
7451 			{
7452 				/* Yes, do it */
7453 				carry = 0;
7454 				for (i = div_ndigits; i > qi; i--)
7455 				{
7456 					newdig = div[i] + carry;
7457 					if (newdig < 0)
7458 					{
7459 						carry = -((-newdig - 1) / NBASE) - 1;
7460 						newdig -= carry * NBASE;
7461 					}
7462 					else if (newdig >= NBASE)
7463 					{
7464 						carry = newdig / NBASE;
7465 						newdig -= carry * NBASE;
7466 					}
7467 					else
7468 						carry = 0;
7469 					div[i] = newdig;
7470 				}
7471 				newdig = div[qi] + carry;
7472 				div[qi] = newdig;
7473 
7474 				/*
7475 				 * All the div[] digits except possibly div[qi] are now in the
7476 				 * range 0..NBASE-1.  We do not need to consider div[qi] in
7477 				 * the maxdiv value anymore, so we can reset maxdiv to 1.
7478 				 */
7479 				maxdiv = 1;
7480 
7481 				/*
7482 				 * Recompute the quotient digit since new info may have
7483 				 * propagated into the top four dividend digits
7484 				 */
7485 				fdividend = (double) div[qi];
7486 				for (i = 1; i < 4; i++)
7487 				{
7488 					fdividend *= NBASE;
7489 					if (qi + i <= div_ndigits)
7490 						fdividend += (double) div[qi + i];
7491 				}
7492 				/* Compute the (approximate) quotient digit */
7493 				fquotient = fdividend * fdivisorinverse;
7494 				qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7495 					(((int) fquotient) - 1);	/* truncate towards -infinity */
7496 				maxdiv += Abs(qdigit);
7497 			}
7498 
7499 			/*
7500 			 * Subtract off the appropriate multiple of the divisor.
7501 			 *
7502 			 * The digits beyond div[qi] cannot overflow, because we know they
7503 			 * will fall within the maxdiv limit.  As for div[qi] itself, note
7504 			 * that qdigit is approximately trunc(div[qi] / vardigits[0]),
7505 			 * which would make the new value simply div[qi] mod vardigits[0].
7506 			 * The lower-order terms in qdigit can change this result by not
7507 			 * more than about twice INT_MAX/NBASE, so overflow is impossible.
7508 			 */
7509 			if (qdigit != 0)
7510 			{
7511 				int			istop = Min(var2ndigits, div_ndigits - qi + 1);
7512 
7513 				for (i = 0; i < istop; i++)
7514 					div[qi + i] -= qdigit * var2digits[i];
7515 			}
7516 		}
7517 
7518 		/*
7519 		 * The dividend digit we are about to replace might still be nonzero.
7520 		 * Fold it into the next digit position.
7521 		 *
7522 		 * There is no risk of overflow here, although proving that requires
7523 		 * some care.  Much as with the argument for div[qi] not overflowing,
7524 		 * if we consider the first two terms in the numerator and denominator
7525 		 * of qdigit, we can see that the final value of div[qi + 1] will be
7526 		 * approximately a remainder mod (vardigits[0]*NBASE + vardigits[1]).
7527 		 * Accounting for the lower-order terms is a bit complicated but ends
7528 		 * up adding not much more than INT_MAX/NBASE to the possible range.
7529 		 * Thus, div[qi + 1] cannot overflow here, and in its role as div[qi]
7530 		 * in the next loop iteration, it can't be large enough to cause
7531 		 * overflow in the carry propagation step (if any), either.
7532 		 *
7533 		 * But having said that: div[qi] can be more than INT_MAX/NBASE, as
7534 		 * noted above, which means that the product div[qi] * NBASE *can*
7535 		 * overflow.  When that happens, adding it to div[qi + 1] will always
7536 		 * cause a canceling overflow so that the end result is correct.  We
7537 		 * could avoid the intermediate overflow by doing the multiplication
7538 		 * and addition in int64 arithmetic, but so far there appears no need.
7539 		 */
7540 		div[qi + 1] += div[qi] * NBASE;
7541 
7542 		div[qi] = qdigit;
7543 	}
7544 
7545 	/*
7546 	 * Approximate and store the last quotient digit (div[div_ndigits])
7547 	 */
7548 	fdividend = (double) div[qi];
7549 	for (i = 1; i < 4; i++)
7550 		fdividend *= NBASE;
7551 	fquotient = fdividend * fdivisorinverse;
7552 	qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7553 		(((int) fquotient) - 1);	/* truncate towards -infinity */
7554 	div[qi] = qdigit;
7555 
7556 	/*
7557 	 * Because the quotient digits might be off by one, some of them might be
7558 	 * -1 or NBASE at this point.  The represented value is correct in a
7559 	 * mathematical sense, but it doesn't look right.  We do a final carry
7560 	 * propagation pass to normalize the digits, which we combine with storing
7561 	 * the result digits into the output.  Note that this is still done at
7562 	 * full precision w/guard digits.
7563 	 */
7564 	alloc_var(result, div_ndigits + 1);
7565 	res_digits = result->digits;
7566 	carry = 0;
7567 	for (i = div_ndigits; i >= 0; i--)
7568 	{
7569 		newdig = div[i] + carry;
7570 		if (newdig < 0)
7571 		{
7572 			carry = -((-newdig - 1) / NBASE) - 1;
7573 			newdig -= carry * NBASE;
7574 		}
7575 		else if (newdig >= NBASE)
7576 		{
7577 			carry = newdig / NBASE;
7578 			newdig -= carry * NBASE;
7579 		}
7580 		else
7581 			carry = 0;
7582 		res_digits[i] = newdig;
7583 	}
7584 	Assert(carry == 0);
7585 
7586 	pfree(div);
7587 
7588 	/*
7589 	 * Finally, round the result to the requested precision.
7590 	 */
7591 	result->weight = res_weight;
7592 	result->sign = res_sign;
7593 
7594 	/* Round to target rscale (and set result->dscale) */
7595 	if (round)
7596 		round_var(result, rscale);
7597 	else
7598 		trunc_var(result, rscale);
7599 
7600 	/* Strip leading and trailing zeroes */
7601 	strip_var(result);
7602 }
7603 
7604 
7605 /*
7606  * Default scale selection for division
7607  *
7608  * Returns the appropriate result scale for the division result.
7609  */
7610 static int
select_div_scale(const NumericVar * var1,const NumericVar * var2)7611 select_div_scale(const NumericVar *var1, const NumericVar *var2)
7612 {
7613 	int			weight1,
7614 				weight2,
7615 				qweight,
7616 				i;
7617 	NumericDigit firstdigit1,
7618 				firstdigit2;
7619 	int			rscale;
7620 
7621 	/*
7622 	 * The result scale of a division isn't specified in any SQL standard. For
7623 	 * PostgreSQL we select a result scale that will give at least
7624 	 * NUMERIC_MIN_SIG_DIGITS significant digits, so that numeric gives a
7625 	 * result no less accurate than float8; but use a scale not less than
7626 	 * either input's display scale.
7627 	 */
7628 
7629 	/* Get the actual (normalized) weight and first digit of each input */
7630 
7631 	weight1 = 0;				/* values to use if var1 is zero */
7632 	firstdigit1 = 0;
7633 	for (i = 0; i < var1->ndigits; i++)
7634 	{
7635 		firstdigit1 = var1->digits[i];
7636 		if (firstdigit1 != 0)
7637 		{
7638 			weight1 = var1->weight - i;
7639 			break;
7640 		}
7641 	}
7642 
7643 	weight2 = 0;				/* values to use if var2 is zero */
7644 	firstdigit2 = 0;
7645 	for (i = 0; i < var2->ndigits; i++)
7646 	{
7647 		firstdigit2 = var2->digits[i];
7648 		if (firstdigit2 != 0)
7649 		{
7650 			weight2 = var2->weight - i;
7651 			break;
7652 		}
7653 	}
7654 
7655 	/*
7656 	 * Estimate weight of quotient.  If the two first digits are equal, we
7657 	 * can't be sure, but assume that var1 is less than var2.
7658 	 */
7659 	qweight = weight1 - weight2;
7660 	if (firstdigit1 <= firstdigit2)
7661 		qweight--;
7662 
7663 	/* Select result scale */
7664 	rscale = NUMERIC_MIN_SIG_DIGITS - qweight * DEC_DIGITS;
7665 	rscale = Max(rscale, var1->dscale);
7666 	rscale = Max(rscale, var2->dscale);
7667 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
7668 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
7669 
7670 	return rscale;
7671 }
7672 
7673 
7674 /*
7675  * mod_var() -
7676  *
7677  *	Calculate the modulo of two numerics at variable level
7678  */
7679 static void
mod_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)7680 mod_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
7681 {
7682 	NumericVar	tmp;
7683 
7684 	init_var(&tmp);
7685 
7686 	/* ---------
7687 	 * We do this using the equation
7688 	 *		mod(x,y) = x - trunc(x/y)*y
7689 	 * div_var can be persuaded to give us trunc(x/y) directly.
7690 	 * ----------
7691 	 */
7692 	div_var(var1, var2, &tmp, 0, false);
7693 
7694 	mul_var(var2, &tmp, &tmp, var2->dscale);
7695 
7696 	sub_var(var1, &tmp, result);
7697 
7698 	free_var(&tmp);
7699 }
7700 
7701 
7702 /*
7703  * ceil_var() -
7704  *
7705  *	Return the smallest integer greater than or equal to the argument
7706  *	on variable level
7707  */
7708 static void
ceil_var(const NumericVar * var,NumericVar * result)7709 ceil_var(const NumericVar *var, NumericVar *result)
7710 {
7711 	NumericVar	tmp;
7712 
7713 	init_var(&tmp);
7714 	set_var_from_var(var, &tmp);
7715 
7716 	trunc_var(&tmp, 0);
7717 
7718 	if (var->sign == NUMERIC_POS && cmp_var(var, &tmp) != 0)
7719 		add_var(&tmp, &const_one, &tmp);
7720 
7721 	set_var_from_var(&tmp, result);
7722 	free_var(&tmp);
7723 }
7724 
7725 
7726 /*
7727  * floor_var() -
7728  *
7729  *	Return the largest integer equal to or less than the argument
7730  *	on variable level
7731  */
7732 static void
floor_var(const NumericVar * var,NumericVar * result)7733 floor_var(const NumericVar *var, NumericVar *result)
7734 {
7735 	NumericVar	tmp;
7736 
7737 	init_var(&tmp);
7738 	set_var_from_var(var, &tmp);
7739 
7740 	trunc_var(&tmp, 0);
7741 
7742 	if (var->sign == NUMERIC_NEG && cmp_var(var, &tmp) != 0)
7743 		sub_var(&tmp, &const_one, &tmp);
7744 
7745 	set_var_from_var(&tmp, result);
7746 	free_var(&tmp);
7747 }
7748 
7749 
7750 /*
7751  * sqrt_var() -
7752  *
7753  *	Compute the square root of x using Newton's algorithm
7754  */
7755 static void
sqrt_var(const NumericVar * arg,NumericVar * result,int rscale)7756 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
7757 {
7758 	NumericVar	tmp_arg;
7759 	NumericVar	tmp_val;
7760 	NumericVar	last_val;
7761 	int			local_rscale;
7762 	int			stat;
7763 
7764 	local_rscale = rscale + 8;
7765 
7766 	stat = cmp_var(arg, &const_zero);
7767 	if (stat == 0)
7768 	{
7769 		zero_var(result);
7770 		result->dscale = rscale;
7771 		return;
7772 	}
7773 
7774 	/*
7775 	 * SQL2003 defines sqrt() in terms of power, so we need to emit the right
7776 	 * SQLSTATE error code if the operand is negative.
7777 	 */
7778 	if (stat < 0)
7779 		ereport(ERROR,
7780 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
7781 				 errmsg("cannot take square root of a negative number")));
7782 
7783 	init_var(&tmp_arg);
7784 	init_var(&tmp_val);
7785 	init_var(&last_val);
7786 
7787 	/* Copy arg in case it is the same var as result */
7788 	set_var_from_var(arg, &tmp_arg);
7789 
7790 	/*
7791 	 * Initialize the result to the first guess
7792 	 */
7793 	alloc_var(result, 1);
7794 	result->digits[0] = tmp_arg.digits[0] / 2;
7795 	if (result->digits[0] == 0)
7796 		result->digits[0] = 1;
7797 	result->weight = tmp_arg.weight / 2;
7798 	result->sign = NUMERIC_POS;
7799 
7800 	set_var_from_var(result, &last_val);
7801 
7802 	for (;;)
7803 	{
7804 		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
7805 
7806 		add_var(result, &tmp_val, result);
7807 		mul_var(result, &const_zero_point_five, result, local_rscale);
7808 
7809 		if (cmp_var(&last_val, result) == 0)
7810 			break;
7811 		set_var_from_var(result, &last_val);
7812 	}
7813 
7814 	free_var(&last_val);
7815 	free_var(&tmp_val);
7816 	free_var(&tmp_arg);
7817 
7818 	/* Round to requested precision */
7819 	round_var(result, rscale);
7820 }
7821 
7822 
7823 /*
7824  * exp_var() -
7825  *
7826  *	Raise e to the power of x, computed to rscale fractional digits
7827  */
7828 static void
exp_var(const NumericVar * arg,NumericVar * result,int rscale)7829 exp_var(const NumericVar *arg, NumericVar *result, int rscale)
7830 {
7831 	NumericVar	x;
7832 	NumericVar	elem;
7833 	NumericVar	ni;
7834 	double		val;
7835 	int			dweight;
7836 	int			ndiv2;
7837 	int			sig_digits;
7838 	int			local_rscale;
7839 
7840 	init_var(&x);
7841 	init_var(&elem);
7842 	init_var(&ni);
7843 
7844 	set_var_from_var(arg, &x);
7845 
7846 	/*
7847 	 * Estimate the dweight of the result using floating point arithmetic, so
7848 	 * that we can choose an appropriate local rscale for the calculation.
7849 	 */
7850 	val = numericvar_to_double_no_overflow(&x);
7851 
7852 	/* Guard against overflow/underflow */
7853 	/* If you change this limit, see also power_var()'s limit */
7854 	if (Abs(val) >= NUMERIC_MAX_RESULT_SCALE * 3)
7855 	{
7856 		if (val > 0)
7857 			ereport(ERROR,
7858 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
7859 					 errmsg("value overflows numeric format")));
7860 		zero_var(result);
7861 		result->dscale = rscale;
7862 		return;
7863 	}
7864 
7865 	/* decimal weight = log10(e^x) = x * log10(e) */
7866 	dweight = (int) (val * 0.434294481903252);
7867 
7868 	/*
7869 	 * Reduce x to the range -0.01 <= x <= 0.01 (approximately) by dividing by
7870 	 * 2^n, to improve the convergence rate of the Taylor series.
7871 	 */
7872 	if (Abs(val) > 0.01)
7873 	{
7874 		NumericVar	tmp;
7875 
7876 		init_var(&tmp);
7877 		set_var_from_var(&const_two, &tmp);
7878 
7879 		ndiv2 = 1;
7880 		val /= 2;
7881 
7882 		while (Abs(val) > 0.01)
7883 		{
7884 			ndiv2++;
7885 			val /= 2;
7886 			add_var(&tmp, &tmp, &tmp);
7887 		}
7888 
7889 		local_rscale = x.dscale + ndiv2;
7890 		div_var_fast(&x, &tmp, &x, local_rscale, true);
7891 
7892 		free_var(&tmp);
7893 	}
7894 	else
7895 		ndiv2 = 0;
7896 
7897 	/*
7898 	 * Set the scale for the Taylor series expansion.  The final result has
7899 	 * (dweight + rscale + 1) significant digits.  In addition, we have to
7900 	 * raise the Taylor series result to the power 2^ndiv2, which introduces
7901 	 * an error of up to around log10(2^ndiv2) digits, so work with this many
7902 	 * extra digits of precision (plus a few more for good measure).
7903 	 */
7904 	sig_digits = 1 + dweight + rscale + (int) (ndiv2 * 0.301029995663981);
7905 	sig_digits = Max(sig_digits, 0) + 8;
7906 
7907 	local_rscale = sig_digits - 1;
7908 
7909 	/*
7910 	 * Use the Taylor series
7911 	 *
7912 	 * exp(x) = 1 + x + x^2/2! + x^3/3! + ...
7913 	 *
7914 	 * Given the limited range of x, this should converge reasonably quickly.
7915 	 * We run the series until the terms fall below the local_rscale limit.
7916 	 */
7917 	add_var(&const_one, &x, result);
7918 
7919 	mul_var(&x, &x, &elem, local_rscale);
7920 	set_var_from_var(&const_two, &ni);
7921 	div_var_fast(&elem, &ni, &elem, local_rscale, true);
7922 
7923 	while (elem.ndigits != 0)
7924 	{
7925 		add_var(result, &elem, result);
7926 
7927 		mul_var(&elem, &x, &elem, local_rscale);
7928 		add_var(&ni, &const_one, &ni);
7929 		div_var_fast(&elem, &ni, &elem, local_rscale, true);
7930 	}
7931 
7932 	/*
7933 	 * Compensate for the argument range reduction.  Since the weight of the
7934 	 * result doubles with each multiplication, we can reduce the local rscale
7935 	 * as we proceed.
7936 	 */
7937 	while (ndiv2-- > 0)
7938 	{
7939 		local_rscale = sig_digits - result->weight * 2 * DEC_DIGITS;
7940 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
7941 		mul_var(result, result, result, local_rscale);
7942 	}
7943 
7944 	/* Round to requested rscale */
7945 	round_var(result, rscale);
7946 
7947 	free_var(&x);
7948 	free_var(&elem);
7949 	free_var(&ni);
7950 }
7951 
7952 
7953 /*
7954  * Estimate the dweight of the most significant decimal digit of the natural
7955  * logarithm of a number.
7956  *
7957  * Essentially, we're approximating log10(abs(ln(var))).  This is used to
7958  * determine the appropriate rscale when computing natural logarithms.
7959  */
7960 static int
estimate_ln_dweight(const NumericVar * var)7961 estimate_ln_dweight(const NumericVar *var)
7962 {
7963 	int			ln_dweight;
7964 
7965 	if (cmp_var(var, &const_zero_point_nine) >= 0 &&
7966 		cmp_var(var, &const_one_point_one) <= 0)
7967 	{
7968 		/*
7969 		 * 0.9 <= var <= 1.1
7970 		 *
7971 		 * ln(var) has a negative weight (possibly very large).  To get a
7972 		 * reasonably accurate result, estimate it using ln(1+x) ~= x.
7973 		 */
7974 		NumericVar	x;
7975 
7976 		init_var(&x);
7977 		sub_var(var, &const_one, &x);
7978 
7979 		if (x.ndigits > 0)
7980 		{
7981 			/* Use weight of most significant decimal digit of x */
7982 			ln_dweight = x.weight * DEC_DIGITS + (int) log10(x.digits[0]);
7983 		}
7984 		else
7985 		{
7986 			/* x = 0.  Since ln(1) = 0 exactly, we don't need extra digits */
7987 			ln_dweight = 0;
7988 		}
7989 
7990 		free_var(&x);
7991 	}
7992 	else
7993 	{
7994 		/*
7995 		 * Estimate the logarithm using the first couple of digits from the
7996 		 * input number.  This will give an accurate result whenever the input
7997 		 * is not too close to 1.
7998 		 */
7999 		if (var->ndigits > 0)
8000 		{
8001 			int			digits;
8002 			int			dweight;
8003 			double		ln_var;
8004 
8005 			digits = var->digits[0];
8006 			dweight = var->weight * DEC_DIGITS;
8007 
8008 			if (var->ndigits > 1)
8009 			{
8010 				digits = digits * NBASE + var->digits[1];
8011 				dweight -= DEC_DIGITS;
8012 			}
8013 
8014 			/*----------
8015 			 * We have var ~= digits * 10^dweight
8016 			 * so ln(var) ~= ln(digits) + dweight * ln(10)
8017 			 *----------
8018 			 */
8019 			ln_var = log((double) digits) + dweight * 2.302585092994046;
8020 			ln_dweight = (int) log10(Abs(ln_var));
8021 		}
8022 		else
8023 		{
8024 			/* Caller should fail on ln(0), but for the moment return zero */
8025 			ln_dweight = 0;
8026 		}
8027 	}
8028 
8029 	return ln_dweight;
8030 }
8031 
8032 
8033 /*
8034  * ln_var() -
8035  *
8036  *	Compute the natural log of x
8037  */
8038 static void
ln_var(const NumericVar * arg,NumericVar * result,int rscale)8039 ln_var(const NumericVar *arg, NumericVar *result, int rscale)
8040 {
8041 	NumericVar	x;
8042 	NumericVar	xx;
8043 	NumericVar	ni;
8044 	NumericVar	elem;
8045 	NumericVar	fact;
8046 	int			local_rscale;
8047 	int			cmp;
8048 
8049 	cmp = cmp_var(arg, &const_zero);
8050 	if (cmp == 0)
8051 		ereport(ERROR,
8052 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
8053 				 errmsg("cannot take logarithm of zero")));
8054 	else if (cmp < 0)
8055 		ereport(ERROR,
8056 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
8057 				 errmsg("cannot take logarithm of a negative number")));
8058 
8059 	init_var(&x);
8060 	init_var(&xx);
8061 	init_var(&ni);
8062 	init_var(&elem);
8063 	init_var(&fact);
8064 
8065 	set_var_from_var(arg, &x);
8066 	set_var_from_var(&const_two, &fact);
8067 
8068 	/*
8069 	 * Reduce input into range 0.9 < x < 1.1 with repeated sqrt() operations.
8070 	 *
8071 	 * The final logarithm will have up to around rscale+6 significant digits.
8072 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
8073 	 * rscale as we work so that we keep this many significant digits at each
8074 	 * step (plus a few more for good measure).
8075 	 */
8076 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
8077 	{
8078 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
8079 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8080 		sqrt_var(&x, &x, local_rscale);
8081 		mul_var(&fact, &const_two, &fact, 0);
8082 	}
8083 	while (cmp_var(&x, &const_one_point_one) >= 0)
8084 	{
8085 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
8086 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8087 		sqrt_var(&x, &x, local_rscale);
8088 		mul_var(&fact, &const_two, &fact, 0);
8089 	}
8090 
8091 	/*
8092 	 * We use the Taylor series for 0.5 * ln((1+z)/(1-z)),
8093 	 *
8094 	 * z + z^3/3 + z^5/5 + ...
8095 	 *
8096 	 * where z = (x-1)/(x+1) is in the range (approximately) -0.053 .. 0.048
8097 	 * due to the above range-reduction of x.
8098 	 *
8099 	 * The convergence of this is not as fast as one would like, but is
8100 	 * tolerable given that z is small.
8101 	 */
8102 	local_rscale = rscale + 8;
8103 
8104 	sub_var(&x, &const_one, result);
8105 	add_var(&x, &const_one, &elem);
8106 	div_var_fast(result, &elem, result, local_rscale, true);
8107 	set_var_from_var(result, &xx);
8108 	mul_var(result, result, &x, local_rscale);
8109 
8110 	set_var_from_var(&const_one, &ni);
8111 
8112 	for (;;)
8113 	{
8114 		add_var(&ni, &const_two, &ni);
8115 		mul_var(&xx, &x, &xx, local_rscale);
8116 		div_var_fast(&xx, &ni, &elem, local_rscale, true);
8117 
8118 		if (elem.ndigits == 0)
8119 			break;
8120 
8121 		add_var(result, &elem, result);
8122 
8123 		if (elem.weight < (result->weight - local_rscale * 2 / DEC_DIGITS))
8124 			break;
8125 	}
8126 
8127 	/* Compensate for argument range reduction, round to requested rscale */
8128 	mul_var(result, &fact, result, rscale);
8129 
8130 	free_var(&x);
8131 	free_var(&xx);
8132 	free_var(&ni);
8133 	free_var(&elem);
8134 	free_var(&fact);
8135 }
8136 
8137 
8138 /*
8139  * log_var() -
8140  *
8141  *	Compute the logarithm of num in a given base.
8142  *
8143  *	Note: this routine chooses dscale of the result.
8144  */
8145 static void
log_var(const NumericVar * base,const NumericVar * num,NumericVar * result)8146 log_var(const NumericVar *base, const NumericVar *num, NumericVar *result)
8147 {
8148 	NumericVar	ln_base;
8149 	NumericVar	ln_num;
8150 	int			ln_base_dweight;
8151 	int			ln_num_dweight;
8152 	int			result_dweight;
8153 	int			rscale;
8154 	int			ln_base_rscale;
8155 	int			ln_num_rscale;
8156 
8157 	init_var(&ln_base);
8158 	init_var(&ln_num);
8159 
8160 	/* Estimated dweights of ln(base), ln(num) and the final result */
8161 	ln_base_dweight = estimate_ln_dweight(base);
8162 	ln_num_dweight = estimate_ln_dweight(num);
8163 	result_dweight = ln_num_dweight - ln_base_dweight;
8164 
8165 	/*
8166 	 * Select the scale of the result so that it will have at least
8167 	 * NUMERIC_MIN_SIG_DIGITS significant digits and is not less than either
8168 	 * input's display scale.
8169 	 */
8170 	rscale = NUMERIC_MIN_SIG_DIGITS - result_dweight;
8171 	rscale = Max(rscale, base->dscale);
8172 	rscale = Max(rscale, num->dscale);
8173 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8174 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8175 
8176 	/*
8177 	 * Set the scales for ln(base) and ln(num) so that they each have more
8178 	 * significant digits than the final result.
8179 	 */
8180 	ln_base_rscale = rscale + result_dweight - ln_base_dweight + 8;
8181 	ln_base_rscale = Max(ln_base_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8182 
8183 	ln_num_rscale = rscale + result_dweight - ln_num_dweight + 8;
8184 	ln_num_rscale = Max(ln_num_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8185 
8186 	/* Form natural logarithms */
8187 	ln_var(base, &ln_base, ln_base_rscale);
8188 	ln_var(num, &ln_num, ln_num_rscale);
8189 
8190 	/* Divide and round to the required scale */
8191 	div_var_fast(&ln_num, &ln_base, result, rscale, true);
8192 
8193 	free_var(&ln_num);
8194 	free_var(&ln_base);
8195 }
8196 
8197 
8198 /*
8199  * power_var() -
8200  *
8201  *	Raise base to the power of exp
8202  *
8203  *	Note: this routine chooses dscale of the result.
8204  */
8205 static void
power_var(const NumericVar * base,const NumericVar * exp,NumericVar * result)8206 power_var(const NumericVar *base, const NumericVar *exp, NumericVar *result)
8207 {
8208 	int			res_sign;
8209 	NumericVar	abs_base;
8210 	NumericVar	ln_base;
8211 	NumericVar	ln_num;
8212 	int			ln_dweight;
8213 	int			rscale;
8214 	int			sig_digits;
8215 	int			local_rscale;
8216 	double		val;
8217 
8218 	/* If exp can be represented as an integer, use power_var_int */
8219 	if (exp->ndigits == 0 || exp->ndigits <= exp->weight + 1)
8220 	{
8221 		/* exact integer, but does it fit in int? */
8222 		int64		expval64;
8223 
8224 		if (numericvar_to_int64(exp, &expval64))
8225 		{
8226 			if (expval64 >= PG_INT32_MIN && expval64 <= PG_INT32_MAX)
8227 			{
8228 				/* Okay, select rscale */
8229 				rscale = NUMERIC_MIN_SIG_DIGITS;
8230 				rscale = Max(rscale, base->dscale);
8231 				rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8232 				rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8233 
8234 				power_var_int(base, (int) expval64, result, rscale);
8235 				return;
8236 			}
8237 		}
8238 	}
8239 
8240 	/*
8241 	 * This avoids log(0) for cases of 0 raised to a non-integer.  0 ^ 0 is
8242 	 * handled by power_var_int().
8243 	 */
8244 	if (cmp_var(base, &const_zero) == 0)
8245 	{
8246 		set_var_from_var(&const_zero, result);
8247 		result->dscale = NUMERIC_MIN_SIG_DIGITS;	/* no need to round */
8248 		return;
8249 	}
8250 
8251 	init_var(&abs_base);
8252 	init_var(&ln_base);
8253 	init_var(&ln_num);
8254 
8255 	/*
8256 	 * If base is negative, insist that exp be an integer.  The result is then
8257 	 * positive if exp is even and negative if exp is odd.
8258 	 */
8259 	if (base->sign == NUMERIC_NEG)
8260 	{
8261 		/*
8262 		 * Check that exp is an integer.  This error code is defined by the
8263 		 * SQL standard, and matches other errors in numeric_power().
8264 		 */
8265 		if (exp->ndigits > 0 && exp->ndigits > exp->weight + 1)
8266 			ereport(ERROR,
8267 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
8268 					 errmsg("a negative number raised to a non-integer power yields a complex result")));
8269 
8270 		/* Test if exp is odd or even */
8271 		if (exp->ndigits > 0 && exp->ndigits == exp->weight + 1 &&
8272 			(exp->digits[exp->ndigits - 1] & 1))
8273 			res_sign = NUMERIC_NEG;
8274 		else
8275 			res_sign = NUMERIC_POS;
8276 
8277 		/* Then work with abs(base) below */
8278 		set_var_from_var(base, &abs_base);
8279 		abs_base.sign = NUMERIC_POS;
8280 		base = &abs_base;
8281 	}
8282 	else
8283 		res_sign = NUMERIC_POS;
8284 
8285 	/*----------
8286 	 * Decide on the scale for the ln() calculation.  For this we need an
8287 	 * estimate of the weight of the result, which we obtain by doing an
8288 	 * initial low-precision calculation of exp * ln(base).
8289 	 *
8290 	 * We want result = e ^ (exp * ln(base))
8291 	 * so result dweight = log10(result) = exp * ln(base) * log10(e)
8292 	 *
8293 	 * We also perform a crude overflow test here so that we can exit early if
8294 	 * the full-precision result is sure to overflow, and to guard against
8295 	 * integer overflow when determining the scale for the real calculation.
8296 	 * exp_var() supports inputs up to NUMERIC_MAX_RESULT_SCALE * 3, so the
8297 	 * result will overflow if exp * ln(base) >= NUMERIC_MAX_RESULT_SCALE * 3.
8298 	 * Since the values here are only approximations, we apply a small fuzz
8299 	 * factor to this overflow test and let exp_var() determine the exact
8300 	 * overflow threshold so that it is consistent for all inputs.
8301 	 *----------
8302 	 */
8303 	ln_dweight = estimate_ln_dweight(base);
8304 
8305 	/*
8306 	 * Set the scale for the low-precision calculation, computing ln(base) to
8307 	 * around 8 significant digits.  Note that ln_dweight may be as small as
8308 	 * -SHRT_MAX, so the scale may exceed NUMERIC_MAX_DISPLAY_SCALE here.
8309 	 */
8310 	local_rscale = 8 - ln_dweight;
8311 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8312 
8313 	ln_var(base, &ln_base, local_rscale);
8314 
8315 	mul_var(&ln_base, exp, &ln_num, local_rscale);
8316 
8317 	val = numericvar_to_double_no_overflow(&ln_num);
8318 
8319 	/* initial overflow/underflow test with fuzz factor */
8320 	if (Abs(val) > NUMERIC_MAX_RESULT_SCALE * 3.01)
8321 	{
8322 		if (val > 0)
8323 			ereport(ERROR,
8324 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8325 					 errmsg("value overflows numeric format")));
8326 		zero_var(result);
8327 		result->dscale = NUMERIC_MAX_DISPLAY_SCALE;
8328 		return;
8329 	}
8330 
8331 	val *= 0.434294481903252;	/* approximate decimal result weight */
8332 
8333 	/* choose the result scale */
8334 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
8335 	rscale = Max(rscale, base->dscale);
8336 	rscale = Max(rscale, exp->dscale);
8337 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8338 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8339 
8340 	/* significant digits required in the result */
8341 	sig_digits = rscale + (int) val;
8342 	sig_digits = Max(sig_digits, 0);
8343 
8344 	/* set the scale for the real exp * ln(base) calculation */
8345 	local_rscale = sig_digits - ln_dweight + 8;
8346 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8347 
8348 	/* and do the real calculation */
8349 
8350 	ln_var(base, &ln_base, local_rscale);
8351 
8352 	mul_var(&ln_base, exp, &ln_num, local_rscale);
8353 
8354 	exp_var(&ln_num, result, rscale);
8355 
8356 	if (res_sign == NUMERIC_NEG && result->ndigits > 0)
8357 		result->sign = NUMERIC_NEG;
8358 
8359 	free_var(&ln_num);
8360 	free_var(&ln_base);
8361 	free_var(&abs_base);
8362 }
8363 
8364 /*
8365  * power_var_int() -
8366  *
8367  *	Raise base to the power of exp, where exp is an integer.
8368  */
8369 static void
power_var_int(const NumericVar * base,int exp,NumericVar * result,int rscale)8370 power_var_int(const NumericVar *base, int exp, NumericVar *result, int rscale)
8371 {
8372 	double		f;
8373 	int			p;
8374 	int			i;
8375 	int			sig_digits;
8376 	unsigned int mask;
8377 	bool		neg;
8378 	NumericVar	base_prod;
8379 	int			local_rscale;
8380 
8381 	/* Handle some common special cases, as well as corner cases */
8382 	switch (exp)
8383 	{
8384 		case 0:
8385 
8386 			/*
8387 			 * While 0 ^ 0 can be either 1 or indeterminate (error), we treat
8388 			 * it as 1 because most programming languages do this. SQL:2003
8389 			 * also requires a return value of 1.
8390 			 * https://en.wikipedia.org/wiki/Exponentiation#Zero_to_the_zero_power
8391 			 */
8392 			set_var_from_var(&const_one, result);
8393 			result->dscale = rscale;	/* no need to round */
8394 			return;
8395 		case 1:
8396 			set_var_from_var(base, result);
8397 			round_var(result, rscale);
8398 			return;
8399 		case -1:
8400 			div_var(&const_one, base, result, rscale, true);
8401 			return;
8402 		case 2:
8403 			mul_var(base, base, result, rscale);
8404 			return;
8405 		default:
8406 			break;
8407 	}
8408 
8409 	/* Handle the special case where the base is zero */
8410 	if (base->ndigits == 0)
8411 	{
8412 		if (exp < 0)
8413 			ereport(ERROR,
8414 					(errcode(ERRCODE_DIVISION_BY_ZERO),
8415 					 errmsg("division by zero")));
8416 		zero_var(result);
8417 		result->dscale = rscale;
8418 		return;
8419 	}
8420 
8421 	/*
8422 	 * The general case repeatedly multiplies base according to the bit
8423 	 * pattern of exp.
8424 	 *
8425 	 * First we need to estimate the weight of the result so that we know how
8426 	 * many significant digits are needed.
8427 	 */
8428 	f = base->digits[0];
8429 	p = base->weight * DEC_DIGITS;
8430 
8431 	for (i = 1; i < base->ndigits && i * DEC_DIGITS < 16; i++)
8432 	{
8433 		f = f * NBASE + base->digits[i];
8434 		p -= DEC_DIGITS;
8435 	}
8436 
8437 	/*----------
8438 	 * We have base ~= f * 10^p
8439 	 * so log10(result) = log10(base^exp) ~= exp * (log10(f) + p)
8440 	 *----------
8441 	 */
8442 	f = exp * (log10(f) + p);
8443 
8444 	/*
8445 	 * Apply crude overflow/underflow tests so we can exit early if the result
8446 	 * certainly will overflow/underflow.
8447 	 */
8448 	if (f > 3 * SHRT_MAX * DEC_DIGITS)
8449 		ereport(ERROR,
8450 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8451 				 errmsg("value overflows numeric format")));
8452 	if (f + 1 < -rscale || f + 1 < -NUMERIC_MAX_DISPLAY_SCALE)
8453 	{
8454 		zero_var(result);
8455 		result->dscale = rscale;
8456 		return;
8457 	}
8458 
8459 	/*
8460 	 * Approximate number of significant digits in the result.  Note that the
8461 	 * underflow test above means that this is necessarily >= 0.
8462 	 */
8463 	sig_digits = 1 + rscale + (int) f;
8464 
8465 	/*
8466 	 * The multiplications to produce the result may introduce an error of up
8467 	 * to around log10(abs(exp)) digits, so work with this many extra digits
8468 	 * of precision (plus a few more for good measure).
8469 	 */
8470 	sig_digits += (int) log(fabs((double) exp)) + 8;
8471 
8472 	/*
8473 	 * Now we can proceed with the multiplications.
8474 	 */
8475 	neg = (exp < 0);
8476 	mask = Abs(exp);
8477 
8478 	init_var(&base_prod);
8479 	set_var_from_var(base, &base_prod);
8480 
8481 	if (mask & 1)
8482 		set_var_from_var(base, result);
8483 	else
8484 		set_var_from_var(&const_one, result);
8485 
8486 	while ((mask >>= 1) > 0)
8487 	{
8488 		/*
8489 		 * Do the multiplications using rscales large enough to hold the
8490 		 * results to the required number of significant digits, but don't
8491 		 * waste time by exceeding the scales of the numbers themselves.
8492 		 */
8493 		local_rscale = sig_digits - 2 * base_prod.weight * DEC_DIGITS;
8494 		local_rscale = Min(local_rscale, 2 * base_prod.dscale);
8495 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8496 
8497 		mul_var(&base_prod, &base_prod, &base_prod, local_rscale);
8498 
8499 		if (mask & 1)
8500 		{
8501 			local_rscale = sig_digits -
8502 				(base_prod.weight + result->weight) * DEC_DIGITS;
8503 			local_rscale = Min(local_rscale,
8504 							   base_prod.dscale + result->dscale);
8505 			local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8506 
8507 			mul_var(&base_prod, result, result, local_rscale);
8508 		}
8509 
8510 		/*
8511 		 * When abs(base) > 1, the number of digits to the left of the decimal
8512 		 * point in base_prod doubles at each iteration, so if exp is large we
8513 		 * could easily spend large amounts of time and memory space doing the
8514 		 * multiplications.  But once the weight exceeds what will fit in
8515 		 * int16, the final result is guaranteed to overflow (or underflow, if
8516 		 * exp < 0), so we can give up before wasting too many cycles.
8517 		 */
8518 		if (base_prod.weight > SHRT_MAX || result->weight > SHRT_MAX)
8519 		{
8520 			/* overflow, unless neg, in which case result should be 0 */
8521 			if (!neg)
8522 				ereport(ERROR,
8523 						(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8524 						 errmsg("value overflows numeric format")));
8525 			zero_var(result);
8526 			neg = false;
8527 			break;
8528 		}
8529 	}
8530 
8531 	free_var(&base_prod);
8532 
8533 	/* Compensate for input sign, and round to requested rscale */
8534 	if (neg)
8535 		div_var_fast(&const_one, result, result, rscale, true);
8536 	else
8537 		round_var(result, rscale);
8538 }
8539 
8540 /*
8541  * power_ten_int() -
8542  *
8543  *	Raise ten to the power of exp, where exp is an integer.  Note that unlike
8544  *	power_var_int(), this does no overflow/underflow checking or rounding.
8545  */
8546 static void
power_ten_int(int exp,NumericVar * result)8547 power_ten_int(int exp, NumericVar *result)
8548 {
8549 	/* Construct the result directly, starting from 10^0 = 1 */
8550 	set_var_from_var(&const_one, result);
8551 
8552 	/* Scale needed to represent the result exactly */
8553 	result->dscale = exp < 0 ? -exp : 0;
8554 
8555 	/* Base-NBASE weight of result and remaining exponent */
8556 	if (exp >= 0)
8557 		result->weight = exp / DEC_DIGITS;
8558 	else
8559 		result->weight = (exp + 1) / DEC_DIGITS - 1;
8560 
8561 	exp -= result->weight * DEC_DIGITS;
8562 
8563 	/* Final adjustment of the result's single NBASE digit */
8564 	while (exp-- > 0)
8565 		result->digits[0] *= 10;
8566 }
8567 
8568 
8569 /* ----------------------------------------------------------------------
8570  *
8571  * Following are the lowest level functions that operate unsigned
8572  * on the variable level
8573  *
8574  * ----------------------------------------------------------------------
8575  */
8576 
8577 
8578 /* ----------
8579  * cmp_abs() -
8580  *
8581  *	Compare the absolute values of var1 and var2
8582  *	Returns:	-1 for ABS(var1) < ABS(var2)
8583  *				0  for ABS(var1) == ABS(var2)
8584  *				1  for ABS(var1) > ABS(var2)
8585  * ----------
8586  */
8587 static int
cmp_abs(const NumericVar * var1,const NumericVar * var2)8588 cmp_abs(const NumericVar *var1, const NumericVar *var2)
8589 {
8590 	return cmp_abs_common(var1->digits, var1->ndigits, var1->weight,
8591 						  var2->digits, var2->ndigits, var2->weight);
8592 }
8593 
8594 /* ----------
8595  * cmp_abs_common() -
8596  *
8597  *	Main routine of cmp_abs(). This function can be used by both
8598  *	NumericVar and Numeric.
8599  * ----------
8600  */
8601 static int
cmp_abs_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,const NumericDigit * var2digits,int var2ndigits,int var2weight)8602 cmp_abs_common(const NumericDigit *var1digits, int var1ndigits, int var1weight,
8603 			   const NumericDigit *var2digits, int var2ndigits, int var2weight)
8604 {
8605 	int			i1 = 0;
8606 	int			i2 = 0;
8607 
8608 	/* Check any digits before the first common digit */
8609 
8610 	while (var1weight > var2weight && i1 < var1ndigits)
8611 	{
8612 		if (var1digits[i1++] != 0)
8613 			return 1;
8614 		var1weight--;
8615 	}
8616 	while (var2weight > var1weight && i2 < var2ndigits)
8617 	{
8618 		if (var2digits[i2++] != 0)
8619 			return -1;
8620 		var2weight--;
8621 	}
8622 
8623 	/* At this point, either w1 == w2 or we've run out of digits */
8624 
8625 	if (var1weight == var2weight)
8626 	{
8627 		while (i1 < var1ndigits && i2 < var2ndigits)
8628 		{
8629 			int			stat = var1digits[i1++] - var2digits[i2++];
8630 
8631 			if (stat)
8632 			{
8633 				if (stat > 0)
8634 					return 1;
8635 				return -1;
8636 			}
8637 		}
8638 	}
8639 
8640 	/*
8641 	 * At this point, we've run out of digits on one side or the other; so any
8642 	 * remaining nonzero digits imply that side is larger
8643 	 */
8644 	while (i1 < var1ndigits)
8645 	{
8646 		if (var1digits[i1++] != 0)
8647 			return 1;
8648 	}
8649 	while (i2 < var2ndigits)
8650 	{
8651 		if (var2digits[i2++] != 0)
8652 			return -1;
8653 	}
8654 
8655 	return 0;
8656 }
8657 
8658 
8659 /*
8660  * add_abs() -
8661  *
8662  *	Add the absolute values of two variables into result.
8663  *	result might point to one of the operands without danger.
8664  */
8665 static void
add_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8666 add_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8667 {
8668 	NumericDigit *res_buf;
8669 	NumericDigit *res_digits;
8670 	int			res_ndigits;
8671 	int			res_weight;
8672 	int			res_rscale,
8673 				rscale1,
8674 				rscale2;
8675 	int			res_dscale;
8676 	int			i,
8677 				i1,
8678 				i2;
8679 	int			carry = 0;
8680 
8681 	/* copy these values into local vars for speed in inner loop */
8682 	int			var1ndigits = var1->ndigits;
8683 	int			var2ndigits = var2->ndigits;
8684 	NumericDigit *var1digits = var1->digits;
8685 	NumericDigit *var2digits = var2->digits;
8686 
8687 	res_weight = Max(var1->weight, var2->weight) + 1;
8688 
8689 	res_dscale = Max(var1->dscale, var2->dscale);
8690 
8691 	/* Note: here we are figuring rscale in base-NBASE digits */
8692 	rscale1 = var1->ndigits - var1->weight - 1;
8693 	rscale2 = var2->ndigits - var2->weight - 1;
8694 	res_rscale = Max(rscale1, rscale2);
8695 
8696 	res_ndigits = res_rscale + res_weight + 1;
8697 	if (res_ndigits <= 0)
8698 		res_ndigits = 1;
8699 
8700 	res_buf = digitbuf_alloc(res_ndigits + 1);
8701 	res_buf[0] = 0;				/* spare digit for later rounding */
8702 	res_digits = res_buf + 1;
8703 
8704 	i1 = res_rscale + var1->weight + 1;
8705 	i2 = res_rscale + var2->weight + 1;
8706 	for (i = res_ndigits - 1; i >= 0; i--)
8707 	{
8708 		i1--;
8709 		i2--;
8710 		if (i1 >= 0 && i1 < var1ndigits)
8711 			carry += var1digits[i1];
8712 		if (i2 >= 0 && i2 < var2ndigits)
8713 			carry += var2digits[i2];
8714 
8715 		if (carry >= NBASE)
8716 		{
8717 			res_digits[i] = carry - NBASE;
8718 			carry = 1;
8719 		}
8720 		else
8721 		{
8722 			res_digits[i] = carry;
8723 			carry = 0;
8724 		}
8725 	}
8726 
8727 	Assert(carry == 0);			/* else we failed to allow for carry out */
8728 
8729 	digitbuf_free(result->buf);
8730 	result->ndigits = res_ndigits;
8731 	result->buf = res_buf;
8732 	result->digits = res_digits;
8733 	result->weight = res_weight;
8734 	result->dscale = res_dscale;
8735 
8736 	/* Remove leading/trailing zeroes */
8737 	strip_var(result);
8738 }
8739 
8740 
8741 /*
8742  * sub_abs()
8743  *
8744  *	Subtract the absolute value of var2 from the absolute value of var1
8745  *	and store in result. result might point to one of the operands
8746  *	without danger.
8747  *
8748  *	ABS(var1) MUST BE GREATER OR EQUAL ABS(var2) !!!
8749  */
8750 static void
sub_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8751 sub_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8752 {
8753 	NumericDigit *res_buf;
8754 	NumericDigit *res_digits;
8755 	int			res_ndigits;
8756 	int			res_weight;
8757 	int			res_rscale,
8758 				rscale1,
8759 				rscale2;
8760 	int			res_dscale;
8761 	int			i,
8762 				i1,
8763 				i2;
8764 	int			borrow = 0;
8765 
8766 	/* copy these values into local vars for speed in inner loop */
8767 	int			var1ndigits = var1->ndigits;
8768 	int			var2ndigits = var2->ndigits;
8769 	NumericDigit *var1digits = var1->digits;
8770 	NumericDigit *var2digits = var2->digits;
8771 
8772 	res_weight = var1->weight;
8773 
8774 	res_dscale = Max(var1->dscale, var2->dscale);
8775 
8776 	/* Note: here we are figuring rscale in base-NBASE digits */
8777 	rscale1 = var1->ndigits - var1->weight - 1;
8778 	rscale2 = var2->ndigits - var2->weight - 1;
8779 	res_rscale = Max(rscale1, rscale2);
8780 
8781 	res_ndigits = res_rscale + res_weight + 1;
8782 	if (res_ndigits <= 0)
8783 		res_ndigits = 1;
8784 
8785 	res_buf = digitbuf_alloc(res_ndigits + 1);
8786 	res_buf[0] = 0;				/* spare digit for later rounding */
8787 	res_digits = res_buf + 1;
8788 
8789 	i1 = res_rscale + var1->weight + 1;
8790 	i2 = res_rscale + var2->weight + 1;
8791 	for (i = res_ndigits - 1; i >= 0; i--)
8792 	{
8793 		i1--;
8794 		i2--;
8795 		if (i1 >= 0 && i1 < var1ndigits)
8796 			borrow += var1digits[i1];
8797 		if (i2 >= 0 && i2 < var2ndigits)
8798 			borrow -= var2digits[i2];
8799 
8800 		if (borrow < 0)
8801 		{
8802 			res_digits[i] = borrow + NBASE;
8803 			borrow = -1;
8804 		}
8805 		else
8806 		{
8807 			res_digits[i] = borrow;
8808 			borrow = 0;
8809 		}
8810 	}
8811 
8812 	Assert(borrow == 0);		/* else caller gave us var1 < var2 */
8813 
8814 	digitbuf_free(result->buf);
8815 	result->ndigits = res_ndigits;
8816 	result->buf = res_buf;
8817 	result->digits = res_digits;
8818 	result->weight = res_weight;
8819 	result->dscale = res_dscale;
8820 
8821 	/* Remove leading/trailing zeroes */
8822 	strip_var(result);
8823 }
8824 
8825 /*
8826  * round_var
8827  *
8828  * Round the value of a variable to no more than rscale decimal digits
8829  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
8830  * rounding before the decimal point.
8831  */
8832 static void
round_var(NumericVar * var,int rscale)8833 round_var(NumericVar *var, int rscale)
8834 {
8835 	NumericDigit *digits = var->digits;
8836 	int			di;
8837 	int			ndigits;
8838 	int			carry;
8839 
8840 	var->dscale = rscale;
8841 
8842 	/* decimal digits wanted */
8843 	di = (var->weight + 1) * DEC_DIGITS + rscale;
8844 
8845 	/*
8846 	 * If di = 0, the value loses all digits, but could round up to 1 if its
8847 	 * first extra digit is >= 5.  If di < 0 the result must be 0.
8848 	 */
8849 	if (di < 0)
8850 	{
8851 		var->ndigits = 0;
8852 		var->weight = 0;
8853 		var->sign = NUMERIC_POS;
8854 	}
8855 	else
8856 	{
8857 		/* NBASE digits wanted */
8858 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
8859 
8860 		/* 0, or number of decimal digits to keep in last NBASE digit */
8861 		di %= DEC_DIGITS;
8862 
8863 		if (ndigits < var->ndigits ||
8864 			(ndigits == var->ndigits && di > 0))
8865 		{
8866 			var->ndigits = ndigits;
8867 
8868 #if DEC_DIGITS == 1
8869 			/* di must be zero */
8870 			carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
8871 #else
8872 			if (di == 0)
8873 				carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
8874 			else
8875 			{
8876 				/* Must round within last NBASE digit */
8877 				int			extra,
8878 							pow10;
8879 
8880 #if DEC_DIGITS == 4
8881 				pow10 = round_powers[di];
8882 #elif DEC_DIGITS == 2
8883 				pow10 = 10;
8884 #else
8885 #error unsupported NBASE
8886 #endif
8887 				extra = digits[--ndigits] % pow10;
8888 				digits[ndigits] -= extra;
8889 				carry = 0;
8890 				if (extra >= pow10 / 2)
8891 				{
8892 					pow10 += digits[ndigits];
8893 					if (pow10 >= NBASE)
8894 					{
8895 						pow10 -= NBASE;
8896 						carry = 1;
8897 					}
8898 					digits[ndigits] = pow10;
8899 				}
8900 			}
8901 #endif
8902 
8903 			/* Propagate carry if needed */
8904 			while (carry)
8905 			{
8906 				carry += digits[--ndigits];
8907 				if (carry >= NBASE)
8908 				{
8909 					digits[ndigits] = carry - NBASE;
8910 					carry = 1;
8911 				}
8912 				else
8913 				{
8914 					digits[ndigits] = carry;
8915 					carry = 0;
8916 				}
8917 			}
8918 
8919 			if (ndigits < 0)
8920 			{
8921 				Assert(ndigits == -1);	/* better not have added > 1 digit */
8922 				Assert(var->digits > var->buf);
8923 				var->digits--;
8924 				var->ndigits++;
8925 				var->weight++;
8926 			}
8927 		}
8928 	}
8929 }
8930 
8931 /*
8932  * trunc_var
8933  *
8934  * Truncate (towards zero) the value of a variable at rscale decimal digits
8935  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
8936  * truncation before the decimal point.
8937  */
8938 static void
trunc_var(NumericVar * var,int rscale)8939 trunc_var(NumericVar *var, int rscale)
8940 {
8941 	int			di;
8942 	int			ndigits;
8943 
8944 	var->dscale = rscale;
8945 
8946 	/* decimal digits wanted */
8947 	di = (var->weight + 1) * DEC_DIGITS + rscale;
8948 
8949 	/*
8950 	 * If di <= 0, the value loses all digits.
8951 	 */
8952 	if (di <= 0)
8953 	{
8954 		var->ndigits = 0;
8955 		var->weight = 0;
8956 		var->sign = NUMERIC_POS;
8957 	}
8958 	else
8959 	{
8960 		/* NBASE digits wanted */
8961 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
8962 
8963 		if (ndigits <= var->ndigits)
8964 		{
8965 			var->ndigits = ndigits;
8966 
8967 #if DEC_DIGITS == 1
8968 			/* no within-digit stuff to worry about */
8969 #else
8970 			/* 0, or number of decimal digits to keep in last NBASE digit */
8971 			di %= DEC_DIGITS;
8972 
8973 			if (di > 0)
8974 			{
8975 				/* Must truncate within last NBASE digit */
8976 				NumericDigit *digits = var->digits;
8977 				int			extra,
8978 							pow10;
8979 
8980 #if DEC_DIGITS == 4
8981 				pow10 = round_powers[di];
8982 #elif DEC_DIGITS == 2
8983 				pow10 = 10;
8984 #else
8985 #error unsupported NBASE
8986 #endif
8987 				extra = digits[--ndigits] % pow10;
8988 				digits[ndigits] -= extra;
8989 			}
8990 #endif
8991 		}
8992 	}
8993 }
8994 
8995 /*
8996  * strip_var
8997  *
8998  * Strip any leading and trailing zeroes from a numeric variable
8999  */
9000 static void
strip_var(NumericVar * var)9001 strip_var(NumericVar *var)
9002 {
9003 	NumericDigit *digits = var->digits;
9004 	int			ndigits = var->ndigits;
9005 
9006 	/* Strip leading zeroes */
9007 	while (ndigits > 0 && *digits == 0)
9008 	{
9009 		digits++;
9010 		var->weight--;
9011 		ndigits--;
9012 	}
9013 
9014 	/* Strip trailing zeroes */
9015 	while (ndigits > 0 && digits[ndigits - 1] == 0)
9016 		ndigits--;
9017 
9018 	/* If it's zero, normalize the sign and weight */
9019 	if (ndigits == 0)
9020 	{
9021 		var->sign = NUMERIC_POS;
9022 		var->weight = 0;
9023 	}
9024 
9025 	var->digits = digits;
9026 	var->ndigits = ndigits;
9027 }
9028 
9029 
9030 /* ----------------------------------------------------------------------
9031  *
9032  * Fast sum accumulator functions
9033  *
9034  * ----------------------------------------------------------------------
9035  */
9036 
9037 /*
9038  * Reset the accumulator's value to zero.  The buffers to hold the digits
9039  * are not free'd.
9040  */
9041 static void
accum_sum_reset(NumericSumAccum * accum)9042 accum_sum_reset(NumericSumAccum *accum)
9043 {
9044 	int			i;
9045 
9046 	accum->dscale = 0;
9047 	for (i = 0; i < accum->ndigits; i++)
9048 	{
9049 		accum->pos_digits[i] = 0;
9050 		accum->neg_digits[i] = 0;
9051 	}
9052 }
9053 
9054 /*
9055  * Accumulate a new value.
9056  */
9057 static void
accum_sum_add(NumericSumAccum * accum,const NumericVar * val)9058 accum_sum_add(NumericSumAccum *accum, const NumericVar *val)
9059 {
9060 	int32	   *accum_digits;
9061 	int			i,
9062 				val_i;
9063 	int			val_ndigits;
9064 	NumericDigit *val_digits;
9065 
9066 	/*
9067 	 * If we have accumulated too many values since the last carry
9068 	 * propagation, do it now, to avoid overflowing.  (We could allow more
9069 	 * than NBASE - 1, if we reserved two extra digits, rather than one, for
9070 	 * carry propagation.  But even with NBASE - 1, this needs to be done so
9071 	 * seldom, that the performance difference is negligible.)
9072 	 */
9073 	if (accum->num_uncarried == NBASE - 1)
9074 		accum_sum_carry(accum);
9075 
9076 	/*
9077 	 * Adjust the weight or scale of the old value, so that it can accommodate
9078 	 * the new value.
9079 	 */
9080 	accum_sum_rescale(accum, val);
9081 
9082 	/* */
9083 	if (val->sign == NUMERIC_POS)
9084 		accum_digits = accum->pos_digits;
9085 	else
9086 		accum_digits = accum->neg_digits;
9087 
9088 	/* copy these values into local vars for speed in loop */
9089 	val_ndigits = val->ndigits;
9090 	val_digits = val->digits;
9091 
9092 	i = accum->weight - val->weight;
9093 	for (val_i = 0; val_i < val_ndigits; val_i++)
9094 	{
9095 		accum_digits[i] += (int32) val_digits[val_i];
9096 		i++;
9097 	}
9098 
9099 	accum->num_uncarried++;
9100 }
9101 
9102 /*
9103  * Propagate carries.
9104  */
9105 static void
accum_sum_carry(NumericSumAccum * accum)9106 accum_sum_carry(NumericSumAccum *accum)
9107 {
9108 	int			i;
9109 	int			ndigits;
9110 	int32	   *dig;
9111 	int32		carry;
9112 	int32		newdig = 0;
9113 
9114 	/*
9115 	 * If no new values have been added since last carry propagation, nothing
9116 	 * to do.
9117 	 */
9118 	if (accum->num_uncarried == 0)
9119 		return;
9120 
9121 	/*
9122 	 * We maintain that the weight of the accumulator is always one larger
9123 	 * than needed to hold the current value, before carrying, to make sure
9124 	 * there is enough space for the possible extra digit when carry is
9125 	 * propagated.  We cannot expand the buffer here, unless we require
9126 	 * callers of accum_sum_final() to switch to the right memory context.
9127 	 */
9128 	Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
9129 
9130 	ndigits = accum->ndigits;
9131 
9132 	/* Propagate carry in the positive sum */
9133 	dig = accum->pos_digits;
9134 	carry = 0;
9135 	for (i = ndigits - 1; i >= 0; i--)
9136 	{
9137 		newdig = dig[i] + carry;
9138 		if (newdig >= NBASE)
9139 		{
9140 			carry = newdig / NBASE;
9141 			newdig -= carry * NBASE;
9142 		}
9143 		else
9144 			carry = 0;
9145 		dig[i] = newdig;
9146 	}
9147 	/* Did we use up the digit reserved for carry propagation? */
9148 	if (newdig > 0)
9149 		accum->have_carry_space = false;
9150 
9151 	/* And the same for the negative sum */
9152 	dig = accum->neg_digits;
9153 	carry = 0;
9154 	for (i = ndigits - 1; i >= 0; i--)
9155 	{
9156 		newdig = dig[i] + carry;
9157 		if (newdig >= NBASE)
9158 		{
9159 			carry = newdig / NBASE;
9160 			newdig -= carry * NBASE;
9161 		}
9162 		else
9163 			carry = 0;
9164 		dig[i] = newdig;
9165 	}
9166 	if (newdig > 0)
9167 		accum->have_carry_space = false;
9168 
9169 	accum->num_uncarried = 0;
9170 }
9171 
9172 /*
9173  * Re-scale accumulator to accommodate new value.
9174  *
9175  * If the new value has more digits than the current digit buffers in the
9176  * accumulator, enlarge the buffers.
9177  */
9178 static void
accum_sum_rescale(NumericSumAccum * accum,const NumericVar * val)9179 accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val)
9180 {
9181 	int			old_weight = accum->weight;
9182 	int			old_ndigits = accum->ndigits;
9183 	int			accum_ndigits;
9184 	int			accum_weight;
9185 	int			accum_rscale;
9186 	int			val_rscale;
9187 
9188 	accum_weight = old_weight;
9189 	accum_ndigits = old_ndigits;
9190 
9191 	/*
9192 	 * Does the new value have a larger weight? If so, enlarge the buffers,
9193 	 * and shift the existing value to the new weight, by adding leading
9194 	 * zeros.
9195 	 *
9196 	 * We enforce that the accumulator always has a weight one larger than
9197 	 * needed for the inputs, so that we have space for an extra digit at the
9198 	 * final carry-propagation phase, if necessary.
9199 	 */
9200 	if (val->weight >= accum_weight)
9201 	{
9202 		accum_weight = val->weight + 1;
9203 		accum_ndigits = accum_ndigits + (accum_weight - old_weight);
9204 	}
9205 
9206 	/*
9207 	 * Even though the new value is small, we might've used up the space
9208 	 * reserved for the carry digit in the last call to accum_sum_carry().  If
9209 	 * so, enlarge to make room for another one.
9210 	 */
9211 	else if (!accum->have_carry_space)
9212 	{
9213 		accum_weight++;
9214 		accum_ndigits++;
9215 	}
9216 
9217 	/* Is the new value wider on the right side? */
9218 	accum_rscale = accum_ndigits - accum_weight - 1;
9219 	val_rscale = val->ndigits - val->weight - 1;
9220 	if (val_rscale > accum_rscale)
9221 		accum_ndigits = accum_ndigits + (val_rscale - accum_rscale);
9222 
9223 	if (accum_ndigits != old_ndigits ||
9224 		accum_weight != old_weight)
9225 	{
9226 		int32	   *new_pos_digits;
9227 		int32	   *new_neg_digits;
9228 		int			weightdiff;
9229 
9230 		weightdiff = accum_weight - old_weight;
9231 
9232 		new_pos_digits = palloc0(accum_ndigits * sizeof(int32));
9233 		new_neg_digits = palloc0(accum_ndigits * sizeof(int32));
9234 
9235 		if (accum->pos_digits)
9236 		{
9237 			memcpy(&new_pos_digits[weightdiff], accum->pos_digits,
9238 				   old_ndigits * sizeof(int32));
9239 			pfree(accum->pos_digits);
9240 
9241 			memcpy(&new_neg_digits[weightdiff], accum->neg_digits,
9242 				   old_ndigits * sizeof(int32));
9243 			pfree(accum->neg_digits);
9244 		}
9245 
9246 		accum->pos_digits = new_pos_digits;
9247 		accum->neg_digits = new_neg_digits;
9248 
9249 		accum->weight = accum_weight;
9250 		accum->ndigits = accum_ndigits;
9251 
9252 		Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
9253 		accum->have_carry_space = true;
9254 	}
9255 
9256 	if (val->dscale > accum->dscale)
9257 		accum->dscale = val->dscale;
9258 }
9259 
9260 /*
9261  * Return the current value of the accumulator.  This perform final carry
9262  * propagation, and adds together the positive and negative sums.
9263  *
9264  * Unlike all the other routines, the caller is not required to switch to
9265  * the memory context that holds the accumulator.
9266  */
9267 static void
accum_sum_final(NumericSumAccum * accum,NumericVar * result)9268 accum_sum_final(NumericSumAccum *accum, NumericVar *result)
9269 {
9270 	int			i;
9271 	NumericVar	pos_var;
9272 	NumericVar	neg_var;
9273 
9274 	if (accum->ndigits == 0)
9275 	{
9276 		set_var_from_var(&const_zero, result);
9277 		return;
9278 	}
9279 
9280 	/* Perform final carry */
9281 	accum_sum_carry(accum);
9282 
9283 	/* Create NumericVars representing the positive and negative sums */
9284 	init_var(&pos_var);
9285 	init_var(&neg_var);
9286 
9287 	pos_var.ndigits = neg_var.ndigits = accum->ndigits;
9288 	pos_var.weight = neg_var.weight = accum->weight;
9289 	pos_var.dscale = neg_var.dscale = accum->dscale;
9290 	pos_var.sign = NUMERIC_POS;
9291 	neg_var.sign = NUMERIC_NEG;
9292 
9293 	pos_var.buf = pos_var.digits = digitbuf_alloc(accum->ndigits);
9294 	neg_var.buf = neg_var.digits = digitbuf_alloc(accum->ndigits);
9295 
9296 	for (i = 0; i < accum->ndigits; i++)
9297 	{
9298 		Assert(accum->pos_digits[i] < NBASE);
9299 		pos_var.digits[i] = (int16) accum->pos_digits[i];
9300 
9301 		Assert(accum->neg_digits[i] < NBASE);
9302 		neg_var.digits[i] = (int16) accum->neg_digits[i];
9303 	}
9304 
9305 	/* And add them together */
9306 	add_var(&pos_var, &neg_var, result);
9307 
9308 	/* Remove leading/trailing zeroes */
9309 	strip_var(result);
9310 }
9311 
9312 /*
9313  * Copy an accumulator's state.
9314  *
9315  * 'dst' is assumed to be uninitialized beforehand.  No attempt is made at
9316  * freeing old values.
9317  */
9318 static void
accum_sum_copy(NumericSumAccum * dst,NumericSumAccum * src)9319 accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src)
9320 {
9321 	dst->pos_digits = palloc(src->ndigits * sizeof(int32));
9322 	dst->neg_digits = palloc(src->ndigits * sizeof(int32));
9323 
9324 	memcpy(dst->pos_digits, src->pos_digits, src->ndigits * sizeof(int32));
9325 	memcpy(dst->neg_digits, src->neg_digits, src->ndigits * sizeof(int32));
9326 	dst->num_uncarried = src->num_uncarried;
9327 	dst->ndigits = src->ndigits;
9328 	dst->weight = src->weight;
9329 	dst->dscale = src->dscale;
9330 }
9331 
9332 /*
9333  * Add the current value of 'accum2' into 'accum'.
9334  */
9335 static void
accum_sum_combine(NumericSumAccum * accum,NumericSumAccum * accum2)9336 accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2)
9337 {
9338 	NumericVar	tmp_var;
9339 
9340 	init_var(&tmp_var);
9341 
9342 	accum_sum_final(accum2, &tmp_var);
9343 	accum_sum_add(accum, &tmp_var);
9344 
9345 	free_var(&tmp_var);
9346 }
9347