1package pgtype
2
3import (
4	"database/sql/driver"
5	"encoding/binary"
6	"math"
7	"math/big"
8	"strconv"
9	"strings"
10
11	"github.com/jackc/pgio"
12	errors "golang.org/x/xerrors"
13)
14
15// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
16const nbase = 10000
17
18const (
19	pgNumericNaN     = 0x000000000c000000
20	pgNumericNaNSign = 0x0c00
21)
22
23var big0 *big.Int = big.NewInt(0)
24var big1 *big.Int = big.NewInt(1)
25var big10 *big.Int = big.NewInt(10)
26var big100 *big.Int = big.NewInt(100)
27var big1000 *big.Int = big.NewInt(1000)
28
29var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8)
30var bigMinInt8 *big.Int = big.NewInt(math.MinInt8)
31var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16)
32var bigMinInt16 *big.Int = big.NewInt(math.MinInt16)
33var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32)
34var bigMinInt32 *big.Int = big.NewInt(math.MinInt32)
35var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64)
36var bigMinInt64 *big.Int = big.NewInt(math.MinInt64)
37var bigMaxInt *big.Int = big.NewInt(int64(maxInt))
38var bigMinInt *big.Int = big.NewInt(int64(minInt))
39
40var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8)
41var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16)
42var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32)
43var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64))
44var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint))
45
46var bigNBase *big.Int = big.NewInt(nbase)
47var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
48var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
49var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
50
51type Numeric struct {
52	Int    *big.Int
53	Exp    int32
54	Status Status
55	NaN    bool
56}
57
58func (dst *Numeric) Set(src interface{}) error {
59	if src == nil {
60		*dst = Numeric{Status: Null}
61		return nil
62	}
63
64	if value, ok := src.(interface{ Get() interface{} }); ok {
65		value2 := value.Get()
66		if value2 != value {
67			return dst.Set(value2)
68		}
69	}
70
71	switch value := src.(type) {
72	case float32:
73		if math.IsNaN(float64(value)) {
74			*dst = Numeric{Status: Present, NaN: true}
75			return nil
76		}
77		num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64))
78		if err != nil {
79			return err
80		}
81		*dst = Numeric{Int: num, Exp: exp, Status: Present}
82	case float64:
83		if math.IsNaN(value) {
84			*dst = Numeric{Status: Present, NaN: true}
85			return nil
86		}
87		num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64))
88		if err != nil {
89			return err
90		}
91		*dst = Numeric{Int: num, Exp: exp, Status: Present}
92	case int8:
93		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
94	case uint8:
95		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
96	case int16:
97		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
98	case uint16:
99		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
100	case int32:
101		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
102	case uint32:
103		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
104	case int64:
105		*dst = Numeric{Int: big.NewInt(value), Status: Present}
106	case uint64:
107		*dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present}
108	case int:
109		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
110	case uint:
111		*dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present}
112	case string:
113		num, exp, err := parseNumericString(value)
114		if err != nil {
115			return err
116		}
117		*dst = Numeric{Int: num, Exp: exp, Status: Present}
118	case *float64:
119		if value == nil {
120			*dst = Numeric{Status: Null}
121		} else {
122			return dst.Set(*value)
123		}
124	case *float32:
125		if value == nil {
126			*dst = Numeric{Status: Null}
127		} else {
128			return dst.Set(*value)
129		}
130	case *int8:
131		if value == nil {
132			*dst = Numeric{Status: Null}
133		} else {
134			return dst.Set(*value)
135		}
136	case *uint8:
137		if value == nil {
138			*dst = Numeric{Status: Null}
139		} else {
140			return dst.Set(*value)
141		}
142	case *int16:
143		if value == nil {
144			*dst = Numeric{Status: Null}
145		} else {
146			return dst.Set(*value)
147		}
148	case *uint16:
149		if value == nil {
150			*dst = Numeric{Status: Null}
151		} else {
152			return dst.Set(*value)
153		}
154	case *int32:
155		if value == nil {
156			*dst = Numeric{Status: Null}
157		} else {
158			return dst.Set(*value)
159		}
160	case *uint32:
161		if value == nil {
162			*dst = Numeric{Status: Null}
163		} else {
164			return dst.Set(*value)
165		}
166	case *int64:
167		if value == nil {
168			*dst = Numeric{Status: Null}
169		} else {
170			return dst.Set(*value)
171		}
172	case *uint64:
173		if value == nil {
174			*dst = Numeric{Status: Null}
175		} else {
176			return dst.Set(*value)
177		}
178	case *int:
179		if value == nil {
180			*dst = Numeric{Status: Null}
181		} else {
182			return dst.Set(*value)
183		}
184	case *uint:
185		if value == nil {
186			*dst = Numeric{Status: Null}
187		} else {
188			return dst.Set(*value)
189		}
190	case *string:
191		if value == nil {
192			*dst = Numeric{Status: Null}
193		} else {
194			return dst.Set(*value)
195		}
196	default:
197		if originalSrc, ok := underlyingNumberType(src); ok {
198			return dst.Set(originalSrc)
199		}
200		return errors.Errorf("cannot convert %v to Numeric", value)
201	}
202
203	return nil
204}
205
206func (dst Numeric) Get() interface{} {
207	switch dst.Status {
208	case Present:
209		return dst
210	case Null:
211		return nil
212	default:
213		return dst.Status
214	}
215}
216
217func (src *Numeric) AssignTo(dst interface{}) error {
218	switch src.Status {
219	case Present:
220		switch v := dst.(type) {
221		case *float32:
222			f, err := src.toFloat64()
223			if err != nil {
224				return err
225			}
226			return float64AssignTo(f, src.Status, dst)
227		case *float64:
228			f, err := src.toFloat64()
229			if err != nil {
230				return err
231			}
232			return float64AssignTo(f, src.Status, dst)
233		case *int:
234			normalizedInt, err := src.toBigInt()
235			if err != nil {
236				return err
237			}
238			if normalizedInt.Cmp(bigMaxInt) > 0 {
239				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
240			}
241			if normalizedInt.Cmp(bigMinInt) < 0 {
242				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
243			}
244			*v = int(normalizedInt.Int64())
245		case *int8:
246			normalizedInt, err := src.toBigInt()
247			if err != nil {
248				return err
249			}
250			if normalizedInt.Cmp(bigMaxInt8) > 0 {
251				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
252			}
253			if normalizedInt.Cmp(bigMinInt8) < 0 {
254				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
255			}
256			*v = int8(normalizedInt.Int64())
257		case *int16:
258			normalizedInt, err := src.toBigInt()
259			if err != nil {
260				return err
261			}
262			if normalizedInt.Cmp(bigMaxInt16) > 0 {
263				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
264			}
265			if normalizedInt.Cmp(bigMinInt16) < 0 {
266				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
267			}
268			*v = int16(normalizedInt.Int64())
269		case *int32:
270			normalizedInt, err := src.toBigInt()
271			if err != nil {
272				return err
273			}
274			if normalizedInt.Cmp(bigMaxInt32) > 0 {
275				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
276			}
277			if normalizedInt.Cmp(bigMinInt32) < 0 {
278				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
279			}
280			*v = int32(normalizedInt.Int64())
281		case *int64:
282			normalizedInt, err := src.toBigInt()
283			if err != nil {
284				return err
285			}
286			if normalizedInt.Cmp(bigMaxInt64) > 0 {
287				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
288			}
289			if normalizedInt.Cmp(bigMinInt64) < 0 {
290				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
291			}
292			*v = normalizedInt.Int64()
293		case *uint:
294			normalizedInt, err := src.toBigInt()
295			if err != nil {
296				return err
297			}
298			if normalizedInt.Cmp(big0) < 0 {
299				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
300			} else if normalizedInt.Cmp(bigMaxUint) > 0 {
301				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
302			}
303			*v = uint(normalizedInt.Uint64())
304		case *uint8:
305			normalizedInt, err := src.toBigInt()
306			if err != nil {
307				return err
308			}
309			if normalizedInt.Cmp(big0) < 0 {
310				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
311			} else if normalizedInt.Cmp(bigMaxUint8) > 0 {
312				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
313			}
314			*v = uint8(normalizedInt.Uint64())
315		case *uint16:
316			normalizedInt, err := src.toBigInt()
317			if err != nil {
318				return err
319			}
320			if normalizedInt.Cmp(big0) < 0 {
321				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
322			} else if normalizedInt.Cmp(bigMaxUint16) > 0 {
323				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
324			}
325			*v = uint16(normalizedInt.Uint64())
326		case *uint32:
327			normalizedInt, err := src.toBigInt()
328			if err != nil {
329				return err
330			}
331			if normalizedInt.Cmp(big0) < 0 {
332				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
333			} else if normalizedInt.Cmp(bigMaxUint32) > 0 {
334				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
335			}
336			*v = uint32(normalizedInt.Uint64())
337		case *uint64:
338			normalizedInt, err := src.toBigInt()
339			if err != nil {
340				return err
341			}
342			if normalizedInt.Cmp(big0) < 0 {
343				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
344			} else if normalizedInt.Cmp(bigMaxUint64) > 0 {
345				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
346			}
347			*v = normalizedInt.Uint64()
348		default:
349			if nextDst, retry := GetAssignToDstType(dst); retry {
350				return src.AssignTo(nextDst)
351			}
352			return errors.Errorf("unable to assign to %T", dst)
353		}
354	case Null:
355		return NullAssignTo(dst)
356	}
357
358	return nil
359}
360
361func (dst *Numeric) toBigInt() (*big.Int, error) {
362	if dst.Exp == 0 {
363		return dst.Int, nil
364	}
365
366	num := &big.Int{}
367	num.Set(dst.Int)
368	if dst.Exp > 0 {
369		mul := &big.Int{}
370		mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil)
371		num.Mul(num, mul)
372		return num, nil
373	}
374
375	div := &big.Int{}
376	div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil)
377	remainder := &big.Int{}
378	num.DivMod(num, div, remainder)
379	if remainder.Cmp(big0) != 0 {
380		return nil, errors.Errorf("cannot convert %v to integer", dst)
381	}
382	return num, nil
383}
384
385func (src *Numeric) toFloat64() (float64, error) {
386	if src.NaN {
387		return math.NaN(), nil
388	}
389
390	buf := make([]byte, 0, 32)
391
392	buf = append(buf, src.Int.String()...)
393	buf = append(buf, 'e')
394	buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
395
396	f, err := strconv.ParseFloat(string(buf), 64)
397	if err != nil {
398		return 0, err
399	}
400	return f, nil
401}
402
403func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
404	if src == nil {
405		*dst = Numeric{Status: Null}
406		return nil
407	}
408
409	if string(src) == "'NaN'" { // includes single quotes, see EncodeText for details.
410		*dst = Numeric{Status: Present, NaN: true}
411		return nil
412	}
413
414	num, exp, err := parseNumericString(string(src))
415	if err != nil {
416		return err
417	}
418
419	*dst = Numeric{Int: num, Exp: exp, Status: Present}
420	return nil
421}
422
423func parseNumericString(str string) (n *big.Int, exp int32, err error) {
424	parts := strings.SplitN(str, ".", 2)
425	digits := strings.Join(parts, "")
426
427	if len(parts) > 1 {
428		exp = int32(-len(parts[1]))
429	} else {
430		for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' {
431			digits = digits[:len(digits)-1]
432			exp++
433		}
434	}
435
436	accum := &big.Int{}
437	if _, ok := accum.SetString(digits, 10); !ok {
438		return nil, 0, errors.Errorf("%s is not a number", str)
439	}
440
441	return accum, exp, nil
442}
443
444func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error {
445	if src == nil {
446		*dst = Numeric{Status: Null}
447		return nil
448	}
449
450	if len(src) < 8 {
451		return errors.Errorf("numeric incomplete %v", src)
452	}
453
454	rp := 0
455	ndigits := int16(binary.BigEndian.Uint16(src[rp:]))
456	rp += 2
457	weight := int16(binary.BigEndian.Uint16(src[rp:]))
458	rp += 2
459	sign := int16(binary.BigEndian.Uint16(src[rp:]))
460	rp += 2
461	dscale := int16(binary.BigEndian.Uint16(src[rp:]))
462	rp += 2
463
464	if sign == pgNumericNaNSign {
465		*dst = Numeric{Status: Present, NaN: true}
466		return nil
467	}
468
469	if ndigits == 0 {
470		*dst = Numeric{Int: big.NewInt(0), Status: Present}
471		return nil
472	}
473
474	if len(src[rp:]) < int(ndigits)*2 {
475		return errors.Errorf("numeric incomplete %v", src)
476	}
477
478	accum := &big.Int{}
479
480	for i := 0; i < int(ndigits+3)/4; i++ {
481		int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:])
482		rp += bytesRead
483
484		if i > 0 {
485			var mul *big.Int
486			switch digitsRead {
487			case 1:
488				mul = bigNBase
489			case 2:
490				mul = bigNBaseX2
491			case 3:
492				mul = bigNBaseX3
493			case 4:
494				mul = bigNBaseX4
495			default:
496				return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead)
497			}
498			accum.Mul(accum, mul)
499		}
500
501		accum.Add(accum, big.NewInt(int64accum))
502	}
503
504	exp := (int32(weight) - int32(ndigits) + 1) * 4
505
506	if dscale > 0 {
507		fracNBaseDigits := ndigits - weight - 1
508		fracDecimalDigits := fracNBaseDigits * 4
509
510		if dscale > fracDecimalDigits {
511			multCount := int(dscale - fracDecimalDigits)
512			for i := 0; i < multCount; i++ {
513				accum.Mul(accum, big10)
514				exp--
515			}
516		} else if dscale < fracDecimalDigits {
517			divCount := int(fracDecimalDigits - dscale)
518			for i := 0; i < divCount; i++ {
519				accum.Div(accum, big10)
520				exp++
521			}
522		}
523	}
524
525	reduced := &big.Int{}
526	remainder := &big.Int{}
527	if exp >= 0 {
528		for {
529			reduced.DivMod(accum, big10, remainder)
530			if remainder.Cmp(big0) != 0 {
531				break
532			}
533			accum.Set(reduced)
534			exp++
535		}
536	}
537
538	if sign != 0 {
539		accum.Neg(accum)
540	}
541
542	*dst = Numeric{Int: accum, Exp: exp, Status: Present}
543
544	return nil
545
546}
547
548func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
549	digits := len(src) / 2
550	if digits > 4 {
551		digits = 4
552	}
553
554	rp := 0
555
556	for i := 0; i < digits; i++ {
557		if i > 0 {
558			accum *= nbase
559		}
560		accum += int64(binary.BigEndian.Uint16(src[rp:]))
561		rp += 2
562	}
563
564	return accum, rp, digits
565}
566
567func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
568	switch src.Status {
569	case Null:
570		return nil, nil
571	case Undefined:
572		return nil, errUndefined
573	}
574
575	if src.NaN {
576		// encode as 'NaN' including single quotes,
577		// "When writing this value [NaN] as a constant in an SQL command,
578		// you must put quotes around it, for example UPDATE table SET x = 'NaN'"
579		// https://www.postgresql.org/docs/9.3/datatype-numeric.html
580		buf = append(buf, "'NaN'"...)
581		return buf, nil
582	}
583
584	buf = append(buf, src.Int.String()...)
585	buf = append(buf, 'e')
586	buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
587	return buf, nil
588}
589
590func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
591	switch src.Status {
592	case Null:
593		return nil, nil
594	case Undefined:
595		return nil, errUndefined
596	}
597
598	if src.NaN {
599		buf = pgio.AppendUint64(buf, pgNumericNaN)
600		return buf, nil
601	}
602
603	var sign int16
604	if src.Int.Cmp(big0) < 0 {
605		sign = 16384
606	}
607
608	absInt := &big.Int{}
609	wholePart := &big.Int{}
610	fracPart := &big.Int{}
611	remainder := &big.Int{}
612	absInt.Abs(src.Int)
613
614	// Normalize absInt and exp to where exp is always a multiple of 4. This makes
615	// converting to 16-bit base 10,000 digits easier.
616	var exp int32
617	switch src.Exp % 4 {
618	case 1, -3:
619		exp = src.Exp - 1
620		absInt.Mul(absInt, big10)
621	case 2, -2:
622		exp = src.Exp - 2
623		absInt.Mul(absInt, big100)
624	case 3, -1:
625		exp = src.Exp - 3
626		absInt.Mul(absInt, big1000)
627	default:
628		exp = src.Exp
629	}
630
631	if exp < 0 {
632		divisor := &big.Int{}
633		divisor.Exp(big10, big.NewInt(int64(-exp)), nil)
634		wholePart.DivMod(absInt, divisor, fracPart)
635		fracPart.Add(fracPart, divisor)
636	} else {
637		wholePart = absInt
638	}
639
640	var wholeDigits, fracDigits []int16
641
642	for wholePart.Cmp(big0) != 0 {
643		wholePart.DivMod(wholePart, bigNBase, remainder)
644		wholeDigits = append(wholeDigits, int16(remainder.Int64()))
645	}
646
647	if fracPart.Cmp(big0) != 0 {
648		for fracPart.Cmp(big1) != 0 {
649			fracPart.DivMod(fracPart, bigNBase, remainder)
650			fracDigits = append(fracDigits, int16(remainder.Int64()))
651		}
652	}
653
654	buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits)))
655
656	var weight int16
657	if len(wholeDigits) > 0 {
658		weight = int16(len(wholeDigits) - 1)
659		if exp > 0 {
660			weight += int16(exp / 4)
661		}
662	} else {
663		weight = int16(exp/4) - 1 + int16(len(fracDigits))
664	}
665	buf = pgio.AppendInt16(buf, weight)
666
667	buf = pgio.AppendInt16(buf, sign)
668
669	var dscale int16
670	if src.Exp < 0 {
671		dscale = int16(-src.Exp)
672	}
673	buf = pgio.AppendInt16(buf, dscale)
674
675	for i := len(wholeDigits) - 1; i >= 0; i-- {
676		buf = pgio.AppendInt16(buf, wholeDigits[i])
677	}
678
679	for i := len(fracDigits) - 1; i >= 0; i-- {
680		buf = pgio.AppendInt16(buf, fracDigits[i])
681	}
682
683	return buf, nil
684}
685
686// Scan implements the database/sql Scanner interface.
687func (dst *Numeric) Scan(src interface{}) error {
688	if src == nil {
689		*dst = Numeric{Status: Null}
690		return nil
691	}
692
693	switch src := src.(type) {
694	case string:
695		return dst.DecodeText(nil, []byte(src))
696	case []byte:
697		srcCopy := make([]byte, len(src))
698		copy(srcCopy, src)
699		return dst.DecodeText(nil, srcCopy)
700	}
701
702	return errors.Errorf("cannot scan %T", src)
703}
704
705// Value implements the database/sql/driver Valuer interface.
706func (src Numeric) Value() (driver.Value, error) {
707	switch src.Status {
708	case Present:
709		buf, err := src.EncodeText(nil, nil)
710		if err != nil {
711			return nil, err
712		}
713
714		return string(buf), nil
715	case Null:
716		return nil, nil
717	default:
718		return nil, errUndefined
719	}
720}
721