1/*
2Copyright 2014 SAP SE
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package driver
18
19import (
20	"database/sql/driver"
21	"errors"
22	"fmt"
23	"math"
24	"math/big"
25	"sync"
26)
27
28//bigint word size (*--> src/pkg/math/big/arith.go)
29const (
30	// Compute the size _S of a Word in bytes.
31	_m    = ^big.Word(0)
32	_logS = _m>>8&1 + _m>>16&1 + _m>>32&1
33	_S    = 1 << _logS
34)
35
36const (
37	// http://en.wikipedia.org/wiki/Decimal128_floating-point_format
38	dec128Digits = 34
39	dec128Bias   = 6176
40	dec128MinExp = -6176
41	dec128MaxExp = 6111
42)
43
44const (
45	decimalSize = 16 //number of bytes
46)
47
48var natZero = big.NewInt(0)
49var natOne = big.NewInt(1)
50var natTen = big.NewInt(10)
51
52var nat = []*big.Int{
53	natOne,                  //10^0
54	natTen,                  //10^1
55	big.NewInt(100),         //10^2
56	big.NewInt(1000),        //10^3
57	big.NewInt(10000),       //10^4
58	big.NewInt(100000),      //10^5
59	big.NewInt(1000000),     //10^6
60	big.NewInt(10000000),    //10^7
61	big.NewInt(100000000),   //10^8
62	big.NewInt(1000000000),  //10^9
63	big.NewInt(10000000000), //10^10
64}
65
66const lg10 = math.Ln10 / math.Ln2 // ~log2(10)
67
68var maxDecimal = new(big.Int).SetBytes([]byte{0x01, 0xED, 0x09, 0xBE, 0xAD, 0x87, 0xC0, 0x37, 0x8D, 0x8E, 0x63, 0xFF, 0xFF, 0xFF, 0xFF})
69
70type decFlags byte
71
72const (
73	dfNotExact decFlags = 1 << iota
74	dfOverflow
75	dfUnderflow
76)
77
78// ErrDecimalOutOfRange means that a big.Rat exceeds the size of hdb decimal fields.
79var ErrDecimalOutOfRange = errors.New("decimal out of range error")
80
81// big.Int free list
82var bigIntFree = sync.Pool{
83	New: func() interface{} { return new(big.Int) },
84}
85
86// big.Rat free list
87var bigRatFree = sync.Pool{
88	New: func() interface{} { return new(big.Rat) },
89}
90
91// A Decimal is the driver representation of a database decimal field value as big.Rat.
92type Decimal big.Rat
93
94// Scan implements the database/sql/Scanner interface.
95func (d *Decimal) Scan(src interface{}) error {
96
97	b, ok := src.([]byte)
98	if !ok {
99		return fmt.Errorf("decimal: invalid data type %T", src)
100	}
101
102	if len(b) != decimalSize {
103		return fmt.Errorf("decimal: invalid size %d of %v - %d expected", len(b), b, decimalSize)
104	}
105
106	if (b[15] & 0x60) == 0x60 {
107		return fmt.Errorf("decimal: format (infinity, nan, ...) not supported : %v", b)
108	}
109
110	v := (*big.Rat)(d)
111	p := v.Num()
112	q := v.Denom()
113
114	neg, exp := decodeDecimal(b, p)
115
116	switch {
117	case exp < 0:
118		q.Set(exp10(exp * -1))
119	case exp == 0:
120		q.Set(natOne)
121	case exp > 0:
122		p.Mul(p, exp10(exp))
123		q.Set(natOne)
124	}
125
126	if neg {
127		v.Neg(v)
128	}
129	return nil
130}
131
132// Value implements the database/sql/Valuer interface.
133func (d Decimal) Value() (driver.Value, error) {
134	m := bigIntFree.Get().(*big.Int)
135	neg, exp, df := convertRatToDecimal((*big.Rat)(&d), m, dec128Digits, dec128MinExp, dec128MaxExp)
136
137	var v driver.Value
138	var err error
139
140	switch {
141	default:
142		v, err = encodeDecimal(m, neg, exp)
143	case df&dfUnderflow != 0: // set to zero
144		m.Set(natZero)
145		v, err = encodeDecimal(m, false, 0)
146	case df&dfOverflow != 0:
147		err = ErrDecimalOutOfRange
148	}
149
150	// performance (avoid expensive defer)
151	bigIntFree.Put(m)
152
153	return v, err
154}
155
156func convertRatToDecimal(x *big.Rat, m *big.Int, digits, minExp, maxExp int) (bool, int, decFlags) {
157
158	neg := x.Sign() < 0 //store sign
159
160	if x.Num().Cmp(natZero) == 0 { // zero
161		m.Set(natZero)
162		return neg, 0, 0
163	}
164
165	c := bigRatFree.Get().(*big.Rat).Abs(x) // copy && abs
166	a := c.Num()
167	b := c.Denom()
168
169	exp, shift := 0, 0
170
171	if c.IsInt() {
172		exp = digits10(a) - 1
173	} else {
174		shift = digits10(a) - digits10(b)
175		switch {
176		case shift < 0:
177			a.Mul(a, exp10(shift*-1))
178		case shift > 0:
179			b.Mul(b, exp10(shift))
180		}
181		if a.Cmp(b) == -1 {
182			exp = shift - 1
183		} else {
184			exp = shift
185		}
186	}
187
188	var df decFlags
189
190	switch {
191	default:
192		exp = max(exp-digits+1, minExp)
193	case exp < minExp:
194		df |= dfUnderflow
195		exp = exp - digits + 1
196	}
197
198	if exp > maxExp {
199		df |= dfOverflow
200	}
201
202	shift = exp - shift
203	switch {
204	case shift < 0:
205		a.Mul(a, exp10(shift*-1))
206	case exp > 0:
207		b.Mul(b, exp10(shift))
208	}
209
210	m.QuoRem(a, b, a) // reuse a as rest
211	if a.Cmp(natZero) != 0 {
212		// round (business >= 0.5 up)
213		df |= dfNotExact
214		if a.Add(a, a).Cmp(b) >= 0 {
215			m.Add(m, natOne)
216			if m.Cmp(exp10(digits)) == 0 {
217				shift := min(digits, maxExp-exp)
218				if shift < 1 { // overflow -> shift one at minimum
219					df |= dfOverflow
220					shift = 1
221				}
222				m.Set(exp10(digits - shift))
223				exp += shift
224			}
225		}
226	}
227
228	// norm
229	for exp < maxExp {
230		a.QuoRem(m, natTen, b) // reuse a, b
231		if b.Cmp(natZero) != 0 {
232			break
233		}
234		m.Set(a)
235		exp++
236	}
237
238	// performance (avoid expensive defer)
239	bigRatFree.Put(c)
240
241	return neg, exp, df
242}
243
244func min(a, b int) int {
245	if a < b {
246		return a
247	}
248	return b
249}
250
251func max(a, b int) int {
252	if a > b {
253		return a
254	}
255	return b
256}
257
258// performance: tested with reference work variable
259// - but int.Set is expensive, so let's live with big.Int creation for n >= len(nat)
260func exp10(n int) *big.Int {
261	if n < len(nat) {
262		return nat[n]
263	}
264	r := big.NewInt(int64(n))
265	return r.Exp(natTen, r, nil)
266}
267
268func digits10(p *big.Int) int {
269	k := p.BitLen() // 2^k <= p < 2^(k+1) - 1
270	//i := int(float64(k) / lg10) //minimal digits base 10
271	//i := int(float64(k) / lg10) //minimal digits base 10
272	i := k * 100 / 332
273	if i < 1 {
274		i = 1
275	}
276
277	for ; ; i++ {
278		if p.Cmp(exp10(i)) < 0 {
279			return i
280		}
281	}
282}
283
284func decodeDecimal(b []byte, m *big.Int) (bool, int) {
285
286	neg := (b[15] & 0x80) != 0
287	exp := int((((uint16(b[15])<<8)|uint16(b[14]))<<1)>>2) - dec128Bias
288
289	b14 := b[14]  // save b[14]
290	b[14] &= 0x01 // keep the mantissa bit (rest: sign and exp)
291
292	//most significand byte
293	msb := 14
294	for msb > 0 {
295		if b[msb] != 0 {
296			break
297		}
298		msb--
299	}
300
301	//calc number of words
302	numWords := (msb / _S) + 1
303	w := make([]big.Word, numWords)
304
305	k := numWords - 1
306	d := big.Word(0)
307	for i := msb; i >= 0; i-- {
308		d |= big.Word(b[i])
309		if k*_S == i {
310			w[k] = d
311			k--
312			d = 0
313		}
314		d <<= 8
315	}
316	b[14] = b14 // restore b[14]
317	m.SetBits(w)
318	return neg, exp
319}
320
321func encodeDecimal(m *big.Int, neg bool, exp int) (driver.Value, error) {
322
323	b := make([]byte, decimalSize)
324
325	// little endian bigint words (significand) -> little endian db decimal format
326	j := 0
327	for _, d := range m.Bits() {
328		for i := 0; i < 8; i++ {
329			b[j] = byte(d)
330			d >>= 8
331			j++
332		}
333	}
334
335	exp += dec128Bias
336	b[14] |= (byte(exp) << 1)
337	b[15] = byte(uint16(exp) >> 7)
338
339	if neg {
340		b[15] |= 0x80
341	}
342
343	return b, nil
344}
345
346// NullDecimal represents an Decimal that may be null.
347// NullDecimal implements the Scanner interface so
348// it can be used as a scan destination, similar to NullString.
349type NullDecimal struct {
350	Decimal *Decimal
351	Valid   bool // Valid is true if Decimal is not NULL
352}
353
354// Scan implements the Scanner interface.
355func (n *NullDecimal) Scan(value interface{}) error {
356	var b []byte
357
358	b, n.Valid = value.([]byte)
359	if !n.Valid {
360		return nil
361	}
362	if n.Decimal == nil {
363		return fmt.Errorf("invalid decimal value %v", n.Decimal)
364	}
365	return n.Decimal.Scan(b)
366}
367
368// Value implements the driver Valuer interface.
369func (n NullDecimal) Value() (driver.Value, error) {
370	if !n.Valid {
371		return nil, nil
372	}
373	if n.Decimal == nil {
374		return nil, fmt.Errorf("invalid decimal value %v", n.Decimal)
375	}
376	return n.Decimal.Value()
377}
378