1 /*-------------------------------------------------------------------------
2 *
3 * numeric.c
4 * An exact numeric data type for the Postgres database system
5 *
6 * Original coding 1998, Jan Wieck. Heavily revised 2003, Tom Lane.
7 *
8 * Many of the algorithmic ideas are borrowed from David M. Smith's "FM"
9 * multiple-precision math library, most recently published as Algorithm
10 * 786: Multiple-Precision Complex Arithmetic and Functions, ACM
11 * Transactions on Mathematical Software, Vol. 24, No. 4, December 1998,
12 * pages 359-367.
13 *
14 * Copyright (c) 1998-2020, PostgreSQL Global Development Group
15 *
16 * IDENTIFICATION
17 * src/backend/utils/adt/numeric.c
18 *
19 *-------------------------------------------------------------------------
20 */
21
22 #include "postgres.h"
23
24 #include <ctype.h>
25 #include <float.h>
26 #include <limits.h>
27 #include <math.h>
28
29 #include "catalog/pg_type.h"
30 #include "common/hashfn.h"
31 #include "common/int.h"
32 #include "funcapi.h"
33 #include "lib/hyperloglog.h"
34 #include "libpq/pqformat.h"
35 #include "miscadmin.h"
36 #include "nodes/nodeFuncs.h"
37 #include "nodes/supportnodes.h"
38 #include "utils/array.h"
39 #include "utils/builtins.h"
40 #include "utils/float.h"
41 #include "utils/guc.h"
42 #include "utils/int8.h"
43 #include "utils/numeric.h"
44 #include "utils/sortsupport.h"
45
46 /* ----------
47 * Uncomment the following to enable compilation of dump_numeric()
48 * and dump_var() and to get a dump of any result produced by make_result().
49 * ----------
50 #define NUMERIC_DEBUG
51 */
52
53
54 /* ----------
55 * Local data types
56 *
57 * Numeric values are represented in a base-NBASE floating point format.
58 * Each "digit" ranges from 0 to NBASE-1. The type NumericDigit is signed
59 * and wide enough to store a digit. We assume that NBASE*NBASE can fit in
60 * an int. Although the purely calculational routines could handle any even
61 * NBASE that's less than sqrt(INT_MAX), in practice we are only interested
62 * in NBASE a power of ten, so that I/O conversions and decimal rounding
63 * are easy. Also, it's actually more efficient if NBASE is rather less than
64 * sqrt(INT_MAX), so that there is "headroom" for mul_var and div_var_fast to
65 * postpone processing carries.
66 *
67 * Values of NBASE other than 10000 are considered of historical interest only
68 * and are no longer supported in any sense; no mechanism exists for the client
69 * to discover the base, so every client supporting binary mode expects the
70 * base-10000 format. If you plan to change this, also note the numeric
71 * abbreviation code, which assumes NBASE=10000.
72 * ----------
73 */
74
75 #if 0
76 #define NBASE 10
77 #define HALF_NBASE 5
78 #define DEC_DIGITS 1 /* decimal digits per NBASE digit */
79 #define MUL_GUARD_DIGITS 4 /* these are measured in NBASE digits */
80 #define DIV_GUARD_DIGITS 8
81
82 typedef signed char NumericDigit;
83 #endif
84
85 #if 0
86 #define NBASE 100
87 #define HALF_NBASE 50
88 #define DEC_DIGITS 2 /* decimal digits per NBASE digit */
89 #define MUL_GUARD_DIGITS 3 /* these are measured in NBASE digits */
90 #define DIV_GUARD_DIGITS 6
91
92 typedef signed char NumericDigit;
93 #endif
94
95 #if 1
96 #define NBASE 10000
97 #define HALF_NBASE 5000
98 #define DEC_DIGITS 4 /* decimal digits per NBASE digit */
99 #define MUL_GUARD_DIGITS 2 /* these are measured in NBASE digits */
100 #define DIV_GUARD_DIGITS 4
101
102 typedef int16 NumericDigit;
103 #endif
104
105 /*
106 * The Numeric type as stored on disk.
107 *
108 * If the high bits of the first word of a NumericChoice (n_header, or
109 * n_short.n_header, or n_long.n_sign_dscale) are NUMERIC_SHORT, then the
110 * numeric follows the NumericShort format; if they are NUMERIC_POS or
111 * NUMERIC_NEG, it follows the NumericLong format. If they are NUMERIC_NAN,
112 * it is a NaN. We currently always store a NaN using just two bytes (i.e.
113 * only n_header), but previous releases used only the NumericLong format,
114 * so we might find 4-byte NaNs on disk if a database has been migrated using
115 * pg_upgrade. In either case, when the high bits indicate a NaN, the
116 * remaining bits are never examined. Currently, we always initialize these
117 * to zero, but it might be possible to use them for some other purpose in
118 * the future.
119 *
120 * In the NumericShort format, the remaining 14 bits of the header word
121 * (n_short.n_header) are allocated as follows: 1 for sign (positive or
122 * negative), 6 for dynamic scale, and 7 for weight. In practice, most
123 * commonly-encountered values can be represented this way.
124 *
125 * In the NumericLong format, the remaining 14 bits of the header word
126 * (n_long.n_sign_dscale) represent the display scale; and the weight is
127 * stored separately in n_weight.
128 *
129 * NOTE: by convention, values in the packed form have been stripped of
130 * all leading and trailing zero digits (where a "digit" is of base NBASE).
131 * In particular, if the value is zero, there will be no digits at all!
132 * The weight is arbitrary in that case, but we normally set it to zero.
133 */
134
135 struct NumericShort
136 {
137 uint16 n_header; /* Sign + display scale + weight */
138 NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
139 };
140
141 struct NumericLong
142 {
143 uint16 n_sign_dscale; /* Sign + display scale */
144 int16 n_weight; /* Weight of 1st digit */
145 NumericDigit n_data[FLEXIBLE_ARRAY_MEMBER]; /* Digits */
146 };
147
148 union NumericChoice
149 {
150 uint16 n_header; /* Header word */
151 struct NumericLong n_long; /* Long form (4-byte header) */
152 struct NumericShort n_short; /* Short form (2-byte header) */
153 };
154
155 struct NumericData
156 {
157 int32 vl_len_; /* varlena header (do not touch directly!) */
158 union NumericChoice choice; /* choice of format */
159 };
160
161
162 /*
163 * Interpretation of high bits.
164 */
165
166 #define NUMERIC_SIGN_MASK 0xC000
167 #define NUMERIC_POS 0x0000
168 #define NUMERIC_NEG 0x4000
169 #define NUMERIC_SHORT 0x8000
170 #define NUMERIC_NAN 0xC000
171
172 #define NUMERIC_FLAGBITS(n) ((n)->choice.n_header & NUMERIC_SIGN_MASK)
173 #define NUMERIC_IS_NAN(n) (NUMERIC_FLAGBITS(n) == NUMERIC_NAN)
174 #define NUMERIC_IS_SHORT(n) (NUMERIC_FLAGBITS(n) == NUMERIC_SHORT)
175
176 #define NUMERIC_HDRSZ (VARHDRSZ + sizeof(uint16) + sizeof(int16))
177 #define NUMERIC_HDRSZ_SHORT (VARHDRSZ + sizeof(uint16))
178
179 /*
180 * If the flag bits are NUMERIC_SHORT or NUMERIC_NAN, we want the short header;
181 * otherwise, we want the long one. Instead of testing against each value, we
182 * can just look at the high bit, for a slight efficiency gain.
183 */
184 #define NUMERIC_HEADER_IS_SHORT(n) (((n)->choice.n_header & 0x8000) != 0)
185 #define NUMERIC_HEADER_SIZE(n) \
186 (VARHDRSZ + sizeof(uint16) + \
187 (NUMERIC_HEADER_IS_SHORT(n) ? 0 : sizeof(int16)))
188
189 /*
190 * Short format definitions.
191 */
192
193 #define NUMERIC_SHORT_SIGN_MASK 0x2000
194 #define NUMERIC_SHORT_DSCALE_MASK 0x1F80
195 #define NUMERIC_SHORT_DSCALE_SHIFT 7
196 #define NUMERIC_SHORT_DSCALE_MAX \
197 (NUMERIC_SHORT_DSCALE_MASK >> NUMERIC_SHORT_DSCALE_SHIFT)
198 #define NUMERIC_SHORT_WEIGHT_SIGN_MASK 0x0040
199 #define NUMERIC_SHORT_WEIGHT_MASK 0x003F
200 #define NUMERIC_SHORT_WEIGHT_MAX NUMERIC_SHORT_WEIGHT_MASK
201 #define NUMERIC_SHORT_WEIGHT_MIN (-(NUMERIC_SHORT_WEIGHT_MASK+1))
202
203 /*
204 * Extract sign, display scale, weight.
205 */
206
207 #define NUMERIC_DSCALE_MASK 0x3FFF
208 #define NUMERIC_DSCALE_MAX NUMERIC_DSCALE_MASK
209
210 #define NUMERIC_SIGN(n) \
211 (NUMERIC_IS_SHORT(n) ? \
212 (((n)->choice.n_short.n_header & NUMERIC_SHORT_SIGN_MASK) ? \
213 NUMERIC_NEG : NUMERIC_POS) : NUMERIC_FLAGBITS(n))
214 #define NUMERIC_DSCALE(n) (NUMERIC_HEADER_IS_SHORT((n)) ? \
215 ((n)->choice.n_short.n_header & NUMERIC_SHORT_DSCALE_MASK) \
216 >> NUMERIC_SHORT_DSCALE_SHIFT \
217 : ((n)->choice.n_long.n_sign_dscale & NUMERIC_DSCALE_MASK))
218 #define NUMERIC_WEIGHT(n) (NUMERIC_HEADER_IS_SHORT((n)) ? \
219 (((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_SIGN_MASK ? \
220 ~NUMERIC_SHORT_WEIGHT_MASK : 0) \
221 | ((n)->choice.n_short.n_header & NUMERIC_SHORT_WEIGHT_MASK)) \
222 : ((n)->choice.n_long.n_weight))
223
224 /* ----------
225 * NumericVar is the format we use for arithmetic. The digit-array part
226 * is the same as the NumericData storage format, but the header is more
227 * complex.
228 *
229 * The value represented by a NumericVar is determined by the sign, weight,
230 * ndigits, and digits[] array.
231 *
232 * Note: the first digit of a NumericVar's value is assumed to be multiplied
233 * by NBASE ** weight. Another way to say it is that there are weight+1
234 * digits before the decimal point. It is possible to have weight < 0.
235 *
236 * buf points at the physical start of the palloc'd digit buffer for the
237 * NumericVar. digits points at the first digit in actual use (the one
238 * with the specified weight). We normally leave an unused digit or two
239 * (preset to zeroes) between buf and digits, so that there is room to store
240 * a carry out of the top digit without reallocating space. We just need to
241 * decrement digits (and increment weight) to make room for the carry digit.
242 * (There is no such extra space in a numeric value stored in the database,
243 * only in a NumericVar in memory.)
244 *
245 * If buf is NULL then the digit buffer isn't actually palloc'd and should
246 * not be freed --- see the constants below for an example.
247 *
248 * dscale, or display scale, is the nominal precision expressed as number
249 * of digits after the decimal point (it must always be >= 0 at present).
250 * dscale may be more than the number of physically stored fractional digits,
251 * implying that we have suppressed storage of significant trailing zeroes.
252 * It should never be less than the number of stored digits, since that would
253 * imply hiding digits that are present. NOTE that dscale is always expressed
254 * in *decimal* digits, and so it may correspond to a fractional number of
255 * base-NBASE digits --- divide by DEC_DIGITS to convert to NBASE digits.
256 *
257 * rscale, or result scale, is the target precision for a computation.
258 * Like dscale it is expressed as number of *decimal* digits after the decimal
259 * point, and is always >= 0 at present.
260 * Note that rscale is not stored in variables --- it's figured on-the-fly
261 * from the dscales of the inputs.
262 *
263 * While we consistently use "weight" to refer to the base-NBASE weight of
264 * a numeric value, it is convenient in some scale-related calculations to
265 * make use of the base-10 weight (ie, the approximate log10 of the value).
266 * To avoid confusion, such a decimal-units weight is called a "dweight".
267 *
268 * NB: All the variable-level functions are written in a style that makes it
269 * possible to give one and the same variable as argument and destination.
270 * This is feasible because the digit buffer is separate from the variable.
271 * ----------
272 */
273 typedef struct NumericVar
274 {
275 int ndigits; /* # of digits in digits[] - can be 0! */
276 int weight; /* weight of first digit */
277 int sign; /* NUMERIC_POS, NUMERIC_NEG, or NUMERIC_NAN */
278 int dscale; /* display scale */
279 NumericDigit *buf; /* start of palloc'd space for digits[] */
280 NumericDigit *digits; /* base-NBASE digits */
281 } NumericVar;
282
283
284 /* ----------
285 * Data for generate_series
286 * ----------
287 */
288 typedef struct
289 {
290 NumericVar current;
291 NumericVar stop;
292 NumericVar step;
293 } generate_series_numeric_fctx;
294
295
296 /* ----------
297 * Sort support.
298 * ----------
299 */
300 typedef struct
301 {
302 void *buf; /* buffer for short varlenas */
303 int64 input_count; /* number of non-null values seen */
304 bool estimating; /* true if estimating cardinality */
305
306 hyperLogLogState abbr_card; /* cardinality estimator */
307 } NumericSortSupport;
308
309
310 /* ----------
311 * Fast sum accumulator.
312 *
313 * NumericSumAccum is used to implement SUM(), and other standard aggregates
314 * that track the sum of input values. It uses 32-bit integers to store the
315 * digits, instead of the normal 16-bit integers (with NBASE=10000). This
316 * way, we can safely accumulate up to NBASE - 1 values without propagating
317 * carry, before risking overflow of any of the digits. 'num_uncarried'
318 * tracks how many values have been accumulated without propagating carry.
319 *
320 * Positive and negative values are accumulated separately, in 'pos_digits'
321 * and 'neg_digits'. This is simpler and faster than deciding whether to add
322 * or subtract from the current value, for each new value (see sub_var() for
323 * the logic we avoid by doing this). Both buffers are of same size, and
324 * have the same weight and scale. In accum_sum_final(), the positive and
325 * negative sums are added together to produce the final result.
326 *
327 * When a new value has a larger ndigits or weight than the accumulator
328 * currently does, the accumulator is enlarged to accommodate the new value.
329 * We normally have one zero digit reserved for carry propagation, and that
330 * is indicated by the 'have_carry_space' flag. When accum_sum_carry() uses
331 * up the reserved digit, it clears the 'have_carry_space' flag. The next
332 * call to accum_sum_add() will enlarge the buffer, to make room for the
333 * extra digit, and set the flag again.
334 *
335 * To initialize a new accumulator, simply reset all fields to zeros.
336 *
337 * The accumulator does not handle NaNs.
338 * ----------
339 */
340 typedef struct NumericSumAccum
341 {
342 int ndigits;
343 int weight;
344 int dscale;
345 int num_uncarried;
346 bool have_carry_space;
347 int32 *pos_digits;
348 int32 *neg_digits;
349 } NumericSumAccum;
350
351
352 /*
353 * We define our own macros for packing and unpacking abbreviated-key
354 * representations for numeric values in order to avoid depending on
355 * USE_FLOAT8_BYVAL. The type of abbreviation we use is based only on
356 * the size of a datum, not the argument-passing convention for float8.
357 */
358 #define NUMERIC_ABBREV_BITS (SIZEOF_DATUM * BITS_PER_BYTE)
359 #if SIZEOF_DATUM == 8
360 #define NumericAbbrevGetDatum(X) ((Datum) (X))
361 #define DatumGetNumericAbbrev(X) ((int64) (X))
362 #define NUMERIC_ABBREV_NAN NumericAbbrevGetDatum(PG_INT64_MIN)
363 #else
364 #define NumericAbbrevGetDatum(X) ((Datum) (X))
365 #define DatumGetNumericAbbrev(X) ((int32) (X))
366 #define NUMERIC_ABBREV_NAN NumericAbbrevGetDatum(PG_INT32_MIN)
367 #endif
368
369
370 /* ----------
371 * Some preinitialized constants
372 * ----------
373 */
374 static const NumericDigit const_zero_data[1] = {0};
375 static const NumericVar const_zero =
376 {0, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_zero_data};
377
378 static const NumericDigit const_one_data[1] = {1};
379 static const NumericVar const_one =
380 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_one_data};
381
382 static const NumericDigit const_two_data[1] = {2};
383 static const NumericVar const_two =
384 {1, 0, NUMERIC_POS, 0, NULL, (NumericDigit *) const_two_data};
385
386 #if DEC_DIGITS == 4
387 static const NumericDigit const_zero_point_nine_data[1] = {9000};
388 #elif DEC_DIGITS == 2
389 static const NumericDigit const_zero_point_nine_data[1] = {90};
390 #elif DEC_DIGITS == 1
391 static const NumericDigit const_zero_point_nine_data[1] = {9};
392 #endif
393 static const NumericVar const_zero_point_nine =
394 {1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_nine_data};
395
396 #if DEC_DIGITS == 4
397 static const NumericDigit const_one_point_one_data[2] = {1, 1000};
398 #elif DEC_DIGITS == 2
399 static const NumericDigit const_one_point_one_data[2] = {1, 10};
400 #elif DEC_DIGITS == 1
401 static const NumericDigit const_one_point_one_data[2] = {1, 1};
402 #endif
403 static const NumericVar const_one_point_one =
404 {2, 0, NUMERIC_POS, 1, NULL, (NumericDigit *) const_one_point_one_data};
405
406 static const NumericVar const_nan =
407 {0, 0, NUMERIC_NAN, 0, NULL, NULL};
408
409 #if DEC_DIGITS == 4
410 static const int round_powers[4] = {0, 1000, 100, 10};
411 #endif
412
413
414 /* ----------
415 * Local functions
416 * ----------
417 */
418
419 #ifdef NUMERIC_DEBUG
420 static void dump_numeric(const char *str, Numeric num);
421 static void dump_var(const char *str, NumericVar *var);
422 #else
423 #define dump_numeric(s,n)
424 #define dump_var(s,v)
425 #endif
426
427 #define digitbuf_alloc(ndigits) \
428 ((NumericDigit *) palloc((ndigits) * sizeof(NumericDigit)))
429 #define digitbuf_free(buf) \
430 do { \
431 if ((buf) != NULL) \
432 pfree(buf); \
433 } while (0)
434
435 #define init_var(v) MemSetAligned(v, 0, sizeof(NumericVar))
436
437 #define NUMERIC_DIGITS(num) (NUMERIC_HEADER_IS_SHORT(num) ? \
438 (num)->choice.n_short.n_data : (num)->choice.n_long.n_data)
439 #define NUMERIC_NDIGITS(num) \
440 ((VARSIZE(num) - NUMERIC_HEADER_SIZE(num)) / sizeof(NumericDigit))
441 #define NUMERIC_CAN_BE_SHORT(scale,weight) \
442 ((scale) <= NUMERIC_SHORT_DSCALE_MAX && \
443 (weight) <= NUMERIC_SHORT_WEIGHT_MAX && \
444 (weight) >= NUMERIC_SHORT_WEIGHT_MIN)
445
446 static void alloc_var(NumericVar *var, int ndigits);
447 static void free_var(NumericVar *var);
448 static void zero_var(NumericVar *var);
449
450 static const char *set_var_from_str(const char *str, const char *cp,
451 NumericVar *dest);
452 static void set_var_from_num(Numeric value, NumericVar *dest);
453 static void init_var_from_num(Numeric num, NumericVar *dest);
454 static void set_var_from_var(const NumericVar *value, NumericVar *dest);
455 static char *get_str_from_var(const NumericVar *var);
456 static char *get_str_from_var_sci(const NumericVar *var, int rscale);
457
458 static Numeric make_result(const NumericVar *var);
459 static Numeric make_result_opt_error(const NumericVar *var, bool *error);
460
461 static void apply_typmod(NumericVar *var, int32 typmod);
462
463 static bool numericvar_to_int32(const NumericVar *var, int32 *result);
464 static bool numericvar_to_int64(const NumericVar *var, int64 *result);
465 static void int64_to_numericvar(int64 val, NumericVar *var);
466 #ifdef HAVE_INT128
467 static bool numericvar_to_int128(const NumericVar *var, int128 *result);
468 static void int128_to_numericvar(int128 val, NumericVar *var);
469 #endif
470 static double numeric_to_double_no_overflow(Numeric num);
471 static double numericvar_to_double_no_overflow(const NumericVar *var);
472
473 static Datum numeric_abbrev_convert(Datum original_datum, SortSupport ssup);
474 static bool numeric_abbrev_abort(int memtupcount, SortSupport ssup);
475 static int numeric_fast_cmp(Datum x, Datum y, SortSupport ssup);
476 static int numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup);
477
478 static Datum numeric_abbrev_convert_var(const NumericVar *var,
479 NumericSortSupport *nss);
480
481 static int cmp_numerics(Numeric num1, Numeric num2);
482 static int cmp_var(const NumericVar *var1, const NumericVar *var2);
483 static int cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
484 int var1weight, int var1sign,
485 const NumericDigit *var2digits, int var2ndigits,
486 int var2weight, int var2sign);
487 static void add_var(const NumericVar *var1, const NumericVar *var2,
488 NumericVar *result);
489 static void sub_var(const NumericVar *var1, const NumericVar *var2,
490 NumericVar *result);
491 static void mul_var(const NumericVar *var1, const NumericVar *var2,
492 NumericVar *result,
493 int rscale);
494 static void div_var(const NumericVar *var1, const NumericVar *var2,
495 NumericVar *result,
496 int rscale, bool round);
497 static void div_var_fast(const NumericVar *var1, const NumericVar *var2,
498 NumericVar *result, int rscale, bool round);
499 static int select_div_scale(const NumericVar *var1, const NumericVar *var2);
500 static void mod_var(const NumericVar *var1, const NumericVar *var2,
501 NumericVar *result);
502 static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
503 NumericVar *quot, NumericVar *rem);
504 static void ceil_var(const NumericVar *var, NumericVar *result);
505 static void floor_var(const NumericVar *var, NumericVar *result);
506
507 static void gcd_var(const NumericVar *var1, const NumericVar *var2,
508 NumericVar *result);
509 static void sqrt_var(const NumericVar *arg, NumericVar *result, int rscale);
510 static void exp_var(const NumericVar *arg, NumericVar *result, int rscale);
511 static int estimate_ln_dweight(const NumericVar *var);
512 static void ln_var(const NumericVar *arg, NumericVar *result, int rscale);
513 static void log_var(const NumericVar *base, const NumericVar *num,
514 NumericVar *result);
515 static void power_var(const NumericVar *base, const NumericVar *exp,
516 NumericVar *result);
517 static void power_var_int(const NumericVar *base, int exp, NumericVar *result,
518 int rscale);
519 static void power_ten_int(int exp, NumericVar *result);
520
521 static int cmp_abs(const NumericVar *var1, const NumericVar *var2);
522 static int cmp_abs_common(const NumericDigit *var1digits, int var1ndigits,
523 int var1weight,
524 const NumericDigit *var2digits, int var2ndigits,
525 int var2weight);
526 static void add_abs(const NumericVar *var1, const NumericVar *var2,
527 NumericVar *result);
528 static void sub_abs(const NumericVar *var1, const NumericVar *var2,
529 NumericVar *result);
530 static void round_var(NumericVar *var, int rscale);
531 static void trunc_var(NumericVar *var, int rscale);
532 static void strip_var(NumericVar *var);
533 static void compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
534 const NumericVar *count_var, NumericVar *result_var);
535
536 static void accum_sum_add(NumericSumAccum *accum, const NumericVar *var1);
537 static void accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val);
538 static void accum_sum_carry(NumericSumAccum *accum);
539 static void accum_sum_reset(NumericSumAccum *accum);
540 static void accum_sum_final(NumericSumAccum *accum, NumericVar *result);
541 static void accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src);
542 static void accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2);
543
544
545 /* ----------------------------------------------------------------------
546 *
547 * Input-, output- and rounding-functions
548 *
549 * ----------------------------------------------------------------------
550 */
551
552
553 /*
554 * numeric_in() -
555 *
556 * Input function for numeric data type
557 */
558 Datum
numeric_in(PG_FUNCTION_ARGS)559 numeric_in(PG_FUNCTION_ARGS)
560 {
561 char *str = PG_GETARG_CSTRING(0);
562
563 #ifdef NOT_USED
564 Oid typelem = PG_GETARG_OID(1);
565 #endif
566 int32 typmod = PG_GETARG_INT32(2);
567 Numeric res;
568 const char *cp;
569
570 /* Skip leading spaces */
571 cp = str;
572 while (*cp)
573 {
574 if (!isspace((unsigned char) *cp))
575 break;
576 cp++;
577 }
578
579 /*
580 * Check for NaN
581 */
582 if (pg_strncasecmp(cp, "NaN", 3) == 0)
583 {
584 res = make_result(&const_nan);
585
586 /* Should be nothing left but spaces */
587 cp += 3;
588 while (*cp)
589 {
590 if (!isspace((unsigned char) *cp))
591 ereport(ERROR,
592 (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
593 errmsg("invalid input syntax for type %s: \"%s\"",
594 "numeric", str)));
595 cp++;
596 }
597 }
598 else
599 {
600 /*
601 * Use set_var_from_str() to parse a normal numeric value
602 */
603 NumericVar value;
604
605 init_var(&value);
606
607 cp = set_var_from_str(str, cp, &value);
608
609 /*
610 * We duplicate a few lines of code here because we would like to
611 * throw any trailing-junk syntax error before any semantic error
612 * resulting from apply_typmod. We can't easily fold the two cases
613 * together because we mustn't apply apply_typmod to a NaN.
614 */
615 while (*cp)
616 {
617 if (!isspace((unsigned char) *cp))
618 ereport(ERROR,
619 (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
620 errmsg("invalid input syntax for type %s: \"%s\"",
621 "numeric", str)));
622 cp++;
623 }
624
625 apply_typmod(&value, typmod);
626
627 res = make_result(&value);
628 free_var(&value);
629 }
630
631 PG_RETURN_NUMERIC(res);
632 }
633
634
635 /*
636 * numeric_out() -
637 *
638 * Output function for numeric data type
639 */
640 Datum
numeric_out(PG_FUNCTION_ARGS)641 numeric_out(PG_FUNCTION_ARGS)
642 {
643 Numeric num = PG_GETARG_NUMERIC(0);
644 NumericVar x;
645 char *str;
646
647 /*
648 * Handle NaN
649 */
650 if (NUMERIC_IS_NAN(num))
651 PG_RETURN_CSTRING(pstrdup("NaN"));
652
653 /*
654 * Get the number in the variable format.
655 */
656 init_var_from_num(num, &x);
657
658 str = get_str_from_var(&x);
659
660 PG_RETURN_CSTRING(str);
661 }
662
663 /*
664 * numeric_is_nan() -
665 *
666 * Is Numeric value a NaN?
667 */
668 bool
numeric_is_nan(Numeric num)669 numeric_is_nan(Numeric num)
670 {
671 return NUMERIC_IS_NAN(num);
672 }
673
674 /*
675 * numeric_maximum_size() -
676 *
677 * Maximum size of a numeric with given typmod, or -1 if unlimited/unknown.
678 */
679 int32
numeric_maximum_size(int32 typmod)680 numeric_maximum_size(int32 typmod)
681 {
682 int precision;
683 int numeric_digits;
684
685 if (typmod < (int32) (VARHDRSZ))
686 return -1;
687
688 /* precision (ie, max # of digits) is in upper bits of typmod */
689 precision = ((typmod - VARHDRSZ) >> 16) & 0xffff;
690
691 /*
692 * This formula computes the maximum number of NumericDigits we could need
693 * in order to store the specified number of decimal digits. Because the
694 * weight is stored as a number of NumericDigits rather than a number of
695 * decimal digits, it's possible that the first NumericDigit will contain
696 * only a single decimal digit. Thus, the first two decimal digits can
697 * require two NumericDigits to store, but it isn't until we reach
698 * DEC_DIGITS + 2 decimal digits that we potentially need a third
699 * NumericDigit.
700 */
701 numeric_digits = (precision + 2 * (DEC_DIGITS - 1)) / DEC_DIGITS;
702
703 /*
704 * In most cases, the size of a numeric will be smaller than the value
705 * computed below, because the varlena header will typically get toasted
706 * down to a single byte before being stored on disk, and it may also be
707 * possible to use a short numeric header. But our job here is to compute
708 * the worst case.
709 */
710 return NUMERIC_HDRSZ + (numeric_digits * sizeof(NumericDigit));
711 }
712
713 /*
714 * numeric_out_sci() -
715 *
716 * Output function for numeric data type in scientific notation.
717 */
718 char *
numeric_out_sci(Numeric num,int scale)719 numeric_out_sci(Numeric num, int scale)
720 {
721 NumericVar x;
722 char *str;
723
724 /*
725 * Handle NaN
726 */
727 if (NUMERIC_IS_NAN(num))
728 return pstrdup("NaN");
729
730 init_var_from_num(num, &x);
731
732 str = get_str_from_var_sci(&x, scale);
733
734 return str;
735 }
736
737 /*
738 * numeric_normalize() -
739 *
740 * Output function for numeric data type, suppressing insignificant trailing
741 * zeroes and then any trailing decimal point. The intent of this is to
742 * produce strings that are equal if and only if the input numeric values
743 * compare equal.
744 */
745 char *
numeric_normalize(Numeric num)746 numeric_normalize(Numeric num)
747 {
748 NumericVar x;
749 char *str;
750 int last;
751
752 /*
753 * Handle NaN
754 */
755 if (NUMERIC_IS_NAN(num))
756 return pstrdup("NaN");
757
758 init_var_from_num(num, &x);
759
760 str = get_str_from_var(&x);
761
762 /* If there's no decimal point, there's certainly nothing to remove. */
763 if (strchr(str, '.') != NULL)
764 {
765 /*
766 * Back up over trailing fractional zeroes. Since there is a decimal
767 * point, this loop will terminate safely.
768 */
769 last = strlen(str) - 1;
770 while (str[last] == '0')
771 last--;
772
773 /* We want to get rid of the decimal point too, if it's now last. */
774 if (str[last] == '.')
775 last--;
776
777 /* Delete whatever we backed up over. */
778 str[last + 1] = '\0';
779 }
780
781 return str;
782 }
783
784 /*
785 * numeric_recv - converts external binary format to numeric
786 *
787 * External format is a sequence of int16's:
788 * ndigits, weight, sign, dscale, NumericDigits.
789 */
790 Datum
numeric_recv(PG_FUNCTION_ARGS)791 numeric_recv(PG_FUNCTION_ARGS)
792 {
793 StringInfo buf = (StringInfo) PG_GETARG_POINTER(0);
794
795 #ifdef NOT_USED
796 Oid typelem = PG_GETARG_OID(1);
797 #endif
798 int32 typmod = PG_GETARG_INT32(2);
799 NumericVar value;
800 Numeric res;
801 int len,
802 i;
803
804 init_var(&value);
805
806 len = (uint16) pq_getmsgint(buf, sizeof(uint16));
807
808 alloc_var(&value, len);
809
810 value.weight = (int16) pq_getmsgint(buf, sizeof(int16));
811 /* we allow any int16 for weight --- OK? */
812
813 value.sign = (uint16) pq_getmsgint(buf, sizeof(uint16));
814 if (!(value.sign == NUMERIC_POS ||
815 value.sign == NUMERIC_NEG ||
816 value.sign == NUMERIC_NAN))
817 ereport(ERROR,
818 (errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
819 errmsg("invalid sign in external \"numeric\" value")));
820
821 value.dscale = (uint16) pq_getmsgint(buf, sizeof(uint16));
822 if ((value.dscale & NUMERIC_DSCALE_MASK) != value.dscale)
823 ereport(ERROR,
824 (errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
825 errmsg("invalid scale in external \"numeric\" value")));
826
827 for (i = 0; i < len; i++)
828 {
829 NumericDigit d = pq_getmsgint(buf, sizeof(NumericDigit));
830
831 if (d < 0 || d >= NBASE)
832 ereport(ERROR,
833 (errcode(ERRCODE_INVALID_BINARY_REPRESENTATION),
834 errmsg("invalid digit in external \"numeric\" value")));
835 value.digits[i] = d;
836 }
837
838 /*
839 * If the given dscale would hide any digits, truncate those digits away.
840 * We could alternatively throw an error, but that would take a bunch of
841 * extra code (about as much as trunc_var involves), and it might cause
842 * client compatibility issues.
843 */
844 trunc_var(&value, value.dscale);
845
846 apply_typmod(&value, typmod);
847
848 res = make_result(&value);
849 free_var(&value);
850
851 PG_RETURN_NUMERIC(res);
852 }
853
854 /*
855 * numeric_send - converts numeric to binary format
856 */
857 Datum
numeric_send(PG_FUNCTION_ARGS)858 numeric_send(PG_FUNCTION_ARGS)
859 {
860 Numeric num = PG_GETARG_NUMERIC(0);
861 NumericVar x;
862 StringInfoData buf;
863 int i;
864
865 init_var_from_num(num, &x);
866
867 pq_begintypsend(&buf);
868
869 pq_sendint16(&buf, x.ndigits);
870 pq_sendint16(&buf, x.weight);
871 pq_sendint16(&buf, x.sign);
872 pq_sendint16(&buf, x.dscale);
873 for (i = 0; i < x.ndigits; i++)
874 pq_sendint16(&buf, x.digits[i]);
875
876 PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
877 }
878
879
880 /*
881 * numeric_support()
882 *
883 * Planner support function for the numeric() length coercion function.
884 *
885 * Flatten calls that solely represent increases in allowable precision.
886 * Scale changes mutate every datum, so they are unoptimizable. Some values,
887 * e.g. 1E-1001, can only fit into an unconstrained numeric, so a change from
888 * an unconstrained numeric to any constrained numeric is also unoptimizable.
889 */
890 Datum
numeric_support(PG_FUNCTION_ARGS)891 numeric_support(PG_FUNCTION_ARGS)
892 {
893 Node *rawreq = (Node *) PG_GETARG_POINTER(0);
894 Node *ret = NULL;
895
896 if (IsA(rawreq, SupportRequestSimplify))
897 {
898 SupportRequestSimplify *req = (SupportRequestSimplify *) rawreq;
899 FuncExpr *expr = req->fcall;
900 Node *typmod;
901
902 Assert(list_length(expr->args) >= 2);
903
904 typmod = (Node *) lsecond(expr->args);
905
906 if (IsA(typmod, Const) && !((Const *) typmod)->constisnull)
907 {
908 Node *source = (Node *) linitial(expr->args);
909 int32 old_typmod = exprTypmod(source);
910 int32 new_typmod = DatumGetInt32(((Const *) typmod)->constvalue);
911 int32 old_scale = (old_typmod - VARHDRSZ) & 0xffff;
912 int32 new_scale = (new_typmod - VARHDRSZ) & 0xffff;
913 int32 old_precision = (old_typmod - VARHDRSZ) >> 16 & 0xffff;
914 int32 new_precision = (new_typmod - VARHDRSZ) >> 16 & 0xffff;
915
916 /*
917 * If new_typmod < VARHDRSZ, the destination is unconstrained;
918 * that's always OK. If old_typmod >= VARHDRSZ, the source is
919 * constrained, and we're OK if the scale is unchanged and the
920 * precision is not decreasing. See further notes in function
921 * header comment.
922 */
923 if (new_typmod < (int32) VARHDRSZ ||
924 (old_typmod >= (int32) VARHDRSZ &&
925 new_scale == old_scale && new_precision >= old_precision))
926 ret = relabel_to_typmod(source, new_typmod);
927 }
928 }
929
930 PG_RETURN_POINTER(ret);
931 }
932
933 /*
934 * numeric() -
935 *
936 * This is a special function called by the Postgres database system
937 * before a value is stored in a tuple's attribute. The precision and
938 * scale of the attribute have to be applied on the value.
939 */
940 Datum
numeric(PG_FUNCTION_ARGS)941 numeric (PG_FUNCTION_ARGS)
942 {
943 Numeric num = PG_GETARG_NUMERIC(0);
944 int32 typmod = PG_GETARG_INT32(1);
945 Numeric new;
946 int32 tmp_typmod;
947 int precision;
948 int scale;
949 int ddigits;
950 int maxdigits;
951 NumericVar var;
952
953 /*
954 * Handle NaN
955 */
956 if (NUMERIC_IS_NAN(num))
957 PG_RETURN_NUMERIC(make_result(&const_nan));
958
959 /*
960 * If the value isn't a valid type modifier, simply return a copy of the
961 * input value
962 */
963 if (typmod < (int32) (VARHDRSZ))
964 {
965 new = (Numeric) palloc(VARSIZE(num));
966 memcpy(new, num, VARSIZE(num));
967 PG_RETURN_NUMERIC(new);
968 }
969
970 /*
971 * Get the precision and scale out of the typmod value
972 */
973 tmp_typmod = typmod - VARHDRSZ;
974 precision = (tmp_typmod >> 16) & 0xffff;
975 scale = tmp_typmod & 0xffff;
976 maxdigits = precision - scale;
977
978 /*
979 * If the number is certainly in bounds and due to the target scale no
980 * rounding could be necessary, just make a copy of the input and modify
981 * its scale fields, unless the larger scale forces us to abandon the
982 * short representation. (Note we assume the existing dscale is
983 * honest...)
984 */
985 ddigits = (NUMERIC_WEIGHT(num) + 1) * DEC_DIGITS;
986 if (ddigits <= maxdigits && scale >= NUMERIC_DSCALE(num)
987 && (NUMERIC_CAN_BE_SHORT(scale, NUMERIC_WEIGHT(num))
988 || !NUMERIC_IS_SHORT(num)))
989 {
990 new = (Numeric) palloc(VARSIZE(num));
991 memcpy(new, num, VARSIZE(num));
992 if (NUMERIC_IS_SHORT(num))
993 new->choice.n_short.n_header =
994 (num->choice.n_short.n_header & ~NUMERIC_SHORT_DSCALE_MASK)
995 | (scale << NUMERIC_SHORT_DSCALE_SHIFT);
996 else
997 new->choice.n_long.n_sign_dscale = NUMERIC_SIGN(new) |
998 ((uint16) scale & NUMERIC_DSCALE_MASK);
999 PG_RETURN_NUMERIC(new);
1000 }
1001
1002 /*
1003 * We really need to fiddle with things - unpack the number into a
1004 * variable and let apply_typmod() do it.
1005 */
1006 init_var(&var);
1007
1008 set_var_from_num(num, &var);
1009 apply_typmod(&var, typmod);
1010 new = make_result(&var);
1011
1012 free_var(&var);
1013
1014 PG_RETURN_NUMERIC(new);
1015 }
1016
1017 Datum
numerictypmodin(PG_FUNCTION_ARGS)1018 numerictypmodin(PG_FUNCTION_ARGS)
1019 {
1020 ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0);
1021 int32 *tl;
1022 int n;
1023 int32 typmod;
1024
1025 tl = ArrayGetIntegerTypmods(ta, &n);
1026
1027 if (n == 2)
1028 {
1029 if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1030 ereport(ERROR,
1031 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1032 errmsg("NUMERIC precision %d must be between 1 and %d",
1033 tl[0], NUMERIC_MAX_PRECISION)));
1034 if (tl[1] < 0 || tl[1] > tl[0])
1035 ereport(ERROR,
1036 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1037 errmsg("NUMERIC scale %d must be between 0 and precision %d",
1038 tl[1], tl[0])));
1039 typmod = ((tl[0] << 16) | tl[1]) + VARHDRSZ;
1040 }
1041 else if (n == 1)
1042 {
1043 if (tl[0] < 1 || tl[0] > NUMERIC_MAX_PRECISION)
1044 ereport(ERROR,
1045 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1046 errmsg("NUMERIC precision %d must be between 1 and %d",
1047 tl[0], NUMERIC_MAX_PRECISION)));
1048 /* scale defaults to zero */
1049 typmod = (tl[0] << 16) + VARHDRSZ;
1050 }
1051 else
1052 {
1053 ereport(ERROR,
1054 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1055 errmsg("invalid NUMERIC type modifier")));
1056 typmod = 0; /* keep compiler quiet */
1057 }
1058
1059 PG_RETURN_INT32(typmod);
1060 }
1061
1062 Datum
numerictypmodout(PG_FUNCTION_ARGS)1063 numerictypmodout(PG_FUNCTION_ARGS)
1064 {
1065 int32 typmod = PG_GETARG_INT32(0);
1066 char *res = (char *) palloc(64);
1067
1068 if (typmod >= 0)
1069 snprintf(res, 64, "(%d,%d)",
1070 ((typmod - VARHDRSZ) >> 16) & 0xffff,
1071 (typmod - VARHDRSZ) & 0xffff);
1072 else
1073 *res = '\0';
1074
1075 PG_RETURN_CSTRING(res);
1076 }
1077
1078
1079 /* ----------------------------------------------------------------------
1080 *
1081 * Sign manipulation, rounding and the like
1082 *
1083 * ----------------------------------------------------------------------
1084 */
1085
1086 Datum
numeric_abs(PG_FUNCTION_ARGS)1087 numeric_abs(PG_FUNCTION_ARGS)
1088 {
1089 Numeric num = PG_GETARG_NUMERIC(0);
1090 Numeric res;
1091
1092 /*
1093 * Handle NaN
1094 */
1095 if (NUMERIC_IS_NAN(num))
1096 PG_RETURN_NUMERIC(make_result(&const_nan));
1097
1098 /*
1099 * Do it the easy way directly on the packed format
1100 */
1101 res = (Numeric) palloc(VARSIZE(num));
1102 memcpy(res, num, VARSIZE(num));
1103
1104 if (NUMERIC_IS_SHORT(num))
1105 res->choice.n_short.n_header =
1106 num->choice.n_short.n_header & ~NUMERIC_SHORT_SIGN_MASK;
1107 else
1108 res->choice.n_long.n_sign_dscale = NUMERIC_POS | NUMERIC_DSCALE(num);
1109
1110 PG_RETURN_NUMERIC(res);
1111 }
1112
1113
1114 Datum
numeric_uminus(PG_FUNCTION_ARGS)1115 numeric_uminus(PG_FUNCTION_ARGS)
1116 {
1117 Numeric num = PG_GETARG_NUMERIC(0);
1118 Numeric res;
1119
1120 /*
1121 * Handle NaN
1122 */
1123 if (NUMERIC_IS_NAN(num))
1124 PG_RETURN_NUMERIC(make_result(&const_nan));
1125
1126 /*
1127 * Do it the easy way directly on the packed format
1128 */
1129 res = (Numeric) palloc(VARSIZE(num));
1130 memcpy(res, num, VARSIZE(num));
1131
1132 /*
1133 * The packed format is known to be totally zero digit trimmed always. So
1134 * we can identify a ZERO by the fact that there are no digits at all. Do
1135 * nothing to a zero.
1136 */
1137 if (NUMERIC_NDIGITS(num) != 0)
1138 {
1139 /* Else, flip the sign */
1140 if (NUMERIC_IS_SHORT(num))
1141 res->choice.n_short.n_header =
1142 num->choice.n_short.n_header ^ NUMERIC_SHORT_SIGN_MASK;
1143 else if (NUMERIC_SIGN(num) == NUMERIC_POS)
1144 res->choice.n_long.n_sign_dscale =
1145 NUMERIC_NEG | NUMERIC_DSCALE(num);
1146 else
1147 res->choice.n_long.n_sign_dscale =
1148 NUMERIC_POS | NUMERIC_DSCALE(num);
1149 }
1150
1151 PG_RETURN_NUMERIC(res);
1152 }
1153
1154
1155 Datum
numeric_uplus(PG_FUNCTION_ARGS)1156 numeric_uplus(PG_FUNCTION_ARGS)
1157 {
1158 Numeric num = PG_GETARG_NUMERIC(0);
1159 Numeric res;
1160
1161 res = (Numeric) palloc(VARSIZE(num));
1162 memcpy(res, num, VARSIZE(num));
1163
1164 PG_RETURN_NUMERIC(res);
1165 }
1166
1167 /*
1168 * numeric_sign() -
1169 *
1170 * returns -1 if the argument is less than 0, 0 if the argument is equal
1171 * to 0, and 1 if the argument is greater than zero.
1172 */
1173 Datum
numeric_sign(PG_FUNCTION_ARGS)1174 numeric_sign(PG_FUNCTION_ARGS)
1175 {
1176 Numeric num = PG_GETARG_NUMERIC(0);
1177 Numeric res;
1178 NumericVar result;
1179
1180 /*
1181 * Handle NaN
1182 */
1183 if (NUMERIC_IS_NAN(num))
1184 PG_RETURN_NUMERIC(make_result(&const_nan));
1185
1186 init_var(&result);
1187
1188 /*
1189 * The packed format is known to be totally zero digit trimmed always. So
1190 * we can identify a ZERO by the fact that there are no digits at all.
1191 */
1192 if (NUMERIC_NDIGITS(num) == 0)
1193 set_var_from_var(&const_zero, &result);
1194 else
1195 {
1196 /*
1197 * And if there are some, we return a copy of ONE with the sign of our
1198 * argument
1199 */
1200 set_var_from_var(&const_one, &result);
1201 result.sign = NUMERIC_SIGN(num);
1202 }
1203
1204 res = make_result(&result);
1205 free_var(&result);
1206
1207 PG_RETURN_NUMERIC(res);
1208 }
1209
1210
1211 /*
1212 * numeric_round() -
1213 *
1214 * Round a value to have 'scale' digits after the decimal point.
1215 * We allow negative 'scale', implying rounding before the decimal
1216 * point --- Oracle interprets rounding that way.
1217 */
1218 Datum
numeric_round(PG_FUNCTION_ARGS)1219 numeric_round(PG_FUNCTION_ARGS)
1220 {
1221 Numeric num = PG_GETARG_NUMERIC(0);
1222 int32 scale = PG_GETARG_INT32(1);
1223 Numeric res;
1224 NumericVar arg;
1225
1226 /*
1227 * Handle NaN
1228 */
1229 if (NUMERIC_IS_NAN(num))
1230 PG_RETURN_NUMERIC(make_result(&const_nan));
1231
1232 /*
1233 * Limit the scale value to avoid possible overflow in calculations
1234 */
1235 scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1236 scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1237
1238 /*
1239 * Unpack the argument and round it at the proper digit position
1240 */
1241 init_var(&arg);
1242 set_var_from_num(num, &arg);
1243
1244 round_var(&arg, scale);
1245
1246 /* We don't allow negative output dscale */
1247 if (scale < 0)
1248 arg.dscale = 0;
1249
1250 /*
1251 * Return the rounded result
1252 */
1253 res = make_result(&arg);
1254
1255 free_var(&arg);
1256 PG_RETURN_NUMERIC(res);
1257 }
1258
1259
1260 /*
1261 * numeric_trunc() -
1262 *
1263 * Truncate a value to have 'scale' digits after the decimal point.
1264 * We allow negative 'scale', implying a truncation before the decimal
1265 * point --- Oracle interprets truncation that way.
1266 */
1267 Datum
numeric_trunc(PG_FUNCTION_ARGS)1268 numeric_trunc(PG_FUNCTION_ARGS)
1269 {
1270 Numeric num = PG_GETARG_NUMERIC(0);
1271 int32 scale = PG_GETARG_INT32(1);
1272 Numeric res;
1273 NumericVar arg;
1274
1275 /*
1276 * Handle NaN
1277 */
1278 if (NUMERIC_IS_NAN(num))
1279 PG_RETURN_NUMERIC(make_result(&const_nan));
1280
1281 /*
1282 * Limit the scale value to avoid possible overflow in calculations
1283 */
1284 scale = Max(scale, -NUMERIC_MAX_RESULT_SCALE);
1285 scale = Min(scale, NUMERIC_MAX_RESULT_SCALE);
1286
1287 /*
1288 * Unpack the argument and truncate it at the proper digit position
1289 */
1290 init_var(&arg);
1291 set_var_from_num(num, &arg);
1292
1293 trunc_var(&arg, scale);
1294
1295 /* We don't allow negative output dscale */
1296 if (scale < 0)
1297 arg.dscale = 0;
1298
1299 /*
1300 * Return the truncated result
1301 */
1302 res = make_result(&arg);
1303
1304 free_var(&arg);
1305 PG_RETURN_NUMERIC(res);
1306 }
1307
1308
1309 /*
1310 * numeric_ceil() -
1311 *
1312 * Return the smallest integer greater than or equal to the argument
1313 */
1314 Datum
numeric_ceil(PG_FUNCTION_ARGS)1315 numeric_ceil(PG_FUNCTION_ARGS)
1316 {
1317 Numeric num = PG_GETARG_NUMERIC(0);
1318 Numeric res;
1319 NumericVar result;
1320
1321 if (NUMERIC_IS_NAN(num))
1322 PG_RETURN_NUMERIC(make_result(&const_nan));
1323
1324 init_var_from_num(num, &result);
1325 ceil_var(&result, &result);
1326
1327 res = make_result(&result);
1328 free_var(&result);
1329
1330 PG_RETURN_NUMERIC(res);
1331 }
1332
1333
1334 /*
1335 * numeric_floor() -
1336 *
1337 * Return the largest integer equal to or less than the argument
1338 */
1339 Datum
numeric_floor(PG_FUNCTION_ARGS)1340 numeric_floor(PG_FUNCTION_ARGS)
1341 {
1342 Numeric num = PG_GETARG_NUMERIC(0);
1343 Numeric res;
1344 NumericVar result;
1345
1346 if (NUMERIC_IS_NAN(num))
1347 PG_RETURN_NUMERIC(make_result(&const_nan));
1348
1349 init_var_from_num(num, &result);
1350 floor_var(&result, &result);
1351
1352 res = make_result(&result);
1353 free_var(&result);
1354
1355 PG_RETURN_NUMERIC(res);
1356 }
1357
1358
1359 /*
1360 * generate_series_numeric() -
1361 *
1362 * Generate series of numeric.
1363 */
1364 Datum
generate_series_numeric(PG_FUNCTION_ARGS)1365 generate_series_numeric(PG_FUNCTION_ARGS)
1366 {
1367 return generate_series_step_numeric(fcinfo);
1368 }
1369
1370 Datum
generate_series_step_numeric(PG_FUNCTION_ARGS)1371 generate_series_step_numeric(PG_FUNCTION_ARGS)
1372 {
1373 generate_series_numeric_fctx *fctx;
1374 FuncCallContext *funcctx;
1375 MemoryContext oldcontext;
1376
1377 if (SRF_IS_FIRSTCALL())
1378 {
1379 Numeric start_num = PG_GETARG_NUMERIC(0);
1380 Numeric stop_num = PG_GETARG_NUMERIC(1);
1381 NumericVar steploc = const_one;
1382
1383 /* handle NaN in start and stop values */
1384 if (NUMERIC_IS_NAN(start_num))
1385 ereport(ERROR,
1386 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1387 errmsg("start value cannot be NaN")));
1388
1389 if (NUMERIC_IS_NAN(stop_num))
1390 ereport(ERROR,
1391 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1392 errmsg("stop value cannot be NaN")));
1393
1394 /* see if we were given an explicit step size */
1395 if (PG_NARGS() == 3)
1396 {
1397 Numeric step_num = PG_GETARG_NUMERIC(2);
1398
1399 if (NUMERIC_IS_NAN(step_num))
1400 ereport(ERROR,
1401 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1402 errmsg("step size cannot be NaN")));
1403
1404 init_var_from_num(step_num, &steploc);
1405
1406 if (cmp_var(&steploc, &const_zero) == 0)
1407 ereport(ERROR,
1408 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1409 errmsg("step size cannot equal zero")));
1410 }
1411
1412 /* create a function context for cross-call persistence */
1413 funcctx = SRF_FIRSTCALL_INIT();
1414
1415 /*
1416 * Switch to memory context appropriate for multiple function calls.
1417 */
1418 oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1419
1420 /* allocate memory for user context */
1421 fctx = (generate_series_numeric_fctx *)
1422 palloc(sizeof(generate_series_numeric_fctx));
1423
1424 /*
1425 * Use fctx to keep state from call to call. Seed current with the
1426 * original start value. We must copy the start_num and stop_num
1427 * values rather than pointing to them, since we may have detoasted
1428 * them in the per-call context.
1429 */
1430 init_var(&fctx->current);
1431 init_var(&fctx->stop);
1432 init_var(&fctx->step);
1433
1434 set_var_from_num(start_num, &fctx->current);
1435 set_var_from_num(stop_num, &fctx->stop);
1436 set_var_from_var(&steploc, &fctx->step);
1437
1438 funcctx->user_fctx = fctx;
1439 MemoryContextSwitchTo(oldcontext);
1440 }
1441
1442 /* stuff done on every call of the function */
1443 funcctx = SRF_PERCALL_SETUP();
1444
1445 /*
1446 * Get the saved state and use current state as the result of this
1447 * iteration.
1448 */
1449 fctx = funcctx->user_fctx;
1450
1451 if ((fctx->step.sign == NUMERIC_POS &&
1452 cmp_var(&fctx->current, &fctx->stop) <= 0) ||
1453 (fctx->step.sign == NUMERIC_NEG &&
1454 cmp_var(&fctx->current, &fctx->stop) >= 0))
1455 {
1456 Numeric result = make_result(&fctx->current);
1457
1458 /* switch to memory context appropriate for iteration calculation */
1459 oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
1460
1461 /* increment current in preparation for next iteration */
1462 add_var(&fctx->current, &fctx->step, &fctx->current);
1463 MemoryContextSwitchTo(oldcontext);
1464
1465 /* do when there is more left to send */
1466 SRF_RETURN_NEXT(funcctx, NumericGetDatum(result));
1467 }
1468 else
1469 /* do when there is no more left */
1470 SRF_RETURN_DONE(funcctx);
1471 }
1472
1473
1474 /*
1475 * Implements the numeric version of the width_bucket() function
1476 * defined by SQL2003. See also width_bucket_float8().
1477 *
1478 * 'bound1' and 'bound2' are the lower and upper bounds of the
1479 * histogram's range, respectively. 'count' is the number of buckets
1480 * in the histogram. width_bucket() returns an integer indicating the
1481 * bucket number that 'operand' belongs to in an equiwidth histogram
1482 * with the specified characteristics. An operand smaller than the
1483 * lower bound is assigned to bucket 0. An operand greater than the
1484 * upper bound is assigned to an additional bucket (with number
1485 * count+1). We don't allow "NaN" for any of the numeric arguments.
1486 */
1487 Datum
width_bucket_numeric(PG_FUNCTION_ARGS)1488 width_bucket_numeric(PG_FUNCTION_ARGS)
1489 {
1490 Numeric operand = PG_GETARG_NUMERIC(0);
1491 Numeric bound1 = PG_GETARG_NUMERIC(1);
1492 Numeric bound2 = PG_GETARG_NUMERIC(2);
1493 int32 count = PG_GETARG_INT32(3);
1494 NumericVar count_var;
1495 NumericVar result_var;
1496 int32 result;
1497
1498 if (count <= 0)
1499 ereport(ERROR,
1500 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1501 errmsg("count must be greater than zero")));
1502
1503 if (NUMERIC_IS_NAN(operand) ||
1504 NUMERIC_IS_NAN(bound1) ||
1505 NUMERIC_IS_NAN(bound2))
1506 ereport(ERROR,
1507 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1508 errmsg("operand, lower bound, and upper bound cannot be NaN")));
1509
1510 init_var(&result_var);
1511 init_var(&count_var);
1512
1513 /* Convert 'count' to a numeric, for ease of use later */
1514 int64_to_numericvar((int64) count, &count_var);
1515
1516 switch (cmp_numerics(bound1, bound2))
1517 {
1518 case 0:
1519 ereport(ERROR,
1520 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION),
1521 errmsg("lower bound cannot equal upper bound")));
1522 break;
1523
1524 /* bound1 < bound2 */
1525 case -1:
1526 if (cmp_numerics(operand, bound1) < 0)
1527 set_var_from_var(&const_zero, &result_var);
1528 else if (cmp_numerics(operand, bound2) >= 0)
1529 add_var(&count_var, &const_one, &result_var);
1530 else
1531 compute_bucket(operand, bound1, bound2,
1532 &count_var, &result_var);
1533 break;
1534
1535 /* bound1 > bound2 */
1536 case 1:
1537 if (cmp_numerics(operand, bound1) > 0)
1538 set_var_from_var(&const_zero, &result_var);
1539 else if (cmp_numerics(operand, bound2) <= 0)
1540 add_var(&count_var, &const_one, &result_var);
1541 else
1542 compute_bucket(operand, bound1, bound2,
1543 &count_var, &result_var);
1544 break;
1545 }
1546
1547 /* if result exceeds the range of a legal int4, we ereport here */
1548 if (!numericvar_to_int32(&result_var, &result))
1549 ereport(ERROR,
1550 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
1551 errmsg("integer out of range")));
1552
1553 free_var(&count_var);
1554 free_var(&result_var);
1555
1556 PG_RETURN_INT32(result);
1557 }
1558
1559 /*
1560 * If 'operand' is not outside the bucket range, determine the correct
1561 * bucket for it to go. The calculations performed by this function
1562 * are derived directly from the SQL2003 spec.
1563 */
1564 static void
compute_bucket(Numeric operand,Numeric bound1,Numeric bound2,const NumericVar * count_var,NumericVar * result_var)1565 compute_bucket(Numeric operand, Numeric bound1, Numeric bound2,
1566 const NumericVar *count_var, NumericVar *result_var)
1567 {
1568 NumericVar bound1_var;
1569 NumericVar bound2_var;
1570 NumericVar operand_var;
1571
1572 init_var_from_num(bound1, &bound1_var);
1573 init_var_from_num(bound2, &bound2_var);
1574 init_var_from_num(operand, &operand_var);
1575
1576 if (cmp_var(&bound1_var, &bound2_var) < 0)
1577 {
1578 sub_var(&operand_var, &bound1_var, &operand_var);
1579 sub_var(&bound2_var, &bound1_var, &bound2_var);
1580 div_var(&operand_var, &bound2_var, result_var,
1581 select_div_scale(&operand_var, &bound2_var), true);
1582 }
1583 else
1584 {
1585 sub_var(&bound1_var, &operand_var, &operand_var);
1586 sub_var(&bound1_var, &bound2_var, &bound1_var);
1587 div_var(&operand_var, &bound1_var, result_var,
1588 select_div_scale(&operand_var, &bound1_var), true);
1589 }
1590
1591 mul_var(result_var, count_var, result_var,
1592 result_var->dscale + count_var->dscale);
1593 add_var(result_var, &const_one, result_var);
1594 floor_var(result_var, result_var);
1595
1596 free_var(&bound1_var);
1597 free_var(&bound2_var);
1598 free_var(&operand_var);
1599 }
1600
1601 /* ----------------------------------------------------------------------
1602 *
1603 * Comparison functions
1604 *
1605 * Note: btree indexes need these routines not to leak memory; therefore,
1606 * be careful to free working copies of toasted datums. Most places don't
1607 * need to be so careful.
1608 *
1609 * Sort support:
1610 *
1611 * We implement the sortsupport strategy routine in order to get the benefit of
1612 * abbreviation. The ordinary numeric comparison can be quite slow as a result
1613 * of palloc/pfree cycles (due to detoasting packed values for alignment);
1614 * while this could be worked on itself, the abbreviation strategy gives more
1615 * speedup in many common cases.
1616 *
1617 * Two different representations are used for the abbreviated form, one in
1618 * int32 and one in int64, whichever fits into a by-value Datum. In both cases
1619 * the representation is negated relative to the original value, because we use
1620 * the largest negative value for NaN, which sorts higher than other values. We
1621 * convert the absolute value of the numeric to a 31-bit or 63-bit positive
1622 * value, and then negate it if the original number was positive.
1623 *
1624 * We abort the abbreviation process if the abbreviation cardinality is below
1625 * 0.01% of the row count (1 per 10k non-null rows). The actual break-even
1626 * point is somewhat below that, perhaps 1 per 30k (at 1 per 100k there's a
1627 * very small penalty), but we don't want to build up too many abbreviated
1628 * values before first testing for abort, so we take the slightly pessimistic
1629 * number. We make no attempt to estimate the cardinality of the real values,
1630 * since it plays no part in the cost model here (if the abbreviation is equal,
1631 * the cost of comparing equal and unequal underlying values is comparable).
1632 * We discontinue even checking for abort (saving us the hashing overhead) if
1633 * the estimated cardinality gets to 100k; that would be enough to support many
1634 * billions of rows while doing no worse than breaking even.
1635 *
1636 * ----------------------------------------------------------------------
1637 */
1638
1639 /*
1640 * Sort support strategy routine.
1641 */
1642 Datum
numeric_sortsupport(PG_FUNCTION_ARGS)1643 numeric_sortsupport(PG_FUNCTION_ARGS)
1644 {
1645 SortSupport ssup = (SortSupport) PG_GETARG_POINTER(0);
1646
1647 ssup->comparator = numeric_fast_cmp;
1648
1649 if (ssup->abbreviate)
1650 {
1651 NumericSortSupport *nss;
1652 MemoryContext oldcontext = MemoryContextSwitchTo(ssup->ssup_cxt);
1653
1654 nss = palloc(sizeof(NumericSortSupport));
1655
1656 /*
1657 * palloc a buffer for handling unaligned packed values in addition to
1658 * the support struct
1659 */
1660 nss->buf = palloc(VARATT_SHORT_MAX + VARHDRSZ + 1);
1661
1662 nss->input_count = 0;
1663 nss->estimating = true;
1664 initHyperLogLog(&nss->abbr_card, 10);
1665
1666 ssup->ssup_extra = nss;
1667
1668 ssup->abbrev_full_comparator = ssup->comparator;
1669 ssup->comparator = numeric_cmp_abbrev;
1670 ssup->abbrev_converter = numeric_abbrev_convert;
1671 ssup->abbrev_abort = numeric_abbrev_abort;
1672
1673 MemoryContextSwitchTo(oldcontext);
1674 }
1675
1676 PG_RETURN_VOID();
1677 }
1678
1679 /*
1680 * Abbreviate a numeric datum, handling NaNs and detoasting
1681 * (must not leak memory!)
1682 */
1683 static Datum
numeric_abbrev_convert(Datum original_datum,SortSupport ssup)1684 numeric_abbrev_convert(Datum original_datum, SortSupport ssup)
1685 {
1686 NumericSortSupport *nss = ssup->ssup_extra;
1687 void *original_varatt = PG_DETOAST_DATUM_PACKED(original_datum);
1688 Numeric value;
1689 Datum result;
1690
1691 nss->input_count += 1;
1692
1693 /*
1694 * This is to handle packed datums without needing a palloc/pfree cycle;
1695 * we keep and reuse a buffer large enough to handle any short datum.
1696 */
1697 if (VARATT_IS_SHORT(original_varatt))
1698 {
1699 void *buf = nss->buf;
1700 Size sz = VARSIZE_SHORT(original_varatt) - VARHDRSZ_SHORT;
1701
1702 Assert(sz <= VARATT_SHORT_MAX - VARHDRSZ_SHORT);
1703
1704 SET_VARSIZE(buf, VARHDRSZ + sz);
1705 memcpy(VARDATA(buf), VARDATA_SHORT(original_varatt), sz);
1706
1707 value = (Numeric) buf;
1708 }
1709 else
1710 value = (Numeric) original_varatt;
1711
1712 if (NUMERIC_IS_NAN(value))
1713 {
1714 result = NUMERIC_ABBREV_NAN;
1715 }
1716 else
1717 {
1718 NumericVar var;
1719
1720 init_var_from_num(value, &var);
1721
1722 result = numeric_abbrev_convert_var(&var, nss);
1723 }
1724
1725 /* should happen only for external/compressed toasts */
1726 if ((Pointer) original_varatt != DatumGetPointer(original_datum))
1727 pfree(original_varatt);
1728
1729 return result;
1730 }
1731
1732 /*
1733 * Consider whether to abort abbreviation.
1734 *
1735 * We pay no attention to the cardinality of the non-abbreviated data. There is
1736 * no reason to do so: unlike text, we have no fast check for equal values, so
1737 * we pay the full overhead whenever the abbreviations are equal regardless of
1738 * whether the underlying values are also equal.
1739 */
1740 static bool
numeric_abbrev_abort(int memtupcount,SortSupport ssup)1741 numeric_abbrev_abort(int memtupcount, SortSupport ssup)
1742 {
1743 NumericSortSupport *nss = ssup->ssup_extra;
1744 double abbr_card;
1745
1746 if (memtupcount < 10000 || nss->input_count < 10000 || !nss->estimating)
1747 return false;
1748
1749 abbr_card = estimateHyperLogLog(&nss->abbr_card);
1750
1751 /*
1752 * If we have >100k distinct values, then even if we were sorting many
1753 * billion rows we'd likely still break even, and the penalty of undoing
1754 * that many rows of abbrevs would probably not be worth it. Stop even
1755 * counting at that point.
1756 */
1757 if (abbr_card > 100000.0)
1758 {
1759 #ifdef TRACE_SORT
1760 if (trace_sort)
1761 elog(LOG,
1762 "numeric_abbrev: estimation ends at cardinality %f"
1763 " after " INT64_FORMAT " values (%d rows)",
1764 abbr_card, nss->input_count, memtupcount);
1765 #endif
1766 nss->estimating = false;
1767 return false;
1768 }
1769
1770 /*
1771 * Target minimum cardinality is 1 per ~10k of non-null inputs. (The
1772 * break even point is somewhere between one per 100k rows, where
1773 * abbreviation has a very slight penalty, and 1 per 10k where it wins by
1774 * a measurable percentage.) We use the relatively pessimistic 10k
1775 * threshold, and add a 0.5 row fudge factor, because it allows us to
1776 * abort earlier on genuinely pathological data where we've had exactly
1777 * one abbreviated value in the first 10k (non-null) rows.
1778 */
1779 if (abbr_card < nss->input_count / 10000.0 + 0.5)
1780 {
1781 #ifdef TRACE_SORT
1782 if (trace_sort)
1783 elog(LOG,
1784 "numeric_abbrev: aborting abbreviation at cardinality %f"
1785 " below threshold %f after " INT64_FORMAT " values (%d rows)",
1786 abbr_card, nss->input_count / 10000.0 + 0.5,
1787 nss->input_count, memtupcount);
1788 #endif
1789 return true;
1790 }
1791
1792 #ifdef TRACE_SORT
1793 if (trace_sort)
1794 elog(LOG,
1795 "numeric_abbrev: cardinality %f"
1796 " after " INT64_FORMAT " values (%d rows)",
1797 abbr_card, nss->input_count, memtupcount);
1798 #endif
1799
1800 return false;
1801 }
1802
1803 /*
1804 * Non-fmgr interface to the comparison routine to allow sortsupport to elide
1805 * the fmgr call. The saving here is small given how slow numeric comparisons
1806 * are, but it is a required part of the sort support API when abbreviations
1807 * are performed.
1808 *
1809 * Two palloc/pfree cycles could be saved here by using persistent buffers for
1810 * aligning short-varlena inputs, but this has not so far been considered to
1811 * be worth the effort.
1812 */
1813 static int
numeric_fast_cmp(Datum x,Datum y,SortSupport ssup)1814 numeric_fast_cmp(Datum x, Datum y, SortSupport ssup)
1815 {
1816 Numeric nx = DatumGetNumeric(x);
1817 Numeric ny = DatumGetNumeric(y);
1818 int result;
1819
1820 result = cmp_numerics(nx, ny);
1821
1822 if ((Pointer) nx != DatumGetPointer(x))
1823 pfree(nx);
1824 if ((Pointer) ny != DatumGetPointer(y))
1825 pfree(ny);
1826
1827 return result;
1828 }
1829
1830 /*
1831 * Compare abbreviations of values. (Abbreviations may be equal where the true
1832 * values differ, but if the abbreviations differ, they must reflect the
1833 * ordering of the true values.)
1834 */
1835 static int
numeric_cmp_abbrev(Datum x,Datum y,SortSupport ssup)1836 numeric_cmp_abbrev(Datum x, Datum y, SortSupport ssup)
1837 {
1838 /*
1839 * NOTE WELL: this is intentionally backwards, because the abbreviation is
1840 * negated relative to the original value, to handle NaN.
1841 */
1842 if (DatumGetNumericAbbrev(x) < DatumGetNumericAbbrev(y))
1843 return 1;
1844 if (DatumGetNumericAbbrev(x) > DatumGetNumericAbbrev(y))
1845 return -1;
1846 return 0;
1847 }
1848
1849 /*
1850 * Abbreviate a NumericVar according to the available bit size.
1851 *
1852 * The 31-bit value is constructed as:
1853 *
1854 * 0 + 7bits digit weight + 24 bits digit value
1855 *
1856 * where the digit weight is in single decimal digits, not digit words, and
1857 * stored in excess-44 representation[1]. The 24-bit digit value is the 7 most
1858 * significant decimal digits of the value converted to binary. Values whose
1859 * weights would fall outside the representable range are rounded off to zero
1860 * (which is also used to represent actual zeros) or to 0x7FFFFFFF (which
1861 * otherwise cannot occur). Abbreviation therefore fails to gain any advantage
1862 * where values are outside the range 10^-44 to 10^83, which is not considered
1863 * to be a serious limitation, or when values are of the same magnitude and
1864 * equal in the first 7 decimal digits, which is considered to be an
1865 * unavoidable limitation given the available bits. (Stealing three more bits
1866 * to compare another digit would narrow the range of representable weights by
1867 * a factor of 8, which starts to look like a real limiting factor.)
1868 *
1869 * (The value 44 for the excess is essentially arbitrary)
1870 *
1871 * The 63-bit value is constructed as:
1872 *
1873 * 0 + 7bits weight + 4 x 14-bit packed digit words
1874 *
1875 * The weight in this case is again stored in excess-44, but this time it is
1876 * the original weight in digit words (i.e. powers of 10000). The first four
1877 * digit words of the value (if present; trailing zeros are assumed as needed)
1878 * are packed into 14 bits each to form the rest of the value. Again,
1879 * out-of-range values are rounded off to 0 or 0x7FFFFFFFFFFFFFFF. The
1880 * representable range in this case is 10^-176 to 10^332, which is considered
1881 * to be good enough for all practical purposes, and comparison of 4 words
1882 * means that at least 13 decimal digits are compared, which is considered to
1883 * be a reasonable compromise between effectiveness and efficiency in computing
1884 * the abbreviation.
1885 *
1886 * (The value 44 for the excess is even more arbitrary here, it was chosen just
1887 * to match the value used in the 31-bit case)
1888 *
1889 * [1] - Excess-k representation means that the value is offset by adding 'k'
1890 * and then treated as unsigned, so the smallest representable value is stored
1891 * with all bits zero. This allows simple comparisons to work on the composite
1892 * value.
1893 */
1894
1895 #if NUMERIC_ABBREV_BITS == 64
1896
1897 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)1898 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
1899 {
1900 int ndigits = var->ndigits;
1901 int weight = var->weight;
1902 int64 result;
1903
1904 if (ndigits == 0 || weight < -44)
1905 {
1906 result = 0;
1907 }
1908 else if (weight > 83)
1909 {
1910 result = PG_INT64_MAX;
1911 }
1912 else
1913 {
1914 result = ((int64) (weight + 44) << 56);
1915
1916 switch (ndigits)
1917 {
1918 default:
1919 result |= ((int64) var->digits[3]);
1920 /* FALLTHROUGH */
1921 case 3:
1922 result |= ((int64) var->digits[2]) << 14;
1923 /* FALLTHROUGH */
1924 case 2:
1925 result |= ((int64) var->digits[1]) << 28;
1926 /* FALLTHROUGH */
1927 case 1:
1928 result |= ((int64) var->digits[0]) << 42;
1929 break;
1930 }
1931 }
1932
1933 /* the abbrev is negated relative to the original */
1934 if (var->sign == NUMERIC_POS)
1935 result = -result;
1936
1937 if (nss->estimating)
1938 {
1939 uint32 tmp = ((uint32) result
1940 ^ (uint32) ((uint64) result >> 32));
1941
1942 addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
1943 }
1944
1945 return NumericAbbrevGetDatum(result);
1946 }
1947
1948 #endif /* NUMERIC_ABBREV_BITS == 64 */
1949
1950 #if NUMERIC_ABBREV_BITS == 32
1951
1952 static Datum
numeric_abbrev_convert_var(const NumericVar * var,NumericSortSupport * nss)1953 numeric_abbrev_convert_var(const NumericVar *var, NumericSortSupport *nss)
1954 {
1955 int ndigits = var->ndigits;
1956 int weight = var->weight;
1957 int32 result;
1958
1959 if (ndigits == 0 || weight < -11)
1960 {
1961 result = 0;
1962 }
1963 else if (weight > 20)
1964 {
1965 result = PG_INT32_MAX;
1966 }
1967 else
1968 {
1969 NumericDigit nxt1 = (ndigits > 1) ? var->digits[1] : 0;
1970
1971 weight = (weight + 11) * 4;
1972
1973 result = var->digits[0];
1974
1975 /*
1976 * "result" now has 1 to 4 nonzero decimal digits. We pack in more
1977 * digits to make 7 in total (largest we can fit in 24 bits)
1978 */
1979
1980 if (result > 999)
1981 {
1982 /* already have 4 digits, add 3 more */
1983 result = (result * 1000) + (nxt1 / 10);
1984 weight += 3;
1985 }
1986 else if (result > 99)
1987 {
1988 /* already have 3 digits, add 4 more */
1989 result = (result * 10000) + nxt1;
1990 weight += 2;
1991 }
1992 else if (result > 9)
1993 {
1994 NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
1995
1996 /* already have 2 digits, add 5 more */
1997 result = (result * 100000) + (nxt1 * 10) + (nxt2 / 1000);
1998 weight += 1;
1999 }
2000 else
2001 {
2002 NumericDigit nxt2 = (ndigits > 2) ? var->digits[2] : 0;
2003
2004 /* already have 1 digit, add 6 more */
2005 result = (result * 1000000) + (nxt1 * 100) + (nxt2 / 100);
2006 }
2007
2008 result = result | (weight << 24);
2009 }
2010
2011 /* the abbrev is negated relative to the original */
2012 if (var->sign == NUMERIC_POS)
2013 result = -result;
2014
2015 if (nss->estimating)
2016 {
2017 uint32 tmp = (uint32) result;
2018
2019 addHyperLogLog(&nss->abbr_card, DatumGetUInt32(hash_uint32(tmp)));
2020 }
2021
2022 return NumericAbbrevGetDatum(result);
2023 }
2024
2025 #endif /* NUMERIC_ABBREV_BITS == 32 */
2026
2027 /*
2028 * Ordinary (non-sortsupport) comparisons follow.
2029 */
2030
2031 Datum
numeric_cmp(PG_FUNCTION_ARGS)2032 numeric_cmp(PG_FUNCTION_ARGS)
2033 {
2034 Numeric num1 = PG_GETARG_NUMERIC(0);
2035 Numeric num2 = PG_GETARG_NUMERIC(1);
2036 int result;
2037
2038 result = cmp_numerics(num1, num2);
2039
2040 PG_FREE_IF_COPY(num1, 0);
2041 PG_FREE_IF_COPY(num2, 1);
2042
2043 PG_RETURN_INT32(result);
2044 }
2045
2046
2047 Datum
numeric_eq(PG_FUNCTION_ARGS)2048 numeric_eq(PG_FUNCTION_ARGS)
2049 {
2050 Numeric num1 = PG_GETARG_NUMERIC(0);
2051 Numeric num2 = PG_GETARG_NUMERIC(1);
2052 bool result;
2053
2054 result = cmp_numerics(num1, num2) == 0;
2055
2056 PG_FREE_IF_COPY(num1, 0);
2057 PG_FREE_IF_COPY(num2, 1);
2058
2059 PG_RETURN_BOOL(result);
2060 }
2061
2062 Datum
numeric_ne(PG_FUNCTION_ARGS)2063 numeric_ne(PG_FUNCTION_ARGS)
2064 {
2065 Numeric num1 = PG_GETARG_NUMERIC(0);
2066 Numeric num2 = PG_GETARG_NUMERIC(1);
2067 bool result;
2068
2069 result = cmp_numerics(num1, num2) != 0;
2070
2071 PG_FREE_IF_COPY(num1, 0);
2072 PG_FREE_IF_COPY(num2, 1);
2073
2074 PG_RETURN_BOOL(result);
2075 }
2076
2077 Datum
numeric_gt(PG_FUNCTION_ARGS)2078 numeric_gt(PG_FUNCTION_ARGS)
2079 {
2080 Numeric num1 = PG_GETARG_NUMERIC(0);
2081 Numeric num2 = PG_GETARG_NUMERIC(1);
2082 bool result;
2083
2084 result = cmp_numerics(num1, num2) > 0;
2085
2086 PG_FREE_IF_COPY(num1, 0);
2087 PG_FREE_IF_COPY(num2, 1);
2088
2089 PG_RETURN_BOOL(result);
2090 }
2091
2092 Datum
numeric_ge(PG_FUNCTION_ARGS)2093 numeric_ge(PG_FUNCTION_ARGS)
2094 {
2095 Numeric num1 = PG_GETARG_NUMERIC(0);
2096 Numeric num2 = PG_GETARG_NUMERIC(1);
2097 bool result;
2098
2099 result = cmp_numerics(num1, num2) >= 0;
2100
2101 PG_FREE_IF_COPY(num1, 0);
2102 PG_FREE_IF_COPY(num2, 1);
2103
2104 PG_RETURN_BOOL(result);
2105 }
2106
2107 Datum
numeric_lt(PG_FUNCTION_ARGS)2108 numeric_lt(PG_FUNCTION_ARGS)
2109 {
2110 Numeric num1 = PG_GETARG_NUMERIC(0);
2111 Numeric num2 = PG_GETARG_NUMERIC(1);
2112 bool result;
2113
2114 result = cmp_numerics(num1, num2) < 0;
2115
2116 PG_FREE_IF_COPY(num1, 0);
2117 PG_FREE_IF_COPY(num2, 1);
2118
2119 PG_RETURN_BOOL(result);
2120 }
2121
2122 Datum
numeric_le(PG_FUNCTION_ARGS)2123 numeric_le(PG_FUNCTION_ARGS)
2124 {
2125 Numeric num1 = PG_GETARG_NUMERIC(0);
2126 Numeric num2 = PG_GETARG_NUMERIC(1);
2127 bool result;
2128
2129 result = cmp_numerics(num1, num2) <= 0;
2130
2131 PG_FREE_IF_COPY(num1, 0);
2132 PG_FREE_IF_COPY(num2, 1);
2133
2134 PG_RETURN_BOOL(result);
2135 }
2136
2137 static int
cmp_numerics(Numeric num1,Numeric num2)2138 cmp_numerics(Numeric num1, Numeric num2)
2139 {
2140 int result;
2141
2142 /*
2143 * We consider all NANs to be equal and larger than any non-NAN. This is
2144 * somewhat arbitrary; the important thing is to have a consistent sort
2145 * order.
2146 */
2147 if (NUMERIC_IS_NAN(num1))
2148 {
2149 if (NUMERIC_IS_NAN(num2))
2150 result = 0; /* NAN = NAN */
2151 else
2152 result = 1; /* NAN > non-NAN */
2153 }
2154 else if (NUMERIC_IS_NAN(num2))
2155 {
2156 result = -1; /* non-NAN < NAN */
2157 }
2158 else
2159 {
2160 result = cmp_var_common(NUMERIC_DIGITS(num1), NUMERIC_NDIGITS(num1),
2161 NUMERIC_WEIGHT(num1), NUMERIC_SIGN(num1),
2162 NUMERIC_DIGITS(num2), NUMERIC_NDIGITS(num2),
2163 NUMERIC_WEIGHT(num2), NUMERIC_SIGN(num2));
2164 }
2165
2166 return result;
2167 }
2168
2169 /*
2170 * in_range support function for numeric.
2171 */
2172 Datum
in_range_numeric_numeric(PG_FUNCTION_ARGS)2173 in_range_numeric_numeric(PG_FUNCTION_ARGS)
2174 {
2175 Numeric val = PG_GETARG_NUMERIC(0);
2176 Numeric base = PG_GETARG_NUMERIC(1);
2177 Numeric offset = PG_GETARG_NUMERIC(2);
2178 bool sub = PG_GETARG_BOOL(3);
2179 bool less = PG_GETARG_BOOL(4);
2180 bool result;
2181
2182 /*
2183 * Reject negative or NaN offset. Negative is per spec, and NaN is
2184 * because appropriate semantics for that seem non-obvious.
2185 */
2186 if (NUMERIC_IS_NAN(offset) || NUMERIC_SIGN(offset) == NUMERIC_NEG)
2187 ereport(ERROR,
2188 (errcode(ERRCODE_INVALID_PRECEDING_OR_FOLLOWING_SIZE),
2189 errmsg("invalid preceding or following size in window function")));
2190
2191 /*
2192 * Deal with cases where val and/or base is NaN, following the rule that
2193 * NaN sorts after non-NaN (cf cmp_numerics). The offset cannot affect
2194 * the conclusion.
2195 */
2196 if (NUMERIC_IS_NAN(val))
2197 {
2198 if (NUMERIC_IS_NAN(base))
2199 result = true; /* NAN = NAN */
2200 else
2201 result = !less; /* NAN > non-NAN */
2202 }
2203 else if (NUMERIC_IS_NAN(base))
2204 {
2205 result = less; /* non-NAN < NAN */
2206 }
2207 else
2208 {
2209 /*
2210 * Otherwise go ahead and compute base +/- offset. While it's
2211 * possible for this to overflow the numeric format, it's unlikely
2212 * enough that we don't take measures to prevent it.
2213 */
2214 NumericVar valv;
2215 NumericVar basev;
2216 NumericVar offsetv;
2217 NumericVar sum;
2218
2219 init_var_from_num(val, &valv);
2220 init_var_from_num(base, &basev);
2221 init_var_from_num(offset, &offsetv);
2222 init_var(&sum);
2223
2224 if (sub)
2225 sub_var(&basev, &offsetv, &sum);
2226 else
2227 add_var(&basev, &offsetv, &sum);
2228
2229 if (less)
2230 result = (cmp_var(&valv, &sum) <= 0);
2231 else
2232 result = (cmp_var(&valv, &sum) >= 0);
2233
2234 free_var(&sum);
2235 }
2236
2237 PG_FREE_IF_COPY(val, 0);
2238 PG_FREE_IF_COPY(base, 1);
2239 PG_FREE_IF_COPY(offset, 2);
2240
2241 PG_RETURN_BOOL(result);
2242 }
2243
2244 Datum
hash_numeric(PG_FUNCTION_ARGS)2245 hash_numeric(PG_FUNCTION_ARGS)
2246 {
2247 Numeric key = PG_GETARG_NUMERIC(0);
2248 Datum digit_hash;
2249 Datum result;
2250 int weight;
2251 int start_offset;
2252 int end_offset;
2253 int i;
2254 int hash_len;
2255 NumericDigit *digits;
2256
2257 /* If it's NaN, don't try to hash the rest of the fields */
2258 if (NUMERIC_IS_NAN(key))
2259 PG_RETURN_UINT32(0);
2260
2261 weight = NUMERIC_WEIGHT(key);
2262 start_offset = 0;
2263 end_offset = 0;
2264
2265 /*
2266 * Omit any leading or trailing zeros from the input to the hash. The
2267 * numeric implementation *should* guarantee that leading and trailing
2268 * zeros are suppressed, but we're paranoid. Note that we measure the
2269 * starting and ending offsets in units of NumericDigits, not bytes.
2270 */
2271 digits = NUMERIC_DIGITS(key);
2272 for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2273 {
2274 if (digits[i] != (NumericDigit) 0)
2275 break;
2276
2277 start_offset++;
2278
2279 /*
2280 * The weight is effectively the # of digits before the decimal point,
2281 * so decrement it for each leading zero we skip.
2282 */
2283 weight--;
2284 }
2285
2286 /*
2287 * If there are no non-zero digits, then the value of the number is zero,
2288 * regardless of any other fields.
2289 */
2290 if (NUMERIC_NDIGITS(key) == start_offset)
2291 PG_RETURN_UINT32(-1);
2292
2293 for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2294 {
2295 if (digits[i] != (NumericDigit) 0)
2296 break;
2297
2298 end_offset++;
2299 }
2300
2301 /* If we get here, there should be at least one non-zero digit */
2302 Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2303
2304 /*
2305 * Note that we don't hash on the Numeric's scale, since two numerics can
2306 * compare equal but have different scales. We also don't hash on the
2307 * sign, although we could: since a sign difference implies inequality,
2308 * this shouldn't affect correctness.
2309 */
2310 hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2311 digit_hash = hash_any((unsigned char *) (NUMERIC_DIGITS(key) + start_offset),
2312 hash_len * sizeof(NumericDigit));
2313
2314 /* Mix in the weight, via XOR */
2315 result = digit_hash ^ weight;
2316
2317 PG_RETURN_DATUM(result);
2318 }
2319
2320 /*
2321 * Returns 64-bit value by hashing a value to a 64-bit value, with a seed.
2322 * Otherwise, similar to hash_numeric.
2323 */
2324 Datum
hash_numeric_extended(PG_FUNCTION_ARGS)2325 hash_numeric_extended(PG_FUNCTION_ARGS)
2326 {
2327 Numeric key = PG_GETARG_NUMERIC(0);
2328 uint64 seed = PG_GETARG_INT64(1);
2329 Datum digit_hash;
2330 Datum result;
2331 int weight;
2332 int start_offset;
2333 int end_offset;
2334 int i;
2335 int hash_len;
2336 NumericDigit *digits;
2337
2338 if (NUMERIC_IS_NAN(key))
2339 PG_RETURN_UINT64(seed);
2340
2341 weight = NUMERIC_WEIGHT(key);
2342 start_offset = 0;
2343 end_offset = 0;
2344
2345 digits = NUMERIC_DIGITS(key);
2346 for (i = 0; i < NUMERIC_NDIGITS(key); i++)
2347 {
2348 if (digits[i] != (NumericDigit) 0)
2349 break;
2350
2351 start_offset++;
2352
2353 weight--;
2354 }
2355
2356 if (NUMERIC_NDIGITS(key) == start_offset)
2357 PG_RETURN_UINT64(seed - 1);
2358
2359 for (i = NUMERIC_NDIGITS(key) - 1; i >= 0; i--)
2360 {
2361 if (digits[i] != (NumericDigit) 0)
2362 break;
2363
2364 end_offset++;
2365 }
2366
2367 Assert(start_offset + end_offset < NUMERIC_NDIGITS(key));
2368
2369 hash_len = NUMERIC_NDIGITS(key) - start_offset - end_offset;
2370 digit_hash = hash_any_extended((unsigned char *) (NUMERIC_DIGITS(key)
2371 + start_offset),
2372 hash_len * sizeof(NumericDigit),
2373 seed);
2374
2375 result = UInt64GetDatum(DatumGetUInt64(digit_hash) ^ weight);
2376
2377 PG_RETURN_DATUM(result);
2378 }
2379
2380
2381 /* ----------------------------------------------------------------------
2382 *
2383 * Basic arithmetic functions
2384 *
2385 * ----------------------------------------------------------------------
2386 */
2387
2388
2389 /*
2390 * numeric_add() -
2391 *
2392 * Add two numerics
2393 */
2394 Datum
numeric_add(PG_FUNCTION_ARGS)2395 numeric_add(PG_FUNCTION_ARGS)
2396 {
2397 Numeric num1 = PG_GETARG_NUMERIC(0);
2398 Numeric num2 = PG_GETARG_NUMERIC(1);
2399 Numeric res;
2400
2401 res = numeric_add_opt_error(num1, num2, NULL);
2402
2403 PG_RETURN_NUMERIC(res);
2404 }
2405
2406 /*
2407 * numeric_add_opt_error() -
2408 *
2409 * Internal version of numeric_add(). If "*have_error" flag is provided,
2410 * on error it's set to true, NULL returned. This is helpful when caller
2411 * need to handle errors by itself.
2412 */
2413 Numeric
numeric_add_opt_error(Numeric num1,Numeric num2,bool * have_error)2414 numeric_add_opt_error(Numeric num1, Numeric num2, bool *have_error)
2415 {
2416 NumericVar arg1;
2417 NumericVar arg2;
2418 NumericVar result;
2419 Numeric res;
2420
2421 /*
2422 * Handle NaN
2423 */
2424 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2425 return make_result(&const_nan);
2426
2427 /*
2428 * Unpack the values, let add_var() compute the result and return it.
2429 */
2430 init_var_from_num(num1, &arg1);
2431 init_var_from_num(num2, &arg2);
2432
2433 init_var(&result);
2434 add_var(&arg1, &arg2, &result);
2435
2436 res = make_result_opt_error(&result, have_error);
2437
2438 free_var(&result);
2439
2440 return res;
2441 }
2442
2443
2444 /*
2445 * numeric_sub() -
2446 *
2447 * Subtract one numeric from another
2448 */
2449 Datum
numeric_sub(PG_FUNCTION_ARGS)2450 numeric_sub(PG_FUNCTION_ARGS)
2451 {
2452 Numeric num1 = PG_GETARG_NUMERIC(0);
2453 Numeric num2 = PG_GETARG_NUMERIC(1);
2454 Numeric res;
2455
2456 res = numeric_sub_opt_error(num1, num2, NULL);
2457
2458 PG_RETURN_NUMERIC(res);
2459 }
2460
2461
2462 /*
2463 * numeric_sub_opt_error() -
2464 *
2465 * Internal version of numeric_sub(). If "*have_error" flag is provided,
2466 * on error it's set to true, NULL returned. This is helpful when caller
2467 * need to handle errors by itself.
2468 */
2469 Numeric
numeric_sub_opt_error(Numeric num1,Numeric num2,bool * have_error)2470 numeric_sub_opt_error(Numeric num1, Numeric num2, bool *have_error)
2471 {
2472 NumericVar arg1;
2473 NumericVar arg2;
2474 NumericVar result;
2475 Numeric res;
2476
2477 /*
2478 * Handle NaN
2479 */
2480 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2481 return make_result(&const_nan);
2482
2483 /*
2484 * Unpack the values, let sub_var() compute the result and return it.
2485 */
2486 init_var_from_num(num1, &arg1);
2487 init_var_from_num(num2, &arg2);
2488
2489 init_var(&result);
2490 sub_var(&arg1, &arg2, &result);
2491
2492 res = make_result_opt_error(&result, have_error);
2493
2494 free_var(&result);
2495
2496 return res;
2497 }
2498
2499
2500 /*
2501 * numeric_mul() -
2502 *
2503 * Calculate the product of two numerics
2504 */
2505 Datum
numeric_mul(PG_FUNCTION_ARGS)2506 numeric_mul(PG_FUNCTION_ARGS)
2507 {
2508 Numeric num1 = PG_GETARG_NUMERIC(0);
2509 Numeric num2 = PG_GETARG_NUMERIC(1);
2510 Numeric res;
2511
2512 res = numeric_mul_opt_error(num1, num2, NULL);
2513
2514 PG_RETURN_NUMERIC(res);
2515 }
2516
2517
2518 /*
2519 * numeric_mul_opt_error() -
2520 *
2521 * Internal version of numeric_mul(). If "*have_error" flag is provided,
2522 * on error it's set to true, NULL returned. This is helpful when caller
2523 * need to handle errors by itself.
2524 */
2525 Numeric
numeric_mul_opt_error(Numeric num1,Numeric num2,bool * have_error)2526 numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error)
2527 {
2528 NumericVar arg1;
2529 NumericVar arg2;
2530 NumericVar result;
2531 Numeric res;
2532
2533 /*
2534 * Handle NaN
2535 */
2536 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2537 return make_result(&const_nan);
2538
2539 /*
2540 * Unpack the values, let mul_var() compute the result and return it.
2541 * Unlike add_var() and sub_var(), mul_var() will round its result. In the
2542 * case of numeric_mul(), which is invoked for the * operator on numerics,
2543 * we request exact representation for the product (rscale = sum(dscale of
2544 * arg1, dscale of arg2)). If the exact result has more digits after the
2545 * decimal point than can be stored in a numeric, we round it. Rounding
2546 * after computing the exact result ensures that the final result is
2547 * correctly rounded (rounding in mul_var() using a truncated product
2548 * would not guarantee this).
2549 */
2550 init_var_from_num(num1, &arg1);
2551 init_var_from_num(num2, &arg2);
2552
2553 init_var(&result);
2554 mul_var(&arg1, &arg2, &result, arg1.dscale + arg2.dscale);
2555
2556 if (result.dscale > NUMERIC_DSCALE_MAX)
2557 round_var(&result, NUMERIC_DSCALE_MAX);
2558
2559 res = make_result_opt_error(&result, have_error);
2560
2561 free_var(&result);
2562
2563 return res;
2564 }
2565
2566
2567 /*
2568 * numeric_div() -
2569 *
2570 * Divide one numeric into another
2571 */
2572 Datum
numeric_div(PG_FUNCTION_ARGS)2573 numeric_div(PG_FUNCTION_ARGS)
2574 {
2575 Numeric num1 = PG_GETARG_NUMERIC(0);
2576 Numeric num2 = PG_GETARG_NUMERIC(1);
2577 Numeric res;
2578
2579 res = numeric_div_opt_error(num1, num2, NULL);
2580
2581 PG_RETURN_NUMERIC(res);
2582 }
2583
2584
2585 /*
2586 * numeric_div_opt_error() -
2587 *
2588 * Internal version of numeric_div(). If "*have_error" flag is provided,
2589 * on error it's set to true, NULL returned. This is helpful when caller
2590 * need to handle errors by itself.
2591 */
2592 Numeric
numeric_div_opt_error(Numeric num1,Numeric num2,bool * have_error)2593 numeric_div_opt_error(Numeric num1, Numeric num2, bool *have_error)
2594 {
2595 NumericVar arg1;
2596 NumericVar arg2;
2597 NumericVar result;
2598 Numeric res;
2599 int rscale;
2600
2601 if (have_error)
2602 *have_error = false;
2603
2604 /*
2605 * Handle NaN
2606 */
2607 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2608 return make_result(&const_nan);
2609
2610 /*
2611 * Unpack the arguments
2612 */
2613 init_var_from_num(num1, &arg1);
2614 init_var_from_num(num2, &arg2);
2615
2616 init_var(&result);
2617
2618 /*
2619 * Select scale for division result
2620 */
2621 rscale = select_div_scale(&arg1, &arg2);
2622
2623 /*
2624 * If "have_error" is provided, check for division by zero here
2625 */
2626 if (have_error && (arg2.ndigits == 0 || arg2.digits[0] == 0))
2627 {
2628 *have_error = true;
2629 return NULL;
2630 }
2631
2632 /*
2633 * Do the divide and return the result
2634 */
2635 div_var(&arg1, &arg2, &result, rscale, true);
2636
2637 res = make_result_opt_error(&result, have_error);
2638
2639 free_var(&result);
2640
2641 return res;
2642 }
2643
2644
2645 /*
2646 * numeric_div_trunc() -
2647 *
2648 * Divide one numeric into another, truncating the result to an integer
2649 */
2650 Datum
numeric_div_trunc(PG_FUNCTION_ARGS)2651 numeric_div_trunc(PG_FUNCTION_ARGS)
2652 {
2653 Numeric num1 = PG_GETARG_NUMERIC(0);
2654 Numeric num2 = PG_GETARG_NUMERIC(1);
2655 NumericVar arg1;
2656 NumericVar arg2;
2657 NumericVar result;
2658 Numeric res;
2659
2660 /*
2661 * Handle NaN
2662 */
2663 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2664 PG_RETURN_NUMERIC(make_result(&const_nan));
2665
2666 /*
2667 * Unpack the arguments
2668 */
2669 init_var_from_num(num1, &arg1);
2670 init_var_from_num(num2, &arg2);
2671
2672 init_var(&result);
2673
2674 /*
2675 * Do the divide and return the result
2676 */
2677 div_var(&arg1, &arg2, &result, 0, false);
2678
2679 res = make_result(&result);
2680
2681 free_var(&result);
2682
2683 PG_RETURN_NUMERIC(res);
2684 }
2685
2686
2687 /*
2688 * numeric_mod() -
2689 *
2690 * Calculate the modulo of two numerics
2691 */
2692 Datum
numeric_mod(PG_FUNCTION_ARGS)2693 numeric_mod(PG_FUNCTION_ARGS)
2694 {
2695 Numeric num1 = PG_GETARG_NUMERIC(0);
2696 Numeric num2 = PG_GETARG_NUMERIC(1);
2697 Numeric res;
2698
2699 res = numeric_mod_opt_error(num1, num2, NULL);
2700
2701 PG_RETURN_NUMERIC(res);
2702 }
2703
2704
2705 /*
2706 * numeric_mod_opt_error() -
2707 *
2708 * Internal version of numeric_mod(). If "*have_error" flag is provided,
2709 * on error it's set to true, NULL returned. This is helpful when caller
2710 * need to handle errors by itself.
2711 */
2712 Numeric
numeric_mod_opt_error(Numeric num1,Numeric num2,bool * have_error)2713 numeric_mod_opt_error(Numeric num1, Numeric num2, bool *have_error)
2714 {
2715 Numeric res;
2716 NumericVar arg1;
2717 NumericVar arg2;
2718 NumericVar result;
2719
2720 if (have_error)
2721 *have_error = false;
2722
2723 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2724 return make_result(&const_nan);
2725
2726 init_var_from_num(num1, &arg1);
2727 init_var_from_num(num2, &arg2);
2728
2729 init_var(&result);
2730
2731 /*
2732 * If "have_error" is provided, check for division by zero here
2733 */
2734 if (have_error && (arg2.ndigits == 0 || arg2.digits[0] == 0))
2735 {
2736 *have_error = true;
2737 return NULL;
2738 }
2739
2740 mod_var(&arg1, &arg2, &result);
2741
2742 res = make_result_opt_error(&result, NULL);
2743
2744 free_var(&result);
2745
2746 return res;
2747 }
2748
2749
2750 /*
2751 * numeric_inc() -
2752 *
2753 * Increment a number by one
2754 */
2755 Datum
numeric_inc(PG_FUNCTION_ARGS)2756 numeric_inc(PG_FUNCTION_ARGS)
2757 {
2758 Numeric num = PG_GETARG_NUMERIC(0);
2759 NumericVar arg;
2760 Numeric res;
2761
2762 /*
2763 * Handle NaN
2764 */
2765 if (NUMERIC_IS_NAN(num))
2766 PG_RETURN_NUMERIC(make_result(&const_nan));
2767
2768 /*
2769 * Compute the result and return it
2770 */
2771 init_var_from_num(num, &arg);
2772
2773 add_var(&arg, &const_one, &arg);
2774
2775 res = make_result(&arg);
2776
2777 free_var(&arg);
2778
2779 PG_RETURN_NUMERIC(res);
2780 }
2781
2782
2783 /*
2784 * numeric_smaller() -
2785 *
2786 * Return the smaller of two numbers
2787 */
2788 Datum
numeric_smaller(PG_FUNCTION_ARGS)2789 numeric_smaller(PG_FUNCTION_ARGS)
2790 {
2791 Numeric num1 = PG_GETARG_NUMERIC(0);
2792 Numeric num2 = PG_GETARG_NUMERIC(1);
2793
2794 /*
2795 * Use cmp_numerics so that this will agree with the comparison operators,
2796 * particularly as regards comparisons involving NaN.
2797 */
2798 if (cmp_numerics(num1, num2) < 0)
2799 PG_RETURN_NUMERIC(num1);
2800 else
2801 PG_RETURN_NUMERIC(num2);
2802 }
2803
2804
2805 /*
2806 * numeric_larger() -
2807 *
2808 * Return the larger of two numbers
2809 */
2810 Datum
numeric_larger(PG_FUNCTION_ARGS)2811 numeric_larger(PG_FUNCTION_ARGS)
2812 {
2813 Numeric num1 = PG_GETARG_NUMERIC(0);
2814 Numeric num2 = PG_GETARG_NUMERIC(1);
2815
2816 /*
2817 * Use cmp_numerics so that this will agree with the comparison operators,
2818 * particularly as regards comparisons involving NaN.
2819 */
2820 if (cmp_numerics(num1, num2) > 0)
2821 PG_RETURN_NUMERIC(num1);
2822 else
2823 PG_RETURN_NUMERIC(num2);
2824 }
2825
2826
2827 /* ----------------------------------------------------------------------
2828 *
2829 * Advanced math functions
2830 *
2831 * ----------------------------------------------------------------------
2832 */
2833
2834 /*
2835 * numeric_gcd() -
2836 *
2837 * Calculate the greatest common divisor of two numerics
2838 */
2839 Datum
numeric_gcd(PG_FUNCTION_ARGS)2840 numeric_gcd(PG_FUNCTION_ARGS)
2841 {
2842 Numeric num1 = PG_GETARG_NUMERIC(0);
2843 Numeric num2 = PG_GETARG_NUMERIC(1);
2844 NumericVar arg1;
2845 NumericVar arg2;
2846 NumericVar result;
2847 Numeric res;
2848
2849 /*
2850 * Handle NaN
2851 */
2852 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2853 PG_RETURN_NUMERIC(make_result(&const_nan));
2854
2855 /*
2856 * Unpack the arguments
2857 */
2858 init_var_from_num(num1, &arg1);
2859 init_var_from_num(num2, &arg2);
2860
2861 init_var(&result);
2862
2863 /*
2864 * Find the GCD and return the result
2865 */
2866 gcd_var(&arg1, &arg2, &result);
2867
2868 res = make_result(&result);
2869
2870 free_var(&result);
2871
2872 PG_RETURN_NUMERIC(res);
2873 }
2874
2875
2876 /*
2877 * numeric_lcm() -
2878 *
2879 * Calculate the least common multiple of two numerics
2880 */
2881 Datum
numeric_lcm(PG_FUNCTION_ARGS)2882 numeric_lcm(PG_FUNCTION_ARGS)
2883 {
2884 Numeric num1 = PG_GETARG_NUMERIC(0);
2885 Numeric num2 = PG_GETARG_NUMERIC(1);
2886 NumericVar arg1;
2887 NumericVar arg2;
2888 NumericVar result;
2889 Numeric res;
2890
2891 /*
2892 * Handle NaN
2893 */
2894 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
2895 PG_RETURN_NUMERIC(make_result(&const_nan));
2896
2897 /*
2898 * Unpack the arguments
2899 */
2900 init_var_from_num(num1, &arg1);
2901 init_var_from_num(num2, &arg2);
2902
2903 init_var(&result);
2904
2905 /*
2906 * Compute the result using lcm(x, y) = abs(x / gcd(x, y) * y), returning
2907 * zero if either input is zero.
2908 *
2909 * Note that the division is guaranteed to be exact, returning an integer
2910 * result, so the LCM is an integral multiple of both x and y. A display
2911 * scale of Min(x.dscale, y.dscale) would be sufficient to represent it,
2912 * but as with other numeric functions, we choose to return a result whose
2913 * display scale is no smaller than either input.
2914 */
2915 if (arg1.ndigits == 0 || arg2.ndigits == 0)
2916 set_var_from_var(&const_zero, &result);
2917 else
2918 {
2919 gcd_var(&arg1, &arg2, &result);
2920 div_var(&arg1, &result, &result, 0, false);
2921 mul_var(&arg2, &result, &result, arg2.dscale);
2922 result.sign = NUMERIC_POS;
2923 }
2924
2925 result.dscale = Max(arg1.dscale, arg2.dscale);
2926
2927 res = make_result(&result);
2928
2929 free_var(&result);
2930
2931 PG_RETURN_NUMERIC(res);
2932 }
2933
2934
2935 /*
2936 * numeric_fac()
2937 *
2938 * Compute factorial
2939 */
2940 Datum
numeric_fac(PG_FUNCTION_ARGS)2941 numeric_fac(PG_FUNCTION_ARGS)
2942 {
2943 int64 num = PG_GETARG_INT64(0);
2944 Numeric res;
2945 NumericVar fact;
2946 NumericVar result;
2947
2948 if (num <= 1)
2949 {
2950 res = make_result(&const_one);
2951 PG_RETURN_NUMERIC(res);
2952 }
2953 /* Fail immediately if the result would overflow */
2954 if (num > 32177)
2955 ereport(ERROR,
2956 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
2957 errmsg("value overflows numeric format")));
2958
2959 init_var(&fact);
2960 init_var(&result);
2961
2962 int64_to_numericvar(num, &result);
2963
2964 for (num = num - 1; num > 1; num--)
2965 {
2966 /* this loop can take awhile, so allow it to be interrupted */
2967 CHECK_FOR_INTERRUPTS();
2968
2969 int64_to_numericvar(num, &fact);
2970
2971 mul_var(&result, &fact, &result, 0);
2972 }
2973
2974 res = make_result(&result);
2975
2976 free_var(&fact);
2977 free_var(&result);
2978
2979 PG_RETURN_NUMERIC(res);
2980 }
2981
2982
2983 /*
2984 * numeric_sqrt() -
2985 *
2986 * Compute the square root of a numeric.
2987 */
2988 Datum
numeric_sqrt(PG_FUNCTION_ARGS)2989 numeric_sqrt(PG_FUNCTION_ARGS)
2990 {
2991 Numeric num = PG_GETARG_NUMERIC(0);
2992 Numeric res;
2993 NumericVar arg;
2994 NumericVar result;
2995 int sweight;
2996 int rscale;
2997
2998 /*
2999 * Handle NaN
3000 */
3001 if (NUMERIC_IS_NAN(num))
3002 PG_RETURN_NUMERIC(make_result(&const_nan));
3003
3004 /*
3005 * Unpack the argument and determine the result scale. We choose a scale
3006 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
3007 * case not less than the input's dscale.
3008 */
3009 init_var_from_num(num, &arg);
3010
3011 init_var(&result);
3012
3013 /* Assume the input was normalized, so arg.weight is accurate */
3014 sweight = (arg.weight + 1) * DEC_DIGITS / 2 - 1;
3015
3016 rscale = NUMERIC_MIN_SIG_DIGITS - sweight;
3017 rscale = Max(rscale, arg.dscale);
3018 rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3019 rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3020
3021 /*
3022 * Let sqrt_var() do the calculation and return the result.
3023 */
3024 sqrt_var(&arg, &result, rscale);
3025
3026 res = make_result(&result);
3027
3028 free_var(&result);
3029
3030 PG_RETURN_NUMERIC(res);
3031 }
3032
3033
3034 /*
3035 * numeric_exp() -
3036 *
3037 * Raise e to the power of x
3038 */
3039 Datum
numeric_exp(PG_FUNCTION_ARGS)3040 numeric_exp(PG_FUNCTION_ARGS)
3041 {
3042 Numeric num = PG_GETARG_NUMERIC(0);
3043 Numeric res;
3044 NumericVar arg;
3045 NumericVar result;
3046 int rscale;
3047 double val;
3048
3049 /*
3050 * Handle NaN
3051 */
3052 if (NUMERIC_IS_NAN(num))
3053 PG_RETURN_NUMERIC(make_result(&const_nan));
3054
3055 /*
3056 * Unpack the argument and determine the result scale. We choose a scale
3057 * to give at least NUMERIC_MIN_SIG_DIGITS significant digits; but in any
3058 * case not less than the input's dscale.
3059 */
3060 init_var_from_num(num, &arg);
3061
3062 init_var(&result);
3063
3064 /* convert input to float8, ignoring overflow */
3065 val = numericvar_to_double_no_overflow(&arg);
3066
3067 /*
3068 * log10(result) = num * log10(e), so this is approximately the decimal
3069 * weight of the result:
3070 */
3071 val *= 0.434294481903252;
3072
3073 /* limit to something that won't cause integer overflow */
3074 val = Max(val, -NUMERIC_MAX_RESULT_SCALE);
3075 val = Min(val, NUMERIC_MAX_RESULT_SCALE);
3076
3077 rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
3078 rscale = Max(rscale, arg.dscale);
3079 rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3080 rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3081
3082 /*
3083 * Let exp_var() do the calculation and return the result.
3084 */
3085 exp_var(&arg, &result, rscale);
3086
3087 res = make_result(&result);
3088
3089 free_var(&result);
3090
3091 PG_RETURN_NUMERIC(res);
3092 }
3093
3094
3095 /*
3096 * numeric_ln() -
3097 *
3098 * Compute the natural logarithm of x
3099 */
3100 Datum
numeric_ln(PG_FUNCTION_ARGS)3101 numeric_ln(PG_FUNCTION_ARGS)
3102 {
3103 Numeric num = PG_GETARG_NUMERIC(0);
3104 Numeric res;
3105 NumericVar arg;
3106 NumericVar result;
3107 int ln_dweight;
3108 int rscale;
3109
3110 /*
3111 * Handle NaN
3112 */
3113 if (NUMERIC_IS_NAN(num))
3114 PG_RETURN_NUMERIC(make_result(&const_nan));
3115
3116 init_var_from_num(num, &arg);
3117 init_var(&result);
3118
3119 /* Estimated dweight of logarithm */
3120 ln_dweight = estimate_ln_dweight(&arg);
3121
3122 rscale = NUMERIC_MIN_SIG_DIGITS - ln_dweight;
3123 rscale = Max(rscale, arg.dscale);
3124 rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
3125 rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
3126
3127 ln_var(&arg, &result, rscale);
3128
3129 res = make_result(&result);
3130
3131 free_var(&result);
3132
3133 PG_RETURN_NUMERIC(res);
3134 }
3135
3136
3137 /*
3138 * numeric_log() -
3139 *
3140 * Compute the logarithm of x in a given base
3141 */
3142 Datum
numeric_log(PG_FUNCTION_ARGS)3143 numeric_log(PG_FUNCTION_ARGS)
3144 {
3145 Numeric num1 = PG_GETARG_NUMERIC(0);
3146 Numeric num2 = PG_GETARG_NUMERIC(1);
3147 Numeric res;
3148 NumericVar arg1;
3149 NumericVar arg2;
3150 NumericVar result;
3151
3152 /*
3153 * Handle NaN
3154 */
3155 if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
3156 PG_RETURN_NUMERIC(make_result(&const_nan));
3157
3158 /*
3159 * Initialize things
3160 */
3161 init_var_from_num(num1, &arg1);
3162 init_var_from_num(num2, &arg2);
3163 init_var(&result);
3164
3165 /*
3166 * Call log_var() to compute and return the result; note it handles scale
3167 * selection itself.
3168 */
3169 log_var(&arg1, &arg2, &result);
3170
3171 res = make_result(&result);
3172
3173 free_var(&result);
3174
3175 PG_RETURN_NUMERIC(res);
3176 }
3177
3178
3179 /*
3180 * numeric_power() -
3181 *
3182 * Raise b to the power of x
3183 */
3184 Datum
numeric_power(PG_FUNCTION_ARGS)3185 numeric_power(PG_FUNCTION_ARGS)
3186 {
3187 Numeric num1 = PG_GETARG_NUMERIC(0);
3188 Numeric num2 = PG_GETARG_NUMERIC(1);
3189 Numeric res;
3190 NumericVar arg1;
3191 NumericVar arg2;
3192 NumericVar arg2_trunc;
3193 NumericVar result;
3194
3195 /*
3196 * Handle NaN cases. We follow the POSIX spec for pow(3), which says that
3197 * NaN ^ 0 = 1, and 1 ^ NaN = 1, while all other cases with NaN inputs
3198 * yield NaN (with no error).
3199 */
3200 if (NUMERIC_IS_NAN(num1))
3201 {
3202 if (!NUMERIC_IS_NAN(num2))
3203 {
3204 init_var_from_num(num2, &arg2);
3205 if (cmp_var(&arg2, &const_zero) == 0)
3206 PG_RETURN_NUMERIC(make_result(&const_one));
3207 }
3208 PG_RETURN_NUMERIC(make_result(&const_nan));
3209 }
3210 if (NUMERIC_IS_NAN(num2))
3211 {
3212 init_var_from_num(num1, &arg1);
3213 if (cmp_var(&arg1, &const_one) == 0)
3214 PG_RETURN_NUMERIC(make_result(&const_one));
3215 PG_RETURN_NUMERIC(make_result(&const_nan));
3216 }
3217
3218 /*
3219 * Initialize things
3220 */
3221 init_var(&arg2_trunc);
3222 init_var(&result);
3223 init_var_from_num(num1, &arg1);
3224 init_var_from_num(num2, &arg2);
3225
3226 set_var_from_var(&arg2, &arg2_trunc);
3227 trunc_var(&arg2_trunc, 0);
3228
3229 /*
3230 * The SQL spec requires that we emit a particular SQLSTATE error code for
3231 * certain error conditions. Specifically, we don't return a
3232 * divide-by-zero error code for 0 ^ -1. Raising a negative number to a
3233 * non-integer power must produce the same error code, but that case is
3234 * handled in power_var().
3235 */
3236 if (cmp_var(&arg1, &const_zero) == 0 &&
3237 cmp_var(&arg2, &const_zero) < 0)
3238 ereport(ERROR,
3239 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
3240 errmsg("zero raised to a negative power is undefined")));
3241
3242 /*
3243 * Call power_var() to compute and return the result; note it handles
3244 * scale selection itself.
3245 */
3246 power_var(&arg1, &arg2, &result);
3247
3248 res = make_result(&result);
3249
3250 free_var(&result);
3251 free_var(&arg2_trunc);
3252
3253 PG_RETURN_NUMERIC(res);
3254 }
3255
3256 /*
3257 * numeric_scale() -
3258 *
3259 * Returns the scale, i.e. the count of decimal digits in the fractional part
3260 */
3261 Datum
numeric_scale(PG_FUNCTION_ARGS)3262 numeric_scale(PG_FUNCTION_ARGS)
3263 {
3264 Numeric num = PG_GETARG_NUMERIC(0);
3265
3266 if (NUMERIC_IS_NAN(num))
3267 PG_RETURN_NULL();
3268
3269 PG_RETURN_INT32(NUMERIC_DSCALE(num));
3270 }
3271
3272 /*
3273 * Calculate minimum scale for value.
3274 */
3275 static int
get_min_scale(NumericVar * var)3276 get_min_scale(NumericVar *var)
3277 {
3278 int min_scale;
3279 int last_digit_pos;
3280
3281 /*
3282 * Ordinarily, the input value will be "stripped" so that the last
3283 * NumericDigit is nonzero. But we don't want to get into an infinite
3284 * loop if it isn't, so explicitly find the last nonzero digit.
3285 */
3286 last_digit_pos = var->ndigits - 1;
3287 while (last_digit_pos >= 0 &&
3288 var->digits[last_digit_pos] == 0)
3289 last_digit_pos--;
3290
3291 if (last_digit_pos >= 0)
3292 {
3293 /* compute min_scale assuming that last ndigit has no zeroes */
3294 min_scale = (last_digit_pos - var->weight) * DEC_DIGITS;
3295
3296 /*
3297 * We could get a negative result if there are no digits after the
3298 * decimal point. In this case the min_scale must be zero.
3299 */
3300 if (min_scale > 0)
3301 {
3302 /*
3303 * Reduce min_scale if trailing digit(s) in last NumericDigit are
3304 * zero.
3305 */
3306 NumericDigit last_digit = var->digits[last_digit_pos];
3307
3308 while (last_digit % 10 == 0)
3309 {
3310 min_scale--;
3311 last_digit /= 10;
3312 }
3313 }
3314 else
3315 min_scale = 0;
3316 }
3317 else
3318 min_scale = 0; /* result if input is zero */
3319
3320 return min_scale;
3321 }
3322
3323 /*
3324 * Returns minimum scale required to represent supplied value without loss.
3325 */
3326 Datum
numeric_min_scale(PG_FUNCTION_ARGS)3327 numeric_min_scale(PG_FUNCTION_ARGS)
3328 {
3329 Numeric num = PG_GETARG_NUMERIC(0);
3330 NumericVar arg;
3331 int min_scale;
3332
3333 if (NUMERIC_IS_NAN(num))
3334 PG_RETURN_NULL();
3335
3336 init_var_from_num(num, &arg);
3337 min_scale = get_min_scale(&arg);
3338 free_var(&arg);
3339
3340 PG_RETURN_INT32(min_scale);
3341 }
3342
3343 /*
3344 * Reduce scale of numeric value to represent supplied value without loss.
3345 */
3346 Datum
numeric_trim_scale(PG_FUNCTION_ARGS)3347 numeric_trim_scale(PG_FUNCTION_ARGS)
3348 {
3349 Numeric num = PG_GETARG_NUMERIC(0);
3350 Numeric res;
3351 NumericVar result;
3352
3353 if (NUMERIC_IS_NAN(num))
3354 PG_RETURN_NUMERIC(make_result(&const_nan));
3355
3356 init_var_from_num(num, &result);
3357 result.dscale = get_min_scale(&result);
3358 res = make_result(&result);
3359 free_var(&result);
3360
3361 PG_RETURN_NUMERIC(res);
3362 }
3363
3364
3365 /* ----------------------------------------------------------------------
3366 *
3367 * Type conversion functions
3368 *
3369 * ----------------------------------------------------------------------
3370 */
3371
3372
3373 Datum
int4_numeric(PG_FUNCTION_ARGS)3374 int4_numeric(PG_FUNCTION_ARGS)
3375 {
3376 int32 val = PG_GETARG_INT32(0);
3377 Numeric res;
3378 NumericVar result;
3379
3380 init_var(&result);
3381
3382 int64_to_numericvar((int64) val, &result);
3383
3384 res = make_result(&result);
3385
3386 free_var(&result);
3387
3388 PG_RETURN_NUMERIC(res);
3389 }
3390
3391 int32
numeric_int4_opt_error(Numeric num,bool * have_error)3392 numeric_int4_opt_error(Numeric num, bool *have_error)
3393 {
3394 NumericVar x;
3395 int32 result;
3396
3397 if (have_error)
3398 *have_error = false;
3399
3400 /* XXX would it be better to return NULL? */
3401 if (NUMERIC_IS_NAN(num))
3402 {
3403 if (have_error)
3404 {
3405 *have_error = true;
3406 return 0;
3407 }
3408 else
3409 {
3410 ereport(ERROR,
3411 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3412 errmsg("cannot convert NaN to integer")));
3413 }
3414 }
3415
3416 /* Convert to variable format, then convert to int4 */
3417 init_var_from_num(num, &x);
3418
3419 if (!numericvar_to_int32(&x, &result))
3420 {
3421 if (have_error)
3422 {
3423 *have_error = true;
3424 return 0;
3425 }
3426 else
3427 {
3428 ereport(ERROR,
3429 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3430 errmsg("integer out of range")));
3431 }
3432 }
3433
3434 return result;
3435 }
3436
3437 Datum
numeric_int4(PG_FUNCTION_ARGS)3438 numeric_int4(PG_FUNCTION_ARGS)
3439 {
3440 Numeric num = PG_GETARG_NUMERIC(0);
3441
3442 PG_RETURN_INT32(numeric_int4_opt_error(num, NULL));
3443 }
3444
3445 /*
3446 * Given a NumericVar, convert it to an int32. If the NumericVar
3447 * exceeds the range of an int32, false is returned, otherwise true is returned.
3448 * The input NumericVar is *not* free'd.
3449 */
3450 static bool
numericvar_to_int32(const NumericVar * var,int32 * result)3451 numericvar_to_int32(const NumericVar *var, int32 *result)
3452 {
3453 int64 val;
3454
3455 if (!numericvar_to_int64(var, &val))
3456 return false;
3457
3458 if (unlikely(val < PG_INT32_MIN) || unlikely(val > PG_INT32_MAX))
3459 return false;
3460
3461 /* Down-convert to int4 */
3462 *result = (int32) val;
3463
3464 return true;
3465 }
3466
3467 Datum
int8_numeric(PG_FUNCTION_ARGS)3468 int8_numeric(PG_FUNCTION_ARGS)
3469 {
3470 int64 val = PG_GETARG_INT64(0);
3471 Numeric res;
3472 NumericVar result;
3473
3474 init_var(&result);
3475
3476 int64_to_numericvar(val, &result);
3477
3478 res = make_result(&result);
3479
3480 free_var(&result);
3481
3482 PG_RETURN_NUMERIC(res);
3483 }
3484
3485
3486 Datum
numeric_int8(PG_FUNCTION_ARGS)3487 numeric_int8(PG_FUNCTION_ARGS)
3488 {
3489 Numeric num = PG_GETARG_NUMERIC(0);
3490 NumericVar x;
3491 int64 result;
3492
3493 /* XXX would it be better to return NULL? */
3494 if (NUMERIC_IS_NAN(num))
3495 ereport(ERROR,
3496 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3497 errmsg("cannot convert NaN to bigint")));
3498
3499 /* Convert to variable format and thence to int8 */
3500 init_var_from_num(num, &x);
3501
3502 if (!numericvar_to_int64(&x, &result))
3503 ereport(ERROR,
3504 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3505 errmsg("bigint out of range")));
3506
3507 PG_RETURN_INT64(result);
3508 }
3509
3510
3511 Datum
int2_numeric(PG_FUNCTION_ARGS)3512 int2_numeric(PG_FUNCTION_ARGS)
3513 {
3514 int16 val = PG_GETARG_INT16(0);
3515 Numeric res;
3516 NumericVar result;
3517
3518 init_var(&result);
3519
3520 int64_to_numericvar((int64) val, &result);
3521
3522 res = make_result(&result);
3523
3524 free_var(&result);
3525
3526 PG_RETURN_NUMERIC(res);
3527 }
3528
3529
3530 Datum
numeric_int2(PG_FUNCTION_ARGS)3531 numeric_int2(PG_FUNCTION_ARGS)
3532 {
3533 Numeric num = PG_GETARG_NUMERIC(0);
3534 NumericVar x;
3535 int64 val;
3536 int16 result;
3537
3538 /* XXX would it be better to return NULL? */
3539 if (NUMERIC_IS_NAN(num))
3540 ereport(ERROR,
3541 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3542 errmsg("cannot convert NaN to smallint")));
3543
3544 /* Convert to variable format and thence to int8 */
3545 init_var_from_num(num, &x);
3546
3547 if (!numericvar_to_int64(&x, &val))
3548 ereport(ERROR,
3549 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3550 errmsg("smallint out of range")));
3551
3552 if (unlikely(val < PG_INT16_MIN) || unlikely(val > PG_INT16_MAX))
3553 ereport(ERROR,
3554 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
3555 errmsg("smallint out of range")));
3556
3557 /* Down-convert to int2 */
3558 result = (int16) val;
3559
3560 PG_RETURN_INT16(result);
3561 }
3562
3563
3564 Datum
float8_numeric(PG_FUNCTION_ARGS)3565 float8_numeric(PG_FUNCTION_ARGS)
3566 {
3567 float8 val = PG_GETARG_FLOAT8(0);
3568 Numeric res;
3569 NumericVar result;
3570 char buf[DBL_DIG + 100];
3571
3572 if (isnan(val))
3573 PG_RETURN_NUMERIC(make_result(&const_nan));
3574
3575 if (isinf(val))
3576 ereport(ERROR,
3577 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3578 errmsg("cannot convert infinity to numeric")));
3579
3580 snprintf(buf, sizeof(buf), "%.*g", DBL_DIG, val);
3581
3582 init_var(&result);
3583
3584 /* Assume we need not worry about leading/trailing spaces */
3585 (void) set_var_from_str(buf, buf, &result);
3586
3587 res = make_result(&result);
3588
3589 free_var(&result);
3590
3591 PG_RETURN_NUMERIC(res);
3592 }
3593
3594
3595 Datum
numeric_float8(PG_FUNCTION_ARGS)3596 numeric_float8(PG_FUNCTION_ARGS)
3597 {
3598 Numeric num = PG_GETARG_NUMERIC(0);
3599 char *tmp;
3600 Datum result;
3601
3602 if (NUMERIC_IS_NAN(num))
3603 PG_RETURN_FLOAT8(get_float8_nan());
3604
3605 tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3606 NumericGetDatum(num)));
3607
3608 result = DirectFunctionCall1(float8in, CStringGetDatum(tmp));
3609
3610 pfree(tmp);
3611
3612 PG_RETURN_DATUM(result);
3613 }
3614
3615
3616 /*
3617 * Convert numeric to float8; if out of range, return +/- HUGE_VAL
3618 *
3619 * (internal helper function, not directly callable from SQL)
3620 */
3621 Datum
numeric_float8_no_overflow(PG_FUNCTION_ARGS)3622 numeric_float8_no_overflow(PG_FUNCTION_ARGS)
3623 {
3624 Numeric num = PG_GETARG_NUMERIC(0);
3625 double val;
3626
3627 if (NUMERIC_IS_NAN(num))
3628 PG_RETURN_FLOAT8(get_float8_nan());
3629
3630 val = numeric_to_double_no_overflow(num);
3631
3632 PG_RETURN_FLOAT8(val);
3633 }
3634
3635 Datum
float4_numeric(PG_FUNCTION_ARGS)3636 float4_numeric(PG_FUNCTION_ARGS)
3637 {
3638 float4 val = PG_GETARG_FLOAT4(0);
3639 Numeric res;
3640 NumericVar result;
3641 char buf[FLT_DIG + 100];
3642
3643 if (isnan(val))
3644 PG_RETURN_NUMERIC(make_result(&const_nan));
3645
3646 if (isinf(val))
3647 ereport(ERROR,
3648 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
3649 errmsg("cannot convert infinity to numeric")));
3650
3651 snprintf(buf, sizeof(buf), "%.*g", FLT_DIG, val);
3652
3653 init_var(&result);
3654
3655 /* Assume we need not worry about leading/trailing spaces */
3656 (void) set_var_from_str(buf, buf, &result);
3657
3658 res = make_result(&result);
3659
3660 free_var(&result);
3661
3662 PG_RETURN_NUMERIC(res);
3663 }
3664
3665
3666 Datum
numeric_float4(PG_FUNCTION_ARGS)3667 numeric_float4(PG_FUNCTION_ARGS)
3668 {
3669 Numeric num = PG_GETARG_NUMERIC(0);
3670 char *tmp;
3671 Datum result;
3672
3673 if (NUMERIC_IS_NAN(num))
3674 PG_RETURN_FLOAT4(get_float4_nan());
3675
3676 tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
3677 NumericGetDatum(num)));
3678
3679 result = DirectFunctionCall1(float4in, CStringGetDatum(tmp));
3680
3681 pfree(tmp);
3682
3683 PG_RETURN_DATUM(result);
3684 }
3685
3686
3687 /* ----------------------------------------------------------------------
3688 *
3689 * Aggregate functions
3690 *
3691 * The transition datatype for all these aggregates is declared as INTERNAL.
3692 * Actually, it's a pointer to a NumericAggState allocated in the aggregate
3693 * context. The digit buffers for the NumericVars will be there too.
3694 *
3695 * On platforms which support 128-bit integers some aggregates instead use a
3696 * 128-bit integer based transition datatype to speed up calculations.
3697 *
3698 * ----------------------------------------------------------------------
3699 */
3700
3701 typedef struct NumericAggState
3702 {
3703 bool calcSumX2; /* if true, calculate sumX2 */
3704 MemoryContext agg_context; /* context we're calculating in */
3705 int64 N; /* count of processed numbers */
3706 NumericSumAccum sumX; /* sum of processed numbers */
3707 NumericSumAccum sumX2; /* sum of squares of processed numbers */
3708 int maxScale; /* maximum scale seen so far */
3709 int64 maxScaleCount; /* number of values seen with maximum scale */
3710 int64 NaNcount; /* count of NaN values (not included in N!) */
3711 } NumericAggState;
3712
3713 /*
3714 * Prepare state data for a numeric aggregate function that needs to compute
3715 * sum, count and optionally sum of squares of the input.
3716 */
3717 static NumericAggState *
makeNumericAggState(FunctionCallInfo fcinfo,bool calcSumX2)3718 makeNumericAggState(FunctionCallInfo fcinfo, bool calcSumX2)
3719 {
3720 NumericAggState *state;
3721 MemoryContext agg_context;
3722 MemoryContext old_context;
3723
3724 if (!AggCheckCallContext(fcinfo, &agg_context))
3725 elog(ERROR, "aggregate function called in non-aggregate context");
3726
3727 old_context = MemoryContextSwitchTo(agg_context);
3728
3729 state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3730 state->calcSumX2 = calcSumX2;
3731 state->agg_context = agg_context;
3732
3733 MemoryContextSwitchTo(old_context);
3734
3735 return state;
3736 }
3737
3738 /*
3739 * Like makeNumericAggState(), but allocate the state in the current memory
3740 * context.
3741 */
3742 static NumericAggState *
makeNumericAggStateCurrentContext(bool calcSumX2)3743 makeNumericAggStateCurrentContext(bool calcSumX2)
3744 {
3745 NumericAggState *state;
3746
3747 state = (NumericAggState *) palloc0(sizeof(NumericAggState));
3748 state->calcSumX2 = calcSumX2;
3749 state->agg_context = CurrentMemoryContext;
3750
3751 return state;
3752 }
3753
3754 /*
3755 * Accumulate a new input value for numeric aggregate functions.
3756 */
3757 static void
do_numeric_accum(NumericAggState * state,Numeric newval)3758 do_numeric_accum(NumericAggState *state, Numeric newval)
3759 {
3760 NumericVar X;
3761 NumericVar X2;
3762 MemoryContext old_context;
3763
3764 /* Count NaN inputs separately from all else */
3765 if (NUMERIC_IS_NAN(newval))
3766 {
3767 state->NaNcount++;
3768 return;
3769 }
3770
3771 /* load processed number in short-lived context */
3772 init_var_from_num(newval, &X);
3773
3774 /*
3775 * Track the highest input dscale that we've seen, to support inverse
3776 * transitions (see do_numeric_discard).
3777 */
3778 if (X.dscale > state->maxScale)
3779 {
3780 state->maxScale = X.dscale;
3781 state->maxScaleCount = 1;
3782 }
3783 else if (X.dscale == state->maxScale)
3784 state->maxScaleCount++;
3785
3786 /* if we need X^2, calculate that in short-lived context */
3787 if (state->calcSumX2)
3788 {
3789 init_var(&X2);
3790 mul_var(&X, &X, &X2, X.dscale * 2);
3791 }
3792
3793 /* The rest of this needs to work in the aggregate context */
3794 old_context = MemoryContextSwitchTo(state->agg_context);
3795
3796 state->N++;
3797
3798 /* Accumulate sums */
3799 accum_sum_add(&(state->sumX), &X);
3800
3801 if (state->calcSumX2)
3802 accum_sum_add(&(state->sumX2), &X2);
3803
3804 MemoryContextSwitchTo(old_context);
3805 }
3806
3807 /*
3808 * Attempt to remove an input value from the aggregated state.
3809 *
3810 * If the value cannot be removed then the function will return false; the
3811 * possible reasons for failing are described below.
3812 *
3813 * If we aggregate the values 1.01 and 2 then the result will be 3.01.
3814 * If we are then asked to un-aggregate the 1.01 then we must fail as we
3815 * won't be able to tell what the new aggregated value's dscale should be.
3816 * We don't want to return 2.00 (dscale = 2), since the sum's dscale would
3817 * have been zero if we'd really aggregated only 2.
3818 *
3819 * Note: alternatively, we could count the number of inputs with each possible
3820 * dscale (up to some sane limit). Not yet clear if it's worth the trouble.
3821 */
3822 static bool
do_numeric_discard(NumericAggState * state,Numeric newval)3823 do_numeric_discard(NumericAggState *state, Numeric newval)
3824 {
3825 NumericVar X;
3826 NumericVar X2;
3827 MemoryContext old_context;
3828
3829 /* Count NaN inputs separately from all else */
3830 if (NUMERIC_IS_NAN(newval))
3831 {
3832 state->NaNcount--;
3833 return true;
3834 }
3835
3836 /* load processed number in short-lived context */
3837 init_var_from_num(newval, &X);
3838
3839 /*
3840 * state->sumX's dscale is the maximum dscale of any of the inputs.
3841 * Removing the last input with that dscale would require us to recompute
3842 * the maximum dscale of the *remaining* inputs, which we cannot do unless
3843 * no more non-NaN inputs remain at all. So we report a failure instead,
3844 * and force the aggregation to be redone from scratch.
3845 */
3846 if (X.dscale == state->maxScale)
3847 {
3848 if (state->maxScaleCount > 1 || state->maxScale == 0)
3849 {
3850 /*
3851 * Some remaining inputs have same dscale, or dscale hasn't gotten
3852 * above zero anyway
3853 */
3854 state->maxScaleCount--;
3855 }
3856 else if (state->N == 1)
3857 {
3858 /* No remaining non-NaN inputs at all, so reset maxScale */
3859 state->maxScale = 0;
3860 state->maxScaleCount = 0;
3861 }
3862 else
3863 {
3864 /* Correct new maxScale is uncertain, must fail */
3865 return false;
3866 }
3867 }
3868
3869 /* if we need X^2, calculate that in short-lived context */
3870 if (state->calcSumX2)
3871 {
3872 init_var(&X2);
3873 mul_var(&X, &X, &X2, X.dscale * 2);
3874 }
3875
3876 /* The rest of this needs to work in the aggregate context */
3877 old_context = MemoryContextSwitchTo(state->agg_context);
3878
3879 if (state->N-- > 1)
3880 {
3881 /* Negate X, to subtract it from the sum */
3882 X.sign = (X.sign == NUMERIC_POS ? NUMERIC_NEG : NUMERIC_POS);
3883 accum_sum_add(&(state->sumX), &X);
3884
3885 if (state->calcSumX2)
3886 {
3887 /* Negate X^2. X^2 is always positive */
3888 X2.sign = NUMERIC_NEG;
3889 accum_sum_add(&(state->sumX2), &X2);
3890 }
3891 }
3892 else
3893 {
3894 /* Zero the sums */
3895 Assert(state->N == 0);
3896
3897 accum_sum_reset(&state->sumX);
3898 if (state->calcSumX2)
3899 accum_sum_reset(&state->sumX2);
3900 }
3901
3902 MemoryContextSwitchTo(old_context);
3903
3904 return true;
3905 }
3906
3907 /*
3908 * Generic transition function for numeric aggregates that require sumX2.
3909 */
3910 Datum
numeric_accum(PG_FUNCTION_ARGS)3911 numeric_accum(PG_FUNCTION_ARGS)
3912 {
3913 NumericAggState *state;
3914
3915 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3916
3917 /* Create the state data on the first call */
3918 if (state == NULL)
3919 state = makeNumericAggState(fcinfo, true);
3920
3921 if (!PG_ARGISNULL(1))
3922 do_numeric_accum(state, PG_GETARG_NUMERIC(1));
3923
3924 PG_RETURN_POINTER(state);
3925 }
3926
3927 /*
3928 * Generic combine function for numeric aggregates which require sumX2
3929 */
3930 Datum
numeric_combine(PG_FUNCTION_ARGS)3931 numeric_combine(PG_FUNCTION_ARGS)
3932 {
3933 NumericAggState *state1;
3934 NumericAggState *state2;
3935 MemoryContext agg_context;
3936 MemoryContext old_context;
3937
3938 if (!AggCheckCallContext(fcinfo, &agg_context))
3939 elog(ERROR, "aggregate function called in non-aggregate context");
3940
3941 state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
3942 state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
3943
3944 if (state2 == NULL)
3945 PG_RETURN_POINTER(state1);
3946
3947 /* manually copy all fields from state2 to state1 */
3948 if (state1 == NULL)
3949 {
3950 old_context = MemoryContextSwitchTo(agg_context);
3951
3952 state1 = makeNumericAggStateCurrentContext(true);
3953 state1->N = state2->N;
3954 state1->NaNcount = state2->NaNcount;
3955 state1->maxScale = state2->maxScale;
3956 state1->maxScaleCount = state2->maxScaleCount;
3957
3958 accum_sum_copy(&state1->sumX, &state2->sumX);
3959 accum_sum_copy(&state1->sumX2, &state2->sumX2);
3960
3961 MemoryContextSwitchTo(old_context);
3962
3963 PG_RETURN_POINTER(state1);
3964 }
3965
3966 state1->N += state2->N;
3967 state1->NaNcount += state2->NaNcount;
3968
3969 if (state2->N > 0)
3970 {
3971 /*
3972 * These are currently only needed for moving aggregates, but let's do
3973 * the right thing anyway...
3974 */
3975 if (state2->maxScale > state1->maxScale)
3976 {
3977 state1->maxScale = state2->maxScale;
3978 state1->maxScaleCount = state2->maxScaleCount;
3979 }
3980 else if (state2->maxScale == state1->maxScale)
3981 state1->maxScaleCount += state2->maxScaleCount;
3982
3983 /* The rest of this needs to work in the aggregate context */
3984 old_context = MemoryContextSwitchTo(agg_context);
3985
3986 /* Accumulate sums */
3987 accum_sum_combine(&state1->sumX, &state2->sumX);
3988 accum_sum_combine(&state1->sumX2, &state2->sumX2);
3989
3990 MemoryContextSwitchTo(old_context);
3991 }
3992 PG_RETURN_POINTER(state1);
3993 }
3994
3995 /*
3996 * Generic transition function for numeric aggregates that don't require sumX2.
3997 */
3998 Datum
numeric_avg_accum(PG_FUNCTION_ARGS)3999 numeric_avg_accum(PG_FUNCTION_ARGS)
4000 {
4001 NumericAggState *state;
4002
4003 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4004
4005 /* Create the state data on the first call */
4006 if (state == NULL)
4007 state = makeNumericAggState(fcinfo, false);
4008
4009 if (!PG_ARGISNULL(1))
4010 do_numeric_accum(state, PG_GETARG_NUMERIC(1));
4011
4012 PG_RETURN_POINTER(state);
4013 }
4014
4015 /*
4016 * Combine function for numeric aggregates which don't require sumX2
4017 */
4018 Datum
numeric_avg_combine(PG_FUNCTION_ARGS)4019 numeric_avg_combine(PG_FUNCTION_ARGS)
4020 {
4021 NumericAggState *state1;
4022 NumericAggState *state2;
4023 MemoryContext agg_context;
4024 MemoryContext old_context;
4025
4026 if (!AggCheckCallContext(fcinfo, &agg_context))
4027 elog(ERROR, "aggregate function called in non-aggregate context");
4028
4029 state1 = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4030 state2 = PG_ARGISNULL(1) ? NULL : (NumericAggState *) PG_GETARG_POINTER(1);
4031
4032 if (state2 == NULL)
4033 PG_RETURN_POINTER(state1);
4034
4035 /* manually copy all fields from state2 to state1 */
4036 if (state1 == NULL)
4037 {
4038 old_context = MemoryContextSwitchTo(agg_context);
4039
4040 state1 = makeNumericAggStateCurrentContext(false);
4041 state1->N = state2->N;
4042 state1->NaNcount = state2->NaNcount;
4043 state1->maxScale = state2->maxScale;
4044 state1->maxScaleCount = state2->maxScaleCount;
4045
4046 accum_sum_copy(&state1->sumX, &state2->sumX);
4047
4048 MemoryContextSwitchTo(old_context);
4049
4050 PG_RETURN_POINTER(state1);
4051 }
4052
4053 state1->N += state2->N;
4054 state1->NaNcount += state2->NaNcount;
4055
4056 if (state2->N > 0)
4057 {
4058 /*
4059 * These are currently only needed for moving aggregates, but let's do
4060 * the right thing anyway...
4061 */
4062 if (state2->maxScale > state1->maxScale)
4063 {
4064 state1->maxScale = state2->maxScale;
4065 state1->maxScaleCount = state2->maxScaleCount;
4066 }
4067 else if (state2->maxScale == state1->maxScale)
4068 state1->maxScaleCount += state2->maxScaleCount;
4069
4070 /* The rest of this needs to work in the aggregate context */
4071 old_context = MemoryContextSwitchTo(agg_context);
4072
4073 /* Accumulate sums */
4074 accum_sum_combine(&state1->sumX, &state2->sumX);
4075
4076 MemoryContextSwitchTo(old_context);
4077 }
4078 PG_RETURN_POINTER(state1);
4079 }
4080
4081 /*
4082 * numeric_avg_serialize
4083 * Serialize NumericAggState for numeric aggregates that don't require
4084 * sumX2.
4085 */
4086 Datum
numeric_avg_serialize(PG_FUNCTION_ARGS)4087 numeric_avg_serialize(PG_FUNCTION_ARGS)
4088 {
4089 NumericAggState *state;
4090 StringInfoData buf;
4091 Datum temp;
4092 bytea *sumX;
4093 bytea *result;
4094 NumericVar tmp_var;
4095
4096 /* Ensure we disallow calling when not in aggregate context */
4097 if (!AggCheckCallContext(fcinfo, NULL))
4098 elog(ERROR, "aggregate function called in non-aggregate context");
4099
4100 state = (NumericAggState *) PG_GETARG_POINTER(0);
4101
4102 /*
4103 * This is a little wasteful since make_result converts the NumericVar
4104 * into a Numeric and numeric_send converts it back again. Is it worth
4105 * splitting the tasks in numeric_send into separate functions to stop
4106 * this? Doing so would also remove the fmgr call overhead.
4107 */
4108 init_var(&tmp_var);
4109 accum_sum_final(&state->sumX, &tmp_var);
4110
4111 temp = DirectFunctionCall1(numeric_send,
4112 NumericGetDatum(make_result(&tmp_var)));
4113 sumX = DatumGetByteaPP(temp);
4114 free_var(&tmp_var);
4115
4116 pq_begintypsend(&buf);
4117
4118 /* N */
4119 pq_sendint64(&buf, state->N);
4120
4121 /* sumX */
4122 pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4123
4124 /* maxScale */
4125 pq_sendint32(&buf, state->maxScale);
4126
4127 /* maxScaleCount */
4128 pq_sendint64(&buf, state->maxScaleCount);
4129
4130 /* NaNcount */
4131 pq_sendint64(&buf, state->NaNcount);
4132
4133 result = pq_endtypsend(&buf);
4134
4135 PG_RETURN_BYTEA_P(result);
4136 }
4137
4138 /*
4139 * numeric_avg_deserialize
4140 * Deserialize bytea into NumericAggState for numeric aggregates that
4141 * don't require sumX2.
4142 */
4143 Datum
numeric_avg_deserialize(PG_FUNCTION_ARGS)4144 numeric_avg_deserialize(PG_FUNCTION_ARGS)
4145 {
4146 bytea *sstate;
4147 NumericAggState *result;
4148 Datum temp;
4149 NumericVar tmp_var;
4150 StringInfoData buf;
4151
4152 if (!AggCheckCallContext(fcinfo, NULL))
4153 elog(ERROR, "aggregate function called in non-aggregate context");
4154
4155 sstate = PG_GETARG_BYTEA_PP(0);
4156
4157 /*
4158 * Copy the bytea into a StringInfo so that we can "receive" it using the
4159 * standard recv-function infrastructure.
4160 */
4161 initStringInfo(&buf);
4162 appendBinaryStringInfo(&buf,
4163 VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4164
4165 result = makeNumericAggStateCurrentContext(false);
4166
4167 /* N */
4168 result->N = pq_getmsgint64(&buf);
4169
4170 /* sumX */
4171 temp = DirectFunctionCall3(numeric_recv,
4172 PointerGetDatum(&buf),
4173 ObjectIdGetDatum(InvalidOid),
4174 Int32GetDatum(-1));
4175 init_var_from_num(DatumGetNumeric(temp), &tmp_var);
4176 accum_sum_add(&(result->sumX), &tmp_var);
4177
4178 /* maxScale */
4179 result->maxScale = pq_getmsgint(&buf, 4);
4180
4181 /* maxScaleCount */
4182 result->maxScaleCount = pq_getmsgint64(&buf);
4183
4184 /* NaNcount */
4185 result->NaNcount = pq_getmsgint64(&buf);
4186
4187 pq_getmsgend(&buf);
4188 pfree(buf.data);
4189
4190 PG_RETURN_POINTER(result);
4191 }
4192
4193 /*
4194 * numeric_serialize
4195 * Serialization function for NumericAggState for numeric aggregates that
4196 * require sumX2.
4197 */
4198 Datum
numeric_serialize(PG_FUNCTION_ARGS)4199 numeric_serialize(PG_FUNCTION_ARGS)
4200 {
4201 NumericAggState *state;
4202 StringInfoData buf;
4203 Datum temp;
4204 bytea *sumX;
4205 NumericVar tmp_var;
4206 bytea *sumX2;
4207 bytea *result;
4208
4209 /* Ensure we disallow calling when not in aggregate context */
4210 if (!AggCheckCallContext(fcinfo, NULL))
4211 elog(ERROR, "aggregate function called in non-aggregate context");
4212
4213 state = (NumericAggState *) PG_GETARG_POINTER(0);
4214
4215 /*
4216 * This is a little wasteful since make_result converts the NumericVar
4217 * into a Numeric and numeric_send converts it back again. Is it worth
4218 * splitting the tasks in numeric_send into separate functions to stop
4219 * this? Doing so would also remove the fmgr call overhead.
4220 */
4221 init_var(&tmp_var);
4222
4223 accum_sum_final(&state->sumX, &tmp_var);
4224 temp = DirectFunctionCall1(numeric_send,
4225 NumericGetDatum(make_result(&tmp_var)));
4226 sumX = DatumGetByteaPP(temp);
4227
4228 accum_sum_final(&state->sumX2, &tmp_var);
4229 temp = DirectFunctionCall1(numeric_send,
4230 NumericGetDatum(make_result(&tmp_var)));
4231 sumX2 = DatumGetByteaPP(temp);
4232
4233 free_var(&tmp_var);
4234
4235 pq_begintypsend(&buf);
4236
4237 /* N */
4238 pq_sendint64(&buf, state->N);
4239
4240 /* sumX */
4241 pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4242
4243 /* sumX2 */
4244 pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
4245
4246 /* maxScale */
4247 pq_sendint32(&buf, state->maxScale);
4248
4249 /* maxScaleCount */
4250 pq_sendint64(&buf, state->maxScaleCount);
4251
4252 /* NaNcount */
4253 pq_sendint64(&buf, state->NaNcount);
4254
4255 result = pq_endtypsend(&buf);
4256
4257 PG_RETURN_BYTEA_P(result);
4258 }
4259
4260 /*
4261 * numeric_deserialize
4262 * Deserialization function for NumericAggState for numeric aggregates that
4263 * require sumX2.
4264 */
4265 Datum
numeric_deserialize(PG_FUNCTION_ARGS)4266 numeric_deserialize(PG_FUNCTION_ARGS)
4267 {
4268 bytea *sstate;
4269 NumericAggState *result;
4270 Datum temp;
4271 NumericVar sumX_var;
4272 NumericVar sumX2_var;
4273 StringInfoData buf;
4274
4275 if (!AggCheckCallContext(fcinfo, NULL))
4276 elog(ERROR, "aggregate function called in non-aggregate context");
4277
4278 sstate = PG_GETARG_BYTEA_PP(0);
4279
4280 /*
4281 * Copy the bytea into a StringInfo so that we can "receive" it using the
4282 * standard recv-function infrastructure.
4283 */
4284 initStringInfo(&buf);
4285 appendBinaryStringInfo(&buf,
4286 VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4287
4288 result = makeNumericAggStateCurrentContext(false);
4289
4290 /* N */
4291 result->N = pq_getmsgint64(&buf);
4292
4293 /* sumX */
4294 temp = DirectFunctionCall3(numeric_recv,
4295 PointerGetDatum(&buf),
4296 ObjectIdGetDatum(InvalidOid),
4297 Int32GetDatum(-1));
4298 init_var_from_num(DatumGetNumeric(temp), &sumX_var);
4299 accum_sum_add(&(result->sumX), &sumX_var);
4300
4301 /* sumX2 */
4302 temp = DirectFunctionCall3(numeric_recv,
4303 PointerGetDatum(&buf),
4304 ObjectIdGetDatum(InvalidOid),
4305 Int32GetDatum(-1));
4306 init_var_from_num(DatumGetNumeric(temp), &sumX2_var);
4307 accum_sum_add(&(result->sumX2), &sumX2_var);
4308
4309 /* maxScale */
4310 result->maxScale = pq_getmsgint(&buf, 4);
4311
4312 /* maxScaleCount */
4313 result->maxScaleCount = pq_getmsgint64(&buf);
4314
4315 /* NaNcount */
4316 result->NaNcount = pq_getmsgint64(&buf);
4317
4318 pq_getmsgend(&buf);
4319 pfree(buf.data);
4320
4321 PG_RETURN_POINTER(result);
4322 }
4323
4324 /*
4325 * Generic inverse transition function for numeric aggregates
4326 * (with or without requirement for X^2).
4327 */
4328 Datum
numeric_accum_inv(PG_FUNCTION_ARGS)4329 numeric_accum_inv(PG_FUNCTION_ARGS)
4330 {
4331 NumericAggState *state;
4332
4333 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4334
4335 /* Should not get here with no state */
4336 if (state == NULL)
4337 elog(ERROR, "numeric_accum_inv called with NULL state");
4338
4339 if (!PG_ARGISNULL(1))
4340 {
4341 /* If we fail to perform the inverse transition, return NULL */
4342 if (!do_numeric_discard(state, PG_GETARG_NUMERIC(1)))
4343 PG_RETURN_NULL();
4344 }
4345
4346 PG_RETURN_POINTER(state);
4347 }
4348
4349
4350 /*
4351 * Integer data types in general use Numeric accumulators to share code
4352 * and avoid risk of overflow.
4353 *
4354 * However for performance reasons optimized special-purpose accumulator
4355 * routines are used when possible.
4356 *
4357 * On platforms with 128-bit integer support, the 128-bit routines will be
4358 * used when sum(X) or sum(X*X) fit into 128-bit.
4359 *
4360 * For 16 and 32 bit inputs, the N and sum(X) fit into 64-bit so the 64-bit
4361 * accumulators will be used for SUM and AVG of these data types.
4362 */
4363
4364 #ifdef HAVE_INT128
4365 typedef struct Int128AggState
4366 {
4367 bool calcSumX2; /* if true, calculate sumX2 */
4368 int64 N; /* count of processed numbers */
4369 int128 sumX; /* sum of processed numbers */
4370 int128 sumX2; /* sum of squares of processed numbers */
4371 } Int128AggState;
4372
4373 /*
4374 * Prepare state data for a 128-bit aggregate function that needs to compute
4375 * sum, count and optionally sum of squares of the input.
4376 */
4377 static Int128AggState *
makeInt128AggState(FunctionCallInfo fcinfo,bool calcSumX2)4378 makeInt128AggState(FunctionCallInfo fcinfo, bool calcSumX2)
4379 {
4380 Int128AggState *state;
4381 MemoryContext agg_context;
4382 MemoryContext old_context;
4383
4384 if (!AggCheckCallContext(fcinfo, &agg_context))
4385 elog(ERROR, "aggregate function called in non-aggregate context");
4386
4387 old_context = MemoryContextSwitchTo(agg_context);
4388
4389 state = (Int128AggState *) palloc0(sizeof(Int128AggState));
4390 state->calcSumX2 = calcSumX2;
4391
4392 MemoryContextSwitchTo(old_context);
4393
4394 return state;
4395 }
4396
4397 /*
4398 * Like makeInt128AggState(), but allocate the state in the current memory
4399 * context.
4400 */
4401 static Int128AggState *
makeInt128AggStateCurrentContext(bool calcSumX2)4402 makeInt128AggStateCurrentContext(bool calcSumX2)
4403 {
4404 Int128AggState *state;
4405
4406 state = (Int128AggState *) palloc0(sizeof(Int128AggState));
4407 state->calcSumX2 = calcSumX2;
4408
4409 return state;
4410 }
4411
4412 /*
4413 * Accumulate a new input value for 128-bit aggregate functions.
4414 */
4415 static void
do_int128_accum(Int128AggState * state,int128 newval)4416 do_int128_accum(Int128AggState *state, int128 newval)
4417 {
4418 if (state->calcSumX2)
4419 state->sumX2 += newval * newval;
4420
4421 state->sumX += newval;
4422 state->N++;
4423 }
4424
4425 /*
4426 * Remove an input value from the aggregated state.
4427 */
4428 static void
do_int128_discard(Int128AggState * state,int128 newval)4429 do_int128_discard(Int128AggState *state, int128 newval)
4430 {
4431 if (state->calcSumX2)
4432 state->sumX2 -= newval * newval;
4433
4434 state->sumX -= newval;
4435 state->N--;
4436 }
4437
4438 typedef Int128AggState PolyNumAggState;
4439 #define makePolyNumAggState makeInt128AggState
4440 #define makePolyNumAggStateCurrentContext makeInt128AggStateCurrentContext
4441 #else
4442 typedef NumericAggState PolyNumAggState;
4443 #define makePolyNumAggState makeNumericAggState
4444 #define makePolyNumAggStateCurrentContext makeNumericAggStateCurrentContext
4445 #endif
4446
4447 Datum
int2_accum(PG_FUNCTION_ARGS)4448 int2_accum(PG_FUNCTION_ARGS)
4449 {
4450 PolyNumAggState *state;
4451
4452 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4453
4454 /* Create the state data on the first call */
4455 if (state == NULL)
4456 state = makePolyNumAggState(fcinfo, true);
4457
4458 if (!PG_ARGISNULL(1))
4459 {
4460 #ifdef HAVE_INT128
4461 do_int128_accum(state, (int128) PG_GETARG_INT16(1));
4462 #else
4463 Numeric newval;
4464
4465 newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4466 PG_GETARG_DATUM(1)));
4467 do_numeric_accum(state, newval);
4468 #endif
4469 }
4470
4471 PG_RETURN_POINTER(state);
4472 }
4473
4474 Datum
int4_accum(PG_FUNCTION_ARGS)4475 int4_accum(PG_FUNCTION_ARGS)
4476 {
4477 PolyNumAggState *state;
4478
4479 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4480
4481 /* Create the state data on the first call */
4482 if (state == NULL)
4483 state = makePolyNumAggState(fcinfo, true);
4484
4485 if (!PG_ARGISNULL(1))
4486 {
4487 #ifdef HAVE_INT128
4488 do_int128_accum(state, (int128) PG_GETARG_INT32(1));
4489 #else
4490 Numeric newval;
4491
4492 newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4493 PG_GETARG_DATUM(1)));
4494 do_numeric_accum(state, newval);
4495 #endif
4496 }
4497
4498 PG_RETURN_POINTER(state);
4499 }
4500
4501 Datum
int8_accum(PG_FUNCTION_ARGS)4502 int8_accum(PG_FUNCTION_ARGS)
4503 {
4504 NumericAggState *state;
4505
4506 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4507
4508 /* Create the state data on the first call */
4509 if (state == NULL)
4510 state = makeNumericAggState(fcinfo, true);
4511
4512 if (!PG_ARGISNULL(1))
4513 {
4514 Numeric newval;
4515
4516 newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4517 PG_GETARG_DATUM(1)));
4518 do_numeric_accum(state, newval);
4519 }
4520
4521 PG_RETURN_POINTER(state);
4522 }
4523
4524 /*
4525 * Combine function for numeric aggregates which require sumX2
4526 */
4527 Datum
numeric_poly_combine(PG_FUNCTION_ARGS)4528 numeric_poly_combine(PG_FUNCTION_ARGS)
4529 {
4530 PolyNumAggState *state1;
4531 PolyNumAggState *state2;
4532 MemoryContext agg_context;
4533 MemoryContext old_context;
4534
4535 if (!AggCheckCallContext(fcinfo, &agg_context))
4536 elog(ERROR, "aggregate function called in non-aggregate context");
4537
4538 state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4539 state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4540
4541 if (state2 == NULL)
4542 PG_RETURN_POINTER(state1);
4543
4544 /* manually copy all fields from state2 to state1 */
4545 if (state1 == NULL)
4546 {
4547 old_context = MemoryContextSwitchTo(agg_context);
4548
4549 state1 = makePolyNumAggState(fcinfo, true);
4550 state1->N = state2->N;
4551
4552 #ifdef HAVE_INT128
4553 state1->sumX = state2->sumX;
4554 state1->sumX2 = state2->sumX2;
4555 #else
4556 accum_sum_copy(&state1->sumX, &state2->sumX);
4557 accum_sum_copy(&state1->sumX2, &state2->sumX2);
4558 #endif
4559
4560 MemoryContextSwitchTo(old_context);
4561
4562 PG_RETURN_POINTER(state1);
4563 }
4564
4565 if (state2->N > 0)
4566 {
4567 state1->N += state2->N;
4568
4569 #ifdef HAVE_INT128
4570 state1->sumX += state2->sumX;
4571 state1->sumX2 += state2->sumX2;
4572 #else
4573 /* The rest of this needs to work in the aggregate context */
4574 old_context = MemoryContextSwitchTo(agg_context);
4575
4576 /* Accumulate sums */
4577 accum_sum_combine(&state1->sumX, &state2->sumX);
4578 accum_sum_combine(&state1->sumX2, &state2->sumX2);
4579
4580 MemoryContextSwitchTo(old_context);
4581 #endif
4582
4583 }
4584 PG_RETURN_POINTER(state1);
4585 }
4586
4587 /*
4588 * numeric_poly_serialize
4589 * Serialize PolyNumAggState into bytea for aggregate functions which
4590 * require sumX2.
4591 */
4592 Datum
numeric_poly_serialize(PG_FUNCTION_ARGS)4593 numeric_poly_serialize(PG_FUNCTION_ARGS)
4594 {
4595 PolyNumAggState *state;
4596 StringInfoData buf;
4597 bytea *sumX;
4598 bytea *sumX2;
4599 bytea *result;
4600
4601 /* Ensure we disallow calling when not in aggregate context */
4602 if (!AggCheckCallContext(fcinfo, NULL))
4603 elog(ERROR, "aggregate function called in non-aggregate context");
4604
4605 state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4606
4607 /*
4608 * If the platform supports int128 then sumX and sumX2 will be a 128 bit
4609 * integer type. Here we'll convert that into a numeric type so that the
4610 * combine state is in the same format for both int128 enabled machines
4611 * and machines which don't support that type. The logic here is that one
4612 * day we might like to send these over to another server for further
4613 * processing and we want a standard format to work with.
4614 */
4615 {
4616 Datum temp;
4617 NumericVar num;
4618
4619 init_var(&num);
4620
4621 #ifdef HAVE_INT128
4622 int128_to_numericvar(state->sumX, &num);
4623 #else
4624 accum_sum_final(&state->sumX, &num);
4625 #endif
4626 temp = DirectFunctionCall1(numeric_send,
4627 NumericGetDatum(make_result(&num)));
4628 sumX = DatumGetByteaPP(temp);
4629
4630 #ifdef HAVE_INT128
4631 int128_to_numericvar(state->sumX2, &num);
4632 #else
4633 accum_sum_final(&state->sumX2, &num);
4634 #endif
4635 temp = DirectFunctionCall1(numeric_send,
4636 NumericGetDatum(make_result(&num)));
4637 sumX2 = DatumGetByteaPP(temp);
4638
4639 free_var(&num);
4640 }
4641
4642 pq_begintypsend(&buf);
4643
4644 /* N */
4645 pq_sendint64(&buf, state->N);
4646
4647 /* sumX */
4648 pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4649
4650 /* sumX2 */
4651 pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
4652
4653 result = pq_endtypsend(&buf);
4654
4655 PG_RETURN_BYTEA_P(result);
4656 }
4657
4658 /*
4659 * numeric_poly_deserialize
4660 * Deserialize PolyNumAggState from bytea for aggregate functions which
4661 * require sumX2.
4662 */
4663 Datum
numeric_poly_deserialize(PG_FUNCTION_ARGS)4664 numeric_poly_deserialize(PG_FUNCTION_ARGS)
4665 {
4666 bytea *sstate;
4667 PolyNumAggState *result;
4668 Datum sumX;
4669 NumericVar sumX_var;
4670 Datum sumX2;
4671 NumericVar sumX2_var;
4672 StringInfoData buf;
4673
4674 if (!AggCheckCallContext(fcinfo, NULL))
4675 elog(ERROR, "aggregate function called in non-aggregate context");
4676
4677 sstate = PG_GETARG_BYTEA_PP(0);
4678
4679 /*
4680 * Copy the bytea into a StringInfo so that we can "receive" it using the
4681 * standard recv-function infrastructure.
4682 */
4683 initStringInfo(&buf);
4684 appendBinaryStringInfo(&buf,
4685 VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4686
4687 result = makePolyNumAggStateCurrentContext(false);
4688
4689 /* N */
4690 result->N = pq_getmsgint64(&buf);
4691
4692 /* sumX */
4693 sumX = DirectFunctionCall3(numeric_recv,
4694 PointerGetDatum(&buf),
4695 ObjectIdGetDatum(InvalidOid),
4696 Int32GetDatum(-1));
4697
4698 /* sumX2 */
4699 sumX2 = DirectFunctionCall3(numeric_recv,
4700 PointerGetDatum(&buf),
4701 ObjectIdGetDatum(InvalidOid),
4702 Int32GetDatum(-1));
4703
4704 init_var_from_num(DatumGetNumeric(sumX), &sumX_var);
4705 #ifdef HAVE_INT128
4706 numericvar_to_int128(&sumX_var, &result->sumX);
4707 #else
4708 accum_sum_add(&result->sumX, &sumX_var);
4709 #endif
4710
4711 init_var_from_num(DatumGetNumeric(sumX2), &sumX2_var);
4712 #ifdef HAVE_INT128
4713 numericvar_to_int128(&sumX2_var, &result->sumX2);
4714 #else
4715 accum_sum_add(&result->sumX2, &sumX2_var);
4716 #endif
4717
4718 pq_getmsgend(&buf);
4719 pfree(buf.data);
4720
4721 PG_RETURN_POINTER(result);
4722 }
4723
4724 /*
4725 * Transition function for int8 input when we don't need sumX2.
4726 */
4727 Datum
int8_avg_accum(PG_FUNCTION_ARGS)4728 int8_avg_accum(PG_FUNCTION_ARGS)
4729 {
4730 PolyNumAggState *state;
4731
4732 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4733
4734 /* Create the state data on the first call */
4735 if (state == NULL)
4736 state = makePolyNumAggState(fcinfo, false);
4737
4738 if (!PG_ARGISNULL(1))
4739 {
4740 #ifdef HAVE_INT128
4741 do_int128_accum(state, (int128) PG_GETARG_INT64(1));
4742 #else
4743 Numeric newval;
4744
4745 newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
4746 PG_GETARG_DATUM(1)));
4747 do_numeric_accum(state, newval);
4748 #endif
4749 }
4750
4751 PG_RETURN_POINTER(state);
4752 }
4753
4754 /*
4755 * Combine function for PolyNumAggState for aggregates which don't require
4756 * sumX2
4757 */
4758 Datum
int8_avg_combine(PG_FUNCTION_ARGS)4759 int8_avg_combine(PG_FUNCTION_ARGS)
4760 {
4761 PolyNumAggState *state1;
4762 PolyNumAggState *state2;
4763 MemoryContext agg_context;
4764 MemoryContext old_context;
4765
4766 if (!AggCheckCallContext(fcinfo, &agg_context))
4767 elog(ERROR, "aggregate function called in non-aggregate context");
4768
4769 state1 = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4770 state2 = PG_ARGISNULL(1) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(1);
4771
4772 if (state2 == NULL)
4773 PG_RETURN_POINTER(state1);
4774
4775 /* manually copy all fields from state2 to state1 */
4776 if (state1 == NULL)
4777 {
4778 old_context = MemoryContextSwitchTo(agg_context);
4779
4780 state1 = makePolyNumAggState(fcinfo, false);
4781 state1->N = state2->N;
4782
4783 #ifdef HAVE_INT128
4784 state1->sumX = state2->sumX;
4785 #else
4786 accum_sum_copy(&state1->sumX, &state2->sumX);
4787 #endif
4788 MemoryContextSwitchTo(old_context);
4789
4790 PG_RETURN_POINTER(state1);
4791 }
4792
4793 if (state2->N > 0)
4794 {
4795 state1->N += state2->N;
4796
4797 #ifdef HAVE_INT128
4798 state1->sumX += state2->sumX;
4799 #else
4800 /* The rest of this needs to work in the aggregate context */
4801 old_context = MemoryContextSwitchTo(agg_context);
4802
4803 /* Accumulate sums */
4804 accum_sum_combine(&state1->sumX, &state2->sumX);
4805
4806 MemoryContextSwitchTo(old_context);
4807 #endif
4808
4809 }
4810 PG_RETURN_POINTER(state1);
4811 }
4812
4813 /*
4814 * int8_avg_serialize
4815 * Serialize PolyNumAggState into bytea using the standard
4816 * recv-function infrastructure.
4817 */
4818 Datum
int8_avg_serialize(PG_FUNCTION_ARGS)4819 int8_avg_serialize(PG_FUNCTION_ARGS)
4820 {
4821 PolyNumAggState *state;
4822 StringInfoData buf;
4823 bytea *sumX;
4824 bytea *result;
4825
4826 /* Ensure we disallow calling when not in aggregate context */
4827 if (!AggCheckCallContext(fcinfo, NULL))
4828 elog(ERROR, "aggregate function called in non-aggregate context");
4829
4830 state = (PolyNumAggState *) PG_GETARG_POINTER(0);
4831
4832 /*
4833 * If the platform supports int128 then sumX will be a 128 integer type.
4834 * Here we'll convert that into a numeric type so that the combine state
4835 * is in the same format for both int128 enabled machines and machines
4836 * which don't support that type. The logic here is that one day we might
4837 * like to send these over to another server for further processing and we
4838 * want a standard format to work with.
4839 */
4840 {
4841 Datum temp;
4842 NumericVar num;
4843
4844 init_var(&num);
4845
4846 #ifdef HAVE_INT128
4847 int128_to_numericvar(state->sumX, &num);
4848 #else
4849 accum_sum_final(&state->sumX, &num);
4850 #endif
4851 temp = DirectFunctionCall1(numeric_send,
4852 NumericGetDatum(make_result(&num)));
4853 sumX = DatumGetByteaPP(temp);
4854
4855 free_var(&num);
4856 }
4857
4858 pq_begintypsend(&buf);
4859
4860 /* N */
4861 pq_sendint64(&buf, state->N);
4862
4863 /* sumX */
4864 pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
4865
4866 result = pq_endtypsend(&buf);
4867
4868 PG_RETURN_BYTEA_P(result);
4869 }
4870
4871 /*
4872 * int8_avg_deserialize
4873 * Deserialize bytea back into PolyNumAggState.
4874 */
4875 Datum
int8_avg_deserialize(PG_FUNCTION_ARGS)4876 int8_avg_deserialize(PG_FUNCTION_ARGS)
4877 {
4878 bytea *sstate;
4879 PolyNumAggState *result;
4880 StringInfoData buf;
4881 Datum temp;
4882 NumericVar num;
4883
4884 if (!AggCheckCallContext(fcinfo, NULL))
4885 elog(ERROR, "aggregate function called in non-aggregate context");
4886
4887 sstate = PG_GETARG_BYTEA_PP(0);
4888
4889 /*
4890 * Copy the bytea into a StringInfo so that we can "receive" it using the
4891 * standard recv-function infrastructure.
4892 */
4893 initStringInfo(&buf);
4894 appendBinaryStringInfo(&buf,
4895 VARDATA_ANY(sstate), VARSIZE_ANY_EXHDR(sstate));
4896
4897 result = makePolyNumAggStateCurrentContext(false);
4898
4899 /* N */
4900 result->N = pq_getmsgint64(&buf);
4901
4902 /* sumX */
4903 temp = DirectFunctionCall3(numeric_recv,
4904 PointerGetDatum(&buf),
4905 ObjectIdGetDatum(InvalidOid),
4906 Int32GetDatum(-1));
4907 init_var_from_num(DatumGetNumeric(temp), &num);
4908 #ifdef HAVE_INT128
4909 numericvar_to_int128(&num, &result->sumX);
4910 #else
4911 accum_sum_add(&result->sumX, &num);
4912 #endif
4913
4914 pq_getmsgend(&buf);
4915 pfree(buf.data);
4916
4917 PG_RETURN_POINTER(result);
4918 }
4919
4920 /*
4921 * Inverse transition functions to go with the above.
4922 */
4923
4924 Datum
int2_accum_inv(PG_FUNCTION_ARGS)4925 int2_accum_inv(PG_FUNCTION_ARGS)
4926 {
4927 PolyNumAggState *state;
4928
4929 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4930
4931 /* Should not get here with no state */
4932 if (state == NULL)
4933 elog(ERROR, "int2_accum_inv called with NULL state");
4934
4935 if (!PG_ARGISNULL(1))
4936 {
4937 #ifdef HAVE_INT128
4938 do_int128_discard(state, (int128) PG_GETARG_INT16(1));
4939 #else
4940 Numeric newval;
4941
4942 newval = DatumGetNumeric(DirectFunctionCall1(int2_numeric,
4943 PG_GETARG_DATUM(1)));
4944
4945 /* Should never fail, all inputs have dscale 0 */
4946 if (!do_numeric_discard(state, newval))
4947 elog(ERROR, "do_numeric_discard failed unexpectedly");
4948 #endif
4949 }
4950
4951 PG_RETURN_POINTER(state);
4952 }
4953
4954 Datum
int4_accum_inv(PG_FUNCTION_ARGS)4955 int4_accum_inv(PG_FUNCTION_ARGS)
4956 {
4957 PolyNumAggState *state;
4958
4959 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
4960
4961 /* Should not get here with no state */
4962 if (state == NULL)
4963 elog(ERROR, "int4_accum_inv called with NULL state");
4964
4965 if (!PG_ARGISNULL(1))
4966 {
4967 #ifdef HAVE_INT128
4968 do_int128_discard(state, (int128) PG_GETARG_INT32(1));
4969 #else
4970 Numeric newval;
4971
4972 newval = DatumGetNumeric(DirectFunctionCall1(int4_numeric,
4973 PG_GETARG_DATUM(1)));
4974
4975 /* Should never fail, all inputs have dscale 0 */
4976 if (!do_numeric_discard(state, newval))
4977 elog(ERROR, "do_numeric_discard failed unexpectedly");
4978 #endif
4979 }
4980
4981 PG_RETURN_POINTER(state);
4982 }
4983
4984 Datum
int8_accum_inv(PG_FUNCTION_ARGS)4985 int8_accum_inv(PG_FUNCTION_ARGS)
4986 {
4987 NumericAggState *state;
4988
4989 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
4990
4991 /* Should not get here with no state */
4992 if (state == NULL)
4993 elog(ERROR, "int8_accum_inv called with NULL state");
4994
4995 if (!PG_ARGISNULL(1))
4996 {
4997 Numeric newval;
4998
4999 newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
5000 PG_GETARG_DATUM(1)));
5001
5002 /* Should never fail, all inputs have dscale 0 */
5003 if (!do_numeric_discard(state, newval))
5004 elog(ERROR, "do_numeric_discard failed unexpectedly");
5005 }
5006
5007 PG_RETURN_POINTER(state);
5008 }
5009
5010 Datum
int8_avg_accum_inv(PG_FUNCTION_ARGS)5011 int8_avg_accum_inv(PG_FUNCTION_ARGS)
5012 {
5013 PolyNumAggState *state;
5014
5015 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5016
5017 /* Should not get here with no state */
5018 if (state == NULL)
5019 elog(ERROR, "int8_avg_accum_inv called with NULL state");
5020
5021 if (!PG_ARGISNULL(1))
5022 {
5023 #ifdef HAVE_INT128
5024 do_int128_discard(state, (int128) PG_GETARG_INT64(1));
5025 #else
5026 Numeric newval;
5027
5028 newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric,
5029 PG_GETARG_DATUM(1)));
5030
5031 /* Should never fail, all inputs have dscale 0 */
5032 if (!do_numeric_discard(state, newval))
5033 elog(ERROR, "do_numeric_discard failed unexpectedly");
5034 #endif
5035 }
5036
5037 PG_RETURN_POINTER(state);
5038 }
5039
5040 Datum
numeric_poly_sum(PG_FUNCTION_ARGS)5041 numeric_poly_sum(PG_FUNCTION_ARGS)
5042 {
5043 #ifdef HAVE_INT128
5044 PolyNumAggState *state;
5045 Numeric res;
5046 NumericVar result;
5047
5048 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5049
5050 /* If there were no non-null inputs, return NULL */
5051 if (state == NULL || state->N == 0)
5052 PG_RETURN_NULL();
5053
5054 init_var(&result);
5055
5056 int128_to_numericvar(state->sumX, &result);
5057
5058 res = make_result(&result);
5059
5060 free_var(&result);
5061
5062 PG_RETURN_NUMERIC(res);
5063 #else
5064 return numeric_sum(fcinfo);
5065 #endif
5066 }
5067
5068 Datum
numeric_poly_avg(PG_FUNCTION_ARGS)5069 numeric_poly_avg(PG_FUNCTION_ARGS)
5070 {
5071 #ifdef HAVE_INT128
5072 PolyNumAggState *state;
5073 NumericVar result;
5074 Datum countd,
5075 sumd;
5076
5077 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5078
5079 /* If there were no non-null inputs, return NULL */
5080 if (state == NULL || state->N == 0)
5081 PG_RETURN_NULL();
5082
5083 init_var(&result);
5084
5085 int128_to_numericvar(state->sumX, &result);
5086
5087 countd = DirectFunctionCall1(int8_numeric,
5088 Int64GetDatumFast(state->N));
5089 sumd = NumericGetDatum(make_result(&result));
5090
5091 free_var(&result);
5092
5093 PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
5094 #else
5095 return numeric_avg(fcinfo);
5096 #endif
5097 }
5098
5099 Datum
numeric_avg(PG_FUNCTION_ARGS)5100 numeric_avg(PG_FUNCTION_ARGS)
5101 {
5102 NumericAggState *state;
5103 Datum N_datum;
5104 Datum sumX_datum;
5105 NumericVar sumX_var;
5106
5107 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5108
5109 /* If there were no non-null inputs, return NULL */
5110 if (state == NULL || (state->N + state->NaNcount) == 0)
5111 PG_RETURN_NULL();
5112
5113 if (state->NaNcount > 0) /* there was at least one NaN input */
5114 PG_RETURN_NUMERIC(make_result(&const_nan));
5115
5116 N_datum = DirectFunctionCall1(int8_numeric, Int64GetDatum(state->N));
5117
5118 init_var(&sumX_var);
5119 accum_sum_final(&state->sumX, &sumX_var);
5120 sumX_datum = NumericGetDatum(make_result(&sumX_var));
5121 free_var(&sumX_var);
5122
5123 PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumX_datum, N_datum));
5124 }
5125
5126 Datum
numeric_sum(PG_FUNCTION_ARGS)5127 numeric_sum(PG_FUNCTION_ARGS)
5128 {
5129 NumericAggState *state;
5130 NumericVar sumX_var;
5131 Numeric result;
5132
5133 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5134
5135 /* If there were no non-null inputs, return NULL */
5136 if (state == NULL || (state->N + state->NaNcount) == 0)
5137 PG_RETURN_NULL();
5138
5139 if (state->NaNcount > 0) /* there was at least one NaN input */
5140 PG_RETURN_NUMERIC(make_result(&const_nan));
5141
5142 init_var(&sumX_var);
5143 accum_sum_final(&state->sumX, &sumX_var);
5144 result = make_result(&sumX_var);
5145 free_var(&sumX_var);
5146
5147 PG_RETURN_NUMERIC(result);
5148 }
5149
5150 /*
5151 * Workhorse routine for the standard deviance and variance
5152 * aggregates. 'state' is aggregate's transition state.
5153 * 'variance' specifies whether we should calculate the
5154 * variance or the standard deviation. 'sample' indicates whether the
5155 * caller is interested in the sample or the population
5156 * variance/stddev.
5157 *
5158 * If appropriate variance statistic is undefined for the input,
5159 * *is_null is set to true and NULL is returned.
5160 */
5161 static Numeric
numeric_stddev_internal(NumericAggState * state,bool variance,bool sample,bool * is_null)5162 numeric_stddev_internal(NumericAggState *state,
5163 bool variance, bool sample,
5164 bool *is_null)
5165 {
5166 Numeric res;
5167 NumericVar vN,
5168 vsumX,
5169 vsumX2,
5170 vNminus1;
5171 const NumericVar *comp;
5172 int rscale;
5173
5174 /* Deal with empty input and NaN-input cases */
5175 if (state == NULL || (state->N + state->NaNcount) == 0)
5176 {
5177 *is_null = true;
5178 return NULL;
5179 }
5180
5181 *is_null = false;
5182
5183 if (state->NaNcount > 0)
5184 return make_result(&const_nan);
5185
5186 init_var(&vN);
5187 init_var(&vsumX);
5188 init_var(&vsumX2);
5189
5190 int64_to_numericvar(state->N, &vN);
5191 accum_sum_final(&(state->sumX), &vsumX);
5192 accum_sum_final(&(state->sumX2), &vsumX2);
5193
5194 /*
5195 * Sample stddev and variance are undefined when N <= 1; population stddev
5196 * is undefined when N == 0. Return NULL in either case.
5197 */
5198 if (sample)
5199 comp = &const_one;
5200 else
5201 comp = &const_zero;
5202
5203 if (cmp_var(&vN, comp) <= 0)
5204 {
5205 *is_null = true;
5206 return NULL;
5207 }
5208
5209 init_var(&vNminus1);
5210 sub_var(&vN, &const_one, &vNminus1);
5211
5212 /* compute rscale for mul_var calls */
5213 rscale = vsumX.dscale * 2;
5214
5215 mul_var(&vsumX, &vsumX, &vsumX, rscale); /* vsumX = sumX * sumX */
5216 mul_var(&vN, &vsumX2, &vsumX2, rscale); /* vsumX2 = N * sumX2 */
5217 sub_var(&vsumX2, &vsumX, &vsumX2); /* N * sumX2 - sumX * sumX */
5218
5219 if (cmp_var(&vsumX2, &const_zero) <= 0)
5220 {
5221 /* Watch out for roundoff error producing a negative numerator */
5222 res = make_result(&const_zero);
5223 }
5224 else
5225 {
5226 if (sample)
5227 mul_var(&vN, &vNminus1, &vNminus1, 0); /* N * (N - 1) */
5228 else
5229 mul_var(&vN, &vN, &vNminus1, 0); /* N * N */
5230 rscale = select_div_scale(&vsumX2, &vNminus1);
5231 div_var(&vsumX2, &vNminus1, &vsumX, rscale, true); /* variance */
5232 if (!variance)
5233 sqrt_var(&vsumX, &vsumX, rscale); /* stddev */
5234
5235 res = make_result(&vsumX);
5236 }
5237
5238 free_var(&vNminus1);
5239 free_var(&vsumX);
5240 free_var(&vsumX2);
5241
5242 return res;
5243 }
5244
5245 Datum
numeric_var_samp(PG_FUNCTION_ARGS)5246 numeric_var_samp(PG_FUNCTION_ARGS)
5247 {
5248 NumericAggState *state;
5249 Numeric res;
5250 bool is_null;
5251
5252 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5253
5254 res = numeric_stddev_internal(state, true, true, &is_null);
5255
5256 if (is_null)
5257 PG_RETURN_NULL();
5258 else
5259 PG_RETURN_NUMERIC(res);
5260 }
5261
5262 Datum
numeric_stddev_samp(PG_FUNCTION_ARGS)5263 numeric_stddev_samp(PG_FUNCTION_ARGS)
5264 {
5265 NumericAggState *state;
5266 Numeric res;
5267 bool is_null;
5268
5269 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5270
5271 res = numeric_stddev_internal(state, false, true, &is_null);
5272
5273 if (is_null)
5274 PG_RETURN_NULL();
5275 else
5276 PG_RETURN_NUMERIC(res);
5277 }
5278
5279 Datum
numeric_var_pop(PG_FUNCTION_ARGS)5280 numeric_var_pop(PG_FUNCTION_ARGS)
5281 {
5282 NumericAggState *state;
5283 Numeric res;
5284 bool is_null;
5285
5286 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5287
5288 res = numeric_stddev_internal(state, true, false, &is_null);
5289
5290 if (is_null)
5291 PG_RETURN_NULL();
5292 else
5293 PG_RETURN_NUMERIC(res);
5294 }
5295
5296 Datum
numeric_stddev_pop(PG_FUNCTION_ARGS)5297 numeric_stddev_pop(PG_FUNCTION_ARGS)
5298 {
5299 NumericAggState *state;
5300 Numeric res;
5301 bool is_null;
5302
5303 state = PG_ARGISNULL(0) ? NULL : (NumericAggState *) PG_GETARG_POINTER(0);
5304
5305 res = numeric_stddev_internal(state, false, false, &is_null);
5306
5307 if (is_null)
5308 PG_RETURN_NULL();
5309 else
5310 PG_RETURN_NUMERIC(res);
5311 }
5312
5313 #ifdef HAVE_INT128
5314 static Numeric
numeric_poly_stddev_internal(Int128AggState * state,bool variance,bool sample,bool * is_null)5315 numeric_poly_stddev_internal(Int128AggState *state,
5316 bool variance, bool sample,
5317 bool *is_null)
5318 {
5319 NumericAggState numstate;
5320 Numeric res;
5321
5322 /* Initialize an empty agg state */
5323 memset(&numstate, 0, sizeof(NumericAggState));
5324
5325 if (state)
5326 {
5327 NumericVar tmp_var;
5328
5329 numstate.N = state->N;
5330
5331 init_var(&tmp_var);
5332
5333 int128_to_numericvar(state->sumX, &tmp_var);
5334 accum_sum_add(&numstate.sumX, &tmp_var);
5335
5336 int128_to_numericvar(state->sumX2, &tmp_var);
5337 accum_sum_add(&numstate.sumX2, &tmp_var);
5338
5339 free_var(&tmp_var);
5340 }
5341
5342 res = numeric_stddev_internal(&numstate, variance, sample, is_null);
5343
5344 if (numstate.sumX.ndigits > 0)
5345 {
5346 pfree(numstate.sumX.pos_digits);
5347 pfree(numstate.sumX.neg_digits);
5348 }
5349 if (numstate.sumX2.ndigits > 0)
5350 {
5351 pfree(numstate.sumX2.pos_digits);
5352 pfree(numstate.sumX2.neg_digits);
5353 }
5354
5355 return res;
5356 }
5357 #endif
5358
5359 Datum
numeric_poly_var_samp(PG_FUNCTION_ARGS)5360 numeric_poly_var_samp(PG_FUNCTION_ARGS)
5361 {
5362 #ifdef HAVE_INT128
5363 PolyNumAggState *state;
5364 Numeric res;
5365 bool is_null;
5366
5367 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5368
5369 res = numeric_poly_stddev_internal(state, true, true, &is_null);
5370
5371 if (is_null)
5372 PG_RETURN_NULL();
5373 else
5374 PG_RETURN_NUMERIC(res);
5375 #else
5376 return numeric_var_samp(fcinfo);
5377 #endif
5378 }
5379
5380 Datum
numeric_poly_stddev_samp(PG_FUNCTION_ARGS)5381 numeric_poly_stddev_samp(PG_FUNCTION_ARGS)
5382 {
5383 #ifdef HAVE_INT128
5384 PolyNumAggState *state;
5385 Numeric res;
5386 bool is_null;
5387
5388 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5389
5390 res = numeric_poly_stddev_internal(state, false, true, &is_null);
5391
5392 if (is_null)
5393 PG_RETURN_NULL();
5394 else
5395 PG_RETURN_NUMERIC(res);
5396 #else
5397 return numeric_stddev_samp(fcinfo);
5398 #endif
5399 }
5400
5401 Datum
numeric_poly_var_pop(PG_FUNCTION_ARGS)5402 numeric_poly_var_pop(PG_FUNCTION_ARGS)
5403 {
5404 #ifdef HAVE_INT128
5405 PolyNumAggState *state;
5406 Numeric res;
5407 bool is_null;
5408
5409 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5410
5411 res = numeric_poly_stddev_internal(state, true, false, &is_null);
5412
5413 if (is_null)
5414 PG_RETURN_NULL();
5415 else
5416 PG_RETURN_NUMERIC(res);
5417 #else
5418 return numeric_var_pop(fcinfo);
5419 #endif
5420 }
5421
5422 Datum
numeric_poly_stddev_pop(PG_FUNCTION_ARGS)5423 numeric_poly_stddev_pop(PG_FUNCTION_ARGS)
5424 {
5425 #ifdef HAVE_INT128
5426 PolyNumAggState *state;
5427 Numeric res;
5428 bool is_null;
5429
5430 state = PG_ARGISNULL(0) ? NULL : (PolyNumAggState *) PG_GETARG_POINTER(0);
5431
5432 res = numeric_poly_stddev_internal(state, false, false, &is_null);
5433
5434 if (is_null)
5435 PG_RETURN_NULL();
5436 else
5437 PG_RETURN_NUMERIC(res);
5438 #else
5439 return numeric_stddev_pop(fcinfo);
5440 #endif
5441 }
5442
5443 /*
5444 * SUM transition functions for integer datatypes.
5445 *
5446 * To avoid overflow, we use accumulators wider than the input datatype.
5447 * A Numeric accumulator is needed for int8 input; for int4 and int2
5448 * inputs, we use int8 accumulators which should be sufficient for practical
5449 * purposes. (The latter two therefore don't really belong in this file,
5450 * but we keep them here anyway.)
5451 *
5452 * Because SQL defines the SUM() of no values to be NULL, not zero,
5453 * the initial condition of the transition data value needs to be NULL. This
5454 * means we can't rely on ExecAgg to automatically insert the first non-null
5455 * data value into the transition data: it doesn't know how to do the type
5456 * conversion. The upshot is that these routines have to be marked non-strict
5457 * and handle substitution of the first non-null input themselves.
5458 *
5459 * Note: these functions are used only in plain aggregation mode.
5460 * In moving-aggregate mode, we use intX_avg_accum and intX_avg_accum_inv.
5461 */
5462
5463 Datum
int2_sum(PG_FUNCTION_ARGS)5464 int2_sum(PG_FUNCTION_ARGS)
5465 {
5466 int64 newval;
5467
5468 if (PG_ARGISNULL(0))
5469 {
5470 /* No non-null input seen so far... */
5471 if (PG_ARGISNULL(1))
5472 PG_RETURN_NULL(); /* still no non-null */
5473 /* This is the first non-null input. */
5474 newval = (int64) PG_GETARG_INT16(1);
5475 PG_RETURN_INT64(newval);
5476 }
5477
5478 /*
5479 * If we're invoked as an aggregate, we can cheat and modify our first
5480 * parameter in-place to avoid palloc overhead. If not, we need to return
5481 * the new value of the transition variable. (If int8 is pass-by-value,
5482 * then of course this is useless as well as incorrect, so just ifdef it
5483 * out.)
5484 */
5485 #ifndef USE_FLOAT8_BYVAL /* controls int8 too */
5486 if (AggCheckCallContext(fcinfo, NULL))
5487 {
5488 int64 *oldsum = (int64 *) PG_GETARG_POINTER(0);
5489
5490 /* Leave the running sum unchanged in the new input is null */
5491 if (!PG_ARGISNULL(1))
5492 *oldsum = *oldsum + (int64) PG_GETARG_INT16(1);
5493
5494 PG_RETURN_POINTER(oldsum);
5495 }
5496 else
5497 #endif
5498 {
5499 int64 oldsum = PG_GETARG_INT64(0);
5500
5501 /* Leave sum unchanged if new input is null. */
5502 if (PG_ARGISNULL(1))
5503 PG_RETURN_INT64(oldsum);
5504
5505 /* OK to do the addition. */
5506 newval = oldsum + (int64) PG_GETARG_INT16(1);
5507
5508 PG_RETURN_INT64(newval);
5509 }
5510 }
5511
5512 Datum
int4_sum(PG_FUNCTION_ARGS)5513 int4_sum(PG_FUNCTION_ARGS)
5514 {
5515 int64 newval;
5516
5517 if (PG_ARGISNULL(0))
5518 {
5519 /* No non-null input seen so far... */
5520 if (PG_ARGISNULL(1))
5521 PG_RETURN_NULL(); /* still no non-null */
5522 /* This is the first non-null input. */
5523 newval = (int64) PG_GETARG_INT32(1);
5524 PG_RETURN_INT64(newval);
5525 }
5526
5527 /*
5528 * If we're invoked as an aggregate, we can cheat and modify our first
5529 * parameter in-place to avoid palloc overhead. If not, we need to return
5530 * the new value of the transition variable. (If int8 is pass-by-value,
5531 * then of course this is useless as well as incorrect, so just ifdef it
5532 * out.)
5533 */
5534 #ifndef USE_FLOAT8_BYVAL /* controls int8 too */
5535 if (AggCheckCallContext(fcinfo, NULL))
5536 {
5537 int64 *oldsum = (int64 *) PG_GETARG_POINTER(0);
5538
5539 /* Leave the running sum unchanged in the new input is null */
5540 if (!PG_ARGISNULL(1))
5541 *oldsum = *oldsum + (int64) PG_GETARG_INT32(1);
5542
5543 PG_RETURN_POINTER(oldsum);
5544 }
5545 else
5546 #endif
5547 {
5548 int64 oldsum = PG_GETARG_INT64(0);
5549
5550 /* Leave sum unchanged if new input is null. */
5551 if (PG_ARGISNULL(1))
5552 PG_RETURN_INT64(oldsum);
5553
5554 /* OK to do the addition. */
5555 newval = oldsum + (int64) PG_GETARG_INT32(1);
5556
5557 PG_RETURN_INT64(newval);
5558 }
5559 }
5560
5561 /*
5562 * Note: this function is obsolete, it's no longer used for SUM(int8).
5563 */
5564 Datum
int8_sum(PG_FUNCTION_ARGS)5565 int8_sum(PG_FUNCTION_ARGS)
5566 {
5567 Numeric oldsum;
5568 Datum newval;
5569
5570 if (PG_ARGISNULL(0))
5571 {
5572 /* No non-null input seen so far... */
5573 if (PG_ARGISNULL(1))
5574 PG_RETURN_NULL(); /* still no non-null */
5575 /* This is the first non-null input. */
5576 newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5577 PG_RETURN_DATUM(newval);
5578 }
5579
5580 /*
5581 * Note that we cannot special-case the aggregate case here, as we do for
5582 * int2_sum and int4_sum: numeric is of variable size, so we cannot modify
5583 * our first parameter in-place.
5584 */
5585
5586 oldsum = PG_GETARG_NUMERIC(0);
5587
5588 /* Leave sum unchanged if new input is null. */
5589 if (PG_ARGISNULL(1))
5590 PG_RETURN_NUMERIC(oldsum);
5591
5592 /* OK to do the addition. */
5593 newval = DirectFunctionCall1(int8_numeric, PG_GETARG_DATUM(1));
5594
5595 PG_RETURN_DATUM(DirectFunctionCall2(numeric_add,
5596 NumericGetDatum(oldsum), newval));
5597 }
5598
5599
5600 /*
5601 * Routines for avg(int2) and avg(int4). The transition datatype
5602 * is a two-element int8 array, holding count and sum.
5603 *
5604 * These functions are also used for sum(int2) and sum(int4) when
5605 * operating in moving-aggregate mode, since for correct inverse transitions
5606 * we need to count the inputs.
5607 */
5608
5609 typedef struct Int8TransTypeData
5610 {
5611 int64 count;
5612 int64 sum;
5613 } Int8TransTypeData;
5614
5615 Datum
int2_avg_accum(PG_FUNCTION_ARGS)5616 int2_avg_accum(PG_FUNCTION_ARGS)
5617 {
5618 ArrayType *transarray;
5619 int16 newval = PG_GETARG_INT16(1);
5620 Int8TransTypeData *transdata;
5621
5622 /*
5623 * If we're invoked as an aggregate, we can cheat and modify our first
5624 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5625 * a copy of it before scribbling on it.
5626 */
5627 if (AggCheckCallContext(fcinfo, NULL))
5628 transarray = PG_GETARG_ARRAYTYPE_P(0);
5629 else
5630 transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5631
5632 if (ARR_HASNULL(transarray) ||
5633 ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5634 elog(ERROR, "expected 2-element int8 array");
5635
5636 transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5637 transdata->count++;
5638 transdata->sum += newval;
5639
5640 PG_RETURN_ARRAYTYPE_P(transarray);
5641 }
5642
5643 Datum
int4_avg_accum(PG_FUNCTION_ARGS)5644 int4_avg_accum(PG_FUNCTION_ARGS)
5645 {
5646 ArrayType *transarray;
5647 int32 newval = PG_GETARG_INT32(1);
5648 Int8TransTypeData *transdata;
5649
5650 /*
5651 * If we're invoked as an aggregate, we can cheat and modify our first
5652 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5653 * a copy of it before scribbling on it.
5654 */
5655 if (AggCheckCallContext(fcinfo, NULL))
5656 transarray = PG_GETARG_ARRAYTYPE_P(0);
5657 else
5658 transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5659
5660 if (ARR_HASNULL(transarray) ||
5661 ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5662 elog(ERROR, "expected 2-element int8 array");
5663
5664 transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5665 transdata->count++;
5666 transdata->sum += newval;
5667
5668 PG_RETURN_ARRAYTYPE_P(transarray);
5669 }
5670
5671 Datum
int4_avg_combine(PG_FUNCTION_ARGS)5672 int4_avg_combine(PG_FUNCTION_ARGS)
5673 {
5674 ArrayType *transarray1;
5675 ArrayType *transarray2;
5676 Int8TransTypeData *state1;
5677 Int8TransTypeData *state2;
5678
5679 if (!AggCheckCallContext(fcinfo, NULL))
5680 elog(ERROR, "aggregate function called in non-aggregate context");
5681
5682 transarray1 = PG_GETARG_ARRAYTYPE_P(0);
5683 transarray2 = PG_GETARG_ARRAYTYPE_P(1);
5684
5685 if (ARR_HASNULL(transarray1) ||
5686 ARR_SIZE(transarray1) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5687 elog(ERROR, "expected 2-element int8 array");
5688
5689 if (ARR_HASNULL(transarray2) ||
5690 ARR_SIZE(transarray2) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5691 elog(ERROR, "expected 2-element int8 array");
5692
5693 state1 = (Int8TransTypeData *) ARR_DATA_PTR(transarray1);
5694 state2 = (Int8TransTypeData *) ARR_DATA_PTR(transarray2);
5695
5696 state1->count += state2->count;
5697 state1->sum += state2->sum;
5698
5699 PG_RETURN_ARRAYTYPE_P(transarray1);
5700 }
5701
5702 Datum
int2_avg_accum_inv(PG_FUNCTION_ARGS)5703 int2_avg_accum_inv(PG_FUNCTION_ARGS)
5704 {
5705 ArrayType *transarray;
5706 int16 newval = PG_GETARG_INT16(1);
5707 Int8TransTypeData *transdata;
5708
5709 /*
5710 * If we're invoked as an aggregate, we can cheat and modify our first
5711 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5712 * a copy of it before scribbling on it.
5713 */
5714 if (AggCheckCallContext(fcinfo, NULL))
5715 transarray = PG_GETARG_ARRAYTYPE_P(0);
5716 else
5717 transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5718
5719 if (ARR_HASNULL(transarray) ||
5720 ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5721 elog(ERROR, "expected 2-element int8 array");
5722
5723 transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5724 transdata->count--;
5725 transdata->sum -= newval;
5726
5727 PG_RETURN_ARRAYTYPE_P(transarray);
5728 }
5729
5730 Datum
int4_avg_accum_inv(PG_FUNCTION_ARGS)5731 int4_avg_accum_inv(PG_FUNCTION_ARGS)
5732 {
5733 ArrayType *transarray;
5734 int32 newval = PG_GETARG_INT32(1);
5735 Int8TransTypeData *transdata;
5736
5737 /*
5738 * If we're invoked as an aggregate, we can cheat and modify our first
5739 * parameter in-place to reduce palloc overhead. Otherwise we need to make
5740 * a copy of it before scribbling on it.
5741 */
5742 if (AggCheckCallContext(fcinfo, NULL))
5743 transarray = PG_GETARG_ARRAYTYPE_P(0);
5744 else
5745 transarray = PG_GETARG_ARRAYTYPE_P_COPY(0);
5746
5747 if (ARR_HASNULL(transarray) ||
5748 ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5749 elog(ERROR, "expected 2-element int8 array");
5750
5751 transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5752 transdata->count--;
5753 transdata->sum -= newval;
5754
5755 PG_RETURN_ARRAYTYPE_P(transarray);
5756 }
5757
5758 Datum
int8_avg(PG_FUNCTION_ARGS)5759 int8_avg(PG_FUNCTION_ARGS)
5760 {
5761 ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0);
5762 Int8TransTypeData *transdata;
5763 Datum countd,
5764 sumd;
5765
5766 if (ARR_HASNULL(transarray) ||
5767 ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5768 elog(ERROR, "expected 2-element int8 array");
5769 transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5770
5771 /* SQL defines AVG of no values to be NULL */
5772 if (transdata->count == 0)
5773 PG_RETURN_NULL();
5774
5775 countd = DirectFunctionCall1(int8_numeric,
5776 Int64GetDatumFast(transdata->count));
5777 sumd = DirectFunctionCall1(int8_numeric,
5778 Int64GetDatumFast(transdata->sum));
5779
5780 PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, sumd, countd));
5781 }
5782
5783 /*
5784 * SUM(int2) and SUM(int4) both return int8, so we can use this
5785 * final function for both.
5786 */
5787 Datum
int2int4_sum(PG_FUNCTION_ARGS)5788 int2int4_sum(PG_FUNCTION_ARGS)
5789 {
5790 ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0);
5791 Int8TransTypeData *transdata;
5792
5793 if (ARR_HASNULL(transarray) ||
5794 ARR_SIZE(transarray) != ARR_OVERHEAD_NONULLS(1) + sizeof(Int8TransTypeData))
5795 elog(ERROR, "expected 2-element int8 array");
5796 transdata = (Int8TransTypeData *) ARR_DATA_PTR(transarray);
5797
5798 /* SQL defines SUM of no values to be NULL */
5799 if (transdata->count == 0)
5800 PG_RETURN_NULL();
5801
5802 PG_RETURN_DATUM(Int64GetDatumFast(transdata->sum));
5803 }
5804
5805
5806 /* ----------------------------------------------------------------------
5807 *
5808 * Debug support
5809 *
5810 * ----------------------------------------------------------------------
5811 */
5812
5813 #ifdef NUMERIC_DEBUG
5814
5815 /*
5816 * dump_numeric() - Dump a value in the db storage format for debugging
5817 */
5818 static void
dump_numeric(const char * str,Numeric num)5819 dump_numeric(const char *str, Numeric num)
5820 {
5821 NumericDigit *digits = NUMERIC_DIGITS(num);
5822 int ndigits;
5823 int i;
5824
5825 ndigits = NUMERIC_NDIGITS(num);
5826
5827 printf("%s: NUMERIC w=%d d=%d ", str,
5828 NUMERIC_WEIGHT(num), NUMERIC_DSCALE(num));
5829 switch (NUMERIC_SIGN(num))
5830 {
5831 case NUMERIC_POS:
5832 printf("POS");
5833 break;
5834 case NUMERIC_NEG:
5835 printf("NEG");
5836 break;
5837 case NUMERIC_NAN:
5838 printf("NaN");
5839 break;
5840 default:
5841 printf("SIGN=0x%x", NUMERIC_SIGN(num));
5842 break;
5843 }
5844
5845 for (i = 0; i < ndigits; i++)
5846 printf(" %0*d", DEC_DIGITS, digits[i]);
5847 printf("\n");
5848 }
5849
5850
5851 /*
5852 * dump_var() - Dump a value in the variable format for debugging
5853 */
5854 static void
dump_var(const char * str,NumericVar * var)5855 dump_var(const char *str, NumericVar *var)
5856 {
5857 int i;
5858
5859 printf("%s: VAR w=%d d=%d ", str, var->weight, var->dscale);
5860 switch (var->sign)
5861 {
5862 case NUMERIC_POS:
5863 printf("POS");
5864 break;
5865 case NUMERIC_NEG:
5866 printf("NEG");
5867 break;
5868 case NUMERIC_NAN:
5869 printf("NaN");
5870 break;
5871 default:
5872 printf("SIGN=0x%x", var->sign);
5873 break;
5874 }
5875
5876 for (i = 0; i < var->ndigits; i++)
5877 printf(" %0*d", DEC_DIGITS, var->digits[i]);
5878
5879 printf("\n");
5880 }
5881 #endif /* NUMERIC_DEBUG */
5882
5883
5884 /* ----------------------------------------------------------------------
5885 *
5886 * Local functions follow
5887 *
5888 * In general, these do not support NaNs --- callers must eliminate
5889 * the possibility of NaN first. (make_result() is an exception.)
5890 *
5891 * ----------------------------------------------------------------------
5892 */
5893
5894
5895 /*
5896 * alloc_var() -
5897 *
5898 * Allocate a digit buffer of ndigits digits (plus a spare digit for rounding)
5899 */
5900 static void
alloc_var(NumericVar * var,int ndigits)5901 alloc_var(NumericVar *var, int ndigits)
5902 {
5903 digitbuf_free(var->buf);
5904 var->buf = digitbuf_alloc(ndigits + 1);
5905 var->buf[0] = 0; /* spare digit for rounding */
5906 var->digits = var->buf + 1;
5907 var->ndigits = ndigits;
5908 }
5909
5910
5911 /*
5912 * free_var() -
5913 *
5914 * Return the digit buffer of a variable to the free pool
5915 */
5916 static void
free_var(NumericVar * var)5917 free_var(NumericVar *var)
5918 {
5919 digitbuf_free(var->buf);
5920 var->buf = NULL;
5921 var->digits = NULL;
5922 var->sign = NUMERIC_NAN;
5923 }
5924
5925
5926 /*
5927 * zero_var() -
5928 *
5929 * Set a variable to ZERO.
5930 * Note: its dscale is not touched.
5931 */
5932 static void
zero_var(NumericVar * var)5933 zero_var(NumericVar *var)
5934 {
5935 digitbuf_free(var->buf);
5936 var->buf = NULL;
5937 var->digits = NULL;
5938 var->ndigits = 0;
5939 var->weight = 0; /* by convention; doesn't really matter */
5940 var->sign = NUMERIC_POS; /* anything but NAN... */
5941 }
5942
5943
5944 /*
5945 * set_var_from_str()
5946 *
5947 * Parse a string and put the number into a variable
5948 *
5949 * This function does not handle leading or trailing spaces, and it doesn't
5950 * accept "NaN" either. It returns the end+1 position so that caller can
5951 * check for trailing spaces/garbage if deemed necessary.
5952 *
5953 * cp is the place to actually start parsing; str is what to use in error
5954 * reports. (Typically cp would be the same except advanced over spaces.)
5955 */
5956 static const char *
set_var_from_str(const char * str,const char * cp,NumericVar * dest)5957 set_var_from_str(const char *str, const char *cp, NumericVar *dest)
5958 {
5959 bool have_dp = false;
5960 int i;
5961 unsigned char *decdigits;
5962 int sign = NUMERIC_POS;
5963 int dweight = -1;
5964 int ddigits;
5965 int dscale = 0;
5966 int weight;
5967 int ndigits;
5968 int offset;
5969 NumericDigit *digits;
5970
5971 /*
5972 * We first parse the string to extract decimal digits and determine the
5973 * correct decimal weight. Then convert to NBASE representation.
5974 */
5975 switch (*cp)
5976 {
5977 case '+':
5978 sign = NUMERIC_POS;
5979 cp++;
5980 break;
5981
5982 case '-':
5983 sign = NUMERIC_NEG;
5984 cp++;
5985 break;
5986 }
5987
5988 if (*cp == '.')
5989 {
5990 have_dp = true;
5991 cp++;
5992 }
5993
5994 if (!isdigit((unsigned char) *cp))
5995 ereport(ERROR,
5996 (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
5997 errmsg("invalid input syntax for type %s: \"%s\"",
5998 "numeric", str)));
5999
6000 decdigits = (unsigned char *) palloc(strlen(cp) + DEC_DIGITS * 2);
6001
6002 /* leading padding for digit alignment later */
6003 memset(decdigits, 0, DEC_DIGITS);
6004 i = DEC_DIGITS;
6005
6006 while (*cp)
6007 {
6008 if (isdigit((unsigned char) *cp))
6009 {
6010 decdigits[i++] = *cp++ - '0';
6011 if (!have_dp)
6012 dweight++;
6013 else
6014 dscale++;
6015 }
6016 else if (*cp == '.')
6017 {
6018 if (have_dp)
6019 ereport(ERROR,
6020 (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6021 errmsg("invalid input syntax for type %s: \"%s\"",
6022 "numeric", str)));
6023 have_dp = true;
6024 cp++;
6025 }
6026 else
6027 break;
6028 }
6029
6030 ddigits = i - DEC_DIGITS;
6031 /* trailing padding for digit alignment later */
6032 memset(decdigits + i, 0, DEC_DIGITS - 1);
6033
6034 /* Handle exponent, if any */
6035 if (*cp == 'e' || *cp == 'E')
6036 {
6037 long exponent;
6038 char *endptr;
6039
6040 cp++;
6041 exponent = strtol(cp, &endptr, 10);
6042 if (endptr == cp)
6043 ereport(ERROR,
6044 (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6045 errmsg("invalid input syntax for type %s: \"%s\"",
6046 "numeric", str)));
6047 cp = endptr;
6048
6049 /*
6050 * At this point, dweight and dscale can't be more than about
6051 * INT_MAX/2 due to the MaxAllocSize limit on string length, so
6052 * constraining the exponent similarly should be enough to prevent
6053 * integer overflow in this function. If the value is too large to
6054 * fit in storage format, make_result() will complain about it later;
6055 * for consistency use the same ereport errcode/text as make_result().
6056 */
6057 if (exponent >= INT_MAX / 2 || exponent <= -(INT_MAX / 2))
6058 ereport(ERROR,
6059 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6060 errmsg("value overflows numeric format")));
6061 dweight += (int) exponent;
6062 dscale -= (int) exponent;
6063 if (dscale < 0)
6064 dscale = 0;
6065 }
6066
6067 /*
6068 * Okay, convert pure-decimal representation to base NBASE. First we need
6069 * to determine the converted weight and ndigits. offset is the number of
6070 * decimal zeroes to insert before the first given digit to have a
6071 * correctly aligned first NBASE digit.
6072 */
6073 if (dweight >= 0)
6074 weight = (dweight + 1 + DEC_DIGITS - 1) / DEC_DIGITS - 1;
6075 else
6076 weight = -((-dweight - 1) / DEC_DIGITS + 1);
6077 offset = (weight + 1) * DEC_DIGITS - (dweight + 1);
6078 ndigits = (ddigits + offset + DEC_DIGITS - 1) / DEC_DIGITS;
6079
6080 alloc_var(dest, ndigits);
6081 dest->sign = sign;
6082 dest->weight = weight;
6083 dest->dscale = dscale;
6084
6085 i = DEC_DIGITS - offset;
6086 digits = dest->digits;
6087
6088 while (ndigits-- > 0)
6089 {
6090 #if DEC_DIGITS == 4
6091 *digits++ = ((decdigits[i] * 10 + decdigits[i + 1]) * 10 +
6092 decdigits[i + 2]) * 10 + decdigits[i + 3];
6093 #elif DEC_DIGITS == 2
6094 *digits++ = decdigits[i] * 10 + decdigits[i + 1];
6095 #elif DEC_DIGITS == 1
6096 *digits++ = decdigits[i];
6097 #else
6098 #error unsupported NBASE
6099 #endif
6100 i += DEC_DIGITS;
6101 }
6102
6103 pfree(decdigits);
6104
6105 /* Strip any leading/trailing zeroes, and normalize weight if zero */
6106 strip_var(dest);
6107
6108 /* Return end+1 position for caller */
6109 return cp;
6110 }
6111
6112
6113 /*
6114 * set_var_from_num() -
6115 *
6116 * Convert the packed db format into a variable
6117 */
6118 static void
set_var_from_num(Numeric num,NumericVar * dest)6119 set_var_from_num(Numeric num, NumericVar *dest)
6120 {
6121 int ndigits;
6122
6123 ndigits = NUMERIC_NDIGITS(num);
6124
6125 alloc_var(dest, ndigits);
6126
6127 dest->weight = NUMERIC_WEIGHT(num);
6128 dest->sign = NUMERIC_SIGN(num);
6129 dest->dscale = NUMERIC_DSCALE(num);
6130
6131 memcpy(dest->digits, NUMERIC_DIGITS(num), ndigits * sizeof(NumericDigit));
6132 }
6133
6134
6135 /*
6136 * init_var_from_num() -
6137 *
6138 * Initialize a variable from packed db format. The digits array is not
6139 * copied, which saves some cycles when the resulting var is not modified.
6140 * Also, there's no need to call free_var(), as long as you don't assign any
6141 * other value to it (with set_var_* functions, or by using the var as the
6142 * destination of a function like add_var())
6143 *
6144 * CAUTION: Do not modify the digits buffer of a var initialized with this
6145 * function, e.g by calling round_var() or trunc_var(), as the changes will
6146 * propagate to the original Numeric! It's OK to use it as the destination
6147 * argument of one of the calculational functions, though.
6148 */
6149 static void
init_var_from_num(Numeric num,NumericVar * dest)6150 init_var_from_num(Numeric num, NumericVar *dest)
6151 {
6152 dest->ndigits = NUMERIC_NDIGITS(num);
6153 dest->weight = NUMERIC_WEIGHT(num);
6154 dest->sign = NUMERIC_SIGN(num);
6155 dest->dscale = NUMERIC_DSCALE(num);
6156 dest->digits = NUMERIC_DIGITS(num);
6157 dest->buf = NULL; /* digits array is not palloc'd */
6158 }
6159
6160
6161 /*
6162 * set_var_from_var() -
6163 *
6164 * Copy one variable into another
6165 */
6166 static void
set_var_from_var(const NumericVar * value,NumericVar * dest)6167 set_var_from_var(const NumericVar *value, NumericVar *dest)
6168 {
6169 NumericDigit *newbuf;
6170
6171 newbuf = digitbuf_alloc(value->ndigits + 1);
6172 newbuf[0] = 0; /* spare digit for rounding */
6173 if (value->ndigits > 0) /* else value->digits might be null */
6174 memcpy(newbuf + 1, value->digits,
6175 value->ndigits * sizeof(NumericDigit));
6176
6177 digitbuf_free(dest->buf);
6178
6179 memmove(dest, value, sizeof(NumericVar));
6180 dest->buf = newbuf;
6181 dest->digits = newbuf + 1;
6182 }
6183
6184
6185 /*
6186 * get_str_from_var() -
6187 *
6188 * Convert a var to text representation (guts of numeric_out).
6189 * The var is displayed to the number of digits indicated by its dscale.
6190 * Returns a palloc'd string.
6191 */
6192 static char *
get_str_from_var(const NumericVar * var)6193 get_str_from_var(const NumericVar *var)
6194 {
6195 int dscale;
6196 char *str;
6197 char *cp;
6198 char *endcp;
6199 int i;
6200 int d;
6201 NumericDigit dig;
6202
6203 #if DEC_DIGITS > 1
6204 NumericDigit d1;
6205 #endif
6206
6207 dscale = var->dscale;
6208
6209 /*
6210 * Allocate space for the result.
6211 *
6212 * i is set to the # of decimal digits before decimal point. dscale is the
6213 * # of decimal digits we will print after decimal point. We may generate
6214 * as many as DEC_DIGITS-1 excess digits at the end, and in addition we
6215 * need room for sign, decimal point, null terminator.
6216 */
6217 i = (var->weight + 1) * DEC_DIGITS;
6218 if (i <= 0)
6219 i = 1;
6220
6221 str = palloc(i + dscale + DEC_DIGITS + 2);
6222 cp = str;
6223
6224 /*
6225 * Output a dash for negative values
6226 */
6227 if (var->sign == NUMERIC_NEG)
6228 *cp++ = '-';
6229
6230 /*
6231 * Output all digits before the decimal point
6232 */
6233 if (var->weight < 0)
6234 {
6235 d = var->weight + 1;
6236 *cp++ = '0';
6237 }
6238 else
6239 {
6240 for (d = 0; d <= var->weight; d++)
6241 {
6242 dig = (d < var->ndigits) ? var->digits[d] : 0;
6243 /* In the first digit, suppress extra leading decimal zeroes */
6244 #if DEC_DIGITS == 4
6245 {
6246 bool putit = (d > 0);
6247
6248 d1 = dig / 1000;
6249 dig -= d1 * 1000;
6250 putit |= (d1 > 0);
6251 if (putit)
6252 *cp++ = d1 + '0';
6253 d1 = dig / 100;
6254 dig -= d1 * 100;
6255 putit |= (d1 > 0);
6256 if (putit)
6257 *cp++ = d1 + '0';
6258 d1 = dig / 10;
6259 dig -= d1 * 10;
6260 putit |= (d1 > 0);
6261 if (putit)
6262 *cp++ = d1 + '0';
6263 *cp++ = dig + '0';
6264 }
6265 #elif DEC_DIGITS == 2
6266 d1 = dig / 10;
6267 dig -= d1 * 10;
6268 if (d1 > 0 || d > 0)
6269 *cp++ = d1 + '0';
6270 *cp++ = dig + '0';
6271 #elif DEC_DIGITS == 1
6272 *cp++ = dig + '0';
6273 #else
6274 #error unsupported NBASE
6275 #endif
6276 }
6277 }
6278
6279 /*
6280 * If requested, output a decimal point and all the digits that follow it.
6281 * We initially put out a multiple of DEC_DIGITS digits, then truncate if
6282 * needed.
6283 */
6284 if (dscale > 0)
6285 {
6286 *cp++ = '.';
6287 endcp = cp + dscale;
6288 for (i = 0; i < dscale; d++, i += DEC_DIGITS)
6289 {
6290 dig = (d >= 0 && d < var->ndigits) ? var->digits[d] : 0;
6291 #if DEC_DIGITS == 4
6292 d1 = dig / 1000;
6293 dig -= d1 * 1000;
6294 *cp++ = d1 + '0';
6295 d1 = dig / 100;
6296 dig -= d1 * 100;
6297 *cp++ = d1 + '0';
6298 d1 = dig / 10;
6299 dig -= d1 * 10;
6300 *cp++ = d1 + '0';
6301 *cp++ = dig + '0';
6302 #elif DEC_DIGITS == 2
6303 d1 = dig / 10;
6304 dig -= d1 * 10;
6305 *cp++ = d1 + '0';
6306 *cp++ = dig + '0';
6307 #elif DEC_DIGITS == 1
6308 *cp++ = dig + '0';
6309 #else
6310 #error unsupported NBASE
6311 #endif
6312 }
6313 cp = endcp;
6314 }
6315
6316 /*
6317 * terminate the string and return it
6318 */
6319 *cp = '\0';
6320 return str;
6321 }
6322
6323 /*
6324 * get_str_from_var_sci() -
6325 *
6326 * Convert a var to a normalised scientific notation text representation.
6327 * This function does the heavy lifting for numeric_out_sci().
6328 *
6329 * This notation has the general form a * 10^b, where a is known as the
6330 * "significand" and b is known as the "exponent".
6331 *
6332 * Because we can't do superscript in ASCII (and because we want to copy
6333 * printf's behaviour) we display the exponent using E notation, with a
6334 * minimum of two exponent digits.
6335 *
6336 * For example, the value 1234 could be output as 1.2e+03.
6337 *
6338 * We assume that the exponent can fit into an int32.
6339 *
6340 * rscale is the number of decimal digits desired after the decimal point in
6341 * the output, negative values will be treated as meaning zero.
6342 *
6343 * Returns a palloc'd string.
6344 */
6345 static char *
get_str_from_var_sci(const NumericVar * var,int rscale)6346 get_str_from_var_sci(const NumericVar *var, int rscale)
6347 {
6348 int32 exponent;
6349 NumericVar tmp_var;
6350 size_t len;
6351 char *str;
6352 char *sig_out;
6353
6354 if (rscale < 0)
6355 rscale = 0;
6356
6357 /*
6358 * Determine the exponent of this number in normalised form.
6359 *
6360 * This is the exponent required to represent the number with only one
6361 * significant digit before the decimal place.
6362 */
6363 if (var->ndigits > 0)
6364 {
6365 exponent = (var->weight + 1) * DEC_DIGITS;
6366
6367 /*
6368 * Compensate for leading decimal zeroes in the first numeric digit by
6369 * decrementing the exponent.
6370 */
6371 exponent -= DEC_DIGITS - (int) log10(var->digits[0]);
6372 }
6373 else
6374 {
6375 /*
6376 * If var has no digits, then it must be zero.
6377 *
6378 * Zero doesn't technically have a meaningful exponent in normalised
6379 * notation, but we just display the exponent as zero for consistency
6380 * of output.
6381 */
6382 exponent = 0;
6383 }
6384
6385 /*
6386 * Divide var by 10^exponent to get the significand, rounding to rscale
6387 * decimal digits in the process.
6388 */
6389 init_var(&tmp_var);
6390
6391 power_ten_int(exponent, &tmp_var);
6392 div_var(var, &tmp_var, &tmp_var, rscale, true);
6393 sig_out = get_str_from_var(&tmp_var);
6394
6395 free_var(&tmp_var);
6396
6397 /*
6398 * Allocate space for the result.
6399 *
6400 * In addition to the significand, we need room for the exponent
6401 * decoration ("e"), the sign of the exponent, up to 10 digits for the
6402 * exponent itself, and of course the null terminator.
6403 */
6404 len = strlen(sig_out) + 13;
6405 str = palloc(len);
6406 snprintf(str, len, "%se%+03d", sig_out, exponent);
6407
6408 pfree(sig_out);
6409
6410 return str;
6411 }
6412
6413
6414 /*
6415 * make_result_opt_error() -
6416 *
6417 * Create the packed db numeric format in palloc()'d memory from
6418 * a variable. If "*have_error" flag is provided, on error it's set to
6419 * true, NULL returned. This is helpful when caller need to handle errors
6420 * by itself.
6421 */
6422 static Numeric
make_result_opt_error(const NumericVar * var,bool * have_error)6423 make_result_opt_error(const NumericVar *var, bool *have_error)
6424 {
6425 Numeric result;
6426 NumericDigit *digits = var->digits;
6427 int weight = var->weight;
6428 int sign = var->sign;
6429 int n;
6430 Size len;
6431
6432 if (have_error)
6433 *have_error = false;
6434
6435 if (sign == NUMERIC_NAN)
6436 {
6437 result = (Numeric) palloc(NUMERIC_HDRSZ_SHORT);
6438
6439 SET_VARSIZE(result, NUMERIC_HDRSZ_SHORT);
6440 result->choice.n_header = NUMERIC_NAN;
6441 /* the header word is all we need */
6442
6443 dump_numeric("make_result()", result);
6444 return result;
6445 }
6446
6447 n = var->ndigits;
6448
6449 /* truncate leading zeroes */
6450 while (n > 0 && *digits == 0)
6451 {
6452 digits++;
6453 weight--;
6454 n--;
6455 }
6456 /* truncate trailing zeroes */
6457 while (n > 0 && digits[n - 1] == 0)
6458 n--;
6459
6460 /* If zero result, force to weight=0 and positive sign */
6461 if (n == 0)
6462 {
6463 weight = 0;
6464 sign = NUMERIC_POS;
6465 }
6466
6467 /* Build the result */
6468 if (NUMERIC_CAN_BE_SHORT(var->dscale, weight))
6469 {
6470 len = NUMERIC_HDRSZ_SHORT + n * sizeof(NumericDigit);
6471 result = (Numeric) palloc(len);
6472 SET_VARSIZE(result, len);
6473 result->choice.n_short.n_header =
6474 (sign == NUMERIC_NEG ? (NUMERIC_SHORT | NUMERIC_SHORT_SIGN_MASK)
6475 : NUMERIC_SHORT)
6476 | (var->dscale << NUMERIC_SHORT_DSCALE_SHIFT)
6477 | (weight < 0 ? NUMERIC_SHORT_WEIGHT_SIGN_MASK : 0)
6478 | (weight & NUMERIC_SHORT_WEIGHT_MASK);
6479 }
6480 else
6481 {
6482 len = NUMERIC_HDRSZ + n * sizeof(NumericDigit);
6483 result = (Numeric) palloc(len);
6484 SET_VARSIZE(result, len);
6485 result->choice.n_long.n_sign_dscale =
6486 sign | (var->dscale & NUMERIC_DSCALE_MASK);
6487 result->choice.n_long.n_weight = weight;
6488 }
6489
6490 Assert(NUMERIC_NDIGITS(result) == n);
6491 if (n > 0)
6492 memcpy(NUMERIC_DIGITS(result), digits, n * sizeof(NumericDigit));
6493
6494 /* Check for overflow of int16 fields */
6495 if (NUMERIC_WEIGHT(result) != weight ||
6496 NUMERIC_DSCALE(result) != var->dscale)
6497 {
6498 if (have_error)
6499 {
6500 *have_error = true;
6501 return NULL;
6502 }
6503 else
6504 {
6505 ereport(ERROR,
6506 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6507 errmsg("value overflows numeric format")));
6508 }
6509 }
6510
6511 dump_numeric("make_result()", result);
6512 return result;
6513 }
6514
6515
6516 /*
6517 * make_result() -
6518 *
6519 * An interface to make_result_opt_error() without "have_error" argument.
6520 */
6521 static Numeric
make_result(const NumericVar * var)6522 make_result(const NumericVar *var)
6523 {
6524 return make_result_opt_error(var, NULL);
6525 }
6526
6527
6528 /*
6529 * apply_typmod() -
6530 *
6531 * Do bounds checking and rounding according to the attributes
6532 * typmod field.
6533 */
6534 static void
apply_typmod(NumericVar * var,int32 typmod)6535 apply_typmod(NumericVar *var, int32 typmod)
6536 {
6537 int precision;
6538 int scale;
6539 int maxdigits;
6540 int ddigits;
6541 int i;
6542
6543 /* Do nothing if we have a default typmod (-1) */
6544 if (typmod < (int32) (VARHDRSZ))
6545 return;
6546
6547 typmod -= VARHDRSZ;
6548 precision = (typmod >> 16) & 0xffff;
6549 scale = typmod & 0xffff;
6550 maxdigits = precision - scale;
6551
6552 /* Round to target scale (and set var->dscale) */
6553 round_var(var, scale);
6554
6555 /*
6556 * Check for overflow - note we can't do this before rounding, because
6557 * rounding could raise the weight. Also note that the var's weight could
6558 * be inflated by leading zeroes, which will be stripped before storage
6559 * but perhaps might not have been yet. In any case, we must recognize a
6560 * true zero, whose weight doesn't mean anything.
6561 */
6562 ddigits = (var->weight + 1) * DEC_DIGITS;
6563 if (ddigits > maxdigits)
6564 {
6565 /* Determine true weight; and check for all-zero result */
6566 for (i = 0; i < var->ndigits; i++)
6567 {
6568 NumericDigit dig = var->digits[i];
6569
6570 if (dig)
6571 {
6572 /* Adjust for any high-order decimal zero digits */
6573 #if DEC_DIGITS == 4
6574 if (dig < 10)
6575 ddigits -= 3;
6576 else if (dig < 100)
6577 ddigits -= 2;
6578 else if (dig < 1000)
6579 ddigits -= 1;
6580 #elif DEC_DIGITS == 2
6581 if (dig < 10)
6582 ddigits -= 1;
6583 #elif DEC_DIGITS == 1
6584 /* no adjustment */
6585 #else
6586 #error unsupported NBASE
6587 #endif
6588 if (ddigits > maxdigits)
6589 ereport(ERROR,
6590 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
6591 errmsg("numeric field overflow"),
6592 errdetail("A field with precision %d, scale %d must round to an absolute value less than %s%d.",
6593 precision, scale,
6594 /* Display 10^0 as 1 */
6595 maxdigits ? "10^" : "",
6596 maxdigits ? maxdigits : 1
6597 )));
6598 break;
6599 }
6600 ddigits -= DEC_DIGITS;
6601 }
6602 }
6603 }
6604
6605 /*
6606 * Convert numeric to int8, rounding if needed.
6607 *
6608 * If overflow, return false (no error is raised). Return true if okay.
6609 */
6610 static bool
numericvar_to_int64(const NumericVar * var,int64 * result)6611 numericvar_to_int64(const NumericVar *var, int64 *result)
6612 {
6613 NumericDigit *digits;
6614 int ndigits;
6615 int weight;
6616 int i;
6617 int64 val;
6618 bool neg;
6619 NumericVar rounded;
6620
6621 /* Round to nearest integer */
6622 init_var(&rounded);
6623 set_var_from_var(var, &rounded);
6624 round_var(&rounded, 0);
6625
6626 /* Check for zero input */
6627 strip_var(&rounded);
6628 ndigits = rounded.ndigits;
6629 if (ndigits == 0)
6630 {
6631 *result = 0;
6632 free_var(&rounded);
6633 return true;
6634 }
6635
6636 /*
6637 * For input like 10000000000, we must treat stripped digits as real. So
6638 * the loop assumes there are weight+1 digits before the decimal point.
6639 */
6640 weight = rounded.weight;
6641 Assert(weight >= 0 && ndigits <= weight + 1);
6642
6643 /*
6644 * Construct the result. To avoid issues with converting a value
6645 * corresponding to INT64_MIN (which can't be represented as a positive 64
6646 * bit two's complement integer), accumulate value as a negative number.
6647 */
6648 digits = rounded.digits;
6649 neg = (rounded.sign == NUMERIC_NEG);
6650 val = -digits[0];
6651 for (i = 1; i <= weight; i++)
6652 {
6653 if (unlikely(pg_mul_s64_overflow(val, NBASE, &val)))
6654 {
6655 free_var(&rounded);
6656 return false;
6657 }
6658
6659 if (i < ndigits)
6660 {
6661 if (unlikely(pg_sub_s64_overflow(val, digits[i], &val)))
6662 {
6663 free_var(&rounded);
6664 return false;
6665 }
6666 }
6667 }
6668
6669 free_var(&rounded);
6670
6671 if (!neg)
6672 {
6673 if (unlikely(val == PG_INT64_MIN))
6674 return false;
6675 val = -val;
6676 }
6677 *result = val;
6678
6679 return true;
6680 }
6681
6682 /*
6683 * Convert int8 value to numeric.
6684 */
6685 static void
int64_to_numericvar(int64 val,NumericVar * var)6686 int64_to_numericvar(int64 val, NumericVar *var)
6687 {
6688 uint64 uval,
6689 newuval;
6690 NumericDigit *ptr;
6691 int ndigits;
6692
6693 /* int64 can require at most 19 decimal digits; add one for safety */
6694 alloc_var(var, 20 / DEC_DIGITS);
6695 if (val < 0)
6696 {
6697 var->sign = NUMERIC_NEG;
6698 uval = -val;
6699 }
6700 else
6701 {
6702 var->sign = NUMERIC_POS;
6703 uval = val;
6704 }
6705 var->dscale = 0;
6706 if (val == 0)
6707 {
6708 var->ndigits = 0;
6709 var->weight = 0;
6710 return;
6711 }
6712 ptr = var->digits + var->ndigits;
6713 ndigits = 0;
6714 do
6715 {
6716 ptr--;
6717 ndigits++;
6718 newuval = uval / NBASE;
6719 *ptr = uval - newuval * NBASE;
6720 uval = newuval;
6721 } while (uval);
6722 var->digits = ptr;
6723 var->ndigits = ndigits;
6724 var->weight = ndigits - 1;
6725 }
6726
6727 #ifdef HAVE_INT128
6728 /*
6729 * Convert numeric to int128, rounding if needed.
6730 *
6731 * If overflow, return false (no error is raised). Return true if okay.
6732 */
6733 static bool
numericvar_to_int128(const NumericVar * var,int128 * result)6734 numericvar_to_int128(const NumericVar *var, int128 *result)
6735 {
6736 NumericDigit *digits;
6737 int ndigits;
6738 int weight;
6739 int i;
6740 int128 val,
6741 oldval;
6742 bool neg;
6743 NumericVar rounded;
6744
6745 /* Round to nearest integer */
6746 init_var(&rounded);
6747 set_var_from_var(var, &rounded);
6748 round_var(&rounded, 0);
6749
6750 /* Check for zero input */
6751 strip_var(&rounded);
6752 ndigits = rounded.ndigits;
6753 if (ndigits == 0)
6754 {
6755 *result = 0;
6756 free_var(&rounded);
6757 return true;
6758 }
6759
6760 /*
6761 * For input like 10000000000, we must treat stripped digits as real. So
6762 * the loop assumes there are weight+1 digits before the decimal point.
6763 */
6764 weight = rounded.weight;
6765 Assert(weight >= 0 && ndigits <= weight + 1);
6766
6767 /* Construct the result */
6768 digits = rounded.digits;
6769 neg = (rounded.sign == NUMERIC_NEG);
6770 val = digits[0];
6771 for (i = 1; i <= weight; i++)
6772 {
6773 oldval = val;
6774 val *= NBASE;
6775 if (i < ndigits)
6776 val += digits[i];
6777
6778 /*
6779 * The overflow check is a bit tricky because we want to accept
6780 * INT128_MIN, which will overflow the positive accumulator. We can
6781 * detect this case easily though because INT128_MIN is the only
6782 * nonzero value for which -val == val (on a two's complement machine,
6783 * anyway).
6784 */
6785 if ((val / NBASE) != oldval) /* possible overflow? */
6786 {
6787 if (!neg || (-val) != val || val == 0 || oldval < 0)
6788 {
6789 free_var(&rounded);
6790 return false;
6791 }
6792 }
6793 }
6794
6795 free_var(&rounded);
6796
6797 *result = neg ? -val : val;
6798 return true;
6799 }
6800
6801 /*
6802 * Convert 128 bit integer to numeric.
6803 */
6804 static void
int128_to_numericvar(int128 val,NumericVar * var)6805 int128_to_numericvar(int128 val, NumericVar *var)
6806 {
6807 uint128 uval,
6808 newuval;
6809 NumericDigit *ptr;
6810 int ndigits;
6811
6812 /* int128 can require at most 39 decimal digits; add one for safety */
6813 alloc_var(var, 40 / DEC_DIGITS);
6814 if (val < 0)
6815 {
6816 var->sign = NUMERIC_NEG;
6817 uval = -val;
6818 }
6819 else
6820 {
6821 var->sign = NUMERIC_POS;
6822 uval = val;
6823 }
6824 var->dscale = 0;
6825 if (val == 0)
6826 {
6827 var->ndigits = 0;
6828 var->weight = 0;
6829 return;
6830 }
6831 ptr = var->digits + var->ndigits;
6832 ndigits = 0;
6833 do
6834 {
6835 ptr--;
6836 ndigits++;
6837 newuval = uval / NBASE;
6838 *ptr = uval - newuval * NBASE;
6839 uval = newuval;
6840 } while (uval);
6841 var->digits = ptr;
6842 var->ndigits = ndigits;
6843 var->weight = ndigits - 1;
6844 }
6845 #endif
6846
6847 /*
6848 * Convert numeric to float8; if out of range, return +/- HUGE_VAL
6849 */
6850 static double
numeric_to_double_no_overflow(Numeric num)6851 numeric_to_double_no_overflow(Numeric num)
6852 {
6853 char *tmp;
6854 double val;
6855 char *endptr;
6856
6857 tmp = DatumGetCString(DirectFunctionCall1(numeric_out,
6858 NumericGetDatum(num)));
6859
6860 /* unlike float8in, we ignore ERANGE from strtod */
6861 val = strtod(tmp, &endptr);
6862 if (*endptr != '\0')
6863 {
6864 /* shouldn't happen ... */
6865 ereport(ERROR,
6866 (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6867 errmsg("invalid input syntax for type %s: \"%s\"",
6868 "double precision", tmp)));
6869 }
6870
6871 pfree(tmp);
6872
6873 return val;
6874 }
6875
6876 /* As above, but work from a NumericVar */
6877 static double
numericvar_to_double_no_overflow(const NumericVar * var)6878 numericvar_to_double_no_overflow(const NumericVar *var)
6879 {
6880 char *tmp;
6881 double val;
6882 char *endptr;
6883
6884 tmp = get_str_from_var(var);
6885
6886 /* unlike float8in, we ignore ERANGE from strtod */
6887 val = strtod(tmp, &endptr);
6888 if (*endptr != '\0')
6889 {
6890 /* shouldn't happen ... */
6891 ereport(ERROR,
6892 (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
6893 errmsg("invalid input syntax for type %s: \"%s\"",
6894 "double precision", tmp)));
6895 }
6896
6897 pfree(tmp);
6898
6899 return val;
6900 }
6901
6902
6903 /*
6904 * cmp_var() -
6905 *
6906 * Compare two values on variable level. We assume zeroes have been
6907 * truncated to no digits.
6908 */
6909 static int
cmp_var(const NumericVar * var1,const NumericVar * var2)6910 cmp_var(const NumericVar *var1, const NumericVar *var2)
6911 {
6912 return cmp_var_common(var1->digits, var1->ndigits,
6913 var1->weight, var1->sign,
6914 var2->digits, var2->ndigits,
6915 var2->weight, var2->sign);
6916 }
6917
6918 /*
6919 * cmp_var_common() -
6920 *
6921 * Main routine of cmp_var(). This function can be used by both
6922 * NumericVar and Numeric.
6923 */
6924 static int
cmp_var_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,int var1sign,const NumericDigit * var2digits,int var2ndigits,int var2weight,int var2sign)6925 cmp_var_common(const NumericDigit *var1digits, int var1ndigits,
6926 int var1weight, int var1sign,
6927 const NumericDigit *var2digits, int var2ndigits,
6928 int var2weight, int var2sign)
6929 {
6930 if (var1ndigits == 0)
6931 {
6932 if (var2ndigits == 0)
6933 return 0;
6934 if (var2sign == NUMERIC_NEG)
6935 return 1;
6936 return -1;
6937 }
6938 if (var2ndigits == 0)
6939 {
6940 if (var1sign == NUMERIC_POS)
6941 return 1;
6942 return -1;
6943 }
6944
6945 if (var1sign == NUMERIC_POS)
6946 {
6947 if (var2sign == NUMERIC_NEG)
6948 return 1;
6949 return cmp_abs_common(var1digits, var1ndigits, var1weight,
6950 var2digits, var2ndigits, var2weight);
6951 }
6952
6953 if (var2sign == NUMERIC_POS)
6954 return -1;
6955
6956 return cmp_abs_common(var2digits, var2ndigits, var2weight,
6957 var1digits, var1ndigits, var1weight);
6958 }
6959
6960
6961 /*
6962 * add_var() -
6963 *
6964 * Full version of add functionality on variable level (handling signs).
6965 * result might point to one of the operands too without danger.
6966 */
6967 static void
add_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)6968 add_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
6969 {
6970 /*
6971 * Decide on the signs of the two variables what to do
6972 */
6973 if (var1->sign == NUMERIC_POS)
6974 {
6975 if (var2->sign == NUMERIC_POS)
6976 {
6977 /*
6978 * Both are positive result = +(ABS(var1) + ABS(var2))
6979 */
6980 add_abs(var1, var2, result);
6981 result->sign = NUMERIC_POS;
6982 }
6983 else
6984 {
6985 /*
6986 * var1 is positive, var2 is negative Must compare absolute values
6987 */
6988 switch (cmp_abs(var1, var2))
6989 {
6990 case 0:
6991 /* ----------
6992 * ABS(var1) == ABS(var2)
6993 * result = ZERO
6994 * ----------
6995 */
6996 zero_var(result);
6997 result->dscale = Max(var1->dscale, var2->dscale);
6998 break;
6999
7000 case 1:
7001 /* ----------
7002 * ABS(var1) > ABS(var2)
7003 * result = +(ABS(var1) - ABS(var2))
7004 * ----------
7005 */
7006 sub_abs(var1, var2, result);
7007 result->sign = NUMERIC_POS;
7008 break;
7009
7010 case -1:
7011 /* ----------
7012 * ABS(var1) < ABS(var2)
7013 * result = -(ABS(var2) - ABS(var1))
7014 * ----------
7015 */
7016 sub_abs(var2, var1, result);
7017 result->sign = NUMERIC_NEG;
7018 break;
7019 }
7020 }
7021 }
7022 else
7023 {
7024 if (var2->sign == NUMERIC_POS)
7025 {
7026 /* ----------
7027 * var1 is negative, var2 is positive
7028 * Must compare absolute values
7029 * ----------
7030 */
7031 switch (cmp_abs(var1, var2))
7032 {
7033 case 0:
7034 /* ----------
7035 * ABS(var1) == ABS(var2)
7036 * result = ZERO
7037 * ----------
7038 */
7039 zero_var(result);
7040 result->dscale = Max(var1->dscale, var2->dscale);
7041 break;
7042
7043 case 1:
7044 /* ----------
7045 * ABS(var1) > ABS(var2)
7046 * result = -(ABS(var1) - ABS(var2))
7047 * ----------
7048 */
7049 sub_abs(var1, var2, result);
7050 result->sign = NUMERIC_NEG;
7051 break;
7052
7053 case -1:
7054 /* ----------
7055 * ABS(var1) < ABS(var2)
7056 * result = +(ABS(var2) - ABS(var1))
7057 * ----------
7058 */
7059 sub_abs(var2, var1, result);
7060 result->sign = NUMERIC_POS;
7061 break;
7062 }
7063 }
7064 else
7065 {
7066 /* ----------
7067 * Both are negative
7068 * result = -(ABS(var1) + ABS(var2))
7069 * ----------
7070 */
7071 add_abs(var1, var2, result);
7072 result->sign = NUMERIC_NEG;
7073 }
7074 }
7075 }
7076
7077
7078 /*
7079 * sub_var() -
7080 *
7081 * Full version of sub functionality on variable level (handling signs).
7082 * result might point to one of the operands too without danger.
7083 */
7084 static void
sub_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)7085 sub_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
7086 {
7087 /*
7088 * Decide on the signs of the two variables what to do
7089 */
7090 if (var1->sign == NUMERIC_POS)
7091 {
7092 if (var2->sign == NUMERIC_NEG)
7093 {
7094 /* ----------
7095 * var1 is positive, var2 is negative
7096 * result = +(ABS(var1) + ABS(var2))
7097 * ----------
7098 */
7099 add_abs(var1, var2, result);
7100 result->sign = NUMERIC_POS;
7101 }
7102 else
7103 {
7104 /* ----------
7105 * Both are positive
7106 * Must compare absolute values
7107 * ----------
7108 */
7109 switch (cmp_abs(var1, var2))
7110 {
7111 case 0:
7112 /* ----------
7113 * ABS(var1) == ABS(var2)
7114 * result = ZERO
7115 * ----------
7116 */
7117 zero_var(result);
7118 result->dscale = Max(var1->dscale, var2->dscale);
7119 break;
7120
7121 case 1:
7122 /* ----------
7123 * ABS(var1) > ABS(var2)
7124 * result = +(ABS(var1) - ABS(var2))
7125 * ----------
7126 */
7127 sub_abs(var1, var2, result);
7128 result->sign = NUMERIC_POS;
7129 break;
7130
7131 case -1:
7132 /* ----------
7133 * ABS(var1) < ABS(var2)
7134 * result = -(ABS(var2) - ABS(var1))
7135 * ----------
7136 */
7137 sub_abs(var2, var1, result);
7138 result->sign = NUMERIC_NEG;
7139 break;
7140 }
7141 }
7142 }
7143 else
7144 {
7145 if (var2->sign == NUMERIC_NEG)
7146 {
7147 /* ----------
7148 * Both are negative
7149 * Must compare absolute values
7150 * ----------
7151 */
7152 switch (cmp_abs(var1, var2))
7153 {
7154 case 0:
7155 /* ----------
7156 * ABS(var1) == ABS(var2)
7157 * result = ZERO
7158 * ----------
7159 */
7160 zero_var(result);
7161 result->dscale = Max(var1->dscale, var2->dscale);
7162 break;
7163
7164 case 1:
7165 /* ----------
7166 * ABS(var1) > ABS(var2)
7167 * result = -(ABS(var1) - ABS(var2))
7168 * ----------
7169 */
7170 sub_abs(var1, var2, result);
7171 result->sign = NUMERIC_NEG;
7172 break;
7173
7174 case -1:
7175 /* ----------
7176 * ABS(var1) < ABS(var2)
7177 * result = +(ABS(var2) - ABS(var1))
7178 * ----------
7179 */
7180 sub_abs(var2, var1, result);
7181 result->sign = NUMERIC_POS;
7182 break;
7183 }
7184 }
7185 else
7186 {
7187 /* ----------
7188 * var1 is negative, var2 is positive
7189 * result = -(ABS(var1) + ABS(var2))
7190 * ----------
7191 */
7192 add_abs(var1, var2, result);
7193 result->sign = NUMERIC_NEG;
7194 }
7195 }
7196 }
7197
7198
7199 /*
7200 * mul_var() -
7201 *
7202 * Multiplication on variable level. Product of var1 * var2 is stored
7203 * in result. Result is rounded to no more than rscale fractional digits.
7204 */
7205 static void
mul_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale)7206 mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
7207 int rscale)
7208 {
7209 int res_ndigits;
7210 int res_sign;
7211 int res_weight;
7212 int maxdigits;
7213 int *dig;
7214 int carry;
7215 int maxdig;
7216 int newdig;
7217 int var1ndigits;
7218 int var2ndigits;
7219 NumericDigit *var1digits;
7220 NumericDigit *var2digits;
7221 NumericDigit *res_digits;
7222 int i,
7223 i1,
7224 i2;
7225
7226 /*
7227 * Arrange for var1 to be the shorter of the two numbers. This improves
7228 * performance because the inner multiplication loop is much simpler than
7229 * the outer loop, so it's better to have a smaller number of iterations
7230 * of the outer loop. This also reduces the number of times that the
7231 * accumulator array needs to be normalized.
7232 */
7233 if (var1->ndigits > var2->ndigits)
7234 {
7235 const NumericVar *tmp = var1;
7236
7237 var1 = var2;
7238 var2 = tmp;
7239 }
7240
7241 /* copy these values into local vars for speed in inner loop */
7242 var1ndigits = var1->ndigits;
7243 var2ndigits = var2->ndigits;
7244 var1digits = var1->digits;
7245 var2digits = var2->digits;
7246
7247 if (var1ndigits == 0 || var2ndigits == 0)
7248 {
7249 /* one or both inputs is zero; so is result */
7250 zero_var(result);
7251 result->dscale = rscale;
7252 return;
7253 }
7254
7255 /* Determine result sign and (maximum possible) weight */
7256 if (var1->sign == var2->sign)
7257 res_sign = NUMERIC_POS;
7258 else
7259 res_sign = NUMERIC_NEG;
7260 res_weight = var1->weight + var2->weight + 2;
7261
7262 /*
7263 * Determine the number of result digits to compute. If the exact result
7264 * would have more than rscale fractional digits, truncate the computation
7265 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
7266 * would only contribute to the right of that. (This will give the exact
7267 * rounded-to-rscale answer unless carries out of the ignored positions
7268 * would have propagated through more than MUL_GUARD_DIGITS digits.)
7269 *
7270 * Note: an exact computation could not produce more than var1ndigits +
7271 * var2ndigits digits, but we allocate one extra output digit in case
7272 * rscale-driven rounding produces a carry out of the highest exact digit.
7273 */
7274 res_ndigits = var1ndigits + var2ndigits + 1;
7275 maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
7276 MUL_GUARD_DIGITS;
7277 res_ndigits = Min(res_ndigits, maxdigits);
7278
7279 if (res_ndigits < 3)
7280 {
7281 /* All input digits will be ignored; so result is zero */
7282 zero_var(result);
7283 result->dscale = rscale;
7284 return;
7285 }
7286
7287 /*
7288 * We do the arithmetic in an array "dig[]" of signed int's. Since
7289 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
7290 * to avoid normalizing carries immediately.
7291 *
7292 * maxdig tracks the maximum possible value of any dig[] entry; when this
7293 * threatens to exceed INT_MAX, we take the time to propagate carries.
7294 * Furthermore, we need to ensure that overflow doesn't occur during the
7295 * carry propagation passes either. The carry values could be as much as
7296 * INT_MAX/NBASE, so really we must normalize when digits threaten to
7297 * exceed INT_MAX - INT_MAX/NBASE.
7298 *
7299 * To avoid overflow in maxdig itself, it actually represents the max
7300 * possible value divided by NBASE-1, ie, at the top of the loop it is
7301 * known that no dig[] entry exceeds maxdig * (NBASE-1).
7302 */
7303 dig = (int *) palloc0(res_ndigits * sizeof(int));
7304 maxdig = 0;
7305
7306 /*
7307 * The least significant digits of var1 should be ignored if they don't
7308 * contribute directly to the first res_ndigits digits of the result that
7309 * we are computing.
7310 *
7311 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
7312 * i1+i2+2 of the accumulator array, so we need only consider digits of
7313 * var1 for which i1 <= res_ndigits - 3.
7314 */
7315 for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
7316 {
7317 int var1digit = var1digits[i1];
7318
7319 if (var1digit == 0)
7320 continue;
7321
7322 /* Time to normalize? */
7323 maxdig += var1digit;
7324 if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
7325 {
7326 /* Yes, do it */
7327 carry = 0;
7328 for (i = res_ndigits - 1; i >= 0; i--)
7329 {
7330 newdig = dig[i] + carry;
7331 if (newdig >= NBASE)
7332 {
7333 carry = newdig / NBASE;
7334 newdig -= carry * NBASE;
7335 }
7336 else
7337 carry = 0;
7338 dig[i] = newdig;
7339 }
7340 Assert(carry == 0);
7341 /* Reset maxdig to indicate new worst-case */
7342 maxdig = 1 + var1digit;
7343 }
7344
7345 /*
7346 * Add the appropriate multiple of var2 into the accumulator.
7347 *
7348 * As above, digits of var2 can be ignored if they don't contribute,
7349 * so we only include digits for which i1+i2+2 <= res_ndigits - 1.
7350 */
7351 for (i2 = Min(var2ndigits - 1, res_ndigits - i1 - 3), i = i1 + i2 + 2;
7352 i2 >= 0; i2--)
7353 dig[i--] += var1digit * var2digits[i2];
7354 }
7355
7356 /*
7357 * Now we do a final carry propagation pass to normalize the result, which
7358 * we combine with storing the result digits into the output. Note that
7359 * this is still done at full precision w/guard digits.
7360 */
7361 alloc_var(result, res_ndigits);
7362 res_digits = result->digits;
7363 carry = 0;
7364 for (i = res_ndigits - 1; i >= 0; i--)
7365 {
7366 newdig = dig[i] + carry;
7367 if (newdig >= NBASE)
7368 {
7369 carry = newdig / NBASE;
7370 newdig -= carry * NBASE;
7371 }
7372 else
7373 carry = 0;
7374 res_digits[i] = newdig;
7375 }
7376 Assert(carry == 0);
7377
7378 pfree(dig);
7379
7380 /*
7381 * Finally, round the result to the requested precision.
7382 */
7383 result->weight = res_weight;
7384 result->sign = res_sign;
7385
7386 /* Round to target rscale (and set result->dscale) */
7387 round_var(result, rscale);
7388
7389 /* Strip leading and trailing zeroes */
7390 strip_var(result);
7391 }
7392
7393
7394 /*
7395 * div_var() -
7396 *
7397 * Division on variable level. Quotient of var1 / var2 is stored in result.
7398 * The quotient is figured to exactly rscale fractional digits.
7399 * If round is true, it is rounded at the rscale'th digit; if false, it
7400 * is truncated (towards zero) at that digit.
7401 */
7402 static void
div_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)7403 div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
7404 int rscale, bool round)
7405 {
7406 int div_ndigits;
7407 int res_ndigits;
7408 int res_sign;
7409 int res_weight;
7410 int carry;
7411 int borrow;
7412 int divisor1;
7413 int divisor2;
7414 NumericDigit *dividend;
7415 NumericDigit *divisor;
7416 NumericDigit *res_digits;
7417 int i;
7418 int j;
7419
7420 /* copy these values into local vars for speed in inner loop */
7421 int var1ndigits = var1->ndigits;
7422 int var2ndigits = var2->ndigits;
7423
7424 /*
7425 * First of all division by zero check; we must not be handed an
7426 * unnormalized divisor.
7427 */
7428 if (var2ndigits == 0 || var2->digits[0] == 0)
7429 ereport(ERROR,
7430 (errcode(ERRCODE_DIVISION_BY_ZERO),
7431 errmsg("division by zero")));
7432
7433 /*
7434 * Now result zero check
7435 */
7436 if (var1ndigits == 0)
7437 {
7438 zero_var(result);
7439 result->dscale = rscale;
7440 return;
7441 }
7442
7443 /*
7444 * Determine the result sign, weight and number of digits to calculate.
7445 * The weight figured here is correct if the emitted quotient has no
7446 * leading zero digits; otherwise strip_var() will fix things up.
7447 */
7448 if (var1->sign == var2->sign)
7449 res_sign = NUMERIC_POS;
7450 else
7451 res_sign = NUMERIC_NEG;
7452 res_weight = var1->weight - var2->weight;
7453 /* The number of accurate result digits we need to produce: */
7454 res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7455 /* ... but always at least 1 */
7456 res_ndigits = Max(res_ndigits, 1);
7457 /* If rounding needed, figure one more digit to ensure correct result */
7458 if (round)
7459 res_ndigits++;
7460
7461 /*
7462 * The working dividend normally requires res_ndigits + var2ndigits
7463 * digits, but make it at least var1ndigits so we can load all of var1
7464 * into it. (There will be an additional digit dividend[0] in the
7465 * dividend space, but for consistency with Knuth's notation we don't
7466 * count that in div_ndigits.)
7467 */
7468 div_ndigits = res_ndigits + var2ndigits;
7469 div_ndigits = Max(div_ndigits, var1ndigits);
7470
7471 /*
7472 * We need a workspace with room for the working dividend (div_ndigits+1
7473 * digits) plus room for the possibly-normalized divisor (var2ndigits
7474 * digits). It is convenient also to have a zero at divisor[0] with the
7475 * actual divisor data in divisor[1 .. var2ndigits]. Transferring the
7476 * digits into the workspace also allows us to realloc the result (which
7477 * might be the same as either input var) before we begin the main loop.
7478 * Note that we use palloc0 to ensure that divisor[0], dividend[0], and
7479 * any additional dividend positions beyond var1ndigits, start out 0.
7480 */
7481 dividend = (NumericDigit *)
7482 palloc0((div_ndigits + var2ndigits + 2) * sizeof(NumericDigit));
7483 divisor = dividend + (div_ndigits + 1);
7484 memcpy(dividend + 1, var1->digits, var1ndigits * sizeof(NumericDigit));
7485 memcpy(divisor + 1, var2->digits, var2ndigits * sizeof(NumericDigit));
7486
7487 /*
7488 * Now we can realloc the result to hold the generated quotient digits.
7489 */
7490 alloc_var(result, res_ndigits);
7491 res_digits = result->digits;
7492
7493 if (var2ndigits == 1)
7494 {
7495 /*
7496 * If there's only a single divisor digit, we can use a fast path (cf.
7497 * Knuth section 4.3.1 exercise 16).
7498 */
7499 divisor1 = divisor[1];
7500 carry = 0;
7501 for (i = 0; i < res_ndigits; i++)
7502 {
7503 carry = carry * NBASE + dividend[i + 1];
7504 res_digits[i] = carry / divisor1;
7505 carry = carry % divisor1;
7506 }
7507 }
7508 else
7509 {
7510 /*
7511 * The full multiple-place algorithm is taken from Knuth volume 2,
7512 * Algorithm 4.3.1D.
7513 *
7514 * We need the first divisor digit to be >= NBASE/2. If it isn't,
7515 * make it so by scaling up both the divisor and dividend by the
7516 * factor "d". (The reason for allocating dividend[0] above is to
7517 * leave room for possible carry here.)
7518 */
7519 if (divisor[1] < HALF_NBASE)
7520 {
7521 int d = NBASE / (divisor[1] + 1);
7522
7523 carry = 0;
7524 for (i = var2ndigits; i > 0; i--)
7525 {
7526 carry += divisor[i] * d;
7527 divisor[i] = carry % NBASE;
7528 carry = carry / NBASE;
7529 }
7530 Assert(carry == 0);
7531 carry = 0;
7532 /* at this point only var1ndigits of dividend can be nonzero */
7533 for (i = var1ndigits; i >= 0; i--)
7534 {
7535 carry += dividend[i] * d;
7536 dividend[i] = carry % NBASE;
7537 carry = carry / NBASE;
7538 }
7539 Assert(carry == 0);
7540 Assert(divisor[1] >= HALF_NBASE);
7541 }
7542 /* First 2 divisor digits are used repeatedly in main loop */
7543 divisor1 = divisor[1];
7544 divisor2 = divisor[2];
7545
7546 /*
7547 * Begin the main loop. Each iteration of this loop produces the j'th
7548 * quotient digit by dividing dividend[j .. j + var2ndigits] by the
7549 * divisor; this is essentially the same as the common manual
7550 * procedure for long division.
7551 */
7552 for (j = 0; j < res_ndigits; j++)
7553 {
7554 /* Estimate quotient digit from the first two dividend digits */
7555 int next2digits = dividend[j] * NBASE + dividend[j + 1];
7556 int qhat;
7557
7558 /*
7559 * If next2digits are 0, then quotient digit must be 0 and there's
7560 * no need to adjust the working dividend. It's worth testing
7561 * here to fall out ASAP when processing trailing zeroes in a
7562 * dividend.
7563 */
7564 if (next2digits == 0)
7565 {
7566 res_digits[j] = 0;
7567 continue;
7568 }
7569
7570 if (dividend[j] == divisor1)
7571 qhat = NBASE - 1;
7572 else
7573 qhat = next2digits / divisor1;
7574
7575 /*
7576 * Adjust quotient digit if it's too large. Knuth proves that
7577 * after this step, the quotient digit will be either correct or
7578 * just one too large. (Note: it's OK to use dividend[j+2] here
7579 * because we know the divisor length is at least 2.)
7580 */
7581 while (divisor2 * qhat >
7582 (next2digits - qhat * divisor1) * NBASE + dividend[j + 2])
7583 qhat--;
7584
7585 /* As above, need do nothing more when quotient digit is 0 */
7586 if (qhat > 0)
7587 {
7588 /*
7589 * Multiply the divisor by qhat, and subtract that from the
7590 * working dividend. "carry" tracks the multiplication,
7591 * "borrow" the subtraction (could we fold these together?)
7592 */
7593 carry = 0;
7594 borrow = 0;
7595 for (i = var2ndigits; i >= 0; i--)
7596 {
7597 carry += divisor[i] * qhat;
7598 borrow -= carry % NBASE;
7599 carry = carry / NBASE;
7600 borrow += dividend[j + i];
7601 if (borrow < 0)
7602 {
7603 dividend[j + i] = borrow + NBASE;
7604 borrow = -1;
7605 }
7606 else
7607 {
7608 dividend[j + i] = borrow;
7609 borrow = 0;
7610 }
7611 }
7612 Assert(carry == 0);
7613
7614 /*
7615 * If we got a borrow out of the top dividend digit, then
7616 * indeed qhat was one too large. Fix it, and add back the
7617 * divisor to correct the working dividend. (Knuth proves
7618 * that this will occur only about 3/NBASE of the time; hence,
7619 * it's a good idea to test this code with small NBASE to be
7620 * sure this section gets exercised.)
7621 */
7622 if (borrow)
7623 {
7624 qhat--;
7625 carry = 0;
7626 for (i = var2ndigits; i >= 0; i--)
7627 {
7628 carry += dividend[j + i] + divisor[i];
7629 if (carry >= NBASE)
7630 {
7631 dividend[j + i] = carry - NBASE;
7632 carry = 1;
7633 }
7634 else
7635 {
7636 dividend[j + i] = carry;
7637 carry = 0;
7638 }
7639 }
7640 /* A carry should occur here to cancel the borrow above */
7641 Assert(carry == 1);
7642 }
7643 }
7644
7645 /* And we're done with this quotient digit */
7646 res_digits[j] = qhat;
7647 }
7648 }
7649
7650 pfree(dividend);
7651
7652 /*
7653 * Finally, round or truncate the result to the requested precision.
7654 */
7655 result->weight = res_weight;
7656 result->sign = res_sign;
7657
7658 /* Round or truncate to target rscale (and set result->dscale) */
7659 if (round)
7660 round_var(result, rscale);
7661 else
7662 trunc_var(result, rscale);
7663
7664 /* Strip leading and trailing zeroes */
7665 strip_var(result);
7666 }
7667
7668
7669 /*
7670 * div_var_fast() -
7671 *
7672 * This has the same API as div_var, but is implemented using the division
7673 * algorithm from the "FM" library, rather than Knuth's schoolbook-division
7674 * approach. This is significantly faster but can produce inaccurate
7675 * results, because it sometimes has to propagate rounding to the left,
7676 * and so we can never be entirely sure that we know the requested digits
7677 * exactly. We compute DIV_GUARD_DIGITS extra digits, but there is
7678 * no certainty that that's enough. We use this only in the transcendental
7679 * function calculation routines, where everything is approximate anyway.
7680 *
7681 * Although we provide a "round" argument for consistency with div_var,
7682 * it is unwise to use this function with round=false. In truncation mode
7683 * it is possible to get a result with no significant digits, for example
7684 * with rscale=0 we might compute 0.99999... and truncate that to 0 when
7685 * the correct answer is 1.
7686 */
7687 static void
div_var_fast(const NumericVar * var1,const NumericVar * var2,NumericVar * result,int rscale,bool round)7688 div_var_fast(const NumericVar *var1, const NumericVar *var2,
7689 NumericVar *result, int rscale, bool round)
7690 {
7691 int div_ndigits;
7692 int load_ndigits;
7693 int res_sign;
7694 int res_weight;
7695 int *div;
7696 int qdigit;
7697 int carry;
7698 int maxdiv;
7699 int newdig;
7700 NumericDigit *res_digits;
7701 double fdividend,
7702 fdivisor,
7703 fdivisorinverse,
7704 fquotient;
7705 int qi;
7706 int i;
7707
7708 /* copy these values into local vars for speed in inner loop */
7709 int var1ndigits = var1->ndigits;
7710 int var2ndigits = var2->ndigits;
7711 NumericDigit *var1digits = var1->digits;
7712 NumericDigit *var2digits = var2->digits;
7713
7714 /*
7715 * First of all division by zero check; we must not be handed an
7716 * unnormalized divisor.
7717 */
7718 if (var2ndigits == 0 || var2digits[0] == 0)
7719 ereport(ERROR,
7720 (errcode(ERRCODE_DIVISION_BY_ZERO),
7721 errmsg("division by zero")));
7722
7723 /*
7724 * Now result zero check
7725 */
7726 if (var1ndigits == 0)
7727 {
7728 zero_var(result);
7729 result->dscale = rscale;
7730 return;
7731 }
7732
7733 /*
7734 * Determine the result sign, weight and number of digits to calculate
7735 */
7736 if (var1->sign == var2->sign)
7737 res_sign = NUMERIC_POS;
7738 else
7739 res_sign = NUMERIC_NEG;
7740 res_weight = var1->weight - var2->weight + 1;
7741 /* The number of accurate result digits we need to produce: */
7742 div_ndigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS;
7743 /* Add guard digits for roundoff error */
7744 div_ndigits += DIV_GUARD_DIGITS;
7745 if (div_ndigits < DIV_GUARD_DIGITS)
7746 div_ndigits = DIV_GUARD_DIGITS;
7747
7748 /*
7749 * We do the arithmetic in an array "div[]" of signed int's. Since
7750 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
7751 * to avoid normalizing carries immediately.
7752 *
7753 * We start with div[] containing one zero digit followed by the
7754 * dividend's digits (plus appended zeroes to reach the desired precision
7755 * including guard digits). Each step of the main loop computes an
7756 * (approximate) quotient digit and stores it into div[], removing one
7757 * position of dividend space. A final pass of carry propagation takes
7758 * care of any mistaken quotient digits.
7759 *
7760 * Note that div[] doesn't necessarily contain all of the digits from the
7761 * dividend --- the desired precision plus guard digits might be less than
7762 * the dividend's precision. This happens, for example, in the square
7763 * root algorithm, where we typically divide a 2N-digit number by an
7764 * N-digit number, and only require a result with N digits of precision.
7765 */
7766 div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
7767 load_ndigits = Min(div_ndigits, var1ndigits);
7768 for (i = 0; i < load_ndigits; i++)
7769 div[i + 1] = var1digits[i];
7770
7771 /*
7772 * We estimate each quotient digit using floating-point arithmetic, taking
7773 * the first four digits of the (current) dividend and divisor. This must
7774 * be float to avoid overflow. The quotient digits will generally be off
7775 * by no more than one from the exact answer.
7776 */
7777 fdivisor = (double) var2digits[0];
7778 for (i = 1; i < 4; i++)
7779 {
7780 fdivisor *= NBASE;
7781 if (i < var2ndigits)
7782 fdivisor += (double) var2digits[i];
7783 }
7784 fdivisorinverse = 1.0 / fdivisor;
7785
7786 /*
7787 * maxdiv tracks the maximum possible absolute value of any div[] entry;
7788 * when this threatens to exceed INT_MAX, we take the time to propagate
7789 * carries. Furthermore, we need to ensure that overflow doesn't occur
7790 * during the carry propagation passes either. The carry values may have
7791 * an absolute value as high as INT_MAX/NBASE + 1, so really we must
7792 * normalize when digits threaten to exceed INT_MAX - INT_MAX/NBASE - 1.
7793 *
7794 * To avoid overflow in maxdiv itself, it represents the max absolute
7795 * value divided by NBASE-1, ie, at the top of the loop it is known that
7796 * no div[] entry has an absolute value exceeding maxdiv * (NBASE-1).
7797 *
7798 * Actually, though, that holds good only for div[] entries after div[qi];
7799 * the adjustment done at the bottom of the loop may cause div[qi + 1] to
7800 * exceed the maxdiv limit, so that div[qi] in the next iteration is
7801 * beyond the limit. This does not cause problems, as explained below.
7802 */
7803 maxdiv = 1;
7804
7805 /*
7806 * Outer loop computes next quotient digit, which will go into div[qi]
7807 */
7808 for (qi = 0; qi < div_ndigits; qi++)
7809 {
7810 /* Approximate the current dividend value */
7811 fdividend = (double) div[qi];
7812 for (i = 1; i < 4; i++)
7813 {
7814 fdividend *= NBASE;
7815 if (qi + i <= div_ndigits)
7816 fdividend += (double) div[qi + i];
7817 }
7818 /* Compute the (approximate) quotient digit */
7819 fquotient = fdividend * fdivisorinverse;
7820 qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7821 (((int) fquotient) - 1); /* truncate towards -infinity */
7822
7823 if (qdigit != 0)
7824 {
7825 /* Do we need to normalize now? */
7826 maxdiv += Abs(qdigit);
7827 if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
7828 {
7829 /*
7830 * Yes, do it. Note that if var2ndigits is much smaller than
7831 * div_ndigits, we can save a significant amount of effort
7832 * here by noting that we only need to normalise those div[]
7833 * entries touched where prior iterations subtracted multiples
7834 * of the divisor.
7835 */
7836 carry = 0;
7837 for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
7838 {
7839 newdig = div[i] + carry;
7840 if (newdig < 0)
7841 {
7842 carry = -((-newdig - 1) / NBASE) - 1;
7843 newdig -= carry * NBASE;
7844 }
7845 else if (newdig >= NBASE)
7846 {
7847 carry = newdig / NBASE;
7848 newdig -= carry * NBASE;
7849 }
7850 else
7851 carry = 0;
7852 div[i] = newdig;
7853 }
7854 newdig = div[qi] + carry;
7855 div[qi] = newdig;
7856
7857 /*
7858 * All the div[] digits except possibly div[qi] are now in the
7859 * range 0..NBASE-1. We do not need to consider div[qi] in
7860 * the maxdiv value anymore, so we can reset maxdiv to 1.
7861 */
7862 maxdiv = 1;
7863
7864 /*
7865 * Recompute the quotient digit since new info may have
7866 * propagated into the top four dividend digits
7867 */
7868 fdividend = (double) div[qi];
7869 for (i = 1; i < 4; i++)
7870 {
7871 fdividend *= NBASE;
7872 if (qi + i <= div_ndigits)
7873 fdividend += (double) div[qi + i];
7874 }
7875 /* Compute the (approximate) quotient digit */
7876 fquotient = fdividend * fdivisorinverse;
7877 qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7878 (((int) fquotient) - 1); /* truncate towards -infinity */
7879 maxdiv += Abs(qdigit);
7880 }
7881
7882 /*
7883 * Subtract off the appropriate multiple of the divisor.
7884 *
7885 * The digits beyond div[qi] cannot overflow, because we know they
7886 * will fall within the maxdiv limit. As for div[qi] itself, note
7887 * that qdigit is approximately trunc(div[qi] / vardigits[0]),
7888 * which would make the new value simply div[qi] mod vardigits[0].
7889 * The lower-order terms in qdigit can change this result by not
7890 * more than about twice INT_MAX/NBASE, so overflow is impossible.
7891 */
7892 if (qdigit != 0)
7893 {
7894 int istop = Min(var2ndigits, div_ndigits - qi + 1);
7895
7896 for (i = 0; i < istop; i++)
7897 div[qi + i] -= qdigit * var2digits[i];
7898 }
7899 }
7900
7901 /*
7902 * The dividend digit we are about to replace might still be nonzero.
7903 * Fold it into the next digit position.
7904 *
7905 * There is no risk of overflow here, although proving that requires
7906 * some care. Much as with the argument for div[qi] not overflowing,
7907 * if we consider the first two terms in the numerator and denominator
7908 * of qdigit, we can see that the final value of div[qi + 1] will be
7909 * approximately a remainder mod (vardigits[0]*NBASE + vardigits[1]).
7910 * Accounting for the lower-order terms is a bit complicated but ends
7911 * up adding not much more than INT_MAX/NBASE to the possible range.
7912 * Thus, div[qi + 1] cannot overflow here, and in its role as div[qi]
7913 * in the next loop iteration, it can't be large enough to cause
7914 * overflow in the carry propagation step (if any), either.
7915 *
7916 * But having said that: div[qi] can be more than INT_MAX/NBASE, as
7917 * noted above, which means that the product div[qi] * NBASE *can*
7918 * overflow. When that happens, adding it to div[qi + 1] will always
7919 * cause a canceling overflow so that the end result is correct. We
7920 * could avoid the intermediate overflow by doing the multiplication
7921 * and addition in int64 arithmetic, but so far there appears no need.
7922 */
7923 div[qi + 1] += div[qi] * NBASE;
7924
7925 div[qi] = qdigit;
7926 }
7927
7928 /*
7929 * Approximate and store the last quotient digit (div[div_ndigits])
7930 */
7931 fdividend = (double) div[qi];
7932 for (i = 1; i < 4; i++)
7933 fdividend *= NBASE;
7934 fquotient = fdividend * fdivisorinverse;
7935 qdigit = (fquotient >= 0.0) ? ((int) fquotient) :
7936 (((int) fquotient) - 1); /* truncate towards -infinity */
7937 div[qi] = qdigit;
7938
7939 /*
7940 * Because the quotient digits might be off by one, some of them might be
7941 * -1 or NBASE at this point. The represented value is correct in a
7942 * mathematical sense, but it doesn't look right. We do a final carry
7943 * propagation pass to normalize the digits, which we combine with storing
7944 * the result digits into the output. Note that this is still done at
7945 * full precision w/guard digits.
7946 */
7947 alloc_var(result, div_ndigits + 1);
7948 res_digits = result->digits;
7949 carry = 0;
7950 for (i = div_ndigits; i >= 0; i--)
7951 {
7952 newdig = div[i] + carry;
7953 if (newdig < 0)
7954 {
7955 carry = -((-newdig - 1) / NBASE) - 1;
7956 newdig -= carry * NBASE;
7957 }
7958 else if (newdig >= NBASE)
7959 {
7960 carry = newdig / NBASE;
7961 newdig -= carry * NBASE;
7962 }
7963 else
7964 carry = 0;
7965 res_digits[i] = newdig;
7966 }
7967 Assert(carry == 0);
7968
7969 pfree(div);
7970
7971 /*
7972 * Finally, round the result to the requested precision.
7973 */
7974 result->weight = res_weight;
7975 result->sign = res_sign;
7976
7977 /* Round to target rscale (and set result->dscale) */
7978 if (round)
7979 round_var(result, rscale);
7980 else
7981 trunc_var(result, rscale);
7982
7983 /* Strip leading and trailing zeroes */
7984 strip_var(result);
7985 }
7986
7987
7988 /*
7989 * Default scale selection for division
7990 *
7991 * Returns the appropriate result scale for the division result.
7992 */
7993 static int
select_div_scale(const NumericVar * var1,const NumericVar * var2)7994 select_div_scale(const NumericVar *var1, const NumericVar *var2)
7995 {
7996 int weight1,
7997 weight2,
7998 qweight,
7999 i;
8000 NumericDigit firstdigit1,
8001 firstdigit2;
8002 int rscale;
8003
8004 /*
8005 * The result scale of a division isn't specified in any SQL standard. For
8006 * PostgreSQL we select a result scale that will give at least
8007 * NUMERIC_MIN_SIG_DIGITS significant digits, so that numeric gives a
8008 * result no less accurate than float8; but use a scale not less than
8009 * either input's display scale.
8010 */
8011
8012 /* Get the actual (normalized) weight and first digit of each input */
8013
8014 weight1 = 0; /* values to use if var1 is zero */
8015 firstdigit1 = 0;
8016 for (i = 0; i < var1->ndigits; i++)
8017 {
8018 firstdigit1 = var1->digits[i];
8019 if (firstdigit1 != 0)
8020 {
8021 weight1 = var1->weight - i;
8022 break;
8023 }
8024 }
8025
8026 weight2 = 0; /* values to use if var2 is zero */
8027 firstdigit2 = 0;
8028 for (i = 0; i < var2->ndigits; i++)
8029 {
8030 firstdigit2 = var2->digits[i];
8031 if (firstdigit2 != 0)
8032 {
8033 weight2 = var2->weight - i;
8034 break;
8035 }
8036 }
8037
8038 /*
8039 * Estimate weight of quotient. If the two first digits are equal, we
8040 * can't be sure, but assume that var1 is less than var2.
8041 */
8042 qweight = weight1 - weight2;
8043 if (firstdigit1 <= firstdigit2)
8044 qweight--;
8045
8046 /* Select result scale */
8047 rscale = NUMERIC_MIN_SIG_DIGITS - qweight * DEC_DIGITS;
8048 rscale = Max(rscale, var1->dscale);
8049 rscale = Max(rscale, var2->dscale);
8050 rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
8051 rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
8052
8053 return rscale;
8054 }
8055
8056
8057 /*
8058 * mod_var() -
8059 *
8060 * Calculate the modulo of two numerics at variable level
8061 */
8062 static void
mod_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8063 mod_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8064 {
8065 NumericVar tmp;
8066
8067 init_var(&tmp);
8068
8069 /* ---------
8070 * We do this using the equation
8071 * mod(x,y) = x - trunc(x/y)*y
8072 * div_var can be persuaded to give us trunc(x/y) directly.
8073 * ----------
8074 */
8075 div_var(var1, var2, &tmp, 0, false);
8076
8077 mul_var(var2, &tmp, &tmp, var2->dscale);
8078
8079 sub_var(var1, &tmp, result);
8080
8081 free_var(&tmp);
8082 }
8083
8084
8085 /*
8086 * div_mod_var() -
8087 *
8088 * Calculate the truncated integer quotient and numeric remainder of two
8089 * numeric variables. The remainder is precise to var2's dscale.
8090 */
8091 static void
div_mod_var(const NumericVar * var1,const NumericVar * var2,NumericVar * quot,NumericVar * rem)8092 div_mod_var(const NumericVar *var1, const NumericVar *var2,
8093 NumericVar *quot, NumericVar *rem)
8094 {
8095 NumericVar q;
8096 NumericVar r;
8097
8098 init_var(&q);
8099 init_var(&r);
8100
8101 /*
8102 * Use div_var_fast() to get an initial estimate for the integer quotient.
8103 * This might be inaccurate (per the warning in div_var_fast's comments),
8104 * but we can correct it below.
8105 */
8106 div_var_fast(var1, var2, &q, 0, false);
8107
8108 /* Compute initial estimate of remainder using the quotient estimate. */
8109 mul_var(var2, &q, &r, var2->dscale);
8110 sub_var(var1, &r, &r);
8111
8112 /*
8113 * Adjust the results if necessary --- the remainder should have the same
8114 * sign as var1, and its absolute value should be less than the absolute
8115 * value of var2.
8116 */
8117 while (r.ndigits != 0 && r.sign != var1->sign)
8118 {
8119 /* The absolute value of the quotient is too large */
8120 if (var1->sign == var2->sign)
8121 {
8122 sub_var(&q, &const_one, &q);
8123 add_var(&r, var2, &r);
8124 }
8125 else
8126 {
8127 add_var(&q, &const_one, &q);
8128 sub_var(&r, var2, &r);
8129 }
8130 }
8131
8132 while (cmp_abs(&r, var2) >= 0)
8133 {
8134 /* The absolute value of the quotient is too small */
8135 if (var1->sign == var2->sign)
8136 {
8137 add_var(&q, &const_one, &q);
8138 sub_var(&r, var2, &r);
8139 }
8140 else
8141 {
8142 sub_var(&q, &const_one, &q);
8143 add_var(&r, var2, &r);
8144 }
8145 }
8146
8147 set_var_from_var(&q, quot);
8148 set_var_from_var(&r, rem);
8149
8150 free_var(&q);
8151 free_var(&r);
8152 }
8153
8154
8155 /*
8156 * ceil_var() -
8157 *
8158 * Return the smallest integer greater than or equal to the argument
8159 * on variable level
8160 */
8161 static void
ceil_var(const NumericVar * var,NumericVar * result)8162 ceil_var(const NumericVar *var, NumericVar *result)
8163 {
8164 NumericVar tmp;
8165
8166 init_var(&tmp);
8167 set_var_from_var(var, &tmp);
8168
8169 trunc_var(&tmp, 0);
8170
8171 if (var->sign == NUMERIC_POS && cmp_var(var, &tmp) != 0)
8172 add_var(&tmp, &const_one, &tmp);
8173
8174 set_var_from_var(&tmp, result);
8175 free_var(&tmp);
8176 }
8177
8178
8179 /*
8180 * floor_var() -
8181 *
8182 * Return the largest integer equal to or less than the argument
8183 * on variable level
8184 */
8185 static void
floor_var(const NumericVar * var,NumericVar * result)8186 floor_var(const NumericVar *var, NumericVar *result)
8187 {
8188 NumericVar tmp;
8189
8190 init_var(&tmp);
8191 set_var_from_var(var, &tmp);
8192
8193 trunc_var(&tmp, 0);
8194
8195 if (var->sign == NUMERIC_NEG && cmp_var(var, &tmp) != 0)
8196 sub_var(&tmp, &const_one, &tmp);
8197
8198 set_var_from_var(&tmp, result);
8199 free_var(&tmp);
8200 }
8201
8202
8203 /*
8204 * gcd_var() -
8205 *
8206 * Calculate the greatest common divisor of two numerics at variable level
8207 */
8208 static void
gcd_var(const NumericVar * var1,const NumericVar * var2,NumericVar * result)8209 gcd_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
8210 {
8211 int res_dscale;
8212 int cmp;
8213 NumericVar tmp_arg;
8214 NumericVar mod;
8215
8216 res_dscale = Max(var1->dscale, var2->dscale);
8217
8218 /*
8219 * Arrange for var1 to be the number with the greater absolute value.
8220 *
8221 * This would happen automatically in the loop below, but avoids an
8222 * expensive modulo operation.
8223 */
8224 cmp = cmp_abs(var1, var2);
8225 if (cmp < 0)
8226 {
8227 const NumericVar *tmp = var1;
8228
8229 var1 = var2;
8230 var2 = tmp;
8231 }
8232
8233 /*
8234 * Also avoid the taking the modulo if the inputs have the same absolute
8235 * value, or if the smaller input is zero.
8236 */
8237 if (cmp == 0 || var2->ndigits == 0)
8238 {
8239 set_var_from_var(var1, result);
8240 result->sign = NUMERIC_POS;
8241 result->dscale = res_dscale;
8242 return;
8243 }
8244
8245 init_var(&tmp_arg);
8246 init_var(&mod);
8247
8248 /* Use the Euclidean algorithm to find the GCD */
8249 set_var_from_var(var1, &tmp_arg);
8250 set_var_from_var(var2, result);
8251
8252 for (;;)
8253 {
8254 /* this loop can take a while, so allow it to be interrupted */
8255 CHECK_FOR_INTERRUPTS();
8256
8257 mod_var(&tmp_arg, result, &mod);
8258 if (mod.ndigits == 0)
8259 break;
8260 set_var_from_var(result, &tmp_arg);
8261 set_var_from_var(&mod, result);
8262 }
8263 result->sign = NUMERIC_POS;
8264 result->dscale = res_dscale;
8265
8266 free_var(&tmp_arg);
8267 free_var(&mod);
8268 }
8269
8270
8271 /*
8272 * sqrt_var() -
8273 *
8274 * Compute the square root of x using the Karatsuba Square Root algorithm.
8275 * NOTE: we allow rscale < 0 here, implying rounding before the decimal
8276 * point.
8277 */
8278 static void
sqrt_var(const NumericVar * arg,NumericVar * result,int rscale)8279 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
8280 {
8281 int stat;
8282 int res_weight;
8283 int res_ndigits;
8284 int src_ndigits;
8285 int step;
8286 int ndigits[32];
8287 int blen;
8288 int64 arg_int64;
8289 int src_idx;
8290 int64 s_int64;
8291 int64 r_int64;
8292 NumericVar s_var;
8293 NumericVar r_var;
8294 NumericVar a0_var;
8295 NumericVar a1_var;
8296 NumericVar q_var;
8297 NumericVar u_var;
8298
8299 stat = cmp_var(arg, &const_zero);
8300 if (stat == 0)
8301 {
8302 zero_var(result);
8303 result->dscale = rscale;
8304 return;
8305 }
8306
8307 /*
8308 * SQL2003 defines sqrt() in terms of power, so we need to emit the right
8309 * SQLSTATE error code if the operand is negative.
8310 */
8311 if (stat < 0)
8312 ereport(ERROR,
8313 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
8314 errmsg("cannot take square root of a negative number")));
8315
8316 init_var(&s_var);
8317 init_var(&r_var);
8318 init_var(&a0_var);
8319 init_var(&a1_var);
8320 init_var(&q_var);
8321 init_var(&u_var);
8322
8323 /*
8324 * The result weight is half the input weight, rounded towards minus
8325 * infinity --- res_weight = floor(arg->weight / 2).
8326 */
8327 if (arg->weight >= 0)
8328 res_weight = arg->weight / 2;
8329 else
8330 res_weight = -((-arg->weight - 1) / 2 + 1);
8331
8332 /*
8333 * Number of NBASE digits to compute. To ensure correct rounding, compute
8334 * at least 1 extra decimal digit. We explicitly allow rscale to be
8335 * negative here, but must always compute at least 1 NBASE digit. Thus
8336 * res_ndigits = res_weight + 1 + ceil((rscale + 1) / DEC_DIGITS) or 1.
8337 */
8338 if (rscale + 1 >= 0)
8339 res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS) / DEC_DIGITS;
8340 else
8341 res_ndigits = res_weight + 1 - (-rscale - 1) / DEC_DIGITS;
8342 res_ndigits = Max(res_ndigits, 1);
8343
8344 /*
8345 * Number of source NBASE digits logically required to produce a result
8346 * with this precision --- every digit before the decimal point, plus 2
8347 * for each result digit after the decimal point (or minus 2 for each
8348 * result digit we round before the decimal point).
8349 */
8350 src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
8351 src_ndigits = Max(src_ndigits, 1);
8352
8353 /* ----------
8354 * From this point on, we treat the input and the result as integers and
8355 * compute the integer square root and remainder using the Karatsuba
8356 * Square Root algorithm, which may be written recursively as follows:
8357 *
8358 * SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
8359 * [ for some base b, and coefficients a0,a1,a2,a3 chosen so that
8360 * 0 <= a0,a1,a2 < b and a3 >= b/4 ]
8361 * Let (s,r) = SqrtRem(a3*b + a2)
8362 * Let (q,u) = DivRem(r*b + a1, 2*s)
8363 * Let s = s*b + q
8364 * Let r = u*b + a0 - q^2
8365 * If r < 0 Then
8366 * Let r = r + s
8367 * Let s = s - 1
8368 * Let r = r + s
8369 * Return (s,r)
8370 *
8371 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
8372 * RR-3805, November 1999. At the time of writing this was available
8373 * on the net at <https://hal.inria.fr/inria-00072854>.
8374 *
8375 * The way to read the assumption "n = a3*b^3 + a2*b^2 + a1*b + a0" is
8376 * "choose a base b such that n requires at least four base-b digits to
8377 * express; then those digits are a3,a2,a1,a0, with a3 possibly larger
8378 * than b". For optimal performance, b should have approximately a
8379 * quarter the number of digits in the input, so that the outer square
8380 * root computes roughly twice as many digits as the inner one. For
8381 * simplicity, we choose b = NBASE^blen, an integer power of NBASE.
8382 *
8383 * We implement the algorithm iteratively rather than recursively, to
8384 * allow the working variables to be reused. With this approach, each
8385 * digit of the input is read precisely once --- src_idx tracks the number
8386 * of input digits used so far.
8387 *
8388 * The array ndigits[] holds the number of NBASE digits of the input that
8389 * will have been used at the end of each iteration, which roughly doubles
8390 * each time. Note that the array elements are stored in reverse order,
8391 * so if the final iteration requires src_ndigits = 37 input digits, the
8392 * array will contain [37,19,11,7,5,3], and we would start by computing
8393 * the square root of the 3 most significant NBASE digits.
8394 *
8395 * In each iteration, we choose blen to be the largest integer for which
8396 * the input number has a3 >= b/4, when written in the form above. In
8397 * general, this means blen = src_ndigits / 4 (truncated), but if
8398 * src_ndigits is a multiple of 4, that might lead to the coefficient a3
8399 * being less than b/4 (if the first input digit is less than NBASE/4), in
8400 * which case we choose blen = src_ndigits / 4 - 1. The number of digits
8401 * in the inner square root is then src_ndigits - 2*blen. So, for
8402 * example, if we have src_ndigits = 26 initially, the array ndigits[]
8403 * will be either [26,14,8,4] or [26,14,8,6,4], depending on the size of
8404 * the first input digit.
8405 *
8406 * Additionally, we can put an upper bound on the number of steps required
8407 * as follows --- suppose that the number of source digits is an n-bit
8408 * number in the range [2^(n-1), 2^n-1], then blen will be in the range
8409 * [2^(n-3)-1, 2^(n-2)-1] and the number of digits in the inner square
8410 * root will be in the range [2^(n-2), 2^(n-1)+1]. In the next step, blen
8411 * will be in the range [2^(n-4)-1, 2^(n-3)] and the number of digits in
8412 * the next inner square root will be in the range [2^(n-3), 2^(n-2)+1].
8413 * This pattern repeats, and in the worst case the array ndigits[] will
8414 * contain [2^n-1, 2^(n-1)+1, 2^(n-2)+1, ... 9, 5, 3], and the computation
8415 * will require n steps. Therefore, since all digit array sizes are
8416 * signed 32-bit integers, the number of steps required is guaranteed to
8417 * be less than 32.
8418 * ----------
8419 */
8420 step = 0;
8421 while ((ndigits[step] = src_ndigits) > 4)
8422 {
8423 /* Choose b so that a3 >= b/4, as described above */
8424 blen = src_ndigits / 4;
8425 if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
8426 blen--;
8427
8428 /* Number of digits in the next step (inner square root) */
8429 src_ndigits -= 2 * blen;
8430 step++;
8431 }
8432
8433 /*
8434 * First iteration (innermost square root and remainder):
8435 *
8436 * Here src_ndigits <= 4, and the input fits in an int64. Its square root
8437 * has at most 9 decimal digits, so estimate it using double precision
8438 * arithmetic, which will in fact almost certainly return the correct
8439 * result with no further correction required.
8440 */
8441 arg_int64 = arg->digits[0];
8442 for (src_idx = 1; src_idx < src_ndigits; src_idx++)
8443 {
8444 arg_int64 *= NBASE;
8445 if (src_idx < arg->ndigits)
8446 arg_int64 += arg->digits[src_idx];
8447 }
8448
8449 s_int64 = (int64) sqrt((double) arg_int64);
8450 r_int64 = arg_int64 - s_int64 * s_int64;
8451
8452 /*
8453 * Use Newton's method to correct the result, if necessary.
8454 *
8455 * This uses integer division with truncation to compute the truncated
8456 * integer square root by iterating using the formula x -> (x + n/x) / 2.
8457 * This is known to converge to isqrt(n), unless n+1 is a perfect square.
8458 * If n+1 is a perfect square, the sequence will oscillate between the two
8459 * values isqrt(n) and isqrt(n)+1, so we can be assured of convergence by
8460 * checking the remainder.
8461 */
8462 while (r_int64 < 0 || r_int64 > 2 * s_int64)
8463 {
8464 s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
8465 r_int64 = arg_int64 - s_int64 * s_int64;
8466 }
8467
8468 /*
8469 * Iterations with src_ndigits <= 8:
8470 *
8471 * The next 1 or 2 iterations compute larger (outer) square roots with
8472 * src_ndigits <= 8, so the result still fits in an int64 (even though the
8473 * input no longer does) and we can continue to compute using int64
8474 * variables to avoid more expensive numeric computations.
8475 *
8476 * It is fairly easy to see that there is no risk of the intermediate
8477 * values below overflowing 64-bit integers. In the worst case, the
8478 * previous iteration will have computed a 3-digit square root (of a
8479 * 6-digit input less than NBASE^6 / 4), so at the start of this
8480 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
8481 * less than 10^12. In this case, blen will be 1, so numer will be less
8482 * than 10^17, and denom will be less than 10^12 (and hence u will also be
8483 * less than 10^12). Finally, since q^2 = u*b + a0 - r, we can also be
8484 * sure that q^2 < 10^17. Therefore all these quantities fit comfortably
8485 * in 64-bit integers.
8486 */
8487 step--;
8488 while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
8489 {
8490 int b;
8491 int a0;
8492 int a1;
8493 int i;
8494 int64 numer;
8495 int64 denom;
8496 int64 q;
8497 int64 u;
8498
8499 blen = (src_ndigits - src_idx) / 2;
8500
8501 /* Extract a1 and a0, and compute b */
8502 a0 = 0;
8503 a1 = 0;
8504 b = 1;
8505
8506 for (i = 0; i < blen; i++, src_idx++)
8507 {
8508 b *= NBASE;
8509 a1 *= NBASE;
8510 if (src_idx < arg->ndigits)
8511 a1 += arg->digits[src_idx];
8512 }
8513
8514 for (i = 0; i < blen; i++, src_idx++)
8515 {
8516 a0 *= NBASE;
8517 if (src_idx < arg->ndigits)
8518 a0 += arg->digits[src_idx];
8519 }
8520
8521 /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
8522 numer = r_int64 * b + a1;
8523 denom = 2 * s_int64;
8524 q = numer / denom;
8525 u = numer - q * denom;
8526
8527 /* Compute s = s*b + q and r = u*b + a0 - q^2 */
8528 s_int64 = s_int64 * b + q;
8529 r_int64 = u * b + a0 - q * q;
8530
8531 if (r_int64 < 0)
8532 {
8533 /* s is too large by 1; set r += s, s--, r += s */
8534 r_int64 += s_int64;
8535 s_int64--;
8536 r_int64 += s_int64;
8537 }
8538
8539 Assert(src_idx == src_ndigits); /* All input digits consumed */
8540 step--;
8541 }
8542
8543 /*
8544 * On platforms with 128-bit integer support, we can further delay the
8545 * need to use numeric variables.
8546 */
8547 #ifdef HAVE_INT128
8548 if (step >= 0)
8549 {
8550 int128 s_int128;
8551 int128 r_int128;
8552
8553 s_int128 = s_int64;
8554 r_int128 = r_int64;
8555
8556 /*
8557 * Iterations with src_ndigits <= 16:
8558 *
8559 * The result fits in an int128 (even though the input doesn't) so we
8560 * use int128 variables to avoid more expensive numeric computations.
8561 */
8562 while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
8563 {
8564 int64 b;
8565 int64 a0;
8566 int64 a1;
8567 int64 i;
8568 int128 numer;
8569 int128 denom;
8570 int128 q;
8571 int128 u;
8572
8573 blen = (src_ndigits - src_idx) / 2;
8574
8575 /* Extract a1 and a0, and compute b */
8576 a0 = 0;
8577 a1 = 0;
8578 b = 1;
8579
8580 for (i = 0; i < blen; i++, src_idx++)
8581 {
8582 b *= NBASE;
8583 a1 *= NBASE;
8584 if (src_idx < arg->ndigits)
8585 a1 += arg->digits[src_idx];
8586 }
8587
8588 for (i = 0; i < blen; i++, src_idx++)
8589 {
8590 a0 *= NBASE;
8591 if (src_idx < arg->ndigits)
8592 a0 += arg->digits[src_idx];
8593 }
8594
8595 /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
8596 numer = r_int128 * b + a1;
8597 denom = 2 * s_int128;
8598 q = numer / denom;
8599 u = numer - q * denom;
8600
8601 /* Compute s = s*b + q and r = u*b + a0 - q^2 */
8602 s_int128 = s_int128 * b + q;
8603 r_int128 = u * b + a0 - q * q;
8604
8605 if (r_int128 < 0)
8606 {
8607 /* s is too large by 1; set r += s, s--, r += s */
8608 r_int128 += s_int128;
8609 s_int128--;
8610 r_int128 += s_int128;
8611 }
8612
8613 Assert(src_idx == src_ndigits); /* All input digits consumed */
8614 step--;
8615 }
8616
8617 /*
8618 * All remaining iterations require numeric variables. Convert the
8619 * integer values to NumericVar and continue. Note that in the final
8620 * iteration we don't need the remainder, so we can save a few cycles
8621 * there by not fully computing it.
8622 */
8623 int128_to_numericvar(s_int128, &s_var);
8624 if (step >= 0)
8625 int128_to_numericvar(r_int128, &r_var);
8626 }
8627 else
8628 {
8629 int64_to_numericvar(s_int64, &s_var);
8630 /* step < 0, so we certainly don't need r */
8631 }
8632 #else /* !HAVE_INT128 */
8633 int64_to_numericvar(s_int64, &s_var);
8634 if (step >= 0)
8635 int64_to_numericvar(r_int64, &r_var);
8636 #endif /* HAVE_INT128 */
8637
8638 /*
8639 * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
8640 * use numeric variables.
8641 */
8642 while (step >= 0)
8643 {
8644 int tmp_len;
8645
8646 src_ndigits = ndigits[step];
8647 blen = (src_ndigits - src_idx) / 2;
8648
8649 /* Extract a1 and a0 */
8650 if (src_idx < arg->ndigits)
8651 {
8652 tmp_len = Min(blen, arg->ndigits - src_idx);
8653 alloc_var(&a1_var, tmp_len);
8654 memcpy(a1_var.digits, arg->digits + src_idx,
8655 tmp_len * sizeof(NumericDigit));
8656 a1_var.weight = blen - 1;
8657 a1_var.sign = NUMERIC_POS;
8658 a1_var.dscale = 0;
8659 strip_var(&a1_var);
8660 }
8661 else
8662 {
8663 zero_var(&a1_var);
8664 a1_var.dscale = 0;
8665 }
8666 src_idx += blen;
8667
8668 if (src_idx < arg->ndigits)
8669 {
8670 tmp_len = Min(blen, arg->ndigits - src_idx);
8671 alloc_var(&a0_var, tmp_len);
8672 memcpy(a0_var.digits, arg->digits + src_idx,
8673 tmp_len * sizeof(NumericDigit));
8674 a0_var.weight = blen - 1;
8675 a0_var.sign = NUMERIC_POS;
8676 a0_var.dscale = 0;
8677 strip_var(&a0_var);
8678 }
8679 else
8680 {
8681 zero_var(&a0_var);
8682 a0_var.dscale = 0;
8683 }
8684 src_idx += blen;
8685
8686 /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
8687 set_var_from_var(&r_var, &q_var);
8688 q_var.weight += blen;
8689 add_var(&q_var, &a1_var, &q_var);
8690 add_var(&s_var, &s_var, &u_var);
8691 div_mod_var(&q_var, &u_var, &q_var, &u_var);
8692
8693 /* Compute s = s*b + q */
8694 s_var.weight += blen;
8695 add_var(&s_var, &q_var, &s_var);
8696
8697 /*
8698 * Compute r = u*b + a0 - q^2.
8699 *
8700 * In the final iteration, we don't actually need r; we just need to
8701 * know whether it is negative, so that we know whether to adjust s.
8702 * So instead of the final subtraction we can just compare.
8703 */
8704 u_var.weight += blen;
8705 add_var(&u_var, &a0_var, &u_var);
8706 mul_var(&q_var, &q_var, &q_var, 0);
8707
8708 if (step > 0)
8709 {
8710 /* Need r for later iterations */
8711 sub_var(&u_var, &q_var, &r_var);
8712 if (r_var.sign == NUMERIC_NEG)
8713 {
8714 /* s is too large by 1; set r += s, s--, r += s */
8715 add_var(&r_var, &s_var, &r_var);
8716 sub_var(&s_var, &const_one, &s_var);
8717 add_var(&r_var, &s_var, &r_var);
8718 }
8719 }
8720 else
8721 {
8722 /* Don't need r anymore, except to test if s is too large by 1 */
8723 if (cmp_var(&u_var, &q_var) < 0)
8724 sub_var(&s_var, &const_one, &s_var);
8725 }
8726
8727 Assert(src_idx == src_ndigits); /* All input digits consumed */
8728 step--;
8729 }
8730
8731 /*
8732 * Construct the final result, rounding it to the requested precision.
8733 */
8734 set_var_from_var(&s_var, result);
8735 result->weight = res_weight;
8736 result->sign = NUMERIC_POS;
8737
8738 /* Round to target rscale (and set result->dscale) */
8739 round_var(result, rscale);
8740
8741 /* Strip leading and trailing zeroes */
8742 strip_var(result);
8743
8744 free_var(&s_var);
8745 free_var(&r_var);
8746 free_var(&a0_var);
8747 free_var(&a1_var);
8748 free_var(&q_var);
8749 free_var(&u_var);
8750 }
8751
8752
8753 /*
8754 * exp_var() -
8755 *
8756 * Raise e to the power of x, computed to rscale fractional digits
8757 */
8758 static void
exp_var(const NumericVar * arg,NumericVar * result,int rscale)8759 exp_var(const NumericVar *arg, NumericVar *result, int rscale)
8760 {
8761 NumericVar x;
8762 NumericVar elem;
8763 NumericVar ni;
8764 double val;
8765 int dweight;
8766 int ndiv2;
8767 int sig_digits;
8768 int local_rscale;
8769
8770 init_var(&x);
8771 init_var(&elem);
8772 init_var(&ni);
8773
8774 set_var_from_var(arg, &x);
8775
8776 /*
8777 * Estimate the dweight of the result using floating point arithmetic, so
8778 * that we can choose an appropriate local rscale for the calculation.
8779 */
8780 val = numericvar_to_double_no_overflow(&x);
8781
8782 /* Guard against overflow/underflow */
8783 /* If you change this limit, see also power_var()'s limit */
8784 if (Abs(val) >= NUMERIC_MAX_RESULT_SCALE * 3)
8785 {
8786 if (val > 0)
8787 ereport(ERROR,
8788 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
8789 errmsg("value overflows numeric format")));
8790 zero_var(result);
8791 result->dscale = rscale;
8792 return;
8793 }
8794
8795 /* decimal weight = log10(e^x) = x * log10(e) */
8796 dweight = (int) (val * 0.434294481903252);
8797
8798 /*
8799 * Reduce x to the range -0.01 <= x <= 0.01 (approximately) by dividing by
8800 * 2^n, to improve the convergence rate of the Taylor series.
8801 */
8802 if (Abs(val) > 0.01)
8803 {
8804 NumericVar tmp;
8805
8806 init_var(&tmp);
8807 set_var_from_var(&const_two, &tmp);
8808
8809 ndiv2 = 1;
8810 val /= 2;
8811
8812 while (Abs(val) > 0.01)
8813 {
8814 ndiv2++;
8815 val /= 2;
8816 add_var(&tmp, &tmp, &tmp);
8817 }
8818
8819 local_rscale = x.dscale + ndiv2;
8820 div_var_fast(&x, &tmp, &x, local_rscale, true);
8821
8822 free_var(&tmp);
8823 }
8824 else
8825 ndiv2 = 0;
8826
8827 /*
8828 * Set the scale for the Taylor series expansion. The final result has
8829 * (dweight + rscale + 1) significant digits. In addition, we have to
8830 * raise the Taylor series result to the power 2^ndiv2, which introduces
8831 * an error of up to around log10(2^ndiv2) digits, so work with this many
8832 * extra digits of precision (plus a few more for good measure).
8833 */
8834 sig_digits = 1 + dweight + rscale + (int) (ndiv2 * 0.301029995663981);
8835 sig_digits = Max(sig_digits, 0) + 8;
8836
8837 local_rscale = sig_digits - 1;
8838
8839 /*
8840 * Use the Taylor series
8841 *
8842 * exp(x) = 1 + x + x^2/2! + x^3/3! + ...
8843 *
8844 * Given the limited range of x, this should converge reasonably quickly.
8845 * We run the series until the terms fall below the local_rscale limit.
8846 */
8847 add_var(&const_one, &x, result);
8848
8849 mul_var(&x, &x, &elem, local_rscale);
8850 set_var_from_var(&const_two, &ni);
8851 div_var_fast(&elem, &ni, &elem, local_rscale, true);
8852
8853 while (elem.ndigits != 0)
8854 {
8855 add_var(result, &elem, result);
8856
8857 mul_var(&elem, &x, &elem, local_rscale);
8858 add_var(&ni, &const_one, &ni);
8859 div_var_fast(&elem, &ni, &elem, local_rscale, true);
8860 }
8861
8862 /*
8863 * Compensate for the argument range reduction. Since the weight of the
8864 * result doubles with each multiplication, we can reduce the local rscale
8865 * as we proceed.
8866 */
8867 while (ndiv2-- > 0)
8868 {
8869 local_rscale = sig_digits - result->weight * 2 * DEC_DIGITS;
8870 local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
8871 mul_var(result, result, result, local_rscale);
8872 }
8873
8874 /* Round to requested rscale */
8875 round_var(result, rscale);
8876
8877 free_var(&x);
8878 free_var(&elem);
8879 free_var(&ni);
8880 }
8881
8882
8883 /*
8884 * Estimate the dweight of the most significant decimal digit of the natural
8885 * logarithm of a number.
8886 *
8887 * Essentially, we're approximating log10(abs(ln(var))). This is used to
8888 * determine the appropriate rscale when computing natural logarithms.
8889 */
8890 static int
estimate_ln_dweight(const NumericVar * var)8891 estimate_ln_dweight(const NumericVar *var)
8892 {
8893 int ln_dweight;
8894
8895 if (cmp_var(var, &const_zero_point_nine) >= 0 &&
8896 cmp_var(var, &const_one_point_one) <= 0)
8897 {
8898 /*
8899 * 0.9 <= var <= 1.1
8900 *
8901 * ln(var) has a negative weight (possibly very large). To get a
8902 * reasonably accurate result, estimate it using ln(1+x) ~= x.
8903 */
8904 NumericVar x;
8905
8906 init_var(&x);
8907 sub_var(var, &const_one, &x);
8908
8909 if (x.ndigits > 0)
8910 {
8911 /* Use weight of most significant decimal digit of x */
8912 ln_dweight = x.weight * DEC_DIGITS + (int) log10(x.digits[0]);
8913 }
8914 else
8915 {
8916 /* x = 0. Since ln(1) = 0 exactly, we don't need extra digits */
8917 ln_dweight = 0;
8918 }
8919
8920 free_var(&x);
8921 }
8922 else
8923 {
8924 /*
8925 * Estimate the logarithm using the first couple of digits from the
8926 * input number. This will give an accurate result whenever the input
8927 * is not too close to 1.
8928 */
8929 if (var->ndigits > 0)
8930 {
8931 int digits;
8932 int dweight;
8933 double ln_var;
8934
8935 digits = var->digits[0];
8936 dweight = var->weight * DEC_DIGITS;
8937
8938 if (var->ndigits > 1)
8939 {
8940 digits = digits * NBASE + var->digits[1];
8941 dweight -= DEC_DIGITS;
8942 }
8943
8944 /*----------
8945 * We have var ~= digits * 10^dweight
8946 * so ln(var) ~= ln(digits) + dweight * ln(10)
8947 *----------
8948 */
8949 ln_var = log((double) digits) + dweight * 2.302585092994046;
8950 ln_dweight = (int) log10(Abs(ln_var));
8951 }
8952 else
8953 {
8954 /* Caller should fail on ln(0), but for the moment return zero */
8955 ln_dweight = 0;
8956 }
8957 }
8958
8959 return ln_dweight;
8960 }
8961
8962
8963 /*
8964 * ln_var() -
8965 *
8966 * Compute the natural log of x
8967 */
8968 static void
ln_var(const NumericVar * arg,NumericVar * result,int rscale)8969 ln_var(const NumericVar *arg, NumericVar *result, int rscale)
8970 {
8971 NumericVar x;
8972 NumericVar xx;
8973 NumericVar ni;
8974 NumericVar elem;
8975 NumericVar fact;
8976 int nsqrt;
8977 int local_rscale;
8978 int cmp;
8979
8980 cmp = cmp_var(arg, &const_zero);
8981 if (cmp == 0)
8982 ereport(ERROR,
8983 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
8984 errmsg("cannot take logarithm of zero")));
8985 else if (cmp < 0)
8986 ereport(ERROR,
8987 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_LOG),
8988 errmsg("cannot take logarithm of a negative number")));
8989
8990 init_var(&x);
8991 init_var(&xx);
8992 init_var(&ni);
8993 init_var(&elem);
8994 init_var(&fact);
8995
8996 set_var_from_var(arg, &x);
8997 set_var_from_var(&const_two, &fact);
8998
8999 /*
9000 * Reduce input into range 0.9 < x < 1.1 with repeated sqrt() operations.
9001 *
9002 * The final logarithm will have up to around rscale+6 significant digits.
9003 * Each sqrt() will roughly halve the weight of x, so adjust the local
9004 * rscale as we work so that we keep this many significant digits at each
9005 * step (plus a few more for good measure).
9006 *
9007 * Note that we allow local_rscale < 0 during this input reduction
9008 * process, which implies rounding before the decimal point. sqrt_var()
9009 * explicitly supports this, and it significantly reduces the work
9010 * required to reduce very large inputs to the required range. Once the
9011 * input reduction is complete, x.weight will be 0 and its display scale
9012 * will be non-negative again.
9013 */
9014 nsqrt = 0;
9015 while (cmp_var(&x, &const_zero_point_nine) <= 0)
9016 {
9017 local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
9018 sqrt_var(&x, &x, local_rscale);
9019 mul_var(&fact, &const_two, &fact, 0);
9020 nsqrt++;
9021 }
9022 while (cmp_var(&x, &const_one_point_one) >= 0)
9023 {
9024 local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
9025 sqrt_var(&x, &x, local_rscale);
9026 mul_var(&fact, &const_two, &fact, 0);
9027 nsqrt++;
9028 }
9029
9030 /*
9031 * We use the Taylor series for 0.5 * ln((1+z)/(1-z)),
9032 *
9033 * z + z^3/3 + z^5/5 + ...
9034 *
9035 * where z = (x-1)/(x+1) is in the range (approximately) -0.053 .. 0.048
9036 * due to the above range-reduction of x.
9037 *
9038 * The convergence of this is not as fast as one would like, but is
9039 * tolerable given that z is small.
9040 *
9041 * The Taylor series result will be multiplied by 2^(nsqrt+1), which has a
9042 * decimal weight of (nsqrt+1) * log10(2), so work with this many extra
9043 * digits of precision (plus a few more for good measure).
9044 */
9045 local_rscale = rscale + (int) ((nsqrt + 1) * 0.301029995663981) + 8;
9046
9047 sub_var(&x, &const_one, result);
9048 add_var(&x, &const_one, &elem);
9049 div_var_fast(result, &elem, result, local_rscale, true);
9050 set_var_from_var(result, &xx);
9051 mul_var(result, result, &x, local_rscale);
9052
9053 set_var_from_var(&const_one, &ni);
9054
9055 for (;;)
9056 {
9057 add_var(&ni, &const_two, &ni);
9058 mul_var(&xx, &x, &xx, local_rscale);
9059 div_var_fast(&xx, &ni, &elem, local_rscale, true);
9060
9061 if (elem.ndigits == 0)
9062 break;
9063
9064 add_var(result, &elem, result);
9065
9066 if (elem.weight < (result->weight - local_rscale * 2 / DEC_DIGITS))
9067 break;
9068 }
9069
9070 /* Compensate for argument range reduction, round to requested rscale */
9071 mul_var(result, &fact, result, rscale);
9072
9073 free_var(&x);
9074 free_var(&xx);
9075 free_var(&ni);
9076 free_var(&elem);
9077 free_var(&fact);
9078 }
9079
9080
9081 /*
9082 * log_var() -
9083 *
9084 * Compute the logarithm of num in a given base.
9085 *
9086 * Note: this routine chooses dscale of the result.
9087 */
9088 static void
log_var(const NumericVar * base,const NumericVar * num,NumericVar * result)9089 log_var(const NumericVar *base, const NumericVar *num, NumericVar *result)
9090 {
9091 NumericVar ln_base;
9092 NumericVar ln_num;
9093 int ln_base_dweight;
9094 int ln_num_dweight;
9095 int result_dweight;
9096 int rscale;
9097 int ln_base_rscale;
9098 int ln_num_rscale;
9099
9100 init_var(&ln_base);
9101 init_var(&ln_num);
9102
9103 /* Estimated dweights of ln(base), ln(num) and the final result */
9104 ln_base_dweight = estimate_ln_dweight(base);
9105 ln_num_dweight = estimate_ln_dweight(num);
9106 result_dweight = ln_num_dweight - ln_base_dweight;
9107
9108 /*
9109 * Select the scale of the result so that it will have at least
9110 * NUMERIC_MIN_SIG_DIGITS significant digits and is not less than either
9111 * input's display scale.
9112 */
9113 rscale = NUMERIC_MIN_SIG_DIGITS - result_dweight;
9114 rscale = Max(rscale, base->dscale);
9115 rscale = Max(rscale, num->dscale);
9116 rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
9117 rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
9118
9119 /*
9120 * Set the scales for ln(base) and ln(num) so that they each have more
9121 * significant digits than the final result.
9122 */
9123 ln_base_rscale = rscale + result_dweight - ln_base_dweight + 8;
9124 ln_base_rscale = Max(ln_base_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9125
9126 ln_num_rscale = rscale + result_dweight - ln_num_dweight + 8;
9127 ln_num_rscale = Max(ln_num_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9128
9129 /* Form natural logarithms */
9130 ln_var(base, &ln_base, ln_base_rscale);
9131 ln_var(num, &ln_num, ln_num_rscale);
9132
9133 /* Divide and round to the required scale */
9134 div_var_fast(&ln_num, &ln_base, result, rscale, true);
9135
9136 free_var(&ln_num);
9137 free_var(&ln_base);
9138 }
9139
9140
9141 /*
9142 * power_var() -
9143 *
9144 * Raise base to the power of exp
9145 *
9146 * Note: this routine chooses dscale of the result.
9147 */
9148 static void
power_var(const NumericVar * base,const NumericVar * exp,NumericVar * result)9149 power_var(const NumericVar *base, const NumericVar *exp, NumericVar *result)
9150 {
9151 int res_sign;
9152 NumericVar abs_base;
9153 NumericVar ln_base;
9154 NumericVar ln_num;
9155 int ln_dweight;
9156 int rscale;
9157 int sig_digits;
9158 int local_rscale;
9159 double val;
9160
9161 /* If exp can be represented as an integer, use power_var_int */
9162 if (exp->ndigits == 0 || exp->ndigits <= exp->weight + 1)
9163 {
9164 /* exact integer, but does it fit in int? */
9165 int64 expval64;
9166
9167 if (numericvar_to_int64(exp, &expval64))
9168 {
9169 if (expval64 >= PG_INT32_MIN && expval64 <= PG_INT32_MAX)
9170 {
9171 /* Okay, select rscale */
9172 rscale = NUMERIC_MIN_SIG_DIGITS;
9173 rscale = Max(rscale, base->dscale);
9174 rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
9175 rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
9176
9177 power_var_int(base, (int) expval64, result, rscale);
9178 return;
9179 }
9180 }
9181 }
9182
9183 /*
9184 * This avoids log(0) for cases of 0 raised to a non-integer. 0 ^ 0 is
9185 * handled by power_var_int().
9186 */
9187 if (cmp_var(base, &const_zero) == 0)
9188 {
9189 set_var_from_var(&const_zero, result);
9190 result->dscale = NUMERIC_MIN_SIG_DIGITS; /* no need to round */
9191 return;
9192 }
9193
9194 init_var(&abs_base);
9195 init_var(&ln_base);
9196 init_var(&ln_num);
9197
9198 /*
9199 * If base is negative, insist that exp be an integer. The result is then
9200 * positive if exp is even and negative if exp is odd.
9201 */
9202 if (base->sign == NUMERIC_NEG)
9203 {
9204 /*
9205 * Check that exp is an integer. This error code is defined by the
9206 * SQL standard, and matches other errors in numeric_power().
9207 */
9208 if (exp->ndigits > 0 && exp->ndigits > exp->weight + 1)
9209 ereport(ERROR,
9210 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
9211 errmsg("a negative number raised to a non-integer power yields a complex result")));
9212
9213 /* Test if exp is odd or even */
9214 if (exp->ndigits > 0 && exp->ndigits == exp->weight + 1 &&
9215 (exp->digits[exp->ndigits - 1] & 1))
9216 res_sign = NUMERIC_NEG;
9217 else
9218 res_sign = NUMERIC_POS;
9219
9220 /* Then work with abs(base) below */
9221 set_var_from_var(base, &abs_base);
9222 abs_base.sign = NUMERIC_POS;
9223 base = &abs_base;
9224 }
9225 else
9226 res_sign = NUMERIC_POS;
9227
9228 /*----------
9229 * Decide on the scale for the ln() calculation. For this we need an
9230 * estimate of the weight of the result, which we obtain by doing an
9231 * initial low-precision calculation of exp * ln(base).
9232 *
9233 * We want result = e ^ (exp * ln(base))
9234 * so result dweight = log10(result) = exp * ln(base) * log10(e)
9235 *
9236 * We also perform a crude overflow test here so that we can exit early if
9237 * the full-precision result is sure to overflow, and to guard against
9238 * integer overflow when determining the scale for the real calculation.
9239 * exp_var() supports inputs up to NUMERIC_MAX_RESULT_SCALE * 3, so the
9240 * result will overflow if exp * ln(base) >= NUMERIC_MAX_RESULT_SCALE * 3.
9241 * Since the values here are only approximations, we apply a small fuzz
9242 * factor to this overflow test and let exp_var() determine the exact
9243 * overflow threshold so that it is consistent for all inputs.
9244 *----------
9245 */
9246 ln_dweight = estimate_ln_dweight(base);
9247
9248 /*
9249 * Set the scale for the low-precision calculation, computing ln(base) to
9250 * around 8 significant digits. Note that ln_dweight may be as small as
9251 * -SHRT_MAX, so the scale may exceed NUMERIC_MAX_DISPLAY_SCALE here.
9252 */
9253 local_rscale = 8 - ln_dweight;
9254 local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9255
9256 ln_var(base, &ln_base, local_rscale);
9257
9258 mul_var(&ln_base, exp, &ln_num, local_rscale);
9259
9260 val = numericvar_to_double_no_overflow(&ln_num);
9261
9262 /* initial overflow/underflow test with fuzz factor */
9263 if (Abs(val) > NUMERIC_MAX_RESULT_SCALE * 3.01)
9264 {
9265 if (val > 0)
9266 ereport(ERROR,
9267 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
9268 errmsg("value overflows numeric format")));
9269 zero_var(result);
9270 result->dscale = NUMERIC_MAX_DISPLAY_SCALE;
9271 return;
9272 }
9273
9274 val *= 0.434294481903252; /* approximate decimal result weight */
9275
9276 /* choose the result scale */
9277 rscale = NUMERIC_MIN_SIG_DIGITS - (int) val;
9278 rscale = Max(rscale, base->dscale);
9279 rscale = Max(rscale, exp->dscale);
9280 rscale = Max(rscale, NUMERIC_MIN_DISPLAY_SCALE);
9281 rscale = Min(rscale, NUMERIC_MAX_DISPLAY_SCALE);
9282
9283 /* significant digits required in the result */
9284 sig_digits = rscale + (int) val;
9285 sig_digits = Max(sig_digits, 0);
9286
9287 /* set the scale for the real exp * ln(base) calculation */
9288 local_rscale = sig_digits - ln_dweight + 8;
9289 local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9290
9291 /* and do the real calculation */
9292
9293 ln_var(base, &ln_base, local_rscale);
9294
9295 mul_var(&ln_base, exp, &ln_num, local_rscale);
9296
9297 exp_var(&ln_num, result, rscale);
9298
9299 if (res_sign == NUMERIC_NEG && result->ndigits > 0)
9300 result->sign = NUMERIC_NEG;
9301
9302 free_var(&ln_num);
9303 free_var(&ln_base);
9304 free_var(&abs_base);
9305 }
9306
9307 /*
9308 * power_var_int() -
9309 *
9310 * Raise base to the power of exp, where exp is an integer.
9311 */
9312 static void
power_var_int(const NumericVar * base,int exp,NumericVar * result,int rscale)9313 power_var_int(const NumericVar *base, int exp, NumericVar *result, int rscale)
9314 {
9315 double f;
9316 int p;
9317 int i;
9318 int sig_digits;
9319 unsigned int mask;
9320 bool neg;
9321 NumericVar base_prod;
9322 int local_rscale;
9323
9324 /* Handle some common special cases, as well as corner cases */
9325 switch (exp)
9326 {
9327 case 0:
9328
9329 /*
9330 * While 0 ^ 0 can be either 1 or indeterminate (error), we treat
9331 * it as 1 because most programming languages do this. SQL:2003
9332 * also requires a return value of 1.
9333 * https://en.wikipedia.org/wiki/Exponentiation#Zero_to_the_zero_power
9334 */
9335 set_var_from_var(&const_one, result);
9336 result->dscale = rscale; /* no need to round */
9337 return;
9338 case 1:
9339 set_var_from_var(base, result);
9340 round_var(result, rscale);
9341 return;
9342 case -1:
9343 div_var(&const_one, base, result, rscale, true);
9344 return;
9345 case 2:
9346 mul_var(base, base, result, rscale);
9347 return;
9348 default:
9349 break;
9350 }
9351
9352 /* Handle the special case where the base is zero */
9353 if (base->ndigits == 0)
9354 {
9355 if (exp < 0)
9356 ereport(ERROR,
9357 (errcode(ERRCODE_DIVISION_BY_ZERO),
9358 errmsg("division by zero")));
9359 zero_var(result);
9360 result->dscale = rscale;
9361 return;
9362 }
9363
9364 /*
9365 * The general case repeatedly multiplies base according to the bit
9366 * pattern of exp.
9367 *
9368 * First we need to estimate the weight of the result so that we know how
9369 * many significant digits are needed.
9370 */
9371 f = base->digits[0];
9372 p = base->weight * DEC_DIGITS;
9373
9374 for (i = 1; i < base->ndigits && i * DEC_DIGITS < 16; i++)
9375 {
9376 f = f * NBASE + base->digits[i];
9377 p -= DEC_DIGITS;
9378 }
9379
9380 /*----------
9381 * We have base ~= f * 10^p
9382 * so log10(result) = log10(base^exp) ~= exp * (log10(f) + p)
9383 *----------
9384 */
9385 f = exp * (log10(f) + p);
9386
9387 /*
9388 * Apply crude overflow/underflow tests so we can exit early if the result
9389 * certainly will overflow/underflow.
9390 */
9391 if (f > 3 * SHRT_MAX * DEC_DIGITS)
9392 ereport(ERROR,
9393 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
9394 errmsg("value overflows numeric format")));
9395 if (f + 1 < -rscale || f + 1 < -NUMERIC_MAX_DISPLAY_SCALE)
9396 {
9397 zero_var(result);
9398 result->dscale = rscale;
9399 return;
9400 }
9401
9402 /*
9403 * Approximate number of significant digits in the result. Note that the
9404 * underflow test above means that this is necessarily >= 0.
9405 */
9406 sig_digits = 1 + rscale + (int) f;
9407
9408 /*
9409 * The multiplications to produce the result may introduce an error of up
9410 * to around log10(abs(exp)) digits, so work with this many extra digits
9411 * of precision (plus a few more for good measure).
9412 */
9413 sig_digits += (int) log(fabs((double) exp)) + 8;
9414
9415 /*
9416 * Now we can proceed with the multiplications.
9417 */
9418 neg = (exp < 0);
9419 mask = Abs(exp);
9420
9421 init_var(&base_prod);
9422 set_var_from_var(base, &base_prod);
9423
9424 if (mask & 1)
9425 set_var_from_var(base, result);
9426 else
9427 set_var_from_var(&const_one, result);
9428
9429 while ((mask >>= 1) > 0)
9430 {
9431 /*
9432 * Do the multiplications using rscales large enough to hold the
9433 * results to the required number of significant digits, but don't
9434 * waste time by exceeding the scales of the numbers themselves.
9435 */
9436 local_rscale = sig_digits - 2 * base_prod.weight * DEC_DIGITS;
9437 local_rscale = Min(local_rscale, 2 * base_prod.dscale);
9438 local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9439
9440 mul_var(&base_prod, &base_prod, &base_prod, local_rscale);
9441
9442 if (mask & 1)
9443 {
9444 local_rscale = sig_digits -
9445 (base_prod.weight + result->weight) * DEC_DIGITS;
9446 local_rscale = Min(local_rscale,
9447 base_prod.dscale + result->dscale);
9448 local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
9449
9450 mul_var(&base_prod, result, result, local_rscale);
9451 }
9452
9453 /*
9454 * When abs(base) > 1, the number of digits to the left of the decimal
9455 * point in base_prod doubles at each iteration, so if exp is large we
9456 * could easily spend large amounts of time and memory space doing the
9457 * multiplications. But once the weight exceeds what will fit in
9458 * int16, the final result is guaranteed to overflow (or underflow, if
9459 * exp < 0), so we can give up before wasting too many cycles.
9460 */
9461 if (base_prod.weight > SHRT_MAX || result->weight > SHRT_MAX)
9462 {
9463 /* overflow, unless neg, in which case result should be 0 */
9464 if (!neg)
9465 ereport(ERROR,
9466 (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
9467 errmsg("value overflows numeric format")));
9468 zero_var(result);
9469 neg = false;
9470 break;
9471 }
9472 }
9473
9474 free_var(&base_prod);
9475
9476 /* Compensate for input sign, and round to requested rscale */
9477 if (neg)
9478 div_var_fast(&const_one, result, result, rscale, true);
9479 else
9480 round_var(result, rscale);
9481 }
9482
9483 /*
9484 * power_ten_int() -
9485 *
9486 * Raise ten to the power of exp, where exp is an integer. Note that unlike
9487 * power_var_int(), this does no overflow/underflow checking or rounding.
9488 */
9489 static void
power_ten_int(int exp,NumericVar * result)9490 power_ten_int(int exp, NumericVar *result)
9491 {
9492 /* Construct the result directly, starting from 10^0 = 1 */
9493 set_var_from_var(&const_one, result);
9494
9495 /* Scale needed to represent the result exactly */
9496 result->dscale = exp < 0 ? -exp : 0;
9497
9498 /* Base-NBASE weight of result and remaining exponent */
9499 if (exp >= 0)
9500 result->weight = exp / DEC_DIGITS;
9501 else
9502 result->weight = (exp + 1) / DEC_DIGITS - 1;
9503
9504 exp -= result->weight * DEC_DIGITS;
9505
9506 /* Final adjustment of the result's single NBASE digit */
9507 while (exp-- > 0)
9508 result->digits[0] *= 10;
9509 }
9510
9511
9512 /* ----------------------------------------------------------------------
9513 *
9514 * Following are the lowest level functions that operate unsigned
9515 * on the variable level
9516 *
9517 * ----------------------------------------------------------------------
9518 */
9519
9520
9521 /* ----------
9522 * cmp_abs() -
9523 *
9524 * Compare the absolute values of var1 and var2
9525 * Returns: -1 for ABS(var1) < ABS(var2)
9526 * 0 for ABS(var1) == ABS(var2)
9527 * 1 for ABS(var1) > ABS(var2)
9528 * ----------
9529 */
9530 static int
cmp_abs(const NumericVar * var1,const NumericVar * var2)9531 cmp_abs(const NumericVar *var1, const NumericVar *var2)
9532 {
9533 return cmp_abs_common(var1->digits, var1->ndigits, var1->weight,
9534 var2->digits, var2->ndigits, var2->weight);
9535 }
9536
9537 /* ----------
9538 * cmp_abs_common() -
9539 *
9540 * Main routine of cmp_abs(). This function can be used by both
9541 * NumericVar and Numeric.
9542 * ----------
9543 */
9544 static int
cmp_abs_common(const NumericDigit * var1digits,int var1ndigits,int var1weight,const NumericDigit * var2digits,int var2ndigits,int var2weight)9545 cmp_abs_common(const NumericDigit *var1digits, int var1ndigits, int var1weight,
9546 const NumericDigit *var2digits, int var2ndigits, int var2weight)
9547 {
9548 int i1 = 0;
9549 int i2 = 0;
9550
9551 /* Check any digits before the first common digit */
9552
9553 while (var1weight > var2weight && i1 < var1ndigits)
9554 {
9555 if (var1digits[i1++] != 0)
9556 return 1;
9557 var1weight--;
9558 }
9559 while (var2weight > var1weight && i2 < var2ndigits)
9560 {
9561 if (var2digits[i2++] != 0)
9562 return -1;
9563 var2weight--;
9564 }
9565
9566 /* At this point, either w1 == w2 or we've run out of digits */
9567
9568 if (var1weight == var2weight)
9569 {
9570 while (i1 < var1ndigits && i2 < var2ndigits)
9571 {
9572 int stat = var1digits[i1++] - var2digits[i2++];
9573
9574 if (stat)
9575 {
9576 if (stat > 0)
9577 return 1;
9578 return -1;
9579 }
9580 }
9581 }
9582
9583 /*
9584 * At this point, we've run out of digits on one side or the other; so any
9585 * remaining nonzero digits imply that side is larger
9586 */
9587 while (i1 < var1ndigits)
9588 {
9589 if (var1digits[i1++] != 0)
9590 return 1;
9591 }
9592 while (i2 < var2ndigits)
9593 {
9594 if (var2digits[i2++] != 0)
9595 return -1;
9596 }
9597
9598 return 0;
9599 }
9600
9601
9602 /*
9603 * add_abs() -
9604 *
9605 * Add the absolute values of two variables into result.
9606 * result might point to one of the operands without danger.
9607 */
9608 static void
add_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)9609 add_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
9610 {
9611 NumericDigit *res_buf;
9612 NumericDigit *res_digits;
9613 int res_ndigits;
9614 int res_weight;
9615 int res_rscale,
9616 rscale1,
9617 rscale2;
9618 int res_dscale;
9619 int i,
9620 i1,
9621 i2;
9622 int carry = 0;
9623
9624 /* copy these values into local vars for speed in inner loop */
9625 int var1ndigits = var1->ndigits;
9626 int var2ndigits = var2->ndigits;
9627 NumericDigit *var1digits = var1->digits;
9628 NumericDigit *var2digits = var2->digits;
9629
9630 res_weight = Max(var1->weight, var2->weight) + 1;
9631
9632 res_dscale = Max(var1->dscale, var2->dscale);
9633
9634 /* Note: here we are figuring rscale in base-NBASE digits */
9635 rscale1 = var1->ndigits - var1->weight - 1;
9636 rscale2 = var2->ndigits - var2->weight - 1;
9637 res_rscale = Max(rscale1, rscale2);
9638
9639 res_ndigits = res_rscale + res_weight + 1;
9640 if (res_ndigits <= 0)
9641 res_ndigits = 1;
9642
9643 res_buf = digitbuf_alloc(res_ndigits + 1);
9644 res_buf[0] = 0; /* spare digit for later rounding */
9645 res_digits = res_buf + 1;
9646
9647 i1 = res_rscale + var1->weight + 1;
9648 i2 = res_rscale + var2->weight + 1;
9649 for (i = res_ndigits - 1; i >= 0; i--)
9650 {
9651 i1--;
9652 i2--;
9653 if (i1 >= 0 && i1 < var1ndigits)
9654 carry += var1digits[i1];
9655 if (i2 >= 0 && i2 < var2ndigits)
9656 carry += var2digits[i2];
9657
9658 if (carry >= NBASE)
9659 {
9660 res_digits[i] = carry - NBASE;
9661 carry = 1;
9662 }
9663 else
9664 {
9665 res_digits[i] = carry;
9666 carry = 0;
9667 }
9668 }
9669
9670 Assert(carry == 0); /* else we failed to allow for carry out */
9671
9672 digitbuf_free(result->buf);
9673 result->ndigits = res_ndigits;
9674 result->buf = res_buf;
9675 result->digits = res_digits;
9676 result->weight = res_weight;
9677 result->dscale = res_dscale;
9678
9679 /* Remove leading/trailing zeroes */
9680 strip_var(result);
9681 }
9682
9683
9684 /*
9685 * sub_abs()
9686 *
9687 * Subtract the absolute value of var2 from the absolute value of var1
9688 * and store in result. result might point to one of the operands
9689 * without danger.
9690 *
9691 * ABS(var1) MUST BE GREATER OR EQUAL ABS(var2) !!!
9692 */
9693 static void
sub_abs(const NumericVar * var1,const NumericVar * var2,NumericVar * result)9694 sub_abs(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
9695 {
9696 NumericDigit *res_buf;
9697 NumericDigit *res_digits;
9698 int res_ndigits;
9699 int res_weight;
9700 int res_rscale,
9701 rscale1,
9702 rscale2;
9703 int res_dscale;
9704 int i,
9705 i1,
9706 i2;
9707 int borrow = 0;
9708
9709 /* copy these values into local vars for speed in inner loop */
9710 int var1ndigits = var1->ndigits;
9711 int var2ndigits = var2->ndigits;
9712 NumericDigit *var1digits = var1->digits;
9713 NumericDigit *var2digits = var2->digits;
9714
9715 res_weight = var1->weight;
9716
9717 res_dscale = Max(var1->dscale, var2->dscale);
9718
9719 /* Note: here we are figuring rscale in base-NBASE digits */
9720 rscale1 = var1->ndigits - var1->weight - 1;
9721 rscale2 = var2->ndigits - var2->weight - 1;
9722 res_rscale = Max(rscale1, rscale2);
9723
9724 res_ndigits = res_rscale + res_weight + 1;
9725 if (res_ndigits <= 0)
9726 res_ndigits = 1;
9727
9728 res_buf = digitbuf_alloc(res_ndigits + 1);
9729 res_buf[0] = 0; /* spare digit for later rounding */
9730 res_digits = res_buf + 1;
9731
9732 i1 = res_rscale + var1->weight + 1;
9733 i2 = res_rscale + var2->weight + 1;
9734 for (i = res_ndigits - 1; i >= 0; i--)
9735 {
9736 i1--;
9737 i2--;
9738 if (i1 >= 0 && i1 < var1ndigits)
9739 borrow += var1digits[i1];
9740 if (i2 >= 0 && i2 < var2ndigits)
9741 borrow -= var2digits[i2];
9742
9743 if (borrow < 0)
9744 {
9745 res_digits[i] = borrow + NBASE;
9746 borrow = -1;
9747 }
9748 else
9749 {
9750 res_digits[i] = borrow;
9751 borrow = 0;
9752 }
9753 }
9754
9755 Assert(borrow == 0); /* else caller gave us var1 < var2 */
9756
9757 digitbuf_free(result->buf);
9758 result->ndigits = res_ndigits;
9759 result->buf = res_buf;
9760 result->digits = res_digits;
9761 result->weight = res_weight;
9762 result->dscale = res_dscale;
9763
9764 /* Remove leading/trailing zeroes */
9765 strip_var(result);
9766 }
9767
9768 /*
9769 * round_var
9770 *
9771 * Round the value of a variable to no more than rscale decimal digits
9772 * after the decimal point. NOTE: we allow rscale < 0 here, implying
9773 * rounding before the decimal point.
9774 */
9775 static void
round_var(NumericVar * var,int rscale)9776 round_var(NumericVar *var, int rscale)
9777 {
9778 NumericDigit *digits = var->digits;
9779 int di;
9780 int ndigits;
9781 int carry;
9782
9783 var->dscale = rscale;
9784
9785 /* decimal digits wanted */
9786 di = (var->weight + 1) * DEC_DIGITS + rscale;
9787
9788 /*
9789 * If di = 0, the value loses all digits, but could round up to 1 if its
9790 * first extra digit is >= 5. If di < 0 the result must be 0.
9791 */
9792 if (di < 0)
9793 {
9794 var->ndigits = 0;
9795 var->weight = 0;
9796 var->sign = NUMERIC_POS;
9797 }
9798 else
9799 {
9800 /* NBASE digits wanted */
9801 ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
9802
9803 /* 0, or number of decimal digits to keep in last NBASE digit */
9804 di %= DEC_DIGITS;
9805
9806 if (ndigits < var->ndigits ||
9807 (ndigits == var->ndigits && di > 0))
9808 {
9809 var->ndigits = ndigits;
9810
9811 #if DEC_DIGITS == 1
9812 /* di must be zero */
9813 carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
9814 #else
9815 if (di == 0)
9816 carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0;
9817 else
9818 {
9819 /* Must round within last NBASE digit */
9820 int extra,
9821 pow10;
9822
9823 #if DEC_DIGITS == 4
9824 pow10 = round_powers[di];
9825 #elif DEC_DIGITS == 2
9826 pow10 = 10;
9827 #else
9828 #error unsupported NBASE
9829 #endif
9830 extra = digits[--ndigits] % pow10;
9831 digits[ndigits] -= extra;
9832 carry = 0;
9833 if (extra >= pow10 / 2)
9834 {
9835 pow10 += digits[ndigits];
9836 if (pow10 >= NBASE)
9837 {
9838 pow10 -= NBASE;
9839 carry = 1;
9840 }
9841 digits[ndigits] = pow10;
9842 }
9843 }
9844 #endif
9845
9846 /* Propagate carry if needed */
9847 while (carry)
9848 {
9849 carry += digits[--ndigits];
9850 if (carry >= NBASE)
9851 {
9852 digits[ndigits] = carry - NBASE;
9853 carry = 1;
9854 }
9855 else
9856 {
9857 digits[ndigits] = carry;
9858 carry = 0;
9859 }
9860 }
9861
9862 if (ndigits < 0)
9863 {
9864 Assert(ndigits == -1); /* better not have added > 1 digit */
9865 Assert(var->digits > var->buf);
9866 var->digits--;
9867 var->ndigits++;
9868 var->weight++;
9869 }
9870 }
9871 }
9872 }
9873
9874 /*
9875 * trunc_var
9876 *
9877 * Truncate (towards zero) the value of a variable at rscale decimal digits
9878 * after the decimal point. NOTE: we allow rscale < 0 here, implying
9879 * truncation before the decimal point.
9880 */
9881 static void
trunc_var(NumericVar * var,int rscale)9882 trunc_var(NumericVar *var, int rscale)
9883 {
9884 int di;
9885 int ndigits;
9886
9887 var->dscale = rscale;
9888
9889 /* decimal digits wanted */
9890 di = (var->weight + 1) * DEC_DIGITS + rscale;
9891
9892 /*
9893 * If di <= 0, the value loses all digits.
9894 */
9895 if (di <= 0)
9896 {
9897 var->ndigits = 0;
9898 var->weight = 0;
9899 var->sign = NUMERIC_POS;
9900 }
9901 else
9902 {
9903 /* NBASE digits wanted */
9904 ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS;
9905
9906 if (ndigits <= var->ndigits)
9907 {
9908 var->ndigits = ndigits;
9909
9910 #if DEC_DIGITS == 1
9911 /* no within-digit stuff to worry about */
9912 #else
9913 /* 0, or number of decimal digits to keep in last NBASE digit */
9914 di %= DEC_DIGITS;
9915
9916 if (di > 0)
9917 {
9918 /* Must truncate within last NBASE digit */
9919 NumericDigit *digits = var->digits;
9920 int extra,
9921 pow10;
9922
9923 #if DEC_DIGITS == 4
9924 pow10 = round_powers[di];
9925 #elif DEC_DIGITS == 2
9926 pow10 = 10;
9927 #else
9928 #error unsupported NBASE
9929 #endif
9930 extra = digits[--ndigits] % pow10;
9931 digits[ndigits] -= extra;
9932 }
9933 #endif
9934 }
9935 }
9936 }
9937
9938 /*
9939 * strip_var
9940 *
9941 * Strip any leading and trailing zeroes from a numeric variable
9942 */
9943 static void
strip_var(NumericVar * var)9944 strip_var(NumericVar *var)
9945 {
9946 NumericDigit *digits = var->digits;
9947 int ndigits = var->ndigits;
9948
9949 /* Strip leading zeroes */
9950 while (ndigits > 0 && *digits == 0)
9951 {
9952 digits++;
9953 var->weight--;
9954 ndigits--;
9955 }
9956
9957 /* Strip trailing zeroes */
9958 while (ndigits > 0 && digits[ndigits - 1] == 0)
9959 ndigits--;
9960
9961 /* If it's zero, normalize the sign and weight */
9962 if (ndigits == 0)
9963 {
9964 var->sign = NUMERIC_POS;
9965 var->weight = 0;
9966 }
9967
9968 var->digits = digits;
9969 var->ndigits = ndigits;
9970 }
9971
9972
9973 /* ----------------------------------------------------------------------
9974 *
9975 * Fast sum accumulator functions
9976 *
9977 * ----------------------------------------------------------------------
9978 */
9979
9980 /*
9981 * Reset the accumulator's value to zero. The buffers to hold the digits
9982 * are not free'd.
9983 */
9984 static void
accum_sum_reset(NumericSumAccum * accum)9985 accum_sum_reset(NumericSumAccum *accum)
9986 {
9987 int i;
9988
9989 accum->dscale = 0;
9990 for (i = 0; i < accum->ndigits; i++)
9991 {
9992 accum->pos_digits[i] = 0;
9993 accum->neg_digits[i] = 0;
9994 }
9995 }
9996
9997 /*
9998 * Accumulate a new value.
9999 */
10000 static void
accum_sum_add(NumericSumAccum * accum,const NumericVar * val)10001 accum_sum_add(NumericSumAccum *accum, const NumericVar *val)
10002 {
10003 int32 *accum_digits;
10004 int i,
10005 val_i;
10006 int val_ndigits;
10007 NumericDigit *val_digits;
10008
10009 /*
10010 * If we have accumulated too many values since the last carry
10011 * propagation, do it now, to avoid overflowing. (We could allow more
10012 * than NBASE - 1, if we reserved two extra digits, rather than one, for
10013 * carry propagation. But even with NBASE - 1, this needs to be done so
10014 * seldom, that the performance difference is negligible.)
10015 */
10016 if (accum->num_uncarried == NBASE - 1)
10017 accum_sum_carry(accum);
10018
10019 /*
10020 * Adjust the weight or scale of the old value, so that it can accommodate
10021 * the new value.
10022 */
10023 accum_sum_rescale(accum, val);
10024
10025 /* */
10026 if (val->sign == NUMERIC_POS)
10027 accum_digits = accum->pos_digits;
10028 else
10029 accum_digits = accum->neg_digits;
10030
10031 /* copy these values into local vars for speed in loop */
10032 val_ndigits = val->ndigits;
10033 val_digits = val->digits;
10034
10035 i = accum->weight - val->weight;
10036 for (val_i = 0; val_i < val_ndigits; val_i++)
10037 {
10038 accum_digits[i] += (int32) val_digits[val_i];
10039 i++;
10040 }
10041
10042 accum->num_uncarried++;
10043 }
10044
10045 /*
10046 * Propagate carries.
10047 */
10048 static void
accum_sum_carry(NumericSumAccum * accum)10049 accum_sum_carry(NumericSumAccum *accum)
10050 {
10051 int i;
10052 int ndigits;
10053 int32 *dig;
10054 int32 carry;
10055 int32 newdig = 0;
10056
10057 /*
10058 * If no new values have been added since last carry propagation, nothing
10059 * to do.
10060 */
10061 if (accum->num_uncarried == 0)
10062 return;
10063
10064 /*
10065 * We maintain that the weight of the accumulator is always one larger
10066 * than needed to hold the current value, before carrying, to make sure
10067 * there is enough space for the possible extra digit when carry is
10068 * propagated. We cannot expand the buffer here, unless we require
10069 * callers of accum_sum_final() to switch to the right memory context.
10070 */
10071 Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
10072
10073 ndigits = accum->ndigits;
10074
10075 /* Propagate carry in the positive sum */
10076 dig = accum->pos_digits;
10077 carry = 0;
10078 for (i = ndigits - 1; i >= 0; i--)
10079 {
10080 newdig = dig[i] + carry;
10081 if (newdig >= NBASE)
10082 {
10083 carry = newdig / NBASE;
10084 newdig -= carry * NBASE;
10085 }
10086 else
10087 carry = 0;
10088 dig[i] = newdig;
10089 }
10090 /* Did we use up the digit reserved for carry propagation? */
10091 if (newdig > 0)
10092 accum->have_carry_space = false;
10093
10094 /* And the same for the negative sum */
10095 dig = accum->neg_digits;
10096 carry = 0;
10097 for (i = ndigits - 1; i >= 0; i--)
10098 {
10099 newdig = dig[i] + carry;
10100 if (newdig >= NBASE)
10101 {
10102 carry = newdig / NBASE;
10103 newdig -= carry * NBASE;
10104 }
10105 else
10106 carry = 0;
10107 dig[i] = newdig;
10108 }
10109 if (newdig > 0)
10110 accum->have_carry_space = false;
10111
10112 accum->num_uncarried = 0;
10113 }
10114
10115 /*
10116 * Re-scale accumulator to accommodate new value.
10117 *
10118 * If the new value has more digits than the current digit buffers in the
10119 * accumulator, enlarge the buffers.
10120 */
10121 static void
accum_sum_rescale(NumericSumAccum * accum,const NumericVar * val)10122 accum_sum_rescale(NumericSumAccum *accum, const NumericVar *val)
10123 {
10124 int old_weight = accum->weight;
10125 int old_ndigits = accum->ndigits;
10126 int accum_ndigits;
10127 int accum_weight;
10128 int accum_rscale;
10129 int val_rscale;
10130
10131 accum_weight = old_weight;
10132 accum_ndigits = old_ndigits;
10133
10134 /*
10135 * Does the new value have a larger weight? If so, enlarge the buffers,
10136 * and shift the existing value to the new weight, by adding leading
10137 * zeros.
10138 *
10139 * We enforce that the accumulator always has a weight one larger than
10140 * needed for the inputs, so that we have space for an extra digit at the
10141 * final carry-propagation phase, if necessary.
10142 */
10143 if (val->weight >= accum_weight)
10144 {
10145 accum_weight = val->weight + 1;
10146 accum_ndigits = accum_ndigits + (accum_weight - old_weight);
10147 }
10148
10149 /*
10150 * Even though the new value is small, we might've used up the space
10151 * reserved for the carry digit in the last call to accum_sum_carry(). If
10152 * so, enlarge to make room for another one.
10153 */
10154 else if (!accum->have_carry_space)
10155 {
10156 accum_weight++;
10157 accum_ndigits++;
10158 }
10159
10160 /* Is the new value wider on the right side? */
10161 accum_rscale = accum_ndigits - accum_weight - 1;
10162 val_rscale = val->ndigits - val->weight - 1;
10163 if (val_rscale > accum_rscale)
10164 accum_ndigits = accum_ndigits + (val_rscale - accum_rscale);
10165
10166 if (accum_ndigits != old_ndigits ||
10167 accum_weight != old_weight)
10168 {
10169 int32 *new_pos_digits;
10170 int32 *new_neg_digits;
10171 int weightdiff;
10172
10173 weightdiff = accum_weight - old_weight;
10174
10175 new_pos_digits = palloc0(accum_ndigits * sizeof(int32));
10176 new_neg_digits = palloc0(accum_ndigits * sizeof(int32));
10177
10178 if (accum->pos_digits)
10179 {
10180 memcpy(&new_pos_digits[weightdiff], accum->pos_digits,
10181 old_ndigits * sizeof(int32));
10182 pfree(accum->pos_digits);
10183
10184 memcpy(&new_neg_digits[weightdiff], accum->neg_digits,
10185 old_ndigits * sizeof(int32));
10186 pfree(accum->neg_digits);
10187 }
10188
10189 accum->pos_digits = new_pos_digits;
10190 accum->neg_digits = new_neg_digits;
10191
10192 accum->weight = accum_weight;
10193 accum->ndigits = accum_ndigits;
10194
10195 Assert(accum->pos_digits[0] == 0 && accum->neg_digits[0] == 0);
10196 accum->have_carry_space = true;
10197 }
10198
10199 if (val->dscale > accum->dscale)
10200 accum->dscale = val->dscale;
10201 }
10202
10203 /*
10204 * Return the current value of the accumulator. This perform final carry
10205 * propagation, and adds together the positive and negative sums.
10206 *
10207 * Unlike all the other routines, the caller is not required to switch to
10208 * the memory context that holds the accumulator.
10209 */
10210 static void
accum_sum_final(NumericSumAccum * accum,NumericVar * result)10211 accum_sum_final(NumericSumAccum *accum, NumericVar *result)
10212 {
10213 int i;
10214 NumericVar pos_var;
10215 NumericVar neg_var;
10216
10217 if (accum->ndigits == 0)
10218 {
10219 set_var_from_var(&const_zero, result);
10220 return;
10221 }
10222
10223 /* Perform final carry */
10224 accum_sum_carry(accum);
10225
10226 /* Create NumericVars representing the positive and negative sums */
10227 init_var(&pos_var);
10228 init_var(&neg_var);
10229
10230 pos_var.ndigits = neg_var.ndigits = accum->ndigits;
10231 pos_var.weight = neg_var.weight = accum->weight;
10232 pos_var.dscale = neg_var.dscale = accum->dscale;
10233 pos_var.sign = NUMERIC_POS;
10234 neg_var.sign = NUMERIC_NEG;
10235
10236 pos_var.buf = pos_var.digits = digitbuf_alloc(accum->ndigits);
10237 neg_var.buf = neg_var.digits = digitbuf_alloc(accum->ndigits);
10238
10239 for (i = 0; i < accum->ndigits; i++)
10240 {
10241 Assert(accum->pos_digits[i] < NBASE);
10242 pos_var.digits[i] = (int16) accum->pos_digits[i];
10243
10244 Assert(accum->neg_digits[i] < NBASE);
10245 neg_var.digits[i] = (int16) accum->neg_digits[i];
10246 }
10247
10248 /* And add them together */
10249 add_var(&pos_var, &neg_var, result);
10250
10251 /* Remove leading/trailing zeroes */
10252 strip_var(result);
10253 }
10254
10255 /*
10256 * Copy an accumulator's state.
10257 *
10258 * 'dst' is assumed to be uninitialized beforehand. No attempt is made at
10259 * freeing old values.
10260 */
10261 static void
accum_sum_copy(NumericSumAccum * dst,NumericSumAccum * src)10262 accum_sum_copy(NumericSumAccum *dst, NumericSumAccum *src)
10263 {
10264 dst->pos_digits = palloc(src->ndigits * sizeof(int32));
10265 dst->neg_digits = palloc(src->ndigits * sizeof(int32));
10266
10267 memcpy(dst->pos_digits, src->pos_digits, src->ndigits * sizeof(int32));
10268 memcpy(dst->neg_digits, src->neg_digits, src->ndigits * sizeof(int32));
10269 dst->num_uncarried = src->num_uncarried;
10270 dst->ndigits = src->ndigits;
10271 dst->weight = src->weight;
10272 dst->dscale = src->dscale;
10273 }
10274
10275 /*
10276 * Add the current value of 'accum2' into 'accum'.
10277 */
10278 static void
accum_sum_combine(NumericSumAccum * accum,NumericSumAccum * accum2)10279 accum_sum_combine(NumericSumAccum *accum, NumericSumAccum *accum2)
10280 {
10281 NumericVar tmp_var;
10282
10283 init_var(&tmp_var);
10284
10285 accum_sum_final(accum2, &tmp_var);
10286 accum_sum_add(accum, &tmp_var);
10287
10288 free_var(&tmp_var);
10289 }
10290