1/* 2Copyright 2014 SAP SE 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package driver 18 19import ( 20 "database/sql/driver" 21 "errors" 22 "fmt" 23 "math" 24 "math/big" 25 "sync" 26) 27 28//bigint word size (*--> src/pkg/math/big/arith.go) 29const ( 30 // Compute the size _S of a Word in bytes. 31 _m = ^big.Word(0) 32 _logS = _m>>8&1 + _m>>16&1 + _m>>32&1 33 _S = 1 << _logS 34) 35 36const ( 37 // http://en.wikipedia.org/wiki/Decimal128_floating-point_format 38 dec128Digits = 34 39 dec128Bias = 6176 40 dec128MinExp = -6176 41 dec128MaxExp = 6111 42) 43 44const ( 45 decimalSize = 16 //number of bytes 46) 47 48var natZero = big.NewInt(0) 49var natOne = big.NewInt(1) 50var natTen = big.NewInt(10) 51 52var nat = []*big.Int{ 53 natOne, //10^0 54 natTen, //10^1 55 big.NewInt(100), //10^2 56 big.NewInt(1000), //10^3 57 big.NewInt(10000), //10^4 58 big.NewInt(100000), //10^5 59 big.NewInt(1000000), //10^6 60 big.NewInt(10000000), //10^7 61 big.NewInt(100000000), //10^8 62 big.NewInt(1000000000), //10^9 63 big.NewInt(10000000000), //10^10 64} 65 66const lg10 = math.Ln10 / math.Ln2 // ~log2(10) 67 68var maxDecimal = new(big.Int).SetBytes([]byte{0x01, 0xED, 0x09, 0xBE, 0xAD, 0x87, 0xC0, 0x37, 0x8D, 0x8E, 0x63, 0xFF, 0xFF, 0xFF, 0xFF}) 69 70type decFlags byte 71 72const ( 73 dfNotExact decFlags = 1 << iota 74 dfOverflow 75 dfUnderflow 76) 77 78// ErrDecimalOutOfRange means that a big.Rat exceeds the size of hdb decimal fields. 79var ErrDecimalOutOfRange = errors.New("decimal out of range error") 80 81// big.Int free list 82var bigIntFree = sync.Pool{ 83 New: func() interface{} { return new(big.Int) }, 84} 85 86// big.Rat free list 87var bigRatFree = sync.Pool{ 88 New: func() interface{} { return new(big.Rat) }, 89} 90 91// A Decimal is the driver representation of a database decimal field value as big.Rat. 92type Decimal big.Rat 93 94// Scan implements the database/sql/Scanner interface. 95func (d *Decimal) Scan(src interface{}) error { 96 97 b, ok := src.([]byte) 98 if !ok { 99 return fmt.Errorf("decimal: invalid data type %T", src) 100 } 101 102 if len(b) != decimalSize { 103 return fmt.Errorf("decimal: invalid size %d of %v - %d expected", len(b), b, decimalSize) 104 } 105 106 if (b[15] & 0x60) == 0x60 { 107 return fmt.Errorf("decimal: format (infinity, nan, ...) not supported : %v", b) 108 } 109 110 v := (*big.Rat)(d) 111 p := v.Num() 112 q := v.Denom() 113 114 neg, exp := decodeDecimal(b, p) 115 116 switch { 117 case exp < 0: 118 q.Set(exp10(exp * -1)) 119 case exp == 0: 120 q.Set(natOne) 121 case exp > 0: 122 p.Mul(p, exp10(exp)) 123 q.Set(natOne) 124 } 125 126 if neg { 127 v.Neg(v) 128 } 129 return nil 130} 131 132// Value implements the database/sql/Valuer interface. 133func (d Decimal) Value() (driver.Value, error) { 134 m := bigIntFree.Get().(*big.Int) 135 neg, exp, df := convertRatToDecimal((*big.Rat)(&d), m, dec128Digits, dec128MinExp, dec128MaxExp) 136 137 var v driver.Value 138 var err error 139 140 switch { 141 default: 142 v, err = encodeDecimal(m, neg, exp) 143 case df&dfUnderflow != 0: // set to zero 144 m.Set(natZero) 145 v, err = encodeDecimal(m, false, 0) 146 case df&dfOverflow != 0: 147 err = ErrDecimalOutOfRange 148 } 149 150 // performance (avoid expensive defer) 151 bigIntFree.Put(m) 152 153 return v, err 154} 155 156func convertRatToDecimal(x *big.Rat, m *big.Int, digits, minExp, maxExp int) (bool, int, decFlags) { 157 158 neg := x.Sign() < 0 //store sign 159 160 if x.Num().Cmp(natZero) == 0 { // zero 161 m.Set(natZero) 162 return neg, 0, 0 163 } 164 165 c := bigRatFree.Get().(*big.Rat).Abs(x) // copy && abs 166 a := c.Num() 167 b := c.Denom() 168 169 exp, shift := 0, 0 170 171 if c.IsInt() { 172 exp = digits10(a) - 1 173 } else { 174 shift = digits10(a) - digits10(b) 175 switch { 176 case shift < 0: 177 a.Mul(a, exp10(shift*-1)) 178 case shift > 0: 179 b.Mul(b, exp10(shift)) 180 } 181 if a.Cmp(b) == -1 { 182 exp = shift - 1 183 } else { 184 exp = shift 185 } 186 } 187 188 var df decFlags 189 190 switch { 191 default: 192 exp = max(exp-digits+1, minExp) 193 case exp < minExp: 194 df |= dfUnderflow 195 exp = exp - digits + 1 196 } 197 198 if exp > maxExp { 199 df |= dfOverflow 200 } 201 202 shift = exp - shift 203 switch { 204 case shift < 0: 205 a.Mul(a, exp10(shift*-1)) 206 case exp > 0: 207 b.Mul(b, exp10(shift)) 208 } 209 210 m.QuoRem(a, b, a) // reuse a as rest 211 if a.Cmp(natZero) != 0 { 212 // round (business >= 0.5 up) 213 df |= dfNotExact 214 if a.Add(a, a).Cmp(b) >= 0 { 215 m.Add(m, natOne) 216 if m.Cmp(exp10(digits)) == 0 { 217 shift := min(digits, maxExp-exp) 218 if shift < 1 { // overflow -> shift one at minimum 219 df |= dfOverflow 220 shift = 1 221 } 222 m.Set(exp10(digits - shift)) 223 exp += shift 224 } 225 } 226 } 227 228 // norm 229 for exp < maxExp { 230 a.QuoRem(m, natTen, b) // reuse a, b 231 if b.Cmp(natZero) != 0 { 232 break 233 } 234 m.Set(a) 235 exp++ 236 } 237 238 // performance (avoid expensive defer) 239 bigRatFree.Put(c) 240 241 return neg, exp, df 242} 243 244func min(a, b int) int { 245 if a < b { 246 return a 247 } 248 return b 249} 250 251func max(a, b int) int { 252 if a > b { 253 return a 254 } 255 return b 256} 257 258// performance: tested with reference work variable 259// - but int.Set is expensive, so let's live with big.Int creation for n >= len(nat) 260func exp10(n int) *big.Int { 261 if n < len(nat) { 262 return nat[n] 263 } 264 r := big.NewInt(int64(n)) 265 return r.Exp(natTen, r, nil) 266} 267 268func digits10(p *big.Int) int { 269 k := p.BitLen() // 2^k <= p < 2^(k+1) - 1 270 //i := int(float64(k) / lg10) //minimal digits base 10 271 //i := int(float64(k) / lg10) //minimal digits base 10 272 i := k * 100 / 332 273 if i < 1 { 274 i = 1 275 } 276 277 for ; ; i++ { 278 if p.Cmp(exp10(i)) < 0 { 279 return i 280 } 281 } 282} 283 284func decodeDecimal(b []byte, m *big.Int) (bool, int) { 285 286 neg := (b[15] & 0x80) != 0 287 exp := int((((uint16(b[15])<<8)|uint16(b[14]))<<1)>>2) - dec128Bias 288 289 b14 := b[14] // save b[14] 290 b[14] &= 0x01 // keep the mantissa bit (rest: sign and exp) 291 292 //most significand byte 293 msb := 14 294 for msb > 0 { 295 if b[msb] != 0 { 296 break 297 } 298 msb-- 299 } 300 301 //calc number of words 302 numWords := (msb / _S) + 1 303 w := make([]big.Word, numWords) 304 305 k := numWords - 1 306 d := big.Word(0) 307 for i := msb; i >= 0; i-- { 308 d |= big.Word(b[i]) 309 if k*_S == i { 310 w[k] = d 311 k-- 312 d = 0 313 } 314 d <<= 8 315 } 316 b[14] = b14 // restore b[14] 317 m.SetBits(w) 318 return neg, exp 319} 320 321func encodeDecimal(m *big.Int, neg bool, exp int) (driver.Value, error) { 322 323 b := make([]byte, decimalSize) 324 325 // little endian bigint words (significand) -> little endian db decimal format 326 j := 0 327 for _, d := range m.Bits() { 328 for i := 0; i < 8; i++ { 329 b[j] = byte(d) 330 d >>= 8 331 j++ 332 } 333 } 334 335 exp += dec128Bias 336 b[14] |= (byte(exp) << 1) 337 b[15] = byte(uint16(exp) >> 7) 338 339 if neg { 340 b[15] |= 0x80 341 } 342 343 return b, nil 344} 345 346// NullDecimal represents an Decimal that may be null. 347// NullDecimal implements the Scanner interface so 348// it can be used as a scan destination, similar to NullString. 349type NullDecimal struct { 350 Decimal *Decimal 351 Valid bool // Valid is true if Decimal is not NULL 352} 353 354// Scan implements the Scanner interface. 355func (n *NullDecimal) Scan(value interface{}) error { 356 var b []byte 357 358 b, n.Valid = value.([]byte) 359 if !n.Valid { 360 return nil 361 } 362 if n.Decimal == nil { 363 return fmt.Errorf("invalid decimal value %v", n.Decimal) 364 } 365 return n.Decimal.Scan(b) 366} 367 368// Value implements the driver Valuer interface. 369func (n NullDecimal) Value() (driver.Value, error) { 370 if !n.Valid { 371 return nil, nil 372 } 373 if n.Decimal == nil { 374 return nil, fmt.Errorf("invalid decimal value %v", n.Decimal) 375 } 376 return n.Decimal.Value() 377} 378