1 /*-------------------------------------------------------------------------
2  *
3  * numeric.c
4  *	  An exact numeric data type for the Postgres database system
5  *
6  * Original coding 1998, Jan Wieck.  Heavily revised 2003, Tom Lane.
7  *
8  * Many of the algorithmic ideas are borrowed from David M. Smith's "FM"
9  * multiple-precision math library, most recently published as Algorithm
10  * 786: Multiple-Precision Complex Arithmetic and Functions, ACM
11  * Transactions on Mathematical Software, Vol. 24, No. 4, December 1998,
12  * pages 359-367.
13  *
14  * Copyright (c) 1998-2016, PostgreSQL Global Development Group
15  *
16  * IDENTIFICATION
17  *	  src/backend/utils/adt/numeric.c
18  *
19  *-------------------------------------------------------------------------
20  */
21 
22 #include "postgres.h"
23 
24 #include <ctype.h>
25 #include <float.h>
26 #include <limits.h>
27 #include <math.h>
28 
29 #include "access/hash.h"
30 #include "catalog/pg_type.h"
31 #include "funcapi.h"
32 #include "lib/hyperloglog.h"
33 #include "libpq/pqformat.h"
34 #include "miscadmin.h"
35 #include "nodes/nodeFuncs.h"
36 #include "utils/array.h"
37 #include "utils/builtins.h"
38 #include "utils/guc.h"
39 #include "utils/int8.h"
40 #include "utils/numeric.h"
41 #include "utils/sortsupport.h"
42 
43 /* ----------
44  * Uncomment the following to enable compilation of dump_numeric()
45  * and dump_var() and to get a dump of any result produced by make_result().
46  * ----------
47 #define NUMERIC_DEBUG
48  */
49 
50 
51 /* ----------
52  * Local data types
53  *
54  * Numeric values are represented in a base-NBASE floating point format.
55  * Each "digit" ranges from 0 to NBASE-1.  The type NumericDigit is signed
56  * and wide enough to store a digit.  We assume that NBASE*NBASE can fit in
57  * an int.  Although the purely calculational routines could handle any even
58  * NBASE that's less than sqrt(INT_MAX), in practice we are only interested
59  * in NBASE a power of ten, so that I/O conversions and decimal rounding
60  * are easy.  Also, it's actually more efficient if NBASE is rather less than
61  * sqrt(INT_MAX), so that there is "headroom" for mul_var and div_var_fast to
62  * postpone processing carries.
63  *
64  * Values of NBASE other than 10000 are considered of historical interest only
65  * and are no longer supported in any sense; no mechanism exists for the client
66  * to discover the base, so every client supporting binary mode expects the
67  * base-10000 format.  If you plan to change this, also note the numeric
68  * abbreviation code, which assumes NBASE=10000.
69  * ----------
70  */
71 
72 #if 0
73 #define NBASE		10
74 #define HALF_NBASE	5
75 #define DEC_DIGITS	1			/* decimal digits per NBASE digit */
76 #define MUL_GUARD_DIGITS	4	/* these are measured in NBASE digits */
77 #define DIV_GUARD_DIGITS	8
78 
79 typedef signed char NumericDigit;
80 #endif
81 
82 #if 0
83 #define NBASE		100
84 #define HALF_NBASE	50
85 #define DEC_DIGITS	2			/* decimal digits per NBASE digit */
86 #define MUL_GUARD_DIGITS	3	/* these are measured in NBASE digits */
87 #define DIV_GUARD_DIGITS	6
88 
89 typedef signed char NumericDigit;
90 #endif
91 
92 #if 1
93 #define NBASE		10000
94 #define HALF_NBASE	5000
95 #define DEC_DIGITS	4			/* decimal digits per NBASE digit */
96 #define MUL_GUARD_DIGITS	2	/* these are measured in NBASE digits */
97 #define DIV_GUARD_DIGITS	4
98 
99 typedef int16 NumericDigit;
100 #endif
101 
102 /*
103  * The Numeric type as stored on disk.
104  *
105  * If the high bits of the first word of a NumericChoice (n_header, or
106  * n_short.n_header, or n_long.n_sign_dscale) are NUMERIC_SHORT, then the
107  * numeric follows the NumericShort format; if they are NUMERIC_POS or
108  * NUMERIC_NEG, it follows the NumericLong format.  If they are NUMERIC_NAN,
109  * it is a NaN.  We currently always store a NaN using just two bytes (i.e.
110  * only n_header), but previous releases used only the NumericLong format,
111  * so we might find 4-byte NaNs on disk if a database has been migrated using
112  * pg_upgrade.  In either case, when the high bits indicate a NaN, the
113  * remaining bits are never examined.  Currently, we always initialize these
114  * to zero, but it might be possible to use them for some other purpose in
115  * the future.
116  *
117  * In the NumericShort format, the remaining 14 bits of the header word
118  * (n_short.n_header) are allocated as follows: 1 for sign (positive or
119  * negative), 6 for dynamic scale, and 7 for weight.  In practice, most
120  * commonly-encountered values can be represented this way.
121  *
122  * In the NumericLong format, the remaining 14 bits of the header word
123  * (n_long.n_sign_dscale) represent the display scale; and the weight is
124  * stored separately in n_weight.
125  *
126  * NOTE: by convention, values in the packed form have been stripped of
127  * all leading and trailing zero digits (where a "digit" is of base NBASE).
128  * In particular, if the value is zero, there will be no digits at all!
129  * The weight is arbitrary in that case, but we normally set it to zero.
130  */
131 
132 struct NumericShort
133 {
134 	uint16		n_header;		/* Sign + display scale + weight */
135 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
136 };
137 
138 struct NumericLong
139 {
140 	uint16		n_sign_dscale;	/* Sign + display scale */
141 	int16		n_weight;		/* Weight of 1st digit	*/
142 	NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
143 };
144 
145 union NumericChoice
146 {
147 	uint16		n_header;		/* Header word */
148 	struct NumericLong n_long;	/* Long form (4-byte header) */
149 	struct NumericShort n_short;	/* Short form (2-byte header) */
150 };
151 
152 struct NumericData
153 {
154 	int32		vl_len_;		/* varlena header (do not touch directly!) */
155 	union NumericChoice choice; /* choice of format */
156 };
157 
158 
159 /*
160  * Interpretation of high bits.
161  */
162 
163 #define NUMERIC_SIGN_MASK	0xC000
164 #define NUMERIC_POS			0x0000
165 #define NUMERIC_NEG			0x4000
166 #define NUMERIC_SHORT		0x8000
167 #define NUMERIC_NAN			0xC000
168 
169 #define NUMERIC_FLAGBITS(n) ((n)->choice.n_header & NUMERIC_SIGN_MASK)
170 #define NUMERIC_IS_NAN(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_NAN)
171 #define NUMERIC_IS_SHORT(n)		(NUMERIC_FLAGBITS(n) == NUMERIC_SHORT)
172 
173 #define NUMERIC_HDRSZ	(VARHDRSZ + sizeof(uint16) + sizeof(int16))
174 #define NUMERIC_HDRSZ_SHORT (VARHDRSZ + sizeof(uint16))
175 
176 /*
177  * If the flag bits are NUMERIC_SHORT or NUMERIC_NAN, we want the short header;
178  * otherwise, we want the long one.  Instead of testing against each value, we
179  * can just look at the high bit, for a slight efficiency gain.
180  */
181 #define NUMERIC_HEADER_IS_SHORT(n)	(((n)->choice.n_header & 0x8000) != 0)
182 #define NUMERIC_HEADER_SIZE(n) \
183 	(VARHDRSZ + sizeof(uint16) + \
184 	 (NUMERIC_HEADER_IS_SHORT(n) ? 0 : sizeof(int16)))
185 
186 /*
187  * Short format definitions.
188  */
189 
190 #define NUMERIC_SHORT_SIGN_MASK			0x2000
191 #define NUMERIC_SHORT_DSCALE_MASK		0x1F80
192 #define NUMERIC_SHORT_DSCALE_SHIFT		7
193 #define NUMERIC_SHORT_DSCALE_MAX		\
194 	(NUMERIC_SHORT_DSCALE_MASK >> NUMERIC_SHORT_DSCALE_SHIFT)
195 #define NUMERIC_SHORT_WEIGHT_SIGN_MASK	0x0040
196 #define NUMERIC_SHORT_WEIGHT_MASK		0x003F
197 #define NUMERIC_SHORT_WEIGHT_MAX		NUMERIC_SHORT_WEIGHT_MASK
198 #define NUMERIC_SHORT_WEIGHT_MIN		(-(NUMERIC_SHORT_WEIGHT_MASK+1))
199 
200 /*
201  * Extract sign, display scale, weight.
202  */
203 
204 #define NUMERIC_DSCALE_MASK			0x3FFF
205 #define NUMERIC_DSCALE_MAX			NUMERIC_DSCALE_MASK
206 
207 #define NUMERIC_SIGN(n) \
208 	(NUMERIC_IS_SHORT(n) ? \
209 		(((n)->choice.n_short.n_header & NUMERIC_SHORT_SIGN_MASK) ? \
210 		NUMERIC_NEG : NUMERIC_POS) : NUMERIC_FLAGBITS(n))
211 #define NUMERIC_DSCALE(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
212 	((n)->choice.n_short.n_header & NUMERIC_SHORT_DSCALE_MASK) \
213 		>> NUMERIC_SHORT_DSCALE_SHIFT \
214 	: ((n)->choice.n_long.n_sign_dscale & NUMERIC_DSCALE_MASK))
215 #define NUMERIC_WEIGHT(n)	(NUMERIC_HEADER_IS_SHORT((n)) ? \
216 	(((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_SIGN_MASK ? \
217 		~NUMERIC_SHORT_WEIGHT_MASK : 0) \
218 	 | ((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_MASK)) \
219 	: ((n)->choice.n_long.n_weight))
220 
221 /* ----------
222  * NumericVar is the format we use for arithmetic.  The digit-array part
223  * is the same as the NumericData storage format, but the header is more
224  * complex.
225  *
226  * The value represented by a NumericVar is determined by the sign, weight,
227  * ndigits, and digits[] array.
228  *
229  * Note: the first digit of a NumericVar's value is assumed to be multiplied
230  * by NBASE ** weight.  Another way to say it is that there are weight+1
231  * digits before the decimal point.  It is possible to have weight < 0.
232  *
233  * buf points at the physical start of the palloc'd digit buffer for the
234  * NumericVar.  digits points at the first digit in actual use (the one
235  * with the specified weight).  We normally leave an unused digit or two
236  * (preset to zeroes) between buf and digits, so that there is room to store
237  * a carry out of the top digit without reallocating space.  We just need to
238  * decrement digits (and increment weight) to make room for the carry digit.
239  * (There is no such extra space in a numeric value stored in the database,
240  * only in a NumericVar in memory.)
241  *
242  * If buf is NULL then the digit buffer isn't actually palloc'd and should
243  * not be freed --- see the constants below for an example.
244  *
245  * dscale, or display scale, is the nominal precision expressed as number
246  * of digits after the decimal point (it must always be >= 0 at present).
247  * dscale may be more than the number of physically stored fractional digits,
248  * implying that we have suppressed storage of significant trailing zeroes.
249  * It should never be less than the number of stored digits, since that would
250  * imply hiding digits that are present.  NOTE that dscale is always expressed
251  * in *decimal* digits, and so it may correspond to a fractional number of
252  * base-NBASE digits --- divide by DEC_DIGITS to convert to NBASE digits.
253  *
254  * rscale, or result scale, is the target precision for a computation.
255  * Like dscale it is expressed as number of *decimal* digits after the decimal
256  * point, and is always >= 0 at present.
257  * Note that rscale is not stored in variables --- it's figured on-the-fly
258  * from the dscales of the inputs.
259  *
260  * While we consistently use "weight" to refer to the base-NBASE weight of
261  * a numeric value, it is convenient in some scale-related calculations to
262  * make use of the base-10 weight (ie, the approximate log10 of the value).
263  * To avoid confusion, such a decimal-units weight is called a "dweight".
264  *
265  * NB: All the variable-level functions are written in a style that makes it
266  * possible to give one and the same variable as argument and destination.
267  * This is feasible because the digit buffer is separate from the variable.
268  * ----------
269  */
270 typedef struct NumericVar
271 {
272 	int			ndigits;		/* # of digits in digits[] - can be 0! */
273 	int			weight;			/* weight of first digit */
274 	int			sign;			/* NUMERIC_POS, NUMERIC_NEG, or NUMERIC_NAN */
275 	int			dscale;			/* display scale */
276 	NumericDigit *buf;			/* start of palloc'd space for digits[] */
277 	NumericDigit *digits;		/* base-NBASE digits */
278 } NumericVar;
279 
280 
281 /* ----------
282  * Data for generate_series
283  * ----------
284  */
285 typedef struct
286 {
287 	NumericVar	current;
288 	NumericVar	stop;
289 	NumericVar	step;
290 } generate_series_numeric_fctx;
291 
292 
293 /* ----------
294  * Sort support.
295  * ----------
296  */
297 typedef struct
298 {
299 	void	   *buf;			/* buffer for short varlenas */
300 	int64		input_count;	/* number of non-null values seen */
301 	bool		estimating;		/* true if estimating cardinality */
302 
303 	hyperLogLogState abbr_card; /* cardinality estimator */
304 } NumericSortSupport;
305 
306 /*
307  * We define our own macros for packing and unpacking abbreviated-key
308  * representations for numeric values in order to avoid depending on
309  * USE_FLOAT8_BYVAL.  The type of abbreviation we use is based only on
310  * the size of a datum, not the argument-passing convention for float8.
311  */
312 #define NUMERIC_ABBREV_BITS (SIZEOF_DATUM * BITS_PER_BYTE)
313 #if SIZEOF_DATUM == 8
314 #define NumericAbbrevGetDatum(X) ((Datum) SET_8_BYTES(X))
315 #define DatumGetNumericAbbrev(X) ((int64) GET_8_BYTES(X))
316 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT64_MIN)
317 #else
318 #define NumericAbbrevGetDatum(X) ((Datum) SET_4_BYTES(X))
319 #define DatumGetNumericAbbrev(X) ((int32) GET_4_BYTES(X))
320 #define NUMERIC_ABBREV_NAN		 NumericAbbrevGetDatum(PG_INT32_MIN)
321 #endif
322 
323 
324 /* ----------
325  * Some preinitialized constants
326  * ----------
327  */
328 static NumericDigit const_zero_data[1] = {0};
329 static NumericVar const_zero =
330 {0, 0, NUMERIC_POS, 0, NULL, const_zero_data};
331 
332 static NumericDigit const_one_data[1] = {1};
333 static NumericVar const_one =
334 {1, 0, NUMERIC_POS, 0, NULL, const_one_data};
335 
336 static NumericDigit const_two_data[1] = {2};
337 static NumericVar const_two =
338 {1, 0, NUMERIC_POS, 0, NULL, const_two_data};
339 
340 #if DEC_DIGITS == 4
341 static NumericDigit const_zero_point_five_data[1] = {5000};
342 #elif DEC_DIGITS == 2
343 static NumericDigit const_zero_point_five_data[1] = {50};
344 #elif DEC_DIGITS == 1
345 static NumericDigit const_zero_point_five_data[1] = {5};
346 #endif
347 static NumericVar const_zero_point_five =
348 {1, -1, NUMERIC_POS, 1, NULL, const_zero_point_five_data};
349 
350 #if DEC_DIGITS == 4
351 static NumericDigit const_zero_point_nine_data[1] = {9000};
352 #elif DEC_DIGITS == 2
353 static NumericDigit const_zero_point_nine_data[1] = {90};
354 #elif DEC_DIGITS == 1
355 static NumericDigit const_zero_point_nine_data[1] = {9};
356 #endif
357 static NumericVar const_zero_point_nine =
358 {1, -1, NUMERIC_POS, 1, NULL, const_zero_point_nine_data};
359 
360 #if DEC_DIGITS == 4
361 static NumericDigit const_one_point_one_data[2] = {1, 1000};
362 #elif DEC_DIGITS == 2
363 static NumericDigit const_one_point_one_data[2] = {1, 10};
364 #elif DEC_DIGITS == 1
365 static NumericDigit const_one_point_one_data[2] = {1, 1};
366 #endif
367 static NumericVar const_one_point_one =
368 {2, 0, NUMERIC_POS, 1, NULL, const_one_point_one_data};
369 
370 static NumericVar const_nan =
371 {0, 0, NUMERIC_NAN, 0, NULL, NULL};
372 
373 #if DEC_DIGITS == 4
374 static const int round_powers[4] = {0, 1000, 100, 10};
375 #endif
376 
377 
378 /* ----------
379  * Local functions
380  * ----------
381  */
382 
383 #ifdef NUMERIC_DEBUG
384 static void dump_numeric(const char *str, Numeric num);
385 static void dump_var(const char *str, NumericVar *var);
386 #else
387 #define dump_numeric(s,n)
388 #define dump_var(s,v)
389 #endif
390 
391 #define digitbuf_alloc(ndigits)  \
392 	((NumericDigit *) palloc((ndigits) * sizeof(NumericDigit)))
393 #define digitbuf_free(buf)	\
394 	do { \
395 		 if ((buf) != NULL) \
396 			 pfree(buf); \
397 	} while (0)
398 
399 #define init_var(v)		MemSetAligned(v, 0, sizeof(NumericVar))
400 
401 #define NUMERIC_DIGITS(num) (NUMERIC_HEADER_IS_SHORT(num) ? \
402 	(num)->choice.n_short.n_data : (num)->choice.n_long.n_data)
403 #define NUMERIC_NDIGITS(num) \
404 	((VARSIZE(num) - NUMERIC_HEADER_SIZE(num)) / sizeof(NumericDigit))
405 #define NUMERIC_CAN_BE_SHORT(scale,weight) \
406 	((scale) <= NUMERIC_SHORT_DSCALE_MAX && \
407 	(weight) <= NUMERIC_SHORT_WEIGHT_MAX && \
408 	(weight) >= NUMERIC_SHORT_WEIGHT_MIN)
409 
410 static void alloc_var(NumericVar *var, int ndigits);
411 static void free_var(NumericVar *var);
412 static void zero_var(NumericVar *var);
413 
414 static const char *set_var_from_str(const char *str, const char *cp,
415 				 NumericVar *dest);
416 static void set_var_from_num(Numeric value, NumericVar *dest);
417 static void init_var_from_num(Numeric num, NumericVar *dest);
418 static void set_var_from_var(NumericVar *value, NumericVar *dest);
419 static char *get_str_from_var(NumericVar *var);
420 static char *get_str_from_var_sci(NumericVar *var, int rscale);
421 
422 static Numeric make_result(NumericVar *var);
423 
424 static void apply_typmod(NumericVar *var, int32 typmod);
425 
426 static int32 numericvar_to_int32(NumericVar *var);
427 static bool numericvar_to_int64(NumericVar *var, int64 *result);
428 static void int64_to_numericvar(int64 val, NumericVar *var);
429 #ifdef HAVE_INT128
430 static bool numericvar_to_int128(NumericVar *var, int128 *result);
431 static void int128_to_numericvar(int128 val, NumericVar *var);
432 #endif
433 static double numeric_to_double_no_overflow(Numeric num);
434 static double numericvar_to_double_no_overflow(NumericVar *var);
435 
436 static Datum numeric_abbrev_convert(Datum original_datum, SortSupport ssup);
437 static bool numeric_abbrev_abort(int memtupcount, SortSupport ssup);
438 static int	numeric_fast_cmp(Datum x, Datum y, SortSupport ssup);
439 static int	numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup);
440 
441 static Datum numeric_abbrev_convert_var(NumericVar *var, NumericSortSupport *nss);
442 
443 static int	cmp_numerics(Numeric num1, Numeric num2);
444 static int	cmp_var(NumericVar *var1, NumericVar *var2);
445 static int cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
446 			   int var1weight, int var1sign,
447 			   const NumericDigit *var2digits, int var2ndigits,
448 			   int var2weight, int var2sign);
449 static void add_var(NumericVar *var1, NumericVar *var2, NumericVar *result);
450 static void sub_var(NumericVar *var1, NumericVar *var2, NumericVar *result);
451 static void mul_var(NumericVar *var1, NumericVar *var2, NumericVar *result,
452 		int rscale);
453 static void div_var(NumericVar *var1, NumericVar *var2, NumericVar *result,
454 		int rscale, bool round);
455 static void div_var_fast(NumericVar *var1, NumericVar *var2, NumericVar *result,
456 			 int rscale, bool round);
457 static int	select_div_scale(NumericVar *var1, NumericVar *var2);
458 static void mod_var(NumericVar *var1, NumericVar *var2, NumericVar *result);
459 static void ceil_var(NumericVar *var, NumericVar *result);
460 static void floor_var(NumericVar *var, NumericVar *result);
461 
462 static void sqrt_var(NumericVar *arg, NumericVar *result, int rscale);
463 static void exp_var(NumericVar *arg, NumericVar *result, int rscale);
464 static int	estimate_ln_dweight(NumericVar *var);
465 static void ln_var(NumericVar *arg, NumericVar *result, int rscale);
466 static void log_var(NumericVar *base, NumericVar *num, NumericVar *result);
467 static void power_var(NumericVar *base, NumericVar *exp, NumericVar *result);
468 static void power_var_int(NumericVar *base, int exp, NumericVar *result,
469 			  int rscale);
470 static void power_ten_int(int exp, NumericVar *result);
471 
472 static int	cmp_abs(NumericVar *var1, NumericVar *var2);
473 static int cmp_abs_common(const NumericDigit *var1digits, int var1ndigits,
474 			   int var1weight,
475 			   const NumericDigit *var2digits, int var2ndigits,
476 			   int var2weight);
477 static void add_abs(NumericVar *var1, NumericVar *var2, NumericVar *result);
478 static void sub_abs(NumericVar *var1, NumericVar *var2, NumericVar *result);
479 static void round_var(NumericVar *var, int rscale);
480 static void trunc_var(NumericVar *var, int rscale);
481 static void strip_var(NumericVar *var);
482 static void compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
483 			   NumericVar *count_var, NumericVar *result_var);
484 
485 
486 /* ----------------------------------------------------------------------
487  *
488  * Input-, output- and rounding-functions
489  *
490  * ----------------------------------------------------------------------
491  */
492 
493 
494 /*
495  * numeric_in() -
496  *
497  *	Input function for numeric data type
498  */
499 Datum
numeric_in(PG_FUNCTION_ARGS)500 numeric_in(PG_FUNCTION_ARGS)
501 {
502 	char	   *str = PG_GETARG_CSTRING(0);
503 
504 #ifdef NOT_USED
505 	Oid			typelem = PG_GETARG_OID(1);
506 #endif
507 	int32		typmod = PG_GETARG_INT32(2);
508 	Numeric		res;
509 	const char *cp;
510 
511 	/* Skip leading spaces */
512 	cp = str;
513 	while (*cp)
514 	{
515 		if (!isspace((unsigned char) *cp))
516 			break;
517 		cp++;
518 	}
519 
520 	/*
521 	 * Check for NaN
522 	 */
523 	if (pg_strncasecmp(cp, "NaN", 3) == 0)
524 	{
525 		res = make_result(&const_nan);
526 
527 		/* Should be nothing left but spaces */
528 		cp += 3;
529 		while (*cp)
530 		{
531 			if (!isspace((unsigned char) *cp))
532 				ereport(ERROR,
533 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
534 					  errmsg("invalid input syntax for type numeric: \"%s\"",
535 							 str)));
536 			cp++;
537 		}
538 	}
539 	else
540 	{
541 		/*
542 		 * Use set_var_from_str() to parse a normal numeric value
543 		 */
544 		NumericVar	value;
545 
546 		init_var(&value);
547 
548 		cp = set_var_from_str(str, cp, &value);
549 
550 		/*
551 		 * We duplicate a few lines of code here because we would like to
552 		 * throw any trailing-junk syntax error before any semantic error
553 		 * resulting from apply_typmod.  We can't easily fold the two cases
554 		 * together because we mustn't apply apply_typmod to a NaN.
555 		 */
556 		while (*cp)
557 		{
558 			if (!isspace((unsigned char) *cp))
559 				ereport(ERROR,
560 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
561 					  errmsg("invalid input syntax for type numeric: \"%s\"",
562 							 str)));
563 			cp++;
564 		}
565 
566 		apply_typmod(&value, typmod);
567 
568 		res = make_result(&value);
569 		free_var(&value);
570 	}
571 
572 	PG_RETURN_NUMERIC(res);
573 }
574 
575 
576 /*
577  * numeric_out() -
578  *
579  *	Output function for numeric data type
580  */
581 Datum
numeric_out(PG_FUNCTION_ARGS)582 numeric_out(PG_FUNCTION_ARGS)
583 {
584 	Numeric		num = PG_GETARG_NUMERIC(0);
585 	NumericVar	x;
586 	char	   *str;
587 
588 	/*
589 	 * Handle NaN
590 	 */
591 	if (NUMERIC_IS_NAN(num))
592 		PG_RETURN_CSTRING(pstrdup("NaN"));
593 
594 	/*
595 	 * Get the number in the variable format.
596 	 */
597 	init_var_from_num(num, &x);
598 
599 	str = get_str_from_var(&x);
600 
601 	PG_RETURN_CSTRING(str);
602 }
603 
604 /*
605  * numeric_is_nan() -
606  *
607  *	Is Numeric value a NaN?
608  */
609 bool
numeric_is_nan(Numeric num)610 numeric_is_nan(Numeric num)
611 {
612 	return NUMERIC_IS_NAN(num);
613 }
614 
615 /*
616  * numeric_maximum_size() -
617  *
618  *	Maximum size of a numeric with given typmod, or -1 if unlimited/unknown.
619  */
620 int32
numeric_maximum_size(int32 typmod)621 numeric_maximum_size(int32 typmod)
622 {
623 	int			precision;
624 	int			numeric_digits;
625 
626 	if (typmod < (int32) (VARHDRSZ))
627 		return -1;
628 
629 	/* precision (ie, max # of digits) is in upper bits of typmod */
630 	precision = ((typmod - VARHDRSZ) >> 16) & 0xffff;
631 
632 	/*
633 	 * This formula computes the maximum number of NumericDigits we could need
634 	 * in order to store the specified number of decimal digits. Because the
635 	 * weight is stored as a number of NumericDigits rather than a number of
636 	 * decimal digits, it's possible that the first NumericDigit will contain
637 	 * only a single decimal digit.  Thus, the first two decimal digits can
638 	 * require two NumericDigits to store, but it isn't until we reach
639 	 * DEC_DIGITS + 2 decimal digits that we potentially need a third
640 	 * NumericDigit.
641 	 */
642 	numeric_digits = (precision + 2 * (DEC_DIGITS - 1)) / DEC_DIGITS;
643 
644 	/*
645 	 * In most cases, the size of a numeric will be smaller than the value
646 	 * computed below, because the varlena header will typically get toasted
647 	 * down to a single byte before being stored on disk, and it may also be
648 	 * possible to use a short numeric header.  But our job here is to compute
649 	 * the worst case.
650 	 */
651 	return NUMERIC_HDRSZ + (numeric_digits * sizeof(NumericDigit));
652 }
653 
654 /*
655  * numeric_out_sci() -
656  *
657  *	Output function for numeric data type in scientific notation.
658  */
659 char *
numeric_out_sci(Numeric num,int scale)660 numeric_out_sci(Numeric num, int scale)
661 {
662 	NumericVar	x;
663 	char	   *str;
664 
665 	/*
666 	 * Handle NaN
667 	 */
668 	if (NUMERIC_IS_NAN(num))
669 		return pstrdup("NaN");
670 
671 	init_var_from_num(num, &x);
672 
673 	str = get_str_from_var_sci(&x, scale);
674 
675 	return str;
676 }
677 
678 /*
679  * numeric_normalize() -
680  *
681  *	Output function for numeric data type, suppressing insignificant trailing
682  *	zeroes and then any trailing decimal point.  The intent of this is to
683  *	produce strings that are equal if and only if the input numeric values
684  *	compare equal.
685  */
686 char *
numeric_normalize(Numeric num)687 numeric_normalize(Numeric num)
688 {
689 	NumericVar	x;
690 	char	   *str;
691 	int			last;
692 
693 	/*
694 	 * Handle NaN
695 	 */
696 	if (NUMERIC_IS_NAN(num))
697 		return pstrdup("NaN");
698 
699 	init_var_from_num(num, &x);
700 
701 	str = get_str_from_var(&x);
702 
703 	/* If there's no decimal point, there's certainly nothing to remove. */
704 	if (strchr(str, '.') != NULL)
705 	{
706 		/*
707 		 * Back up over trailing fractional zeroes.  Since there is a decimal
708 		 * point, this loop will terminate safely.
709 		 */
710 		last = strlen(str) - 1;
711 		while (str[last] == '0')
712 			last--;
713 
714 		/* We want to get rid of the decimal point too, if it's now last. */
715 		if (str[last] == '.')
716 			last--;
717 
718 		/* Delete whatever we backed up over. */
719 		str[last + 1] = '\0';
720 	}
721 
722 	return str;
723 }
724 
725 /*
726  *		numeric_recv			- converts external binary format to numeric
727  *
728  * External format is a sequence of int16's:
729  * ndigits, weight, sign, dscale, NumericDigits.
730  */
731 Datum
numeric_recv(PG_FUNCTION_ARGS)732 numeric_recv(PG_FUNCTION_ARGS)
733 {
734 	StringInfo	buf = (StringInfo) PG_GETARG_POINTER(0);
735 
736 #ifdef NOT_USED
737 	Oid			typelem = PG_GETARG_OID(1);
738 #endif
739 	int32		typmod = PG_GETARG_INT32(2);
740 	NumericVar	value;
741 	Numeric		res;
742 	int			len,
743 				i;
744 
745 	init_var(&value);
746 
747 	len = (uint16) pq_getmsgint(buf, sizeof(uint16));
748 
749 	alloc_var(&value, len);
750 
751 	value.weight = (int16) pq_getmsgint(buf, sizeof(int16));
752 	/* we allow any int16 for weight --- OK? */
753 
754 	value.sign = (uint16) pq_getmsgint(buf, sizeof(uint16));
755 	if (!(value.sign == NUMERIC_POS ||
756 		  value.sign == NUMERIC_NEG ||
757 		  value.sign == NUMERIC_NAN))
758 		ereport(ERROR,
759 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
760 				 errmsg("invalid sign in external \"numeric\" value")));
761 
762 	value.dscale = (uint16) pq_getmsgint(buf, sizeof(uint16));
763 	if ((value.dscale & NUMERIC_DSCALE_MASK) != value.dscale)
764 		ereport(ERROR,
765 				(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
766 				 errmsg("invalid scale in external \"numeric\" value")));
767 
768 	for (i = 0; i < len; i++)
769 	{
770 		NumericDigit d = pq_getmsgint(buf, sizeof(NumericDigit));
771 
772 		if (d < 0 || d >= NBASE)
773 			ereport(ERROR,
774 					(errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
775 					 errmsg("invalid digit in external \"numeric\" value")));
776 		value.digits[i] = d;
777 	}
778 
779 	/*
780 	 * If the given dscale would hide any digits, truncate those digits away.
781 	 * We could alternatively throw an error, but that would take a bunch of
782 	 * extra code (about as much as trunc_var involves), and it might cause
783 	 * client compatibility issues.
784 	 */
785 	trunc_var(&value, value.dscale);
786 
787 	apply_typmod(&value, typmod);
788 
789 	res = make_result(&value);
790 	free_var(&value);
791 
792 	PG_RETURN_NUMERIC(res);
793 }
794 
795 /*
796  *		numeric_send			- converts numeric to binary format
797  */
798 Datum
numeric_send(PG_FUNCTION_ARGS)799 numeric_send(PG_FUNCTION_ARGS)
800 {
801 	Numeric		num = PG_GETARG_NUMERIC(0);
802 	NumericVar	x;
803 	StringInfoData buf;
804 	int			i;
805 
806 	init_var_from_num(num, &x);
807 
808 	pq_begintypsend(&buf);
809 
810 	pq_sendint(&buf, x.ndigits, sizeof(int16));
811 	pq_sendint(&buf, x.weight, sizeof(int16));
812 	pq_sendint(&buf, x.sign, sizeof(int16));
813 	pq_sendint(&buf, x.dscale, sizeof(int16));
814 	for (i = 0; i < x.ndigits; i++)
815 		pq_sendint(&buf, x.digits[i], sizeof(NumericDigit));
816 
817 	PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
818 }
819 
820 
821 /*
822  * numeric_transform() -
823  *
824  * Flatten calls to numeric's length coercion function that solely represent
825  * increases in allowable precision.  Scale changes mutate every datum, so
826  * they are unoptimizable.  Some values, e.g. 1E-1001, can only fit into an
827  * unconstrained numeric, so a change from an unconstrained numeric to any
828  * constrained numeric is also unoptimizable.
829  */
830 Datum
numeric_transform(PG_FUNCTION_ARGS)831 numeric_transform(PG_FUNCTION_ARGS)
832 {
833 	FuncExpr   *expr = (FuncExpr *) PG_GETARG_POINTER(0);
834 	Node	   *ret = NULL;
835 	Node	   *typmod;
836 
837 	Assert(IsA(expr, FuncExpr));
838 	Assert(list_length(expr->args) >= 2);
839 
840 	typmod = (Node *) lsecond(expr->args);
841 
842 	if (IsA(typmod, Const) &&!((Const *) typmod)->constisnull)
843 	{
844 		Node	   *source = (Node *) linitial(expr->args);
845 		int32		old_typmod = exprTypmod(source);
846 		int32		new_typmod = DatumGetInt32(((Const *) typmod)->constvalue);
847 		int32		old_scale = (old_typmod - VARHDRSZ) & 0xffff;
848 		int32		new_scale = (new_typmod - VARHDRSZ) & 0xffff;
849 		int32		old_precision = (old_typmod - VARHDRSZ) >> 16 & 0xffff;
850 		int32		new_precision = (new_typmod - VARHDRSZ) >> 16 & 0xffff;
851 
852 		/*
853 		 * If new_typmod < VARHDRSZ, the destination is unconstrained; that's
854 		 * always OK.  If old_typmod >= VARHDRSZ, the source is constrained,
855 		 * and we're OK if the scale is unchanged and the precision is not
856 		 * decreasing.  See further notes in function header comment.
857 		 */
858 		if (new_typmod < (int32) VARHDRSZ ||
859 			(old_typmod >= (int32) VARHDRSZ &&
860 			 new_scale == old_scale && new_precision >= old_precision))
861 			ret = relabel_to_typmod(source, new_typmod);
862 	}
863 
864 	PG_RETURN_POINTER(ret);
865 }
866 
867 /*
868  * numeric() -
869  *
870  *	This is a special function called by the Postgres database system
871  *	before a value is stored in a tuple's attribute. The precision and
872  *	scale of the attribute have to be applied on the value.
873  */
874 Datum
numeric(PG_FUNCTION_ARGS)875 numeric		(PG_FUNCTION_ARGS)
876 {
877 	Numeric		num = PG_GETARG_NUMERIC(0);
878 	int32		typmod = PG_GETARG_INT32(1);
879 	Numeric		new;
880 	int32		tmp_typmod;
881 	int			precision;
882 	int			scale;
883 	int			ddigits;
884 	int			maxdigits;
885 	NumericVar	var;
886 
887 	/*
888 	 * Handle NaN
889 	 */
890 	if (NUMERIC_IS_NAN(num))
891 		PG_RETURN_NUMERIC(make_result(&const_nan));
892 
893 	/*
894 	 * If the value isn't a valid type modifier, simply return a copy of the
895 	 * input value
896 	 */
897 	if (typmod < (int32) (VARHDRSZ))
898 	{
899 		new = (Numeric) palloc(VARSIZE(num));
900 		memcpy(new, num, VARSIZE(num));
901 		PG_RETURN_NUMERIC(new);
902 	}
903 
904 	/*
905 	 * Get the precision and scale out of the typmod value
906 	 */
907 	tmp_typmod = typmod - VARHDRSZ;
908 	precision = (tmp_typmod >> 16) & 0xffff;
909 	scale = tmp_typmod & 0xffff;
910 	maxdigits = precision - scale;
911 
912 	/*
913 	 * If the number is certainly in bounds and due to the target scale no
914 	 * rounding could be necessary, just make a copy of the input and modify
915 	 * its scale fields, unless the larger scale forces us to abandon the
916 	 * short representation.  (Note we assume the existing dscale is
917 	 * honest...)
918 	 */
919 	ddigits = (NUMERIC_WEIGHT(num) + 1) * DEC_DIGITS;
920 	if (ddigits <= maxdigits && scale >= NUMERIC_DSCALE(num)
921 		&& (NUMERIC_CAN_BE_SHORT(scale, NUMERIC_WEIGHT(num))
922 			|| !NUMERIC_IS_SHORT(num)))
923 	{
924 		new = (Numeric) palloc(VARSIZE(num));
925 		memcpy(new, num, VARSIZE(num));
926 		if (NUMERIC_IS_SHORT(num))
927 			new->choice.n_short.n_header =
928 				(num->choice.n_short.n_header & ~NUMERIC_SHORT_DSCALE_MASK)
929 				| (scale << NUMERIC_SHORT_DSCALE_SHIFT);
930 		else
931 			new->choice.n_long.n_sign_dscale = NUMERIC_SIGN(new) |
932 				((uint16) scale & NUMERIC_DSCALE_MASK);
933 		PG_RETURN_NUMERIC(new);
934 	}
935 
936 	/*
937 	 * We really need to fiddle with things - unpack the number into a
938 	 * variable and let apply_typmod() do it.
939 	 */
940 	init_var(&var);
941 
942 	set_var_from_num(num, &var);
943 	apply_typmod(&var, typmod);
944 	new = make_result(&var);
945 
946 	free_var(&var);
947 
948 	PG_RETURN_NUMERIC(new);
949 }
950 
951 Datum
numerictypmodin(PG_FUNCTION_ARGS)952 numerictypmodin(PG_FUNCTION_ARGS)
953 {
954 	ArrayType  *ta = PG_GETARG_ARRAYTYPE_P(0);
955 	int32	   *tl;
956 	int			n;
957 	int32		typmod;
958 
959 	tl = ArrayGetIntegerTypmods(ta, &n);
960 
961 	if (n == 2)
962 	{
963 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
964 			ereport(ERROR,
965 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
966 					 errmsg("NUMERIC precision %d must be between 1 and %d",
967 							tl[0], NUMERIC_MAX_PRECISION)));
968 		if (tl[1] < 0 || tl[1] > tl[0])
969 			ereport(ERROR,
970 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
971 				errmsg("NUMERIC scale %d must be between 0 and precision %d",
972 					   tl[1], tl[0])));
973 		typmod = ((tl[0] << 16) | tl[1]) + VARHDRSZ;
974 	}
975 	else if (n == 1)
976 	{
977 		if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
978 			ereport(ERROR,
979 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
980 					 errmsg("NUMERIC precision %d must be between 1 and %d",
981 							tl[0], NUMERIC_MAX_PRECISION)));
982 		/* scale defaults to zero */
983 		typmod = (tl[0] << 16) + VARHDRSZ;
984 	}
985 	else
986 	{
987 		ereport(ERROR,
988 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
989 				 errmsg("invalid NUMERIC type modifier")));
990 		typmod = 0;				/* keep compiler quiet */
991 	}
992 
993 	PG_RETURN_INT32(typmod);
994 }
995 
996 Datum
numerictypmodout(PG_FUNCTION_ARGS)997 numerictypmodout(PG_FUNCTION_ARGS)
998 {
999 	int32		typmod = PG_GETARG_INT32(0);
1000 	char	   *res = (char *) palloc(64);
1001 
1002 	if (typmod >= 0)
1003 		snprintf(res, 64, "(%d,%d)",
1004 				 ((typmod - VARHDRSZ) >> 16) & 0xffff,
1005 				 (typmod - VARHDRSZ) & 0xffff);
1006 	else
1007 		*res = '\0';
1008 
1009 	PG_RETURN_CSTRING(res);
1010 }
1011 
1012 
1013 /* ----------------------------------------------------------------------
1014  *
1015  * Sign manipulation, rounding and the like
1016  *
1017  * ----------------------------------------------------------------------
1018  */
1019 
1020 Datum
numeric_abs(PG_FUNCTION_ARGS)1021 numeric_abs(PG_FUNCTION_ARGS)
1022 {
1023 	Numeric		num = PG_GETARG_NUMERIC(0);
1024 	Numeric		res;
1025 
1026 	/*
1027 	 * Handle NaN
1028 	 */
1029 	if (NUMERIC_IS_NAN(num))
1030 		PG_RETURN_NUMERIC(make_result(&const_nan));
1031 
1032 	/*
1033 	 * Do it the easy way directly on the packed format
1034 	 */
1035 	res = (Numeric) palloc(VARSIZE(num));
1036 	memcpy(res, num, VARSIZE(num));
1037 
1038 	if (NUMERIC_IS_SHORT(num))
1039 		res->choice.n_short.n_header =
1040 			num->choice.n_short.n_header & ~NUMERIC_SHORT_SIGN_MASK;
1041 	else
1042 		res->choice.n_long.n_sign_dscale = NUMERIC_POS | NUMERIC_DSCALE(num);
1043 
1044 	PG_RETURN_NUMERIC(res);
1045 }
1046 
1047 
1048 Datum
numeric_uminus(PG_FUNCTION_ARGS)1049 numeric_uminus(PG_FUNCTION_ARGS)
1050 {
1051 	Numeric		num = PG_GETARG_NUMERIC(0);
1052 	Numeric		res;
1053 
1054 	/*
1055 	 * Handle NaN
1056 	 */
1057 	if (NUMERIC_IS_NAN(num))
1058 		PG_RETURN_NUMERIC(make_result(&const_nan));
1059 
1060 	/*
1061 	 * Do it the easy way directly on the packed format
1062 	 */
1063 	res = (Numeric) palloc(VARSIZE(num));
1064 	memcpy(res, num, VARSIZE(num));
1065 
1066 	/*
1067 	 * The packed format is known to be totally zero digit trimmed always. So
1068 	 * we can identify a ZERO by the fact that there are no digits at all.  Do
1069 	 * nothing to a zero.
1070 	 */
1071 	if (NUMERIC_NDIGITS(num) != 0)
1072 	{
1073 		/* Else, flip the sign */
1074 		if (NUMERIC_IS_SHORT(num))
1075 			res->choice.n_short.n_header =
1076 				num->choice.n_short.n_header ^ NUMERIC_SHORT_SIGN_MASK;
1077 		else if (NUMERIC_SIGN(num) == NUMERIC_POS)
1078 			res->choice.n_long.n_sign_dscale =
1079 				NUMERIC_NEG | NUMERIC_DSCALE(num);
1080 		else
1081 			res->choice.n_long.n_sign_dscale =
1082 				NUMERIC_POS | NUMERIC_DSCALE(num);
1083 	}
1084 
1085 	PG_RETURN_NUMERIC(res);
1086 }
1087 
1088 
1089 Datum
numeric_uplus(PG_FUNCTION_ARGS)1090 numeric_uplus(PG_FUNCTION_ARGS)
1091 {
1092 	Numeric		num = PG_GETARG_NUMERIC(0);
1093 	Numeric		res;
1094 
1095 	res = (Numeric) palloc(VARSIZE(num));
1096 	memcpy(res, num, VARSIZE(num));
1097 
1098 	PG_RETURN_NUMERIC(res);
1099 }
1100 
1101 /*
1102  * numeric_sign() -
1103  *
1104  * returns -1 if the argument is less than 0, 0 if the argument is equal
1105  * to 0, and 1 if the argument is greater than zero.
1106  */
1107 Datum
numeric_sign(PG_FUNCTION_ARGS)1108 numeric_sign(PG_FUNCTION_ARGS)
1109 {
1110 	Numeric		num = PG_GETARG_NUMERIC(0);
1111 	Numeric		res;
1112 	NumericVar	result;
1113 
1114 	/*
1115 	 * Handle NaN
1116 	 */
1117 	if (NUMERIC_IS_NAN(num))
1118 		PG_RETURN_NUMERIC(make_result(&const_nan));
1119 
1120 	init_var(&result);
1121 
1122 	/*
1123 	 * The packed format is known to be totally zero digit trimmed always. So
1124 	 * we can identify a ZERO by the fact that there are no digits at all.
1125 	 */
1126 	if (NUMERIC_NDIGITS(num) == 0)
1127 		set_var_from_var(&const_zero, &result);
1128 	else
1129 	{
1130 		/*
1131 		 * And if there are some, we return a copy of ONE with the sign of our
1132 		 * argument
1133 		 */
1134 		set_var_from_var(&const_one, &result);
1135 		result.sign = NUMERIC_SIGN(num);
1136 	}
1137 
1138 	res = make_result(&result);
1139 	free_var(&result);
1140 
1141 	PG_RETURN_NUMERIC(res);
1142 }
1143 
1144 
1145 /*
1146  * numeric_round() -
1147  *
1148  *	Round a value to have 'scale' digits after the decimal point.
1149  *	We allow negative 'scale', implying rounding before the decimal
1150  *	point --- Oracle interprets rounding that way.
1151  */
1152 Datum
numeric_round(PG_FUNCTION_ARGS)1153 numeric_round(PG_FUNCTION_ARGS)
1154 {
1155 	Numeric		num = PG_GETARG_NUMERIC(0);
1156 	int32		scale = PG_GETARG_INT32(1);
1157 	Numeric		res;
1158 	NumericVar	arg;
1159 
1160 	/*
1161 	 * Handle NaN
1162 	 */
1163 	if (NUMERIC_IS_NAN(num))
1164 		PG_RETURN_NUMERIC(make_result(&const_nan));
1165 
1166 	/*
1167 	 * Limit the scale value to avoid possible overflow in calculations
1168 	 */
1169 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1170 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1171 
1172 	/*
1173 	 * Unpack the argument and round it at the proper digit position
1174 	 */
1175 	init_var(&arg);
1176 	set_var_from_num(num, &arg);
1177 
1178 	round_var(&arg, scale);
1179 
1180 	/* We don't allow negative output dscale */
1181 	if (scale < 0)
1182 		arg.dscale = 0;
1183 
1184 	/*
1185 	 * Return the rounded result
1186 	 */
1187 	res = make_result(&arg);
1188 
1189 	free_var(&arg);
1190 	PG_RETURN_NUMERIC(res);
1191 }
1192 
1193 
1194 /*
1195  * numeric_trunc() -
1196  *
1197  *	Truncate a value to have 'scale' digits after the decimal point.
1198  *	We allow negative 'scale', implying a truncation before the decimal
1199  *	point --- Oracle interprets truncation that way.
1200  */
1201 Datum
numeric_trunc(PG_FUNCTION_ARGS)1202 numeric_trunc(PG_FUNCTION_ARGS)
1203 {
1204 	Numeric		num = PG_GETARG_NUMERIC(0);
1205 	int32		scale = PG_GETARG_INT32(1);
1206 	Numeric		res;
1207 	NumericVar	arg;
1208 
1209 	/*
1210 	 * Handle NaN
1211 	 */
1212 	if (NUMERIC_IS_NAN(num))
1213 		PG_RETURN_NUMERIC(make_result(&const_nan));
1214 
1215 	/*
1216 	 * Limit the scale value to avoid possible overflow in calculations
1217 	 */
1218 	scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1219 	scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1220 
1221 	/*
1222 	 * Unpack the argument and truncate it at the proper digit position
1223 	 */
1224 	init_var(&arg);
1225 	set_var_from_num(num, &arg);
1226 
1227 	trunc_var(&arg, scale);
1228 
1229 	/* We don't allow negative output dscale */
1230 	if (scale < 0)
1231 		arg.dscale = 0;
1232 
1233 	/*
1234 	 * Return the truncated result
1235 	 */
1236 	res = make_result(&arg);
1237 
1238 	free_var(&arg);
1239 	PG_RETURN_NUMERIC(res);
1240 }
1241 
1242 
1243 /*
1244  * numeric_ceil() -
1245  *
1246  *	Return the smallest integer greater than or equal to the argument
1247  */
1248 Datum
numeric_ceil(PG_FUNCTION_ARGS)1249 numeric_ceil(PG_FUNCTION_ARGS)
1250 {
1251 	Numeric		num = PG_GETARG_NUMERIC(0);
1252 	Numeric		res;
1253 	NumericVar	result;
1254 
1255 	if (NUMERIC_IS_NAN(num))
1256 		PG_RETURN_NUMERIC(make_result(&const_nan));
1257 
1258 	init_var_from_num(num, &result);
1259 	ceil_var(&result, &result);
1260 
1261 	res = make_result(&result);
1262 	free_var(&result);
1263 
1264 	PG_RETURN_NUMERIC(res);
1265 }
1266 
1267 
1268 /*
1269  * numeric_floor() -
1270  *
1271  *	Return the largest integer equal to or less than the argument
1272  */
1273 Datum
numeric_floor(PG_FUNCTION_ARGS)1274 numeric_floor(PG_FUNCTION_ARGS)
1275 {
1276 	Numeric		num = PG_GETARG_NUMERIC(0);
1277 	Numeric		res;
1278 	NumericVar	result;
1279 
1280 	if (NUMERIC_IS_NAN(num))
1281 		PG_RETURN_NUMERIC(make_result(&const_nan));
1282 
1283 	init_var_from_num(num, &result);
1284 	floor_var(&result, &result);
1285 
1286 	res = make_result(&result);
1287 	free_var(&result);
1288 
1289 	PG_RETURN_NUMERIC(res);
1290 }
1291 
1292 
1293 /*
1294  * generate_series_numeric() -
1295  *
1296  *	Generate series of numeric.
1297  */
1298 Datum
generate_series_numeric(PG_FUNCTION_ARGS)1299 generate_series_numeric(PG_FUNCTION_ARGS)
1300 {
1301 	return generate_series_step_numeric(fcinfo);
1302 }
1303 
1304 Datum
generate_series_step_numeric(PG_FUNCTION_ARGS)1305 generate_series_step_numeric(PG_FUNCTION_ARGS)
1306 {
1307 	generate_series_numeric_fctx *fctx;
1308 	FuncCallContext *funcctx;
1309 	MemoryContext oldcontext;
1310 
1311 	if (SRF_IS_FIRSTCALL())
1312 	{
1313 		Numeric		start_num = PG_GETARG_NUMERIC(0);
1314 		Numeric		stop_num = PG_GETARG_NUMERIC(1);
1315 		NumericVar	steploc = const_one;
1316 
1317 		/* handle NaN in start and stop values */
1318 		if (NUMERIC_IS_NAN(start_num))
1319 			ereport(ERROR,
1320 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1321 					 errmsg("start value cannot be NaN")));
1322 
1323 		if (NUMERIC_IS_NAN(stop_num))
1324 			ereport(ERROR,
1325 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1326 					 errmsg("stop value cannot be NaN")));
1327 
1328 		/* see if we were given an explicit step size */
1329 		if (PG_NARGS() == 3)
1330 		{
1331 			Numeric		step_num = PG_GETARG_NUMERIC(2);
1332 
1333 			if (NUMERIC_IS_NAN(step_num))
1334 				ereport(ERROR,
1335 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1336 						 errmsg("step size cannot be NaN")));
1337 
1338 			init_var_from_num(step_num, &steploc);
1339 
1340 			if (cmp_var(&steploc, &const_zero) == 0)
1341 				ereport(ERROR,
1342 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1343 						 errmsg("step size cannot equal zero")));
1344 		}
1345 
1346 		/* create a function context for cross-call persistence */
1347 		funcctx = SRF_FIRSTCALL_INIT();
1348 
1349 		/*
1350 		 * Switch to memory context appropriate for multiple function calls.
1351 		 */
1352 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1353 
1354 		/* allocate memory for user context */
1355 		fctx = (generate_series_numeric_fctx *)
1356 			palloc(sizeof(generate_series_numeric_fctx));
1357 
1358 		/*
1359 		 * Use fctx to keep state from call to call. Seed current with the
1360 		 * original start value. We must copy the start_num and stop_num
1361 		 * values rather than pointing to them, since we may have detoasted
1362 		 * them in the per-call context.
1363 		 */
1364 		init_var(&fctx->current);
1365 		init_var(&fctx->stop);
1366 		init_var(&fctx->step);
1367 
1368 		set_var_from_num(start_num, &fctx->current);
1369 		set_var_from_num(stop_num, &fctx->stop);
1370 		set_var_from_var(&steploc, &fctx->step);
1371 
1372 		funcctx->user_fctx = fctx;
1373 		MemoryContextSwitchTo(oldcontext);
1374 	}
1375 
1376 	/* stuff done on every call of the function */
1377 	funcctx = SRF_PERCALL_SETUP();
1378 
1379 	/*
1380 	 * Get the saved state and use current state as the result of this
1381 	 * iteration.
1382 	 */
1383 	fctx = funcctx->user_fctx;
1384 
1385 	if ((fctx->step.sign == NUMERIC_POS &&
1386 		 cmp_var(&fctx->current, &fctx->stop) <= 0) ||
1387 		(fctx->step.sign == NUMERIC_NEG &&
1388 		 cmp_var(&fctx->current, &fctx->stop) >= 0))
1389 	{
1390 		Numeric		result = make_result(&fctx->current);
1391 
1392 		/* switch to memory context appropriate for iteration calculation */
1393 		oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1394 
1395 		/* increment current in preparation for next iteration */
1396 		add_var(&fctx->current, &fctx->step, &fctx->current);
1397 		MemoryContextSwitchTo(oldcontext);
1398 
1399 		/* do when there is more left to send */
1400 		SRF_RETURN_NEXT(funcctx, NumericGetDatum(result));
1401 	}
1402 	else
1403 		/* do when there is no more left */
1404 		SRF_RETURN_DONE(funcctx);
1405 }
1406 
1407 
1408 /*
1409  * Implements the numeric version of the width_bucket() function
1410  * defined by SQL2003. See also width_bucket_float8().
1411  *
1412  * 'bound1' and 'bound2' are the lower and upper bounds of the
1413  * histogram's range, respectively. 'count' is the number of buckets
1414  * in the histogram. width_bucket() returns an integer indicating the
1415  * bucket number that 'operand' belongs to in an equiwidth histogram
1416  * with the specified characteristics. An operand smaller than the
1417  * lower bound is assigned to bucket 0. An operand greater than the
1418  * upper bound is assigned to an additional bucket (with number
1419  * count+1). We don't allow "NaN" for any of the numeric arguments.
1420  */
1421 Datum
width_bucket_numeric(PG_FUNCTION_ARGS)1422 width_bucket_numeric(PG_FUNCTION_ARGS)
1423 {
1424 	Numeric		operand = PG_GETARG_NUMERIC(0);
1425 	Numeric		bound1 = PG_GETARG_NUMERIC(1);
1426 	Numeric		bound2 = PG_GETARG_NUMERIC(2);
1427 	int32		count = PG_GETARG_INT32(3);
1428 	NumericVar	count_var;
1429 	NumericVar	result_var;
1430 	int32		result;
1431 
1432 	if (count <= 0)
1433 		ereport(ERROR,
1434 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1435 				 errmsg("count must be greater than zero")));
1436 
1437 	if (NUMERIC_IS_NAN(operand) ||
1438 		NUMERIC_IS_NAN(bound1) ||
1439 		NUMERIC_IS_NAN(bound2))
1440 		ereport(ERROR,
1441 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1442 			 errmsg("operand, lower bound, and upper bound cannot be NaN")));
1443 
1444 	init_var(&result_var);
1445 	init_var(&count_var);
1446 
1447 	/* Convert 'count' to a numeric, for ease of use later */
1448 	int64_to_numericvar((int64) count, &count_var);
1449 
1450 	switch (cmp_numerics(bound1, bound2))
1451 	{
1452 		case 0:
1453 			ereport(ERROR,
1454 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1455 				 errmsg("lower bound cannot equal upper bound")));
1456 
1457 			/* bound1 < bound2 */
1458 		case -1:
1459 			if (cmp_numerics(operand, bound1) < 0)
1460 				set_var_from_var(&const_zero, &result_var);
1461 			else if (cmp_numerics(operand, bound2) >= 0)
1462 				add_var(&count_var, &const_one, &result_var);
1463 			else
1464 				compute_bucket(operand, bound1, bound2,
1465 							   &count_var, &result_var);
1466 			break;
1467 
1468 			/* bound1 > bound2 */
1469 		case 1:
1470 			if (cmp_numerics(operand, bound1) > 0)
1471 				set_var_from_var(&const_zero, &result_var);
1472 			else if (cmp_numerics(operand, bound2) <= 0)
1473 				add_var(&count_var, &const_one, &result_var);
1474 			else
1475 				compute_bucket(operand, bound1, bound2,
1476 							   &count_var, &result_var);
1477 			break;
1478 	}
1479 
1480 	/* if result exceeds the range of a legal int4, we ereport here */
1481 	result = numericvar_to_int32(&result_var);
1482 
1483 	free_var(&count_var);
1484 	free_var(&result_var);
1485 
1486 	PG_RETURN_INT32(result);
1487 }
1488 
1489 /*
1490  * If 'operand' is not outside the bucket range, determine the correct
1491  * bucket for it to go. The calculations performed by this function
1492  * are derived directly from the SQL2003 spec.
1493  */
1494 static void
compute_bucket(Numeric operand,Numeric bound1,Numeric bound2,NumericVar * count_var,NumericVar * result_var)1495 compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
1496 			   NumericVar *count_var, NumericVar *result_var)
1497 {
1498 	NumericVar	bound1_var;
1499 	NumericVar	bound2_var;
1500 	NumericVar	operand_var;
1501 
1502 	init_var_from_num(bound1, &bound1_var);
1503 	init_var_from_num(bound2, &bound2_var);
1504 	init_var_from_num(operand, &operand_var);
1505 
1506 	if (cmp_var(&bound1_var, &bound2_var) < 0)
1507 	{
1508 		sub_var(&operand_var, &bound1_var, &operand_var);
1509 		sub_var(&bound2_var, &bound1_var, &bound2_var);
1510 		div_var(&operand_var, &bound2_var, result_var,
1511 				select_div_scale(&operand_var, &bound2_var), true);
1512 	}
1513 	else
1514 	{
1515 		sub_var(&bound1_var, &operand_var, &operand_var);
1516 		sub_var(&bound1_var, &bound2_var, &bound1_var);
1517 		div_var(&operand_var, &bound1_var, result_var,
1518 				select_div_scale(&operand_var, &bound1_var), true);
1519 	}
1520 
1521 	mul_var(result_var, count_var, result_var,
1522 			result_var->dscale + count_var->dscale);
1523 	add_var(result_var, &const_one, result_var);
1524 	floor_var(result_var, result_var);
1525 
1526 	free_var(&bound1_var);
1527 	free_var(&bound2_var);
1528 	free_var(&operand_var);
1529 }
1530 
1531 /* ----------------------------------------------------------------------
1532  *
1533  * Comparison functions
1534  *
1535  * Note: btree indexes need these routines not to leak memory; therefore,
1536  * be careful to free working copies of toasted datums.  Most places don't
1537  * need to be so careful.
1538  *
1539  * Sort support:
1540  *
1541  * We implement the sortsupport strategy routine in order to get the benefit of
1542  * abbreviation. The ordinary numeric comparison can be quite slow as a result
1543  * of palloc/pfree cycles (due to detoasting packed values for alignment);
1544  * while this could be worked on itself, the abbreviation strategy gives more
1545  * speedup in many common cases.
1546  *
1547  * Two different representations are used for the abbreviated form, one in
1548  * int32 and one in int64, whichever fits into a by-value Datum.  In both cases
1549  * the representation is negated relative to the original value, because we use
1550  * the largest negative value for NaN, which sorts higher than other values. We
1551  * convert the absolute value of the numeric to a 31-bit or 63-bit positive
1552  * value, and then negate it if the original number was positive.
1553  *
1554  * We abort the abbreviation process if the abbreviation cardinality is below
1555  * 0.01% of the row count (1 per 10k non-null rows).  The actual break-even
1556  * point is somewhat below that, perhaps 1 per 30k (at 1 per 100k there's a
1557  * very small penalty), but we don't want to build up too many abbreviated
1558  * values before first testing for abort, so we take the slightly pessimistic
1559  * number.  We make no attempt to estimate the cardinality of the real values,
1560  * since it plays no part in the cost model here (if the abbreviation is equal,
1561  * the cost of comparing equal and unequal underlying values is comparable).
1562  * We discontinue even checking for abort (saving us the hashing overhead) if
1563  * the estimated cardinality gets to 100k; that would be enough to support many
1564  * billions of rows while doing no worse than breaking even.
1565  *
1566  * ----------------------------------------------------------------------
1567  */
1568 
1569 /*
1570  * Sort support strategy routine.
1571  */
1572 Datum
numeric_sortsupport(PG_FUNCTION_ARGS)1573 numeric_sortsupport(PG_FUNCTION_ARGS)
1574 {
1575 	SortSupport ssup = (SortSupport) PG_GETARG_POINTER(0);
1576 
1577 	ssup->comparator = numeric_fast_cmp;
1578 
1579 	if (ssup->abbreviate)
1580 	{
1581 		NumericSortSupport *nss;
1582 		MemoryContext oldcontext = MemoryContextSwitchTo(ssup->ssup_cxt);
1583 
1584 		nss = palloc(sizeof(NumericSortSupport));
1585 
1586 		/*
1587 		 * palloc a buffer for handling unaligned packed values in addition to
1588 		 * the support struct
1589 		 */
1590 		nss->buf = palloc(VARATT_SHORT_MAX + VARHDRSZ + 1);
1591 
1592 		nss->input_count = 0;
1593 		nss->estimating = true;
1594 		initHyperLogLog(&nss->abbr_card, 10);
1595 
1596 		ssup->ssup_extra = nss;
1597 
1598 		ssup->abbrev_full_comparator = ssup->comparator;
1599 		ssup->comparator = numeric_cmp_abbrev;
1600 		ssup->abbrev_converter = numeric_abbrev_convert;
1601 		ssup->abbrev_abort = numeric_abbrev_abort;
1602 
1603 		MemoryContextSwitchTo(oldcontext);
1604 	}
1605 
1606 	PG_RETURN_VOID();
1607 }
1608 
1609 /*
1610  * Abbreviate a numeric datum, handling NaNs and detoasting
1611  * (must not leak memory!)
1612  */
1613 static Datum
numeric_abbrev_convert(Datum original_datum,SortSupport ssup)1614 numeric_abbrev_convert(Datum original_datum, SortSupport ssup)
1615 {
1616 	NumericSortSupport *nss = ssup->ssup_extra;
1617 	void	   *original_varatt = PG_DETOAST_DATUM_PACKED(original_datum);
1618 	Numeric		value;
1619 	Datum		result;
1620 
1621 	nss->input_count += 1;
1622 
1623 	/*
1624 	 * This is to handle packed datums without needing a palloc/pfree cycle;
1625 	 * we keep and reuse a buffer large enough to handle any short datum.
1626 	 */
1627 	if (VARATT_IS_SHORT(original_varatt))
1628 	{
1629 		void	   *buf = nss->buf;
1630 		Size		sz = VARSIZE_SHORT(original_varatt) - VARHDRSZ_SHORT;
1631 
1632 		Assert(sz <= VARATT_SHORT_MAX - VARHDRSZ_SHORT);
1633 
1634 		SET_VARSIZE(buf, VARHDRSZ + sz);
1635 		memcpy(VARDATA(buf), VARDATA_SHORT(original_varatt), sz);
1636 
1637 		value = (Numeric) buf;
1638 	}
1639 	else
1640 		value = (Numeric) original_varatt;
1641 
1642 	if (NUMERIC_IS_NAN(value))
1643 	{
1644 		result = NUMERIC_ABBREV_NAN;
1645 	}
1646 	else
1647 	{
1648 		NumericVar	var;
1649 
1650 		init_var_from_num(value, &var);
1651 
1652 		result = numeric_abbrev_convert_var(&var, nss);
1653 	}
1654 
1655 	/* should happen only for external/compressed toasts */
1656 	if ((Pointer) original_varatt != DatumGetPointer(original_datum))
1657 		pfree(original_varatt);
1658 
1659 	return result;
1660 }
1661 
1662 /*
1663  * Consider whether to abort abbreviation.
1664  *
1665  * We pay no attention to the cardinality of the non-abbreviated data. There is
1666  * no reason to do so: unlike text, we have no fast check for equal values, so
1667  * we pay the full overhead whenever the abbreviations are equal regardless of
1668  * whether the underlying values are also equal.
1669  */
1670 static bool
numeric_abbrev_abort(int memtupcount,SortSupport ssup)1671 numeric_abbrev_abort(int memtupcount, SortSupport ssup)
1672 {
1673 	NumericSortSupport *nss = ssup->ssup_extra;
1674 	double		abbr_card;
1675 
1676 	if (memtupcount < 10000 || nss->input_count < 10000 || !nss->estimating)
1677 		return false;
1678 
1679 	abbr_card = estimateHyperLogLog(&nss->abbr_card);
1680 
1681 	/*
1682 	 * If we have >100k distinct values, then even if we were sorting many
1683 	 * billion rows we'd likely still break even, and the penalty of undoing
1684 	 * that many rows of abbrevs would probably not be worth it. Stop even
1685 	 * counting at that point.
1686 	 */
1687 	if (abbr_card > 100000.0)
1688 	{
1689 #ifdef TRACE_SORT
1690 		if (trace_sort)
1691 			elog(LOG,
1692 				 "numeric_abbrev: estimation ends at cardinality %f"
1693 				 " after " INT64_FORMAT " values (%d rows)",
1694 				 abbr_card, nss->input_count, memtupcount);
1695 #endif
1696 		nss->estimating = false;
1697 		return false;
1698 	}
1699 
1700 	/*
1701 	 * Target minimum cardinality is 1 per ~10k of non-null inputs.  (The
1702 	 * break even point is somewhere between one per 100k rows, where
1703 	 * abbreviation has a very slight penalty, and 1 per 10k where it wins by
1704 	 * a measurable percentage.)  We use the relatively pessimistic 10k
1705 	 * threshold, and add a 0.5 row fudge factor, because it allows us to
1706 	 * abort earlier on genuinely pathological data where we've had exactly
1707 	 * one abbreviated value in the first 10k (non-null) rows.
1708 	 */
1709 	if (abbr_card < nss->input_count / 10000.0 + 0.5)
1710 	{
1711 #ifdef TRACE_SORT
1712 		if (trace_sort)
1713 			elog(LOG,
1714 				 "numeric_abbrev: aborting abbreviation at cardinality %f"
1715 			   " below threshold %f after " INT64_FORMAT " values (%d rows)",
1716 				 abbr_card, nss->input_count / 10000.0 + 0.5,
1717 				 nss->input_count, memtupcount);
1718 #endif
1719 		return true;
1720 	}
1721 
1722 #ifdef TRACE_SORT
1723 	if (trace_sort)
1724 		elog(LOG,
1725 			 "numeric_abbrev: cardinality %f"
1726 			 " after " INT64_FORMAT " values (%d rows)",
1727 			 abbr_card, nss->input_count, memtupcount);
1728 #endif
1729 
1730 	return false;
1731 }
1732 
1733 /*
1734  * Non-fmgr interface to the comparison routine to allow sortsupport to elide
1735  * the fmgr call.  The saving here is small given how slow numeric comparisons
1736  * are, but it is a required part of the sort support API when abbreviations
1737  * are performed.
1738  *
1739  * Two palloc/pfree cycles could be saved here by using persistent buffers for
1740  * aligning short-varlena inputs, but this has not so far been considered to
1741  * be worth the effort.
1742  */
1743 static int
numeric_fast_cmp(Datum x,Datum y,SortSupport ssup)1744 numeric_fast_cmp(Datum x, Datum y, SortSupport ssup)
1745 {
1746 	Numeric		nx = DatumGetNumeric(x);
1747 	Numeric		ny = DatumGetNumeric(y);
1748 	int			result;
1749 
1750 	result = cmp_numerics(nx, ny);
1751 
1752 	if ((Pointer) nx != DatumGetPointer(x))
1753 		pfree(nx);
1754 	if ((Pointer) ny != DatumGetPointer(y))
1755 		pfree(ny);
1756 
1757 	return result;
1758 }
1759 
1760 /*
1761  * Compare abbreviations of values. (Abbreviations may be equal where the true
1762  * values differ, but if the abbreviations differ, they must reflect the
1763  * ordering of the true values.)
1764  */
1765 static int
numeric_cmp_abbrev(Datum x,Datum y,SortSupport ssup)1766 numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup)
1767 {
1768 	/*
1769 	 * NOTE WELL: this is intentionally backwards, because the abbreviation is
1770 	 * negated relative to the original value, to handle NaN.
1771 	 */
1772 	if (DatumGetNumericAbbrev(x) < DatumGetNumericAbbrev(y))
1773 		return 1;
1774 	if (DatumGetNumericAbbrev(x) > DatumGetNumericAbbrev(y))
1775 		return -1;
1776 	return 0;
1777 }
1778 
1779 /*
1780  * Abbreviate a NumericVar according to the available bit size.
1781  *
1782  * The 31-bit value is constructed as:
1783  *
1784  *	0 + 7bits digit weight + 24 bits digit value
1785  *
1786  * where the digit weight is in single decimal digits, not digit words, and
1787  * stored in excess-44 representation[1]. The 24-bit digit value is the 7 most
1788  * significant decimal digits of the value converted to binary. Values whose
1789  * weights would fall outside the representable range are rounded off to zero
1790  * (which is also used to represent actual zeros) or to 0x7FFFFFFF (which
1791  * otherwise cannot occur). Abbreviation therefore fails to gain any advantage
1792  * where values are outside the range 10^-44 to 10^83, which is not considered
1793  * to be a serious limitation, or when values are of the same magnitude and
1794  * equal in the first 7 decimal digits, which is considered to be an
1795  * unavoidable limitation given the available bits. (Stealing three more bits
1796  * to compare another digit would narrow the range of representable weights by
1797  * a factor of 8, which starts to look like a real limiting factor.)
1798  *
1799  * (The value 44 for the excess is essentially arbitrary)
1800  *
1801  * The 63-bit value is constructed as:
1802  *
1803  *	0 + 7bits weight + 4 x 14-bit packed digit words
1804  *
1805  * The weight in this case is again stored in excess-44, but this time it is
1806  * the original weight in digit words (i.e. powers of 10000). The first four
1807  * digit words of the value (if present; trailing zeros are assumed as needed)
1808  * are packed into 14 bits each to form the rest of the value. Again,
1809  * out-of-range values are rounded off to 0 or 0x7FFFFFFFFFFFFFFF. The
1810  * representable range in this case is 10^-176 to 10^332, which is considered
1811  * to be good enough for all practical purposes, and comparison of 4 words
1812  * means that at least 13 decimal digits are compared, which is considered to
1813  * be a reasonable compromise between effectiveness and efficiency in computing
1814  * the abbreviation.
1815  *
1816  * (The value 44 for the excess is even more arbitrary here, it was chosen just
1817  * to match the value used in the 31-bit case)
1818  *
1819  * [1] - Excess-k representation means that the value is offset by adding 'k'
1820  * and then treated as unsigned, so the smallest representable value is stored
1821  * with all bits zero. This allows simple comparisons to work on the composite
1822  * value.
1823  */
1824 
1825 #if NUMERIC_ABBREV_BITS == 64
1826 
1827 static Datum
numeric_abbrev_convert_var(NumericVar * var,NumericSortSupport * nss)1828 numeric_abbrev_convert_var(NumericVar *var, NumericSortSupport *nss)
1829 {
1830 	int			ndigits = var->ndigits;
1831 	int			weight = var->weight;
1832 	int64		result;
1833 
1834 	if (ndigits == 0 || weight < -44)
1835 	{
1836 		result = 0;
1837 	}
1838 	else if (weight > 83)
1839 	{
1840 		result = PG_INT64_MAX;
1841 	}
1842 	else
1843 	{
1844 		result = ((int64) (weight + 44) << 56);
1845 
1846 		switch (ndigits)
1847 		{
1848 			default:
1849 				result |= ((int64) var->digits[3]);
1850 				/* FALLTHROUGH */
1851 			case 3:
1852 				result |= ((int64) var->digits[2]) << 14;
1853 				/* FALLTHROUGH */
1854 			case 2:
1855 				result |= ((int64) var->digits[1]) << 28;
1856 				/* FALLTHROUGH */
1857 			case 1:
1858 				result |= ((int64) var->digits[0]) << 42;
1859 				break;
1860 		}
1861 	}
1862 
1863 	/* the abbrev is negated relative to the original */
1864 	if (var->sign == NUMERIC_POS)
1865 		result = -result;
1866 
1867 	if (nss->estimating)
1868 	{
1869 		uint32		tmp = ((uint32) result
1870 						   ^ (uint32) ((uint64) result >> 32));
1871 
1872 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
1873 	}
1874 
1875 	return NumericAbbrevGetDatum(result);
1876 }
1877 
1878 #endif   /* NUMERIC_ABBREV_BITS == 64 */
1879 
1880 #if NUMERIC_ABBREV_BITS == 32
1881 
1882 static Datum
numeric_abbrev_convert_var(NumericVar * var,NumericSortSupport * nss)1883 numeric_abbrev_convert_var(NumericVar *var, NumericSortSupport *nss)
1884 {
1885 	int			ndigits = var->ndigits;
1886 	int			weight = var->weight;
1887 	int32		result;
1888 
1889 	if (ndigits == 0 || weight < -11)
1890 	{
1891 		result = 0;
1892 	}
1893 	else if (weight > 20)
1894 	{
1895 		result = PG_INT32_MAX;
1896 	}
1897 	else
1898 	{
1899 		NumericDigit nxt1 = (ndigits > 1) ? var->digits[1] : 0;
1900 
1901 		weight = (weight + 11) * 4;
1902 
1903 		result = var->digits[0];
1904 
1905 		/*
1906 		 * "result" now has 1 to 4 nonzero decimal digits. We pack in more
1907 		 * digits to make 7 in total (largest we can fit in 24 bits)
1908 		 */
1909 
1910 		if (result > 999)
1911 		{
1912 			/* already have 4 digits, add 3 more */
1913 			result = (result * 1000) + (nxt1 / 10);
1914 			weight += 3;
1915 		}
1916 		else if (result > 99)
1917 		{
1918 			/* already have 3 digits, add 4 more */
1919 			result = (result * 10000) + nxt1;
1920 			weight += 2;
1921 		}
1922 		else if (result > 9)
1923 		{
1924 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
1925 
1926 			/* already have 2 digits, add 5 more */
1927 			result = (result * 100000) + (nxt1 * 10) + (nxt2 / 1000);
1928 			weight += 1;
1929 		}
1930 		else
1931 		{
1932 			NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
1933 
1934 			/* already have 1 digit, add 6 more */
1935 			result = (result * 1000000) + (nxt1 * 100) + (nxt2 / 100);
1936 		}
1937 
1938 		result = result | (weight << 24);
1939 	}
1940 
1941 	/* the abbrev is negated relative to the original */
1942 	if (var->sign == NUMERIC_POS)
1943 		result = -result;
1944 
1945 	if (nss->estimating)
1946 	{
1947 		uint32		tmp = (uint32) result;
1948 
1949 		addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
1950 	}
1951 
1952 	return NumericAbbrevGetDatum(result);
1953 }
1954 
1955 #endif   /* NUMERIC_ABBREV_BITS == 32 */
1956 
1957 /*
1958  * Ordinary (non-sortsupport) comparisons follow.
1959  */
1960 
1961 Datum
numeric_cmp(PG_FUNCTION_ARGS)1962 numeric_cmp(PG_FUNCTION_ARGS)
1963 {
1964 	Numeric		num1 = PG_GETARG_NUMERIC(0);
1965 	Numeric		num2 = PG_GETARG_NUMERIC(1);
1966 	int			result;
1967 
1968 	result = cmp_numerics(num1, num2);
1969 
1970 	PG_FREE_IF_COPY(num1, 0);
1971 	PG_FREE_IF_COPY(num2, 1);
1972 
1973 	PG_RETURN_INT32(result);
1974 }
1975 
1976 
1977 Datum
numeric_eq(PG_FUNCTION_ARGS)1978 numeric_eq(PG_FUNCTION_ARGS)
1979 {
1980 	Numeric		num1 = PG_GETARG_NUMERIC(0);
1981 	Numeric		num2 = PG_GETARG_NUMERIC(1);
1982 	bool		result;
1983 
1984 	result = cmp_numerics(num1, num2) == 0;
1985 
1986 	PG_FREE_IF_COPY(num1, 0);
1987 	PG_FREE_IF_COPY(num2, 1);
1988 
1989 	PG_RETURN_BOOL(result);
1990 }
1991 
1992 Datum
numeric_ne(PG_FUNCTION_ARGS)1993 numeric_ne(PG_FUNCTION_ARGS)
1994 {
1995 	Numeric		num1 = PG_GETARG_NUMERIC(0);
1996 	Numeric		num2 = PG_GETARG_NUMERIC(1);
1997 	bool		result;
1998 
1999 	result = cmp_numerics(num1, num2) != 0;
2000 
2001 	PG_FREE_IF_COPY(num1, 0);
2002 	PG_FREE_IF_COPY(num2, 1);
2003 
2004 	PG_RETURN_BOOL(result);
2005 }
2006 
2007 Datum
numeric_gt(PG_FUNCTION_ARGS)2008 numeric_gt(PG_FUNCTION_ARGS)
2009 {
2010 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2011 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2012 	bool		result;
2013 
2014 	result = cmp_numerics(num1, num2) > 0;
2015 
2016 	PG_FREE_IF_COPY(num1, 0);
2017 	PG_FREE_IF_COPY(num2, 1);
2018 
2019 	PG_RETURN_BOOL(result);
2020 }
2021 
2022 Datum
numeric_ge(PG_FUNCTION_ARGS)2023 numeric_ge(PG_FUNCTION_ARGS)
2024 {
2025 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2026 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2027 	bool		result;
2028 
2029 	result = cmp_numerics(num1, num2) >= 0;
2030 
2031 	PG_FREE_IF_COPY(num1, 0);
2032 	PG_FREE_IF_COPY(num2, 1);
2033 
2034 	PG_RETURN_BOOL(result);
2035 }
2036 
2037 Datum
numeric_lt(PG_FUNCTION_ARGS)2038 numeric_lt(PG_FUNCTION_ARGS)
2039 {
2040 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2041 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2042 	bool		result;
2043 
2044 	result = cmp_numerics(num1, num2) < 0;
2045 
2046 	PG_FREE_IF_COPY(num1, 0);
2047 	PG_FREE_IF_COPY(num2, 1);
2048 
2049 	PG_RETURN_BOOL(result);
2050 }
2051 
2052 Datum
numeric_le(PG_FUNCTION_ARGS)2053 numeric_le(PG_FUNCTION_ARGS)
2054 {
2055 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2056 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2057 	bool		result;
2058 
2059 	result = cmp_numerics(num1, num2) <= 0;
2060 
2061 	PG_FREE_IF_COPY(num1, 0);
2062 	PG_FREE_IF_COPY(num2, 1);
2063 
2064 	PG_RETURN_BOOL(result);
2065 }
2066 
2067 static int
cmp_numerics(Numeric num1,Numeric num2)2068 cmp_numerics(Numeric num1, Numeric num2)
2069 {
2070 	int			result;
2071 
2072 	/*
2073 	 * We consider all NANs to be equal and larger than any non-NAN. This is
2074 	 * somewhat arbitrary; the important thing is to have a consistent sort
2075 	 * order.
2076 	 */
2077 	if (NUMERIC_IS_NAN(num1))
2078 	{
2079 		if (NUMERIC_IS_NAN(num2))
2080 			result = 0;			/* NAN = NAN */
2081 		else
2082 			result = 1;			/* NAN > non-NAN */
2083 	}
2084 	else if (NUMERIC_IS_NAN(num2))
2085 	{
2086 		result = -1;			/* non-NAN < NAN */
2087 	}
2088 	else
2089 	{
2090 		result = cmp_var_common(NUMERIC_DIGITS(num1), NUMERIC_NDIGITS(num1),
2091 								NUMERIC_WEIGHT(num1), NUMERIC_SIGN(num1),
2092 								NUMERIC_DIGITS(num2), NUMERIC_NDIGITS(num2),
2093 								NUMERIC_WEIGHT(num2), NUMERIC_SIGN(num2));
2094 	}
2095 
2096 	return result;
2097 }
2098 
2099 Datum
hash_numeric(PG_FUNCTION_ARGS)2100 hash_numeric(PG_FUNCTION_ARGS)
2101 {
2102 	Numeric		key = PG_GETARG_NUMERIC(0);
2103 	Datum		digit_hash;
2104 	Datum		result;
2105 	int			weight;
2106 	int			start_offset;
2107 	int			end_offset;
2108 	int			i;
2109 	int			hash_len;
2110 	NumericDigit *digits;
2111 
2112 	/* If it's NaN, don't try to hash the rest of the fields */
2113 	if (NUMERIC_IS_NAN(key))
2114 		PG_RETURN_UINT32(0);
2115 
2116 	weight = NUMERIC_WEIGHT(key);
2117 	start_offset = 0;
2118 	end_offset = 0;
2119 
2120 	/*
2121 	 * Omit any leading or trailing zeros from the input to the hash. The
2122 	 * numeric implementation *should* guarantee that leading and trailing
2123 	 * zeros are suppressed, but we're paranoid. Note that we measure the
2124 	 * starting and ending offsets in units of NumericDigits, not bytes.
2125 	 */
2126 	digits = NUMERIC_DIGITS(key);
2127 	for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2128 	{
2129 		if (digits[i] != (NumericDigit) 0)
2130 			break;
2131 
2132 		start_offset++;
2133 
2134 		/*
2135 		 * The weight is effectively the # of digits before the decimal point,
2136 		 * so decrement it for each leading zero we skip.
2137 		 */
2138 		weight--;
2139 	}
2140 
2141 	/*
2142 	 * If there are no non-zero digits, then the value of the number is zero,
2143 	 * regardless of any other fields.
2144 	 */
2145 	if (NUMERIC_NDIGITS(key) == start_offset)
2146 		PG_RETURN_UINT32(-1);
2147 
2148 	for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2149 	{
2150 		if (digits[i] != (NumericDigit) 0)
2151 			break;
2152 
2153 		end_offset++;
2154 	}
2155 
2156 	/* If we get here, there should be at least one non-zero digit */
2157 	Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2158 
2159 	/*
2160 	 * Note that we don't hash on the Numeric's scale, since two numerics can
2161 	 * compare equal but have different scales. We also don't hash on the
2162 	 * sign, although we could: since a sign difference implies inequality,
2163 	 * this shouldn't affect correctness.
2164 	 */
2165 	hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2166 	digit_hash = hash_any((unsigned char *) (NUMERIC_DIGITS(key) + start_offset),
2167 						  hash_len * sizeof(NumericDigit));
2168 
2169 	/* Mix in the weight, via XOR */
2170 	result = digit_hash ^ weight;
2171 
2172 	PG_RETURN_DATUM(result);
2173 }
2174 
2175 
2176 /* ----------------------------------------------------------------------
2177  *
2178  * Basic arithmetic functions
2179  *
2180  * ----------------------------------------------------------------------
2181  */
2182 
2183 
2184 /*
2185  * numeric_add() -
2186  *
2187  *	Add two numerics
2188  */
2189 Datum
numeric_add(PG_FUNCTION_ARGS)2190 numeric_add(PG_FUNCTION_ARGS)
2191 {
2192 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2193 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2194 	NumericVar	arg1;
2195 	NumericVar	arg2;
2196 	NumericVar	result;
2197 	Numeric		res;
2198 
2199 	/*
2200 	 * Handle NaN
2201 	 */
2202 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2203 		PG_RETURN_NUMERIC(make_result(&const_nan));
2204 
2205 	/*
2206 	 * Unpack the values, let add_var() compute the result and return it.
2207 	 */
2208 	init_var_from_num(num1, &arg1);
2209 	init_var_from_num(num2, &arg2);
2210 
2211 	init_var(&result);
2212 	add_var(&arg1, &arg2, &result);
2213 
2214 	res = make_result(&result);
2215 
2216 	free_var(&result);
2217 
2218 	PG_RETURN_NUMERIC(res);
2219 }
2220 
2221 
2222 /*
2223  * numeric_sub() -
2224  *
2225  *	Subtract one numeric from another
2226  */
2227 Datum
numeric_sub(PG_FUNCTION_ARGS)2228 numeric_sub(PG_FUNCTION_ARGS)
2229 {
2230 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2231 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2232 	NumericVar	arg1;
2233 	NumericVar	arg2;
2234 	NumericVar	result;
2235 	Numeric		res;
2236 
2237 	/*
2238 	 * Handle NaN
2239 	 */
2240 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2241 		PG_RETURN_NUMERIC(make_result(&const_nan));
2242 
2243 	/*
2244 	 * Unpack the values, let sub_var() compute the result and return it.
2245 	 */
2246 	init_var_from_num(num1, &arg1);
2247 	init_var_from_num(num2, &arg2);
2248 
2249 	init_var(&result);
2250 	sub_var(&arg1, &arg2, &result);
2251 
2252 	res = make_result(&result);
2253 
2254 	free_var(&result);
2255 
2256 	PG_RETURN_NUMERIC(res);
2257 }
2258 
2259 
2260 /*
2261  * numeric_mul() -
2262  *
2263  *	Calculate the product of two numerics
2264  */
2265 Datum
numeric_mul(PG_FUNCTION_ARGS)2266 numeric_mul(PG_FUNCTION_ARGS)
2267 {
2268 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2269 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2270 	NumericVar	arg1;
2271 	NumericVar	arg2;
2272 	NumericVar	result;
2273 	Numeric		res;
2274 
2275 	/*
2276 	 * Handle NaN
2277 	 */
2278 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2279 		PG_RETURN_NUMERIC(make_result(&const_nan));
2280 
2281 	/*
2282 	 * Unpack the values, let mul_var() compute the result and return it.
2283 	 * Unlike add_var() and sub_var(), mul_var() will round its result. In the
2284 	 * case of numeric_mul(), which is invoked for the * operator on numerics,
2285 	 * we request exact representation for the product (rscale = sum(dscale of
2286 	 * arg1, dscale of arg2)).  If the exact result has more digits after the
2287 	 * decimal point than can be stored in a numeric, we round it.  Rounding
2288 	 * after computing the exact result ensures that the final result is
2289 	 * correctly rounded (rounding in mul_var() using a truncated product
2290 	 * would not guarantee this).
2291 	 */
2292 	init_var_from_num(num1, &arg1);
2293 	init_var_from_num(num2, &arg2);
2294 
2295 	init_var(&result);
2296 	mul_var(&arg1, &arg2, &result, arg1.dscale + arg2.dscale);
2297 
2298 	if (result.dscale > NUMERIC_DSCALE_MAX)
2299 		round_var(&result, NUMERIC_DSCALE_MAX);
2300 
2301 	res = make_result(&result);
2302 
2303 	free_var(&result);
2304 
2305 	PG_RETURN_NUMERIC(res);
2306 }
2307 
2308 
2309 /*
2310  * numeric_div() -
2311  *
2312  *	Divide one numeric into another
2313  */
2314 Datum
numeric_div(PG_FUNCTION_ARGS)2315 numeric_div(PG_FUNCTION_ARGS)
2316 {
2317 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2318 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2319 	NumericVar	arg1;
2320 	NumericVar	arg2;
2321 	NumericVar	result;
2322 	Numeric		res;
2323 	int			rscale;
2324 
2325 	/*
2326 	 * Handle NaN
2327 	 */
2328 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2329 		PG_RETURN_NUMERIC(make_result(&const_nan));
2330 
2331 	/*
2332 	 * Unpack the arguments
2333 	 */
2334 	init_var_from_num(num1, &arg1);
2335 	init_var_from_num(num2, &arg2);
2336 
2337 	init_var(&result);
2338 
2339 	/*
2340 	 * Select scale for division result
2341 	 */
2342 	rscale = select_div_scale(&arg1, &arg2);
2343 
2344 	/*
2345 	 * Do the divide and return the result
2346 	 */
2347 	div_var(&arg1, &arg2, &result, rscale, true);
2348 
2349 	res = make_result(&result);
2350 
2351 	free_var(&result);
2352 
2353 	PG_RETURN_NUMERIC(res);
2354 }
2355 
2356 
2357 /*
2358  * numeric_div_trunc() -
2359  *
2360  *	Divide one numeric into another, truncating the result to an integer
2361  */
2362 Datum
numeric_div_trunc(PG_FUNCTION_ARGS)2363 numeric_div_trunc(PG_FUNCTION_ARGS)
2364 {
2365 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2366 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2367 	NumericVar	arg1;
2368 	NumericVar	arg2;
2369 	NumericVar	result;
2370 	Numeric		res;
2371 
2372 	/*
2373 	 * Handle NaN
2374 	 */
2375 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2376 		PG_RETURN_NUMERIC(make_result(&const_nan));
2377 
2378 	/*
2379 	 * Unpack the arguments
2380 	 */
2381 	init_var_from_num(num1, &arg1);
2382 	init_var_from_num(num2, &arg2);
2383 
2384 	init_var(&result);
2385 
2386 	/*
2387 	 * Do the divide and return the result
2388 	 */
2389 	div_var(&arg1, &arg2, &result, 0, false);
2390 
2391 	res = make_result(&result);
2392 
2393 	free_var(&result);
2394 
2395 	PG_RETURN_NUMERIC(res);
2396 }
2397 
2398 
2399 /*
2400  * numeric_mod() -
2401  *
2402  *	Calculate the modulo of two numerics
2403  */
2404 Datum
numeric_mod(PG_FUNCTION_ARGS)2405 numeric_mod(PG_FUNCTION_ARGS)
2406 {
2407 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2408 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2409 	Numeric		res;
2410 	NumericVar	arg1;
2411 	NumericVar	arg2;
2412 	NumericVar	result;
2413 
2414 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2415 		PG_RETURN_NUMERIC(make_result(&const_nan));
2416 
2417 	init_var_from_num(num1, &arg1);
2418 	init_var_from_num(num2, &arg2);
2419 
2420 	init_var(&result);
2421 
2422 	mod_var(&arg1, &arg2, &result);
2423 
2424 	res = make_result(&result);
2425 
2426 	free_var(&result);
2427 
2428 	PG_RETURN_NUMERIC(res);
2429 }
2430 
2431 
2432 /*
2433  * numeric_inc() -
2434  *
2435  *	Increment a number by one
2436  */
2437 Datum
numeric_inc(PG_FUNCTION_ARGS)2438 numeric_inc(PG_FUNCTION_ARGS)
2439 {
2440 	Numeric		num = PG_GETARG_NUMERIC(0);
2441 	NumericVar	arg;
2442 	Numeric		res;
2443 
2444 	/*
2445 	 * Handle NaN
2446 	 */
2447 	if (NUMERIC_IS_NAN(num))
2448 		PG_RETURN_NUMERIC(make_result(&const_nan));
2449 
2450 	/*
2451 	 * Compute the result and return it
2452 	 */
2453 	init_var_from_num(num, &arg);
2454 
2455 	add_var(&arg, &const_one, &arg);
2456 
2457 	res = make_result(&arg);
2458 
2459 	free_var(&arg);
2460 
2461 	PG_RETURN_NUMERIC(res);
2462 }
2463 
2464 
2465 /*
2466  * numeric_smaller() -
2467  *
2468  *	Return the smaller of two numbers
2469  */
2470 Datum
numeric_smaller(PG_FUNCTION_ARGS)2471 numeric_smaller(PG_FUNCTION_ARGS)
2472 {
2473 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2474 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2475 
2476 	/*
2477 	 * Use cmp_numerics so that this will agree with the comparison operators,
2478 	 * particularly as regards comparisons involving NaN.
2479 	 */
2480 	if (cmp_numerics(num1, num2) < 0)
2481 		PG_RETURN_NUMERIC(num1);
2482 	else
2483 		PG_RETURN_NUMERIC(num2);
2484 }
2485 
2486 
2487 /*
2488  * numeric_larger() -
2489  *
2490  *	Return the larger of two numbers
2491  */
2492 Datum
numeric_larger(PG_FUNCTION_ARGS)2493 numeric_larger(PG_FUNCTION_ARGS)
2494 {
2495 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2496 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2497 
2498 	/*
2499 	 * Use cmp_numerics so that this will agree with the comparison operators,
2500 	 * particularly as regards comparisons involving NaN.
2501 	 */
2502 	if (cmp_numerics(num1, num2) > 0)
2503 		PG_RETURN_NUMERIC(num1);
2504 	else
2505 		PG_RETURN_NUMERIC(num2);
2506 }
2507 
2508 
2509 /* ----------------------------------------------------------------------
2510  *
2511  * Advanced math functions
2512  *
2513  * ----------------------------------------------------------------------
2514  */
2515 
2516 /*
2517  * numeric_fac()
2518  *
2519  * Compute factorial
2520  */
2521 Datum
numeric_fac(PG_FUNCTION_ARGS)2522 numeric_fac(PG_FUNCTION_ARGS)
2523 {
2524 	int64		num = PG_GETARG_INT64(0);
2525 	Numeric		res;
2526 	NumericVar	fact;
2527 	NumericVar	result;
2528 
2529 	if (num <= 1)
2530 	{
2531 		res = make_result(&const_one);
2532 		PG_RETURN_NUMERIC(res);
2533 	}
2534 	/* Fail immediately if the result would overflow */
2535 	if (num > 32177)
2536 		ereport(ERROR,
2537 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2538 				 errmsg("value overflows numeric format")));
2539 
2540 	init_var(&fact);
2541 	init_var(&result);
2542 
2543 	int64_to_numericvar(num, &result);
2544 
2545 	for (num = num - 1; num > 1; num--)
2546 	{
2547 		/* this loop can take awhile, so allow it to be interrupted */
2548 		CHECK_FOR_INTERRUPTS();
2549 
2550 		int64_to_numericvar(num, &fact);
2551 
2552 		mul_var(&result, &fact, &result, 0);
2553 	}
2554 
2555 	res = make_result(&result);
2556 
2557 	free_var(&fact);
2558 	free_var(&result);
2559 
2560 	PG_RETURN_NUMERIC(res);
2561 }
2562 
2563 
2564 /*
2565  * numeric_sqrt() -
2566  *
2567  *	Compute the square root of a numeric.
2568  */
2569 Datum
numeric_sqrt(PG_FUNCTION_ARGS)2570 numeric_sqrt(PG_FUNCTION_ARGS)
2571 {
2572 	Numeric		num = PG_GETARG_NUMERIC(0);
2573 	Numeric		res;
2574 	NumericVar	arg;
2575 	NumericVar	result;
2576 	int			sweight;
2577 	int			rscale;
2578 
2579 	/*
2580 	 * Handle NaN
2581 	 */
2582 	if (NUMERIC_IS_NAN(num))
2583 		PG_RETURN_NUMERIC(make_result(&const_nan));
2584 
2585 	/*
2586 	 * Unpack the argument and determine the result scale.  We choose a scale
2587 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
2588 	 * case not less than the input's dscale.
2589 	 */
2590 	init_var_from_num(num, &arg);
2591 
2592 	init_var(&result);
2593 
2594 	/* Assume the input was normalized, so arg.weight is accurate */
2595 	sweight = (arg.weight + 1) * DEC_DIGITS / 2 - 1;
2596 
2597 	rscale = NUMERIC_MIN_SIG_DIGITS - sweight;
2598 	rscale = Max(rscale, arg.dscale);
2599 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
2600 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
2601 
2602 	/*
2603 	 * Let sqrt_var() do the calculation and return the result.
2604 	 */
2605 	sqrt_var(&arg, &result, rscale);
2606 
2607 	res = make_result(&result);
2608 
2609 	free_var(&result);
2610 
2611 	PG_RETURN_NUMERIC(res);
2612 }
2613 
2614 
2615 /*
2616  * numeric_exp() -
2617  *
2618  *	Raise e to the power of x
2619  */
2620 Datum
numeric_exp(PG_FUNCTION_ARGS)2621 numeric_exp(PG_FUNCTION_ARGS)
2622 {
2623 	Numeric		num = PG_GETARG_NUMERIC(0);
2624 	Numeric		res;
2625 	NumericVar	arg;
2626 	NumericVar	result;
2627 	int			rscale;
2628 	double		val;
2629 
2630 	/*
2631 	 * Handle NaN
2632 	 */
2633 	if (NUMERIC_IS_NAN(num))
2634 		PG_RETURN_NUMERIC(make_result(&const_nan));
2635 
2636 	/*
2637 	 * Unpack the argument and determine the result scale.  We choose a scale
2638 	 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
2639 	 * case not less than the input's dscale.
2640 	 */
2641 	init_var_from_num(num, &arg);
2642 
2643 	init_var(&result);
2644 
2645 	/* convert input to float8, ignoring overflow */
2646 	val = numericvar_to_double_no_overflow(&arg);
2647 
2648 	/*
2649 	 * log10(result) = num * log10(e), so this is approximately the decimal
2650 	 * weight of the result:
2651 	 */
2652 	val *= 0.434294481903252;
2653 
2654 	/* limit to something that won't cause integer overflow */
2655 	val = Max(val, -NUMERIC_MAX_RESULT_SCALE);
2656 	val = Min(val, NUMERIC_MAX_RESULT_SCALE);
2657 
2658 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
2659 	rscale = Max(rscale, arg.dscale);
2660 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
2661 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
2662 
2663 	/*
2664 	 * Let exp_var() do the calculation and return the result.
2665 	 */
2666 	exp_var(&arg, &result, rscale);
2667 
2668 	res = make_result(&result);
2669 
2670 	free_var(&result);
2671 
2672 	PG_RETURN_NUMERIC(res);
2673 }
2674 
2675 
2676 /*
2677  * numeric_ln() -
2678  *
2679  *	Compute the natural logarithm of x
2680  */
2681 Datum
numeric_ln(PG_FUNCTION_ARGS)2682 numeric_ln(PG_FUNCTION_ARGS)
2683 {
2684 	Numeric		num = PG_GETARG_NUMERIC(0);
2685 	Numeric		res;
2686 	NumericVar	arg;
2687 	NumericVar	result;
2688 	int			ln_dweight;
2689 	int			rscale;
2690 
2691 	/*
2692 	 * Handle NaN
2693 	 */
2694 	if (NUMERIC_IS_NAN(num))
2695 		PG_RETURN_NUMERIC(make_result(&const_nan));
2696 
2697 	init_var_from_num(num, &arg);
2698 	init_var(&result);
2699 
2700 	/* Estimated dweight of logarithm */
2701 	ln_dweight = estimate_ln_dweight(&arg);
2702 
2703 	rscale = NUMERIC_MIN_SIG_DIGITS - ln_dweight;
2704 	rscale = Max(rscale, arg.dscale);
2705 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
2706 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
2707 
2708 	ln_var(&arg, &result, rscale);
2709 
2710 	res = make_result(&result);
2711 
2712 	free_var(&result);
2713 
2714 	PG_RETURN_NUMERIC(res);
2715 }
2716 
2717 
2718 /*
2719  * numeric_log() -
2720  *
2721  *	Compute the logarithm of x in a given base
2722  */
2723 Datum
numeric_log(PG_FUNCTION_ARGS)2724 numeric_log(PG_FUNCTION_ARGS)
2725 {
2726 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2727 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2728 	Numeric		res;
2729 	NumericVar	arg1;
2730 	NumericVar	arg2;
2731 	NumericVar	result;
2732 
2733 	/*
2734 	 * Handle NaN
2735 	 */
2736 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2737 		PG_RETURN_NUMERIC(make_result(&const_nan));
2738 
2739 	/*
2740 	 * Initialize things
2741 	 */
2742 	init_var_from_num(num1, &arg1);
2743 	init_var_from_num(num2, &arg2);
2744 	init_var(&result);
2745 
2746 	/*
2747 	 * Call log_var() to compute and return the result; note it handles scale
2748 	 * selection itself.
2749 	 */
2750 	log_var(&arg1, &arg2, &result);
2751 
2752 	res = make_result(&result);
2753 
2754 	free_var(&result);
2755 
2756 	PG_RETURN_NUMERIC(res);
2757 }
2758 
2759 
2760 /*
2761  * numeric_power() -
2762  *
2763  *	Raise b to the power of x
2764  */
2765 Datum
numeric_power(PG_FUNCTION_ARGS)2766 numeric_power(PG_FUNCTION_ARGS)
2767 {
2768 	Numeric		num1 = PG_GETARG_NUMERIC(0);
2769 	Numeric		num2 = PG_GETARG_NUMERIC(1);
2770 	Numeric		res;
2771 	NumericVar	arg1;
2772 	NumericVar	arg2;
2773 	NumericVar	arg2_trunc;
2774 	NumericVar	result;
2775 
2776 	/*
2777 	 * Handle NaN
2778 	 */
2779 	if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2780 		PG_RETURN_NUMERIC(make_result(&const_nan));
2781 
2782 	/*
2783 	 * Initialize things
2784 	 */
2785 	init_var(&arg2_trunc);
2786 	init_var(&result);
2787 	init_var_from_num(num1, &arg1);
2788 	init_var_from_num(num2, &arg2);
2789 
2790 	set_var_from_var(&arg2, &arg2_trunc);
2791 	trunc_var(&arg2_trunc, 0);
2792 
2793 	/*
2794 	 * The SQL spec requires that we emit a particular SQLSTATE error code for
2795 	 * certain error conditions.  Specifically, we don't return a
2796 	 * divide-by-zero error code for 0 ^ -1.  Raising a negative number to a
2797 	 * non-integer power must produce the same error code, but that case is
2798 	 * handled in power_var().
2799 	 */
2800 	if (cmp_var(&arg1, &const_zero) == 0 &&
2801 		cmp_var(&arg2, &const_zero) < 0)
2802 		ereport(ERROR,
2803 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
2804 				 errmsg("zero raised to a negative power is undefined")));
2805 
2806 	/*
2807 	 * Call power_var() to compute and return the result; note it handles
2808 	 * scale selection itself.
2809 	 */
2810 	power_var(&arg1, &arg2, &result);
2811 
2812 	res = make_result(&result);
2813 
2814 	free_var(&result);
2815 	free_var(&arg2_trunc);
2816 
2817 	PG_RETURN_NUMERIC(res);
2818 }
2819 
2820 /*
2821  * numeric_scale() -
2822  *
2823  *	Returns the scale, i.e. the count of decimal digits in the fractional part
2824  */
2825 Datum
numeric_scale(PG_FUNCTION_ARGS)2826 numeric_scale(PG_FUNCTION_ARGS)
2827 {
2828 	Numeric		num = PG_GETARG_NUMERIC(0);
2829 
2830 	if (NUMERIC_IS_NAN(num))
2831 		PG_RETURN_NULL();
2832 
2833 	PG_RETURN_INT32(NUMERIC_DSCALE(num));
2834 }
2835 
2836 
2837 
2838 /* ----------------------------------------------------------------------
2839  *
2840  * Type conversion functions
2841  *
2842  * ----------------------------------------------------------------------
2843  */
2844 
2845 
2846 Datum
int4_numeric(PG_FUNCTION_ARGS)2847 int4_numeric(PG_FUNCTION_ARGS)
2848 {
2849 	int32		val = PG_GETARG_INT32(0);
2850 	Numeric		res;
2851 	NumericVar	result;
2852 
2853 	init_var(&result);
2854 
2855 	int64_to_numericvar((int64) val, &result);
2856 
2857 	res = make_result(&result);
2858 
2859 	free_var(&result);
2860 
2861 	PG_RETURN_NUMERIC(res);
2862 }
2863 
2864 
2865 Datum
numeric_int4(PG_FUNCTION_ARGS)2866 numeric_int4(PG_FUNCTION_ARGS)
2867 {
2868 	Numeric		num = PG_GETARG_NUMERIC(0);
2869 	NumericVar	x;
2870 	int32		result;
2871 
2872 	/* XXX would it be better to return NULL? */
2873 	if (NUMERIC_IS_NAN(num))
2874 		ereport(ERROR,
2875 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2876 				 errmsg("cannot convert NaN to integer")));
2877 
2878 	/* Convert to variable format, then convert to int4 */
2879 	init_var_from_num(num, &x);
2880 	result = numericvar_to_int32(&x);
2881 	PG_RETURN_INT32(result);
2882 }
2883 
2884 /*
2885  * Given a NumericVar, convert it to an int32. If the NumericVar
2886  * exceeds the range of an int32, raise the appropriate error via
2887  * ereport(). The input NumericVar is *not* free'd.
2888  */
2889 static int32
numericvar_to_int32(NumericVar * var)2890 numericvar_to_int32(NumericVar *var)
2891 {
2892 	int64		val;
2893 
2894 	if (!numericvar_to_int64(var, &val))
2895 		ereport(ERROR,
2896 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2897 				 errmsg("integer out of range")));
2898 
2899 	if (val < PG_INT32_MIN || val > PG_INT32_MAX)
2900 		ereport(ERROR,
2901 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2902 				 errmsg("integer out of range")));
2903 
2904 	/* Down-convert to int4 */
2905 	return (int32) val;
2906 }
2907 
2908 Datum
int8_numeric(PG_FUNCTION_ARGS)2909 int8_numeric(PG_FUNCTION_ARGS)
2910 {
2911 	int64		val = PG_GETARG_INT64(0);
2912 	Numeric		res;
2913 	NumericVar	result;
2914 
2915 	init_var(&result);
2916 
2917 	int64_to_numericvar(val, &result);
2918 
2919 	res = make_result(&result);
2920 
2921 	free_var(&result);
2922 
2923 	PG_RETURN_NUMERIC(res);
2924 }
2925 
2926 
2927 Datum
numeric_int8(PG_FUNCTION_ARGS)2928 numeric_int8(PG_FUNCTION_ARGS)
2929 {
2930 	Numeric		num = PG_GETARG_NUMERIC(0);
2931 	NumericVar	x;
2932 	int64		result;
2933 
2934 	/* XXX would it be better to return NULL? */
2935 	if (NUMERIC_IS_NAN(num))
2936 		ereport(ERROR,
2937 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2938 				 errmsg("cannot convert NaN to bigint")));
2939 
2940 	/* Convert to variable format and thence to int8 */
2941 	init_var_from_num(num, &x);
2942 
2943 	if (!numericvar_to_int64(&x, &result))
2944 		ereport(ERROR,
2945 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2946 				 errmsg("bigint out of range")));
2947 
2948 	PG_RETURN_INT64(result);
2949 }
2950 
2951 
2952 Datum
int2_numeric(PG_FUNCTION_ARGS)2953 int2_numeric(PG_FUNCTION_ARGS)
2954 {
2955 	int16		val = PG_GETARG_INT16(0);
2956 	Numeric		res;
2957 	NumericVar	result;
2958 
2959 	init_var(&result);
2960 
2961 	int64_to_numericvar((int64) val, &result);
2962 
2963 	res = make_result(&result);
2964 
2965 	free_var(&result);
2966 
2967 	PG_RETURN_NUMERIC(res);
2968 }
2969 
2970 
2971 Datum
numeric_int2(PG_FUNCTION_ARGS)2972 numeric_int2(PG_FUNCTION_ARGS)
2973 {
2974 	Numeric		num = PG_GETARG_NUMERIC(0);
2975 	NumericVar	x;
2976 	int64		val;
2977 	int16		result;
2978 
2979 	/* XXX would it be better to return NULL? */
2980 	if (NUMERIC_IS_NAN(num))
2981 		ereport(ERROR,
2982 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
2983 				 errmsg("cannot convert NaN to smallint")));
2984 
2985 	/* Convert to variable format and thence to int8 */
2986 	init_var_from_num(num, &x);
2987 
2988 	if (!numericvar_to_int64(&x, &val))
2989 		ereport(ERROR,
2990 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2991 				 errmsg("smallint out of range")));
2992 
2993 	if (val < PG_INT16_MIN || val > PG_INT16_MAX)
2994 		ereport(ERROR,
2995 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2996 				 errmsg("smallint out of range")));
2997 
2998 	/* Down-convert to int2 */
2999 	result = (int16) val;
3000 
3001 	PG_RETURN_INT16(result);
3002 }
3003 
3004 
3005 Datum
float8_numeric(PG_FUNCTION_ARGS)3006 float8_numeric(PG_FUNCTION_ARGS)
3007 {
3008 	float8		val = PG_GETARG_FLOAT8(0);
3009 	Numeric		res;
3010 	NumericVar	result;
3011 	char		buf[DBL_DIG + 100];
3012 
3013 	if (isnan(val))
3014 		PG_RETURN_NUMERIC(make_result(&const_nan));
3015 
3016 	if (isinf(val))
3017 		ereport(ERROR,
3018 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3019 				 errmsg("cannot convert infinity to numeric")));
3020 
3021 	snprintf(buf, sizeof(buf), "%.*g", DBL_DIG, val);
3022 
3023 	init_var(&result);
3024 
3025 	/* Assume we need not worry about leading/trailing spaces */
3026 	(void) set_var_from_str(buf, buf, &result);
3027 
3028 	res = make_result(&result);
3029 
3030 	free_var(&result);
3031 
3032 	PG_RETURN_NUMERIC(res);
3033 }
3034 
3035 
3036 Datum
numeric_float8(PG_FUNCTION_ARGS)3037 numeric_float8(PG_FUNCTION_ARGS)
3038 {
3039 	Numeric		num = PG_GETARG_NUMERIC(0);
3040 	char	   *tmp;
3041 	Datum		result;
3042 
3043 	if (NUMERIC_IS_NAN(num))
3044 		PG_RETURN_FLOAT8(get_float8_nan());
3045 
3046 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3047 											  NumericGetDatum(num)));
3048 
3049 	result = DirectFunctionCall1(float8in, CStringGetDatum(tmp));
3050 
3051 	pfree(tmp);
3052 
3053 	PG_RETURN_DATUM(result);
3054 }
3055 
3056 
3057 /* Convert numeric to float8; if out of range, return +/- HUGE_VAL */
3058 Datum
numeric_float8_no_overflow(PG_FUNCTION_ARGS)3059 numeric_float8_no_overflow(PG_FUNCTION_ARGS)
3060 {
3061 	Numeric		num = PG_GETARG_NUMERIC(0);
3062 	double		val;
3063 
3064 	if (NUMERIC_IS_NAN(num))
3065 		PG_RETURN_FLOAT8(get_float8_nan());
3066 
3067 	val = numeric_to_double_no_overflow(num);
3068 
3069 	PG_RETURN_FLOAT8(val);
3070 }
3071 
3072 Datum
float4_numeric(PG_FUNCTION_ARGS)3073 float4_numeric(PG_FUNCTION_ARGS)
3074 {
3075 	float4		val = PG_GETARG_FLOAT4(0);
3076 	Numeric		res;
3077 	NumericVar	result;
3078 	char		buf[FLT_DIG + 100];
3079 
3080 	if (isnan(val))
3081 		PG_RETURN_NUMERIC(make_result(&const_nan));
3082 
3083 	if (isinf(val))
3084 		ereport(ERROR,
3085 				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3086 				 errmsg("cannot convert infinity to numeric")));
3087 
3088 	snprintf(buf, sizeof(buf), "%.*g", FLT_DIG, val);
3089 
3090 	init_var(&result);
3091 
3092 	/* Assume we need not worry about leading/trailing spaces */
3093 	(void) set_var_from_str(buf, buf, &result);
3094 
3095 	res = make_result(&result);
3096 
3097 	free_var(&result);
3098 
3099 	PG_RETURN_NUMERIC(res);
3100 }
3101 
3102 
3103 Datum
numeric_float4(PG_FUNCTION_ARGS)3104 numeric_float4(PG_FUNCTION_ARGS)
3105 {
3106 	Numeric		num = PG_GETARG_NUMERIC(0);
3107 	char	   *tmp;
3108 	Datum		result;
3109 
3110 	if (NUMERIC_IS_NAN(num))
3111 		PG_RETURN_FLOAT4(get_float4_nan());
3112 
3113 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3114 											  NumericGetDatum(num)));
3115 
3116 	result = DirectFunctionCall1(float4in, CStringGetDatum(tmp));
3117 
3118 	pfree(tmp);
3119 
3120 	PG_RETURN_DATUM(result);
3121 }
3122 
3123 
3124 /* ----------------------------------------------------------------------
3125  *
3126  * Aggregate functions
3127  *
3128  * The transition datatype for all these aggregates is declared as INTERNAL.
3129  * Actually, it's a pointer to a NumericAggState allocated in the aggregate
3130  * context.  The digit buffers for the NumericVars will be there too.
3131  *
3132  * On platforms which support 128-bit integers some aggregates instead use a
3133  * 128-bit integer based transition datatype to speed up calculations.
3134  *
3135  * ----------------------------------------------------------------------
3136  */
3137 
3138 typedef struct NumericAggState
3139 {
3140 	bool		calcSumX2;		/* if true, calculate sumX2 */
3141 	MemoryContext agg_context;	/* context we're calculating in */
3142 	int64		N;				/* count of processed numbers */
3143 	NumericVar	sumX;			/* sum of processed numbers */
3144 	NumericVar	sumX2;			/* sum of squares of processed numbers */
3145 	int			maxScale;		/* maximum scale seen so far */
3146 	int64		maxScaleCount;	/* number of values seen with maximum scale */
3147 	int64		NaNcount;		/* count of NaN values (not included in N!) */
3148 } NumericAggState;
3149 
3150 /*
3151  * Prepare state data for a numeric aggregate function that needs to compute
3152  * sum, count and optionally sum of squares of the input.
3153  */
3154 static NumericAggState *
makeNumericAggState(FunctionCallInfo fcinfo,bool calcSumX2)3155 makeNumericAggState(FunctionCallInfo fcinfo, bool calcSumX2)
3156 {
3157 	NumericAggState *state;
3158 	MemoryContext agg_context;
3159 	MemoryContext old_context;
3160 
3161 	if (!AggCheckCallContext(fcinfo, &agg_context))
3162 		elog(ERROR, "aggregate function called in non-aggregate context");
3163 
3164 	old_context = MemoryContextSwitchTo(agg_context);
3165 
3166 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3167 	state->calcSumX2 = calcSumX2;
3168 	state->agg_context = agg_context;
3169 
3170 	MemoryContextSwitchTo(old_context);
3171 
3172 	return state;
3173 }
3174 
3175 /*
3176  * Like makeNumericAggState(), but allocate the state in the current memory
3177  * context.
3178  */
3179 static NumericAggState *
makeNumericAggStateCurrentContext(bool calcSumX2)3180 makeNumericAggStateCurrentContext(bool calcSumX2)
3181 {
3182 	NumericAggState *state;
3183 
3184 	state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3185 	state->calcSumX2 = calcSumX2;
3186 	state->agg_context = CurrentMemoryContext;
3187 
3188 	return state;
3189 }
3190 
3191 /*
3192  * Accumulate a new input value for numeric aggregate functions.
3193  */
3194 static void
do_numeric_accum(NumericAggState * state,Numeric newval)3195 do_numeric_accum(NumericAggState *state, Numeric newval)
3196 {
3197 	NumericVar	X;
3198 	NumericVar	X2;
3199 	MemoryContext old_context;
3200 
3201 	/* Count NaN inputs separately from all else */
3202 	if (NUMERIC_IS_NAN(newval))
3203 	{
3204 		state->NaNcount++;
3205 		return;
3206 	}
3207 
3208 	/* load processed number in short-lived context */
3209 	init_var_from_num(newval, &X);
3210 
3211 	/*
3212 	 * Track the highest input dscale that we've seen, to support inverse
3213 	 * transitions (see do_numeric_discard).
3214 	 */
3215 	if (X.dscale > state->maxScale)
3216 	{
3217 		state->maxScale = X.dscale;
3218 		state->maxScaleCount = 1;
3219 	}
3220 	else if (X.dscale == state->maxScale)
3221 		state->maxScaleCount++;
3222 
3223 	/* if we need X^2, calculate that in short-lived context */
3224 	if (state->calcSumX2)
3225 	{
3226 		init_var(&X2);
3227 		mul_var(&X, &X, &X2, X.dscale * 2);
3228 	}
3229 
3230 	/* The rest of this needs to work in the aggregate context */
3231 	old_context = MemoryContextSwitchTo(state->agg_context);
3232 
3233 	if (state->N++ > 0)
3234 	{
3235 		/* Accumulate sums */
3236 		add_var(&X, &(state->sumX), &(state->sumX));
3237 
3238 		if (state->calcSumX2)
3239 			add_var(&X2, &(state->sumX2), &(state->sumX2));
3240 	}
3241 	else
3242 	{
3243 		/* First input, so initialize sums */
3244 		set_var_from_var(&X, &(state->sumX));
3245 
3246 		if (state->calcSumX2)
3247 			set_var_from_var(&X2, &(state->sumX2));
3248 	}
3249 
3250 	MemoryContextSwitchTo(old_context);
3251 }
3252 
3253 /*
3254  * Attempt to remove an input value from the aggregated state.
3255  *
3256  * If the value cannot be removed then the function will return false; the
3257  * possible reasons for failing are described below.
3258  *
3259  * If we aggregate the values 1.01 and 2 then the result will be 3.01.
3260  * If we are then asked to un-aggregate the 1.01 then we must fail as we
3261  * won't be able to tell what the new aggregated value's dscale should be.
3262  * We don't want to return 2.00 (dscale = 2), since the sum's dscale would
3263  * have been zero if we'd really aggregated only 2.
3264  *
3265  * Note: alternatively, we could count the number of inputs with each possible
3266  * dscale (up to some sane limit).  Not yet clear if it's worth the trouble.
3267  */
3268 static bool
do_numeric_discard(NumericAggState * state,Numeric newval)3269 do_numeric_discard(NumericAggState *state, Numeric newval)
3270 {
3271 	NumericVar	X;
3272 	NumericVar	X2;
3273 	MemoryContext old_context;
3274 
3275 	/* Count NaN inputs separately from all else */
3276 	if (NUMERIC_IS_NAN(newval))
3277 	{
3278 		state->NaNcount--;
3279 		return true;
3280 	}
3281 
3282 	/* load processed number in short-lived context */
3283 	init_var_from_num(newval, &X);
3284 
3285 	/*
3286 	 * state->sumX's dscale is the maximum dscale of any of the inputs.
3287 	 * Removing the last input with that dscale would require us to recompute
3288 	 * the maximum dscale of the *remaining* inputs, which we cannot do unless
3289 	 * no more non-NaN inputs remain at all.  So we report a failure instead,
3290 	 * and force the aggregation to be redone from scratch.
3291 	 */
3292 	if (X.dscale == state->maxScale)
3293 	{
3294 		if (state->maxScaleCount > 1 || state->maxScale == 0)
3295 		{
3296 			/*
3297 			 * Some remaining inputs have same dscale, or dscale hasn't gotten
3298 			 * above zero anyway
3299 			 */
3300 			state->maxScaleCount--;
3301 		}
3302 		else if (state->N == 1)
3303 		{
3304 			/* No remaining non-NaN inputs at all, so reset maxScale */
3305 			state->maxScale = 0;
3306 			state->maxScaleCount = 0;
3307 		}
3308 		else
3309 		{
3310 			/* Correct new maxScale is uncertain, must fail */
3311 			return false;
3312 		}
3313 	}
3314 
3315 	/* if we need X^2, calculate that in short-lived context */
3316 	if (state->calcSumX2)
3317 	{
3318 		init_var(&X2);
3319 		mul_var(&X, &X, &X2, X.dscale * 2);
3320 	}
3321 
3322 	/* The rest of this needs to work in the aggregate context */
3323 	old_context = MemoryContextSwitchTo(state->agg_context);
3324 
3325 	if (state->N-- > 1)
3326 	{
3327 		/* De-accumulate sums */
3328 		sub_var(&(state->sumX), &X, &(state->sumX));
3329 
3330 		if (state->calcSumX2)
3331 			sub_var(&(state->sumX2), &X2, &(state->sumX2));
3332 	}
3333 	else
3334 	{
3335 		/* Sums will be reset by next call to do_numeric_accum */
3336 		Assert(state->N == 0);
3337 	}
3338 
3339 	MemoryContextSwitchTo(old_context);
3340 
3341 	return true;
3342 }
3343 
3344 /*
3345  * Generic transition function for numeric aggregates that require sumX2.
3346  */
3347 Datum
numeric_accum(PG_FUNCTION_ARGS)3348 numeric_accum(PG_FUNCTION_ARGS)
3349 {
3350 	NumericAggState *state;
3351 
3352 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3353 
3354 	/* Create the state data on the first call */
3355 	if (state == NULL)
3356 		state = makeNumericAggState(fcinfo, true);
3357 
3358 	if (!PG_ARGISNULL(1))
3359 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
3360 
3361 	PG_RETURN_POINTER(state);
3362 }
3363 
3364 /*
3365  * Generic combine function for numeric aggregates which require sumX2
3366  */
3367 Datum
numeric_combine(PG_FUNCTION_ARGS)3368 numeric_combine(PG_FUNCTION_ARGS)
3369 {
3370 	NumericAggState *state1;
3371 	NumericAggState *state2;
3372 	MemoryContext agg_context;
3373 	MemoryContext old_context;
3374 
3375 	if (!AggCheckCallContext(fcinfo, &agg_context))
3376 		elog(ERROR, "aggregate function called in non-aggregate context");
3377 
3378 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3379 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
3380 
3381 	if (state2 == NULL)
3382 		PG_RETURN_POINTER(state1);
3383 
3384 	/* manually copy all fields from state2 to state1 */
3385 	if (state1 == NULL)
3386 	{
3387 		old_context = MemoryContextSwitchTo(agg_context);
3388 
3389 		state1 = makeNumericAggStateCurrentContext(true);
3390 		state1->N = state2->N;
3391 		state1->NaNcount = state2->NaNcount;
3392 		state1->maxScale = state2->maxScale;
3393 		state1->maxScaleCount = state2->maxScaleCount;
3394 
3395 		init_var(&state1->sumX);
3396 		set_var_from_var(&state2->sumX, &state1->sumX);
3397 
3398 		init_var(&state1->sumX2);
3399 		set_var_from_var(&state2->sumX2, &state1->sumX2);
3400 
3401 		MemoryContextSwitchTo(old_context);
3402 
3403 		PG_RETURN_POINTER(state1);
3404 	}
3405 
3406 	state1->N += state2->N;
3407 	state1->NaNcount += state2->NaNcount;
3408 
3409 	if (state2->N > 0)
3410 	{
3411 		/*
3412 		 * These are currently only needed for moving aggregates, but let's do
3413 		 * the right thing anyway...
3414 		 */
3415 		if (state2->maxScale > state1->maxScale)
3416 		{
3417 			state1->maxScale = state2->maxScale;
3418 			state1->maxScaleCount = state2->maxScaleCount;
3419 		}
3420 		else if (state2->maxScale == state1->maxScale)
3421 			state1->maxScaleCount += state2->maxScaleCount;
3422 
3423 		/* The rest of this needs to work in the aggregate context */
3424 		old_context = MemoryContextSwitchTo(agg_context);
3425 
3426 		/* Accumulate sums */
3427 		add_var(&(state1->sumX), &(state2->sumX), &(state1->sumX));
3428 		add_var(&(state1->sumX2), &(state2->sumX2), &(state1->sumX2));
3429 
3430 		MemoryContextSwitchTo(old_context);
3431 	}
3432 	PG_RETURN_POINTER(state1);
3433 }
3434 
3435 /*
3436  * Generic transition function for numeric aggregates that don't require sumX2.
3437  */
3438 Datum
numeric_avg_accum(PG_FUNCTION_ARGS)3439 numeric_avg_accum(PG_FUNCTION_ARGS)
3440 {
3441 	NumericAggState *state;
3442 
3443 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3444 
3445 	/* Create the state data on the first call */
3446 	if (state == NULL)
3447 		state = makeNumericAggState(fcinfo, false);
3448 
3449 	if (!PG_ARGISNULL(1))
3450 		do_numeric_accum(state, PG_GETARG_NUMERIC(1));
3451 
3452 	PG_RETURN_POINTER(state);
3453 }
3454 
3455 /*
3456  * Combine function for numeric aggregates which don't require sumX2
3457  */
3458 Datum
numeric_avg_combine(PG_FUNCTION_ARGS)3459 numeric_avg_combine(PG_FUNCTION_ARGS)
3460 {
3461 	NumericAggState *state1;
3462 	NumericAggState *state2;
3463 	MemoryContext agg_context;
3464 	MemoryContext old_context;
3465 
3466 	if (!AggCheckCallContext(fcinfo, &agg_context))
3467 		elog(ERROR, "aggregate function called in non-aggregate context");
3468 
3469 	state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3470 	state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
3471 
3472 	if (state2 == NULL)
3473 		PG_RETURN_POINTER(state1);
3474 
3475 	/* manually copy all fields from state2 to state1 */
3476 	if (state1 == NULL)
3477 	{
3478 		old_context = MemoryContextSwitchTo(agg_context);
3479 
3480 		state1 = makeNumericAggStateCurrentContext(false);
3481 		state1->N = state2->N;
3482 		state1->NaNcount = state2->NaNcount;
3483 		state1->maxScale = state2->maxScale;
3484 		state1->maxScaleCount = state2->maxScaleCount;
3485 
3486 		init_var(&state1->sumX);
3487 		set_var_from_var(&state2->sumX, &state1->sumX);
3488 
3489 		MemoryContextSwitchTo(old_context);
3490 
3491 		PG_RETURN_POINTER(state1);
3492 	}
3493 
3494 	state1->N += state2->N;
3495 	state1->NaNcount += state2->NaNcount;
3496 
3497 	if (state2->N > 0)
3498 	{
3499 		/*
3500 		 * These are currently only needed for moving aggregates, but let's do
3501 		 * the right thing anyway...
3502 		 */
3503 		if (state2->maxScale > state1->maxScale)
3504 		{
3505 			state1->maxScale = state2->maxScale;
3506 			state1->maxScaleCount = state2->maxScaleCount;
3507 		}
3508 		else if (state2->maxScale == state1->maxScale)
3509 			state1->maxScaleCount += state2->maxScaleCount;
3510 
3511 		/* The rest of this needs to work in the aggregate context */
3512 		old_context = MemoryContextSwitchTo(agg_context);
3513 
3514 		/* Accumulate sums */
3515 		add_var(&(state1->sumX), &(state2->sumX), &(state1->sumX));
3516 
3517 		MemoryContextSwitchTo(old_context);
3518 	}
3519 	PG_RETURN_POINTER(state1);
3520 }
3521 
3522 /*
3523  * numeric_avg_serialize
3524  *		Serialize NumericAggState for numeric aggregates that don't require
3525  *		sumX2.
3526  */
3527 Datum
numeric_avg_serialize(PG_FUNCTION_ARGS)3528 numeric_avg_serialize(PG_FUNCTION_ARGS)
3529 {
3530 	NumericAggState *state;
3531 	StringInfoData buf;
3532 	Datum		temp;
3533 	bytea	   *sumX;
3534 	bytea	   *result;
3535 
3536 	/* Ensure we disallow calling when not in aggregate context */
3537 	if (!AggCheckCallContext(fcinfo, NULL))
3538 		elog(ERROR, "aggregate function called in non-aggregate context");
3539 
3540 	state = (NumericAggState *) PG_GETARG_POINTER(0);
3541 
3542 	/*
3543 	 * This is a little wasteful since make_result converts the NumericVar
3544 	 * into a Numeric and numeric_send converts it back again. Is it worth
3545 	 * splitting the tasks in numeric_send into separate functions to stop
3546 	 * this? Doing so would also remove the fmgr call overhead.
3547 	 */
3548 	temp = DirectFunctionCall1(numeric_send,
3549 							   NumericGetDatum(make_result(&state->sumX)));
3550 	sumX = DatumGetByteaP(temp);
3551 
3552 	pq_begintypsend(&buf);
3553 
3554 	/* N */
3555 	pq_sendint64(&buf, state->N);
3556 
3557 	/* sumX */
3558 	pq_sendbytes(&buf, VARDATA(sumX), VARSIZE(sumX) - VARHDRSZ);
3559 
3560 	/* maxScale */
3561 	pq_sendint(&buf, state->maxScale, 4);
3562 
3563 	/* maxScaleCount */
3564 	pq_sendint64(&buf, state->maxScaleCount);
3565 
3566 	/* NaNcount */
3567 	pq_sendint64(&buf, state->NaNcount);
3568 
3569 	result = pq_endtypsend(&buf);
3570 
3571 	PG_RETURN_BYTEA_P(result);
3572 }
3573 
3574 /*
3575  * numeric_avg_deserialize
3576  *		Deserialize bytea into NumericAggState for numeric aggregates that
3577  *		don't require sumX2.
3578  */
3579 Datum
numeric_avg_deserialize(PG_FUNCTION_ARGS)3580 numeric_avg_deserialize(PG_FUNCTION_ARGS)
3581 {
3582 	bytea	   *sstate;
3583 	NumericAggState *result;
3584 	Datum		temp;
3585 	StringInfoData buf;
3586 
3587 	if (!AggCheckCallContext(fcinfo, NULL))
3588 		elog(ERROR, "aggregate function called in non-aggregate context");
3589 
3590 	sstate = PG_GETARG_BYTEA_P(0);
3591 
3592 	/*
3593 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
3594 	 * standard recv-function infrastructure.
3595 	 */
3596 	initStringInfo(&buf);
3597 	appendBinaryStringInfo(&buf, VARDATA(sstate), VARSIZE(sstate) - VARHDRSZ);
3598 
3599 	result = makeNumericAggStateCurrentContext(false);
3600 
3601 	/* N */
3602 	result->N = pq_getmsgint64(&buf);
3603 
3604 	/* sumX */
3605 	temp = DirectFunctionCall3(numeric_recv,
3606 							   PointerGetDatum(&buf),
3607 							   InvalidOid,
3608 							   -1);
3609 	set_var_from_num(DatumGetNumeric(temp), &result->sumX);
3610 
3611 	/* maxScale */
3612 	result->maxScale = pq_getmsgint(&buf, 4);
3613 
3614 	/* maxScaleCount */
3615 	result->maxScaleCount = pq_getmsgint64(&buf);
3616 
3617 	/* NaNcount */
3618 	result->NaNcount = pq_getmsgint64(&buf);
3619 
3620 	pq_getmsgend(&buf);
3621 	pfree(buf.data);
3622 
3623 	PG_RETURN_POINTER(result);
3624 }
3625 
3626 /*
3627  * numeric_serialize
3628  *		Serialization function for NumericAggState for numeric aggregates that
3629  *		require sumX2.
3630  */
3631 Datum
numeric_serialize(PG_FUNCTION_ARGS)3632 numeric_serialize(PG_FUNCTION_ARGS)
3633 {
3634 	NumericAggState *state;
3635 	StringInfoData buf;
3636 	Datum		temp;
3637 	bytea	   *sumX;
3638 	bytea	   *sumX2;
3639 	bytea	   *result;
3640 
3641 	/* Ensure we disallow calling when not in aggregate context */
3642 	if (!AggCheckCallContext(fcinfo, NULL))
3643 		elog(ERROR, "aggregate function called in non-aggregate context");
3644 
3645 	state = (NumericAggState *) PG_GETARG_POINTER(0);
3646 
3647 	/*
3648 	 * This is a little wasteful since make_result converts the NumericVar
3649 	 * into a Numeric and numeric_send converts it back again. Is it worth
3650 	 * splitting the tasks in numeric_send into separate functions to stop
3651 	 * this? Doing so would also remove the fmgr call overhead.
3652 	 */
3653 	temp = DirectFunctionCall1(numeric_send,
3654 							   NumericGetDatum(make_result(&state->sumX)));
3655 	sumX = DatumGetByteaP(temp);
3656 
3657 	temp = DirectFunctionCall1(numeric_send,
3658 							   NumericGetDatum(make_result(&state->sumX2)));
3659 	sumX2 = DatumGetByteaP(temp);
3660 
3661 	pq_begintypsend(&buf);
3662 
3663 	/* N */
3664 	pq_sendint64(&buf, state->N);
3665 
3666 	/* sumX */
3667 	pq_sendbytes(&buf, VARDATA(sumX), VARSIZE(sumX) - VARHDRSZ);
3668 
3669 	/* sumX2 */
3670 	pq_sendbytes(&buf, VARDATA(sumX2), VARSIZE(sumX2) - VARHDRSZ);
3671 
3672 	/* maxScale */
3673 	pq_sendint(&buf, state->maxScale, 4);
3674 
3675 	/* maxScaleCount */
3676 	pq_sendint64(&buf, state->maxScaleCount);
3677 
3678 	/* NaNcount */
3679 	pq_sendint64(&buf, state->NaNcount);
3680 
3681 	result = pq_endtypsend(&buf);
3682 
3683 	PG_RETURN_BYTEA_P(result);
3684 }
3685 
3686 /*
3687  * numeric_deserialize
3688  *		Deserialization function for NumericAggState for numeric aggregates that
3689  *		require sumX2.
3690  */
3691 Datum
numeric_deserialize(PG_FUNCTION_ARGS)3692 numeric_deserialize(PG_FUNCTION_ARGS)
3693 {
3694 	bytea	   *sstate;
3695 	NumericAggState *result;
3696 	Datum		temp;
3697 	StringInfoData buf;
3698 
3699 	if (!AggCheckCallContext(fcinfo, NULL))
3700 		elog(ERROR, "aggregate function called in non-aggregate context");
3701 
3702 	sstate = PG_GETARG_BYTEA_P(0);
3703 
3704 	/*
3705 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
3706 	 * standard recv-function infrastructure.
3707 	 */
3708 	initStringInfo(&buf);
3709 	appendBinaryStringInfo(&buf, VARDATA(sstate), VARSIZE(sstate) - VARHDRSZ);
3710 
3711 	result = makeNumericAggStateCurrentContext(false);
3712 
3713 	/* N */
3714 	result->N = pq_getmsgint64(&buf);
3715 
3716 	/* sumX */
3717 	temp = DirectFunctionCall3(numeric_recv,
3718 							   PointerGetDatum(&buf),
3719 							   InvalidOid,
3720 							   -1);
3721 	set_var_from_num(DatumGetNumeric(temp), &result->sumX);
3722 
3723 	/* sumX2 */
3724 	temp = DirectFunctionCall3(numeric_recv,
3725 							   PointerGetDatum(&buf),
3726 							   InvalidOid,
3727 							   -1);
3728 	set_var_from_num(DatumGetNumeric(temp), &result->sumX2);
3729 
3730 	/* maxScale */
3731 	result->maxScale = pq_getmsgint(&buf, 4);
3732 
3733 	/* maxScaleCount */
3734 	result->maxScaleCount = pq_getmsgint64(&buf);
3735 
3736 	/* NaNcount */
3737 	result->NaNcount = pq_getmsgint64(&buf);
3738 
3739 	pq_getmsgend(&buf);
3740 	pfree(buf.data);
3741 
3742 	PG_RETURN_POINTER(result);
3743 }
3744 
3745 /*
3746  * Generic inverse transition function for numeric aggregates
3747  * (with or without requirement for X^2).
3748  */
3749 Datum
numeric_accum_inv(PG_FUNCTION_ARGS)3750 numeric_accum_inv(PG_FUNCTION_ARGS)
3751 {
3752 	NumericAggState *state;
3753 
3754 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3755 
3756 	/* Should not get here with no state */
3757 	if (state == NULL)
3758 		elog(ERROR, "numeric_accum_inv called with NULL state");
3759 
3760 	if (!PG_ARGISNULL(1))
3761 	{
3762 		/* If we fail to perform the inverse transition, return NULL */
3763 		if (!do_numeric_discard(state, PG_GETARG_NUMERIC(1)))
3764 			PG_RETURN_NULL();
3765 	}
3766 
3767 	PG_RETURN_POINTER(state);
3768 }
3769 
3770 
3771 /*
3772  * Integer data types in general use Numeric accumulators to share code
3773  * and avoid risk of overflow.
3774  *
3775  * However for performance reasons optimized special-purpose accumulator
3776  * routines are used when possible.
3777  *
3778  * On platforms with 128-bit integer support, the 128-bit routines will be
3779  * used when sum(X) or sum(X*X) fit into 128-bit.
3780  *
3781  * For 16 and 32 bit inputs, the N and sum(X) fit into 64-bit so the 64-bit
3782  * accumulators will be used for SUM and AVG of these data types.
3783  */
3784 
3785 #ifdef HAVE_INT128
3786 typedef struct Int128AggState
3787 {
3788 	bool		calcSumX2;		/* if true, calculate sumX2 */
3789 	int64		N;				/* count of processed numbers */
3790 	int128		sumX;			/* sum of processed numbers */
3791 	int128		sumX2;			/* sum of squares of processed numbers */
3792 } Int128AggState;
3793 
3794 /*
3795  * Prepare state data for a 128-bit aggregate function that needs to compute
3796  * sum, count and optionally sum of squares of the input.
3797  */
3798 static Int128AggState *
makeInt128AggState(FunctionCallInfo fcinfo,bool calcSumX2)3799 makeInt128AggState(FunctionCallInfo fcinfo, bool calcSumX2)
3800 {
3801 	Int128AggState *state;
3802 	MemoryContext agg_context;
3803 	MemoryContext old_context;
3804 
3805 	if (!AggCheckCallContext(fcinfo, &agg_context))
3806 		elog(ERROR, "aggregate function called in non-aggregate context");
3807 
3808 	old_context = MemoryContextSwitchTo(agg_context);
3809 
3810 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
3811 	state->calcSumX2 = calcSumX2;
3812 
3813 	MemoryContextSwitchTo(old_context);
3814 
3815 	return state;
3816 }
3817 
3818 /*
3819  * Like makeInt128AggState(), but allocate the state in the current memory
3820  * context.
3821  */
3822 static Int128AggState *
makeInt128AggStateCurrentContext(bool calcSumX2)3823 makeInt128AggStateCurrentContext(bool calcSumX2)
3824 {
3825 	Int128AggState *state;
3826 
3827 	state = (Int128AggState *) palloc0(sizeof(Int128AggState));
3828 	state->calcSumX2 = calcSumX2;
3829 
3830 	return state;
3831 }
3832 
3833 /*
3834  * Accumulate a new input value for 128-bit aggregate functions.
3835  */
3836 static void
do_int128_accum(Int128AggState * state,int128 newval)3837 do_int128_accum(Int128AggState *state, int128 newval)
3838 {
3839 	if (state->calcSumX2)
3840 		state->sumX2 += newval * newval;
3841 
3842 	state->sumX += newval;
3843 	state->N++;
3844 }
3845 
3846 /*
3847  * Remove an input value from the aggregated state.
3848  */
3849 static void
do_int128_discard(Int128AggState * state,int128 newval)3850 do_int128_discard(Int128AggState *state, int128 newval)
3851 {
3852 	if (state->calcSumX2)
3853 		state->sumX2 -= newval * newval;
3854 
3855 	state->sumX -= newval;
3856 	state->N--;
3857 }
3858 
3859 typedef Int128AggState PolyNumAggState;
3860 #define makePolyNumAggState makeInt128AggState
3861 #define makePolyNumAggStateCurrentContext makeInt128AggStateCurrentContext
3862 #else
3863 typedef NumericAggState PolyNumAggState;
3864 #define makePolyNumAggState makeNumericAggState
3865 #define makePolyNumAggStateCurrentContext makeNumericAggStateCurrentContext
3866 #endif
3867 
3868 Datum
int2_accum(PG_FUNCTION_ARGS)3869 int2_accum(PG_FUNCTION_ARGS)
3870 {
3871 	PolyNumAggState *state;
3872 
3873 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
3874 
3875 	/* Create the state data on the first call */
3876 	if (state == NULL)
3877 		state = makePolyNumAggState(fcinfo, true);
3878 
3879 	if (!PG_ARGISNULL(1))
3880 	{
3881 #ifdef HAVE_INT128
3882 		do_int128_accum(state, (int128) PG_GETARG_INT16(1));
3883 #else
3884 		Numeric		newval;
3885 
3886 		newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
3887 													 PG_GETARG_DATUM(1)));
3888 		do_numeric_accum(state, newval);
3889 #endif
3890 	}
3891 
3892 	PG_RETURN_POINTER(state);
3893 }
3894 
3895 Datum
int4_accum(PG_FUNCTION_ARGS)3896 int4_accum(PG_FUNCTION_ARGS)
3897 {
3898 	PolyNumAggState *state;
3899 
3900 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
3901 
3902 	/* Create the state data on the first call */
3903 	if (state == NULL)
3904 		state = makePolyNumAggState(fcinfo, true);
3905 
3906 	if (!PG_ARGISNULL(1))
3907 	{
3908 #ifdef HAVE_INT128
3909 		do_int128_accum(state, (int128) PG_GETARG_INT32(1));
3910 #else
3911 		Numeric		newval;
3912 
3913 		newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
3914 													 PG_GETARG_DATUM(1)));
3915 		do_numeric_accum(state, newval);
3916 #endif
3917 	}
3918 
3919 	PG_RETURN_POINTER(state);
3920 }
3921 
3922 Datum
int8_accum(PG_FUNCTION_ARGS)3923 int8_accum(PG_FUNCTION_ARGS)
3924 {
3925 	NumericAggState *state;
3926 
3927 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3928 
3929 	/* Create the state data on the first call */
3930 	if (state == NULL)
3931 		state = makeNumericAggState(fcinfo, true);
3932 
3933 	if (!PG_ARGISNULL(1))
3934 	{
3935 		Numeric		newval;
3936 
3937 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
3938 													 PG_GETARG_DATUM(1)));
3939 		do_numeric_accum(state, newval);
3940 	}
3941 
3942 	PG_RETURN_POINTER(state);
3943 }
3944 
3945 /*
3946  * Combine function for numeric aggregates which require sumX2
3947  */
3948 Datum
numeric_poly_combine(PG_FUNCTION_ARGS)3949 numeric_poly_combine(PG_FUNCTION_ARGS)
3950 {
3951 	PolyNumAggState *state1;
3952 	PolyNumAggState *state2;
3953 	MemoryContext agg_context;
3954 	MemoryContext old_context;
3955 
3956 	if (!AggCheckCallContext(fcinfo, &agg_context))
3957 		elog(ERROR, "aggregate function called in non-aggregate context");
3958 
3959 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
3960 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
3961 
3962 	if (state2 == NULL)
3963 		PG_RETURN_POINTER(state1);
3964 
3965 	/* manually copy all fields from state2 to state1 */
3966 	if (state1 == NULL)
3967 	{
3968 		old_context = MemoryContextSwitchTo(agg_context);
3969 
3970 		state1 = makePolyNumAggState(fcinfo, true);
3971 		state1->N = state2->N;
3972 
3973 #ifdef HAVE_INT128
3974 		state1->sumX = state2->sumX;
3975 		state1->sumX2 = state2->sumX2;
3976 #else
3977 		init_var(&(state1->sumX));
3978 		set_var_from_var(&(state2->sumX), &(state1->sumX));
3979 
3980 		init_var(&state1->sumX2);
3981 		set_var_from_var(&(state2->sumX2), &(state1->sumX2));
3982 #endif
3983 
3984 		MemoryContextSwitchTo(old_context);
3985 
3986 		PG_RETURN_POINTER(state1);
3987 	}
3988 
3989 	if (state2->N > 0)
3990 	{
3991 		state1->N += state2->N;
3992 
3993 #ifdef HAVE_INT128
3994 		state1->sumX += state2->sumX;
3995 		state1->sumX2 += state2->sumX2;
3996 #else
3997 		/* The rest of this needs to work in the aggregate context */
3998 		old_context = MemoryContextSwitchTo(agg_context);
3999 
4000 		/* Accumulate sums */
4001 		add_var(&(state1->sumX), &(state2->sumX), &(state1->sumX));
4002 		add_var(&(state1->sumX2), &(state2->sumX2), &(state1->sumX2));
4003 
4004 		MemoryContextSwitchTo(old_context);
4005 #endif
4006 
4007 	}
4008 	PG_RETURN_POINTER(state1);
4009 }
4010 
4011 /*
4012  * numeric_poly_serialize
4013  *		Serialize PolyNumAggState into bytea for aggregate functions which
4014  *		require sumX2.
4015  */
4016 Datum
numeric_poly_serialize(PG_FUNCTION_ARGS)4017 numeric_poly_serialize(PG_FUNCTION_ARGS)
4018 {
4019 	PolyNumAggState *state;
4020 	StringInfoData buf;
4021 	bytea	   *sumX;
4022 	bytea	   *sumX2;
4023 	bytea	   *result;
4024 
4025 	/* Ensure we disallow calling when not in aggregate context */
4026 	if (!AggCheckCallContext(fcinfo, NULL))
4027 		elog(ERROR, "aggregate function called in non-aggregate context");
4028 
4029 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4030 
4031 	/*
4032 	 * If the platform supports int128 then sumX and sumX2 will be a 128 bit
4033 	 * integer type. Here we'll convert that into a numeric type so that the
4034 	 * combine state is in the same format for both int128 enabled machines
4035 	 * and machines which don't support that type. The logic here is that one
4036 	 * day we might like to send these over to another server for further
4037 	 * processing and we want a standard format to work with.
4038 	 */
4039 	{
4040 		Datum		temp;
4041 
4042 #ifdef HAVE_INT128
4043 		NumericVar	num;
4044 
4045 		init_var(&num);
4046 		int128_to_numericvar(state->sumX, &num);
4047 		temp = DirectFunctionCall1(numeric_send,
4048 								   NumericGetDatum(make_result(&num)));
4049 		sumX = DatumGetByteaP(temp);
4050 
4051 		int128_to_numericvar(state->sumX2, &num);
4052 		temp = DirectFunctionCall1(numeric_send,
4053 								   NumericGetDatum(make_result(&num)));
4054 		sumX2 = DatumGetByteaP(temp);
4055 		free_var(&num);
4056 #else
4057 		temp = DirectFunctionCall1(numeric_send,
4058 								 NumericGetDatum(make_result(&state->sumX)));
4059 		sumX = DatumGetByteaP(temp);
4060 
4061 		temp = DirectFunctionCall1(numeric_send,
4062 								NumericGetDatum(make_result(&state->sumX2)));
4063 		sumX2 = DatumGetByteaP(temp);
4064 #endif
4065 	}
4066 
4067 	pq_begintypsend(&buf);
4068 
4069 	/* N */
4070 	pq_sendint64(&buf, state->N);
4071 
4072 	/* sumX */
4073 	pq_sendbytes(&buf, VARDATA(sumX), VARSIZE(sumX) - VARHDRSZ);
4074 
4075 	/* sumX2 */
4076 	pq_sendbytes(&buf, VARDATA(sumX2), VARSIZE(sumX2) - VARHDRSZ);
4077 
4078 	result = pq_endtypsend(&buf);
4079 
4080 	PG_RETURN_BYTEA_P(result);
4081 }
4082 
4083 /*
4084  * numeric_poly_deserialize
4085  *		Deserialize PolyNumAggState from bytea for aggregate functions which
4086  *		require sumX2.
4087  */
4088 Datum
numeric_poly_deserialize(PG_FUNCTION_ARGS)4089 numeric_poly_deserialize(PG_FUNCTION_ARGS)
4090 {
4091 	bytea	   *sstate;
4092 	PolyNumAggState *result;
4093 	Datum		sumX;
4094 	Datum		sumX2;
4095 	StringInfoData buf;
4096 
4097 	if (!AggCheckCallContext(fcinfo, NULL))
4098 		elog(ERROR, "aggregate function called in non-aggregate context");
4099 
4100 	sstate = PG_GETARG_BYTEA_P(0);
4101 
4102 	/*
4103 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4104 	 * standard recv-function infrastructure.
4105 	 */
4106 	initStringInfo(&buf);
4107 	appendBinaryStringInfo(&buf, VARDATA(sstate), VARSIZE(sstate) - VARHDRSZ);
4108 
4109 	result = makePolyNumAggStateCurrentContext(false);
4110 
4111 	/* N */
4112 	result->N = pq_getmsgint64(&buf);
4113 
4114 	/* sumX */
4115 	sumX = DirectFunctionCall3(numeric_recv,
4116 							   PointerGetDatum(&buf),
4117 							   InvalidOid,
4118 							   -1);
4119 
4120 	/* sumX2 */
4121 	sumX2 = DirectFunctionCall3(numeric_recv,
4122 								PointerGetDatum(&buf),
4123 								InvalidOid,
4124 								-1);
4125 
4126 #ifdef HAVE_INT128
4127 	{
4128 		NumericVar	num;
4129 
4130 		init_var(&num);
4131 		set_var_from_num(DatumGetNumeric(sumX), &num);
4132 		numericvar_to_int128(&num, &result->sumX);
4133 
4134 		set_var_from_num(DatumGetNumeric(sumX2), &num);
4135 		numericvar_to_int128(&num, &result->sumX2);
4136 
4137 		free_var(&num);
4138 	}
4139 #else
4140 	set_var_from_num(DatumGetNumeric(sumX), &result->sumX);
4141 	set_var_from_num(DatumGetNumeric(sumX2), &result->sumX2);
4142 #endif
4143 
4144 	pq_getmsgend(&buf);
4145 	pfree(buf.data);
4146 
4147 	PG_RETURN_POINTER(result);
4148 }
4149 
4150 /*
4151  * Transition function for int8 input when we don't need sumX2.
4152  */
4153 Datum
int8_avg_accum(PG_FUNCTION_ARGS)4154 int8_avg_accum(PG_FUNCTION_ARGS)
4155 {
4156 	PolyNumAggState *state;
4157 
4158 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4159 
4160 	/* Create the state data on the first call */
4161 	if (state == NULL)
4162 		state = makePolyNumAggState(fcinfo, false);
4163 
4164 	if (!PG_ARGISNULL(1))
4165 	{
4166 #ifdef HAVE_INT128
4167 		do_int128_accum(state, (int128) PG_GETARG_INT64(1));
4168 #else
4169 		Numeric		newval;
4170 
4171 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4172 													 PG_GETARG_DATUM(1)));
4173 		do_numeric_accum(state, newval);
4174 #endif
4175 	}
4176 
4177 	PG_RETURN_POINTER(state);
4178 }
4179 
4180 /*
4181  * Combine function for PolyNumAggState for aggregates which don't require
4182  * sumX2
4183  */
4184 Datum
int8_avg_combine(PG_FUNCTION_ARGS)4185 int8_avg_combine(PG_FUNCTION_ARGS)
4186 {
4187 	PolyNumAggState *state1;
4188 	PolyNumAggState *state2;
4189 	MemoryContext agg_context;
4190 	MemoryContext old_context;
4191 
4192 	if (!AggCheckCallContext(fcinfo, &agg_context))
4193 		elog(ERROR, "aggregate function called in non-aggregate context");
4194 
4195 	state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4196 	state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4197 
4198 	if (state2 == NULL)
4199 		PG_RETURN_POINTER(state1);
4200 
4201 	/* manually copy all fields from state2 to state1 */
4202 	if (state1 == NULL)
4203 	{
4204 		old_context = MemoryContextSwitchTo(agg_context);
4205 
4206 		state1 = makePolyNumAggState(fcinfo, false);
4207 		state1->N = state2->N;
4208 
4209 #ifdef HAVE_INT128
4210 		state1->sumX = state2->sumX;
4211 #else
4212 		init_var(&state1->sumX);
4213 		set_var_from_var(&state2->sumX, &state1->sumX);
4214 #endif
4215 		MemoryContextSwitchTo(old_context);
4216 
4217 		PG_RETURN_POINTER(state1);
4218 	}
4219 
4220 	if (state2->N > 0)
4221 	{
4222 		state1->N += state2->N;
4223 
4224 #ifdef HAVE_INT128
4225 		state1->sumX += state2->sumX;
4226 #else
4227 		/* The rest of this needs to work in the aggregate context */
4228 		old_context = MemoryContextSwitchTo(agg_context);
4229 
4230 		/* Accumulate sums */
4231 		add_var(&(state1->sumX), &(state2->sumX), &(state1->sumX));
4232 
4233 		MemoryContextSwitchTo(old_context);
4234 #endif
4235 
4236 	}
4237 	PG_RETURN_POINTER(state1);
4238 }
4239 
4240 /*
4241  * int8_avg_serialize
4242  *		Serialize PolyNumAggState into bytea using the standard
4243  *		recv-function infrastructure.
4244  */
4245 Datum
int8_avg_serialize(PG_FUNCTION_ARGS)4246 int8_avg_serialize(PG_FUNCTION_ARGS)
4247 {
4248 	PolyNumAggState *state;
4249 	StringInfoData buf;
4250 	bytea	   *sumX;
4251 	bytea	   *result;
4252 
4253 	/* Ensure we disallow calling when not in aggregate context */
4254 	if (!AggCheckCallContext(fcinfo, NULL))
4255 		elog(ERROR, "aggregate function called in non-aggregate context");
4256 
4257 	state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4258 
4259 	/*
4260 	 * If the platform supports int128 then sumX will be a 128 integer type.
4261 	 * Here we'll convert that into a numeric type so that the combine state
4262 	 * is in the same format for both int128 enabled machines and machines
4263 	 * which don't support that type. The logic here is that one day we might
4264 	 * like to send these over to another server for further processing and we
4265 	 * want a standard format to work with.
4266 	 */
4267 	{
4268 		Datum		temp;
4269 #ifdef HAVE_INT128
4270 		NumericVar	num;
4271 
4272 		init_var(&num);
4273 		int128_to_numericvar(state->sumX, &num);
4274 		temp = DirectFunctionCall1(numeric_send,
4275 								   NumericGetDatum(make_result(&num)));
4276 		free_var(&num);
4277 		sumX = DatumGetByteaP(temp);
4278 #else
4279 		temp = DirectFunctionCall1(numeric_send,
4280 								 NumericGetDatum(make_result(&state->sumX)));
4281 		sumX = DatumGetByteaP(temp);
4282 #endif
4283 	}
4284 
4285 	pq_begintypsend(&buf);
4286 
4287 	/* N */
4288 	pq_sendint64(&buf, state->N);
4289 
4290 	/* sumX */
4291 	pq_sendbytes(&buf, VARDATA(sumX), VARSIZE(sumX) - VARHDRSZ);
4292 
4293 	result = pq_endtypsend(&buf);
4294 
4295 	PG_RETURN_BYTEA_P(result);
4296 }
4297 
4298 /*
4299  * int8_avg_deserialize
4300  *		Deserialize bytea back into PolyNumAggState.
4301  */
4302 Datum
int8_avg_deserialize(PG_FUNCTION_ARGS)4303 int8_avg_deserialize(PG_FUNCTION_ARGS)
4304 {
4305 	bytea	   *sstate;
4306 	PolyNumAggState *result;
4307 	StringInfoData buf;
4308 	Datum		temp;
4309 
4310 	if (!AggCheckCallContext(fcinfo, NULL))
4311 		elog(ERROR, "aggregate function called in non-aggregate context");
4312 
4313 	sstate = PG_GETARG_BYTEA_P(0);
4314 
4315 	/*
4316 	 * Copy the bytea into a StringInfo so that we can "receive" it using the
4317 	 * standard recv-function infrastructure.
4318 	 */
4319 	initStringInfo(&buf);
4320 	appendBinaryStringInfo(&buf, VARDATA(sstate), VARSIZE(sstate) - VARHDRSZ);
4321 
4322 	result = makePolyNumAggStateCurrentContext(false);
4323 
4324 	/* N */
4325 	result->N = pq_getmsgint64(&buf);
4326 
4327 	/* sumX */
4328 	temp = DirectFunctionCall3(numeric_recv,
4329 							   PointerGetDatum(&buf),
4330 							   InvalidOid,
4331 							   -1);
4332 
4333 #ifdef HAVE_INT128
4334 	{
4335 		NumericVar	num;
4336 
4337 		init_var(&num);
4338 		set_var_from_num(DatumGetNumeric(temp), &num);
4339 		numericvar_to_int128(&num, &result->sumX);
4340 		free_var(&num);
4341 	}
4342 #else
4343 	set_var_from_num(DatumGetNumeric(temp), &result->sumX);
4344 #endif
4345 
4346 	pq_getmsgend(&buf);
4347 	pfree(buf.data);
4348 
4349 	PG_RETURN_POINTER(result);
4350 }
4351 
4352 /*
4353  * Inverse transition functions to go with the above.
4354  */
4355 
4356 Datum
int2_accum_inv(PG_FUNCTION_ARGS)4357 int2_accum_inv(PG_FUNCTION_ARGS)
4358 {
4359 	PolyNumAggState *state;
4360 
4361 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4362 
4363 	/* Should not get here with no state */
4364 	if (state == NULL)
4365 		elog(ERROR, "int2_accum_inv called with NULL state");
4366 
4367 	if (!PG_ARGISNULL(1))
4368 	{
4369 #ifdef HAVE_INT128
4370 		do_int128_discard(state, (int128) PG_GETARG_INT16(1));
4371 #else
4372 		Numeric		newval;
4373 
4374 		newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4375 													 PG_GETARG_DATUM(1)));
4376 
4377 		/* Should never fail, all inputs have dscale 0 */
4378 		if (!do_numeric_discard(state, newval))
4379 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4380 #endif
4381 	}
4382 
4383 	PG_RETURN_POINTER(state);
4384 }
4385 
4386 Datum
int4_accum_inv(PG_FUNCTION_ARGS)4387 int4_accum_inv(PG_FUNCTION_ARGS)
4388 {
4389 	PolyNumAggState *state;
4390 
4391 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4392 
4393 	/* Should not get here with no state */
4394 	if (state == NULL)
4395 		elog(ERROR, "int4_accum_inv called with NULL state");
4396 
4397 	if (!PG_ARGISNULL(1))
4398 	{
4399 #ifdef HAVE_INT128
4400 		do_int128_discard(state, (int128) PG_GETARG_INT32(1));
4401 #else
4402 		Numeric		newval;
4403 
4404 		newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4405 													 PG_GETARG_DATUM(1)));
4406 
4407 		/* Should never fail, all inputs have dscale 0 */
4408 		if (!do_numeric_discard(state, newval))
4409 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4410 #endif
4411 	}
4412 
4413 	PG_RETURN_POINTER(state);
4414 }
4415 
4416 Datum
int8_accum_inv(PG_FUNCTION_ARGS)4417 int8_accum_inv(PG_FUNCTION_ARGS)
4418 {
4419 	NumericAggState *state;
4420 
4421 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4422 
4423 	/* Should not get here with no state */
4424 	if (state == NULL)
4425 		elog(ERROR, "int8_accum_inv called with NULL state");
4426 
4427 	if (!PG_ARGISNULL(1))
4428 	{
4429 		Numeric		newval;
4430 
4431 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4432 													 PG_GETARG_DATUM(1)));
4433 
4434 		/* Should never fail, all inputs have dscale 0 */
4435 		if (!do_numeric_discard(state, newval))
4436 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4437 	}
4438 
4439 	PG_RETURN_POINTER(state);
4440 }
4441 
4442 Datum
int8_avg_accum_inv(PG_FUNCTION_ARGS)4443 int8_avg_accum_inv(PG_FUNCTION_ARGS)
4444 {
4445 	PolyNumAggState *state;
4446 
4447 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4448 
4449 	/* Should not get here with no state */
4450 	if (state == NULL)
4451 		elog(ERROR, "int8_avg_accum_inv called with NULL state");
4452 
4453 	if (!PG_ARGISNULL(1))
4454 	{
4455 #ifdef HAVE_INT128
4456 		do_int128_discard(state, (int128) PG_GETARG_INT64(1));
4457 #else
4458 		Numeric		newval;
4459 
4460 		newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4461 													 PG_GETARG_DATUM(1)));
4462 
4463 		/* Should never fail, all inputs have dscale 0 */
4464 		if (!do_numeric_discard(state, newval))
4465 			elog(ERROR, "do_numeric_discard failed unexpectedly");
4466 #endif
4467 	}
4468 
4469 	PG_RETURN_POINTER(state);
4470 }
4471 
4472 Datum
numeric_poly_sum(PG_FUNCTION_ARGS)4473 numeric_poly_sum(PG_FUNCTION_ARGS)
4474 {
4475 #ifdef HAVE_INT128
4476 	PolyNumAggState *state;
4477 	Numeric		res;
4478 	NumericVar	result;
4479 
4480 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4481 
4482 	/* If there were no non-null inputs, return NULL */
4483 	if (state == NULL || state->N == 0)
4484 		PG_RETURN_NULL();
4485 
4486 	init_var(&result);
4487 
4488 	int128_to_numericvar(state->sumX, &result);
4489 
4490 	res = make_result(&result);
4491 
4492 	free_var(&result);
4493 
4494 	PG_RETURN_NUMERIC(res);
4495 #else
4496 	return numeric_sum(fcinfo);
4497 #endif
4498 }
4499 
4500 Datum
numeric_poly_avg(PG_FUNCTION_ARGS)4501 numeric_poly_avg(PG_FUNCTION_ARGS)
4502 {
4503 #ifdef HAVE_INT128
4504 	PolyNumAggState *state;
4505 	NumericVar	result;
4506 	Datum		countd,
4507 				sumd;
4508 
4509 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4510 
4511 	/* If there were no non-null inputs, return NULL */
4512 	if (state == NULL || state->N == 0)
4513 		PG_RETURN_NULL();
4514 
4515 	init_var(&result);
4516 
4517 	int128_to_numericvar(state->sumX, &result);
4518 
4519 	countd = DirectFunctionCall1(int8_numeric,
4520 								 Int64GetDatumFast(state->N));
4521 	sumd = NumericGetDatum(make_result(&result));
4522 
4523 	free_var(&result);
4524 
4525 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
4526 #else
4527 	return numeric_avg(fcinfo);
4528 #endif
4529 }
4530 
4531 Datum
numeric_avg(PG_FUNCTION_ARGS)4532 numeric_avg(PG_FUNCTION_ARGS)
4533 {
4534 	NumericAggState *state;
4535 	Datum		N_datum;
4536 	Datum		sumX_datum;
4537 
4538 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4539 
4540 	/* If there were no non-null inputs, return NULL */
4541 	if (state == NULL || (state->N + state->NaNcount) == 0)
4542 		PG_RETURN_NULL();
4543 
4544 	if (state->NaNcount > 0)	/* there was at least one NaN input */
4545 		PG_RETURN_NUMERIC(make_result(&const_nan));
4546 
4547 	N_datum = DirectFunctionCall1(int8_numeric, Int64GetDatum(state->N));
4548 	sumX_datum = NumericGetDatum(make_result(&state->sumX));
4549 
4550 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumX_datum, N_datum));
4551 }
4552 
4553 Datum
numeric_sum(PG_FUNCTION_ARGS)4554 numeric_sum(PG_FUNCTION_ARGS)
4555 {
4556 	NumericAggState *state;
4557 
4558 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4559 
4560 	/* If there were no non-null inputs, return NULL */
4561 	if (state == NULL || (state->N + state->NaNcount) == 0)
4562 		PG_RETURN_NULL();
4563 
4564 	if (state->NaNcount > 0)	/* there was at least one NaN input */
4565 		PG_RETURN_NUMERIC(make_result(&const_nan));
4566 
4567 	PG_RETURN_NUMERIC(make_result(&(state->sumX)));
4568 }
4569 
4570 /*
4571  * Workhorse routine for the standard deviance and variance
4572  * aggregates. 'state' is aggregate's transition state.
4573  * 'variance' specifies whether we should calculate the
4574  * variance or the standard deviation. 'sample' indicates whether the
4575  * caller is interested in the sample or the population
4576  * variance/stddev.
4577  *
4578  * If appropriate variance statistic is undefined for the input,
4579  * *is_null is set to true and NULL is returned.
4580  */
4581 static Numeric
numeric_stddev_internal(NumericAggState * state,bool variance,bool sample,bool * is_null)4582 numeric_stddev_internal(NumericAggState *state,
4583 						bool variance, bool sample,
4584 						bool *is_null)
4585 {
4586 	Numeric		res;
4587 	NumericVar	vN,
4588 				vsumX,
4589 				vsumX2,
4590 				vNminus1;
4591 	NumericVar *comp;
4592 	int			rscale;
4593 
4594 	/* Deal with empty input and NaN-input cases */
4595 	if (state == NULL || (state->N + state->NaNcount) == 0)
4596 	{
4597 		*is_null = true;
4598 		return NULL;
4599 	}
4600 
4601 	*is_null = false;
4602 
4603 	if (state->NaNcount > 0)
4604 		return make_result(&const_nan);
4605 
4606 	init_var(&vN);
4607 	init_var(&vsumX);
4608 	init_var(&vsumX2);
4609 
4610 	int64_to_numericvar(state->N, &vN);
4611 	set_var_from_var(&(state->sumX), &vsumX);
4612 	set_var_from_var(&(state->sumX2), &vsumX2);
4613 
4614 	/*
4615 	 * Sample stddev and variance are undefined when N <= 1; population stddev
4616 	 * is undefined when N == 0. Return NULL in either case.
4617 	 */
4618 	if (sample)
4619 		comp = &const_one;
4620 	else
4621 		comp = &const_zero;
4622 
4623 	if (cmp_var(&vN, comp) <= 0)
4624 	{
4625 		*is_null = true;
4626 		return NULL;
4627 	}
4628 
4629 	init_var(&vNminus1);
4630 	sub_var(&vN, &const_one, &vNminus1);
4631 
4632 	/* compute rscale for mul_var calls */
4633 	rscale = vsumX.dscale * 2;
4634 
4635 	mul_var(&vsumX, &vsumX, &vsumX, rscale);	/* vsumX = sumX * sumX */
4636 	mul_var(&vN, &vsumX2, &vsumX2, rscale);		/* vsumX2 = N * sumX2 */
4637 	sub_var(&vsumX2, &vsumX, &vsumX2);	/* N * sumX2 - sumX * sumX */
4638 
4639 	if (cmp_var(&vsumX2, &const_zero) <= 0)
4640 	{
4641 		/* Watch out for roundoff error producing a negative numerator */
4642 		res = make_result(&const_zero);
4643 	}
4644 	else
4645 	{
4646 		if (sample)
4647 			mul_var(&vN, &vNminus1, &vNminus1, 0);		/* N * (N - 1) */
4648 		else
4649 			mul_var(&vN, &vN, &vNminus1, 0);	/* N * N */
4650 		rscale = select_div_scale(&vsumX2, &vNminus1);
4651 		div_var(&vsumX2, &vNminus1, &vsumX, rscale, true);		/* variance */
4652 		if (!variance)
4653 			sqrt_var(&vsumX, &vsumX, rscale);	/* stddev */
4654 
4655 		res = make_result(&vsumX);
4656 	}
4657 
4658 	free_var(&vNminus1);
4659 	free_var(&vsumX);
4660 	free_var(&vsumX2);
4661 
4662 	return res;
4663 }
4664 
4665 Datum
numeric_var_samp(PG_FUNCTION_ARGS)4666 numeric_var_samp(PG_FUNCTION_ARGS)
4667 {
4668 	NumericAggState *state;
4669 	Numeric		res;
4670 	bool		is_null;
4671 
4672 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4673 
4674 	res = numeric_stddev_internal(state, true, true, &is_null);
4675 
4676 	if (is_null)
4677 		PG_RETURN_NULL();
4678 	else
4679 		PG_RETURN_NUMERIC(res);
4680 }
4681 
4682 Datum
numeric_stddev_samp(PG_FUNCTION_ARGS)4683 numeric_stddev_samp(PG_FUNCTION_ARGS)
4684 {
4685 	NumericAggState *state;
4686 	Numeric		res;
4687 	bool		is_null;
4688 
4689 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4690 
4691 	res = numeric_stddev_internal(state, false, true, &is_null);
4692 
4693 	if (is_null)
4694 		PG_RETURN_NULL();
4695 	else
4696 		PG_RETURN_NUMERIC(res);
4697 }
4698 
4699 Datum
numeric_var_pop(PG_FUNCTION_ARGS)4700 numeric_var_pop(PG_FUNCTION_ARGS)
4701 {
4702 	NumericAggState *state;
4703 	Numeric		res;
4704 	bool		is_null;
4705 
4706 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4707 
4708 	res = numeric_stddev_internal(state, true, false, &is_null);
4709 
4710 	if (is_null)
4711 		PG_RETURN_NULL();
4712 	else
4713 		PG_RETURN_NUMERIC(res);
4714 }
4715 
4716 Datum
numeric_stddev_pop(PG_FUNCTION_ARGS)4717 numeric_stddev_pop(PG_FUNCTION_ARGS)
4718 {
4719 	NumericAggState *state;
4720 	Numeric		res;
4721 	bool		is_null;
4722 
4723 	state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4724 
4725 	res = numeric_stddev_internal(state, false, false, &is_null);
4726 
4727 	if (is_null)
4728 		PG_RETURN_NULL();
4729 	else
4730 		PG_RETURN_NUMERIC(res);
4731 }
4732 
4733 #ifdef HAVE_INT128
4734 static Numeric
numeric_poly_stddev_internal(Int128AggState * state,bool variance,bool sample,bool * is_null)4735 numeric_poly_stddev_internal(Int128AggState *state,
4736 							 bool variance, bool sample,
4737 							 bool *is_null)
4738 {
4739 	NumericAggState numstate;
4740 	Numeric		res;
4741 
4742 	init_var(&numstate.sumX);
4743 	init_var(&numstate.sumX2);
4744 	numstate.NaNcount = 0;
4745 	numstate.agg_context = NULL;
4746 
4747 	if (state)
4748 	{
4749 		numstate.N = state->N;
4750 		int128_to_numericvar(state->sumX, &numstate.sumX);
4751 		int128_to_numericvar(state->sumX2, &numstate.sumX2);
4752 	}
4753 	else
4754 	{
4755 		numstate.N = 0;
4756 	}
4757 
4758 	res = numeric_stddev_internal(&numstate, variance, sample, is_null);
4759 
4760 	free_var(&numstate.sumX);
4761 	free_var(&numstate.sumX2);
4762 
4763 	return res;
4764 }
4765 #endif
4766 
4767 Datum
numeric_poly_var_samp(PG_FUNCTION_ARGS)4768 numeric_poly_var_samp(PG_FUNCTION_ARGS)
4769 {
4770 #ifdef HAVE_INT128
4771 	PolyNumAggState *state;
4772 	Numeric		res;
4773 	bool		is_null;
4774 
4775 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4776 
4777 	res = numeric_poly_stddev_internal(state, true, true, &is_null);
4778 
4779 	if (is_null)
4780 		PG_RETURN_NULL();
4781 	else
4782 		PG_RETURN_NUMERIC(res);
4783 #else
4784 	return numeric_var_samp(fcinfo);
4785 #endif
4786 }
4787 
4788 Datum
numeric_poly_stddev_samp(PG_FUNCTION_ARGS)4789 numeric_poly_stddev_samp(PG_FUNCTION_ARGS)
4790 {
4791 #ifdef HAVE_INT128
4792 	PolyNumAggState *state;
4793 	Numeric		res;
4794 	bool		is_null;
4795 
4796 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4797 
4798 	res = numeric_poly_stddev_internal(state, false, true, &is_null);
4799 
4800 	if (is_null)
4801 		PG_RETURN_NULL();
4802 	else
4803 		PG_RETURN_NUMERIC(res);
4804 #else
4805 	return numeric_stddev_samp(fcinfo);
4806 #endif
4807 }
4808 
4809 Datum
numeric_poly_var_pop(PG_FUNCTION_ARGS)4810 numeric_poly_var_pop(PG_FUNCTION_ARGS)
4811 {
4812 #ifdef HAVE_INT128
4813 	PolyNumAggState *state;
4814 	Numeric		res;
4815 	bool		is_null;
4816 
4817 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4818 
4819 	res = numeric_poly_stddev_internal(state, true, false, &is_null);
4820 
4821 	if (is_null)
4822 		PG_RETURN_NULL();
4823 	else
4824 		PG_RETURN_NUMERIC(res);
4825 #else
4826 	return numeric_var_pop(fcinfo);
4827 #endif
4828 }
4829 
4830 Datum
numeric_poly_stddev_pop(PG_FUNCTION_ARGS)4831 numeric_poly_stddev_pop(PG_FUNCTION_ARGS)
4832 {
4833 #ifdef HAVE_INT128
4834 	PolyNumAggState *state;
4835 	Numeric		res;
4836 	bool		is_null;
4837 
4838 	state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4839 
4840 	res = numeric_poly_stddev_internal(state, false, false, &is_null);
4841 
4842 	if (is_null)
4843 		PG_RETURN_NULL();
4844 	else
4845 		PG_RETURN_NUMERIC(res);
4846 #else
4847 	return numeric_stddev_pop(fcinfo);
4848 #endif
4849 }
4850 
4851 /*
4852  * SUM transition functions for integer datatypes.
4853  *
4854  * To avoid overflow, we use accumulators wider than the input datatype.
4855  * A Numeric accumulator is needed for int8 input; for int4 and int2
4856  * inputs, we use int8 accumulators which should be sufficient for practical
4857  * purposes.  (The latter two therefore don't really belong in this file,
4858  * but we keep them here anyway.)
4859  *
4860  * Because SQL defines the SUM() of no values to be NULL, not zero,
4861  * the initial condition of the transition data value needs to be NULL. This
4862  * means we can't rely on ExecAgg to automatically insert the first non-null
4863  * data value into the transition data: it doesn't know how to do the type
4864  * conversion.  The upshot is that these routines have to be marked non-strict
4865  * and handle substitution of the first non-null input themselves.
4866  *
4867  * Note: these functions are used only in plain aggregation mode.
4868  * In moving-aggregate mode, we use intX_avg_accum and intX_avg_accum_inv.
4869  */
4870 
4871 Datum
int2_sum(PG_FUNCTION_ARGS)4872 int2_sum(PG_FUNCTION_ARGS)
4873 {
4874 	int64		newval;
4875 
4876 	if (PG_ARGISNULL(0))
4877 	{
4878 		/* No non-null input seen so far... */
4879 		if (PG_ARGISNULL(1))
4880 			PG_RETURN_NULL();	/* still no non-null */
4881 		/* This is the first non-null input. */
4882 		newval = (int64) PG_GETARG_INT16(1);
4883 		PG_RETURN_INT64(newval);
4884 	}
4885 
4886 	/*
4887 	 * If we're invoked as an aggregate, we can cheat and modify our first
4888 	 * parameter in-place to avoid palloc overhead. If not, we need to return
4889 	 * the new value of the transition variable. (If int8 is pass-by-value,
4890 	 * then of course this is useless as well as incorrect, so just ifdef it
4891 	 * out.)
4892 	 */
4893 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
4894 	if (AggCheckCallContext(fcinfo, NULL))
4895 	{
4896 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
4897 
4898 		/* Leave the running sum unchanged in the new input is null */
4899 		if (!PG_ARGISNULL(1))
4900 			*oldsum = *oldsum + (int64) PG_GETARG_INT16(1);
4901 
4902 		PG_RETURN_POINTER(oldsum);
4903 	}
4904 	else
4905 #endif
4906 	{
4907 		int64		oldsum = PG_GETARG_INT64(0);
4908 
4909 		/* Leave sum unchanged if new input is null. */
4910 		if (PG_ARGISNULL(1))
4911 			PG_RETURN_INT64(oldsum);
4912 
4913 		/* OK to do the addition. */
4914 		newval = oldsum + (int64) PG_GETARG_INT16(1);
4915 
4916 		PG_RETURN_INT64(newval);
4917 	}
4918 }
4919 
4920 Datum
int4_sum(PG_FUNCTION_ARGS)4921 int4_sum(PG_FUNCTION_ARGS)
4922 {
4923 	int64		newval;
4924 
4925 	if (PG_ARGISNULL(0))
4926 	{
4927 		/* No non-null input seen so far... */
4928 		if (PG_ARGISNULL(1))
4929 			PG_RETURN_NULL();	/* still no non-null */
4930 		/* This is the first non-null input. */
4931 		newval = (int64) PG_GETARG_INT32(1);
4932 		PG_RETURN_INT64(newval);
4933 	}
4934 
4935 	/*
4936 	 * If we're invoked as an aggregate, we can cheat and modify our first
4937 	 * parameter in-place to avoid palloc overhead. If not, we need to return
4938 	 * the new value of the transition variable. (If int8 is pass-by-value,
4939 	 * then of course this is useless as well as incorrect, so just ifdef it
4940 	 * out.)
4941 	 */
4942 #ifndef USE_FLOAT8_BYVAL		/* controls int8 too */
4943 	if (AggCheckCallContext(fcinfo, NULL))
4944 	{
4945 		int64	   *oldsum = (int64 *) PG_GETARG_POINTER(0);
4946 
4947 		/* Leave the running sum unchanged in the new input is null */
4948 		if (!PG_ARGISNULL(1))
4949 			*oldsum = *oldsum + (int64) PG_GETARG_INT32(1);
4950 
4951 		PG_RETURN_POINTER(oldsum);
4952 	}
4953 	else
4954 #endif
4955 	{
4956 		int64		oldsum = PG_GETARG_INT64(0);
4957 
4958 		/* Leave sum unchanged if new input is null. */
4959 		if (PG_ARGISNULL(1))
4960 			PG_RETURN_INT64(oldsum);
4961 
4962 		/* OK to do the addition. */
4963 		newval = oldsum + (int64) PG_GETARG_INT32(1);
4964 
4965 		PG_RETURN_INT64(newval);
4966 	}
4967 }
4968 
4969 /*
4970  * Note: this function is obsolete, it's no longer used for SUM(int8).
4971  */
4972 Datum
int8_sum(PG_FUNCTION_ARGS)4973 int8_sum(PG_FUNCTION_ARGS)
4974 {
4975 	Numeric		oldsum;
4976 	Datum		newval;
4977 
4978 	if (PG_ARGISNULL(0))
4979 	{
4980 		/* No non-null input seen so far... */
4981 		if (PG_ARGISNULL(1))
4982 			PG_RETURN_NULL();	/* still no non-null */
4983 		/* This is the first non-null input. */
4984 		newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
4985 		PG_RETURN_DATUM(newval);
4986 	}
4987 
4988 	/*
4989 	 * Note that we cannot special-case the aggregate case here, as we do for
4990 	 * int2_sum and int4_sum: numeric is of variable size, so we cannot modify
4991 	 * our first parameter in-place.
4992 	 */
4993 
4994 	oldsum = PG_GETARG_NUMERIC(0);
4995 
4996 	/* Leave sum unchanged if new input is null. */
4997 	if (PG_ARGISNULL(1))
4998 		PG_RETURN_NUMERIC(oldsum);
4999 
5000 	/* OK to do the addition. */
5001 	newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5002 
5003 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_add,
5004 										NumericGetDatum(oldsum), newval));
5005 }
5006 
5007 
5008 /*
5009  * Routines for avg(int2) and avg(int4).  The transition datatype
5010  * is a two-element int8 array, holding count and sum.
5011  *
5012  * These functions are also used for sum(int2) and sum(int4) when
5013  * operating in moving-aggregate mode, since for correct inverse transitions
5014  * we need to count the inputs.
5015  */
5016 
5017 typedef struct Int8TransTypeData
5018 {
5019 	int64		count;
5020 	int64		sum;
5021 } Int8TransTypeData;
5022 
5023 Datum
int2_avg_accum(PG_FUNCTION_ARGS)5024 int2_avg_accum(PG_FUNCTION_ARGS)
5025 {
5026 	ArrayType  *transarray;
5027 	int16		newval = PG_GETARG_INT16(1);
5028 	Int8TransTypeData *transdata;
5029 
5030 	/*
5031 	 * If we're invoked as an aggregate, we can cheat and modify our first
5032 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5033 	 * a copy of it before scribbling on it.
5034 	 */
5035 	if (AggCheckCallContext(fcinfo, NULL))
5036 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5037 	else
5038 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5039 
5040 	if (ARR_HASNULL(transarray) ||
5041 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5042 		elog(ERROR, "expected 2-element int8 array");
5043 
5044 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5045 	transdata->count++;
5046 	transdata->sum += newval;
5047 
5048 	PG_RETURN_ARRAYTYPE_P(transarray);
5049 }
5050 
5051 Datum
int4_avg_accum(PG_FUNCTION_ARGS)5052 int4_avg_accum(PG_FUNCTION_ARGS)
5053 {
5054 	ArrayType  *transarray;
5055 	int32		newval = PG_GETARG_INT32(1);
5056 	Int8TransTypeData *transdata;
5057 
5058 	/*
5059 	 * If we're invoked as an aggregate, we can cheat and modify our first
5060 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5061 	 * a copy of it before scribbling on it.
5062 	 */
5063 	if (AggCheckCallContext(fcinfo, NULL))
5064 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5065 	else
5066 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5067 
5068 	if (ARR_HASNULL(transarray) ||
5069 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5070 		elog(ERROR, "expected 2-element int8 array");
5071 
5072 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5073 	transdata->count++;
5074 	transdata->sum += newval;
5075 
5076 	PG_RETURN_ARRAYTYPE_P(transarray);
5077 }
5078 
5079 Datum
int4_avg_combine(PG_FUNCTION_ARGS)5080 int4_avg_combine(PG_FUNCTION_ARGS)
5081 {
5082 	ArrayType  *transarray1;
5083 	ArrayType  *transarray2;
5084 	Int8TransTypeData *state1;
5085 	Int8TransTypeData *state2;
5086 
5087 	if (!AggCheckCallContext(fcinfo, NULL))
5088 		elog(ERROR, "aggregate function called in non-aggregate context");
5089 
5090 	transarray1 = PG_GETARG_ARRAYTYPE_P(0);
5091 	transarray2 = PG_GETARG_ARRAYTYPE_P(1);
5092 
5093 	if (ARR_HASNULL(transarray1) ||
5094 		ARR_SIZE(transarray1) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5095 		elog(ERROR, "expected 2-element int8 array");
5096 
5097 	if (ARR_HASNULL(transarray2) ||
5098 		ARR_SIZE(transarray2) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5099 		elog(ERROR, "expected 2-element int8 array");
5100 
5101 	state1 = (Int8TransTypeData *) ARR_DATA_PTR(transarray1);
5102 	state2 = (Int8TransTypeData *) ARR_DATA_PTR(transarray2);
5103 
5104 	state1->count += state2->count;
5105 	state1->sum += state2->sum;
5106 
5107 	PG_RETURN_ARRAYTYPE_P(transarray1);
5108 }
5109 
5110 Datum
int2_avg_accum_inv(PG_FUNCTION_ARGS)5111 int2_avg_accum_inv(PG_FUNCTION_ARGS)
5112 {
5113 	ArrayType  *transarray;
5114 	int16		newval = PG_GETARG_INT16(1);
5115 	Int8TransTypeData *transdata;
5116 
5117 	/*
5118 	 * If we're invoked as an aggregate, we can cheat and modify our first
5119 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5120 	 * a copy of it before scribbling on it.
5121 	 */
5122 	if (AggCheckCallContext(fcinfo, NULL))
5123 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5124 	else
5125 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5126 
5127 	if (ARR_HASNULL(transarray) ||
5128 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5129 		elog(ERROR, "expected 2-element int8 array");
5130 
5131 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5132 	transdata->count--;
5133 	transdata->sum -= newval;
5134 
5135 	PG_RETURN_ARRAYTYPE_P(transarray);
5136 }
5137 
5138 Datum
int4_avg_accum_inv(PG_FUNCTION_ARGS)5139 int4_avg_accum_inv(PG_FUNCTION_ARGS)
5140 {
5141 	ArrayType  *transarray;
5142 	int32		newval = PG_GETARG_INT32(1);
5143 	Int8TransTypeData *transdata;
5144 
5145 	/*
5146 	 * If we're invoked as an aggregate, we can cheat and modify our first
5147 	 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5148 	 * a copy of it before scribbling on it.
5149 	 */
5150 	if (AggCheckCallContext(fcinfo, NULL))
5151 		transarray = PG_GETARG_ARRAYTYPE_P(0);
5152 	else
5153 		transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5154 
5155 	if (ARR_HASNULL(transarray) ||
5156 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5157 		elog(ERROR, "expected 2-element int8 array");
5158 
5159 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5160 	transdata->count--;
5161 	transdata->sum -= newval;
5162 
5163 	PG_RETURN_ARRAYTYPE_P(transarray);
5164 }
5165 
5166 Datum
int8_avg(PG_FUNCTION_ARGS)5167 int8_avg(PG_FUNCTION_ARGS)
5168 {
5169 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
5170 	Int8TransTypeData *transdata;
5171 	Datum		countd,
5172 				sumd;
5173 
5174 	if (ARR_HASNULL(transarray) ||
5175 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5176 		elog(ERROR, "expected 2-element int8 array");
5177 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5178 
5179 	/* SQL defines AVG of no values to be NULL */
5180 	if (transdata->count == 0)
5181 		PG_RETURN_NULL();
5182 
5183 	countd = DirectFunctionCall1(int8_numeric,
5184 								 Int64GetDatumFast(transdata->count));
5185 	sumd = DirectFunctionCall1(int8_numeric,
5186 							   Int64GetDatumFast(transdata->sum));
5187 
5188 	PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
5189 }
5190 
5191 /*
5192  * SUM(int2) and SUM(int4) both return int8, so we can use this
5193  * final function for both.
5194  */
5195 Datum
int2int4_sum(PG_FUNCTION_ARGS)5196 int2int4_sum(PG_FUNCTION_ARGS)
5197 {
5198 	ArrayType  *transarray = PG_GETARG_ARRAYTYPE_P(0);
5199 	Int8TransTypeData *transdata;
5200 
5201 	if (ARR_HASNULL(transarray) ||
5202 		ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5203 		elog(ERROR, "expected 2-element int8 array");
5204 	transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5205 
5206 	/* SQL defines SUM of no values to be NULL */
5207 	if (transdata->count == 0)
5208 		PG_RETURN_NULL();
5209 
5210 	PG_RETURN_DATUM(Int64GetDatumFast(transdata->sum));
5211 }
5212 
5213 
5214 /* ----------------------------------------------------------------------
5215  *
5216  * Debug support
5217  *
5218  * ----------------------------------------------------------------------
5219  */
5220 
5221 #ifdef NUMERIC_DEBUG
5222 
5223 /*
5224  * dump_numeric() - Dump a value in the db storage format for debugging
5225  */
5226 static void
dump_numeric(const char * str,Numeric num)5227 dump_numeric(const char *str, Numeric num)
5228 {
5229 	NumericDigit *digits = NUMERIC_DIGITS(num);
5230 	int			ndigits;
5231 	int			i;
5232 
5233 	ndigits = NUMERIC_NDIGITS(num);
5234 
5235 	printf("%s: NUMERIC w=%d d=%d ", str,
5236 		   NUMERIC_WEIGHT(num), NUMERIC_DSCALE(num));
5237 	switch (NUMERIC_SIGN(num))
5238 	{
5239 		case NUMERIC_POS:
5240 			printf("POS");
5241 			break;
5242 		case NUMERIC_NEG:
5243 			printf("NEG");
5244 			break;
5245 		case NUMERIC_NAN:
5246 			printf("NaN");
5247 			break;
5248 		default:
5249 			printf("SIGN=0x%x", NUMERIC_SIGN(num));
5250 			break;
5251 	}
5252 
5253 	for (i = 0; i < ndigits; i++)
5254 		printf(" %0*d", DEC_DIGITS, digits[i]);
5255 	printf("\n");
5256 }
5257 
5258 
5259 /*
5260  * dump_var() - Dump a value in the variable format for debugging
5261  */
5262 static void
dump_var(const char * str,NumericVar * var)5263 dump_var(const char *str, NumericVar *var)
5264 {
5265 	int			i;
5266 
5267 	printf("%s: VAR w=%d d=%d ", str, var->weight, var->dscale);
5268 	switch (var->sign)
5269 	{
5270 		case NUMERIC_POS:
5271 			printf("POS");
5272 			break;
5273 		case NUMERIC_NEG:
5274 			printf("NEG");
5275 			break;
5276 		case NUMERIC_NAN:
5277 			printf("NaN");
5278 			break;
5279 		default:
5280 			printf("SIGN=0x%x", var->sign);
5281 			break;
5282 	}
5283 
5284 	for (i = 0; i < var->ndigits; i++)
5285 		printf(" %0*d", DEC_DIGITS, var->digits[i]);
5286 
5287 	printf("\n");
5288 }
5289 #endif   /* NUMERIC_DEBUG */
5290 
5291 
5292 /* ----------------------------------------------------------------------
5293  *
5294  * Local functions follow
5295  *
5296  * In general, these do not support NaNs --- callers must eliminate
5297  * the possibility of NaN first.  (make_result() is an exception.)
5298  *
5299  * ----------------------------------------------------------------------
5300  */
5301 
5302 
5303 /*
5304  * alloc_var() -
5305  *
5306  *	Allocate a digit buffer of ndigits digits (plus a spare digit for rounding)
5307  */
5308 static void
alloc_var(NumericVar * var,int ndigits)5309 alloc_var(NumericVar *var, int ndigits)
5310 {
5311 	digitbuf_free(var->buf);
5312 	var->buf = digitbuf_alloc(ndigits + 1);
5313 	var->buf[0] = 0;			/* spare digit for rounding */
5314 	var->digits = var->buf + 1;
5315 	var->ndigits = ndigits;
5316 }
5317 
5318 
5319 /*
5320  * free_var() -
5321  *
5322  *	Return the digit buffer of a variable to the free pool
5323  */
5324 static void
free_var(NumericVar * var)5325 free_var(NumericVar *var)
5326 {
5327 	digitbuf_free(var->buf);
5328 	var->buf = NULL;
5329 	var->digits = NULL;
5330 	var->sign = NUMERIC_NAN;
5331 }
5332 
5333 
5334 /*
5335  * zero_var() -
5336  *
5337  *	Set a variable to ZERO.
5338  *	Note: its dscale is not touched.
5339  */
5340 static void
zero_var(NumericVar * var)5341 zero_var(NumericVar *var)
5342 {
5343 	digitbuf_free(var->buf);
5344 	var->buf = NULL;
5345 	var->digits = NULL;
5346 	var->ndigits = 0;
5347 	var->weight = 0;			/* by convention; doesn't really matter */
5348 	var->sign = NUMERIC_POS;	/* anything but NAN... */
5349 }
5350 
5351 
5352 /*
5353  * set_var_from_str()
5354  *
5355  *	Parse a string and put the number into a variable
5356  *
5357  * This function does not handle leading or trailing spaces, and it doesn't
5358  * accept "NaN" either.  It returns the end+1 position so that caller can
5359  * check for trailing spaces/garbage if deemed necessary.
5360  *
5361  * cp is the place to actually start parsing; str is what to use in error
5362  * reports.  (Typically cp would be the same except advanced over spaces.)
5363  */
5364 static const char *
set_var_from_str(const char * str,const char * cp,NumericVar * dest)5365 set_var_from_str(const char *str, const char *cp, NumericVar *dest)
5366 {
5367 	bool		have_dp = FALSE;
5368 	int			i;
5369 	unsigned char *decdigits;
5370 	int			sign = NUMERIC_POS;
5371 	int			dweight = -1;
5372 	int			ddigits;
5373 	int			dscale = 0;
5374 	int			weight;
5375 	int			ndigits;
5376 	int			offset;
5377 	NumericDigit *digits;
5378 
5379 	/*
5380 	 * We first parse the string to extract decimal digits and determine the
5381 	 * correct decimal weight.  Then convert to NBASE representation.
5382 	 */
5383 	switch (*cp)
5384 	{
5385 		case '+':
5386 			sign = NUMERIC_POS;
5387 			cp++;
5388 			break;
5389 
5390 		case '-':
5391 			sign = NUMERIC_NEG;
5392 			cp++;
5393 			break;
5394 	}
5395 
5396 	if (*cp == '.')
5397 	{
5398 		have_dp = TRUE;
5399 		cp++;
5400 	}
5401 
5402 	if (!isdigit((unsigned char) *cp))
5403 		ereport(ERROR,
5404 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5405 			  errmsg("invalid input syntax for type numeric: \"%s\"", str)));
5406 
5407 	decdigits = (unsigned char *) palloc(strlen(cp) + DEC_DIGITS * 2);
5408 
5409 	/* leading padding for digit alignment later */
5410 	memset(decdigits, 0, DEC_DIGITS);
5411 	i = DEC_DIGITS;
5412 
5413 	while (*cp)
5414 	{
5415 		if (isdigit((unsigned char) *cp))
5416 		{
5417 			decdigits[i++] = *cp++ - '0';
5418 			if (!have_dp)
5419 				dweight++;
5420 			else
5421 				dscale++;
5422 		}
5423 		else if (*cp == '.')
5424 		{
5425 			if (have_dp)
5426 				ereport(ERROR,
5427 						(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5428 					  errmsg("invalid input syntax for type numeric: \"%s\"",
5429 							 str)));
5430 			have_dp = TRUE;
5431 			cp++;
5432 		}
5433 		else
5434 			break;
5435 	}
5436 
5437 	ddigits = i - DEC_DIGITS;
5438 	/* trailing padding for digit alignment later */
5439 	memset(decdigits + i, 0, DEC_DIGITS - 1);
5440 
5441 	/* Handle exponent, if any */
5442 	if (*cp == 'e' || *cp == 'E')
5443 	{
5444 		long		exponent;
5445 		char	   *endptr;
5446 
5447 		cp++;
5448 		exponent = strtol(cp, &endptr, 10);
5449 		if (endptr == cp)
5450 			ereport(ERROR,
5451 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5452 					 errmsg("invalid input syntax for type numeric: \"%s\"",
5453 							str)));
5454 		cp = endptr;
5455 
5456 		/*
5457 		 * At this point, dweight and dscale can't be more than about
5458 		 * INT_MAX/2 due to the MaxAllocSize limit on string length, so
5459 		 * constraining the exponent similarly should be enough to prevent
5460 		 * integer overflow in this function.  If the value is too large to
5461 		 * fit in storage format, make_result() will complain about it later;
5462 		 * for consistency use the same ereport errcode/text as make_result().
5463 		 */
5464 		if (exponent >= INT_MAX / 2 || exponent <= -(INT_MAX / 2))
5465 			ereport(ERROR,
5466 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
5467 					 errmsg("value overflows numeric format")));
5468 		dweight += (int) exponent;
5469 		dscale -= (int) exponent;
5470 		if (dscale < 0)
5471 			dscale = 0;
5472 	}
5473 
5474 	/*
5475 	 * Okay, convert pure-decimal representation to base NBASE.  First we need
5476 	 * to determine the converted weight and ndigits.  offset is the number of
5477 	 * decimal zeroes to insert before the first given digit to have a
5478 	 * correctly aligned first NBASE digit.
5479 	 */
5480 	if (dweight >= 0)
5481 		weight = (dweight + 1 + DEC_DIGITS - 1) / DEC_DIGITS - 1;
5482 	else
5483 		weight = -((-dweight - 1) / DEC_DIGITS + 1);
5484 	offset = (weight + 1) * DEC_DIGITS - (dweight + 1);
5485 	ndigits = (ddigits + offset + DEC_DIGITS - 1) / DEC_DIGITS;
5486 
5487 	alloc_var(dest, ndigits);
5488 	dest->sign = sign;
5489 	dest->weight = weight;
5490 	dest->dscale = dscale;
5491 
5492 	i = DEC_DIGITS - offset;
5493 	digits = dest->digits;
5494 
5495 	while (ndigits-- > 0)
5496 	{
5497 #if DEC_DIGITS == 4
5498 		*digits++ = ((decdigits[i] * 10 + decdigits[i + 1]) * 10 +
5499 					 decdigits[i + 2]) * 10 + decdigits[i + 3];
5500 #elif DEC_DIGITS == 2
5501 		*digits++ = decdigits[i] * 10 + decdigits[i + 1];
5502 #elif DEC_DIGITS == 1
5503 		*digits++ = decdigits[i];
5504 #else
5505 #error unsupported NBASE
5506 #endif
5507 		i += DEC_DIGITS;
5508 	}
5509 
5510 	pfree(decdigits);
5511 
5512 	/* Strip any leading/trailing zeroes, and normalize weight if zero */
5513 	strip_var(dest);
5514 
5515 	/* Return end+1 position for caller */
5516 	return cp;
5517 }
5518 
5519 
5520 /*
5521  * set_var_from_num() -
5522  *
5523  *	Convert the packed db format into a variable
5524  */
5525 static void
set_var_from_num(Numeric num,NumericVar * dest)5526 set_var_from_num(Numeric num, NumericVar *dest)
5527 {
5528 	int			ndigits;
5529 
5530 	ndigits = NUMERIC_NDIGITS(num);
5531 
5532 	alloc_var(dest, ndigits);
5533 
5534 	dest->weight = NUMERIC_WEIGHT(num);
5535 	dest->sign = NUMERIC_SIGN(num);
5536 	dest->dscale = NUMERIC_DSCALE(num);
5537 
5538 	memcpy(dest->digits, NUMERIC_DIGITS(num), ndigits * sizeof(NumericDigit));
5539 }
5540 
5541 
5542 /*
5543  * init_var_from_num() -
5544  *
5545  *	Initialize a variable from packed db format. The digits array is not
5546  *	copied, which saves some cycles when the resulting var is not modified.
5547  *	Also, there's no need to call free_var(), as long as you don't assign any
5548  *	other value to it (with set_var_* functions, or by using the var as the
5549  *	destination of a function like add_var())
5550  *
5551  *	CAUTION: Do not modify the digits buffer of a var initialized with this
5552  *	function, e.g by calling round_var() or trunc_var(), as the changes will
5553  *	propagate to the original Numeric! It's OK to use it as the destination
5554  *	argument of one of the calculational functions, though.
5555  */
5556 static void
init_var_from_num(Numeric num,NumericVar * dest)5557 init_var_from_num(Numeric num, NumericVar *dest)
5558 {
5559 	dest->ndigits = NUMERIC_NDIGITS(num);
5560 	dest->weight = NUMERIC_WEIGHT(num);
5561 	dest->sign = NUMERIC_SIGN(num);
5562 	dest->dscale = NUMERIC_DSCALE(num);
5563 	dest->digits = NUMERIC_DIGITS(num);
5564 	dest->buf = NULL;			/* digits array is not palloc'd */
5565 }
5566 
5567 
5568 /*
5569  * set_var_from_var() -
5570  *
5571  *	Copy one variable into another
5572  */
5573 static void
set_var_from_var(NumericVar * value,NumericVar * dest)5574 set_var_from_var(NumericVar *value, NumericVar *dest)
5575 {
5576 	NumericDigit *newbuf;
5577 
5578 	newbuf = digitbuf_alloc(value->ndigits + 1);
5579 	newbuf[0] = 0;				/* spare digit for rounding */
5580 	if (value->ndigits > 0)		/* else value->digits might be null */
5581 		memcpy(newbuf + 1, value->digits,
5582 			   value->ndigits * sizeof(NumericDigit));
5583 
5584 	digitbuf_free(dest->buf);
5585 
5586 	memmove(dest, value, sizeof(NumericVar));
5587 	dest->buf = newbuf;
5588 	dest->digits = newbuf + 1;
5589 }
5590 
5591 
5592 /*
5593  * get_str_from_var() -
5594  *
5595  *	Convert a var to text representation (guts of numeric_out).
5596  *	The var is displayed to the number of digits indicated by its dscale.
5597  *	Returns a palloc'd string.
5598  */
5599 static char *
get_str_from_var(NumericVar * var)5600 get_str_from_var(NumericVar *var)
5601 {
5602 	int			dscale;
5603 	char	   *str;
5604 	char	   *cp;
5605 	char	   *endcp;
5606 	int			i;
5607 	int			d;
5608 	NumericDigit dig;
5609 
5610 #if DEC_DIGITS > 1
5611 	NumericDigit d1;
5612 #endif
5613 
5614 	dscale = var->dscale;
5615 
5616 	/*
5617 	 * Allocate space for the result.
5618 	 *
5619 	 * i is set to the # of decimal digits before decimal point. dscale is the
5620 	 * # of decimal digits we will print after decimal point. We may generate
5621 	 * as many as DEC_DIGITS-1 excess digits at the end, and in addition we
5622 	 * need room for sign, decimal point, null terminator.
5623 	 */
5624 	i = (var->weight + 1) * DEC_DIGITS;
5625 	if (i <= 0)
5626 		i = 1;
5627 
5628 	str = palloc(i + dscale + DEC_DIGITS + 2);
5629 	cp = str;
5630 
5631 	/*
5632 	 * Output a dash for negative values
5633 	 */
5634 	if (var->sign == NUMERIC_NEG)
5635 		*cp++ = '-';
5636 
5637 	/*
5638 	 * Output all digits before the decimal point
5639 	 */
5640 	if (var->weight < 0)
5641 	{
5642 		d = var->weight + 1;
5643 		*cp++ = '0';
5644 	}
5645 	else
5646 	{
5647 		for (d = 0; d <= var->weight; d++)
5648 		{
5649 			dig = (d < var->ndigits) ? var->digits[d] : 0;
5650 			/* In the first digit, suppress extra leading decimal zeroes */
5651 #if DEC_DIGITS == 4
5652 			{
5653 				bool		putit = (d > 0);
5654 
5655 				d1 = dig / 1000;
5656 				dig -= d1 * 1000;
5657 				putit |= (d1 > 0);
5658 				if (putit)
5659 					*cp++ = d1 + '0';
5660 				d1 = dig / 100;
5661 				dig -= d1 * 100;
5662 				putit |= (d1 > 0);
5663 				if (putit)
5664 					*cp++ = d1 + '0';
5665 				d1 = dig / 10;
5666 				dig -= d1 * 10;
5667 				putit |= (d1 > 0);
5668 				if (putit)
5669 					*cp++ = d1 + '0';
5670 				*cp++ = dig + '0';
5671 			}
5672 #elif DEC_DIGITS == 2
5673 			d1 = dig / 10;
5674 			dig -= d1 * 10;
5675 			if (d1 > 0 || d > 0)
5676 				*cp++ = d1 + '0';
5677 			*cp++ = dig + '0';
5678 #elif DEC_DIGITS == 1
5679 			*cp++ = dig + '0';
5680 #else
5681 #error unsupported NBASE
5682 #endif
5683 		}
5684 	}
5685 
5686 	/*
5687 	 * If requested, output a decimal point and all the digits that follow it.
5688 	 * We initially put out a multiple of DEC_DIGITS digits, then truncate if
5689 	 * needed.
5690 	 */
5691 	if (dscale > 0)
5692 	{
5693 		*cp++ = '.';
5694 		endcp = cp + dscale;
5695 		for (i = 0; i < dscale; d++, i += DEC_DIGITS)
5696 		{
5697 			dig = (d >= 0 && d < var->ndigits) ? var->digits[d] : 0;
5698 #if DEC_DIGITS == 4
5699 			d1 = dig / 1000;
5700 			dig -= d1 * 1000;
5701 			*cp++ = d1 + '0';
5702 			d1 = dig / 100;
5703 			dig -= d1 * 100;
5704 			*cp++ = d1 + '0';
5705 			d1 = dig / 10;
5706 			dig -= d1 * 10;
5707 			*cp++ = d1 + '0';
5708 			*cp++ = dig + '0';
5709 #elif DEC_DIGITS == 2
5710 			d1 = dig / 10;
5711 			dig -= d1 * 10;
5712 			*cp++ = d1 + '0';
5713 			*cp++ = dig + '0';
5714 #elif DEC_DIGITS == 1
5715 			*cp++ = dig + '0';
5716 #else
5717 #error unsupported NBASE
5718 #endif
5719 		}
5720 		cp = endcp;
5721 	}
5722 
5723 	/*
5724 	 * terminate the string and return it
5725 	 */
5726 	*cp = '\0';
5727 	return str;
5728 }
5729 
5730 /*
5731  * get_str_from_var_sci() -
5732  *
5733  *	Convert a var to a normalised scientific notation text representation.
5734  *	This function does the heavy lifting for numeric_out_sci().
5735  *
5736  *	This notation has the general form a * 10^b, where a is known as the
5737  *	"significand" and b is known as the "exponent".
5738  *
5739  *	Because we can't do superscript in ASCII (and because we want to copy
5740  *	printf's behaviour) we display the exponent using E notation, with a
5741  *	minimum of two exponent digits.
5742  *
5743  *	For example, the value 1234 could be output as 1.2e+03.
5744  *
5745  *	We assume that the exponent can fit into an int32.
5746  *
5747  *	rscale is the number of decimal digits desired after the decimal point in
5748  *	the output, negative values will be treated as meaning zero.
5749  *
5750  *	Returns a palloc'd string.
5751  */
5752 static char *
get_str_from_var_sci(NumericVar * var,int rscale)5753 get_str_from_var_sci(NumericVar *var, int rscale)
5754 {
5755 	int32		exponent;
5756 	NumericVar	tmp_var;
5757 	size_t		len;
5758 	char	   *str;
5759 	char	   *sig_out;
5760 
5761 	if (rscale < 0)
5762 		rscale = 0;
5763 
5764 	/*
5765 	 * Determine the exponent of this number in normalised form.
5766 	 *
5767 	 * This is the exponent required to represent the number with only one
5768 	 * significant digit before the decimal place.
5769 	 */
5770 	if (var->ndigits > 0)
5771 	{
5772 		exponent = (var->weight + 1) * DEC_DIGITS;
5773 
5774 		/*
5775 		 * Compensate for leading decimal zeroes in the first numeric digit by
5776 		 * decrementing the exponent.
5777 		 */
5778 		exponent -= DEC_DIGITS - (int) log10(var->digits[0]);
5779 	}
5780 	else
5781 	{
5782 		/*
5783 		 * If var has no digits, then it must be zero.
5784 		 *
5785 		 * Zero doesn't technically have a meaningful exponent in normalised
5786 		 * notation, but we just display the exponent as zero for consistency
5787 		 * of output.
5788 		 */
5789 		exponent = 0;
5790 	}
5791 
5792 	/*
5793 	 * Divide var by 10^exponent to get the significand, rounding to rscale
5794 	 * decimal digits in the process.
5795 	 */
5796 	init_var(&tmp_var);
5797 
5798 	power_ten_int(exponent, &tmp_var);
5799 	div_var(var, &tmp_var, &tmp_var, rscale, true);
5800 	sig_out = get_str_from_var(&tmp_var);
5801 
5802 	free_var(&tmp_var);
5803 
5804 	/*
5805 	 * Allocate space for the result.
5806 	 *
5807 	 * In addition to the significand, we need room for the exponent
5808 	 * decoration ("e"), the sign of the exponent, up to 10 digits for the
5809 	 * exponent itself, and of course the null terminator.
5810 	 */
5811 	len = strlen(sig_out) + 13;
5812 	str = palloc(len);
5813 	snprintf(str, len, "%se%+03d", sig_out, exponent);
5814 
5815 	pfree(sig_out);
5816 
5817 	return str;
5818 }
5819 
5820 
5821 /*
5822  * make_result() -
5823  *
5824  *	Create the packed db numeric format in palloc()'d memory from
5825  *	a variable.
5826  */
5827 static Numeric
make_result(NumericVar * var)5828 make_result(NumericVar *var)
5829 {
5830 	Numeric		result;
5831 	NumericDigit *digits = var->digits;
5832 	int			weight = var->weight;
5833 	int			sign = var->sign;
5834 	int			n;
5835 	Size		len;
5836 
5837 	if (sign == NUMERIC_NAN)
5838 	{
5839 		result = (Numeric) palloc(NUMERIC_HDRSZ_SHORT);
5840 
5841 		SET_VARSIZE(result, NUMERIC_HDRSZ_SHORT);
5842 		result->choice.n_header = NUMERIC_NAN;
5843 		/* the header word is all we need */
5844 
5845 		dump_numeric("make_result()", result);
5846 		return result;
5847 	}
5848 
5849 	n = var->ndigits;
5850 
5851 	/* truncate leading zeroes */
5852 	while (n > 0 && *digits == 0)
5853 	{
5854 		digits++;
5855 		weight--;
5856 		n--;
5857 	}
5858 	/* truncate trailing zeroes */
5859 	while (n > 0 && digits[n - 1] == 0)
5860 		n--;
5861 
5862 	/* If zero result, force to weight=0 and positive sign */
5863 	if (n == 0)
5864 	{
5865 		weight = 0;
5866 		sign = NUMERIC_POS;
5867 	}
5868 
5869 	/* Build the result */
5870 	if (NUMERIC_CAN_BE_SHORT(var->dscale, weight))
5871 	{
5872 		len = NUMERIC_HDRSZ_SHORT + n * sizeof(NumericDigit);
5873 		result = (Numeric) palloc(len);
5874 		SET_VARSIZE(result, len);
5875 		result->choice.n_short.n_header =
5876 			(sign == NUMERIC_NEG ? (NUMERIC_SHORT | NUMERIC_SHORT_SIGN_MASK)
5877 			 : NUMERIC_SHORT)
5878 			| (var->dscale << NUMERIC_SHORT_DSCALE_SHIFT)
5879 			| (weight < 0 ? NUMERIC_SHORT_WEIGHT_SIGN_MASK : 0)
5880 			| (weight & NUMERIC_SHORT_WEIGHT_MASK);
5881 	}
5882 	else
5883 	{
5884 		len = NUMERIC_HDRSZ + n * sizeof(NumericDigit);
5885 		result = (Numeric) palloc(len);
5886 		SET_VARSIZE(result, len);
5887 		result->choice.n_long.n_sign_dscale =
5888 			sign | (var->dscale & NUMERIC_DSCALE_MASK);
5889 		result->choice.n_long.n_weight = weight;
5890 	}
5891 
5892 	Assert(NUMERIC_NDIGITS(result) == n);
5893 	if (n > 0)
5894 		memcpy(NUMERIC_DIGITS(result), digits, n * sizeof(NumericDigit));
5895 
5896 	/* Check for overflow of int16 fields */
5897 	if (NUMERIC_WEIGHT(result) != weight ||
5898 		NUMERIC_DSCALE(result) != var->dscale)
5899 		ereport(ERROR,
5900 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
5901 				 errmsg("value overflows numeric format")));
5902 
5903 	dump_numeric("make_result()", result);
5904 	return result;
5905 }
5906 
5907 
5908 /*
5909  * apply_typmod() -
5910  *
5911  *	Do bounds checking and rounding according to the attributes
5912  *	typmod field.
5913  */
5914 static void
apply_typmod(NumericVar * var,int32 typmod)5915 apply_typmod(NumericVar *var, int32 typmod)
5916 {
5917 	int			precision;
5918 	int			scale;
5919 	int			maxdigits;
5920 	int			ddigits;
5921 	int			i;
5922 
5923 	/* Do nothing if we have a default typmod (-1) */
5924 	if (typmod < (int32) (VARHDRSZ))
5925 		return;
5926 
5927 	typmod -= VARHDRSZ;
5928 	precision = (typmod >> 16) & 0xffff;
5929 	scale = typmod & 0xffff;
5930 	maxdigits = precision - scale;
5931 
5932 	/* Round to target scale (and set var->dscale) */
5933 	round_var(var, scale);
5934 
5935 	/*
5936 	 * Check for overflow - note we can't do this before rounding, because
5937 	 * rounding could raise the weight.  Also note that the var's weight could
5938 	 * be inflated by leading zeroes, which will be stripped before storage
5939 	 * but perhaps might not have been yet. In any case, we must recognize a
5940 	 * true zero, whose weight doesn't mean anything.
5941 	 */
5942 	ddigits = (var->weight + 1) * DEC_DIGITS;
5943 	if (ddigits > maxdigits)
5944 	{
5945 		/* Determine true weight; and check for all-zero result */
5946 		for (i = 0; i < var->ndigits; i++)
5947 		{
5948 			NumericDigit dig = var->digits[i];
5949 
5950 			if (dig)
5951 			{
5952 				/* Adjust for any high-order decimal zero digits */
5953 #if DEC_DIGITS == 4
5954 				if (dig < 10)
5955 					ddigits -= 3;
5956 				else if (dig < 100)
5957 					ddigits -= 2;
5958 				else if (dig < 1000)
5959 					ddigits -= 1;
5960 #elif DEC_DIGITS == 2
5961 				if (dig < 10)
5962 					ddigits -= 1;
5963 #elif DEC_DIGITS == 1
5964 				/* no adjustment */
5965 #else
5966 #error unsupported NBASE
5967 #endif
5968 				if (ddigits > maxdigits)
5969 					ereport(ERROR,
5970 							(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
5971 							 errmsg("numeric field overflow"),
5972 							 errdetail("A field with precision %d, scale %d must round to an absolute value less than %s%d.",
5973 									   precision, scale,
5974 					/* Display 10^0 as 1 */
5975 									   maxdigits ? "10^" : "",
5976 									   maxdigits ? maxdigits : 1
5977 									   )));
5978 				break;
5979 			}
5980 			ddigits -= DEC_DIGITS;
5981 		}
5982 	}
5983 }
5984 
5985 /*
5986  * Convert numeric to int8, rounding if needed.
5987  *
5988  * If overflow, return FALSE (no error is raised).  Return TRUE if okay.
5989  */
5990 static bool
numericvar_to_int64(NumericVar * var,int64 * result)5991 numericvar_to_int64(NumericVar *var, int64 *result)
5992 {
5993 	NumericDigit *digits;
5994 	int			ndigits;
5995 	int			weight;
5996 	int			i;
5997 	int64		val,
5998 				oldval;
5999 	bool		neg;
6000 	NumericVar	rounded;
6001 
6002 	/* Round to nearest integer */
6003 	init_var(&rounded);
6004 	set_var_from_var(var, &rounded);
6005 	round_var(&rounded, 0);
6006 
6007 	/* Check for zero input */
6008 	strip_var(&rounded);
6009 	ndigits = rounded.ndigits;
6010 	if (ndigits == 0)
6011 	{
6012 		*result = 0;
6013 		free_var(&rounded);
6014 		return true;
6015 	}
6016 
6017 	/*
6018 	 * For input like 10000000000, we must treat stripped digits as real. So
6019 	 * the loop assumes there are weight+1 digits before the decimal point.
6020 	 */
6021 	weight = rounded.weight;
6022 	Assert(weight >= 0 && ndigits <= weight + 1);
6023 
6024 	/* Construct the result */
6025 	digits = rounded.digits;
6026 	neg = (rounded.sign == NUMERIC_NEG);
6027 	val = digits[0];
6028 	for (i = 1; i <= weight; i++)
6029 	{
6030 		oldval = val;
6031 		val *= NBASE;
6032 		if (i < ndigits)
6033 			val += digits[i];
6034 
6035 		/*
6036 		 * The overflow check is a bit tricky because we want to accept
6037 		 * INT64_MIN, which will overflow the positive accumulator.  We can
6038 		 * detect this case easily though because INT64_MIN is the only
6039 		 * nonzero value for which -val == val (on a two's complement machine,
6040 		 * anyway).
6041 		 */
6042 		if ((val / NBASE) != oldval)	/* possible overflow? */
6043 		{
6044 			if (!neg || (-val) != val || val == 0 || oldval < 0)
6045 			{
6046 				free_var(&rounded);
6047 				return false;
6048 			}
6049 		}
6050 	}
6051 
6052 	free_var(&rounded);
6053 
6054 	*result = neg ? -val : val;
6055 	return true;
6056 }
6057 
6058 /*
6059  * Convert int8 value to numeric.
6060  */
6061 static void
int64_to_numericvar(int64 val,NumericVar * var)6062 int64_to_numericvar(int64 val, NumericVar *var)
6063 {
6064 	uint64		uval,
6065 				newuval;
6066 	NumericDigit *ptr;
6067 	int			ndigits;
6068 
6069 	/* int64 can require at most 19 decimal digits; add one for safety */
6070 	alloc_var(var, 20 / DEC_DIGITS);
6071 	if (val < 0)
6072 	{
6073 		var->sign = NUMERIC_NEG;
6074 		uval = -val;
6075 	}
6076 	else
6077 	{
6078 		var->sign = NUMERIC_POS;
6079 		uval = val;
6080 	}
6081 	var->dscale = 0;
6082 	if (val == 0)
6083 	{
6084 		var->ndigits = 0;
6085 		var->weight = 0;
6086 		return;
6087 	}
6088 	ptr = var->digits + var->ndigits;
6089 	ndigits = 0;
6090 	do
6091 	{
6092 		ptr--;
6093 		ndigits++;
6094 		newuval = uval / NBASE;
6095 		*ptr = uval - newuval * NBASE;
6096 		uval = newuval;
6097 	} while (uval);
6098 	var->digits = ptr;
6099 	var->ndigits = ndigits;
6100 	var->weight = ndigits - 1;
6101 }
6102 
6103 #ifdef HAVE_INT128
6104 /*
6105  * Convert numeric to int128, rounding if needed.
6106  *
6107  * If overflow, return FALSE (no error is raised).  Return TRUE if okay.
6108  */
6109 static bool
numericvar_to_int128(NumericVar * var,int128 * result)6110 numericvar_to_int128(NumericVar *var, int128 *result)
6111 {
6112 	NumericDigit *digits;
6113 	int			ndigits;
6114 	int			weight;
6115 	int			i;
6116 	int128		val,
6117 				oldval;
6118 	bool		neg;
6119 	NumericVar	rounded;
6120 
6121 	/* Round to nearest integer */
6122 	init_var(&rounded);
6123 	set_var_from_var(var, &rounded);
6124 	round_var(&rounded, 0);
6125 
6126 	/* Check for zero input */
6127 	strip_var(&rounded);
6128 	ndigits = rounded.ndigits;
6129 	if (ndigits == 0)
6130 	{
6131 		*result = 0;
6132 		free_var(&rounded);
6133 		return true;
6134 	}
6135 
6136 	/*
6137 	 * For input like 10000000000, we must treat stripped digits as real. So
6138 	 * the loop assumes there are weight+1 digits before the decimal point.
6139 	 */
6140 	weight = rounded.weight;
6141 	Assert(weight >= 0 && ndigits <= weight + 1);
6142 
6143 	/* Construct the result */
6144 	digits = rounded.digits;
6145 	neg = (rounded.sign == NUMERIC_NEG);
6146 	val = digits[0];
6147 	for (i = 1; i <= weight; i++)
6148 	{
6149 		oldval = val;
6150 		val *= NBASE;
6151 		if (i < ndigits)
6152 			val += digits[i];
6153 
6154 		/*
6155 		 * The overflow check is a bit tricky because we want to accept
6156 		 * INT128_MIN, which will overflow the positive accumulator.  We can
6157 		 * detect this case easily though because INT128_MIN is the only
6158 		 * nonzero value for which -val == val (on a two's complement machine,
6159 		 * anyway).
6160 		 */
6161 		if ((val / NBASE) != oldval)	/* possible overflow? */
6162 		{
6163 			if (!neg || (-val) != val || val == 0 || oldval < 0)
6164 			{
6165 				free_var(&rounded);
6166 				return false;
6167 			}
6168 		}
6169 	}
6170 
6171 	free_var(&rounded);
6172 
6173 	*result = neg ? -val : val;
6174 	return true;
6175 }
6176 
6177 /*
6178  * Convert 128 bit integer to numeric.
6179  */
6180 static void
int128_to_numericvar(int128 val,NumericVar * var)6181 int128_to_numericvar(int128 val, NumericVar *var)
6182 {
6183 	uint128		uval,
6184 				newuval;
6185 	NumericDigit *ptr;
6186 	int			ndigits;
6187 
6188 	/* int128 can require at most 39 decimal digits; add one for safety */
6189 	alloc_var(var, 40 / DEC_DIGITS);
6190 	if (val < 0)
6191 	{
6192 		var->sign = NUMERIC_NEG;
6193 		uval = -val;
6194 	}
6195 	else
6196 	{
6197 		var->sign = NUMERIC_POS;
6198 		uval = val;
6199 	}
6200 	var->dscale = 0;
6201 	if (val == 0)
6202 	{
6203 		var->ndigits = 0;
6204 		var->weight = 0;
6205 		return;
6206 	}
6207 	ptr = var->digits + var->ndigits;
6208 	ndigits = 0;
6209 	do
6210 	{
6211 		ptr--;
6212 		ndigits++;
6213 		newuval = uval / NBASE;
6214 		*ptr = uval - newuval * NBASE;
6215 		uval = newuval;
6216 	} while (uval);
6217 	var->digits = ptr;
6218 	var->ndigits = ndigits;
6219 	var->weight = ndigits - 1;
6220 }
6221 #endif
6222 
6223 /*
6224  * Convert numeric to float8; if out of range, return +/- HUGE_VAL
6225  */
6226 static double
numeric_to_double_no_overflow(Numeric num)6227 numeric_to_double_no_overflow(Numeric num)
6228 {
6229 	char	   *tmp;
6230 	double		val;
6231 	char	   *endptr;
6232 
6233 	tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
6234 											  NumericGetDatum(num)));
6235 
6236 	/* unlike float8in, we ignore ERANGE from strtod */
6237 	val = strtod(tmp, &endptr);
6238 	if (*endptr != '\0')
6239 	{
6240 		/* shouldn't happen ... */
6241 		ereport(ERROR,
6242 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6243 			 errmsg("invalid input syntax for type double precision: \"%s\"",
6244 					tmp)));
6245 	}
6246 
6247 	pfree(tmp);
6248 
6249 	return val;
6250 }
6251 
6252 /* As above, but work from a NumericVar */
6253 static double
numericvar_to_double_no_overflow(NumericVar * var)6254 numericvar_to_double_no_overflow(NumericVar *var)
6255 {
6256 	char	   *tmp;
6257 	double		val;
6258 	char	   *endptr;
6259 
6260 	tmp = get_str_from_var(var);
6261 
6262 	/* unlike float8in, we ignore ERANGE from strtod */
6263 	val = strtod(tmp, &endptr);
6264 	if (*endptr != '\0')
6265 	{
6266 		/* shouldn't happen ... */
6267 		ereport(ERROR,
6268 				(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6269 			 errmsg("invalid input syntax for type double precision: \"%s\"",
6270 					tmp)));
6271 	}
6272 
6273 	pfree(tmp);
6274 
6275 	return val;
6276 }
6277 
6278 
6279 /*
6280  * cmp_var() -
6281  *
6282  *	Compare two values on variable level.  We assume zeroes have been
6283  *	truncated to no digits.
6284  */
6285 static int
cmp_var(NumericVar * var1,NumericVar * var2)6286 cmp_var(NumericVar *var1, NumericVar *var2)
6287 {
6288 	return cmp_var_common(var1->digits, var1->ndigits,
6289 						  var1->weight, var1->sign,
6290 						  var2->digits, var2->ndigits,
6291 						  var2->weight, var2->sign);
6292 }
6293 
6294 /*
6295  * cmp_var_common() -
6296  *
6297  *	Main routine of cmp_var(). This function can be used by both
6298  *	NumericVar and Numeric.
6299  */
6300 static int
cmp_var_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,int var1sign,const NumericDigit * var2digits,int var2ndigits,int var2weight,int var2sign)6301 cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
6302 			   int var1weight, int var1sign,
6303 			   const NumericDigit *var2digits, int var2ndigits,
6304 			   int var2weight, int var2sign)
6305 {
6306 	if (var1ndigits == 0)
6307 	{
6308 		if (var2ndigits == 0)
6309 			return 0;
6310 		if (var2sign == NUMERIC_NEG)
6311 			return 1;
6312 		return -1;
6313 	}
6314 	if (var2ndigits == 0)
6315 	{
6316 		if (var1sign == NUMERIC_POS)
6317 			return 1;
6318 		return -1;
6319 	}
6320 
6321 	if (var1sign == NUMERIC_POS)
6322 	{
6323 		if (var2sign == NUMERIC_NEG)
6324 			return 1;
6325 		return cmp_abs_common(var1digits, var1ndigits, var1weight,
6326 							  var2digits, var2ndigits, var2weight);
6327 	}
6328 
6329 	if (var2sign == NUMERIC_POS)
6330 		return -1;
6331 
6332 	return cmp_abs_common(var2digits, var2ndigits, var2weight,
6333 						  var1digits, var1ndigits, var1weight);
6334 }
6335 
6336 
6337 /*
6338  * add_var() -
6339  *
6340  *	Full version of add functionality on variable level (handling signs).
6341  *	result might point to one of the operands too without danger.
6342  */
6343 static void
add_var(NumericVar * var1,NumericVar * var2,NumericVar * result)6344 add_var(NumericVar *var1, NumericVar *var2, NumericVar *result)
6345 {
6346 	/*
6347 	 * Decide on the signs of the two variables what to do
6348 	 */
6349 	if (var1->sign == NUMERIC_POS)
6350 	{
6351 		if (var2->sign == NUMERIC_POS)
6352 		{
6353 			/*
6354 			 * Both are positive result = +(ABS(var1) + ABS(var2))
6355 			 */
6356 			add_abs(var1, var2, result);
6357 			result->sign = NUMERIC_POS;
6358 		}
6359 		else
6360 		{
6361 			/*
6362 			 * var1 is positive, var2 is negative Must compare absolute values
6363 			 */
6364 			switch (cmp_abs(var1, var2))
6365 			{
6366 				case 0:
6367 					/* ----------
6368 					 * ABS(var1) == ABS(var2)
6369 					 * result = ZERO
6370 					 * ----------
6371 					 */
6372 					zero_var(result);
6373 					result->dscale = Max(var1->dscale, var2->dscale);
6374 					break;
6375 
6376 				case 1:
6377 					/* ----------
6378 					 * ABS(var1) > ABS(var2)
6379 					 * result = +(ABS(var1) - ABS(var2))
6380 					 * ----------
6381 					 */
6382 					sub_abs(var1, var2, result);
6383 					result->sign = NUMERIC_POS;
6384 					break;
6385 
6386 				case -1:
6387 					/* ----------
6388 					 * ABS(var1) < ABS(var2)
6389 					 * result = -(ABS(var2) - ABS(var1))
6390 					 * ----------
6391 					 */
6392 					sub_abs(var2, var1, result);
6393 					result->sign = NUMERIC_NEG;
6394 					break;
6395 			}
6396 		}
6397 	}
6398 	else
6399 	{
6400 		if (var2->sign == NUMERIC_POS)
6401 		{
6402 			/* ----------
6403 			 * var1 is negative, var2 is positive
6404 			 * Must compare absolute values
6405 			 * ----------
6406 			 */
6407 			switch (cmp_abs(var1, var2))
6408 			{
6409 				case 0:
6410 					/* ----------
6411 					 * ABS(var1) == ABS(var2)
6412 					 * result = ZERO
6413 					 * ----------
6414 					 */
6415 					zero_var(result);
6416 					result->dscale = Max(var1->dscale, var2->dscale);
6417 					break;
6418 
6419 				case 1:
6420 					/* ----------
6421 					 * ABS(var1) > ABS(var2)
6422 					 * result = -(ABS(var1) - ABS(var2))
6423 					 * ----------
6424 					 */
6425 					sub_abs(var1, var2, result);
6426 					result->sign = NUMERIC_NEG;
6427 					break;
6428 
6429 				case -1:
6430 					/* ----------
6431 					 * ABS(var1) < ABS(var2)
6432 					 * result = +(ABS(var2) - ABS(var1))
6433 					 * ----------
6434 					 */
6435 					sub_abs(var2, var1, result);
6436 					result->sign = NUMERIC_POS;
6437 					break;
6438 			}
6439 		}
6440 		else
6441 		{
6442 			/* ----------
6443 			 * Both are negative
6444 			 * result = -(ABS(var1) + ABS(var2))
6445 			 * ----------
6446 			 */
6447 			add_abs(var1, var2, result);
6448 			result->sign = NUMERIC_NEG;
6449 		}
6450 	}
6451 }
6452 
6453 
6454 /*
6455  * sub_var() -
6456  *
6457  *	Full version of sub functionality on variable level (handling signs).
6458  *	result might point to one of the operands too without danger.
6459  */
6460 static void
sub_var(NumericVar * var1,NumericVar * var2,NumericVar * result)6461 sub_var(NumericVar *var1, NumericVar *var2, NumericVar *result)
6462 {
6463 	/*
6464 	 * Decide on the signs of the two variables what to do
6465 	 */
6466 	if (var1->sign == NUMERIC_POS)
6467 	{
6468 		if (var2->sign == NUMERIC_NEG)
6469 		{
6470 			/* ----------
6471 			 * var1 is positive, var2 is negative
6472 			 * result = +(ABS(var1) + ABS(var2))
6473 			 * ----------
6474 			 */
6475 			add_abs(var1, var2, result);
6476 			result->sign = NUMERIC_POS;
6477 		}
6478 		else
6479 		{
6480 			/* ----------
6481 			 * Both are positive
6482 			 * Must compare absolute values
6483 			 * ----------
6484 			 */
6485 			switch (cmp_abs(var1, var2))
6486 			{
6487 				case 0:
6488 					/* ----------
6489 					 * ABS(var1) == ABS(var2)
6490 					 * result = ZERO
6491 					 * ----------
6492 					 */
6493 					zero_var(result);
6494 					result->dscale = Max(var1->dscale, var2->dscale);
6495 					break;
6496 
6497 				case 1:
6498 					/* ----------
6499 					 * ABS(var1) > ABS(var2)
6500 					 * result = +(ABS(var1) - ABS(var2))
6501 					 * ----------
6502 					 */
6503 					sub_abs(var1, var2, result);
6504 					result->sign = NUMERIC_POS;
6505 					break;
6506 
6507 				case -1:
6508 					/* ----------
6509 					 * ABS(var1) < ABS(var2)
6510 					 * result = -(ABS(var2) - ABS(var1))
6511 					 * ----------
6512 					 */
6513 					sub_abs(var2, var1, result);
6514 					result->sign = NUMERIC_NEG;
6515 					break;
6516 			}
6517 		}
6518 	}
6519 	else
6520 	{
6521 		if (var2->sign == NUMERIC_NEG)
6522 		{
6523 			/* ----------
6524 			 * Both are negative
6525 			 * Must compare absolute values
6526 			 * ----------
6527 			 */
6528 			switch (cmp_abs(var1, var2))
6529 			{
6530 				case 0:
6531 					/* ----------
6532 					 * ABS(var1) == ABS(var2)
6533 					 * result = ZERO
6534 					 * ----------
6535 					 */
6536 					zero_var(result);
6537 					result->dscale = Max(var1->dscale, var2->dscale);
6538 					break;
6539 
6540 				case 1:
6541 					/* ----------
6542 					 * ABS(var1) > ABS(var2)
6543 					 * result = -(ABS(var1) - ABS(var2))
6544 					 * ----------
6545 					 */
6546 					sub_abs(var1, var2, result);
6547 					result->sign = NUMERIC_NEG;
6548 					break;
6549 
6550 				case -1:
6551 					/* ----------
6552 					 * ABS(var1) < ABS(var2)
6553 					 * result = +(ABS(var2) - ABS(var1))
6554 					 * ----------
6555 					 */
6556 					sub_abs(var2, var1, result);
6557 					result->sign = NUMERIC_POS;
6558 					break;
6559 			}
6560 		}
6561 		else
6562 		{
6563 			/* ----------
6564 			 * var1 is negative, var2 is positive
6565 			 * result = -(ABS(var1) + ABS(var2))
6566 			 * ----------
6567 			 */
6568 			add_abs(var1, var2, result);
6569 			result->sign = NUMERIC_NEG;
6570 		}
6571 	}
6572 }
6573 
6574 
6575 /*
6576  * mul_var() -
6577  *
6578  *	Multiplication on variable level. Product of var1 * var2 is stored
6579  *	in result.  Result is rounded to no more than rscale fractional digits.
6580  */
6581 static void
mul_var(NumericVar * var1,NumericVar * var2,NumericVar * result,int rscale)6582 mul_var(NumericVar *var1, NumericVar *var2, NumericVar *result,
6583 		int rscale)
6584 {
6585 	int			res_ndigits;
6586 	int			res_sign;
6587 	int			res_weight;
6588 	int			maxdigits;
6589 	int		   *dig;
6590 	int			carry;
6591 	int			maxdig;
6592 	int			newdig;
6593 	int			var1ndigits;
6594 	int			var2ndigits;
6595 	NumericDigit *var1digits;
6596 	NumericDigit *var2digits;
6597 	NumericDigit *res_digits;
6598 	int			i,
6599 				i1,
6600 				i2;
6601 
6602 	/*
6603 	 * Arrange for var1 to be the shorter of the two numbers.  This improves
6604 	 * performance because the inner multiplication loop is much simpler than
6605 	 * the outer loop, so it's better to have a smaller number of iterations
6606 	 * of the outer loop.  This also reduces the number of times that the
6607 	 * accumulator array needs to be normalized.
6608 	 */
6609 	if (var1->ndigits > var2->ndigits)
6610 	{
6611 		NumericVar *tmp = var1;
6612 
6613 		var1 = var2;
6614 		var2 = tmp;
6615 	}
6616 
6617 	/* copy these values into local vars for speed in inner loop */
6618 	var1ndigits = var1->ndigits;
6619 	var2ndigits = var2->ndigits;
6620 	var1digits = var1->digits;
6621 	var2digits = var2->digits;
6622 
6623 	if (var1ndigits == 0 || var2ndigits == 0)
6624 	{
6625 		/* one or both inputs is zero; so is result */
6626 		zero_var(result);
6627 		result->dscale = rscale;
6628 		return;
6629 	}
6630 
6631 	/* Determine result sign and (maximum possible) weight */
6632 	if (var1->sign == var2->sign)
6633 		res_sign = NUMERIC_POS;
6634 	else
6635 		res_sign = NUMERIC_NEG;
6636 	res_weight = var1->weight + var2->weight + 2;
6637 
6638 	/*
6639 	 * Determine the number of result digits to compute.  If the exact result
6640 	 * would have more than rscale fractional digits, truncate the computation
6641 	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
6642 	 * would only contribute to the right of that.  (This will give the exact
6643 	 * rounded-to-rscale answer unless carries out of the ignored positions
6644 	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
6645 	 *
6646 	 * Note: an exact computation could not produce more than var1ndigits +
6647 	 * var2ndigits digits, but we allocate one extra output digit in case
6648 	 * rscale-driven rounding produces a carry out of the highest exact digit.
6649 	 */
6650 	res_ndigits = var1ndigits + var2ndigits + 1;
6651 	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
6652 		MUL_GUARD_DIGITS;
6653 	res_ndigits = Min(res_ndigits, maxdigits);
6654 
6655 	if (res_ndigits < 3)
6656 	{
6657 		/* All input digits will be ignored; so result is zero */
6658 		zero_var(result);
6659 		result->dscale = rscale;
6660 		return;
6661 	}
6662 
6663 	/*
6664 	 * We do the arithmetic in an array "dig[]" of signed int's.  Since
6665 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
6666 	 * to avoid normalizing carries immediately.
6667 	 *
6668 	 * maxdig tracks the maximum possible value of any dig[] entry; when this
6669 	 * threatens to exceed INT_MAX, we take the time to propagate carries.
6670 	 * Furthermore, we need to ensure that overflow doesn't occur during the
6671 	 * carry propagation passes either.  The carry values could be as much as
6672 	 * INT_MAX/NBASE, so really we must normalize when digits threaten to
6673 	 * exceed INT_MAX - INT_MAX/NBASE.
6674 	 *
6675 	 * To avoid overflow in maxdig itself, it actually represents the max
6676 	 * possible value divided by NBASE-1, ie, at the top of the loop it is
6677 	 * known that no dig[] entry exceeds maxdig * (NBASE-1).
6678 	 */
6679 	dig = (int *) palloc0(res_ndigits * sizeof(int));
6680 	maxdig = 0;
6681 
6682 	/*
6683 	 * The least significant digits of var1 should be ignored if they don't
6684 	 * contribute directly to the first res_ndigits digits of the result that
6685 	 * we are computing.
6686 	 *
6687 	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
6688 	 * i1+i2+2 of the accumulator array, so we need only consider digits of
6689 	 * var1 for which i1 <= res_ndigits - 3.
6690 	 */
6691 	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
6692 	{
6693 		int			var1digit = var1digits[i1];
6694 
6695 		if (var1digit == 0)
6696 			continue;
6697 
6698 		/* Time to normalize? */
6699 		maxdig += var1digit;
6700 		if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
6701 		{
6702 			/* Yes, do it */
6703 			carry = 0;
6704 			for (i = res_ndigits - 1; i >= 0; i--)
6705 			{
6706 				newdig = dig[i] + carry;
6707 				if (newdig >= NBASE)
6708 				{
6709 					carry = newdig / NBASE;
6710 					newdig -= carry * NBASE;
6711 				}
6712 				else
6713 					carry = 0;
6714 				dig[i] = newdig;
6715 			}
6716 			Assert(carry == 0);
6717 			/* Reset maxdig to indicate new worst-case */
6718 			maxdig = 1 + var1digit;
6719 		}
6720 
6721 		/*
6722 		 * Add the appropriate multiple of var2 into the accumulator.
6723 		 *
6724 		 * As above, digits of var2 can be ignored if they don't contribute,
6725 		 * so we only include digits for which i1+i2+2 <= res_ndigits - 1.
6726 		 */
6727 		for (i2 = Min(var2ndigits - 1, res_ndigits - i1 - 3), i = i1 + i2 + 2;
6728 			 i2 >= 0; i2--)
6729 			dig[i--] += var1digit * var2digits[i2];
6730 	}
6731 
6732 	/*
6733 	 * Now we do a final carry propagation pass to normalize the result, which
6734 	 * we combine with storing the result digits into the output. Note that
6735 	 * this is still done at full precision w/guard digits.
6736 	 */
6737 	alloc_var(result, res_ndigits);
6738 	res_digits = result->digits;
6739 	carry = 0;
6740 	for (i = res_ndigits - 1; i >= 0; i--)
6741 	{
6742 		newdig = dig[i] + carry;
6743 		if (newdig >= NBASE)
6744 		{
6745 			carry = newdig / NBASE;
6746 			newdig -= carry * NBASE;
6747 		}
6748 		else
6749 			carry = 0;
6750 		res_digits[i] = newdig;
6751 	}
6752 	Assert(carry == 0);
6753 
6754 	pfree(dig);
6755 
6756 	/*
6757 	 * Finally, round the result to the requested precision.
6758 	 */
6759 	result->weight = res_weight;
6760 	result->sign = res_sign;
6761 
6762 	/* Round to target rscale (and set result->dscale) */
6763 	round_var(result, rscale);
6764 
6765 	/* Strip leading and trailing zeroes */
6766 	strip_var(result);
6767 }
6768 
6769 
6770 /*
6771  * div_var() -
6772  *
6773  *	Division on variable level. Quotient of var1 / var2 is stored in result.
6774  *	The quotient is figured to exactly rscale fractional digits.
6775  *	If round is true, it is rounded at the rscale'th digit; if false, it
6776  *	is truncated (towards zero) at that digit.
6777  */
6778 static void
div_var(NumericVar * var1,NumericVar * var2,NumericVar * result,int rscale,bool round)6779 div_var(NumericVar *var1, NumericVar *var2, NumericVar *result,
6780 		int rscale, bool round)
6781 {
6782 	int			div_ndigits;
6783 	int			res_ndigits;
6784 	int			res_sign;
6785 	int			res_weight;
6786 	int			carry;
6787 	int			borrow;
6788 	int			divisor1;
6789 	int			divisor2;
6790 	NumericDigit *dividend;
6791 	NumericDigit *divisor;
6792 	NumericDigit *res_digits;
6793 	int			i;
6794 	int			j;
6795 
6796 	/* copy these values into local vars for speed in inner loop */
6797 	int			var1ndigits = var1->ndigits;
6798 	int			var2ndigits = var2->ndigits;
6799 
6800 	/*
6801 	 * First of all division by zero check; we must not be handed an
6802 	 * unnormalized divisor.
6803 	 */
6804 	if (var2ndigits == 0 || var2->digits[0] == 0)
6805 		ereport(ERROR,
6806 				(errcode(ERRCODE_DIVISION_BY_ZERO),
6807 				 errmsg("division by zero")));
6808 
6809 	/*
6810 	 * Now result zero check
6811 	 */
6812 	if (var1ndigits == 0)
6813 	{
6814 		zero_var(result);
6815 		result->dscale = rscale;
6816 		return;
6817 	}
6818 
6819 	/*
6820 	 * Determine the result sign, weight and number of digits to calculate.
6821 	 * The weight figured here is correct if the emitted quotient has no
6822 	 * leading zero digits; otherwise strip_var() will fix things up.
6823 	 */
6824 	if (var1->sign == var2->sign)
6825 		res_sign = NUMERIC_POS;
6826 	else
6827 		res_sign = NUMERIC_NEG;
6828 	res_weight = var1->weight - var2->weight;
6829 	/* The number of accurate result digits we need to produce: */
6830 	res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
6831 	/* ... but always at least 1 */
6832 	res_ndigits = Max(res_ndigits, 1);
6833 	/* If rounding needed, figure one more digit to ensure correct result */
6834 	if (round)
6835 		res_ndigits++;
6836 
6837 	/*
6838 	 * The working dividend normally requires res_ndigits + var2ndigits
6839 	 * digits, but make it at least var1ndigits so we can load all of var1
6840 	 * into it.  (There will be an additional digit dividend[0] in the
6841 	 * dividend space, but for consistency with Knuth's notation we don't
6842 	 * count that in div_ndigits.)
6843 	 */
6844 	div_ndigits = res_ndigits + var2ndigits;
6845 	div_ndigits = Max(div_ndigits, var1ndigits);
6846 
6847 	/*
6848 	 * We need a workspace with room for the working dividend (div_ndigits+1
6849 	 * digits) plus room for the possibly-normalized divisor (var2ndigits
6850 	 * digits).  It is convenient also to have a zero at divisor[0] with the
6851 	 * actual divisor data in divisor[1 .. var2ndigits].  Transferring the
6852 	 * digits into the workspace also allows us to realloc the result (which
6853 	 * might be the same as either input var) before we begin the main loop.
6854 	 * Note that we use palloc0 to ensure that divisor[0], dividend[0], and
6855 	 * any additional dividend positions beyond var1ndigits, start out 0.
6856 	 */
6857 	dividend = (NumericDigit *)
6858 		palloc0((div_ndigits + var2ndigits + 2) * sizeof(NumericDigit));
6859 	divisor = dividend + (div_ndigits + 1);
6860 	memcpy(dividend + 1, var1->digits, var1ndigits * sizeof(NumericDigit));
6861 	memcpy(divisor + 1, var2->digits, var2ndigits * sizeof(NumericDigit));
6862 
6863 	/*
6864 	 * Now we can realloc the result to hold the generated quotient digits.
6865 	 */
6866 	alloc_var(result, res_ndigits);
6867 	res_digits = result->digits;
6868 
6869 	if (var2ndigits == 1)
6870 	{
6871 		/*
6872 		 * If there's only a single divisor digit, we can use a fast path (cf.
6873 		 * Knuth section 4.3.1 exercise 16).
6874 		 */
6875 		divisor1 = divisor[1];
6876 		carry = 0;
6877 		for (i = 0; i < res_ndigits; i++)
6878 		{
6879 			carry = carry * NBASE + dividend[i + 1];
6880 			res_digits[i] = carry / divisor1;
6881 			carry = carry % divisor1;
6882 		}
6883 	}
6884 	else
6885 	{
6886 		/*
6887 		 * The full multiple-place algorithm is taken from Knuth volume 2,
6888 		 * Algorithm 4.3.1D.
6889 		 *
6890 		 * We need the first divisor digit to be >= NBASE/2.  If it isn't,
6891 		 * make it so by scaling up both the divisor and dividend by the
6892 		 * factor "d".  (The reason for allocating dividend[0] above is to
6893 		 * leave room for possible carry here.)
6894 		 */
6895 		if (divisor[1] < HALF_NBASE)
6896 		{
6897 			int			d = NBASE / (divisor[1] + 1);
6898 
6899 			carry = 0;
6900 			for (i = var2ndigits; i > 0; i--)
6901 			{
6902 				carry += divisor[i] * d;
6903 				divisor[i] = carry % NBASE;
6904 				carry = carry / NBASE;
6905 			}
6906 			Assert(carry == 0);
6907 			carry = 0;
6908 			/* at this point only var1ndigits of dividend can be nonzero */
6909 			for (i = var1ndigits; i >= 0; i--)
6910 			{
6911 				carry += dividend[i] * d;
6912 				dividend[i] = carry % NBASE;
6913 				carry = carry / NBASE;
6914 			}
6915 			Assert(carry == 0);
6916 			Assert(divisor[1] >= HALF_NBASE);
6917 		}
6918 		/* First 2 divisor digits are used repeatedly in main loop */
6919 		divisor1 = divisor[1];
6920 		divisor2 = divisor[2];
6921 
6922 		/*
6923 		 * Begin the main loop.  Each iteration of this loop produces the j'th
6924 		 * quotient digit by dividing dividend[j .. j + var2ndigits] by the
6925 		 * divisor; this is essentially the same as the common manual
6926 		 * procedure for long division.
6927 		 */
6928 		for (j = 0; j < res_ndigits; j++)
6929 		{
6930 			/* Estimate quotient digit from the first two dividend digits */
6931 			int			next2digits = dividend[j] * NBASE + dividend[j + 1];
6932 			int			qhat;
6933 
6934 			/*
6935 			 * If next2digits are 0, then quotient digit must be 0 and there's
6936 			 * no need to adjust the working dividend.  It's worth testing
6937 			 * here to fall out ASAP when processing trailing zeroes in a
6938 			 * dividend.
6939 			 */
6940 			if (next2digits == 0)
6941 			{
6942 				res_digits[j] = 0;
6943 				continue;
6944 			}
6945 
6946 			if (dividend[j] == divisor1)
6947 				qhat = NBASE - 1;
6948 			else
6949 				qhat = next2digits / divisor1;
6950 
6951 			/*
6952 			 * Adjust quotient digit if it's too large.  Knuth proves that
6953 			 * after this step, the quotient digit will be either correct or
6954 			 * just one too large.  (Note: it's OK to use dividend[j+2] here
6955 			 * because we know the divisor length is at least 2.)
6956 			 */
6957 			while (divisor2 * qhat >
6958 				   (next2digits - qhat * divisor1) * NBASE + dividend[j + 2])
6959 				qhat--;
6960 
6961 			/* As above, need do nothing more when quotient digit is 0 */
6962 			if (qhat > 0)
6963 			{
6964 				/*
6965 				 * Multiply the divisor by qhat, and subtract that from the
6966 				 * working dividend.  "carry" tracks the multiplication,
6967 				 * "borrow" the subtraction (could we fold these together?)
6968 				 */
6969 				carry = 0;
6970 				borrow = 0;
6971 				for (i = var2ndigits; i >= 0; i--)
6972 				{
6973 					carry += divisor[i] * qhat;
6974 					borrow -= carry % NBASE;
6975 					carry = carry / NBASE;
6976 					borrow += dividend[j + i];
6977 					if (borrow < 0)
6978 					{
6979 						dividend[j + i] = borrow + NBASE;
6980 						borrow = -1;
6981 					}
6982 					else
6983 					{
6984 						dividend[j + i] = borrow;
6985 						borrow = 0;
6986 					}
6987 				}
6988 				Assert(carry == 0);
6989 
6990 				/*
6991 				 * If we got a borrow out of the top dividend digit, then
6992 				 * indeed qhat was one too large.  Fix it, and add back the
6993 				 * divisor to correct the working dividend.  (Knuth proves
6994 				 * that this will occur only about 3/NBASE of the time; hence,
6995 				 * it's a good idea to test this code with small NBASE to be
6996 				 * sure this section gets exercised.)
6997 				 */
6998 				if (borrow)
6999 				{
7000 					qhat--;
7001 					carry = 0;
7002 					for (i = var2ndigits; i >= 0; i--)
7003 					{
7004 						carry += dividend[j + i] + divisor[i];
7005 						if (carry >= NBASE)
7006 						{
7007 							dividend[j + i] = carry - NBASE;
7008 							carry = 1;
7009 						}
7010 						else
7011 						{
7012 							dividend[j + i] = carry;
7013 							carry = 0;
7014 						}
7015 					}
7016 					/* A carry should occur here to cancel the borrow above */
7017 					Assert(carry == 1);
7018 				}
7019 			}
7020 
7021 			/* And we're done with this quotient digit */
7022 			res_digits[j] = qhat;
7023 		}
7024 	}
7025 
7026 	pfree(dividend);
7027 
7028 	/*
7029 	 * Finally, round or truncate the result to the requested precision.
7030 	 */
7031 	result->weight = res_weight;
7032 	result->sign = res_sign;
7033 
7034 	/* Round or truncate to target rscale (and set result->dscale) */
7035 	if (round)
7036 		round_var(result, rscale);
7037 	else
7038 		trunc_var(result, rscale);
7039 
7040 	/* Strip leading and trailing zeroes */
7041 	strip_var(result);
7042 }
7043 
7044 
7045 /*
7046  * div_var_fast() -
7047  *
7048  *	This has the same API as div_var, but is implemented using the division
7049  *	algorithm from the "FM" library, rather than Knuth's schoolbook-division
7050  *	approach.  This is significantly faster but can produce inaccurate
7051  *	results, because it sometimes has to propagate rounding to the left,
7052  *	and so we can never be entirely sure that we know the requested digits
7053  *	exactly.  We compute DIV_GUARD_DIGITS extra digits, but there is
7054  *	no certainty that that's enough.  We use this only in the transcendental
7055  *	function calculation routines, where everything is approximate anyway.
7056  *
7057  *	Although we provide a "round" argument for consistency with div_var,
7058  *	it is unwise to use this function with round=false.  In truncation mode
7059  *	it is possible to get a result with no significant digits, for example
7060  *	with rscale=0 we might compute 0.99999... and truncate that to 0 when
7061  *	the correct answer is 1.
7062  */
7063 static void
div_var_fast(NumericVar * var1,NumericVar * var2,NumericVar * result,int rscale,bool round)7064 div_var_fast(NumericVar *var1, NumericVar *var2, NumericVar *result,
7065 			 int rscale, bool round)
7066 {
7067 	int			div_ndigits;
7068 	int			res_sign;
7069 	int			res_weight;
7070 	int		   *div;
7071 	int			qdigit;
7072 	int			carry;
7073 	int			maxdiv;
7074 	int			newdig;
7075 	NumericDigit *res_digits;
7076 	double		fdividend,
7077 				fdivisor,
7078 				fdivisorinverse,
7079 				fquotient;
7080 	int			qi;
7081 	int			i;
7082 
7083 	/* copy these values into local vars for speed in inner loop */
7084 	int			var1ndigits = var1->ndigits;
7085 	int			var2ndigits = var2->ndigits;
7086 	NumericDigit *var1digits = var1->digits;
7087 	NumericDigit *var2digits = var2->digits;
7088 
7089 	/*
7090 	 * First of all division by zero check; we must not be handed an
7091 	 * unnormalized divisor.
7092 	 */
7093 	if (var2ndigits == 0 || var2digits[0] == 0)
7094 		ereport(ERROR,
7095 				(errcode(ERRCODE_DIVISION_BY_ZERO),
7096 				 errmsg("division by zero")));
7097 
7098 	/*
7099 	 * Now result zero check
7100 	 */
7101 	if (var1ndigits == 0)
7102 	{
7103 		zero_var(result);
7104 		result->dscale = rscale;
7105 		return;
7106 	}
7107 
7108 	/*
7109 	 * Determine the result sign, weight and number of digits to calculate
7110 	 */
7111 	if (var1->sign == var2->sign)
7112 		res_sign = NUMERIC_POS;
7113 	else
7114 		res_sign = NUMERIC_NEG;
7115 	res_weight = var1->weight - var2->weight + 1;
7116 	/* The number of accurate result digits we need to produce: */
7117 	div_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7118 	/* Add guard digits for roundoff error */
7119 	div_ndigits += DIV_GUARD_DIGITS;
7120 	if (div_ndigits < DIV_GUARD_DIGITS)
7121 		div_ndigits = DIV_GUARD_DIGITS;
7122 	/* Must be at least var1ndigits, too, to simplify data-loading loop */
7123 	if (div_ndigits < var1ndigits)
7124 		div_ndigits = var1ndigits;
7125 
7126 	/*
7127 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
7128 	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
7129 	 * to avoid normalizing carries immediately.
7130 	 *
7131 	 * We start with div[] containing one zero digit followed by the
7132 	 * dividend's digits (plus appended zeroes to reach the desired precision
7133 	 * including guard digits).  Each step of the main loop computes an
7134 	 * (approximate) quotient digit and stores it into div[], removing one
7135 	 * position of dividend space.  A final pass of carry propagation takes
7136 	 * care of any mistaken quotient digits.
7137 	 */
7138 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
7139 	for (i = 0; i < var1ndigits; i++)
7140 		div[i + 1] = var1digits[i];
7141 
7142 	/*
7143 	 * We estimate each quotient digit using floating-point arithmetic, taking
7144 	 * the first four digits of the (current) dividend and divisor.  This must
7145 	 * be float to avoid overflow.  The quotient digits will generally be off
7146 	 * by no more than one from the exact answer.
7147 	 */
7148 	fdivisor = (double) var2digits[0];
7149 	for (i = 1; i < 4; i++)
7150 	{
7151 		fdivisor *= NBASE;
7152 		if (i < var2ndigits)
7153 			fdivisor += (double) var2digits[i];
7154 	}
7155 	fdivisorinverse = 1.0 / fdivisor;
7156 
7157 	/*
7158 	 * maxdiv tracks the maximum possible absolute value of any div[] entry;
7159 	 * when this threatens to exceed INT_MAX, we take the time to propagate
7160 	 * carries.  Furthermore, we need to ensure that overflow doesn't occur
7161 	 * during the carry propagation passes either.  The carry values may have
7162 	 * an absolute value as high as INT_MAX/NBASE + 1, so really we must
7163 	 * normalize when digits threaten to exceed INT_MAX - INT_MAX/NBASE - 1.
7164 	 *
7165 	 * To avoid overflow in maxdiv itself, it represents the max absolute
7166 	 * value divided by NBASE-1, ie, at the top of the loop it is known that
7167 	 * no div[] entry has an absolute value exceeding maxdiv * (NBASE-1).
7168 	 *
7169 	 * Actually, though, that holds good only for div[] entries after div[qi];
7170 	 * the adjustment done at the bottom of the loop may cause div[qi + 1] to
7171 	 * exceed the maxdiv limit, so that div[qi] in the next iteration is
7172 	 * beyond the limit.  This does not cause problems, as explained below.
7173 	 */
7174 	maxdiv = 1;
7175 
7176 	/*
7177 	 * Outer loop computes next quotient digit, which will go into div[qi]
7178 	 */
7179 	for (qi = 0; qi < div_ndigits; qi++)
7180 	{
7181 		/* Approximate the current dividend value */
7182 		fdividend = (double) div[qi];
7183 		for (i = 1; i < 4; i++)
7184 		{
7185 			fdividend *= NBASE;
7186 			if (qi + i <= div_ndigits)
7187 				fdividend += (double) div[qi + i];
7188 		}
7189 		/* Compute the (approximate) quotient digit */
7190 		fquotient = fdividend * fdivisorinverse;
7191 		qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7192 			(((int) fquotient) - 1);	/* truncate towards -infinity */
7193 
7194 		if (qdigit != 0)
7195 		{
7196 			/* Do we need to normalize now? */
7197 			maxdiv += Abs(qdigit);
7198 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
7199 			{
7200 				/* Yes, do it */
7201 				carry = 0;
7202 				for (i = div_ndigits; i > qi; i--)
7203 				{
7204 					newdig = div[i] + carry;
7205 					if (newdig < 0)
7206 					{
7207 						carry = -((-newdig - 1) / NBASE) - 1;
7208 						newdig -= carry * NBASE;
7209 					}
7210 					else if (newdig >= NBASE)
7211 					{
7212 						carry = newdig / NBASE;
7213 						newdig -= carry * NBASE;
7214 					}
7215 					else
7216 						carry = 0;
7217 					div[i] = newdig;
7218 				}
7219 				newdig = div[qi] + carry;
7220 				div[qi] = newdig;
7221 
7222 				/*
7223 				 * All the div[] digits except possibly div[qi] are now in the
7224 				 * range 0..NBASE-1.  We do not need to consider div[qi] in
7225 				 * the maxdiv value anymore, so we can reset maxdiv to 1.
7226 				 */
7227 				maxdiv = 1;
7228 
7229 				/*
7230 				 * Recompute the quotient digit since new info may have
7231 				 * propagated into the top four dividend digits
7232 				 */
7233 				fdividend = (double) div[qi];
7234 				for (i = 1; i < 4; i++)
7235 				{
7236 					fdividend *= NBASE;
7237 					if (qi + i <= div_ndigits)
7238 						fdividend += (double) div[qi + i];
7239 				}
7240 				/* Compute the (approximate) quotient digit */
7241 				fquotient = fdividend * fdivisorinverse;
7242 				qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7243 					(((int) fquotient) - 1);	/* truncate towards -infinity */
7244 				maxdiv += Abs(qdigit);
7245 			}
7246 
7247 			/*
7248 			 * Subtract off the appropriate multiple of the divisor.
7249 			 *
7250 			 * The digits beyond div[qi] cannot overflow, because we know they
7251 			 * will fall within the maxdiv limit.  As for div[qi] itself, note
7252 			 * that qdigit is approximately trunc(div[qi] / vardigits[0]),
7253 			 * which would make the new value simply div[qi] mod vardigits[0].
7254 			 * The lower-order terms in qdigit can change this result by not
7255 			 * more than about twice INT_MAX/NBASE, so overflow is impossible.
7256 			 */
7257 			if (qdigit != 0)
7258 			{
7259 				int			istop = Min(var2ndigits, div_ndigits - qi + 1);
7260 
7261 				for (i = 0; i < istop; i++)
7262 					div[qi + i] -= qdigit * var2digits[i];
7263 			}
7264 		}
7265 
7266 		/*
7267 		 * The dividend digit we are about to replace might still be nonzero.
7268 		 * Fold it into the next digit position.
7269 		 *
7270 		 * There is no risk of overflow here, although proving that requires
7271 		 * some care.  Much as with the argument for div[qi] not overflowing,
7272 		 * if we consider the first two terms in the numerator and denominator
7273 		 * of qdigit, we can see that the final value of div[qi + 1] will be
7274 		 * approximately a remainder mod (vardigits[0]*NBASE + vardigits[1]).
7275 		 * Accounting for the lower-order terms is a bit complicated but ends
7276 		 * up adding not much more than INT_MAX/NBASE to the possible range.
7277 		 * Thus, div[qi + 1] cannot overflow here, and in its role as div[qi]
7278 		 * in the next loop iteration, it can't be large enough to cause
7279 		 * overflow in the carry propagation step (if any), either.
7280 		 *
7281 		 * But having said that: div[qi] can be more than INT_MAX/NBASE, as
7282 		 * noted above, which means that the product div[qi] * NBASE *can*
7283 		 * overflow.  When that happens, adding it to div[qi + 1] will always
7284 		 * cause a canceling overflow so that the end result is correct.  We
7285 		 * could avoid the intermediate overflow by doing the multiplication
7286 		 * and addition in int64 arithmetic, but so far there appears no need.
7287 		 */
7288 		div[qi + 1] += div[qi] * NBASE;
7289 
7290 		div[qi] = qdigit;
7291 	}
7292 
7293 	/*
7294 	 * Approximate and store the last quotient digit (div[div_ndigits])
7295 	 */
7296 	fdividend = (double) div[qi];
7297 	for (i = 1; i < 4; i++)
7298 		fdividend *= NBASE;
7299 	fquotient = fdividend * fdivisorinverse;
7300 	qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7301 		(((int) fquotient) - 1);	/* truncate towards -infinity */
7302 	div[qi] = qdigit;
7303 
7304 	/*
7305 	 * Because the quotient digits might be off by one, some of them might be
7306 	 * -1 or NBASE at this point.  The represented value is correct in a
7307 	 * mathematical sense, but it doesn't look right.  We do a final carry
7308 	 * propagation pass to normalize the digits, which we combine with storing
7309 	 * the result digits into the output.  Note that this is still done at
7310 	 * full precision w/guard digits.
7311 	 */
7312 	alloc_var(result, div_ndigits + 1);
7313 	res_digits = result->digits;
7314 	carry = 0;
7315 	for (i = div_ndigits; i >= 0; i--)
7316 	{
7317 		newdig = div[i] + carry;
7318 		if (newdig < 0)
7319 		{
7320 			carry = -((-newdig - 1) / NBASE) - 1;
7321 			newdig -= carry * NBASE;
7322 		}
7323 		else if (newdig >= NBASE)
7324 		{
7325 			carry = newdig / NBASE;
7326 			newdig -= carry * NBASE;
7327 		}
7328 		else
7329 			carry = 0;
7330 		res_digits[i] = newdig;
7331 	}
7332 	Assert(carry == 0);
7333 
7334 	pfree(div);
7335 
7336 	/*
7337 	 * Finally, round the result to the requested precision.
7338 	 */
7339 	result->weight = res_weight;
7340 	result->sign = res_sign;
7341 
7342 	/* Round to target rscale (and set result->dscale) */
7343 	if (round)
7344 		round_var(result, rscale);
7345 	else
7346 		trunc_var(result, rscale);
7347 
7348 	/* Strip leading and trailing zeroes */
7349 	strip_var(result);
7350 }
7351 
7352 
7353 /*
7354  * Default scale selection for division
7355  *
7356  * Returns the appropriate result scale for the division result.
7357  */
7358 static int
select_div_scale(NumericVar * var1,NumericVar * var2)7359 select_div_scale(NumericVar *var1, NumericVar *var2)
7360 {
7361 	int			weight1,
7362 				weight2,
7363 				qweight,
7364 				i;
7365 	NumericDigit firstdigit1,
7366 				firstdigit2;
7367 	int			rscale;
7368 
7369 	/*
7370 	 * The result scale of a division isn't specified in any SQL standard. For
7371 	 * PostgreSQL we select a result scale that will give at least
7372 	 * NUMERIC_MIN_SIG_DIGITS significant digits, so that numeric gives a
7373 	 * result no less accurate than float8; but use a scale not less than
7374 	 * either input's display scale.
7375 	 */
7376 
7377 	/* Get the actual (normalized) weight and first digit of each input */
7378 
7379 	weight1 = 0;				/* values to use if var1 is zero */
7380 	firstdigit1 = 0;
7381 	for (i = 0; i < var1->ndigits; i++)
7382 	{
7383 		firstdigit1 = var1->digits[i];
7384 		if (firstdigit1 != 0)
7385 		{
7386 			weight1 = var1->weight - i;
7387 			break;
7388 		}
7389 	}
7390 
7391 	weight2 = 0;				/* values to use if var2 is zero */
7392 	firstdigit2 = 0;
7393 	for (i = 0; i < var2->ndigits; i++)
7394 	{
7395 		firstdigit2 = var2->digits[i];
7396 		if (firstdigit2 != 0)
7397 		{
7398 			weight2 = var2->weight - i;
7399 			break;
7400 		}
7401 	}
7402 
7403 	/*
7404 	 * Estimate weight of quotient.  If the two first digits are equal, we
7405 	 * can't be sure, but assume that var1 is less than var2.
7406 	 */
7407 	qweight = weight1 - weight2;
7408 	if (firstdigit1 <= firstdigit2)
7409 		qweight--;
7410 
7411 	/* Select result scale */
7412 	rscale = NUMERIC_MIN_SIG_DIGITS - qweight * DEC_DIGITS;
7413 	rscale = Max(rscale, var1->dscale);
7414 	rscale = Max(rscale, var2->dscale);
7415 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
7416 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
7417 
7418 	return rscale;
7419 }
7420 
7421 
7422 /*
7423  * mod_var() -
7424  *
7425  *	Calculate the modulo of two numerics at variable level
7426  */
7427 static void
mod_var(NumericVar * var1,NumericVar * var2,NumericVar * result)7428 mod_var(NumericVar *var1, NumericVar *var2, NumericVar *result)
7429 {
7430 	NumericVar	tmp;
7431 
7432 	init_var(&tmp);
7433 
7434 	/* ---------
7435 	 * We do this using the equation
7436 	 *		mod(x,y) = x - trunc(x/y)*y
7437 	 * div_var can be persuaded to give us trunc(x/y) directly.
7438 	 * ----------
7439 	 */
7440 	div_var(var1, var2, &tmp, 0, false);
7441 
7442 	mul_var(var2, &tmp, &tmp, var2->dscale);
7443 
7444 	sub_var(var1, &tmp, result);
7445 
7446 	free_var(&tmp);
7447 }
7448 
7449 
7450 /*
7451  * ceil_var() -
7452  *
7453  *	Return the smallest integer greater than or equal to the argument
7454  *	on variable level
7455  */
7456 static void
ceil_var(NumericVar * var,NumericVar * result)7457 ceil_var(NumericVar *var, NumericVar *result)
7458 {
7459 	NumericVar	tmp;
7460 
7461 	init_var(&tmp);
7462 	set_var_from_var(var, &tmp);
7463 
7464 	trunc_var(&tmp, 0);
7465 
7466 	if (var->sign == NUMERIC_POS && cmp_var(var, &tmp) != 0)
7467 		add_var(&tmp, &const_one, &tmp);
7468 
7469 	set_var_from_var(&tmp, result);
7470 	free_var(&tmp);
7471 }
7472 
7473 
7474 /*
7475  * floor_var() -
7476  *
7477  *	Return the largest integer equal to or less than the argument
7478  *	on variable level
7479  */
7480 static void
floor_var(NumericVar * var,NumericVar * result)7481 floor_var(NumericVar *var, NumericVar *result)
7482 {
7483 	NumericVar	tmp;
7484 
7485 	init_var(&tmp);
7486 	set_var_from_var(var, &tmp);
7487 
7488 	trunc_var(&tmp, 0);
7489 
7490 	if (var->sign == NUMERIC_NEG && cmp_var(var, &tmp) != 0)
7491 		sub_var(&tmp, &const_one, &tmp);
7492 
7493 	set_var_from_var(&tmp, result);
7494 	free_var(&tmp);
7495 }
7496 
7497 
7498 /*
7499  * sqrt_var() -
7500  *
7501  *	Compute the square root of x using Newton's algorithm
7502  */
7503 static void
sqrt_var(NumericVar * arg,NumericVar * result,int rscale)7504 sqrt_var(NumericVar *arg, NumericVar *result, int rscale)
7505 {
7506 	NumericVar	tmp_arg;
7507 	NumericVar	tmp_val;
7508 	NumericVar	last_val;
7509 	int			local_rscale;
7510 	int			stat;
7511 
7512 	local_rscale = rscale + 8;
7513 
7514 	stat = cmp_var(arg, &const_zero);
7515 	if (stat == 0)
7516 	{
7517 		zero_var(result);
7518 		result->dscale = rscale;
7519 		return;
7520 	}
7521 
7522 	/*
7523 	 * SQL2003 defines sqrt() in terms of power, so we need to emit the right
7524 	 * SQLSTATE error code if the operand is negative.
7525 	 */
7526 	if (stat < 0)
7527 		ereport(ERROR,
7528 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
7529 				 errmsg("cannot take square root of a negative number")));
7530 
7531 	init_var(&tmp_arg);
7532 	init_var(&tmp_val);
7533 	init_var(&last_val);
7534 
7535 	/* Copy arg in case it is the same var as result */
7536 	set_var_from_var(arg, &tmp_arg);
7537 
7538 	/*
7539 	 * Initialize the result to the first guess
7540 	 */
7541 	alloc_var(result, 1);
7542 	result->digits[0] = tmp_arg.digits[0] / 2;
7543 	if (result->digits[0] == 0)
7544 		result->digits[0] = 1;
7545 	result->weight = tmp_arg.weight / 2;
7546 	result->sign = NUMERIC_POS;
7547 
7548 	set_var_from_var(result, &last_val);
7549 
7550 	for (;;)
7551 	{
7552 		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
7553 
7554 		add_var(result, &tmp_val, result);
7555 		mul_var(result, &const_zero_point_five, result, local_rscale);
7556 
7557 		if (cmp_var(&last_val, result) == 0)
7558 			break;
7559 		set_var_from_var(result, &last_val);
7560 	}
7561 
7562 	free_var(&last_val);
7563 	free_var(&tmp_val);
7564 	free_var(&tmp_arg);
7565 
7566 	/* Round to requested precision */
7567 	round_var(result, rscale);
7568 }
7569 
7570 
7571 /*
7572  * exp_var() -
7573  *
7574  *	Raise e to the power of x, computed to rscale fractional digits
7575  */
7576 static void
exp_var(NumericVar * arg,NumericVar * result,int rscale)7577 exp_var(NumericVar *arg, NumericVar *result, int rscale)
7578 {
7579 	NumericVar	x;
7580 	NumericVar	elem;
7581 	NumericVar	ni;
7582 	double		val;
7583 	int			dweight;
7584 	int			ndiv2;
7585 	int			sig_digits;
7586 	int			local_rscale;
7587 
7588 	init_var(&x);
7589 	init_var(&elem);
7590 	init_var(&ni);
7591 
7592 	set_var_from_var(arg, &x);
7593 
7594 	/*
7595 	 * Estimate the dweight of the result using floating point arithmetic, so
7596 	 * that we can choose an appropriate local rscale for the calculation.
7597 	 */
7598 	val = numericvar_to_double_no_overflow(&x);
7599 
7600 	/* Guard against overflow/underflow */
7601 	/* If you change this limit, see also power_var()'s limit */
7602 	if (Abs(val) >= NUMERIC_MAX_RESULT_SCALE * 3)
7603 	{
7604 		if (val > 0)
7605 			ereport(ERROR,
7606 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
7607 					 errmsg("value overflows numeric format")));
7608 		zero_var(result);
7609 		result->dscale = rscale;
7610 		return;
7611 	}
7612 
7613 	/* decimal weight = log10(e^x) = x * log10(e) */
7614 	dweight = (int) (val * 0.434294481903252);
7615 
7616 	/*
7617 	 * Reduce x to the range -0.01 <= x <= 0.01 (approximately) by dividing by
7618 	 * 2^n, to improve the convergence rate of the Taylor series.
7619 	 */
7620 	if (Abs(val) > 0.01)
7621 	{
7622 		NumericVar	tmp;
7623 
7624 		init_var(&tmp);
7625 		set_var_from_var(&const_two, &tmp);
7626 
7627 		ndiv2 = 1;
7628 		val /= 2;
7629 
7630 		while (Abs(val) > 0.01)
7631 		{
7632 			ndiv2++;
7633 			val /= 2;
7634 			add_var(&tmp, &tmp, &tmp);
7635 		}
7636 
7637 		local_rscale = x.dscale + ndiv2;
7638 		div_var_fast(&x, &tmp, &x, local_rscale, true);
7639 
7640 		free_var(&tmp);
7641 	}
7642 	else
7643 		ndiv2 = 0;
7644 
7645 	/*
7646 	 * Set the scale for the Taylor series expansion.  The final result has
7647 	 * (dweight + rscale + 1) significant digits.  In addition, we have to
7648 	 * raise the Taylor series result to the power 2^ndiv2, which introduces
7649 	 * an error of up to around log10(2^ndiv2) digits, so work with this many
7650 	 * extra digits of precision (plus a few more for good measure).
7651 	 */
7652 	sig_digits = 1 + dweight + rscale + (int) (ndiv2 * 0.301029995663981);
7653 	sig_digits = Max(sig_digits, 0) + 8;
7654 
7655 	local_rscale = sig_digits - 1;
7656 
7657 	/*
7658 	 * Use the Taylor series
7659 	 *
7660 	 * exp(x) = 1 + x + x^2/2! + x^3/3! + ...
7661 	 *
7662 	 * Given the limited range of x, this should converge reasonably quickly.
7663 	 * We run the series until the terms fall below the local_rscale limit.
7664 	 */
7665 	add_var(&const_one, &x, result);
7666 
7667 	mul_var(&x, &x, &elem, local_rscale);
7668 	set_var_from_var(&const_two, &ni);
7669 	div_var_fast(&elem, &ni, &elem, local_rscale, true);
7670 
7671 	while (elem.ndigits != 0)
7672 	{
7673 		add_var(result, &elem, result);
7674 
7675 		mul_var(&elem, &x, &elem, local_rscale);
7676 		add_var(&ni, &const_one, &ni);
7677 		div_var_fast(&elem, &ni, &elem, local_rscale, true);
7678 	}
7679 
7680 	/*
7681 	 * Compensate for the argument range reduction.  Since the weight of the
7682 	 * result doubles with each multiplication, we can reduce the local rscale
7683 	 * as we proceed.
7684 	 */
7685 	while (ndiv2-- > 0)
7686 	{
7687 		local_rscale = sig_digits - result->weight * 2 * DEC_DIGITS;
7688 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
7689 		mul_var(result, result, result, local_rscale);
7690 	}
7691 
7692 	/* Round to requested rscale */
7693 	round_var(result, rscale);
7694 
7695 	free_var(&x);
7696 	free_var(&elem);
7697 	free_var(&ni);
7698 }
7699 
7700 
7701 /*
7702  * Estimate the dweight of the most significant decimal digit of the natural
7703  * logarithm of a number.
7704  *
7705  * Essentially, we're approximating log10(abs(ln(var))).  This is used to
7706  * determine the appropriate rscale when computing natural logarithms.
7707  */
7708 static int
estimate_ln_dweight(NumericVar * var)7709 estimate_ln_dweight(NumericVar *var)
7710 {
7711 	int			ln_dweight;
7712 
7713 	if (cmp_var(var, &const_zero_point_nine) >= 0 &&
7714 		cmp_var(var, &const_one_point_one) <= 0)
7715 	{
7716 		/*
7717 		 * 0.9 <= var <= 1.1
7718 		 *
7719 		 * ln(var) has a negative weight (possibly very large).  To get a
7720 		 * reasonably accurate result, estimate it using ln(1+x) ~= x.
7721 		 */
7722 		NumericVar	x;
7723 
7724 		init_var(&x);
7725 		sub_var(var, &const_one, &x);
7726 
7727 		if (x.ndigits > 0)
7728 		{
7729 			/* Use weight of most significant decimal digit of x */
7730 			ln_dweight = x.weight * DEC_DIGITS + (int) log10(x.digits[0]);
7731 		}
7732 		else
7733 		{
7734 			/* x = 0.  Since ln(1) = 0 exactly, we don't need extra digits */
7735 			ln_dweight = 0;
7736 		}
7737 
7738 		free_var(&x);
7739 	}
7740 	else
7741 	{
7742 		/*
7743 		 * Estimate the logarithm using the first couple of digits from the
7744 		 * input number.  This will give an accurate result whenever the input
7745 		 * is not too close to 1.
7746 		 */
7747 		if (var->ndigits > 0)
7748 		{
7749 			int			digits;
7750 			int			dweight;
7751 			double		ln_var;
7752 
7753 			digits = var->digits[0];
7754 			dweight = var->weight * DEC_DIGITS;
7755 
7756 			if (var->ndigits > 1)
7757 			{
7758 				digits = digits * NBASE + var->digits[1];
7759 				dweight -= DEC_DIGITS;
7760 			}
7761 
7762 			/*----------
7763 			 * We have var ~= digits * 10^dweight
7764 			 * so ln(var) ~= ln(digits) + dweight * ln(10)
7765 			 *----------
7766 			 */
7767 			ln_var = log((double) digits) + dweight * 2.302585092994046;
7768 			ln_dweight = (int) log10(Abs(ln_var));
7769 		}
7770 		else
7771 		{
7772 			/* Caller should fail on ln(0), but for the moment return zero */
7773 			ln_dweight = 0;
7774 		}
7775 	}
7776 
7777 	return ln_dweight;
7778 }
7779 
7780 
7781 /*
7782  * ln_var() -
7783  *
7784  *	Compute the natural log of x
7785  */
7786 static void
ln_var(NumericVar * arg,NumericVar * result,int rscale)7787 ln_var(NumericVar *arg, NumericVar *result, int rscale)
7788 {
7789 	NumericVar	x;
7790 	NumericVar	xx;
7791 	NumericVar	ni;
7792 	NumericVar	elem;
7793 	NumericVar	fact;
7794 	int			local_rscale;
7795 	int			cmp;
7796 
7797 	cmp = cmp_var(arg, &const_zero);
7798 	if (cmp == 0)
7799 		ereport(ERROR,
7800 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
7801 				 errmsg("cannot take logarithm of zero")));
7802 	else if (cmp < 0)
7803 		ereport(ERROR,
7804 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
7805 				 errmsg("cannot take logarithm of a negative number")));
7806 
7807 	init_var(&x);
7808 	init_var(&xx);
7809 	init_var(&ni);
7810 	init_var(&elem);
7811 	init_var(&fact);
7812 
7813 	set_var_from_var(arg, &x);
7814 	set_var_from_var(&const_two, &fact);
7815 
7816 	/*
7817 	 * Reduce input into range 0.9 < x < 1.1 with repeated sqrt() operations.
7818 	 *
7819 	 * The final logarithm will have up to around rscale+6 significant digits.
7820 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
7821 	 * rscale as we work so that we keep this many significant digits at each
7822 	 * step (plus a few more for good measure).
7823 	 */
7824 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
7825 	{
7826 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
7827 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
7828 		sqrt_var(&x, &x, local_rscale);
7829 		mul_var(&fact, &const_two, &fact, 0);
7830 	}
7831 	while (cmp_var(&x, &const_one_point_one) >= 0)
7832 	{
7833 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
7834 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
7835 		sqrt_var(&x, &x, local_rscale);
7836 		mul_var(&fact, &const_two, &fact, 0);
7837 	}
7838 
7839 	/*
7840 	 * We use the Taylor series for 0.5 * ln((1+z)/(1-z)),
7841 	 *
7842 	 * z + z^3/3 + z^5/5 + ...
7843 	 *
7844 	 * where z = (x-1)/(x+1) is in the range (approximately) -0.053 .. 0.048
7845 	 * due to the above range-reduction of x.
7846 	 *
7847 	 * The convergence of this is not as fast as one would like, but is
7848 	 * tolerable given that z is small.
7849 	 */
7850 	local_rscale = rscale + 8;
7851 
7852 	sub_var(&x, &const_one, result);
7853 	add_var(&x, &const_one, &elem);
7854 	div_var_fast(result, &elem, result, local_rscale, true);
7855 	set_var_from_var(result, &xx);
7856 	mul_var(result, result, &x, local_rscale);
7857 
7858 	set_var_from_var(&const_one, &ni);
7859 
7860 	for (;;)
7861 	{
7862 		add_var(&ni, &const_two, &ni);
7863 		mul_var(&xx, &x, &xx, local_rscale);
7864 		div_var_fast(&xx, &ni, &elem, local_rscale, true);
7865 
7866 		if (elem.ndigits == 0)
7867 			break;
7868 
7869 		add_var(result, &elem, result);
7870 
7871 		if (elem.weight < (result->weight - local_rscale * 2 / DEC_DIGITS))
7872 			break;
7873 	}
7874 
7875 	/* Compensate for argument range reduction, round to requested rscale */
7876 	mul_var(result, &fact, result, rscale);
7877 
7878 	free_var(&x);
7879 	free_var(&xx);
7880 	free_var(&ni);
7881 	free_var(&elem);
7882 	free_var(&fact);
7883 }
7884 
7885 
7886 /*
7887  * log_var() -
7888  *
7889  *	Compute the logarithm of num in a given base.
7890  *
7891  *	Note: this routine chooses dscale of the result.
7892  */
7893 static void
log_var(NumericVar * base,NumericVar * num,NumericVar * result)7894 log_var(NumericVar *base, NumericVar *num, NumericVar *result)
7895 {
7896 	NumericVar	ln_base;
7897 	NumericVar	ln_num;
7898 	int			ln_base_dweight;
7899 	int			ln_num_dweight;
7900 	int			result_dweight;
7901 	int			rscale;
7902 	int			ln_base_rscale;
7903 	int			ln_num_rscale;
7904 
7905 	init_var(&ln_base);
7906 	init_var(&ln_num);
7907 
7908 	/* Estimated dweights of ln(base), ln(num) and the final result */
7909 	ln_base_dweight = estimate_ln_dweight(base);
7910 	ln_num_dweight = estimate_ln_dweight(num);
7911 	result_dweight = ln_num_dweight - ln_base_dweight;
7912 
7913 	/*
7914 	 * Select the scale of the result so that it will have at least
7915 	 * NUMERIC_MIN_SIG_DIGITS significant digits and is not less than either
7916 	 * input's display scale.
7917 	 */
7918 	rscale = NUMERIC_MIN_SIG_DIGITS - result_dweight;
7919 	rscale = Max(rscale, base->dscale);
7920 	rscale = Max(rscale, num->dscale);
7921 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
7922 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
7923 
7924 	/*
7925 	 * Set the scales for ln(base) and ln(num) so that they each have more
7926 	 * significant digits than the final result.
7927 	 */
7928 	ln_base_rscale = rscale + result_dweight - ln_base_dweight + 8;
7929 	ln_base_rscale = Max(ln_base_rscale, NUMERIC_MIN_DISPLAY_SCALE);
7930 
7931 	ln_num_rscale = rscale + result_dweight - ln_num_dweight + 8;
7932 	ln_num_rscale = Max(ln_num_rscale, NUMERIC_MIN_DISPLAY_SCALE);
7933 
7934 	/* Form natural logarithms */
7935 	ln_var(base, &ln_base, ln_base_rscale);
7936 	ln_var(num, &ln_num, ln_num_rscale);
7937 
7938 	/* Divide and round to the required scale */
7939 	div_var_fast(&ln_num, &ln_base, result, rscale, true);
7940 
7941 	free_var(&ln_num);
7942 	free_var(&ln_base);
7943 }
7944 
7945 
7946 /*
7947  * power_var() -
7948  *
7949  *	Raise base to the power of exp
7950  *
7951  *	Note: this routine chooses dscale of the result.
7952  */
7953 static void
power_var(NumericVar * base,NumericVar * exp,NumericVar * result)7954 power_var(NumericVar *base, NumericVar *exp, NumericVar *result)
7955 {
7956 	int			res_sign;
7957 	NumericVar	abs_base;
7958 	NumericVar	ln_base;
7959 	NumericVar	ln_num;
7960 	int			ln_dweight;
7961 	int			rscale;
7962 	int			sig_digits;
7963 	int			local_rscale;
7964 	double		val;
7965 
7966 	/* If exp can be represented as an integer, use power_var_int */
7967 	if (exp->ndigits == 0 || exp->ndigits <= exp->weight + 1)
7968 	{
7969 		/* exact integer, but does it fit in int? */
7970 		int64		expval64;
7971 
7972 		if (numericvar_to_int64(exp, &expval64))
7973 		{
7974 			if (expval64 >= PG_INT32_MIN && expval64 <= PG_INT32_MAX)
7975 			{
7976 				/* Okay, select rscale */
7977 				rscale = NUMERIC_MIN_SIG_DIGITS;
7978 				rscale = Max(rscale, base->dscale);
7979 				rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
7980 				rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
7981 
7982 				power_var_int(base, (int) expval64, result, rscale);
7983 				return;
7984 			}
7985 		}
7986 	}
7987 
7988 	/*
7989 	 * This avoids log(0) for cases of 0 raised to a non-integer.  0 ^ 0 is
7990 	 * handled by power_var_int().
7991 	 */
7992 	if (cmp_var(base, &const_zero) == 0)
7993 	{
7994 		set_var_from_var(&const_zero, result);
7995 		result->dscale = NUMERIC_MIN_SIG_DIGITS;		/* no need to round */
7996 		return;
7997 	}
7998 
7999 	init_var(&abs_base);
8000 	init_var(&ln_base);
8001 	init_var(&ln_num);
8002 
8003 	/*
8004 	 * If base is negative, insist that exp be an integer.  The result is then
8005 	 * positive if exp is even and negative if exp is odd.
8006 	 */
8007 	if (base->sign == NUMERIC_NEG)
8008 	{
8009 		/*
8010 		 * Check that exp is an integer.  This error code is defined by the
8011 		 * SQL standard, and matches other errors in numeric_power().
8012 		 */
8013 		if (exp->ndigits > 0 && exp->ndigits > exp->weight + 1)
8014 			ereport(ERROR,
8015 					(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
8016 					 errmsg("a negative number raised to a non-integer power yields a complex result")));
8017 
8018 		/* Test if exp is odd or even */
8019 		if (exp->ndigits > 0 && exp->ndigits == exp->weight + 1 &&
8020 			(exp->digits[exp->ndigits - 1] & 1))
8021 			res_sign = NUMERIC_NEG;
8022 		else
8023 			res_sign = NUMERIC_POS;
8024 
8025 		/* Then work with abs(base) below */
8026 		set_var_from_var(base, &abs_base);
8027 		abs_base.sign = NUMERIC_POS;
8028 		base = &abs_base;
8029 	}
8030 	else
8031 		res_sign = NUMERIC_POS;
8032 
8033 	/*----------
8034 	 * Decide on the scale for the ln() calculation.  For this we need an
8035 	 * estimate of the weight of the result, which we obtain by doing an
8036 	 * initial low-precision calculation of exp * ln(base).
8037 	 *
8038 	 * We want result = e ^ (exp * ln(base))
8039 	 * so result dweight = log10(result) = exp * ln(base) * log10(e)
8040 	 *
8041 	 * We also perform a crude overflow test here so that we can exit early if
8042 	 * the full-precision result is sure to overflow, and to guard against
8043 	 * integer overflow when determining the scale for the real calculation.
8044 	 * exp_var() supports inputs up to NUMERIC_MAX_RESULT_SCALE * 3, so the
8045 	 * result will overflow if exp * ln(base) >= NUMERIC_MAX_RESULT_SCALE * 3.
8046 	 * Since the values here are only approximations, we apply a small fuzz
8047 	 * factor to this overflow test and let exp_var() determine the exact
8048 	 * overflow threshold so that it is consistent for all inputs.
8049 	 *----------
8050 	 */
8051 	ln_dweight = estimate_ln_dweight(base);
8052 
8053 	/*
8054 	 * Set the scale for the low-precision calculation, computing ln(base) to
8055 	 * around 8 significant digits.  Note that ln_dweight may be as small as
8056 	 * -SHRT_MAX, so the scale may exceed NUMERIC_MAX_DISPLAY_SCALE here.
8057 	 */
8058 	local_rscale = 8 - ln_dweight;
8059 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8060 
8061 	ln_var(base, &ln_base, local_rscale);
8062 
8063 	mul_var(&ln_base, exp, &ln_num, local_rscale);
8064 
8065 	val = numericvar_to_double_no_overflow(&ln_num);
8066 
8067 	/* initial overflow/underflow test with fuzz factor */
8068 	if (Abs(val) > NUMERIC_MAX_RESULT_SCALE * 3.01)
8069 	{
8070 		if (val > 0)
8071 			ereport(ERROR,
8072 					(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8073 					 errmsg("value overflows numeric format")));
8074 		zero_var(result);
8075 		result->dscale = NUMERIC_MAX_DISPLAY_SCALE;
8076 		return;
8077 	}
8078 
8079 	val *= 0.434294481903252;	/* approximate decimal result weight */
8080 
8081 	/* choose the result scale */
8082 	rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
8083 	rscale = Max(rscale, base->dscale);
8084 	rscale = Max(rscale, exp->dscale);
8085 	rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8086 	rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8087 
8088 	/* significant digits required in the result */
8089 	sig_digits = rscale + (int) val;
8090 	sig_digits = Max(sig_digits, 0);
8091 
8092 	/* set the scale for the real exp * ln(base) calculation */
8093 	local_rscale = sig_digits - ln_dweight + 8;
8094 	local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8095 
8096 	/* and do the real calculation */
8097 
8098 	ln_var(base, &ln_base, local_rscale);
8099 
8100 	mul_var(&ln_base, exp, &ln_num, local_rscale);
8101 
8102 	exp_var(&ln_num, result, rscale);
8103 
8104 	if (res_sign == NUMERIC_NEG && result->ndigits > 0)
8105 		result->sign = NUMERIC_NEG;
8106 
8107 	free_var(&ln_num);
8108 	free_var(&ln_base);
8109 	free_var(&abs_base);
8110 }
8111 
8112 /*
8113  * power_var_int() -
8114  *
8115  *	Raise base to the power of exp, where exp is an integer.
8116  */
8117 static void
power_var_int(NumericVar * base,int exp,NumericVar * result,int rscale)8118 power_var_int(NumericVar *base, int exp, NumericVar *result, int rscale)
8119 {
8120 	double		f;
8121 	int			p;
8122 	int			i;
8123 	int			sig_digits;
8124 	unsigned int mask;
8125 	bool		neg;
8126 	NumericVar	base_prod;
8127 	int			local_rscale;
8128 
8129 	/* Handle some common special cases, as well as corner cases */
8130 	switch (exp)
8131 	{
8132 		case 0:
8133 
8134 			/*
8135 			 * While 0 ^ 0 can be either 1 or indeterminate (error), we treat
8136 			 * it as 1 because most programming languages do this. SQL:2003
8137 			 * also requires a return value of 1.
8138 			 * http://en.wikipedia.org/wiki/Exponentiation#Zero_to_the_zero_pow
8139 			 * er
8140 			 */
8141 			set_var_from_var(&const_one, result);
8142 			result->dscale = rscale;	/* no need to round */
8143 			return;
8144 		case 1:
8145 			set_var_from_var(base, result);
8146 			round_var(result, rscale);
8147 			return;
8148 		case -1:
8149 			div_var(&const_one, base, result, rscale, true);
8150 			return;
8151 		case 2:
8152 			mul_var(base, base, result, rscale);
8153 			return;
8154 		default:
8155 			break;
8156 	}
8157 
8158 	/* Handle the special case where the base is zero */
8159 	if (base->ndigits == 0)
8160 	{
8161 		if (exp < 0)
8162 			ereport(ERROR,
8163 					(errcode(ERRCODE_DIVISION_BY_ZERO),
8164 					 errmsg("division by zero")));
8165 		zero_var(result);
8166 		result->dscale = rscale;
8167 		return;
8168 	}
8169 
8170 	/*
8171 	 * The general case repeatedly multiplies base according to the bit
8172 	 * pattern of exp.
8173 	 *
8174 	 * First we need to estimate the weight of the result so that we know how
8175 	 * many significant digits are needed.
8176 	 */
8177 	f = base->digits[0];
8178 	p = base->weight * DEC_DIGITS;
8179 
8180 	for (i = 1; i < base->ndigits && i * DEC_DIGITS < 16; i++)
8181 	{
8182 		f = f * NBASE + base->digits[i];
8183 		p -= DEC_DIGITS;
8184 	}
8185 
8186 	/*----------
8187 	 * We have base ~= f * 10^p
8188 	 * so log10(result) = log10(base^exp) ~= exp * (log10(f) + p)
8189 	 *----------
8190 	 */
8191 	f = exp * (log10(f) + p);
8192 
8193 	/*
8194 	 * Apply crude overflow/underflow tests so we can exit early if the result
8195 	 * certainly will overflow/underflow.
8196 	 */
8197 	if (f > 3 * SHRT_MAX * DEC_DIGITS)
8198 		ereport(ERROR,
8199 				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8200 				 errmsg("value overflows numeric format")));
8201 	if (f + 1 < -rscale || f + 1 < -NUMERIC_MAX_DISPLAY_SCALE)
8202 	{
8203 		zero_var(result);
8204 		result->dscale = rscale;
8205 		return;
8206 	}
8207 
8208 	/*
8209 	 * Approximate number of significant digits in the result.  Note that the
8210 	 * underflow test above means that this is necessarily >= 0.
8211 	 */
8212 	sig_digits = 1 + rscale + (int) f;
8213 
8214 	/*
8215 	 * The multiplications to produce the result may introduce an error of up
8216 	 * to around log10(abs(exp)) digits, so work with this many extra digits
8217 	 * of precision (plus a few more for good measure).
8218 	 */
8219 	sig_digits += (int) log(fabs((double) exp)) + 8;
8220 
8221 	/*
8222 	 * Now we can proceed with the multiplications.
8223 	 */
8224 	neg = (exp < 0);
8225 	mask = Abs(exp);
8226 
8227 	init_var(&base_prod);
8228 	set_var_from_var(base, &base_prod);
8229 
8230 	if (mask & 1)
8231 		set_var_from_var(base, result);
8232 	else
8233 		set_var_from_var(&const_one, result);
8234 
8235 	while ((mask >>= 1) > 0)
8236 	{
8237 		/*
8238 		 * Do the multiplications using rscales large enough to hold the
8239 		 * results to the required number of significant digits, but don't
8240 		 * waste time by exceeding the scales of the numbers themselves.
8241 		 */
8242 		local_rscale = sig_digits - 2 * base_prod.weight * DEC_DIGITS;
8243 		local_rscale = Min(local_rscale, 2 * base_prod.dscale);
8244 		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8245 
8246 		mul_var(&base_prod, &base_prod, &base_prod, local_rscale);
8247 
8248 		if (mask & 1)
8249 		{
8250 			local_rscale = sig_digits -
8251 				(base_prod.weight + result->weight) * DEC_DIGITS;
8252 			local_rscale = Min(local_rscale,
8253 							   base_prod.dscale + result->dscale);
8254 			local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8255 
8256 			mul_var(&base_prod, result, result, local_rscale);
8257 		}
8258 
8259 		/*
8260 		 * When abs(base) > 1, the number of digits to the left of the decimal
8261 		 * point in base_prod doubles at each iteration, so if exp is large we
8262 		 * could easily spend large amounts of time and memory space doing the
8263 		 * multiplications.  But once the weight exceeds what will fit in
8264 		 * int16, the final result is guaranteed to overflow (or underflow, if
8265 		 * exp < 0), so we can give up before wasting too many cycles.
8266 		 */
8267 		if (base_prod.weight > SHRT_MAX || result->weight > SHRT_MAX)
8268 		{
8269 			/* overflow, unless neg, in which case result should be 0 */
8270 			if (!neg)
8271 				ereport(ERROR,
8272 						(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8273 						 errmsg("value overflows numeric format")));
8274 			zero_var(result);
8275 			neg = false;
8276 			break;
8277 		}
8278 	}
8279 
8280 	free_var(&base_prod);
8281 
8282 	/* Compensate for input sign, and round to requested rscale */
8283 	if (neg)
8284 		div_var_fast(&const_one, result, result, rscale, true);
8285 	else
8286 		round_var(result, rscale);
8287 }
8288 
8289 /*
8290  * power_ten_int() -
8291  *
8292  *	Raise ten to the power of exp, where exp is an integer.  Note that unlike
8293  *	power_var_int(), this does no overflow/underflow checking or rounding.
8294  */
8295 static void
power_ten_int(int exp,NumericVar * result)8296 power_ten_int(int exp, NumericVar *result)
8297 {
8298 	/* Construct the result directly, starting from 10^0 = 1 */
8299 	set_var_from_var(&const_one, result);
8300 
8301 	/* Scale needed to represent the result exactly */
8302 	result->dscale = exp < 0 ? -exp : 0;
8303 
8304 	/* Base-NBASE weight of result and remaining exponent */
8305 	if (exp >= 0)
8306 		result->weight = exp / DEC_DIGITS;
8307 	else
8308 		result->weight = (exp + 1) / DEC_DIGITS - 1;
8309 
8310 	exp -= result->weight * DEC_DIGITS;
8311 
8312 	/* Final adjustment of the result's single NBASE digit */
8313 	while (exp-- > 0)
8314 		result->digits[0] *= 10;
8315 }
8316 
8317 
8318 /* ----------------------------------------------------------------------
8319  *
8320  * Following are the lowest level functions that operate unsigned
8321  * on the variable level
8322  *
8323  * ----------------------------------------------------------------------
8324  */
8325 
8326 
8327 /* ----------
8328  * cmp_abs() -
8329  *
8330  *	Compare the absolute values of var1 and var2
8331  *	Returns:	-1 for ABS(var1) < ABS(var2)
8332  *				0  for ABS(var1) == ABS(var2)
8333  *				1  for ABS(var1) > ABS(var2)
8334  * ----------
8335  */
8336 static int
cmp_abs(NumericVar * var1,NumericVar * var2)8337 cmp_abs(NumericVar *var1, NumericVar *var2)
8338 {
8339 	return cmp_abs_common(var1->digits, var1->ndigits, var1->weight,
8340 						  var2->digits, var2->ndigits, var2->weight);
8341 }
8342 
8343 /* ----------
8344  * cmp_abs_common() -
8345  *
8346  *	Main routine of cmp_abs(). This function can be used by both
8347  *	NumericVar and Numeric.
8348  * ----------
8349  */
8350 static int
cmp_abs_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,const NumericDigit * var2digits,int var2ndigits,int var2weight)8351 cmp_abs_common(const NumericDigit *var1digits, int var1ndigits, int var1weight,
8352 			 const NumericDigit *var2digits, int var2ndigits, int var2weight)
8353 {
8354 	int			i1 = 0;
8355 	int			i2 = 0;
8356 
8357 	/* Check any digits before the first common digit */
8358 
8359 	while (var1weight > var2weight && i1 < var1ndigits)
8360 	{
8361 		if (var1digits[i1++] != 0)
8362 			return 1;
8363 		var1weight--;
8364 	}
8365 	while (var2weight > var1weight && i2 < var2ndigits)
8366 	{
8367 		if (var2digits[i2++] != 0)
8368 			return -1;
8369 		var2weight--;
8370 	}
8371 
8372 	/* At this point, either w1 == w2 or we've run out of digits */
8373 
8374 	if (var1weight == var2weight)
8375 	{
8376 		while (i1 < var1ndigits && i2 < var2ndigits)
8377 		{
8378 			int			stat = var1digits[i1++] - var2digits[i2++];
8379 
8380 			if (stat)
8381 			{
8382 				if (stat > 0)
8383 					return 1;
8384 				return -1;
8385 			}
8386 		}
8387 	}
8388 
8389 	/*
8390 	 * At this point, we've run out of digits on one side or the other; so any
8391 	 * remaining nonzero digits imply that side is larger
8392 	 */
8393 	while (i1 < var1ndigits)
8394 	{
8395 		if (var1digits[i1++] != 0)
8396 			return 1;
8397 	}
8398 	while (i2 < var2ndigits)
8399 	{
8400 		if (var2digits[i2++] != 0)
8401 			return -1;
8402 	}
8403 
8404 	return 0;
8405 }
8406 
8407 
8408 /*
8409  * add_abs() -
8410  *
8411  *	Add the absolute values of two variables into result.
8412  *	result might point to one of the operands without danger.
8413  */
8414 static void
add_abs(NumericVar * var1,NumericVar * var2,NumericVar * result)8415 add_abs(NumericVar *var1, NumericVar *var2, NumericVar *result)
8416 {
8417 	NumericDigit *res_buf;
8418 	NumericDigit *res_digits;
8419 	int			res_ndigits;
8420 	int			res_weight;
8421 	int			res_rscale,
8422 				rscale1,
8423 				rscale2;
8424 	int			res_dscale;
8425 	int			i,
8426 				i1,
8427 				i2;
8428 	int			carry = 0;
8429 
8430 	/* copy these values into local vars for speed in inner loop */
8431 	int			var1ndigits = var1->ndigits;
8432 	int			var2ndigits = var2->ndigits;
8433 	NumericDigit *var1digits = var1->digits;
8434 	NumericDigit *var2digits = var2->digits;
8435 
8436 	res_weight = Max(var1->weight, var2->weight) + 1;
8437 
8438 	res_dscale = Max(var1->dscale, var2->dscale);
8439 
8440 	/* Note: here we are figuring rscale in base-NBASE digits */
8441 	rscale1 = var1->ndigits - var1->weight - 1;
8442 	rscale2 = var2->ndigits - var2->weight - 1;
8443 	res_rscale = Max(rscale1, rscale2);
8444 
8445 	res_ndigits = res_rscale + res_weight + 1;
8446 	if (res_ndigits <= 0)
8447 		res_ndigits = 1;
8448 
8449 	res_buf = digitbuf_alloc(res_ndigits + 1);
8450 	res_buf[0] = 0;				/* spare digit for later rounding */
8451 	res_digits = res_buf + 1;
8452 
8453 	i1 = res_rscale + var1->weight + 1;
8454 	i2 = res_rscale + var2->weight + 1;
8455 	for (i = res_ndigits - 1; i >= 0; i--)
8456 	{
8457 		i1--;
8458 		i2--;
8459 		if (i1 >= 0 && i1 < var1ndigits)
8460 			carry += var1digits[i1];
8461 		if (i2 >= 0 && i2 < var2ndigits)
8462 			carry += var2digits[i2];
8463 
8464 		if (carry >= NBASE)
8465 		{
8466 			res_digits[i] = carry - NBASE;
8467 			carry = 1;
8468 		}
8469 		else
8470 		{
8471 			res_digits[i] = carry;
8472 			carry = 0;
8473 		}
8474 	}
8475 
8476 	Assert(carry == 0);			/* else we failed to allow for carry out */
8477 
8478 	digitbuf_free(result->buf);
8479 	result->ndigits = res_ndigits;
8480 	result->buf = res_buf;
8481 	result->digits = res_digits;
8482 	result->weight = res_weight;
8483 	result->dscale = res_dscale;
8484 
8485 	/* Remove leading/trailing zeroes */
8486 	strip_var(result);
8487 }
8488 
8489 
8490 /*
8491  * sub_abs()
8492  *
8493  *	Subtract the absolute value of var2 from the absolute value of var1
8494  *	and store in result. result might point to one of the operands
8495  *	without danger.
8496  *
8497  *	ABS(var1) MUST BE GREATER OR EQUAL ABS(var2) !!!
8498  */
8499 static void
sub_abs(NumericVar * var1,NumericVar * var2,NumericVar * result)8500 sub_abs(NumericVar *var1, NumericVar *var2, NumericVar *result)
8501 {
8502 	NumericDigit *res_buf;
8503 	NumericDigit *res_digits;
8504 	int			res_ndigits;
8505 	int			res_weight;
8506 	int			res_rscale,
8507 				rscale1,
8508 				rscale2;
8509 	int			res_dscale;
8510 	int			i,
8511 				i1,
8512 				i2;
8513 	int			borrow = 0;
8514 
8515 	/* copy these values into local vars for speed in inner loop */
8516 	int			var1ndigits = var1->ndigits;
8517 	int			var2ndigits = var2->ndigits;
8518 	NumericDigit *var1digits = var1->digits;
8519 	NumericDigit *var2digits = var2->digits;
8520 
8521 	res_weight = var1->weight;
8522 
8523 	res_dscale = Max(var1->dscale, var2->dscale);
8524 
8525 	/* Note: here we are figuring rscale in base-NBASE digits */
8526 	rscale1 = var1->ndigits - var1->weight - 1;
8527 	rscale2 = var2->ndigits - var2->weight - 1;
8528 	res_rscale = Max(rscale1, rscale2);
8529 
8530 	res_ndigits = res_rscale + res_weight + 1;
8531 	if (res_ndigits <= 0)
8532 		res_ndigits = 1;
8533 
8534 	res_buf = digitbuf_alloc(res_ndigits + 1);
8535 	res_buf[0] = 0;				/* spare digit for later rounding */
8536 	res_digits = res_buf + 1;
8537 
8538 	i1 = res_rscale + var1->weight + 1;
8539 	i2 = res_rscale + var2->weight + 1;
8540 	for (i = res_ndigits - 1; i >= 0; i--)
8541 	{
8542 		i1--;
8543 		i2--;
8544 		if (i1 >= 0 && i1 < var1ndigits)
8545 			borrow += var1digits[i1];
8546 		if (i2 >= 0 && i2 < var2ndigits)
8547 			borrow -= var2digits[i2];
8548 
8549 		if (borrow < 0)
8550 		{
8551 			res_digits[i] = borrow + NBASE;
8552 			borrow = -1;
8553 		}
8554 		else
8555 		{
8556 			res_digits[i] = borrow;
8557 			borrow = 0;
8558 		}
8559 	}
8560 
8561 	Assert(borrow == 0);		/* else caller gave us var1 < var2 */
8562 
8563 	digitbuf_free(result->buf);
8564 	result->ndigits = res_ndigits;
8565 	result->buf = res_buf;
8566 	result->digits = res_digits;
8567 	result->weight = res_weight;
8568 	result->dscale = res_dscale;
8569 
8570 	/* Remove leading/trailing zeroes */
8571 	strip_var(result);
8572 }
8573 
8574 /*
8575  * round_var
8576  *
8577  * Round the value of a variable to no more than rscale decimal digits
8578  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
8579  * rounding before the decimal point.
8580  */
8581 static void
round_var(NumericVar * var,int rscale)8582 round_var(NumericVar *var, int rscale)
8583 {
8584 	NumericDigit *digits = var->digits;
8585 	int			di;
8586 	int			ndigits;
8587 	int			carry;
8588 
8589 	var->dscale = rscale;
8590 
8591 	/* decimal digits wanted */
8592 	di = (var->weight + 1) * DEC_DIGITS + rscale;
8593 
8594 	/*
8595 	 * If di = 0, the value loses all digits, but could round up to 1 if its
8596 	 * first extra digit is >= 5.  If di < 0 the result must be 0.
8597 	 */
8598 	if (di < 0)
8599 	{
8600 		var->ndigits = 0;
8601 		var->weight = 0;
8602 		var->sign = NUMERIC_POS;
8603 	}
8604 	else
8605 	{
8606 		/* NBASE digits wanted */
8607 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
8608 
8609 		/* 0, or number of decimal digits to keep in last NBASE digit */
8610 		di %= DEC_DIGITS;
8611 
8612 		if (ndigits < var->ndigits ||
8613 			(ndigits == var->ndigits && di > 0))
8614 		{
8615 			var->ndigits = ndigits;
8616 
8617 #if DEC_DIGITS == 1
8618 			/* di must be zero */
8619 			carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
8620 #else
8621 			if (di == 0)
8622 				carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
8623 			else
8624 			{
8625 				/* Must round within last NBASE digit */
8626 				int			extra,
8627 							pow10;
8628 
8629 #if DEC_DIGITS == 4
8630 				pow10 = round_powers[di];
8631 #elif DEC_DIGITS == 2
8632 				pow10 = 10;
8633 #else
8634 #error unsupported NBASE
8635 #endif
8636 				extra = digits[--ndigits] % pow10;
8637 				digits[ndigits] -= extra;
8638 				carry = 0;
8639 				if (extra >= pow10 / 2)
8640 				{
8641 					pow10 += digits[ndigits];
8642 					if (pow10 >= NBASE)
8643 					{
8644 						pow10 -= NBASE;
8645 						carry = 1;
8646 					}
8647 					digits[ndigits] = pow10;
8648 				}
8649 			}
8650 #endif
8651 
8652 			/* Propagate carry if needed */
8653 			while (carry)
8654 			{
8655 				carry += digits[--ndigits];
8656 				if (carry >= NBASE)
8657 				{
8658 					digits[ndigits] = carry - NBASE;
8659 					carry = 1;
8660 				}
8661 				else
8662 				{
8663 					digits[ndigits] = carry;
8664 					carry = 0;
8665 				}
8666 			}
8667 
8668 			if (ndigits < 0)
8669 			{
8670 				Assert(ndigits == -1);	/* better not have added > 1 digit */
8671 				Assert(var->digits > var->buf);
8672 				var->digits--;
8673 				var->ndigits++;
8674 				var->weight++;
8675 			}
8676 		}
8677 	}
8678 }
8679 
8680 /*
8681  * trunc_var
8682  *
8683  * Truncate (towards zero) the value of a variable at rscale decimal digits
8684  * after the decimal point.  NOTE: we allow rscale < 0 here, implying
8685  * truncation before the decimal point.
8686  */
8687 static void
trunc_var(NumericVar * var,int rscale)8688 trunc_var(NumericVar *var, int rscale)
8689 {
8690 	int			di;
8691 	int			ndigits;
8692 
8693 	var->dscale = rscale;
8694 
8695 	/* decimal digits wanted */
8696 	di = (var->weight + 1) * DEC_DIGITS + rscale;
8697 
8698 	/*
8699 	 * If di <= 0, the value loses all digits.
8700 	 */
8701 	if (di <= 0)
8702 	{
8703 		var->ndigits = 0;
8704 		var->weight = 0;
8705 		var->sign = NUMERIC_POS;
8706 	}
8707 	else
8708 	{
8709 		/* NBASE digits wanted */
8710 		ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
8711 
8712 		if (ndigits <= var->ndigits)
8713 		{
8714 			var->ndigits = ndigits;
8715 
8716 #if DEC_DIGITS == 1
8717 			/* no within-digit stuff to worry about */
8718 #else
8719 			/* 0, or number of decimal digits to keep in last NBASE digit */
8720 			di %= DEC_DIGITS;
8721 
8722 			if (di > 0)
8723 			{
8724 				/* Must truncate within last NBASE digit */
8725 				NumericDigit *digits = var->digits;
8726 				int			extra,
8727 							pow10;
8728 
8729 #if DEC_DIGITS == 4
8730 				pow10 = round_powers[di];
8731 #elif DEC_DIGITS == 2
8732 				pow10 = 10;
8733 #else
8734 #error unsupported NBASE
8735 #endif
8736 				extra = digits[--ndigits] % pow10;
8737 				digits[ndigits] -= extra;
8738 			}
8739 #endif
8740 		}
8741 	}
8742 }
8743 
8744 /*
8745  * strip_var
8746  *
8747  * Strip any leading and trailing zeroes from a numeric variable
8748  */
8749 static void
strip_var(NumericVar * var)8750 strip_var(NumericVar *var)
8751 {
8752 	NumericDigit *digits = var->digits;
8753 	int			ndigits = var->ndigits;
8754 
8755 	/* Strip leading zeroes */
8756 	while (ndigits > 0 && *digits == 0)
8757 	{
8758 		digits++;
8759 		var->weight--;
8760 		ndigits--;
8761 	}
8762 
8763 	/* Strip trailing zeroes */
8764 	while (ndigits > 0 && digits[ndigits - 1] == 0)
8765 		ndigits--;
8766 
8767 	/* If it's zero, normalize the sign and weight */
8768 	if (ndigits == 0)
8769 	{
8770 		var->sign = NUMERIC_POS;
8771 		var->weight = 0;
8772 	}
8773 
8774 	var->digits = digits;
8775 	var->ndigits = ndigits;
8776 }
8777