1// Copyright 2016 The Cockroach Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12// implied. See the License for the specific language governing
13// permissions and limitations under the License.
14
15package apd
16
17import (
18	"encoding/json"
19	"fmt"
20	"math"
21	"math/big"
22	"testing"
23	"unsafe"
24)
25
26var (
27	testCtx = &BaseContext
28)
29
30func (d *Decimal) GoString() string {
31	return fmt.Sprintf(`{Coeff: %s, Exponent: %d, Negative: %v, Form: %s}`, d.Coeff.String(), d.Exponent, d.Negative, d.Form)
32}
33
34// testExponentError skips t if err was caused by an exponent being outside
35// of the package's supported exponent range. Since the exponent is so large,
36// we don't support those tests yet (i.e., it's an expected failure, so we
37// skip it).
38func testExponentError(t *testing.T, err error) {
39	if err == nil {
40		return
41	}
42	if err.Error() == errExponentOutOfRangeStr {
43		t.Skip(err)
44	}
45}
46
47func newDecimal(t *testing.T, c *Context, s string) *Decimal {
48	d, _, err := c.NewFromString(s)
49	testExponentError(t, err)
50	if err != nil {
51		t.Fatalf("%s: %+v", s, err)
52	}
53	return d
54}
55
56func TestUpscale(t *testing.T) {
57	tests := []struct {
58		x, y *Decimal
59		a, b *big.Int
60		s    int32
61	}{
62		{x: New(1, 0), y: New(100, -1), a: big.NewInt(10), b: big.NewInt(100), s: -1},
63		{x: New(1, 0), y: New(10, -1), a: big.NewInt(10), b: big.NewInt(10), s: -1},
64		{x: New(1, 0), y: New(10, 0), a: big.NewInt(1), b: big.NewInt(10), s: 0},
65		{x: New(1, 1), y: New(1, 0), a: big.NewInt(10), b: big.NewInt(1), s: 0},
66		{x: New(10, -2), y: New(1, -1), a: big.NewInt(10), b: big.NewInt(10), s: -2},
67		{x: New(1, -2), y: New(100, 1), a: big.NewInt(1), b: big.NewInt(100000), s: -2},
68	}
69	for _, tc := range tests {
70		t.Run(fmt.Sprintf("%s, %s", tc.x, tc.y), func(t *testing.T) {
71			a, b, s, err := upscale(tc.x, tc.y)
72			if err != nil {
73				t.Fatal(err)
74			}
75			if a.Cmp(tc.a) != 0 {
76				t.Errorf("a: expected %s, got %s", tc.a, a)
77			}
78			if b.Cmp(tc.b) != 0 {
79				t.Errorf("b: expected %s, got %s", tc.b, b)
80			}
81			if s != tc.s {
82				t.Errorf("s: expected %d, got %d", tc.s, s)
83			}
84		})
85	}
86}
87
88func TestAdd(t *testing.T) {
89	tests := []struct {
90		x, y string
91		r    string
92	}{
93		{x: "1", y: "10", r: "11"},
94		{x: "1", y: "1e1", r: "11"},
95		{x: "1e1", y: "1", r: "11"},
96		{x: ".1e1", y: "100e-1", r: "11.0"},
97	}
98	for _, tc := range tests {
99		t.Run(fmt.Sprintf("%s, %s", tc.x, tc.y), func(t *testing.T) {
100			x := newDecimal(t, testCtx, tc.x)
101			y := newDecimal(t, testCtx, tc.y)
102			d := new(Decimal)
103			_, err := testCtx.Add(d, x, y)
104			if err != nil {
105				t.Fatal(err)
106			}
107			s := d.String()
108			if s != tc.r {
109				t.Fatalf("expected: %s, got: %s", tc.r, s)
110			}
111		})
112	}
113}
114
115func TestCmp(t *testing.T) {
116	tests := []struct {
117		x, y string
118		c    int
119	}{
120		{x: "1", y: "10", c: -1},
121		{x: "1", y: "1e1", c: -1},
122		{x: "1e1", y: "1", c: 1},
123		{x: ".1e1", y: "100e-1", c: -1},
124
125		{x: ".1e1", y: "100e-2", c: 0},
126		{x: "1", y: ".1e1", c: 0},
127		{x: "1", y: "1", c: 0},
128	}
129	for _, tc := range tests {
130		t.Run(fmt.Sprintf("%s, %s", tc.x, tc.y), func(t *testing.T) {
131			x := newDecimal(t, testCtx, tc.x)
132			y := newDecimal(t, testCtx, tc.y)
133			c := x.Cmp(y)
134			if c != tc.c {
135				t.Fatalf("expected: %d, got: %d", tc.c, c)
136			}
137		})
138	}
139}
140
141func TestModf(t *testing.T) {
142	tests := []struct {
143		x string
144		i string
145		f string
146	}{
147		{x: "1", i: "1", f: "0"},
148		{x: "1.0", i: "1", f: "0.0"},
149		{x: "1.0e1", i: "10", f: "0"},
150		{x: "1.0e2", i: "1.0E+2", f: "0"},
151		{x: "1.0e-1", i: "0", f: "0.10"},
152		{x: "1.0e-2", i: "0", f: "0.010"},
153		{x: "1.1", i: "1", f: "0.1"},
154		{x: "1234.56", i: "1234", f: "0.56"},
155		{x: "1234.56e2", i: "123456", f: "0"},
156		{x: "1234.56e4", i: "1.23456E+7", f: "0"},
157		{x: "1234.56e-2", i: "12", f: "0.3456"},
158		{x: "1234.56e-4", i: "0", f: "0.123456"},
159		{x: "1234.56e-6", i: "0", f: "0.00123456"},
160		{x: "123456e-8", i: "0", f: "0.00123456"},
161		{x: ".123456e8", i: "1.23456E+7", f: "0"},
162
163		{x: "-1", i: "-1", f: "-0"},
164		{x: "-1.0", i: "-1", f: "-0.0"},
165		{x: "-1.0e1", i: "-10", f: "-0"},
166		{x: "-1.0e2", i: "-1.0E+2", f: "-0"},
167		{x: "-1.0e-1", i: "-0", f: "-0.10"},
168		{x: "-1.0e-2", i: "-0", f: "-0.010"},
169		{x: "-1.1", i: "-1", f: "-0.1"},
170		{x: "-1234.56", i: "-1234", f: "-0.56"},
171		{x: "-1234.56e2", i: "-123456", f: "-0"},
172		{x: "-1234.56e4", i: "-1.23456E+7", f: "-0"},
173		{x: "-1234.56e-2", i: "-12", f: "-0.3456"},
174		{x: "-1234.56e-4", i: "-0", f: "-0.123456"},
175		{x: "-1234.56e-6", i: "-0", f: "-0.00123456"},
176		{x: "-123456e-8", i: "-0", f: "-0.00123456"},
177		{x: "-.123456e8", i: "-1.23456E+7", f: "-0"},
178	}
179	for _, tc := range tests {
180		t.Run(tc.x, func(t *testing.T) {
181			x := newDecimal(t, testCtx, tc.x)
182			integ, frac := new(Decimal), new(Decimal)
183			x.Modf(integ, frac)
184			if tc.i != integ.String() {
185				t.Fatalf("integ: expected: %s, got: %s", tc.i, integ)
186			}
187			if tc.f != frac.String() {
188				t.Fatalf("frac: expected: %s, got: %s", tc.f, frac)
189			}
190			a := new(Decimal)
191			if _, err := testCtx.Add(a, integ, frac); err != nil {
192				t.Fatal(err)
193			}
194			if a.Cmp(x) != 0 {
195				t.Fatalf("%s != %s", a, x)
196			}
197			if integ.Exponent < 0 {
198				t.Fatal(integ.Exponent)
199			}
200			if frac.Exponent > 0 {
201				t.Fatal(frac.Exponent)
202			}
203
204			integ2, frac2 := new(Decimal), new(Decimal)
205			x.Modf(integ2, nil)
206			x.Modf(nil, frac2)
207			if integ.CmpTotal(integ2) != 0 {
208				t.Fatalf("got %s, expected %s", integ2, integ)
209			}
210			if frac.CmpTotal(frac2) != 0 {
211				t.Fatalf("got %s, expected %s", frac2, frac)
212			}
213		})
214	}
215
216	// Ensure we don't panic on both nil.
217	a := new(Decimal)
218	a.Modf(nil, nil)
219}
220
221func TestInt64(t *testing.T) {
222	tests := []struct {
223		x   string
224		i   int64
225		err bool
226	}{
227		{x: "0.12e1", err: true},
228		{x: "0.1e1", i: 1},
229		{x: "10", i: 10},
230		{x: "12.3e3", i: 12300},
231		{x: "1e-1", err: true},
232		{x: "1e2", i: 100},
233		{x: "1", i: 1},
234		{x: "NaN", err: true},
235		{x: "Inf", err: true},
236		{x: "9223372036854775807", i: 9223372036854775807},
237		{x: "-9223372036854775808", i: -9223372036854775808},
238		{x: "9223372036854775808", err: true},
239	}
240	for _, tc := range tests {
241		t.Run(tc.x, func(t *testing.T) {
242			x := newDecimal(t, testCtx, tc.x)
243			i, err := x.Int64()
244			hasErr := err != nil
245			if tc.err != hasErr {
246				t.Fatalf("expected error: %v, got error: %v", tc.err, err)
247			}
248			if hasErr {
249				return
250			}
251			if i != tc.i {
252				t.Fatalf("expected: %v, got %v", tc.i, i)
253			}
254		})
255	}
256}
257
258func TestQuoErr(t *testing.T) {
259	tests := []struct {
260		x, y string
261		p    uint32
262		err  string
263	}{
264		{x: "1", y: "1", p: 0, err: errZeroPrecisionStr},
265		{x: "1", y: "0", p: 1, err: "division by zero"},
266	}
267	for _, tc := range tests {
268		c := testCtx.WithPrecision(tc.p)
269		x := newDecimal(t, testCtx, tc.x)
270		y := newDecimal(t, testCtx, tc.y)
271		d := new(Decimal)
272		_, err := c.Quo(d, x, y)
273		if err == nil {
274			t.Fatal("expected error")
275		}
276		if err.Error() != tc.err {
277			t.Fatalf("expected %s, got %s", tc.err, err)
278		}
279	}
280}
281
282func TestConditionString(t *testing.T) {
283	tests := map[Condition]string{
284		Overflow:             "overflow",
285		Overflow | Underflow: "overflow, underflow",
286		Subnormal | Inexact:  "inexact, subnormal",
287	}
288	for c, s := range tests {
289		t.Run(s, func(t *testing.T) {
290			cs := c.String()
291			if cs != s {
292				t.Errorf("expected %s; got %s", s, cs)
293			}
294		})
295	}
296}
297
298func TestFloat64(t *testing.T) {
299	tests := []float64{
300		0,
301		1,
302		-1,
303		math.MaxFloat32,
304		math.SmallestNonzeroFloat32,
305		math.MaxFloat64,
306		math.SmallestNonzeroFloat64,
307	}
308
309	for _, tc := range tests {
310		t.Run(fmt.Sprint(tc), func(t *testing.T) {
311			d := new(Decimal)
312			d.SetFloat64(tc)
313			f, err := d.Float64()
314			if err != nil {
315				t.Fatal(err)
316			}
317			if tc != f {
318				t.Fatalf("expected %v, got %v", tc, f)
319			}
320		})
321	}
322}
323
324func TestCeil(t *testing.T) {
325	tests := map[float64]int64{
326		0:    0,
327		-0.1: 0,
328		0.1:  1,
329		-0.9: 0,
330		0.9:  1,
331		-1:   -1,
332		1:    1,
333		-1.1: -1,
334		1.1:  2,
335	}
336
337	for f, r := range tests {
338		t.Run(fmt.Sprint(f), func(t *testing.T) {
339			d, err := new(Decimal).SetFloat64(f)
340			if err != nil {
341				t.Fatal(err)
342			}
343			_, err = testCtx.Ceil(d, d)
344			if err != nil {
345				t.Fatal(err)
346			}
347			i, err := d.Int64()
348			if err != nil {
349				t.Fatal(err)
350			}
351			if i != r {
352				t.Fatalf("got %v, expected %v", i, r)
353			}
354		})
355	}
356}
357
358func TestFloor(t *testing.T) {
359	tests := map[float64]int64{
360		0:    0,
361		-0.1: -1,
362		0.1:  0,
363		-0.9: -1,
364		0.9:  0,
365		-1:   -1,
366		1:    1,
367		-1.1: -2,
368		1.1:  1,
369	}
370
371	for f, r := range tests {
372		t.Run(fmt.Sprint(f), func(t *testing.T) {
373			d, err := new(Decimal).SetFloat64(f)
374			if err != nil {
375				t.Fatal(err)
376			}
377			_, err = testCtx.Floor(d, d)
378			if err != nil {
379				t.Fatal(err)
380			}
381			i, err := d.Int64()
382			if err != nil {
383				t.Fatal(err)
384			}
385			if i != r {
386				t.Fatalf("got %v, expected %v", i, r)
387			}
388		})
389	}
390}
391
392func TestFormat(t *testing.T) {
393	tests := map[string]struct {
394		e, E, f, g, G string
395	}{
396		"NaN":       {},
397		"Infinity":  {},
398		"-Infinity": {},
399		"sNaN":      {},
400		"0": {
401			e: "0e+0",
402			E: "0E+0",
403		},
404		"-0": {
405			e: "-0e+0",
406			E: "-0E+0",
407		},
408		"0.0": {
409			e: "0e-1",
410			E: "0E-1",
411		},
412		"-0.0": {
413			e: "-0e-1",
414			E: "-0E-1",
415		},
416		"0E+2": {
417			e: "0e+2",
418			f: "000",
419			g: "0e+2",
420		},
421	}
422	verbs := []string{"%e", "%E", "%f", "%g", "%G"}
423
424	for input, tc := range tests {
425		t.Run(input, func(t *testing.T) {
426			d, _, err := NewFromString(input)
427			if err != nil {
428				t.Fatal(err)
429			}
430			for i, s := range []string{tc.e, tc.E, tc.f, tc.g, tc.G} {
431				if s == "" {
432					s = input
433				}
434				v := verbs[i]
435				t.Run(v, func(t *testing.T) {
436					out := fmt.Sprintf(v, d)
437					if out != s {
438						t.Fatalf("expected %s, got %s", s, out)
439					}
440				})
441			}
442		})
443	}
444}
445
446func TestFormatFlags(t *testing.T) {
447	const stdD = "1.23E+56"
448	tests := []struct {
449		d   string
450		fmt string
451		out string
452	}{
453		{
454			d:   stdD,
455			fmt: "%3G",
456			out: "1.23E+56",
457		},
458		{
459			d:   stdD,
460			fmt: "%010G",
461			out: "001.23E+56",
462		},
463		{
464			d:   stdD,
465			fmt: "%10G",
466			out: "  1.23E+56",
467		},
468		{
469			d:   stdD,
470			fmt: "%+G",
471			out: "+1.23E+56",
472		},
473		{
474			d:   stdD,
475			fmt: "% G",
476			out: " 1.23E+56",
477		},
478		{
479			d:   stdD,
480			fmt: "%-10G",
481			out: "1.23E+56  ",
482		},
483		{
484			d:   stdD,
485			fmt: "%-010G",
486			out: "1.23E+56  ",
487		},
488		{
489			d:   "nan",
490			fmt: "%-10G",
491			out: "NaN       ",
492		},
493		{
494			d:   "nan",
495			fmt: "%10G",
496			out: "       NaN",
497		},
498		{
499			d:   "nan",
500			fmt: "%010G",
501			out: "       NaN",
502		},
503		{
504			d:   "inf",
505			fmt: "%-10G",
506			out: "Infinity  ",
507		},
508		{
509			d:   "inf",
510			fmt: "%10G",
511			out: "  Infinity",
512		},
513		{
514			d:   "inf",
515			fmt: "%010G",
516			out: "  Infinity",
517		},
518		{
519			d:   "-inf",
520			fmt: "%-10G",
521			out: "-Infinity ",
522		},
523		{
524			d:   "-inf",
525			fmt: "%10G",
526			out: " -Infinity",
527		},
528		{
529			d:   "-inf",
530			fmt: "%010G",
531			out: " -Infinity",
532		},
533		{
534			d:   "0",
535			fmt: "%d",
536			out: "%!d(*apd.Decimal=0)",
537		},
538	}
539	for _, tc := range tests {
540		t.Run(fmt.Sprintf("%s: %s", tc.d, tc.fmt), func(t *testing.T) {
541			d := newDecimal(t, &BaseContext, tc.d)
542			s := fmt.Sprintf(tc.fmt, d)
543			if s != tc.out {
544				t.Fatalf("expected %q, got %q", tc.out, s)
545			}
546		})
547	}
548}
549
550func TestContextSetStringt(t *testing.T) {
551	tests := []struct {
552		s      string
553		c      *Context
554		expect string
555	}{
556		{
557			s:      "1.234",
558			c:      &BaseContext,
559			expect: "1.234",
560		},
561		{
562			s:      "1.234",
563			c:      BaseContext.WithPrecision(2),
564			expect: "1.2",
565		},
566	}
567	for i, tc := range tests {
568		t.Run(fmt.Sprintf("%d: %s", i, tc.s), func(t *testing.T) {
569			d := new(Decimal)
570			if _, _, err := tc.c.SetString(d, tc.s); err != nil {
571				t.Fatal(err)
572			}
573			got := d.String()
574			if got != tc.expect {
575				t.Fatalf("expected: %s, got: %s", tc.expect, got)
576			}
577		})
578	}
579}
580
581func TestQuantize(t *testing.T) {
582	tests := []struct {
583		s      string
584		e      int32
585		expect string
586	}{
587		{
588			s:      "1.00",
589			e:      -1,
590			expect: "1.0",
591		},
592		{
593			s:      "2.0",
594			e:      -1,
595			expect: "2.0",
596		},
597		{
598			s:      "3",
599			e:      -1,
600			expect: "3.0",
601		},
602		{
603			s:      "9.9999",
604			e:      -2,
605			expect: "10.00",
606		},
607	}
608	c := BaseContext.WithPrecision(10)
609	for _, tc := range tests {
610		t.Run(fmt.Sprintf("%s: %d", tc.s, tc.e), func(t *testing.T) {
611			d, _, err := NewFromString(tc.s)
612			if err != nil {
613				t.Fatal(err)
614			}
615			if _, err := c.Quantize(d, d, tc.e); err != nil {
616				t.Fatal(err)
617			}
618			s := d.String()
619			if s != tc.expect {
620				t.Fatalf("expected: %s, got: %s", tc.expect, s)
621			}
622		})
623	}
624}
625
626func TestCmpOrder(t *testing.T) {
627	tests := []struct {
628		s     string
629		order int
630	}{
631		{s: "-NaN", order: -4},
632		{s: "-sNaN", order: -3},
633		{s: "-Infinity", order: -2},
634		{s: "-127", order: -1},
635		{s: "-1.00", order: -1},
636		{s: "-1", order: -1},
637		{s: "-0.000", order: -1},
638		{s: "-0", order: -1},
639		{s: "0", order: 1},
640		{s: "1.2300", order: 1},
641		{s: "1.23", order: 1},
642		{s: "1E+9", order: 1},
643		{s: "Infinity", order: 2},
644		{s: "sNaN", order: 3},
645		{s: "NaN", order: 4},
646	}
647
648	for _, tc := range tests {
649		t.Run(tc.s, func(t *testing.T) {
650			d, _, err := NewFromString(tc.s)
651			if err != nil {
652				t.Fatal(err)
653			}
654			o := d.cmpOrder()
655			if o != tc.order {
656				t.Fatalf("got %d, expected %d", o, tc.order)
657			}
658		})
659	}
660}
661
662func TestIsZero(t *testing.T) {
663	tests := []struct {
664		s    string
665		zero bool
666	}{
667		{s: "-NaN", zero: false},
668		{s: "-sNaN", zero: false},
669		{s: "-Infinity", zero: false},
670		{s: "-127", zero: false},
671		{s: "-1.00", zero: false},
672		{s: "-1", zero: false},
673		{s: "-0.000", zero: true},
674		{s: "-0", zero: true},
675		{s: "0", zero: true},
676		{s: "1.2300", zero: false},
677		{s: "1.23", zero: false},
678		{s: "1E+9", zero: false},
679		{s: "Infinity", zero: false},
680		{s: "sNaN", zero: false},
681		{s: "NaN", zero: false},
682	}
683
684	for _, tc := range tests {
685		t.Run(tc.s, func(t *testing.T) {
686			d, _, err := NewFromString(tc.s)
687			if err != nil {
688				t.Fatal(err)
689			}
690			z := d.IsZero()
691			if z != tc.zero {
692				t.Fatalf("got %v, expected %v", z, tc.zero)
693			}
694		})
695	}
696}
697
698func TestNeg(t *testing.T) {
699	tests := map[string]string{
700		"0":          "0",
701		"-0":         "0",
702		"-0.000":     "0.000",
703		"-00.000100": "0.000100",
704	}
705
706	for tc, expect := range tests {
707		t.Run(tc, func(t *testing.T) {
708			d, _, err := NewFromString(tc)
709			if err != nil {
710				t.Fatal(err)
711			}
712			var z Decimal
713			z.Neg(d)
714			s := z.String()
715			if s != expect {
716				t.Fatalf("expected %s, got %s", expect, s)
717			}
718		})
719	}
720}
721
722func TestReduce(t *testing.T) {
723	tests := map[string]int{
724		"-0":        0,
725		"0":         0,
726		"0.0":       0,
727		"00":        0,
728		"0.00":      0,
729		"-01000":    3,
730		"01000":     3,
731		"-1":        0,
732		"1":         0,
733		"-10.000E4": 4,
734		"10.000E4":  4,
735		"-10.00":    3,
736		"10.00":     3,
737		"-10":       1,
738		"10":        1,
739		"-143200000000000000000000000000000000000000000000000000000000": 56,
740		"143200000000000000000000000000000000000000000000000000000000":  56,
741		"Inf": 0,
742		"NaN": 0,
743	}
744
745	for s, n := range tests {
746		t.Run(s, func(t *testing.T) {
747			d, _, err := NewFromString(s)
748			if err != nil {
749				t.Fatal(err)
750			}
751			_, got := d.Reduce(d)
752			if n != got {
753				t.Fatalf("got %v, expected %v", got, n)
754			}
755		})
756	}
757}
758
759// TestSizeof is meant to catch changes that unexpectedly increase
760// the size of the Decimal struct.
761func TestSizeof(t *testing.T) {
762	var d Decimal
763	if s := unsafe.Sizeof(d); s != 48 {
764		t.Errorf("sizeof(Decimal) changed: %d", s)
765	}
766	var c Context
767	if s := unsafe.Sizeof(c); s != 32 {
768		t.Errorf("sizeof(Context) changed: %d", s)
769	}
770}
771
772func TestJSONEncoding(t *testing.T) {
773	var encodingTests = []string{
774		"0",
775		"1",
776		"2",
777		"10",
778		"1000",
779		"1234567890",
780		"298472983472983471903246121093472394872319615612417471234712061",
781		"0.0",
782		"NaN",
783		"Inf",
784		"123.456",
785		"1E1",
786		"1E-1",
787		"1.2E3",
788	}
789
790	for _, test := range encodingTests {
791		for _, sign := range []string{"", "+", "-"} {
792			x := sign + test
793			var tx Decimal
794			tx.SetString(x)
795			b, err := json.Marshal(&tx)
796			if err != nil {
797				t.Errorf("marshaling of %s failed: %s", &tx, err)
798				continue
799			}
800			var rx Decimal
801			if err := json.Unmarshal(b, &rx); err != nil {
802				t.Errorf("unmarshaling of %s failed: %s", &tx, err)
803				continue
804			}
805			if rx.CmpTotal(&tx) != 0 {
806				t.Errorf("JSON encoding of %s failed: got %s want %s", &tx, &rx, &tx)
807			}
808		}
809	}
810}
811