1 /*-------------------------------------------------------------------------
2  *
3  * numeric.c
4  *	  An exact numeric data type for the Postgres database system
5  *
6  * Original coding 1998, Jan Wieck.  Heavily revised 2003, Tom Lane.
7  *
8  * Many of the algorithmic ideas are borrowed from David M. Smith's "FM"
9  * multiple-precision math library, most recently published as Algorithm
10  * 786: Multiple-Precision Complex Arithmetic and Functions, ACM
11  * Transactions on Mathematical Software, Vol. 24, No. 4, December 1998,
12  * pages 359-367.
13  *
14  * Copyright (c) 1998-2019, PostgreSQL Global Development Group
15  *
16  * IDENTIFICATION
17  *	  src/backend/utils/adt/numeric.c
18  *
19  *-------------------------------------------------------------------------
20  */
21 
22 #include "postgres.h"
23 
24 #include <ctype.h>
25 #include <float.h>
26 #include <limits.h>
27 #include <math.h>
28 
29 #include "catalog/pg_type.h"
30 #include "common/int.h"
31 #include "funcapi.h"
32 #include "lib/hyperloglog.h"
33 #include "libpq/pqformat.h"
34 #include "miscadmin.h"
35 #include "nodes/nodeFuncs.h"
36 #include "nodes/supportnodes.h"
37 #include "utils/array.h"
38 #include "utils/builtins.h"
39 #include "utils/float.h"
40 #include "utils/guc.h"
41 #include "utils/hashutils.h"
42 #include "utils/int8.h"
43 #include "utils/numeric.h"
44 #include "utils/sortsupport.h"
45 
46 /* ----------
47  * Uncomment the following to enable compilation of dump_numeric()
48  * and dump_var() and to get a dump of any result produced by make_result().
49  * ----------
50 #define NUMERIC_DEBUG
51  */
52 
53 
54 /* ----------
55  * Local data types
56  *
57  * Numeric values are represented in a base-NBASE floating point format.
58  * Each "digit" ranges from 0 to NBASE-1.  The type NumericDigit is signed
59  * and wide enough to store a digit.  We assume that NBASE*NBASE can fit in
60  * an int.  Although the purely calculational routines could handle any even
61  * NBASE that's less than sqrt(INT_MAX), in practice we are only interested
62  * in NBASE a power of ten, so that I/O conversions and decimal rounding
63  * are easy.  Also, it's actually more efficient if NBASE is rather less than
64  * sqrt(INT_MAX), so that there is "headroom" for mul_var and div_var_fast to
65  * postpone processing carries.
66  *
67  * Values of NBASE other than 10000 are considered of historical interest only
68  * and are no longer supported in any sense; no mechanism exists for the client
69  * to discover the base, so every client supporting binary mode expects the
70  * base-10000 format.  If you plan to change this, also note the numeric
71  * abbreviation code, which assumes NBASE=10000.
72  * ----------
73  */
74 
75 #if 0
76 #define NBASE		10
77 #define HALF_NBASE	5
78 #define DEC_DIGITS	1			/* decimal digits per NBASE digit */
79 #define MUL_GUARD_DIGITS	4	/* these are measured in NBASE digits */
80 #define DIV_GUARD_DIGITS	8
81 
82 typedef signed char NumericDigit;
83 #endif
84 
85 #if 0
86 #define NBASE		100
87 #define HALF_NBASE	50
88 #define DEC_DIGITS	2			/* decimal digits per NBASE digit */
89 #define MUL_GUARD_DIGITS	3	/* these are measured in NBASE digits */
90 #define DIV_GUARD_DIGITS	6
91 
92 typedef signed char NumericDigit;
93 #endif
94 
95 #if 1
96 #define NBASE		10000
97 #define HALF_NBASE	5000
98 #define DEC_DIGITS	4			/* decimal digits per NBASE digit */
99 #define MUL_GUARD_DIGITS	2	/* these are measured in NBASE digits */
100 #define DIV_GUARD_DIGITS	4
101 
102 typedef int16 NumericDigit;
103 #endif
104 
105 /*
106  * The Numeric type as stored on disk.
107  *
108  * If the high bits of the first word of a NumericChoice (n_header, or
109  * n_short.n_header, or n_long.n_sign_dscale) are NUMERIC_SHORT, then the
110  * numeric follows the NumericShort format; if they are NUMERIC_POS or
111  * NUMERIC_NEG, it follows the NumericLong format.  If they are NUMERIC_NAN,
112  * it is a NaN.  We currently always store a NaN using just two bytes (i.e.
113  * only n_header), but previous releases used only the NumericLong format,
114  * so we might find 4-byte NaNs on disk if a database has been migrated using
115  * pg_upgrade.  In either case, when the high bits indicate a NaN, the
116  * remaining bits are never examined.  Currently, we always initialize these
117  * to zero, but it might be possible to use them for some other purpose in
118  * the future.
119  *
120  * In the NumericShort format, the remaining 14 bits of the header word
121  * (n_short.n_header) are allocated as follows: 1 for sign (positive or
122  * negative), 6 for dynamic scale, and 7 for weight.  In practice, most
123  * commonly-encountered values can be represented this way.
124  *
125  * In the NumericLong format, the remaining 14 bits of the header word
126  * (n_long.n_sign_dscale) represent the display scale; and the weight is
127  * stored separately in n_weight.
128  *
129  * NOTE: by convention, values in the packed form have been stripped of
130  * all leading and trailing zero digits (where a "digit" is of base NBASE).
131  * In particular, if the value is zero, there will be no digits at all!
132  * The weight is arbitrary in that case, but we normally set it to zero.
133  */
134 
135 struct NumericShort
136 {
137 	uint16		n_header;		/* Sign + display scale + weight */
138 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
139 };
140 
141 struct NumericLong
142 {
143 	uint16		n_sign_dscale;	/* Sign + display scale */
144 	int16		n_weight;		/* Weight of 1st digit	*/
145 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
146 };
147 
148 union NumericChoice
149 {
150 	uint16		n_header;		/* Header word */
151 	struct NumericLong n_long;	/* Long form (4-byte header) */
152 	struct NumericShort n_short;	/* Short form (2-byte header) */
153 };
154 
155 struct NumericData
156 {
157 	int32		vl_len_;		/* varlena header (do not touch directly!) */
158 	union NumericChoice choice; /* choice of format */
159 };
160 
161 
162 /*
163  * Interpretation of high bits.
164  */
165 
166 #define NUMERIC_SIGN_MASK	0xC000
167 #define NUMERIC_POS			0x0000
168 #define NUMERIC_NEG			0x4000
169 #define NUMERIC_SHORT		0x8000
170 #define NUMERIC_NAN			0xC000
171 
172 #define NUMERIC_FLAGBITS(n) ((n)->choice.n_header & NUMERIC_SIGN_MASK)
173 #define NUMERIC_IS_NAN(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_NAN)
174 #define NUMERIC_IS_SHORT(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_SHORT)
175 
176 #define NUMERIC_HDRSZ	(VARHDRSZ + sizeof(uint16) + sizeof(int16))
177 #define NUMERIC_HDRSZ_SHORT (VARHDRSZ + sizeof(uint16))
178 
179 /*
180  * If the flag bits are NUMERIC_SHORT or NUMERIC_NAN, we want the short header;
181  * otherwise, we want the long one.  Instead of testing against each value, we
182  * can just look at the high bit, for a slight efficiency gain.
183  */
184 #define NUMERIC_HEADER_IS_SHORT(n)	(((n)->choice.n_header & 0x8000) != 0)
185 #define NUMERIC_HEADER_SIZE(n) \
186 	(VARHDRSZ + sizeof(uint16) + \
187 	 (NUMERIC_HEADER_IS_SHORT(n) ? 0 : sizeof(int16)))
188 
189 /*
190  * Short format definitions.
191  */
192 
193 #define NUMERIC_SHORT_SIGN_MASK			0x2000
194 #define NUMERIC_SHORT_DSCALE_MASK		0x1F80
195 #define NUMERIC_SHORT_DSCALE_SHIFT		7
196 #define NUMERIC_SHORT_DSCALE_MAX		\
197 	(NUMERIC_SHORT_DSCALE_MASK >> NUMERIC_SHORT_DSCALE_SHIFT)
198 #define NUMERIC_SHORT_WEIGHT_SIGN_MASK	0x0040
199 #define NUMERIC_SHORT_WEIGHT_MASK		0x003F
200 #define NUMERIC_SHORT_WEIGHT_MAX		NUMERIC_SHORT_WEIGHT_MASK
201 #define NUMERIC_SHORT_WEIGHT_MIN		(-(NUMERIC_SHORT_WEIGHT_MASK+1))
202 
203 /*
204  * Extract sign, display scale, weight.
205  */
206 
207 #define NUMERIC_DSCALE_MASK			0x3FFF
208 #define NUMERIC_DSCALE_MAX			NUMERIC_DSCALE_MASK
209 
210 #define NUMERIC_SIGN(n) \
211 	(NUMERIC_IS_SHORT(n) ? \
212 		(((n)->choice.n_short.n_header & NUMERIC_SHORT_SIGN_MASK) ? \
213 		NUMERIC_NEG : NUMERIC_POS) : NUMERIC_FLAGBITS(n))
214 #define NUMERIC_DSCALE(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
215 	((n)->choice.n_short.n_header & NUMERIC_SHORT_DSCALE_MASK) \
216 		>> NUMERIC_SHORT_DSCALE_SHIFT \
217 	: ((n)->choice.n_long.n_sign_dscale & NUMERIC_DSCALE_MASK))
218 #define NUMERIC_WEIGHT(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
219 	(((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_SIGN_MASK ? \
220 		~NUMERIC_SHORT_WEIGHT_MASK : 0) \
221 	 | ((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_MASK)) \
222 	: ((n)->choice.n_long.n_weight))
223 
224 /* ----------
225  * NumericVar is the format we use for arithmetic.  The digit-array part
226  * is the same as the NumericData storage format, but the header is more
227  * complex.
228  *
229  * The value represented by a NumericVar is determined by the sign, weight,
230  * ndigits, and digits[] array.
231  *
232  * Note: the first digit of a NumericVar's value is assumed to be multiplied
233  * by NBASE ** weight.  Another way to say it is that there are weight+1
234  * digits before the decimal point.  It is possible to have weight < 0.
235  *
236  * buf points at the physical start of the palloc'd digit buffer for the
237  * NumericVar.  digits points at the first digit in actual use (the one
238  * with the specified weight).  We normally leave an unused digit or two
239  * (preset to zeroes) between buf and digits, so that there is room to store
240  * a carry out of the top digit without reallocating space.  We just need to
241  * decrement digits (and increment weight) to make room for the carry digit.
242  * (There is no such extra space in a numeric value stored in the database,
243  * only in a NumericVar in memory.)
244  *
245  * If buf is NULL then the digit buffer isn't actually palloc'd and should
246  * not be freed --- see the constants below for an example.
247  *
248  * dscale, or display scale, is the nominal precision expressed as number
249  * of digits after the decimal point (it must always be >= 0 at present).
250  * dscale may be more than the number of physically stored fractional digits,
251  * implying that we have suppressed storage of significant trailing zeroes.
252  * It should never be less than the number of stored digits, since that would
253  * imply hiding digits that are present.  NOTE that dscale is always expressed
254  * in *decimal* digits, and so it may correspond to a fractional number of
255  * base-NBASE digits --- divide by DEC_DIGITS to convert to NBASE digits.
256  *
257  * rscale, or result scale, is the target precision for a computation.
258  * Like dscale it is expressed as number of *decimal* digits after the decimal
259  * point, and is always >= 0 at present.
260  * Note that rscale is not stored in variables --- it's figured on-the-fly
261  * from the dscales of the inputs.
262  *
263  * While we consistently use "weight" to refer to the base-NBASE weight of
264  * a numeric value, it is convenient in some scale-related calculations to
265  * make use of the base-10 weight (ie, the approximate log10 of the value).
266  * To avoid confusion, such a decimal-units weight is called a "dweight".
267  *
268  * NB: All the variable-level functions are written in a style that makes it
269  * possible to give one and the same variable as argument and destination.
270  * This is feasible because the digit buffer is separate from the variable.
271  * ----------
272  */
273 typedef struct NumericVar
274 {
275 	int			ndigits;		/* # of digits in digits[] - can be 0! */
276 	int			weight;			/* weight of first digit */
277 	int			sign;			/* NUMERIC_POS, NUMERIC_NEG, or NUMERIC_NAN */
278 	int			dscale;			/* display scale */
279 	NumericDigit *buf;			/* start of palloc'd space for digits[] */
280 	NumericDigit *digits;		/* base-NBASE digits */
281 } NumericVar;
282 
283 
284 /* ----------
285  * Data for generate_series
286  * ----------
287  */
288 typedef struct
289 {
290 	NumericVar	current;
291 	NumericVar	stop;
292 	NumericVar	step;
293 } generate_series_numeric_fctx;
294 
295 
296 /* ----------
297  * Sort support.
298  * ----------
299  */
300 typedef struct
301 {
302 	void	   *buf;			/* buffer for short varlenas */
303 	int64		input_count;	/* number of non-null values seen */
304 	bool		estimating;		/* true if estimating cardinality */
305 
306 	hyperLogLogState abbr_card; /* cardinality estimator */
307 } NumericSortSupport;
308 
309 
310 /* ----------
311  * Fast sum accumulator.
312  *
313  * NumericSumAccum is used to implement SUM(), and other standard aggregates
314  * that track the sum of input values.  It uses 32-bit integers to store the
315  * digits, instead of the normal 16-bit integers (with NBASE=10000).  This
316  * way, we can safely accumulate up to NBASE - 1 values without propagating
317  * carry, before risking overflow of any of the digits.  'num_uncarried'
318  * tracks how many values have been accumulated without propagating carry.
319  *
320  * Positive and negative values are accumulated separately, in 'pos_digits'
321  * and 'neg_digits'.  This is simpler and faster than deciding whether to add
322  * or subtract from the current value, for each new value (see sub_var() for
323  * the logic we avoid by doing this).  Both buffers are of same size, and
324  * have the same weight and scale.  In accum_sum_final(), the positive and
325  * negative sums are added together to produce the final result.
326  *
327  * When a new value has a larger ndigits or weight than the accumulator
328  * currently does, the accumulator is enlarged to accommodate the new value.
329  * We normally have one zero digit reserved for carry propagation, and that
330  * is indicated by the 'have_carry_space' flag.  When accum_sum_carry() uses
331  * up the reserved digit, it clears the 'have_carry_space' flag.  The next
332  * call to accum_sum_add() will enlarge the buffer, to make room for the
333  * extra digit, and set the flag again.
334  *
335  * To initialize a new accumulator, simply reset all fields to zeros.
336  *
337  * The accumulator does not handle NaNs.
338  * ----------
339  */
340 typedef struct NumericSumAccum
341 {
342 	int			ndigits;
343 	int			weight;
344 	int			dscale;
345 	int			num_uncarried;
346 	bool		have_carry_space;
347 	int32	   *pos_digits;
348 	int32	   *neg_digits;
349 } NumericSumAccum;
350 
351 
352 /*
353  * We define our own macros for packing and unpacking abbreviated-key
354  * representations for numeric values in order to avoid depending on
355  * USE_FLOAT8_BYVAL.  The type of abbreviation we use is based only on
356  * the size of a datum, not the argument-passing convention for float8.
357  */
358 #define NUMERIC_ABBREV_BITS (SIZEOF_DATUM * BITS_PER_BYTE)
359 #if SIZEOF_DATUM == 8
360 #define NumericAbbrevGetDatum(X) ((Datum) (X))
361 #define DatumGetNumericAbbrev(X) ((int64) (X))
362 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT64_MIN)
363 #else
364 #define NumericAbbrevGetDatum(X) ((Datum) (X))
365 #define DatumGetNumericAbbrev(X) ((int32) (X))
366 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT32_MIN)
367 #endif
368 
369 
370 /* ----------
371  * Some preinitialized constants
372  * ----------
373  */
374 static const NumericDigit const_zero_data[1] = {0};
375 static const NumericVar const_zero =
376 {0, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_zero_data};
377 
378 static const NumericDigit const_one_data[1] = {1};
379 static const NumericVar const_one =
380 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_one_data};
381 
382 static const NumericDigit const_two_data[1] = {2};
383 static const NumericVar const_two =
384 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_two_data};
385 
386 #if DEC_DIGITS == 4
387 static const NumericDigit const_zero_point_five_data[1] = {5000};
388 #elif DEC_DIGITS == 2
389 static const NumericDigit const_zero_point_five_data[1] = {50};
390 #elif DEC_DIGITS == 1
391 static const NumericDigit const_zero_point_five_data[1] = {5};
392 #endif
393 static const NumericVar const_zero_point_five =
394 {1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
395 
396 #if DEC_DIGITS == 4
397 static const NumericDigit const_zero_point_nine_data[1] = {9000};
398 #elif DEC_DIGITS == 2
399 static const NumericDigit const_zero_point_nine_data[1] = {90};
400 #elif DEC_DIGITS == 1
401 static const NumericDigit const_zero_point_nine_data[1] = {9};
402 #endif
403 static const NumericVar const_zero_point_nine =
404 {1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_nine_data};
405 
406 #if DEC_DIGITS == 4
407 static const NumericDigit const_one_point_one_data[2] = {1, 1000};
408 #elif DEC_DIGITS == 2
409 static const NumericDigit const_one_point_one_data[2] = {1, 10};
410 #elif DEC_DIGITS == 1
411 static const NumericDigit const_one_point_one_data[2] = {1, 1};
412 #endif
413 static const NumericVar const_one_point_one =
414 {2, 0, NUMERIC_POS, 1, NULL, (NumericDigit *) const_one_point_one_data};
415 
416 static const NumericVar const_nan =
417 {0, 0, NUMERIC_NAN, 0, NULL, NULL};
418 
419 #if DEC_DIGITS == 4
420 static const int round_powers[4] = {0, 1000, 100, 10};
421 #endif
422 
423 
424 /* ----------
425  * Local functions
426  * ----------
427  */
428 
429 #ifdef NUMERIC_DEBUG
430 static void dump_numeric(const char *str, Numeric num);
431 static void dump_var(const char *str, NumericVar *var);
432 #else
433 #define dump_numeric(s,n)
434 #define dump_var(s,v)
435 #endif
436 
437 #define digitbuf_alloc(ndigits)  \
438 	((NumericDigit *) palloc((ndigits) * sizeof(NumericDigit)))
439 #define digitbuf_free(buf)	\
440 	do { \
441 		 if ((buf) != NULL) \
442 			 pfree(buf); \
443 	} while (0)
444 
445 #define init_var(v)		MemSetAligned(v, 0, sizeof(NumericVar))
446 
447 #define NUMERIC_DIGITS(num) (NUMERIC_HEADER_IS_SHORT(num) ? \
448 	(num)->choice.n_short.n_data : (num)->choice.n_long.n_data)
449 #define NUMERIC_NDIGITS(num) \
450 	((VARSIZE(num) - NUMERIC_HEADER_SIZE(num)) / sizeof(NumericDigit))
451 #define NUMERIC_CAN_BE_SHORT(scale,weight) \
452 	((scale) <= NUMERIC_SHORT_DSCALE_MAX && \
453 	(weight) <= NUMERIC_SHORT_WEIGHT_MAX && \
454 	(weight) >= NUMERIC_SHORT_WEIGHT_MIN)
455 
456 static void alloc_var(NumericVar *var, int ndigits);
457 static void free_var(NumericVar *var);
458 static void zero_var(NumericVar *var);
459 
460 static const char *set_var_from_str(const char *str, const char *cp,
461 									NumericVar *dest);
462 static void set_var_from_num(Numeric value, NumericVar *dest);
463 static void init_var_from_num(Numeric num, NumericVar *dest);
464 static void set_var_from_var(const NumericVar *value, NumericVar *dest);
465 static char *get_str_from_var(const NumericVar *var);
466 static char *get_str_from_var_sci(const NumericVar *var, int rscale);
467 
468 static Numeric make_result(const NumericVar *var);
469 static Numeric make_result_opt_error(const NumericVar *var, bool *error);
470 
471 static void apply_typmod(NumericVar *var, int32 typmod);
472 
473 static bool numericvar_to_int32(const NumericVar *var, int32 *result);
474 static bool numericvar_to_int64(const NumericVar *var, int64 *result);
475 static void int64_to_numericvar(int64 val, NumericVar *var);
476 #ifdef HAVE_INT128
477 static bool numericvar_to_int128(const NumericVar *var, int128 *result);
478 static void int128_to_numericvar(int128 val, NumericVar *var);
479 #endif
480 static double numeric_to_double_no_overflow(Numeric num);
481 static double numericvar_to_double_no_overflow(const NumericVar *var);
482 
483 static Datum numeric_abbrev_convert(Datum original_datum, SortSupport ssup);
484 static bool numeric_abbrev_abort(int memtupcount, SortSupport ssup);
485 static int	numeric_fast_cmp(Datum x, Datum y, SortSupport ssup);
486 static int	numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup);
487 
488 static Datum numeric_abbrev_convert_var(const NumericVar *var,
489 										NumericSortSupport *nss);
490 
491 static int	cmp_numerics(Numeric num1, Numeric num2);
492 static int	cmp_var(const NumericVar *var1, const NumericVar *var2);
493 static int	cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
494 						   int var1weight, int var1sign,
495 						   const NumericDigit *var2digits, int var2ndigits,
496 						   int var2weight, int var2sign);
497 static void add_var(const NumericVar *var1, const NumericVar *var2,
498 					NumericVar *result);
499 static void sub_var(const NumericVar *var1, const NumericVar *var2,
500 					NumericVar *result);
501 static void mul_var(const NumericVar *var1, const NumericVar *var2,
502 					NumericVar *result,
503 					int rscale);
504 static void div_var(const NumericVar *var1, const NumericVar *var2,
505 					NumericVar *result,
506 					int rscale, bool round);
507 static void div_var_fast(const NumericVar *var1, const NumericVar *var2,
508 						 NumericVar *result, int rscale, bool round);
509 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
510 static void mod_var(const NumericVar *var1, const NumericVar *var2,
511 					NumericVar *result);
512 static void ceil_var(const NumericVar *var, NumericVar *result);
513 static void floor_var(const NumericVar *var, NumericVar *result);
514 
515 static void sqrt_var(const NumericVar *arg, NumericVar *result, int rscale);
516 static void exp_var(const NumericVar *arg, NumericVar *result, int rscale);
517 static int	estimate_ln_dweight(const NumericVar *var);
518 static void ln_var(const NumericVar *arg, NumericVar *result, int rscale);
519 static void log_var(const NumericVar *base, const NumericVar *num,
520 					NumericVar *result);
521 static void power_var(const NumericVar *base, const NumericVar *exp,
522 					  NumericVar *result);
523 static void power_var_int(const NumericVar *base, int exp, NumericVar *result,
524 						  int rscale);
525 static void power_ten_int(int exp, NumericVar *result);
526 
527 static int	cmp_abs(const NumericVar *var1, const NumericVar *var2);
528 static int	cmp_abs_common(const NumericDigit *var1digits, int var1ndigits,
529 						   int var1weight,
530 						   const NumericDigit *var2digits, int var2ndigits,
531 						   int var2weight);
532 static void add_abs(const NumericVar *var1, const NumericVar *var2,
533 					NumericVar *result);
534 static void sub_abs(const NumericVar *var1, const NumericVar *var2,
535 					NumericVar *result);
536 static void round_var(NumericVar *var, int rscale);
537 static void trunc_var(NumericVar *var, int rscale);
538 static void strip_var(NumericVar *var);
539 static void compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
540 						   const NumericVar *count_var, NumericVar *result_var);
541 
542 static void accum_sum_add(NumericSumAccum *accum, const NumericVar *var1);
543 static void accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val);
544 static void accum_sum_carry(NumericSumAccum *accum);
545 static void accum_sum_reset(NumericSumAccum *accum);
546 static void accum_sum_final(NumericSumAccum *accum, NumericVar *result);
547 static void accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src);
548 static void accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2);
549 
550 
551 /* ----------------------------------------------------------------------
552  *
553  * Input-, output- and rounding-functions
554  *
555  * ----------------------------------------------------------------------
556  */
557 
558 
559 /*
560  * numeric_in() -
561  *
562  *	Input function for numeric data type
563  */
564 Datum
numeric_in(PG_FUNCTION_ARGS)565 numeric_in(PG_FUNCTION_ARGS)
566 {
567 	char	   *str = PG_GETARG_CSTRING(0);
568 
569 #ifdef NOT_USED
570 	Oid			typelem = PG_GETARG_OID(1);
571 #endif
572 	int32		typmod = PG_GETARG_INT32(2);
573 	Numeric		res;
574 	const char *cp;
575 
576 	/* Skip leading spaces */
577 	cp = str;
578 	while (*cp)
579 	{
580 		if (!isspace((unsigned char) *cp))
581 			break;
582 		cp++;
583 	}
584 
585 	/*
586 	 * Check for NaN
587 	 */
588 	if (pg_strncasecmp(cp, "NaN", 3) == 0)
589 	{
590 		res = make_result(&const_nan);
591 
592 		/* Should be nothing left but spaces */
593 		cp += 3;
594 		while (*cp)
595 		{
596 			if (!isspace((unsigned char) *cp))
597 				ereport(ERROR,
598 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
599 						 errmsg("invalid input syntax for type %s: \"%s\"",
600 								"numeric", str)));
601 			cp++;
602 		}
603 	}
604 	else
605 	{
606 		/*
607 		 * Use set_var_from_str() to parse a normal numeric value
608 		 */
609 		NumericVar	value;
610 
611 		init_var(&value);
612 
613 		cp = set_var_from_str(str, cp, &value);
614 
615 		/*
616 		 * We duplicate a few lines of code here because we would like to
617 		 * throw any trailing-junk syntax error before any semantic error
618 		 * resulting from apply_typmod.  We can't easily fold the two cases
619 		 * together because we mustn't apply apply_typmod to a NaN.
620 		 */
621 		while (*cp)
622 		{
623 			if (!isspace((unsigned char) *cp))
624 				ereport(ERROR,
625 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
626 						 errmsg("invalid input syntax for type %s: \"%s\"",
627 								"numeric", str)));
628 			cp++;
629 		}
630 
631 		apply_typmod(&value, typmod);
632 
633 		res = make_result(&value);
634 		free_var(&value);
635 	}
636 
637 	PG_RETURN_NUMERIC(res);
638 }
639 
640 
641 /*
642  * numeric_out() -
643  *
644  *	Output function for numeric data type
645  */
646 Datum
numeric_out(PG_FUNCTION_ARGS)647 numeric_out(PG_FUNCTION_ARGS)
648 {
649 	Numeric		num = PG_GETARG_NUMERIC(0);
650 	NumericVar	x;
651 	char	   *str;
652 
653 	/*
654 	 * Handle NaN
655 	 */
656 	if (NUMERIC_IS_NAN(num))
657 		PG_RETURN_CSTRING(pstrdup("NaN"));
658 
659 	/*
660 	 * Get the number in the variable format.
661 	 */
662 	init_var_from_num(num, &x);
663 
664 	str = get_str_from_var(&x);
665 
666 	PG_RETURN_CSTRING(str);
667 }
668 
669 /*
670  * numeric_is_nan() -
671  *
672  *	Is Numeric value a NaN?
673  */
674 bool
numeric_is_nan(Numeric num)675 numeric_is_nan(Numeric num)
676 {
677 	return NUMERIC_IS_NAN(num);
678 }
679 
680 /*
681  * numeric_maximum_size() -
682  *
683  *	Maximum size of a numeric with given typmod, or -1 if unlimited/unknown.
684  */
685 int32
numeric_maximum_size(int32 typmod)686 numeric_maximum_size(int32 typmod)
687 {
688 	int			precision;
689 	int			numeric_digits;
690 
691 	if (typmod < (int32) (VARHDRSZ))
692 		return -1;
693 
694 	/* precision (ie, max # of digits) is in upper bits of typmod */
695 	precision = ((typmod - VARHDRSZ) >> 16) & 0xffff;
696 
697 	/*
698 	 * This formula computes the maximum number of NumericDigits we could need
699 	 * in order to store the specified number of decimal digits. Because the
700 	 * weight is stored as a number of NumericDigits rather than a number of
701 	 * decimal digits, it's possible that the first NumericDigit will contain
702 	 * only a single decimal digit.  Thus, the first two decimal digits can
703 	 * require two NumericDigits to store, but it isn't until we reach
704 	 * DEC_DIGITS + 2 decimal digits that we potentially need a third
705 	 * NumericDigit.
706 	 */
707 	numeric_digits = (precision + 2 * (DEC_DIGITS - 1)) / DEC_DIGITS;
708 
709 	/*
710 	 * In most cases, the size of a numeric will be smaller than the value
711 	 * computed below, because the varlena header will typically get toasted
712 	 * down to a single byte before being stored on disk, and it may also be
713 	 * possible to use a short numeric header.  But our job here is to compute
714 	 * the worst case.
715 	 */
716 	return NUMERIC_HDRSZ + (numeric_digits * sizeof(NumericDigit));
717 }
718 
719 /*
720  * numeric_out_sci() -
721  *
722  *	Output function for numeric data type in scientific notation.
723  */
724 char *
numeric_out_sci(Numeric num,int scale)725 numeric_out_sci(Numeric num, int scale)
726 {
727 	NumericVar	x;
728 	char	   *str;
729 
730 	/*
731 	 * Handle NaN
732 	 */
733 	if (NUMERIC_IS_NAN(num))
734 		return pstrdup("NaN");
735 
736 	init_var_from_num(num, &x);
737 
738 	str = get_str_from_var_sci(&x, scale);
739 
740 	return str;
741 }
742 
743 /*
744  * numeric_normalize() -
745  *
746  *	Output function for numeric data type, suppressing insignificant trailing
747  *	zeroes and then any trailing decimal point.  The intent of this is to
748  *	produce strings that are equal if and only if the input numeric values
749  *	compare equal.
750  */
751 char *
numeric_normalize(Numeric num)752 numeric_normalize(Numeric num)
753 {
754 	NumericVar	x;
755 	char	   *str;
756 	int			last;
757 
758 	/*
759 	 * Handle NaN
760 	 */
761 	if (NUMERIC_IS_NAN(num))
762 		return pstrdup("NaN");
763 
764 	init_var_from_num(num, &x);
765 
766 	str = get_str_from_var(&x);
767 
768 	/* If there's no decimal point, there's certainly nothing to remove. */
769 	if (strchr(str, '.') != NULL)
770 	{
771 		/*
772 		 * Back up over trailing fractional zeroes.  Since there is a decimal
773 		 * point, this loop will terminate safely.
774 		 */
775 		last = strlen(str) - 1;
776 		while (str[last] == '0')
777 			last--;
778 
779 		/* We want to get rid of the decimal point too, if it's now last. */
780 		if (str[last] == '.')
781 			last--;
782 
783 		/* Delete whatever we backed up over. */
784 		str[last + 1] = '\0';
785 	}
786 
787 	return str;
788 }
789 
790 /*
791  *		numeric_recv			- converts external binary format to numeric
792  *
793  * External format is a sequence of int16's:
794  * ndigits, weight, sign, dscale, NumericDigits.
795  */
796 Datum
numeric_recv(PG_FUNCTION_ARGS)797 numeric_recv(PG_FUNCTION_ARGS)
798 {
799 	StringInfo	buf = (StringInfo) PG_GETARG_POINTER(0);
800 
801 #ifdef NOT_USED
802 	Oid			typelem = PG_GETARG_OID(1);
803 #endif
804 	int32		typmod = PG_GETARG_INT32(2);
805 	NumericVar	value;
806 	Numeric		res;
807 	int			len,
808 				i;
809 
810 	init_var(&value);
811 
812 	len = (uint16) pq_getmsgint(buf, sizeof(uint16));
813 
814 	alloc_var(&value, len);
815 
816 	value.weight = (int16) pq_getmsgint(buf, sizeof(int16));
817 	/* we allow any int16 for weight --- OK? */
818 
819 	value.sign = (uint16) pq_getmsgint(buf, sizeof(uint16));
820 	if (!(value.sign == NUMERIC_POS ||
821 		  value.sign == NUMERIC_NEG ||
822 		  value.sign == NUMERIC_NAN))
823 		ereport(ERROR,
824 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
825 				 errmsg("invalid sign in external \"numeric\" value")));
826 
827 	value.dscale = (uint16) pq_getmsgint(buf, sizeof(uint16));
828 	if ((value.dscale & NUMERIC_DSCALE_MASK) != value.dscale)
829 		ereport(ERROR,
830 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
831 				 errmsg("invalid scale in external \"numeric\" value")));
832 
833 	for (i = 0; i < len; i++)
834 	{
835 		NumericDigit d = pq_getmsgint(buf, sizeof(NumericDigit));
836 
837 		if (d < 0 || d >= NBASE)
838 			ereport(ERROR,
839 					(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
840 					 errmsg("invalid digit in external \"numeric\" value")));
841 		value.digits[i] = d;
842 	}
843 
844 	/*
845 	 * If the given dscale would hide any digits, truncate those digits away.
846 	 * We could alternatively throw an error, but that would take a bunch of
847 	 * extra code (about as much as trunc_var involves), and it might cause
848 	 * client compatibility issues.
849 	 */
850 	trunc_var(&value, value.dscale);
851 
852 	apply_typmod(&value, typmod);
853 
854 	res = make_result(&value);
855 	free_var(&value);
856 
857 	PG_RETURN_NUMERIC(res);
858 }
859 
860 /*
861  *		numeric_send			- converts numeric to binary format
862  */
863 Datum
numeric_send(PG_FUNCTION_ARGS)864 numeric_send(PG_FUNCTION_ARGS)
865 {
866 	Numeric		num = PG_GETARG_NUMERIC(0);
867 	NumericVar	x;
868 	StringInfoData buf;
869 	int			i;
870 
871 	init_var_from_num(num, &x);
872 
873 	pq_begintypsend(&buf);
874 
875 	pq_sendint16(&buf, x.ndigits);
876 	pq_sendint16(&buf, x.weight);
877 	pq_sendint16(&buf, x.sign);
878 	pq_sendint16(&buf, x.dscale);
879 	for (i = 0; i < x.ndigits; i++)
880 		pq_sendint16(&buf, x.digits[i]);
881 
882 	PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
883 }
884 
885 
886 /*
887  * numeric_support()
888  *
889  * Planner support function for the numeric() length coercion function.
890  *
891  * Flatten calls that solely represent increases in allowable precision.
892  * Scale changes mutate every datum, so they are unoptimizable.  Some values,
893  * e.g. 1E-1001, can only fit into an unconstrained numeric, so a change from
894  * an unconstrained numeric to any constrained numeric is also unoptimizable.
895  */
896 Datum
numeric_support(PG_FUNCTION_ARGS)897 numeric_support(PG_FUNCTION_ARGS)
898 {
899 	Node	   *rawreq = (Node *) PG_GETARG_POINTER(0);
900 	Node	   *ret = NULL;
901 
902 	if (IsA(rawreq, SupportRequestSimplify))
903 	{
904 		SupportRequestSimplify *req = (SupportRequestSimplify *) rawreq;
905 		FuncExpr   *expr = req->fcall;
906 		Node	   *typmod;
907 
908 		Assert(list_length(expr->args) >= 2);
909 
910 		typmod = (Node *) lsecond(expr->args);
911 
912 		if (IsA(typmod, Const) &&!((Const *) typmod)->constisnull)
913 		{
914 			Node	   *source = (Node *) linitial(expr->args);
915 			int32		old_typmod = exprTypmod(source);
916 			int32		new_typmod = DatumGetInt32(((Const *) typmod)->constvalue);
917 			int32		old_scale = (old_typmod - VARHDRSZ) & 0xffff;
918 			int32		new_scale = (new_typmod - VARHDRSZ) & 0xffff;
919 			int32		old_precision = (old_typmod - VARHDRSZ) >> 16 & 0xffff;
920 			int32		new_precision = (new_typmod - VARHDRSZ) >> 16 & 0xffff;
921 
922 			/*
923 			 * If new_typmod < VARHDRSZ, the destination is unconstrained;
924 			 * that's always OK.  If old_typmod >= VARHDRSZ, the source is
925 			 * constrained, and we're OK if the scale is unchanged and the
926 			 * precision is not decreasing.  See further notes in function
927 			 * header comment.
928 			 */
929 			if (new_typmod < (int32) VARHDRSZ ||
930 				(old_typmod >= (int32) VARHDRSZ &&
931 				 new_scale == old_scale && new_precision >= old_precision))
932 				ret = relabel_to_typmod(source, new_typmod);
933 		}
934 	}
935 
936 	PG_RETURN_POINTER(ret);
937 }
938 
939 /*
940  * numeric() -
941  *
942  *	This is a special function called by the Postgres database system
943  *	before a value is stored in a tuple's attribute. The precision and
944  *	scale of the attribute have to be applied on the value.
945  */
946 Datum
numeric(PG_FUNCTION_ARGS)947 numeric		(PG_FUNCTION_ARGS)
948 {
949 	Numeric		num = PG_GETARG_NUMERIC(0);
950 	int32		typmod = PG_GETARG_INT32(1);
951 	Numeric		new;
952 	int32		tmp_typmod;
953 	int			precision;
954 	int			scale;
955 	int			ddigits;
956 	int			maxdigits;
957 	NumericVar	var;
958 
959 	/*
960 	 * Handle NaN
961 	 */
962 	if (NUMERIC_IS_NAN(num))
963 		PG_RETURN_NUMERIC(make_result(&const_nan));
964 
965 	/*
966 	 * If the value isn't a valid type modifier, simply return a copy of the
967 	 * input value
968 	 */
969 	if (typmod < (int32) (VARHDRSZ))
970 	{
971 		new = (Numeric) palloc(VARSIZE(num));
972 		memcpy(new, num, VARSIZE(num));
973 		PG_RETURN_NUMERIC(new);
974 	}
975 
976 	/*
977 	 * Get the precision and scale out of the typmod value
978 	 */
979 	tmp_typmod = typmod - VARHDRSZ;
980 	precision = (tmp_typmod >> 16) & 0xffff;
981 	scale = tmp_typmod & 0xffff;
982 	maxdigits = precision - scale;
983 
984 	/*
985 	 * If the number is certainly in bounds and due to the target scale no
986 	 * rounding could be necessary, just make a copy of the input and modify
987 	 * its scale fields, unless the larger scale forces us to abandon the
988 	 * short representation.  (Note we assume the existing dscale is
989 	 * honest...)
990 	 */
991 	ddigits = (NUMERIC_WEIGHT(num) + 1) * DEC_DIGITS;
992 	if (ddigits <= maxdigits && scale >= NUMERIC_DSCALE(num)
993 		&& (NUMERIC_CAN_BE_SHORT(scale, NUMERIC_WEIGHT(num))
994 			|| !NUMERIC_IS_SHORT(num)))
995 	{
996 		new = (Numeric) palloc(VARSIZE(num));
997 		memcpy(new, num, VARSIZE(num));
998 		if (NUMERIC_IS_SHORT(num))
999 			new->choice.n_short.n_header =
1000 				(num->choice.n_short.n_header & ~NUMERIC_SHORT_DSCALE_MASK)
1001 				| (scale << NUMERIC_SHORT_DSCALE_SHIFT);
1002 		else
1003 			new->choice.n_long.n_sign_dscale = NUMERIC_SIGN(new) |
1004 				((uint16) scale & NUMERIC_DSCALE_MASK);
1005 		PG_RETURN_NUMERIC(new);
1006 	}
1007 
1008 	/*
1009 	 * We really need to fiddle with things - unpack the number into a
1010 	 * variable and let apply_typmod() do it.
1011 	 */
1012 	init_var(&var);
1013 
1014 	set_var_from_num(num, &var);
1015 	apply_typmod(&var, typmod);
1016 	new = make_result(&var);
1017 
1018 	free_var(&var);
1019 
1020 	PG_RETURN_NUMERIC(new);
1021 }
1022 
1023 Datum
numerictypmodin(PG_FUNCTION_ARGS)1024 numerictypmodin(PG_FUNCTION_ARGS)
1025 {
1026 	ArrayType  *ta = PG_GETARG_ARRAYTYPE_P(0);
1027 	int32	   *tl;
1028 	int			n;
1029 	int32		typmod;
1030 
1031 	tl = ArrayGetIntegerTypmods(ta, &n);
1032 
1033 	if (n == 2)
1034 	{
1035 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1036 			ereport(ERROR,
1037 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1038 					 errmsg("NUMERIC precision %d must be between 1 and %d",
1039 							tl[0], NUMERIC_MAX_PRECISION)));
1040 		if (tl[1] < 0 || tl[1] > tl[0])
1041 			ereport(ERROR,
1042 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1043 					 errmsg("NUMERIC scale %d must be between 0 and precision %d",
1044 							tl[1], tl[0])));
1045 		typmod = ((tl[0] << 16) | tl[1]) + VARHDRSZ;
1046 	}
1047 	else if (n == 1)
1048 	{
1049 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1050 			ereport(ERROR,
1051 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1052 					 errmsg("NUMERIC precision %d must be between 1 and %d",
1053 							tl[0], NUMERIC_MAX_PRECISION)));
1054 		/* scale defaults to zero */
1055 		typmod = (tl[0] << 16) + VARHDRSZ;
1056 	}
1057 	else
1058 	{
1059 		ereport(ERROR,
1060 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1061 				 errmsg("invalid NUMERIC type modifier")));
1062 		typmod = 0;				/* keep compiler quiet */
1063 	}
1064 
1065 	PG_RETURN_INT32(typmod);
1066 }
1067 
1068 Datum
numerictypmodout(PG_FUNCTION_ARGS)1069 numerictypmodout(PG_FUNCTION_ARGS)
1070 {
1071 	int32		typmod = PG_GETARG_INT32(0);
1072 	char	   *res = (char *) palloc(64);
1073 
1074 	if (typmod >= 0)
1075 		snprintf(res, 64, "(%d,%d)",
1076 				 ((typmod - VARHDRSZ) >> 16) & 0xffff,
1077 				 (typmod - VARHDRSZ) & 0xffff);
1078 	else
1079 		*res = '\0';
1080 
1081 	PG_RETURN_CSTRING(res);
1082 }
1083 
1084 
1085 /* ----------------------------------------------------------------------
1086  *
1087  * Sign manipulation, rounding and the like
1088  *
1089  * ----------------------------------------------------------------------
1090  */
1091 
1092 Datum
numeric_abs(PG_FUNCTION_ARGS)1093 numeric_abs(PG_FUNCTION_ARGS)
1094 {
1095 	Numeric		num = PG_GETARG_NUMERIC(0);
1096 	Numeric		res;
1097 
1098 	/*
1099 	 * Handle NaN
1100 	 */
1101 	if (NUMERIC_IS_NAN(num))
1102 		PG_RETURN_NUMERIC(make_result(&const_nan));
1103 
1104 	/*
1105 	 * Do it the easy way directly on the packed format
1106 	 */
1107 	res = (Numeric) palloc(VARSIZE(num));
1108 	memcpy(res, num, VARSIZE(num));
1109 
1110 	if (NUMERIC_IS_SHORT(num))
1111 		res->choice.n_short.n_header =
1112 			num->choice.n_short.n_header & ~NUMERIC_SHORT_SIGN_MASK;
1113 	else
1114 		res->choice.n_long.n_sign_dscale = NUMERIC_POS | NUMERIC_DSCALE(num);
1115 
1116 	PG_RETURN_NUMERIC(res);
1117 }
1118 
1119 
1120 Datum
numeric_uminus(PG_FUNCTION_ARGS)1121 numeric_uminus(PG_FUNCTION_ARGS)
1122 {
1123 	Numeric		num = PG_GETARG_NUMERIC(0);
1124 	Numeric		res;
1125 
1126 	/*
1127 	 * Handle NaN
1128 	 */
1129 	if (NUMERIC_IS_NAN(num))
1130 		PG_RETURN_NUMERIC(make_result(&const_nan));
1131 
1132 	/*
1133 	 * Do it the easy way directly on the packed format
1134 	 */
1135 	res = (Numeric) palloc(VARSIZE(num));
1136 	memcpy(res, num, VARSIZE(num));
1137 
1138 	/*
1139 	 * The packed format is known to be totally zero digit trimmed always. So
1140 	 * we can identify a ZERO by the fact that there are no digits at all.  Do
1141 	 * nothing to a zero.
1142 	 */
1143 	if (NUMERIC_NDIGITS(num) != 0)
1144 	{
1145 		/* Else, flip the sign */
1146 		if (NUMERIC_IS_SHORT(num))
1147 			res->choice.n_short.n_header =
1148 				num->choice.n_short.n_header ^ NUMERIC_SHORT_SIGN_MASK;
1149 		else if (NUMERIC_SIGN(num) == NUMERIC_POS)
1150 			res->choice.n_long.n_sign_dscale =
1151 				NUMERIC_NEG | NUMERIC_DSCALE(num);
1152 		else
1153 			res->choice.n_long.n_sign_dscale =
1154 				NUMERIC_POS | NUMERIC_DSCALE(num);
1155 	}
1156 
1157 	PG_RETURN_NUMERIC(res);
1158 }
1159 
1160 
1161 Datum
numeric_uplus(PG_FUNCTION_ARGS)1162 numeric_uplus(PG_FUNCTION_ARGS)
1163 {
1164 	Numeric		num = PG_GETARG_NUMERIC(0);
1165 	Numeric		res;
1166 
1167 	res = (Numeric) palloc(VARSIZE(num));
1168 	memcpy(res, num, VARSIZE(num));
1169 
1170 	PG_RETURN_NUMERIC(res);
1171 }
1172 
1173 /*
1174  * numeric_sign() -
1175  *
1176  * returns -1 if the argument is less than 0, 0 if the argument is equal
1177  * to 0, and 1 if the argument is greater than zero.
1178  */
1179 Datum
numeric_sign(PG_FUNCTION_ARGS)1180 numeric_sign(PG_FUNCTION_ARGS)
1181 {
1182 	Numeric		num = PG_GETARG_NUMERIC(0);
1183 	Numeric		res;
1184 	NumericVar	result;
1185 
1186 	/*
1187 	 * Handle NaN
1188 	 */
1189 	if (NUMERIC_IS_NAN(num))
1190 		PG_RETURN_NUMERIC(make_result(&const_nan));
1191 
1192 	init_var(&result);
1193 
1194 	/*
1195 	 * The packed format is known to be totally zero digit trimmed always. So
1196 	 * we can identify a ZERO by the fact that there are no digits at all.
1197 	 */
1198 	if (NUMERIC_NDIGITS(num) == 0)
1199 		set_var_from_var(&const_zero, &result);
1200 	else
1201 	{
1202 		/*
1203 		 * And if there are some, we return a copy of ONE with the sign of our
1204 		 * argument
1205 		 */
1206 		set_var_from_var(&const_one, &result);
1207 		result.sign = NUMERIC_SIGN(num);
1208 	}
1209 
1210 	res = make_result(&result);
1211 	free_var(&result);
1212 
1213 	PG_RETURN_NUMERIC(res);
1214 }
1215 
1216 
1217 /*
1218  * numeric_round() -
1219  *
1220  *	Round a value to have 'scale' digits after the decimal point.
1221  *	We allow negative 'scale', implying rounding before the decimal
1222  *	point --- Oracle interprets rounding that way.
1223  */
1224 Datum
numeric_round(PG_FUNCTION_ARGS)1225 numeric_round(PG_FUNCTION_ARGS)
1226 {
1227 	Numeric		num = PG_GETARG_NUMERIC(0);
1228 	int32		scale = PG_GETARG_INT32(1);
1229 	Numeric		res;
1230 	NumericVar	arg;
1231 
1232 	/*
1233 	 * Handle NaN
1234 	 */
1235 	if (NUMERIC_IS_NAN(num))
1236 		PG_RETURN_NUMERIC(make_result(&const_nan));
1237 
1238 	/*
1239 	 * Limit the scale value to avoid possible overflow in calculations
1240 	 */
1241 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1242 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1243 
1244 	/*
1245 	 * Unpack the argument and round it at the proper digit position
1246 	 */
1247 	init_var(&arg);
1248 	set_var_from_num(num, &arg);
1249 
1250 	round_var(&arg, scale);
1251 
1252 	/* We don't allow negative output dscale */
1253 	if (scale < 0)
1254 		arg.dscale = 0;
1255 
1256 	/*
1257 	 * Return the rounded result
1258 	 */
1259 	res = make_result(&arg);
1260 
1261 	free_var(&arg);
1262 	PG_RETURN_NUMERIC(res);
1263 }
1264 
1265 
1266 /*
1267  * numeric_trunc() -
1268  *
1269  *	Truncate a value to have 'scale' digits after the decimal point.
1270  *	We allow negative 'scale', implying a truncation before the decimal
1271  *	point --- Oracle interprets truncation that way.
1272  */
1273 Datum
numeric_trunc(PG_FUNCTION_ARGS)1274 numeric_trunc(PG_FUNCTION_ARGS)
1275 {
1276 	Numeric		num = PG_GETARG_NUMERIC(0);
1277 	int32		scale = PG_GETARG_INT32(1);
1278 	Numeric		res;
1279 	NumericVar	arg;
1280 
1281 	/*
1282 	 * Handle NaN
1283 	 */
1284 	if (NUMERIC_IS_NAN(num))
1285 		PG_RETURN_NUMERIC(make_result(&const_nan));
1286 
1287 	/*
1288 	 * Limit the scale value to avoid possible overflow in calculations
1289 	 */
1290 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1291 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1292 
1293 	/*
1294 	 * Unpack the argument and truncate it at the proper digit position
1295 	 */
1296 	init_var(&arg);
1297 	set_var_from_num(num, &arg);
1298 
1299 	trunc_var(&arg, scale);
1300 
1301 	/* We don't allow negative output dscale */
1302 	if (scale < 0)
1303 		arg.dscale = 0;
1304 
1305 	/*
1306 	 * Return the truncated result
1307 	 */
1308 	res = make_result(&arg);
1309 
1310 	free_var(&arg);
1311 	PG_RETURN_NUMERIC(res);
1312 }
1313 
1314 
1315 /*
1316  * numeric_ceil() -
1317  *
1318  *	Return the smallest integer greater than or equal to the argument
1319  */
1320 Datum
numeric_ceil(PG_FUNCTION_ARGS)1321 numeric_ceil(PG_FUNCTION_ARGS)
1322 {
1323 	Numeric		num = PG_GETARG_NUMERIC(0);
1324 	Numeric		res;
1325 	NumericVar	result;
1326 
1327 	if (NUMERIC_IS_NAN(num))
1328 		PG_RETURN_NUMERIC(make_result(&const_nan));
1329 
1330 	init_var_from_num(num, &result);
1331 	ceil_var(&result, &result);
1332 
1333 	res = make_result(&result);
1334 	free_var(&result);
1335 
1336 	PG_RETURN_NUMERIC(res);
1337 }
1338 
1339 
1340 /*
1341  * numeric_floor() -
1342  *
1343  *	Return the largest integer equal to or less than the argument
1344  */
1345 Datum
numeric_floor(PG_FUNCTION_ARGS)1346 numeric_floor(PG_FUNCTION_ARGS)
1347 {
1348 	Numeric		num = PG_GETARG_NUMERIC(0);
1349 	Numeric		res;
1350 	NumericVar	result;
1351 
1352 	if (NUMERIC_IS_NAN(num))
1353 		PG_RETURN_NUMERIC(make_result(&const_nan));
1354 
1355 	init_var_from_num(num, &result);
1356 	floor_var(&result, &result);
1357 
1358 	res = make_result(&result);
1359 	free_var(&result);
1360 
1361 	PG_RETURN_NUMERIC(res);
1362 }
1363 
1364 
1365 /*
1366  * generate_series_numeric() -
1367  *
1368  *	Generate series of numeric.
1369  */
1370 Datum
generate_series_numeric(PG_FUNCTION_ARGS)1371 generate_series_numeric(PG_FUNCTION_ARGS)
1372 {
1373 	return generate_series_step_numeric(fcinfo);
1374 }
1375 
1376 Datum
generate_series_step_numeric(PG_FUNCTION_ARGS)1377 generate_series_step_numeric(PG_FUNCTION_ARGS)
1378 {
1379 	generate_series_numeric_fctx *fctx;
1380 	FuncCallContext *funcctx;
1381 	MemoryContext oldcontext;
1382 
1383 	if (SRF_IS_FIRSTCALL())
1384 	{
1385 		Numeric		start_num = PG_GETARG_NUMERIC(0);
1386 		Numeric		stop_num = PG_GETARG_NUMERIC(1);
1387 		NumericVar	steploc = const_one;
1388 
1389 		/* handle NaN in start and stop values */
1390 		if (NUMERIC_IS_NAN(start_num))
1391 			ereport(ERROR,
1392 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1393 					 errmsg("start value cannot be NaN")));
1394 
1395 		if (NUMERIC_IS_NAN(stop_num))
1396 			ereport(ERROR,
1397 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1398 					 errmsg("stop value cannot be NaN")));
1399 
1400 		/* see if we were given an explicit step size */
1401 		if (PG_NARGS() == 3)
1402 		{
1403 			Numeric		step_num = PG_GETARG_NUMERIC(2);
1404 
1405 			if (NUMERIC_IS_NAN(step_num))
1406 				ereport(ERROR,
1407 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1408 						 errmsg("step size cannot be NaN")));
1409 
1410 			init_var_from_num(step_num, &steploc);
1411 
1412 			if (cmp_var(&steploc, &const_zero) == 0)
1413 				ereport(ERROR,
1414 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1415 						 errmsg("step size cannot equal zero")));
1416 		}
1417 
1418 		/* create a function context for cross-call persistence */
1419 		funcctx = SRF_FIRSTCALL_INIT();
1420 
1421 		/*
1422 		 * Switch to memory context appropriate for multiple function calls.
1423 		 */
1424 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1425 
1426 		/* allocate memory for user context */
1427 		fctx = (generate_series_numeric_fctx *)
1428 			palloc(sizeof(generate_series_numeric_fctx));
1429 
1430 		/*
1431 		 * Use fctx to keep state from call to call. Seed current with the
1432 		 * original start value. We must copy the start_num and stop_num
1433 		 * values rather than pointing to them, since we may have detoasted
1434 		 * them in the per-call context.
1435 		 */
1436 		init_var(&fctx->current);
1437 		init_var(&fctx->stop);
1438 		init_var(&fctx->step);
1439 
1440 		set_var_from_num(start_num, &fctx->current);
1441 		set_var_from_num(stop_num, &fctx->stop);
1442 		set_var_from_var(&steploc, &fctx->step);
1443 
1444 		funcctx->user_fctx = fctx;
1445 		MemoryContextSwitchTo(oldcontext);
1446 	}
1447 
1448 	/* stuff done on every call of the function */
1449 	funcctx = SRF_PERCALL_SETUP();
1450 
1451 	/*
1452 	 * Get the saved state and use current state as the result of this
1453 	 * iteration.
1454 	 */
1455 	fctx = funcctx->user_fctx;
1456 
1457 	if ((fctx->step.sign == NUMERIC_POS &&
1458 		 cmp_var(&fctx->current, &fctx->stop) <= 0) ||
1459 		(fctx->step.sign == NUMERIC_NEG &&
1460 		 cmp_var(&fctx->current, &fctx->stop) >= 0))
1461 	{
1462 		Numeric		result = make_result(&fctx->current);
1463 
1464 		/* switch to memory context appropriate for iteration calculation */
1465 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1466 
1467 		/* increment current in preparation for next iteration */
1468 		add_var(&fctx->current, &fctx->step, &fctx->current);
1469 		MemoryContextSwitchTo(oldcontext);
1470 
1471 		/* do when there is more left to send */
1472 		SRF_RETURN_NEXT(funcctx, NumericGetDatum(result));
1473 	}
1474 	else
1475 		/* do when there is no more left */
1476 		SRF_RETURN_DONE(funcctx);
1477 }
1478 
1479 
1480 /*
1481  * Implements the numeric version of the width_bucket() function
1482  * defined by SQL2003. See also width_bucket_float8().
1483  *
1484  * 'bound1' and 'bound2' are the lower and upper bounds of the
1485  * histogram's range, respectively. 'count' is the number of buckets
1486  * in the histogram. width_bucket() returns an integer indicating the
1487  * bucket number that 'operand' belongs to in an equiwidth histogram
1488  * with the specified characteristics. An operand smaller than the
1489  * lower bound is assigned to bucket 0. An operand greater than the
1490  * upper bound is assigned to an additional bucket (with number
1491  * count+1). We don't allow "NaN" for any of the numeric arguments.
1492  */
1493 Datum
width_bucket_numeric(PG_FUNCTION_ARGS)1494 width_bucket_numeric(PG_FUNCTION_ARGS)
1495 {
1496 	Numeric		operand = PG_GETARG_NUMERIC(0);
1497 	Numeric		bound1 = PG_GETARG_NUMERIC(1);
1498 	Numeric		bound2 = PG_GETARG_NUMERIC(2);
1499 	int32		count = PG_GETARG_INT32(3);
1500 	NumericVar	count_var;
1501 	NumericVar	result_var;
1502 	int32		result;
1503 
1504 	if (count <= 0)
1505 		ereport(ERROR,
1506 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1507 				 errmsg("count must be greater than zero")));
1508 
1509 	if (NUMERIC_IS_NAN(operand) ||
1510 		NUMERIC_IS_NAN(bound1) ||
1511 		NUMERIC_IS_NAN(bound2))
1512 		ereport(ERROR,
1513 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1514 				 errmsg("operand, lower bound, and upper bound cannot be NaN")));
1515 
1516 	init_var(&result_var);
1517 	init_var(&count_var);
1518 
1519 	/* Convert 'count' to a numeric, for ease of use later */
1520 	int64_to_numericvar((int64) count, &count_var);
1521 
1522 	switch (cmp_numerics(bound1, bound2))
1523 	{
1524 		case 0:
1525 			ereport(ERROR,
1526 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1527 					 errmsg("lower bound cannot equal upper bound")));
1528 			break;
1529 
1530 			/* bound1 < bound2 */
1531 		case -1:
1532 			if (cmp_numerics(operand, bound1) < 0)
1533 				set_var_from_var(&const_zero, &result_var);
1534 			else if (cmp_numerics(operand, bound2) >= 0)
1535 				add_var(&count_var, &const_one, &result_var);
1536 			else
1537 				compute_bucket(operand, bound1, bound2,
1538 							   &count_var, &result_var);
1539 			break;
1540 
1541 			/* bound1 > bound2 */
1542 		case 1:
1543 			if (cmp_numerics(operand, bound1) > 0)
1544 				set_var_from_var(&const_zero, &result_var);
1545 			else if (cmp_numerics(operand, bound2) <= 0)
1546 				add_var(&count_var, &const_one, &result_var);
1547 			else
1548 				compute_bucket(operand, bound1, bound2,
1549 							   &count_var, &result_var);
1550 			break;
1551 	}
1552 
1553 	/* if result exceeds the range of a legal int4, we ereport here */
1554 	if (!numericvar_to_int32(&result_var, &result))
1555 		ereport(ERROR,
1556 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
1557 				 errmsg("integer out of range")));
1558 
1559 	free_var(&count_var);
1560 	free_var(&result_var);
1561 
1562 	PG_RETURN_INT32(result);
1563 }
1564 
1565 /*
1566  * If 'operand' is not outside the bucket range, determine the correct
1567  * bucket for it to go. The calculations performed by this function
1568  * are derived directly from the SQL2003 spec.
1569  */
1570 static void
compute_bucket(Numeric operand,Numeric bound1,Numeric bound2,const NumericVar * count_var,NumericVar * result_var)1571 compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
1572 			   const NumericVar *count_var, NumericVar *result_var)
1573 {
1574 	NumericVar	bound1_var;
1575 	NumericVar	bound2_var;
1576 	NumericVar	operand_var;
1577 
1578 	init_var_from_num(bound1, &bound1_var);
1579 	init_var_from_num(bound2, &bound2_var);
1580 	init_var_from_num(operand, &operand_var);
1581 
1582 	if (cmp_var(&bound1_var, &bound2_var) < 0)
1583 	{
1584 		sub_var(&operand_var, &bound1_var, &operand_var);
1585 		sub_var(&bound2_var, &bound1_var, &bound2_var);
1586 		div_var(&operand_var, &bound2_var, result_var,
1587 				select_div_scale(&operand_var, &bound2_var), true);
1588 	}
1589 	else
1590 	{
1591 		sub_var(&bound1_var, &operand_var, &operand_var);
1592 		sub_var(&bound1_var, &bound2_var, &bound1_var);
1593 		div_var(&operand_var, &bound1_var, result_var,
1594 				select_div_scale(&operand_var, &bound1_var), true);
1595 	}
1596 
1597 	mul_var(result_var, count_var, result_var,
1598 			result_var->dscale + count_var->dscale);
1599 	add_var(result_var, &const_one, result_var);
1600 	floor_var(result_var, result_var);
1601 
1602 	free_var(&bound1_var);
1603 	free_var(&bound2_var);
1604 	free_var(&operand_var);
1605 }
1606 
1607 /* ----------------------------------------------------------------------
1608  *
1609  * Comparison functions
1610  *
1611  * Note: btree indexes need these routines not to leak memory; therefore,
1612  * be careful to free working copies of toasted datums.  Most places don't
1613  * need to be so careful.
1614  *
1615  * Sort support:
1616  *
1617  * We implement the sortsupport strategy routine in order to get the benefit of
1618  * abbreviation. The ordinary numeric comparison can be quite slow as a result
1619  * of palloc/pfree cycles (due to detoasting packed values for alignment);
1620  * while this could be worked on itself, the abbreviation strategy gives more
1621  * speedup in many common cases.
1622  *
1623  * Two different representations are used for the abbreviated form, one in
1624  * int32 and one in int64, whichever fits into a by-value Datum.  In both cases
1625  * the representation is negated relative to the original value, because we use
1626  * the largest negative value for NaN, which sorts higher than other values. We
1627  * convert the absolute value of the numeric to a 31-bit or 63-bit positive
1628  * value, and then negate it if the original number was positive.
1629  *
1630  * We abort the abbreviation process if the abbreviation cardinality is below
1631  * 0.01% of the row count (1 per 10k non-null rows).  The actual break-even
1632  * point is somewhat below that, perhaps 1 per 30k (at 1 per 100k there's a
1633  * very small penalty), but we don't want to build up too many abbreviated
1634  * values before first testing for abort, so we take the slightly pessimistic
1635  * number.  We make no attempt to estimate the cardinality of the real values,
1636  * since it plays no part in the cost model here (if the abbreviation is equal,
1637  * the cost of comparing equal and unequal underlying values is comparable).
1638  * We discontinue even checking for abort (saving us the hashing overhead) if
1639  * the estimated cardinality gets to 100k; that would be enough to support many
1640  * billions of rows while doing no worse than breaking even.
1641  *
1642  * ----------------------------------------------------------------------
1643  */
1644 
1645 /*
1646  * Sort support strategy routine.
1647  */
1648 Datum
numeric_sortsupport(PG_FUNCTION_ARGS)1649 numeric_sortsupport(PG_FUNCTION_ARGS)
1650 {
1651 	SortSupport ssup = (SortSupport) PG_GETARG_POINTER(0);
1652 
1653 	ssup->comparator = numeric_fast_cmp;
1654 
1655 	if (ssup->abbreviate)
1656 	{
1657 		NumericSortSupport *nss;
1658 		MemoryContext oldcontext = MemoryContextSwitchTo(ssup->ssup_cxt);
1659 
1660 		nss = palloc(sizeof(NumericSortSupport));
1661 
1662 		/*
1663 		 * palloc a buffer for handling unaligned packed values in addition to
1664 		 * the support struct
1665 		 */
1666 		nss->buf = palloc(VARATT_SHORT_MAX + VARHDRSZ + 1);
1667 
1668 		nss->input_count = 0;
1669 		nss->estimating = true;
1670 		initHyperLogLog(&nss->abbr_card, 10);
1671 
1672 		ssup->ssup_extra = nss;
1673 
1674 		ssup->abbrev_full_comparator = ssup->comparator;
1675 		ssup->comparator = numeric_cmp_abbrev;
1676 		ssup->abbrev_converter = numeric_abbrev_convert;
1677 		ssup->abbrev_abort = numeric_abbrev_abort;
1678 
1679 		MemoryContextSwitchTo(oldcontext);
1680 	}
1681 
1682 	PG_RETURN_VOID();
1683 }
1684 
1685 /*
1686  * Abbreviate a numeric datum, handling NaNs and detoasting
1687  * (must not leak memory!)
1688  */
1689 static Datum
numeric_abbrev_convert(Datum original_datum,SortSupport ssup)1690 numeric_abbrev_convert(Datum original_datum, SortSupport ssup)
1691 {
1692 	NumericSortSupport *nss = ssup->ssup_extra;
1693 	void	   *original_varatt = PG_DETOAST_DATUM_PACKED(original_datum);
1694 	Numeric		value;
1695 	Datum		result;
1696 
1697 	nss->input_count += 1;
1698 
1699 	/*
1700 	 * This is to handle packed datums without needing a palloc/pfree cycle;
1701 	 * we keep and reuse a buffer large enough to handle any short datum.
1702 	 */
1703 	if (VARATT_IS_SHORT(original_varatt))
1704 	{
1705 		void	   *buf = nss->buf;
1706 		Size		sz = VARSIZE_SHORT(original_varatt) - VARHDRSZ_SHORT;
1707 
1708 		Assert(sz <= VARATT_SHORT_MAX - VARHDRSZ_SHORT);
1709 
1710 		SET_VARSIZE(buf, VARHDRSZ + sz);
1711 		memcpy(VARDATA(buf), VARDATA_SHORT(original_varatt), sz);
1712 
1713 		value = (Numeric) buf;
1714 	}
1715 	else
1716 		value = (Numeric) original_varatt;
1717 
1718 	if (NUMERIC_IS_NAN(value))
1719 	{
1720 		result = NUMERIC_ABBREV_NAN;
1721 	}
1722 	else
1723 	{
1724 		NumericVar	var;
1725 
1726 		init_var_from_num(value, &var);
1727 
1728 		result = numeric_abbrev_convert_var(&var, nss);
1729 	}
1730 
1731 	/* should happen only for external/compressed toasts */
1732 	if ((Pointer) original_varatt != DatumGetPointer(original_datum))
1733 		pfree(original_varatt);
1734 
1735 	return result;
1736 }
1737 
1738 /*
1739  * Consider whether to abort abbreviation.
1740  *
1741  * We pay no attention to the cardinality of the non-abbreviated data. There is
1742  * no reason to do so: unlike text, we have no fast check for equal values, so
1743  * we pay the full overhead whenever the abbreviations are equal regardless of
1744  * whether the underlying values are also equal.
1745  */
1746 static bool
numeric_abbrev_abort(int memtupcount,SortSupport ssup)1747 numeric_abbrev_abort(int memtupcount, SortSupport ssup)
1748 {
1749 	NumericSortSupport *nss = ssup->ssup_extra;
1750 	double		abbr_card;
1751 
1752 	if (memtupcount < 10000 || nss->input_count < 10000 || !nss->estimating)
1753 		return false;
1754 
1755 	abbr_card = estimateHyperLogLog(&nss->abbr_card);
1756 
1757 	/*
1758 	 * If we have >100k distinct values, then even if we were sorting many
1759 	 * billion rows we'd likely still break even, and the penalty of undoing
1760 	 * that many rows of abbrevs would probably not be worth it. Stop even
1761 	 * counting at that point.
1762 	 */
1763 	if (abbr_card > 100000.0)
1764 	{
1765 #ifdef TRACE_SORT
1766 		if (trace_sort)
1767 			elog(LOG,
1768 				 "numeric_abbrev: estimation ends at cardinality %f"
1769 				 " after " INT64_FORMAT " values (%d rows)",
1770 				 abbr_card, nss->input_count, memtupcount);
1771 #endif
1772 		nss->estimating = false;
1773 		return false;
1774 	}
1775 
1776 	/*
1777 	 * Target minimum cardinality is 1 per ~10k of non-null inputs.  (The
1778 	 * break even point is somewhere between one per 100k rows, where
1779 	 * abbreviation has a very slight penalty, and 1 per 10k where it wins by
1780 	 * a measurable percentage.)  We use the relatively pessimistic 10k
1781 	 * threshold, and add a 0.5 row fudge factor, because it allows us to
1782 	 * abort earlier on genuinely pathological data where we've had exactly
1783 	 * one abbreviated value in the first 10k (non-null) rows.
1784 	 */
1785 	if (abbr_card < nss->input_count / 10000.0 + 0.5)
1786 	{
1787 #ifdef TRACE_SORT
1788 		if (trace_sort)
1789 			elog(LOG,
1790 				 "numeric_abbrev: aborting abbreviation at cardinality %f"
1791 				 " below threshold %f after " INT64_FORMAT " values (%d rows)",
1792 				 abbr_card, nss->input_count / 10000.0 + 0.5,
1793 				 nss->input_count, memtupcount);
1794 #endif
1795 		return true;
1796 	}
1797 
1798 #ifdef TRACE_SORT
1799 	if (trace_sort)
1800 		elog(LOG,
1801 			 "numeric_abbrev: cardinality %f"
1802 			 " after " INT64_FORMAT " values (%d rows)",
1803 			 abbr_card, nss->input_count, memtupcount);
1804 #endif
1805 
1806 	return false;
1807 }
1808 
1809 /*
1810  * Non-fmgr interface to the comparison routine to allow sortsupport to elide
1811  * the fmgr call.  The saving here is small given how slow numeric comparisons
1812  * are, but it is a required part of the sort support API when abbreviations
1813  * are performed.
1814  *
1815  * Two palloc/pfree cycles could be saved here by using persistent buffers for
1816  * aligning short-varlena inputs, but this has not so far been considered to
1817  * be worth the effort.
1818  */
1819 static int
numeric_fast_cmp(Datum x,Datum y,SortSupport ssup)1820 numeric_fast_cmp(Datum x, Datum y, SortSupport ssup)
1821 {
1822 	Numeric		nx = DatumGetNumeric(x);
1823 	Numeric		ny = DatumGetNumeric(y);
1824 	int			result;
1825 
1826 	result = cmp_numerics(nx, ny);
1827 
1828 	if ((Pointer) nx != DatumGetPointer(x))
1829 		pfree(nx);
1830 	if ((Pointer) ny != DatumGetPointer(y))
1831 		pfree(ny);
1832 
1833 	return result;
1834 }
1835 
1836 /*
1837  * Compare abbreviations of values. (Abbreviations may be equal where the true
1838  * values differ, but if the abbreviations differ, they must reflect the
1839  * ordering of the true values.)
1840  */
1841 static int
numeric_cmp_abbrev(Datum x,Datum y,SortSupport ssup)1842 numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup)
1843 {
1844 	/*
1845 	 * NOTE WELL: this is intentionally backwards, because the abbreviation is
1846 	 * negated relative to the original value, to handle NaN.
1847 	 */
1848 	if (DatumGetNumericAbbrev(x) < DatumGetNumericAbbrev(y))
1849 		return 1;
1850 	if (DatumGetNumericAbbrev(x) > DatumGetNumericAbbrev(y))
1851 		return -1;
1852 	return 0;
1853 }
1854 
1855 /*
1856  * Abbreviate a NumericVar according to the available bit size.
1857  *
1858  * The 31-bit value is constructed as:
1859  *
1860  *	0 + 7bits digit weight + 24 bits digit value
1861  *
1862  * where the digit weight is in single decimal digits, not digit words, and
1863  * stored in excess-44 representation[1]. The 24-bit digit value is the 7 most
1864  * significant decimal digits of the value converted to binary. Values whose
1865  * weights would fall outside the representable range are rounded off to zero
1866  * (which is also used to represent actual zeros) or to 0x7FFFFFFF (which
1867  * otherwise cannot occur). Abbreviation therefore fails to gain any advantage
1868  * where values are outside the range 10^-44 to 10^83, which is not considered
1869  * to be a serious limitation, or when values are of the same magnitude and
1870  * equal in the first 7 decimal digits, which is considered to be an
1871  * unavoidable limitation given the available bits. (Stealing three more bits
1872  * to compare another digit would narrow the range of representable weights by
1873  * a factor of 8, which starts to look like a real limiting factor.)
1874  *
1875  * (The value 44 for the excess is essentially arbitrary)
1876  *
1877  * The 63-bit value is constructed as:
1878  *
1879  *	0 + 7bits weight + 4 x 14-bit packed digit words
1880  *
1881  * The weight in this case is again stored in excess-44, but this time it is
1882  * the original weight in digit words (i.e. powers of 10000). The first four
1883  * digit words of the value (if present; trailing zeros are assumed as needed)
1884  * are packed into 14 bits each to form the rest of the value. Again,
1885  * out-of-range values are rounded off to 0 or 0x7FFFFFFFFFFFFFFF. The
1886  * representable range in this case is 10^-176 to 10^332, which is considered
1887  * to be good enough for all practical purposes, and comparison of 4 words
1888  * means that at least 13 decimal digits are compared, which is considered to
1889  * be a reasonable compromise between effectiveness and efficiency in computing
1890  * the abbreviation.
1891  *
1892  * (The value 44 for the excess is even more arbitrary here, it was chosen just
1893  * to match the value used in the 31-bit case)
1894  *
1895  * [1] - Excess-k representation means that the value is offset by adding 'k'
1896  * and then treated as unsigned, so the smallest representable value is stored
1897  * with all bits zero. This allows simple comparisons to work on the composite
1898  * value.
1899  */
1900 
1901 #if NUMERIC_ABBREV_BITS == 64
1902 
1903 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)1904 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
1905 {
1906 	int			ndigits = var->ndigits;
1907 	int			weight = var->weight;
1908 	int64		result;
1909 
1910 	if (ndigits == 0 || weight < -44)
1911 	{
1912 		result = 0;
1913 	}
1914 	else if (weight > 83)
1915 	{
1916 		result = PG_INT64_MAX;
1917 	}
1918 	else
1919 	{
1920 		result = ((int64) (weight + 44) << 56);
1921 
1922 		switch (ndigits)
1923 		{
1924 			default:
1925 				result |= ((int64) var->digits[3]);
1926 				/* FALLTHROUGH */
1927 			case 3:
1928 				result |= ((int64) var->digits[2]) << 14;
1929 				/* FALLTHROUGH */
1930 			case 2:
1931 				result |= ((int64) var->digits[1]) << 28;
1932 				/* FALLTHROUGH */
1933 			case 1:
1934 				result |= ((int64) var->digits[0]) << 42;
1935 				break;
1936 		}
1937 	}
1938 
1939 	/* the abbrev is negated relative to the original */
1940 	if (var->sign == NUMERIC_POS)
1941 		result = -result;
1942 
1943 	if (nss->estimating)
1944 	{
1945 		uint32		tmp = ((uint32) result
1946 						   ^ (uint32) ((uint64) result >> 32));
1947 
1948 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
1949 	}
1950 
1951 	return NumericAbbrevGetDatum(result);
1952 }
1953 
1954 #endif							/* NUMERIC_ABBREV_BITS == 64 */
1955 
1956 #if NUMERIC_ABBREV_BITS == 32
1957 
1958 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)1959 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
1960 {
1961 	int			ndigits = var->ndigits;
1962 	int			weight = var->weight;
1963 	int32		result;
1964 
1965 	if (ndigits == 0 || weight < -11)
1966 	{
1967 		result = 0;
1968 	}
1969 	else if (weight > 20)
1970 	{
1971 		result = PG_INT32_MAX;
1972 	}
1973 	else
1974 	{
1975 		NumericDigit nxt1 = (ndigits > 1) ? var->digits[1] : 0;
1976 
1977 		weight = (weight + 11) * 4;
1978 
1979 		result = var->digits[0];
1980 
1981 		/*
1982 		 * "result" now has 1 to 4 nonzero decimal digits. We pack in more
1983 		 * digits to make 7 in total (largest we can fit in 24 bits)
1984 		 */
1985 
1986 		if (result > 999)
1987 		{
1988 			/* already have 4 digits, add 3 more */
1989 			result = (result * 1000) + (nxt1 / 10);
1990 			weight += 3;
1991 		}
1992 		else if (result > 99)
1993 		{
1994 			/* already have 3 digits, add 4 more */
1995 			result = (result * 10000) + nxt1;
1996 			weight += 2;
1997 		}
1998 		else if (result > 9)
1999 		{
2000 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
2001 
2002 			/* already have 2 digits, add 5 more */
2003 			result = (result * 100000) + (nxt1 * 10) + (nxt2 / 1000);
2004 			weight += 1;
2005 		}
2006 		else
2007 		{
2008 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
2009 
2010 			/* already have 1 digit, add 6 more */
2011 			result = (result * 1000000) + (nxt1 * 100) + (nxt2 / 100);
2012 		}
2013 
2014 		result = result | (weight << 24);
2015 	}
2016 
2017 	/* the abbrev is negated relative to the original */
2018 	if (var->sign == NUMERIC_POS)
2019 		result = -result;
2020 
2021 	if (nss->estimating)
2022 	{
2023 		uint32		tmp = (uint32) result;
2024 
2025 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
2026 	}
2027 
2028 	return NumericAbbrevGetDatum(result);
2029 }
2030 
2031 #endif							/* NUMERIC_ABBREV_BITS == 32 */
2032 
2033 /*
2034  * Ordinary (non-sortsupport) comparisons follow.
2035  */
2036 
2037 Datum
numeric_cmp(PG_FUNCTION_ARGS)2038 numeric_cmp(PG_FUNCTION_ARGS)
2039 {
2040 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2041 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2042 	int			result;
2043 
2044 	result = cmp_numerics(num1, num2);
2045 
2046 	PG_FREE_IF_COPY(num1, 0);
2047 	PG_FREE_IF_COPY(num2, 1);
2048 
2049 	PG_RETURN_INT32(result);
2050 }
2051 
2052 
2053 Datum
numeric_eq(PG_FUNCTION_ARGS)2054 numeric_eq(PG_FUNCTION_ARGS)
2055 {
2056 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2057 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2058 	bool		result;
2059 
2060 	result = cmp_numerics(num1, num2) == 0;
2061 
2062 	PG_FREE_IF_COPY(num1, 0);
2063 	PG_FREE_IF_COPY(num2, 1);
2064 
2065 	PG_RETURN_BOOL(result);
2066 }
2067 
2068 Datum
numeric_ne(PG_FUNCTION_ARGS)2069 numeric_ne(PG_FUNCTION_ARGS)
2070 {
2071 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2072 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2073 	bool		result;
2074 
2075 	result = cmp_numerics(num1, num2) != 0;
2076 
2077 	PG_FREE_IF_COPY(num1, 0);
2078 	PG_FREE_IF_COPY(num2, 1);
2079 
2080 	PG_RETURN_BOOL(result);
2081 }
2082 
2083 Datum
numeric_gt(PG_FUNCTION_ARGS)2084 numeric_gt(PG_FUNCTION_ARGS)
2085 {
2086 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2087 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2088 	bool		result;
2089 
2090 	result = cmp_numerics(num1, num2) > 0;
2091 
2092 	PG_FREE_IF_COPY(num1, 0);
2093 	PG_FREE_IF_COPY(num2, 1);
2094 
2095 	PG_RETURN_BOOL(result);
2096 }
2097 
2098 Datum
numeric_ge(PG_FUNCTION_ARGS)2099 numeric_ge(PG_FUNCTION_ARGS)
2100 {
2101 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2102 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2103 	bool		result;
2104 
2105 	result = cmp_numerics(num1, num2) >= 0;
2106 
2107 	PG_FREE_IF_COPY(num1, 0);
2108 	PG_FREE_IF_COPY(num2, 1);
2109 
2110 	PG_RETURN_BOOL(result);
2111 }
2112 
2113 Datum
numeric_lt(PG_FUNCTION_ARGS)2114 numeric_lt(PG_FUNCTION_ARGS)
2115 {
2116 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2117 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2118 	bool		result;
2119 
2120 	result = cmp_numerics(num1, num2) < 0;
2121 
2122 	PG_FREE_IF_COPY(num1, 0);
2123 	PG_FREE_IF_COPY(num2, 1);
2124 
2125 	PG_RETURN_BOOL(result);
2126 }
2127 
2128 Datum
numeric_le(PG_FUNCTION_ARGS)2129 numeric_le(PG_FUNCTION_ARGS)
2130 {
2131 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2132 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2133 	bool		result;
2134 
2135 	result = cmp_numerics(num1, num2) <= 0;
2136 
2137 	PG_FREE_IF_COPY(num1, 0);
2138 	PG_FREE_IF_COPY(num2, 1);
2139 
2140 	PG_RETURN_BOOL(result);
2141 }
2142 
2143 static int
cmp_numerics(Numeric num1,Numeric num2)2144 cmp_numerics(Numeric num1, Numeric num2)
2145 {
2146 	int			result;
2147 
2148 	/*
2149 	 * We consider all NANs to be equal and larger than any non-NAN. This is
2150 	 * somewhat arbitrary; the important thing is to have a consistent sort
2151 	 * order.
2152 	 */
2153 	if (NUMERIC_IS_NAN(num1))
2154 	{
2155 		if (NUMERIC_IS_NAN(num2))
2156 			result = 0;			/* NAN = NAN */
2157 		else
2158 			result = 1;			/* NAN > non-NAN */
2159 	}
2160 	else if (NUMERIC_IS_NAN(num2))
2161 	{
2162 		result = -1;			/* non-NAN < NAN */
2163 	}
2164 	else
2165 	{
2166 		result = cmp_var_common(NUMERIC_DIGITS(num1), NUMERIC_NDIGITS(num1),
2167 								NUMERIC_WEIGHT(num1), NUMERIC_SIGN(num1),
2168 								NUMERIC_DIGITS(num2), NUMERIC_NDIGITS(num2),
2169 								NUMERIC_WEIGHT(num2), NUMERIC_SIGN(num2));
2170 	}
2171 
2172 	return result;
2173 }
2174 
2175 /*
2176  * in_range support function for numeric.
2177  */
2178 Datum
in_range_numeric_numeric(PG_FUNCTION_ARGS)2179 in_range_numeric_numeric(PG_FUNCTION_ARGS)
2180 {
2181 	Numeric		val = PG_GETARG_NUMERIC(0);
2182 	Numeric		base = PG_GETARG_NUMERIC(1);
2183 	Numeric		offset = PG_GETARG_NUMERIC(2);
2184 	bool		sub = PG_GETARG_BOOL(3);
2185 	bool		less = PG_GETARG_BOOL(4);
2186 	bool		result;
2187 
2188 	/*
2189 	 * Reject negative or NaN offset.  Negative is per spec, and NaN is
2190 	 * because appropriate semantics for that seem non-obvious.
2191 	 */
2192 	if (NUMERIC_IS_NAN(offset) || NUMERIC_SIGN(offset) == NUMERIC_NEG)
2193 		ereport(ERROR,
2194 				(errcode(ERRCODE_INVALID_PRECEDING_OR_FOLLOWING_SIZE),
2195 				 errmsg("invalid preceding or following size in window function")));
2196 
2197 	/*
2198 	 * Deal with cases where val and/or base is NaN, following the rule that
2199 	 * NaN sorts after non-NaN (cf cmp_numerics).  The offset cannot affect
2200 	 * the conclusion.
2201 	 */
2202 	if (NUMERIC_IS_NAN(val))
2203 	{
2204 		if (NUMERIC_IS_NAN(base))
2205 			result = true;		/* NAN = NAN */
2206 		else
2207 			result = !less;		/* NAN > non-NAN */
2208 	}
2209 	else if (NUMERIC_IS_NAN(base))
2210 	{
2211 		result = less;			/* non-NAN < NAN */
2212 	}
2213 	else
2214 	{
2215 		/*
2216 		 * Otherwise go ahead and compute base +/- offset.  While it's
2217 		 * possible for this to overflow the numeric format, it's unlikely
2218 		 * enough that we don't take measures to prevent it.
2219 		 */
2220 		NumericVar	valv;
2221 		NumericVar	basev;
2222 		NumericVar	offsetv;
2223 		NumericVar	sum;
2224 
2225 		init_var_from_num(val, &valv);
2226 		init_var_from_num(base, &basev);
2227 		init_var_from_num(offset, &offsetv);
2228 		init_var(&sum);
2229 
2230 		if (sub)
2231 			sub_var(&basev, &offsetv, &sum);
2232 		else
2233 			add_var(&basev, &offsetv, &sum);
2234 
2235 		if (less)
2236 			result = (cmp_var(&valv, &sum) <= 0);
2237 		else
2238 			result = (cmp_var(&valv, &sum) >= 0);
2239 
2240 		free_var(&sum);
2241 	}
2242 
2243 	PG_FREE_IF_COPY(val, 0);
2244 	PG_FREE_IF_COPY(base, 1);
2245 	PG_FREE_IF_COPY(offset, 2);
2246 
2247 	PG_RETURN_BOOL(result);
2248 }
2249 
2250 Datum
hash_numeric(PG_FUNCTION_ARGS)2251 hash_numeric(PG_FUNCTION_ARGS)
2252 {
2253 	Numeric		key = PG_GETARG_NUMERIC(0);
2254 	Datum		digit_hash;
2255 	Datum		result;
2256 	int			weight;
2257 	int			start_offset;
2258 	int			end_offset;
2259 	int			i;
2260 	int			hash_len;
2261 	NumericDigit *digits;
2262 
2263 	/* If it's NaN, don't try to hash the rest of the fields */
2264 	if (NUMERIC_IS_NAN(key))
2265 		PG_RETURN_UINT32(0);
2266 
2267 	weight = NUMERIC_WEIGHT(key);
2268 	start_offset = 0;
2269 	end_offset = 0;
2270 
2271 	/*
2272 	 * Omit any leading or trailing zeros from the input to the hash. The
2273 	 * numeric implementation *should* guarantee that leading and trailing
2274 	 * zeros are suppressed, but we're paranoid. Note that we measure the
2275 	 * starting and ending offsets in units of NumericDigits, not bytes.
2276 	 */
2277 	digits = NUMERIC_DIGITS(key);
2278 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2279 	{
2280 		if (digits[i] != (NumericDigit) 0)
2281 			break;
2282 
2283 		start_offset++;
2284 
2285 		/*
2286 		 * The weight is effectively the # of digits before the decimal point,
2287 		 * so decrement it for each leading zero we skip.
2288 		 */
2289 		weight--;
2290 	}
2291 
2292 	/*
2293 	 * If there are no non-zero digits, then the value of the number is zero,
2294 	 * regardless of any other fields.
2295 	 */
2296 	if (NUMERIC_NDIGITS(key) == start_offset)
2297 		PG_RETURN_UINT32(-1);
2298 
2299 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2300 	{
2301 		if (digits[i] != (NumericDigit) 0)
2302 			break;
2303 
2304 		end_offset++;
2305 	}
2306 
2307 	/* If we get here, there should be at least one non-zero digit */
2308 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2309 
2310 	/*
2311 	 * Note that we don't hash on the Numeric's scale, since two numerics can
2312 	 * compare equal but have different scales. We also don't hash on the
2313 	 * sign, although we could: since a sign difference implies inequality,
2314 	 * this shouldn't affect correctness.
2315 	 */
2316 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2317 	digit_hash = hash_any((unsigned char *) (NUMERIC_DIGITS(key) + start_offset),
2318 						  hash_len * sizeof(NumericDigit));
2319 
2320 	/* Mix in the weight, via XOR */
2321 	result = digit_hash ^ weight;
2322 
2323 	PG_RETURN_DATUM(result);
2324 }
2325 
2326 /*
2327  * Returns 64-bit value by hashing a value to a 64-bit value, with a seed.
2328  * Otherwise, similar to hash_numeric.
2329  */
2330 Datum
hash_numeric_extended(PG_FUNCTION_ARGS)2331 hash_numeric_extended(PG_FUNCTION_ARGS)
2332 {
2333 	Numeric		key = PG_GETARG_NUMERIC(0);
2334 	uint64		seed = PG_GETARG_INT64(1);
2335 	Datum		digit_hash;
2336 	Datum		result;
2337 	int			weight;
2338 	int			start_offset;
2339 	int			end_offset;
2340 	int			i;
2341 	int			hash_len;
2342 	NumericDigit *digits;
2343 
2344 	if (NUMERIC_IS_NAN(key))
2345 		PG_RETURN_UINT64(seed);
2346 
2347 	weight = NUMERIC_WEIGHT(key);
2348 	start_offset = 0;
2349 	end_offset = 0;
2350 
2351 	digits = NUMERIC_DIGITS(key);
2352 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2353 	{
2354 		if (digits[i] != (NumericDigit) 0)
2355 			break;
2356 
2357 		start_offset++;
2358 
2359 		weight--;
2360 	}
2361 
2362 	if (NUMERIC_NDIGITS(key) == start_offset)
2363 		PG_RETURN_UINT64(seed - 1);
2364 
2365 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2366 	{
2367 		if (digits[i] != (NumericDigit) 0)
2368 			break;
2369 
2370 		end_offset++;
2371 	}
2372 
2373 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2374 
2375 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2376 	digit_hash = hash_any_extended((unsigned char *) (NUMERIC_DIGITS(key)
2377 													  + start_offset),
2378 								   hash_len * sizeof(NumericDigit),
2379 								   seed);
2380 
2381 	result = UInt64GetDatum(DatumGetUInt64(digit_hash) ^ weight);
2382 
2383 	PG_RETURN_DATUM(result);
2384 }
2385 
2386 
2387 /* ----------------------------------------------------------------------
2388  *
2389  * Basic arithmetic functions
2390  *
2391  * ----------------------------------------------------------------------
2392  */
2393 
2394 
2395 /*
2396  * numeric_add() -
2397  *
2398  *	Add two numerics
2399  */
2400 Datum
numeric_add(PG_FUNCTION_ARGS)2401 numeric_add(PG_FUNCTION_ARGS)
2402 {
2403 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2404 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2405 	Numeric		res;
2406 
2407 	res = numeric_add_opt_error(num1, num2, NULL);
2408 
2409 	PG_RETURN_NUMERIC(res);
2410 }
2411 
2412 /*
2413  * numeric_add_opt_error() -
2414  *
2415  *	Internal version of numeric_add().  If "*have_error" flag is provided,
2416  *	on error it's set to true, NULL returned.  This is helpful when caller
2417  *	need to handle errors by itself.
2418  */
2419 Numeric
numeric_add_opt_error(Numeric num1,Numeric num2,bool * have_error)2420 numeric_add_opt_error(Numeric num1, Numeric num2, bool *have_error)
2421 {
2422 	NumericVar	arg1;
2423 	NumericVar	arg2;
2424 	NumericVar	result;
2425 	Numeric		res;
2426 
2427 	/*
2428 	 * Handle NaN
2429 	 */
2430 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2431 		return make_result(&const_nan);
2432 
2433 	/*
2434 	 * Unpack the values, let add_var() compute the result and return it.
2435 	 */
2436 	init_var_from_num(num1, &arg1);
2437 	init_var_from_num(num2, &arg2);
2438 
2439 	init_var(&result);
2440 	add_var(&arg1, &arg2, &result);
2441 
2442 	res = make_result_opt_error(&result, have_error);
2443 
2444 	free_var(&result);
2445 
2446 	return res;
2447 }
2448 
2449 
2450 /*
2451  * numeric_sub() -
2452  *
2453  *	Subtract one numeric from another
2454  */
2455 Datum
numeric_sub(PG_FUNCTION_ARGS)2456 numeric_sub(PG_FUNCTION_ARGS)
2457 {
2458 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2459 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2460 	Numeric		res;
2461 
2462 	res = numeric_sub_opt_error(num1, num2, NULL);
2463 
2464 	PG_RETURN_NUMERIC(res);
2465 }
2466 
2467 
2468 /*
2469  * numeric_sub_opt_error() -
2470  *
2471  *	Internal version of numeric_sub().  If "*have_error" flag is provided,
2472  *	on error it's set to true, NULL returned.  This is helpful when caller
2473  *	need to handle errors by itself.
2474  */
2475 Numeric
numeric_sub_opt_error(Numeric num1,Numeric num2,bool * have_error)2476 numeric_sub_opt_error(Numeric num1, Numeric num2, bool *have_error)
2477 {
2478 	NumericVar	arg1;
2479 	NumericVar	arg2;
2480 	NumericVar	result;
2481 	Numeric		res;
2482 
2483 	/*
2484 	 * Handle NaN
2485 	 */
2486 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2487 		return make_result(&const_nan);
2488 
2489 	/*
2490 	 * Unpack the values, let sub_var() compute the result and return it.
2491 	 */
2492 	init_var_from_num(num1, &arg1);
2493 	init_var_from_num(num2, &arg2);
2494 
2495 	init_var(&result);
2496 	sub_var(&arg1, &arg2, &result);
2497 
2498 	res = make_result_opt_error(&result, have_error);
2499 
2500 	free_var(&result);
2501 
2502 	return res;
2503 }
2504 
2505 
2506 /*
2507  * numeric_mul() -
2508  *
2509  *	Calculate the product of two numerics
2510  */
2511 Datum
numeric_mul(PG_FUNCTION_ARGS)2512 numeric_mul(PG_FUNCTION_ARGS)
2513 {
2514 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2515 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2516 	Numeric		res;
2517 
2518 	res = numeric_mul_opt_error(num1, num2, NULL);
2519 
2520 	PG_RETURN_NUMERIC(res);
2521 }
2522 
2523 
2524 /*
2525  * numeric_mul_opt_error() -
2526  *
2527  *	Internal version of numeric_mul().  If "*have_error" flag is provided,
2528  *	on error it's set to true, NULL returned.  This is helpful when caller
2529  *	need to handle errors by itself.
2530  */
2531 Numeric
numeric_mul_opt_error(Numeric num1,Numeric num2,bool * have_error)2532 numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error)
2533 {
2534 	NumericVar	arg1;
2535 	NumericVar	arg2;
2536 	NumericVar	result;
2537 	Numeric		res;
2538 
2539 	/*
2540 	 * Handle NaN
2541 	 */
2542 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2543 		return make_result(&const_nan);
2544 
2545 	/*
2546 	 * Unpack the values, let mul_var() compute the result and return it.
2547 	 * Unlike add_var() and sub_var(), mul_var() will round its result. In the
2548 	 * case of numeric_mul(), which is invoked for the * operator on numerics,
2549 	 * we request exact representation for the product (rscale = sum(dscale of
2550 	 * arg1, dscale of arg2)).  If the exact result has more digits after the
2551 	 * decimal point than can be stored in a numeric, we round it.  Rounding
2552 	 * after computing the exact result ensures that the final result is
2553 	 * correctly rounded (rounding in mul_var() using a truncated product
2554 	 * would not guarantee this).
2555 	 */
2556 	init_var_from_num(num1, &arg1);
2557 	init_var_from_num(num2, &arg2);
2558 
2559 	init_var(&result);
2560 	mul_var(&arg1, &arg2, &result, arg1.dscale + arg2.dscale);
2561 
2562 	if (result.dscale > NUMERIC_DSCALE_MAX)
2563 		round_var(&result, NUMERIC_DSCALE_MAX);
2564 
2565 	res = make_result_opt_error(&result, have_error);
2566 
2567 	free_var(&result);
2568 
2569 	return res;
2570 }
2571 
2572 
2573 /*
2574  * numeric_div() -
2575  *
2576  *	Divide one numeric into another
2577  */
2578 Datum
numeric_div(PG_FUNCTION_ARGS)2579 numeric_div(PG_FUNCTION_ARGS)
2580 {
2581 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2582 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2583 	Numeric		res;
2584 
2585 	res = numeric_div_opt_error(num1, num2, NULL);
2586 
2587 	PG_RETURN_NUMERIC(res);
2588 }
2589 
2590 
2591 /*
2592  * numeric_div_opt_error() -
2593  *
2594  *	Internal version of numeric_div().  If "*have_error" flag is provided,
2595  *	on error it's set to true, NULL returned.  This is helpful when caller
2596  *	need to handle errors by itself.
2597  */
2598 Numeric
numeric_div_opt_error(Numeric num1,Numeric num2,bool * have_error)2599 numeric_div_opt_error(Numeric num1, Numeric num2, bool *have_error)
2600 {
2601 	NumericVar	arg1;
2602 	NumericVar	arg2;
2603 	NumericVar	result;
2604 	Numeric		res;
2605 	int			rscale;
2606 
2607 	/*
2608 	 * Handle NaN
2609 	 */
2610 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2611 		return make_result(&const_nan);
2612 
2613 	/*
2614 	 * Unpack the arguments
2615 	 */
2616 	init_var_from_num(num1, &arg1);
2617 	init_var_from_num(num2, &arg2);
2618 
2619 	init_var(&result);
2620 
2621 	/*
2622 	 * Select scale for division result
2623 	 */
2624 	rscale = select_div_scale(&arg1, &arg2);
2625 
2626 	/*
2627 	 * If "have_error" is provided, check for division by zero here
2628 	 */
2629 	if (have_error && (arg2.ndigits == 0 || arg2.digits[0] == 0))
2630 	{
2631 		*have_error = true;
2632 		return NULL;
2633 	}
2634 
2635 	/*
2636 	 * Do the divide and return the result
2637 	 */
2638 	div_var(&arg1, &arg2, &result, rscale, true);
2639 
2640 	res = make_result_opt_error(&result, have_error);
2641 
2642 	free_var(&result);
2643 
2644 	return res;
2645 }
2646 
2647 
2648 /*
2649  * numeric_div_trunc() -
2650  *
2651  *	Divide one numeric into another, truncating the result to an integer
2652  */
2653 Datum
numeric_div_trunc(PG_FUNCTION_ARGS)2654 numeric_div_trunc(PG_FUNCTION_ARGS)
2655 {
2656 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2657 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2658 	NumericVar	arg1;
2659 	NumericVar	arg2;
2660 	NumericVar	result;
2661 	Numeric		res;
2662 
2663 	/*
2664 	 * Handle NaN
2665 	 */
2666 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2667 		PG_RETURN_NUMERIC(make_result(&const_nan));
2668 
2669 	/*
2670 	 * Unpack the arguments
2671 	 */
2672 	init_var_from_num(num1, &arg1);
2673 	init_var_from_num(num2, &arg2);
2674 
2675 	init_var(&result);
2676 
2677 	/*
2678 	 * Do the divide and return the result
2679 	 */
2680 	div_var(&arg1, &arg2, &result, 0, false);
2681 
2682 	res = make_result(&result);
2683 
2684 	free_var(&result);
2685 
2686 	PG_RETURN_NUMERIC(res);
2687 }
2688 
2689 
2690 /*
2691  * numeric_mod() -
2692  *
2693  *	Calculate the modulo of two numerics
2694  */
2695 Datum
numeric_mod(PG_FUNCTION_ARGS)2696 numeric_mod(PG_FUNCTION_ARGS)
2697 {
2698 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2699 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2700 	Numeric		res;
2701 
2702 	res = numeric_mod_opt_error(num1, num2, NULL);
2703 
2704 	PG_RETURN_NUMERIC(res);
2705 }
2706 
2707 
2708 /*
2709  * numeric_mod_opt_error() -
2710  *
2711  *	Internal version of numeric_mod().  If "*have_error" flag is provided,
2712  *	on error it's set to true, NULL returned.  This is helpful when caller
2713  *	need to handle errors by itself.
2714  */
2715 Numeric
numeric_mod_opt_error(Numeric num1,Numeric num2,bool * have_error)2716 numeric_mod_opt_error(Numeric num1, Numeric num2, bool *have_error)
2717 {
2718 	Numeric		res;
2719 	NumericVar	arg1;
2720 	NumericVar	arg2;
2721 	NumericVar	result;
2722 
2723 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2724 		return make_result(&const_nan);
2725 
2726 	init_var_from_num(num1, &arg1);
2727 	init_var_from_num(num2, &arg2);
2728 
2729 	init_var(&result);
2730 
2731 	/*
2732 	 * If "have_error" is provided, check for division by zero here
2733 	 */
2734 	if (have_error && (arg2.ndigits == 0 || arg2.digits[0] == 0))
2735 	{
2736 		*have_error = true;
2737 		return NULL;
2738 	}
2739 
2740 	mod_var(&arg1, &arg2, &result);
2741 
2742 	res = make_result_opt_error(&result, NULL);
2743 
2744 	free_var(&result);
2745 
2746 	return res;
2747 }
2748 
2749 
2750 /*
2751  * numeric_inc() -
2752  *
2753  *	Increment a number by one
2754  */
2755 Datum
numeric_inc(PG_FUNCTION_ARGS)2756 numeric_inc(PG_FUNCTION_ARGS)
2757 {
2758 	Numeric		num = PG_GETARG_NUMERIC(0);
2759 	NumericVar	arg;
2760 	Numeric		res;
2761 
2762 	/*
2763 	 * Handle NaN
2764 	 */
2765 	if (NUMERIC_IS_NAN(num))
2766 		PG_RETURN_NUMERIC(make_result(&const_nan));
2767 
2768 	/*
2769 	 * Compute the result and return it
2770 	 */
2771 	init_var_from_num(num, &arg);
2772 
2773 	add_var(&arg, &const_one, &arg);
2774 
2775 	res = make_result(&arg);
2776 
2777 	free_var(&arg);
2778 
2779 	PG_RETURN_NUMERIC(res);
2780 }
2781 
2782 
2783 /*
2784  * numeric_smaller() -
2785  *
2786  *	Return the smaller of two numbers
2787  */
2788 Datum
numeric_smaller(PG_FUNCTION_ARGS)2789 numeric_smaller(PG_FUNCTION_ARGS)
2790 {
2791 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2792 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2793 
2794 	/*
2795 	 * Use cmp_numerics so that this will agree with the comparison operators,
2796 	 * particularly as regards comparisons involving NaN.
2797 	 */
2798 	if (cmp_numerics(num1, num2) < 0)
2799 		PG_RETURN_NUMERIC(num1);
2800 	else
2801 		PG_RETURN_NUMERIC(num2);
2802 }
2803 
2804 
2805 /*
2806  * numeric_larger() -
2807  *
2808  *	Return the larger of two numbers
2809  */
2810 Datum
numeric_larger(PG_FUNCTION_ARGS)2811 numeric_larger(PG_FUNCTION_ARGS)
2812 {
2813 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2814 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2815 
2816 	/*
2817 	 * Use cmp_numerics so that this will agree with the comparison operators,
2818 	 * particularly as regards comparisons involving NaN.
2819 	 */
2820 	if (cmp_numerics(num1, num2) > 0)
2821 		PG_RETURN_NUMERIC(num1);
2822 	else
2823 		PG_RETURN_NUMERIC(num2);
2824 }
2825 
2826 
2827 /* ----------------------------------------------------------------------
2828  *
2829  * Advanced math functions
2830  *
2831  * ----------------------------------------------------------------------
2832  */
2833 
2834 /*
2835  * numeric_fac()
2836  *
2837  * Compute factorial
2838  */
2839 Datum
numeric_fac(PG_FUNCTION_ARGS)2840 numeric_fac(PG_FUNCTION_ARGS)
2841 {
2842 	int64		num = PG_GETARG_INT64(0);
2843 	Numeric		res;
2844 	NumericVar	fact;
2845 	NumericVar	result;
2846 
2847 	if (num <= 1)
2848 	{
2849 		res = make_result(&const_one);
2850 		PG_RETURN_NUMERIC(res);
2851 	}
2852 	/* Fail immediately if the result would overflow */
2853 	if (num > 32177)
2854 		ereport(ERROR,
2855 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2856 				 errmsg("value overflows numeric format")));
2857 
2858 	init_var(&fact);
2859 	init_var(&result);
2860 
2861 	int64_to_numericvar(num, &result);
2862 
2863 	for (num = num - 1; num > 1; num--)
2864 	{
2865 		/* this loop can take awhile, so allow it to be interrupted */
2866 		CHECK_FOR_INTERRUPTS();
2867 
2868 		int64_to_numericvar(num, &fact);
2869 
2870 		mul_var(&result, &fact, &result, 0);
2871 	}
2872 
2873 	res = make_result(&result);
2874 
2875 	free_var(&fact);
2876 	free_var(&result);
2877 
2878 	PG_RETURN_NUMERIC(res);
2879 }
2880 
2881 
2882 /*
2883  * numeric_sqrt() -
2884  *
2885  *	Compute the square root of a numeric.
2886  */
2887 Datum
numeric_sqrt(PG_FUNCTION_ARGS)2888 numeric_sqrt(PG_FUNCTION_ARGS)
2889 {
2890 	Numeric		num = PG_GETARG_NUMERIC(0);
2891 	Numeric		res;
2892 	NumericVar	arg;
2893 	NumericVar	result;
2894 	int			sweight;
2895 	int			rscale;
2896 
2897 	/*
2898 	 * Handle NaN
2899 	 */
2900 	if (NUMERIC_IS_NAN(num))
2901 		PG_RETURN_NUMERIC(make_result(&const_nan));
2902 
2903 	/*
2904 	 * Unpack the argument and determine the result scale.  We choose a scale
2905 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
2906 	 * case not less than the input's dscale.
2907 	 */
2908 	init_var_from_num(num, &arg);
2909 
2910 	init_var(&result);
2911 
2912 	/* Assume the input was normalized, so arg.weight is accurate */
2913 	sweight = (arg.weight + 1) * DEC_DIGITS / 2 - 1;
2914 
2915 	rscale = NUMERIC_MIN_SIG_DIGITS - sweight;
2916 	rscale = Max(rscale, arg.dscale);
2917 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
2918 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
2919 
2920 	/*
2921 	 * Let sqrt_var() do the calculation and return the result.
2922 	 */
2923 	sqrt_var(&arg, &result, rscale);
2924 
2925 	res = make_result(&result);
2926 
2927 	free_var(&result);
2928 
2929 	PG_RETURN_NUMERIC(res);
2930 }
2931 
2932 
2933 /*
2934  * numeric_exp() -
2935  *
2936  *	Raise e to the power of x
2937  */
2938 Datum
numeric_exp(PG_FUNCTION_ARGS)2939 numeric_exp(PG_FUNCTION_ARGS)
2940 {
2941 	Numeric		num = PG_GETARG_NUMERIC(0);
2942 	Numeric		res;
2943 	NumericVar	arg;
2944 	NumericVar	result;
2945 	int			rscale;
2946 	double		val;
2947 
2948 	/*
2949 	 * Handle NaN
2950 	 */
2951 	if (NUMERIC_IS_NAN(num))
2952 		PG_RETURN_NUMERIC(make_result(&const_nan));
2953 
2954 	/*
2955 	 * Unpack the argument and determine the result scale.  We choose a scale
2956 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
2957 	 * case not less than the input's dscale.
2958 	 */
2959 	init_var_from_num(num, &arg);
2960 
2961 	init_var(&result);
2962 
2963 	/* convert input to float8, ignoring overflow */
2964 	val = numericvar_to_double_no_overflow(&arg);
2965 
2966 	/*
2967 	 * log10(result) = num * log10(e), so this is approximately the decimal
2968 	 * weight of the result:
2969 	 */
2970 	val *= 0.434294481903252;
2971 
2972 	/* limit to something that won't cause integer overflow */
2973 	val = Max(val, -NUMERIC_MAX_RESULT_SCALE);
2974 	val = Min(val, NUMERIC_MAX_RESULT_SCALE);
2975 
2976 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
2977 	rscale = Max(rscale, arg.dscale);
2978 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
2979 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
2980 
2981 	/*
2982 	 * Let exp_var() do the calculation and return the result.
2983 	 */
2984 	exp_var(&arg, &result, rscale);
2985 
2986 	res = make_result(&result);
2987 
2988 	free_var(&result);
2989 
2990 	PG_RETURN_NUMERIC(res);
2991 }
2992 
2993 
2994 /*
2995  * numeric_ln() -
2996  *
2997  *	Compute the natural logarithm of x
2998  */
2999 Datum
numeric_ln(PG_FUNCTION_ARGS)3000 numeric_ln(PG_FUNCTION_ARGS)
3001 {
3002 	Numeric		num = PG_GETARG_NUMERIC(0);
3003 	Numeric		res;
3004 	NumericVar	arg;
3005 	NumericVar	result;
3006 	int			ln_dweight;
3007 	int			rscale;
3008 
3009 	/*
3010 	 * Handle NaN
3011 	 */
3012 	if (NUMERIC_IS_NAN(num))
3013 		PG_RETURN_NUMERIC(make_result(&const_nan));
3014 
3015 	init_var_from_num(num, &arg);
3016 	init_var(&result);
3017 
3018 	/* Estimated dweight of logarithm */
3019 	ln_dweight = estimate_ln_dweight(&arg);
3020 
3021 	rscale = NUMERIC_MIN_SIG_DIGITS - ln_dweight;
3022 	rscale = Max(rscale, arg.dscale);
3023 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3024 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3025 
3026 	ln_var(&arg, &result, rscale);
3027 
3028 	res = make_result(&result);
3029 
3030 	free_var(&result);
3031 
3032 	PG_RETURN_NUMERIC(res);
3033 }
3034 
3035 
3036 /*
3037  * numeric_log() -
3038  *
3039  *	Compute the logarithm of x in a given base
3040  */
3041 Datum
numeric_log(PG_FUNCTION_ARGS)3042 numeric_log(PG_FUNCTION_ARGS)
3043 {
3044 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3045 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3046 	Numeric		res;
3047 	NumericVar	arg1;
3048 	NumericVar	arg2;
3049 	NumericVar	result;
3050 
3051 	/*
3052 	 * Handle NaN
3053 	 */
3054 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
3055 		PG_RETURN_NUMERIC(make_result(&const_nan));
3056 
3057 	/*
3058 	 * Initialize things
3059 	 */
3060 	init_var_from_num(num1, &arg1);
3061 	init_var_from_num(num2, &arg2);
3062 	init_var(&result);
3063 
3064 	/*
3065 	 * Call log_var() to compute and return the result; note it handles scale
3066 	 * selection itself.
3067 	 */
3068 	log_var(&arg1, &arg2, &result);
3069 
3070 	res = make_result(&result);
3071 
3072 	free_var(&result);
3073 
3074 	PG_RETURN_NUMERIC(res);
3075 }
3076 
3077 
3078 /*
3079  * numeric_power() -
3080  *
3081  *	Raise b to the power of x
3082  */
3083 Datum
numeric_power(PG_FUNCTION_ARGS)3084 numeric_power(PG_FUNCTION_ARGS)
3085 {
3086 	Numeric		num1 = PG_GETARG_NUMERIC(0);
3087 	Numeric		num2 = PG_GETARG_NUMERIC(1);
3088 	Numeric		res;
3089 	NumericVar	arg1;
3090 	NumericVar	arg2;
3091 	NumericVar	arg2_trunc;
3092 	NumericVar	result;
3093 
3094 	/*
3095 	 * Handle NaN cases.  We follow the POSIX spec for pow(3), which says that
3096 	 * NaN ^ 0 = 1, and 1 ^ NaN = 1, while all other cases with NaN inputs
3097 	 * yield NaN (with no error).
3098 	 */
3099 	if (NUMERIC_IS_NAN(num1))
3100 	{
3101 		if (!NUMERIC_IS_NAN(num2))
3102 		{
3103 			init_var_from_num(num2, &arg2);
3104 			if (cmp_var(&arg2, &const_zero) == 0)
3105 				PG_RETURN_NUMERIC(make_result(&const_one));
3106 		}
3107 		PG_RETURN_NUMERIC(make_result(&const_nan));
3108 	}
3109 	if (NUMERIC_IS_NAN(num2))
3110 	{
3111 		init_var_from_num(num1, &arg1);
3112 		if (cmp_var(&arg1, &const_one) == 0)
3113 			PG_RETURN_NUMERIC(make_result(&const_one));
3114 		PG_RETURN_NUMERIC(make_result(&const_nan));
3115 	}
3116 
3117 	/*
3118 	 * Initialize things
3119 	 */
3120 	init_var(&arg2_trunc);
3121 	init_var(&result);
3122 	init_var_from_num(num1, &arg1);
3123 	init_var_from_num(num2, &arg2);
3124 
3125 	set_var_from_var(&arg2, &arg2_trunc);
3126 	trunc_var(&arg2_trunc, 0);
3127 
3128 	/*
3129 	 * The SQL spec requires that we emit a particular SQLSTATE error code for
3130 	 * certain error conditions.  Specifically, we don't return a
3131 	 * divide-by-zero error code for 0 ^ -1.  Raising a negative number to a
3132 	 * non-integer power must produce the same error code, but that case is
3133 	 * handled in power_var().
3134 	 */
3135 	if (cmp_var(&arg1, &const_zero) == 0 &&
3136 		cmp_var(&arg2, &const_zero) < 0)
3137 		ereport(ERROR,
3138 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
3139 				 errmsg("zero raised to a negative power is undefined")));
3140 
3141 	/*
3142 	 * Call power_var() to compute and return the result; note it handles
3143 	 * scale selection itself.
3144 	 */
3145 	power_var(&arg1, &arg2, &result);
3146 
3147 	res = make_result(&result);
3148 
3149 	free_var(&result);
3150 	free_var(&arg2_trunc);
3151 
3152 	PG_RETURN_NUMERIC(res);
3153 }
3154 
3155 /*
3156  * numeric_scale() -
3157  *
3158  *	Returns the scale, i.e. the count of decimal digits in the fractional part
3159  */
3160 Datum
numeric_scale(PG_FUNCTION_ARGS)3161 numeric_scale(PG_FUNCTION_ARGS)
3162 {
3163 	Numeric		num = PG_GETARG_NUMERIC(0);
3164 
3165 	if (NUMERIC_IS_NAN(num))
3166 		PG_RETURN_NULL();
3167 
3168 	PG_RETURN_INT32(NUMERIC_DSCALE(num));
3169 }
3170 
3171 
3172 
3173 /* ----------------------------------------------------------------------
3174  *
3175  * Type conversion functions
3176  *
3177  * ----------------------------------------------------------------------
3178  */
3179 
3180 
3181 Datum
int4_numeric(PG_FUNCTION_ARGS)3182 int4_numeric(PG_FUNCTION_ARGS)
3183 {
3184 	int32		val = PG_GETARG_INT32(0);
3185 	Numeric		res;
3186 	NumericVar	result;
3187 
3188 	init_var(&result);
3189 
3190 	int64_to_numericvar((int64) val, &result);
3191 
3192 	res = make_result(&result);
3193 
3194 	free_var(&result);
3195 
3196 	PG_RETURN_NUMERIC(res);
3197 }
3198 
3199 int32
numeric_int4_opt_error(Numeric num,bool * have_error)3200 numeric_int4_opt_error(Numeric num, bool *have_error)
3201 {
3202 	NumericVar	x;
3203 	int32		result;
3204 
3205 	/* XXX would it be better to return NULL? */
3206 	if (NUMERIC_IS_NAN(num))
3207 	{
3208 		if (have_error)
3209 		{
3210 			*have_error = true;
3211 			return 0;
3212 		}
3213 		else
3214 		{
3215 			ereport(ERROR,
3216 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3217 					 errmsg("cannot convert NaN to integer")));
3218 		}
3219 	}
3220 
3221 	/* Convert to variable format, then convert to int4 */
3222 	init_var_from_num(num, &x);
3223 
3224 	if (!numericvar_to_int32(&x, &result))
3225 	{
3226 		if (have_error)
3227 		{
3228 			*have_error = true;
3229 			return 0;
3230 		}
3231 		else
3232 		{
3233 			ereport(ERROR,
3234 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3235 					 errmsg("integer out of range")));
3236 		}
3237 	}
3238 
3239 	return result;
3240 }
3241 
3242 Datum
numeric_int4(PG_FUNCTION_ARGS)3243 numeric_int4(PG_FUNCTION_ARGS)
3244 {
3245 	Numeric		num = PG_GETARG_NUMERIC(0);
3246 
3247 	PG_RETURN_INT32(numeric_int4_opt_error(num, NULL));
3248 }
3249 
3250 /*
3251  * Given a NumericVar, convert it to an int32. If the NumericVar
3252  * exceeds the range of an int32, false is returned, otherwise true is returned.
3253  * The input NumericVar is *not* free'd.
3254  */
3255 static bool
numericvar_to_int32(const NumericVar * var,int32 * result)3256 numericvar_to_int32(const NumericVar *var, int32 *result)
3257 {
3258 	int64		val;
3259 
3260 	if (!numericvar_to_int64(var, &val))
3261 		return false;
3262 
3263 	if (unlikely(val < PG_INT32_MIN) || unlikely(val > PG_INT32_MAX))
3264 		return false;
3265 
3266 	/* Down-convert to int4 */
3267 	*result = (int32) val;
3268 
3269 	return true;
3270 }
3271 
3272 Datum
int8_numeric(PG_FUNCTION_ARGS)3273 int8_numeric(PG_FUNCTION_ARGS)
3274 {
3275 	int64		val = PG_GETARG_INT64(0);
3276 	Numeric		res;
3277 	NumericVar	result;
3278 
3279 	init_var(&result);
3280 
3281 	int64_to_numericvar(val, &result);
3282 
3283 	res = make_result(&result);
3284 
3285 	free_var(&result);
3286 
3287 	PG_RETURN_NUMERIC(res);
3288 }
3289 
3290 
3291 Datum
numeric_int8(PG_FUNCTION_ARGS)3292 numeric_int8(PG_FUNCTION_ARGS)
3293 {
3294 	Numeric		num = PG_GETARG_NUMERIC(0);
3295 	NumericVar	x;
3296 	int64		result;
3297 
3298 	/* XXX would it be better to return NULL? */
3299 	if (NUMERIC_IS_NAN(num))
3300 		ereport(ERROR,
3301 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3302 				 errmsg("cannot convert NaN to bigint")));
3303 
3304 	/* Convert to variable format and thence to int8 */
3305 	init_var_from_num(num, &x);
3306 
3307 	if (!numericvar_to_int64(&x, &result))
3308 		ereport(ERROR,
3309 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3310 				 errmsg("bigint out of range")));
3311 
3312 	PG_RETURN_INT64(result);
3313 }
3314 
3315 
3316 Datum
int2_numeric(PG_FUNCTION_ARGS)3317 int2_numeric(PG_FUNCTION_ARGS)
3318 {
3319 	int16		val = PG_GETARG_INT16(0);
3320 	Numeric		res;
3321 	NumericVar	result;
3322 
3323 	init_var(&result);
3324 
3325 	int64_to_numericvar((int64) val, &result);
3326 
3327 	res = make_result(&result);
3328 
3329 	free_var(&result);
3330 
3331 	PG_RETURN_NUMERIC(res);
3332 }
3333 
3334 
3335 Datum
numeric_int2(PG_FUNCTION_ARGS)3336 numeric_int2(PG_FUNCTION_ARGS)
3337 {
3338 	Numeric		num = PG_GETARG_NUMERIC(0);
3339 	NumericVar	x;
3340 	int64		val;
3341 	int16		result;
3342 
3343 	/* XXX would it be better to return NULL? */
3344 	if (NUMERIC_IS_NAN(num))
3345 		ereport(ERROR,
3346 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3347 				 errmsg("cannot convert NaN to smallint")));
3348 
3349 	/* Convert to variable format and thence to int8 */
3350 	init_var_from_num(num, &x);
3351 
3352 	if (!numericvar_to_int64(&x, &val))
3353 		ereport(ERROR,
3354 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3355 				 errmsg("smallint out of range")));
3356 
3357 	if (unlikely(val < PG_INT16_MIN) || unlikely(val > PG_INT16_MAX))
3358 		ereport(ERROR,
3359 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3360 				 errmsg("smallint out of range")));
3361 
3362 	/* Down-convert to int2 */
3363 	result = (int16) val;
3364 
3365 	PG_RETURN_INT16(result);
3366 }
3367 
3368 
3369 Datum
float8_numeric(PG_FUNCTION_ARGS)3370 float8_numeric(PG_FUNCTION_ARGS)
3371 {
3372 	float8		val = PG_GETARG_FLOAT8(0);
3373 	Numeric		res;
3374 	NumericVar	result;
3375 	char		buf[DBL_DIG + 100];
3376 
3377 	if (isnan(val))
3378 		PG_RETURN_NUMERIC(make_result(&const_nan));
3379 
3380 	if (isinf(val))
3381 		ereport(ERROR,
3382 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3383 				 errmsg("cannot convert infinity to numeric")));
3384 
3385 	snprintf(buf, sizeof(buf), "%.*g", DBL_DIG, val);
3386 
3387 	init_var(&result);
3388 
3389 	/* Assume we need not worry about leading/trailing spaces */
3390 	(void) set_var_from_str(buf, buf, &result);
3391 
3392 	res = make_result(&result);
3393 
3394 	free_var(&result);
3395 
3396 	PG_RETURN_NUMERIC(res);
3397 }
3398 
3399 
3400 Datum
numeric_float8(PG_FUNCTION_ARGS)3401 numeric_float8(PG_FUNCTION_ARGS)
3402 {
3403 	Numeric		num = PG_GETARG_NUMERIC(0);
3404 	char	   *tmp;
3405 	Datum		result;
3406 
3407 	if (NUMERIC_IS_NAN(num))
3408 		PG_RETURN_FLOAT8(get_float8_nan());
3409 
3410 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3411 											  NumericGetDatum(num)));
3412 
3413 	result = DirectFunctionCall1(float8in, CStringGetDatum(tmp));
3414 
3415 	pfree(tmp);
3416 
3417 	PG_RETURN_DATUM(result);
3418 }
3419 
3420 
3421 /*
3422  * Convert numeric to float8; if out of range, return +/- HUGE_VAL
3423  *
3424  * (internal helper function, not directly callable from SQL)
3425  */
3426 Datum
numeric_float8_no_overflow(PG_FUNCTION_ARGS)3427 numeric_float8_no_overflow(PG_FUNCTION_ARGS)
3428 {
3429 	Numeric		num = PG_GETARG_NUMERIC(0);
3430 	double		val;
3431 
3432 	if (NUMERIC_IS_NAN(num))
3433 		PG_RETURN_FLOAT8(get_float8_nan());
3434 
3435 	val = numeric_to_double_no_overflow(num);
3436 
3437 	PG_RETURN_FLOAT8(val);
3438 }
3439 
3440 Datum
float4_numeric(PG_FUNCTION_ARGS)3441 float4_numeric(PG_FUNCTION_ARGS)
3442 {
3443 	float4		val = PG_GETARG_FLOAT4(0);
3444 	Numeric		res;
3445 	NumericVar	result;
3446 	char		buf[FLT_DIG + 100];
3447 
3448 	if (isnan(val))
3449 		PG_RETURN_NUMERIC(make_result(&const_nan));
3450 
3451 	if (isinf(val))
3452 		ereport(ERROR,
3453 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3454 				 errmsg("cannot convert infinity to numeric")));
3455 
3456 	snprintf(buf, sizeof(buf), "%.*g", FLT_DIG, val);
3457 
3458 	init_var(&result);
3459 
3460 	/* Assume we need not worry about leading/trailing spaces */
3461 	(void) set_var_from_str(buf, buf, &result);
3462 
3463 	res = make_result(&result);
3464 
3465 	free_var(&result);
3466 
3467 	PG_RETURN_NUMERIC(res);
3468 }
3469 
3470 
3471 Datum
numeric_float4(PG_FUNCTION_ARGS)3472 numeric_float4(PG_FUNCTION_ARGS)
3473 {
3474 	Numeric		num = PG_GETARG_NUMERIC(0);
3475 	char	   *tmp;
3476 	Datum		result;
3477 
3478 	if (NUMERIC_IS_NAN(num))
3479 		PG_RETURN_FLOAT4(get_float4_nan());
3480 
3481 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3482 											  NumericGetDatum(num)));
3483 
3484 	result = DirectFunctionCall1(float4in, CStringGetDatum(tmp));
3485 
3486 	pfree(tmp);
3487 
3488 	PG_RETURN_DATUM(result);
3489 }
3490 
3491 
3492 /* ----------------------------------------------------------------------
3493  *
3494  * Aggregate functions
3495  *
3496  * The transition datatype for all these aggregates is declared as INTERNAL.
3497  * Actually, it's a pointer to a NumericAggState allocated in the aggregate
3498  * context.  The digit buffers for the NumericVars will be there too.
3499  *
3500  * On platforms which support 128-bit integers some aggregates instead use a
3501  * 128-bit integer based transition datatype to speed up calculations.
3502  *
3503  * ----------------------------------------------------------------------
3504  */
3505 
3506 typedef struct NumericAggState
3507 {
3508 	bool		calcSumX2;		/* if true, calculate sumX2 */
3509 	MemoryContext agg_context;	/* context we're calculating in */
3510 	int64		N;				/* count of processed numbers */
3511 	NumericSumAccum sumX;		/* sum of processed numbers */
3512 	NumericSumAccum sumX2;		/* sum of squares of processed numbers */
3513 	int			maxScale;		/* maximum scale seen so far */
3514 	int64		maxScaleCount;	/* number of values seen with maximum scale */
3515 	int64		NaNcount;		/* count of NaN values (not included in N!) */
3516 } NumericAggState;
3517 
3518 /*
3519  * Prepare state data for a numeric aggregate function that needs to compute
3520  * sum, count and optionally sum of squares of the input.
3521  */
3522 static NumericAggState *
makeNumericAggState(FunctionCallInfo fcinfo,bool calcSumX2)3523 makeNumericAggState(FunctionCallInfo fcinfo, bool calcSumX2)
3524 {
3525 	NumericAggState *state;
3526 	MemoryContext agg_context;
3527 	MemoryContext old_context;
3528 
3529 	if (!AggCheckCallContext(fcinfo, &agg_context))
3530 		elog(ERROR, "aggregate function called in non-aggregate context");
3531 
3532 	old_context = MemoryContextSwitchTo(agg_context);
3533 
3534 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3535 	state->calcSumX2 = calcSumX2;
3536 	state->agg_context = agg_context;
3537 
3538 	MemoryContextSwitchTo(old_context);
3539 
3540 	return state;
3541 }
3542 
3543 /*
3544  * Like makeNumericAggState(), but allocate the state in the current memory
3545  * context.
3546  */
3547 static NumericAggState *
makeNumericAggStateCurrentContext(bool calcSumX2)3548 makeNumericAggStateCurrentContext(bool calcSumX2)
3549 {
3550 	NumericAggState *state;
3551 
3552 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3553 	state->calcSumX2 = calcSumX2;
3554 	state->agg_context = CurrentMemoryContext;
3555 
3556 	return state;
3557 }
3558 
3559 /*
3560  * Accumulate a new input value for numeric aggregate functions.
3561  */
3562 static void
do_numeric_accum(NumericAggState * state,Numeric newval)3563 do_numeric_accum(NumericAggState *state, Numeric newval)
3564 {
3565 	NumericVar	X;
3566 	NumericVar	X2;
3567 	MemoryContext old_context;
3568 
3569 	/* Count NaN inputs separately from all else */
3570 	if (NUMERIC_IS_NAN(newval))
3571 	{
3572 		state->NaNcount++;
3573 		return;
3574 	}
3575 
3576 	/* load processed number in short-lived context */
3577 	init_var_from_num(newval, &X);
3578 
3579 	/*
3580 	 * Track the highest input dscale that we've seen, to support inverse
3581 	 * transitions (see do_numeric_discard).
3582 	 */
3583 	if (X.dscale > state->maxScale)
3584 	{
3585 		state->maxScale = X.dscale;
3586 		state->maxScaleCount = 1;
3587 	}
3588 	else if (X.dscale == state->maxScale)
3589 		state->maxScaleCount++;
3590 
3591 	/* if we need X^2, calculate that in short-lived context */
3592 	if (state->calcSumX2)
3593 	{
3594 		init_var(&X2);
3595 		mul_var(&X, &X, &X2, X.dscale * 2);
3596 	}
3597 
3598 	/* The rest of this needs to work in the aggregate context */
3599 	old_context = MemoryContextSwitchTo(state->agg_context);
3600 
3601 	state->N++;
3602 
3603 	/* Accumulate sums */
3604 	accum_sum_add(&(state->sumX), &X);
3605 
3606 	if (state->calcSumX2)
3607 		accum_sum_add(&(state->sumX2), &X2);
3608 
3609 	MemoryContextSwitchTo(old_context);
3610 }
3611 
3612 /*
3613  * Attempt to remove an input value from the aggregated state.
3614  *
3615  * If the value cannot be removed then the function will return false; the
3616  * possible reasons for failing are described below.
3617  *
3618  * If we aggregate the values 1.01 and 2 then the result will be 3.01.
3619  * If we are then asked to un-aggregate the 1.01 then we must fail as we
3620  * won't be able to tell what the new aggregated value's dscale should be.
3621  * We don't want to return 2.00 (dscale = 2), since the sum's dscale would
3622  * have been zero if we'd really aggregated only 2.
3623  *
3624  * Note: alternatively, we could count the number of inputs with each possible
3625  * dscale (up to some sane limit).  Not yet clear if it's worth the trouble.
3626  */
3627 static bool
do_numeric_discard(NumericAggState * state,Numeric newval)3628 do_numeric_discard(NumericAggState *state, Numeric newval)
3629 {
3630 	NumericVar	X;
3631 	NumericVar	X2;
3632 	MemoryContext old_context;
3633 
3634 	/* Count NaN inputs separately from all else */
3635 	if (NUMERIC_IS_NAN(newval))
3636 	{
3637 		state->NaNcount--;
3638 		return true;
3639 	}
3640 
3641 	/* load processed number in short-lived context */
3642 	init_var_from_num(newval, &X);
3643 
3644 	/*
3645 	 * state->sumX's dscale is the maximum dscale of any of the inputs.
3646 	 * Removing the last input with that dscale would require us to recompute
3647 	 * the maximum dscale of the *remaining* inputs, which we cannot do unless
3648 	 * no more non-NaN inputs remain at all.  So we report a failure instead,
3649 	 * and force the aggregation to be redone from scratch.
3650 	 */
3651 	if (X.dscale == state->maxScale)
3652 	{
3653 		if (state->maxScaleCount > 1 || state->maxScale == 0)
3654 		{
3655 			/*
3656 			 * Some remaining inputs have same dscale, or dscale hasn't gotten
3657 			 * above zero anyway
3658 			 */
3659 			state->maxScaleCount--;
3660 		}
3661 		else if (state->N == 1)
3662 		{
3663 			/* No remaining non-NaN inputs at all, so reset maxScale */
3664 			state->maxScale = 0;
3665 			state->maxScaleCount = 0;
3666 		}
3667 		else
3668 		{
3669 			/* Correct new maxScale is uncertain, must fail */
3670 			return false;
3671 		}
3672 	}
3673 
3674 	/* if we need X^2, calculate that in short-lived context */
3675 	if (state->calcSumX2)
3676 	{
3677 		init_var(&X2);
3678 		mul_var(&X, &X, &X2, X.dscale * 2);
3679 	}
3680 
3681 	/* The rest of this needs to work in the aggregate context */
3682 	old_context = MemoryContextSwitchTo(state->agg_context);
3683 
3684 	if (state->N-- > 1)
3685 	{
3686 		/* Negate X, to subtract it from the sum */
3687 		X.sign = (X.sign == NUMERIC_POS ? NUMERIC_NEG : NUMERIC_POS);
3688 		accum_sum_add(&(state->sumX), &X);
3689 
3690 		if (state->calcSumX2)
3691 		{
3692 			/* Negate X^2. X^2 is always positive */
3693 			X2.sign = NUMERIC_NEG;
3694 			accum_sum_add(&(state->sumX2), &X2);
3695 		}
3696 	}
3697 	else
3698 	{
3699 		/* Zero the sums */
3700 		Assert(state->N == 0);
3701 
3702 		accum_sum_reset(&state->sumX);
3703 		if (state->calcSumX2)
3704 			accum_sum_reset(&state->sumX2);
3705 	}
3706 
3707 	MemoryContextSwitchTo(old_context);
3708 
3709 	return true;
3710 }
3711 
3712 /*
3713  * Generic transition function for numeric aggregates that require sumX2.
3714  */
3715 Datum
numeric_accum(PG_FUNCTION_ARGS)3716 numeric_accum(PG_FUNCTION_ARGS)
3717 {
3718 	NumericAggState *state;
3719 
3720 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3721 
3722 	/* Create the state data on the first call */
3723 	if (state == NULL)
3724 		state = makeNumericAggState(fcinfo, true);
3725 
3726 	if (!PG_ARGISNULL(1))
3727 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
3728 
3729 	PG_RETURN_POINTER(state);
3730 }
3731 
3732 /*
3733  * Generic combine function for numeric aggregates which require sumX2
3734  */
3735 Datum
numeric_combine(PG_FUNCTION_ARGS)3736 numeric_combine(PG_FUNCTION_ARGS)
3737 {
3738 	NumericAggState *state1;
3739 	NumericAggState *state2;
3740 	MemoryContext agg_context;
3741 	MemoryContext old_context;
3742 
3743 	if (!AggCheckCallContext(fcinfo, &agg_context))
3744 		elog(ERROR, "aggregate function called in non-aggregate context");
3745 
3746 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3747 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
3748 
3749 	if (state2 == NULL)
3750 		PG_RETURN_POINTER(state1);
3751 
3752 	/* manually copy all fields from state2 to state1 */
3753 	if (state1 == NULL)
3754 	{
3755 		old_context = MemoryContextSwitchTo(agg_context);
3756 
3757 		state1 = makeNumericAggStateCurrentContext(true);
3758 		state1->N = state2->N;
3759 		state1->NaNcount = state2->NaNcount;
3760 		state1->maxScale = state2->maxScale;
3761 		state1->maxScaleCount = state2->maxScaleCount;
3762 
3763 		accum_sum_copy(&state1->sumX, &state2->sumX);
3764 		accum_sum_copy(&state1->sumX2, &state2->sumX2);
3765 
3766 		MemoryContextSwitchTo(old_context);
3767 
3768 		PG_RETURN_POINTER(state1);
3769 	}
3770 
3771 	state1->N += state2->N;
3772 	state1->NaNcount += state2->NaNcount;
3773 
3774 	if (state2->N > 0)
3775 	{
3776 		/*
3777 		 * These are currently only needed for moving aggregates, but let's do
3778 		 * the right thing anyway...
3779 		 */
3780 		if (state2->maxScale > state1->maxScale)
3781 		{
3782 			state1->maxScale = state2->maxScale;
3783 			state1->maxScaleCount = state2->maxScaleCount;
3784 		}
3785 		else if (state2->maxScale == state1->maxScale)
3786 			state1->maxScaleCount += state2->maxScaleCount;
3787 
3788 		/* The rest of this needs to work in the aggregate context */
3789 		old_context = MemoryContextSwitchTo(agg_context);
3790 
3791 		/* Accumulate sums */
3792 		accum_sum_combine(&state1->sumX, &state2->sumX);
3793 		accum_sum_combine(&state1->sumX2, &state2->sumX2);
3794 
3795 		MemoryContextSwitchTo(old_context);
3796 	}
3797 	PG_RETURN_POINTER(state1);
3798 }
3799 
3800 /*
3801  * Generic transition function for numeric aggregates that don't require sumX2.
3802  */
3803 Datum
numeric_avg_accum(PG_FUNCTION_ARGS)3804 numeric_avg_accum(PG_FUNCTION_ARGS)
3805 {
3806 	NumericAggState *state;
3807 
3808 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3809 
3810 	/* Create the state data on the first call */
3811 	if (state == NULL)
3812 		state = makeNumericAggState(fcinfo, false);
3813 
3814 	if (!PG_ARGISNULL(1))
3815 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
3816 
3817 	PG_RETURN_POINTER(state);
3818 }
3819 
3820 /*
3821  * Combine function for numeric aggregates which don't require sumX2
3822  */
3823 Datum
numeric_avg_combine(PG_FUNCTION_ARGS)3824 numeric_avg_combine(PG_FUNCTION_ARGS)
3825 {
3826 	NumericAggState *state1;
3827 	NumericAggState *state2;
3828 	MemoryContext agg_context;
3829 	MemoryContext old_context;
3830 
3831 	if (!AggCheckCallContext(fcinfo, &agg_context))
3832 		elog(ERROR, "aggregate function called in non-aggregate context");
3833 
3834 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3835 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
3836 
3837 	if (state2 == NULL)
3838 		PG_RETURN_POINTER(state1);
3839 
3840 	/* manually copy all fields from state2 to state1 */
3841 	if (state1 == NULL)
3842 	{
3843 		old_context = MemoryContextSwitchTo(agg_context);
3844 
3845 		state1 = makeNumericAggStateCurrentContext(false);
3846 		state1->N = state2->N;
3847 		state1->NaNcount = state2->NaNcount;
3848 		state1->maxScale = state2->maxScale;
3849 		state1->maxScaleCount = state2->maxScaleCount;
3850 
3851 		accum_sum_copy(&state1->sumX, &state2->sumX);
3852 
3853 		MemoryContextSwitchTo(old_context);
3854 
3855 		PG_RETURN_POINTER(state1);
3856 	}
3857 
3858 	state1->N += state2->N;
3859 	state1->NaNcount += state2->NaNcount;
3860 
3861 	if (state2->N > 0)
3862 	{
3863 		/*
3864 		 * These are currently only needed for moving aggregates, but let's do
3865 		 * the right thing anyway...
3866 		 */
3867 		if (state2->maxScale > state1->maxScale)
3868 		{
3869 			state1->maxScale = state2->maxScale;
3870 			state1->maxScaleCount = state2->maxScaleCount;
3871 		}
3872 		else if (state2->maxScale == state1->maxScale)
3873 			state1->maxScaleCount += state2->maxScaleCount;
3874 
3875 		/* The rest of this needs to work in the aggregate context */
3876 		old_context = MemoryContextSwitchTo(agg_context);
3877 
3878 		/* Accumulate sums */
3879 		accum_sum_combine(&state1->sumX, &state2->sumX);
3880 
3881 		MemoryContextSwitchTo(old_context);
3882 	}
3883 	PG_RETURN_POINTER(state1);
3884 }
3885 
3886 /*
3887  * numeric_avg_serialize
3888  *		Serialize NumericAggState for numeric aggregates that don't require
3889  *		sumX2.
3890  */
3891 Datum
numeric_avg_serialize(PG_FUNCTION_ARGS)3892 numeric_avg_serialize(PG_FUNCTION_ARGS)
3893 {
3894 	NumericAggState *state;
3895 	StringInfoData buf;
3896 	Datum		temp;
3897 	bytea	   *sumX;
3898 	bytea	   *result;
3899 	NumericVar	tmp_var;
3900 
3901 	/* Ensure we disallow calling when not in aggregate context */
3902 	if (!AggCheckCallContext(fcinfo, NULL))
3903 		elog(ERROR, "aggregate function called in non-aggregate context");
3904 
3905 	state = (NumericAggState *) PG_GETARG_POINTER(0);
3906 
3907 	/*
3908 	 * This is a little wasteful since make_result converts the NumericVar
3909 	 * into a Numeric and numeric_send converts it back again. Is it worth
3910 	 * splitting the tasks in numeric_send into separate functions to stop
3911 	 * this? Doing so would also remove the fmgr call overhead.
3912 	 */
3913 	init_var(&tmp_var);
3914 	accum_sum_final(&state->sumX, &tmp_var);
3915 
3916 	temp = DirectFunctionCall1(numeric_send,
3917 							   NumericGetDatum(make_result(&tmp_var)));
3918 	sumX = DatumGetByteaPP(temp);
3919 	free_var(&tmp_var);
3920 
3921 	pq_begintypsend(&buf);
3922 
3923 	/* N */
3924 	pq_sendint64(&buf, state->N);
3925 
3926 	/* sumX */
3927 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
3928 
3929 	/* maxScale */
3930 	pq_sendint32(&buf, state->maxScale);
3931 
3932 	/* maxScaleCount */
3933 	pq_sendint64(&buf, state->maxScaleCount);
3934 
3935 	/* NaNcount */
3936 	pq_sendint64(&buf, state->NaNcount);
3937 
3938 	result = pq_endtypsend(&buf);
3939 
3940 	PG_RETURN_BYTEA_P(result);
3941 }
3942 
3943 /*
3944  * numeric_avg_deserialize
3945  *		Deserialize bytea into NumericAggState for numeric aggregates that
3946  *		don't require sumX2.
3947  */
3948 Datum
numeric_avg_deserialize(PG_FUNCTION_ARGS)3949 numeric_avg_deserialize(PG_FUNCTION_ARGS)
3950 {
3951 	bytea	   *sstate;
3952 	NumericAggState *result;
3953 	Datum		temp;
3954 	NumericVar	tmp_var;
3955 	StringInfoData buf;
3956 
3957 	if (!AggCheckCallContext(fcinfo, NULL))
3958 		elog(ERROR, "aggregate function called in non-aggregate context");
3959 
3960 	sstate = PG_GETARG_BYTEA_PP(0);
3961 
3962 	/*
3963 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
3964 	 * standard recv-function infrastructure.
3965 	 */
3966 	initStringInfo(&buf);
3967 	appendBinaryStringInfo(&buf,
3968 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
3969 
3970 	result = makeNumericAggStateCurrentContext(false);
3971 
3972 	/* N */
3973 	result->N = pq_getmsgint64(&buf);
3974 
3975 	/* sumX */
3976 	temp = DirectFunctionCall3(numeric_recv,
3977 							   PointerGetDatum(&buf),
3978 							   ObjectIdGetDatum(InvalidOid),
3979 							   Int32GetDatum(-1));
3980 	init_var_from_num(DatumGetNumeric(temp), &tmp_var);
3981 	accum_sum_add(&(result->sumX), &tmp_var);
3982 
3983 	/* maxScale */
3984 	result->maxScale = pq_getmsgint(&buf, 4);
3985 
3986 	/* maxScaleCount */
3987 	result->maxScaleCount = pq_getmsgint64(&buf);
3988 
3989 	/* NaNcount */
3990 	result->NaNcount = pq_getmsgint64(&buf);
3991 
3992 	pq_getmsgend(&buf);
3993 	pfree(buf.data);
3994 
3995 	PG_RETURN_POINTER(result);
3996 }
3997 
3998 /*
3999  * numeric_serialize
4000  *		Serialization function for NumericAggState for numeric aggregates that
4001  *		require sumX2.
4002  */
4003 Datum
numeric_serialize(PG_FUNCTION_ARGS)4004 numeric_serialize(PG_FUNCTION_ARGS)
4005 {
4006 	NumericAggState *state;
4007 	StringInfoData buf;
4008 	Datum		temp;
4009 	bytea	   *sumX;
4010 	NumericVar	tmp_var;
4011 	bytea	   *sumX2;
4012 	bytea	   *result;
4013 
4014 	/* Ensure we disallow calling when not in aggregate context */
4015 	if (!AggCheckCallContext(fcinfo, NULL))
4016 		elog(ERROR, "aggregate function called in non-aggregate context");
4017 
4018 	state = (NumericAggState *) PG_GETARG_POINTER(0);
4019 
4020 	/*
4021 	 * This is a little wasteful since make_result converts the NumericVar
4022 	 * into a Numeric and numeric_send converts it back again. Is it worth
4023 	 * splitting the tasks in numeric_send into separate functions to stop
4024 	 * this? Doing so would also remove the fmgr call overhead.
4025 	 */
4026 	init_var(&tmp_var);
4027 
4028 	accum_sum_final(&state->sumX, &tmp_var);
4029 	temp = DirectFunctionCall1(numeric_send,
4030 							   NumericGetDatum(make_result(&tmp_var)));
4031 	sumX = DatumGetByteaPP(temp);
4032 
4033 	accum_sum_final(&state->sumX2, &tmp_var);
4034 	temp = DirectFunctionCall1(numeric_send,
4035 							   NumericGetDatum(make_result(&tmp_var)));
4036 	sumX2 = DatumGetByteaPP(temp);
4037 
4038 	free_var(&tmp_var);
4039 
4040 	pq_begintypsend(&buf);
4041 
4042 	/* N */
4043 	pq_sendint64(&buf, state->N);
4044 
4045 	/* sumX */
4046 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4047 
4048 	/* sumX2 */
4049 	pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
4050 
4051 	/* maxScale */
4052 	pq_sendint32(&buf, state->maxScale);
4053 
4054 	/* maxScaleCount */
4055 	pq_sendint64(&buf, state->maxScaleCount);
4056 
4057 	/* NaNcount */
4058 	pq_sendint64(&buf, state->NaNcount);
4059 
4060 	result = pq_endtypsend(&buf);
4061 
4062 	PG_RETURN_BYTEA_P(result);
4063 }
4064 
4065 /*
4066  * numeric_deserialize
4067  *		Deserialization function for NumericAggState for numeric aggregates that
4068  *		require sumX2.
4069  */
4070 Datum
numeric_deserialize(PG_FUNCTION_ARGS)4071 numeric_deserialize(PG_FUNCTION_ARGS)
4072 {
4073 	bytea	   *sstate;
4074 	NumericAggState *result;
4075 	Datum		temp;
4076 	NumericVar	sumX_var;
4077 	NumericVar	sumX2_var;
4078 	StringInfoData buf;
4079 
4080 	if (!AggCheckCallContext(fcinfo, NULL))
4081 		elog(ERROR, "aggregate function called in non-aggregate context");
4082 
4083 	sstate = PG_GETARG_BYTEA_PP(0);
4084 
4085 	/*
4086 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4087 	 * standard recv-function infrastructure.
4088 	 */
4089 	initStringInfo(&buf);
4090 	appendBinaryStringInfo(&buf,
4091 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4092 
4093 	result = makeNumericAggStateCurrentContext(false);
4094 
4095 	/* N */
4096 	result->N = pq_getmsgint64(&buf);
4097 
4098 	/* sumX */
4099 	temp = DirectFunctionCall3(numeric_recv,
4100 							   PointerGetDatum(&buf),
4101 							   ObjectIdGetDatum(InvalidOid),
4102 							   Int32GetDatum(-1));
4103 	init_var_from_num(DatumGetNumeric(temp), &sumX_var);
4104 	accum_sum_add(&(result->sumX), &sumX_var);
4105 
4106 	/* sumX2 */
4107 	temp = DirectFunctionCall3(numeric_recv,
4108 							   PointerGetDatum(&buf),
4109 							   ObjectIdGetDatum(InvalidOid),
4110 							   Int32GetDatum(-1));
4111 	init_var_from_num(DatumGetNumeric(temp), &sumX2_var);
4112 	accum_sum_add(&(result->sumX2), &sumX2_var);
4113 
4114 	/* maxScale */
4115 	result->maxScale = pq_getmsgint(&buf, 4);
4116 
4117 	/* maxScaleCount */
4118 	result->maxScaleCount = pq_getmsgint64(&buf);
4119 
4120 	/* NaNcount */
4121 	result->NaNcount = pq_getmsgint64(&buf);
4122 
4123 	pq_getmsgend(&buf);
4124 	pfree(buf.data);
4125 
4126 	PG_RETURN_POINTER(result);
4127 }
4128 
4129 /*
4130  * Generic inverse transition function for numeric aggregates
4131  * (with or without requirement for X^2).
4132  */
4133 Datum
numeric_accum_inv(PG_FUNCTION_ARGS)4134 numeric_accum_inv(PG_FUNCTION_ARGS)
4135 {
4136 	NumericAggState *state;
4137 
4138 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4139 
4140 	/* Should not get here with no state */
4141 	if (state == NULL)
4142 		elog(ERROR, "numeric_accum_inv called with NULL state");
4143 
4144 	if (!PG_ARGISNULL(1))
4145 	{
4146 		/* If we fail to perform the inverse transition, return NULL */
4147 		if (!do_numeric_discard(state, PG_GETARG_NUMERIC(1)))
4148 			PG_RETURN_NULL();
4149 	}
4150 
4151 	PG_RETURN_POINTER(state);
4152 }
4153 
4154 
4155 /*
4156  * Integer data types in general use Numeric accumulators to share code
4157  * and avoid risk of overflow.
4158  *
4159  * However for performance reasons optimized special-purpose accumulator
4160  * routines are used when possible.
4161  *
4162  * On platforms with 128-bit integer support, the 128-bit routines will be
4163  * used when sum(X) or sum(X*X) fit into 128-bit.
4164  *
4165  * For 16 and 32 bit inputs, the N and sum(X) fit into 64-bit so the 64-bit
4166  * accumulators will be used for SUM and AVG of these data types.
4167  */
4168 
4169 #ifdef HAVE_INT128
4170 typedef struct Int128AggState
4171 {
4172 	bool		calcSumX2;		/* if true, calculate sumX2 */
4173 	int64		N;				/* count of processed numbers */
4174 	int128		sumX;			/* sum of processed numbers */
4175 	int128		sumX2;			/* sum of squares of processed numbers */
4176 } Int128AggState;
4177 
4178 /*
4179  * Prepare state data for a 128-bit aggregate function that needs to compute
4180  * sum, count and optionally sum of squares of the input.
4181  */
4182 static Int128AggState *
makeInt128AggState(FunctionCallInfo fcinfo,bool calcSumX2)4183 makeInt128AggState(FunctionCallInfo fcinfo, bool calcSumX2)
4184 {
4185 	Int128AggState *state;
4186 	MemoryContext agg_context;
4187 	MemoryContext old_context;
4188 
4189 	if (!AggCheckCallContext(fcinfo, &agg_context))
4190 		elog(ERROR, "aggregate function called in non-aggregate context");
4191 
4192 	old_context = MemoryContextSwitchTo(agg_context);
4193 
4194 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
4195 	state->calcSumX2 = calcSumX2;
4196 
4197 	MemoryContextSwitchTo(old_context);
4198 
4199 	return state;
4200 }
4201 
4202 /*
4203  * Like makeInt128AggState(), but allocate the state in the current memory
4204  * context.
4205  */
4206 static Int128AggState *
makeInt128AggStateCurrentContext(bool calcSumX2)4207 makeInt128AggStateCurrentContext(bool calcSumX2)
4208 {
4209 	Int128AggState *state;
4210 
4211 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
4212 	state->calcSumX2 = calcSumX2;
4213 
4214 	return state;
4215 }
4216 
4217 /*
4218  * Accumulate a new input value for 128-bit aggregate functions.
4219  */
4220 static void
do_int128_accum(Int128AggState * state,int128 newval)4221 do_int128_accum(Int128AggState *state, int128 newval)
4222 {
4223 	if (state->calcSumX2)
4224 		state->sumX2 += newval * newval;
4225 
4226 	state->sumX += newval;
4227 	state->N++;
4228 }
4229 
4230 /*
4231  * Remove an input value from the aggregated state.
4232  */
4233 static void
do_int128_discard(Int128AggState * state,int128 newval)4234 do_int128_discard(Int128AggState *state, int128 newval)
4235 {
4236 	if (state->calcSumX2)
4237 		state->sumX2 -= newval * newval;
4238 
4239 	state->sumX -= newval;
4240 	state->N--;
4241 }
4242 
4243 typedef Int128AggState PolyNumAggState;
4244 #define makePolyNumAggState makeInt128AggState
4245 #define makePolyNumAggStateCurrentContext makeInt128AggStateCurrentContext
4246 #else
4247 typedef NumericAggState PolyNumAggState;
4248 #define makePolyNumAggState makeNumericAggState
4249 #define makePolyNumAggStateCurrentContext makeNumericAggStateCurrentContext
4250 #endif
4251 
4252 Datum
int2_accum(PG_FUNCTION_ARGS)4253 int2_accum(PG_FUNCTION_ARGS)
4254 {
4255 	PolyNumAggState *state;
4256 
4257 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4258 
4259 	/* Create the state data on the first call */
4260 	if (state == NULL)
4261 		state = makePolyNumAggState(fcinfo, true);
4262 
4263 	if (!PG_ARGISNULL(1))
4264 	{
4265 #ifdef HAVE_INT128
4266 		do_int128_accum(state, (int128) PG_GETARG_INT16(1));
4267 #else
4268 		Numeric		newval;
4269 
4270 		newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4271 													 PG_GETARG_DATUM(1)));
4272 		do_numeric_accum(state, newval);
4273 #endif
4274 	}
4275 
4276 	PG_RETURN_POINTER(state);
4277 }
4278 
4279 Datum
int4_accum(PG_FUNCTION_ARGS)4280 int4_accum(PG_FUNCTION_ARGS)
4281 {
4282 	PolyNumAggState *state;
4283 
4284 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4285 
4286 	/* Create the state data on the first call */
4287 	if (state == NULL)
4288 		state = makePolyNumAggState(fcinfo, true);
4289 
4290 	if (!PG_ARGISNULL(1))
4291 	{
4292 #ifdef HAVE_INT128
4293 		do_int128_accum(state, (int128) PG_GETARG_INT32(1));
4294 #else
4295 		Numeric		newval;
4296 
4297 		newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4298 													 PG_GETARG_DATUM(1)));
4299 		do_numeric_accum(state, newval);
4300 #endif
4301 	}
4302 
4303 	PG_RETURN_POINTER(state);
4304 }
4305 
4306 Datum
int8_accum(PG_FUNCTION_ARGS)4307 int8_accum(PG_FUNCTION_ARGS)
4308 {
4309 	NumericAggState *state;
4310 
4311 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4312 
4313 	/* Create the state data on the first call */
4314 	if (state == NULL)
4315 		state = makeNumericAggState(fcinfo, true);
4316 
4317 	if (!PG_ARGISNULL(1))
4318 	{
4319 		Numeric		newval;
4320 
4321 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4322 													 PG_GETARG_DATUM(1)));
4323 		do_numeric_accum(state, newval);
4324 	}
4325 
4326 	PG_RETURN_POINTER(state);
4327 }
4328 
4329 /*
4330  * Combine function for numeric aggregates which require sumX2
4331  */
4332 Datum
numeric_poly_combine(PG_FUNCTION_ARGS)4333 numeric_poly_combine(PG_FUNCTION_ARGS)
4334 {
4335 	PolyNumAggState *state1;
4336 	PolyNumAggState *state2;
4337 	MemoryContext agg_context;
4338 	MemoryContext old_context;
4339 
4340 	if (!AggCheckCallContext(fcinfo, &agg_context))
4341 		elog(ERROR, "aggregate function called in non-aggregate context");
4342 
4343 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4344 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4345 
4346 	if (state2 == NULL)
4347 		PG_RETURN_POINTER(state1);
4348 
4349 	/* manually copy all fields from state2 to state1 */
4350 	if (state1 == NULL)
4351 	{
4352 		old_context = MemoryContextSwitchTo(agg_context);
4353 
4354 		state1 = makePolyNumAggState(fcinfo, true);
4355 		state1->N = state2->N;
4356 
4357 #ifdef HAVE_INT128
4358 		state1->sumX = state2->sumX;
4359 		state1->sumX2 = state2->sumX2;
4360 #else
4361 		accum_sum_copy(&state1->sumX, &state2->sumX);
4362 		accum_sum_copy(&state1->sumX2, &state2->sumX2);
4363 #endif
4364 
4365 		MemoryContextSwitchTo(old_context);
4366 
4367 		PG_RETURN_POINTER(state1);
4368 	}
4369 
4370 	if (state2->N > 0)
4371 	{
4372 		state1->N += state2->N;
4373 
4374 #ifdef HAVE_INT128
4375 		state1->sumX += state2->sumX;
4376 		state1->sumX2 += state2->sumX2;
4377 #else
4378 		/* The rest of this needs to work in the aggregate context */
4379 		old_context = MemoryContextSwitchTo(agg_context);
4380 
4381 		/* Accumulate sums */
4382 		accum_sum_combine(&state1->sumX, &state2->sumX);
4383 		accum_sum_combine(&state1->sumX2, &state2->sumX2);
4384 
4385 		MemoryContextSwitchTo(old_context);
4386 #endif
4387 
4388 	}
4389 	PG_RETURN_POINTER(state1);
4390 }
4391 
4392 /*
4393  * numeric_poly_serialize
4394  *		Serialize PolyNumAggState into bytea for aggregate functions which
4395  *		require sumX2.
4396  */
4397 Datum
numeric_poly_serialize(PG_FUNCTION_ARGS)4398 numeric_poly_serialize(PG_FUNCTION_ARGS)
4399 {
4400 	PolyNumAggState *state;
4401 	StringInfoData buf;
4402 	bytea	   *sumX;
4403 	bytea	   *sumX2;
4404 	bytea	   *result;
4405 
4406 	/* Ensure we disallow calling when not in aggregate context */
4407 	if (!AggCheckCallContext(fcinfo, NULL))
4408 		elog(ERROR, "aggregate function called in non-aggregate context");
4409 
4410 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4411 
4412 	/*
4413 	 * If the platform supports int128 then sumX and sumX2 will be a 128 bit
4414 	 * integer type. Here we'll convert that into a numeric type so that the
4415 	 * combine state is in the same format for both int128 enabled machines
4416 	 * and machines which don't support that type. The logic here is that one
4417 	 * day we might like to send these over to another server for further
4418 	 * processing and we want a standard format to work with.
4419 	 */
4420 	{
4421 		Datum		temp;
4422 		NumericVar	num;
4423 
4424 		init_var(&num);
4425 
4426 #ifdef HAVE_INT128
4427 		int128_to_numericvar(state->sumX, &num);
4428 #else
4429 		accum_sum_final(&state->sumX, &num);
4430 #endif
4431 		temp = DirectFunctionCall1(numeric_send,
4432 								   NumericGetDatum(make_result(&num)));
4433 		sumX = DatumGetByteaPP(temp);
4434 
4435 #ifdef HAVE_INT128
4436 		int128_to_numericvar(state->sumX2, &num);
4437 #else
4438 		accum_sum_final(&state->sumX2, &num);
4439 #endif
4440 		temp = DirectFunctionCall1(numeric_send,
4441 								   NumericGetDatum(make_result(&num)));
4442 		sumX2 = DatumGetByteaPP(temp);
4443 
4444 		free_var(&num);
4445 	}
4446 
4447 	pq_begintypsend(&buf);
4448 
4449 	/* N */
4450 	pq_sendint64(&buf, state->N);
4451 
4452 	/* sumX */
4453 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4454 
4455 	/* sumX2 */
4456 	pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
4457 
4458 	result = pq_endtypsend(&buf);
4459 
4460 	PG_RETURN_BYTEA_P(result);
4461 }
4462 
4463 /*
4464  * numeric_poly_deserialize
4465  *		Deserialize PolyNumAggState from bytea for aggregate functions which
4466  *		require sumX2.
4467  */
4468 Datum
numeric_poly_deserialize(PG_FUNCTION_ARGS)4469 numeric_poly_deserialize(PG_FUNCTION_ARGS)
4470 {
4471 	bytea	   *sstate;
4472 	PolyNumAggState *result;
4473 	Datum		sumX;
4474 	NumericVar	sumX_var;
4475 	Datum		sumX2;
4476 	NumericVar	sumX2_var;
4477 	StringInfoData buf;
4478 
4479 	if (!AggCheckCallContext(fcinfo, NULL))
4480 		elog(ERROR, "aggregate function called in non-aggregate context");
4481 
4482 	sstate = PG_GETARG_BYTEA_PP(0);
4483 
4484 	/*
4485 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4486 	 * standard recv-function infrastructure.
4487 	 */
4488 	initStringInfo(&buf);
4489 	appendBinaryStringInfo(&buf,
4490 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4491 
4492 	result = makePolyNumAggStateCurrentContext(false);
4493 
4494 	/* N */
4495 	result->N = pq_getmsgint64(&buf);
4496 
4497 	/* sumX */
4498 	sumX = DirectFunctionCall3(numeric_recv,
4499 							   PointerGetDatum(&buf),
4500 							   ObjectIdGetDatum(InvalidOid),
4501 							   Int32GetDatum(-1));
4502 
4503 	/* sumX2 */
4504 	sumX2 = DirectFunctionCall3(numeric_recv,
4505 								PointerGetDatum(&buf),
4506 								ObjectIdGetDatum(InvalidOid),
4507 								Int32GetDatum(-1));
4508 
4509 	init_var_from_num(DatumGetNumeric(sumX), &sumX_var);
4510 #ifdef HAVE_INT128
4511 	numericvar_to_int128(&sumX_var, &result->sumX);
4512 #else
4513 	accum_sum_add(&result->sumX, &sumX_var);
4514 #endif
4515 
4516 	init_var_from_num(DatumGetNumeric(sumX2), &sumX2_var);
4517 #ifdef HAVE_INT128
4518 	numericvar_to_int128(&sumX2_var, &result->sumX2);
4519 #else
4520 	accum_sum_add(&result->sumX2, &sumX2_var);
4521 #endif
4522 
4523 	pq_getmsgend(&buf);
4524 	pfree(buf.data);
4525 
4526 	PG_RETURN_POINTER(result);
4527 }
4528 
4529 /*
4530  * Transition function for int8 input when we don't need sumX2.
4531  */
4532 Datum
int8_avg_accum(PG_FUNCTION_ARGS)4533 int8_avg_accum(PG_FUNCTION_ARGS)
4534 {
4535 	PolyNumAggState *state;
4536 
4537 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4538 
4539 	/* Create the state data on the first call */
4540 	if (state == NULL)
4541 		state = makePolyNumAggState(fcinfo, false);
4542 
4543 	if (!PG_ARGISNULL(1))
4544 	{
4545 #ifdef HAVE_INT128
4546 		do_int128_accum(state, (int128) PG_GETARG_INT64(1));
4547 #else
4548 		Numeric		newval;
4549 
4550 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4551 													 PG_GETARG_DATUM(1)));
4552 		do_numeric_accum(state, newval);
4553 #endif
4554 	}
4555 
4556 	PG_RETURN_POINTER(state);
4557 }
4558 
4559 /*
4560  * Combine function for PolyNumAggState for aggregates which don't require
4561  * sumX2
4562  */
4563 Datum
int8_avg_combine(PG_FUNCTION_ARGS)4564 int8_avg_combine(PG_FUNCTION_ARGS)
4565 {
4566 	PolyNumAggState *state1;
4567 	PolyNumAggState *state2;
4568 	MemoryContext agg_context;
4569 	MemoryContext old_context;
4570 
4571 	if (!AggCheckCallContext(fcinfo, &agg_context))
4572 		elog(ERROR, "aggregate function called in non-aggregate context");
4573 
4574 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4575 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4576 
4577 	if (state2 == NULL)
4578 		PG_RETURN_POINTER(state1);
4579 
4580 	/* manually copy all fields from state2 to state1 */
4581 	if (state1 == NULL)
4582 	{
4583 		old_context = MemoryContextSwitchTo(agg_context);
4584 
4585 		state1 = makePolyNumAggState(fcinfo, false);
4586 		state1->N = state2->N;
4587 
4588 #ifdef HAVE_INT128
4589 		state1->sumX = state2->sumX;
4590 #else
4591 		accum_sum_copy(&state1->sumX, &state2->sumX);
4592 #endif
4593 		MemoryContextSwitchTo(old_context);
4594 
4595 		PG_RETURN_POINTER(state1);
4596 	}
4597 
4598 	if (state2->N > 0)
4599 	{
4600 		state1->N += state2->N;
4601 
4602 #ifdef HAVE_INT128
4603 		state1->sumX += state2->sumX;
4604 #else
4605 		/* The rest of this needs to work in the aggregate context */
4606 		old_context = MemoryContextSwitchTo(agg_context);
4607 
4608 		/* Accumulate sums */
4609 		accum_sum_combine(&state1->sumX, &state2->sumX);
4610 
4611 		MemoryContextSwitchTo(old_context);
4612 #endif
4613 
4614 	}
4615 	PG_RETURN_POINTER(state1);
4616 }
4617 
4618 /*
4619  * int8_avg_serialize
4620  *		Serialize PolyNumAggState into bytea using the standard
4621  *		recv-function infrastructure.
4622  */
4623 Datum
int8_avg_serialize(PG_FUNCTION_ARGS)4624 int8_avg_serialize(PG_FUNCTION_ARGS)
4625 {
4626 	PolyNumAggState *state;
4627 	StringInfoData buf;
4628 	bytea	   *sumX;
4629 	bytea	   *result;
4630 
4631 	/* Ensure we disallow calling when not in aggregate context */
4632 	if (!AggCheckCallContext(fcinfo, NULL))
4633 		elog(ERROR, "aggregate function called in non-aggregate context");
4634 
4635 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4636 
4637 	/*
4638 	 * If the platform supports int128 then sumX will be a 128 integer type.
4639 	 * Here we'll convert that into a numeric type so that the combine state
4640 	 * is in the same format for both int128 enabled machines and machines
4641 	 * which don't support that type. The logic here is that one day we might
4642 	 * like to send these over to another server for further processing and we
4643 	 * want a standard format to work with.
4644 	 */
4645 	{
4646 		Datum		temp;
4647 		NumericVar	num;
4648 
4649 		init_var(&num);
4650 
4651 #ifdef HAVE_INT128
4652 		int128_to_numericvar(state->sumX, &num);
4653 #else
4654 		accum_sum_final(&state->sumX, &num);
4655 #endif
4656 		temp = DirectFunctionCall1(numeric_send,
4657 								   NumericGetDatum(make_result(&num)));
4658 		sumX = DatumGetByteaPP(temp);
4659 
4660 		free_var(&num);
4661 	}
4662 
4663 	pq_begintypsend(&buf);
4664 
4665 	/* N */
4666 	pq_sendint64(&buf, state->N);
4667 
4668 	/* sumX */
4669 	pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4670 
4671 	result = pq_endtypsend(&buf);
4672 
4673 	PG_RETURN_BYTEA_P(result);
4674 }
4675 
4676 /*
4677  * int8_avg_deserialize
4678  *		Deserialize bytea back into PolyNumAggState.
4679  */
4680 Datum
int8_avg_deserialize(PG_FUNCTION_ARGS)4681 int8_avg_deserialize(PG_FUNCTION_ARGS)
4682 {
4683 	bytea	   *sstate;
4684 	PolyNumAggState *result;
4685 	StringInfoData buf;
4686 	Datum		temp;
4687 	NumericVar	num;
4688 
4689 	if (!AggCheckCallContext(fcinfo, NULL))
4690 		elog(ERROR, "aggregate function called in non-aggregate context");
4691 
4692 	sstate = PG_GETARG_BYTEA_PP(0);
4693 
4694 	/*
4695 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4696 	 * standard recv-function infrastructure.
4697 	 */
4698 	initStringInfo(&buf);
4699 	appendBinaryStringInfo(&buf,
4700 						   VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4701 
4702 	result = makePolyNumAggStateCurrentContext(false);
4703 
4704 	/* N */
4705 	result->N = pq_getmsgint64(&buf);
4706 
4707 	/* sumX */
4708 	temp = DirectFunctionCall3(numeric_recv,
4709 							   PointerGetDatum(&buf),
4710 							   ObjectIdGetDatum(InvalidOid),
4711 							   Int32GetDatum(-1));
4712 	init_var_from_num(DatumGetNumeric(temp), &num);
4713 #ifdef HAVE_INT128
4714 	numericvar_to_int128(&num, &result->sumX);
4715 #else
4716 	accum_sum_add(&result->sumX, &num);
4717 #endif
4718 
4719 	pq_getmsgend(&buf);
4720 	pfree(buf.data);
4721 
4722 	PG_RETURN_POINTER(result);
4723 }
4724 
4725 /*
4726  * Inverse transition functions to go with the above.
4727  */
4728 
4729 Datum
int2_accum_inv(PG_FUNCTION_ARGS)4730 int2_accum_inv(PG_FUNCTION_ARGS)
4731 {
4732 	PolyNumAggState *state;
4733 
4734 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4735 
4736 	/* Should not get here with no state */
4737 	if (state == NULL)
4738 		elog(ERROR, "int2_accum_inv called with NULL state");
4739 
4740 	if (!PG_ARGISNULL(1))
4741 	{
4742 #ifdef HAVE_INT128
4743 		do_int128_discard(state, (int128) PG_GETARG_INT16(1));
4744 #else
4745 		Numeric		newval;
4746 
4747 		newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4748 													 PG_GETARG_DATUM(1)));
4749 
4750 		/* Should never fail, all inputs have dscale 0 */
4751 		if (!do_numeric_discard(state, newval))
4752 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4753 #endif
4754 	}
4755 
4756 	PG_RETURN_POINTER(state);
4757 }
4758 
4759 Datum
int4_accum_inv(PG_FUNCTION_ARGS)4760 int4_accum_inv(PG_FUNCTION_ARGS)
4761 {
4762 	PolyNumAggState *state;
4763 
4764 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4765 
4766 	/* Should not get here with no state */
4767 	if (state == NULL)
4768 		elog(ERROR, "int4_accum_inv called with NULL state");
4769 
4770 	if (!PG_ARGISNULL(1))
4771 	{
4772 #ifdef HAVE_INT128
4773 		do_int128_discard(state, (int128) PG_GETARG_INT32(1));
4774 #else
4775 		Numeric		newval;
4776 
4777 		newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4778 													 PG_GETARG_DATUM(1)));
4779 
4780 		/* Should never fail, all inputs have dscale 0 */
4781 		if (!do_numeric_discard(state, newval))
4782 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4783 #endif
4784 	}
4785 
4786 	PG_RETURN_POINTER(state);
4787 }
4788 
4789 Datum
int8_accum_inv(PG_FUNCTION_ARGS)4790 int8_accum_inv(PG_FUNCTION_ARGS)
4791 {
4792 	NumericAggState *state;
4793 
4794 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4795 
4796 	/* Should not get here with no state */
4797 	if (state == NULL)
4798 		elog(ERROR, "int8_accum_inv called with NULL state");
4799 
4800 	if (!PG_ARGISNULL(1))
4801 	{
4802 		Numeric		newval;
4803 
4804 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4805 													 PG_GETARG_DATUM(1)));
4806 
4807 		/* Should never fail, all inputs have dscale 0 */
4808 		if (!do_numeric_discard(state, newval))
4809 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4810 	}
4811 
4812 	PG_RETURN_POINTER(state);
4813 }
4814 
4815 Datum
int8_avg_accum_inv(PG_FUNCTION_ARGS)4816 int8_avg_accum_inv(PG_FUNCTION_ARGS)
4817 {
4818 	PolyNumAggState *state;
4819 
4820 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4821 
4822 	/* Should not get here with no state */
4823 	if (state == NULL)
4824 		elog(ERROR, "int8_avg_accum_inv called with NULL state");
4825 
4826 	if (!PG_ARGISNULL(1))
4827 	{
4828 #ifdef HAVE_INT128
4829 		do_int128_discard(state, (int128) PG_GETARG_INT64(1));
4830 #else
4831 		Numeric		newval;
4832 
4833 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4834 													 PG_GETARG_DATUM(1)));
4835 
4836 		/* Should never fail, all inputs have dscale 0 */
4837 		if (!do_numeric_discard(state, newval))
4838 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4839 #endif
4840 	}
4841 
4842 	PG_RETURN_POINTER(state);
4843 }
4844 
4845 Datum
numeric_poly_sum(PG_FUNCTION_ARGS)4846 numeric_poly_sum(PG_FUNCTION_ARGS)
4847 {
4848 #ifdef HAVE_INT128
4849 	PolyNumAggState *state;
4850 	Numeric		res;
4851 	NumericVar	result;
4852 
4853 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4854 
4855 	/* If there were no non-null inputs, return NULL */
4856 	if (state == NULL || state->N == 0)
4857 		PG_RETURN_NULL();
4858 
4859 	init_var(&result);
4860 
4861 	int128_to_numericvar(state->sumX, &result);
4862 
4863 	res = make_result(&result);
4864 
4865 	free_var(&result);
4866 
4867 	PG_RETURN_NUMERIC(res);
4868 #else
4869 	return numeric_sum(fcinfo);
4870 #endif
4871 }
4872 
4873 Datum
numeric_poly_avg(PG_FUNCTION_ARGS)4874 numeric_poly_avg(PG_FUNCTION_ARGS)
4875 {
4876 #ifdef HAVE_INT128
4877 	PolyNumAggState *state;
4878 	NumericVar	result;
4879 	Datum		countd,
4880 				sumd;
4881 
4882 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4883 
4884 	/* If there were no non-null inputs, return NULL */
4885 	if (state == NULL || state->N == 0)
4886 		PG_RETURN_NULL();
4887 
4888 	init_var(&result);
4889 
4890 	int128_to_numericvar(state->sumX, &result);
4891 
4892 	countd = DirectFunctionCall1(int8_numeric,
4893 								 Int64GetDatumFast(state->N));
4894 	sumd = NumericGetDatum(make_result(&result));
4895 
4896 	free_var(&result);
4897 
4898 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
4899 #else
4900 	return numeric_avg(fcinfo);
4901 #endif
4902 }
4903 
4904 Datum
numeric_avg(PG_FUNCTION_ARGS)4905 numeric_avg(PG_FUNCTION_ARGS)
4906 {
4907 	NumericAggState *state;
4908 	Datum		N_datum;
4909 	Datum		sumX_datum;
4910 	NumericVar	sumX_var;
4911 
4912 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4913 
4914 	/* If there were no non-null inputs, return NULL */
4915 	if (state == NULL || (state->N + state->NaNcount) == 0)
4916 		PG_RETURN_NULL();
4917 
4918 	if (state->NaNcount > 0)	/* there was at least one NaN input */
4919 		PG_RETURN_NUMERIC(make_result(&const_nan));
4920 
4921 	N_datum = DirectFunctionCall1(int8_numeric, Int64GetDatum(state->N));
4922 
4923 	init_var(&sumX_var);
4924 	accum_sum_final(&state->sumX, &sumX_var);
4925 	sumX_datum = NumericGetDatum(make_result(&sumX_var));
4926 	free_var(&sumX_var);
4927 
4928 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumX_datum, N_datum));
4929 }
4930 
4931 Datum
numeric_sum(PG_FUNCTION_ARGS)4932 numeric_sum(PG_FUNCTION_ARGS)
4933 {
4934 	NumericAggState *state;
4935 	NumericVar	sumX_var;
4936 	Numeric		result;
4937 
4938 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4939 
4940 	/* If there were no non-null inputs, return NULL */
4941 	if (state == NULL || (state->N + state->NaNcount) == 0)
4942 		PG_RETURN_NULL();
4943 
4944 	if (state->NaNcount > 0)	/* there was at least one NaN input */
4945 		PG_RETURN_NUMERIC(make_result(&const_nan));
4946 
4947 	init_var(&sumX_var);
4948 	accum_sum_final(&state->sumX, &sumX_var);
4949 	result = make_result(&sumX_var);
4950 	free_var(&sumX_var);
4951 
4952 	PG_RETURN_NUMERIC(result);
4953 }
4954 
4955 /*
4956  * Workhorse routine for the standard deviance and variance
4957  * aggregates. 'state' is aggregate's transition state.
4958  * 'variance' specifies whether we should calculate the
4959  * variance or the standard deviation. 'sample' indicates whether the
4960  * caller is interested in the sample or the population
4961  * variance/stddev.
4962  *
4963  * If appropriate variance statistic is undefined for the input,
4964  * *is_null is set to true and NULL is returned.
4965  */
4966 static Numeric
numeric_stddev_internal(NumericAggState * state,bool variance,bool sample,bool * is_null)4967 numeric_stddev_internal(NumericAggState *state,
4968 						bool variance, bool sample,
4969 						bool *is_null)
4970 {
4971 	Numeric		res;
4972 	NumericVar	vN,
4973 				vsumX,
4974 				vsumX2,
4975 				vNminus1;
4976 	const NumericVar *comp;
4977 	int			rscale;
4978 
4979 	/* Deal with empty input and NaN-input cases */
4980 	if (state == NULL || (state->N + state->NaNcount) == 0)
4981 	{
4982 		*is_null = true;
4983 		return NULL;
4984 	}
4985 
4986 	*is_null = false;
4987 
4988 	if (state->NaNcount > 0)
4989 		return make_result(&const_nan);
4990 
4991 	init_var(&vN);
4992 	init_var(&vsumX);
4993 	init_var(&vsumX2);
4994 
4995 	int64_to_numericvar(state->N, &vN);
4996 	accum_sum_final(&(state->sumX), &vsumX);
4997 	accum_sum_final(&(state->sumX2), &vsumX2);
4998 
4999 	/*
5000 	 * Sample stddev and variance are undefined when N <= 1; population stddev
5001 	 * is undefined when N == 0. Return NULL in either case.
5002 	 */
5003 	if (sample)
5004 		comp = &const_one;
5005 	else
5006 		comp = &const_zero;
5007 
5008 	if (cmp_var(&vN, comp) <= 0)
5009 	{
5010 		*is_null = true;
5011 		return NULL;
5012 	}
5013 
5014 	init_var(&vNminus1);
5015 	sub_var(&vN, &const_one, &vNminus1);
5016 
5017 	/* compute rscale for mul_var calls */
5018 	rscale = vsumX.dscale * 2;
5019 
5020 	mul_var(&vsumX, &vsumX, &vsumX, rscale);	/* vsumX = sumX * sumX */
5021 	mul_var(&vN, &vsumX2, &vsumX2, rscale); /* vsumX2 = N * sumX2 */
5022 	sub_var(&vsumX2, &vsumX, &vsumX2);	/* N * sumX2 - sumX * sumX */
5023 
5024 	if (cmp_var(&vsumX2, &const_zero) <= 0)
5025 	{
5026 		/* Watch out for roundoff error producing a negative numerator */
5027 		res = make_result(&const_zero);
5028 	}
5029 	else
5030 	{
5031 		if (sample)
5032 			mul_var(&vN, &vNminus1, &vNminus1, 0);	/* N * (N - 1) */
5033 		else
5034 			mul_var(&vN, &vN, &vNminus1, 0);	/* N * N */
5035 		rscale = select_div_scale(&vsumX2, &vNminus1);
5036 		div_var(&vsumX2, &vNminus1, &vsumX, rscale, true);	/* variance */
5037 		if (!variance)
5038 			sqrt_var(&vsumX, &vsumX, rscale);	/* stddev */
5039 
5040 		res = make_result(&vsumX);
5041 	}
5042 
5043 	free_var(&vNminus1);
5044 	free_var(&vsumX);
5045 	free_var(&vsumX2);
5046 
5047 	return res;
5048 }
5049 
5050 Datum
numeric_var_samp(PG_FUNCTION_ARGS)5051 numeric_var_samp(PG_FUNCTION_ARGS)
5052 {
5053 	NumericAggState *state;
5054 	Numeric		res;
5055 	bool		is_null;
5056 
5057 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5058 
5059 	res = numeric_stddev_internal(state, true, true, &is_null);
5060 
5061 	if (is_null)
5062 		PG_RETURN_NULL();
5063 	else
5064 		PG_RETURN_NUMERIC(res);
5065 }
5066 
5067 Datum
numeric_stddev_samp(PG_FUNCTION_ARGS)5068 numeric_stddev_samp(PG_FUNCTION_ARGS)
5069 {
5070 	NumericAggState *state;
5071 	Numeric		res;
5072 	bool		is_null;
5073 
5074 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5075 
5076 	res = numeric_stddev_internal(state, false, true, &is_null);
5077 
5078 	if (is_null)
5079 		PG_RETURN_NULL();
5080 	else
5081 		PG_RETURN_NUMERIC(res);
5082 }
5083 
5084 Datum
numeric_var_pop(PG_FUNCTION_ARGS)5085 numeric_var_pop(PG_FUNCTION_ARGS)
5086 {
5087 	NumericAggState *state;
5088 	Numeric		res;
5089 	bool		is_null;
5090 
5091 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5092 
5093 	res = numeric_stddev_internal(state, true, false, &is_null);
5094 
5095 	if (is_null)
5096 		PG_RETURN_NULL();
5097 	else
5098 		PG_RETURN_NUMERIC(res);
5099 }
5100 
5101 Datum
numeric_stddev_pop(PG_FUNCTION_ARGS)5102 numeric_stddev_pop(PG_FUNCTION_ARGS)
5103 {
5104 	NumericAggState *state;
5105 	Numeric		res;
5106 	bool		is_null;
5107 
5108 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5109 
5110 	res = numeric_stddev_internal(state, false, false, &is_null);
5111 
5112 	if (is_null)
5113 		PG_RETURN_NULL();
5114 	else
5115 		PG_RETURN_NUMERIC(res);
5116 }
5117 
5118 #ifdef HAVE_INT128
5119 static Numeric
numeric_poly_stddev_internal(Int128AggState * state,bool variance,bool sample,bool * is_null)5120 numeric_poly_stddev_internal(Int128AggState *state,
5121 							 bool variance, bool sample,
5122 							 bool *is_null)
5123 {
5124 	NumericAggState numstate;
5125 	Numeric		res;
5126 
5127 	/* Initialize an empty agg state */
5128 	memset(&numstate, 0, sizeof(NumericAggState));
5129 
5130 	if (state)
5131 	{
5132 		NumericVar	tmp_var;
5133 
5134 		numstate.N = state->N;
5135 
5136 		init_var(&tmp_var);
5137 
5138 		int128_to_numericvar(state->sumX, &tmp_var);
5139 		accum_sum_add(&numstate.sumX, &tmp_var);
5140 
5141 		int128_to_numericvar(state->sumX2, &tmp_var);
5142 		accum_sum_add(&numstate.sumX2, &tmp_var);
5143 
5144 		free_var(&tmp_var);
5145 	}
5146 
5147 	res = numeric_stddev_internal(&numstate, variance, sample, is_null);
5148 
5149 	if (numstate.sumX.ndigits > 0)
5150 	{
5151 		pfree(numstate.sumX.pos_digits);
5152 		pfree(numstate.sumX.neg_digits);
5153 	}
5154 	if (numstate.sumX2.ndigits > 0)
5155 	{
5156 		pfree(numstate.sumX2.pos_digits);
5157 		pfree(numstate.sumX2.neg_digits);
5158 	}
5159 
5160 	return res;
5161 }
5162 #endif
5163 
5164 Datum
numeric_poly_var_samp(PG_FUNCTION_ARGS)5165 numeric_poly_var_samp(PG_FUNCTION_ARGS)
5166 {
5167 #ifdef HAVE_INT128
5168 	PolyNumAggState *state;
5169 	Numeric		res;
5170 	bool		is_null;
5171 
5172 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5173 
5174 	res = numeric_poly_stddev_internal(state, true, true, &is_null);
5175 
5176 	if (is_null)
5177 		PG_RETURN_NULL();
5178 	else
5179 		PG_RETURN_NUMERIC(res);
5180 #else
5181 	return numeric_var_samp(fcinfo);
5182 #endif
5183 }
5184 
5185 Datum
numeric_poly_stddev_samp(PG_FUNCTION_ARGS)5186 numeric_poly_stddev_samp(PG_FUNCTION_ARGS)
5187 {
5188 #ifdef HAVE_INT128
5189 	PolyNumAggState *state;
5190 	Numeric		res;
5191 	bool		is_null;
5192 
5193 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5194 
5195 	res = numeric_poly_stddev_internal(state, false, true, &is_null);
5196 
5197 	if (is_null)
5198 		PG_RETURN_NULL();
5199 	else
5200 		PG_RETURN_NUMERIC(res);
5201 #else
5202 	return numeric_stddev_samp(fcinfo);
5203 #endif
5204 }
5205 
5206 Datum
numeric_poly_var_pop(PG_FUNCTION_ARGS)5207 numeric_poly_var_pop(PG_FUNCTION_ARGS)
5208 {
5209 #ifdef HAVE_INT128
5210 	PolyNumAggState *state;
5211 	Numeric		res;
5212 	bool		is_null;
5213 
5214 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5215 
5216 	res = numeric_poly_stddev_internal(state, true, false, &is_null);
5217 
5218 	if (is_null)
5219 		PG_RETURN_NULL();
5220 	else
5221 		PG_RETURN_NUMERIC(res);
5222 #else
5223 	return numeric_var_pop(fcinfo);
5224 #endif
5225 }
5226 
5227 Datum
numeric_poly_stddev_pop(PG_FUNCTION_ARGS)5228 numeric_poly_stddev_pop(PG_FUNCTION_ARGS)
5229 {
5230 #ifdef HAVE_INT128
5231 	PolyNumAggState *state;
5232 	Numeric		res;
5233 	bool		is_null;
5234 
5235 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5236 
5237 	res = numeric_poly_stddev_internal(state, false, false, &is_null);
5238 
5239 	if (is_null)
5240 		PG_RETURN_NULL();
5241 	else
5242 		PG_RETURN_NUMERIC(res);
5243 #else
5244 	return numeric_stddev_pop(fcinfo);
5245 #endif
5246 }
5247 
5248 /*
5249  * SUM transition functions for integer datatypes.
5250  *
5251  * To avoid overflow, we use accumulators wider than the input datatype.
5252  * A Numeric accumulator is needed for int8 input; for int4 and int2
5253  * inputs, we use int8 accumulators which should be sufficient for practical
5254  * purposes.  (The latter two therefore don't really belong in this file,
5255  * but we keep them here anyway.)
5256  *
5257  * Because SQL defines the SUM() of no values to be NULL, not zero,
5258  * the initial condition of the transition data value needs to be NULL. This
5259  * means we can't rely on ExecAgg to automatically insert the first non-null
5260  * data value into the transition data: it doesn't know how to do the type
5261  * conversion.  The upshot is that these routines have to be marked non-strict
5262  * and handle substitution of the first non-null input themselves.
5263  *
5264  * Note: these functions are used only in plain aggregation mode.
5265  * In moving-aggregate mode, we use intX_avg_accum and intX_avg_accum_inv.
5266  */
5267 
5268 Datum
int2_sum(PG_FUNCTION_ARGS)5269 int2_sum(PG_FUNCTION_ARGS)
5270 {
5271 	int64		newval;
5272 
5273 	if (PG_ARGISNULL(0))
5274 	{
5275 		/* No non-null input seen so far... */
5276 		if (PG_ARGISNULL(1))
5277 			PG_RETURN_NULL();	/* still no non-null */
5278 		/* This is the first non-null input. */
5279 		newval = (int64) PG_GETARG_INT16(1);
5280 		PG_RETURN_INT64(newval);
5281 	}
5282 
5283 	/*
5284 	 * If we're invoked as an aggregate, we can cheat and modify our first
5285 	 * parameter in-place to avoid palloc overhead. If not, we need to return
5286 	 * the new value of the transition variable. (If int8 is pass-by-value,
5287 	 * then of course this is useless as well as incorrect, so just ifdef it
5288 	 * out.)
5289 	 */
5290 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
5291 	if (AggCheckCallContext(fcinfo, NULL))
5292 	{
5293 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
5294 
5295 		/* Leave the running sum unchanged in the new input is null */
5296 		if (!PG_ARGISNULL(1))
5297 			*oldsum = *oldsum + (int64) PG_GETARG_INT16(1);
5298 
5299 		PG_RETURN_POINTER(oldsum);
5300 	}
5301 	else
5302 #endif
5303 	{
5304 		int64		oldsum = PG_GETARG_INT64(0);
5305 
5306 		/* Leave sum unchanged if new input is null. */
5307 		if (PG_ARGISNULL(1))
5308 			PG_RETURN_INT64(oldsum);
5309 
5310 		/* OK to do the addition. */
5311 		newval = oldsum + (int64) PG_GETARG_INT16(1);
5312 
5313 		PG_RETURN_INT64(newval);
5314 	}
5315 }
5316 
5317 Datum
int4_sum(PG_FUNCTION_ARGS)5318 int4_sum(PG_FUNCTION_ARGS)
5319 {
5320 	int64		newval;
5321 
5322 	if (PG_ARGISNULL(0))
5323 	{
5324 		/* No non-null input seen so far... */
5325 		if (PG_ARGISNULL(1))
5326 			PG_RETURN_NULL();	/* still no non-null */
5327 		/* This is the first non-null input. */
5328 		newval = (int64) PG_GETARG_INT32(1);
5329 		PG_RETURN_INT64(newval);
5330 	}
5331 
5332 	/*
5333 	 * If we're invoked as an aggregate, we can cheat and modify our first
5334 	 * parameter in-place to avoid palloc overhead. If not, we need to return
5335 	 * the new value of the transition variable. (If int8 is pass-by-value,
5336 	 * then of course this is useless as well as incorrect, so just ifdef it
5337 	 * out.)
5338 	 */
5339 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
5340 	if (AggCheckCallContext(fcinfo, NULL))
5341 	{
5342 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
5343 
5344 		/* Leave the running sum unchanged in the new input is null */
5345 		if (!PG_ARGISNULL(1))
5346 			*oldsum = *oldsum + (int64) PG_GETARG_INT32(1);
5347 
5348 		PG_RETURN_POINTER(oldsum);
5349 	}
5350 	else
5351 #endif
5352 	{
5353 		int64		oldsum = PG_GETARG_INT64(0);
5354 
5355 		/* Leave sum unchanged if new input is null. */
5356 		if (PG_ARGISNULL(1))
5357 			PG_RETURN_INT64(oldsum);
5358 
5359 		/* OK to do the addition. */
5360 		newval = oldsum + (int64) PG_GETARG_INT32(1);
5361 
5362 		PG_RETURN_INT64(newval);
5363 	}
5364 }
5365 
5366 /*
5367  * Note: this function is obsolete, it's no longer used for SUM(int8).
5368  */
5369 Datum
int8_sum(PG_FUNCTION_ARGS)5370 int8_sum(PG_FUNCTION_ARGS)
5371 {
5372 	Numeric		oldsum;
5373 	Datum		newval;
5374 
5375 	if (PG_ARGISNULL(0))
5376 	{
5377 		/* No non-null input seen so far... */
5378 		if (PG_ARGISNULL(1))
5379 			PG_RETURN_NULL();	/* still no non-null */
5380 		/* This is the first non-null input. */
5381 		newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5382 		PG_RETURN_DATUM(newval);
5383 	}
5384 
5385 	/*
5386 	 * Note that we cannot special-case the aggregate case here, as we do for
5387 	 * int2_sum and int4_sum: numeric is of variable size, so we cannot modify
5388 	 * our first parameter in-place.
5389 	 */
5390 
5391 	oldsum = PG_GETARG_NUMERIC(0);
5392 
5393 	/* Leave sum unchanged if new input is null. */
5394 	if (PG_ARGISNULL(1))
5395 		PG_RETURN_NUMERIC(oldsum);
5396 
5397 	/* OK to do the addition. */
5398 	newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5399 
5400 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_add,
5401 										NumericGetDatum(oldsum), newval));
5402 }
5403 
5404 
5405 /*
5406  * Routines for avg(int2) and avg(int4).  The transition datatype
5407  * is a two-element int8 array, holding count and sum.
5408  *
5409  * These functions are also used for sum(int2) and sum(int4) when
5410  * operating in moving-aggregate mode, since for correct inverse transitions
5411  * we need to count the inputs.
5412  */
5413 
5414 typedef struct Int8TransTypeData
5415 {
5416 	int64		count;
5417 	int64		sum;
5418 } Int8TransTypeData;
5419 
5420 Datum
int2_avg_accum(PG_FUNCTION_ARGS)5421 int2_avg_accum(PG_FUNCTION_ARGS)
5422 {
5423 	ArrayType  *transarray;
5424 	int16		newval = PG_GETARG_INT16(1);
5425 	Int8TransTypeData *transdata;
5426 
5427 	/*
5428 	 * If we're invoked as an aggregate, we can cheat and modify our first
5429 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5430 	 * a copy of it before scribbling on it.
5431 	 */
5432 	if (AggCheckCallContext(fcinfo, NULL))
5433 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5434 	else
5435 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5436 
5437 	if (ARR_HASNULL(transarray) ||
5438 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5439 		elog(ERROR, "expected 2-element int8 array");
5440 
5441 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5442 	transdata->count++;
5443 	transdata->sum += newval;
5444 
5445 	PG_RETURN_ARRAYTYPE_P(transarray);
5446 }
5447 
5448 Datum
int4_avg_accum(PG_FUNCTION_ARGS)5449 int4_avg_accum(PG_FUNCTION_ARGS)
5450 {
5451 	ArrayType  *transarray;
5452 	int32		newval = PG_GETARG_INT32(1);
5453 	Int8TransTypeData *transdata;
5454 
5455 	/*
5456 	 * If we're invoked as an aggregate, we can cheat and modify our first
5457 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5458 	 * a copy of it before scribbling on it.
5459 	 */
5460 	if (AggCheckCallContext(fcinfo, NULL))
5461 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5462 	else
5463 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5464 
5465 	if (ARR_HASNULL(transarray) ||
5466 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5467 		elog(ERROR, "expected 2-element int8 array");
5468 
5469 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5470 	transdata->count++;
5471 	transdata->sum += newval;
5472 
5473 	PG_RETURN_ARRAYTYPE_P(transarray);
5474 }
5475 
5476 Datum
int4_avg_combine(PG_FUNCTION_ARGS)5477 int4_avg_combine(PG_FUNCTION_ARGS)
5478 {
5479 	ArrayType  *transarray1;
5480 	ArrayType  *transarray2;
5481 	Int8TransTypeData *state1;
5482 	Int8TransTypeData *state2;
5483 
5484 	if (!AggCheckCallContext(fcinfo, NULL))
5485 		elog(ERROR, "aggregate function called in non-aggregate context");
5486 
5487 	transarray1 = PG_GETARG_ARRAYTYPE_P(0);
5488 	transarray2 = PG_GETARG_ARRAYTYPE_P(1);
5489 
5490 	if (ARR_HASNULL(transarray1) ||
5491 		ARR_SIZE(transarray1) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5492 		elog(ERROR, "expected 2-element int8 array");
5493 
5494 	if (ARR_HASNULL(transarray2) ||
5495 		ARR_SIZE(transarray2) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5496 		elog(ERROR, "expected 2-element int8 array");
5497 
5498 	state1 = (Int8TransTypeData *) ARR_DATA_PTR(transarray1);
5499 	state2 = (Int8TransTypeData *) ARR_DATA_PTR(transarray2);
5500 
5501 	state1->count += state2->count;
5502 	state1->sum += state2->sum;
5503 
5504 	PG_RETURN_ARRAYTYPE_P(transarray1);
5505 }
5506 
5507 Datum
int2_avg_accum_inv(PG_FUNCTION_ARGS)5508 int2_avg_accum_inv(PG_FUNCTION_ARGS)
5509 {
5510 	ArrayType  *transarray;
5511 	int16		newval = PG_GETARG_INT16(1);
5512 	Int8TransTypeData *transdata;
5513 
5514 	/*
5515 	 * If we're invoked as an aggregate, we can cheat and modify our first
5516 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5517 	 * a copy of it before scribbling on it.
5518 	 */
5519 	if (AggCheckCallContext(fcinfo, NULL))
5520 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5521 	else
5522 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5523 
5524 	if (ARR_HASNULL(transarray) ||
5525 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5526 		elog(ERROR, "expected 2-element int8 array");
5527 
5528 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5529 	transdata->count--;
5530 	transdata->sum -= newval;
5531 
5532 	PG_RETURN_ARRAYTYPE_P(transarray);
5533 }
5534 
5535 Datum
int4_avg_accum_inv(PG_FUNCTION_ARGS)5536 int4_avg_accum_inv(PG_FUNCTION_ARGS)
5537 {
5538 	ArrayType  *transarray;
5539 	int32		newval = PG_GETARG_INT32(1);
5540 	Int8TransTypeData *transdata;
5541 
5542 	/*
5543 	 * If we're invoked as an aggregate, we can cheat and modify our first
5544 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5545 	 * a copy of it before scribbling on it.
5546 	 */
5547 	if (AggCheckCallContext(fcinfo, NULL))
5548 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5549 	else
5550 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5551 
5552 	if (ARR_HASNULL(transarray) ||
5553 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5554 		elog(ERROR, "expected 2-element int8 array");
5555 
5556 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5557 	transdata->count--;
5558 	transdata->sum -= newval;
5559 
5560 	PG_RETURN_ARRAYTYPE_P(transarray);
5561 }
5562 
5563 Datum
int8_avg(PG_FUNCTION_ARGS)5564 int8_avg(PG_FUNCTION_ARGS)
5565 {
5566 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
5567 	Int8TransTypeData *transdata;
5568 	Datum		countd,
5569 				sumd;
5570 
5571 	if (ARR_HASNULL(transarray) ||
5572 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5573 		elog(ERROR, "expected 2-element int8 array");
5574 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5575 
5576 	/* SQL defines AVG of no values to be NULL */
5577 	if (transdata->count == 0)
5578 		PG_RETURN_NULL();
5579 
5580 	countd = DirectFunctionCall1(int8_numeric,
5581 								 Int64GetDatumFast(transdata->count));
5582 	sumd = DirectFunctionCall1(int8_numeric,
5583 							   Int64GetDatumFast(transdata->sum));
5584 
5585 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
5586 }
5587 
5588 /*
5589  * SUM(int2) and SUM(int4) both return int8, so we can use this
5590  * final function for both.
5591  */
5592 Datum
int2int4_sum(PG_FUNCTION_ARGS)5593 int2int4_sum(PG_FUNCTION_ARGS)
5594 {
5595 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
5596 	Int8TransTypeData *transdata;
5597 
5598 	if (ARR_HASNULL(transarray) ||
5599 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5600 		elog(ERROR, "expected 2-element int8 array");
5601 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5602 
5603 	/* SQL defines SUM of no values to be NULL */
5604 	if (transdata->count == 0)
5605 		PG_RETURN_NULL();
5606 
5607 	PG_RETURN_DATUM(Int64GetDatumFast(transdata->sum));
5608 }
5609 
5610 
5611 /* ----------------------------------------------------------------------
5612  *
5613  * Debug support
5614  *
5615  * ----------------------------------------------------------------------
5616  */
5617 
5618 #ifdef NUMERIC_DEBUG
5619 
5620 /*
5621  * dump_numeric() - Dump a value in the db storage format for debugging
5622  */
5623 static void
dump_numeric(const char * str,Numeric num)5624 dump_numeric(const char *str, Numeric num)
5625 {
5626 	NumericDigit *digits = NUMERIC_DIGITS(num);
5627 	int			ndigits;
5628 	int			i;
5629 
5630 	ndigits = NUMERIC_NDIGITS(num);
5631 
5632 	printf("%s: NUMERIC w=%d d=%d ", str,
5633 		   NUMERIC_WEIGHT(num), NUMERIC_DSCALE(num));
5634 	switch (NUMERIC_SIGN(num))
5635 	{
5636 		case NUMERIC_POS:
5637 			printf("POS");
5638 			break;
5639 		case NUMERIC_NEG:
5640 			printf("NEG");
5641 			break;
5642 		case NUMERIC_NAN:
5643 			printf("NaN");
5644 			break;
5645 		default:
5646 			printf("SIGN=0x%x", NUMERIC_SIGN(num));
5647 			break;
5648 	}
5649 
5650 	for (i = 0; i < ndigits; i++)
5651 		printf(" %0*d", DEC_DIGITS, digits[i]);
5652 	printf("\n");
5653 }
5654 
5655 
5656 /*
5657  * dump_var() - Dump a value in the variable format for debugging
5658  */
5659 static void
dump_var(const char * str,NumericVar * var)5660 dump_var(const char *str, NumericVar *var)
5661 {
5662 	int			i;
5663 
5664 	printf("%s: VAR w=%d d=%d ", str, var->weight, var->dscale);
5665 	switch (var->sign)
5666 	{
5667 		case NUMERIC_POS:
5668 			printf("POS");
5669 			break;
5670 		case NUMERIC_NEG:
5671 			printf("NEG");
5672 			break;
5673 		case NUMERIC_NAN:
5674 			printf("NaN");
5675 			break;
5676 		default:
5677 			printf("SIGN=0x%x", var->sign);
5678 			break;
5679 	}
5680 
5681 	for (i = 0; i < var->ndigits; i++)
5682 		printf(" %0*d", DEC_DIGITS, var->digits[i]);
5683 
5684 	printf("\n");
5685 }
5686 #endif							/* NUMERIC_DEBUG */
5687 
5688 
5689 /* ----------------------------------------------------------------------
5690  *
5691  * Local functions follow
5692  *
5693  * In general, these do not support NaNs --- callers must eliminate
5694  * the possibility of NaN first.  (make_result() is an exception.)
5695  *
5696  * ----------------------------------------------------------------------
5697  */
5698 
5699 
5700 /*
5701  * alloc_var() -
5702  *
5703  *	Allocate a digit buffer of ndigits digits (plus a spare digit for rounding)
5704  */
5705 static void
alloc_var(NumericVar * var,int ndigits)5706 alloc_var(NumericVar *var, int ndigits)
5707 {
5708 	digitbuf_free(var->buf);
5709 	var->buf = digitbuf_alloc(ndigits + 1);
5710 	var->buf[0] = 0;			/* spare digit for rounding */
5711 	var->digits = var->buf + 1;
5712 	var->ndigits = ndigits;
5713 }
5714 
5715 
5716 /*
5717  * free_var() -
5718  *
5719  *	Return the digit buffer of a variable to the free pool
5720  */
5721 static void
free_var(NumericVar * var)5722 free_var(NumericVar *var)
5723 {
5724 	digitbuf_free(var->buf);
5725 	var->buf = NULL;
5726 	var->digits = NULL;
5727 	var->sign = NUMERIC_NAN;
5728 }
5729 
5730 
5731 /*
5732  * zero_var() -
5733  *
5734  *	Set a variable to ZERO.
5735  *	Note: its dscale is not touched.
5736  */
5737 static void
zero_var(NumericVar * var)5738 zero_var(NumericVar *var)
5739 {
5740 	digitbuf_free(var->buf);
5741 	var->buf = NULL;
5742 	var->digits = NULL;
5743 	var->ndigits = 0;
5744 	var->weight = 0;			/* by convention; doesn't really matter */
5745 	var->sign = NUMERIC_POS;	/* anything but NAN... */
5746 }
5747 
5748 
5749 /*
5750  * set_var_from_str()
5751  *
5752  *	Parse a string and put the number into a variable
5753  *
5754  * This function does not handle leading or trailing spaces, and it doesn't
5755  * accept "NaN" either.  It returns the end+1 position so that caller can
5756  * check for trailing spaces/garbage if deemed necessary.
5757  *
5758  * cp is the place to actually start parsing; str is what to use in error
5759  * reports.  (Typically cp would be the same except advanced over spaces.)
5760  */
5761 static const char *
set_var_from_str(const char * str,const char * cp,NumericVar * dest)5762 set_var_from_str(const char *str, const char *cp, NumericVar *dest)
5763 {
5764 	bool		have_dp = false;
5765 	int			i;
5766 	unsigned char *decdigits;
5767 	int			sign = NUMERIC_POS;
5768 	int			dweight = -1;
5769 	int			ddigits;
5770 	int			dscale = 0;
5771 	int			weight;
5772 	int			ndigits;
5773 	int			offset;
5774 	NumericDigit *digits;
5775 
5776 	/*
5777 	 * We first parse the string to extract decimal digits and determine the
5778 	 * correct decimal weight.  Then convert to NBASE representation.
5779 	 */
5780 	switch (*cp)
5781 	{
5782 		case '+':
5783 			sign = NUMERIC_POS;
5784 			cp++;
5785 			break;
5786 
5787 		case '-':
5788 			sign = NUMERIC_NEG;
5789 			cp++;
5790 			break;
5791 	}
5792 
5793 	if (*cp == '.')
5794 	{
5795 		have_dp = true;
5796 		cp++;
5797 	}
5798 
5799 	if (!isdigit((unsigned char) *cp))
5800 		ereport(ERROR,
5801 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5802 				 errmsg("invalid input syntax for type %s: \"%s\"",
5803 						"numeric", str)));
5804 
5805 	decdigits = (unsigned char *) palloc(strlen(cp) + DEC_DIGITS * 2);
5806 
5807 	/* leading padding for digit alignment later */
5808 	memset(decdigits, 0, DEC_DIGITS);
5809 	i = DEC_DIGITS;
5810 
5811 	while (*cp)
5812 	{
5813 		if (isdigit((unsigned char) *cp))
5814 		{
5815 			decdigits[i++] = *cp++ - '0';
5816 			if (!have_dp)
5817 				dweight++;
5818 			else
5819 				dscale++;
5820 		}
5821 		else if (*cp == '.')
5822 		{
5823 			if (have_dp)
5824 				ereport(ERROR,
5825 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5826 						 errmsg("invalid input syntax for type %s: \"%s\"",
5827 								"numeric", str)));
5828 			have_dp = true;
5829 			cp++;
5830 		}
5831 		else
5832 			break;
5833 	}
5834 
5835 	ddigits = i - DEC_DIGITS;
5836 	/* trailing padding for digit alignment later */
5837 	memset(decdigits + i, 0, DEC_DIGITS - 1);
5838 
5839 	/* Handle exponent, if any */
5840 	if (*cp == 'e' || *cp == 'E')
5841 	{
5842 		long		exponent;
5843 		char	   *endptr;
5844 
5845 		cp++;
5846 		exponent = strtol(cp, &endptr, 10);
5847 		if (endptr == cp)
5848 			ereport(ERROR,
5849 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5850 					 errmsg("invalid input syntax for type %s: \"%s\"",
5851 							"numeric", str)));
5852 		cp = endptr;
5853 
5854 		/*
5855 		 * At this point, dweight and dscale can't be more than about
5856 		 * INT_MAX/2 due to the MaxAllocSize limit on string length, so
5857 		 * constraining the exponent similarly should be enough to prevent
5858 		 * integer overflow in this function.  If the value is too large to
5859 		 * fit in storage format, make_result() will complain about it later;
5860 		 * for consistency use the same ereport errcode/text as make_result().
5861 		 */
5862 		if (exponent >= INT_MAX / 2 || exponent <= -(INT_MAX / 2))
5863 			ereport(ERROR,
5864 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
5865 					 errmsg("value overflows numeric format")));
5866 		dweight += (int) exponent;
5867 		dscale -= (int) exponent;
5868 		if (dscale < 0)
5869 			dscale = 0;
5870 	}
5871 
5872 	/*
5873 	 * Okay, convert pure-decimal representation to base NBASE.  First we need
5874 	 * to determine the converted weight and ndigits.  offset is the number of
5875 	 * decimal zeroes to insert before the first given digit to have a
5876 	 * correctly aligned first NBASE digit.
5877 	 */
5878 	if (dweight >= 0)
5879 		weight = (dweight + 1 + DEC_DIGITS - 1) / DEC_DIGITS - 1;
5880 	else
5881 		weight = -((-dweight - 1) / DEC_DIGITS + 1);
5882 	offset = (weight + 1) * DEC_DIGITS - (dweight + 1);
5883 	ndigits = (ddigits + offset + DEC_DIGITS - 1) / DEC_DIGITS;
5884 
5885 	alloc_var(dest, ndigits);
5886 	dest->sign = sign;
5887 	dest->weight = weight;
5888 	dest->dscale = dscale;
5889 
5890 	i = DEC_DIGITS - offset;
5891 	digits = dest->digits;
5892 
5893 	while (ndigits-- > 0)
5894 	{
5895 #if DEC_DIGITS == 4
5896 		*digits++ = ((decdigits[i] * 10 + decdigits[i + 1]) * 10 +
5897 					 decdigits[i + 2]) * 10 + decdigits[i + 3];
5898 #elif DEC_DIGITS == 2
5899 		*digits++ = decdigits[i] * 10 + decdigits[i + 1];
5900 #elif DEC_DIGITS == 1
5901 		*digits++ = decdigits[i];
5902 #else
5903 #error unsupported NBASE
5904 #endif
5905 		i += DEC_DIGITS;
5906 	}
5907 
5908 	pfree(decdigits);
5909 
5910 	/* Strip any leading/trailing zeroes, and normalize weight if zero */
5911 	strip_var(dest);
5912 
5913 	/* Return end+1 position for caller */
5914 	return cp;
5915 }
5916 
5917 
5918 /*
5919  * set_var_from_num() -
5920  *
5921  *	Convert the packed db format into a variable
5922  */
5923 static void
set_var_from_num(Numeric num,NumericVar * dest)5924 set_var_from_num(Numeric num, NumericVar *dest)
5925 {
5926 	int			ndigits;
5927 
5928 	ndigits = NUMERIC_NDIGITS(num);
5929 
5930 	alloc_var(dest, ndigits);
5931 
5932 	dest->weight = NUMERIC_WEIGHT(num);
5933 	dest->sign = NUMERIC_SIGN(num);
5934 	dest->dscale = NUMERIC_DSCALE(num);
5935 
5936 	memcpy(dest->digits, NUMERIC_DIGITS(num), ndigits * sizeof(NumericDigit));
5937 }
5938 
5939 
5940 /*
5941  * init_var_from_num() -
5942  *
5943  *	Initialize a variable from packed db format. The digits array is not
5944  *	copied, which saves some cycles when the resulting var is not modified.
5945  *	Also, there's no need to call free_var(), as long as you don't assign any
5946  *	other value to it (with set_var_* functions, or by using the var as the
5947  *	destination of a function like add_var())
5948  *
5949  *	CAUTION: Do not modify the digits buffer of a var initialized with this
5950  *	function, e.g by calling round_var() or trunc_var(), as the changes will
5951  *	propagate to the original Numeric! It's OK to use it as the destination
5952  *	argument of one of the calculational functions, though.
5953  */
5954 static void
init_var_from_num(Numeric num,NumericVar * dest)5955 init_var_from_num(Numeric num, NumericVar *dest)
5956 {
5957 	dest->ndigits = NUMERIC_NDIGITS(num);
5958 	dest->weight = NUMERIC_WEIGHT(num);
5959 	dest->sign = NUMERIC_SIGN(num);
5960 	dest->dscale = NUMERIC_DSCALE(num);
5961 	dest->digits = NUMERIC_DIGITS(num);
5962 	dest->buf = NULL;			/* digits array is not palloc'd */
5963 }
5964 
5965 
5966 /*
5967  * set_var_from_var() -
5968  *
5969  *	Copy one variable into another
5970  */
5971 static void
set_var_from_var(const NumericVar * value,NumericVar * dest)5972 set_var_from_var(const NumericVar *value, NumericVar *dest)
5973 {
5974 	NumericDigit *newbuf;
5975 
5976 	newbuf = digitbuf_alloc(value->ndigits + 1);
5977 	newbuf[0] = 0;				/* spare digit for rounding */
5978 	if (value->ndigits > 0)		/* else value->digits might be null */
5979 		memcpy(newbuf + 1, value->digits,
5980 			   value->ndigits * sizeof(NumericDigit));
5981 
5982 	digitbuf_free(dest->buf);
5983 
5984 	memmove(dest, value, sizeof(NumericVar));
5985 	dest->buf = newbuf;
5986 	dest->digits = newbuf + 1;
5987 }
5988 
5989 
5990 /*
5991  * get_str_from_var() -
5992  *
5993  *	Convert a var to text representation (guts of numeric_out).
5994  *	The var is displayed to the number of digits indicated by its dscale.
5995  *	Returns a palloc'd string.
5996  */
5997 static char *
get_str_from_var(const NumericVar * var)5998 get_str_from_var(const NumericVar *var)
5999 {
6000 	int			dscale;
6001 	char	   *str;
6002 	char	   *cp;
6003 	char	   *endcp;
6004 	int			i;
6005 	int			d;
6006 	NumericDigit dig;
6007 
6008 #if DEC_DIGITS > 1
6009 	NumericDigit d1;
6010 #endif
6011 
6012 	dscale = var->dscale;
6013 
6014 	/*
6015 	 * Allocate space for the result.
6016 	 *
6017 	 * i is set to the # of decimal digits before decimal point. dscale is the
6018 	 * # of decimal digits we will print after decimal point. We may generate
6019 	 * as many as DEC_DIGITS-1 excess digits at the end, and in addition we
6020 	 * need room for sign, decimal point, null terminator.
6021 	 */
6022 	i = (var->weight + 1) * DEC_DIGITS;
6023 	if (i <= 0)
6024 		i = 1;
6025 
6026 	str = palloc(i + dscale + DEC_DIGITS + 2);
6027 	cp = str;
6028 
6029 	/*
6030 	 * Output a dash for negative values
6031 	 */
6032 	if (var->sign == NUMERIC_NEG)
6033 		*cp++ = '-';
6034 
6035 	/*
6036 	 * Output all digits before the decimal point
6037 	 */
6038 	if (var->weight < 0)
6039 	{
6040 		d = var->weight + 1;
6041 		*cp++ = '0';
6042 	}
6043 	else
6044 	{
6045 		for (d = 0; d <= var->weight; d++)
6046 		{
6047 			dig = (d < var->ndigits) ? var->digits[d] : 0;
6048 			/* In the first digit, suppress extra leading decimal zeroes */
6049 #if DEC_DIGITS == 4
6050 			{
6051 				bool		putit = (d > 0);
6052 
6053 				d1 = dig / 1000;
6054 				dig -= d1 * 1000;
6055 				putit |= (d1 > 0);
6056 				if (putit)
6057 					*cp++ = d1 + '0';
6058 				d1 = dig / 100;
6059 				dig -= d1 * 100;
6060 				putit |= (d1 > 0);
6061 				if (putit)
6062 					*cp++ = d1 + '0';
6063 				d1 = dig / 10;
6064 				dig -= d1 * 10;
6065 				putit |= (d1 > 0);
6066 				if (putit)
6067 					*cp++ = d1 + '0';
6068 				*cp++ = dig + '0';
6069 			}
6070 #elif DEC_DIGITS == 2
6071 			d1 = dig / 10;
6072 			dig -= d1 * 10;
6073 			if (d1 > 0 || d > 0)
6074 				*cp++ = d1 + '0';
6075 			*cp++ = dig + '0';
6076 #elif DEC_DIGITS == 1
6077 			*cp++ = dig + '0';
6078 #else
6079 #error unsupported NBASE
6080 #endif
6081 		}
6082 	}
6083 
6084 	/*
6085 	 * If requested, output a decimal point and all the digits that follow it.
6086 	 * We initially put out a multiple of DEC_DIGITS digits, then truncate if
6087 	 * needed.
6088 	 */
6089 	if (dscale > 0)
6090 	{
6091 		*cp++ = '.';
6092 		endcp = cp + dscale;
6093 		for (i = 0; i < dscale; d++, i += DEC_DIGITS)
6094 		{
6095 			dig = (d >= 0 && d < var->ndigits) ? var->digits[d] : 0;
6096 #if DEC_DIGITS == 4
6097 			d1 = dig / 1000;
6098 			dig -= d1 * 1000;
6099 			*cp++ = d1 + '0';
6100 			d1 = dig / 100;
6101 			dig -= d1 * 100;
6102 			*cp++ = d1 + '0';
6103 			d1 = dig / 10;
6104 			dig -= d1 * 10;
6105 			*cp++ = d1 + '0';
6106 			*cp++ = dig + '0';
6107 #elif DEC_DIGITS == 2
6108 			d1 = dig / 10;
6109 			dig -= d1 * 10;
6110 			*cp++ = d1 + '0';
6111 			*cp++ = dig + '0';
6112 #elif DEC_DIGITS == 1
6113 			*cp++ = dig + '0';
6114 #else
6115 #error unsupported NBASE
6116 #endif
6117 		}
6118 		cp = endcp;
6119 	}
6120 
6121 	/*
6122 	 * terminate the string and return it
6123 	 */
6124 	*cp = '\0';
6125 	return str;
6126 }
6127 
6128 /*
6129  * get_str_from_var_sci() -
6130  *
6131  *	Convert a var to a normalised scientific notation text representation.
6132  *	This function does the heavy lifting for numeric_out_sci().
6133  *
6134  *	This notation has the general form a * 10^b, where a is known as the
6135  *	"significand" and b is known as the "exponent".
6136  *
6137  *	Because we can't do superscript in ASCII (and because we want to copy
6138  *	printf's behaviour) we display the exponent using E notation, with a
6139  *	minimum of two exponent digits.
6140  *
6141  *	For example, the value 1234 could be output as 1.2e+03.
6142  *
6143  *	We assume that the exponent can fit into an int32.
6144  *
6145  *	rscale is the number of decimal digits desired after the decimal point in
6146  *	the output, negative values will be treated as meaning zero.
6147  *
6148  *	Returns a palloc'd string.
6149  */
6150 static char *
get_str_from_var_sci(const NumericVar * var,int rscale)6151 get_str_from_var_sci(const NumericVar *var, int rscale)
6152 {
6153 	int32		exponent;
6154 	NumericVar	tmp_var;
6155 	size_t		len;
6156 	char	   *str;
6157 	char	   *sig_out;
6158 
6159 	if (rscale < 0)
6160 		rscale = 0;
6161 
6162 	/*
6163 	 * Determine the exponent of this number in normalised form.
6164 	 *
6165 	 * This is the exponent required to represent the number with only one
6166 	 * significant digit before the decimal place.
6167 	 */
6168 	if (var->ndigits > 0)
6169 	{
6170 		exponent = (var->weight + 1) * DEC_DIGITS;
6171 
6172 		/*
6173 		 * Compensate for leading decimal zeroes in the first numeric digit by
6174 		 * decrementing the exponent.
6175 		 */
6176 		exponent -= DEC_DIGITS - (int) log10(var->digits[0]);
6177 	}
6178 	else
6179 	{
6180 		/*
6181 		 * If var has no digits, then it must be zero.
6182 		 *
6183 		 * Zero doesn't technically have a meaningful exponent in normalised
6184 		 * notation, but we just display the exponent as zero for consistency
6185 		 * of output.
6186 		 */
6187 		exponent = 0;
6188 	}
6189 
6190 	/*
6191 	 * Divide var by 10^exponent to get the significand, rounding to rscale
6192 	 * decimal digits in the process.
6193 	 */
6194 	init_var(&tmp_var);
6195 
6196 	power_ten_int(exponent, &tmp_var);
6197 	div_var(var, &tmp_var, &tmp_var, rscale, true);
6198 	sig_out = get_str_from_var(&tmp_var);
6199 
6200 	free_var(&tmp_var);
6201 
6202 	/*
6203 	 * Allocate space for the result.
6204 	 *
6205 	 * In addition to the significand, we need room for the exponent
6206 	 * decoration ("e"), the sign of the exponent, up to 10 digits for the
6207 	 * exponent itself, and of course the null terminator.
6208 	 */
6209 	len = strlen(sig_out) + 13;
6210 	str = palloc(len);
6211 	snprintf(str, len, "%se%+03d", sig_out, exponent);
6212 
6213 	pfree(sig_out);
6214 
6215 	return str;
6216 }
6217 
6218 
6219 /*
6220  * make_result_opt_error() -
6221  *
6222  *	Create the packed db numeric format in palloc()'d memory from
6223  *	a variable.  If "*have_error" flag is provided, on error it's set to
6224  *	true, NULL returned.  This is helpful when caller need to handle errors
6225  *	by itself.
6226  */
6227 static Numeric
make_result_opt_error(const NumericVar * var,bool * have_error)6228 make_result_opt_error(const NumericVar *var, bool *have_error)
6229 {
6230 	Numeric		result;
6231 	NumericDigit *digits = var->digits;
6232 	int			weight = var->weight;
6233 	int			sign = var->sign;
6234 	int			n;
6235 	Size		len;
6236 
6237 	if (sign == NUMERIC_NAN)
6238 	{
6239 		result = (Numeric) palloc(NUMERIC_HDRSZ_SHORT);
6240 
6241 		SET_VARSIZE(result, NUMERIC_HDRSZ_SHORT);
6242 		result->choice.n_header = NUMERIC_NAN;
6243 		/* the header word is all we need */
6244 
6245 		dump_numeric("make_result()", result);
6246 		return result;
6247 	}
6248 
6249 	n = var->ndigits;
6250 
6251 	/* truncate leading zeroes */
6252 	while (n > 0 && *digits == 0)
6253 	{
6254 		digits++;
6255 		weight--;
6256 		n--;
6257 	}
6258 	/* truncate trailing zeroes */
6259 	while (n > 0 && digits[n - 1] == 0)
6260 		n--;
6261 
6262 	/* If zero result, force to weight=0 and positive sign */
6263 	if (n == 0)
6264 	{
6265 		weight = 0;
6266 		sign = NUMERIC_POS;
6267 	}
6268 
6269 	/* Build the result */
6270 	if (NUMERIC_CAN_BE_SHORT(var->dscale, weight))
6271 	{
6272 		len = NUMERIC_HDRSZ_SHORT + n * sizeof(NumericDigit);
6273 		result = (Numeric) palloc(len);
6274 		SET_VARSIZE(result, len);
6275 		result->choice.n_short.n_header =
6276 			(sign == NUMERIC_NEG ? (NUMERIC_SHORT | NUMERIC_SHORT_SIGN_MASK)
6277 			 : NUMERIC_SHORT)
6278 			| (var->dscale << NUMERIC_SHORT_DSCALE_SHIFT)
6279 			| (weight < 0 ? NUMERIC_SHORT_WEIGHT_SIGN_MASK : 0)
6280 			| (weight & NUMERIC_SHORT_WEIGHT_MASK);
6281 	}
6282 	else
6283 	{
6284 		len = NUMERIC_HDRSZ + n * sizeof(NumericDigit);
6285 		result = (Numeric) palloc(len);
6286 		SET_VARSIZE(result, len);
6287 		result->choice.n_long.n_sign_dscale =
6288 			sign | (var->dscale & NUMERIC_DSCALE_MASK);
6289 		result->choice.n_long.n_weight = weight;
6290 	}
6291 
6292 	Assert(NUMERIC_NDIGITS(result) == n);
6293 	if (n > 0)
6294 		memcpy(NUMERIC_DIGITS(result), digits, n * sizeof(NumericDigit));
6295 
6296 	/* Check for overflow of int16 fields */
6297 	if (NUMERIC_WEIGHT(result) != weight ||
6298 		NUMERIC_DSCALE(result) != var->dscale)
6299 	{
6300 		if (have_error)
6301 		{
6302 			*have_error = true;
6303 			return NULL;
6304 		}
6305 		else
6306 		{
6307 			ereport(ERROR,
6308 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6309 					 errmsg("value overflows numeric format")));
6310 		}
6311 	}
6312 
6313 	dump_numeric("make_result()", result);
6314 	return result;
6315 }
6316 
6317 
6318 /*
6319  * make_result() -
6320  *
6321  *	An interface to make_result_opt_error() without "have_error" argument.
6322  */
6323 static Numeric
make_result(const NumericVar * var)6324 make_result(const NumericVar *var)
6325 {
6326 	return make_result_opt_error(var, NULL);
6327 }
6328 
6329 
6330 /*
6331  * apply_typmod() -
6332  *
6333  *	Do bounds checking and rounding according to the attributes
6334  *	typmod field.
6335  */
6336 static void
apply_typmod(NumericVar * var,int32 typmod)6337 apply_typmod(NumericVar *var, int32 typmod)
6338 {
6339 	int			precision;
6340 	int			scale;
6341 	int			maxdigits;
6342 	int			ddigits;
6343 	int			i;
6344 
6345 	/* Do nothing if we have a default typmod (-1) */
6346 	if (typmod < (int32) (VARHDRSZ))
6347 		return;
6348 
6349 	typmod -= VARHDRSZ;
6350 	precision = (typmod >> 16) & 0xffff;
6351 	scale = typmod & 0xffff;
6352 	maxdigits = precision - scale;
6353 
6354 	/* Round to target scale (and set var->dscale) */
6355 	round_var(var, scale);
6356 
6357 	/*
6358 	 * Check for overflow - note we can't do this before rounding, because
6359 	 * rounding could raise the weight.  Also note that the var's weight could
6360 	 * be inflated by leading zeroes, which will be stripped before storage
6361 	 * but perhaps might not have been yet. In any case, we must recognize a
6362 	 * true zero, whose weight doesn't mean anything.
6363 	 */
6364 	ddigits = (var->weight + 1) * DEC_DIGITS;
6365 	if (ddigits > maxdigits)
6366 	{
6367 		/* Determine true weight; and check for all-zero result */
6368 		for (i = 0; i < var->ndigits; i++)
6369 		{
6370 			NumericDigit dig = var->digits[i];
6371 
6372 			if (dig)
6373 			{
6374 				/* Adjust for any high-order decimal zero digits */
6375 #if DEC_DIGITS == 4
6376 				if (dig < 10)
6377 					ddigits -= 3;
6378 				else if (dig < 100)
6379 					ddigits -= 2;
6380 				else if (dig < 1000)
6381 					ddigits -= 1;
6382 #elif DEC_DIGITS == 2
6383 				if (dig < 10)
6384 					ddigits -= 1;
6385 #elif DEC_DIGITS == 1
6386 				/* no adjustment */
6387 #else
6388 #error unsupported NBASE
6389 #endif
6390 				if (ddigits > maxdigits)
6391 					ereport(ERROR,
6392 							(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6393 							 errmsg("numeric field overflow"),
6394 							 errdetail("A field with precision %d, scale %d must round to an absolute value less than %s%d.",
6395 									   precision, scale,
6396 					/* Display 10^0 as 1 */
6397 									   maxdigits ? "10^" : "",
6398 									   maxdigits ? maxdigits : 1
6399 									   )));
6400 				break;
6401 			}
6402 			ddigits -= DEC_DIGITS;
6403 		}
6404 	}
6405 }
6406 
6407 /*
6408  * Convert numeric to int8, rounding if needed.
6409  *
6410  * If overflow, return false (no error is raised).  Return true if okay.
6411  */
6412 static bool
numericvar_to_int64(const NumericVar * var,int64 * result)6413 numericvar_to_int64(const NumericVar *var, int64 *result)
6414 {
6415 	NumericDigit *digits;
6416 	int			ndigits;
6417 	int			weight;
6418 	int			i;
6419 	int64		val;
6420 	bool		neg;
6421 	NumericVar	rounded;
6422 
6423 	/* Round to nearest integer */
6424 	init_var(&rounded);
6425 	set_var_from_var(var, &rounded);
6426 	round_var(&rounded, 0);
6427 
6428 	/* Check for zero input */
6429 	strip_var(&rounded);
6430 	ndigits = rounded.ndigits;
6431 	if (ndigits == 0)
6432 	{
6433 		*result = 0;
6434 		free_var(&rounded);
6435 		return true;
6436 	}
6437 
6438 	/*
6439 	 * For input like 10000000000, we must treat stripped digits as real. So
6440 	 * the loop assumes there are weight+1 digits before the decimal point.
6441 	 */
6442 	weight = rounded.weight;
6443 	Assert(weight >= 0 && ndigits <= weight + 1);
6444 
6445 	/*
6446 	 * Construct the result. To avoid issues with converting a value
6447 	 * corresponding to INT64_MIN (which can't be represented as a positive 64
6448 	 * bit two's complement integer), accumulate value as a negative number.
6449 	 */
6450 	digits = rounded.digits;
6451 	neg = (rounded.sign == NUMERIC_NEG);
6452 	val = -digits[0];
6453 	for (i = 1; i <= weight; i++)
6454 	{
6455 		if (unlikely(pg_mul_s64_overflow(val, NBASE, &val)))
6456 		{
6457 			free_var(&rounded);
6458 			return false;
6459 		}
6460 
6461 		if (i < ndigits)
6462 		{
6463 			if (unlikely(pg_sub_s64_overflow(val, digits[i], &val)))
6464 			{
6465 				free_var(&rounded);
6466 				return false;
6467 			}
6468 		}
6469 	}
6470 
6471 	free_var(&rounded);
6472 
6473 	if (!neg)
6474 	{
6475 		if (unlikely(val == PG_INT64_MIN))
6476 			return false;
6477 		val = -val;
6478 	}
6479 	*result = val;
6480 
6481 	return true;
6482 }
6483 
6484 /*
6485  * Convert int8 value to numeric.
6486  */
6487 static void
int64_to_numericvar(int64 val,NumericVar * var)6488 int64_to_numericvar(int64 val, NumericVar *var)
6489 {
6490 	uint64		uval,
6491 				newuval;
6492 	NumericDigit *ptr;
6493 	int			ndigits;
6494 
6495 	/* int64 can require at most 19 decimal digits; add one for safety */
6496 	alloc_var(var, 20 / DEC_DIGITS);
6497 	if (val < 0)
6498 	{
6499 		var->sign = NUMERIC_NEG;
6500 		uval = -val;
6501 	}
6502 	else
6503 	{
6504 		var->sign = NUMERIC_POS;
6505 		uval = val;
6506 	}
6507 	var->dscale = 0;
6508 	if (val == 0)
6509 	{
6510 		var->ndigits = 0;
6511 		var->weight = 0;
6512 		return;
6513 	}
6514 	ptr = var->digits + var->ndigits;
6515 	ndigits = 0;
6516 	do
6517 	{
6518 		ptr--;
6519 		ndigits++;
6520 		newuval = uval / NBASE;
6521 		*ptr = uval - newuval * NBASE;
6522 		uval = newuval;
6523 	} while (uval);
6524 	var->digits = ptr;
6525 	var->ndigits = ndigits;
6526 	var->weight = ndigits - 1;
6527 }
6528 
6529 #ifdef HAVE_INT128
6530 /*
6531  * Convert numeric to int128, rounding if needed.
6532  *
6533  * If overflow, return false (no error is raised).  Return true if okay.
6534  */
6535 static bool
numericvar_to_int128(const NumericVar * var,int128 * result)6536 numericvar_to_int128(const NumericVar *var, int128 *result)
6537 {
6538 	NumericDigit *digits;
6539 	int			ndigits;
6540 	int			weight;
6541 	int			i;
6542 	int128		val,
6543 				oldval;
6544 	bool		neg;
6545 	NumericVar	rounded;
6546 
6547 	/* Round to nearest integer */
6548 	init_var(&rounded);
6549 	set_var_from_var(var, &rounded);
6550 	round_var(&rounded, 0);
6551 
6552 	/* Check for zero input */
6553 	strip_var(&rounded);
6554 	ndigits = rounded.ndigits;
6555 	if (ndigits == 0)
6556 	{
6557 		*result = 0;
6558 		free_var(&rounded);
6559 		return true;
6560 	}
6561 
6562 	/*
6563 	 * For input like 10000000000, we must treat stripped digits as real. So
6564 	 * the loop assumes there are weight+1 digits before the decimal point.
6565 	 */
6566 	weight = rounded.weight;
6567 	Assert(weight >= 0 && ndigits <= weight + 1);
6568 
6569 	/* Construct the result */
6570 	digits = rounded.digits;
6571 	neg = (rounded.sign == NUMERIC_NEG);
6572 	val = digits[0];
6573 	for (i = 1; i <= weight; i++)
6574 	{
6575 		oldval = val;
6576 		val *= NBASE;
6577 		if (i < ndigits)
6578 			val += digits[i];
6579 
6580 		/*
6581 		 * The overflow check is a bit tricky because we want to accept
6582 		 * INT128_MIN, which will overflow the positive accumulator.  We can
6583 		 * detect this case easily though because INT128_MIN is the only
6584 		 * nonzero value for which -val == val (on a two's complement machine,
6585 		 * anyway).
6586 		 */
6587 		if ((val / NBASE) != oldval)	/* possible overflow? */
6588 		{
6589 			if (!neg || (-val) != val || val == 0 || oldval < 0)
6590 			{
6591 				free_var(&rounded);
6592 				return false;
6593 			}
6594 		}
6595 	}
6596 
6597 	free_var(&rounded);
6598 
6599 	*result = neg ? -val : val;
6600 	return true;
6601 }
6602 
6603 /*
6604  * Convert 128 bit integer to numeric.
6605  */
6606 static void
int128_to_numericvar(int128 val,NumericVar * var)6607 int128_to_numericvar(int128 val, NumericVar *var)
6608 {
6609 	uint128		uval,
6610 				newuval;
6611 	NumericDigit *ptr;
6612 	int			ndigits;
6613 
6614 	/* int128 can require at most 39 decimal digits; add one for safety */
6615 	alloc_var(var, 40 / DEC_DIGITS);
6616 	if (val < 0)
6617 	{
6618 		var->sign = NUMERIC_NEG;
6619 		uval = -val;
6620 	}
6621 	else
6622 	{
6623 		var->sign = NUMERIC_POS;
6624 		uval = val;
6625 	}
6626 	var->dscale = 0;
6627 	if (val == 0)
6628 	{
6629 		var->ndigits = 0;
6630 		var->weight = 0;
6631 		return;
6632 	}
6633 	ptr = var->digits + var->ndigits;
6634 	ndigits = 0;
6635 	do
6636 	{
6637 		ptr--;
6638 		ndigits++;
6639 		newuval = uval / NBASE;
6640 		*ptr = uval - newuval * NBASE;
6641 		uval = newuval;
6642 	} while (uval);
6643 	var->digits = ptr;
6644 	var->ndigits = ndigits;
6645 	var->weight = ndigits - 1;
6646 }
6647 #endif
6648 
6649 /*
6650  * Convert numeric to float8; if out of range, return +/- HUGE_VAL
6651  */
6652 static double
numeric_to_double_no_overflow(Numeric num)6653 numeric_to_double_no_overflow(Numeric num)
6654 {
6655 	char	   *tmp;
6656 	double		val;
6657 	char	   *endptr;
6658 
6659 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
6660 											  NumericGetDatum(num)));
6661 
6662 	/* unlike float8in, we ignore ERANGE from strtod */
6663 	val = strtod(tmp, &endptr);
6664 	if (*endptr != '\0')
6665 	{
6666 		/* shouldn't happen ... */
6667 		ereport(ERROR,
6668 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6669 				 errmsg("invalid input syntax for type %s: \"%s\"",
6670 						"double precision", tmp)));
6671 	}
6672 
6673 	pfree(tmp);
6674 
6675 	return val;
6676 }
6677 
6678 /* As above, but work from a NumericVar */
6679 static double
numericvar_to_double_no_overflow(const NumericVar * var)6680 numericvar_to_double_no_overflow(const NumericVar *var)
6681 {
6682 	char	   *tmp;
6683 	double		val;
6684 	char	   *endptr;
6685 
6686 	tmp = get_str_from_var(var);
6687 
6688 	/* unlike float8in, we ignore ERANGE from strtod */
6689 	val = strtod(tmp, &endptr);
6690 	if (*endptr != '\0')
6691 	{
6692 		/* shouldn't happen ... */
6693 		ereport(ERROR,
6694 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6695 				 errmsg("invalid input syntax for type %s: \"%s\"",
6696 						"double precision", tmp)));
6697 	}
6698 
6699 	pfree(tmp);
6700 
6701 	return val;
6702 }
6703 
6704 
6705 /*
6706  * cmp_var() -
6707  *
6708  *	Compare two values on variable level.  We assume zeroes have been
6709  *	truncated to no digits.
6710  */
6711 static int
cmp_var(const NumericVar * var1,const NumericVar * var2)6712 cmp_var(const NumericVar *var1, const NumericVar *var2)
6713 {
6714 	return cmp_var_common(var1->digits, var1->ndigits,
6715 						  var1->weight, var1->sign,
6716 						  var2->digits, var2->ndigits,
6717 						  var2->weight, var2->sign);
6718 }
6719 
6720 /*
6721  * cmp_var_common() -
6722  *
6723  *	Main routine of cmp_var(). This function can be used by both
6724  *	NumericVar and Numeric.
6725  */
6726 static int
cmp_var_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,int var1sign,const NumericDigit * var2digits,int var2ndigits,int var2weight,int var2sign)6727 cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
6728 			   int var1weight, int var1sign,
6729 			   const NumericDigit *var2digits, int var2ndigits,
6730 			   int var2weight, int var2sign)
6731 {
6732 	if (var1ndigits == 0)
6733 	{
6734 		if (var2ndigits == 0)
6735 			return 0;
6736 		if (var2sign == NUMERIC_NEG)
6737 			return 1;
6738 		return -1;
6739 	}
6740 	if (var2ndigits == 0)
6741 	{
6742 		if (var1sign == NUMERIC_POS)
6743 			return 1;
6744 		return -1;
6745 	}
6746 
6747 	if (var1sign == NUMERIC_POS)
6748 	{
6749 		if (var2sign == NUMERIC_NEG)
6750 			return 1;
6751 		return cmp_abs_common(var1digits, var1ndigits, var1weight,
6752 							  var2digits, var2ndigits, var2weight);
6753 	}
6754 
6755 	if (var2sign == NUMERIC_POS)
6756 		return -1;
6757 
6758 	return cmp_abs_common(var2digits, var2ndigits, var2weight,
6759 						  var1digits, var1ndigits, var1weight);
6760 }
6761 
6762 
6763 /*
6764  * add_var() -
6765  *
6766  *	Full version of add functionality on variable level (handling signs).
6767  *	result might point to one of the operands too without danger.
6768  */
6769 static void
add_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)6770 add_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
6771 {
6772 	/*
6773 	 * Decide on the signs of the two variables what to do
6774 	 */
6775 	if (var1->sign == NUMERIC_POS)
6776 	{
6777 		if (var2->sign == NUMERIC_POS)
6778 		{
6779 			/*
6780 			 * Both are positive result = +(ABS(var1) + ABS(var2))
6781 			 */
6782 			add_abs(var1, var2, result);
6783 			result->sign = NUMERIC_POS;
6784 		}
6785 		else
6786 		{
6787 			/*
6788 			 * var1 is positive, var2 is negative Must compare absolute values
6789 			 */
6790 			switch (cmp_abs(var1, var2))
6791 			{
6792 				case 0:
6793 					/* ----------
6794 					 * ABS(var1) == ABS(var2)
6795 					 * result = ZERO
6796 					 * ----------
6797 					 */
6798 					zero_var(result);
6799 					result->dscale = Max(var1->dscale, var2->dscale);
6800 					break;
6801 
6802 				case 1:
6803 					/* ----------
6804 					 * ABS(var1) > ABS(var2)
6805 					 * result = +(ABS(var1) - ABS(var2))
6806 					 * ----------
6807 					 */
6808 					sub_abs(var1, var2, result);
6809 					result->sign = NUMERIC_POS;
6810 					break;
6811 
6812 				case -1:
6813 					/* ----------
6814 					 * ABS(var1) < ABS(var2)
6815 					 * result = -(ABS(var2) - ABS(var1))
6816 					 * ----------
6817 					 */
6818 					sub_abs(var2, var1, result);
6819 					result->sign = NUMERIC_NEG;
6820 					break;
6821 			}
6822 		}
6823 	}
6824 	else
6825 	{
6826 		if (var2->sign == NUMERIC_POS)
6827 		{
6828 			/* ----------
6829 			 * var1 is negative, var2 is positive
6830 			 * Must compare absolute values
6831 			 * ----------
6832 			 */
6833 			switch (cmp_abs(var1, var2))
6834 			{
6835 				case 0:
6836 					/* ----------
6837 					 * ABS(var1) == ABS(var2)
6838 					 * result = ZERO
6839 					 * ----------
6840 					 */
6841 					zero_var(result);
6842 					result->dscale = Max(var1->dscale, var2->dscale);
6843 					break;
6844 
6845 				case 1:
6846 					/* ----------
6847 					 * ABS(var1) > ABS(var2)
6848 					 * result = -(ABS(var1) - ABS(var2))
6849 					 * ----------
6850 					 */
6851 					sub_abs(var1, var2, result);
6852 					result->sign = NUMERIC_NEG;
6853 					break;
6854 
6855 				case -1:
6856 					/* ----------
6857 					 * ABS(var1) < ABS(var2)
6858 					 * result = +(ABS(var2) - ABS(var1))
6859 					 * ----------
6860 					 */
6861 					sub_abs(var2, var1, result);
6862 					result->sign = NUMERIC_POS;
6863 					break;
6864 			}
6865 		}
6866 		else
6867 		{
6868 			/* ----------
6869 			 * Both are negative
6870 			 * result = -(ABS(var1) + ABS(var2))
6871 			 * ----------
6872 			 */
6873 			add_abs(var1, var2, result);
6874 			result->sign = NUMERIC_NEG;
6875 		}
6876 	}
6877 }
6878 
6879 
6880 /*
6881  * sub_var() -
6882  *
6883  *	Full version of sub functionality on variable level (handling signs).
6884  *	result might point to one of the operands too without danger.
6885  */
6886 static void
sub_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)6887 sub_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
6888 {
6889 	/*
6890 	 * Decide on the signs of the two variables what to do
6891 	 */
6892 	if (var1->sign == NUMERIC_POS)
6893 	{
6894 		if (var2->sign == NUMERIC_NEG)
6895 		{
6896 			/* ----------
6897 			 * var1 is positive, var2 is negative
6898 			 * result = +(ABS(var1) + ABS(var2))
6899 			 * ----------
6900 			 */
6901 			add_abs(var1, var2, result);
6902 			result->sign = NUMERIC_POS;
6903 		}
6904 		else
6905 		{
6906 			/* ----------
6907 			 * Both are positive
6908 			 * Must compare absolute values
6909 			 * ----------
6910 			 */
6911 			switch (cmp_abs(var1, var2))
6912 			{
6913 				case 0:
6914 					/* ----------
6915 					 * ABS(var1) == ABS(var2)
6916 					 * result = ZERO
6917 					 * ----------
6918 					 */
6919 					zero_var(result);
6920 					result->dscale = Max(var1->dscale, var2->dscale);
6921 					break;
6922 
6923 				case 1:
6924 					/* ----------
6925 					 * ABS(var1) > ABS(var2)
6926 					 * result = +(ABS(var1) - ABS(var2))
6927 					 * ----------
6928 					 */
6929 					sub_abs(var1, var2, result);
6930 					result->sign = NUMERIC_POS;
6931 					break;
6932 
6933 				case -1:
6934 					/* ----------
6935 					 * ABS(var1) < ABS(var2)
6936 					 * result = -(ABS(var2) - ABS(var1))
6937 					 * ----------
6938 					 */
6939 					sub_abs(var2, var1, result);
6940 					result->sign = NUMERIC_NEG;
6941 					break;
6942 			}
6943 		}
6944 	}
6945 	else
6946 	{
6947 		if (var2->sign == NUMERIC_NEG)
6948 		{
6949 			/* ----------
6950 			 * Both are negative
6951 			 * Must compare absolute values
6952 			 * ----------
6953 			 */
6954 			switch (cmp_abs(var1, var2))
6955 			{
6956 				case 0:
6957 					/* ----------
6958 					 * ABS(var1) == ABS(var2)
6959 					 * result = ZERO
6960 					 * ----------
6961 					 */
6962 					zero_var(result);
6963 					result->dscale = Max(var1->dscale, var2->dscale);
6964 					break;
6965 
6966 				case 1:
6967 					/* ----------
6968 					 * ABS(var1) > ABS(var2)
6969 					 * result = -(ABS(var1) - ABS(var2))
6970 					 * ----------
6971 					 */
6972 					sub_abs(var1, var2, result);
6973 					result->sign = NUMERIC_NEG;
6974 					break;
6975 
6976 				case -1:
6977 					/* ----------
6978 					 * ABS(var1) < ABS(var2)
6979 					 * result = +(ABS(var2) - ABS(var1))
6980 					 * ----------
6981 					 */
6982 					sub_abs(var2, var1, result);
6983 					result->sign = NUMERIC_POS;
6984 					break;
6985 			}
6986 		}
6987 		else
6988 		{
6989 			/* ----------
6990 			 * var1 is negative, var2 is positive
6991 			 * result = -(ABS(var1) + ABS(var2))
6992 			 * ----------
6993 			 */
6994 			add_abs(var1, var2, result);
6995 			result->sign = NUMERIC_NEG;
6996 		}
6997 	}
6998 }
6999 
7000 
7001 /*
7002  * mul_var() -
7003  *
7004  *	Multiplication on variable level. Product of var1 * var2 is stored
7005  *	in result.  Result is rounded to no more than rscale fractional digits.
7006  */
7007 static void
mul_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale)7008 mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
7009 		int rscale)
7010 {
7011 	int			res_ndigits;
7012 	int			res_sign;
7013 	int			res_weight;
7014 	int			maxdigits;
7015 	int		   *dig;
7016 	int			carry;
7017 	int			maxdig;
7018 	int			newdig;
7019 	int			var1ndigits;
7020 	int			var2ndigits;
7021 	NumericDigit *var1digits;
7022 	NumericDigit *var2digits;
7023 	NumericDigit *res_digits;
7024 	int			i,
7025 				i1,
7026 				i2;
7027 
7028 	/*
7029 	 * Arrange for var1 to be the shorter of the two numbers.  This improves
7030 	 * performance because the inner multiplication loop is much simpler than
7031 	 * the outer loop, so it's better to have a smaller number of iterations
7032 	 * of the outer loop.  This also reduces the number of times that the
7033 	 * accumulator array needs to be normalized.
7034 	 */
7035 	if (var1->ndigits > var2->ndigits)
7036 	{
7037 		const NumericVar *tmp = var1;
7038 
7039 		var1 = var2;
7040 		var2 = tmp;
7041 	}
7042 
7043 	/* copy these values into local vars for speed in inner loop */
7044 	var1ndigits = var1->ndigits;
7045 	var2ndigits = var2->ndigits;
7046 	var1digits = var1->digits;
7047 	var2digits = var2->digits;
7048 
7049 	if (var1ndigits == 0 || var2ndigits == 0)
7050 	{
7051 		/* one or both inputs is zero; so is result */
7052 		zero_var(result);
7053 		result->dscale = rscale;
7054 		return;
7055 	}
7056 
7057 	/* Determine result sign and (maximum possible) weight */
7058 	if (var1->sign == var2->sign)
7059 		res_sign = NUMERIC_POS;
7060 	else
7061 		res_sign = NUMERIC_NEG;
7062 	res_weight = var1->weight + var2->weight + 2;
7063 
7064 	/*
7065 	 * Determine the number of result digits to compute.  If the exact result
7066 	 * would have more than rscale fractional digits, truncate the computation
7067 	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
7068 	 * would only contribute to the right of that.  (This will give the exact
7069 	 * rounded-to-rscale answer unless carries out of the ignored positions
7070 	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
7071 	 *
7072 	 * Note: an exact computation could not produce more than var1ndigits +
7073 	 * var2ndigits digits, but we allocate one extra output digit in case
7074 	 * rscale-driven rounding produces a carry out of the highest exact digit.
7075 	 */
7076 	res_ndigits = var1ndigits + var2ndigits + 1;
7077 	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
7078 		MUL_GUARD_DIGITS;
7079 	res_ndigits = Min(res_ndigits, maxdigits);
7080 
7081 	if (res_ndigits < 3)
7082 	{
7083 		/* All input digits will be ignored; so result is zero */
7084 		zero_var(result);
7085 		result->dscale = rscale;
7086 		return;
7087 	}
7088 
7089 	/*
7090 	 * We do the arithmetic in an array "dig[]" of signed int's.  Since
7091 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
7092 	 * to avoid normalizing carries immediately.
7093 	 *
7094 	 * maxdig tracks the maximum possible value of any dig[] entry; when this
7095 	 * threatens to exceed INT_MAX, we take the time to propagate carries.
7096 	 * Furthermore, we need to ensure that overflow doesn't occur during the
7097 	 * carry propagation passes either.  The carry values could be as much as
7098 	 * INT_MAX/NBASE, so really we must normalize when digits threaten to
7099 	 * exceed INT_MAX - INT_MAX/NBASE.
7100 	 *
7101 	 * To avoid overflow in maxdig itself, it actually represents the max
7102 	 * possible value divided by NBASE-1, ie, at the top of the loop it is
7103 	 * known that no dig[] entry exceeds maxdig * (NBASE-1).
7104 	 */
7105 	dig = (int *) palloc0(res_ndigits * sizeof(int));
7106 	maxdig = 0;
7107 
7108 	/*
7109 	 * The least significant digits of var1 should be ignored if they don't
7110 	 * contribute directly to the first res_ndigits digits of the result that
7111 	 * we are computing.
7112 	 *
7113 	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
7114 	 * i1+i2+2 of the accumulator array, so we need only consider digits of
7115 	 * var1 for which i1 <= res_ndigits - 3.
7116 	 */
7117 	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
7118 	{
7119 		int			var1digit = var1digits[i1];
7120 
7121 		if (var1digit == 0)
7122 			continue;
7123 
7124 		/* Time to normalize? */
7125 		maxdig += var1digit;
7126 		if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
7127 		{
7128 			/* Yes, do it */
7129 			carry = 0;
7130 			for (i = res_ndigits - 1; i >= 0; i--)
7131 			{
7132 				newdig = dig[i] + carry;
7133 				if (newdig >= NBASE)
7134 				{
7135 					carry = newdig / NBASE;
7136 					newdig -= carry * NBASE;
7137 				}
7138 				else
7139 					carry = 0;
7140 				dig[i] = newdig;
7141 			}
7142 			Assert(carry == 0);
7143 			/* Reset maxdig to indicate new worst-case */
7144 			maxdig = 1 + var1digit;
7145 		}
7146 
7147 		/*
7148 		 * Add the appropriate multiple of var2 into the accumulator.
7149 		 *
7150 		 * As above, digits of var2 can be ignored if they don't contribute,
7151 		 * so we only include digits for which i1+i2+2 <= res_ndigits - 1.
7152 		 */
7153 		for (i2 = Min(var2ndigits - 1, res_ndigits - i1 - 3), i = i1 + i2 + 2;
7154 			 i2 >= 0; i2--)
7155 			dig[i--] += var1digit * var2digits[i2];
7156 	}
7157 
7158 	/*
7159 	 * Now we do a final carry propagation pass to normalize the result, which
7160 	 * we combine with storing the result digits into the output. Note that
7161 	 * this is still done at full precision w/guard digits.
7162 	 */
7163 	alloc_var(result, res_ndigits);
7164 	res_digits = result->digits;
7165 	carry = 0;
7166 	for (i = res_ndigits - 1; i >= 0; i--)
7167 	{
7168 		newdig = dig[i] + carry;
7169 		if (newdig >= NBASE)
7170 		{
7171 			carry = newdig / NBASE;
7172 			newdig -= carry * NBASE;
7173 		}
7174 		else
7175 			carry = 0;
7176 		res_digits[i] = newdig;
7177 	}
7178 	Assert(carry == 0);
7179 
7180 	pfree(dig);
7181 
7182 	/*
7183 	 * Finally, round the result to the requested precision.
7184 	 */
7185 	result->weight = res_weight;
7186 	result->sign = res_sign;
7187 
7188 	/* Round to target rscale (and set result->dscale) */
7189 	round_var(result, rscale);
7190 
7191 	/* Strip leading and trailing zeroes */
7192 	strip_var(result);
7193 }
7194 
7195 
7196 /*
7197  * div_var() -
7198  *
7199  *	Division on variable level. Quotient of var1 / var2 is stored in result.
7200  *	The quotient is figured to exactly rscale fractional digits.
7201  *	If round is true, it is rounded at the rscale'th digit; if false, it
7202  *	is truncated (towards zero) at that digit.
7203  */
7204 static void
div_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)7205 div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
7206 		int rscale, bool round)
7207 {
7208 	int			div_ndigits;
7209 	int			res_ndigits;
7210 	int			res_sign;
7211 	int			res_weight;
7212 	int			carry;
7213 	int			borrow;
7214 	int			divisor1;
7215 	int			divisor2;
7216 	NumericDigit *dividend;
7217 	NumericDigit *divisor;
7218 	NumericDigit *res_digits;
7219 	int			i;
7220 	int			j;
7221 
7222 	/* copy these values into local vars for speed in inner loop */
7223 	int			var1ndigits = var1->ndigits;
7224 	int			var2ndigits = var2->ndigits;
7225 
7226 	/*
7227 	 * First of all division by zero check; we must not be handed an
7228 	 * unnormalized divisor.
7229 	 */
7230 	if (var2ndigits == 0 || var2->digits[0] == 0)
7231 		ereport(ERROR,
7232 				(errcode(ERRCODE_DIVISION_BY_ZERO),
7233 				 errmsg("division by zero")));
7234 
7235 	/*
7236 	 * Now result zero check
7237 	 */
7238 	if (var1ndigits == 0)
7239 	{
7240 		zero_var(result);
7241 		result->dscale = rscale;
7242 		return;
7243 	}
7244 
7245 	/*
7246 	 * Determine the result sign, weight and number of digits to calculate.
7247 	 * The weight figured here is correct if the emitted quotient has no
7248 	 * leading zero digits; otherwise strip_var() will fix things up.
7249 	 */
7250 	if (var1->sign == var2->sign)
7251 		res_sign = NUMERIC_POS;
7252 	else
7253 		res_sign = NUMERIC_NEG;
7254 	res_weight = var1->weight - var2->weight;
7255 	/* The number of accurate result digits we need to produce: */
7256 	res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7257 	/* ... but always at least 1 */
7258 	res_ndigits = Max(res_ndigits, 1);
7259 	/* If rounding needed, figure one more digit to ensure correct result */
7260 	if (round)
7261 		res_ndigits++;
7262 
7263 	/*
7264 	 * The working dividend normally requires res_ndigits + var2ndigits
7265 	 * digits, but make it at least var1ndigits so we can load all of var1
7266 	 * into it.  (There will be an additional digit dividend[0] in the
7267 	 * dividend space, but for consistency with Knuth's notation we don't
7268 	 * count that in div_ndigits.)
7269 	 */
7270 	div_ndigits = res_ndigits + var2ndigits;
7271 	div_ndigits = Max(div_ndigits, var1ndigits);
7272 
7273 	/*
7274 	 * We need a workspace with room for the working dividend (div_ndigits+1
7275 	 * digits) plus room for the possibly-normalized divisor (var2ndigits
7276 	 * digits).  It is convenient also to have a zero at divisor[0] with the
7277 	 * actual divisor data in divisor[1 .. var2ndigits].  Transferring the
7278 	 * digits into the workspace also allows us to realloc the result (which
7279 	 * might be the same as either input var) before we begin the main loop.
7280 	 * Note that we use palloc0 to ensure that divisor[0], dividend[0], and
7281 	 * any additional dividend positions beyond var1ndigits, start out 0.
7282 	 */
7283 	dividend = (NumericDigit *)
7284 		palloc0((div_ndigits + var2ndigits + 2) * sizeof(NumericDigit));
7285 	divisor = dividend + (div_ndigits + 1);
7286 	memcpy(dividend + 1, var1->digits, var1ndigits * sizeof(NumericDigit));
7287 	memcpy(divisor + 1, var2->digits, var2ndigits * sizeof(NumericDigit));
7288 
7289 	/*
7290 	 * Now we can realloc the result to hold the generated quotient digits.
7291 	 */
7292 	alloc_var(result, res_ndigits);
7293 	res_digits = result->digits;
7294 
7295 	if (var2ndigits == 1)
7296 	{
7297 		/*
7298 		 * If there's only a single divisor digit, we can use a fast path (cf.
7299 		 * Knuth section 4.3.1 exercise 16).
7300 		 */
7301 		divisor1 = divisor[1];
7302 		carry = 0;
7303 		for (i = 0; i < res_ndigits; i++)
7304 		{
7305 			carry = carry * NBASE + dividend[i + 1];
7306 			res_digits[i] = carry / divisor1;
7307 			carry = carry % divisor1;
7308 		}
7309 	}
7310 	else
7311 	{
7312 		/*
7313 		 * The full multiple-place algorithm is taken from Knuth volume 2,
7314 		 * Algorithm 4.3.1D.
7315 		 *
7316 		 * We need the first divisor digit to be >= NBASE/2.  If it isn't,
7317 		 * make it so by scaling up both the divisor and dividend by the
7318 		 * factor "d".  (The reason for allocating dividend[0] above is to
7319 		 * leave room for possible carry here.)
7320 		 */
7321 		if (divisor[1] < HALF_NBASE)
7322 		{
7323 			int			d = NBASE / (divisor[1] + 1);
7324 
7325 			carry = 0;
7326 			for (i = var2ndigits; i > 0; i--)
7327 			{
7328 				carry += divisor[i] * d;
7329 				divisor[i] = carry % NBASE;
7330 				carry = carry / NBASE;
7331 			}
7332 			Assert(carry == 0);
7333 			carry = 0;
7334 			/* at this point only var1ndigits of dividend can be nonzero */
7335 			for (i = var1ndigits; i >= 0; i--)
7336 			{
7337 				carry += dividend[i] * d;
7338 				dividend[i] = carry % NBASE;
7339 				carry = carry / NBASE;
7340 			}
7341 			Assert(carry == 0);
7342 			Assert(divisor[1] >= HALF_NBASE);
7343 		}
7344 		/* First 2 divisor digits are used repeatedly in main loop */
7345 		divisor1 = divisor[1];
7346 		divisor2 = divisor[2];
7347 
7348 		/*
7349 		 * Begin the main loop.  Each iteration of this loop produces the j'th
7350 		 * quotient digit by dividing dividend[j .. j + var2ndigits] by the
7351 		 * divisor; this is essentially the same as the common manual
7352 		 * procedure for long division.
7353 		 */
7354 		for (j = 0; j < res_ndigits; j++)
7355 		{
7356 			/* Estimate quotient digit from the first two dividend digits */
7357 			int			next2digits = dividend[j] * NBASE + dividend[j + 1];
7358 			int			qhat;
7359 
7360 			/*
7361 			 * If next2digits are 0, then quotient digit must be 0 and there's
7362 			 * no need to adjust the working dividend.  It's worth testing
7363 			 * here to fall out ASAP when processing trailing zeroes in a
7364 			 * dividend.
7365 			 */
7366 			if (next2digits == 0)
7367 			{
7368 				res_digits[j] = 0;
7369 				continue;
7370 			}
7371 
7372 			if (dividend[j] == divisor1)
7373 				qhat = NBASE - 1;
7374 			else
7375 				qhat = next2digits / divisor1;
7376 
7377 			/*
7378 			 * Adjust quotient digit if it's too large.  Knuth proves that
7379 			 * after this step, the quotient digit will be either correct or
7380 			 * just one too large.  (Note: it's OK to use dividend[j+2] here
7381 			 * because we know the divisor length is at least 2.)
7382 			 */
7383 			while (divisor2 * qhat >
7384 				   (next2digits - qhat * divisor1) * NBASE + dividend[j + 2])
7385 				qhat--;
7386 
7387 			/* As above, need do nothing more when quotient digit is 0 */
7388 			if (qhat > 0)
7389 			{
7390 				/*
7391 				 * Multiply the divisor by qhat, and subtract that from the
7392 				 * working dividend.  "carry" tracks the multiplication,
7393 				 * "borrow" the subtraction (could we fold these together?)
7394 				 */
7395 				carry = 0;
7396 				borrow = 0;
7397 				for (i = var2ndigits; i >= 0; i--)
7398 				{
7399 					carry += divisor[i] * qhat;
7400 					borrow -= carry % NBASE;
7401 					carry = carry / NBASE;
7402 					borrow += dividend[j + i];
7403 					if (borrow < 0)
7404 					{
7405 						dividend[j + i] = borrow + NBASE;
7406 						borrow = -1;
7407 					}
7408 					else
7409 					{
7410 						dividend[j + i] = borrow;
7411 						borrow = 0;
7412 					}
7413 				}
7414 				Assert(carry == 0);
7415 
7416 				/*
7417 				 * If we got a borrow out of the top dividend digit, then
7418 				 * indeed qhat was one too large.  Fix it, and add back the
7419 				 * divisor to correct the working dividend.  (Knuth proves
7420 				 * that this will occur only about 3/NBASE of the time; hence,
7421 				 * it's a good idea to test this code with small NBASE to be
7422 				 * sure this section gets exercised.)
7423 				 */
7424 				if (borrow)
7425 				{
7426 					qhat--;
7427 					carry = 0;
7428 					for (i = var2ndigits; i >= 0; i--)
7429 					{
7430 						carry += dividend[j + i] + divisor[i];
7431 						if (carry >= NBASE)
7432 						{
7433 							dividend[j + i] = carry - NBASE;
7434 							carry = 1;
7435 						}
7436 						else
7437 						{
7438 							dividend[j + i] = carry;
7439 							carry = 0;
7440 						}
7441 					}
7442 					/* A carry should occur here to cancel the borrow above */
7443 					Assert(carry == 1);
7444 				}
7445 			}
7446 
7447 			/* And we're done with this quotient digit */
7448 			res_digits[j] = qhat;
7449 		}
7450 	}
7451 
7452 	pfree(dividend);
7453 
7454 	/*
7455 	 * Finally, round or truncate the result to the requested precision.
7456 	 */
7457 	result->weight = res_weight;
7458 	result->sign = res_sign;
7459 
7460 	/* Round or truncate to target rscale (and set result->dscale) */
7461 	if (round)
7462 		round_var(result, rscale);
7463 	else
7464 		trunc_var(result, rscale);
7465 
7466 	/* Strip leading and trailing zeroes */
7467 	strip_var(result);
7468 }
7469 
7470 
7471 /*
7472  * div_var_fast() -
7473  *
7474  *	This has the same API as div_var, but is implemented using the division
7475  *	algorithm from the "FM" library, rather than Knuth's schoolbook-division
7476  *	approach.  This is significantly faster but can produce inaccurate
7477  *	results, because it sometimes has to propagate rounding to the left,
7478  *	and so we can never be entirely sure that we know the requested digits
7479  *	exactly.  We compute DIV_GUARD_DIGITS extra digits, but there is
7480  *	no certainty that that's enough.  We use this only in the transcendental
7481  *	function calculation routines, where everything is approximate anyway.
7482  *
7483  *	Although we provide a "round" argument for consistency with div_var,
7484  *	it is unwise to use this function with round=false.  In truncation mode
7485  *	it is possible to get a result with no significant digits, for example
7486  *	with rscale=0 we might compute 0.99999... and truncate that to 0 when
7487  *	the correct answer is 1.
7488  */
7489 static void
div_var_fast(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)7490 div_var_fast(const NumericVar *var1, const NumericVar *var2,
7491 			 NumericVar *result, int rscale, bool round)
7492 {
7493 	int			div_ndigits;
7494 	int			res_sign;
7495 	int			res_weight;
7496 	int		   *div;
7497 	int			qdigit;
7498 	int			carry;
7499 	int			maxdiv;
7500 	int			newdig;
7501 	NumericDigit *res_digits;
7502 	double		fdividend,
7503 				fdivisor,
7504 				fdivisorinverse,
7505 				fquotient;
7506 	int			qi;
7507 	int			i;
7508 
7509 	/* copy these values into local vars for speed in inner loop */
7510 	int			var1ndigits = var1->ndigits;
7511 	int			var2ndigits = var2->ndigits;
7512 	NumericDigit *var1digits = var1->digits;
7513 	NumericDigit *var2digits = var2->digits;
7514 
7515 	/*
7516 	 * First of all division by zero check; we must not be handed an
7517 	 * unnormalized divisor.
7518 	 */
7519 	if (var2ndigits == 0 || var2digits[0] == 0)
7520 		ereport(ERROR,
7521 				(errcode(ERRCODE_DIVISION_BY_ZERO),
7522 				 errmsg("division by zero")));
7523 
7524 	/*
7525 	 * Now result zero check
7526 	 */
7527 	if (var1ndigits == 0)
7528 	{
7529 		zero_var(result);
7530 		result->dscale = rscale;
7531 		return;
7532 	}
7533 
7534 	/*
7535 	 * Determine the result sign, weight and number of digits to calculate
7536 	 */
7537 	if (var1->sign == var2->sign)
7538 		res_sign = NUMERIC_POS;
7539 	else
7540 		res_sign = NUMERIC_NEG;
7541 	res_weight = var1->weight - var2->weight + 1;
7542 	/* The number of accurate result digits we need to produce: */
7543 	div_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7544 	/* Add guard digits for roundoff error */
7545 	div_ndigits += DIV_GUARD_DIGITS;
7546 	if (div_ndigits < DIV_GUARD_DIGITS)
7547 		div_ndigits = DIV_GUARD_DIGITS;
7548 	/* Must be at least var1ndigits, too, to simplify data-loading loop */
7549 	if (div_ndigits < var1ndigits)
7550 		div_ndigits = var1ndigits;
7551 
7552 	/*
7553 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
7554 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
7555 	 * to avoid normalizing carries immediately.
7556 	 *
7557 	 * We start with div[] containing one zero digit followed by the
7558 	 * dividend's digits (plus appended zeroes to reach the desired precision
7559 	 * including guard digits).  Each step of the main loop computes an
7560 	 * (approximate) quotient digit and stores it into div[], removing one
7561 	 * position of dividend space.  A final pass of carry propagation takes
7562 	 * care of any mistaken quotient digits.
7563 	 */
7564 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
7565 	for (i = 0; i < var1ndigits; i++)
7566 		div[i + 1] = var1digits[i];
7567 
7568 	/*
7569 	 * We estimate each quotient digit using floating-point arithmetic, taking
7570 	 * the first four digits of the (current) dividend and divisor.  This must
7571 	 * be float to avoid overflow.  The quotient digits will generally be off
7572 	 * by no more than one from the exact answer.
7573 	 */
7574 	fdivisor = (double) var2digits[0];
7575 	for (i = 1; i < 4; i++)
7576 	{
7577 		fdivisor *= NBASE;
7578 		if (i < var2ndigits)
7579 			fdivisor += (double) var2digits[i];
7580 	}
7581 	fdivisorinverse = 1.0 / fdivisor;
7582 
7583 	/*
7584 	 * maxdiv tracks the maximum possible absolute value of any div[] entry;
7585 	 * when this threatens to exceed INT_MAX, we take the time to propagate
7586 	 * carries.  Furthermore, we need to ensure that overflow doesn't occur
7587 	 * during the carry propagation passes either.  The carry values may have
7588 	 * an absolute value as high as INT_MAX/NBASE + 1, so really we must
7589 	 * normalize when digits threaten to exceed INT_MAX - INT_MAX/NBASE - 1.
7590 	 *
7591 	 * To avoid overflow in maxdiv itself, it represents the max absolute
7592 	 * value divided by NBASE-1, ie, at the top of the loop it is known that
7593 	 * no div[] entry has an absolute value exceeding maxdiv * (NBASE-1).
7594 	 *
7595 	 * Actually, though, that holds good only for div[] entries after div[qi];
7596 	 * the adjustment done at the bottom of the loop may cause div[qi + 1] to
7597 	 * exceed the maxdiv limit, so that div[qi] in the next iteration is
7598 	 * beyond the limit.  This does not cause problems, as explained below.
7599 	 */
7600 	maxdiv = 1;
7601 
7602 	/*
7603 	 * Outer loop computes next quotient digit, which will go into div[qi]
7604 	 */
7605 	for (qi = 0; qi < div_ndigits; qi++)
7606 	{
7607 		/* Approximate the current dividend value */
7608 		fdividend = (double) div[qi];
7609 		for (i = 1; i < 4; i++)
7610 		{
7611 			fdividend *= NBASE;
7612 			if (qi + i <= div_ndigits)
7613 				fdividend += (double) div[qi + i];
7614 		}
7615 		/* Compute the (approximate) quotient digit */
7616 		fquotient = fdividend * fdivisorinverse;
7617 		qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7618 			(((int) fquotient) - 1);	/* truncate towards -infinity */
7619 
7620 		if (qdigit != 0)
7621 		{
7622 			/* Do we need to normalize now? */
7623 			maxdiv += Abs(qdigit);
7624 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
7625 			{
7626 				/* Yes, do it */
7627 				carry = 0;
7628 				for (i = div_ndigits; i > qi; i--)
7629 				{
7630 					newdig = div[i] + carry;
7631 					if (newdig < 0)
7632 					{
7633 						carry = -((-newdig - 1) / NBASE) - 1;
7634 						newdig -= carry * NBASE;
7635 					}
7636 					else if (newdig >= NBASE)
7637 					{
7638 						carry = newdig / NBASE;
7639 						newdig -= carry * NBASE;
7640 					}
7641 					else
7642 						carry = 0;
7643 					div[i] = newdig;
7644 				}
7645 				newdig = div[qi] + carry;
7646 				div[qi] = newdig;
7647 
7648 				/*
7649 				 * All the div[] digits except possibly div[qi] are now in the
7650 				 * range 0..NBASE-1.  We do not need to consider div[qi] in
7651 				 * the maxdiv value anymore, so we can reset maxdiv to 1.
7652 				 */
7653 				maxdiv = 1;
7654 
7655 				/*
7656 				 * Recompute the quotient digit since new info may have
7657 				 * propagated into the top four dividend digits
7658 				 */
7659 				fdividend = (double) div[qi];
7660 				for (i = 1; i < 4; i++)
7661 				{
7662 					fdividend *= NBASE;
7663 					if (qi + i <= div_ndigits)
7664 						fdividend += (double) div[qi + i];
7665 				}
7666 				/* Compute the (approximate) quotient digit */
7667 				fquotient = fdividend * fdivisorinverse;
7668 				qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7669 					(((int) fquotient) - 1);	/* truncate towards -infinity */
7670 				maxdiv += Abs(qdigit);
7671 			}
7672 
7673 			/*
7674 			 * Subtract off the appropriate multiple of the divisor.
7675 			 *
7676 			 * The digits beyond div[qi] cannot overflow, because we know they
7677 			 * will fall within the maxdiv limit.  As for div[qi] itself, note
7678 			 * that qdigit is approximately trunc(div[qi] / vardigits[0]),
7679 			 * which would make the new value simply div[qi] mod vardigits[0].
7680 			 * The lower-order terms in qdigit can change this result by not
7681 			 * more than about twice INT_MAX/NBASE, so overflow is impossible.
7682 			 */
7683 			if (qdigit != 0)
7684 			{
7685 				int			istop = Min(var2ndigits, div_ndigits - qi + 1);
7686 
7687 				for (i = 0; i < istop; i++)
7688 					div[qi + i] -= qdigit * var2digits[i];
7689 			}
7690 		}
7691 
7692 		/*
7693 		 * The dividend digit we are about to replace might still be nonzero.
7694 		 * Fold it into the next digit position.
7695 		 *
7696 		 * There is no risk of overflow here, although proving that requires
7697 		 * some care.  Much as with the argument for div[qi] not overflowing,
7698 		 * if we consider the first two terms in the numerator and denominator
7699 		 * of qdigit, we can see that the final value of div[qi + 1] will be
7700 		 * approximately a remainder mod (vardigits[0]*NBASE + vardigits[1]).
7701 		 * Accounting for the lower-order terms is a bit complicated but ends
7702 		 * up adding not much more than INT_MAX/NBASE to the possible range.
7703 		 * Thus, div[qi + 1] cannot overflow here, and in its role as div[qi]
7704 		 * in the next loop iteration, it can't be large enough to cause
7705 		 * overflow in the carry propagation step (if any), either.
7706 		 *
7707 		 * But having said that: div[qi] can be more than INT_MAX/NBASE, as
7708 		 * noted above, which means that the product div[qi] * NBASE *can*
7709 		 * overflow.  When that happens, adding it to div[qi + 1] will always
7710 		 * cause a canceling overflow so that the end result is correct.  We
7711 		 * could avoid the intermediate overflow by doing the multiplication
7712 		 * and addition in int64 arithmetic, but so far there appears no need.
7713 		 */
7714 		div[qi + 1] += div[qi] * NBASE;
7715 
7716 		div[qi] = qdigit;
7717 	}
7718 
7719 	/*
7720 	 * Approximate and store the last quotient digit (div[div_ndigits])
7721 	 */
7722 	fdividend = (double) div[qi];
7723 	for (i = 1; i < 4; i++)
7724 		fdividend *= NBASE;
7725 	fquotient = fdividend * fdivisorinverse;
7726 	qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7727 		(((int) fquotient) - 1);	/* truncate towards -infinity */
7728 	div[qi] = qdigit;
7729 
7730 	/*
7731 	 * Because the quotient digits might be off by one, some of them might be
7732 	 * -1 or NBASE at this point.  The represented value is correct in a
7733 	 * mathematical sense, but it doesn't look right.  We do a final carry
7734 	 * propagation pass to normalize the digits, which we combine with storing
7735 	 * the result digits into the output.  Note that this is still done at
7736 	 * full precision w/guard digits.
7737 	 */
7738 	alloc_var(result, div_ndigits + 1);
7739 	res_digits = result->digits;
7740 	carry = 0;
7741 	for (i = div_ndigits; i >= 0; i--)
7742 	{
7743 		newdig = div[i] + carry;
7744 		if (newdig < 0)
7745 		{
7746 			carry = -((-newdig - 1) / NBASE) - 1;
7747 			newdig -= carry * NBASE;
7748 		}
7749 		else if (newdig >= NBASE)
7750 		{
7751 			carry = newdig / NBASE;
7752 			newdig -= carry * NBASE;
7753 		}
7754 		else
7755 			carry = 0;
7756 		res_digits[i] = newdig;
7757 	}
7758 	Assert(carry == 0);
7759 
7760 	pfree(div);
7761 
7762 	/*
7763 	 * Finally, round the result to the requested precision.
7764 	 */
7765 	result->weight = res_weight;
7766 	result->sign = res_sign;
7767 
7768 	/* Round to target rscale (and set result->dscale) */
7769 	if (round)
7770 		round_var(result, rscale);
7771 	else
7772 		trunc_var(result, rscale);
7773 
7774 	/* Strip leading and trailing zeroes */
7775 	strip_var(result);
7776 }
7777 
7778 
7779 /*
7780  * Default scale selection for division
7781  *
7782  * Returns the appropriate result scale for the division result.
7783  */
7784 static int
select_div_scale(const NumericVar * var1,const NumericVar * var2)7785 select_div_scale(const NumericVar *var1, const NumericVar *var2)
7786 {
7787 	int			weight1,
7788 				weight2,
7789 				qweight,
7790 				i;
7791 	NumericDigit firstdigit1,
7792 				firstdigit2;
7793 	int			rscale;
7794 
7795 	/*
7796 	 * The result scale of a division isn't specified in any SQL standard. For
7797 	 * PostgreSQL we select a result scale that will give at least
7798 	 * NUMERIC_MIN_SIG_DIGITS significant digits, so that numeric gives a
7799 	 * result no less accurate than float8; but use a scale not less than
7800 	 * either input's display scale.
7801 	 */
7802 
7803 	/* Get the actual (normalized) weight and first digit of each input */
7804 
7805 	weight1 = 0;				/* values to use if var1 is zero */
7806 	firstdigit1 = 0;
7807 	for (i = 0; i < var1->ndigits; i++)
7808 	{
7809 		firstdigit1 = var1->digits[i];
7810 		if (firstdigit1 != 0)
7811 		{
7812 			weight1 = var1->weight - i;
7813 			break;
7814 		}
7815 	}
7816 
7817 	weight2 = 0;				/* values to use if var2 is zero */
7818 	firstdigit2 = 0;
7819 	for (i = 0; i < var2->ndigits; i++)
7820 	{
7821 		firstdigit2 = var2->digits[i];
7822 		if (firstdigit2 != 0)
7823 		{
7824 			weight2 = var2->weight - i;
7825 			break;
7826 		}
7827 	}
7828 
7829 	/*
7830 	 * Estimate weight of quotient.  If the two first digits are equal, we
7831 	 * can't be sure, but assume that var1 is less than var2.
7832 	 */
7833 	qweight = weight1 - weight2;
7834 	if (firstdigit1 <= firstdigit2)
7835 		qweight--;
7836 
7837 	/* Select result scale */
7838 	rscale = NUMERIC_MIN_SIG_DIGITS - qweight * DEC_DIGITS;
7839 	rscale = Max(rscale, var1->dscale);
7840 	rscale = Max(rscale, var2->dscale);
7841 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
7842 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
7843 
7844 	return rscale;
7845 }
7846 
7847 
7848 /*
7849  * mod_var() -
7850  *
7851  *	Calculate the modulo of two numerics at variable level
7852  */
7853 static void
mod_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)7854 mod_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
7855 {
7856 	NumericVar	tmp;
7857 
7858 	init_var(&tmp);
7859 
7860 	/* ---------
7861 	 * We do this using the equation
7862 	 *		mod(x,y) = x - trunc(x/y)*y
7863 	 * div_var can be persuaded to give us trunc(x/y) directly.
7864 	 * ----------
7865 	 */
7866 	div_var(var1, var2, &tmp, 0, false);
7867 
7868 	mul_var(var2, &tmp, &tmp, var2->dscale);
7869 
7870 	sub_var(var1, &tmp, result);
7871 
7872 	free_var(&tmp);
7873 }
7874 
7875 
7876 /*
7877  * ceil_var() -
7878  *
7879  *	Return the smallest integer greater than or equal to the argument
7880  *	on variable level
7881  */
7882 static void
ceil_var(const NumericVar * var,NumericVar * result)7883 ceil_var(const NumericVar *var, NumericVar *result)
7884 {
7885 	NumericVar	tmp;
7886 
7887 	init_var(&tmp);
7888 	set_var_from_var(var, &tmp);
7889 
7890 	trunc_var(&tmp, 0);
7891 
7892 	if (var->sign == NUMERIC_POS && cmp_var(var, &tmp) != 0)
7893 		add_var(&tmp, &const_one, &tmp);
7894 
7895 	set_var_from_var(&tmp, result);
7896 	free_var(&tmp);
7897 }
7898 
7899 
7900 /*
7901  * floor_var() -
7902  *
7903  *	Return the largest integer equal to or less than the argument
7904  *	on variable level
7905  */
7906 static void
floor_var(const NumericVar * var,NumericVar * result)7907 floor_var(const NumericVar *var, NumericVar *result)
7908 {
7909 	NumericVar	tmp;
7910 
7911 	init_var(&tmp);
7912 	set_var_from_var(var, &tmp);
7913 
7914 	trunc_var(&tmp, 0);
7915 
7916 	if (var->sign == NUMERIC_NEG && cmp_var(var, &tmp) != 0)
7917 		sub_var(&tmp, &const_one, &tmp);
7918 
7919 	set_var_from_var(&tmp, result);
7920 	free_var(&tmp);
7921 }
7922 
7923 
7924 /*
7925  * sqrt_var() -
7926  *
7927  *	Compute the square root of x using Newton's algorithm
7928  */
7929 static void
sqrt_var(const NumericVar * arg,NumericVar * result,int rscale)7930 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
7931 {
7932 	NumericVar	tmp_arg;
7933 	NumericVar	tmp_val;
7934 	NumericVar	last_val;
7935 	int			local_rscale;
7936 	int			stat;
7937 
7938 	local_rscale = rscale + 8;
7939 
7940 	stat = cmp_var(arg, &const_zero);
7941 	if (stat == 0)
7942 	{
7943 		zero_var(result);
7944 		result->dscale = rscale;
7945 		return;
7946 	}
7947 
7948 	/*
7949 	 * SQL2003 defines sqrt() in terms of power, so we need to emit the right
7950 	 * SQLSTATE error code if the operand is negative.
7951 	 */
7952 	if (stat < 0)
7953 		ereport(ERROR,
7954 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
7955 				 errmsg("cannot take square root of a negative number")));
7956 
7957 	init_var(&tmp_arg);
7958 	init_var(&tmp_val);
7959 	init_var(&last_val);
7960 
7961 	/* Copy arg in case it is the same var as result */
7962 	set_var_from_var(arg, &tmp_arg);
7963 
7964 	/*
7965 	 * Initialize the result to the first guess
7966 	 */
7967 	alloc_var(result, 1);
7968 	result->digits[0] = tmp_arg.digits[0] / 2;
7969 	if (result->digits[0] == 0)
7970 		result->digits[0] = 1;
7971 	result->weight = tmp_arg.weight / 2;
7972 	result->sign = NUMERIC_POS;
7973 
7974 	set_var_from_var(result, &last_val);
7975 
7976 	for (;;)
7977 	{
7978 		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
7979 
7980 		add_var(result, &tmp_val, result);
7981 		mul_var(result, &const_zero_point_five, result, local_rscale);
7982 
7983 		if (cmp_var(&last_val, result) == 0)
7984 			break;
7985 		set_var_from_var(result, &last_val);
7986 	}
7987 
7988 	free_var(&last_val);
7989 	free_var(&tmp_val);
7990 	free_var(&tmp_arg);
7991 
7992 	/* Round to requested precision */
7993 	round_var(result, rscale);
7994 }
7995 
7996 
7997 /*
7998  * exp_var() -
7999  *
8000  *	Raise e to the power of x, computed to rscale fractional digits
8001  */
8002 static void
exp_var(const NumericVar * arg,NumericVar * result,int rscale)8003 exp_var(const NumericVar *arg, NumericVar *result, int rscale)
8004 {
8005 	NumericVar	x;
8006 	NumericVar	elem;
8007 	NumericVar	ni;
8008 	double		val;
8009 	int			dweight;
8010 	int			ndiv2;
8011 	int			sig_digits;
8012 	int			local_rscale;
8013 
8014 	init_var(&x);
8015 	init_var(&elem);
8016 	init_var(&ni);
8017 
8018 	set_var_from_var(arg, &x);
8019 
8020 	/*
8021 	 * Estimate the dweight of the result using floating point arithmetic, so
8022 	 * that we can choose an appropriate local rscale for the calculation.
8023 	 */
8024 	val = numericvar_to_double_no_overflow(&x);
8025 
8026 	/* Guard against overflow/underflow */
8027 	/* If you change this limit, see also power_var()'s limit */
8028 	if (Abs(val) >= NUMERIC_MAX_RESULT_SCALE * 3)
8029 	{
8030 		if (val > 0)
8031 			ereport(ERROR,
8032 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8033 					 errmsg("value overflows numeric format")));
8034 		zero_var(result);
8035 		result->dscale = rscale;
8036 		return;
8037 	}
8038 
8039 	/* decimal weight = log10(e^x) = x * log10(e) */
8040 	dweight = (int) (val * 0.434294481903252);
8041 
8042 	/*
8043 	 * Reduce x to the range -0.01 <= x <= 0.01 (approximately) by dividing by
8044 	 * 2^n, to improve the convergence rate of the Taylor series.
8045 	 */
8046 	if (Abs(val) > 0.01)
8047 	{
8048 		NumericVar	tmp;
8049 
8050 		init_var(&tmp);
8051 		set_var_from_var(&const_two, &tmp);
8052 
8053 		ndiv2 = 1;
8054 		val /= 2;
8055 
8056 		while (Abs(val) > 0.01)
8057 		{
8058 			ndiv2++;
8059 			val /= 2;
8060 			add_var(&tmp, &tmp, &tmp);
8061 		}
8062 
8063 		local_rscale = x.dscale + ndiv2;
8064 		div_var_fast(&x, &tmp, &x, local_rscale, true);
8065 
8066 		free_var(&tmp);
8067 	}
8068 	else
8069 		ndiv2 = 0;
8070 
8071 	/*
8072 	 * Set the scale for the Taylor series expansion.  The final result has
8073 	 * (dweight + rscale + 1) significant digits.  In addition, we have to
8074 	 * raise the Taylor series result to the power 2^ndiv2, which introduces
8075 	 * an error of up to around log10(2^ndiv2) digits, so work with this many
8076 	 * extra digits of precision (plus a few more for good measure).
8077 	 */
8078 	sig_digits = 1 + dweight + rscale + (int) (ndiv2 * 0.301029995663981);
8079 	sig_digits = Max(sig_digits, 0) + 8;
8080 
8081 	local_rscale = sig_digits - 1;
8082 
8083 	/*
8084 	 * Use the Taylor series
8085 	 *
8086 	 * exp(x) = 1 + x + x^2/2! + x^3/3! + ...
8087 	 *
8088 	 * Given the limited range of x, this should converge reasonably quickly.
8089 	 * We run the series until the terms fall below the local_rscale limit.
8090 	 */
8091 	add_var(&const_one, &x, result);
8092 
8093 	mul_var(&x, &x, &elem, local_rscale);
8094 	set_var_from_var(&const_two, &ni);
8095 	div_var_fast(&elem, &ni, &elem, local_rscale, true);
8096 
8097 	while (elem.ndigits != 0)
8098 	{
8099 		add_var(result, &elem, result);
8100 
8101 		mul_var(&elem, &x, &elem, local_rscale);
8102 		add_var(&ni, &const_one, &ni);
8103 		div_var_fast(&elem, &ni, &elem, local_rscale, true);
8104 	}
8105 
8106 	/*
8107 	 * Compensate for the argument range reduction.  Since the weight of the
8108 	 * result doubles with each multiplication, we can reduce the local rscale
8109 	 * as we proceed.
8110 	 */
8111 	while (ndiv2-- > 0)
8112 	{
8113 		local_rscale = sig_digits - result->weight * 2 * DEC_DIGITS;
8114 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8115 		mul_var(result, result, result, local_rscale);
8116 	}
8117 
8118 	/* Round to requested rscale */
8119 	round_var(result, rscale);
8120 
8121 	free_var(&x);
8122 	free_var(&elem);
8123 	free_var(&ni);
8124 }
8125 
8126 
8127 /*
8128  * Estimate the dweight of the most significant decimal digit of the natural
8129  * logarithm of a number.
8130  *
8131  * Essentially, we're approximating log10(abs(ln(var))).  This is used to
8132  * determine the appropriate rscale when computing natural logarithms.
8133  */
8134 static int
estimate_ln_dweight(const NumericVar * var)8135 estimate_ln_dweight(const NumericVar *var)
8136 {
8137 	int			ln_dweight;
8138 
8139 	if (cmp_var(var, &const_zero_point_nine) >= 0 &&
8140 		cmp_var(var, &const_one_point_one) <= 0)
8141 	{
8142 		/*
8143 		 * 0.9 <= var <= 1.1
8144 		 *
8145 		 * ln(var) has a negative weight (possibly very large).  To get a
8146 		 * reasonably accurate result, estimate it using ln(1+x) ~= x.
8147 		 */
8148 		NumericVar	x;
8149 
8150 		init_var(&x);
8151 		sub_var(var, &const_one, &x);
8152 
8153 		if (x.ndigits > 0)
8154 		{
8155 			/* Use weight of most significant decimal digit of x */
8156 			ln_dweight = x.weight * DEC_DIGITS + (int) log10(x.digits[0]);
8157 		}
8158 		else
8159 		{
8160 			/* x = 0.  Since ln(1) = 0 exactly, we don't need extra digits */
8161 			ln_dweight = 0;
8162 		}
8163 
8164 		free_var(&x);
8165 	}
8166 	else
8167 	{
8168 		/*
8169 		 * Estimate the logarithm using the first couple of digits from the
8170 		 * input number.  This will give an accurate result whenever the input
8171 		 * is not too close to 1.
8172 		 */
8173 		if (var->ndigits > 0)
8174 		{
8175 			int			digits;
8176 			int			dweight;
8177 			double		ln_var;
8178 
8179 			digits = var->digits[0];
8180 			dweight = var->weight * DEC_DIGITS;
8181 
8182 			if (var->ndigits > 1)
8183 			{
8184 				digits = digits * NBASE + var->digits[1];
8185 				dweight -= DEC_DIGITS;
8186 			}
8187 
8188 			/*----------
8189 			 * We have var ~= digits * 10^dweight
8190 			 * so ln(var) ~= ln(digits) + dweight * ln(10)
8191 			 *----------
8192 			 */
8193 			ln_var = log((double) digits) + dweight * 2.302585092994046;
8194 			ln_dweight = (int) log10(Abs(ln_var));
8195 		}
8196 		else
8197 		{
8198 			/* Caller should fail on ln(0), but for the moment return zero */
8199 			ln_dweight = 0;
8200 		}
8201 	}
8202 
8203 	return ln_dweight;
8204 }
8205 
8206 
8207 /*
8208  * ln_var() -
8209  *
8210  *	Compute the natural log of x
8211  */
8212 static void
ln_var(const NumericVar * arg,NumericVar * result,int rscale)8213 ln_var(const NumericVar *arg, NumericVar *result, int rscale)
8214 {
8215 	NumericVar	x;
8216 	NumericVar	xx;
8217 	NumericVar	ni;
8218 	NumericVar	elem;
8219 	NumericVar	fact;
8220 	int			local_rscale;
8221 	int			cmp;
8222 
8223 	cmp = cmp_var(arg, &const_zero);
8224 	if (cmp == 0)
8225 		ereport(ERROR,
8226 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
8227 				 errmsg("cannot take logarithm of zero")));
8228 	else if (cmp < 0)
8229 		ereport(ERROR,
8230 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
8231 				 errmsg("cannot take logarithm of a negative number")));
8232 
8233 	init_var(&x);
8234 	init_var(&xx);
8235 	init_var(&ni);
8236 	init_var(&elem);
8237 	init_var(&fact);
8238 
8239 	set_var_from_var(arg, &x);
8240 	set_var_from_var(&const_two, &fact);
8241 
8242 	/*
8243 	 * Reduce input into range 0.9 < x < 1.1 with repeated sqrt() operations.
8244 	 *
8245 	 * The final logarithm will have up to around rscale+6 significant digits.
8246 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
8247 	 * rscale as we work so that we keep this many significant digits at each
8248 	 * step (plus a few more for good measure).
8249 	 */
8250 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
8251 	{
8252 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
8253 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8254 		sqrt_var(&x, &x, local_rscale);
8255 		mul_var(&fact, &const_two, &fact, 0);
8256 	}
8257 	while (cmp_var(&x, &const_one_point_one) >= 0)
8258 	{
8259 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
8260 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8261 		sqrt_var(&x, &x, local_rscale);
8262 		mul_var(&fact, &const_two, &fact, 0);
8263 	}
8264 
8265 	/*
8266 	 * We use the Taylor series for 0.5 * ln((1+z)/(1-z)),
8267 	 *
8268 	 * z + z^3/3 + z^5/5 + ...
8269 	 *
8270 	 * where z = (x-1)/(x+1) is in the range (approximately) -0.053 .. 0.048
8271 	 * due to the above range-reduction of x.
8272 	 *
8273 	 * The convergence of this is not as fast as one would like, but is
8274 	 * tolerable given that z is small.
8275 	 */
8276 	local_rscale = rscale + 8;
8277 
8278 	sub_var(&x, &const_one, result);
8279 	add_var(&x, &const_one, &elem);
8280 	div_var_fast(result, &elem, result, local_rscale, true);
8281 	set_var_from_var(result, &xx);
8282 	mul_var(result, result, &x, local_rscale);
8283 
8284 	set_var_from_var(&const_one, &ni);
8285 
8286 	for (;;)
8287 	{
8288 		add_var(&ni, &const_two, &ni);
8289 		mul_var(&xx, &x, &xx, local_rscale);
8290 		div_var_fast(&xx, &ni, &elem, local_rscale, true);
8291 
8292 		if (elem.ndigits == 0)
8293 			break;
8294 
8295 		add_var(result, &elem, result);
8296 
8297 		if (elem.weight < (result->weight - local_rscale * 2 / DEC_DIGITS))
8298 			break;
8299 	}
8300 
8301 	/* Compensate for argument range reduction, round to requested rscale */
8302 	mul_var(result, &fact, result, rscale);
8303 
8304 	free_var(&x);
8305 	free_var(&xx);
8306 	free_var(&ni);
8307 	free_var(&elem);
8308 	free_var(&fact);
8309 }
8310 
8311 
8312 /*
8313  * log_var() -
8314  *
8315  *	Compute the logarithm of num in a given base.
8316  *
8317  *	Note: this routine chooses dscale of the result.
8318  */
8319 static void
log_var(const NumericVar * base,const NumericVar * num,NumericVar * result)8320 log_var(const NumericVar *base, const NumericVar *num, NumericVar *result)
8321 {
8322 	NumericVar	ln_base;
8323 	NumericVar	ln_num;
8324 	int			ln_base_dweight;
8325 	int			ln_num_dweight;
8326 	int			result_dweight;
8327 	int			rscale;
8328 	int			ln_base_rscale;
8329 	int			ln_num_rscale;
8330 
8331 	init_var(&ln_base);
8332 	init_var(&ln_num);
8333 
8334 	/* Estimated dweights of ln(base), ln(num) and the final result */
8335 	ln_base_dweight = estimate_ln_dweight(base);
8336 	ln_num_dweight = estimate_ln_dweight(num);
8337 	result_dweight = ln_num_dweight - ln_base_dweight;
8338 
8339 	/*
8340 	 * Select the scale of the result so that it will have at least
8341 	 * NUMERIC_MIN_SIG_DIGITS significant digits and is not less than either
8342 	 * input's display scale.
8343 	 */
8344 	rscale = NUMERIC_MIN_SIG_DIGITS - result_dweight;
8345 	rscale = Max(rscale, base->dscale);
8346 	rscale = Max(rscale, num->dscale);
8347 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8348 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8349 
8350 	/*
8351 	 * Set the scales for ln(base) and ln(num) so that they each have more
8352 	 * significant digits than the final result.
8353 	 */
8354 	ln_base_rscale = rscale + result_dweight - ln_base_dweight + 8;
8355 	ln_base_rscale = Max(ln_base_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8356 
8357 	ln_num_rscale = rscale + result_dweight - ln_num_dweight + 8;
8358 	ln_num_rscale = Max(ln_num_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8359 
8360 	/* Form natural logarithms */
8361 	ln_var(base, &ln_base, ln_base_rscale);
8362 	ln_var(num, &ln_num, ln_num_rscale);
8363 
8364 	/* Divide and round to the required scale */
8365 	div_var_fast(&ln_num, &ln_base, result, rscale, true);
8366 
8367 	free_var(&ln_num);
8368 	free_var(&ln_base);
8369 }
8370 
8371 
8372 /*
8373  * power_var() -
8374  *
8375  *	Raise base to the power of exp
8376  *
8377  *	Note: this routine chooses dscale of the result.
8378  */
8379 static void
power_var(const NumericVar * base,const NumericVar * exp,NumericVar * result)8380 power_var(const NumericVar *base, const NumericVar *exp, NumericVar *result)
8381 {
8382 	int			res_sign;
8383 	NumericVar	abs_base;
8384 	NumericVar	ln_base;
8385 	NumericVar	ln_num;
8386 	int			ln_dweight;
8387 	int			rscale;
8388 	int			sig_digits;
8389 	int			local_rscale;
8390 	double		val;
8391 
8392 	/* If exp can be represented as an integer, use power_var_int */
8393 	if (exp->ndigits == 0 || exp->ndigits <= exp->weight + 1)
8394 	{
8395 		/* exact integer, but does it fit in int? */
8396 		int64		expval64;
8397 
8398 		if (numericvar_to_int64(exp, &expval64))
8399 		{
8400 			if (expval64 >= PG_INT32_MIN && expval64 <= PG_INT32_MAX)
8401 			{
8402 				/* Okay, select rscale */
8403 				rscale = NUMERIC_MIN_SIG_DIGITS;
8404 				rscale = Max(rscale, base->dscale);
8405 				rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8406 				rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8407 
8408 				power_var_int(base, (int) expval64, result, rscale);
8409 				return;
8410 			}
8411 		}
8412 	}
8413 
8414 	/*
8415 	 * This avoids log(0) for cases of 0 raised to a non-integer.  0 ^ 0 is
8416 	 * handled by power_var_int().
8417 	 */
8418 	if (cmp_var(base, &const_zero) == 0)
8419 	{
8420 		set_var_from_var(&const_zero, result);
8421 		result->dscale = NUMERIC_MIN_SIG_DIGITS;	/* no need to round */
8422 		return;
8423 	}
8424 
8425 	init_var(&abs_base);
8426 	init_var(&ln_base);
8427 	init_var(&ln_num);
8428 
8429 	/*
8430 	 * If base is negative, insist that exp be an integer.  The result is then
8431 	 * positive if exp is even and negative if exp is odd.
8432 	 */
8433 	if (base->sign == NUMERIC_NEG)
8434 	{
8435 		/*
8436 		 * Check that exp is an integer.  This error code is defined by the
8437 		 * SQL standard, and matches other errors in numeric_power().
8438 		 */
8439 		if (exp->ndigits > 0 && exp->ndigits > exp->weight + 1)
8440 			ereport(ERROR,
8441 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
8442 					 errmsg("a negative number raised to a non-integer power yields a complex result")));
8443 
8444 		/* Test if exp is odd or even */
8445 		if (exp->ndigits > 0 && exp->ndigits == exp->weight + 1 &&
8446 			(exp->digits[exp->ndigits - 1] & 1))
8447 			res_sign = NUMERIC_NEG;
8448 		else
8449 			res_sign = NUMERIC_POS;
8450 
8451 		/* Then work with abs(base) below */
8452 		set_var_from_var(base, &abs_base);
8453 		abs_base.sign = NUMERIC_POS;
8454 		base = &abs_base;
8455 	}
8456 	else
8457 		res_sign = NUMERIC_POS;
8458 
8459 	/*----------
8460 	 * Decide on the scale for the ln() calculation.  For this we need an
8461 	 * estimate of the weight of the result, which we obtain by doing an
8462 	 * initial low-precision calculation of exp * ln(base).
8463 	 *
8464 	 * We want result = e ^ (exp * ln(base))
8465 	 * so result dweight = log10(result) = exp * ln(base) * log10(e)
8466 	 *
8467 	 * We also perform a crude overflow test here so that we can exit early if
8468 	 * the full-precision result is sure to overflow, and to guard against
8469 	 * integer overflow when determining the scale for the real calculation.
8470 	 * exp_var() supports inputs up to NUMERIC_MAX_RESULT_SCALE * 3, so the
8471 	 * result will overflow if exp * ln(base) >= NUMERIC_MAX_RESULT_SCALE * 3.
8472 	 * Since the values here are only approximations, we apply a small fuzz
8473 	 * factor to this overflow test and let exp_var() determine the exact
8474 	 * overflow threshold so that it is consistent for all inputs.
8475 	 *----------
8476 	 */
8477 	ln_dweight = estimate_ln_dweight(base);
8478 
8479 	/*
8480 	 * Set the scale for the low-precision calculation, computing ln(base) to
8481 	 * around 8 significant digits.  Note that ln_dweight may be as small as
8482 	 * -SHRT_MAX, so the scale may exceed NUMERIC_MAX_DISPLAY_SCALE here.
8483 	 */
8484 	local_rscale = 8 - ln_dweight;
8485 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8486 
8487 	ln_var(base, &ln_base, local_rscale);
8488 
8489 	mul_var(&ln_base, exp, &ln_num, local_rscale);
8490 
8491 	val = numericvar_to_double_no_overflow(&ln_num);
8492 
8493 	/* initial overflow/underflow test with fuzz factor */
8494 	if (Abs(val) > NUMERIC_MAX_RESULT_SCALE * 3.01)
8495 	{
8496 		if (val > 0)
8497 			ereport(ERROR,
8498 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8499 					 errmsg("value overflows numeric format")));
8500 		zero_var(result);
8501 		result->dscale = NUMERIC_MAX_DISPLAY_SCALE;
8502 		return;
8503 	}
8504 
8505 	val *= 0.434294481903252;	/* approximate decimal result weight */
8506 
8507 	/* choose the result scale */
8508 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
8509 	rscale = Max(rscale, base->dscale);
8510 	rscale = Max(rscale, exp->dscale);
8511 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8512 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8513 
8514 	/* significant digits required in the result */
8515 	sig_digits = rscale + (int) val;
8516 	sig_digits = Max(sig_digits, 0);
8517 
8518 	/* set the scale for the real exp * ln(base) calculation */
8519 	local_rscale = sig_digits - ln_dweight + 8;
8520 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8521 
8522 	/* and do the real calculation */
8523 
8524 	ln_var(base, &ln_base, local_rscale);
8525 
8526 	mul_var(&ln_base, exp, &ln_num, local_rscale);
8527 
8528 	exp_var(&ln_num, result, rscale);
8529 
8530 	if (res_sign == NUMERIC_NEG && result->ndigits > 0)
8531 		result->sign = NUMERIC_NEG;
8532 
8533 	free_var(&ln_num);
8534 	free_var(&ln_base);
8535 	free_var(&abs_base);
8536 }
8537 
8538 /*
8539  * power_var_int() -
8540  *
8541  *	Raise base to the power of exp, where exp is an integer.
8542  */
8543 static void
power_var_int(const NumericVar * base,int exp,NumericVar * result,int rscale)8544 power_var_int(const NumericVar *base, int exp, NumericVar *result, int rscale)
8545 {
8546 	double		f;
8547 	int			p;
8548 	int			i;
8549 	int			sig_digits;
8550 	unsigned int mask;
8551 	bool		neg;
8552 	NumericVar	base_prod;
8553 	int			local_rscale;
8554 
8555 	/* Handle some common special cases, as well as corner cases */
8556 	switch (exp)
8557 	{
8558 		case 0:
8559 
8560 			/*
8561 			 * While 0 ^ 0 can be either 1 or indeterminate (error), we treat
8562 			 * it as 1 because most programming languages do this. SQL:2003
8563 			 * also requires a return value of 1.
8564 			 * https://en.wikipedia.org/wiki/Exponentiation#Zero_to_the_zero_power
8565 			 */
8566 			set_var_from_var(&const_one, result);
8567 			result->dscale = rscale;	/* no need to round */
8568 			return;
8569 		case 1:
8570 			set_var_from_var(base, result);
8571 			round_var(result, rscale);
8572 			return;
8573 		case -1:
8574 			div_var(&const_one, base, result, rscale, true);
8575 			return;
8576 		case 2:
8577 			mul_var(base, base, result, rscale);
8578 			return;
8579 		default:
8580 			break;
8581 	}
8582 
8583 	/* Handle the special case where the base is zero */
8584 	if (base->ndigits == 0)
8585 	{
8586 		if (exp < 0)
8587 			ereport(ERROR,
8588 					(errcode(ERRCODE_DIVISION_BY_ZERO),
8589 					 errmsg("division by zero")));
8590 		zero_var(result);
8591 		result->dscale = rscale;
8592 		return;
8593 	}
8594 
8595 	/*
8596 	 * The general case repeatedly multiplies base according to the bit
8597 	 * pattern of exp.
8598 	 *
8599 	 * First we need to estimate the weight of the result so that we know how
8600 	 * many significant digits are needed.
8601 	 */
8602 	f = base->digits[0];
8603 	p = base->weight * DEC_DIGITS;
8604 
8605 	for (i = 1; i < base->ndigits && i * DEC_DIGITS < 16; i++)
8606 	{
8607 		f = f * NBASE + base->digits[i];
8608 		p -= DEC_DIGITS;
8609 	}
8610 
8611 	/*----------
8612 	 * We have base ~= f * 10^p
8613 	 * so log10(result) = log10(base^exp) ~= exp * (log10(f) + p)
8614 	 *----------
8615 	 */
8616 	f = exp * (log10(f) + p);
8617 
8618 	/*
8619 	 * Apply crude overflow/underflow tests so we can exit early if the result
8620 	 * certainly will overflow/underflow.
8621 	 */
8622 	if (f > 3 * SHRT_MAX * DEC_DIGITS)
8623 		ereport(ERROR,
8624 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8625 				 errmsg("value overflows numeric format")));
8626 	if (f + 1 < -rscale || f + 1 < -NUMERIC_MAX_DISPLAY_SCALE)
8627 	{
8628 		zero_var(result);
8629 		result->dscale = rscale;
8630 		return;
8631 	}
8632 
8633 	/*
8634 	 * Approximate number of significant digits in the result.  Note that the
8635 	 * underflow test above means that this is necessarily >= 0.
8636 	 */
8637 	sig_digits = 1 + rscale + (int) f;
8638 
8639 	/*
8640 	 * The multiplications to produce the result may introduce an error of up
8641 	 * to around log10(abs(exp)) digits, so work with this many extra digits
8642 	 * of precision (plus a few more for good measure).
8643 	 */
8644 	sig_digits += (int) log(fabs((double) exp)) + 8;
8645 
8646 	/*
8647 	 * Now we can proceed with the multiplications.
8648 	 */
8649 	neg = (exp < 0);
8650 	mask = Abs(exp);
8651 
8652 	init_var(&base_prod);
8653 	set_var_from_var(base, &base_prod);
8654 
8655 	if (mask & 1)
8656 		set_var_from_var(base, result);
8657 	else
8658 		set_var_from_var(&const_one, result);
8659 
8660 	while ((mask >>= 1) > 0)
8661 	{
8662 		/*
8663 		 * Do the multiplications using rscales large enough to hold the
8664 		 * results to the required number of significant digits, but don't
8665 		 * waste time by exceeding the scales of the numbers themselves.
8666 		 */
8667 		local_rscale = sig_digits - 2 * base_prod.weight * DEC_DIGITS;
8668 		local_rscale = Min(local_rscale, 2 * base_prod.dscale);
8669 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8670 
8671 		mul_var(&base_prod, &base_prod, &base_prod, local_rscale);
8672 
8673 		if (mask & 1)
8674 		{
8675 			local_rscale = sig_digits -
8676 				(base_prod.weight + result->weight) * DEC_DIGITS;
8677 			local_rscale = Min(local_rscale,
8678 							   base_prod.dscale + result->dscale);
8679 			local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8680 
8681 			mul_var(&base_prod, result, result, local_rscale);
8682 		}
8683 
8684 		/*
8685 		 * When abs(base) > 1, the number of digits to the left of the decimal
8686 		 * point in base_prod doubles at each iteration, so if exp is large we
8687 		 * could easily spend large amounts of time and memory space doing the
8688 		 * multiplications.  But once the weight exceeds what will fit in
8689 		 * int16, the final result is guaranteed to overflow (or underflow, if
8690 		 * exp < 0), so we can give up before wasting too many cycles.
8691 		 */
8692 		if (base_prod.weight > SHRT_MAX || result->weight > SHRT_MAX)
8693 		{
8694 			/* overflow, unless neg, in which case result should be 0 */
8695 			if (!neg)
8696 				ereport(ERROR,
8697 						(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8698 						 errmsg("value overflows numeric format")));
8699 			zero_var(result);
8700 			neg = false;
8701 			break;
8702 		}
8703 	}
8704 
8705 	free_var(&base_prod);
8706 
8707 	/* Compensate for input sign, and round to requested rscale */
8708 	if (neg)
8709 		div_var_fast(&const_one, result, result, rscale, true);
8710 	else
8711 		round_var(result, rscale);
8712 }
8713 
8714 /*
8715  * power_ten_int() -
8716  *
8717  *	Raise ten to the power of exp, where exp is an integer.  Note that unlike
8718  *	power_var_int(), this does no overflow/underflow checking or rounding.
8719  */
8720 static void
power_ten_int(int exp,NumericVar * result)8721 power_ten_int(int exp, NumericVar *result)
8722 {
8723 	/* Construct the result directly, starting from 10^0 = 1 */
8724 	set_var_from_var(&const_one, result);
8725 
8726 	/* Scale needed to represent the result exactly */
8727 	result->dscale = exp < 0 ? -exp : 0;
8728 
8729 	/* Base-NBASE weight of result and remaining exponent */
8730 	if (exp >= 0)
8731 		result->weight = exp / DEC_DIGITS;
8732 	else
8733 		result->weight = (exp + 1) / DEC_DIGITS - 1;
8734 
8735 	exp -= result->weight * DEC_DIGITS;
8736 
8737 	/* Final adjustment of the result's single NBASE digit */
8738 	while (exp-- > 0)
8739 		result->digits[0] *= 10;
8740 }
8741 
8742 
8743 /* ----------------------------------------------------------------------
8744  *
8745  * Following are the lowest level functions that operate unsigned
8746  * on the variable level
8747  *
8748  * ----------------------------------------------------------------------
8749  */
8750 
8751 
8752 /* ----------
8753  * cmp_abs() -
8754  *
8755  *	Compare the absolute values of var1 and var2
8756  *	Returns:	-1 for ABS(var1) < ABS(var2)
8757  *				0  for ABS(var1) == ABS(var2)
8758  *				1  for ABS(var1) > ABS(var2)
8759  * ----------
8760  */
8761 static int
cmp_abs(const NumericVar * var1,const NumericVar * var2)8762 cmp_abs(const NumericVar *var1, const NumericVar *var2)
8763 {
8764 	return cmp_abs_common(var1->digits, var1->ndigits, var1->weight,
8765 						  var2->digits, var2->ndigits, var2->weight);
8766 }
8767 
8768 /* ----------
8769  * cmp_abs_common() -
8770  *
8771  *	Main routine of cmp_abs(). This function can be used by both
8772  *	NumericVar and Numeric.
8773  * ----------
8774  */
8775 static int
cmp_abs_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,const NumericDigit * var2digits,int var2ndigits,int var2weight)8776 cmp_abs_common(const NumericDigit *var1digits, int var1ndigits, int var1weight,
8777 			   const NumericDigit *var2digits, int var2ndigits, int var2weight)
8778 {
8779 	int			i1 = 0;
8780 	int			i2 = 0;
8781 
8782 	/* Check any digits before the first common digit */
8783 
8784 	while (var1weight > var2weight && i1 < var1ndigits)
8785 	{
8786 		if (var1digits[i1++] != 0)
8787 			return 1;
8788 		var1weight--;
8789 	}
8790 	while (var2weight > var1weight && i2 < var2ndigits)
8791 	{
8792 		if (var2digits[i2++] != 0)
8793 			return -1;
8794 		var2weight--;
8795 	}
8796 
8797 	/* At this point, either w1 == w2 or we've run out of digits */
8798 
8799 	if (var1weight == var2weight)
8800 	{
8801 		while (i1 < var1ndigits && i2 < var2ndigits)
8802 		{
8803 			int			stat = var1digits[i1++] - var2digits[i2++];
8804 
8805 			if (stat)
8806 			{
8807 				if (stat > 0)
8808 					return 1;
8809 				return -1;
8810 			}
8811 		}
8812 	}
8813 
8814 	/*
8815 	 * At this point, we've run out of digits on one side or the other; so any
8816 	 * remaining nonzero digits imply that side is larger
8817 	 */
8818 	while (i1 < var1ndigits)
8819 	{
8820 		if (var1digits[i1++] != 0)
8821 			return 1;
8822 	}
8823 	while (i2 < var2ndigits)
8824 	{
8825 		if (var2digits[i2++] != 0)
8826 			return -1;
8827 	}
8828 
8829 	return 0;
8830 }
8831 
8832 
8833 /*
8834  * add_abs() -
8835  *
8836  *	Add the absolute values of two variables into result.
8837  *	result might point to one of the operands without danger.
8838  */
8839 static void
add_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8840 add_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8841 {
8842 	NumericDigit *res_buf;
8843 	NumericDigit *res_digits;
8844 	int			res_ndigits;
8845 	int			res_weight;
8846 	int			res_rscale,
8847 				rscale1,
8848 				rscale2;
8849 	int			res_dscale;
8850 	int			i,
8851 				i1,
8852 				i2;
8853 	int			carry = 0;
8854 
8855 	/* copy these values into local vars for speed in inner loop */
8856 	int			var1ndigits = var1->ndigits;
8857 	int			var2ndigits = var2->ndigits;
8858 	NumericDigit *var1digits = var1->digits;
8859 	NumericDigit *var2digits = var2->digits;
8860 
8861 	res_weight = Max(var1->weight, var2->weight) + 1;
8862 
8863 	res_dscale = Max(var1->dscale, var2->dscale);
8864 
8865 	/* Note: here we are figuring rscale in base-NBASE digits */
8866 	rscale1 = var1->ndigits - var1->weight - 1;
8867 	rscale2 = var2->ndigits - var2->weight - 1;
8868 	res_rscale = Max(rscale1, rscale2);
8869 
8870 	res_ndigits = res_rscale + res_weight + 1;
8871 	if (res_ndigits <= 0)
8872 		res_ndigits = 1;
8873 
8874 	res_buf = digitbuf_alloc(res_ndigits + 1);
8875 	res_buf[0] = 0;				/* spare digit for later rounding */
8876 	res_digits = res_buf + 1;
8877 
8878 	i1 = res_rscale + var1->weight + 1;
8879 	i2 = res_rscale + var2->weight + 1;
8880 	for (i = res_ndigits - 1; i >= 0; i--)
8881 	{
8882 		i1--;
8883 		i2--;
8884 		if (i1 >= 0 && i1 < var1ndigits)
8885 			carry += var1digits[i1];
8886 		if (i2 >= 0 && i2 < var2ndigits)
8887 			carry += var2digits[i2];
8888 
8889 		if (carry >= NBASE)
8890 		{
8891 			res_digits[i] = carry - NBASE;
8892 			carry = 1;
8893 		}
8894 		else
8895 		{
8896 			res_digits[i] = carry;
8897 			carry = 0;
8898 		}
8899 	}
8900 
8901 	Assert(carry == 0);			/* else we failed to allow for carry out */
8902 
8903 	digitbuf_free(result->buf);
8904 	result->ndigits = res_ndigits;
8905 	result->buf = res_buf;
8906 	result->digits = res_digits;
8907 	result->weight = res_weight;
8908 	result->dscale = res_dscale;
8909 
8910 	/* Remove leading/trailing zeroes */
8911 	strip_var(result);
8912 }
8913 
8914 
8915 /*
8916  * sub_abs()
8917  *
8918  *	Subtract the absolute value of var2 from the absolute value of var1
8919  *	and store in result. result might point to one of the operands
8920  *	without danger.
8921  *
8922  *	ABS(var1) MUST BE GREATER OR EQUAL ABS(var2) !!!
8923  */
8924 static void
sub_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8925 sub_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8926 {
8927 	NumericDigit *res_buf;
8928 	NumericDigit *res_digits;
8929 	int			res_ndigits;
8930 	int			res_weight;
8931 	int			res_rscale,
8932 				rscale1,
8933 				rscale2;
8934 	int			res_dscale;
8935 	int			i,
8936 				i1,
8937 				i2;
8938 	int			borrow = 0;
8939 
8940 	/* copy these values into local vars for speed in inner loop */
8941 	int			var1ndigits = var1->ndigits;
8942 	int			var2ndigits = var2->ndigits;
8943 	NumericDigit *var1digits = var1->digits;
8944 	NumericDigit *var2digits = var2->digits;
8945 
8946 	res_weight = var1->weight;
8947 
8948 	res_dscale = Max(var1->dscale, var2->dscale);
8949 
8950 	/* Note: here we are figuring rscale in base-NBASE digits */
8951 	rscale1 = var1->ndigits - var1->weight - 1;
8952 	rscale2 = var2->ndigits - var2->weight - 1;
8953 	res_rscale = Max(rscale1, rscale2);
8954 
8955 	res_ndigits = res_rscale + res_weight + 1;
8956 	if (res_ndigits <= 0)
8957 		res_ndigits = 1;
8958 
8959 	res_buf = digitbuf_alloc(res_ndigits + 1);
8960 	res_buf[0] = 0;				/* spare digit for later rounding */
8961 	res_digits = res_buf + 1;
8962 
8963 	i1 = res_rscale + var1->weight + 1;
8964 	i2 = res_rscale + var2->weight + 1;
8965 	for (i = res_ndigits - 1; i >= 0; i--)
8966 	{
8967 		i1--;
8968 		i2--;
8969 		if (i1 >= 0 && i1 < var1ndigits)
8970 			borrow += var1digits[i1];
8971 		if (i2 >= 0 && i2 < var2ndigits)
8972 			borrow -= var2digits[i2];
8973 
8974 		if (borrow < 0)
8975 		{
8976 			res_digits[i] = borrow + NBASE;
8977 			borrow = -1;
8978 		}
8979 		else
8980 		{
8981 			res_digits[i] = borrow;
8982 			borrow = 0;
8983 		}
8984 	}
8985 
8986 	Assert(borrow == 0);		/* else caller gave us var1 < var2 */
8987 
8988 	digitbuf_free(result->buf);
8989 	result->ndigits = res_ndigits;
8990 	result->buf = res_buf;
8991 	result->digits = res_digits;
8992 	result->weight = res_weight;
8993 	result->dscale = res_dscale;
8994 
8995 	/* Remove leading/trailing zeroes */
8996 	strip_var(result);
8997 }
8998 
8999 /*
9000  * round_var
9001  *
9002  * Round the value of a variable to no more than rscale decimal digits
9003  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
9004  * rounding before the decimal point.
9005  */
9006 static void
round_var(NumericVar * var,int rscale)9007 round_var(NumericVar *var, int rscale)
9008 {
9009 	NumericDigit *digits = var->digits;
9010 	int			di;
9011 	int			ndigits;
9012 	int			carry;
9013 
9014 	var->dscale = rscale;
9015 
9016 	/* decimal digits wanted */
9017 	di = (var->weight + 1) * DEC_DIGITS + rscale;
9018 
9019 	/*
9020 	 * If di = 0, the value loses all digits, but could round up to 1 if its
9021 	 * first extra digit is >= 5.  If di < 0 the result must be 0.
9022 	 */
9023 	if (di < 0)
9024 	{
9025 		var->ndigits = 0;
9026 		var->weight = 0;
9027 		var->sign = NUMERIC_POS;
9028 	}
9029 	else
9030 	{
9031 		/* NBASE digits wanted */
9032 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
9033 
9034 		/* 0, or number of decimal digits to keep in last NBASE digit */
9035 		di %= DEC_DIGITS;
9036 
9037 		if (ndigits < var->ndigits ||
9038 			(ndigits == var->ndigits && di > 0))
9039 		{
9040 			var->ndigits = ndigits;
9041 
9042 #if DEC_DIGITS == 1
9043 			/* di must be zero */
9044 			carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
9045 #else
9046 			if (di == 0)
9047 				carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
9048 			else
9049 			{
9050 				/* Must round within last NBASE digit */
9051 				int			extra,
9052 							pow10;
9053 
9054 #if DEC_DIGITS == 4
9055 				pow10 = round_powers[di];
9056 #elif DEC_DIGITS == 2
9057 				pow10 = 10;
9058 #else
9059 #error unsupported NBASE
9060 #endif
9061 				extra = digits[--ndigits] % pow10;
9062 				digits[ndigits] -= extra;
9063 				carry = 0;
9064 				if (extra >= pow10 / 2)
9065 				{
9066 					pow10 += digits[ndigits];
9067 					if (pow10 >= NBASE)
9068 					{
9069 						pow10 -= NBASE;
9070 						carry = 1;
9071 					}
9072 					digits[ndigits] = pow10;
9073 				}
9074 			}
9075 #endif
9076 
9077 			/* Propagate carry if needed */
9078 			while (carry)
9079 			{
9080 				carry += digits[--ndigits];
9081 				if (carry >= NBASE)
9082 				{
9083 					digits[ndigits] = carry - NBASE;
9084 					carry = 1;
9085 				}
9086 				else
9087 				{
9088 					digits[ndigits] = carry;
9089 					carry = 0;
9090 				}
9091 			}
9092 
9093 			if (ndigits < 0)
9094 			{
9095 				Assert(ndigits == -1);	/* better not have added > 1 digit */
9096 				Assert(var->digits > var->buf);
9097 				var->digits--;
9098 				var->ndigits++;
9099 				var->weight++;
9100 			}
9101 		}
9102 	}
9103 }
9104 
9105 /*
9106  * trunc_var
9107  *
9108  * Truncate (towards zero) the value of a variable at rscale decimal digits
9109  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
9110  * truncation before the decimal point.
9111  */
9112 static void
trunc_var(NumericVar * var,int rscale)9113 trunc_var(NumericVar *var, int rscale)
9114 {
9115 	int			di;
9116 	int			ndigits;
9117 
9118 	var->dscale = rscale;
9119 
9120 	/* decimal digits wanted */
9121 	di = (var->weight + 1) * DEC_DIGITS + rscale;
9122 
9123 	/*
9124 	 * If di <= 0, the value loses all digits.
9125 	 */
9126 	if (di <= 0)
9127 	{
9128 		var->ndigits = 0;
9129 		var->weight = 0;
9130 		var->sign = NUMERIC_POS;
9131 	}
9132 	else
9133 	{
9134 		/* NBASE digits wanted */
9135 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
9136 
9137 		if (ndigits <= var->ndigits)
9138 		{
9139 			var->ndigits = ndigits;
9140 
9141 #if DEC_DIGITS == 1
9142 			/* no within-digit stuff to worry about */
9143 #else
9144 			/* 0, or number of decimal digits to keep in last NBASE digit */
9145 			di %= DEC_DIGITS;
9146 
9147 			if (di > 0)
9148 			{
9149 				/* Must truncate within last NBASE digit */
9150 				NumericDigit *digits = var->digits;
9151 				int			extra,
9152 							pow10;
9153 
9154 #if DEC_DIGITS == 4
9155 				pow10 = round_powers[di];
9156 #elif DEC_DIGITS == 2
9157 				pow10 = 10;
9158 #else
9159 #error unsupported NBASE
9160 #endif
9161 				extra = digits[--ndigits] % pow10;
9162 				digits[ndigits] -= extra;
9163 			}
9164 #endif
9165 		}
9166 	}
9167 }
9168 
9169 /*
9170  * strip_var
9171  *
9172  * Strip any leading and trailing zeroes from a numeric variable
9173  */
9174 static void
strip_var(NumericVar * var)9175 strip_var(NumericVar *var)
9176 {
9177 	NumericDigit *digits = var->digits;
9178 	int			ndigits = var->ndigits;
9179 
9180 	/* Strip leading zeroes */
9181 	while (ndigits > 0 && *digits == 0)
9182 	{
9183 		digits++;
9184 		var->weight--;
9185 		ndigits--;
9186 	}
9187 
9188 	/* Strip trailing zeroes */
9189 	while (ndigits > 0 && digits[ndigits - 1] == 0)
9190 		ndigits--;
9191 
9192 	/* If it's zero, normalize the sign and weight */
9193 	if (ndigits == 0)
9194 	{
9195 		var->sign = NUMERIC_POS;
9196 		var->weight = 0;
9197 	}
9198 
9199 	var->digits = digits;
9200 	var->ndigits = ndigits;
9201 }
9202 
9203 
9204 /* ----------------------------------------------------------------------
9205  *
9206  * Fast sum accumulator functions
9207  *
9208  * ----------------------------------------------------------------------
9209  */
9210 
9211 /*
9212  * Reset the accumulator's value to zero.  The buffers to hold the digits
9213  * are not free'd.
9214  */
9215 static void
accum_sum_reset(NumericSumAccum * accum)9216 accum_sum_reset(NumericSumAccum *accum)
9217 {
9218 	int			i;
9219 
9220 	accum->dscale = 0;
9221 	for (i = 0; i < accum->ndigits; i++)
9222 	{
9223 		accum->pos_digits[i] = 0;
9224 		accum->neg_digits[i] = 0;
9225 	}
9226 }
9227 
9228 /*
9229  * Accumulate a new value.
9230  */
9231 static void
accum_sum_add(NumericSumAccum * accum,const NumericVar * val)9232 accum_sum_add(NumericSumAccum *accum, const NumericVar *val)
9233 {
9234 	int32	   *accum_digits;
9235 	int			i,
9236 				val_i;
9237 	int			val_ndigits;
9238 	NumericDigit *val_digits;
9239 
9240 	/*
9241 	 * If we have accumulated too many values since the last carry
9242 	 * propagation, do it now, to avoid overflowing.  (We could allow more
9243 	 * than NBASE - 1, if we reserved two extra digits, rather than one, for
9244 	 * carry propagation.  But even with NBASE - 1, this needs to be done so
9245 	 * seldom, that the performance difference is negligible.)
9246 	 */
9247 	if (accum->num_uncarried == NBASE - 1)
9248 		accum_sum_carry(accum);
9249 
9250 	/*
9251 	 * Adjust the weight or scale of the old value, so that it can accommodate
9252 	 * the new value.
9253 	 */
9254 	accum_sum_rescale(accum, val);
9255 
9256 	/* */
9257 	if (val->sign == NUMERIC_POS)
9258 		accum_digits = accum->pos_digits;
9259 	else
9260 		accum_digits = accum->neg_digits;
9261 
9262 	/* copy these values into local vars for speed in loop */
9263 	val_ndigits = val->ndigits;
9264 	val_digits = val->digits;
9265 
9266 	i = accum->weight - val->weight;
9267 	for (val_i = 0; val_i < val_ndigits; val_i++)
9268 	{
9269 		accum_digits[i] += (int32) val_digits[val_i];
9270 		i++;
9271 	}
9272 
9273 	accum->num_uncarried++;
9274 }
9275 
9276 /*
9277  * Propagate carries.
9278  */
9279 static void
accum_sum_carry(NumericSumAccum * accum)9280 accum_sum_carry(NumericSumAccum *accum)
9281 {
9282 	int			i;
9283 	int			ndigits;
9284 	int32	   *dig;
9285 	int32		carry;
9286 	int32		newdig = 0;
9287 
9288 	/*
9289 	 * If no new values have been added since last carry propagation, nothing
9290 	 * to do.
9291 	 */
9292 	if (accum->num_uncarried == 0)
9293 		return;
9294 
9295 	/*
9296 	 * We maintain that the weight of the accumulator is always one larger
9297 	 * than needed to hold the current value, before carrying, to make sure
9298 	 * there is enough space for the possible extra digit when carry is
9299 	 * propagated.  We cannot expand the buffer here, unless we require
9300 	 * callers of accum_sum_final() to switch to the right memory context.
9301 	 */
9302 	Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
9303 
9304 	ndigits = accum->ndigits;
9305 
9306 	/* Propagate carry in the positive sum */
9307 	dig = accum->pos_digits;
9308 	carry = 0;
9309 	for (i = ndigits - 1; i >= 0; i--)
9310 	{
9311 		newdig = dig[i] + carry;
9312 		if (newdig >= NBASE)
9313 		{
9314 			carry = newdig / NBASE;
9315 			newdig -= carry * NBASE;
9316 		}
9317 		else
9318 			carry = 0;
9319 		dig[i] = newdig;
9320 	}
9321 	/* Did we use up the digit reserved for carry propagation? */
9322 	if (newdig > 0)
9323 		accum->have_carry_space = false;
9324 
9325 	/* And the same for the negative sum */
9326 	dig = accum->neg_digits;
9327 	carry = 0;
9328 	for (i = ndigits - 1; i >= 0; i--)
9329 	{
9330 		newdig = dig[i] + carry;
9331 		if (newdig >= NBASE)
9332 		{
9333 			carry = newdig / NBASE;
9334 			newdig -= carry * NBASE;
9335 		}
9336 		else
9337 			carry = 0;
9338 		dig[i] = newdig;
9339 	}
9340 	if (newdig > 0)
9341 		accum->have_carry_space = false;
9342 
9343 	accum->num_uncarried = 0;
9344 }
9345 
9346 /*
9347  * Re-scale accumulator to accommodate new value.
9348  *
9349  * If the new value has more digits than the current digit buffers in the
9350  * accumulator, enlarge the buffers.
9351  */
9352 static void
accum_sum_rescale(NumericSumAccum * accum,const NumericVar * val)9353 accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val)
9354 {
9355 	int			old_weight = accum->weight;
9356 	int			old_ndigits = accum->ndigits;
9357 	int			accum_ndigits;
9358 	int			accum_weight;
9359 	int			accum_rscale;
9360 	int			val_rscale;
9361 
9362 	accum_weight = old_weight;
9363 	accum_ndigits = old_ndigits;
9364 
9365 	/*
9366 	 * Does the new value have a larger weight? If so, enlarge the buffers,
9367 	 * and shift the existing value to the new weight, by adding leading
9368 	 * zeros.
9369 	 *
9370 	 * We enforce that the accumulator always has a weight one larger than
9371 	 * needed for the inputs, so that we have space for an extra digit at the
9372 	 * final carry-propagation phase, if necessary.
9373 	 */
9374 	if (val->weight >= accum_weight)
9375 	{
9376 		accum_weight = val->weight + 1;
9377 		accum_ndigits = accum_ndigits + (accum_weight - old_weight);
9378 	}
9379 
9380 	/*
9381 	 * Even though the new value is small, we might've used up the space
9382 	 * reserved for the carry digit in the last call to accum_sum_carry().  If
9383 	 * so, enlarge to make room for another one.
9384 	 */
9385 	else if (!accum->have_carry_space)
9386 	{
9387 		accum_weight++;
9388 		accum_ndigits++;
9389 	}
9390 
9391 	/* Is the new value wider on the right side? */
9392 	accum_rscale = accum_ndigits - accum_weight - 1;
9393 	val_rscale = val->ndigits - val->weight - 1;
9394 	if (val_rscale > accum_rscale)
9395 		accum_ndigits = accum_ndigits + (val_rscale - accum_rscale);
9396 
9397 	if (accum_ndigits != old_ndigits ||
9398 		accum_weight != old_weight)
9399 	{
9400 		int32	   *new_pos_digits;
9401 		int32	   *new_neg_digits;
9402 		int			weightdiff;
9403 
9404 		weightdiff = accum_weight - old_weight;
9405 
9406 		new_pos_digits = palloc0(accum_ndigits * sizeof(int32));
9407 		new_neg_digits = palloc0(accum_ndigits * sizeof(int32));
9408 
9409 		if (accum->pos_digits)
9410 		{
9411 			memcpy(&new_pos_digits[weightdiff], accum->pos_digits,
9412 				   old_ndigits * sizeof(int32));
9413 			pfree(accum->pos_digits);
9414 
9415 			memcpy(&new_neg_digits[weightdiff], accum->neg_digits,
9416 				   old_ndigits * sizeof(int32));
9417 			pfree(accum->neg_digits);
9418 		}
9419 
9420 		accum->pos_digits = new_pos_digits;
9421 		accum->neg_digits = new_neg_digits;
9422 
9423 		accum->weight = accum_weight;
9424 		accum->ndigits = accum_ndigits;
9425 
9426 		Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
9427 		accum->have_carry_space = true;
9428 	}
9429 
9430 	if (val->dscale > accum->dscale)
9431 		accum->dscale = val->dscale;
9432 }
9433 
9434 /*
9435  * Return the current value of the accumulator.  This perform final carry
9436  * propagation, and adds together the positive and negative sums.
9437  *
9438  * Unlike all the other routines, the caller is not required to switch to
9439  * the memory context that holds the accumulator.
9440  */
9441 static void
accum_sum_final(NumericSumAccum * accum,NumericVar * result)9442 accum_sum_final(NumericSumAccum *accum, NumericVar *result)
9443 {
9444 	int			i;
9445 	NumericVar	pos_var;
9446 	NumericVar	neg_var;
9447 
9448 	if (accum->ndigits == 0)
9449 	{
9450 		set_var_from_var(&const_zero, result);
9451 		return;
9452 	}
9453 
9454 	/* Perform final carry */
9455 	accum_sum_carry(accum);
9456 
9457 	/* Create NumericVars representing the positive and negative sums */
9458 	init_var(&pos_var);
9459 	init_var(&neg_var);
9460 
9461 	pos_var.ndigits = neg_var.ndigits = accum->ndigits;
9462 	pos_var.weight = neg_var.weight = accum->weight;
9463 	pos_var.dscale = neg_var.dscale = accum->dscale;
9464 	pos_var.sign = NUMERIC_POS;
9465 	neg_var.sign = NUMERIC_NEG;
9466 
9467 	pos_var.buf = pos_var.digits = digitbuf_alloc(accum->ndigits);
9468 	neg_var.buf = neg_var.digits = digitbuf_alloc(accum->ndigits);
9469 
9470 	for (i = 0; i < accum->ndigits; i++)
9471 	{
9472 		Assert(accum->pos_digits[i] < NBASE);
9473 		pos_var.digits[i] = (int16) accum->pos_digits[i];
9474 
9475 		Assert(accum->neg_digits[i] < NBASE);
9476 		neg_var.digits[i] = (int16) accum->neg_digits[i];
9477 	}
9478 
9479 	/* And add them together */
9480 	add_var(&pos_var, &neg_var, result);
9481 
9482 	/* Remove leading/trailing zeroes */
9483 	strip_var(result);
9484 }
9485 
9486 /*
9487  * Copy an accumulator's state.
9488  *
9489  * 'dst' is assumed to be uninitialized beforehand.  No attempt is made at
9490  * freeing old values.
9491  */
9492 static void
accum_sum_copy(NumericSumAccum * dst,NumericSumAccum * src)9493 accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src)
9494 {
9495 	dst->pos_digits = palloc(src->ndigits * sizeof(int32));
9496 	dst->neg_digits = palloc(src->ndigits * sizeof(int32));
9497 
9498 	memcpy(dst->pos_digits, src->pos_digits, src->ndigits * sizeof(int32));
9499 	memcpy(dst->neg_digits, src->neg_digits, src->ndigits * sizeof(int32));
9500 	dst->num_uncarried = src->num_uncarried;
9501 	dst->ndigits = src->ndigits;
9502 	dst->weight = src->weight;
9503 	dst->dscale = src->dscale;
9504 }
9505 
9506 /*
9507  * Add the current value of 'accum2' into 'accum'.
9508  */
9509 static void
accum_sum_combine(NumericSumAccum * accum,NumericSumAccum * accum2)9510 accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2)
9511 {
9512 	NumericVar	tmp_var;
9513 
9514 	init_var(&tmp_var);
9515 
9516 	accum_sum_final(accum2, &tmp_var);
9517 	accum_sum_add(accum, &tmp_var);
9518 
9519 	free_var(&tmp_var);
9520 }
9521