1package pgtype
2
3import (
4	"database/sql/driver"
5	"encoding/binary"
6	"fmt"
7	"math"
8	"math/big"
9	"strconv"
10	"strings"
11
12	"github.com/jackc/pgio"
13)
14
15// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
16const nbase = 10000
17
18const (
19	pgNumericNaN     = 0x00000000c0000000
20	pgNumericNaNSign = 0xc000
21)
22
23var big0 *big.Int = big.NewInt(0)
24var big1 *big.Int = big.NewInt(1)
25var big10 *big.Int = big.NewInt(10)
26var big100 *big.Int = big.NewInt(100)
27var big1000 *big.Int = big.NewInt(1000)
28
29var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8)
30var bigMinInt8 *big.Int = big.NewInt(math.MinInt8)
31var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16)
32var bigMinInt16 *big.Int = big.NewInt(math.MinInt16)
33var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32)
34var bigMinInt32 *big.Int = big.NewInt(math.MinInt32)
35var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64)
36var bigMinInt64 *big.Int = big.NewInt(math.MinInt64)
37var bigMaxInt *big.Int = big.NewInt(int64(maxInt))
38var bigMinInt *big.Int = big.NewInt(int64(minInt))
39
40var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8)
41var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16)
42var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32)
43var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64))
44var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint))
45
46var bigNBase *big.Int = big.NewInt(nbase)
47var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
48var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
49var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
50
51type Numeric struct {
52	Int    *big.Int
53	Exp    int32
54	Status Status
55	NaN    bool
56}
57
58func (dst *Numeric) Set(src interface{}) error {
59	if src == nil {
60		*dst = Numeric{Status: Null}
61		return nil
62	}
63
64	if value, ok := src.(interface{ Get() interface{} }); ok {
65		value2 := value.Get()
66		if value2 != value {
67			return dst.Set(value2)
68		}
69	}
70
71	switch value := src.(type) {
72	case float32:
73		if math.IsNaN(float64(value)) {
74			*dst = Numeric{Status: Present, NaN: true}
75			return nil
76		}
77		num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64))
78		if err != nil {
79			return err
80		}
81		*dst = Numeric{Int: num, Exp: exp, Status: Present}
82	case float64:
83		if math.IsNaN(value) {
84			*dst = Numeric{Status: Present, NaN: true}
85			return nil
86		}
87		num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64))
88		if err != nil {
89			return err
90		}
91		*dst = Numeric{Int: num, Exp: exp, Status: Present}
92	case int8:
93		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
94	case uint8:
95		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
96	case int16:
97		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
98	case uint16:
99		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
100	case int32:
101		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
102	case uint32:
103		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
104	case int64:
105		*dst = Numeric{Int: big.NewInt(value), Status: Present}
106	case uint64:
107		*dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present}
108	case int:
109		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
110	case uint:
111		*dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present}
112	case string:
113		num, exp, err := parseNumericString(value)
114		if err != nil {
115			return err
116		}
117		*dst = Numeric{Int: num, Exp: exp, Status: Present}
118	case *float64:
119		if value == nil {
120			*dst = Numeric{Status: Null}
121		} else {
122			return dst.Set(*value)
123		}
124	case *float32:
125		if value == nil {
126			*dst = Numeric{Status: Null}
127		} else {
128			return dst.Set(*value)
129		}
130	case *int8:
131		if value == nil {
132			*dst = Numeric{Status: Null}
133		} else {
134			return dst.Set(*value)
135		}
136	case *uint8:
137		if value == nil {
138			*dst = Numeric{Status: Null}
139		} else {
140			return dst.Set(*value)
141		}
142	case *int16:
143		if value == nil {
144			*dst = Numeric{Status: Null}
145		} else {
146			return dst.Set(*value)
147		}
148	case *uint16:
149		if value == nil {
150			*dst = Numeric{Status: Null}
151		} else {
152			return dst.Set(*value)
153		}
154	case *int32:
155		if value == nil {
156			*dst = Numeric{Status: Null}
157		} else {
158			return dst.Set(*value)
159		}
160	case *uint32:
161		if value == nil {
162			*dst = Numeric{Status: Null}
163		} else {
164			return dst.Set(*value)
165		}
166	case *int64:
167		if value == nil {
168			*dst = Numeric{Status: Null}
169		} else {
170			return dst.Set(*value)
171		}
172	case *uint64:
173		if value == nil {
174			*dst = Numeric{Status: Null}
175		} else {
176			return dst.Set(*value)
177		}
178	case *int:
179		if value == nil {
180			*dst = Numeric{Status: Null}
181		} else {
182			return dst.Set(*value)
183		}
184	case *uint:
185		if value == nil {
186			*dst = Numeric{Status: Null}
187		} else {
188			return dst.Set(*value)
189		}
190	case *string:
191		if value == nil {
192			*dst = Numeric{Status: Null}
193		} else {
194			return dst.Set(*value)
195		}
196	default:
197		if originalSrc, ok := underlyingNumberType(src); ok {
198			return dst.Set(originalSrc)
199		}
200		return fmt.Errorf("cannot convert %v to Numeric", value)
201	}
202
203	return nil
204}
205
206func (dst Numeric) Get() interface{} {
207	switch dst.Status {
208	case Present:
209		return dst
210	case Null:
211		return nil
212	default:
213		return dst.Status
214	}
215}
216
217func (src *Numeric) AssignTo(dst interface{}) error {
218	switch src.Status {
219	case Present:
220		switch v := dst.(type) {
221		case *float32:
222			f, err := src.toFloat64()
223			if err != nil {
224				return err
225			}
226			return float64AssignTo(f, src.Status, dst)
227		case *float64:
228			f, err := src.toFloat64()
229			if err != nil {
230				return err
231			}
232			return float64AssignTo(f, src.Status, dst)
233		case *int:
234			normalizedInt, err := src.toBigInt()
235			if err != nil {
236				return err
237			}
238			if normalizedInt.Cmp(bigMaxInt) > 0 {
239				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
240			}
241			if normalizedInt.Cmp(bigMinInt) < 0 {
242				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
243			}
244			*v = int(normalizedInt.Int64())
245		case *int8:
246			normalizedInt, err := src.toBigInt()
247			if err != nil {
248				return err
249			}
250			if normalizedInt.Cmp(bigMaxInt8) > 0 {
251				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
252			}
253			if normalizedInt.Cmp(bigMinInt8) < 0 {
254				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
255			}
256			*v = int8(normalizedInt.Int64())
257		case *int16:
258			normalizedInt, err := src.toBigInt()
259			if err != nil {
260				return err
261			}
262			if normalizedInt.Cmp(bigMaxInt16) > 0 {
263				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
264			}
265			if normalizedInt.Cmp(bigMinInt16) < 0 {
266				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
267			}
268			*v = int16(normalizedInt.Int64())
269		case *int32:
270			normalizedInt, err := src.toBigInt()
271			if err != nil {
272				return err
273			}
274			if normalizedInt.Cmp(bigMaxInt32) > 0 {
275				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
276			}
277			if normalizedInt.Cmp(bigMinInt32) < 0 {
278				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
279			}
280			*v = int32(normalizedInt.Int64())
281		case *int64:
282			normalizedInt, err := src.toBigInt()
283			if err != nil {
284				return err
285			}
286			if normalizedInt.Cmp(bigMaxInt64) > 0 {
287				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
288			}
289			if normalizedInt.Cmp(bigMinInt64) < 0 {
290				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
291			}
292			*v = normalizedInt.Int64()
293		case *uint:
294			normalizedInt, err := src.toBigInt()
295			if err != nil {
296				return err
297			}
298			if normalizedInt.Cmp(big0) < 0 {
299				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
300			} else if normalizedInt.Cmp(bigMaxUint) > 0 {
301				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
302			}
303			*v = uint(normalizedInt.Uint64())
304		case *uint8:
305			normalizedInt, err := src.toBigInt()
306			if err != nil {
307				return err
308			}
309			if normalizedInt.Cmp(big0) < 0 {
310				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
311			} else if normalizedInt.Cmp(bigMaxUint8) > 0 {
312				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
313			}
314			*v = uint8(normalizedInt.Uint64())
315		case *uint16:
316			normalizedInt, err := src.toBigInt()
317			if err != nil {
318				return err
319			}
320			if normalizedInt.Cmp(big0) < 0 {
321				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
322			} else if normalizedInt.Cmp(bigMaxUint16) > 0 {
323				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
324			}
325			*v = uint16(normalizedInt.Uint64())
326		case *uint32:
327			normalizedInt, err := src.toBigInt()
328			if err != nil {
329				return err
330			}
331			if normalizedInt.Cmp(big0) < 0 {
332				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
333			} else if normalizedInt.Cmp(bigMaxUint32) > 0 {
334				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
335			}
336			*v = uint32(normalizedInt.Uint64())
337		case *uint64:
338			normalizedInt, err := src.toBigInt()
339			if err != nil {
340				return err
341			}
342			if normalizedInt.Cmp(big0) < 0 {
343				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
344			} else if normalizedInt.Cmp(bigMaxUint64) > 0 {
345				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
346			}
347			*v = normalizedInt.Uint64()
348		default:
349			if nextDst, retry := GetAssignToDstType(dst); retry {
350				return src.AssignTo(nextDst)
351			}
352			return fmt.Errorf("unable to assign to %T", dst)
353		}
354	case Null:
355		return NullAssignTo(dst)
356	}
357
358	return nil
359}
360
361func (dst *Numeric) toBigInt() (*big.Int, error) {
362	if dst.Exp == 0 {
363		return dst.Int, nil
364	}
365
366	num := &big.Int{}
367	num.Set(dst.Int)
368	if dst.Exp > 0 {
369		mul := &big.Int{}
370		mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil)
371		num.Mul(num, mul)
372		return num, nil
373	}
374
375	div := &big.Int{}
376	div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil)
377	remainder := &big.Int{}
378	num.DivMod(num, div, remainder)
379	if remainder.Cmp(big0) != 0 {
380		return nil, fmt.Errorf("cannot convert %v to integer", dst)
381	}
382	return num, nil
383}
384
385func (src *Numeric) toFloat64() (float64, error) {
386	if src.NaN {
387		return math.NaN(), nil
388	}
389
390	buf := make([]byte, 0, 32)
391
392	buf = append(buf, src.Int.String()...)
393	buf = append(buf, 'e')
394	buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
395
396	f, err := strconv.ParseFloat(string(buf), 64)
397	if err != nil {
398		return 0, err
399	}
400	return f, nil
401}
402
403func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
404	if src == nil {
405		*dst = Numeric{Status: Null}
406		return nil
407	}
408
409	if string(src) == "NaN" {
410		*dst = Numeric{Status: Present, NaN: true}
411		return nil
412	}
413
414	num, exp, err := parseNumericString(string(src))
415	if err != nil {
416		return err
417	}
418
419	*dst = Numeric{Int: num, Exp: exp, Status: Present}
420	return nil
421}
422
423func parseNumericString(str string) (n *big.Int, exp int32, err error) {
424	parts := strings.SplitN(str, ".", 2)
425	digits := strings.Join(parts, "")
426
427	if len(parts) > 1 {
428		exp = int32(-len(parts[1]))
429	} else {
430		for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' {
431			digits = digits[:len(digits)-1]
432			exp++
433		}
434	}
435
436	accum := &big.Int{}
437	if _, ok := accum.SetString(digits, 10); !ok {
438		return nil, 0, fmt.Errorf("%s is not a number", str)
439	}
440
441	return accum, exp, nil
442}
443
444func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error {
445	if src == nil {
446		*dst = Numeric{Status: Null}
447		return nil
448	}
449
450	if len(src) < 8 {
451		return fmt.Errorf("numeric incomplete %v", src)
452	}
453
454	rp := 0
455	ndigits := int16(binary.BigEndian.Uint16(src[rp:]))
456	rp += 2
457	weight := int16(binary.BigEndian.Uint16(src[rp:]))
458	rp += 2
459	sign := uint16(binary.BigEndian.Uint16(src[rp:]))
460	rp += 2
461	dscale := int16(binary.BigEndian.Uint16(src[rp:]))
462	rp += 2
463
464	if sign == pgNumericNaNSign {
465		*dst = Numeric{Status: Present, NaN: true}
466		return nil
467	}
468
469	if ndigits == 0 {
470		*dst = Numeric{Int: big.NewInt(0), Status: Present}
471		return nil
472	}
473
474	if len(src[rp:]) < int(ndigits)*2 {
475		return fmt.Errorf("numeric incomplete %v", src)
476	}
477
478	accum := &big.Int{}
479
480	for i := 0; i < int(ndigits+3)/4; i++ {
481		int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:])
482		rp += bytesRead
483
484		if i > 0 {
485			var mul *big.Int
486			switch digitsRead {
487			case 1:
488				mul = bigNBase
489			case 2:
490				mul = bigNBaseX2
491			case 3:
492				mul = bigNBaseX3
493			case 4:
494				mul = bigNBaseX4
495			default:
496				return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead)
497			}
498			accum.Mul(accum, mul)
499		}
500
501		accum.Add(accum, big.NewInt(int64accum))
502	}
503
504	exp := (int32(weight) - int32(ndigits) + 1) * 4
505
506	if dscale > 0 {
507		fracNBaseDigits := ndigits - weight - 1
508		fracDecimalDigits := fracNBaseDigits * 4
509
510		if dscale > fracDecimalDigits {
511			multCount := int(dscale - fracDecimalDigits)
512			for i := 0; i < multCount; i++ {
513				accum.Mul(accum, big10)
514				exp--
515			}
516		} else if dscale < fracDecimalDigits {
517			divCount := int(fracDecimalDigits - dscale)
518			for i := 0; i < divCount; i++ {
519				accum.Div(accum, big10)
520				exp++
521			}
522		}
523	}
524
525	reduced := &big.Int{}
526	remainder := &big.Int{}
527	if exp >= 0 {
528		for {
529			reduced.DivMod(accum, big10, remainder)
530			if remainder.Cmp(big0) != 0 {
531				break
532			}
533			accum.Set(reduced)
534			exp++
535		}
536	}
537
538	if sign != 0 {
539		accum.Neg(accum)
540	}
541
542	*dst = Numeric{Int: accum, Exp: exp, Status: Present}
543
544	return nil
545
546}
547
548func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
549	digits := len(src) / 2
550	if digits > 4 {
551		digits = 4
552	}
553
554	rp := 0
555
556	for i := 0; i < digits; i++ {
557		if i > 0 {
558			accum *= nbase
559		}
560		accum += int64(binary.BigEndian.Uint16(src[rp:]))
561		rp += 2
562	}
563
564	return accum, rp, digits
565}
566
567func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
568	switch src.Status {
569	case Null:
570		return nil, nil
571	case Undefined:
572		return nil, errUndefined
573	}
574
575	if src.NaN {
576		buf = append(buf, "NaN"...)
577		return buf, nil
578	}
579
580	buf = append(buf, src.Int.String()...)
581	buf = append(buf, 'e')
582	buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
583	return buf, nil
584}
585
586func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
587	switch src.Status {
588	case Null:
589		return nil, nil
590	case Undefined:
591		return nil, errUndefined
592	}
593
594	if src.NaN {
595		buf = pgio.AppendUint64(buf, pgNumericNaN)
596		return buf, nil
597	}
598
599	var sign int16
600	if src.Int.Cmp(big0) < 0 {
601		sign = 16384
602	}
603
604	absInt := &big.Int{}
605	wholePart := &big.Int{}
606	fracPart := &big.Int{}
607	remainder := &big.Int{}
608	absInt.Abs(src.Int)
609
610	// Normalize absInt and exp to where exp is always a multiple of 4. This makes
611	// converting to 16-bit base 10,000 digits easier.
612	var exp int32
613	switch src.Exp % 4 {
614	case 1, -3:
615		exp = src.Exp - 1
616		absInt.Mul(absInt, big10)
617	case 2, -2:
618		exp = src.Exp - 2
619		absInt.Mul(absInt, big100)
620	case 3, -1:
621		exp = src.Exp - 3
622		absInt.Mul(absInt, big1000)
623	default:
624		exp = src.Exp
625	}
626
627	if exp < 0 {
628		divisor := &big.Int{}
629		divisor.Exp(big10, big.NewInt(int64(-exp)), nil)
630		wholePart.DivMod(absInt, divisor, fracPart)
631		fracPart.Add(fracPart, divisor)
632	} else {
633		wholePart = absInt
634	}
635
636	var wholeDigits, fracDigits []int16
637
638	for wholePart.Cmp(big0) != 0 {
639		wholePart.DivMod(wholePart, bigNBase, remainder)
640		wholeDigits = append(wholeDigits, int16(remainder.Int64()))
641	}
642
643	if fracPart.Cmp(big0) != 0 {
644		for fracPart.Cmp(big1) != 0 {
645			fracPart.DivMod(fracPart, bigNBase, remainder)
646			fracDigits = append(fracDigits, int16(remainder.Int64()))
647		}
648	}
649
650	buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits)))
651
652	var weight int16
653	if len(wholeDigits) > 0 {
654		weight = int16(len(wholeDigits) - 1)
655		if exp > 0 {
656			weight += int16(exp / 4)
657		}
658	} else {
659		weight = int16(exp/4) - 1 + int16(len(fracDigits))
660	}
661	buf = pgio.AppendInt16(buf, weight)
662
663	buf = pgio.AppendInt16(buf, sign)
664
665	var dscale int16
666	if src.Exp < 0 {
667		dscale = int16(-src.Exp)
668	}
669	buf = pgio.AppendInt16(buf, dscale)
670
671	for i := len(wholeDigits) - 1; i >= 0; i-- {
672		buf = pgio.AppendInt16(buf, wholeDigits[i])
673	}
674
675	for i := len(fracDigits) - 1; i >= 0; i-- {
676		buf = pgio.AppendInt16(buf, fracDigits[i])
677	}
678
679	return buf, nil
680}
681
682// Scan implements the database/sql Scanner interface.
683func (dst *Numeric) Scan(src interface{}) error {
684	if src == nil {
685		*dst = Numeric{Status: Null}
686		return nil
687	}
688
689	switch src := src.(type) {
690	case string:
691		return dst.DecodeText(nil, []byte(src))
692	case []byte:
693		srcCopy := make([]byte, len(src))
694		copy(srcCopy, src)
695		return dst.DecodeText(nil, srcCopy)
696	}
697
698	return fmt.Errorf("cannot scan %T", src)
699}
700
701// Value implements the database/sql/driver Valuer interface.
702func (src Numeric) Value() (driver.Value, error) {
703	switch src.Status {
704	case Present:
705		buf, err := src.EncodeText(nil, nil)
706		if err != nil {
707			return nil, err
708		}
709
710		return string(buf), nil
711	case Null:
712		return nil, nil
713	default:
714		return nil, errUndefined
715	}
716}
717