1package pgtype
2
3import (
4	"database/sql/driver"
5	"encoding/binary"
6	"math"
7	"math/big"
8	"strconv"
9	"strings"
10
11	"github.com/jackc/pgio"
12	errors "golang.org/x/xerrors"
13)
14
15// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
16const nbase = 10000
17
18var big0 *big.Int = big.NewInt(0)
19var big1 *big.Int = big.NewInt(1)
20var big10 *big.Int = big.NewInt(10)
21var big100 *big.Int = big.NewInt(100)
22var big1000 *big.Int = big.NewInt(1000)
23
24var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8)
25var bigMinInt8 *big.Int = big.NewInt(math.MinInt8)
26var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16)
27var bigMinInt16 *big.Int = big.NewInt(math.MinInt16)
28var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32)
29var bigMinInt32 *big.Int = big.NewInt(math.MinInt32)
30var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64)
31var bigMinInt64 *big.Int = big.NewInt(math.MinInt64)
32var bigMaxInt *big.Int = big.NewInt(int64(maxInt))
33var bigMinInt *big.Int = big.NewInt(int64(minInt))
34
35var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8)
36var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16)
37var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32)
38var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64))
39var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint))
40
41var bigNBase *big.Int = big.NewInt(nbase)
42var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
43var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
44var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
45
46type Numeric struct {
47	Int    *big.Int
48	Exp    int32
49	Status Status
50}
51
52func (dst *Numeric) Set(src interface{}) error {
53	if src == nil {
54		*dst = Numeric{Status: Null}
55		return nil
56	}
57
58	if value, ok := src.(interface{ Get() interface{} }); ok {
59		value2 := value.Get()
60		if value2 != value {
61			return dst.Set(value2)
62		}
63	}
64
65	switch value := src.(type) {
66	case float32:
67		num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64))
68		if err != nil {
69			return err
70		}
71		*dst = Numeric{Int: num, Exp: exp, Status: Present}
72	case float64:
73		num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64))
74		if err != nil {
75			return err
76		}
77		*dst = Numeric{Int: num, Exp: exp, Status: Present}
78	case int8:
79		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
80	case uint8:
81		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
82	case int16:
83		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
84	case uint16:
85		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
86	case int32:
87		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
88	case uint32:
89		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
90	case int64:
91		*dst = Numeric{Int: big.NewInt(value), Status: Present}
92	case uint64:
93		*dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present}
94	case int:
95		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
96	case uint:
97		*dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present}
98	case string:
99		num, exp, err := parseNumericString(value)
100		if err != nil {
101			return err
102		}
103		*dst = Numeric{Int: num, Exp: exp, Status: Present}
104	default:
105		if originalSrc, ok := underlyingNumberType(src); ok {
106			return dst.Set(originalSrc)
107		}
108		return errors.Errorf("cannot convert %v to Numeric", value)
109	}
110
111	return nil
112}
113
114func (dst Numeric) Get() interface{} {
115	switch dst.Status {
116	case Present:
117		return dst
118	case Null:
119		return nil
120	default:
121		return dst.Status
122	}
123}
124
125func (src *Numeric) AssignTo(dst interface{}) error {
126	switch src.Status {
127	case Present:
128		switch v := dst.(type) {
129		case *float32:
130			f, err := src.toFloat64()
131			if err != nil {
132				return err
133			}
134			return float64AssignTo(f, src.Status, dst)
135		case *float64:
136			f, err := src.toFloat64()
137			if err != nil {
138				return err
139			}
140			return float64AssignTo(f, src.Status, dst)
141		case *int:
142			normalizedInt, err := src.toBigInt()
143			if err != nil {
144				return err
145			}
146			if normalizedInt.Cmp(bigMaxInt) > 0 {
147				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
148			}
149			if normalizedInt.Cmp(bigMinInt) < 0 {
150				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
151			}
152			*v = int(normalizedInt.Int64())
153		case *int8:
154			normalizedInt, err := src.toBigInt()
155			if err != nil {
156				return err
157			}
158			if normalizedInt.Cmp(bigMaxInt8) > 0 {
159				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
160			}
161			if normalizedInt.Cmp(bigMinInt8) < 0 {
162				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
163			}
164			*v = int8(normalizedInt.Int64())
165		case *int16:
166			normalizedInt, err := src.toBigInt()
167			if err != nil {
168				return err
169			}
170			if normalizedInt.Cmp(bigMaxInt16) > 0 {
171				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
172			}
173			if normalizedInt.Cmp(bigMinInt16) < 0 {
174				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
175			}
176			*v = int16(normalizedInt.Int64())
177		case *int32:
178			normalizedInt, err := src.toBigInt()
179			if err != nil {
180				return err
181			}
182			if normalizedInt.Cmp(bigMaxInt32) > 0 {
183				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
184			}
185			if normalizedInt.Cmp(bigMinInt32) < 0 {
186				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
187			}
188			*v = int32(normalizedInt.Int64())
189		case *int64:
190			normalizedInt, err := src.toBigInt()
191			if err != nil {
192				return err
193			}
194			if normalizedInt.Cmp(bigMaxInt64) > 0 {
195				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
196			}
197			if normalizedInt.Cmp(bigMinInt64) < 0 {
198				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
199			}
200			*v = normalizedInt.Int64()
201		case *uint:
202			normalizedInt, err := src.toBigInt()
203			if err != nil {
204				return err
205			}
206			if normalizedInt.Cmp(big0) < 0 {
207				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
208			} else if normalizedInt.Cmp(bigMaxUint) > 0 {
209				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
210			}
211			*v = uint(normalizedInt.Uint64())
212		case *uint8:
213			normalizedInt, err := src.toBigInt()
214			if err != nil {
215				return err
216			}
217			if normalizedInt.Cmp(big0) < 0 {
218				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
219			} else if normalizedInt.Cmp(bigMaxUint8) > 0 {
220				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
221			}
222			*v = uint8(normalizedInt.Uint64())
223		case *uint16:
224			normalizedInt, err := src.toBigInt()
225			if err != nil {
226				return err
227			}
228			if normalizedInt.Cmp(big0) < 0 {
229				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
230			} else if normalizedInt.Cmp(bigMaxUint16) > 0 {
231				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
232			}
233			*v = uint16(normalizedInt.Uint64())
234		case *uint32:
235			normalizedInt, err := src.toBigInt()
236			if err != nil {
237				return err
238			}
239			if normalizedInt.Cmp(big0) < 0 {
240				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
241			} else if normalizedInt.Cmp(bigMaxUint32) > 0 {
242				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
243			}
244			*v = uint32(normalizedInt.Uint64())
245		case *uint64:
246			normalizedInt, err := src.toBigInt()
247			if err != nil {
248				return err
249			}
250			if normalizedInt.Cmp(big0) < 0 {
251				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
252			} else if normalizedInt.Cmp(bigMaxUint64) > 0 {
253				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
254			}
255			*v = normalizedInt.Uint64()
256		default:
257			if nextDst, retry := GetAssignToDstType(dst); retry {
258				return src.AssignTo(nextDst)
259			}
260			return errors.Errorf("unable to assign to %T", dst)
261		}
262	case Null:
263		return NullAssignTo(dst)
264	}
265
266	return nil
267}
268
269func (dst *Numeric) toBigInt() (*big.Int, error) {
270	if dst.Exp == 0 {
271		return dst.Int, nil
272	}
273
274	num := &big.Int{}
275	num.Set(dst.Int)
276	if dst.Exp > 0 {
277		mul := &big.Int{}
278		mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil)
279		num.Mul(num, mul)
280		return num, nil
281	}
282
283	div := &big.Int{}
284	div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil)
285	remainder := &big.Int{}
286	num.DivMod(num, div, remainder)
287	if remainder.Cmp(big0) != 0 {
288		return nil, errors.Errorf("cannot convert %v to integer", dst)
289	}
290	return num, nil
291}
292
293func (src *Numeric) toFloat64() (float64, error) {
294	f, err := strconv.ParseFloat(src.Int.String(), 64)
295	if err != nil {
296		return 0, err
297	}
298	if src.Exp > 0 {
299		for i := 0; i < int(src.Exp); i++ {
300			f *= 10
301		}
302	} else if src.Exp < 0 {
303		for i := 0; i > int(src.Exp); i-- {
304			f /= 10
305		}
306	}
307	return f, nil
308}
309
310func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
311	if src == nil {
312		*dst = Numeric{Status: Null}
313		return nil
314	}
315
316	num, exp, err := parseNumericString(string(src))
317	if err != nil {
318		return err
319	}
320
321	*dst = Numeric{Int: num, Exp: exp, Status: Present}
322	return nil
323}
324
325func parseNumericString(str string) (n *big.Int, exp int32, err error) {
326	parts := strings.SplitN(str, ".", 2)
327	digits := strings.Join(parts, "")
328
329	if len(parts) > 1 {
330		exp = int32(-len(parts[1]))
331	} else {
332		for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' {
333			digits = digits[:len(digits)-1]
334			exp++
335		}
336	}
337
338	accum := &big.Int{}
339	if _, ok := accum.SetString(digits, 10); !ok {
340		return nil, 0, errors.Errorf("%s is not a number", str)
341	}
342
343	return accum, exp, nil
344}
345
346func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error {
347	if src == nil {
348		*dst = Numeric{Status: Null}
349		return nil
350	}
351
352	if len(src) < 8 {
353		return errors.Errorf("numeric incomplete %v", src)
354	}
355
356	rp := 0
357	ndigits := int16(binary.BigEndian.Uint16(src[rp:]))
358	rp += 2
359
360	if ndigits == 0 {
361		*dst = Numeric{Int: big.NewInt(0), Status: Present}
362		return nil
363	}
364
365	weight := int16(binary.BigEndian.Uint16(src[rp:]))
366	rp += 2
367	sign := int16(binary.BigEndian.Uint16(src[rp:]))
368	rp += 2
369	dscale := int16(binary.BigEndian.Uint16(src[rp:]))
370	rp += 2
371
372	if len(src[rp:]) < int(ndigits)*2 {
373		return errors.Errorf("numeric incomplete %v", src)
374	}
375
376	accum := &big.Int{}
377
378	for i := 0; i < int(ndigits+3)/4; i++ {
379		int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:])
380		rp += bytesRead
381
382		if i > 0 {
383			var mul *big.Int
384			switch digitsRead {
385			case 1:
386				mul = bigNBase
387			case 2:
388				mul = bigNBaseX2
389			case 3:
390				mul = bigNBaseX3
391			case 4:
392				mul = bigNBaseX4
393			default:
394				return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead)
395			}
396			accum.Mul(accum, mul)
397		}
398
399		accum.Add(accum, big.NewInt(int64accum))
400	}
401
402	exp := (int32(weight) - int32(ndigits) + 1) * 4
403
404	if dscale > 0 {
405		fracNBaseDigits := ndigits - weight - 1
406		fracDecimalDigits := fracNBaseDigits * 4
407
408		if dscale > fracDecimalDigits {
409			multCount := int(dscale - fracDecimalDigits)
410			for i := 0; i < multCount; i++ {
411				accum.Mul(accum, big10)
412				exp--
413			}
414		} else if dscale < fracDecimalDigits {
415			divCount := int(fracDecimalDigits - dscale)
416			for i := 0; i < divCount; i++ {
417				accum.Div(accum, big10)
418				exp++
419			}
420		}
421	}
422
423	reduced := &big.Int{}
424	remainder := &big.Int{}
425	if exp >= 0 {
426		for {
427			reduced.DivMod(accum, big10, remainder)
428			if remainder.Cmp(big0) != 0 {
429				break
430			}
431			accum.Set(reduced)
432			exp++
433		}
434	}
435
436	if sign != 0 {
437		accum.Neg(accum)
438	}
439
440	*dst = Numeric{Int: accum, Exp: exp, Status: Present}
441
442	return nil
443
444}
445
446func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
447	digits := len(src) / 2
448	if digits > 4 {
449		digits = 4
450	}
451
452	rp := 0
453
454	for i := 0; i < digits; i++ {
455		if i > 0 {
456			accum *= nbase
457		}
458		accum += int64(binary.BigEndian.Uint16(src[rp:]))
459		rp += 2
460	}
461
462	return accum, rp, digits
463}
464
465func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
466	switch src.Status {
467	case Null:
468		return nil, nil
469	case Undefined:
470		return nil, errUndefined
471	}
472
473	buf = append(buf, src.Int.String()...)
474	buf = append(buf, 'e')
475	buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
476	return buf, nil
477}
478
479func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
480	switch src.Status {
481	case Null:
482		return nil, nil
483	case Undefined:
484		return nil, errUndefined
485	}
486
487	var sign int16
488	if src.Int.Cmp(big0) < 0 {
489		sign = 16384
490	}
491
492	absInt := &big.Int{}
493	wholePart := &big.Int{}
494	fracPart := &big.Int{}
495	remainder := &big.Int{}
496	absInt.Abs(src.Int)
497
498	// Normalize absInt and exp to where exp is always a multiple of 4. This makes
499	// converting to 16-bit base 10,000 digits easier.
500	var exp int32
501	switch src.Exp % 4 {
502	case 1, -3:
503		exp = src.Exp - 1
504		absInt.Mul(absInt, big10)
505	case 2, -2:
506		exp = src.Exp - 2
507		absInt.Mul(absInt, big100)
508	case 3, -1:
509		exp = src.Exp - 3
510		absInt.Mul(absInt, big1000)
511	default:
512		exp = src.Exp
513	}
514
515	if exp < 0 {
516		divisor := &big.Int{}
517		divisor.Exp(big10, big.NewInt(int64(-exp)), nil)
518		wholePart.DivMod(absInt, divisor, fracPart)
519		fracPart.Add(fracPart, divisor)
520	} else {
521		wholePart = absInt
522	}
523
524	var wholeDigits, fracDigits []int16
525
526	for wholePart.Cmp(big0) != 0 {
527		wholePart.DivMod(wholePart, bigNBase, remainder)
528		wholeDigits = append(wholeDigits, int16(remainder.Int64()))
529	}
530
531	if fracPart.Cmp(big0) != 0 {
532		for fracPart.Cmp(big1) != 0 {
533			fracPart.DivMod(fracPart, bigNBase, remainder)
534			fracDigits = append(fracDigits, int16(remainder.Int64()))
535		}
536	}
537
538	buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits)))
539
540	var weight int16
541	if len(wholeDigits) > 0 {
542		weight = int16(len(wholeDigits) - 1)
543		if exp > 0 {
544			weight += int16(exp / 4)
545		}
546	} else {
547		weight = int16(exp/4) - 1 + int16(len(fracDigits))
548	}
549	buf = pgio.AppendInt16(buf, weight)
550
551	buf = pgio.AppendInt16(buf, sign)
552
553	var dscale int16
554	if src.Exp < 0 {
555		dscale = int16(-src.Exp)
556	}
557	buf = pgio.AppendInt16(buf, dscale)
558
559	for i := len(wholeDigits) - 1; i >= 0; i-- {
560		buf = pgio.AppendInt16(buf, wholeDigits[i])
561	}
562
563	for i := len(fracDigits) - 1; i >= 0; i-- {
564		buf = pgio.AppendInt16(buf, fracDigits[i])
565	}
566
567	return buf, nil
568}
569
570// Scan implements the database/sql Scanner interface.
571func (dst *Numeric) Scan(src interface{}) error {
572	if src == nil {
573		*dst = Numeric{Status: Null}
574		return nil
575	}
576
577	switch src := src.(type) {
578	case string:
579		return dst.DecodeText(nil, []byte(src))
580	case []byte:
581		srcCopy := make([]byte, len(src))
582		copy(srcCopy, src)
583		return dst.DecodeText(nil, srcCopy)
584	}
585
586	return errors.Errorf("cannot scan %T", src)
587}
588
589// Value implements the database/sql/driver Valuer interface.
590func (src Numeric) Value() (driver.Value, error) {
591	switch src.Status {
592	case Present:
593		buf, err := src.EncodeText(nil, nil)
594		if err != nil {
595			return nil, err
596		}
597
598		return string(buf), nil
599	case Null:
600		return nil, nil
601	default:
602		return nil, errUndefined
603	}
604}
605