1 /*-------------------------------------------------------------------------
2  *
3  * numeric.c
4  *	  An exact numeric data type for the Postgres database system
5  *
6  * Original coding 1998, Jan Wieck.  Heavily revised 2003, Tom Lane.
7  *
8  * Many of the algorithmic ideas are borrowed from David M. Smith's "FM"
9  * multiple-precision math library, most recently published as Algorithm
10  * 786: Multiple-Precision Complex Arithmetic and Functions, ACM
11  * Transactions on Mathematical Software, Vol. 24, No. 4, December 1998,
12  * pages 359-367.
13  *
14  * Copyright (c) 1998-2021, PostgreSQL Global Development Group
15  *
16  * IDENTIFICATION
17  *	  src/backend/utils/adt/numeric.c
18  *
19  *-------------------------------------------------------------------------
20  */
21 
22 #include "postgres.h"
23 
24 #include <ctype.h>
25 #include <float.h>
26 #include <limits.h>
27 #include <math.h>
28 
29 #include "catalog/pg_type.h"
30 #include "common/hashfn.h"
31 #include "common/int.h"
32 #include "funcapi.h"
33 #include "lib/hyperloglog.h"
34 #include "libpq/pqformat.h"
35 #include "miscadmin.h"
36 #include "nodes/nodeFuncs.h"
37 #include "nodes/supportnodes.h"
38 #include "utils/array.h"
39 #include "utils/builtins.h"
40 #include "utils/float.h"
41 #include "utils/guc.h"
42 #include "utils/int8.h"
43 #include "utils/numeric.h"
44 #include "utils/pg_lsn.h"
45 #include "utils/sortsupport.h"
46 
47 /* ----------
48  * Uncomment the following to enable compilation of dump_numeric()
49  * and dump_var() and to get a dump of any result produced by make_result().
50  * ----------
51 #define NUMERIC_DEBUG
52  */
53 
54 
55 /* ----------
56  * Local data types
57  *
58  * Numeric values are represented in a base-NBASE floating point format.
59  * Each "digit" ranges from 0 to NBASE-1.  The type NumericDigit is signed
60  * and wide enough to store a digit.  We assume that NBASE*NBASE can fit in
61  * an int.  Although the purely calculational routines could handle any even
62  * NBASE that's less than sqrt(INT_MAX), in practice we are only interested
63  * in NBASE a power of ten, so that I/O conversions and decimal rounding
64  * are easy.  Also, it's actually more efficient if NBASE is rather less than
65  * sqrt(INT_MAX), so that there is "headroom" for mul_var and div_var_fast to
66  * postpone processing carries.
67  *
68  * Values of NBASE other than 10000 are considered of historical interest only
69  * and are no longer supported in any sense; no mechanism exists for the client
70  * to discover the base, so every client supporting binary mode expects the
71  * base-10000 format.  If you plan to change this, also note the numeric
72  * abbreviation code, which assumes NBASE=10000.
73  * ----------
74  */
75 
76 #if 0
77 #define NBASE		10
78 #define HALF_NBASE	5
79 #define DEC_DIGITS	1			/* decimal digits per NBASE digit */
80 #define MUL_GUARD_DIGITS	4	/* these are measured in NBASE digits */
81 #define DIV_GUARD_DIGITS	8
82 
83 typedef signed char NumericDigit;
84 #endif
85 
86 #if 0
87 #define NBASE		100
88 #define HALF_NBASE	50
89 #define DEC_DIGITS	2			/* decimal digits per NBASE digit */
90 #define MUL_GUARD_DIGITS	3	/* these are measured in NBASE digits */
91 #define DIV_GUARD_DIGITS	6
92 
93 typedef signed char NumericDigit;
94 #endif
95 
96 #if 1
97 #define NBASE		10000
98 #define HALF_NBASE	5000
99 #define DEC_DIGITS	4			/* decimal digits per NBASE digit */
100 #define MUL_GUARD_DIGITS	2	/* these are measured in NBASE digits */
101 #define DIV_GUARD_DIGITS	4
102 
103 typedef int16 NumericDigit;
104 #endif
105 
106 /*
107  * The Numeric type as stored on disk.
108  *
109  * If the high bits of the first word of a NumericChoice (n_header, or
110  * n_short.n_header, or n_long.n_sign_dscale) are NUMERIC_SHORT, then the
111  * numeric follows the NumericShort format; if they are NUMERIC_POS or
112  * NUMERIC_NEG, it follows the NumericLong format. If they are NUMERIC_SPECIAL,
113  * the value is a NaN or Infinity.  We currently always store SPECIAL values
114  * using just two bytes (i.e. only n_header), but previous releases used only
115  * the NumericLong format, so we might find 4-byte NaNs (though not infinities)
116  * on disk if a database has been migrated using pg_upgrade.  In either case,
117  * the low-order bits of a special value's header are reserved and currently
118  * should always be set to zero.
119  *
120  * In the NumericShort format, the remaining 14 bits of the header word
121  * (n_short.n_header) are allocated as follows: 1 for sign (positive or
122  * negative), 6 for dynamic scale, and 7 for weight.  In practice, most
123  * commonly-encountered values can be represented this way.
124  *
125  * In the NumericLong format, the remaining 14 bits of the header word
126  * (n_long.n_sign_dscale) represent the display scale; and the weight is
127  * stored separately in n_weight.
128  *
129  * NOTE: by convention, values in the packed form have been stripped of
130  * all leading and trailing zero digits (where a "digit" is of base NBASE).
131  * In particular, if the value is zero, there will be no digits at all!
132  * The weight is arbitrary in that case, but we normally set it to zero.
133  */
134 
135 struct NumericShort
136 {
137 	uint16		n_header;		/* Sign + display scale + weight */
138 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
139 };
140 
141 struct NumericLong
142 {
143 	uint16		n_sign_dscale;	/* Sign + display scale */
144 	int16		n_weight;		/* Weight of 1st digit	*/
145 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
146 };
147 
148 union NumericChoice
149 {
150 	uint16		n_header;		/* Header word */
151 	struct NumericLong n_long;	/* Long form (4-byte header) */
152 	struct NumericShort n_short;	/* Short form (2-byte header) */
153 };
154 
155 struct NumericData
156 {
157 	int32		vl_len_;		/* varlena header (do not touch directly!) */
158 	union NumericChoice choice; /* choice of format */
159 };
160 
161 
162 /*
163  * Interpretation of high bits.
164  */
165 
166 #define NUMERIC_SIGN_MASK	0xC000
167 #define NUMERIC_POS			0x0000
168 #define NUMERIC_NEG			0x4000
169 #define NUMERIC_SHORT		0x8000
170 #define NUMERIC_SPECIAL		0xC000
171 
172 #define NUMERIC_FLAGBITS(n) ((n)->choice.n_header & NUMERIC_SIGN_MASK)
173 #define NUMERIC_IS_SHORT(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_SHORT)
174 #define NUMERIC_IS_SPECIAL(n)	(NUMERIC_FLAGBITS(n) == NUMERIC_SPECIAL)
175 
176 #define NUMERIC_HDRSZ	(VARHDRSZ + sizeof(uint16) + sizeof(int16))
177 #define NUMERIC_HDRSZ_SHORT (VARHDRSZ + sizeof(uint16))
178 
179 /*
180  * If the flag bits are NUMERIC_SHORT or NUMERIC_SPECIAL, we want the short
181  * header; otherwise, we want the long one.  Instead of testing against each
182  * value, we can just look at the high bit, for a slight efficiency gain.
183  */
184 #define NUMERIC_HEADER_IS_SHORT(n)	(((n)->choice.n_header & 0x8000) != 0)
185 #define NUMERIC_HEADER_SIZE(n) \
186 	(VARHDRSZ + sizeof(uint16) + \
187 	 (NUMERIC_HEADER_IS_SHORT(n) ? 0 : sizeof(int16)))
188 
189 /*
190  * Definitions for special values (NaN, positive infinity, negative infinity).
191  *
192  * The two bits after the NUMERIC_SPECIAL bits are 00 for NaN, 01 for positive
193  * infinity, 11 for negative infinity.  (This makes the sign bit match where
194  * it is in a short-format value, though we make no use of that at present.)
195  * We could mask off the remaining bits before testing the active bits, but
196  * currently those bits must be zeroes, so masking would just add cycles.
197  */
198 #define NUMERIC_EXT_SIGN_MASK	0xF000	/* high bits plus NaN/Inf flag bits */
199 #define NUMERIC_NAN				0xC000
200 #define NUMERIC_PINF			0xD000
201 #define NUMERIC_NINF			0xF000
202 #define NUMERIC_INF_SIGN_MASK	0x2000
203 
204 #define NUMERIC_EXT_FLAGBITS(n)	((n)->choice.n_header & NUMERIC_EXT_SIGN_MASK)
205 #define NUMERIC_IS_NAN(n)		((n)->choice.n_header == NUMERIC_NAN)
206 #define NUMERIC_IS_PINF(n)		((n)->choice.n_header == NUMERIC_PINF)
207 #define NUMERIC_IS_NINF(n)		((n)->choice.n_header == NUMERIC_NINF)
208 #define NUMERIC_IS_INF(n) \
209 	(((n)->choice.n_header & ~NUMERIC_INF_SIGN_MASK) == NUMERIC_PINF)
210 
211 /*
212  * Short format definitions.
213  */
214 
215 #define NUMERIC_SHORT_SIGN_MASK			0x2000
216 #define NUMERIC_SHORT_DSCALE_MASK		0x1F80
217 #define NUMERIC_SHORT_DSCALE_SHIFT		7
218 #define NUMERIC_SHORT_DSCALE_MAX		\
219 	(NUMERIC_SHORT_DSCALE_MASK >> NUMERIC_SHORT_DSCALE_SHIFT)
220 #define NUMERIC_SHORT_WEIGHT_SIGN_MASK	0x0040
221 #define NUMERIC_SHORT_WEIGHT_MASK		0x003F
222 #define NUMERIC_SHORT_WEIGHT_MAX		NUMERIC_SHORT_WEIGHT_MASK
223 #define NUMERIC_SHORT_WEIGHT_MIN		(-(NUMERIC_SHORT_WEIGHT_MASK+1))
224 
225 /*
226  * Extract sign, display scale, weight.  These macros extract field values
227  * suitable for the NumericVar format from the Numeric (on-disk) format.
228  *
229  * Note that we don't trouble to ensure that dscale and weight read as zero
230  * for an infinity; however, that doesn't matter since we never convert
231  * "special" numerics to NumericVar form.  Only the constants defined below
232  * (const_nan, etc) ever represent a non-finite value as a NumericVar.
233  */
234 
235 #define NUMERIC_DSCALE_MASK			0x3FFF
236 #define NUMERIC_DSCALE_MAX			NUMERIC_DSCALE_MASK
237 
238 #define NUMERIC_SIGN(n) \
239 	(NUMERIC_IS_SHORT(n) ? \
240 		(((n)->choice.n_short.n_header & NUMERIC_SHORT_SIGN_MASK) ? \
241 		 NUMERIC_NEG : NUMERIC_POS) : \
242 		(NUMERIC_IS_SPECIAL(n) ? \
243 		 NUMERIC_EXT_FLAGBITS(n) : NUMERIC_FLAGBITS(n)))
244 #define NUMERIC_DSCALE(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
245 	((n)->choice.n_short.n_header & NUMERIC_SHORT_DSCALE_MASK) \
246 		>> NUMERIC_SHORT_DSCALE_SHIFT \
247 	: ((n)->choice.n_long.n_sign_dscale & NUMERIC_DSCALE_MASK))
248 #define NUMERIC_WEIGHT(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
249 	(((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_SIGN_MASK ? \
250 		~NUMERIC_SHORT_WEIGHT_MASK : 0) \
251 	 | ((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_MASK)) \
252 	: ((n)->choice.n_long.n_weight))
253 
254 /* ----------
255  * NumericVar is the format we use for arithmetic.  The digit-array part
256  * is the same as the NumericData storage format, but the header is more
257  * complex.
258  *
259  * The value represented by a NumericVar is determined by the sign, weight,
260  * ndigits, and digits[] array.  If it is a "special" value (NaN or Inf)
261  * then only the sign field matters; ndigits should be zero, and the weight
262  * and dscale fields are ignored.
263  *
264  * Note: the first digit of a NumericVar's value is assumed to be multiplied
265  * by NBASE ** weight.  Another way to say it is that there are weight+1
266  * digits before the decimal point.  It is possible to have weight < 0.
267  *
268  * buf points at the physical start of the palloc'd digit buffer for the
269  * NumericVar.  digits points at the first digit in actual use (the one
270  * with the specified weight).  We normally leave an unused digit or two
271  * (preset to zeroes) between buf and digits, so that there is room to store
272  * a carry out of the top digit without reallocating space.  We just need to
273  * decrement digits (and increment weight) to make room for the carry digit.
274  * (There is no such extra space in a numeric value stored in the database,
275  * only in a NumericVar in memory.)
276  *
277  * If buf is NULL then the digit buffer isn't actually palloc'd and should
278  * not be freed --- see the constants below for an example.
279  *
280  * dscale, or display scale, is the nominal precision expressed as number
281  * of digits after the decimal point (it must always be >= 0 at present).
282  * dscale may be more than the number of physically stored fractional digits,
283  * implying that we have suppressed storage of significant trailing zeroes.
284  * It should never be less than the number of stored digits, since that would
285  * imply hiding digits that are present.  NOTE that dscale is always expressed
286  * in *decimal* digits, and so it may correspond to a fractional number of
287  * base-NBASE digits --- divide by DEC_DIGITS to convert to NBASE digits.
288  *
289  * rscale, or result scale, is the target precision for a computation.
290  * Like dscale it is expressed as number of *decimal* digits after the decimal
291  * point, and is always >= 0 at present.
292  * Note that rscale is not stored in variables --- it's figured on-the-fly
293  * from the dscales of the inputs.
294  *
295  * While we consistently use "weight" to refer to the base-NBASE weight of
296  * a numeric value, it is convenient in some scale-related calculations to
297  * make use of the base-10 weight (ie, the approximate log10 of the value).
298  * To avoid confusion, such a decimal-units weight is called a "dweight".
299  *
300  * NB: All the variable-level functions are written in a style that makes it
301  * possible to give one and the same variable as argument and destination.
302  * This is feasible because the digit buffer is separate from the variable.
303  * ----------
304  */
305 typedef struct NumericVar
306 {
307 	int			ndigits;		/* # of digits in digits[] - can be 0! */
308 	int			weight;			/* weight of first digit */
309 	int			sign;			/* NUMERIC_POS, _NEG, _NAN, _PINF, or _NINF */
310 	int			dscale;			/* display scale */
311 	NumericDigit *buf;			/* start of palloc'd space for digits[] */
312 	NumericDigit *digits;		/* base-NBASE digits */
313 } NumericVar;
314 
315 
316 /* ----------
317  * Data for generate_series
318  * ----------
319  */
320 typedef struct
321 {
322 	NumericVar	current;
323 	NumericVar	stop;
324 	NumericVar	step;
325 } generate_series_numeric_fctx;
326 
327 
328 /* ----------
329  * Sort support.
330  * ----------
331  */
332 typedef struct
333 {
334 	void	   *buf;			/* buffer for short varlenas */
335 	int64		input_count;	/* number of non-null values seen */
336 	bool		estimating;		/* true if estimating cardinality */
337 
338 	hyperLogLogState abbr_card; /* cardinality estimator */
339 } NumericSortSupport;
340 
341 
342 /* ----------
343  * Fast sum accumulator.
344  *
345  * NumericSumAccum is used to implement SUM(), and other standard aggregates
346  * that track the sum of input values.  It uses 32-bit integers to store the
347  * digits, instead of the normal 16-bit integers (with NBASE=10000).  This
348  * way, we can safely accumulate up to NBASE - 1 values without propagating
349  * carry, before risking overflow of any of the digits.  'num_uncarried'
350  * tracks how many values have been accumulated without propagating carry.
351  *
352  * Positive and negative values are accumulated separately, in 'pos_digits'
353  * and 'neg_digits'.  This is simpler and faster than deciding whether to add
354  * or subtract from the current value, for each new value (see sub_var() for
355  * the logic we avoid by doing this).  Both buffers are of same size, and
356  * have the same weight and scale.  In accum_sum_final(), the positive and
357  * negative sums are added together to produce the final result.
358  *
359  * When a new value has a larger ndigits or weight than the accumulator
360  * currently does, the accumulator is enlarged to accommodate the new value.
361  * We normally have one zero digit reserved for carry propagation, and that
362  * is indicated by the 'have_carry_space' flag.  When accum_sum_carry() uses
363  * up the reserved digit, it clears the 'have_carry_space' flag.  The next
364  * call to accum_sum_add() will enlarge the buffer, to make room for the
365  * extra digit, and set the flag again.
366  *
367  * To initialize a new accumulator, simply reset all fields to zeros.
368  *
369  * The accumulator does not handle NaNs.
370  * ----------
371  */
372 typedef struct NumericSumAccum
373 {
374 	int			ndigits;
375 	int			weight;
376 	int			dscale;
377 	int			num_uncarried;
378 	bool		have_carry_space;
379 	int32	   *pos_digits;
380 	int32	   *neg_digits;
381 } NumericSumAccum;
382 
383 
384 /*
385  * We define our own macros for packing and unpacking abbreviated-key
386  * representations for numeric values in order to avoid depending on
387  * USE_FLOAT8_BYVAL.  The type of abbreviation we use is based only on
388  * the size of a datum, not the argument-passing convention for float8.
389  *
390  * The range of abbreviations for finite values is from +PG_INT64/32_MAX
391  * to -PG_INT64/32_MAX.  NaN has the abbreviation PG_INT64/32_MIN, and we
392  * define the sort ordering to make that work out properly (see further
393  * comments below).  PINF and NINF share the abbreviations of the largest
394  * and smallest finite abbreviation classes.
395  */
396 #define NUMERIC_ABBREV_BITS (SIZEOF_DATUM * BITS_PER_BYTE)
397 #if SIZEOF_DATUM == 8
398 #define NumericAbbrevGetDatum(X) ((Datum) (X))
399 #define DatumGetNumericAbbrev(X) ((int64) (X))
400 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT64_MIN)
401 #define NUMERIC_ABBREV_PINF		 NumericAbbrevGetDatum(-PG_INT64_MAX)
402 #define NUMERIC_ABBREV_NINF		 NumericAbbrevGetDatum(PG_INT64_MAX)
403 #else
404 #define NumericAbbrevGetDatum(X) ((Datum) (X))
405 #define DatumGetNumericAbbrev(X) ((int32) (X))
406 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT32_MIN)
407 #define NUMERIC_ABBREV_PINF		 NumericAbbrevGetDatum(-PG_INT32_MAX)
408 #define NUMERIC_ABBREV_NINF		 NumericAbbrevGetDatum(PG_INT32_MAX)
409 #endif
410 
411 
412 /* ----------
413  * Some preinitialized constants
414  * ----------
415  */
416 static const NumericDigit const_zero_data[1] = {0};
417 static const NumericVar const_zero =
418 {0, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_zero_data};
419 
420 static const NumericDigit const_one_data[1] = {1};
421 static const NumericVar const_one =
422 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_one_data};
423 
424 static const NumericVar const_minus_one =
425 {1, 0, NUMERIC_NEG, 0, NULL, (NumericDigit *) const_one_data};
426 
427 static const NumericDigit const_two_data[1] = {2};
428 static const NumericVar const_two =
429 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_two_data};
430 
431 #if DEC_DIGITS == 4
432 static const NumericDigit const_zero_point_nine_data[1] = {9000};
433 #elif DEC_DIGITS == 2
434 static const NumericDigit const_zero_point_nine_data[1] = {90};
435 #elif DEC_DIGITS == 1
436 static const NumericDigit const_zero_point_nine_data[1] = {9};
437 #endif
438 static const NumericVar const_zero_point_nine =
439 {1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_nine_data};
440 
441 #if DEC_DIGITS == 4
442 static const NumericDigit const_one_point_one_data[2] = {1, 1000};
443 #elif DEC_DIGITS == 2
444 static const NumericDigit const_one_point_one_data[2] = {1, 10};
445 #elif DEC_DIGITS == 1
446 static const NumericDigit const_one_point_one_data[2] = {1, 1};
447 #endif
448 static const NumericVar const_one_point_one =
449 {2, 0, NUMERIC_POS, 1, NULL, (NumericDigit *) const_one_point_one_data};
450 
451 static const NumericVar const_nan =
452 {0, 0, NUMERIC_NAN, 0, NULL, NULL};
453 
454 static const NumericVar const_pinf =
455 {0, 0, NUMERIC_PINF, 0, NULL, NULL};
456 
457 static const NumericVar const_ninf =
458 {0, 0, NUMERIC_NINF, 0, NULL, NULL};
459 
460 #if DEC_DIGITS == 4
461 static const int round_powers[4] = {0, 1000, 100, 10};
462 #endif
463 
464 
465 /* ----------
466  * Local functions
467  * ----------
468  */
469 
470 #ifdef NUMERIC_DEBUG
471 static void dump_numeric(const char *str, Numeric num);
472 static void dump_var(const char *str, NumericVar *var);
473 #else
474 #define dump_numeric(s,n)
475 #define dump_var(s,v)
476 #endif
477 
478 #define digitbuf_alloc(ndigits)  \
479 	((NumericDigit *) palloc((ndigits) * sizeof(NumericDigit)))
480 #define digitbuf_free(buf)	\
481 	do { \
482 		 if ((buf) != NULL) \
483 			 pfree(buf); \
484 	} while (0)
485 
486 #define init_var(v)		memset(v, 0, sizeof(NumericVar))
487 
488 #define NUMERIC_DIGITS(num) (NUMERIC_HEADER_IS_SHORT(num) ? \
489 	(num)->choice.n_short.n_data : (num)->choice.n_long.n_data)
490 #define NUMERIC_NDIGITS(num) \
491 	((VARSIZE(num) - NUMERIC_HEADER_SIZE(num)) / sizeof(NumericDigit))
492 #define NUMERIC_CAN_BE_SHORT(scale,weight) \
493 	((scale) <= NUMERIC_SHORT_DSCALE_MAX && \
494 	(weight) <= NUMERIC_SHORT_WEIGHT_MAX && \
495 	(weight) >= NUMERIC_SHORT_WEIGHT_MIN)
496 
497 static void alloc_var(NumericVar *var, int ndigits);
498 static void free_var(NumericVar *var);
499 static void zero_var(NumericVar *var);
500 
501 static const char *set_var_from_str(const char *str, const char *cp,
502 									NumericVar *dest);
503 static void set_var_from_num(Numeric value, NumericVar *dest);
504 static void init_var_from_num(Numeric num, NumericVar *dest);
505 static void set_var_from_var(const NumericVar *value, NumericVar *dest);
506 static char *get_str_from_var(const NumericVar *var);
507 static char *get_str_from_var_sci(const NumericVar *var, int rscale);
508 
509 static Numeric duplicate_numeric(Numeric num);
510 static Numeric make_result(const NumericVar *var);
511 static Numeric make_result_opt_error(const NumericVar *var, bool *error);
512 
513 static void apply_typmod(NumericVar *var, int32 typmod);
514 static void apply_typmod_special(Numeric num, int32 typmod);
515 
516 static bool numericvar_to_int32(const NumericVar *var, int32 *result);
517 static bool numericvar_to_int64(const NumericVar *var, int64 *result);
518 static void int64_to_numericvar(int64 val, NumericVar *var);
519 static bool numericvar_to_uint64(const NumericVar *var, uint64 *result);
520 #ifdef HAVE_INT128
521 static bool numericvar_to_int128(const NumericVar *var, int128 *result);
522 static void int128_to_numericvar(int128 val, NumericVar *var);
523 #endif
524 static double numericvar_to_double_no_overflow(const NumericVar *var);
525 
526 static Datum numeric_abbrev_convert(Datum original_datum, SortSupport ssup);
527 static bool numeric_abbrev_abort(int memtupcount, SortSupport ssup);
528 static int	numeric_fast_cmp(Datum x, Datum y, SortSupport ssup);
529 static int	numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup);
530 
531 static Datum numeric_abbrev_convert_var(const NumericVar *var,
532 										NumericSortSupport *nss);
533 
534 static int	cmp_numerics(Numeric num1, Numeric num2);
535 static int	cmp_var(const NumericVar *var1, const NumericVar *var2);
536 static int	cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
537 						   int var1weight, int var1sign,
538 						   const NumericDigit *var2digits, int var2ndigits,
539 						   int var2weight, int var2sign);
540 static void add_var(const NumericVar *var1, const NumericVar *var2,
541 					NumericVar *result);
542 static void sub_var(const NumericVar *var1, const NumericVar *var2,
543 					NumericVar *result);
544 static void mul_var(const NumericVar *var1, const NumericVar *var2,
545 					NumericVar *result,
546 					int rscale);
547 static void div_var(const NumericVar *var1, const NumericVar *var2,
548 					NumericVar *result,
549 					int rscale, bool round);
550 static void div_var_fast(const NumericVar *var1, const NumericVar *var2,
551 						 NumericVar *result, int rscale, bool round);
552 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
553 static void mod_var(const NumericVar *var1, const NumericVar *var2,
554 					NumericVar *result);
555 static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
556 						NumericVar *quot, NumericVar *rem);
557 static void ceil_var(const NumericVar *var, NumericVar *result);
558 static void floor_var(const NumericVar *var, NumericVar *result);
559 
560 static void gcd_var(const NumericVar *var1, const NumericVar *var2,
561 					NumericVar *result);
562 static void sqrt_var(const NumericVar *arg, NumericVar *result, int rscale);
563 static void exp_var(const NumericVar *arg, NumericVar *result, int rscale);
564 static int	estimate_ln_dweight(const NumericVar *var);
565 static void ln_var(const NumericVar *arg, NumericVar *result, int rscale);
566 static void log_var(const NumericVar *base, const NumericVar *num,
567 					NumericVar *result);
568 static void power_var(const NumericVar *base, const NumericVar *exp,
569 					  NumericVar *result);
570 static void power_var_int(const NumericVar *base, int exp, NumericVar *result,
571 						  int rscale);
572 static void power_ten_int(int exp, NumericVar *result);
573 
574 static int	cmp_abs(const NumericVar *var1, const NumericVar *var2);
575 static int	cmp_abs_common(const NumericDigit *var1digits, int var1ndigits,
576 						   int var1weight,
577 						   const NumericDigit *var2digits, int var2ndigits,
578 						   int var2weight);
579 static void add_abs(const NumericVar *var1, const NumericVar *var2,
580 					NumericVar *result);
581 static void sub_abs(const NumericVar *var1, const NumericVar *var2,
582 					NumericVar *result);
583 static void round_var(NumericVar *var, int rscale);
584 static void trunc_var(NumericVar *var, int rscale);
585 static void strip_var(NumericVar *var);
586 static void compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
587 						   const NumericVar *count_var, bool reversed_bounds,
588 						   NumericVar *result_var);
589 
590 static void accum_sum_add(NumericSumAccum *accum, const NumericVar *var1);
591 static void accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val);
592 static void accum_sum_carry(NumericSumAccum *accum);
593 static void accum_sum_reset(NumericSumAccum *accum);
594 static void accum_sum_final(NumericSumAccum *accum, NumericVar *result);
595 static void accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src);
596 static void accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2);
597 
598 
599 /* ----------------------------------------------------------------------
600  *
601  * Input-, output- and rounding-functions
602  *
603  * ----------------------------------------------------------------------
604  */
605 
606 
607 /*
608  * numeric_in() -
609  *
610  *	Input function for numeric data type
611  */
612 Datum
numeric_in(PG_FUNCTION_ARGS)613 numeric_in(PG_FUNCTION_ARGS)
614 {
615 	char	   *str = PG_GETARG_CSTRING(0);
616 
617 #ifdef NOT_USED
618 	Oid			typelem = PG_GETARG_OID(1);
619 #endif
620 	int32		typmod = PG_GETARG_INT32(2);
621 	Numeric		res;
622 	const char *cp;
623 
624 	/* Skip leading spaces */
625 	cp = str;
626 	while (*cp)
627 	{
628 		if (!isspace((unsigned char) *cp))
629 			break;
630 		cp++;
631 	}
632 
633 	/*
634 	 * Check for NaN and infinities.  We recognize the same strings allowed by
635 	 * float8in().
636 	 */
637 	if (pg_strncasecmp(cp, "NaN", 3) == 0)
638 	{
639 		res = make_result(&const_nan);
640 		cp += 3;
641 	}
642 	else if (pg_strncasecmp(cp, "Infinity", 8) == 0)
643 	{
644 		res = make_result(&const_pinf);
645 		cp += 8;
646 	}
647 	else if (pg_strncasecmp(cp, "+Infinity", 9) == 0)
648 	{
649 		res = make_result(&const_pinf);
650 		cp += 9;
651 	}
652 	else if (pg_strncasecmp(cp, "-Infinity", 9) == 0)
653 	{
654 		res = make_result(&const_ninf);
655 		cp += 9;
656 	}
657 	else if (pg_strncasecmp(cp, "inf", 3) == 0)
658 	{
659 		res = make_result(&const_pinf);
660 		cp += 3;
661 	}
662 	else if (pg_strncasecmp(cp, "+inf", 4) == 0)
663 	{
664 		res = make_result(&const_pinf);
665 		cp += 4;
666 	}
667 	else if (pg_strncasecmp(cp, "-inf", 4) == 0)
668 	{
669 		res = make_result(&const_ninf);
670 		cp += 4;
671 	}
672 	else
673 	{
674 		/*
675 		 * Use set_var_from_str() to parse a normal numeric value
676 		 */
677 		NumericVar	value;
678 
679 		init_var(&value);
680 
681 		cp = set_var_from_str(str, cp, &value);
682 
683 		/*
684 		 * We duplicate a few lines of code here because we would like to
685 		 * throw any trailing-junk syntax error before any semantic error
686 		 * resulting from apply_typmod.  We can't easily fold the two cases
687 		 * together because we mustn't apply apply_typmod to a NaN/Inf.
688 		 */
689 		while (*cp)
690 		{
691 			if (!isspace((unsigned char) *cp))
692 				ereport(ERROR,
693 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
694 						 errmsg("invalid input syntax for type %s: \"%s\"",
695 								"numeric", str)));
696 			cp++;
697 		}
698 
699 		apply_typmod(&value, typmod);
700 
701 		res = make_result(&value);
702 		free_var(&value);
703 
704 		PG_RETURN_NUMERIC(res);
705 	}
706 
707 	/* Should be nothing left but spaces */
708 	while (*cp)
709 	{
710 		if (!isspace((unsigned char) *cp))
711 			ereport(ERROR,
712 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
713 					 errmsg("invalid input syntax for type %s: \"%s\"",
714 							"numeric", str)));
715 		cp++;
716 	}
717 
718 	/* As above, throw any typmod error after finishing syntax check */
719 	apply_typmod_special(res, typmod);
720 
721 	PG_RETURN_NUMERIC(res);
722 }
723 
724 
725 /*
726  * numeric_out() -
727  *
728  *	Output function for numeric data type
729  */
730 Datum
numeric_out(PG_FUNCTION_ARGS)731 numeric_out(PG_FUNCTION_ARGS)
732 {
733 	Numeric		num = PG_GETARG_NUMERIC(0);
734 	NumericVar	x;
735 	char	   *str;
736 
737 	/*
738 	 * Handle NaN and infinities
739 	 */
740 	if (NUMERIC_IS_SPECIAL(num))
741 	{
742 		if (NUMERIC_IS_PINF(num))
743 			PG_RETURN_CSTRING(pstrdup("Infinity"));
744 		else if (NUMERIC_IS_NINF(num))
745 			PG_RETURN_CSTRING(pstrdup("-Infinity"));
746 		else
747 			PG_RETURN_CSTRING(pstrdup("NaN"));
748 	}
749 
750 	/*
751 	 * Get the number in the variable format.
752 	 */
753 	init_var_from_num(num, &x);
754 
755 	str = get_str_from_var(&x);
756 
757 	PG_RETURN_CSTRING(str);
758 }
759 
760 /*
761  * numeric_is_nan() -
762  *
763  *	Is Numeric value a NaN?
764  */
765 bool
numeric_is_nan(Numeric num)766 numeric_is_nan(Numeric num)
767 {
768 	return NUMERIC_IS_NAN(num);
769 }
770 
771 /*
772  * numeric_is_inf() -
773  *
774  *	Is Numeric value an infinity?
775  */
776 bool
numeric_is_inf(Numeric num)777 numeric_is_inf(Numeric num)
778 {
779 	return NUMERIC_IS_INF(num);
780 }
781 
782 /*
783  * numeric_is_integral() -
784  *
785  *	Is Numeric value integral?
786  */
787 static bool
numeric_is_integral(Numeric num)788 numeric_is_integral(Numeric num)
789 {
790 	NumericVar	arg;
791 
792 	/* Reject NaN, but infinities are considered integral */
793 	if (NUMERIC_IS_SPECIAL(num))
794 	{
795 		if (NUMERIC_IS_NAN(num))
796 			return false;
797 		return true;
798 	}
799 
800 	/* Integral if there are no digits to the right of the decimal point */
801 	init_var_from_num(num, &arg);
802 
803 	return (arg.ndigits == 0 || arg.ndigits <= arg.weight + 1);
804 }
805 
806 /*
807  * numeric_maximum_size() -
808  *
809  *	Maximum size of a numeric with given typmod, or -1 if unlimited/unknown.
810  */
811 int32
numeric_maximum_size(int32 typmod)812 numeric_maximum_size(int32 typmod)
813 {
814 	int			precision;
815 	int			numeric_digits;
816 
817 	if (typmod < (int32) (VARHDRSZ))
818 		return -1;
819 
820 	/* precision (ie, max # of digits) is in upper bits of typmod */
821 	precision = ((typmod - VARHDRSZ) >> 16) & 0xffff;
822 
823 	/*
824 	 * This formula computes the maximum number of NumericDigits we could need
825 	 * in order to store the specified number of decimal digits. Because the
826 	 * weight is stored as a number of NumericDigits rather than a number of
827 	 * decimal digits, it's possible that the first NumericDigit will contain
828 	 * only a single decimal digit.  Thus, the first two decimal digits can
829 	 * require two NumericDigits to store, but it isn't until we reach
830 	 * DEC_DIGITS + 2 decimal digits that we potentially need a third
831 	 * NumericDigit.
832 	 */
833 	numeric_digits = (precision + 2 * (DEC_DIGITS - 1)) / DEC_DIGITS;
834 
835 	/*
836 	 * In most cases, the size of a numeric will be smaller than the value
837 	 * computed below, because the varlena header will typically get toasted
838 	 * down to a single byte before being stored on disk, and it may also be
839 	 * possible to use a short numeric header.  But our job here is to compute
840 	 * the worst case.
841 	 */
842 	return NUMERIC_HDRSZ + (numeric_digits * sizeof(NumericDigit));
843 }
844 
845 /*
846  * numeric_out_sci() -
847  *
848  *	Output function for numeric data type in scientific notation.
849  */
850 char *
numeric_out_sci(Numeric num,int scale)851 numeric_out_sci(Numeric num, int scale)
852 {
853 	NumericVar	x;
854 	char	   *str;
855 
856 	/*
857 	 * Handle NaN and infinities
858 	 */
859 	if (NUMERIC_IS_SPECIAL(num))
860 	{
861 		if (NUMERIC_IS_PINF(num))
862 			return pstrdup("Infinity");
863 		else if (NUMERIC_IS_NINF(num))
864 			return pstrdup("-Infinity");
865 		else
866 			return pstrdup("NaN");
867 	}
868 
869 	init_var_from_num(num, &x);
870 
871 	str = get_str_from_var_sci(&x, scale);
872 
873 	return str;
874 }
875 
876 /*
877  * numeric_normalize() -
878  *
879  *	Output function for numeric data type, suppressing insignificant trailing
880  *	zeroes and then any trailing decimal point.  The intent of this is to
881  *	produce strings that are equal if and only if the input numeric values
882  *	compare equal.
883  */
884 char *
numeric_normalize(Numeric num)885 numeric_normalize(Numeric num)
886 {
887 	NumericVar	x;
888 	char	   *str;
889 	int			last;
890 
891 	/*
892 	 * Handle NaN and infinities
893 	 */
894 	if (NUMERIC_IS_SPECIAL(num))
895 	{
896 		if (NUMERIC_IS_PINF(num))
897 			return pstrdup("Infinity");
898 		else if (NUMERIC_IS_NINF(num))
899 			return pstrdup("-Infinity");
900 		else
901 			return pstrdup("NaN");
902 	}
903 
904 	init_var_from_num(num, &x);
905 
906 	str = get_str_from_var(&x);
907 
908 	/* If there's no decimal point, there's certainly nothing to remove. */
909 	if (strchr(str, '.') != NULL)
910 	{
911 		/*
912 		 * Back up over trailing fractional zeroes.  Since there is a decimal
913 		 * point, this loop will terminate safely.
914 		 */
915 		last = strlen(str) - 1;
916 		while (str[last] == '0')
917 			last--;
918 
919 		/* We want to get rid of the decimal point too, if it's now last. */
920 		if (str[last] == '.')
921 			last--;
922 
923 		/* Delete whatever we backed up over. */
924 		str[last + 1] = '\0';
925 	}
926 
927 	return str;
928 }
929 
930 /*
931  *		numeric_recv			- converts external binary format to numeric
932  *
933  * External format is a sequence of int16's:
934  * ndigits, weight, sign, dscale, NumericDigits.
935  */
936 Datum
numeric_recv(PG_FUNCTION_ARGS)937 numeric_recv(PG_FUNCTION_ARGS)
938 {
939 	StringInfo	buf = (StringInfo) PG_GETARG_POINTER(0);
940 
941 #ifdef NOT_USED
942 	Oid			typelem = PG_GETARG_OID(1);
943 #endif
944 	int32		typmod = PG_GETARG_INT32(2);
945 	NumericVar	value;
946 	Numeric		res;
947 	int			len,
948 				i;
949 
950 	init_var(&value);
951 
952 	len = (uint16) pq_getmsgint(buf, sizeof(uint16));
953 
954 	alloc_var(&value, len);
955 
956 	value.weight = (int16) pq_getmsgint(buf, sizeof(int16));
957 	/* we allow any int16 for weight --- OK? */
958 
959 	value.sign = (uint16) pq_getmsgint(buf, sizeof(uint16));
960 	if (!(value.sign == NUMERIC_POS ||
961 		  value.sign == NUMERIC_NEG ||
962 		  value.sign == NUMERIC_NAN ||
963 		  value.sign == NUMERIC_PINF ||
964 		  value.sign == NUMERIC_NINF))
965 		ereport(ERROR,
966 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
967 				 errmsg("invalid sign in external \"numeric\" value")));
968 
969 	value.dscale = (uint16) pq_getmsgint(buf, sizeof(uint16));
970 	if ((value.dscale & NUMERIC_DSCALE_MASK) != value.dscale)
971 		ereport(ERROR,
972 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
973 				 errmsg("invalid scale in external \"numeric\" value")));
974 
975 	for (i = 0; i < len; i++)
976 	{
977 		NumericDigit d = pq_getmsgint(buf, sizeof(NumericDigit));
978 
979 		if (d < 0 || d >= NBASE)
980 			ereport(ERROR,
981 					(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
982 					 errmsg("invalid digit in external \"numeric\" value")));
983 		value.digits[i] = d;
984 	}
985 
986 	/*
987 	 * If the given dscale would hide any digits, truncate those digits away.
988 	 * We could alternatively throw an error, but that would take a bunch of
989 	 * extra code (about as much as trunc_var involves), and it might cause
990 	 * client compatibility issues.  Be careful not to apply trunc_var to
991 	 * special values, as it could do the wrong thing; we don't need it
992 	 * anyway, since make_result will ignore all but the sign field.
993 	 *
994 	 * After doing that, be sure to check the typmod restriction.
995 	 */
996 	if (value.sign == NUMERIC_POS ||
997 		value.sign == NUMERIC_NEG)
998 	{
999 		trunc_var(&value, value.dscale);
1000 
1001 		apply_typmod(&value, typmod);
1002 
1003 		res = make_result(&value);
1004 	}
1005 	else
1006 	{
1007 		/* apply_typmod_special wants us to make the Numeric first */
1008 		res = make_result(&value);
1009 
1010 		apply_typmod_special(res, typmod);
1011 	}
1012 
1013 	free_var(&value);
1014 
1015 	PG_RETURN_NUMERIC(res);
1016 }
1017 
1018 /*
1019  *		numeric_send			- converts numeric to binary format
1020  */
1021 Datum
numeric_send(PG_FUNCTION_ARGS)1022 numeric_send(PG_FUNCTION_ARGS)
1023 {
1024 	Numeric		num = PG_GETARG_NUMERIC(0);
1025 	NumericVar	x;
1026 	StringInfoData buf;
1027 	int			i;
1028 
1029 	init_var_from_num(num, &x);
1030 
1031 	pq_begintypsend(&buf);
1032 
1033 	pq_sendint16(&buf, x.ndigits);
1034 	pq_sendint16(&buf, x.weight);
1035 	pq_sendint16(&buf, x.sign);
1036 	pq_sendint16(&buf, x.dscale);
1037 	for (i = 0; i < x.ndigits; i++)
1038 		pq_sendint16(&buf, x.digits[i]);
1039 
1040 	PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
1041 }
1042 
1043 
1044 /*
1045  * numeric_support()
1046  *
1047  * Planner support function for the numeric() length coercion function.
1048  *
1049  * Flatten calls that solely represent increases in allowable precision.
1050  * Scale changes mutate every datum, so they are unoptimizable.  Some values,
1051  * e.g. 1E-1001, can only fit into an unconstrained numeric, so a change from
1052  * an unconstrained numeric to any constrained numeric is also unoptimizable.
1053  */
1054 Datum
numeric_support(PG_FUNCTION_ARGS)1055 numeric_support(PG_FUNCTION_ARGS)
1056 {
1057 	Node	   *rawreq = (Node *) PG_GETARG_POINTER(0);
1058 	Node	   *ret = NULL;
1059 
1060 	if (IsA(rawreq, SupportRequestSimplify))
1061 	{
1062 		SupportRequestSimplify *req = (SupportRequestSimplify *) rawreq;
1063 		FuncExpr   *expr = req->fcall;
1064 		Node	   *typmod;
1065 
1066 		Assert(list_length(expr->args) >= 2);
1067 
1068 		typmod = (Node *) lsecond(expr->args);
1069 
1070 		if (IsA(typmod, Const) && !((Const *) typmod)->constisnull)
1071 		{
1072 			Node	   *source = (Node *) linitial(expr->args);
1073 			int32		old_typmod = exprTypmod(source);
1074 			int32		new_typmod = DatumGetInt32(((Const *) typmod)->constvalue);
1075 			int32		old_scale = (old_typmod - VARHDRSZ) & 0xffff;
1076 			int32		new_scale = (new_typmod - VARHDRSZ) & 0xffff;
1077 			int32		old_precision = (old_typmod - VARHDRSZ) >> 16 & 0xffff;
1078 			int32		new_precision = (new_typmod - VARHDRSZ) >> 16 & 0xffff;
1079 
1080 			/*
1081 			 * If new_typmod < VARHDRSZ, the destination is unconstrained;
1082 			 * that's always OK.  If old_typmod >= VARHDRSZ, the source is
1083 			 * constrained, and we're OK if the scale is unchanged and the
1084 			 * precision is not decreasing.  See further notes in function
1085 			 * header comment.
1086 			 */
1087 			if (new_typmod < (int32) VARHDRSZ ||
1088 				(old_typmod >= (int32) VARHDRSZ &&
1089 				 new_scale == old_scale && new_precision >= old_precision))
1090 				ret = relabel_to_typmod(source, new_typmod);
1091 		}
1092 	}
1093 
1094 	PG_RETURN_POINTER(ret);
1095 }
1096 
1097 /*
1098  * numeric() -
1099  *
1100  *	This is a special function called by the Postgres database system
1101  *	before a value is stored in a tuple's attribute. The precision and
1102  *	scale of the attribute have to be applied on the value.
1103  */
1104 Datum
numeric(PG_FUNCTION_ARGS)1105 numeric		(PG_FUNCTION_ARGS)
1106 {
1107 	Numeric		num = PG_GETARG_NUMERIC(0);
1108 	int32		typmod = PG_GETARG_INT32(1);
1109 	Numeric		new;
1110 	int32		tmp_typmod;
1111 	int			precision;
1112 	int			scale;
1113 	int			ddigits;
1114 	int			maxdigits;
1115 	NumericVar	var;
1116 
1117 	/*
1118 	 * Handle NaN and infinities: if apply_typmod_special doesn't complain,
1119 	 * just return a copy of the input.
1120 	 */
1121 	if (NUMERIC_IS_SPECIAL(num))
1122 	{
1123 		apply_typmod_special(num, typmod);
1124 		PG_RETURN_NUMERIC(duplicate_numeric(num));
1125 	}
1126 
1127 	/*
1128 	 * If the value isn't a valid type modifier, simply return a copy of the
1129 	 * input value
1130 	 */
1131 	if (typmod < (int32) (VARHDRSZ))
1132 		PG_RETURN_NUMERIC(duplicate_numeric(num));
1133 
1134 	/*
1135 	 * Get the precision and scale out of the typmod value
1136 	 */
1137 	tmp_typmod = typmod - VARHDRSZ;
1138 	precision = (tmp_typmod >> 16) & 0xffff;
1139 	scale = tmp_typmod & 0xffff;
1140 	maxdigits = precision - scale;
1141 
1142 	/*
1143 	 * If the number is certainly in bounds and due to the target scale no
1144 	 * rounding could be necessary, just make a copy of the input and modify
1145 	 * its scale fields, unless the larger scale forces us to abandon the
1146 	 * short representation.  (Note we assume the existing dscale is
1147 	 * honest...)
1148 	 */
1149 	ddigits = (NUMERIC_WEIGHT(num) + 1) * DEC_DIGITS;
1150 	if (ddigits <= maxdigits && scale >= NUMERIC_DSCALE(num)
1151 		&& (NUMERIC_CAN_BE_SHORT(scale, NUMERIC_WEIGHT(num))
1152 			|| !NUMERIC_IS_SHORT(num)))
1153 	{
1154 		new = duplicate_numeric(num);
1155 		if (NUMERIC_IS_SHORT(num))
1156 			new->choice.n_short.n_header =
1157 				(num->choice.n_short.n_header & ~NUMERIC_SHORT_DSCALE_MASK)
1158 				| (scale << NUMERIC_SHORT_DSCALE_SHIFT);
1159 		else
1160 			new->choice.n_long.n_sign_dscale = NUMERIC_SIGN(new) |
1161 				((uint16) scale & NUMERIC_DSCALE_MASK);
1162 		PG_RETURN_NUMERIC(new);
1163 	}
1164 
1165 	/*
1166 	 * We really need to fiddle with things - unpack the number into a
1167 	 * variable and let apply_typmod() do it.
1168 	 */
1169 	init_var(&var);
1170 
1171 	set_var_from_num(num, &var);
1172 	apply_typmod(&var, typmod);
1173 	new = make_result(&var);
1174 
1175 	free_var(&var);
1176 
1177 	PG_RETURN_NUMERIC(new);
1178 }
1179 
1180 Datum
numerictypmodin(PG_FUNCTION_ARGS)1181 numerictypmodin(PG_FUNCTION_ARGS)
1182 {
1183 	ArrayType  *ta = PG_GETARG_ARRAYTYPE_P(0);
1184 	int32	   *tl;
1185 	int			n;
1186 	int32		typmod;
1187 
1188 	tl = ArrayGetIntegerTypmods(ta, &n);
1189 
1190 	if (n == 2)
1191 	{
1192 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1193 			ereport(ERROR,
1194 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1195 					 errmsg("NUMERIC precision %d must be between 1 and %d",
1196 							tl[0], NUMERIC_MAX_PRECISION)));
1197 		if (tl[1] < 0 || tl[1] > tl[0])
1198 			ereport(ERROR,
1199 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1200 					 errmsg("NUMERIC scale %d must be between 0 and precision %d",
1201 							tl[1], tl[0])));
1202 		typmod = ((tl[0] << 16) | tl[1]) + VARHDRSZ;
1203 	}
1204 	else if (n == 1)
1205 	{
1206 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1207 			ereport(ERROR,
1208 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1209 					 errmsg("NUMERIC precision %d must be between 1 and %d",
1210 							tl[0], NUMERIC_MAX_PRECISION)));
1211 		/* scale defaults to zero */
1212 		typmod = (tl[0] << 16) + VARHDRSZ;
1213 	}
1214 	else
1215 	{
1216 		ereport(ERROR,
1217 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1218 				 errmsg("invalid NUMERIC type modifier")));
1219 		typmod = 0;				/* keep compiler quiet */
1220 	}
1221 
1222 	PG_RETURN_INT32(typmod);
1223 }
1224 
1225 Datum
numerictypmodout(PG_FUNCTION_ARGS)1226 numerictypmodout(PG_FUNCTION_ARGS)
1227 {
1228 	int32		typmod = PG_GETARG_INT32(0);
1229 	char	   *res = (char *) palloc(64);
1230 
1231 	if (typmod >= 0)
1232 		snprintf(res, 64, "(%d,%d)",
1233 				 ((typmod - VARHDRSZ) >> 16) & 0xffff,
1234 				 (typmod - VARHDRSZ) & 0xffff);
1235 	else
1236 		*res = '\0';
1237 
1238 	PG_RETURN_CSTRING(res);
1239 }
1240 
1241 
1242 /* ----------------------------------------------------------------------
1243  *
1244  * Sign manipulation, rounding and the like
1245  *
1246  * ----------------------------------------------------------------------
1247  */
1248 
1249 Datum
numeric_abs(PG_FUNCTION_ARGS)1250 numeric_abs(PG_FUNCTION_ARGS)
1251 {
1252 	Numeric		num = PG_GETARG_NUMERIC(0);
1253 	Numeric		res;
1254 
1255 	/*
1256 	 * Do it the easy way directly on the packed format
1257 	 */
1258 	res = duplicate_numeric(num);
1259 
1260 	if (NUMERIC_IS_SHORT(num))
1261 		res->choice.n_short.n_header =
1262 			num->choice.n_short.n_header & ~NUMERIC_SHORT_SIGN_MASK;
1263 	else if (NUMERIC_IS_SPECIAL(num))
1264 	{
1265 		/* This changes -Inf to Inf, and doesn't affect NaN */
1266 		res->choice.n_short.n_header =
1267 			num->choice.n_short.n_header & ~NUMERIC_INF_SIGN_MASK;
1268 	}
1269 	else
1270 		res->choice.n_long.n_sign_dscale = NUMERIC_POS | NUMERIC_DSCALE(num);
1271 
1272 	PG_RETURN_NUMERIC(res);
1273 }
1274 
1275 
1276 Datum
numeric_uminus(PG_FUNCTION_ARGS)1277 numeric_uminus(PG_FUNCTION_ARGS)
1278 {
1279 	Numeric		num = PG_GETARG_NUMERIC(0);
1280 	Numeric		res;
1281 
1282 	/*
1283 	 * Do it the easy way directly on the packed format
1284 	 */
1285 	res = duplicate_numeric(num);
1286 
1287 	if (NUMERIC_IS_SPECIAL(num))
1288 	{
1289 		/* Flip the sign, if it's Inf or -Inf */
1290 		if (!NUMERIC_IS_NAN(num))
1291 			res->choice.n_short.n_header =
1292 				num->choice.n_short.n_header ^ NUMERIC_INF_SIGN_MASK;
1293 	}
1294 
1295 	/*
1296 	 * The packed format is known to be totally zero digit trimmed always. So
1297 	 * once we've eliminated specials, we can identify a zero by the fact that
1298 	 * there are no digits at all. Do nothing to a zero.
1299 	 */
1300 	else if (NUMERIC_NDIGITS(num) != 0)
1301 	{
1302 		/* Else, flip the sign */
1303 		if (NUMERIC_IS_SHORT(num))
1304 			res->choice.n_short.n_header =
1305 				num->choice.n_short.n_header ^ NUMERIC_SHORT_SIGN_MASK;
1306 		else if (NUMERIC_SIGN(num) == NUMERIC_POS)
1307 			res->choice.n_long.n_sign_dscale =
1308 				NUMERIC_NEG | NUMERIC_DSCALE(num);
1309 		else
1310 			res->choice.n_long.n_sign_dscale =
1311 				NUMERIC_POS | NUMERIC_DSCALE(num);
1312 	}
1313 
1314 	PG_RETURN_NUMERIC(res);
1315 }
1316 
1317 
1318 Datum
numeric_uplus(PG_FUNCTION_ARGS)1319 numeric_uplus(PG_FUNCTION_ARGS)
1320 {
1321 	Numeric		num = PG_GETARG_NUMERIC(0);
1322 
1323 	PG_RETURN_NUMERIC(duplicate_numeric(num));
1324 }
1325 
1326 
1327 /*
1328  * numeric_sign_internal() -
1329  *
1330  * Returns -1 if the argument is less than 0, 0 if the argument is equal
1331  * to 0, and 1 if the argument is greater than zero.  Caller must have
1332  * taken care of the NaN case, but we can handle infinities here.
1333  */
1334 static int
numeric_sign_internal(Numeric num)1335 numeric_sign_internal(Numeric num)
1336 {
1337 	if (NUMERIC_IS_SPECIAL(num))
1338 	{
1339 		Assert(!NUMERIC_IS_NAN(num));
1340 		/* Must be Inf or -Inf */
1341 		if (NUMERIC_IS_PINF(num))
1342 			return 1;
1343 		else
1344 			return -1;
1345 	}
1346 
1347 	/*
1348 	 * The packed format is known to be totally zero digit trimmed always. So
1349 	 * once we've eliminated specials, we can identify a zero by the fact that
1350 	 * there are no digits at all.
1351 	 */
1352 	else if (NUMERIC_NDIGITS(num) == 0)
1353 		return 0;
1354 	else if (NUMERIC_SIGN(num) == NUMERIC_NEG)
1355 		return -1;
1356 	else
1357 		return 1;
1358 }
1359 
1360 /*
1361  * numeric_sign() -
1362  *
1363  * returns -1 if the argument is less than 0, 0 if the argument is equal
1364  * to 0, and 1 if the argument is greater than zero.
1365  */
1366 Datum
numeric_sign(PG_FUNCTION_ARGS)1367 numeric_sign(PG_FUNCTION_ARGS)
1368 {
1369 	Numeric		num = PG_GETARG_NUMERIC(0);
1370 
1371 	/*
1372 	 * Handle NaN (infinities can be handled normally)
1373 	 */
1374 	if (NUMERIC_IS_NAN(num))
1375 		PG_RETURN_NUMERIC(make_result(&const_nan));
1376 
1377 	switch (numeric_sign_internal(num))
1378 	{
1379 		case 0:
1380 			PG_RETURN_NUMERIC(make_result(&const_zero));
1381 		case 1:
1382 			PG_RETURN_NUMERIC(make_result(&const_one));
1383 		case -1:
1384 			PG_RETURN_NUMERIC(make_result(&const_minus_one));
1385 	}
1386 
1387 	Assert(false);
1388 	return (Datum) 0;
1389 }
1390 
1391 
1392 /*
1393  * numeric_round() -
1394  *
1395  *	Round a value to have 'scale' digits after the decimal point.
1396  *	We allow negative 'scale', implying rounding before the decimal
1397  *	point --- Oracle interprets rounding that way.
1398  */
1399 Datum
numeric_round(PG_FUNCTION_ARGS)1400 numeric_round(PG_FUNCTION_ARGS)
1401 {
1402 	Numeric		num = PG_GETARG_NUMERIC(0);
1403 	int32		scale = PG_GETARG_INT32(1);
1404 	Numeric		res;
1405 	NumericVar	arg;
1406 
1407 	/*
1408 	 * Handle NaN and infinities
1409 	 */
1410 	if (NUMERIC_IS_SPECIAL(num))
1411 		PG_RETURN_NUMERIC(duplicate_numeric(num));
1412 
1413 	/*
1414 	 * Limit the scale value to avoid possible overflow in calculations
1415 	 */
1416 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1417 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1418 
1419 	/*
1420 	 * Unpack the argument and round it at the proper digit position
1421 	 */
1422 	init_var(&arg);
1423 	set_var_from_num(num, &arg);
1424 
1425 	round_var(&arg, scale);
1426 
1427 	/* We don't allow negative output dscale */
1428 	if (scale < 0)
1429 		arg.dscale = 0;
1430 
1431 	/*
1432 	 * Return the rounded result
1433 	 */
1434 	res = make_result(&arg);
1435 
1436 	free_var(&arg);
1437 	PG_RETURN_NUMERIC(res);
1438 }
1439 
1440 
1441 /*
1442  * numeric_trunc() -
1443  *
1444  *	Truncate a value to have 'scale' digits after the decimal point.
1445  *	We allow negative 'scale', implying a truncation before the decimal
1446  *	point --- Oracle interprets truncation that way.
1447  */
1448 Datum
numeric_trunc(PG_FUNCTION_ARGS)1449 numeric_trunc(PG_FUNCTION_ARGS)
1450 {
1451 	Numeric		num = PG_GETARG_NUMERIC(0);
1452 	int32		scale = PG_GETARG_INT32(1);
1453 	Numeric		res;
1454 	NumericVar	arg;
1455 
1456 	/*
1457 	 * Handle NaN and infinities
1458 	 */
1459 	if (NUMERIC_IS_SPECIAL(num))
1460 		PG_RETURN_NUMERIC(duplicate_numeric(num));
1461 
1462 	/*
1463 	 * Limit the scale value to avoid possible overflow in calculations
1464 	 */
1465 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1466 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1467 
1468 	/*
1469 	 * Unpack the argument and truncate it at the proper digit position
1470 	 */
1471 	init_var(&arg);
1472 	set_var_from_num(num, &arg);
1473 
1474 	trunc_var(&arg, scale);
1475 
1476 	/* We don't allow negative output dscale */
1477 	if (scale < 0)
1478 		arg.dscale = 0;
1479 
1480 	/*
1481 	 * Return the truncated result
1482 	 */
1483 	res = make_result(&arg);
1484 
1485 	free_var(&arg);
1486 	PG_RETURN_NUMERIC(res);
1487 }
1488 
1489 
1490 /*
1491  * numeric_ceil() -
1492  *
1493  *	Return the smallest integer greater than or equal to the argument
1494  */
1495 Datum
numeric_ceil(PG_FUNCTION_ARGS)1496 numeric_ceil(PG_FUNCTION_ARGS)
1497 {
1498 	Numeric		num = PG_GETARG_NUMERIC(0);
1499 	Numeric		res;
1500 	NumericVar	result;
1501 
1502 	/*
1503 	 * Handle NaN and infinities
1504 	 */
1505 	if (NUMERIC_IS_SPECIAL(num))
1506 		PG_RETURN_NUMERIC(duplicate_numeric(num));
1507 
1508 	init_var_from_num(num, &result);
1509 	ceil_var(&result, &result);
1510 
1511 	res = make_result(&result);
1512 	free_var(&result);
1513 
1514 	PG_RETURN_NUMERIC(res);
1515 }
1516 
1517 
1518 /*
1519  * numeric_floor() -
1520  *
1521  *	Return the largest integer equal to or less than the argument
1522  */
1523 Datum
numeric_floor(PG_FUNCTION_ARGS)1524 numeric_floor(PG_FUNCTION_ARGS)
1525 {
1526 	Numeric		num = PG_GETARG_NUMERIC(0);
1527 	Numeric		res;
1528 	NumericVar	result;
1529 
1530 	/*
1531 	 * Handle NaN and infinities
1532 	 */
1533 	if (NUMERIC_IS_SPECIAL(num))
1534 		PG_RETURN_NUMERIC(duplicate_numeric(num));
1535 
1536 	init_var_from_num(num, &result);
1537 	floor_var(&result, &result);
1538 
1539 	res = make_result(&result);
1540 	free_var(&result);
1541 
1542 	PG_RETURN_NUMERIC(res);
1543 }
1544 
1545 
1546 /*
1547  * generate_series_numeric() -
1548  *
1549  *	Generate series of numeric.
1550  */
1551 Datum
generate_series_numeric(PG_FUNCTION_ARGS)1552 generate_series_numeric(PG_FUNCTION_ARGS)
1553 {
1554 	return generate_series_step_numeric(fcinfo);
1555 }
1556 
1557 Datum
generate_series_step_numeric(PG_FUNCTION_ARGS)1558 generate_series_step_numeric(PG_FUNCTION_ARGS)
1559 {
1560 	generate_series_numeric_fctx *fctx;
1561 	FuncCallContext *funcctx;
1562 	MemoryContext oldcontext;
1563 
1564 	if (SRF_IS_FIRSTCALL())
1565 	{
1566 		Numeric		start_num = PG_GETARG_NUMERIC(0);
1567 		Numeric		stop_num = PG_GETARG_NUMERIC(1);
1568 		NumericVar	steploc = const_one;
1569 
1570 		/* Reject NaN and infinities in start and stop values */
1571 		if (NUMERIC_IS_SPECIAL(start_num))
1572 		{
1573 			if (NUMERIC_IS_NAN(start_num))
1574 				ereport(ERROR,
1575 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1576 						 errmsg("start value cannot be NaN")));
1577 			else
1578 				ereport(ERROR,
1579 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1580 						 errmsg("start value cannot be infinity")));
1581 		}
1582 		if (NUMERIC_IS_SPECIAL(stop_num))
1583 		{
1584 			if (NUMERIC_IS_NAN(stop_num))
1585 				ereport(ERROR,
1586 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1587 						 errmsg("stop value cannot be NaN")));
1588 			else
1589 				ereport(ERROR,
1590 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1591 						 errmsg("stop value cannot be infinity")));
1592 		}
1593 
1594 		/* see if we were given an explicit step size */
1595 		if (PG_NARGS() == 3)
1596 		{
1597 			Numeric		step_num = PG_GETARG_NUMERIC(2);
1598 
1599 			if (NUMERIC_IS_SPECIAL(step_num))
1600 			{
1601 				if (NUMERIC_IS_NAN(step_num))
1602 					ereport(ERROR,
1603 							(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1604 							 errmsg("step size cannot be NaN")));
1605 				else
1606 					ereport(ERROR,
1607 							(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1608 							 errmsg("step size cannot be infinity")));
1609 			}
1610 
1611 			init_var_from_num(step_num, &steploc);
1612 
1613 			if (cmp_var(&steploc, &const_zero) == 0)
1614 				ereport(ERROR,
1615 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1616 						 errmsg("step size cannot equal zero")));
1617 		}
1618 
1619 		/* create a function context for cross-call persistence */
1620 		funcctx = SRF_FIRSTCALL_INIT();
1621 
1622 		/*
1623 		 * Switch to memory context appropriate for multiple function calls.
1624 		 */
1625 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1626 
1627 		/* allocate memory for user context */
1628 		fctx = (generate_series_numeric_fctx *)
1629 			palloc(sizeof(generate_series_numeric_fctx));
1630 
1631 		/*
1632 		 * Use fctx to keep state from call to call. Seed current with the
1633 		 * original start value. We must copy the start_num and stop_num
1634 		 * values rather than pointing to them, since we may have detoasted
1635 		 * them in the per-call context.
1636 		 */
1637 		init_var(&fctx->current);
1638 		init_var(&fctx->stop);
1639 		init_var(&fctx->step);
1640 
1641 		set_var_from_num(start_num, &fctx->current);
1642 		set_var_from_num(stop_num, &fctx->stop);
1643 		set_var_from_var(&steploc, &fctx->step);
1644 
1645 		funcctx->user_fctx = fctx;
1646 		MemoryContextSwitchTo(oldcontext);
1647 	}
1648 
1649 	/* stuff done on every call of the function */
1650 	funcctx = SRF_PERCALL_SETUP();
1651 
1652 	/*
1653 	 * Get the saved state and use current state as the result of this
1654 	 * iteration.
1655 	 */
1656 	fctx = funcctx->user_fctx;
1657 
1658 	if ((fctx->step.sign == NUMERIC_POS &&
1659 		 cmp_var(&fctx->current, &fctx->stop) <= 0) ||
1660 		(fctx->step.sign == NUMERIC_NEG &&
1661 		 cmp_var(&fctx->current, &fctx->stop) >= 0))
1662 	{
1663 		Numeric		result = make_result(&fctx->current);
1664 
1665 		/* switch to memory context appropriate for iteration calculation */
1666 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1667 
1668 		/* increment current in preparation for next iteration */
1669 		add_var(&fctx->current, &fctx->step, &fctx->current);
1670 		MemoryContextSwitchTo(oldcontext);
1671 
1672 		/* do when there is more left to send */
1673 		SRF_RETURN_NEXT(funcctx, NumericGetDatum(result));
1674 	}
1675 	else
1676 		/* do when there is no more left */
1677 		SRF_RETURN_DONE(funcctx);
1678 }
1679 
1680 
1681 /*
1682  * Implements the numeric version of the width_bucket() function
1683  * defined by SQL2003. See also width_bucket_float8().
1684  *
1685  * 'bound1' and 'bound2' are the lower and upper bounds of the
1686  * histogram's range, respectively. 'count' is the number of buckets
1687  * in the histogram. width_bucket() returns an integer indicating the
1688  * bucket number that 'operand' belongs to in an equiwidth histogram
1689  * with the specified characteristics. An operand smaller than the
1690  * lower bound is assigned to bucket 0. An operand greater than the
1691  * upper bound is assigned to an additional bucket (with number
1692  * count+1). We don't allow "NaN" for any of the numeric arguments.
1693  */
1694 Datum
width_bucket_numeric(PG_FUNCTION_ARGS)1695 width_bucket_numeric(PG_FUNCTION_ARGS)
1696 {
1697 	Numeric		operand = PG_GETARG_NUMERIC(0);
1698 	Numeric		bound1 = PG_GETARG_NUMERIC(1);
1699 	Numeric		bound2 = PG_GETARG_NUMERIC(2);
1700 	int32		count = PG_GETARG_INT32(3);
1701 	NumericVar	count_var;
1702 	NumericVar	result_var;
1703 	int32		result;
1704 
1705 	if (count <= 0)
1706 		ereport(ERROR,
1707 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1708 				 errmsg("count must be greater than zero")));
1709 
1710 	if (NUMERIC_IS_SPECIAL(operand) ||
1711 		NUMERIC_IS_SPECIAL(bound1) ||
1712 		NUMERIC_IS_SPECIAL(bound2))
1713 	{
1714 		if (NUMERIC_IS_NAN(operand) ||
1715 			NUMERIC_IS_NAN(bound1) ||
1716 			NUMERIC_IS_NAN(bound2))
1717 			ereport(ERROR,
1718 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1719 					 errmsg("operand, lower bound, and upper bound cannot be NaN")));
1720 		/* We allow "operand" to be infinite; cmp_numerics will cope */
1721 		if (NUMERIC_IS_INF(bound1) || NUMERIC_IS_INF(bound2))
1722 			ereport(ERROR,
1723 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1724 					 errmsg("lower and upper bounds must be finite")));
1725 	}
1726 
1727 	init_var(&result_var);
1728 	init_var(&count_var);
1729 
1730 	/* Convert 'count' to a numeric, for ease of use later */
1731 	int64_to_numericvar((int64) count, &count_var);
1732 
1733 	switch (cmp_numerics(bound1, bound2))
1734 	{
1735 		case 0:
1736 			ereport(ERROR,
1737 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1738 					 errmsg("lower bound cannot equal upper bound")));
1739 			break;
1740 
1741 			/* bound1 < bound2 */
1742 		case -1:
1743 			if (cmp_numerics(operand, bound1) < 0)
1744 				set_var_from_var(&const_zero, &result_var);
1745 			else if (cmp_numerics(operand, bound2) >= 0)
1746 				add_var(&count_var, &const_one, &result_var);
1747 			else
1748 				compute_bucket(operand, bound1, bound2, &count_var, false,
1749 							   &result_var);
1750 			break;
1751 
1752 			/* bound1 > bound2 */
1753 		case 1:
1754 			if (cmp_numerics(operand, bound1) > 0)
1755 				set_var_from_var(&const_zero, &result_var);
1756 			else if (cmp_numerics(operand, bound2) <= 0)
1757 				add_var(&count_var, &const_one, &result_var);
1758 			else
1759 				compute_bucket(operand, bound1, bound2, &count_var, true,
1760 							   &result_var);
1761 			break;
1762 	}
1763 
1764 	/* if result exceeds the range of a legal int4, we ereport here */
1765 	if (!numericvar_to_int32(&result_var, &result))
1766 		ereport(ERROR,
1767 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
1768 				 errmsg("integer out of range")));
1769 
1770 	free_var(&count_var);
1771 	free_var(&result_var);
1772 
1773 	PG_RETURN_INT32(result);
1774 }
1775 
1776 /*
1777  * If 'operand' is not outside the bucket range, determine the correct
1778  * bucket for it to go. The calculations performed by this function
1779  * are derived directly from the SQL2003 spec. Note however that we
1780  * multiply by count before dividing, to avoid unnecessary roundoff error.
1781  */
1782 static void
compute_bucket(Numeric operand,Numeric bound1,Numeric bound2,const NumericVar * count_var,bool reversed_bounds,NumericVar * result_var)1783 compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
1784 			   const NumericVar *count_var, bool reversed_bounds,
1785 			   NumericVar *result_var)
1786 {
1787 	NumericVar	bound1_var;
1788 	NumericVar	bound2_var;
1789 	NumericVar	operand_var;
1790 
1791 	init_var_from_num(bound1, &bound1_var);
1792 	init_var_from_num(bound2, &bound2_var);
1793 	init_var_from_num(operand, &operand_var);
1794 
1795 	if (!reversed_bounds)
1796 	{
1797 		sub_var(&operand_var, &bound1_var, &operand_var);
1798 		sub_var(&bound2_var, &bound1_var, &bound2_var);
1799 	}
1800 	else
1801 	{
1802 		sub_var(&bound1_var, &operand_var, &operand_var);
1803 		sub_var(&bound1_var, &bound2_var, &bound2_var);
1804 	}
1805 
1806 	mul_var(&operand_var, count_var, &operand_var,
1807 			operand_var.dscale + count_var->dscale);
1808 	div_var(&operand_var, &bound2_var, result_var,
1809 			select_div_scale(&operand_var, &bound2_var), true);
1810 	add_var(result_var, &const_one, result_var);
1811 	floor_var(result_var, result_var);
1812 
1813 	free_var(&bound1_var);
1814 	free_var(&bound2_var);
1815 	free_var(&operand_var);
1816 }
1817 
1818 /* ----------------------------------------------------------------------
1819  *
1820  * Comparison functions
1821  *
1822  * Note: btree indexes need these routines not to leak memory; therefore,
1823  * be careful to free working copies of toasted datums.  Most places don't
1824  * need to be so careful.
1825  *
1826  * Sort support:
1827  *
1828  * We implement the sortsupport strategy routine in order to get the benefit of
1829  * abbreviation. The ordinary numeric comparison can be quite slow as a result
1830  * of palloc/pfree cycles (due to detoasting packed values for alignment);
1831  * while this could be worked on itself, the abbreviation strategy gives more
1832  * speedup in many common cases.
1833  *
1834  * Two different representations are used for the abbreviated form, one in
1835  * int32 and one in int64, whichever fits into a by-value Datum.  In both cases
1836  * the representation is negated relative to the original value, because we use
1837  * the largest negative value for NaN, which sorts higher than other values. We
1838  * convert the absolute value of the numeric to a 31-bit or 63-bit positive
1839  * value, and then negate it if the original number was positive.
1840  *
1841  * We abort the abbreviation process if the abbreviation cardinality is below
1842  * 0.01% of the row count (1 per 10k non-null rows).  The actual break-even
1843  * point is somewhat below that, perhaps 1 per 30k (at 1 per 100k there's a
1844  * very small penalty), but we don't want to build up too many abbreviated
1845  * values before first testing for abort, so we take the slightly pessimistic
1846  * number.  We make no attempt to estimate the cardinality of the real values,
1847  * since it plays no part in the cost model here (if the abbreviation is equal,
1848  * the cost of comparing equal and unequal underlying values is comparable).
1849  * We discontinue even checking for abort (saving us the hashing overhead) if
1850  * the estimated cardinality gets to 100k; that would be enough to support many
1851  * billions of rows while doing no worse than breaking even.
1852  *
1853  * ----------------------------------------------------------------------
1854  */
1855 
1856 /*
1857  * Sort support strategy routine.
1858  */
1859 Datum
numeric_sortsupport(PG_FUNCTION_ARGS)1860 numeric_sortsupport(PG_FUNCTION_ARGS)
1861 {
1862 	SortSupport ssup = (SortSupport) PG_GETARG_POINTER(0);
1863 
1864 	ssup->comparator = numeric_fast_cmp;
1865 
1866 	if (ssup->abbreviate)
1867 	{
1868 		NumericSortSupport *nss;
1869 		MemoryContext oldcontext = MemoryContextSwitchTo(ssup->ssup_cxt);
1870 
1871 		nss = palloc(sizeof(NumericSortSupport));
1872 
1873 		/*
1874 		 * palloc a buffer for handling unaligned packed values in addition to
1875 		 * the support struct
1876 		 */
1877 		nss->buf = palloc(VARATT_SHORT_MAX + VARHDRSZ + 1);
1878 
1879 		nss->input_count = 0;
1880 		nss->estimating = true;
1881 		initHyperLogLog(&nss->abbr_card, 10);
1882 
1883 		ssup->ssup_extra = nss;
1884 
1885 		ssup->abbrev_full_comparator = ssup->comparator;
1886 		ssup->comparator = numeric_cmp_abbrev;
1887 		ssup->abbrev_converter = numeric_abbrev_convert;
1888 		ssup->abbrev_abort = numeric_abbrev_abort;
1889 
1890 		MemoryContextSwitchTo(oldcontext);
1891 	}
1892 
1893 	PG_RETURN_VOID();
1894 }
1895 
1896 /*
1897  * Abbreviate a numeric datum, handling NaNs and detoasting
1898  * (must not leak memory!)
1899  */
1900 static Datum
numeric_abbrev_convert(Datum original_datum,SortSupport ssup)1901 numeric_abbrev_convert(Datum original_datum, SortSupport ssup)
1902 {
1903 	NumericSortSupport *nss = ssup->ssup_extra;
1904 	void	   *original_varatt = PG_DETOAST_DATUM_PACKED(original_datum);
1905 	Numeric		value;
1906 	Datum		result;
1907 
1908 	nss->input_count += 1;
1909 
1910 	/*
1911 	 * This is to handle packed datums without needing a palloc/pfree cycle;
1912 	 * we keep and reuse a buffer large enough to handle any short datum.
1913 	 */
1914 	if (VARATT_IS_SHORT(original_varatt))
1915 	{
1916 		void	   *buf = nss->buf;
1917 		Size		sz = VARSIZE_SHORT(original_varatt) - VARHDRSZ_SHORT;
1918 
1919 		Assert(sz <= VARATT_SHORT_MAX - VARHDRSZ_SHORT);
1920 
1921 		SET_VARSIZE(buf, VARHDRSZ + sz);
1922 		memcpy(VARDATA(buf), VARDATA_SHORT(original_varatt), sz);
1923 
1924 		value = (Numeric) buf;
1925 	}
1926 	else
1927 		value = (Numeric) original_varatt;
1928 
1929 	if (NUMERIC_IS_SPECIAL(value))
1930 	{
1931 		if (NUMERIC_IS_PINF(value))
1932 			result = NUMERIC_ABBREV_PINF;
1933 		else if (NUMERIC_IS_NINF(value))
1934 			result = NUMERIC_ABBREV_NINF;
1935 		else
1936 			result = NUMERIC_ABBREV_NAN;
1937 	}
1938 	else
1939 	{
1940 		NumericVar	var;
1941 
1942 		init_var_from_num(value, &var);
1943 
1944 		result = numeric_abbrev_convert_var(&var, nss);
1945 	}
1946 
1947 	/* should happen only for external/compressed toasts */
1948 	if ((Pointer) original_varatt != DatumGetPointer(original_datum))
1949 		pfree(original_varatt);
1950 
1951 	return result;
1952 }
1953 
1954 /*
1955  * Consider whether to abort abbreviation.
1956  *
1957  * We pay no attention to the cardinality of the non-abbreviated data. There is
1958  * no reason to do so: unlike text, we have no fast check for equal values, so
1959  * we pay the full overhead whenever the abbreviations are equal regardless of
1960  * whether the underlying values are also equal.
1961  */
1962 static bool
numeric_abbrev_abort(int memtupcount,SortSupport ssup)1963 numeric_abbrev_abort(int memtupcount, SortSupport ssup)
1964 {
1965 	NumericSortSupport *nss = ssup->ssup_extra;
1966 	double		abbr_card;
1967 
1968 	if (memtupcount < 10000 || nss->input_count < 10000 || !nss->estimating)
1969 		return false;
1970 
1971 	abbr_card = estimateHyperLogLog(&nss->abbr_card);
1972 
1973 	/*
1974 	 * If we have >100k distinct values, then even if we were sorting many
1975 	 * billion rows we'd likely still break even, and the penalty of undoing
1976 	 * that many rows of abbrevs would probably not be worth it. Stop even
1977 	 * counting at that point.
1978 	 */
1979 	if (abbr_card > 100000.0)
1980 	{
1981 #ifdef TRACE_SORT
1982 		if (trace_sort)
1983 			elog(LOG,
1984 				 "numeric_abbrev: estimation ends at cardinality %f"
1985 				 " after " INT64_FORMAT " values (%d rows)",
1986 				 abbr_card, nss->input_count, memtupcount);
1987 #endif
1988 		nss->estimating = false;
1989 		return false;
1990 	}
1991 
1992 	/*
1993 	 * Target minimum cardinality is 1 per ~10k of non-null inputs.  (The
1994 	 * break even point is somewhere between one per 100k rows, where
1995 	 * abbreviation has a very slight penalty, and 1 per 10k where it wins by
1996 	 * a measurable percentage.)  We use the relatively pessimistic 10k
1997 	 * threshold, and add a 0.5 row fudge factor, because it allows us to
1998 	 * abort earlier on genuinely pathological data where we've had exactly
1999 	 * one abbreviated value in the first 10k (non-null) rows.
2000 	 */
2001 	if (abbr_card < nss->input_count / 10000.0 + 0.5)
2002 	{
2003 #ifdef TRACE_SORT
2004 		if (trace_sort)
2005 			elog(LOG,
2006 				 "numeric_abbrev: aborting abbreviation at cardinality %f"
2007 				 " below threshold %f after " INT64_FORMAT " values (%d rows)",
2008 				 abbr_card, nss->input_count / 10000.0 + 0.5,
2009 				 nss->input_count, memtupcount);
2010 #endif
2011 		return true;
2012 	}
2013 
2014 #ifdef TRACE_SORT
2015 	if (trace_sort)
2016 		elog(LOG,
2017 			 "numeric_abbrev: cardinality %f"
2018 			 " after " INT64_FORMAT " values (%d rows)",
2019 			 abbr_card, nss->input_count, memtupcount);
2020 #endif
2021 
2022 	return false;
2023 }
2024 
2025 /*
2026  * Non-fmgr interface to the comparison routine to allow sortsupport to elide
2027  * the fmgr call.  The saving here is small given how slow numeric comparisons
2028  * are, but it is a required part of the sort support API when abbreviations
2029  * are performed.
2030  *
2031  * Two palloc/pfree cycles could be saved here by using persistent buffers for
2032  * aligning short-varlena inputs, but this has not so far been considered to
2033  * be worth the effort.
2034  */
2035 static int
numeric_fast_cmp(Datum x,Datum y,SortSupport ssup)2036 numeric_fast_cmp(Datum x, Datum y, SortSupport ssup)
2037 {
2038 	Numeric		nx = DatumGetNumeric(x);
2039 	Numeric		ny = DatumGetNumeric(y);
2040 	int			result;
2041 
2042 	result = cmp_numerics(nx, ny);
2043 
2044 	if ((Pointer) nx != DatumGetPointer(x))
2045 		pfree(nx);
2046 	if ((Pointer) ny != DatumGetPointer(y))
2047 		pfree(ny);
2048 
2049 	return result;
2050 }
2051 
2052 /*
2053  * Compare abbreviations of values. (Abbreviations may be equal where the true
2054  * values differ, but if the abbreviations differ, they must reflect the
2055  * ordering of the true values.)
2056  */
2057 static int
numeric_cmp_abbrev(Datum x,Datum y,SortSupport ssup)2058 numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup)
2059 {
2060 	/*
2061 	 * NOTE WELL: this is intentionally backwards, because the abbreviation is
2062 	 * negated relative to the original value, to handle NaN/infinity cases.
2063 	 */
2064 	if (DatumGetNumericAbbrev(x) < DatumGetNumericAbbrev(y))
2065 		return 1;
2066 	if (DatumGetNumericAbbrev(x) > DatumGetNumericAbbrev(y))
2067 		return -1;
2068 	return 0;
2069 }
2070 
2071 /*
2072  * Abbreviate a NumericVar according to the available bit size.
2073  *
2074  * The 31-bit value is constructed as:
2075  *
2076  *	0 + 7bits digit weight + 24 bits digit value
2077  *
2078  * where the digit weight is in single decimal digits, not digit words, and
2079  * stored in excess-44 representation[1]. The 24-bit digit value is the 7 most
2080  * significant decimal digits of the value converted to binary. Values whose
2081  * weights would fall outside the representable range are rounded off to zero
2082  * (which is also used to represent actual zeros) or to 0x7FFFFFFF (which
2083  * otherwise cannot occur). Abbreviation therefore fails to gain any advantage
2084  * where values are outside the range 10^-44 to 10^83, which is not considered
2085  * to be a serious limitation, or when values are of the same magnitude and
2086  * equal in the first 7 decimal digits, which is considered to be an
2087  * unavoidable limitation given the available bits. (Stealing three more bits
2088  * to compare another digit would narrow the range of representable weights by
2089  * a factor of 8, which starts to look like a real limiting factor.)
2090  *
2091  * (The value 44 for the excess is essentially arbitrary)
2092  *
2093  * The 63-bit value is constructed as:
2094  *
2095  *	0 + 7bits weight + 4 x 14-bit packed digit words
2096  *
2097  * The weight in this case is again stored in excess-44, but this time it is
2098  * the original weight in digit words (i.e. powers of 10000). The first four
2099  * digit words of the value (if present; trailing zeros are assumed as needed)
2100  * are packed into 14 bits each to form the rest of the value. Again,
2101  * out-of-range values are rounded off to 0 or 0x7FFFFFFFFFFFFFFF. The
2102  * representable range in this case is 10^-176 to 10^332, which is considered
2103  * to be good enough for all practical purposes, and comparison of 4 words
2104  * means that at least 13 decimal digits are compared, which is considered to
2105  * be a reasonable compromise between effectiveness and efficiency in computing
2106  * the abbreviation.
2107  *
2108  * (The value 44 for the excess is even more arbitrary here, it was chosen just
2109  * to match the value used in the 31-bit case)
2110  *
2111  * [1] - Excess-k representation means that the value is offset by adding 'k'
2112  * and then treated as unsigned, so the smallest representable value is stored
2113  * with all bits zero. This allows simple comparisons to work on the composite
2114  * value.
2115  */
2116 
2117 #if NUMERIC_ABBREV_BITS == 64
2118 
2119 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)2120 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
2121 {
2122 	int			ndigits = var->ndigits;
2123 	int			weight = var->weight;
2124 	int64		result;
2125 
2126 	if (ndigits == 0 || weight < -44)
2127 	{
2128 		result = 0;
2129 	}
2130 	else if (weight > 83)
2131 	{
2132 		result = PG_INT64_MAX;
2133 	}
2134 	else
2135 	{
2136 		result = ((int64) (weight + 44) << 56);
2137 
2138 		switch (ndigits)
2139 		{
2140 			default:
2141 				result |= ((int64) var->digits[3]);
2142 				/* FALLTHROUGH */
2143 			case 3:
2144 				result |= ((int64) var->digits[2]) << 14;
2145 				/* FALLTHROUGH */
2146 			case 2:
2147 				result |= ((int64) var->digits[1]) << 28;
2148 				/* FALLTHROUGH */
2149 			case 1:
2150 				result |= ((int64) var->digits[0]) << 42;
2151 				break;
2152 		}
2153 	}
2154 
2155 	/* the abbrev is negated relative to the original */
2156 	if (var->sign == NUMERIC_POS)
2157 		result = -result;
2158 
2159 	if (nss->estimating)
2160 	{
2161 		uint32		tmp = ((uint32) result
2162 						   ^ (uint32) ((uint64) result >> 32));
2163 
2164 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
2165 	}
2166 
2167 	return NumericAbbrevGetDatum(result);
2168 }
2169 
2170 #endif							/* NUMERIC_ABBREV_BITS == 64 */
2171 
2172 #if NUMERIC_ABBREV_BITS == 32
2173 
2174 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)2175 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
2176 {
2177 	int			ndigits = var->ndigits;
2178 	int			weight = var->weight;
2179 	int32		result;
2180 
2181 	if (ndigits == 0 || weight < -11)
2182 	{
2183 		result = 0;
2184 	}
2185 	else if (weight > 20)
2186 	{
2187 		result = PG_INT32_MAX;
2188 	}
2189 	else
2190 	{
2191 		NumericDigit nxt1 = (ndigits > 1) ? var->digits[1] : 0;
2192 
2193 		weight = (weight + 11) * 4;
2194 
2195 		result = var->digits[0];
2196 
2197 		/*
2198 		 * "result" now has 1 to 4 nonzero decimal digits. We pack in more
2199 		 * digits to make 7 in total (largest we can fit in 24 bits)
2200 		 */
2201 
2202 		if (result > 999)
2203 		{
2204 			/* already have 4 digits, add 3 more */
2205 			result = (result * 1000) + (nxt1 / 10);
2206 			weight += 3;
2207 		}
2208 		else if (result > 99)
2209 		{
2210 			/* already have 3 digits, add 4 more */
2211 			result = (result * 10000) + nxt1;
2212 			weight += 2;
2213 		}
2214 		else if (result > 9)
2215 		{
2216 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
2217 
2218 			/* already have 2 digits, add 5 more */
2219 			result = (result * 100000) + (nxt1 * 10) + (nxt2 / 1000);
2220 			weight += 1;
2221 		}
2222 		else
2223 		{
2224 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
2225 
2226 			/* already have 1 digit, add 6 more */
2227 			result = (result * 1000000) + (nxt1 * 100) + (nxt2 / 100);
2228 		}
2229 
2230 		result = result | (weight << 24);
2231 	}
2232 
2233 	/* the abbrev is negated relative to the original */
2234 	if (var->sign == NUMERIC_POS)
2235 		result = -result;
2236 
2237 	if (nss->estimating)
2238 	{
2239 		uint32		tmp = (uint32) result;
2240 
2241 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
2242 	}
2243 
2244 	return NumericAbbrevGetDatum(result);
2245 }
2246 
2247 #endif							/* NUMERIC_ABBREV_BITS == 32 */
2248 
2249 /*
2250  * Ordinary (non-sortsupport) comparisons follow.
2251  */
2252 
2253 Datum
numeric_cmp(PG_FUNCTION_ARGS)2254 numeric_cmp(PG_FUNCTION_ARGS)
2255 {
2256 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2257 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2258 	int			result;
2259 
2260 	result = cmp_numerics(num1, num2);
2261 
2262 	PG_FREE_IF_COPY(num1, 0);
2263 	PG_FREE_IF_COPY(num2, 1);
2264 
2265 	PG_RETURN_INT32(result);
2266 }
2267 
2268 
2269 Datum
numeric_eq(PG_FUNCTION_ARGS)2270 numeric_eq(PG_FUNCTION_ARGS)
2271 {
2272 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2273 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2274 	bool		result;
2275 
2276 	result = cmp_numerics(num1, num2) == 0;
2277 
2278 	PG_FREE_IF_COPY(num1, 0);
2279 	PG_FREE_IF_COPY(num2, 1);
2280 
2281 	PG_RETURN_BOOL(result);
2282 }
2283 
2284 Datum
numeric_ne(PG_FUNCTION_ARGS)2285 numeric_ne(PG_FUNCTION_ARGS)
2286 {
2287 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2288 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2289 	bool		result;
2290 
2291 	result = cmp_numerics(num1, num2) != 0;
2292 
2293 	PG_FREE_IF_COPY(num1, 0);
2294 	PG_FREE_IF_COPY(num2, 1);
2295 
2296 	PG_RETURN_BOOL(result);
2297 }
2298 
2299 Datum
numeric_gt(PG_FUNCTION_ARGS)2300 numeric_gt(PG_FUNCTION_ARGS)
2301 {
2302 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2303 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2304 	bool		result;
2305 
2306 	result = cmp_numerics(num1, num2) > 0;
2307 
2308 	PG_FREE_IF_COPY(num1, 0);
2309 	PG_FREE_IF_COPY(num2, 1);
2310 
2311 	PG_RETURN_BOOL(result);
2312 }
2313 
2314 Datum
numeric_ge(PG_FUNCTION_ARGS)2315 numeric_ge(PG_FUNCTION_ARGS)
2316 {
2317 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2318 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2319 	bool		result;
2320 
2321 	result = cmp_numerics(num1, num2) >= 0;
2322 
2323 	PG_FREE_IF_COPY(num1, 0);
2324 	PG_FREE_IF_COPY(num2, 1);
2325 
2326 	PG_RETURN_BOOL(result);
2327 }
2328 
2329 Datum
numeric_lt(PG_FUNCTION_ARGS)2330 numeric_lt(PG_FUNCTION_ARGS)
2331 {
2332 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2333 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2334 	bool		result;
2335 
2336 	result = cmp_numerics(num1, num2) < 0;
2337 
2338 	PG_FREE_IF_COPY(num1, 0);
2339 	PG_FREE_IF_COPY(num2, 1);
2340 
2341 	PG_RETURN_BOOL(result);
2342 }
2343 
2344 Datum
numeric_le(PG_FUNCTION_ARGS)2345 numeric_le(PG_FUNCTION_ARGS)
2346 {
2347 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2348 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2349 	bool		result;
2350 
2351 	result = cmp_numerics(num1, num2) <= 0;
2352 
2353 	PG_FREE_IF_COPY(num1, 0);
2354 	PG_FREE_IF_COPY(num2, 1);
2355 
2356 	PG_RETURN_BOOL(result);
2357 }
2358 
2359 static int
cmp_numerics(Numeric num1,Numeric num2)2360 cmp_numerics(Numeric num1, Numeric num2)
2361 {
2362 	int			result;
2363 
2364 	/*
2365 	 * We consider all NANs to be equal and larger than any non-NAN (including
2366 	 * Infinity).  This is somewhat arbitrary; the important thing is to have
2367 	 * a consistent sort order.
2368 	 */
2369 	if (NUMERIC_IS_SPECIAL(num1))
2370 	{
2371 		if (NUMERIC_IS_NAN(num1))
2372 		{
2373 			if (NUMERIC_IS_NAN(num2))
2374 				result = 0;		/* NAN = NAN */
2375 			else
2376 				result = 1;		/* NAN > non-NAN */
2377 		}
2378 		else if (NUMERIC_IS_PINF(num1))
2379 		{
2380 			if (NUMERIC_IS_NAN(num2))
2381 				result = -1;	/* PINF < NAN */
2382 			else if (NUMERIC_IS_PINF(num2))
2383 				result = 0;		/* PINF = PINF */
2384 			else
2385 				result = 1;		/* PINF > anything else */
2386 		}
2387 		else					/* num1 must be NINF */
2388 		{
2389 			if (NUMERIC_IS_NINF(num2))
2390 				result = 0;		/* NINF = NINF */
2391 			else
2392 				result = -1;	/* NINF < anything else */
2393 		}
2394 	}
2395 	else if (NUMERIC_IS_SPECIAL(num2))
2396 	{
2397 		if (NUMERIC_IS_NINF(num2))
2398 			result = 1;			/* normal > NINF */
2399 		else
2400 			result = -1;		/* normal < NAN or PINF */
2401 	}
2402 	else
2403 	{
2404 		result = cmp_var_common(NUMERIC_DIGITS(num1), NUMERIC_NDIGITS(num1),
2405 								NUMERIC_WEIGHT(num1), NUMERIC_SIGN(num1),
2406 								NUMERIC_DIGITS(num2), NUMERIC_NDIGITS(num2),
2407 								NUMERIC_WEIGHT(num2), NUMERIC_SIGN(num2));
2408 	}
2409 
2410 	return result;
2411 }
2412 
2413 /*
2414  * in_range support function for numeric.
2415  */
2416 Datum
in_range_numeric_numeric(PG_FUNCTION_ARGS)2417 in_range_numeric_numeric(PG_FUNCTION_ARGS)
2418 {
2419 	Numeric		val = PG_GETARG_NUMERIC(0);
2420 	Numeric		base = PG_GETARG_NUMERIC(1);
2421 	Numeric		offset = PG_GETARG_NUMERIC(2);
2422 	bool		sub = PG_GETARG_BOOL(3);
2423 	bool		less = PG_GETARG_BOOL(4);
2424 	bool		result;
2425 
2426 	/*
2427 	 * Reject negative (including -Inf) or NaN offset.  Negative is per spec,
2428 	 * and NaN is because appropriate semantics for that seem non-obvious.
2429 	 */
2430 	if (NUMERIC_IS_NAN(offset) ||
2431 		NUMERIC_IS_NINF(offset) ||
2432 		NUMERIC_SIGN(offset) == NUMERIC_NEG)
2433 		ereport(ERROR,
2434 				(errcode(ERRCODE_INVALID_PRECEDING_OR_FOLLOWING_SIZE),
2435 				 errmsg("invalid preceding or following size in window function")));
2436 
2437 	/*
2438 	 * Deal with cases where val and/or base is NaN, following the rule that
2439 	 * NaN sorts after non-NaN (cf cmp_numerics).  The offset cannot affect
2440 	 * the conclusion.
2441 	 */
2442 	if (NUMERIC_IS_NAN(val))
2443 	{
2444 		if (NUMERIC_IS_NAN(base))
2445 			result = true;		/* NAN = NAN */
2446 		else
2447 			result = !less;		/* NAN > non-NAN */
2448 	}
2449 	else if (NUMERIC_IS_NAN(base))
2450 	{
2451 		result = less;			/* non-NAN < NAN */
2452 	}
2453 
2454 	/*
2455 	 * Deal with infinite offset (necessarily +Inf, at this point).
2456 	 */
2457 	else if (NUMERIC_IS_SPECIAL(offset))
2458 	{
2459 		Assert(NUMERIC_IS_PINF(offset));
2460 		if (sub ? NUMERIC_IS_PINF(base) : NUMERIC_IS_NINF(base))
2461 		{
2462 			/*
2463 			 * base +/- offset would produce NaN, so return true for any val
2464 			 * (see in_range_float8_float8() for reasoning).
2465 			 */
2466 			result = true;
2467 		}
2468 		else if (sub)
2469 		{
2470 			/* base - offset must be -inf */
2471 			if (less)
2472 				result = NUMERIC_IS_NINF(val);	/* only -inf is <= sum */
2473 			else
2474 				result = true;	/* any val is >= sum */
2475 		}
2476 		else
2477 		{
2478 			/* base + offset must be +inf */
2479 			if (less)
2480 				result = true;	/* any val is <= sum */
2481 			else
2482 				result = NUMERIC_IS_PINF(val);	/* only +inf is >= sum */
2483 		}
2484 	}
2485 
2486 	/*
2487 	 * Deal with cases where val and/or base is infinite.  The offset, being
2488 	 * now known finite, cannot affect the conclusion.
2489 	 */
2490 	else if (NUMERIC_IS_SPECIAL(val))
2491 	{
2492 		if (NUMERIC_IS_PINF(val))
2493 		{
2494 			if (NUMERIC_IS_PINF(base))
2495 				result = true;	/* PINF = PINF */
2496 			else
2497 				result = !less; /* PINF > any other non-NAN */
2498 		}
2499 		else					/* val must be NINF */
2500 		{
2501 			if (NUMERIC_IS_NINF(base))
2502 				result = true;	/* NINF = NINF */
2503 			else
2504 				result = less;	/* NINF < anything else */
2505 		}
2506 	}
2507 	else if (NUMERIC_IS_SPECIAL(base))
2508 	{
2509 		if (NUMERIC_IS_NINF(base))
2510 			result = !less;		/* normal > NINF */
2511 		else
2512 			result = less;		/* normal < PINF */
2513 	}
2514 	else
2515 	{
2516 		/*
2517 		 * Otherwise go ahead and compute base +/- offset.  While it's
2518 		 * possible for this to overflow the numeric format, it's unlikely
2519 		 * enough that we don't take measures to prevent it.
2520 		 */
2521 		NumericVar	valv;
2522 		NumericVar	basev;
2523 		NumericVar	offsetv;
2524 		NumericVar	sum;
2525 
2526 		init_var_from_num(val, &valv);
2527 		init_var_from_num(base, &basev);
2528 		init_var_from_num(offset, &offsetv);
2529 		init_var(&sum);
2530 
2531 		if (sub)
2532 			sub_var(&basev, &offsetv, &sum);
2533 		else
2534 			add_var(&basev, &offsetv, &sum);
2535 
2536 		if (less)
2537 			result = (cmp_var(&valv, &sum) <= 0);
2538 		else
2539 			result = (cmp_var(&valv, &sum) >= 0);
2540 
2541 		free_var(&sum);
2542 	}
2543 
2544 	PG_FREE_IF_COPY(val, 0);
2545 	PG_FREE_IF_COPY(base, 1);
2546 	PG_FREE_IF_COPY(offset, 2);
2547 
2548 	PG_RETURN_BOOL(result);
2549 }
2550 
2551 Datum
hash_numeric(PG_FUNCTION_ARGS)2552 hash_numeric(PG_FUNCTION_ARGS)
2553 {
2554 	Numeric		key = PG_GETARG_NUMERIC(0);
2555 	Datum		digit_hash;
2556 	Datum		result;
2557 	int			weight;
2558 	int			start_offset;
2559 	int			end_offset;
2560 	int			i;
2561 	int			hash_len;
2562 	NumericDigit *digits;
2563 
2564 	/* If it's NaN or infinity, don't try to hash the rest of the fields */
2565 	if (NUMERIC_IS_SPECIAL(key))
2566 		PG_RETURN_UINT32(0);
2567 
2568 	weight = NUMERIC_WEIGHT(key);
2569 	start_offset = 0;
2570 	end_offset = 0;
2571 
2572 	/*
2573 	 * Omit any leading or trailing zeros from the input to the hash. The
2574 	 * numeric implementation *should* guarantee that leading and trailing
2575 	 * zeros are suppressed, but we're paranoid. Note that we measure the
2576 	 * starting and ending offsets in units of NumericDigits, not bytes.
2577 	 */
2578 	digits = NUMERIC_DIGITS(key);
2579 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2580 	{
2581 		if (digits[i] != (NumericDigit) 0)
2582 			break;
2583 
2584 		start_offset++;
2585 
2586 		/*
2587 		 * The weight is effectively the # of digits before the decimal point,
2588 		 * so decrement it for each leading zero we skip.
2589 		 */
2590 		weight--;
2591 	}
2592 
2593 	/*
2594 	 * If there are no non-zero digits, then the value of the number is zero,
2595 	 * regardless of any other fields.
2596 	 */
2597 	if (NUMERIC_NDIGITS(key) == start_offset)
2598 		PG_RETURN_UINT32(-1);
2599 
2600 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2601 	{
2602 		if (digits[i] != (NumericDigit) 0)
2603 			break;
2604 
2605 		end_offset++;
2606 	}
2607 
2608 	/* If we get here, there should be at least one non-zero digit */
2609 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2610 
2611 	/*
2612 	 * Note that we don't hash on the Numeric's scale, since two numerics can
2613 	 * compare equal but have different scales. We also don't hash on the
2614 	 * sign, although we could: since a sign difference implies inequality,
2615 	 * this shouldn't affect correctness.
2616 	 */
2617 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2618 	digit_hash = hash_any((unsigned char *) (NUMERIC_DIGITS(key) + start_offset),
2619 						  hash_len * sizeof(NumericDigit));
2620 
2621 	/* Mix in the weight, via XOR */
2622 	result = digit_hash ^ weight;
2623 
2624 	PG_RETURN_DATUM(result);
2625 }
2626 
2627 /*
2628  * Returns 64-bit value by hashing a value to a 64-bit value, with a seed.
2629  * Otherwise, similar to hash_numeric.
2630  */
2631 Datum
hash_numeric_extended(PG_FUNCTION_ARGS)2632 hash_numeric_extended(PG_FUNCTION_ARGS)
2633 {
2634 	Numeric		key = PG_GETARG_NUMERIC(0);
2635 	uint64		seed = PG_GETARG_INT64(1);
2636 	Datum		digit_hash;
2637 	Datum		result;
2638 	int			weight;
2639 	int			start_offset;
2640 	int			end_offset;
2641 	int			i;
2642 	int			hash_len;
2643 	NumericDigit *digits;
2644 
2645 	/* If it's NaN or infinity, don't try to hash the rest of the fields */
2646 	if (NUMERIC_IS_SPECIAL(key))
2647 		PG_RETURN_UINT64(seed);
2648 
2649 	weight = NUMERIC_WEIGHT(key);
2650 	start_offset = 0;
2651 	end_offset = 0;
2652 
2653 	digits = NUMERIC_DIGITS(key);
2654 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2655 	{
2656 		if (digits[i] != (NumericDigit) 0)
2657 			break;
2658 
2659 		start_offset++;
2660 
2661 		weight--;
2662 	}
2663 
2664 	if (NUMERIC_NDIGITS(key) == start_offset)
2665 		PG_RETURN_UINT64(seed - 1);
2666 
2667 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2668 	{
2669 		if (digits[i] != (NumericDigit) 0)
2670 			break;
2671 
2672 		end_offset++;
2673 	}
2674 
2675 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2676 
2677 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2678 	digit_hash = hash_any_extended((unsigned char *) (NUMERIC_DIGITS(key)
2679 													  + start_offset),
2680 								   hash_len * sizeof(NumericDigit),
2681 								   seed);
2682 
2683 	result = UInt64GetDatum(DatumGetUInt64(digit_hash) ^ weight);
2684 
2685 	PG_RETURN_DATUM(result);
2686 }
2687 
2688 
2689 /* ----------------------------------------------------------------------
2690  *
2691  * Basic arithmetic functions
2692  *
2693  * ----------------------------------------------------------------------
2694  */
2695 
2696 
2697 /*
2698  * numeric_add() -
2699  *
2700  *	Add two numerics
2701  */
2702 Datum
numeric_add(PG_FUNCTION_ARGS)2703 numeric_add(PG_FUNCTION_ARGS)
2704 {
2705 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2706 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2707 	Numeric		res;
2708 
2709 	res = numeric_add_opt_error(num1, num2, NULL);
2710 
2711 	PG_RETURN_NUMERIC(res);
2712 }
2713 
2714 /*
2715  * numeric_add_opt_error() -
2716  *
2717  *	Internal version of numeric_add().  If "*have_error" flag is provided,
2718  *	on error it's set to true, NULL returned.  This is helpful when caller
2719  *	need to handle errors by itself.
2720  */
2721 Numeric
numeric_add_opt_error(Numeric num1,Numeric num2,bool * have_error)2722 numeric_add_opt_error(Numeric num1, Numeric num2, bool *have_error)
2723 {
2724 	NumericVar	arg1;
2725 	NumericVar	arg2;
2726 	NumericVar	result;
2727 	Numeric		res;
2728 
2729 	/*
2730 	 * Handle NaN and infinities
2731 	 */
2732 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
2733 	{
2734 		if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2735 			return make_result(&const_nan);
2736 		if (NUMERIC_IS_PINF(num1))
2737 		{
2738 			if (NUMERIC_IS_NINF(num2))
2739 				return make_result(&const_nan); /* Inf + -Inf */
2740 			else
2741 				return make_result(&const_pinf);
2742 		}
2743 		if (NUMERIC_IS_NINF(num1))
2744 		{
2745 			if (NUMERIC_IS_PINF(num2))
2746 				return make_result(&const_nan); /* -Inf + Inf */
2747 			else
2748 				return make_result(&const_ninf);
2749 		}
2750 		/* by here, num1 must be finite, so num2 is not */
2751 		if (NUMERIC_IS_PINF(num2))
2752 			return make_result(&const_pinf);
2753 		Assert(NUMERIC_IS_NINF(num2));
2754 		return make_result(&const_ninf);
2755 	}
2756 
2757 	/*
2758 	 * Unpack the values, let add_var() compute the result and return it.
2759 	 */
2760 	init_var_from_num(num1, &arg1);
2761 	init_var_from_num(num2, &arg2);
2762 
2763 	init_var(&result);
2764 	add_var(&arg1, &arg2, &result);
2765 
2766 	res = make_result_opt_error(&result, have_error);
2767 
2768 	free_var(&result);
2769 
2770 	return res;
2771 }
2772 
2773 
2774 /*
2775  * numeric_sub() -
2776  *
2777  *	Subtract one numeric from another
2778  */
2779 Datum
numeric_sub(PG_FUNCTION_ARGS)2780 numeric_sub(PG_FUNCTION_ARGS)
2781 {
2782 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2783 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2784 	Numeric		res;
2785 
2786 	res = numeric_sub_opt_error(num1, num2, NULL);
2787 
2788 	PG_RETURN_NUMERIC(res);
2789 }
2790 
2791 
2792 /*
2793  * numeric_sub_opt_error() -
2794  *
2795  *	Internal version of numeric_sub().  If "*have_error" flag is provided,
2796  *	on error it's set to true, NULL returned.  This is helpful when caller
2797  *	need to handle errors by itself.
2798  */
2799 Numeric
numeric_sub_opt_error(Numeric num1,Numeric num2,bool * have_error)2800 numeric_sub_opt_error(Numeric num1, Numeric num2, bool *have_error)
2801 {
2802 	NumericVar	arg1;
2803 	NumericVar	arg2;
2804 	NumericVar	result;
2805 	Numeric		res;
2806 
2807 	/*
2808 	 * Handle NaN and infinities
2809 	 */
2810 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
2811 	{
2812 		if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2813 			return make_result(&const_nan);
2814 		if (NUMERIC_IS_PINF(num1))
2815 		{
2816 			if (NUMERIC_IS_PINF(num2))
2817 				return make_result(&const_nan); /* Inf - Inf */
2818 			else
2819 				return make_result(&const_pinf);
2820 		}
2821 		if (NUMERIC_IS_NINF(num1))
2822 		{
2823 			if (NUMERIC_IS_NINF(num2))
2824 				return make_result(&const_nan); /* -Inf - -Inf */
2825 			else
2826 				return make_result(&const_ninf);
2827 		}
2828 		/* by here, num1 must be finite, so num2 is not */
2829 		if (NUMERIC_IS_PINF(num2))
2830 			return make_result(&const_ninf);
2831 		Assert(NUMERIC_IS_NINF(num2));
2832 		return make_result(&const_pinf);
2833 	}
2834 
2835 	/*
2836 	 * Unpack the values, let sub_var() compute the result and return it.
2837 	 */
2838 	init_var_from_num(num1, &arg1);
2839 	init_var_from_num(num2, &arg2);
2840 
2841 	init_var(&result);
2842 	sub_var(&arg1, &arg2, &result);
2843 
2844 	res = make_result_opt_error(&result, have_error);
2845 
2846 	free_var(&result);
2847 
2848 	return res;
2849 }
2850 
2851 
2852 /*
2853  * numeric_mul() -
2854  *
2855  *	Calculate the product of two numerics
2856  */
2857 Datum
numeric_mul(PG_FUNCTION_ARGS)2858 numeric_mul(PG_FUNCTION_ARGS)
2859 {
2860 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2861 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2862 	Numeric		res;
2863 
2864 	res = numeric_mul_opt_error(num1, num2, NULL);
2865 
2866 	PG_RETURN_NUMERIC(res);
2867 }
2868 
2869 
2870 /*
2871  * numeric_mul_opt_error() -
2872  *
2873  *	Internal version of numeric_mul().  If "*have_error" flag is provided,
2874  *	on error it's set to true, NULL returned.  This is helpful when caller
2875  *	need to handle errors by itself.
2876  */
2877 Numeric
numeric_mul_opt_error(Numeric num1,Numeric num2,bool * have_error)2878 numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error)
2879 {
2880 	NumericVar	arg1;
2881 	NumericVar	arg2;
2882 	NumericVar	result;
2883 	Numeric		res;
2884 
2885 	/*
2886 	 * Handle NaN and infinities
2887 	 */
2888 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
2889 	{
2890 		if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2891 			return make_result(&const_nan);
2892 		if (NUMERIC_IS_PINF(num1))
2893 		{
2894 			switch (numeric_sign_internal(num2))
2895 			{
2896 				case 0:
2897 					return make_result(&const_nan); /* Inf * 0 */
2898 				case 1:
2899 					return make_result(&const_pinf);
2900 				case -1:
2901 					return make_result(&const_ninf);
2902 			}
2903 			Assert(false);
2904 		}
2905 		if (NUMERIC_IS_NINF(num1))
2906 		{
2907 			switch (numeric_sign_internal(num2))
2908 			{
2909 				case 0:
2910 					return make_result(&const_nan); /* -Inf * 0 */
2911 				case 1:
2912 					return make_result(&const_ninf);
2913 				case -1:
2914 					return make_result(&const_pinf);
2915 			}
2916 			Assert(false);
2917 		}
2918 		/* by here, num1 must be finite, so num2 is not */
2919 		if (NUMERIC_IS_PINF(num2))
2920 		{
2921 			switch (numeric_sign_internal(num1))
2922 			{
2923 				case 0:
2924 					return make_result(&const_nan); /* 0 * Inf */
2925 				case 1:
2926 					return make_result(&const_pinf);
2927 				case -1:
2928 					return make_result(&const_ninf);
2929 			}
2930 			Assert(false);
2931 		}
2932 		Assert(NUMERIC_IS_NINF(num2));
2933 		switch (numeric_sign_internal(num1))
2934 		{
2935 			case 0:
2936 				return make_result(&const_nan); /* 0 * -Inf */
2937 			case 1:
2938 				return make_result(&const_ninf);
2939 			case -1:
2940 				return make_result(&const_pinf);
2941 		}
2942 		Assert(false);
2943 	}
2944 
2945 	/*
2946 	 * Unpack the values, let mul_var() compute the result and return it.
2947 	 * Unlike add_var() and sub_var(), mul_var() will round its result. In the
2948 	 * case of numeric_mul(), which is invoked for the * operator on numerics,
2949 	 * we request exact representation for the product (rscale = sum(dscale of
2950 	 * arg1, dscale of arg2)).  If the exact result has more digits after the
2951 	 * decimal point than can be stored in a numeric, we round it.  Rounding
2952 	 * after computing the exact result ensures that the final result is
2953 	 * correctly rounded (rounding in mul_var() using a truncated product
2954 	 * would not guarantee this).
2955 	 */
2956 	init_var_from_num(num1, &arg1);
2957 	init_var_from_num(num2, &arg2);
2958 
2959 	init_var(&result);
2960 	mul_var(&arg1, &arg2, &result, arg1.dscale + arg2.dscale);
2961 
2962 	if (result.dscale > NUMERIC_DSCALE_MAX)
2963 		round_var(&result, NUMERIC_DSCALE_MAX);
2964 
2965 	res = make_result_opt_error(&result, have_error);
2966 
2967 	free_var(&result);
2968 
2969 	return res;
2970 }
2971 
2972 
2973 /*
2974  * numeric_div() -
2975  *
2976  *	Divide one numeric into another
2977  */
2978 Datum
numeric_div(PG_FUNCTION_ARGS)2979 numeric_div(PG_FUNCTION_ARGS)
2980 {
2981 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2982 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2983 	Numeric		res;
2984 
2985 	res = numeric_div_opt_error(num1, num2, NULL);
2986 
2987 	PG_RETURN_NUMERIC(res);
2988 }
2989 
2990 
2991 /*
2992  * numeric_div_opt_error() -
2993  *
2994  *	Internal version of numeric_div().  If "*have_error" flag is provided,
2995  *	on error it's set to true, NULL returned.  This is helpful when caller
2996  *	need to handle errors by itself.
2997  */
2998 Numeric
numeric_div_opt_error(Numeric num1,Numeric num2,bool * have_error)2999 numeric_div_opt_error(Numeric num1, Numeric num2, bool *have_error)
3000 {
3001 	NumericVar	arg1;
3002 	NumericVar	arg2;
3003 	NumericVar	result;
3004 	Numeric		res;
3005 	int			rscale;
3006 
3007 	if (have_error)
3008 		*have_error = false;
3009 
3010 	/*
3011 	 * Handle NaN and infinities
3012 	 */
3013 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
3014 	{
3015 		if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
3016 			return make_result(&const_nan);
3017 		if (NUMERIC_IS_PINF(num1))
3018 		{
3019 			if (NUMERIC_IS_SPECIAL(num2))
3020 				return make_result(&const_nan); /* Inf / [-]Inf */
3021 			switch (numeric_sign_internal(num2))
3022 			{
3023 				case 0:
3024 					if (have_error)
3025 					{
3026 						*have_error = true;
3027 						return NULL;
3028 					}
3029 					ereport(ERROR,
3030 							(errcode(ERRCODE_DIVISION_BY_ZERO),
3031 							 errmsg("division by zero")));
3032 					break;
3033 				case 1:
3034 					return make_result(&const_pinf);
3035 				case -1:
3036 					return make_result(&const_ninf);
3037 			}
3038 			Assert(false);
3039 		}
3040 		if (NUMERIC_IS_NINF(num1))
3041 		{
3042 			if (NUMERIC_IS_SPECIAL(num2))
3043 				return make_result(&const_nan); /* -Inf / [-]Inf */
3044 			switch (numeric_sign_internal(num2))
3045 			{
3046 				case 0:
3047 					if (have_error)
3048 					{
3049 						*have_error = true;
3050 						return NULL;
3051 					}
3052 					ereport(ERROR,
3053 							(errcode(ERRCODE_DIVISION_BY_ZERO),
3054 							 errmsg("division by zero")));
3055 					break;
3056 				case 1:
3057 					return make_result(&const_ninf);
3058 				case -1:
3059 					return make_result(&const_pinf);
3060 			}
3061 			Assert(false);
3062 		}
3063 		/* by here, num1 must be finite, so num2 is not */
3064 
3065 		/*
3066 		 * POSIX would have us return zero or minus zero if num1 is zero, and
3067 		 * otherwise throw an underflow error.  But the numeric type doesn't
3068 		 * really do underflow, so let's just return zero.
3069 		 */
3070 		return make_result(&const_zero);
3071 	}
3072 
3073 	/*
3074 	 * Unpack the arguments
3075 	 */
3076 	init_var_from_num(num1, &arg1);
3077 	init_var_from_num(num2, &arg2);
3078 
3079 	init_var(&result);
3080 
3081 	/*
3082 	 * Select scale for division result
3083 	 */
3084 	rscale = select_div_scale(&arg1, &arg2);
3085 
3086 	/*
3087 	 * If "have_error" is provided, check for division by zero here
3088 	 */
3089 	if (have_error && (arg2.ndigits == 0 || arg2.digits[0] == 0))
3090 	{
3091 		*have_error = true;
3092 		return NULL;
3093 	}
3094 
3095 	/*
3096 	 * Do the divide and return the result
3097 	 */
3098 	div_var(&arg1, &arg2, &result, rscale, true);
3099 
3100 	res = make_result_opt_error(&result, have_error);
3101 
3102 	free_var(&result);
3103 
3104 	return res;
3105 }
3106 
3107 
3108 /*
3109  * numeric_div_trunc() -
3110  *
3111  *	Divide one numeric into another, truncating the result to an integer
3112  */
3113 Datum
numeric_div_trunc(PG_FUNCTION_ARGS)3114 numeric_div_trunc(PG_FUNCTION_ARGS)
3115 {
3116 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3117 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3118 	NumericVar	arg1;
3119 	NumericVar	arg2;
3120 	NumericVar	result;
3121 	Numeric		res;
3122 
3123 	/*
3124 	 * Handle NaN and infinities
3125 	 */
3126 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
3127 	{
3128 		if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
3129 			PG_RETURN_NUMERIC(make_result(&const_nan));
3130 		if (NUMERIC_IS_PINF(num1))
3131 		{
3132 			if (NUMERIC_IS_SPECIAL(num2))
3133 				PG_RETURN_NUMERIC(make_result(&const_nan)); /* Inf / [-]Inf */
3134 			switch (numeric_sign_internal(num2))
3135 			{
3136 				case 0:
3137 					ereport(ERROR,
3138 							(errcode(ERRCODE_DIVISION_BY_ZERO),
3139 							 errmsg("division by zero")));
3140 					break;
3141 				case 1:
3142 					PG_RETURN_NUMERIC(make_result(&const_pinf));
3143 				case -1:
3144 					PG_RETURN_NUMERIC(make_result(&const_ninf));
3145 			}
3146 			Assert(false);
3147 		}
3148 		if (NUMERIC_IS_NINF(num1))
3149 		{
3150 			if (NUMERIC_IS_SPECIAL(num2))
3151 				PG_RETURN_NUMERIC(make_result(&const_nan)); /* -Inf / [-]Inf */
3152 			switch (numeric_sign_internal(num2))
3153 			{
3154 				case 0:
3155 					ereport(ERROR,
3156 							(errcode(ERRCODE_DIVISION_BY_ZERO),
3157 							 errmsg("division by zero")));
3158 					break;
3159 				case 1:
3160 					PG_RETURN_NUMERIC(make_result(&const_ninf));
3161 				case -1:
3162 					PG_RETURN_NUMERIC(make_result(&const_pinf));
3163 			}
3164 			Assert(false);
3165 		}
3166 		/* by here, num1 must be finite, so num2 is not */
3167 
3168 		/*
3169 		 * POSIX would have us return zero or minus zero if num1 is zero, and
3170 		 * otherwise throw an underflow error.  But the numeric type doesn't
3171 		 * really do underflow, so let's just return zero.
3172 		 */
3173 		PG_RETURN_NUMERIC(make_result(&const_zero));
3174 	}
3175 
3176 	/*
3177 	 * Unpack the arguments
3178 	 */
3179 	init_var_from_num(num1, &arg1);
3180 	init_var_from_num(num2, &arg2);
3181 
3182 	init_var(&result);
3183 
3184 	/*
3185 	 * Do the divide and return the result
3186 	 */
3187 	div_var(&arg1, &arg2, &result, 0, false);
3188 
3189 	res = make_result(&result);
3190 
3191 	free_var(&result);
3192 
3193 	PG_RETURN_NUMERIC(res);
3194 }
3195 
3196 
3197 /*
3198  * numeric_mod() -
3199  *
3200  *	Calculate the modulo of two numerics
3201  */
3202 Datum
numeric_mod(PG_FUNCTION_ARGS)3203 numeric_mod(PG_FUNCTION_ARGS)
3204 {
3205 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3206 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3207 	Numeric		res;
3208 
3209 	res = numeric_mod_opt_error(num1, num2, NULL);
3210 
3211 	PG_RETURN_NUMERIC(res);
3212 }
3213 
3214 
3215 /*
3216  * numeric_mod_opt_error() -
3217  *
3218  *	Internal version of numeric_mod().  If "*have_error" flag is provided,
3219  *	on error it's set to true, NULL returned.  This is helpful when caller
3220  *	need to handle errors by itself.
3221  */
3222 Numeric
numeric_mod_opt_error(Numeric num1,Numeric num2,bool * have_error)3223 numeric_mod_opt_error(Numeric num1, Numeric num2, bool *have_error)
3224 {
3225 	Numeric		res;
3226 	NumericVar	arg1;
3227 	NumericVar	arg2;
3228 	NumericVar	result;
3229 
3230 	if (have_error)
3231 		*have_error = false;
3232 
3233 	/*
3234 	 * Handle NaN and infinities.  We follow POSIX fmod() on this, except that
3235 	 * POSIX treats x-is-infinite and y-is-zero identically, raising EDOM and
3236 	 * returning NaN.  We choose to throw error only for y-is-zero.
3237 	 */
3238 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
3239 	{
3240 		if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
3241 			return make_result(&const_nan);
3242 		if (NUMERIC_IS_INF(num1))
3243 		{
3244 			if (numeric_sign_internal(num2) == 0)
3245 			{
3246 				if (have_error)
3247 				{
3248 					*have_error = true;
3249 					return NULL;
3250 				}
3251 				ereport(ERROR,
3252 						(errcode(ERRCODE_DIVISION_BY_ZERO),
3253 						 errmsg("division by zero")));
3254 			}
3255 			/* Inf % any nonzero = NaN */
3256 			return make_result(&const_nan);
3257 		}
3258 		/* num2 must be [-]Inf; result is num1 regardless of sign of num2 */
3259 		return duplicate_numeric(num1);
3260 	}
3261 
3262 	init_var_from_num(num1, &arg1);
3263 	init_var_from_num(num2, &arg2);
3264 
3265 	init_var(&result);
3266 
3267 	/*
3268 	 * If "have_error" is provided, check for division by zero here
3269 	 */
3270 	if (have_error && (arg2.ndigits == 0 || arg2.digits[0] == 0))
3271 	{
3272 		*have_error = true;
3273 		return NULL;
3274 	}
3275 
3276 	mod_var(&arg1, &arg2, &result);
3277 
3278 	res = make_result_opt_error(&result, NULL);
3279 
3280 	free_var(&result);
3281 
3282 	return res;
3283 }
3284 
3285 
3286 /*
3287  * numeric_inc() -
3288  *
3289  *	Increment a number by one
3290  */
3291 Datum
numeric_inc(PG_FUNCTION_ARGS)3292 numeric_inc(PG_FUNCTION_ARGS)
3293 {
3294 	Numeric		num = PG_GETARG_NUMERIC(0);
3295 	NumericVar	arg;
3296 	Numeric		res;
3297 
3298 	/*
3299 	 * Handle NaN and infinities
3300 	 */
3301 	if (NUMERIC_IS_SPECIAL(num))
3302 		PG_RETURN_NUMERIC(duplicate_numeric(num));
3303 
3304 	/*
3305 	 * Compute the result and return it
3306 	 */
3307 	init_var_from_num(num, &arg);
3308 
3309 	add_var(&arg, &const_one, &arg);
3310 
3311 	res = make_result(&arg);
3312 
3313 	free_var(&arg);
3314 
3315 	PG_RETURN_NUMERIC(res);
3316 }
3317 
3318 
3319 /*
3320  * numeric_smaller() -
3321  *
3322  *	Return the smaller of two numbers
3323  */
3324 Datum
numeric_smaller(PG_FUNCTION_ARGS)3325 numeric_smaller(PG_FUNCTION_ARGS)
3326 {
3327 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3328 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3329 
3330 	/*
3331 	 * Use cmp_numerics so that this will agree with the comparison operators,
3332 	 * particularly as regards comparisons involving NaN.
3333 	 */
3334 	if (cmp_numerics(num1, num2) < 0)
3335 		PG_RETURN_NUMERIC(num1);
3336 	else
3337 		PG_RETURN_NUMERIC(num2);
3338 }
3339 
3340 
3341 /*
3342  * numeric_larger() -
3343  *
3344  *	Return the larger of two numbers
3345  */
3346 Datum
numeric_larger(PG_FUNCTION_ARGS)3347 numeric_larger(PG_FUNCTION_ARGS)
3348 {
3349 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3350 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3351 
3352 	/*
3353 	 * Use cmp_numerics so that this will agree with the comparison operators,
3354 	 * particularly as regards comparisons involving NaN.
3355 	 */
3356 	if (cmp_numerics(num1, num2) > 0)
3357 		PG_RETURN_NUMERIC(num1);
3358 	else
3359 		PG_RETURN_NUMERIC(num2);
3360 }
3361 
3362 
3363 /* ----------------------------------------------------------------------
3364  *
3365  * Advanced math functions
3366  *
3367  * ----------------------------------------------------------------------
3368  */
3369 
3370 /*
3371  * numeric_gcd() -
3372  *
3373  *	Calculate the greatest common divisor of two numerics
3374  */
3375 Datum
numeric_gcd(PG_FUNCTION_ARGS)3376 numeric_gcd(PG_FUNCTION_ARGS)
3377 {
3378 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3379 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3380 	NumericVar	arg1;
3381 	NumericVar	arg2;
3382 	NumericVar	result;
3383 	Numeric		res;
3384 
3385 	/*
3386 	 * Handle NaN and infinities: we consider the result to be NaN in all such
3387 	 * cases.
3388 	 */
3389 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
3390 		PG_RETURN_NUMERIC(make_result(&const_nan));
3391 
3392 	/*
3393 	 * Unpack the arguments
3394 	 */
3395 	init_var_from_num(num1, &arg1);
3396 	init_var_from_num(num2, &arg2);
3397 
3398 	init_var(&result);
3399 
3400 	/*
3401 	 * Find the GCD and return the result
3402 	 */
3403 	gcd_var(&arg1, &arg2, &result);
3404 
3405 	res = make_result(&result);
3406 
3407 	free_var(&result);
3408 
3409 	PG_RETURN_NUMERIC(res);
3410 }
3411 
3412 
3413 /*
3414  * numeric_lcm() -
3415  *
3416  *	Calculate the least common multiple of two numerics
3417  */
3418 Datum
numeric_lcm(PG_FUNCTION_ARGS)3419 numeric_lcm(PG_FUNCTION_ARGS)
3420 {
3421 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3422 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3423 	NumericVar	arg1;
3424 	NumericVar	arg2;
3425 	NumericVar	result;
3426 	Numeric		res;
3427 
3428 	/*
3429 	 * Handle NaN and infinities: we consider the result to be NaN in all such
3430 	 * cases.
3431 	 */
3432 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
3433 		PG_RETURN_NUMERIC(make_result(&const_nan));
3434 
3435 	/*
3436 	 * Unpack the arguments
3437 	 */
3438 	init_var_from_num(num1, &arg1);
3439 	init_var_from_num(num2, &arg2);
3440 
3441 	init_var(&result);
3442 
3443 	/*
3444 	 * Compute the result using lcm(x, y) = abs(x / gcd(x, y) * y), returning
3445 	 * zero if either input is zero.
3446 	 *
3447 	 * Note that the division is guaranteed to be exact, returning an integer
3448 	 * result, so the LCM is an integral multiple of both x and y.  A display
3449 	 * scale of Min(x.dscale, y.dscale) would be sufficient to represent it,
3450 	 * but as with other numeric functions, we choose to return a result whose
3451 	 * display scale is no smaller than either input.
3452 	 */
3453 	if (arg1.ndigits == 0 || arg2.ndigits == 0)
3454 		set_var_from_var(&const_zero, &result);
3455 	else
3456 	{
3457 		gcd_var(&arg1, &arg2, &result);
3458 		div_var(&arg1, &result, &result, 0, false);
3459 		mul_var(&arg2, &result, &result, arg2.dscale);
3460 		result.sign = NUMERIC_POS;
3461 	}
3462 
3463 	result.dscale = Max(arg1.dscale, arg2.dscale);
3464 
3465 	res = make_result(&result);
3466 
3467 	free_var(&result);
3468 
3469 	PG_RETURN_NUMERIC(res);
3470 }
3471 
3472 
3473 /*
3474  * numeric_fac()
3475  *
3476  * Compute factorial
3477  */
3478 Datum
numeric_fac(PG_FUNCTION_ARGS)3479 numeric_fac(PG_FUNCTION_ARGS)
3480 {
3481 	int64		num = PG_GETARG_INT64(0);
3482 	Numeric		res;
3483 	NumericVar	fact;
3484 	NumericVar	result;
3485 
3486 	if (num < 0)
3487 		ereport(ERROR,
3488 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3489 				 errmsg("factorial of a negative number is undefined")));
3490 	if (num <= 1)
3491 	{
3492 		res = make_result(&const_one);
3493 		PG_RETURN_NUMERIC(res);
3494 	}
3495 	/* Fail immediately if the result would overflow */
3496 	if (num > 32177)
3497 		ereport(ERROR,
3498 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3499 				 errmsg("value overflows numeric format")));
3500 
3501 	init_var(&fact);
3502 	init_var(&result);
3503 
3504 	int64_to_numericvar(num, &result);
3505 
3506 	for (num = num - 1; num > 1; num--)
3507 	{
3508 		/* this loop can take awhile, so allow it to be interrupted */
3509 		CHECK_FOR_INTERRUPTS();
3510 
3511 		int64_to_numericvar(num, &fact);
3512 
3513 		mul_var(&result, &fact, &result, 0);
3514 	}
3515 
3516 	res = make_result(&result);
3517 
3518 	free_var(&fact);
3519 	free_var(&result);
3520 
3521 	PG_RETURN_NUMERIC(res);
3522 }
3523 
3524 
3525 /*
3526  * numeric_sqrt() -
3527  *
3528  *	Compute the square root of a numeric.
3529  */
3530 Datum
numeric_sqrt(PG_FUNCTION_ARGS)3531 numeric_sqrt(PG_FUNCTION_ARGS)
3532 {
3533 	Numeric		num = PG_GETARG_NUMERIC(0);
3534 	Numeric		res;
3535 	NumericVar	arg;
3536 	NumericVar	result;
3537 	int			sweight;
3538 	int			rscale;
3539 
3540 	/*
3541 	 * Handle NaN and infinities
3542 	 */
3543 	if (NUMERIC_IS_SPECIAL(num))
3544 	{
3545 		/* error should match that in sqrt_var() */
3546 		if (NUMERIC_IS_NINF(num))
3547 			ereport(ERROR,
3548 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
3549 					 errmsg("cannot take square root of a negative number")));
3550 		/* For NAN or PINF, just duplicate the input */
3551 		PG_RETURN_NUMERIC(duplicate_numeric(num));
3552 	}
3553 
3554 	/*
3555 	 * Unpack the argument and determine the result scale.  We choose a scale
3556 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
3557 	 * case not less than the input's dscale.
3558 	 */
3559 	init_var_from_num(num, &arg);
3560 
3561 	init_var(&result);
3562 
3563 	/* Assume the input was normalized, so arg.weight is accurate */
3564 	sweight = (arg.weight + 1) * DEC_DIGITS / 2 - 1;
3565 
3566 	rscale = NUMERIC_MIN_SIG_DIGITS - sweight;
3567 	rscale = Max(rscale, arg.dscale);
3568 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3569 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3570 
3571 	/*
3572 	 * Let sqrt_var() do the calculation and return the result.
3573 	 */
3574 	sqrt_var(&arg, &result, rscale);
3575 
3576 	res = make_result(&result);
3577 
3578 	free_var(&result);
3579 
3580 	PG_RETURN_NUMERIC(res);
3581 }
3582 
3583 
3584 /*
3585  * numeric_exp() -
3586  *
3587  *	Raise e to the power of x
3588  */
3589 Datum
numeric_exp(PG_FUNCTION_ARGS)3590 numeric_exp(PG_FUNCTION_ARGS)
3591 {
3592 	Numeric		num = PG_GETARG_NUMERIC(0);
3593 	Numeric		res;
3594 	NumericVar	arg;
3595 	NumericVar	result;
3596 	int			rscale;
3597 	double		val;
3598 
3599 	/*
3600 	 * Handle NaN and infinities
3601 	 */
3602 	if (NUMERIC_IS_SPECIAL(num))
3603 	{
3604 		/* Per POSIX, exp(-Inf) is zero */
3605 		if (NUMERIC_IS_NINF(num))
3606 			PG_RETURN_NUMERIC(make_result(&const_zero));
3607 		/* For NAN or PINF, just duplicate the input */
3608 		PG_RETURN_NUMERIC(duplicate_numeric(num));
3609 	}
3610 
3611 	/*
3612 	 * Unpack the argument and determine the result scale.  We choose a scale
3613 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
3614 	 * case not less than the input's dscale.
3615 	 */
3616 	init_var_from_num(num, &arg);
3617 
3618 	init_var(&result);
3619 
3620 	/* convert input to float8, ignoring overflow */
3621 	val = numericvar_to_double_no_overflow(&arg);
3622 
3623 	/*
3624 	 * log10(result) = num * log10(e), so this is approximately the decimal
3625 	 * weight of the result:
3626 	 */
3627 	val *= 0.434294481903252;
3628 
3629 	/* limit to something that won't cause integer overflow */
3630 	val = Max(val, -NUMERIC_MAX_RESULT_SCALE);
3631 	val = Min(val, NUMERIC_MAX_RESULT_SCALE);
3632 
3633 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
3634 	rscale = Max(rscale, arg.dscale);
3635 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3636 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3637 
3638 	/*
3639 	 * Let exp_var() do the calculation and return the result.
3640 	 */
3641 	exp_var(&arg, &result, rscale);
3642 
3643 	res = make_result(&result);
3644 
3645 	free_var(&result);
3646 
3647 	PG_RETURN_NUMERIC(res);
3648 }
3649 
3650 
3651 /*
3652  * numeric_ln() -
3653  *
3654  *	Compute the natural logarithm of x
3655  */
3656 Datum
numeric_ln(PG_FUNCTION_ARGS)3657 numeric_ln(PG_FUNCTION_ARGS)
3658 {
3659 	Numeric		num = PG_GETARG_NUMERIC(0);
3660 	Numeric		res;
3661 	NumericVar	arg;
3662 	NumericVar	result;
3663 	int			ln_dweight;
3664 	int			rscale;
3665 
3666 	/*
3667 	 * Handle NaN and infinities
3668 	 */
3669 	if (NUMERIC_IS_SPECIAL(num))
3670 	{
3671 		if (NUMERIC_IS_NINF(num))
3672 			ereport(ERROR,
3673 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
3674 					 errmsg("cannot take logarithm of a negative number")));
3675 		/* For NAN or PINF, just duplicate the input */
3676 		PG_RETURN_NUMERIC(duplicate_numeric(num));
3677 	}
3678 
3679 	init_var_from_num(num, &arg);
3680 	init_var(&result);
3681 
3682 	/* Estimated dweight of logarithm */
3683 	ln_dweight = estimate_ln_dweight(&arg);
3684 
3685 	rscale = NUMERIC_MIN_SIG_DIGITS - ln_dweight;
3686 	rscale = Max(rscale, arg.dscale);
3687 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3688 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3689 
3690 	ln_var(&arg, &result, rscale);
3691 
3692 	res = make_result(&result);
3693 
3694 	free_var(&result);
3695 
3696 	PG_RETURN_NUMERIC(res);
3697 }
3698 
3699 
3700 /*
3701  * numeric_log() -
3702  *
3703  *	Compute the logarithm of x in a given base
3704  */
3705 Datum
numeric_log(PG_FUNCTION_ARGS)3706 numeric_log(PG_FUNCTION_ARGS)
3707 {
3708 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3709 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3710 	Numeric		res;
3711 	NumericVar	arg1;
3712 	NumericVar	arg2;
3713 	NumericVar	result;
3714 
3715 	/*
3716 	 * Handle NaN and infinities
3717 	 */
3718 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
3719 	{
3720 		int			sign1,
3721 					sign2;
3722 
3723 		if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
3724 			PG_RETURN_NUMERIC(make_result(&const_nan));
3725 		/* fail on negative inputs including -Inf, as log_var would */
3726 		sign1 = numeric_sign_internal(num1);
3727 		sign2 = numeric_sign_internal(num2);
3728 		if (sign1 < 0 || sign2 < 0)
3729 			ereport(ERROR,
3730 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
3731 					 errmsg("cannot take logarithm of a negative number")));
3732 		/* fail on zero inputs, as log_var would */
3733 		if (sign1 == 0 || sign2 == 0)
3734 			ereport(ERROR,
3735 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
3736 					 errmsg("cannot take logarithm of zero")));
3737 		if (NUMERIC_IS_PINF(num1))
3738 		{
3739 			/* log(Inf, Inf) reduces to Inf/Inf, so it's NaN */
3740 			if (NUMERIC_IS_PINF(num2))
3741 				PG_RETURN_NUMERIC(make_result(&const_nan));
3742 			/* log(Inf, finite-positive) is zero (we don't throw underflow) */
3743 			PG_RETURN_NUMERIC(make_result(&const_zero));
3744 		}
3745 		Assert(NUMERIC_IS_PINF(num2));
3746 		/* log(finite-positive, Inf) is Inf */
3747 		PG_RETURN_NUMERIC(make_result(&const_pinf));
3748 	}
3749 
3750 	/*
3751 	 * Initialize things
3752 	 */
3753 	init_var_from_num(num1, &arg1);
3754 	init_var_from_num(num2, &arg2);
3755 	init_var(&result);
3756 
3757 	/*
3758 	 * Call log_var() to compute and return the result; note it handles scale
3759 	 * selection itself.
3760 	 */
3761 	log_var(&arg1, &arg2, &result);
3762 
3763 	res = make_result(&result);
3764 
3765 	free_var(&result);
3766 
3767 	PG_RETURN_NUMERIC(res);
3768 }
3769 
3770 
3771 /*
3772  * numeric_power() -
3773  *
3774  *	Raise x to the power of y
3775  */
3776 Datum
numeric_power(PG_FUNCTION_ARGS)3777 numeric_power(PG_FUNCTION_ARGS)
3778 {
3779 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3780 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3781 	Numeric		res;
3782 	NumericVar	arg1;
3783 	NumericVar	arg2;
3784 	NumericVar	result;
3785 	int			sign1,
3786 				sign2;
3787 
3788 	/*
3789 	 * Handle NaN and infinities
3790 	 */
3791 	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
3792 	{
3793 		/*
3794 		 * We follow the POSIX spec for pow(3), which says that NaN ^ 0 = 1,
3795 		 * and 1 ^ NaN = 1, while all other cases with NaN inputs yield NaN
3796 		 * (with no error).
3797 		 */
3798 		if (NUMERIC_IS_NAN(num1))
3799 		{
3800 			if (!NUMERIC_IS_SPECIAL(num2))
3801 			{
3802 				init_var_from_num(num2, &arg2);
3803 				if (cmp_var(&arg2, &const_zero) == 0)
3804 					PG_RETURN_NUMERIC(make_result(&const_one));
3805 			}
3806 			PG_RETURN_NUMERIC(make_result(&const_nan));
3807 		}
3808 		if (NUMERIC_IS_NAN(num2))
3809 		{
3810 			if (!NUMERIC_IS_SPECIAL(num1))
3811 			{
3812 				init_var_from_num(num1, &arg1);
3813 				if (cmp_var(&arg1, &const_one) == 0)
3814 					PG_RETURN_NUMERIC(make_result(&const_one));
3815 			}
3816 			PG_RETURN_NUMERIC(make_result(&const_nan));
3817 		}
3818 		/* At least one input is infinite, but error rules still apply */
3819 		sign1 = numeric_sign_internal(num1);
3820 		sign2 = numeric_sign_internal(num2);
3821 		if (sign1 == 0 && sign2 < 0)
3822 			ereport(ERROR,
3823 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
3824 					 errmsg("zero raised to a negative power is undefined")));
3825 		if (sign1 < 0 && !numeric_is_integral(num2))
3826 			ereport(ERROR,
3827 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
3828 					 errmsg("a negative number raised to a non-integer power yields a complex result")));
3829 
3830 		/*
3831 		 * POSIX gives this series of rules for pow(3) with infinite inputs:
3832 		 *
3833 		 * For any value of y, if x is +1, 1.0 shall be returned.
3834 		 */
3835 		if (!NUMERIC_IS_SPECIAL(num1))
3836 		{
3837 			init_var_from_num(num1, &arg1);
3838 			if (cmp_var(&arg1, &const_one) == 0)
3839 				PG_RETURN_NUMERIC(make_result(&const_one));
3840 		}
3841 
3842 		/*
3843 		 * For any value of x, if y is [-]0, 1.0 shall be returned.
3844 		 */
3845 		if (sign2 == 0)
3846 			PG_RETURN_NUMERIC(make_result(&const_one));
3847 
3848 		/*
3849 		 * For any odd integer value of y > 0, if x is [-]0, [-]0 shall be
3850 		 * returned.  For y > 0 and not an odd integer, if x is [-]0, +0 shall
3851 		 * be returned.  (Since we don't deal in minus zero, we need not
3852 		 * distinguish these two cases.)
3853 		 */
3854 		if (sign1 == 0 && sign2 > 0)
3855 			PG_RETURN_NUMERIC(make_result(&const_zero));
3856 
3857 		/*
3858 		 * If x is -1, and y is [-]Inf, 1.0 shall be returned.
3859 		 *
3860 		 * For |x| < 1, if y is -Inf, +Inf shall be returned.
3861 		 *
3862 		 * For |x| > 1, if y is -Inf, +0 shall be returned.
3863 		 *
3864 		 * For |x| < 1, if y is +Inf, +0 shall be returned.
3865 		 *
3866 		 * For |x| > 1, if y is +Inf, +Inf shall be returned.
3867 		 */
3868 		if (NUMERIC_IS_INF(num2))
3869 		{
3870 			bool		abs_x_gt_one;
3871 
3872 			if (NUMERIC_IS_SPECIAL(num1))
3873 				abs_x_gt_one = true;	/* x is either Inf or -Inf */
3874 			else
3875 			{
3876 				init_var_from_num(num1, &arg1);
3877 				if (cmp_var(&arg1, &const_minus_one) == 0)
3878 					PG_RETURN_NUMERIC(make_result(&const_one));
3879 				arg1.sign = NUMERIC_POS;	/* now arg1 = abs(x) */
3880 				abs_x_gt_one = (cmp_var(&arg1, &const_one) > 0);
3881 			}
3882 			if (abs_x_gt_one == (sign2 > 0))
3883 				PG_RETURN_NUMERIC(make_result(&const_pinf));
3884 			else
3885 				PG_RETURN_NUMERIC(make_result(&const_zero));
3886 		}
3887 
3888 		/*
3889 		 * For y < 0, if x is +Inf, +0 shall be returned.
3890 		 *
3891 		 * For y > 0, if x is +Inf, +Inf shall be returned.
3892 		 */
3893 		if (NUMERIC_IS_PINF(num1))
3894 		{
3895 			if (sign2 > 0)
3896 				PG_RETURN_NUMERIC(make_result(&const_pinf));
3897 			else
3898 				PG_RETURN_NUMERIC(make_result(&const_zero));
3899 		}
3900 
3901 		Assert(NUMERIC_IS_NINF(num1));
3902 
3903 		/*
3904 		 * For y an odd integer < 0, if x is -Inf, -0 shall be returned.  For
3905 		 * y < 0 and not an odd integer, if x is -Inf, +0 shall be returned.
3906 		 * (Again, we need not distinguish these two cases.)
3907 		 */
3908 		if (sign2 < 0)
3909 			PG_RETURN_NUMERIC(make_result(&const_zero));
3910 
3911 		/*
3912 		 * For y an odd integer > 0, if x is -Inf, -Inf shall be returned. For
3913 		 * y > 0 and not an odd integer, if x is -Inf, +Inf shall be returned.
3914 		 */
3915 		init_var_from_num(num2, &arg2);
3916 		if (arg2.ndigits > 0 && arg2.ndigits == arg2.weight + 1 &&
3917 			(arg2.digits[arg2.ndigits - 1] & 1))
3918 			PG_RETURN_NUMERIC(make_result(&const_ninf));
3919 		else
3920 			PG_RETURN_NUMERIC(make_result(&const_pinf));
3921 	}
3922 
3923 	/*
3924 	 * The SQL spec requires that we emit a particular SQLSTATE error code for
3925 	 * certain error conditions.  Specifically, we don't return a
3926 	 * divide-by-zero error code for 0 ^ -1.  Raising a negative number to a
3927 	 * non-integer power must produce the same error code, but that case is
3928 	 * handled in power_var().
3929 	 */
3930 	sign1 = numeric_sign_internal(num1);
3931 	sign2 = numeric_sign_internal(num2);
3932 
3933 	if (sign1 == 0 && sign2 < 0)
3934 		ereport(ERROR,
3935 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
3936 				 errmsg("zero raised to a negative power is undefined")));
3937 
3938 	/*
3939 	 * Initialize things
3940 	 */
3941 	init_var(&result);
3942 	init_var_from_num(num1, &arg1);
3943 	init_var_from_num(num2, &arg2);
3944 
3945 	/*
3946 	 * Call power_var() to compute and return the result; note it handles
3947 	 * scale selection itself.
3948 	 */
3949 	power_var(&arg1, &arg2, &result);
3950 
3951 	res = make_result(&result);
3952 
3953 	free_var(&result);
3954 
3955 	PG_RETURN_NUMERIC(res);
3956 }
3957 
3958 /*
3959  * numeric_scale() -
3960  *
3961  *	Returns the scale, i.e. the count of decimal digits in the fractional part
3962  */
3963 Datum
numeric_scale(PG_FUNCTION_ARGS)3964 numeric_scale(PG_FUNCTION_ARGS)
3965 {
3966 	Numeric		num = PG_GETARG_NUMERIC(0);
3967 
3968 	if (NUMERIC_IS_SPECIAL(num))
3969 		PG_RETURN_NULL();
3970 
3971 	PG_RETURN_INT32(NUMERIC_DSCALE(num));
3972 }
3973 
3974 /*
3975  * Calculate minimum scale for value.
3976  */
3977 static int
get_min_scale(NumericVar * var)3978 get_min_scale(NumericVar *var)
3979 {
3980 	int			min_scale;
3981 	int			last_digit_pos;
3982 
3983 	/*
3984 	 * Ordinarily, the input value will be "stripped" so that the last
3985 	 * NumericDigit is nonzero.  But we don't want to get into an infinite
3986 	 * loop if it isn't, so explicitly find the last nonzero digit.
3987 	 */
3988 	last_digit_pos = var->ndigits - 1;
3989 	while (last_digit_pos >= 0 &&
3990 		   var->digits[last_digit_pos] == 0)
3991 		last_digit_pos--;
3992 
3993 	if (last_digit_pos >= 0)
3994 	{
3995 		/* compute min_scale assuming that last ndigit has no zeroes */
3996 		min_scale = (last_digit_pos - var->weight) * DEC_DIGITS;
3997 
3998 		/*
3999 		 * We could get a negative result if there are no digits after the
4000 		 * decimal point.  In this case the min_scale must be zero.
4001 		 */
4002 		if (min_scale > 0)
4003 		{
4004 			/*
4005 			 * Reduce min_scale if trailing digit(s) in last NumericDigit are
4006 			 * zero.
4007 			 */
4008 			NumericDigit last_digit = var->digits[last_digit_pos];
4009 
4010 			while (last_digit % 10 == 0)
4011 			{
4012 				min_scale--;
4013 				last_digit /= 10;
4014 			}
4015 		}
4016 		else
4017 			min_scale = 0;
4018 	}
4019 	else
4020 		min_scale = 0;			/* result if input is zero */
4021 
4022 	return min_scale;
4023 }
4024 
4025 /*
4026  * Returns minimum scale required to represent supplied value without loss.
4027  */
4028 Datum
numeric_min_scale(PG_FUNCTION_ARGS)4029 numeric_min_scale(PG_FUNCTION_ARGS)
4030 {
4031 	Numeric		num = PG_GETARG_NUMERIC(0);
4032 	NumericVar	arg;
4033 	int			min_scale;
4034 
4035 	if (NUMERIC_IS_SPECIAL(num))
4036 		PG_RETURN_NULL();
4037 
4038 	init_var_from_num(num, &arg);
4039 	min_scale = get_min_scale(&arg);
4040 	free_var(&arg);
4041 
4042 	PG_RETURN_INT32(min_scale);
4043 }
4044 
4045 /*
4046  * Reduce scale of numeric value to represent supplied value without loss.
4047  */
4048 Datum
numeric_trim_scale(PG_FUNCTION_ARGS)4049 numeric_trim_scale(PG_FUNCTION_ARGS)
4050 {
4051 	Numeric		num = PG_GETARG_NUMERIC(0);
4052 	Numeric		res;
4053 	NumericVar	result;
4054 
4055 	if (NUMERIC_IS_SPECIAL(num))
4056 		PG_RETURN_NUMERIC(duplicate_numeric(num));
4057 
4058 	init_var_from_num(num, &result);
4059 	result.dscale = get_min_scale(&result);
4060 	res = make_result(&result);
4061 	free_var(&result);
4062 
4063 	PG_RETURN_NUMERIC(res);
4064 }
4065 
4066 
4067 /* ----------------------------------------------------------------------
4068  *
4069  * Type conversion functions
4070  *
4071  * ----------------------------------------------------------------------
4072  */
4073 
4074 Numeric
int64_to_numeric(int64 val)4075 int64_to_numeric(int64 val)
4076 {
4077 	Numeric		res;
4078 	NumericVar	result;
4079 
4080 	init_var(&result);
4081 
4082 	int64_to_numericvar(val, &result);
4083 
4084 	res = make_result(&result);
4085 
4086 	free_var(&result);
4087 
4088 	return res;
4089 }
4090 
4091 /*
4092  * Convert val1/(10**val2) to numeric.  This is much faster than normal
4093  * numeric division.
4094  */
4095 Numeric
int64_div_fast_to_numeric(int64 val1,int log10val2)4096 int64_div_fast_to_numeric(int64 val1, int log10val2)
4097 {
4098 	Numeric		res;
4099 	NumericVar	result;
4100 	int64		saved_val1 = val1;
4101 	int			w;
4102 	int			m;
4103 
4104 	/* how much to decrease the weight by */
4105 	w = log10val2 / DEC_DIGITS;
4106 	/* how much is left */
4107 	m = log10val2 % DEC_DIGITS;
4108 
4109 	/*
4110 	 * If there is anything left, multiply the dividend by what's left, then
4111 	 * shift the weight by one more.
4112 	 */
4113 	if (m > 0)
4114 	{
4115 		static int	pow10[] = {1, 10, 100, 1000};
4116 
4117 		StaticAssertStmt(lengthof(pow10) == DEC_DIGITS, "mismatch with DEC_DIGITS");
4118 		if (unlikely(pg_mul_s64_overflow(val1, pow10[DEC_DIGITS - m], &val1)))
4119 		{
4120 			/*
4121 			 * If it doesn't fit, do the whole computation in numeric the slow
4122 			 * way.  Note that va1l may have been overwritten, so use
4123 			 * saved_val1 instead.
4124 			 */
4125 			int			val2 = 1;
4126 
4127 			for (int i = 0; i < log10val2; i++)
4128 				val2 *= 10;
4129 			res = numeric_div_opt_error(int64_to_numeric(saved_val1), int64_to_numeric(val2), NULL);
4130 			res = DatumGetNumeric(DirectFunctionCall2(numeric_round,
4131 													  NumericGetDatum(res),
4132 													  Int32GetDatum(log10val2)));
4133 			return res;
4134 		}
4135 		w++;
4136 	}
4137 
4138 	init_var(&result);
4139 
4140 	int64_to_numericvar(val1, &result);
4141 
4142 	result.weight -= w;
4143 	result.dscale += w * DEC_DIGITS - (DEC_DIGITS - m);
4144 
4145 	res = make_result(&result);
4146 
4147 	free_var(&result);
4148 
4149 	return res;
4150 }
4151 
4152 Datum
int4_numeric(PG_FUNCTION_ARGS)4153 int4_numeric(PG_FUNCTION_ARGS)
4154 {
4155 	int32		val = PG_GETARG_INT32(0);
4156 
4157 	PG_RETURN_NUMERIC(int64_to_numeric(val));
4158 }
4159 
4160 int32
numeric_int4_opt_error(Numeric num,bool * have_error)4161 numeric_int4_opt_error(Numeric num, bool *have_error)
4162 {
4163 	NumericVar	x;
4164 	int32		result;
4165 
4166 	if (have_error)
4167 		*have_error = false;
4168 
4169 	if (NUMERIC_IS_SPECIAL(num))
4170 	{
4171 		if (have_error)
4172 		{
4173 			*have_error = true;
4174 			return 0;
4175 		}
4176 		else
4177 		{
4178 			if (NUMERIC_IS_NAN(num))
4179 				ereport(ERROR,
4180 						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4181 						 errmsg("cannot convert NaN to %s", "integer")));
4182 			else
4183 				ereport(ERROR,
4184 						(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4185 						 errmsg("cannot convert infinity to %s", "integer")));
4186 		}
4187 	}
4188 
4189 	/* Convert to variable format, then convert to int4 */
4190 	init_var_from_num(num, &x);
4191 
4192 	if (!numericvar_to_int32(&x, &result))
4193 	{
4194 		if (have_error)
4195 		{
4196 			*have_error = true;
4197 			return 0;
4198 		}
4199 		else
4200 		{
4201 			ereport(ERROR,
4202 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
4203 					 errmsg("integer out of range")));
4204 		}
4205 	}
4206 
4207 	return result;
4208 }
4209 
4210 Datum
numeric_int4(PG_FUNCTION_ARGS)4211 numeric_int4(PG_FUNCTION_ARGS)
4212 {
4213 	Numeric		num = PG_GETARG_NUMERIC(0);
4214 
4215 	PG_RETURN_INT32(numeric_int4_opt_error(num, NULL));
4216 }
4217 
4218 /*
4219  * Given a NumericVar, convert it to an int32. If the NumericVar
4220  * exceeds the range of an int32, false is returned, otherwise true is returned.
4221  * The input NumericVar is *not* free'd.
4222  */
4223 static bool
numericvar_to_int32(const NumericVar * var,int32 * result)4224 numericvar_to_int32(const NumericVar *var, int32 *result)
4225 {
4226 	int64		val;
4227 
4228 	if (!numericvar_to_int64(var, &val))
4229 		return false;
4230 
4231 	if (unlikely(val < PG_INT32_MIN) || unlikely(val > PG_INT32_MAX))
4232 		return false;
4233 
4234 	/* Down-convert to int4 */
4235 	*result = (int32) val;
4236 
4237 	return true;
4238 }
4239 
4240 Datum
int8_numeric(PG_FUNCTION_ARGS)4241 int8_numeric(PG_FUNCTION_ARGS)
4242 {
4243 	int64		val = PG_GETARG_INT64(0);
4244 
4245 	PG_RETURN_NUMERIC(int64_to_numeric(val));
4246 }
4247 
4248 
4249 Datum
numeric_int8(PG_FUNCTION_ARGS)4250 numeric_int8(PG_FUNCTION_ARGS)
4251 {
4252 	Numeric		num = PG_GETARG_NUMERIC(0);
4253 	NumericVar	x;
4254 	int64		result;
4255 
4256 	if (NUMERIC_IS_SPECIAL(num))
4257 	{
4258 		if (NUMERIC_IS_NAN(num))
4259 			ereport(ERROR,
4260 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4261 					 errmsg("cannot convert NaN to %s", "bigint")));
4262 		else
4263 			ereport(ERROR,
4264 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4265 					 errmsg("cannot convert infinity to %s", "bigint")));
4266 	}
4267 
4268 	/* Convert to variable format and thence to int8 */
4269 	init_var_from_num(num, &x);
4270 
4271 	if (!numericvar_to_int64(&x, &result))
4272 		ereport(ERROR,
4273 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
4274 				 errmsg("bigint out of range")));
4275 
4276 	PG_RETURN_INT64(result);
4277 }
4278 
4279 
4280 Datum
int2_numeric(PG_FUNCTION_ARGS)4281 int2_numeric(PG_FUNCTION_ARGS)
4282 {
4283 	int16		val = PG_GETARG_INT16(0);
4284 
4285 	PG_RETURN_NUMERIC(int64_to_numeric(val));
4286 }
4287 
4288 
4289 Datum
numeric_int2(PG_FUNCTION_ARGS)4290 numeric_int2(PG_FUNCTION_ARGS)
4291 {
4292 	Numeric		num = PG_GETARG_NUMERIC(0);
4293 	NumericVar	x;
4294 	int64		val;
4295 	int16		result;
4296 
4297 	if (NUMERIC_IS_SPECIAL(num))
4298 	{
4299 		if (NUMERIC_IS_NAN(num))
4300 			ereport(ERROR,
4301 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4302 					 errmsg("cannot convert NaN to %s", "smallint")));
4303 		else
4304 			ereport(ERROR,
4305 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4306 					 errmsg("cannot convert infinity to %s", "smallint")));
4307 	}
4308 
4309 	/* Convert to variable format and thence to int8 */
4310 	init_var_from_num(num, &x);
4311 
4312 	if (!numericvar_to_int64(&x, &val))
4313 		ereport(ERROR,
4314 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
4315 				 errmsg("smallint out of range")));
4316 
4317 	if (unlikely(val < PG_INT16_MIN) || unlikely(val > PG_INT16_MAX))
4318 		ereport(ERROR,
4319 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
4320 				 errmsg("smallint out of range")));
4321 
4322 	/* Down-convert to int2 */
4323 	result = (int16) val;
4324 
4325 	PG_RETURN_INT16(result);
4326 }
4327 
4328 
4329 Datum
float8_numeric(PG_FUNCTION_ARGS)4330 float8_numeric(PG_FUNCTION_ARGS)
4331 {
4332 	float8		val = PG_GETARG_FLOAT8(0);
4333 	Numeric		res;
4334 	NumericVar	result;
4335 	char		buf[DBL_DIG + 100];
4336 
4337 	if (isnan(val))
4338 		PG_RETURN_NUMERIC(make_result(&const_nan));
4339 
4340 	if (isinf(val))
4341 	{
4342 		if (val < 0)
4343 			PG_RETURN_NUMERIC(make_result(&const_ninf));
4344 		else
4345 			PG_RETURN_NUMERIC(make_result(&const_pinf));
4346 	}
4347 
4348 	snprintf(buf, sizeof(buf), "%.*g", DBL_DIG, val);
4349 
4350 	init_var(&result);
4351 
4352 	/* Assume we need not worry about leading/trailing spaces */
4353 	(void) set_var_from_str(buf, buf, &result);
4354 
4355 	res = make_result(&result);
4356 
4357 	free_var(&result);
4358 
4359 	PG_RETURN_NUMERIC(res);
4360 }
4361 
4362 
4363 Datum
numeric_float8(PG_FUNCTION_ARGS)4364 numeric_float8(PG_FUNCTION_ARGS)
4365 {
4366 	Numeric		num = PG_GETARG_NUMERIC(0);
4367 	char	   *tmp;
4368 	Datum		result;
4369 
4370 	if (NUMERIC_IS_SPECIAL(num))
4371 	{
4372 		if (NUMERIC_IS_PINF(num))
4373 			PG_RETURN_FLOAT8(get_float8_infinity());
4374 		else if (NUMERIC_IS_NINF(num))
4375 			PG_RETURN_FLOAT8(-get_float8_infinity());
4376 		else
4377 			PG_RETURN_FLOAT8(get_float8_nan());
4378 	}
4379 
4380 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
4381 											  NumericGetDatum(num)));
4382 
4383 	result = DirectFunctionCall1(float8in, CStringGetDatum(tmp));
4384 
4385 	pfree(tmp);
4386 
4387 	PG_RETURN_DATUM(result);
4388 }
4389 
4390 
4391 /*
4392  * Convert numeric to float8; if out of range, return +/- HUGE_VAL
4393  *
4394  * (internal helper function, not directly callable from SQL)
4395  */
4396 Datum
numeric_float8_no_overflow(PG_FUNCTION_ARGS)4397 numeric_float8_no_overflow(PG_FUNCTION_ARGS)
4398 {
4399 	Numeric		num = PG_GETARG_NUMERIC(0);
4400 	double		val;
4401 
4402 	if (NUMERIC_IS_SPECIAL(num))
4403 	{
4404 		if (NUMERIC_IS_PINF(num))
4405 			val = HUGE_VAL;
4406 		else if (NUMERIC_IS_NINF(num))
4407 			val = -HUGE_VAL;
4408 		else
4409 			val = get_float8_nan();
4410 	}
4411 	else
4412 	{
4413 		NumericVar	x;
4414 
4415 		init_var_from_num(num, &x);
4416 		val = numericvar_to_double_no_overflow(&x);
4417 	}
4418 
4419 	PG_RETURN_FLOAT8(val);
4420 }
4421 
4422 Datum
float4_numeric(PG_FUNCTION_ARGS)4423 float4_numeric(PG_FUNCTION_ARGS)
4424 {
4425 	float4		val = PG_GETARG_FLOAT4(0);
4426 	Numeric		res;
4427 	NumericVar	result;
4428 	char		buf[FLT_DIG + 100];
4429 
4430 	if (isnan(val))
4431 		PG_RETURN_NUMERIC(make_result(&const_nan));
4432 
4433 	if (isinf(val))
4434 	{
4435 		if (val < 0)
4436 			PG_RETURN_NUMERIC(make_result(&const_ninf));
4437 		else
4438 			PG_RETURN_NUMERIC(make_result(&const_pinf));
4439 	}
4440 
4441 	snprintf(buf, sizeof(buf), "%.*g", FLT_DIG, val);
4442 
4443 	init_var(&result);
4444 
4445 	/* Assume we need not worry about leading/trailing spaces */
4446 	(void) set_var_from_str(buf, buf, &result);
4447 
4448 	res = make_result(&result);
4449 
4450 	free_var(&result);
4451 
4452 	PG_RETURN_NUMERIC(res);
4453 }
4454 
4455 
4456 Datum
numeric_float4(PG_FUNCTION_ARGS)4457 numeric_float4(PG_FUNCTION_ARGS)
4458 {
4459 	Numeric		num = PG_GETARG_NUMERIC(0);
4460 	char	   *tmp;
4461 	Datum		result;
4462 
4463 	if (NUMERIC_IS_SPECIAL(num))
4464 	{
4465 		if (NUMERIC_IS_PINF(num))
4466 			PG_RETURN_FLOAT4(get_float4_infinity());
4467 		else if (NUMERIC_IS_NINF(num))
4468 			PG_RETURN_FLOAT4(-get_float4_infinity());
4469 		else
4470 			PG_RETURN_FLOAT4(get_float4_nan());
4471 	}
4472 
4473 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
4474 											  NumericGetDatum(num)));
4475 
4476 	result = DirectFunctionCall1(float4in, CStringGetDatum(tmp));
4477 
4478 	pfree(tmp);
4479 
4480 	PG_RETURN_DATUM(result);
4481 }
4482 
4483 
4484 Datum
numeric_pg_lsn(PG_FUNCTION_ARGS)4485 numeric_pg_lsn(PG_FUNCTION_ARGS)
4486 {
4487 	Numeric		num = PG_GETARG_NUMERIC(0);
4488 	NumericVar	x;
4489 	XLogRecPtr	result;
4490 
4491 	if (NUMERIC_IS_SPECIAL(num))
4492 	{
4493 		if (NUMERIC_IS_NAN(num))
4494 			ereport(ERROR,
4495 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4496 					 errmsg("cannot convert NaN to %s", "pg_lsn")));
4497 		else
4498 			ereport(ERROR,
4499 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
4500 					 errmsg("cannot convert infinity to %s", "pg_lsn")));
4501 	}
4502 
4503 	/* Convert to variable format and thence to pg_lsn */
4504 	init_var_from_num(num, &x);
4505 
4506 	if (!numericvar_to_uint64(&x, (uint64 *) &result))
4507 		ereport(ERROR,
4508 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
4509 				 errmsg("pg_lsn out of range")));
4510 
4511 	PG_RETURN_LSN(result);
4512 }
4513 
4514 
4515 /* ----------------------------------------------------------------------
4516  *
4517  * Aggregate functions
4518  *
4519  * The transition datatype for all these aggregates is declared as INTERNAL.
4520  * Actually, it's a pointer to a NumericAggState allocated in the aggregate
4521  * context.  The digit buffers for the NumericVars will be there too.
4522  *
4523  * On platforms which support 128-bit integers some aggregates instead use a
4524  * 128-bit integer based transition datatype to speed up calculations.
4525  *
4526  * ----------------------------------------------------------------------
4527  */
4528 
4529 typedef struct NumericAggState
4530 {
4531 	bool		calcSumX2;		/* if true, calculate sumX2 */
4532 	MemoryContext agg_context;	/* context we're calculating in */
4533 	int64		N;				/* count of processed numbers */
4534 	NumericSumAccum sumX;		/* sum of processed numbers */
4535 	NumericSumAccum sumX2;		/* sum of squares of processed numbers */
4536 	int			maxScale;		/* maximum scale seen so far */
4537 	int64		maxScaleCount;	/* number of values seen with maximum scale */
4538 	/* These counts are *not* included in N!  Use NA_TOTAL_COUNT() as needed */
4539 	int64		NaNcount;		/* count of NaN values */
4540 	int64		pInfcount;		/* count of +Inf values */
4541 	int64		nInfcount;		/* count of -Inf values */
4542 } NumericAggState;
4543 
4544 #define NA_TOTAL_COUNT(na) \
4545 	((na)->N + (na)->NaNcount + (na)->pInfcount + (na)->nInfcount)
4546 
4547 /*
4548  * Prepare state data for a numeric aggregate function that needs to compute
4549  * sum, count and optionally sum of squares of the input.
4550  */
4551 static NumericAggState *
makeNumericAggState(FunctionCallInfo fcinfo,bool calcSumX2)4552 makeNumericAggState(FunctionCallInfo fcinfo, bool calcSumX2)
4553 {
4554 	NumericAggState *state;
4555 	MemoryContext agg_context;
4556 	MemoryContext old_context;
4557 
4558 	if (!AggCheckCallContext(fcinfo, &agg_context))
4559 		elog(ERROR, "aggregate function called in non-aggregate context");
4560 
4561 	old_context = MemoryContextSwitchTo(agg_context);
4562 
4563 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
4564 	state->calcSumX2 = calcSumX2;
4565 	state->agg_context = agg_context;
4566 
4567 	MemoryContextSwitchTo(old_context);
4568 
4569 	return state;
4570 }
4571 
4572 /*
4573  * Like makeNumericAggState(), but allocate the state in the current memory
4574  * context.
4575  */
4576 static NumericAggState *
makeNumericAggStateCurrentContext(bool calcSumX2)4577 makeNumericAggStateCurrentContext(bool calcSumX2)
4578 {
4579 	NumericAggState *state;
4580 
4581 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
4582 	state->calcSumX2 = calcSumX2;
4583 	state->agg_context = CurrentMemoryContext;
4584 
4585 	return state;
4586 }
4587 
4588 /*
4589  * Accumulate a new input value for numeric aggregate functions.
4590  */
4591 static void
do_numeric_accum(NumericAggState * state,Numeric newval)4592 do_numeric_accum(NumericAggState *state, Numeric newval)
4593 {
4594 	NumericVar	X;
4595 	NumericVar	X2;
4596 	MemoryContext old_context;
4597 
4598 	/* Count NaN/infinity inputs separately from all else */
4599 	if (NUMERIC_IS_SPECIAL(newval))
4600 	{
4601 		if (NUMERIC_IS_PINF(newval))
4602 			state->pInfcount++;
4603 		else if (NUMERIC_IS_NINF(newval))
4604 			state->nInfcount++;
4605 		else
4606 			state->NaNcount++;
4607 		return;
4608 	}
4609 
4610 	/* load processed number in short-lived context */
4611 	init_var_from_num(newval, &X);
4612 
4613 	/*
4614 	 * Track the highest input dscale that we've seen, to support inverse
4615 	 * transitions (see do_numeric_discard).
4616 	 */
4617 	if (X.dscale > state->maxScale)
4618 	{
4619 		state->maxScale = X.dscale;
4620 		state->maxScaleCount = 1;
4621 	}
4622 	else if (X.dscale == state->maxScale)
4623 		state->maxScaleCount++;
4624 
4625 	/* if we need X^2, calculate that in short-lived context */
4626 	if (state->calcSumX2)
4627 	{
4628 		init_var(&X2);
4629 		mul_var(&X, &X, &X2, X.dscale * 2);
4630 	}
4631 
4632 	/* The rest of this needs to work in the aggregate context */
4633 	old_context = MemoryContextSwitchTo(state->agg_context);
4634 
4635 	state->N++;
4636 
4637 	/* Accumulate sums */
4638 	accum_sum_add(&(state->sumX), &X);
4639 
4640 	if (state->calcSumX2)
4641 		accum_sum_add(&(state->sumX2), &X2);
4642 
4643 	MemoryContextSwitchTo(old_context);
4644 }
4645 
4646 /*
4647  * Attempt to remove an input value from the aggregated state.
4648  *
4649  * If the value cannot be removed then the function will return false; the
4650  * possible reasons for failing are described below.
4651  *
4652  * If we aggregate the values 1.01 and 2 then the result will be 3.01.
4653  * If we are then asked to un-aggregate the 1.01 then we must fail as we
4654  * won't be able to tell what the new aggregated value's dscale should be.
4655  * We don't want to return 2.00 (dscale = 2), since the sum's dscale would
4656  * have been zero if we'd really aggregated only 2.
4657  *
4658  * Note: alternatively, we could count the number of inputs with each possible
4659  * dscale (up to some sane limit).  Not yet clear if it's worth the trouble.
4660  */
4661 static bool
do_numeric_discard(NumericAggState * state,Numeric newval)4662 do_numeric_discard(NumericAggState *state, Numeric newval)
4663 {
4664 	NumericVar	X;
4665 	NumericVar	X2;
4666 	MemoryContext old_context;
4667 
4668 	/* Count NaN/infinity inputs separately from all else */
4669 	if (NUMERIC_IS_SPECIAL(newval))
4670 	{
4671 		if (NUMERIC_IS_PINF(newval))
4672 			state->pInfcount--;
4673 		else if (NUMERIC_IS_NINF(newval))
4674 			state->nInfcount--;
4675 		else
4676 			state->NaNcount--;
4677 		return true;
4678 	}
4679 
4680 	/* load processed number in short-lived context */
4681 	init_var_from_num(newval, &X);
4682 
4683 	/*
4684 	 * state->sumX's dscale is the maximum dscale of any of the inputs.
4685 	 * Removing the last input with that dscale would require us to recompute
4686 	 * the maximum dscale of the *remaining* inputs, which we cannot do unless
4687 	 * no more non-NaN inputs remain at all.  So we report a failure instead,
4688 	 * and force the aggregation to be redone from scratch.
4689 	 */
4690 	if (X.dscale == state->maxScale)
4691 	{
4692 		if (state->maxScaleCount > 1 || state->maxScale == 0)
4693 		{
4694 			/*
4695 			 * Some remaining inputs have same dscale, or dscale hasn't gotten
4696 			 * above zero anyway
4697 			 */
4698 			state->maxScaleCount--;
4699 		}
4700 		else if (state->N == 1)
4701 		{
4702 			/* No remaining non-NaN inputs at all, so reset maxScale */
4703 			state->maxScale = 0;
4704 			state->maxScaleCount = 0;
4705 		}
4706 		else
4707 		{
4708 			/* Correct new maxScale is uncertain, must fail */
4709 			return false;
4710 		}
4711 	}
4712 
4713 	/* if we need X^2, calculate that in short-lived context */
4714 	if (state->calcSumX2)
4715 	{
4716 		init_var(&X2);
4717 		mul_var(&X, &X, &X2, X.dscale * 2);
4718 	}
4719 
4720 	/* The rest of this needs to work in the aggregate context */
4721 	old_context = MemoryContextSwitchTo(state->agg_context);
4722 
4723 	if (state->N-- > 1)
4724 	{
4725 		/* Negate X, to subtract it from the sum */
4726 		X.sign = (X.sign == NUMERIC_POS ? NUMERIC_NEG : NUMERIC_POS);
4727 		accum_sum_add(&(state->sumX), &X);
4728 
4729 		if (state->calcSumX2)
4730 		{
4731 			/* Negate X^2. X^2 is always positive */
4732 			X2.sign = NUMERIC_NEG;
4733 			accum_sum_add(&(state->sumX2), &X2);
4734 		}
4735 	}
4736 	else
4737 	{
4738 		/* Zero the sums */
4739 		Assert(state->N == 0);
4740 
4741 		accum_sum_reset(&state->sumX);
4742 		if (state->calcSumX2)
4743 			accum_sum_reset(&state->sumX2);
4744 	}
4745 
4746 	MemoryContextSwitchTo(old_context);
4747 
4748 	return true;
4749 }
4750 
4751 /*
4752  * Generic transition function for numeric aggregates that require sumX2.
4753  */
4754 Datum
numeric_accum(PG_FUNCTION_ARGS)4755 numeric_accum(PG_FUNCTION_ARGS)
4756 {
4757 	NumericAggState *state;
4758 
4759 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4760 
4761 	/* Create the state data on the first call */
4762 	if (state == NULL)
4763 		state = makeNumericAggState(fcinfo, true);
4764 
4765 	if (!PG_ARGISNULL(1))
4766 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
4767 
4768 	PG_RETURN_POINTER(state);
4769 }
4770 
4771 /*
4772  * Generic combine function for numeric aggregates which require sumX2
4773  */
4774 Datum
numeric_combine(PG_FUNCTION_ARGS)4775 numeric_combine(PG_FUNCTION_ARGS)
4776 {
4777 	NumericAggState *state1;
4778 	NumericAggState *state2;
4779 	MemoryContext agg_context;
4780 	MemoryContext old_context;
4781 
4782 	if (!AggCheckCallContext(fcinfo, &agg_context))
4783 		elog(ERROR, "aggregate function called in non-aggregate context");
4784 
4785 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4786 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
4787 
4788 	if (state2 == NULL)
4789 		PG_RETURN_POINTER(state1);
4790 
4791 	/* manually copy all fields from state2 to state1 */
4792 	if (state1 == NULL)
4793 	{
4794 		old_context = MemoryContextSwitchTo(agg_context);
4795 
4796 		state1 = makeNumericAggStateCurrentContext(true);
4797 		state1->N = state2->N;
4798 		state1->NaNcount = state2->NaNcount;
4799 		state1->pInfcount = state2->pInfcount;
4800 		state1->nInfcount = state2->nInfcount;
4801 		state1->maxScale = state2->maxScale;
4802 		state1->maxScaleCount = state2->maxScaleCount;
4803 
4804 		accum_sum_copy(&state1->sumX, &state2->sumX);
4805 		accum_sum_copy(&state1->sumX2, &state2->sumX2);
4806 
4807 		MemoryContextSwitchTo(old_context);
4808 
4809 		PG_RETURN_POINTER(state1);
4810 	}
4811 
4812 	state1->N += state2->N;
4813 	state1->NaNcount += state2->NaNcount;
4814 	state1->pInfcount += state2->pInfcount;
4815 	state1->nInfcount += state2->nInfcount;
4816 
4817 	if (state2->N > 0)
4818 	{
4819 		/*
4820 		 * These are currently only needed for moving aggregates, but let's do
4821 		 * the right thing anyway...
4822 		 */
4823 		if (state2->maxScale > state1->maxScale)
4824 		{
4825 			state1->maxScale = state2->maxScale;
4826 			state1->maxScaleCount = state2->maxScaleCount;
4827 		}
4828 		else if (state2->maxScale == state1->maxScale)
4829 			state1->maxScaleCount += state2->maxScaleCount;
4830 
4831 		/* The rest of this needs to work in the aggregate context */
4832 		old_context = MemoryContextSwitchTo(agg_context);
4833 
4834 		/* Accumulate sums */
4835 		accum_sum_combine(&state1->sumX, &state2->sumX);
4836 		accum_sum_combine(&state1->sumX2, &state2->sumX2);
4837 
4838 		MemoryContextSwitchTo(old_context);
4839 	}
4840 	PG_RETURN_POINTER(state1);
4841 }
4842 
4843 /*
4844  * Generic transition function for numeric aggregates that don't require sumX2.
4845  */
4846 Datum
numeric_avg_accum(PG_FUNCTION_ARGS)4847 numeric_avg_accum(PG_FUNCTION_ARGS)
4848 {
4849 	NumericAggState *state;
4850 
4851 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4852 
4853 	/* Create the state data on the first call */
4854 	if (state == NULL)
4855 		state = makeNumericAggState(fcinfo, false);
4856 
4857 	if (!PG_ARGISNULL(1))
4858 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
4859 
4860 	PG_RETURN_POINTER(state);
4861 }
4862 
4863 /*
4864  * Combine function for numeric aggregates which don't require sumX2
4865  */
4866 Datum
numeric_avg_combine(PG_FUNCTION_ARGS)4867 numeric_avg_combine(PG_FUNCTION_ARGS)
4868 {
4869 	NumericAggState *state1;
4870 	NumericAggState *state2;
4871 	MemoryContext agg_context;
4872 	MemoryContext old_context;
4873 
4874 	if (!AggCheckCallContext(fcinfo, &agg_context))
4875 		elog(ERROR, "aggregate function called in non-aggregate context");
4876 
4877 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4878 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
4879 
4880 	if (state2 == NULL)
4881 		PG_RETURN_POINTER(state1);
4882 
4883 	/* manually copy all fields from state2 to state1 */
4884 	if (state1 == NULL)
4885 	{
4886 		old_context = MemoryContextSwitchTo(agg_context);
4887 
4888 		state1 = makeNumericAggStateCurrentContext(false);
4889 		state1->N = state2->N;
4890 		state1->NaNcount = state2->NaNcount;
4891 		state1->pInfcount = state2->pInfcount;
4892 		state1->nInfcount = state2->nInfcount;
4893 		state1->maxScale = state2->maxScale;
4894 		state1->maxScaleCount = state2->maxScaleCount;
4895 
4896 		accum_sum_copy(&state1->sumX, &state2->sumX);
4897 
4898 		MemoryContextSwitchTo(old_context);
4899 
4900 		PG_RETURN_POINTER(state1);
4901 	}
4902 
4903 	state1->N += state2->N;
4904 	state1->NaNcount += state2->NaNcount;
4905 	state1->pInfcount += state2->pInfcount;
4906 	state1->nInfcount += state2->nInfcount;
4907 
4908 	if (state2->N > 0)
4909 	{
4910 		/*
4911 		 * These are currently only needed for moving aggregates, but let's do
4912 		 * the right thing anyway...
4913 		 */
4914 		if (state2->maxScale > state1->maxScale)
4915 		{
4916 			state1->maxScale = state2->maxScale;
4917 			state1->maxScaleCount = state2->maxScaleCount;
4918 		}
4919 		else if (state2->maxScale == state1->maxScale)
4920 			state1->maxScaleCount += state2->maxScaleCount;
4921 
4922 		/* The rest of this needs to work in the aggregate context */
4923 		old_context = MemoryContextSwitchTo(agg_context);
4924 
4925 		/* Accumulate sums */
4926 		accum_sum_combine(&state1->sumX, &state2->sumX);
4927 
4928 		MemoryContextSwitchTo(old_context);
4929 	}
4930 	PG_RETURN_POINTER(state1);
4931 }
4932 
4933 /*
4934  * numeric_avg_serialize
4935  *		Serialize NumericAggState for numeric aggregates that don't require
4936  *		sumX2.
4937  */
4938 Datum
numeric_avg_serialize(PG_FUNCTION_ARGS)4939 numeric_avg_serialize(PG_FUNCTION_ARGS)
4940 {
4941 	NumericAggState *state;
4942 	StringInfoData buf;
4943 	Datum		temp;
4944 	bytea	   *sumX;
4945 	bytea	   *result;
4946 	NumericVar	tmp_var;
4947 
4948 	/* Ensure we disallow calling when not in aggregate context */
4949 	if (!AggCheckCallContext(fcinfo, NULL))
4950 		elog(ERROR, "aggregate function called in non-aggregate context");
4951 
4952 	state = (NumericAggState *) PG_GETARG_POINTER(0);
4953 
4954 	/*
4955 	 * This is a little wasteful since make_result converts the NumericVar
4956 	 * into a Numeric and numeric_send converts it back again. Is it worth
4957 	 * splitting the tasks in numeric_send into separate functions to stop
4958 	 * this? Doing so would also remove the fmgr call overhead.
4959 	 */
4960 	init_var(&tmp_var);
4961 	accum_sum_final(&state->sumX, &tmp_var);
4962 
4963 	temp = DirectFunctionCall1(numeric_send,
4964 							   NumericGetDatum(make_result(&tmp_var)));
4965 	sumX = DatumGetByteaPP(temp);
4966 	free_var(&tmp_var);
4967 
4968 	pq_begintypsend(&buf);
4969 
4970 	/* N */
4971 	pq_sendint64(&buf, state->N);
4972 
4973 	/* sumX */
4974 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4975 
4976 	/* maxScale */
4977 	pq_sendint32(&buf, state->maxScale);
4978 
4979 	/* maxScaleCount */
4980 	pq_sendint64(&buf, state->maxScaleCount);
4981 
4982 	/* NaNcount */
4983 	pq_sendint64(&buf, state->NaNcount);
4984 
4985 	/* pInfcount */
4986 	pq_sendint64(&buf, state->pInfcount);
4987 
4988 	/* nInfcount */
4989 	pq_sendint64(&buf, state->nInfcount);
4990 
4991 	result = pq_endtypsend(&buf);
4992 
4993 	PG_RETURN_BYTEA_P(result);
4994 }
4995 
4996 /*
4997  * numeric_avg_deserialize
4998  *		Deserialize bytea into NumericAggState for numeric aggregates that
4999  *		don't require sumX2.
5000  */
5001 Datum
numeric_avg_deserialize(PG_FUNCTION_ARGS)5002 numeric_avg_deserialize(PG_FUNCTION_ARGS)
5003 {
5004 	bytea	   *sstate;
5005 	NumericAggState *result;
5006 	Datum		temp;
5007 	NumericVar	tmp_var;
5008 	StringInfoData buf;
5009 
5010 	if (!AggCheckCallContext(fcinfo, NULL))
5011 		elog(ERROR, "aggregate function called in non-aggregate context");
5012 
5013 	sstate = PG_GETARG_BYTEA_PP(0);
5014 
5015 	/*
5016 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
5017 	 * standard recv-function infrastructure.
5018 	 */
5019 	initStringInfo(&buf);
5020 	appendBinaryStringInfo(&buf,
5021 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
5022 
5023 	result = makeNumericAggStateCurrentContext(false);
5024 
5025 	/* N */
5026 	result->N = pq_getmsgint64(&buf);
5027 
5028 	/* sumX */
5029 	temp = DirectFunctionCall3(numeric_recv,
5030 							   PointerGetDatum(&buf),
5031 							   ObjectIdGetDatum(InvalidOid),
5032 							   Int32GetDatum(-1));
5033 	init_var_from_num(DatumGetNumeric(temp), &tmp_var);
5034 	accum_sum_add(&(result->sumX), &tmp_var);
5035 
5036 	/* maxScale */
5037 	result->maxScale = pq_getmsgint(&buf, 4);
5038 
5039 	/* maxScaleCount */
5040 	result->maxScaleCount = pq_getmsgint64(&buf);
5041 
5042 	/* NaNcount */
5043 	result->NaNcount = pq_getmsgint64(&buf);
5044 
5045 	/* pInfcount */
5046 	result->pInfcount = pq_getmsgint64(&buf);
5047 
5048 	/* nInfcount */
5049 	result->nInfcount = pq_getmsgint64(&buf);
5050 
5051 	pq_getmsgend(&buf);
5052 	pfree(buf.data);
5053 
5054 	PG_RETURN_POINTER(result);
5055 }
5056 
5057 /*
5058  * numeric_serialize
5059  *		Serialization function for NumericAggState for numeric aggregates that
5060  *		require sumX2.
5061  */
5062 Datum
numeric_serialize(PG_FUNCTION_ARGS)5063 numeric_serialize(PG_FUNCTION_ARGS)
5064 {
5065 	NumericAggState *state;
5066 	StringInfoData buf;
5067 	Datum		temp;
5068 	bytea	   *sumX;
5069 	NumericVar	tmp_var;
5070 	bytea	   *sumX2;
5071 	bytea	   *result;
5072 
5073 	/* Ensure we disallow calling when not in aggregate context */
5074 	if (!AggCheckCallContext(fcinfo, NULL))
5075 		elog(ERROR, "aggregate function called in non-aggregate context");
5076 
5077 	state = (NumericAggState *) PG_GETARG_POINTER(0);
5078 
5079 	/*
5080 	 * This is a little wasteful since make_result converts the NumericVar
5081 	 * into a Numeric and numeric_send converts it back again. Is it worth
5082 	 * splitting the tasks in numeric_send into separate functions to stop
5083 	 * this? Doing so would also remove the fmgr call overhead.
5084 	 */
5085 	init_var(&tmp_var);
5086 
5087 	accum_sum_final(&state->sumX, &tmp_var);
5088 	temp = DirectFunctionCall1(numeric_send,
5089 							   NumericGetDatum(make_result(&tmp_var)));
5090 	sumX = DatumGetByteaPP(temp);
5091 
5092 	accum_sum_final(&state->sumX2, &tmp_var);
5093 	temp = DirectFunctionCall1(numeric_send,
5094 							   NumericGetDatum(make_result(&tmp_var)));
5095 	sumX2 = DatumGetByteaPP(temp);
5096 
5097 	free_var(&tmp_var);
5098 
5099 	pq_begintypsend(&buf);
5100 
5101 	/* N */
5102 	pq_sendint64(&buf, state->N);
5103 
5104 	/* sumX */
5105 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
5106 
5107 	/* sumX2 */
5108 	pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
5109 
5110 	/* maxScale */
5111 	pq_sendint32(&buf, state->maxScale);
5112 
5113 	/* maxScaleCount */
5114 	pq_sendint64(&buf, state->maxScaleCount);
5115 
5116 	/* NaNcount */
5117 	pq_sendint64(&buf, state->NaNcount);
5118 
5119 	/* pInfcount */
5120 	pq_sendint64(&buf, state->pInfcount);
5121 
5122 	/* nInfcount */
5123 	pq_sendint64(&buf, state->nInfcount);
5124 
5125 	result = pq_endtypsend(&buf);
5126 
5127 	PG_RETURN_BYTEA_P(result);
5128 }
5129 
5130 /*
5131  * numeric_deserialize
5132  *		Deserialization function for NumericAggState for numeric aggregates that
5133  *		require sumX2.
5134  */
5135 Datum
numeric_deserialize(PG_FUNCTION_ARGS)5136 numeric_deserialize(PG_FUNCTION_ARGS)
5137 {
5138 	bytea	   *sstate;
5139 	NumericAggState *result;
5140 	Datum		temp;
5141 	NumericVar	sumX_var;
5142 	NumericVar	sumX2_var;
5143 	StringInfoData buf;
5144 
5145 	if (!AggCheckCallContext(fcinfo, NULL))
5146 		elog(ERROR, "aggregate function called in non-aggregate context");
5147 
5148 	sstate = PG_GETARG_BYTEA_PP(0);
5149 
5150 	/*
5151 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
5152 	 * standard recv-function infrastructure.
5153 	 */
5154 	initStringInfo(&buf);
5155 	appendBinaryStringInfo(&buf,
5156 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
5157 
5158 	result = makeNumericAggStateCurrentContext(false);
5159 
5160 	/* N */
5161 	result->N = pq_getmsgint64(&buf);
5162 
5163 	/* sumX */
5164 	temp = DirectFunctionCall3(numeric_recv,
5165 							   PointerGetDatum(&buf),
5166 							   ObjectIdGetDatum(InvalidOid),
5167 							   Int32GetDatum(-1));
5168 	init_var_from_num(DatumGetNumeric(temp), &sumX_var);
5169 	accum_sum_add(&(result->sumX), &sumX_var);
5170 
5171 	/* sumX2 */
5172 	temp = DirectFunctionCall3(numeric_recv,
5173 							   PointerGetDatum(&buf),
5174 							   ObjectIdGetDatum(InvalidOid),
5175 							   Int32GetDatum(-1));
5176 	init_var_from_num(DatumGetNumeric(temp), &sumX2_var);
5177 	accum_sum_add(&(result->sumX2), &sumX2_var);
5178 
5179 	/* maxScale */
5180 	result->maxScale = pq_getmsgint(&buf, 4);
5181 
5182 	/* maxScaleCount */
5183 	result->maxScaleCount = pq_getmsgint64(&buf);
5184 
5185 	/* NaNcount */
5186 	result->NaNcount = pq_getmsgint64(&buf);
5187 
5188 	/* pInfcount */
5189 	result->pInfcount = pq_getmsgint64(&buf);
5190 
5191 	/* nInfcount */
5192 	result->nInfcount = pq_getmsgint64(&buf);
5193 
5194 	pq_getmsgend(&buf);
5195 	pfree(buf.data);
5196 
5197 	PG_RETURN_POINTER(result);
5198 }
5199 
5200 /*
5201  * Generic inverse transition function for numeric aggregates
5202  * (with or without requirement for X^2).
5203  */
5204 Datum
numeric_accum_inv(PG_FUNCTION_ARGS)5205 numeric_accum_inv(PG_FUNCTION_ARGS)
5206 {
5207 	NumericAggState *state;
5208 
5209 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5210 
5211 	/* Should not get here with no state */
5212 	if (state == NULL)
5213 		elog(ERROR, "numeric_accum_inv called with NULL state");
5214 
5215 	if (!PG_ARGISNULL(1))
5216 	{
5217 		/* If we fail to perform the inverse transition, return NULL */
5218 		if (!do_numeric_discard(state, PG_GETARG_NUMERIC(1)))
5219 			PG_RETURN_NULL();
5220 	}
5221 
5222 	PG_RETURN_POINTER(state);
5223 }
5224 
5225 
5226 /*
5227  * Integer data types in general use Numeric accumulators to share code
5228  * and avoid risk of overflow.
5229  *
5230  * However for performance reasons optimized special-purpose accumulator
5231  * routines are used when possible.
5232  *
5233  * On platforms with 128-bit integer support, the 128-bit routines will be
5234  * used when sum(X) or sum(X*X) fit into 128-bit.
5235  *
5236  * For 16 and 32 bit inputs, the N and sum(X) fit into 64-bit so the 64-bit
5237  * accumulators will be used for SUM and AVG of these data types.
5238  */
5239 
5240 #ifdef HAVE_INT128
5241 typedef struct Int128AggState
5242 {
5243 	bool		calcSumX2;		/* if true, calculate sumX2 */
5244 	int64		N;				/* count of processed numbers */
5245 	int128		sumX;			/* sum of processed numbers */
5246 	int128		sumX2;			/* sum of squares of processed numbers */
5247 } Int128AggState;
5248 
5249 /*
5250  * Prepare state data for a 128-bit aggregate function that needs to compute
5251  * sum, count and optionally sum of squares of the input.
5252  */
5253 static Int128AggState *
makeInt128AggState(FunctionCallInfo fcinfo,bool calcSumX2)5254 makeInt128AggState(FunctionCallInfo fcinfo, bool calcSumX2)
5255 {
5256 	Int128AggState *state;
5257 	MemoryContext agg_context;
5258 	MemoryContext old_context;
5259 
5260 	if (!AggCheckCallContext(fcinfo, &agg_context))
5261 		elog(ERROR, "aggregate function called in non-aggregate context");
5262 
5263 	old_context = MemoryContextSwitchTo(agg_context);
5264 
5265 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
5266 	state->calcSumX2 = calcSumX2;
5267 
5268 	MemoryContextSwitchTo(old_context);
5269 
5270 	return state;
5271 }
5272 
5273 /*
5274  * Like makeInt128AggState(), but allocate the state in the current memory
5275  * context.
5276  */
5277 static Int128AggState *
makeInt128AggStateCurrentContext(bool calcSumX2)5278 makeInt128AggStateCurrentContext(bool calcSumX2)
5279 {
5280 	Int128AggState *state;
5281 
5282 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
5283 	state->calcSumX2 = calcSumX2;
5284 
5285 	return state;
5286 }
5287 
5288 /*
5289  * Accumulate a new input value for 128-bit aggregate functions.
5290  */
5291 static void
do_int128_accum(Int128AggState * state,int128 newval)5292 do_int128_accum(Int128AggState *state, int128 newval)
5293 {
5294 	if (state->calcSumX2)
5295 		state->sumX2 += newval * newval;
5296 
5297 	state->sumX += newval;
5298 	state->N++;
5299 }
5300 
5301 /*
5302  * Remove an input value from the aggregated state.
5303  */
5304 static void
do_int128_discard(Int128AggState * state,int128 newval)5305 do_int128_discard(Int128AggState *state, int128 newval)
5306 {
5307 	if (state->calcSumX2)
5308 		state->sumX2 -= newval * newval;
5309 
5310 	state->sumX -= newval;
5311 	state->N--;
5312 }
5313 
5314 typedef Int128AggState PolyNumAggState;
5315 #define makePolyNumAggState makeInt128AggState
5316 #define makePolyNumAggStateCurrentContext makeInt128AggStateCurrentContext
5317 #else
5318 typedef NumericAggState PolyNumAggState;
5319 #define makePolyNumAggState makeNumericAggState
5320 #define makePolyNumAggStateCurrentContext makeNumericAggStateCurrentContext
5321 #endif
5322 
5323 Datum
int2_accum(PG_FUNCTION_ARGS)5324 int2_accum(PG_FUNCTION_ARGS)
5325 {
5326 	PolyNumAggState *state;
5327 
5328 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5329 
5330 	/* Create the state data on the first call */
5331 	if (state == NULL)
5332 		state = makePolyNumAggState(fcinfo, true);
5333 
5334 	if (!PG_ARGISNULL(1))
5335 	{
5336 #ifdef HAVE_INT128
5337 		do_int128_accum(state, (int128) PG_GETARG_INT16(1));
5338 #else
5339 		do_numeric_accum(state, int64_to_numeric(PG_GETARG_INT16(1)));
5340 #endif
5341 	}
5342 
5343 	PG_RETURN_POINTER(state);
5344 }
5345 
5346 Datum
int4_accum(PG_FUNCTION_ARGS)5347 int4_accum(PG_FUNCTION_ARGS)
5348 {
5349 	PolyNumAggState *state;
5350 
5351 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5352 
5353 	/* Create the state data on the first call */
5354 	if (state == NULL)
5355 		state = makePolyNumAggState(fcinfo, true);
5356 
5357 	if (!PG_ARGISNULL(1))
5358 	{
5359 #ifdef HAVE_INT128
5360 		do_int128_accum(state, (int128) PG_GETARG_INT32(1));
5361 #else
5362 		do_numeric_accum(state, int64_to_numeric(PG_GETARG_INT32(1)));
5363 #endif
5364 	}
5365 
5366 	PG_RETURN_POINTER(state);
5367 }
5368 
5369 Datum
int8_accum(PG_FUNCTION_ARGS)5370 int8_accum(PG_FUNCTION_ARGS)
5371 {
5372 	NumericAggState *state;
5373 
5374 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5375 
5376 	/* Create the state data on the first call */
5377 	if (state == NULL)
5378 		state = makeNumericAggState(fcinfo, true);
5379 
5380 	if (!PG_ARGISNULL(1))
5381 		do_numeric_accum(state, int64_to_numeric(PG_GETARG_INT64(1)));
5382 
5383 	PG_RETURN_POINTER(state);
5384 }
5385 
5386 /*
5387  * Combine function for numeric aggregates which require sumX2
5388  */
5389 Datum
numeric_poly_combine(PG_FUNCTION_ARGS)5390 numeric_poly_combine(PG_FUNCTION_ARGS)
5391 {
5392 	PolyNumAggState *state1;
5393 	PolyNumAggState *state2;
5394 	MemoryContext agg_context;
5395 	MemoryContext old_context;
5396 
5397 	if (!AggCheckCallContext(fcinfo, &agg_context))
5398 		elog(ERROR, "aggregate function called in non-aggregate context");
5399 
5400 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5401 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
5402 
5403 	if (state2 == NULL)
5404 		PG_RETURN_POINTER(state1);
5405 
5406 	/* manually copy all fields from state2 to state1 */
5407 	if (state1 == NULL)
5408 	{
5409 		old_context = MemoryContextSwitchTo(agg_context);
5410 
5411 		state1 = makePolyNumAggState(fcinfo, true);
5412 		state1->N = state2->N;
5413 
5414 #ifdef HAVE_INT128
5415 		state1->sumX = state2->sumX;
5416 		state1->sumX2 = state2->sumX2;
5417 #else
5418 		accum_sum_copy(&state1->sumX, &state2->sumX);
5419 		accum_sum_copy(&state1->sumX2, &state2->sumX2);
5420 #endif
5421 
5422 		MemoryContextSwitchTo(old_context);
5423 
5424 		PG_RETURN_POINTER(state1);
5425 	}
5426 
5427 	if (state2->N > 0)
5428 	{
5429 		state1->N += state2->N;
5430 
5431 #ifdef HAVE_INT128
5432 		state1->sumX += state2->sumX;
5433 		state1->sumX2 += state2->sumX2;
5434 #else
5435 		/* The rest of this needs to work in the aggregate context */
5436 		old_context = MemoryContextSwitchTo(agg_context);
5437 
5438 		/* Accumulate sums */
5439 		accum_sum_combine(&state1->sumX, &state2->sumX);
5440 		accum_sum_combine(&state1->sumX2, &state2->sumX2);
5441 
5442 		MemoryContextSwitchTo(old_context);
5443 #endif
5444 
5445 	}
5446 	PG_RETURN_POINTER(state1);
5447 }
5448 
5449 /*
5450  * numeric_poly_serialize
5451  *		Serialize PolyNumAggState into bytea for aggregate functions which
5452  *		require sumX2.
5453  */
5454 Datum
numeric_poly_serialize(PG_FUNCTION_ARGS)5455 numeric_poly_serialize(PG_FUNCTION_ARGS)
5456 {
5457 	PolyNumAggState *state;
5458 	StringInfoData buf;
5459 	bytea	   *sumX;
5460 	bytea	   *sumX2;
5461 	bytea	   *result;
5462 
5463 	/* Ensure we disallow calling when not in aggregate context */
5464 	if (!AggCheckCallContext(fcinfo, NULL))
5465 		elog(ERROR, "aggregate function called in non-aggregate context");
5466 
5467 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
5468 
5469 	/*
5470 	 * If the platform supports int128 then sumX and sumX2 will be a 128 bit
5471 	 * integer type. Here we'll convert that into a numeric type so that the
5472 	 * combine state is in the same format for both int128 enabled machines
5473 	 * and machines which don't support that type. The logic here is that one
5474 	 * day we might like to send these over to another server for further
5475 	 * processing and we want a standard format to work with.
5476 	 */
5477 	{
5478 		Datum		temp;
5479 		NumericVar	num;
5480 
5481 		init_var(&num);
5482 
5483 #ifdef HAVE_INT128
5484 		int128_to_numericvar(state->sumX, &num);
5485 #else
5486 		accum_sum_final(&state->sumX, &num);
5487 #endif
5488 		temp = DirectFunctionCall1(numeric_send,
5489 								   NumericGetDatum(make_result(&num)));
5490 		sumX = DatumGetByteaPP(temp);
5491 
5492 #ifdef HAVE_INT128
5493 		int128_to_numericvar(state->sumX2, &num);
5494 #else
5495 		accum_sum_final(&state->sumX2, &num);
5496 #endif
5497 		temp = DirectFunctionCall1(numeric_send,
5498 								   NumericGetDatum(make_result(&num)));
5499 		sumX2 = DatumGetByteaPP(temp);
5500 
5501 		free_var(&num);
5502 	}
5503 
5504 	pq_begintypsend(&buf);
5505 
5506 	/* N */
5507 	pq_sendint64(&buf, state->N);
5508 
5509 	/* sumX */
5510 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
5511 
5512 	/* sumX2 */
5513 	pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
5514 
5515 	result = pq_endtypsend(&buf);
5516 
5517 	PG_RETURN_BYTEA_P(result);
5518 }
5519 
5520 /*
5521  * numeric_poly_deserialize
5522  *		Deserialize PolyNumAggState from bytea for aggregate functions which
5523  *		require sumX2.
5524  */
5525 Datum
numeric_poly_deserialize(PG_FUNCTION_ARGS)5526 numeric_poly_deserialize(PG_FUNCTION_ARGS)
5527 {
5528 	bytea	   *sstate;
5529 	PolyNumAggState *result;
5530 	Datum		sumX;
5531 	NumericVar	sumX_var;
5532 	Datum		sumX2;
5533 	NumericVar	sumX2_var;
5534 	StringInfoData buf;
5535 
5536 	if (!AggCheckCallContext(fcinfo, NULL))
5537 		elog(ERROR, "aggregate function called in non-aggregate context");
5538 
5539 	sstate = PG_GETARG_BYTEA_PP(0);
5540 
5541 	/*
5542 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
5543 	 * standard recv-function infrastructure.
5544 	 */
5545 	initStringInfo(&buf);
5546 	appendBinaryStringInfo(&buf,
5547 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
5548 
5549 	result = makePolyNumAggStateCurrentContext(false);
5550 
5551 	/* N */
5552 	result->N = pq_getmsgint64(&buf);
5553 
5554 	/* sumX */
5555 	sumX = DirectFunctionCall3(numeric_recv,
5556 							   PointerGetDatum(&buf),
5557 							   ObjectIdGetDatum(InvalidOid),
5558 							   Int32GetDatum(-1));
5559 
5560 	/* sumX2 */
5561 	sumX2 = DirectFunctionCall3(numeric_recv,
5562 								PointerGetDatum(&buf),
5563 								ObjectIdGetDatum(InvalidOid),
5564 								Int32GetDatum(-1));
5565 
5566 	init_var_from_num(DatumGetNumeric(sumX), &sumX_var);
5567 #ifdef HAVE_INT128
5568 	numericvar_to_int128(&sumX_var, &result->sumX);
5569 #else
5570 	accum_sum_add(&result->sumX, &sumX_var);
5571 #endif
5572 
5573 	init_var_from_num(DatumGetNumeric(sumX2), &sumX2_var);
5574 #ifdef HAVE_INT128
5575 	numericvar_to_int128(&sumX2_var, &result->sumX2);
5576 #else
5577 	accum_sum_add(&result->sumX2, &sumX2_var);
5578 #endif
5579 
5580 	pq_getmsgend(&buf);
5581 	pfree(buf.data);
5582 
5583 	PG_RETURN_POINTER(result);
5584 }
5585 
5586 /*
5587  * Transition function for int8 input when we don't need sumX2.
5588  */
5589 Datum
int8_avg_accum(PG_FUNCTION_ARGS)5590 int8_avg_accum(PG_FUNCTION_ARGS)
5591 {
5592 	PolyNumAggState *state;
5593 
5594 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5595 
5596 	/* Create the state data on the first call */
5597 	if (state == NULL)
5598 		state = makePolyNumAggState(fcinfo, false);
5599 
5600 	if (!PG_ARGISNULL(1))
5601 	{
5602 #ifdef HAVE_INT128
5603 		do_int128_accum(state, (int128) PG_GETARG_INT64(1));
5604 #else
5605 		do_numeric_accum(state, int64_to_numeric(PG_GETARG_INT64(1)));
5606 #endif
5607 	}
5608 
5609 	PG_RETURN_POINTER(state);
5610 }
5611 
5612 /*
5613  * Combine function for PolyNumAggState for aggregates which don't require
5614  * sumX2
5615  */
5616 Datum
int8_avg_combine(PG_FUNCTION_ARGS)5617 int8_avg_combine(PG_FUNCTION_ARGS)
5618 {
5619 	PolyNumAggState *state1;
5620 	PolyNumAggState *state2;
5621 	MemoryContext agg_context;
5622 	MemoryContext old_context;
5623 
5624 	if (!AggCheckCallContext(fcinfo, &agg_context))
5625 		elog(ERROR, "aggregate function called in non-aggregate context");
5626 
5627 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5628 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
5629 
5630 	if (state2 == NULL)
5631 		PG_RETURN_POINTER(state1);
5632 
5633 	/* manually copy all fields from state2 to state1 */
5634 	if (state1 == NULL)
5635 	{
5636 		old_context = MemoryContextSwitchTo(agg_context);
5637 
5638 		state1 = makePolyNumAggState(fcinfo, false);
5639 		state1->N = state2->N;
5640 
5641 #ifdef HAVE_INT128
5642 		state1->sumX = state2->sumX;
5643 #else
5644 		accum_sum_copy(&state1->sumX, &state2->sumX);
5645 #endif
5646 		MemoryContextSwitchTo(old_context);
5647 
5648 		PG_RETURN_POINTER(state1);
5649 	}
5650 
5651 	if (state2->N > 0)
5652 	{
5653 		state1->N += state2->N;
5654 
5655 #ifdef HAVE_INT128
5656 		state1->sumX += state2->sumX;
5657 #else
5658 		/* The rest of this needs to work in the aggregate context */
5659 		old_context = MemoryContextSwitchTo(agg_context);
5660 
5661 		/* Accumulate sums */
5662 		accum_sum_combine(&state1->sumX, &state2->sumX);
5663 
5664 		MemoryContextSwitchTo(old_context);
5665 #endif
5666 
5667 	}
5668 	PG_RETURN_POINTER(state1);
5669 }
5670 
5671 /*
5672  * int8_avg_serialize
5673  *		Serialize PolyNumAggState into bytea using the standard
5674  *		recv-function infrastructure.
5675  */
5676 Datum
int8_avg_serialize(PG_FUNCTION_ARGS)5677 int8_avg_serialize(PG_FUNCTION_ARGS)
5678 {
5679 	PolyNumAggState *state;
5680 	StringInfoData buf;
5681 	bytea	   *sumX;
5682 	bytea	   *result;
5683 
5684 	/* Ensure we disallow calling when not in aggregate context */
5685 	if (!AggCheckCallContext(fcinfo, NULL))
5686 		elog(ERROR, "aggregate function called in non-aggregate context");
5687 
5688 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
5689 
5690 	/*
5691 	 * If the platform supports int128 then sumX will be a 128 integer type.
5692 	 * Here we'll convert that into a numeric type so that the combine state
5693 	 * is in the same format for both int128 enabled machines and machines
5694 	 * which don't support that type. The logic here is that one day we might
5695 	 * like to send these over to another server for further processing and we
5696 	 * want a standard format to work with.
5697 	 */
5698 	{
5699 		Datum		temp;
5700 		NumericVar	num;
5701 
5702 		init_var(&num);
5703 
5704 #ifdef HAVE_INT128
5705 		int128_to_numericvar(state->sumX, &num);
5706 #else
5707 		accum_sum_final(&state->sumX, &num);
5708 #endif
5709 		temp = DirectFunctionCall1(numeric_send,
5710 								   NumericGetDatum(make_result(&num)));
5711 		sumX = DatumGetByteaPP(temp);
5712 
5713 		free_var(&num);
5714 	}
5715 
5716 	pq_begintypsend(&buf);
5717 
5718 	/* N */
5719 	pq_sendint64(&buf, state->N);
5720 
5721 	/* sumX */
5722 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
5723 
5724 	result = pq_endtypsend(&buf);
5725 
5726 	PG_RETURN_BYTEA_P(result);
5727 }
5728 
5729 /*
5730  * int8_avg_deserialize
5731  *		Deserialize bytea back into PolyNumAggState.
5732  */
5733 Datum
int8_avg_deserialize(PG_FUNCTION_ARGS)5734 int8_avg_deserialize(PG_FUNCTION_ARGS)
5735 {
5736 	bytea	   *sstate;
5737 	PolyNumAggState *result;
5738 	StringInfoData buf;
5739 	Datum		temp;
5740 	NumericVar	num;
5741 
5742 	if (!AggCheckCallContext(fcinfo, NULL))
5743 		elog(ERROR, "aggregate function called in non-aggregate context");
5744 
5745 	sstate = PG_GETARG_BYTEA_PP(0);
5746 
5747 	/*
5748 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
5749 	 * standard recv-function infrastructure.
5750 	 */
5751 	initStringInfo(&buf);
5752 	appendBinaryStringInfo(&buf,
5753 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
5754 
5755 	result = makePolyNumAggStateCurrentContext(false);
5756 
5757 	/* N */
5758 	result->N = pq_getmsgint64(&buf);
5759 
5760 	/* sumX */
5761 	temp = DirectFunctionCall3(numeric_recv,
5762 							   PointerGetDatum(&buf),
5763 							   ObjectIdGetDatum(InvalidOid),
5764 							   Int32GetDatum(-1));
5765 	init_var_from_num(DatumGetNumeric(temp), &num);
5766 #ifdef HAVE_INT128
5767 	numericvar_to_int128(&num, &result->sumX);
5768 #else
5769 	accum_sum_add(&result->sumX, &num);
5770 #endif
5771 
5772 	pq_getmsgend(&buf);
5773 	pfree(buf.data);
5774 
5775 	PG_RETURN_POINTER(result);
5776 }
5777 
5778 /*
5779  * Inverse transition functions to go with the above.
5780  */
5781 
5782 Datum
int2_accum_inv(PG_FUNCTION_ARGS)5783 int2_accum_inv(PG_FUNCTION_ARGS)
5784 {
5785 	PolyNumAggState *state;
5786 
5787 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5788 
5789 	/* Should not get here with no state */
5790 	if (state == NULL)
5791 		elog(ERROR, "int2_accum_inv called with NULL state");
5792 
5793 	if (!PG_ARGISNULL(1))
5794 	{
5795 #ifdef HAVE_INT128
5796 		do_int128_discard(state, (int128) PG_GETARG_INT16(1));
5797 #else
5798 		/* Should never fail, all inputs have dscale 0 */
5799 		if (!do_numeric_discard(state, int64_to_numeric(PG_GETARG_INT16(1))))
5800 			elog(ERROR, "do_numeric_discard failed unexpectedly");
5801 #endif
5802 	}
5803 
5804 	PG_RETURN_POINTER(state);
5805 }
5806 
5807 Datum
int4_accum_inv(PG_FUNCTION_ARGS)5808 int4_accum_inv(PG_FUNCTION_ARGS)
5809 {
5810 	PolyNumAggState *state;
5811 
5812 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5813 
5814 	/* Should not get here with no state */
5815 	if (state == NULL)
5816 		elog(ERROR, "int4_accum_inv called with NULL state");
5817 
5818 	if (!PG_ARGISNULL(1))
5819 	{
5820 #ifdef HAVE_INT128
5821 		do_int128_discard(state, (int128) PG_GETARG_INT32(1));
5822 #else
5823 		/* Should never fail, all inputs have dscale 0 */
5824 		if (!do_numeric_discard(state, int64_to_numeric(PG_GETARG_INT32(1))))
5825 			elog(ERROR, "do_numeric_discard failed unexpectedly");
5826 #endif
5827 	}
5828 
5829 	PG_RETURN_POINTER(state);
5830 }
5831 
5832 Datum
int8_accum_inv(PG_FUNCTION_ARGS)5833 int8_accum_inv(PG_FUNCTION_ARGS)
5834 {
5835 	NumericAggState *state;
5836 
5837 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5838 
5839 	/* Should not get here with no state */
5840 	if (state == NULL)
5841 		elog(ERROR, "int8_accum_inv called with NULL state");
5842 
5843 	if (!PG_ARGISNULL(1))
5844 	{
5845 		/* Should never fail, all inputs have dscale 0 */
5846 		if (!do_numeric_discard(state, int64_to_numeric(PG_GETARG_INT64(1))))
5847 			elog(ERROR, "do_numeric_discard failed unexpectedly");
5848 	}
5849 
5850 	PG_RETURN_POINTER(state);
5851 }
5852 
5853 Datum
int8_avg_accum_inv(PG_FUNCTION_ARGS)5854 int8_avg_accum_inv(PG_FUNCTION_ARGS)
5855 {
5856 	PolyNumAggState *state;
5857 
5858 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5859 
5860 	/* Should not get here with no state */
5861 	if (state == NULL)
5862 		elog(ERROR, "int8_avg_accum_inv called with NULL state");
5863 
5864 	if (!PG_ARGISNULL(1))
5865 	{
5866 #ifdef HAVE_INT128
5867 		do_int128_discard(state, (int128) PG_GETARG_INT64(1));
5868 #else
5869 		/* Should never fail, all inputs have dscale 0 */
5870 		if (!do_numeric_discard(state, int64_to_numeric(PG_GETARG_INT64(1))))
5871 			elog(ERROR, "do_numeric_discard failed unexpectedly");
5872 #endif
5873 	}
5874 
5875 	PG_RETURN_POINTER(state);
5876 }
5877 
5878 Datum
numeric_poly_sum(PG_FUNCTION_ARGS)5879 numeric_poly_sum(PG_FUNCTION_ARGS)
5880 {
5881 #ifdef HAVE_INT128
5882 	PolyNumAggState *state;
5883 	Numeric		res;
5884 	NumericVar	result;
5885 
5886 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5887 
5888 	/* If there were no non-null inputs, return NULL */
5889 	if (state == NULL || state->N == 0)
5890 		PG_RETURN_NULL();
5891 
5892 	init_var(&result);
5893 
5894 	int128_to_numericvar(state->sumX, &result);
5895 
5896 	res = make_result(&result);
5897 
5898 	free_var(&result);
5899 
5900 	PG_RETURN_NUMERIC(res);
5901 #else
5902 	return numeric_sum(fcinfo);
5903 #endif
5904 }
5905 
5906 Datum
numeric_poly_avg(PG_FUNCTION_ARGS)5907 numeric_poly_avg(PG_FUNCTION_ARGS)
5908 {
5909 #ifdef HAVE_INT128
5910 	PolyNumAggState *state;
5911 	NumericVar	result;
5912 	Datum		countd,
5913 				sumd;
5914 
5915 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5916 
5917 	/* If there were no non-null inputs, return NULL */
5918 	if (state == NULL || state->N == 0)
5919 		PG_RETURN_NULL();
5920 
5921 	init_var(&result);
5922 
5923 	int128_to_numericvar(state->sumX, &result);
5924 
5925 	countd = NumericGetDatum(int64_to_numeric(state->N));
5926 	sumd = NumericGetDatum(make_result(&result));
5927 
5928 	free_var(&result);
5929 
5930 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
5931 #else
5932 	return numeric_avg(fcinfo);
5933 #endif
5934 }
5935 
5936 Datum
numeric_avg(PG_FUNCTION_ARGS)5937 numeric_avg(PG_FUNCTION_ARGS)
5938 {
5939 	NumericAggState *state;
5940 	Datum		N_datum;
5941 	Datum		sumX_datum;
5942 	NumericVar	sumX_var;
5943 
5944 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5945 
5946 	/* If there were no non-null inputs, return NULL */
5947 	if (state == NULL || NA_TOTAL_COUNT(state) == 0)
5948 		PG_RETURN_NULL();
5949 
5950 	if (state->NaNcount > 0)	/* there was at least one NaN input */
5951 		PG_RETURN_NUMERIC(make_result(&const_nan));
5952 
5953 	/* adding plus and minus infinities gives NaN */
5954 	if (state->pInfcount > 0 && state->nInfcount > 0)
5955 		PG_RETURN_NUMERIC(make_result(&const_nan));
5956 	if (state->pInfcount > 0)
5957 		PG_RETURN_NUMERIC(make_result(&const_pinf));
5958 	if (state->nInfcount > 0)
5959 		PG_RETURN_NUMERIC(make_result(&const_ninf));
5960 
5961 	N_datum = NumericGetDatum(int64_to_numeric(state->N));
5962 
5963 	init_var(&sumX_var);
5964 	accum_sum_final(&state->sumX, &sumX_var);
5965 	sumX_datum = NumericGetDatum(make_result(&sumX_var));
5966 	free_var(&sumX_var);
5967 
5968 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumX_datum, N_datum));
5969 }
5970 
5971 Datum
numeric_sum(PG_FUNCTION_ARGS)5972 numeric_sum(PG_FUNCTION_ARGS)
5973 {
5974 	NumericAggState *state;
5975 	NumericVar	sumX_var;
5976 	Numeric		result;
5977 
5978 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5979 
5980 	/* If there were no non-null inputs, return NULL */
5981 	if (state == NULL || NA_TOTAL_COUNT(state) == 0)
5982 		PG_RETURN_NULL();
5983 
5984 	if (state->NaNcount > 0)	/* there was at least one NaN input */
5985 		PG_RETURN_NUMERIC(make_result(&const_nan));
5986 
5987 	/* adding plus and minus infinities gives NaN */
5988 	if (state->pInfcount > 0 && state->nInfcount > 0)
5989 		PG_RETURN_NUMERIC(make_result(&const_nan));
5990 	if (state->pInfcount > 0)
5991 		PG_RETURN_NUMERIC(make_result(&const_pinf));
5992 	if (state->nInfcount > 0)
5993 		PG_RETURN_NUMERIC(make_result(&const_ninf));
5994 
5995 	init_var(&sumX_var);
5996 	accum_sum_final(&state->sumX, &sumX_var);
5997 	result = make_result(&sumX_var);
5998 	free_var(&sumX_var);
5999 
6000 	PG_RETURN_NUMERIC(result);
6001 }
6002 
6003 /*
6004  * Workhorse routine for the standard deviance and variance
6005  * aggregates. 'state' is aggregate's transition state.
6006  * 'variance' specifies whether we should calculate the
6007  * variance or the standard deviation. 'sample' indicates whether the
6008  * caller is interested in the sample or the population
6009  * variance/stddev.
6010  *
6011  * If appropriate variance statistic is undefined for the input,
6012  * *is_null is set to true and NULL is returned.
6013  */
6014 static Numeric
numeric_stddev_internal(NumericAggState * state,bool variance,bool sample,bool * is_null)6015 numeric_stddev_internal(NumericAggState *state,
6016 						bool variance, bool sample,
6017 						bool *is_null)
6018 {
6019 	Numeric		res;
6020 	NumericVar	vN,
6021 				vsumX,
6022 				vsumX2,
6023 				vNminus1;
6024 	int64		totCount;
6025 	int			rscale;
6026 
6027 	/*
6028 	 * Sample stddev and variance are undefined when N <= 1; population stddev
6029 	 * is undefined when N == 0.  Return NULL in either case (note that NaNs
6030 	 * and infinities count as normal inputs for this purpose).
6031 	 */
6032 	if (state == NULL || (totCount = NA_TOTAL_COUNT(state)) == 0)
6033 	{
6034 		*is_null = true;
6035 		return NULL;
6036 	}
6037 
6038 	if (sample && totCount <= 1)
6039 	{
6040 		*is_null = true;
6041 		return NULL;
6042 	}
6043 
6044 	*is_null = false;
6045 
6046 	/*
6047 	 * Deal with NaN and infinity cases.  By analogy to the behavior of the
6048 	 * float8 functions, any infinity input produces NaN output.
6049 	 */
6050 	if (state->NaNcount > 0 || state->pInfcount > 0 || state->nInfcount > 0)
6051 		return make_result(&const_nan);
6052 
6053 	/* OK, normal calculation applies */
6054 	init_var(&vN);
6055 	init_var(&vsumX);
6056 	init_var(&vsumX2);
6057 
6058 	int64_to_numericvar(state->N, &vN);
6059 	accum_sum_final(&(state->sumX), &vsumX);
6060 	accum_sum_final(&(state->sumX2), &vsumX2);
6061 
6062 	init_var(&vNminus1);
6063 	sub_var(&vN, &const_one, &vNminus1);
6064 
6065 	/* compute rscale for mul_var calls */
6066 	rscale = vsumX.dscale * 2;
6067 
6068 	mul_var(&vsumX, &vsumX, &vsumX, rscale);	/* vsumX = sumX * sumX */
6069 	mul_var(&vN, &vsumX2, &vsumX2, rscale); /* vsumX2 = N * sumX2 */
6070 	sub_var(&vsumX2, &vsumX, &vsumX2);	/* N * sumX2 - sumX * sumX */
6071 
6072 	if (cmp_var(&vsumX2, &const_zero) <= 0)
6073 	{
6074 		/* Watch out for roundoff error producing a negative numerator */
6075 		res = make_result(&const_zero);
6076 	}
6077 	else
6078 	{
6079 		if (sample)
6080 			mul_var(&vN, &vNminus1, &vNminus1, 0);	/* N * (N - 1) */
6081 		else
6082 			mul_var(&vN, &vN, &vNminus1, 0);	/* N * N */
6083 		rscale = select_div_scale(&vsumX2, &vNminus1);
6084 		div_var(&vsumX2, &vNminus1, &vsumX, rscale, true);	/* variance */
6085 		if (!variance)
6086 			sqrt_var(&vsumX, &vsumX, rscale);	/* stddev */
6087 
6088 		res = make_result(&vsumX);
6089 	}
6090 
6091 	free_var(&vNminus1);
6092 	free_var(&vsumX);
6093 	free_var(&vsumX2);
6094 
6095 	return res;
6096 }
6097 
6098 Datum
numeric_var_samp(PG_FUNCTION_ARGS)6099 numeric_var_samp(PG_FUNCTION_ARGS)
6100 {
6101 	NumericAggState *state;
6102 	Numeric		res;
6103 	bool		is_null;
6104 
6105 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
6106 
6107 	res = numeric_stddev_internal(state, true, true, &is_null);
6108 
6109 	if (is_null)
6110 		PG_RETURN_NULL();
6111 	else
6112 		PG_RETURN_NUMERIC(res);
6113 }
6114 
6115 Datum
numeric_stddev_samp(PG_FUNCTION_ARGS)6116 numeric_stddev_samp(PG_FUNCTION_ARGS)
6117 {
6118 	NumericAggState *state;
6119 	Numeric		res;
6120 	bool		is_null;
6121 
6122 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
6123 
6124 	res = numeric_stddev_internal(state, false, true, &is_null);
6125 
6126 	if (is_null)
6127 		PG_RETURN_NULL();
6128 	else
6129 		PG_RETURN_NUMERIC(res);
6130 }
6131 
6132 Datum
numeric_var_pop(PG_FUNCTION_ARGS)6133 numeric_var_pop(PG_FUNCTION_ARGS)
6134 {
6135 	NumericAggState *state;
6136 	Numeric		res;
6137 	bool		is_null;
6138 
6139 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
6140 
6141 	res = numeric_stddev_internal(state, true, false, &is_null);
6142 
6143 	if (is_null)
6144 		PG_RETURN_NULL();
6145 	else
6146 		PG_RETURN_NUMERIC(res);
6147 }
6148 
6149 Datum
numeric_stddev_pop(PG_FUNCTION_ARGS)6150 numeric_stddev_pop(PG_FUNCTION_ARGS)
6151 {
6152 	NumericAggState *state;
6153 	Numeric		res;
6154 	bool		is_null;
6155 
6156 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
6157 
6158 	res = numeric_stddev_internal(state, false, false, &is_null);
6159 
6160 	if (is_null)
6161 		PG_RETURN_NULL();
6162 	else
6163 		PG_RETURN_NUMERIC(res);
6164 }
6165 
6166 #ifdef HAVE_INT128
6167 static Numeric
numeric_poly_stddev_internal(Int128AggState * state,bool variance,bool sample,bool * is_null)6168 numeric_poly_stddev_internal(Int128AggState *state,
6169 							 bool variance, bool sample,
6170 							 bool *is_null)
6171 {
6172 	NumericAggState numstate;
6173 	Numeric		res;
6174 
6175 	/* Initialize an empty agg state */
6176 	memset(&numstate, 0, sizeof(NumericAggState));
6177 
6178 	if (state)
6179 	{
6180 		NumericVar	tmp_var;
6181 
6182 		numstate.N = state->N;
6183 
6184 		init_var(&tmp_var);
6185 
6186 		int128_to_numericvar(state->sumX, &tmp_var);
6187 		accum_sum_add(&numstate.sumX, &tmp_var);
6188 
6189 		int128_to_numericvar(state->sumX2, &tmp_var);
6190 		accum_sum_add(&numstate.sumX2, &tmp_var);
6191 
6192 		free_var(&tmp_var);
6193 	}
6194 
6195 	res = numeric_stddev_internal(&numstate, variance, sample, is_null);
6196 
6197 	if (numstate.sumX.ndigits > 0)
6198 	{
6199 		pfree(numstate.sumX.pos_digits);
6200 		pfree(numstate.sumX.neg_digits);
6201 	}
6202 	if (numstate.sumX2.ndigits > 0)
6203 	{
6204 		pfree(numstate.sumX2.pos_digits);
6205 		pfree(numstate.sumX2.neg_digits);
6206 	}
6207 
6208 	return res;
6209 }
6210 #endif
6211 
6212 Datum
numeric_poly_var_samp(PG_FUNCTION_ARGS)6213 numeric_poly_var_samp(PG_FUNCTION_ARGS)
6214 {
6215 #ifdef HAVE_INT128
6216 	PolyNumAggState *state;
6217 	Numeric		res;
6218 	bool		is_null;
6219 
6220 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
6221 
6222 	res = numeric_poly_stddev_internal(state, true, true, &is_null);
6223 
6224 	if (is_null)
6225 		PG_RETURN_NULL();
6226 	else
6227 		PG_RETURN_NUMERIC(res);
6228 #else
6229 	return numeric_var_samp(fcinfo);
6230 #endif
6231 }
6232 
6233 Datum
numeric_poly_stddev_samp(PG_FUNCTION_ARGS)6234 numeric_poly_stddev_samp(PG_FUNCTION_ARGS)
6235 {
6236 #ifdef HAVE_INT128
6237 	PolyNumAggState *state;
6238 	Numeric		res;
6239 	bool		is_null;
6240 
6241 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
6242 
6243 	res = numeric_poly_stddev_internal(state, false, true, &is_null);
6244 
6245 	if (is_null)
6246 		PG_RETURN_NULL();
6247 	else
6248 		PG_RETURN_NUMERIC(res);
6249 #else
6250 	return numeric_stddev_samp(fcinfo);
6251 #endif
6252 }
6253 
6254 Datum
numeric_poly_var_pop(PG_FUNCTION_ARGS)6255 numeric_poly_var_pop(PG_FUNCTION_ARGS)
6256 {
6257 #ifdef HAVE_INT128
6258 	PolyNumAggState *state;
6259 	Numeric		res;
6260 	bool		is_null;
6261 
6262 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
6263 
6264 	res = numeric_poly_stddev_internal(state, true, false, &is_null);
6265 
6266 	if (is_null)
6267 		PG_RETURN_NULL();
6268 	else
6269 		PG_RETURN_NUMERIC(res);
6270 #else
6271 	return numeric_var_pop(fcinfo);
6272 #endif
6273 }
6274 
6275 Datum
numeric_poly_stddev_pop(PG_FUNCTION_ARGS)6276 numeric_poly_stddev_pop(PG_FUNCTION_ARGS)
6277 {
6278 #ifdef HAVE_INT128
6279 	PolyNumAggState *state;
6280 	Numeric		res;
6281 	bool		is_null;
6282 
6283 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
6284 
6285 	res = numeric_poly_stddev_internal(state, false, false, &is_null);
6286 
6287 	if (is_null)
6288 		PG_RETURN_NULL();
6289 	else
6290 		PG_RETURN_NUMERIC(res);
6291 #else
6292 	return numeric_stddev_pop(fcinfo);
6293 #endif
6294 }
6295 
6296 /*
6297  * SUM transition functions for integer datatypes.
6298  *
6299  * To avoid overflow, we use accumulators wider than the input datatype.
6300  * A Numeric accumulator is needed for int8 input; for int4 and int2
6301  * inputs, we use int8 accumulators which should be sufficient for practical
6302  * purposes.  (The latter two therefore don't really belong in this file,
6303  * but we keep them here anyway.)
6304  *
6305  * Because SQL defines the SUM() of no values to be NULL, not zero,
6306  * the initial condition of the transition data value needs to be NULL. This
6307  * means we can't rely on ExecAgg to automatically insert the first non-null
6308  * data value into the transition data: it doesn't know how to do the type
6309  * conversion.  The upshot is that these routines have to be marked non-strict
6310  * and handle substitution of the first non-null input themselves.
6311  *
6312  * Note: these functions are used only in plain aggregation mode.
6313  * In moving-aggregate mode, we use intX_avg_accum and intX_avg_accum_inv.
6314  */
6315 
6316 Datum
int2_sum(PG_FUNCTION_ARGS)6317 int2_sum(PG_FUNCTION_ARGS)
6318 {
6319 	int64		newval;
6320 
6321 	if (PG_ARGISNULL(0))
6322 	{
6323 		/* No non-null input seen so far... */
6324 		if (PG_ARGISNULL(1))
6325 			PG_RETURN_NULL();	/* still no non-null */
6326 		/* This is the first non-null input. */
6327 		newval = (int64) PG_GETARG_INT16(1);
6328 		PG_RETURN_INT64(newval);
6329 	}
6330 
6331 	/*
6332 	 * If we're invoked as an aggregate, we can cheat and modify our first
6333 	 * parameter in-place to avoid palloc overhead. If not, we need to return
6334 	 * the new value of the transition variable. (If int8 is pass-by-value,
6335 	 * then of course this is useless as well as incorrect, so just ifdef it
6336 	 * out.)
6337 	 */
6338 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
6339 	if (AggCheckCallContext(fcinfo, NULL))
6340 	{
6341 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
6342 
6343 		/* Leave the running sum unchanged in the new input is null */
6344 		if (!PG_ARGISNULL(1))
6345 			*oldsum = *oldsum + (int64) PG_GETARG_INT16(1);
6346 
6347 		PG_RETURN_POINTER(oldsum);
6348 	}
6349 	else
6350 #endif
6351 	{
6352 		int64		oldsum = PG_GETARG_INT64(0);
6353 
6354 		/* Leave sum unchanged if new input is null. */
6355 		if (PG_ARGISNULL(1))
6356 			PG_RETURN_INT64(oldsum);
6357 
6358 		/* OK to do the addition. */
6359 		newval = oldsum + (int64) PG_GETARG_INT16(1);
6360 
6361 		PG_RETURN_INT64(newval);
6362 	}
6363 }
6364 
6365 Datum
int4_sum(PG_FUNCTION_ARGS)6366 int4_sum(PG_FUNCTION_ARGS)
6367 {
6368 	int64		newval;
6369 
6370 	if (PG_ARGISNULL(0))
6371 	{
6372 		/* No non-null input seen so far... */
6373 		if (PG_ARGISNULL(1))
6374 			PG_RETURN_NULL();	/* still no non-null */
6375 		/* This is the first non-null input. */
6376 		newval = (int64) PG_GETARG_INT32(1);
6377 		PG_RETURN_INT64(newval);
6378 	}
6379 
6380 	/*
6381 	 * If we're invoked as an aggregate, we can cheat and modify our first
6382 	 * parameter in-place to avoid palloc overhead. If not, we need to return
6383 	 * the new value of the transition variable. (If int8 is pass-by-value,
6384 	 * then of course this is useless as well as incorrect, so just ifdef it
6385 	 * out.)
6386 	 */
6387 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
6388 	if (AggCheckCallContext(fcinfo, NULL))
6389 	{
6390 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
6391 
6392 		/* Leave the running sum unchanged in the new input is null */
6393 		if (!PG_ARGISNULL(1))
6394 			*oldsum = *oldsum + (int64) PG_GETARG_INT32(1);
6395 
6396 		PG_RETURN_POINTER(oldsum);
6397 	}
6398 	else
6399 #endif
6400 	{
6401 		int64		oldsum = PG_GETARG_INT64(0);
6402 
6403 		/* Leave sum unchanged if new input is null. */
6404 		if (PG_ARGISNULL(1))
6405 			PG_RETURN_INT64(oldsum);
6406 
6407 		/* OK to do the addition. */
6408 		newval = oldsum + (int64) PG_GETARG_INT32(1);
6409 
6410 		PG_RETURN_INT64(newval);
6411 	}
6412 }
6413 
6414 /*
6415  * Note: this function is obsolete, it's no longer used for SUM(int8).
6416  */
6417 Datum
int8_sum(PG_FUNCTION_ARGS)6418 int8_sum(PG_FUNCTION_ARGS)
6419 {
6420 	Numeric		oldsum;
6421 
6422 	if (PG_ARGISNULL(0))
6423 	{
6424 		/* No non-null input seen so far... */
6425 		if (PG_ARGISNULL(1))
6426 			PG_RETURN_NULL();	/* still no non-null */
6427 		/* This is the first non-null input. */
6428 		PG_RETURN_NUMERIC(int64_to_numeric(PG_GETARG_INT64(1)));
6429 	}
6430 
6431 	/*
6432 	 * Note that we cannot special-case the aggregate case here, as we do for
6433 	 * int2_sum and int4_sum: numeric is of variable size, so we cannot modify
6434 	 * our first parameter in-place.
6435 	 */
6436 
6437 	oldsum = PG_GETARG_NUMERIC(0);
6438 
6439 	/* Leave sum unchanged if new input is null. */
6440 	if (PG_ARGISNULL(1))
6441 		PG_RETURN_NUMERIC(oldsum);
6442 
6443 	/* OK to do the addition. */
6444 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_add,
6445 										NumericGetDatum(oldsum),
6446 										NumericGetDatum(int64_to_numeric(PG_GETARG_INT64(1)))));
6447 }
6448 
6449 
6450 /*
6451  * Routines for avg(int2) and avg(int4).  The transition datatype
6452  * is a two-element int8 array, holding count and sum.
6453  *
6454  * These functions are also used for sum(int2) and sum(int4) when
6455  * operating in moving-aggregate mode, since for correct inverse transitions
6456  * we need to count the inputs.
6457  */
6458 
6459 typedef struct Int8TransTypeData
6460 {
6461 	int64		count;
6462 	int64		sum;
6463 } Int8TransTypeData;
6464 
6465 Datum
int2_avg_accum(PG_FUNCTION_ARGS)6466 int2_avg_accum(PG_FUNCTION_ARGS)
6467 {
6468 	ArrayType  *transarray;
6469 	int16		newval = PG_GETARG_INT16(1);
6470 	Int8TransTypeData *transdata;
6471 
6472 	/*
6473 	 * If we're invoked as an aggregate, we can cheat and modify our first
6474 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
6475 	 * a copy of it before scribbling on it.
6476 	 */
6477 	if (AggCheckCallContext(fcinfo, NULL))
6478 		transarray = PG_GETARG_ARRAYTYPE_P(0);
6479 	else
6480 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
6481 
6482 	if (ARR_HASNULL(transarray) ||
6483 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
6484 		elog(ERROR, "expected 2-element int8 array");
6485 
6486 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
6487 	transdata->count++;
6488 	transdata->sum += newval;
6489 
6490 	PG_RETURN_ARRAYTYPE_P(transarray);
6491 }
6492 
6493 Datum
int4_avg_accum(PG_FUNCTION_ARGS)6494 int4_avg_accum(PG_FUNCTION_ARGS)
6495 {
6496 	ArrayType  *transarray;
6497 	int32		newval = PG_GETARG_INT32(1);
6498 	Int8TransTypeData *transdata;
6499 
6500 	/*
6501 	 * If we're invoked as an aggregate, we can cheat and modify our first
6502 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
6503 	 * a copy of it before scribbling on it.
6504 	 */
6505 	if (AggCheckCallContext(fcinfo, NULL))
6506 		transarray = PG_GETARG_ARRAYTYPE_P(0);
6507 	else
6508 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
6509 
6510 	if (ARR_HASNULL(transarray) ||
6511 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
6512 		elog(ERROR, "expected 2-element int8 array");
6513 
6514 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
6515 	transdata->count++;
6516 	transdata->sum += newval;
6517 
6518 	PG_RETURN_ARRAYTYPE_P(transarray);
6519 }
6520 
6521 Datum
int4_avg_combine(PG_FUNCTION_ARGS)6522 int4_avg_combine(PG_FUNCTION_ARGS)
6523 {
6524 	ArrayType  *transarray1;
6525 	ArrayType  *transarray2;
6526 	Int8TransTypeData *state1;
6527 	Int8TransTypeData *state2;
6528 
6529 	if (!AggCheckCallContext(fcinfo, NULL))
6530 		elog(ERROR, "aggregate function called in non-aggregate context");
6531 
6532 	transarray1 = PG_GETARG_ARRAYTYPE_P(0);
6533 	transarray2 = PG_GETARG_ARRAYTYPE_P(1);
6534 
6535 	if (ARR_HASNULL(transarray1) ||
6536 		ARR_SIZE(transarray1) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
6537 		elog(ERROR, "expected 2-element int8 array");
6538 
6539 	if (ARR_HASNULL(transarray2) ||
6540 		ARR_SIZE(transarray2) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
6541 		elog(ERROR, "expected 2-element int8 array");
6542 
6543 	state1 = (Int8TransTypeData *) ARR_DATA_PTR(transarray1);
6544 	state2 = (Int8TransTypeData *) ARR_DATA_PTR(transarray2);
6545 
6546 	state1->count += state2->count;
6547 	state1->sum += state2->sum;
6548 
6549 	PG_RETURN_ARRAYTYPE_P(transarray1);
6550 }
6551 
6552 Datum
int2_avg_accum_inv(PG_FUNCTION_ARGS)6553 int2_avg_accum_inv(PG_FUNCTION_ARGS)
6554 {
6555 	ArrayType  *transarray;
6556 	int16		newval = PG_GETARG_INT16(1);
6557 	Int8TransTypeData *transdata;
6558 
6559 	/*
6560 	 * If we're invoked as an aggregate, we can cheat and modify our first
6561 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
6562 	 * a copy of it before scribbling on it.
6563 	 */
6564 	if (AggCheckCallContext(fcinfo, NULL))
6565 		transarray = PG_GETARG_ARRAYTYPE_P(0);
6566 	else
6567 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
6568 
6569 	if (ARR_HASNULL(transarray) ||
6570 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
6571 		elog(ERROR, "expected 2-element int8 array");
6572 
6573 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
6574 	transdata->count--;
6575 	transdata->sum -= newval;
6576 
6577 	PG_RETURN_ARRAYTYPE_P(transarray);
6578 }
6579 
6580 Datum
int4_avg_accum_inv(PG_FUNCTION_ARGS)6581 int4_avg_accum_inv(PG_FUNCTION_ARGS)
6582 {
6583 	ArrayType  *transarray;
6584 	int32		newval = PG_GETARG_INT32(1);
6585 	Int8TransTypeData *transdata;
6586 
6587 	/*
6588 	 * If we're invoked as an aggregate, we can cheat and modify our first
6589 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
6590 	 * a copy of it before scribbling on it.
6591 	 */
6592 	if (AggCheckCallContext(fcinfo, NULL))
6593 		transarray = PG_GETARG_ARRAYTYPE_P(0);
6594 	else
6595 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
6596 
6597 	if (ARR_HASNULL(transarray) ||
6598 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
6599 		elog(ERROR, "expected 2-element int8 array");
6600 
6601 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
6602 	transdata->count--;
6603 	transdata->sum -= newval;
6604 
6605 	PG_RETURN_ARRAYTYPE_P(transarray);
6606 }
6607 
6608 Datum
int8_avg(PG_FUNCTION_ARGS)6609 int8_avg(PG_FUNCTION_ARGS)
6610 {
6611 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
6612 	Int8TransTypeData *transdata;
6613 	Datum		countd,
6614 				sumd;
6615 
6616 	if (ARR_HASNULL(transarray) ||
6617 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
6618 		elog(ERROR, "expected 2-element int8 array");
6619 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
6620 
6621 	/* SQL defines AVG of no values to be NULL */
6622 	if (transdata->count == 0)
6623 		PG_RETURN_NULL();
6624 
6625 	countd = NumericGetDatum(int64_to_numeric(transdata->count));
6626 	sumd = NumericGetDatum(int64_to_numeric(transdata->sum));
6627 
6628 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
6629 }
6630 
6631 /*
6632  * SUM(int2) and SUM(int4) both return int8, so we can use this
6633  * final function for both.
6634  */
6635 Datum
int2int4_sum(PG_FUNCTION_ARGS)6636 int2int4_sum(PG_FUNCTION_ARGS)
6637 {
6638 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
6639 	Int8TransTypeData *transdata;
6640 
6641 	if (ARR_HASNULL(transarray) ||
6642 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
6643 		elog(ERROR, "expected 2-element int8 array");
6644 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
6645 
6646 	/* SQL defines SUM of no values to be NULL */
6647 	if (transdata->count == 0)
6648 		PG_RETURN_NULL();
6649 
6650 	PG_RETURN_DATUM(Int64GetDatumFast(transdata->sum));
6651 }
6652 
6653 
6654 /* ----------------------------------------------------------------------
6655  *
6656  * Debug support
6657  *
6658  * ----------------------------------------------------------------------
6659  */
6660 
6661 #ifdef NUMERIC_DEBUG
6662 
6663 /*
6664  * dump_numeric() - Dump a value in the db storage format for debugging
6665  */
6666 static void
dump_numeric(const char * str,Numeric num)6667 dump_numeric(const char *str, Numeric num)
6668 {
6669 	NumericDigit *digits = NUMERIC_DIGITS(num);
6670 	int			ndigits;
6671 	int			i;
6672 
6673 	ndigits = NUMERIC_NDIGITS(num);
6674 
6675 	printf("%s: NUMERIC w=%d d=%d ", str,
6676 		   NUMERIC_WEIGHT(num), NUMERIC_DSCALE(num));
6677 	switch (NUMERIC_SIGN(num))
6678 	{
6679 		case NUMERIC_POS:
6680 			printf("POS");
6681 			break;
6682 		case NUMERIC_NEG:
6683 			printf("NEG");
6684 			break;
6685 		case NUMERIC_NAN:
6686 			printf("NaN");
6687 			break;
6688 		case NUMERIC_PINF:
6689 			printf("Infinity");
6690 			break;
6691 		case NUMERIC_NINF:
6692 			printf("-Infinity");
6693 			break;
6694 		default:
6695 			printf("SIGN=0x%x", NUMERIC_SIGN(num));
6696 			break;
6697 	}
6698 
6699 	for (i = 0; i < ndigits; i++)
6700 		printf(" %0*d", DEC_DIGITS, digits[i]);
6701 	printf("\n");
6702 }
6703 
6704 
6705 /*
6706  * dump_var() - Dump a value in the variable format for debugging
6707  */
6708 static void
dump_var(const char * str,NumericVar * var)6709 dump_var(const char *str, NumericVar *var)
6710 {
6711 	int			i;
6712 
6713 	printf("%s: VAR w=%d d=%d ", str, var->weight, var->dscale);
6714 	switch (var->sign)
6715 	{
6716 		case NUMERIC_POS:
6717 			printf("POS");
6718 			break;
6719 		case NUMERIC_NEG:
6720 			printf("NEG");
6721 			break;
6722 		case NUMERIC_NAN:
6723 			printf("NaN");
6724 			break;
6725 		case NUMERIC_PINF:
6726 			printf("Infinity");
6727 			break;
6728 		case NUMERIC_NINF:
6729 			printf("-Infinity");
6730 			break;
6731 		default:
6732 			printf("SIGN=0x%x", var->sign);
6733 			break;
6734 	}
6735 
6736 	for (i = 0; i < var->ndigits; i++)
6737 		printf(" %0*d", DEC_DIGITS, var->digits[i]);
6738 
6739 	printf("\n");
6740 }
6741 #endif							/* NUMERIC_DEBUG */
6742 
6743 
6744 /* ----------------------------------------------------------------------
6745  *
6746  * Local functions follow
6747  *
6748  * In general, these do not support "special" (NaN or infinity) inputs;
6749  * callers should handle those possibilities first.
6750  * (There are one or two exceptions, noted in their header comments.)
6751  *
6752  * ----------------------------------------------------------------------
6753  */
6754 
6755 
6756 /*
6757  * alloc_var() -
6758  *
6759  *	Allocate a digit buffer of ndigits digits (plus a spare digit for rounding)
6760  */
6761 static void
alloc_var(NumericVar * var,int ndigits)6762 alloc_var(NumericVar *var, int ndigits)
6763 {
6764 	digitbuf_free(var->buf);
6765 	var->buf = digitbuf_alloc(ndigits + 1);
6766 	var->buf[0] = 0;			/* spare digit for rounding */
6767 	var->digits = var->buf + 1;
6768 	var->ndigits = ndigits;
6769 }
6770 
6771 
6772 /*
6773  * free_var() -
6774  *
6775  *	Return the digit buffer of a variable to the free pool
6776  */
6777 static void
free_var(NumericVar * var)6778 free_var(NumericVar *var)
6779 {
6780 	digitbuf_free(var->buf);
6781 	var->buf = NULL;
6782 	var->digits = NULL;
6783 	var->sign = NUMERIC_NAN;
6784 }
6785 
6786 
6787 /*
6788  * zero_var() -
6789  *
6790  *	Set a variable to ZERO.
6791  *	Note: its dscale is not touched.
6792  */
6793 static void
zero_var(NumericVar * var)6794 zero_var(NumericVar *var)
6795 {
6796 	digitbuf_free(var->buf);
6797 	var->buf = NULL;
6798 	var->digits = NULL;
6799 	var->ndigits = 0;
6800 	var->weight = 0;			/* by convention; doesn't really matter */
6801 	var->sign = NUMERIC_POS;	/* anything but NAN... */
6802 }
6803 
6804 
6805 /*
6806  * set_var_from_str()
6807  *
6808  *	Parse a string and put the number into a variable
6809  *
6810  * This function does not handle leading or trailing spaces.  It returns
6811  * the end+1 position parsed, so that caller can check for trailing
6812  * spaces/garbage if deemed necessary.
6813  *
6814  * cp is the place to actually start parsing; str is what to use in error
6815  * reports.  (Typically cp would be the same except advanced over spaces.)
6816  */
6817 static const char *
set_var_from_str(const char * str,const char * cp,NumericVar * dest)6818 set_var_from_str(const char *str, const char *cp, NumericVar *dest)
6819 {
6820 	bool		have_dp = false;
6821 	int			i;
6822 	unsigned char *decdigits;
6823 	int			sign = NUMERIC_POS;
6824 	int			dweight = -1;
6825 	int			ddigits;
6826 	int			dscale = 0;
6827 	int			weight;
6828 	int			ndigits;
6829 	int			offset;
6830 	NumericDigit *digits;
6831 
6832 	/*
6833 	 * We first parse the string to extract decimal digits and determine the
6834 	 * correct decimal weight.  Then convert to NBASE representation.
6835 	 */
6836 	switch (*cp)
6837 	{
6838 		case '+':
6839 			sign = NUMERIC_POS;
6840 			cp++;
6841 			break;
6842 
6843 		case '-':
6844 			sign = NUMERIC_NEG;
6845 			cp++;
6846 			break;
6847 	}
6848 
6849 	if (*cp == '.')
6850 	{
6851 		have_dp = true;
6852 		cp++;
6853 	}
6854 
6855 	if (!isdigit((unsigned char) *cp))
6856 		ereport(ERROR,
6857 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6858 				 errmsg("invalid input syntax for type %s: \"%s\"",
6859 						"numeric", str)));
6860 
6861 	decdigits = (unsigned char *) palloc(strlen(cp) + DEC_DIGITS * 2);
6862 
6863 	/* leading padding for digit alignment later */
6864 	memset(decdigits, 0, DEC_DIGITS);
6865 	i = DEC_DIGITS;
6866 
6867 	while (*cp)
6868 	{
6869 		if (isdigit((unsigned char) *cp))
6870 		{
6871 			decdigits[i++] = *cp++ - '0';
6872 			if (!have_dp)
6873 				dweight++;
6874 			else
6875 				dscale++;
6876 		}
6877 		else if (*cp == '.')
6878 		{
6879 			if (have_dp)
6880 				ereport(ERROR,
6881 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6882 						 errmsg("invalid input syntax for type %s: \"%s\"",
6883 								"numeric", str)));
6884 			have_dp = true;
6885 			cp++;
6886 		}
6887 		else
6888 			break;
6889 	}
6890 
6891 	ddigits = i - DEC_DIGITS;
6892 	/* trailing padding for digit alignment later */
6893 	memset(decdigits + i, 0, DEC_DIGITS - 1);
6894 
6895 	/* Handle exponent, if any */
6896 	if (*cp == 'e' || *cp == 'E')
6897 	{
6898 		long		exponent;
6899 		char	   *endptr;
6900 
6901 		cp++;
6902 		exponent = strtol(cp, &endptr, 10);
6903 		if (endptr == cp)
6904 			ereport(ERROR,
6905 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6906 					 errmsg("invalid input syntax for type %s: \"%s\"",
6907 							"numeric", str)));
6908 		cp = endptr;
6909 
6910 		/*
6911 		 * At this point, dweight and dscale can't be more than about
6912 		 * INT_MAX/2 due to the MaxAllocSize limit on string length, so
6913 		 * constraining the exponent similarly should be enough to prevent
6914 		 * integer overflow in this function.  If the value is too large to
6915 		 * fit in storage format, make_result() will complain about it later;
6916 		 * for consistency use the same ereport errcode/text as make_result().
6917 		 */
6918 		if (exponent >= INT_MAX / 2 || exponent <= -(INT_MAX / 2))
6919 			ereport(ERROR,
6920 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6921 					 errmsg("value overflows numeric format")));
6922 		dweight += (int) exponent;
6923 		dscale -= (int) exponent;
6924 		if (dscale < 0)
6925 			dscale = 0;
6926 	}
6927 
6928 	/*
6929 	 * Okay, convert pure-decimal representation to base NBASE.  First we need
6930 	 * to determine the converted weight and ndigits.  offset is the number of
6931 	 * decimal zeroes to insert before the first given digit to have a
6932 	 * correctly aligned first NBASE digit.
6933 	 */
6934 	if (dweight >= 0)
6935 		weight = (dweight + 1 + DEC_DIGITS - 1) / DEC_DIGITS - 1;
6936 	else
6937 		weight = -((-dweight - 1) / DEC_DIGITS + 1);
6938 	offset = (weight + 1) * DEC_DIGITS - (dweight + 1);
6939 	ndigits = (ddigits + offset + DEC_DIGITS - 1) / DEC_DIGITS;
6940 
6941 	alloc_var(dest, ndigits);
6942 	dest->sign = sign;
6943 	dest->weight = weight;
6944 	dest->dscale = dscale;
6945 
6946 	i = DEC_DIGITS - offset;
6947 	digits = dest->digits;
6948 
6949 	while (ndigits-- > 0)
6950 	{
6951 #if DEC_DIGITS == 4
6952 		*digits++ = ((decdigits[i] * 10 + decdigits[i + 1]) * 10 +
6953 					 decdigits[i + 2]) * 10 + decdigits[i + 3];
6954 #elif DEC_DIGITS == 2
6955 		*digits++ = decdigits[i] * 10 + decdigits[i + 1];
6956 #elif DEC_DIGITS == 1
6957 		*digits++ = decdigits[i];
6958 #else
6959 #error unsupported NBASE
6960 #endif
6961 		i += DEC_DIGITS;
6962 	}
6963 
6964 	pfree(decdigits);
6965 
6966 	/* Strip any leading/trailing zeroes, and normalize weight if zero */
6967 	strip_var(dest);
6968 
6969 	/* Return end+1 position for caller */
6970 	return cp;
6971 }
6972 
6973 
6974 /*
6975  * set_var_from_num() -
6976  *
6977  *	Convert the packed db format into a variable
6978  */
6979 static void
set_var_from_num(Numeric num,NumericVar * dest)6980 set_var_from_num(Numeric num, NumericVar *dest)
6981 {
6982 	int			ndigits;
6983 
6984 	ndigits = NUMERIC_NDIGITS(num);
6985 
6986 	alloc_var(dest, ndigits);
6987 
6988 	dest->weight = NUMERIC_WEIGHT(num);
6989 	dest->sign = NUMERIC_SIGN(num);
6990 	dest->dscale = NUMERIC_DSCALE(num);
6991 
6992 	memcpy(dest->digits, NUMERIC_DIGITS(num), ndigits * sizeof(NumericDigit));
6993 }
6994 
6995 
6996 /*
6997  * init_var_from_num() -
6998  *
6999  *	Initialize a variable from packed db format. The digits array is not
7000  *	copied, which saves some cycles when the resulting var is not modified.
7001  *	Also, there's no need to call free_var(), as long as you don't assign any
7002  *	other value to it (with set_var_* functions, or by using the var as the
7003  *	destination of a function like add_var())
7004  *
7005  *	CAUTION: Do not modify the digits buffer of a var initialized with this
7006  *	function, e.g by calling round_var() or trunc_var(), as the changes will
7007  *	propagate to the original Numeric! It's OK to use it as the destination
7008  *	argument of one of the calculational functions, though.
7009  */
7010 static void
init_var_from_num(Numeric num,NumericVar * dest)7011 init_var_from_num(Numeric num, NumericVar *dest)
7012 {
7013 	dest->ndigits = NUMERIC_NDIGITS(num);
7014 	dest->weight = NUMERIC_WEIGHT(num);
7015 	dest->sign = NUMERIC_SIGN(num);
7016 	dest->dscale = NUMERIC_DSCALE(num);
7017 	dest->digits = NUMERIC_DIGITS(num);
7018 	dest->buf = NULL;			/* digits array is not palloc'd */
7019 }
7020 
7021 
7022 /*
7023  * set_var_from_var() -
7024  *
7025  *	Copy one variable into another
7026  */
7027 static void
set_var_from_var(const NumericVar * value,NumericVar * dest)7028 set_var_from_var(const NumericVar *value, NumericVar *dest)
7029 {
7030 	NumericDigit *newbuf;
7031 
7032 	newbuf = digitbuf_alloc(value->ndigits + 1);
7033 	newbuf[0] = 0;				/* spare digit for rounding */
7034 	if (value->ndigits > 0)		/* else value->digits might be null */
7035 		memcpy(newbuf + 1, value->digits,
7036 			   value->ndigits * sizeof(NumericDigit));
7037 
7038 	digitbuf_free(dest->buf);
7039 
7040 	memmove(dest, value, sizeof(NumericVar));
7041 	dest->buf = newbuf;
7042 	dest->digits = newbuf + 1;
7043 }
7044 
7045 
7046 /*
7047  * get_str_from_var() -
7048  *
7049  *	Convert a var to text representation (guts of numeric_out).
7050  *	The var is displayed to the number of digits indicated by its dscale.
7051  *	Returns a palloc'd string.
7052  */
7053 static char *
get_str_from_var(const NumericVar * var)7054 get_str_from_var(const NumericVar *var)
7055 {
7056 	int			dscale;
7057 	char	   *str;
7058 	char	   *cp;
7059 	char	   *endcp;
7060 	int			i;
7061 	int			d;
7062 	NumericDigit dig;
7063 
7064 #if DEC_DIGITS > 1
7065 	NumericDigit d1;
7066 #endif
7067 
7068 	dscale = var->dscale;
7069 
7070 	/*
7071 	 * Allocate space for the result.
7072 	 *
7073 	 * i is set to the # of decimal digits before decimal point. dscale is the
7074 	 * # of decimal digits we will print after decimal point. We may generate
7075 	 * as many as DEC_DIGITS-1 excess digits at the end, and in addition we
7076 	 * need room for sign, decimal point, null terminator.
7077 	 */
7078 	i = (var->weight + 1) * DEC_DIGITS;
7079 	if (i <= 0)
7080 		i = 1;
7081 
7082 	str = palloc(i + dscale + DEC_DIGITS + 2);
7083 	cp = str;
7084 
7085 	/*
7086 	 * Output a dash for negative values
7087 	 */
7088 	if (var->sign == NUMERIC_NEG)
7089 		*cp++ = '-';
7090 
7091 	/*
7092 	 * Output all digits before the decimal point
7093 	 */
7094 	if (var->weight < 0)
7095 	{
7096 		d = var->weight + 1;
7097 		*cp++ = '0';
7098 	}
7099 	else
7100 	{
7101 		for (d = 0; d <= var->weight; d++)
7102 		{
7103 			dig = (d < var->ndigits) ? var->digits[d] : 0;
7104 			/* In the first digit, suppress extra leading decimal zeroes */
7105 #if DEC_DIGITS == 4
7106 			{
7107 				bool		putit = (d > 0);
7108 
7109 				d1 = dig / 1000;
7110 				dig -= d1 * 1000;
7111 				putit |= (d1 > 0);
7112 				if (putit)
7113 					*cp++ = d1 + '0';
7114 				d1 = dig / 100;
7115 				dig -= d1 * 100;
7116 				putit |= (d1 > 0);
7117 				if (putit)
7118 					*cp++ = d1 + '0';
7119 				d1 = dig / 10;
7120 				dig -= d1 * 10;
7121 				putit |= (d1 > 0);
7122 				if (putit)
7123 					*cp++ = d1 + '0';
7124 				*cp++ = dig + '0';
7125 			}
7126 #elif DEC_DIGITS == 2
7127 			d1 = dig / 10;
7128 			dig -= d1 * 10;
7129 			if (d1 > 0 || d > 0)
7130 				*cp++ = d1 + '0';
7131 			*cp++ = dig + '0';
7132 #elif DEC_DIGITS == 1
7133 			*cp++ = dig + '0';
7134 #else
7135 #error unsupported NBASE
7136 #endif
7137 		}
7138 	}
7139 
7140 	/*
7141 	 * If requested, output a decimal point and all the digits that follow it.
7142 	 * We initially put out a multiple of DEC_DIGITS digits, then truncate if
7143 	 * needed.
7144 	 */
7145 	if (dscale > 0)
7146 	{
7147 		*cp++ = '.';
7148 		endcp = cp + dscale;
7149 		for (i = 0; i < dscale; d++, i += DEC_DIGITS)
7150 		{
7151 			dig = (d >= 0 && d < var->ndigits) ? var->digits[d] : 0;
7152 #if DEC_DIGITS == 4
7153 			d1 = dig / 1000;
7154 			dig -= d1 * 1000;
7155 			*cp++ = d1 + '0';
7156 			d1 = dig / 100;
7157 			dig -= d1 * 100;
7158 			*cp++ = d1 + '0';
7159 			d1 = dig / 10;
7160 			dig -= d1 * 10;
7161 			*cp++ = d1 + '0';
7162 			*cp++ = dig + '0';
7163 #elif DEC_DIGITS == 2
7164 			d1 = dig / 10;
7165 			dig -= d1 * 10;
7166 			*cp++ = d1 + '0';
7167 			*cp++ = dig + '0';
7168 #elif DEC_DIGITS == 1
7169 			*cp++ = dig + '0';
7170 #else
7171 #error unsupported NBASE
7172 #endif
7173 		}
7174 		cp = endcp;
7175 	}
7176 
7177 	/*
7178 	 * terminate the string and return it
7179 	 */
7180 	*cp = '\0';
7181 	return str;
7182 }
7183 
7184 /*
7185  * get_str_from_var_sci() -
7186  *
7187  *	Convert a var to a normalised scientific notation text representation.
7188  *	This function does the heavy lifting for numeric_out_sci().
7189  *
7190  *	This notation has the general form a * 10^b, where a is known as the
7191  *	"significand" and b is known as the "exponent".
7192  *
7193  *	Because we can't do superscript in ASCII (and because we want to copy
7194  *	printf's behaviour) we display the exponent using E notation, with a
7195  *	minimum of two exponent digits.
7196  *
7197  *	For example, the value 1234 could be output as 1.2e+03.
7198  *
7199  *	We assume that the exponent can fit into an int32.
7200  *
7201  *	rscale is the number of decimal digits desired after the decimal point in
7202  *	the output, negative values will be treated as meaning zero.
7203  *
7204  *	Returns a palloc'd string.
7205  */
7206 static char *
get_str_from_var_sci(const NumericVar * var,int rscale)7207 get_str_from_var_sci(const NumericVar *var, int rscale)
7208 {
7209 	int32		exponent;
7210 	NumericVar	tmp_var;
7211 	size_t		len;
7212 	char	   *str;
7213 	char	   *sig_out;
7214 
7215 	if (rscale < 0)
7216 		rscale = 0;
7217 
7218 	/*
7219 	 * Determine the exponent of this number in normalised form.
7220 	 *
7221 	 * This is the exponent required to represent the number with only one
7222 	 * significant digit before the decimal place.
7223 	 */
7224 	if (var->ndigits > 0)
7225 	{
7226 		exponent = (var->weight + 1) * DEC_DIGITS;
7227 
7228 		/*
7229 		 * Compensate for leading decimal zeroes in the first numeric digit by
7230 		 * decrementing the exponent.
7231 		 */
7232 		exponent -= DEC_DIGITS - (int) log10(var->digits[0]);
7233 	}
7234 	else
7235 	{
7236 		/*
7237 		 * If var has no digits, then it must be zero.
7238 		 *
7239 		 * Zero doesn't technically have a meaningful exponent in normalised
7240 		 * notation, but we just display the exponent as zero for consistency
7241 		 * of output.
7242 		 */
7243 		exponent = 0;
7244 	}
7245 
7246 	/*
7247 	 * Divide var by 10^exponent to get the significand, rounding to rscale
7248 	 * decimal digits in the process.
7249 	 */
7250 	init_var(&tmp_var);
7251 
7252 	power_ten_int(exponent, &tmp_var);
7253 	div_var(var, &tmp_var, &tmp_var, rscale, true);
7254 	sig_out = get_str_from_var(&tmp_var);
7255 
7256 	free_var(&tmp_var);
7257 
7258 	/*
7259 	 * Allocate space for the result.
7260 	 *
7261 	 * In addition to the significand, we need room for the exponent
7262 	 * decoration ("e"), the sign of the exponent, up to 10 digits for the
7263 	 * exponent itself, and of course the null terminator.
7264 	 */
7265 	len = strlen(sig_out) + 13;
7266 	str = palloc(len);
7267 	snprintf(str, len, "%se%+03d", sig_out, exponent);
7268 
7269 	pfree(sig_out);
7270 
7271 	return str;
7272 }
7273 
7274 
7275 /*
7276  * duplicate_numeric() - copy a packed-format Numeric
7277  *
7278  * This will handle NaN and Infinity cases.
7279  */
7280 static Numeric
duplicate_numeric(Numeric num)7281 duplicate_numeric(Numeric num)
7282 {
7283 	Numeric		res;
7284 
7285 	res = (Numeric) palloc(VARSIZE(num));
7286 	memcpy(res, num, VARSIZE(num));
7287 	return res;
7288 }
7289 
7290 /*
7291  * make_result_opt_error() -
7292  *
7293  *	Create the packed db numeric format in palloc()'d memory from
7294  *	a variable.  This will handle NaN and Infinity cases.
7295  *
7296  *	If "have_error" isn't NULL, on overflow *have_error is set to true and
7297  *	NULL is returned.  This is helpful when caller needs to handle errors.
7298  */
7299 static Numeric
make_result_opt_error(const NumericVar * var,bool * have_error)7300 make_result_opt_error(const NumericVar *var, bool *have_error)
7301 {
7302 	Numeric		result;
7303 	NumericDigit *digits = var->digits;
7304 	int			weight = var->weight;
7305 	int			sign = var->sign;
7306 	int			n;
7307 	Size		len;
7308 
7309 	if (have_error)
7310 		*have_error = false;
7311 
7312 	if ((sign & NUMERIC_SIGN_MASK) == NUMERIC_SPECIAL)
7313 	{
7314 		/*
7315 		 * Verify valid special value.  This could be just an Assert, perhaps,
7316 		 * but it seems worthwhile to expend a few cycles to ensure that we
7317 		 * never write any nonzero reserved bits to disk.
7318 		 */
7319 		if (!(sign == NUMERIC_NAN ||
7320 			  sign == NUMERIC_PINF ||
7321 			  sign == NUMERIC_NINF))
7322 			elog(ERROR, "invalid numeric sign value 0x%x", sign);
7323 
7324 		result = (Numeric) palloc(NUMERIC_HDRSZ_SHORT);
7325 
7326 		SET_VARSIZE(result, NUMERIC_HDRSZ_SHORT);
7327 		result->choice.n_header = sign;
7328 		/* the header word is all we need */
7329 
7330 		dump_numeric("make_result()", result);
7331 		return result;
7332 	}
7333 
7334 	n = var->ndigits;
7335 
7336 	/* truncate leading zeroes */
7337 	while (n > 0 && *digits == 0)
7338 	{
7339 		digits++;
7340 		weight--;
7341 		n--;
7342 	}
7343 	/* truncate trailing zeroes */
7344 	while (n > 0 && digits[n - 1] == 0)
7345 		n--;
7346 
7347 	/* If zero result, force to weight=0 and positive sign */
7348 	if (n == 0)
7349 	{
7350 		weight = 0;
7351 		sign = NUMERIC_POS;
7352 	}
7353 
7354 	/* Build the result */
7355 	if (NUMERIC_CAN_BE_SHORT(var->dscale, weight))
7356 	{
7357 		len = NUMERIC_HDRSZ_SHORT + n * sizeof(NumericDigit);
7358 		result = (Numeric) palloc(len);
7359 		SET_VARSIZE(result, len);
7360 		result->choice.n_short.n_header =
7361 			(sign == NUMERIC_NEG ? (NUMERIC_SHORT | NUMERIC_SHORT_SIGN_MASK)
7362 			 : NUMERIC_SHORT)
7363 			| (var->dscale << NUMERIC_SHORT_DSCALE_SHIFT)
7364 			| (weight < 0 ? NUMERIC_SHORT_WEIGHT_SIGN_MASK : 0)
7365 			| (weight & NUMERIC_SHORT_WEIGHT_MASK);
7366 	}
7367 	else
7368 	{
7369 		len = NUMERIC_HDRSZ + n * sizeof(NumericDigit);
7370 		result = (Numeric) palloc(len);
7371 		SET_VARSIZE(result, len);
7372 		result->choice.n_long.n_sign_dscale =
7373 			sign | (var->dscale & NUMERIC_DSCALE_MASK);
7374 		result->choice.n_long.n_weight = weight;
7375 	}
7376 
7377 	Assert(NUMERIC_NDIGITS(result) == n);
7378 	if (n > 0)
7379 		memcpy(NUMERIC_DIGITS(result), digits, n * sizeof(NumericDigit));
7380 
7381 	/* Check for overflow of int16 fields */
7382 	if (NUMERIC_WEIGHT(result) != weight ||
7383 		NUMERIC_DSCALE(result) != var->dscale)
7384 	{
7385 		if (have_error)
7386 		{
7387 			*have_error = true;
7388 			return NULL;
7389 		}
7390 		else
7391 		{
7392 			ereport(ERROR,
7393 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
7394 					 errmsg("value overflows numeric format")));
7395 		}
7396 	}
7397 
7398 	dump_numeric("make_result()", result);
7399 	return result;
7400 }
7401 
7402 
7403 /*
7404  * make_result() -
7405  *
7406  *	An interface to make_result_opt_error() without "have_error" argument.
7407  */
7408 static Numeric
make_result(const NumericVar * var)7409 make_result(const NumericVar *var)
7410 {
7411 	return make_result_opt_error(var, NULL);
7412 }
7413 
7414 
7415 /*
7416  * apply_typmod() -
7417  *
7418  *	Do bounds checking and rounding according to the specified typmod.
7419  *	Note that this is only applied to normal finite values.
7420  */
7421 static void
apply_typmod(NumericVar * var,int32 typmod)7422 apply_typmod(NumericVar *var, int32 typmod)
7423 {
7424 	int			precision;
7425 	int			scale;
7426 	int			maxdigits;
7427 	int			ddigits;
7428 	int			i;
7429 
7430 	/* Do nothing if we have a default typmod (-1) */
7431 	if (typmod < (int32) (VARHDRSZ))
7432 		return;
7433 
7434 	typmod -= VARHDRSZ;
7435 	precision = (typmod >> 16) & 0xffff;
7436 	scale = typmod & 0xffff;
7437 	maxdigits = precision - scale;
7438 
7439 	/* Round to target scale (and set var->dscale) */
7440 	round_var(var, scale);
7441 
7442 	/*
7443 	 * Check for overflow - note we can't do this before rounding, because
7444 	 * rounding could raise the weight.  Also note that the var's weight could
7445 	 * be inflated by leading zeroes, which will be stripped before storage
7446 	 * but perhaps might not have been yet. In any case, we must recognize a
7447 	 * true zero, whose weight doesn't mean anything.
7448 	 */
7449 	ddigits = (var->weight + 1) * DEC_DIGITS;
7450 	if (ddigits > maxdigits)
7451 	{
7452 		/* Determine true weight; and check for all-zero result */
7453 		for (i = 0; i < var->ndigits; i++)
7454 		{
7455 			NumericDigit dig = var->digits[i];
7456 
7457 			if (dig)
7458 			{
7459 				/* Adjust for any high-order decimal zero digits */
7460 #if DEC_DIGITS == 4
7461 				if (dig < 10)
7462 					ddigits -= 3;
7463 				else if (dig < 100)
7464 					ddigits -= 2;
7465 				else if (dig < 1000)
7466 					ddigits -= 1;
7467 #elif DEC_DIGITS == 2
7468 				if (dig < 10)
7469 					ddigits -= 1;
7470 #elif DEC_DIGITS == 1
7471 				/* no adjustment */
7472 #else
7473 #error unsupported NBASE
7474 #endif
7475 				if (ddigits > maxdigits)
7476 					ereport(ERROR,
7477 							(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
7478 							 errmsg("numeric field overflow"),
7479 							 errdetail("A field with precision %d, scale %d must round to an absolute value less than %s%d.",
7480 									   precision, scale,
7481 					/* Display 10^0 as 1 */
7482 									   maxdigits ? "10^" : "",
7483 									   maxdigits ? maxdigits : 1
7484 									   )));
7485 				break;
7486 			}
7487 			ddigits -= DEC_DIGITS;
7488 		}
7489 	}
7490 }
7491 
7492 /*
7493  * apply_typmod_special() -
7494  *
7495  *	Do bounds checking according to the specified typmod, for an Inf or NaN.
7496  *	For convenience of most callers, the value is presented in packed form.
7497  */
7498 static void
apply_typmod_special(Numeric num,int32 typmod)7499 apply_typmod_special(Numeric num, int32 typmod)
7500 {
7501 	int			precision;
7502 	int			scale;
7503 
7504 	Assert(NUMERIC_IS_SPECIAL(num));	/* caller error if not */
7505 
7506 	/*
7507 	 * NaN is allowed regardless of the typmod; that's rather dubious perhaps,
7508 	 * but it's a longstanding behavior.  Inf is rejected if we have any
7509 	 * typmod restriction, since an infinity shouldn't be claimed to fit in
7510 	 * any finite number of digits.
7511 	 */
7512 	if (NUMERIC_IS_NAN(num))
7513 		return;
7514 
7515 	/* Do nothing if we have a default typmod (-1) */
7516 	if (typmod < (int32) (VARHDRSZ))
7517 		return;
7518 
7519 	typmod -= VARHDRSZ;
7520 	precision = (typmod >> 16) & 0xffff;
7521 	scale = typmod & 0xffff;
7522 
7523 	ereport(ERROR,
7524 			(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
7525 			 errmsg("numeric field overflow"),
7526 			 errdetail("A field with precision %d, scale %d cannot hold an infinite value.",
7527 					   precision, scale)));
7528 }
7529 
7530 
7531 /*
7532  * Convert numeric to int8, rounding if needed.
7533  *
7534  * If overflow, return false (no error is raised).  Return true if okay.
7535  */
7536 static bool
numericvar_to_int64(const NumericVar * var,int64 * result)7537 numericvar_to_int64(const NumericVar *var, int64 *result)
7538 {
7539 	NumericDigit *digits;
7540 	int			ndigits;
7541 	int			weight;
7542 	int			i;
7543 	int64		val;
7544 	bool		neg;
7545 	NumericVar	rounded;
7546 
7547 	/* Round to nearest integer */
7548 	init_var(&rounded);
7549 	set_var_from_var(var, &rounded);
7550 	round_var(&rounded, 0);
7551 
7552 	/* Check for zero input */
7553 	strip_var(&rounded);
7554 	ndigits = rounded.ndigits;
7555 	if (ndigits == 0)
7556 	{
7557 		*result = 0;
7558 		free_var(&rounded);
7559 		return true;
7560 	}
7561 
7562 	/*
7563 	 * For input like 10000000000, we must treat stripped digits as real. So
7564 	 * the loop assumes there are weight+1 digits before the decimal point.
7565 	 */
7566 	weight = rounded.weight;
7567 	Assert(weight >= 0 && ndigits <= weight + 1);
7568 
7569 	/*
7570 	 * Construct the result. To avoid issues with converting a value
7571 	 * corresponding to INT64_MIN (which can't be represented as a positive 64
7572 	 * bit two's complement integer), accumulate value as a negative number.
7573 	 */
7574 	digits = rounded.digits;
7575 	neg = (rounded.sign == NUMERIC_NEG);
7576 	val = -digits[0];
7577 	for (i = 1; i <= weight; i++)
7578 	{
7579 		if (unlikely(pg_mul_s64_overflow(val, NBASE, &val)))
7580 		{
7581 			free_var(&rounded);
7582 			return false;
7583 		}
7584 
7585 		if (i < ndigits)
7586 		{
7587 			if (unlikely(pg_sub_s64_overflow(val, digits[i], &val)))
7588 			{
7589 				free_var(&rounded);
7590 				return false;
7591 			}
7592 		}
7593 	}
7594 
7595 	free_var(&rounded);
7596 
7597 	if (!neg)
7598 	{
7599 		if (unlikely(val == PG_INT64_MIN))
7600 			return false;
7601 		val = -val;
7602 	}
7603 	*result = val;
7604 
7605 	return true;
7606 }
7607 
7608 /*
7609  * Convert int8 value to numeric.
7610  */
7611 static void
int64_to_numericvar(int64 val,NumericVar * var)7612 int64_to_numericvar(int64 val, NumericVar *var)
7613 {
7614 	uint64		uval,
7615 				newuval;
7616 	NumericDigit *ptr;
7617 	int			ndigits;
7618 
7619 	/* int64 can require at most 19 decimal digits; add one for safety */
7620 	alloc_var(var, 20 / DEC_DIGITS);
7621 	if (val < 0)
7622 	{
7623 		var->sign = NUMERIC_NEG;
7624 		uval = -val;
7625 	}
7626 	else
7627 	{
7628 		var->sign = NUMERIC_POS;
7629 		uval = val;
7630 	}
7631 	var->dscale = 0;
7632 	if (val == 0)
7633 	{
7634 		var->ndigits = 0;
7635 		var->weight = 0;
7636 		return;
7637 	}
7638 	ptr = var->digits + var->ndigits;
7639 	ndigits = 0;
7640 	do
7641 	{
7642 		ptr--;
7643 		ndigits++;
7644 		newuval = uval / NBASE;
7645 		*ptr = uval - newuval * NBASE;
7646 		uval = newuval;
7647 	} while (uval);
7648 	var->digits = ptr;
7649 	var->ndigits = ndigits;
7650 	var->weight = ndigits - 1;
7651 }
7652 
7653 /*
7654  * Convert numeric to uint64, rounding if needed.
7655  *
7656  * If overflow, return false (no error is raised).  Return true if okay.
7657  */
7658 static bool
numericvar_to_uint64(const NumericVar * var,uint64 * result)7659 numericvar_to_uint64(const NumericVar *var, uint64 *result)
7660 {
7661 	NumericDigit *digits;
7662 	int			ndigits;
7663 	int			weight;
7664 	int			i;
7665 	uint64		val;
7666 	NumericVar	rounded;
7667 
7668 	/* Round to nearest integer */
7669 	init_var(&rounded);
7670 	set_var_from_var(var, &rounded);
7671 	round_var(&rounded, 0);
7672 
7673 	/* Check for zero input */
7674 	strip_var(&rounded);
7675 	ndigits = rounded.ndigits;
7676 	if (ndigits == 0)
7677 	{
7678 		*result = 0;
7679 		free_var(&rounded);
7680 		return true;
7681 	}
7682 
7683 	/* Check for negative input */
7684 	if (rounded.sign == NUMERIC_NEG)
7685 	{
7686 		free_var(&rounded);
7687 		return false;
7688 	}
7689 
7690 	/*
7691 	 * For input like 10000000000, we must treat stripped digits as real. So
7692 	 * the loop assumes there are weight+1 digits before the decimal point.
7693 	 */
7694 	weight = rounded.weight;
7695 	Assert(weight >= 0 && ndigits <= weight + 1);
7696 
7697 	/* Construct the result */
7698 	digits = rounded.digits;
7699 	val = digits[0];
7700 	for (i = 1; i <= weight; i++)
7701 	{
7702 		if (unlikely(pg_mul_u64_overflow(val, NBASE, &val)))
7703 		{
7704 			free_var(&rounded);
7705 			return false;
7706 		}
7707 
7708 		if (i < ndigits)
7709 		{
7710 			if (unlikely(pg_add_u64_overflow(val, digits[i], &val)))
7711 			{
7712 				free_var(&rounded);
7713 				return false;
7714 			}
7715 		}
7716 	}
7717 
7718 	free_var(&rounded);
7719 
7720 	*result = val;
7721 
7722 	return true;
7723 }
7724 
7725 #ifdef HAVE_INT128
7726 /*
7727  * Convert numeric to int128, rounding if needed.
7728  *
7729  * If overflow, return false (no error is raised).  Return true if okay.
7730  */
7731 static bool
numericvar_to_int128(const NumericVar * var,int128 * result)7732 numericvar_to_int128(const NumericVar *var, int128 *result)
7733 {
7734 	NumericDigit *digits;
7735 	int			ndigits;
7736 	int			weight;
7737 	int			i;
7738 	int128		val,
7739 				oldval;
7740 	bool		neg;
7741 	NumericVar	rounded;
7742 
7743 	/* Round to nearest integer */
7744 	init_var(&rounded);
7745 	set_var_from_var(var, &rounded);
7746 	round_var(&rounded, 0);
7747 
7748 	/* Check for zero input */
7749 	strip_var(&rounded);
7750 	ndigits = rounded.ndigits;
7751 	if (ndigits == 0)
7752 	{
7753 		*result = 0;
7754 		free_var(&rounded);
7755 		return true;
7756 	}
7757 
7758 	/*
7759 	 * For input like 10000000000, we must treat stripped digits as real. So
7760 	 * the loop assumes there are weight+1 digits before the decimal point.
7761 	 */
7762 	weight = rounded.weight;
7763 	Assert(weight >= 0 && ndigits <= weight + 1);
7764 
7765 	/* Construct the result */
7766 	digits = rounded.digits;
7767 	neg = (rounded.sign == NUMERIC_NEG);
7768 	val = digits[0];
7769 	for (i = 1; i <= weight; i++)
7770 	{
7771 		oldval = val;
7772 		val *= NBASE;
7773 		if (i < ndigits)
7774 			val += digits[i];
7775 
7776 		/*
7777 		 * The overflow check is a bit tricky because we want to accept
7778 		 * INT128_MIN, which will overflow the positive accumulator.  We can
7779 		 * detect this case easily though because INT128_MIN is the only
7780 		 * nonzero value for which -val == val (on a two's complement machine,
7781 		 * anyway).
7782 		 */
7783 		if ((val / NBASE) != oldval)	/* possible overflow? */
7784 		{
7785 			if (!neg || (-val) != val || val == 0 || oldval < 0)
7786 			{
7787 				free_var(&rounded);
7788 				return false;
7789 			}
7790 		}
7791 	}
7792 
7793 	free_var(&rounded);
7794 
7795 	*result = neg ? -val : val;
7796 	return true;
7797 }
7798 
7799 /*
7800  * Convert 128 bit integer to numeric.
7801  */
7802 static void
int128_to_numericvar(int128 val,NumericVar * var)7803 int128_to_numericvar(int128 val, NumericVar *var)
7804 {
7805 	uint128		uval,
7806 				newuval;
7807 	NumericDigit *ptr;
7808 	int			ndigits;
7809 
7810 	/* int128 can require at most 39 decimal digits; add one for safety */
7811 	alloc_var(var, 40 / DEC_DIGITS);
7812 	if (val < 0)
7813 	{
7814 		var->sign = NUMERIC_NEG;
7815 		uval = -val;
7816 	}
7817 	else
7818 	{
7819 		var->sign = NUMERIC_POS;
7820 		uval = val;
7821 	}
7822 	var->dscale = 0;
7823 	if (val == 0)
7824 	{
7825 		var->ndigits = 0;
7826 		var->weight = 0;
7827 		return;
7828 	}
7829 	ptr = var->digits + var->ndigits;
7830 	ndigits = 0;
7831 	do
7832 	{
7833 		ptr--;
7834 		ndigits++;
7835 		newuval = uval / NBASE;
7836 		*ptr = uval - newuval * NBASE;
7837 		uval = newuval;
7838 	} while (uval);
7839 	var->digits = ptr;
7840 	var->ndigits = ndigits;
7841 	var->weight = ndigits - 1;
7842 }
7843 #endif
7844 
7845 /*
7846  * Convert a NumericVar to float8; if out of range, return +/- HUGE_VAL
7847  */
7848 static double
numericvar_to_double_no_overflow(const NumericVar * var)7849 numericvar_to_double_no_overflow(const NumericVar *var)
7850 {
7851 	char	   *tmp;
7852 	double		val;
7853 	char	   *endptr;
7854 
7855 	tmp = get_str_from_var(var);
7856 
7857 	/* unlike float8in, we ignore ERANGE from strtod */
7858 	val = strtod(tmp, &endptr);
7859 	if (*endptr != '\0')
7860 	{
7861 		/* shouldn't happen ... */
7862 		ereport(ERROR,
7863 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
7864 				 errmsg("invalid input syntax for type %s: \"%s\"",
7865 						"double precision", tmp)));
7866 	}
7867 
7868 	pfree(tmp);
7869 
7870 	return val;
7871 }
7872 
7873 
7874 /*
7875  * cmp_var() -
7876  *
7877  *	Compare two values on variable level.  We assume zeroes have been
7878  *	truncated to no digits.
7879  */
7880 static int
cmp_var(const NumericVar * var1,const NumericVar * var2)7881 cmp_var(const NumericVar *var1, const NumericVar *var2)
7882 {
7883 	return cmp_var_common(var1->digits, var1->ndigits,
7884 						  var1->weight, var1->sign,
7885 						  var2->digits, var2->ndigits,
7886 						  var2->weight, var2->sign);
7887 }
7888 
7889 /*
7890  * cmp_var_common() -
7891  *
7892  *	Main routine of cmp_var(). This function can be used by both
7893  *	NumericVar and Numeric.
7894  */
7895 static int
cmp_var_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,int var1sign,const NumericDigit * var2digits,int var2ndigits,int var2weight,int var2sign)7896 cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
7897 			   int var1weight, int var1sign,
7898 			   const NumericDigit *var2digits, int var2ndigits,
7899 			   int var2weight, int var2sign)
7900 {
7901 	if (var1ndigits == 0)
7902 	{
7903 		if (var2ndigits == 0)
7904 			return 0;
7905 		if (var2sign == NUMERIC_NEG)
7906 			return 1;
7907 		return -1;
7908 	}
7909 	if (var2ndigits == 0)
7910 	{
7911 		if (var1sign == NUMERIC_POS)
7912 			return 1;
7913 		return -1;
7914 	}
7915 
7916 	if (var1sign == NUMERIC_POS)
7917 	{
7918 		if (var2sign == NUMERIC_NEG)
7919 			return 1;
7920 		return cmp_abs_common(var1digits, var1ndigits, var1weight,
7921 							  var2digits, var2ndigits, var2weight);
7922 	}
7923 
7924 	if (var2sign == NUMERIC_POS)
7925 		return -1;
7926 
7927 	return cmp_abs_common(var2digits, var2ndigits, var2weight,
7928 						  var1digits, var1ndigits, var1weight);
7929 }
7930 
7931 
7932 /*
7933  * add_var() -
7934  *
7935  *	Full version of add functionality on variable level (handling signs).
7936  *	result might point to one of the operands too without danger.
7937  */
7938 static void
add_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)7939 add_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
7940 {
7941 	/*
7942 	 * Decide on the signs of the two variables what to do
7943 	 */
7944 	if (var1->sign == NUMERIC_POS)
7945 	{
7946 		if (var2->sign == NUMERIC_POS)
7947 		{
7948 			/*
7949 			 * Both are positive result = +(ABS(var1) + ABS(var2))
7950 			 */
7951 			add_abs(var1, var2, result);
7952 			result->sign = NUMERIC_POS;
7953 		}
7954 		else
7955 		{
7956 			/*
7957 			 * var1 is positive, var2 is negative Must compare absolute values
7958 			 */
7959 			switch (cmp_abs(var1, var2))
7960 			{
7961 				case 0:
7962 					/* ----------
7963 					 * ABS(var1) == ABS(var2)
7964 					 * result = ZERO
7965 					 * ----------
7966 					 */
7967 					zero_var(result);
7968 					result->dscale = Max(var1->dscale, var2->dscale);
7969 					break;
7970 
7971 				case 1:
7972 					/* ----------
7973 					 * ABS(var1) > ABS(var2)
7974 					 * result = +(ABS(var1) - ABS(var2))
7975 					 * ----------
7976 					 */
7977 					sub_abs(var1, var2, result);
7978 					result->sign = NUMERIC_POS;
7979 					break;
7980 
7981 				case -1:
7982 					/* ----------
7983 					 * ABS(var1) < ABS(var2)
7984 					 * result = -(ABS(var2) - ABS(var1))
7985 					 * ----------
7986 					 */
7987 					sub_abs(var2, var1, result);
7988 					result->sign = NUMERIC_NEG;
7989 					break;
7990 			}
7991 		}
7992 	}
7993 	else
7994 	{
7995 		if (var2->sign == NUMERIC_POS)
7996 		{
7997 			/* ----------
7998 			 * var1 is negative, var2 is positive
7999 			 * Must compare absolute values
8000 			 * ----------
8001 			 */
8002 			switch (cmp_abs(var1, var2))
8003 			{
8004 				case 0:
8005 					/* ----------
8006 					 * ABS(var1) == ABS(var2)
8007 					 * result = ZERO
8008 					 * ----------
8009 					 */
8010 					zero_var(result);
8011 					result->dscale = Max(var1->dscale, var2->dscale);
8012 					break;
8013 
8014 				case 1:
8015 					/* ----------
8016 					 * ABS(var1) > ABS(var2)
8017 					 * result = -(ABS(var1) - ABS(var2))
8018 					 * ----------
8019 					 */
8020 					sub_abs(var1, var2, result);
8021 					result->sign = NUMERIC_NEG;
8022 					break;
8023 
8024 				case -1:
8025 					/* ----------
8026 					 * ABS(var1) < ABS(var2)
8027 					 * result = +(ABS(var2) - ABS(var1))
8028 					 * ----------
8029 					 */
8030 					sub_abs(var2, var1, result);
8031 					result->sign = NUMERIC_POS;
8032 					break;
8033 			}
8034 		}
8035 		else
8036 		{
8037 			/* ----------
8038 			 * Both are negative
8039 			 * result = -(ABS(var1) + ABS(var2))
8040 			 * ----------
8041 			 */
8042 			add_abs(var1, var2, result);
8043 			result->sign = NUMERIC_NEG;
8044 		}
8045 	}
8046 }
8047 
8048 
8049 /*
8050  * sub_var() -
8051  *
8052  *	Full version of sub functionality on variable level (handling signs).
8053  *	result might point to one of the operands too without danger.
8054  */
8055 static void
sub_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8056 sub_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8057 {
8058 	/*
8059 	 * Decide on the signs of the two variables what to do
8060 	 */
8061 	if (var1->sign == NUMERIC_POS)
8062 	{
8063 		if (var2->sign == NUMERIC_NEG)
8064 		{
8065 			/* ----------
8066 			 * var1 is positive, var2 is negative
8067 			 * result = +(ABS(var1) + ABS(var2))
8068 			 * ----------
8069 			 */
8070 			add_abs(var1, var2, result);
8071 			result->sign = NUMERIC_POS;
8072 		}
8073 		else
8074 		{
8075 			/* ----------
8076 			 * Both are positive
8077 			 * Must compare absolute values
8078 			 * ----------
8079 			 */
8080 			switch (cmp_abs(var1, var2))
8081 			{
8082 				case 0:
8083 					/* ----------
8084 					 * ABS(var1) == ABS(var2)
8085 					 * result = ZERO
8086 					 * ----------
8087 					 */
8088 					zero_var(result);
8089 					result->dscale = Max(var1->dscale, var2->dscale);
8090 					break;
8091 
8092 				case 1:
8093 					/* ----------
8094 					 * ABS(var1) > ABS(var2)
8095 					 * result = +(ABS(var1) - ABS(var2))
8096 					 * ----------
8097 					 */
8098 					sub_abs(var1, var2, result);
8099 					result->sign = NUMERIC_POS;
8100 					break;
8101 
8102 				case -1:
8103 					/* ----------
8104 					 * ABS(var1) < ABS(var2)
8105 					 * result = -(ABS(var2) - ABS(var1))
8106 					 * ----------
8107 					 */
8108 					sub_abs(var2, var1, result);
8109 					result->sign = NUMERIC_NEG;
8110 					break;
8111 			}
8112 		}
8113 	}
8114 	else
8115 	{
8116 		if (var2->sign == NUMERIC_NEG)
8117 		{
8118 			/* ----------
8119 			 * Both are negative
8120 			 * Must compare absolute values
8121 			 * ----------
8122 			 */
8123 			switch (cmp_abs(var1, var2))
8124 			{
8125 				case 0:
8126 					/* ----------
8127 					 * ABS(var1) == ABS(var2)
8128 					 * result = ZERO
8129 					 * ----------
8130 					 */
8131 					zero_var(result);
8132 					result->dscale = Max(var1->dscale, var2->dscale);
8133 					break;
8134 
8135 				case 1:
8136 					/* ----------
8137 					 * ABS(var1) > ABS(var2)
8138 					 * result = -(ABS(var1) - ABS(var2))
8139 					 * ----------
8140 					 */
8141 					sub_abs(var1, var2, result);
8142 					result->sign = NUMERIC_NEG;
8143 					break;
8144 
8145 				case -1:
8146 					/* ----------
8147 					 * ABS(var1) < ABS(var2)
8148 					 * result = +(ABS(var2) - ABS(var1))
8149 					 * ----------
8150 					 */
8151 					sub_abs(var2, var1, result);
8152 					result->sign = NUMERIC_POS;
8153 					break;
8154 			}
8155 		}
8156 		else
8157 		{
8158 			/* ----------
8159 			 * var1 is negative, var2 is positive
8160 			 * result = -(ABS(var1) + ABS(var2))
8161 			 * ----------
8162 			 */
8163 			add_abs(var1, var2, result);
8164 			result->sign = NUMERIC_NEG;
8165 		}
8166 	}
8167 }
8168 
8169 
8170 /*
8171  * mul_var() -
8172  *
8173  *	Multiplication on variable level. Product of var1 * var2 is stored
8174  *	in result.  Result is rounded to no more than rscale fractional digits.
8175  */
8176 static void
mul_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale)8177 mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
8178 		int rscale)
8179 {
8180 	int			res_ndigits;
8181 	int			res_sign;
8182 	int			res_weight;
8183 	int			maxdigits;
8184 	int		   *dig;
8185 	int			carry;
8186 	int			maxdig;
8187 	int			newdig;
8188 	int			var1ndigits;
8189 	int			var2ndigits;
8190 	NumericDigit *var1digits;
8191 	NumericDigit *var2digits;
8192 	NumericDigit *res_digits;
8193 	int			i,
8194 				i1,
8195 				i2;
8196 
8197 	/*
8198 	 * Arrange for var1 to be the shorter of the two numbers.  This improves
8199 	 * performance because the inner multiplication loop is much simpler than
8200 	 * the outer loop, so it's better to have a smaller number of iterations
8201 	 * of the outer loop.  This also reduces the number of times that the
8202 	 * accumulator array needs to be normalized.
8203 	 */
8204 	if (var1->ndigits > var2->ndigits)
8205 	{
8206 		const NumericVar *tmp = var1;
8207 
8208 		var1 = var2;
8209 		var2 = tmp;
8210 	}
8211 
8212 	/* copy these values into local vars for speed in inner loop */
8213 	var1ndigits = var1->ndigits;
8214 	var2ndigits = var2->ndigits;
8215 	var1digits = var1->digits;
8216 	var2digits = var2->digits;
8217 
8218 	if (var1ndigits == 0 || var2ndigits == 0)
8219 	{
8220 		/* one or both inputs is zero; so is result */
8221 		zero_var(result);
8222 		result->dscale = rscale;
8223 		return;
8224 	}
8225 
8226 	/* Determine result sign and (maximum possible) weight */
8227 	if (var1->sign == var2->sign)
8228 		res_sign = NUMERIC_POS;
8229 	else
8230 		res_sign = NUMERIC_NEG;
8231 	res_weight = var1->weight + var2->weight + 2;
8232 
8233 	/*
8234 	 * Determine the number of result digits to compute.  If the exact result
8235 	 * would have more than rscale fractional digits, truncate the computation
8236 	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
8237 	 * would only contribute to the right of that.  (This will give the exact
8238 	 * rounded-to-rscale answer unless carries out of the ignored positions
8239 	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
8240 	 *
8241 	 * Note: an exact computation could not produce more than var1ndigits +
8242 	 * var2ndigits digits, but we allocate one extra output digit in case
8243 	 * rscale-driven rounding produces a carry out of the highest exact digit.
8244 	 */
8245 	res_ndigits = var1ndigits + var2ndigits + 1;
8246 	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
8247 		MUL_GUARD_DIGITS;
8248 	res_ndigits = Min(res_ndigits, maxdigits);
8249 
8250 	if (res_ndigits < 3)
8251 	{
8252 		/* All input digits will be ignored; so result is zero */
8253 		zero_var(result);
8254 		result->dscale = rscale;
8255 		return;
8256 	}
8257 
8258 	/*
8259 	 * We do the arithmetic in an array "dig[]" of signed int's.  Since
8260 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
8261 	 * to avoid normalizing carries immediately.
8262 	 *
8263 	 * maxdig tracks the maximum possible value of any dig[] entry; when this
8264 	 * threatens to exceed INT_MAX, we take the time to propagate carries.
8265 	 * Furthermore, we need to ensure that overflow doesn't occur during the
8266 	 * carry propagation passes either.  The carry values could be as much as
8267 	 * INT_MAX/NBASE, so really we must normalize when digits threaten to
8268 	 * exceed INT_MAX - INT_MAX/NBASE.
8269 	 *
8270 	 * To avoid overflow in maxdig itself, it actually represents the max
8271 	 * possible value divided by NBASE-1, ie, at the top of the loop it is
8272 	 * known that no dig[] entry exceeds maxdig * (NBASE-1).
8273 	 */
8274 	dig = (int *) palloc0(res_ndigits * sizeof(int));
8275 	maxdig = 0;
8276 
8277 	/*
8278 	 * The least significant digits of var1 should be ignored if they don't
8279 	 * contribute directly to the first res_ndigits digits of the result that
8280 	 * we are computing.
8281 	 *
8282 	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
8283 	 * i1+i2+2 of the accumulator array, so we need only consider digits of
8284 	 * var1 for which i1 <= res_ndigits - 3.
8285 	 */
8286 	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
8287 	{
8288 		int			var1digit = var1digits[i1];
8289 
8290 		if (var1digit == 0)
8291 			continue;
8292 
8293 		/* Time to normalize? */
8294 		maxdig += var1digit;
8295 		if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
8296 		{
8297 			/* Yes, do it */
8298 			carry = 0;
8299 			for (i = res_ndigits - 1; i >= 0; i--)
8300 			{
8301 				newdig = dig[i] + carry;
8302 				if (newdig >= NBASE)
8303 				{
8304 					carry = newdig / NBASE;
8305 					newdig -= carry * NBASE;
8306 				}
8307 				else
8308 					carry = 0;
8309 				dig[i] = newdig;
8310 			}
8311 			Assert(carry == 0);
8312 			/* Reset maxdig to indicate new worst-case */
8313 			maxdig = 1 + var1digit;
8314 		}
8315 
8316 		/*
8317 		 * Add the appropriate multiple of var2 into the accumulator.
8318 		 *
8319 		 * As above, digits of var2 can be ignored if they don't contribute,
8320 		 * so we only include digits for which i1+i2+2 < res_ndigits.
8321 		 *
8322 		 * This inner loop is the performance bottleneck for multiplication,
8323 		 * so we want to keep it simple enough so that it can be
8324 		 * auto-vectorized.  Accordingly, process the digits left-to-right
8325 		 * even though schoolbook multiplication would suggest right-to-left.
8326 		 * Since we aren't propagating carries in this loop, the order does
8327 		 * not matter.
8328 		 */
8329 		{
8330 			int			i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
8331 			int		   *dig_i1_2 = &dig[i1 + 2];
8332 
8333 			for (i2 = 0; i2 < i2limit; i2++)
8334 				dig_i1_2[i2] += var1digit * var2digits[i2];
8335 		}
8336 	}
8337 
8338 	/*
8339 	 * Now we do a final carry propagation pass to normalize the result, which
8340 	 * we combine with storing the result digits into the output. Note that
8341 	 * this is still done at full precision w/guard digits.
8342 	 */
8343 	alloc_var(result, res_ndigits);
8344 	res_digits = result->digits;
8345 	carry = 0;
8346 	for (i = res_ndigits - 1; i >= 0; i--)
8347 	{
8348 		newdig = dig[i] + carry;
8349 		if (newdig >= NBASE)
8350 		{
8351 			carry = newdig / NBASE;
8352 			newdig -= carry * NBASE;
8353 		}
8354 		else
8355 			carry = 0;
8356 		res_digits[i] = newdig;
8357 	}
8358 	Assert(carry == 0);
8359 
8360 	pfree(dig);
8361 
8362 	/*
8363 	 * Finally, round the result to the requested precision.
8364 	 */
8365 	result->weight = res_weight;
8366 	result->sign = res_sign;
8367 
8368 	/* Round to target rscale (and set result->dscale) */
8369 	round_var(result, rscale);
8370 
8371 	/* Strip leading and trailing zeroes */
8372 	strip_var(result);
8373 }
8374 
8375 
8376 /*
8377  * div_var() -
8378  *
8379  *	Division on variable level. Quotient of var1 / var2 is stored in result.
8380  *	The quotient is figured to exactly rscale fractional digits.
8381  *	If round is true, it is rounded at the rscale'th digit; if false, it
8382  *	is truncated (towards zero) at that digit.
8383  */
8384 static void
div_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)8385 div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
8386 		int rscale, bool round)
8387 {
8388 	int			div_ndigits;
8389 	int			res_ndigits;
8390 	int			res_sign;
8391 	int			res_weight;
8392 	int			carry;
8393 	int			borrow;
8394 	int			divisor1;
8395 	int			divisor2;
8396 	NumericDigit *dividend;
8397 	NumericDigit *divisor;
8398 	NumericDigit *res_digits;
8399 	int			i;
8400 	int			j;
8401 
8402 	/* copy these values into local vars for speed in inner loop */
8403 	int			var1ndigits = var1->ndigits;
8404 	int			var2ndigits = var2->ndigits;
8405 
8406 	/*
8407 	 * First of all division by zero check; we must not be handed an
8408 	 * unnormalized divisor.
8409 	 */
8410 	if (var2ndigits == 0 || var2->digits[0] == 0)
8411 		ereport(ERROR,
8412 				(errcode(ERRCODE_DIVISION_BY_ZERO),
8413 				 errmsg("division by zero")));
8414 
8415 	/*
8416 	 * Now result zero check
8417 	 */
8418 	if (var1ndigits == 0)
8419 	{
8420 		zero_var(result);
8421 		result->dscale = rscale;
8422 		return;
8423 	}
8424 
8425 	/*
8426 	 * Determine the result sign, weight and number of digits to calculate.
8427 	 * The weight figured here is correct if the emitted quotient has no
8428 	 * leading zero digits; otherwise strip_var() will fix things up.
8429 	 */
8430 	if (var1->sign == var2->sign)
8431 		res_sign = NUMERIC_POS;
8432 	else
8433 		res_sign = NUMERIC_NEG;
8434 	res_weight = var1->weight - var2->weight;
8435 	/* The number of accurate result digits we need to produce: */
8436 	res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
8437 	/* ... but always at least 1 */
8438 	res_ndigits = Max(res_ndigits, 1);
8439 	/* If rounding needed, figure one more digit to ensure correct result */
8440 	if (round)
8441 		res_ndigits++;
8442 
8443 	/*
8444 	 * The working dividend normally requires res_ndigits + var2ndigits
8445 	 * digits, but make it at least var1ndigits so we can load all of var1
8446 	 * into it.  (There will be an additional digit dividend[0] in the
8447 	 * dividend space, but for consistency with Knuth's notation we don't
8448 	 * count that in div_ndigits.)
8449 	 */
8450 	div_ndigits = res_ndigits + var2ndigits;
8451 	div_ndigits = Max(div_ndigits, var1ndigits);
8452 
8453 	/*
8454 	 * We need a workspace with room for the working dividend (div_ndigits+1
8455 	 * digits) plus room for the possibly-normalized divisor (var2ndigits
8456 	 * digits).  It is convenient also to have a zero at divisor[0] with the
8457 	 * actual divisor data in divisor[1 .. var2ndigits].  Transferring the
8458 	 * digits into the workspace also allows us to realloc the result (which
8459 	 * might be the same as either input var) before we begin the main loop.
8460 	 * Note that we use palloc0 to ensure that divisor[0], dividend[0], and
8461 	 * any additional dividend positions beyond var1ndigits, start out 0.
8462 	 */
8463 	dividend = (NumericDigit *)
8464 		palloc0((div_ndigits + var2ndigits + 2) * sizeof(NumericDigit));
8465 	divisor = dividend + (div_ndigits + 1);
8466 	memcpy(dividend + 1, var1->digits, var1ndigits * sizeof(NumericDigit));
8467 	memcpy(divisor + 1, var2->digits, var2ndigits * sizeof(NumericDigit));
8468 
8469 	/*
8470 	 * Now we can realloc the result to hold the generated quotient digits.
8471 	 */
8472 	alloc_var(result, res_ndigits);
8473 	res_digits = result->digits;
8474 
8475 	if (var2ndigits == 1)
8476 	{
8477 		/*
8478 		 * If there's only a single divisor digit, we can use a fast path (cf.
8479 		 * Knuth section 4.3.1 exercise 16).
8480 		 */
8481 		divisor1 = divisor[1];
8482 		carry = 0;
8483 		for (i = 0; i < res_ndigits; i++)
8484 		{
8485 			carry = carry * NBASE + dividend[i + 1];
8486 			res_digits[i] = carry / divisor1;
8487 			carry = carry % divisor1;
8488 		}
8489 	}
8490 	else
8491 	{
8492 		/*
8493 		 * The full multiple-place algorithm is taken from Knuth volume 2,
8494 		 * Algorithm 4.3.1D.
8495 		 *
8496 		 * We need the first divisor digit to be >= NBASE/2.  If it isn't,
8497 		 * make it so by scaling up both the divisor and dividend by the
8498 		 * factor "d".  (The reason for allocating dividend[0] above is to
8499 		 * leave room for possible carry here.)
8500 		 */
8501 		if (divisor[1] < HALF_NBASE)
8502 		{
8503 			int			d = NBASE / (divisor[1] + 1);
8504 
8505 			carry = 0;
8506 			for (i = var2ndigits; i > 0; i--)
8507 			{
8508 				carry += divisor[i] * d;
8509 				divisor[i] = carry % NBASE;
8510 				carry = carry / NBASE;
8511 			}
8512 			Assert(carry == 0);
8513 			carry = 0;
8514 			/* at this point only var1ndigits of dividend can be nonzero */
8515 			for (i = var1ndigits; i >= 0; i--)
8516 			{
8517 				carry += dividend[i] * d;
8518 				dividend[i] = carry % NBASE;
8519 				carry = carry / NBASE;
8520 			}
8521 			Assert(carry == 0);
8522 			Assert(divisor[1] >= HALF_NBASE);
8523 		}
8524 		/* First 2 divisor digits are used repeatedly in main loop */
8525 		divisor1 = divisor[1];
8526 		divisor2 = divisor[2];
8527 
8528 		/*
8529 		 * Begin the main loop.  Each iteration of this loop produces the j'th
8530 		 * quotient digit by dividing dividend[j .. j + var2ndigits] by the
8531 		 * divisor; this is essentially the same as the common manual
8532 		 * procedure for long division.
8533 		 */
8534 		for (j = 0; j < res_ndigits; j++)
8535 		{
8536 			/* Estimate quotient digit from the first two dividend digits */
8537 			int			next2digits = dividend[j] * NBASE + dividend[j + 1];
8538 			int			qhat;
8539 
8540 			/*
8541 			 * If next2digits are 0, then quotient digit must be 0 and there's
8542 			 * no need to adjust the working dividend.  It's worth testing
8543 			 * here to fall out ASAP when processing trailing zeroes in a
8544 			 * dividend.
8545 			 */
8546 			if (next2digits == 0)
8547 			{
8548 				res_digits[j] = 0;
8549 				continue;
8550 			}
8551 
8552 			if (dividend[j] == divisor1)
8553 				qhat = NBASE - 1;
8554 			else
8555 				qhat = next2digits / divisor1;
8556 
8557 			/*
8558 			 * Adjust quotient digit if it's too large.  Knuth proves that
8559 			 * after this step, the quotient digit will be either correct or
8560 			 * just one too large.  (Note: it's OK to use dividend[j+2] here
8561 			 * because we know the divisor length is at least 2.)
8562 			 */
8563 			while (divisor2 * qhat >
8564 				   (next2digits - qhat * divisor1) * NBASE + dividend[j + 2])
8565 				qhat--;
8566 
8567 			/* As above, need do nothing more when quotient digit is 0 */
8568 			if (qhat > 0)
8569 			{
8570 				/*
8571 				 * Multiply the divisor by qhat, and subtract that from the
8572 				 * working dividend.  "carry" tracks the multiplication,
8573 				 * "borrow" the subtraction (could we fold these together?)
8574 				 */
8575 				carry = 0;
8576 				borrow = 0;
8577 				for (i = var2ndigits; i >= 0; i--)
8578 				{
8579 					carry += divisor[i] * qhat;
8580 					borrow -= carry % NBASE;
8581 					carry = carry / NBASE;
8582 					borrow += dividend[j + i];
8583 					if (borrow < 0)
8584 					{
8585 						dividend[j + i] = borrow + NBASE;
8586 						borrow = -1;
8587 					}
8588 					else
8589 					{
8590 						dividend[j + i] = borrow;
8591 						borrow = 0;
8592 					}
8593 				}
8594 				Assert(carry == 0);
8595 
8596 				/*
8597 				 * If we got a borrow out of the top dividend digit, then
8598 				 * indeed qhat was one too large.  Fix it, and add back the
8599 				 * divisor to correct the working dividend.  (Knuth proves
8600 				 * that this will occur only about 3/NBASE of the time; hence,
8601 				 * it's a good idea to test this code with small NBASE to be
8602 				 * sure this section gets exercised.)
8603 				 */
8604 				if (borrow)
8605 				{
8606 					qhat--;
8607 					carry = 0;
8608 					for (i = var2ndigits; i >= 0; i--)
8609 					{
8610 						carry += dividend[j + i] + divisor[i];
8611 						if (carry >= NBASE)
8612 						{
8613 							dividend[j + i] = carry - NBASE;
8614 							carry = 1;
8615 						}
8616 						else
8617 						{
8618 							dividend[j + i] = carry;
8619 							carry = 0;
8620 						}
8621 					}
8622 					/* A carry should occur here to cancel the borrow above */
8623 					Assert(carry == 1);
8624 				}
8625 			}
8626 
8627 			/* And we're done with this quotient digit */
8628 			res_digits[j] = qhat;
8629 		}
8630 	}
8631 
8632 	pfree(dividend);
8633 
8634 	/*
8635 	 * Finally, round or truncate the result to the requested precision.
8636 	 */
8637 	result->weight = res_weight;
8638 	result->sign = res_sign;
8639 
8640 	/* Round or truncate to target rscale (and set result->dscale) */
8641 	if (round)
8642 		round_var(result, rscale);
8643 	else
8644 		trunc_var(result, rscale);
8645 
8646 	/* Strip leading and trailing zeroes */
8647 	strip_var(result);
8648 }
8649 
8650 
8651 /*
8652  * div_var_fast() -
8653  *
8654  *	This has the same API as div_var, but is implemented using the division
8655  *	algorithm from the "FM" library, rather than Knuth's schoolbook-division
8656  *	approach.  This is significantly faster but can produce inaccurate
8657  *	results, because it sometimes has to propagate rounding to the left,
8658  *	and so we can never be entirely sure that we know the requested digits
8659  *	exactly.  We compute DIV_GUARD_DIGITS extra digits, but there is
8660  *	no certainty that that's enough.  We use this only in the transcendental
8661  *	function calculation routines, where everything is approximate anyway.
8662  *
8663  *	Although we provide a "round" argument for consistency with div_var,
8664  *	it is unwise to use this function with round=false.  In truncation mode
8665  *	it is possible to get a result with no significant digits, for example
8666  *	with rscale=0 we might compute 0.99999... and truncate that to 0 when
8667  *	the correct answer is 1.
8668  */
8669 static void
div_var_fast(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)8670 div_var_fast(const NumericVar *var1, const NumericVar *var2,
8671 			 NumericVar *result, int rscale, bool round)
8672 {
8673 	int			div_ndigits;
8674 	int			load_ndigits;
8675 	int			res_sign;
8676 	int			res_weight;
8677 	int		   *div;
8678 	int			qdigit;
8679 	int			carry;
8680 	int			maxdiv;
8681 	int			newdig;
8682 	NumericDigit *res_digits;
8683 	double		fdividend,
8684 				fdivisor,
8685 				fdivisorinverse,
8686 				fquotient;
8687 	int			qi;
8688 	int			i;
8689 
8690 	/* copy these values into local vars for speed in inner loop */
8691 	int			var1ndigits = var1->ndigits;
8692 	int			var2ndigits = var2->ndigits;
8693 	NumericDigit *var1digits = var1->digits;
8694 	NumericDigit *var2digits = var2->digits;
8695 
8696 	/*
8697 	 * First of all division by zero check; we must not be handed an
8698 	 * unnormalized divisor.
8699 	 */
8700 	if (var2ndigits == 0 || var2digits[0] == 0)
8701 		ereport(ERROR,
8702 				(errcode(ERRCODE_DIVISION_BY_ZERO),
8703 				 errmsg("division by zero")));
8704 
8705 	/*
8706 	 * Now result zero check
8707 	 */
8708 	if (var1ndigits == 0)
8709 	{
8710 		zero_var(result);
8711 		result->dscale = rscale;
8712 		return;
8713 	}
8714 
8715 	/*
8716 	 * Determine the result sign, weight and number of digits to calculate
8717 	 */
8718 	if (var1->sign == var2->sign)
8719 		res_sign = NUMERIC_POS;
8720 	else
8721 		res_sign = NUMERIC_NEG;
8722 	res_weight = var1->weight - var2->weight + 1;
8723 	/* The number of accurate result digits we need to produce: */
8724 	div_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
8725 	/* Add guard digits for roundoff error */
8726 	div_ndigits += DIV_GUARD_DIGITS;
8727 	if (div_ndigits < DIV_GUARD_DIGITS)
8728 		div_ndigits = DIV_GUARD_DIGITS;
8729 
8730 	/*
8731 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
8732 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
8733 	 * to avoid normalizing carries immediately.
8734 	 *
8735 	 * We start with div[] containing one zero digit followed by the
8736 	 * dividend's digits (plus appended zeroes to reach the desired precision
8737 	 * including guard digits).  Each step of the main loop computes an
8738 	 * (approximate) quotient digit and stores it into div[], removing one
8739 	 * position of dividend space.  A final pass of carry propagation takes
8740 	 * care of any mistaken quotient digits.
8741 	 *
8742 	 * Note that div[] doesn't necessarily contain all of the digits from the
8743 	 * dividend --- the desired precision plus guard digits might be less than
8744 	 * the dividend's precision.  This happens, for example, in the square
8745 	 * root algorithm, where we typically divide a 2N-digit number by an
8746 	 * N-digit number, and only require a result with N digits of precision.
8747 	 */
8748 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
8749 	load_ndigits = Min(div_ndigits, var1ndigits);
8750 	for (i = 0; i < load_ndigits; i++)
8751 		div[i + 1] = var1digits[i];
8752 
8753 	/*
8754 	 * We estimate each quotient digit using floating-point arithmetic, taking
8755 	 * the first four digits of the (current) dividend and divisor.  This must
8756 	 * be float to avoid overflow.  The quotient digits will generally be off
8757 	 * by no more than one from the exact answer.
8758 	 */
8759 	fdivisor = (double) var2digits[0];
8760 	for (i = 1; i < 4; i++)
8761 	{
8762 		fdivisor *= NBASE;
8763 		if (i < var2ndigits)
8764 			fdivisor += (double) var2digits[i];
8765 	}
8766 	fdivisorinverse = 1.0 / fdivisor;
8767 
8768 	/*
8769 	 * maxdiv tracks the maximum possible absolute value of any div[] entry;
8770 	 * when this threatens to exceed INT_MAX, we take the time to propagate
8771 	 * carries.  Furthermore, we need to ensure that overflow doesn't occur
8772 	 * during the carry propagation passes either.  The carry values may have
8773 	 * an absolute value as high as INT_MAX/NBASE + 1, so really we must
8774 	 * normalize when digits threaten to exceed INT_MAX - INT_MAX/NBASE - 1.
8775 	 *
8776 	 * To avoid overflow in maxdiv itself, it represents the max absolute
8777 	 * value divided by NBASE-1, ie, at the top of the loop it is known that
8778 	 * no div[] entry has an absolute value exceeding maxdiv * (NBASE-1).
8779 	 *
8780 	 * Actually, though, that holds good only for div[] entries after div[qi];
8781 	 * the adjustment done at the bottom of the loop may cause div[qi + 1] to
8782 	 * exceed the maxdiv limit, so that div[qi] in the next iteration is
8783 	 * beyond the limit.  This does not cause problems, as explained below.
8784 	 */
8785 	maxdiv = 1;
8786 
8787 	/*
8788 	 * Outer loop computes next quotient digit, which will go into div[qi]
8789 	 */
8790 	for (qi = 0; qi < div_ndigits; qi++)
8791 	{
8792 		/* Approximate the current dividend value */
8793 		fdividend = (double) div[qi];
8794 		for (i = 1; i < 4; i++)
8795 		{
8796 			fdividend *= NBASE;
8797 			if (qi + i <= div_ndigits)
8798 				fdividend += (double) div[qi + i];
8799 		}
8800 		/* Compute the (approximate) quotient digit */
8801 		fquotient = fdividend * fdivisorinverse;
8802 		qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
8803 			(((int) fquotient) - 1);	/* truncate towards -infinity */
8804 
8805 		if (qdigit != 0)
8806 		{
8807 			/* Do we need to normalize now? */
8808 			maxdiv += Abs(qdigit);
8809 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
8810 			{
8811 				/*
8812 				 * Yes, do it.  Note that if var2ndigits is much smaller than
8813 				 * div_ndigits, we can save a significant amount of effort
8814 				 * here by noting that we only need to normalise those div[]
8815 				 * entries touched where prior iterations subtracted multiples
8816 				 * of the divisor.
8817 				 */
8818 				carry = 0;
8819 				for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
8820 				{
8821 					newdig = div[i] + carry;
8822 					if (newdig < 0)
8823 					{
8824 						carry = -((-newdig - 1) / NBASE) - 1;
8825 						newdig -= carry * NBASE;
8826 					}
8827 					else if (newdig >= NBASE)
8828 					{
8829 						carry = newdig / NBASE;
8830 						newdig -= carry * NBASE;
8831 					}
8832 					else
8833 						carry = 0;
8834 					div[i] = newdig;
8835 				}
8836 				newdig = div[qi] + carry;
8837 				div[qi] = newdig;
8838 
8839 				/*
8840 				 * All the div[] digits except possibly div[qi] are now in the
8841 				 * range 0..NBASE-1.  We do not need to consider div[qi] in
8842 				 * the maxdiv value anymore, so we can reset maxdiv to 1.
8843 				 */
8844 				maxdiv = 1;
8845 
8846 				/*
8847 				 * Recompute the quotient digit since new info may have
8848 				 * propagated into the top four dividend digits
8849 				 */
8850 				fdividend = (double) div[qi];
8851 				for (i = 1; i < 4; i++)
8852 				{
8853 					fdividend *= NBASE;
8854 					if (qi + i <= div_ndigits)
8855 						fdividend += (double) div[qi + i];
8856 				}
8857 				/* Compute the (approximate) quotient digit */
8858 				fquotient = fdividend * fdivisorinverse;
8859 				qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
8860 					(((int) fquotient) - 1);	/* truncate towards -infinity */
8861 				maxdiv += Abs(qdigit);
8862 			}
8863 
8864 			/*
8865 			 * Subtract off the appropriate multiple of the divisor.
8866 			 *
8867 			 * The digits beyond div[qi] cannot overflow, because we know they
8868 			 * will fall within the maxdiv limit.  As for div[qi] itself, note
8869 			 * that qdigit is approximately trunc(div[qi] / vardigits[0]),
8870 			 * which would make the new value simply div[qi] mod vardigits[0].
8871 			 * The lower-order terms in qdigit can change this result by not
8872 			 * more than about twice INT_MAX/NBASE, so overflow is impossible.
8873 			 */
8874 			if (qdigit != 0)
8875 			{
8876 				int			istop = Min(var2ndigits, div_ndigits - qi + 1);
8877 
8878 				for (i = 0; i < istop; i++)
8879 					div[qi + i] -= qdigit * var2digits[i];
8880 			}
8881 		}
8882 
8883 		/*
8884 		 * The dividend digit we are about to replace might still be nonzero.
8885 		 * Fold it into the next digit position.
8886 		 *
8887 		 * There is no risk of overflow here, although proving that requires
8888 		 * some care.  Much as with the argument for div[qi] not overflowing,
8889 		 * if we consider the first two terms in the numerator and denominator
8890 		 * of qdigit, we can see that the final value of div[qi + 1] will be
8891 		 * approximately a remainder mod (vardigits[0]*NBASE + vardigits[1]).
8892 		 * Accounting for the lower-order terms is a bit complicated but ends
8893 		 * up adding not much more than INT_MAX/NBASE to the possible range.
8894 		 * Thus, div[qi + 1] cannot overflow here, and in its role as div[qi]
8895 		 * in the next loop iteration, it can't be large enough to cause
8896 		 * overflow in the carry propagation step (if any), either.
8897 		 *
8898 		 * But having said that: div[qi] can be more than INT_MAX/NBASE, as
8899 		 * noted above, which means that the product div[qi] * NBASE *can*
8900 		 * overflow.  When that happens, adding it to div[qi + 1] will always
8901 		 * cause a canceling overflow so that the end result is correct.  We
8902 		 * could avoid the intermediate overflow by doing the multiplication
8903 		 * and addition in int64 arithmetic, but so far there appears no need.
8904 		 */
8905 		div[qi + 1] += div[qi] * NBASE;
8906 
8907 		div[qi] = qdigit;
8908 	}
8909 
8910 	/*
8911 	 * Approximate and store the last quotient digit (div[div_ndigits])
8912 	 */
8913 	fdividend = (double) div[qi];
8914 	for (i = 1; i < 4; i++)
8915 		fdividend *= NBASE;
8916 	fquotient = fdividend * fdivisorinverse;
8917 	qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
8918 		(((int) fquotient) - 1);	/* truncate towards -infinity */
8919 	div[qi] = qdigit;
8920 
8921 	/*
8922 	 * Because the quotient digits might be off by one, some of them might be
8923 	 * -1 or NBASE at this point.  The represented value is correct in a
8924 	 * mathematical sense, but it doesn't look right.  We do a final carry
8925 	 * propagation pass to normalize the digits, which we combine with storing
8926 	 * the result digits into the output.  Note that this is still done at
8927 	 * full precision w/guard digits.
8928 	 */
8929 	alloc_var(result, div_ndigits + 1);
8930 	res_digits = result->digits;
8931 	carry = 0;
8932 	for (i = div_ndigits; i >= 0; i--)
8933 	{
8934 		newdig = div[i] + carry;
8935 		if (newdig < 0)
8936 		{
8937 			carry = -((-newdig - 1) / NBASE) - 1;
8938 			newdig -= carry * NBASE;
8939 		}
8940 		else if (newdig >= NBASE)
8941 		{
8942 			carry = newdig / NBASE;
8943 			newdig -= carry * NBASE;
8944 		}
8945 		else
8946 			carry = 0;
8947 		res_digits[i] = newdig;
8948 	}
8949 	Assert(carry == 0);
8950 
8951 	pfree(div);
8952 
8953 	/*
8954 	 * Finally, round the result to the requested precision.
8955 	 */
8956 	result->weight = res_weight;
8957 	result->sign = res_sign;
8958 
8959 	/* Round to target rscale (and set result->dscale) */
8960 	if (round)
8961 		round_var(result, rscale);
8962 	else
8963 		trunc_var(result, rscale);
8964 
8965 	/* Strip leading and trailing zeroes */
8966 	strip_var(result);
8967 }
8968 
8969 
8970 /*
8971  * Default scale selection for division
8972  *
8973  * Returns the appropriate result scale for the division result.
8974  */
8975 static int
select_div_scale(const NumericVar * var1,const NumericVar * var2)8976 select_div_scale(const NumericVar *var1, const NumericVar *var2)
8977 {
8978 	int			weight1,
8979 				weight2,
8980 				qweight,
8981 				i;
8982 	NumericDigit firstdigit1,
8983 				firstdigit2;
8984 	int			rscale;
8985 
8986 	/*
8987 	 * The result scale of a division isn't specified in any SQL standard. For
8988 	 * PostgreSQL we select a result scale that will give at least
8989 	 * NUMERIC_MIN_SIG_DIGITS significant digits, so that numeric gives a
8990 	 * result no less accurate than float8; but use a scale not less than
8991 	 * either input's display scale.
8992 	 */
8993 
8994 	/* Get the actual (normalized) weight and first digit of each input */
8995 
8996 	weight1 = 0;				/* values to use if var1 is zero */
8997 	firstdigit1 = 0;
8998 	for (i = 0; i < var1->ndigits; i++)
8999 	{
9000 		firstdigit1 = var1->digits[i];
9001 		if (firstdigit1 != 0)
9002 		{
9003 			weight1 = var1->weight - i;
9004 			break;
9005 		}
9006 	}
9007 
9008 	weight2 = 0;				/* values to use if var2 is zero */
9009 	firstdigit2 = 0;
9010 	for (i = 0; i < var2->ndigits; i++)
9011 	{
9012 		firstdigit2 = var2->digits[i];
9013 		if (firstdigit2 != 0)
9014 		{
9015 			weight2 = var2->weight - i;
9016 			break;
9017 		}
9018 	}
9019 
9020 	/*
9021 	 * Estimate weight of quotient.  If the two first digits are equal, we
9022 	 * can't be sure, but assume that var1 is less than var2.
9023 	 */
9024 	qweight = weight1 - weight2;
9025 	if (firstdigit1 <= firstdigit2)
9026 		qweight--;
9027 
9028 	/* Select result scale */
9029 	rscale = NUMERIC_MIN_SIG_DIGITS - qweight * DEC_DIGITS;
9030 	rscale = Max(rscale, var1->dscale);
9031 	rscale = Max(rscale, var2->dscale);
9032 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
9033 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
9034 
9035 	return rscale;
9036 }
9037 
9038 
9039 /*
9040  * mod_var() -
9041  *
9042  *	Calculate the modulo of two numerics at variable level
9043  */
9044 static void
mod_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)9045 mod_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
9046 {
9047 	NumericVar	tmp;
9048 
9049 	init_var(&tmp);
9050 
9051 	/* ---------
9052 	 * We do this using the equation
9053 	 *		mod(x,y) = x - trunc(x/y)*y
9054 	 * div_var can be persuaded to give us trunc(x/y) directly.
9055 	 * ----------
9056 	 */
9057 	div_var(var1, var2, &tmp, 0, false);
9058 
9059 	mul_var(var2, &tmp, &tmp, var2->dscale);
9060 
9061 	sub_var(var1, &tmp, result);
9062 
9063 	free_var(&tmp);
9064 }
9065 
9066 
9067 /*
9068  * div_mod_var() -
9069  *
9070  *	Calculate the truncated integer quotient and numeric remainder of two
9071  *	numeric variables.  The remainder is precise to var2's dscale.
9072  */
9073 static void
div_mod_var(const NumericVar * var1,const NumericVar * var2,NumericVar * quot,NumericVar * rem)9074 div_mod_var(const NumericVar *var1, const NumericVar *var2,
9075 			NumericVar *quot, NumericVar *rem)
9076 {
9077 	NumericVar	q;
9078 	NumericVar	r;
9079 
9080 	init_var(&q);
9081 	init_var(&r);
9082 
9083 	/*
9084 	 * Use div_var_fast() to get an initial estimate for the integer quotient.
9085 	 * This might be inaccurate (per the warning in div_var_fast's comments),
9086 	 * but we can correct it below.
9087 	 */
9088 	div_var_fast(var1, var2, &q, 0, false);
9089 
9090 	/* Compute initial estimate of remainder using the quotient estimate. */
9091 	mul_var(var2, &q, &r, var2->dscale);
9092 	sub_var(var1, &r, &r);
9093 
9094 	/*
9095 	 * Adjust the results if necessary --- the remainder should have the same
9096 	 * sign as var1, and its absolute value should be less than the absolute
9097 	 * value of var2.
9098 	 */
9099 	while (r.ndigits != 0 && r.sign != var1->sign)
9100 	{
9101 		/* The absolute value of the quotient is too large */
9102 		if (var1->sign == var2->sign)
9103 		{
9104 			sub_var(&q, &const_one, &q);
9105 			add_var(&r, var2, &r);
9106 		}
9107 		else
9108 		{
9109 			add_var(&q, &const_one, &q);
9110 			sub_var(&r, var2, &r);
9111 		}
9112 	}
9113 
9114 	while (cmp_abs(&r, var2) >= 0)
9115 	{
9116 		/* The absolute value of the quotient is too small */
9117 		if (var1->sign == var2->sign)
9118 		{
9119 			add_var(&q, &const_one, &q);
9120 			sub_var(&r, var2, &r);
9121 		}
9122 		else
9123 		{
9124 			sub_var(&q, &const_one, &q);
9125 			add_var(&r, var2, &r);
9126 		}
9127 	}
9128 
9129 	set_var_from_var(&q, quot);
9130 	set_var_from_var(&r, rem);
9131 
9132 	free_var(&q);
9133 	free_var(&r);
9134 }
9135 
9136 
9137 /*
9138  * ceil_var() -
9139  *
9140  *	Return the smallest integer greater than or equal to the argument
9141  *	on variable level
9142  */
9143 static void
ceil_var(const NumericVar * var,NumericVar * result)9144 ceil_var(const NumericVar *var, NumericVar *result)
9145 {
9146 	NumericVar	tmp;
9147 
9148 	init_var(&tmp);
9149 	set_var_from_var(var, &tmp);
9150 
9151 	trunc_var(&tmp, 0);
9152 
9153 	if (var->sign == NUMERIC_POS && cmp_var(var, &tmp) != 0)
9154 		add_var(&tmp, &const_one, &tmp);
9155 
9156 	set_var_from_var(&tmp, result);
9157 	free_var(&tmp);
9158 }
9159 
9160 
9161 /*
9162  * floor_var() -
9163  *
9164  *	Return the largest integer equal to or less than the argument
9165  *	on variable level
9166  */
9167 static void
floor_var(const NumericVar * var,NumericVar * result)9168 floor_var(const NumericVar *var, NumericVar *result)
9169 {
9170 	NumericVar	tmp;
9171 
9172 	init_var(&tmp);
9173 	set_var_from_var(var, &tmp);
9174 
9175 	trunc_var(&tmp, 0);
9176 
9177 	if (var->sign == NUMERIC_NEG && cmp_var(var, &tmp) != 0)
9178 		sub_var(&tmp, &const_one, &tmp);
9179 
9180 	set_var_from_var(&tmp, result);
9181 	free_var(&tmp);
9182 }
9183 
9184 
9185 /*
9186  * gcd_var() -
9187  *
9188  *	Calculate the greatest common divisor of two numerics at variable level
9189  */
9190 static void
gcd_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)9191 gcd_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
9192 {
9193 	int			res_dscale;
9194 	int			cmp;
9195 	NumericVar	tmp_arg;
9196 	NumericVar	mod;
9197 
9198 	res_dscale = Max(var1->dscale, var2->dscale);
9199 
9200 	/*
9201 	 * Arrange for var1 to be the number with the greater absolute value.
9202 	 *
9203 	 * This would happen automatically in the loop below, but avoids an
9204 	 * expensive modulo operation.
9205 	 */
9206 	cmp = cmp_abs(var1, var2);
9207 	if (cmp < 0)
9208 	{
9209 		const NumericVar *tmp = var1;
9210 
9211 		var1 = var2;
9212 		var2 = tmp;
9213 	}
9214 
9215 	/*
9216 	 * Also avoid the taking the modulo if the inputs have the same absolute
9217 	 * value, or if the smaller input is zero.
9218 	 */
9219 	if (cmp == 0 || var2->ndigits == 0)
9220 	{
9221 		set_var_from_var(var1, result);
9222 		result->sign = NUMERIC_POS;
9223 		result->dscale = res_dscale;
9224 		return;
9225 	}
9226 
9227 	init_var(&tmp_arg);
9228 	init_var(&mod);
9229 
9230 	/* Use the Euclidean algorithm to find the GCD */
9231 	set_var_from_var(var1, &tmp_arg);
9232 	set_var_from_var(var2, result);
9233 
9234 	for (;;)
9235 	{
9236 		/* this loop can take a while, so allow it to be interrupted */
9237 		CHECK_FOR_INTERRUPTS();
9238 
9239 		mod_var(&tmp_arg, result, &mod);
9240 		if (mod.ndigits == 0)
9241 			break;
9242 		set_var_from_var(result, &tmp_arg);
9243 		set_var_from_var(&mod, result);
9244 	}
9245 	result->sign = NUMERIC_POS;
9246 	result->dscale = res_dscale;
9247 
9248 	free_var(&tmp_arg);
9249 	free_var(&mod);
9250 }
9251 
9252 
9253 /*
9254  * sqrt_var() -
9255  *
9256  *	Compute the square root of x using the Karatsuba Square Root algorithm.
9257  *	NOTE: we allow rscale < 0 here, implying rounding before the decimal
9258  *	point.
9259  */
9260 static void
sqrt_var(const NumericVar * arg,NumericVar * result,int rscale)9261 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
9262 {
9263 	int			stat;
9264 	int			res_weight;
9265 	int			res_ndigits;
9266 	int			src_ndigits;
9267 	int			step;
9268 	int			ndigits[32];
9269 	int			blen;
9270 	int64		arg_int64;
9271 	int			src_idx;
9272 	int64		s_int64;
9273 	int64		r_int64;
9274 	NumericVar	s_var;
9275 	NumericVar	r_var;
9276 	NumericVar	a0_var;
9277 	NumericVar	a1_var;
9278 	NumericVar	q_var;
9279 	NumericVar	u_var;
9280 
9281 	stat = cmp_var(arg, &const_zero);
9282 	if (stat == 0)
9283 	{
9284 		zero_var(result);
9285 		result->dscale = rscale;
9286 		return;
9287 	}
9288 
9289 	/*
9290 	 * SQL2003 defines sqrt() in terms of power, so we need to emit the right
9291 	 * SQLSTATE error code if the operand is negative.
9292 	 */
9293 	if (stat < 0)
9294 		ereport(ERROR,
9295 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
9296 				 errmsg("cannot take square root of a negative number")));
9297 
9298 	init_var(&s_var);
9299 	init_var(&r_var);
9300 	init_var(&a0_var);
9301 	init_var(&a1_var);
9302 	init_var(&q_var);
9303 	init_var(&u_var);
9304 
9305 	/*
9306 	 * The result weight is half the input weight, rounded towards minus
9307 	 * infinity --- res_weight = floor(arg->weight / 2).
9308 	 */
9309 	if (arg->weight >= 0)
9310 		res_weight = arg->weight / 2;
9311 	else
9312 		res_weight = -((-arg->weight - 1) / 2 + 1);
9313 
9314 	/*
9315 	 * Number of NBASE digits to compute.  To ensure correct rounding, compute
9316 	 * at least 1 extra decimal digit.  We explicitly allow rscale to be
9317 	 * negative here, but must always compute at least 1 NBASE digit.  Thus
9318 	 * res_ndigits = res_weight + 1 + ceil((rscale + 1) / DEC_DIGITS) or 1.
9319 	 */
9320 	if (rscale + 1 >= 0)
9321 		res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS) / DEC_DIGITS;
9322 	else
9323 		res_ndigits = res_weight + 1 - (-rscale - 1) / DEC_DIGITS;
9324 	res_ndigits = Max(res_ndigits, 1);
9325 
9326 	/*
9327 	 * Number of source NBASE digits logically required to produce a result
9328 	 * with this precision --- every digit before the decimal point, plus 2
9329 	 * for each result digit after the decimal point (or minus 2 for each
9330 	 * result digit we round before the decimal point).
9331 	 */
9332 	src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
9333 	src_ndigits = Max(src_ndigits, 1);
9334 
9335 	/* ----------
9336 	 * From this point on, we treat the input and the result as integers and
9337 	 * compute the integer square root and remainder using the Karatsuba
9338 	 * Square Root algorithm, which may be written recursively as follows:
9339 	 *
9340 	 *	SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
9341 	 *		[ for some base b, and coefficients a0,a1,a2,a3 chosen so that
9342 	 *		  0 <= a0,a1,a2 < b and a3 >= b/4 ]
9343 	 *		Let (s,r) = SqrtRem(a3*b + a2)
9344 	 *		Let (q,u) = DivRem(r*b + a1, 2*s)
9345 	 *		Let s = s*b + q
9346 	 *		Let r = u*b + a0 - q^2
9347 	 *		If r < 0 Then
9348 	 *			Let r = r + s
9349 	 *			Let s = s - 1
9350 	 *			Let r = r + s
9351 	 *		Return (s,r)
9352 	 *
9353 	 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
9354 	 * RR-3805, November 1999.  At the time of writing this was available
9355 	 * on the net at <https://hal.inria.fr/inria-00072854>.
9356 	 *
9357 	 * The way to read the assumption "n = a3*b^3 + a2*b^2 + a1*b + a0" is
9358 	 * "choose a base b such that n requires at least four base-b digits to
9359 	 * express; then those digits are a3,a2,a1,a0, with a3 possibly larger
9360 	 * than b".  For optimal performance, b should have approximately a
9361 	 * quarter the number of digits in the input, so that the outer square
9362 	 * root computes roughly twice as many digits as the inner one.  For
9363 	 * simplicity, we choose b = NBASE^blen, an integer power of NBASE.
9364 	 *
9365 	 * We implement the algorithm iteratively rather than recursively, to
9366 	 * allow the working variables to be reused.  With this approach, each
9367 	 * digit of the input is read precisely once --- src_idx tracks the number
9368 	 * of input digits used so far.
9369 	 *
9370 	 * The array ndigits[] holds the number of NBASE digits of the input that
9371 	 * will have been used at the end of each iteration, which roughly doubles
9372 	 * each time.  Note that the array elements are stored in reverse order,
9373 	 * so if the final iteration requires src_ndigits = 37 input digits, the
9374 	 * array will contain [37,19,11,7,5,3], and we would start by computing
9375 	 * the square root of the 3 most significant NBASE digits.
9376 	 *
9377 	 * In each iteration, we choose blen to be the largest integer for which
9378 	 * the input number has a3 >= b/4, when written in the form above.  In
9379 	 * general, this means blen = src_ndigits / 4 (truncated), but if
9380 	 * src_ndigits is a multiple of 4, that might lead to the coefficient a3
9381 	 * being less than b/4 (if the first input digit is less than NBASE/4), in
9382 	 * which case we choose blen = src_ndigits / 4 - 1.  The number of digits
9383 	 * in the inner square root is then src_ndigits - 2*blen.  So, for
9384 	 * example, if we have src_ndigits = 26 initially, the array ndigits[]
9385 	 * will be either [26,14,8,4] or [26,14,8,6,4], depending on the size of
9386 	 * the first input digit.
9387 	 *
9388 	 * Additionally, we can put an upper bound on the number of steps required
9389 	 * as follows --- suppose that the number of source digits is an n-bit
9390 	 * number in the range [2^(n-1), 2^n-1], then blen will be in the range
9391 	 * [2^(n-3)-1, 2^(n-2)-1] and the number of digits in the inner square
9392 	 * root will be in the range [2^(n-2), 2^(n-1)+1].  In the next step, blen
9393 	 * will be in the range [2^(n-4)-1, 2^(n-3)] and the number of digits in
9394 	 * the next inner square root will be in the range [2^(n-3), 2^(n-2)+1].
9395 	 * This pattern repeats, and in the worst case the array ndigits[] will
9396 	 * contain [2^n-1, 2^(n-1)+1, 2^(n-2)+1, ... 9, 5, 3], and the computation
9397 	 * will require n steps.  Therefore, since all digit array sizes are
9398 	 * signed 32-bit integers, the number of steps required is guaranteed to
9399 	 * be less than 32.
9400 	 * ----------
9401 	 */
9402 	step = 0;
9403 	while ((ndigits[step] = src_ndigits) > 4)
9404 	{
9405 		/* Choose b so that a3 >= b/4, as described above */
9406 		blen = src_ndigits / 4;
9407 		if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
9408 			blen--;
9409 
9410 		/* Number of digits in the next step (inner square root) */
9411 		src_ndigits -= 2 * blen;
9412 		step++;
9413 	}
9414 
9415 	/*
9416 	 * First iteration (innermost square root and remainder):
9417 	 *
9418 	 * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
9419 	 * has at most 9 decimal digits, so estimate it using double precision
9420 	 * arithmetic, which will in fact almost certainly return the correct
9421 	 * result with no further correction required.
9422 	 */
9423 	arg_int64 = arg->digits[0];
9424 	for (src_idx = 1; src_idx < src_ndigits; src_idx++)
9425 	{
9426 		arg_int64 *= NBASE;
9427 		if (src_idx < arg->ndigits)
9428 			arg_int64 += arg->digits[src_idx];
9429 	}
9430 
9431 	s_int64 = (int64) sqrt((double) arg_int64);
9432 	r_int64 = arg_int64 - s_int64 * s_int64;
9433 
9434 	/*
9435 	 * Use Newton's method to correct the result, if necessary.
9436 	 *
9437 	 * This uses integer division with truncation to compute the truncated
9438 	 * integer square root by iterating using the formula x -> (x + n/x) / 2.
9439 	 * This is known to converge to isqrt(n), unless n+1 is a perfect square.
9440 	 * If n+1 is a perfect square, the sequence will oscillate between the two
9441 	 * values isqrt(n) and isqrt(n)+1, so we can be assured of convergence by
9442 	 * checking the remainder.
9443 	 */
9444 	while (r_int64 < 0 || r_int64 > 2 * s_int64)
9445 	{
9446 		s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
9447 		r_int64 = arg_int64 - s_int64 * s_int64;
9448 	}
9449 
9450 	/*
9451 	 * Iterations with src_ndigits <= 8:
9452 	 *
9453 	 * The next 1 or 2 iterations compute larger (outer) square roots with
9454 	 * src_ndigits <= 8, so the result still fits in an int64 (even though the
9455 	 * input no longer does) and we can continue to compute using int64
9456 	 * variables to avoid more expensive numeric computations.
9457 	 *
9458 	 * It is fairly easy to see that there is no risk of the intermediate
9459 	 * values below overflowing 64-bit integers.  In the worst case, the
9460 	 * previous iteration will have computed a 3-digit square root (of a
9461 	 * 6-digit input less than NBASE^6 / 4), so at the start of this
9462 	 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
9463 	 * less than 10^12.  In this case, blen will be 1, so numer will be less
9464 	 * than 10^17, and denom will be less than 10^12 (and hence u will also be
9465 	 * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
9466 	 * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
9467 	 * in 64-bit integers.
9468 	 */
9469 	step--;
9470 	while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
9471 	{
9472 		int			b;
9473 		int			a0;
9474 		int			a1;
9475 		int			i;
9476 		int64		numer;
9477 		int64		denom;
9478 		int64		q;
9479 		int64		u;
9480 
9481 		blen = (src_ndigits - src_idx) / 2;
9482 
9483 		/* Extract a1 and a0, and compute b */
9484 		a0 = 0;
9485 		a1 = 0;
9486 		b = 1;
9487 
9488 		for (i = 0; i < blen; i++, src_idx++)
9489 		{
9490 			b *= NBASE;
9491 			a1 *= NBASE;
9492 			if (src_idx < arg->ndigits)
9493 				a1 += arg->digits[src_idx];
9494 		}
9495 
9496 		for (i = 0; i < blen; i++, src_idx++)
9497 		{
9498 			a0 *= NBASE;
9499 			if (src_idx < arg->ndigits)
9500 				a0 += arg->digits[src_idx];
9501 		}
9502 
9503 		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
9504 		numer = r_int64 * b + a1;
9505 		denom = 2 * s_int64;
9506 		q = numer / denom;
9507 		u = numer - q * denom;
9508 
9509 		/* Compute s = s*b + q and r = u*b + a0 - q^2 */
9510 		s_int64 = s_int64 * b + q;
9511 		r_int64 = u * b + a0 - q * q;
9512 
9513 		if (r_int64 < 0)
9514 		{
9515 			/* s is too large by 1; set r += s, s--, r += s */
9516 			r_int64 += s_int64;
9517 			s_int64--;
9518 			r_int64 += s_int64;
9519 		}
9520 
9521 		Assert(src_idx == src_ndigits); /* All input digits consumed */
9522 		step--;
9523 	}
9524 
9525 	/*
9526 	 * On platforms with 128-bit integer support, we can further delay the
9527 	 * need to use numeric variables.
9528 	 */
9529 #ifdef HAVE_INT128
9530 	if (step >= 0)
9531 	{
9532 		int128		s_int128;
9533 		int128		r_int128;
9534 
9535 		s_int128 = s_int64;
9536 		r_int128 = r_int64;
9537 
9538 		/*
9539 		 * Iterations with src_ndigits <= 16:
9540 		 *
9541 		 * The result fits in an int128 (even though the input doesn't) so we
9542 		 * use int128 variables to avoid more expensive numeric computations.
9543 		 */
9544 		while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
9545 		{
9546 			int64		b;
9547 			int64		a0;
9548 			int64		a1;
9549 			int64		i;
9550 			int128		numer;
9551 			int128		denom;
9552 			int128		q;
9553 			int128		u;
9554 
9555 			blen = (src_ndigits - src_idx) / 2;
9556 
9557 			/* Extract a1 and a0, and compute b */
9558 			a0 = 0;
9559 			a1 = 0;
9560 			b = 1;
9561 
9562 			for (i = 0; i < blen; i++, src_idx++)
9563 			{
9564 				b *= NBASE;
9565 				a1 *= NBASE;
9566 				if (src_idx < arg->ndigits)
9567 					a1 += arg->digits[src_idx];
9568 			}
9569 
9570 			for (i = 0; i < blen; i++, src_idx++)
9571 			{
9572 				a0 *= NBASE;
9573 				if (src_idx < arg->ndigits)
9574 					a0 += arg->digits[src_idx];
9575 			}
9576 
9577 			/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
9578 			numer = r_int128 * b + a1;
9579 			denom = 2 * s_int128;
9580 			q = numer / denom;
9581 			u = numer - q * denom;
9582 
9583 			/* Compute s = s*b + q and r = u*b + a0 - q^2 */
9584 			s_int128 = s_int128 * b + q;
9585 			r_int128 = u * b + a0 - q * q;
9586 
9587 			if (r_int128 < 0)
9588 			{
9589 				/* s is too large by 1; set r += s, s--, r += s */
9590 				r_int128 += s_int128;
9591 				s_int128--;
9592 				r_int128 += s_int128;
9593 			}
9594 
9595 			Assert(src_idx == src_ndigits); /* All input digits consumed */
9596 			step--;
9597 		}
9598 
9599 		/*
9600 		 * All remaining iterations require numeric variables.  Convert the
9601 		 * integer values to NumericVar and continue.  Note that in the final
9602 		 * iteration we don't need the remainder, so we can save a few cycles
9603 		 * there by not fully computing it.
9604 		 */
9605 		int128_to_numericvar(s_int128, &s_var);
9606 		if (step >= 0)
9607 			int128_to_numericvar(r_int128, &r_var);
9608 	}
9609 	else
9610 	{
9611 		int64_to_numericvar(s_int64, &s_var);
9612 		/* step < 0, so we certainly don't need r */
9613 	}
9614 #else							/* !HAVE_INT128 */
9615 	int64_to_numericvar(s_int64, &s_var);
9616 	if (step >= 0)
9617 		int64_to_numericvar(r_int64, &r_var);
9618 #endif							/* HAVE_INT128 */
9619 
9620 	/*
9621 	 * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
9622 	 * use numeric variables.
9623 	 */
9624 	while (step >= 0)
9625 	{
9626 		int			tmp_len;
9627 
9628 		src_ndigits = ndigits[step];
9629 		blen = (src_ndigits - src_idx) / 2;
9630 
9631 		/* Extract a1 and a0 */
9632 		if (src_idx < arg->ndigits)
9633 		{
9634 			tmp_len = Min(blen, arg->ndigits - src_idx);
9635 			alloc_var(&a1_var, tmp_len);
9636 			memcpy(a1_var.digits, arg->digits + src_idx,
9637 				   tmp_len * sizeof(NumericDigit));
9638 			a1_var.weight = blen - 1;
9639 			a1_var.sign = NUMERIC_POS;
9640 			a1_var.dscale = 0;
9641 			strip_var(&a1_var);
9642 		}
9643 		else
9644 		{
9645 			zero_var(&a1_var);
9646 			a1_var.dscale = 0;
9647 		}
9648 		src_idx += blen;
9649 
9650 		if (src_idx < arg->ndigits)
9651 		{
9652 			tmp_len = Min(blen, arg->ndigits - src_idx);
9653 			alloc_var(&a0_var, tmp_len);
9654 			memcpy(a0_var.digits, arg->digits + src_idx,
9655 				   tmp_len * sizeof(NumericDigit));
9656 			a0_var.weight = blen - 1;
9657 			a0_var.sign = NUMERIC_POS;
9658 			a0_var.dscale = 0;
9659 			strip_var(&a0_var);
9660 		}
9661 		else
9662 		{
9663 			zero_var(&a0_var);
9664 			a0_var.dscale = 0;
9665 		}
9666 		src_idx += blen;
9667 
9668 		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
9669 		set_var_from_var(&r_var, &q_var);
9670 		q_var.weight += blen;
9671 		add_var(&q_var, &a1_var, &q_var);
9672 		add_var(&s_var, &s_var, &u_var);
9673 		div_mod_var(&q_var, &u_var, &q_var, &u_var);
9674 
9675 		/* Compute s = s*b + q */
9676 		s_var.weight += blen;
9677 		add_var(&s_var, &q_var, &s_var);
9678 
9679 		/*
9680 		 * Compute r = u*b + a0 - q^2.
9681 		 *
9682 		 * In the final iteration, we don't actually need r; we just need to
9683 		 * know whether it is negative, so that we know whether to adjust s.
9684 		 * So instead of the final subtraction we can just compare.
9685 		 */
9686 		u_var.weight += blen;
9687 		add_var(&u_var, &a0_var, &u_var);
9688 		mul_var(&q_var, &q_var, &q_var, 0);
9689 
9690 		if (step > 0)
9691 		{
9692 			/* Need r for later iterations */
9693 			sub_var(&u_var, &q_var, &r_var);
9694 			if (r_var.sign == NUMERIC_NEG)
9695 			{
9696 				/* s is too large by 1; set r += s, s--, r += s */
9697 				add_var(&r_var, &s_var, &r_var);
9698 				sub_var(&s_var, &const_one, &s_var);
9699 				add_var(&r_var, &s_var, &r_var);
9700 			}
9701 		}
9702 		else
9703 		{
9704 			/* Don't need r anymore, except to test if s is too large by 1 */
9705 			if (cmp_var(&u_var, &q_var) < 0)
9706 				sub_var(&s_var, &const_one, &s_var);
9707 		}
9708 
9709 		Assert(src_idx == src_ndigits); /* All input digits consumed */
9710 		step--;
9711 	}
9712 
9713 	/*
9714 	 * Construct the final result, rounding it to the requested precision.
9715 	 */
9716 	set_var_from_var(&s_var, result);
9717 	result->weight = res_weight;
9718 	result->sign = NUMERIC_POS;
9719 
9720 	/* Round to target rscale (and set result->dscale) */
9721 	round_var(result, rscale);
9722 
9723 	/* Strip leading and trailing zeroes */
9724 	strip_var(result);
9725 
9726 	free_var(&s_var);
9727 	free_var(&r_var);
9728 	free_var(&a0_var);
9729 	free_var(&a1_var);
9730 	free_var(&q_var);
9731 	free_var(&u_var);
9732 }
9733 
9734 
9735 /*
9736  * exp_var() -
9737  *
9738  *	Raise e to the power of x, computed to rscale fractional digits
9739  */
9740 static void
exp_var(const NumericVar * arg,NumericVar * result,int rscale)9741 exp_var(const NumericVar *arg, NumericVar *result, int rscale)
9742 {
9743 	NumericVar	x;
9744 	NumericVar	elem;
9745 	NumericVar	ni;
9746 	double		val;
9747 	int			dweight;
9748 	int			ndiv2;
9749 	int			sig_digits;
9750 	int			local_rscale;
9751 
9752 	init_var(&x);
9753 	init_var(&elem);
9754 	init_var(&ni);
9755 
9756 	set_var_from_var(arg, &x);
9757 
9758 	/*
9759 	 * Estimate the dweight of the result using floating point arithmetic, so
9760 	 * that we can choose an appropriate local rscale for the calculation.
9761 	 */
9762 	val = numericvar_to_double_no_overflow(&x);
9763 
9764 	/* Guard against overflow/underflow */
9765 	/* If you change this limit, see also power_var()'s limit */
9766 	if (Abs(val) >= NUMERIC_MAX_RESULT_SCALE * 3)
9767 	{
9768 		if (val > 0)
9769 			ereport(ERROR,
9770 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
9771 					 errmsg("value overflows numeric format")));
9772 		zero_var(result);
9773 		result->dscale = rscale;
9774 		return;
9775 	}
9776 
9777 	/* decimal weight = log10(e^x) = x * log10(e) */
9778 	dweight = (int) (val * 0.434294481903252);
9779 
9780 	/*
9781 	 * Reduce x to the range -0.01 <= x <= 0.01 (approximately) by dividing by
9782 	 * 2^n, to improve the convergence rate of the Taylor series.
9783 	 */
9784 	if (Abs(val) > 0.01)
9785 	{
9786 		NumericVar	tmp;
9787 
9788 		init_var(&tmp);
9789 		set_var_from_var(&const_two, &tmp);
9790 
9791 		ndiv2 = 1;
9792 		val /= 2;
9793 
9794 		while (Abs(val) > 0.01)
9795 		{
9796 			ndiv2++;
9797 			val /= 2;
9798 			add_var(&tmp, &tmp, &tmp);
9799 		}
9800 
9801 		local_rscale = x.dscale + ndiv2;
9802 		div_var_fast(&x, &tmp, &x, local_rscale, true);
9803 
9804 		free_var(&tmp);
9805 	}
9806 	else
9807 		ndiv2 = 0;
9808 
9809 	/*
9810 	 * Set the scale for the Taylor series expansion.  The final result has
9811 	 * (dweight + rscale + 1) significant digits.  In addition, we have to
9812 	 * raise the Taylor series result to the power 2^ndiv2, which introduces
9813 	 * an error of up to around log10(2^ndiv2) digits, so work with this many
9814 	 * extra digits of precision (plus a few more for good measure).
9815 	 */
9816 	sig_digits = 1 + dweight + rscale + (int) (ndiv2 * 0.301029995663981);
9817 	sig_digits = Max(sig_digits, 0) + 8;
9818 
9819 	local_rscale = sig_digits - 1;
9820 
9821 	/*
9822 	 * Use the Taylor series
9823 	 *
9824 	 * exp(x) = 1 + x + x^2/2! + x^3/3! + ...
9825 	 *
9826 	 * Given the limited range of x, this should converge reasonably quickly.
9827 	 * We run the series until the terms fall below the local_rscale limit.
9828 	 */
9829 	add_var(&const_one, &x, result);
9830 
9831 	mul_var(&x, &x, &elem, local_rscale);
9832 	set_var_from_var(&const_two, &ni);
9833 	div_var_fast(&elem, &ni, &elem, local_rscale, true);
9834 
9835 	while (elem.ndigits != 0)
9836 	{
9837 		add_var(result, &elem, result);
9838 
9839 		mul_var(&elem, &x, &elem, local_rscale);
9840 		add_var(&ni, &const_one, &ni);
9841 		div_var_fast(&elem, &ni, &elem, local_rscale, true);
9842 	}
9843 
9844 	/*
9845 	 * Compensate for the argument range reduction.  Since the weight of the
9846 	 * result doubles with each multiplication, we can reduce the local rscale
9847 	 * as we proceed.
9848 	 */
9849 	while (ndiv2-- > 0)
9850 	{
9851 		local_rscale = sig_digits - result->weight * 2 * DEC_DIGITS;
9852 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9853 		mul_var(result, result, result, local_rscale);
9854 	}
9855 
9856 	/* Round to requested rscale */
9857 	round_var(result, rscale);
9858 
9859 	free_var(&x);
9860 	free_var(&elem);
9861 	free_var(&ni);
9862 }
9863 
9864 
9865 /*
9866  * Estimate the dweight of the most significant decimal digit of the natural
9867  * logarithm of a number.
9868  *
9869  * Essentially, we're approximating log10(abs(ln(var))).  This is used to
9870  * determine the appropriate rscale when computing natural logarithms.
9871  */
9872 static int
estimate_ln_dweight(const NumericVar * var)9873 estimate_ln_dweight(const NumericVar *var)
9874 {
9875 	int			ln_dweight;
9876 
9877 	if (cmp_var(var, &const_zero_point_nine) >= 0 &&
9878 		cmp_var(var, &const_one_point_one) <= 0)
9879 	{
9880 		/*
9881 		 * 0.9 <= var <= 1.1
9882 		 *
9883 		 * ln(var) has a negative weight (possibly very large).  To get a
9884 		 * reasonably accurate result, estimate it using ln(1+x) ~= x.
9885 		 */
9886 		NumericVar	x;
9887 
9888 		init_var(&x);
9889 		sub_var(var, &const_one, &x);
9890 
9891 		if (x.ndigits > 0)
9892 		{
9893 			/* Use weight of most significant decimal digit of x */
9894 			ln_dweight = x.weight * DEC_DIGITS + (int) log10(x.digits[0]);
9895 		}
9896 		else
9897 		{
9898 			/* x = 0.  Since ln(1) = 0 exactly, we don't need extra digits */
9899 			ln_dweight = 0;
9900 		}
9901 
9902 		free_var(&x);
9903 	}
9904 	else
9905 	{
9906 		/*
9907 		 * Estimate the logarithm using the first couple of digits from the
9908 		 * input number.  This will give an accurate result whenever the input
9909 		 * is not too close to 1.
9910 		 */
9911 		if (var->ndigits > 0)
9912 		{
9913 			int			digits;
9914 			int			dweight;
9915 			double		ln_var;
9916 
9917 			digits = var->digits[0];
9918 			dweight = var->weight * DEC_DIGITS;
9919 
9920 			if (var->ndigits > 1)
9921 			{
9922 				digits = digits * NBASE + var->digits[1];
9923 				dweight -= DEC_DIGITS;
9924 			}
9925 
9926 			/*----------
9927 			 * We have var ~= digits * 10^dweight
9928 			 * so ln(var) ~= ln(digits) + dweight * ln(10)
9929 			 *----------
9930 			 */
9931 			ln_var = log((double) digits) + dweight * 2.302585092994046;
9932 			ln_dweight = (int) log10(Abs(ln_var));
9933 		}
9934 		else
9935 		{
9936 			/* Caller should fail on ln(0), but for the moment return zero */
9937 			ln_dweight = 0;
9938 		}
9939 	}
9940 
9941 	return ln_dweight;
9942 }
9943 
9944 
9945 /*
9946  * ln_var() -
9947  *
9948  *	Compute the natural log of x
9949  */
9950 static void
ln_var(const NumericVar * arg,NumericVar * result,int rscale)9951 ln_var(const NumericVar *arg, NumericVar *result, int rscale)
9952 {
9953 	NumericVar	x;
9954 	NumericVar	xx;
9955 	NumericVar	ni;
9956 	NumericVar	elem;
9957 	NumericVar	fact;
9958 	int			nsqrt;
9959 	int			local_rscale;
9960 	int			cmp;
9961 
9962 	cmp = cmp_var(arg, &const_zero);
9963 	if (cmp == 0)
9964 		ereport(ERROR,
9965 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
9966 				 errmsg("cannot take logarithm of zero")));
9967 	else if (cmp < 0)
9968 		ereport(ERROR,
9969 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
9970 				 errmsg("cannot take logarithm of a negative number")));
9971 
9972 	init_var(&x);
9973 	init_var(&xx);
9974 	init_var(&ni);
9975 	init_var(&elem);
9976 	init_var(&fact);
9977 
9978 	set_var_from_var(arg, &x);
9979 	set_var_from_var(&const_two, &fact);
9980 
9981 	/*
9982 	 * Reduce input into range 0.9 < x < 1.1 with repeated sqrt() operations.
9983 	 *
9984 	 * The final logarithm will have up to around rscale+6 significant digits.
9985 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
9986 	 * rscale as we work so that we keep this many significant digits at each
9987 	 * step (plus a few more for good measure).
9988 	 *
9989 	 * Note that we allow local_rscale < 0 during this input reduction
9990 	 * process, which implies rounding before the decimal point.  sqrt_var()
9991 	 * explicitly supports this, and it significantly reduces the work
9992 	 * required to reduce very large inputs to the required range.  Once the
9993 	 * input reduction is complete, x.weight will be 0 and its display scale
9994 	 * will be non-negative again.
9995 	 */
9996 	nsqrt = 0;
9997 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
9998 	{
9999 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
10000 		sqrt_var(&x, &x, local_rscale);
10001 		mul_var(&fact, &const_two, &fact, 0);
10002 		nsqrt++;
10003 	}
10004 	while (cmp_var(&x, &const_one_point_one) >= 0)
10005 	{
10006 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
10007 		sqrt_var(&x, &x, local_rscale);
10008 		mul_var(&fact, &const_two, &fact, 0);
10009 		nsqrt++;
10010 	}
10011 
10012 	/*
10013 	 * We use the Taylor series for 0.5 * ln((1+z)/(1-z)),
10014 	 *
10015 	 * z + z^3/3 + z^5/5 + ...
10016 	 *
10017 	 * where z = (x-1)/(x+1) is in the range (approximately) -0.053 .. 0.048
10018 	 * due to the above range-reduction of x.
10019 	 *
10020 	 * The convergence of this is not as fast as one would like, but is
10021 	 * tolerable given that z is small.
10022 	 *
10023 	 * The Taylor series result will be multiplied by 2^(nsqrt+1), which has a
10024 	 * decimal weight of (nsqrt+1) * log10(2), so work with this many extra
10025 	 * digits of precision (plus a few more for good measure).
10026 	 */
10027 	local_rscale = rscale + (int) ((nsqrt + 1) * 0.301029995663981) + 8;
10028 
10029 	sub_var(&x, &const_one, result);
10030 	add_var(&x, &const_one, &elem);
10031 	div_var_fast(result, &elem, result, local_rscale, true);
10032 	set_var_from_var(result, &xx);
10033 	mul_var(result, result, &x, local_rscale);
10034 
10035 	set_var_from_var(&const_one, &ni);
10036 
10037 	for (;;)
10038 	{
10039 		add_var(&ni, &const_two, &ni);
10040 		mul_var(&xx, &x, &xx, local_rscale);
10041 		div_var_fast(&xx, &ni, &elem, local_rscale, true);
10042 
10043 		if (elem.ndigits == 0)
10044 			break;
10045 
10046 		add_var(result, &elem, result);
10047 
10048 		if (elem.weight < (result->weight - local_rscale * 2 / DEC_DIGITS))
10049 			break;
10050 	}
10051 
10052 	/* Compensate for argument range reduction, round to requested rscale */
10053 	mul_var(result, &fact, result, rscale);
10054 
10055 	free_var(&x);
10056 	free_var(&xx);
10057 	free_var(&ni);
10058 	free_var(&elem);
10059 	free_var(&fact);
10060 }
10061 
10062 
10063 /*
10064  * log_var() -
10065  *
10066  *	Compute the logarithm of num in a given base.
10067  *
10068  *	Note: this routine chooses dscale of the result.
10069  */
10070 static void
log_var(const NumericVar * base,const NumericVar * num,NumericVar * result)10071 log_var(const NumericVar *base, const NumericVar *num, NumericVar *result)
10072 {
10073 	NumericVar	ln_base;
10074 	NumericVar	ln_num;
10075 	int			ln_base_dweight;
10076 	int			ln_num_dweight;
10077 	int			result_dweight;
10078 	int			rscale;
10079 	int			ln_base_rscale;
10080 	int			ln_num_rscale;
10081 
10082 	init_var(&ln_base);
10083 	init_var(&ln_num);
10084 
10085 	/* Estimated dweights of ln(base), ln(num) and the final result */
10086 	ln_base_dweight = estimate_ln_dweight(base);
10087 	ln_num_dweight = estimate_ln_dweight(num);
10088 	result_dweight = ln_num_dweight - ln_base_dweight;
10089 
10090 	/*
10091 	 * Select the scale of the result so that it will have at least
10092 	 * NUMERIC_MIN_SIG_DIGITS significant digits and is not less than either
10093 	 * input's display scale.
10094 	 */
10095 	rscale = NUMERIC_MIN_SIG_DIGITS - result_dweight;
10096 	rscale = Max(rscale, base->dscale);
10097 	rscale = Max(rscale, num->dscale);
10098 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
10099 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
10100 
10101 	/*
10102 	 * Set the scales for ln(base) and ln(num) so that they each have more
10103 	 * significant digits than the final result.
10104 	 */
10105 	ln_base_rscale = rscale + result_dweight - ln_base_dweight + 8;
10106 	ln_base_rscale = Max(ln_base_rscale, NUMERIC_MIN_DISPLAY_SCALE);
10107 
10108 	ln_num_rscale = rscale + result_dweight - ln_num_dweight + 8;
10109 	ln_num_rscale = Max(ln_num_rscale, NUMERIC_MIN_DISPLAY_SCALE);
10110 
10111 	/* Form natural logarithms */
10112 	ln_var(base, &ln_base, ln_base_rscale);
10113 	ln_var(num, &ln_num, ln_num_rscale);
10114 
10115 	/* Divide and round to the required scale */
10116 	div_var_fast(&ln_num, &ln_base, result, rscale, true);
10117 
10118 	free_var(&ln_num);
10119 	free_var(&ln_base);
10120 }
10121 
10122 
10123 /*
10124  * power_var() -
10125  *
10126  *	Raise base to the power of exp
10127  *
10128  *	Note: this routine chooses dscale of the result.
10129  */
10130 static void
power_var(const NumericVar * base,const NumericVar * exp,NumericVar * result)10131 power_var(const NumericVar *base, const NumericVar *exp, NumericVar *result)
10132 {
10133 	int			res_sign;
10134 	NumericVar	abs_base;
10135 	NumericVar	ln_base;
10136 	NumericVar	ln_num;
10137 	int			ln_dweight;
10138 	int			rscale;
10139 	int			sig_digits;
10140 	int			local_rscale;
10141 	double		val;
10142 
10143 	/* If exp can be represented as an integer, use power_var_int */
10144 	if (exp->ndigits == 0 || exp->ndigits <= exp->weight + 1)
10145 	{
10146 		/* exact integer, but does it fit in int? */
10147 		int64		expval64;
10148 
10149 		if (numericvar_to_int64(exp, &expval64))
10150 		{
10151 			if (expval64 >= PG_INT32_MIN && expval64 <= PG_INT32_MAX)
10152 			{
10153 				/* Okay, select rscale */
10154 				rscale = NUMERIC_MIN_SIG_DIGITS;
10155 				rscale = Max(rscale, base->dscale);
10156 				rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
10157 				rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
10158 
10159 				power_var_int(base, (int) expval64, result, rscale);
10160 				return;
10161 			}
10162 		}
10163 	}
10164 
10165 	/*
10166 	 * This avoids log(0) for cases of 0 raised to a non-integer.  0 ^ 0 is
10167 	 * handled by power_var_int().
10168 	 */
10169 	if (cmp_var(base, &const_zero) == 0)
10170 	{
10171 		set_var_from_var(&const_zero, result);
10172 		result->dscale = NUMERIC_MIN_SIG_DIGITS;	/* no need to round */
10173 		return;
10174 	}
10175 
10176 	init_var(&abs_base);
10177 	init_var(&ln_base);
10178 	init_var(&ln_num);
10179 
10180 	/*
10181 	 * If base is negative, insist that exp be an integer.  The result is then
10182 	 * positive if exp is even and negative if exp is odd.
10183 	 */
10184 	if (base->sign == NUMERIC_NEG)
10185 	{
10186 		/*
10187 		 * Check that exp is an integer.  This error code is defined by the
10188 		 * SQL standard, and matches other errors in numeric_power().
10189 		 */
10190 		if (exp->ndigits > 0 && exp->ndigits > exp->weight + 1)
10191 			ereport(ERROR,
10192 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
10193 					 errmsg("a negative number raised to a non-integer power yields a complex result")));
10194 
10195 		/* Test if exp is odd or even */
10196 		if (exp->ndigits > 0 && exp->ndigits == exp->weight + 1 &&
10197 			(exp->digits[exp->ndigits - 1] & 1))
10198 			res_sign = NUMERIC_NEG;
10199 		else
10200 			res_sign = NUMERIC_POS;
10201 
10202 		/* Then work with abs(base) below */
10203 		set_var_from_var(base, &abs_base);
10204 		abs_base.sign = NUMERIC_POS;
10205 		base = &abs_base;
10206 	}
10207 	else
10208 		res_sign = NUMERIC_POS;
10209 
10210 	/*----------
10211 	 * Decide on the scale for the ln() calculation.  For this we need an
10212 	 * estimate of the weight of the result, which we obtain by doing an
10213 	 * initial low-precision calculation of exp * ln(base).
10214 	 *
10215 	 * We want result = e ^ (exp * ln(base))
10216 	 * so result dweight = log10(result) = exp * ln(base) * log10(e)
10217 	 *
10218 	 * We also perform a crude overflow test here so that we can exit early if
10219 	 * the full-precision result is sure to overflow, and to guard against
10220 	 * integer overflow when determining the scale for the real calculation.
10221 	 * exp_var() supports inputs up to NUMERIC_MAX_RESULT_SCALE * 3, so the
10222 	 * result will overflow if exp * ln(base) >= NUMERIC_MAX_RESULT_SCALE * 3.
10223 	 * Since the values here are only approximations, we apply a small fuzz
10224 	 * factor to this overflow test and let exp_var() determine the exact
10225 	 * overflow threshold so that it is consistent for all inputs.
10226 	 *----------
10227 	 */
10228 	ln_dweight = estimate_ln_dweight(base);
10229 
10230 	/*
10231 	 * Set the scale for the low-precision calculation, computing ln(base) to
10232 	 * around 8 significant digits.  Note that ln_dweight may be as small as
10233 	 * -SHRT_MAX, so the scale may exceed NUMERIC_MAX_DISPLAY_SCALE here.
10234 	 */
10235 	local_rscale = 8 - ln_dweight;
10236 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
10237 
10238 	ln_var(base, &ln_base, local_rscale);
10239 
10240 	mul_var(&ln_base, exp, &ln_num, local_rscale);
10241 
10242 	val = numericvar_to_double_no_overflow(&ln_num);
10243 
10244 	/* initial overflow/underflow test with fuzz factor */
10245 	if (Abs(val) > NUMERIC_MAX_RESULT_SCALE * 3.01)
10246 	{
10247 		if (val > 0)
10248 			ereport(ERROR,
10249 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
10250 					 errmsg("value overflows numeric format")));
10251 		zero_var(result);
10252 		result->dscale = NUMERIC_MAX_DISPLAY_SCALE;
10253 		return;
10254 	}
10255 
10256 	val *= 0.434294481903252;	/* approximate decimal result weight */
10257 
10258 	/* choose the result scale */
10259 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
10260 	rscale = Max(rscale, base->dscale);
10261 	rscale = Max(rscale, exp->dscale);
10262 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
10263 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
10264 
10265 	/* significant digits required in the result */
10266 	sig_digits = rscale + (int) val;
10267 	sig_digits = Max(sig_digits, 0);
10268 
10269 	/* set the scale for the real exp * ln(base) calculation */
10270 	local_rscale = sig_digits - ln_dweight + 8;
10271 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
10272 
10273 	/* and do the real calculation */
10274 
10275 	ln_var(base, &ln_base, local_rscale);
10276 
10277 	mul_var(&ln_base, exp, &ln_num, local_rscale);
10278 
10279 	exp_var(&ln_num, result, rscale);
10280 
10281 	if (res_sign == NUMERIC_NEG && result->ndigits > 0)
10282 		result->sign = NUMERIC_NEG;
10283 
10284 	free_var(&ln_num);
10285 	free_var(&ln_base);
10286 	free_var(&abs_base);
10287 }
10288 
10289 /*
10290  * power_var_int() -
10291  *
10292  *	Raise base to the power of exp, where exp is an integer.
10293  */
10294 static void
power_var_int(const NumericVar * base,int exp,NumericVar * result,int rscale)10295 power_var_int(const NumericVar *base, int exp, NumericVar *result, int rscale)
10296 {
10297 	double		f;
10298 	int			p;
10299 	int			i;
10300 	int			sig_digits;
10301 	unsigned int mask;
10302 	bool		neg;
10303 	NumericVar	base_prod;
10304 	int			local_rscale;
10305 
10306 	/* Handle some common special cases, as well as corner cases */
10307 	switch (exp)
10308 	{
10309 		case 0:
10310 
10311 			/*
10312 			 * While 0 ^ 0 can be either 1 or indeterminate (error), we treat
10313 			 * it as 1 because most programming languages do this. SQL:2003
10314 			 * also requires a return value of 1.
10315 			 * https://en.wikipedia.org/wiki/Exponentiation#Zero_to_the_zero_power
10316 			 */
10317 			set_var_from_var(&const_one, result);
10318 			result->dscale = rscale;	/* no need to round */
10319 			return;
10320 		case 1:
10321 			set_var_from_var(base, result);
10322 			round_var(result, rscale);
10323 			return;
10324 		case -1:
10325 			div_var(&const_one, base, result, rscale, true);
10326 			return;
10327 		case 2:
10328 			mul_var(base, base, result, rscale);
10329 			return;
10330 		default:
10331 			break;
10332 	}
10333 
10334 	/* Handle the special case where the base is zero */
10335 	if (base->ndigits == 0)
10336 	{
10337 		if (exp < 0)
10338 			ereport(ERROR,
10339 					(errcode(ERRCODE_DIVISION_BY_ZERO),
10340 					 errmsg("division by zero")));
10341 		zero_var(result);
10342 		result->dscale = rscale;
10343 		return;
10344 	}
10345 
10346 	/*
10347 	 * The general case repeatedly multiplies base according to the bit
10348 	 * pattern of exp.
10349 	 *
10350 	 * First we need to estimate the weight of the result so that we know how
10351 	 * many significant digits are needed.
10352 	 */
10353 	f = base->digits[0];
10354 	p = base->weight * DEC_DIGITS;
10355 
10356 	for (i = 1; i < base->ndigits && i * DEC_DIGITS < 16; i++)
10357 	{
10358 		f = f * NBASE + base->digits[i];
10359 		p -= DEC_DIGITS;
10360 	}
10361 
10362 	/*----------
10363 	 * We have base ~= f * 10^p
10364 	 * so log10(result) = log10(base^exp) ~= exp * (log10(f) + p)
10365 	 *----------
10366 	 */
10367 	f = exp * (log10(f) + p);
10368 
10369 	/*
10370 	 * Apply crude overflow/underflow tests so we can exit early if the result
10371 	 * certainly will overflow/underflow.
10372 	 */
10373 	if (f > 3 * SHRT_MAX * DEC_DIGITS)
10374 		ereport(ERROR,
10375 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
10376 				 errmsg("value overflows numeric format")));
10377 	if (f + 1 < -rscale || f + 1 < -NUMERIC_MAX_DISPLAY_SCALE)
10378 	{
10379 		zero_var(result);
10380 		result->dscale = rscale;
10381 		return;
10382 	}
10383 
10384 	/*
10385 	 * Approximate number of significant digits in the result.  Note that the
10386 	 * underflow test above means that this is necessarily >= 0.
10387 	 */
10388 	sig_digits = 1 + rscale + (int) f;
10389 
10390 	/*
10391 	 * The multiplications to produce the result may introduce an error of up
10392 	 * to around log10(abs(exp)) digits, so work with this many extra digits
10393 	 * of precision (plus a few more for good measure).
10394 	 */
10395 	sig_digits += (int) log(fabs((double) exp)) + 8;
10396 
10397 	/*
10398 	 * Now we can proceed with the multiplications.
10399 	 */
10400 	neg = (exp < 0);
10401 	mask = Abs(exp);
10402 
10403 	init_var(&base_prod);
10404 	set_var_from_var(base, &base_prod);
10405 
10406 	if (mask & 1)
10407 		set_var_from_var(base, result);
10408 	else
10409 		set_var_from_var(&const_one, result);
10410 
10411 	while ((mask >>= 1) > 0)
10412 	{
10413 		/*
10414 		 * Do the multiplications using rscales large enough to hold the
10415 		 * results to the required number of significant digits, but don't
10416 		 * waste time by exceeding the scales of the numbers themselves.
10417 		 */
10418 		local_rscale = sig_digits - 2 * base_prod.weight * DEC_DIGITS;
10419 		local_rscale = Min(local_rscale, 2 * base_prod.dscale);
10420 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
10421 
10422 		mul_var(&base_prod, &base_prod, &base_prod, local_rscale);
10423 
10424 		if (mask & 1)
10425 		{
10426 			local_rscale = sig_digits -
10427 				(base_prod.weight + result->weight) * DEC_DIGITS;
10428 			local_rscale = Min(local_rscale,
10429 							   base_prod.dscale + result->dscale);
10430 			local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
10431 
10432 			mul_var(&base_prod, result, result, local_rscale);
10433 		}
10434 
10435 		/*
10436 		 * When abs(base) > 1, the number of digits to the left of the decimal
10437 		 * point in base_prod doubles at each iteration, so if exp is large we
10438 		 * could easily spend large amounts of time and memory space doing the
10439 		 * multiplications.  But once the weight exceeds what will fit in
10440 		 * int16, the final result is guaranteed to overflow (or underflow, if
10441 		 * exp < 0), so we can give up before wasting too many cycles.
10442 		 */
10443 		if (base_prod.weight > SHRT_MAX || result->weight > SHRT_MAX)
10444 		{
10445 			/* overflow, unless neg, in which case result should be 0 */
10446 			if (!neg)
10447 				ereport(ERROR,
10448 						(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
10449 						 errmsg("value overflows numeric format")));
10450 			zero_var(result);
10451 			neg = false;
10452 			break;
10453 		}
10454 	}
10455 
10456 	free_var(&base_prod);
10457 
10458 	/* Compensate for input sign, and round to requested rscale */
10459 	if (neg)
10460 		div_var_fast(&const_one, result, result, rscale, true);
10461 	else
10462 		round_var(result, rscale);
10463 }
10464 
10465 /*
10466  * power_ten_int() -
10467  *
10468  *	Raise ten to the power of exp, where exp is an integer.  Note that unlike
10469  *	power_var_int(), this does no overflow/underflow checking or rounding.
10470  */
10471 static void
power_ten_int(int exp,NumericVar * result)10472 power_ten_int(int exp, NumericVar *result)
10473 {
10474 	/* Construct the result directly, starting from 10^0 = 1 */
10475 	set_var_from_var(&const_one, result);
10476 
10477 	/* Scale needed to represent the result exactly */
10478 	result->dscale = exp < 0 ? -exp : 0;
10479 
10480 	/* Base-NBASE weight of result and remaining exponent */
10481 	if (exp >= 0)
10482 		result->weight = exp / DEC_DIGITS;
10483 	else
10484 		result->weight = (exp + 1) / DEC_DIGITS - 1;
10485 
10486 	exp -= result->weight * DEC_DIGITS;
10487 
10488 	/* Final adjustment of the result's single NBASE digit */
10489 	while (exp-- > 0)
10490 		result->digits[0] *= 10;
10491 }
10492 
10493 
10494 /* ----------------------------------------------------------------------
10495  *
10496  * Following are the lowest level functions that operate unsigned
10497  * on the variable level
10498  *
10499  * ----------------------------------------------------------------------
10500  */
10501 
10502 
10503 /* ----------
10504  * cmp_abs() -
10505  *
10506  *	Compare the absolute values of var1 and var2
10507  *	Returns:	-1 for ABS(var1) < ABS(var2)
10508  *				0  for ABS(var1) == ABS(var2)
10509  *				1  for ABS(var1) > ABS(var2)
10510  * ----------
10511  */
10512 static int
cmp_abs(const NumericVar * var1,const NumericVar * var2)10513 cmp_abs(const NumericVar *var1, const NumericVar *var2)
10514 {
10515 	return cmp_abs_common(var1->digits, var1->ndigits, var1->weight,
10516 						  var2->digits, var2->ndigits, var2->weight);
10517 }
10518 
10519 /* ----------
10520  * cmp_abs_common() -
10521  *
10522  *	Main routine of cmp_abs(). This function can be used by both
10523  *	NumericVar and Numeric.
10524  * ----------
10525  */
10526 static int
cmp_abs_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,const NumericDigit * var2digits,int var2ndigits,int var2weight)10527 cmp_abs_common(const NumericDigit *var1digits, int var1ndigits, int var1weight,
10528 			   const NumericDigit *var2digits, int var2ndigits, int var2weight)
10529 {
10530 	int			i1 = 0;
10531 	int			i2 = 0;
10532 
10533 	/* Check any digits before the first common digit */
10534 
10535 	while (var1weight > var2weight && i1 < var1ndigits)
10536 	{
10537 		if (var1digits[i1++] != 0)
10538 			return 1;
10539 		var1weight--;
10540 	}
10541 	while (var2weight > var1weight && i2 < var2ndigits)
10542 	{
10543 		if (var2digits[i2++] != 0)
10544 			return -1;
10545 		var2weight--;
10546 	}
10547 
10548 	/* At this point, either w1 == w2 or we've run out of digits */
10549 
10550 	if (var1weight == var2weight)
10551 	{
10552 		while (i1 < var1ndigits && i2 < var2ndigits)
10553 		{
10554 			int			stat = var1digits[i1++] - var2digits[i2++];
10555 
10556 			if (stat)
10557 			{
10558 				if (stat > 0)
10559 					return 1;
10560 				return -1;
10561 			}
10562 		}
10563 	}
10564 
10565 	/*
10566 	 * At this point, we've run out of digits on one side or the other; so any
10567 	 * remaining nonzero digits imply that side is larger
10568 	 */
10569 	while (i1 < var1ndigits)
10570 	{
10571 		if (var1digits[i1++] != 0)
10572 			return 1;
10573 	}
10574 	while (i2 < var2ndigits)
10575 	{
10576 		if (var2digits[i2++] != 0)
10577 			return -1;
10578 	}
10579 
10580 	return 0;
10581 }
10582 
10583 
10584 /*
10585  * add_abs() -
10586  *
10587  *	Add the absolute values of two variables into result.
10588  *	result might point to one of the operands without danger.
10589  */
10590 static void
add_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)10591 add_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
10592 {
10593 	NumericDigit *res_buf;
10594 	NumericDigit *res_digits;
10595 	int			res_ndigits;
10596 	int			res_weight;
10597 	int			res_rscale,
10598 				rscale1,
10599 				rscale2;
10600 	int			res_dscale;
10601 	int			i,
10602 				i1,
10603 				i2;
10604 	int			carry = 0;
10605 
10606 	/* copy these values into local vars for speed in inner loop */
10607 	int			var1ndigits = var1->ndigits;
10608 	int			var2ndigits = var2->ndigits;
10609 	NumericDigit *var1digits = var1->digits;
10610 	NumericDigit *var2digits = var2->digits;
10611 
10612 	res_weight = Max(var1->weight, var2->weight) + 1;
10613 
10614 	res_dscale = Max(var1->dscale, var2->dscale);
10615 
10616 	/* Note: here we are figuring rscale in base-NBASE digits */
10617 	rscale1 = var1->ndigits - var1->weight - 1;
10618 	rscale2 = var2->ndigits - var2->weight - 1;
10619 	res_rscale = Max(rscale1, rscale2);
10620 
10621 	res_ndigits = res_rscale + res_weight + 1;
10622 	if (res_ndigits <= 0)
10623 		res_ndigits = 1;
10624 
10625 	res_buf = digitbuf_alloc(res_ndigits + 1);
10626 	res_buf[0] = 0;				/* spare digit for later rounding */
10627 	res_digits = res_buf + 1;
10628 
10629 	i1 = res_rscale + var1->weight + 1;
10630 	i2 = res_rscale + var2->weight + 1;
10631 	for (i = res_ndigits - 1; i >= 0; i--)
10632 	{
10633 		i1--;
10634 		i2--;
10635 		if (i1 >= 0 && i1 < var1ndigits)
10636 			carry += var1digits[i1];
10637 		if (i2 >= 0 && i2 < var2ndigits)
10638 			carry += var2digits[i2];
10639 
10640 		if (carry >= NBASE)
10641 		{
10642 			res_digits[i] = carry - NBASE;
10643 			carry = 1;
10644 		}
10645 		else
10646 		{
10647 			res_digits[i] = carry;
10648 			carry = 0;
10649 		}
10650 	}
10651 
10652 	Assert(carry == 0);			/* else we failed to allow for carry out */
10653 
10654 	digitbuf_free(result->buf);
10655 	result->ndigits = res_ndigits;
10656 	result->buf = res_buf;
10657 	result->digits = res_digits;
10658 	result->weight = res_weight;
10659 	result->dscale = res_dscale;
10660 
10661 	/* Remove leading/trailing zeroes */
10662 	strip_var(result);
10663 }
10664 
10665 
10666 /*
10667  * sub_abs()
10668  *
10669  *	Subtract the absolute value of var2 from the absolute value of var1
10670  *	and store in result. result might point to one of the operands
10671  *	without danger.
10672  *
10673  *	ABS(var1) MUST BE GREATER OR EQUAL ABS(var2) !!!
10674  */
10675 static void
sub_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)10676 sub_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
10677 {
10678 	NumericDigit *res_buf;
10679 	NumericDigit *res_digits;
10680 	int			res_ndigits;
10681 	int			res_weight;
10682 	int			res_rscale,
10683 				rscale1,
10684 				rscale2;
10685 	int			res_dscale;
10686 	int			i,
10687 				i1,
10688 				i2;
10689 	int			borrow = 0;
10690 
10691 	/* copy these values into local vars for speed in inner loop */
10692 	int			var1ndigits = var1->ndigits;
10693 	int			var2ndigits = var2->ndigits;
10694 	NumericDigit *var1digits = var1->digits;
10695 	NumericDigit *var2digits = var2->digits;
10696 
10697 	res_weight = var1->weight;
10698 
10699 	res_dscale = Max(var1->dscale, var2->dscale);
10700 
10701 	/* Note: here we are figuring rscale in base-NBASE digits */
10702 	rscale1 = var1->ndigits - var1->weight - 1;
10703 	rscale2 = var2->ndigits - var2->weight - 1;
10704 	res_rscale = Max(rscale1, rscale2);
10705 
10706 	res_ndigits = res_rscale + res_weight + 1;
10707 	if (res_ndigits <= 0)
10708 		res_ndigits = 1;
10709 
10710 	res_buf = digitbuf_alloc(res_ndigits + 1);
10711 	res_buf[0] = 0;				/* spare digit for later rounding */
10712 	res_digits = res_buf + 1;
10713 
10714 	i1 = res_rscale + var1->weight + 1;
10715 	i2 = res_rscale + var2->weight + 1;
10716 	for (i = res_ndigits - 1; i >= 0; i--)
10717 	{
10718 		i1--;
10719 		i2--;
10720 		if (i1 >= 0 && i1 < var1ndigits)
10721 			borrow += var1digits[i1];
10722 		if (i2 >= 0 && i2 < var2ndigits)
10723 			borrow -= var2digits[i2];
10724 
10725 		if (borrow < 0)
10726 		{
10727 			res_digits[i] = borrow + NBASE;
10728 			borrow = -1;
10729 		}
10730 		else
10731 		{
10732 			res_digits[i] = borrow;
10733 			borrow = 0;
10734 		}
10735 	}
10736 
10737 	Assert(borrow == 0);		/* else caller gave us var1 < var2 */
10738 
10739 	digitbuf_free(result->buf);
10740 	result->ndigits = res_ndigits;
10741 	result->buf = res_buf;
10742 	result->digits = res_digits;
10743 	result->weight = res_weight;
10744 	result->dscale = res_dscale;
10745 
10746 	/* Remove leading/trailing zeroes */
10747 	strip_var(result);
10748 }
10749 
10750 /*
10751  * round_var
10752  *
10753  * Round the value of a variable to no more than rscale decimal digits
10754  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
10755  * rounding before the decimal point.
10756  */
10757 static void
round_var(NumericVar * var,int rscale)10758 round_var(NumericVar *var, int rscale)
10759 {
10760 	NumericDigit *digits = var->digits;
10761 	int			di;
10762 	int			ndigits;
10763 	int			carry;
10764 
10765 	var->dscale = rscale;
10766 
10767 	/* decimal digits wanted */
10768 	di = (var->weight + 1) * DEC_DIGITS + rscale;
10769 
10770 	/*
10771 	 * If di = 0, the value loses all digits, but could round up to 1 if its
10772 	 * first extra digit is >= 5.  If di < 0 the result must be 0.
10773 	 */
10774 	if (di < 0)
10775 	{
10776 		var->ndigits = 0;
10777 		var->weight = 0;
10778 		var->sign = NUMERIC_POS;
10779 	}
10780 	else
10781 	{
10782 		/* NBASE digits wanted */
10783 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
10784 
10785 		/* 0, or number of decimal digits to keep in last NBASE digit */
10786 		di %= DEC_DIGITS;
10787 
10788 		if (ndigits < var->ndigits ||
10789 			(ndigits == var->ndigits && di > 0))
10790 		{
10791 			var->ndigits = ndigits;
10792 
10793 #if DEC_DIGITS == 1
10794 			/* di must be zero */
10795 			carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
10796 #else
10797 			if (di == 0)
10798 				carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
10799 			else
10800 			{
10801 				/* Must round within last NBASE digit */
10802 				int			extra,
10803 							pow10;
10804 
10805 #if DEC_DIGITS == 4
10806 				pow10 = round_powers[di];
10807 #elif DEC_DIGITS == 2
10808 				pow10 = 10;
10809 #else
10810 #error unsupported NBASE
10811 #endif
10812 				extra = digits[--ndigits] % pow10;
10813 				digits[ndigits] -= extra;
10814 				carry = 0;
10815 				if (extra >= pow10 / 2)
10816 				{
10817 					pow10 += digits[ndigits];
10818 					if (pow10 >= NBASE)
10819 					{
10820 						pow10 -= NBASE;
10821 						carry = 1;
10822 					}
10823 					digits[ndigits] = pow10;
10824 				}
10825 			}
10826 #endif
10827 
10828 			/* Propagate carry if needed */
10829 			while (carry)
10830 			{
10831 				carry += digits[--ndigits];
10832 				if (carry >= NBASE)
10833 				{
10834 					digits[ndigits] = carry - NBASE;
10835 					carry = 1;
10836 				}
10837 				else
10838 				{
10839 					digits[ndigits] = carry;
10840 					carry = 0;
10841 				}
10842 			}
10843 
10844 			if (ndigits < 0)
10845 			{
10846 				Assert(ndigits == -1);	/* better not have added > 1 digit */
10847 				Assert(var->digits > var->buf);
10848 				var->digits--;
10849 				var->ndigits++;
10850 				var->weight++;
10851 			}
10852 		}
10853 	}
10854 }
10855 
10856 /*
10857  * trunc_var
10858  *
10859  * Truncate (towards zero) the value of a variable at rscale decimal digits
10860  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
10861  * truncation before the decimal point.
10862  */
10863 static void
trunc_var(NumericVar * var,int rscale)10864 trunc_var(NumericVar *var, int rscale)
10865 {
10866 	int			di;
10867 	int			ndigits;
10868 
10869 	var->dscale = rscale;
10870 
10871 	/* decimal digits wanted */
10872 	di = (var->weight + 1) * DEC_DIGITS + rscale;
10873 
10874 	/*
10875 	 * If di <= 0, the value loses all digits.
10876 	 */
10877 	if (di <= 0)
10878 	{
10879 		var->ndigits = 0;
10880 		var->weight = 0;
10881 		var->sign = NUMERIC_POS;
10882 	}
10883 	else
10884 	{
10885 		/* NBASE digits wanted */
10886 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
10887 
10888 		if (ndigits <= var->ndigits)
10889 		{
10890 			var->ndigits = ndigits;
10891 
10892 #if DEC_DIGITS == 1
10893 			/* no within-digit stuff to worry about */
10894 #else
10895 			/* 0, or number of decimal digits to keep in last NBASE digit */
10896 			di %= DEC_DIGITS;
10897 
10898 			if (di > 0)
10899 			{
10900 				/* Must truncate within last NBASE digit */
10901 				NumericDigit *digits = var->digits;
10902 				int			extra,
10903 							pow10;
10904 
10905 #if DEC_DIGITS == 4
10906 				pow10 = round_powers[di];
10907 #elif DEC_DIGITS == 2
10908 				pow10 = 10;
10909 #else
10910 #error unsupported NBASE
10911 #endif
10912 				extra = digits[--ndigits] % pow10;
10913 				digits[ndigits] -= extra;
10914 			}
10915 #endif
10916 		}
10917 	}
10918 }
10919 
10920 /*
10921  * strip_var
10922  *
10923  * Strip any leading and trailing zeroes from a numeric variable
10924  */
10925 static void
strip_var(NumericVar * var)10926 strip_var(NumericVar *var)
10927 {
10928 	NumericDigit *digits = var->digits;
10929 	int			ndigits = var->ndigits;
10930 
10931 	/* Strip leading zeroes */
10932 	while (ndigits > 0 && *digits == 0)
10933 	{
10934 		digits++;
10935 		var->weight--;
10936 		ndigits--;
10937 	}
10938 
10939 	/* Strip trailing zeroes */
10940 	while (ndigits > 0 && digits[ndigits - 1] == 0)
10941 		ndigits--;
10942 
10943 	/* If it's zero, normalize the sign and weight */
10944 	if (ndigits == 0)
10945 	{
10946 		var->sign = NUMERIC_POS;
10947 		var->weight = 0;
10948 	}
10949 
10950 	var->digits = digits;
10951 	var->ndigits = ndigits;
10952 }
10953 
10954 
10955 /* ----------------------------------------------------------------------
10956  *
10957  * Fast sum accumulator functions
10958  *
10959  * ----------------------------------------------------------------------
10960  */
10961 
10962 /*
10963  * Reset the accumulator's value to zero.  The buffers to hold the digits
10964  * are not free'd.
10965  */
10966 static void
accum_sum_reset(NumericSumAccum * accum)10967 accum_sum_reset(NumericSumAccum *accum)
10968 {
10969 	int			i;
10970 
10971 	accum->dscale = 0;
10972 	for (i = 0; i < accum->ndigits; i++)
10973 	{
10974 		accum->pos_digits[i] = 0;
10975 		accum->neg_digits[i] = 0;
10976 	}
10977 }
10978 
10979 /*
10980  * Accumulate a new value.
10981  */
10982 static void
accum_sum_add(NumericSumAccum * accum,const NumericVar * val)10983 accum_sum_add(NumericSumAccum *accum, const NumericVar *val)
10984 {
10985 	int32	   *accum_digits;
10986 	int			i,
10987 				val_i;
10988 	int			val_ndigits;
10989 	NumericDigit *val_digits;
10990 
10991 	/*
10992 	 * If we have accumulated too many values since the last carry
10993 	 * propagation, do it now, to avoid overflowing.  (We could allow more
10994 	 * than NBASE - 1, if we reserved two extra digits, rather than one, for
10995 	 * carry propagation.  But even with NBASE - 1, this needs to be done so
10996 	 * seldom, that the performance difference is negligible.)
10997 	 */
10998 	if (accum->num_uncarried == NBASE - 1)
10999 		accum_sum_carry(accum);
11000 
11001 	/*
11002 	 * Adjust the weight or scale of the old value, so that it can accommodate
11003 	 * the new value.
11004 	 */
11005 	accum_sum_rescale(accum, val);
11006 
11007 	/* */
11008 	if (val->sign == NUMERIC_POS)
11009 		accum_digits = accum->pos_digits;
11010 	else
11011 		accum_digits = accum->neg_digits;
11012 
11013 	/* copy these values into local vars for speed in loop */
11014 	val_ndigits = val->ndigits;
11015 	val_digits = val->digits;
11016 
11017 	i = accum->weight - val->weight;
11018 	for (val_i = 0; val_i < val_ndigits; val_i++)
11019 	{
11020 		accum_digits[i] += (int32) val_digits[val_i];
11021 		i++;
11022 	}
11023 
11024 	accum->num_uncarried++;
11025 }
11026 
11027 /*
11028  * Propagate carries.
11029  */
11030 static void
accum_sum_carry(NumericSumAccum * accum)11031 accum_sum_carry(NumericSumAccum *accum)
11032 {
11033 	int			i;
11034 	int			ndigits;
11035 	int32	   *dig;
11036 	int32		carry;
11037 	int32		newdig = 0;
11038 
11039 	/*
11040 	 * If no new values have been added since last carry propagation, nothing
11041 	 * to do.
11042 	 */
11043 	if (accum->num_uncarried == 0)
11044 		return;
11045 
11046 	/*
11047 	 * We maintain that the weight of the accumulator is always one larger
11048 	 * than needed to hold the current value, before carrying, to make sure
11049 	 * there is enough space for the possible extra digit when carry is
11050 	 * propagated.  We cannot expand the buffer here, unless we require
11051 	 * callers of accum_sum_final() to switch to the right memory context.
11052 	 */
11053 	Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
11054 
11055 	ndigits = accum->ndigits;
11056 
11057 	/* Propagate carry in the positive sum */
11058 	dig = accum->pos_digits;
11059 	carry = 0;
11060 	for (i = ndigits - 1; i >= 0; i--)
11061 	{
11062 		newdig = dig[i] + carry;
11063 		if (newdig >= NBASE)
11064 		{
11065 			carry = newdig / NBASE;
11066 			newdig -= carry * NBASE;
11067 		}
11068 		else
11069 			carry = 0;
11070 		dig[i] = newdig;
11071 	}
11072 	/* Did we use up the digit reserved for carry propagation? */
11073 	if (newdig > 0)
11074 		accum->have_carry_space = false;
11075 
11076 	/* And the same for the negative sum */
11077 	dig = accum->neg_digits;
11078 	carry = 0;
11079 	for (i = ndigits - 1; i >= 0; i--)
11080 	{
11081 		newdig = dig[i] + carry;
11082 		if (newdig >= NBASE)
11083 		{
11084 			carry = newdig / NBASE;
11085 			newdig -= carry * NBASE;
11086 		}
11087 		else
11088 			carry = 0;
11089 		dig[i] = newdig;
11090 	}
11091 	if (newdig > 0)
11092 		accum->have_carry_space = false;
11093 
11094 	accum->num_uncarried = 0;
11095 }
11096 
11097 /*
11098  * Re-scale accumulator to accommodate new value.
11099  *
11100  * If the new value has more digits than the current digit buffers in the
11101  * accumulator, enlarge the buffers.
11102  */
11103 static void
accum_sum_rescale(NumericSumAccum * accum,const NumericVar * val)11104 accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val)
11105 {
11106 	int			old_weight = accum->weight;
11107 	int			old_ndigits = accum->ndigits;
11108 	int			accum_ndigits;
11109 	int			accum_weight;
11110 	int			accum_rscale;
11111 	int			val_rscale;
11112 
11113 	accum_weight = old_weight;
11114 	accum_ndigits = old_ndigits;
11115 
11116 	/*
11117 	 * Does the new value have a larger weight? If so, enlarge the buffers,
11118 	 * and shift the existing value to the new weight, by adding leading
11119 	 * zeros.
11120 	 *
11121 	 * We enforce that the accumulator always has a weight one larger than
11122 	 * needed for the inputs, so that we have space for an extra digit at the
11123 	 * final carry-propagation phase, if necessary.
11124 	 */
11125 	if (val->weight >= accum_weight)
11126 	{
11127 		accum_weight = val->weight + 1;
11128 		accum_ndigits = accum_ndigits + (accum_weight - old_weight);
11129 	}
11130 
11131 	/*
11132 	 * Even though the new value is small, we might've used up the space
11133 	 * reserved for the carry digit in the last call to accum_sum_carry().  If
11134 	 * so, enlarge to make room for another one.
11135 	 */
11136 	else if (!accum->have_carry_space)
11137 	{
11138 		accum_weight++;
11139 		accum_ndigits++;
11140 	}
11141 
11142 	/* Is the new value wider on the right side? */
11143 	accum_rscale = accum_ndigits - accum_weight - 1;
11144 	val_rscale = val->ndigits - val->weight - 1;
11145 	if (val_rscale > accum_rscale)
11146 		accum_ndigits = accum_ndigits + (val_rscale - accum_rscale);
11147 
11148 	if (accum_ndigits != old_ndigits ||
11149 		accum_weight != old_weight)
11150 	{
11151 		int32	   *new_pos_digits;
11152 		int32	   *new_neg_digits;
11153 		int			weightdiff;
11154 
11155 		weightdiff = accum_weight - old_weight;
11156 
11157 		new_pos_digits = palloc0(accum_ndigits * sizeof(int32));
11158 		new_neg_digits = palloc0(accum_ndigits * sizeof(int32));
11159 
11160 		if (accum->pos_digits)
11161 		{
11162 			memcpy(&new_pos_digits[weightdiff], accum->pos_digits,
11163 				   old_ndigits * sizeof(int32));
11164 			pfree(accum->pos_digits);
11165 
11166 			memcpy(&new_neg_digits[weightdiff], accum->neg_digits,
11167 				   old_ndigits * sizeof(int32));
11168 			pfree(accum->neg_digits);
11169 		}
11170 
11171 		accum->pos_digits = new_pos_digits;
11172 		accum->neg_digits = new_neg_digits;
11173 
11174 		accum->weight = accum_weight;
11175 		accum->ndigits = accum_ndigits;
11176 
11177 		Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
11178 		accum->have_carry_space = true;
11179 	}
11180 
11181 	if (val->dscale > accum->dscale)
11182 		accum->dscale = val->dscale;
11183 }
11184 
11185 /*
11186  * Return the current value of the accumulator.  This perform final carry
11187  * propagation, and adds together the positive and negative sums.
11188  *
11189  * Unlike all the other routines, the caller is not required to switch to
11190  * the memory context that holds the accumulator.
11191  */
11192 static void
accum_sum_final(NumericSumAccum * accum,NumericVar * result)11193 accum_sum_final(NumericSumAccum *accum, NumericVar *result)
11194 {
11195 	int			i;
11196 	NumericVar	pos_var;
11197 	NumericVar	neg_var;
11198 
11199 	if (accum->ndigits == 0)
11200 	{
11201 		set_var_from_var(&const_zero, result);
11202 		return;
11203 	}
11204 
11205 	/* Perform final carry */
11206 	accum_sum_carry(accum);
11207 
11208 	/* Create NumericVars representing the positive and negative sums */
11209 	init_var(&pos_var);
11210 	init_var(&neg_var);
11211 
11212 	pos_var.ndigits = neg_var.ndigits = accum->ndigits;
11213 	pos_var.weight = neg_var.weight = accum->weight;
11214 	pos_var.dscale = neg_var.dscale = accum->dscale;
11215 	pos_var.sign = NUMERIC_POS;
11216 	neg_var.sign = NUMERIC_NEG;
11217 
11218 	pos_var.buf = pos_var.digits = digitbuf_alloc(accum->ndigits);
11219 	neg_var.buf = neg_var.digits = digitbuf_alloc(accum->ndigits);
11220 
11221 	for (i = 0; i < accum->ndigits; i++)
11222 	{
11223 		Assert(accum->pos_digits[i] < NBASE);
11224 		pos_var.digits[i] = (int16) accum->pos_digits[i];
11225 
11226 		Assert(accum->neg_digits[i] < NBASE);
11227 		neg_var.digits[i] = (int16) accum->neg_digits[i];
11228 	}
11229 
11230 	/* And add them together */
11231 	add_var(&pos_var, &neg_var, result);
11232 
11233 	/* Remove leading/trailing zeroes */
11234 	strip_var(result);
11235 }
11236 
11237 /*
11238  * Copy an accumulator's state.
11239  *
11240  * 'dst' is assumed to be uninitialized beforehand.  No attempt is made at
11241  * freeing old values.
11242  */
11243 static void
accum_sum_copy(NumericSumAccum * dst,NumericSumAccum * src)11244 accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src)
11245 {
11246 	dst->pos_digits = palloc(src->ndigits * sizeof(int32));
11247 	dst->neg_digits = palloc(src->ndigits * sizeof(int32));
11248 
11249 	memcpy(dst->pos_digits, src->pos_digits, src->ndigits * sizeof(int32));
11250 	memcpy(dst->neg_digits, src->neg_digits, src->ndigits * sizeof(int32));
11251 	dst->num_uncarried = src->num_uncarried;
11252 	dst->ndigits = src->ndigits;
11253 	dst->weight = src->weight;
11254 	dst->dscale = src->dscale;
11255 }
11256 
11257 /*
11258  * Add the current value of 'accum2' into 'accum'.
11259  */
11260 static void
accum_sum_combine(NumericSumAccum * accum,NumericSumAccum * accum2)11261 accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2)
11262 {
11263 	NumericVar	tmp_var;
11264 
11265 	init_var(&tmp_var);
11266 
11267 	accum_sum_final(accum2, &tmp_var);
11268 	accum_sum_add(accum, &tmp_var);
11269 
11270 	free_var(&tmp_var);
11271 }
11272