1package pgtype
2
3import (
4	"database/sql/driver"
5	"encoding/binary"
6	"math"
7	"math/big"
8	"strconv"
9	"strings"
10
11	"github.com/jackc/pgx/pgio"
12	"github.com/pkg/errors"
13)
14
15// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
16const nbase = 10000
17
18var big0 *big.Int = big.NewInt(0)
19var big1 *big.Int = big.NewInt(1)
20var big10 *big.Int = big.NewInt(10)
21var big100 *big.Int = big.NewInt(100)
22var big1000 *big.Int = big.NewInt(1000)
23
24var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8)
25var bigMinInt8 *big.Int = big.NewInt(math.MinInt8)
26var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16)
27var bigMinInt16 *big.Int = big.NewInt(math.MinInt16)
28var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32)
29var bigMinInt32 *big.Int = big.NewInt(math.MinInt32)
30var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64)
31var bigMinInt64 *big.Int = big.NewInt(math.MinInt64)
32var bigMaxInt *big.Int = big.NewInt(int64(maxInt))
33var bigMinInt *big.Int = big.NewInt(int64(minInt))
34
35var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8)
36var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16)
37var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32)
38var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64))
39var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint))
40
41var bigNBase *big.Int = big.NewInt(nbase)
42var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
43var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
44var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
45
46type Numeric struct {
47	Int    *big.Int
48	Exp    int32
49	Status Status
50}
51
52func (dst *Numeric) Set(src interface{}) error {
53	if src == nil {
54		*dst = Numeric{Status: Null}
55		return nil
56	}
57
58	switch value := src.(type) {
59	case float32:
60		num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64))
61		if err != nil {
62			return err
63		}
64		*dst = Numeric{Int: num, Exp: exp, Status: Present}
65	case float64:
66		num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64))
67		if err != nil {
68			return err
69		}
70		*dst = Numeric{Int: num, Exp: exp, Status: Present}
71	case int8:
72		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
73	case uint8:
74		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
75	case int16:
76		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
77	case uint16:
78		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
79	case int32:
80		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
81	case uint32:
82		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
83	case int64:
84		*dst = Numeric{Int: big.NewInt(value), Status: Present}
85	case uint64:
86		*dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present}
87	case int:
88		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
89	case uint:
90		*dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present}
91	case string:
92		num, exp, err := parseNumericString(value)
93		if err != nil {
94			return err
95		}
96		*dst = Numeric{Int: num, Exp: exp, Status: Present}
97	default:
98		if originalSrc, ok := underlyingNumberType(src); ok {
99			return dst.Set(originalSrc)
100		}
101		return errors.Errorf("cannot convert %v to Numeric", value)
102	}
103
104	return nil
105}
106
107func (dst *Numeric) Get() interface{} {
108	switch dst.Status {
109	case Present:
110		return dst
111	case Null:
112		return nil
113	default:
114		return dst.Status
115	}
116}
117
118func (src *Numeric) AssignTo(dst interface{}) error {
119	switch src.Status {
120	case Present:
121		switch v := dst.(type) {
122		case *float32:
123			f, err := src.toFloat64()
124			if err != nil {
125				return err
126			}
127			return float64AssignTo(f, src.Status, dst)
128		case *float64:
129			f, err := src.toFloat64()
130			if err != nil {
131				return err
132			}
133			return float64AssignTo(f, src.Status, dst)
134		case *int:
135			normalizedInt, err := src.toBigInt()
136			if err != nil {
137				return err
138			}
139			if normalizedInt.Cmp(bigMaxInt) > 0 {
140				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
141			}
142			if normalizedInt.Cmp(bigMinInt) < 0 {
143				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
144			}
145			*v = int(normalizedInt.Int64())
146		case *int8:
147			normalizedInt, err := src.toBigInt()
148			if err != nil {
149				return err
150			}
151			if normalizedInt.Cmp(bigMaxInt8) > 0 {
152				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
153			}
154			if normalizedInt.Cmp(bigMinInt8) < 0 {
155				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
156			}
157			*v = int8(normalizedInt.Int64())
158		case *int16:
159			normalizedInt, err := src.toBigInt()
160			if err != nil {
161				return err
162			}
163			if normalizedInt.Cmp(bigMaxInt16) > 0 {
164				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
165			}
166			if normalizedInt.Cmp(bigMinInt16) < 0 {
167				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
168			}
169			*v = int16(normalizedInt.Int64())
170		case *int32:
171			normalizedInt, err := src.toBigInt()
172			if err != nil {
173				return err
174			}
175			if normalizedInt.Cmp(bigMaxInt32) > 0 {
176				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
177			}
178			if normalizedInt.Cmp(bigMinInt32) < 0 {
179				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
180			}
181			*v = int32(normalizedInt.Int64())
182		case *int64:
183			normalizedInt, err := src.toBigInt()
184			if err != nil {
185				return err
186			}
187			if normalizedInt.Cmp(bigMaxInt64) > 0 {
188				return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
189			}
190			if normalizedInt.Cmp(bigMinInt64) < 0 {
191				return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
192			}
193			*v = normalizedInt.Int64()
194		case *uint:
195			normalizedInt, err := src.toBigInt()
196			if err != nil {
197				return err
198			}
199			if normalizedInt.Cmp(big0) < 0 {
200				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
201			} else if normalizedInt.Cmp(bigMaxUint) > 0 {
202				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
203			}
204			*v = uint(normalizedInt.Uint64())
205		case *uint8:
206			normalizedInt, err := src.toBigInt()
207			if err != nil {
208				return err
209			}
210			if normalizedInt.Cmp(big0) < 0 {
211				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
212			} else if normalizedInt.Cmp(bigMaxUint8) > 0 {
213				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
214			}
215			*v = uint8(normalizedInt.Uint64())
216		case *uint16:
217			normalizedInt, err := src.toBigInt()
218			if err != nil {
219				return err
220			}
221			if normalizedInt.Cmp(big0) < 0 {
222				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
223			} else if normalizedInt.Cmp(bigMaxUint16) > 0 {
224				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
225			}
226			*v = uint16(normalizedInt.Uint64())
227		case *uint32:
228			normalizedInt, err := src.toBigInt()
229			if err != nil {
230				return err
231			}
232			if normalizedInt.Cmp(big0) < 0 {
233				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
234			} else if normalizedInt.Cmp(bigMaxUint32) > 0 {
235				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
236			}
237			*v = uint32(normalizedInt.Uint64())
238		case *uint64:
239			normalizedInt, err := src.toBigInt()
240			if err != nil {
241				return err
242			}
243			if normalizedInt.Cmp(big0) < 0 {
244				return errors.Errorf("%d is less than zero for %T", normalizedInt, *v)
245			} else if normalizedInt.Cmp(bigMaxUint64) > 0 {
246				return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
247			}
248			*v = normalizedInt.Uint64()
249		default:
250			if nextDst, retry := GetAssignToDstType(dst); retry {
251				return src.AssignTo(nextDst)
252			}
253		}
254	case Null:
255		return NullAssignTo(dst)
256	}
257
258	return nil
259}
260
261func (dst *Numeric) toBigInt() (*big.Int, error) {
262	if dst.Exp == 0 {
263		return dst.Int, nil
264	}
265
266	num := &big.Int{}
267	num.Set(dst.Int)
268	if dst.Exp > 0 {
269		mul := &big.Int{}
270		mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil)
271		num.Mul(num, mul)
272		return num, nil
273	}
274
275	div := &big.Int{}
276	div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil)
277	remainder := &big.Int{}
278	num.DivMod(num, div, remainder)
279	if remainder.Cmp(big0) != 0 {
280		return nil, errors.Errorf("cannot convert %v to integer", dst)
281	}
282	return num, nil
283}
284
285func (src *Numeric) toFloat64() (float64, error) {
286	f, err := strconv.ParseFloat(src.Int.String(), 64)
287	if err != nil {
288		return 0, err
289	}
290	if src.Exp > 0 {
291		for i := 0; i < int(src.Exp); i++ {
292			f *= 10
293		}
294	} else if src.Exp < 0 {
295		for i := 0; i > int(src.Exp); i-- {
296			f /= 10
297		}
298	}
299	return f, nil
300}
301
302func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
303	if src == nil {
304		*dst = Numeric{Status: Null}
305		return nil
306	}
307
308	num, exp, err := parseNumericString(string(src))
309	if err != nil {
310		return err
311	}
312
313	*dst = Numeric{Int: num, Exp: exp, Status: Present}
314	return nil
315}
316
317func parseNumericString(str string) (n *big.Int, exp int32, err error) {
318	parts := strings.SplitN(str, ".", 2)
319	digits := strings.Join(parts, "")
320
321	if len(parts) > 1 {
322		exp = int32(-len(parts[1]))
323	} else {
324		for len(digits) > 1 && digits[len(digits)-1] == '0' {
325			digits = digits[:len(digits)-1]
326			exp++
327		}
328	}
329
330	accum := &big.Int{}
331	if _, ok := accum.SetString(digits, 10); !ok {
332		return nil, 0, errors.Errorf("%s is not a number", str)
333	}
334
335	return accum, exp, nil
336}
337
338func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error {
339	if src == nil {
340		*dst = Numeric{Status: Null}
341		return nil
342	}
343
344	if len(src) < 8 {
345		return errors.Errorf("numeric incomplete %v", src)
346	}
347
348	rp := 0
349	ndigits := int16(binary.BigEndian.Uint16(src[rp:]))
350	rp += 2
351
352	if ndigits == 0 {
353		*dst = Numeric{Int: big.NewInt(0), Status: Present}
354		return nil
355	}
356
357	weight := int16(binary.BigEndian.Uint16(src[rp:]))
358	rp += 2
359	sign := int16(binary.BigEndian.Uint16(src[rp:]))
360	rp += 2
361	dscale := int16(binary.BigEndian.Uint16(src[rp:]))
362	rp += 2
363
364	if len(src[rp:]) < int(ndigits)*2 {
365		return errors.Errorf("numeric incomplete %v", src)
366	}
367
368	accum := &big.Int{}
369
370	for i := 0; i < int(ndigits+3)/4; i++ {
371		int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:])
372		rp += bytesRead
373
374		if i > 0 {
375			var mul *big.Int
376			switch digitsRead {
377			case 1:
378				mul = bigNBase
379			case 2:
380				mul = bigNBaseX2
381			case 3:
382				mul = bigNBaseX3
383			case 4:
384				mul = bigNBaseX4
385			default:
386				return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead)
387			}
388			accum.Mul(accum, mul)
389		}
390
391		accum.Add(accum, big.NewInt(int64accum))
392	}
393
394	exp := (int32(weight) - int32(ndigits) + 1) * 4
395
396	if dscale > 0 {
397		fracNBaseDigits := ndigits - weight - 1
398		fracDecimalDigits := fracNBaseDigits * 4
399
400		if dscale > fracDecimalDigits {
401			multCount := int(dscale - fracDecimalDigits)
402			for i := 0; i < multCount; i++ {
403				accum.Mul(accum, big10)
404				exp--
405			}
406		} else if dscale < fracDecimalDigits {
407			divCount := int(fracDecimalDigits - dscale)
408			for i := 0; i < divCount; i++ {
409				accum.Div(accum, big10)
410				exp++
411			}
412		}
413	}
414
415	reduced := &big.Int{}
416	remainder := &big.Int{}
417	if exp >= 0 {
418		for {
419			reduced.DivMod(accum, big10, remainder)
420			if remainder.Cmp(big0) != 0 {
421				break
422			}
423			accum.Set(reduced)
424			exp++
425		}
426	}
427
428	if sign != 0 {
429		accum.Neg(accum)
430	}
431
432	*dst = Numeric{Int: accum, Exp: exp, Status: Present}
433
434	return nil
435
436}
437
438func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
439	digits := len(src) / 2
440	if digits > 4 {
441		digits = 4
442	}
443
444	rp := 0
445
446	for i := 0; i < digits; i++ {
447		if i > 0 {
448			accum *= nbase
449		}
450		accum += int64(binary.BigEndian.Uint16(src[rp:]))
451		rp += 2
452	}
453
454	return accum, rp, digits
455}
456
457func (src *Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
458	switch src.Status {
459	case Null:
460		return nil, nil
461	case Undefined:
462		return nil, errUndefined
463	}
464
465	buf = append(buf, src.Int.String()...)
466	buf = append(buf, 'e')
467	buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
468	return buf, nil
469}
470
471func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
472	switch src.Status {
473	case Null:
474		return nil, nil
475	case Undefined:
476		return nil, errUndefined
477	}
478
479	var sign int16
480	if src.Int.Cmp(big0) < 0 {
481		sign = 16384
482	}
483
484	absInt := &big.Int{}
485	wholePart := &big.Int{}
486	fracPart := &big.Int{}
487	remainder := &big.Int{}
488	absInt.Abs(src.Int)
489
490	// Normalize absInt and exp to where exp is always a multiple of 4. This makes
491	// converting to 16-bit base 10,000 digits easier.
492	var exp int32
493	switch src.Exp % 4 {
494	case 1, -3:
495		exp = src.Exp - 1
496		absInt.Mul(absInt, big10)
497	case 2, -2:
498		exp = src.Exp - 2
499		absInt.Mul(absInt, big100)
500	case 3, -1:
501		exp = src.Exp - 3
502		absInt.Mul(absInt, big1000)
503	default:
504		exp = src.Exp
505	}
506
507	if exp < 0 {
508		divisor := &big.Int{}
509		divisor.Exp(big10, big.NewInt(int64(-exp)), nil)
510		wholePart.DivMod(absInt, divisor, fracPart)
511		fracPart.Add(fracPart, divisor)
512	} else {
513		wholePart = absInt
514	}
515
516	var wholeDigits, fracDigits []int16
517
518	for wholePart.Cmp(big0) != 0 {
519		wholePart.DivMod(wholePart, bigNBase, remainder)
520		wholeDigits = append(wholeDigits, int16(remainder.Int64()))
521	}
522
523	if fracPart.Cmp(big0) != 0 {
524		for fracPart.Cmp(big1) != 0 {
525			fracPart.DivMod(fracPart, bigNBase, remainder)
526			fracDigits = append(fracDigits, int16(remainder.Int64()))
527		}
528	}
529
530	buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits)))
531
532	var weight int16
533	if len(wholeDigits) > 0 {
534		weight = int16(len(wholeDigits) - 1)
535		if exp > 0 {
536			weight += int16(exp / 4)
537		}
538	} else {
539		weight = int16(exp/4) - 1 + int16(len(fracDigits))
540	}
541	buf = pgio.AppendInt16(buf, weight)
542
543	buf = pgio.AppendInt16(buf, sign)
544
545	var dscale int16
546	if src.Exp < 0 {
547		dscale = int16(-src.Exp)
548	}
549	buf = pgio.AppendInt16(buf, dscale)
550
551	for i := len(wholeDigits) - 1; i >= 0; i-- {
552		buf = pgio.AppendInt16(buf, wholeDigits[i])
553	}
554
555	for i := len(fracDigits) - 1; i >= 0; i-- {
556		buf = pgio.AppendInt16(buf, fracDigits[i])
557	}
558
559	return buf, nil
560}
561
562// Scan implements the database/sql Scanner interface.
563func (dst *Numeric) Scan(src interface{}) error {
564	if src == nil {
565		*dst = Numeric{Status: Null}
566		return nil
567	}
568
569	switch src := src.(type) {
570	case float64:
571		// TODO
572		// *dst = Numeric{Float: src, Status: Present}
573		return nil
574	case string:
575		return dst.DecodeText(nil, []byte(src))
576	case []byte:
577		srcCopy := make([]byte, len(src))
578		copy(srcCopy, src)
579		return dst.DecodeText(nil, srcCopy)
580	}
581
582	return errors.Errorf("cannot scan %T", src)
583}
584
585// Value implements the database/sql/driver Valuer interface.
586func (src *Numeric) Value() (driver.Value, error) {
587	switch src.Status {
588	case Present:
589		buf, err := src.EncodeText(nil, nil)
590		if err != nil {
591			return nil, err
592		}
593
594		return string(buf), nil
595	case Null:
596		return nil, nil
597	default:
598		return nil, errUndefined
599	}
600}
601