1package pgtype 2 3import ( 4 "database/sql/driver" 5 "encoding/binary" 6 "math" 7 "math/big" 8 "strconv" 9 "strings" 10 11 "github.com/jackc/pgx/pgio" 12 "github.com/pkg/errors" 13) 14 15// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 16const nbase = 10000 17 18var big0 *big.Int = big.NewInt(0) 19var big1 *big.Int = big.NewInt(1) 20var big10 *big.Int = big.NewInt(10) 21var big100 *big.Int = big.NewInt(100) 22var big1000 *big.Int = big.NewInt(1000) 23 24var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) 25var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) 26var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) 27var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) 28var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) 29var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) 30var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) 31var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) 32var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) 33var bigMinInt *big.Int = big.NewInt(int64(minInt)) 34 35var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) 36var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) 37var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) 38var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) 39var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) 40 41var bigNBase *big.Int = big.NewInt(nbase) 42var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) 43var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) 44var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) 45 46type Numeric struct { 47 Int *big.Int 48 Exp int32 49 Status Status 50} 51 52func (dst *Numeric) Set(src interface{}) error { 53 if src == nil { 54 *dst = Numeric{Status: Null} 55 return nil 56 } 57 58 switch value := src.(type) { 59 case float32: 60 num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) 61 if err != nil { 62 return err 63 } 64 *dst = Numeric{Int: num, Exp: exp, Status: Present} 65 case float64: 66 num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) 67 if err != nil { 68 return err 69 } 70 *dst = Numeric{Int: num, Exp: exp, Status: Present} 71 case int8: 72 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 73 case uint8: 74 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 75 case int16: 76 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 77 case uint16: 78 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 79 case int32: 80 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 81 case uint32: 82 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 83 case int64: 84 *dst = Numeric{Int: big.NewInt(value), Status: Present} 85 case uint64: 86 *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} 87 case int: 88 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 89 case uint: 90 *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} 91 case string: 92 num, exp, err := parseNumericString(value) 93 if err != nil { 94 return err 95 } 96 *dst = Numeric{Int: num, Exp: exp, Status: Present} 97 default: 98 if originalSrc, ok := underlyingNumberType(src); ok { 99 return dst.Set(originalSrc) 100 } 101 return errors.Errorf("cannot convert %v to Numeric", value) 102 } 103 104 return nil 105} 106 107func (dst *Numeric) Get() interface{} { 108 switch dst.Status { 109 case Present: 110 return dst 111 case Null: 112 return nil 113 default: 114 return dst.Status 115 } 116} 117 118func (src *Numeric) AssignTo(dst interface{}) error { 119 switch src.Status { 120 case Present: 121 switch v := dst.(type) { 122 case *float32: 123 f, err := src.toFloat64() 124 if err != nil { 125 return err 126 } 127 return float64AssignTo(f, src.Status, dst) 128 case *float64: 129 f, err := src.toFloat64() 130 if err != nil { 131 return err 132 } 133 return float64AssignTo(f, src.Status, dst) 134 case *int: 135 normalizedInt, err := src.toBigInt() 136 if err != nil { 137 return err 138 } 139 if normalizedInt.Cmp(bigMaxInt) > 0 { 140 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 141 } 142 if normalizedInt.Cmp(bigMinInt) < 0 { 143 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 144 } 145 *v = int(normalizedInt.Int64()) 146 case *int8: 147 normalizedInt, err := src.toBigInt() 148 if err != nil { 149 return err 150 } 151 if normalizedInt.Cmp(bigMaxInt8) > 0 { 152 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 153 } 154 if normalizedInt.Cmp(bigMinInt8) < 0 { 155 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 156 } 157 *v = int8(normalizedInt.Int64()) 158 case *int16: 159 normalizedInt, err := src.toBigInt() 160 if err != nil { 161 return err 162 } 163 if normalizedInt.Cmp(bigMaxInt16) > 0 { 164 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 165 } 166 if normalizedInt.Cmp(bigMinInt16) < 0 { 167 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 168 } 169 *v = int16(normalizedInt.Int64()) 170 case *int32: 171 normalizedInt, err := src.toBigInt() 172 if err != nil { 173 return err 174 } 175 if normalizedInt.Cmp(bigMaxInt32) > 0 { 176 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 177 } 178 if normalizedInt.Cmp(bigMinInt32) < 0 { 179 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 180 } 181 *v = int32(normalizedInt.Int64()) 182 case *int64: 183 normalizedInt, err := src.toBigInt() 184 if err != nil { 185 return err 186 } 187 if normalizedInt.Cmp(bigMaxInt64) > 0 { 188 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 189 } 190 if normalizedInt.Cmp(bigMinInt64) < 0 { 191 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 192 } 193 *v = normalizedInt.Int64() 194 case *uint: 195 normalizedInt, err := src.toBigInt() 196 if err != nil { 197 return err 198 } 199 if normalizedInt.Cmp(big0) < 0 { 200 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 201 } else if normalizedInt.Cmp(bigMaxUint) > 0 { 202 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 203 } 204 *v = uint(normalizedInt.Uint64()) 205 case *uint8: 206 normalizedInt, err := src.toBigInt() 207 if err != nil { 208 return err 209 } 210 if normalizedInt.Cmp(big0) < 0 { 211 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 212 } else if normalizedInt.Cmp(bigMaxUint8) > 0 { 213 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 214 } 215 *v = uint8(normalizedInt.Uint64()) 216 case *uint16: 217 normalizedInt, err := src.toBigInt() 218 if err != nil { 219 return err 220 } 221 if normalizedInt.Cmp(big0) < 0 { 222 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 223 } else if normalizedInt.Cmp(bigMaxUint16) > 0 { 224 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 225 } 226 *v = uint16(normalizedInt.Uint64()) 227 case *uint32: 228 normalizedInt, err := src.toBigInt() 229 if err != nil { 230 return err 231 } 232 if normalizedInt.Cmp(big0) < 0 { 233 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 234 } else if normalizedInt.Cmp(bigMaxUint32) > 0 { 235 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 236 } 237 *v = uint32(normalizedInt.Uint64()) 238 case *uint64: 239 normalizedInt, err := src.toBigInt() 240 if err != nil { 241 return err 242 } 243 if normalizedInt.Cmp(big0) < 0 { 244 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 245 } else if normalizedInt.Cmp(bigMaxUint64) > 0 { 246 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 247 } 248 *v = normalizedInt.Uint64() 249 default: 250 if nextDst, retry := GetAssignToDstType(dst); retry { 251 return src.AssignTo(nextDst) 252 } 253 } 254 case Null: 255 return NullAssignTo(dst) 256 } 257 258 return nil 259} 260 261func (dst *Numeric) toBigInt() (*big.Int, error) { 262 if dst.Exp == 0 { 263 return dst.Int, nil 264 } 265 266 num := &big.Int{} 267 num.Set(dst.Int) 268 if dst.Exp > 0 { 269 mul := &big.Int{} 270 mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) 271 num.Mul(num, mul) 272 return num, nil 273 } 274 275 div := &big.Int{} 276 div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) 277 remainder := &big.Int{} 278 num.DivMod(num, div, remainder) 279 if remainder.Cmp(big0) != 0 { 280 return nil, errors.Errorf("cannot convert %v to integer", dst) 281 } 282 return num, nil 283} 284 285func (src *Numeric) toFloat64() (float64, error) { 286 f, err := strconv.ParseFloat(src.Int.String(), 64) 287 if err != nil { 288 return 0, err 289 } 290 if src.Exp > 0 { 291 for i := 0; i < int(src.Exp); i++ { 292 f *= 10 293 } 294 } else if src.Exp < 0 { 295 for i := 0; i > int(src.Exp); i-- { 296 f /= 10 297 } 298 } 299 return f, nil 300} 301 302func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { 303 if src == nil { 304 *dst = Numeric{Status: Null} 305 return nil 306 } 307 308 num, exp, err := parseNumericString(string(src)) 309 if err != nil { 310 return err 311 } 312 313 *dst = Numeric{Int: num, Exp: exp, Status: Present} 314 return nil 315} 316 317func parseNumericString(str string) (n *big.Int, exp int32, err error) { 318 parts := strings.SplitN(str, ".", 2) 319 digits := strings.Join(parts, "") 320 321 if len(parts) > 1 { 322 exp = int32(-len(parts[1])) 323 } else { 324 for len(digits) > 1 && digits[len(digits)-1] == '0' { 325 digits = digits[:len(digits)-1] 326 exp++ 327 } 328 } 329 330 accum := &big.Int{} 331 if _, ok := accum.SetString(digits, 10); !ok { 332 return nil, 0, errors.Errorf("%s is not a number", str) 333 } 334 335 return accum, exp, nil 336} 337 338func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { 339 if src == nil { 340 *dst = Numeric{Status: Null} 341 return nil 342 } 343 344 if len(src) < 8 { 345 return errors.Errorf("numeric incomplete %v", src) 346 } 347 348 rp := 0 349 ndigits := int16(binary.BigEndian.Uint16(src[rp:])) 350 rp += 2 351 352 if ndigits == 0 { 353 *dst = Numeric{Int: big.NewInt(0), Status: Present} 354 return nil 355 } 356 357 weight := int16(binary.BigEndian.Uint16(src[rp:])) 358 rp += 2 359 sign := int16(binary.BigEndian.Uint16(src[rp:])) 360 rp += 2 361 dscale := int16(binary.BigEndian.Uint16(src[rp:])) 362 rp += 2 363 364 if len(src[rp:]) < int(ndigits)*2 { 365 return errors.Errorf("numeric incomplete %v", src) 366 } 367 368 accum := &big.Int{} 369 370 for i := 0; i < int(ndigits+3)/4; i++ { 371 int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) 372 rp += bytesRead 373 374 if i > 0 { 375 var mul *big.Int 376 switch digitsRead { 377 case 1: 378 mul = bigNBase 379 case 2: 380 mul = bigNBaseX2 381 case 3: 382 mul = bigNBaseX3 383 case 4: 384 mul = bigNBaseX4 385 default: 386 return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) 387 } 388 accum.Mul(accum, mul) 389 } 390 391 accum.Add(accum, big.NewInt(int64accum)) 392 } 393 394 exp := (int32(weight) - int32(ndigits) + 1) * 4 395 396 if dscale > 0 { 397 fracNBaseDigits := ndigits - weight - 1 398 fracDecimalDigits := fracNBaseDigits * 4 399 400 if dscale > fracDecimalDigits { 401 multCount := int(dscale - fracDecimalDigits) 402 for i := 0; i < multCount; i++ { 403 accum.Mul(accum, big10) 404 exp-- 405 } 406 } else if dscale < fracDecimalDigits { 407 divCount := int(fracDecimalDigits - dscale) 408 for i := 0; i < divCount; i++ { 409 accum.Div(accum, big10) 410 exp++ 411 } 412 } 413 } 414 415 reduced := &big.Int{} 416 remainder := &big.Int{} 417 if exp >= 0 { 418 for { 419 reduced.DivMod(accum, big10, remainder) 420 if remainder.Cmp(big0) != 0 { 421 break 422 } 423 accum.Set(reduced) 424 exp++ 425 } 426 } 427 428 if sign != 0 { 429 accum.Neg(accum) 430 } 431 432 *dst = Numeric{Int: accum, Exp: exp, Status: Present} 433 434 return nil 435 436} 437 438func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { 439 digits := len(src) / 2 440 if digits > 4 { 441 digits = 4 442 } 443 444 rp := 0 445 446 for i := 0; i < digits; i++ { 447 if i > 0 { 448 accum *= nbase 449 } 450 accum += int64(binary.BigEndian.Uint16(src[rp:])) 451 rp += 2 452 } 453 454 return accum, rp, digits 455} 456 457func (src *Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { 458 switch src.Status { 459 case Null: 460 return nil, nil 461 case Undefined: 462 return nil, errUndefined 463 } 464 465 buf = append(buf, src.Int.String()...) 466 buf = append(buf, 'e') 467 buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) 468 return buf, nil 469} 470 471func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { 472 switch src.Status { 473 case Null: 474 return nil, nil 475 case Undefined: 476 return nil, errUndefined 477 } 478 479 var sign int16 480 if src.Int.Cmp(big0) < 0 { 481 sign = 16384 482 } 483 484 absInt := &big.Int{} 485 wholePart := &big.Int{} 486 fracPart := &big.Int{} 487 remainder := &big.Int{} 488 absInt.Abs(src.Int) 489 490 // Normalize absInt and exp to where exp is always a multiple of 4. This makes 491 // converting to 16-bit base 10,000 digits easier. 492 var exp int32 493 switch src.Exp % 4 { 494 case 1, -3: 495 exp = src.Exp - 1 496 absInt.Mul(absInt, big10) 497 case 2, -2: 498 exp = src.Exp - 2 499 absInt.Mul(absInt, big100) 500 case 3, -1: 501 exp = src.Exp - 3 502 absInt.Mul(absInt, big1000) 503 default: 504 exp = src.Exp 505 } 506 507 if exp < 0 { 508 divisor := &big.Int{} 509 divisor.Exp(big10, big.NewInt(int64(-exp)), nil) 510 wholePart.DivMod(absInt, divisor, fracPart) 511 fracPart.Add(fracPart, divisor) 512 } else { 513 wholePart = absInt 514 } 515 516 var wholeDigits, fracDigits []int16 517 518 for wholePart.Cmp(big0) != 0 { 519 wholePart.DivMod(wholePart, bigNBase, remainder) 520 wholeDigits = append(wholeDigits, int16(remainder.Int64())) 521 } 522 523 if fracPart.Cmp(big0) != 0 { 524 for fracPart.Cmp(big1) != 0 { 525 fracPart.DivMod(fracPart, bigNBase, remainder) 526 fracDigits = append(fracDigits, int16(remainder.Int64())) 527 } 528 } 529 530 buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) 531 532 var weight int16 533 if len(wholeDigits) > 0 { 534 weight = int16(len(wholeDigits) - 1) 535 if exp > 0 { 536 weight += int16(exp / 4) 537 } 538 } else { 539 weight = int16(exp/4) - 1 + int16(len(fracDigits)) 540 } 541 buf = pgio.AppendInt16(buf, weight) 542 543 buf = pgio.AppendInt16(buf, sign) 544 545 var dscale int16 546 if src.Exp < 0 { 547 dscale = int16(-src.Exp) 548 } 549 buf = pgio.AppendInt16(buf, dscale) 550 551 for i := len(wholeDigits) - 1; i >= 0; i-- { 552 buf = pgio.AppendInt16(buf, wholeDigits[i]) 553 } 554 555 for i := len(fracDigits) - 1; i >= 0; i-- { 556 buf = pgio.AppendInt16(buf, fracDigits[i]) 557 } 558 559 return buf, nil 560} 561 562// Scan implements the database/sql Scanner interface. 563func (dst *Numeric) Scan(src interface{}) error { 564 if src == nil { 565 *dst = Numeric{Status: Null} 566 return nil 567 } 568 569 switch src := src.(type) { 570 case float64: 571 // TODO 572 // *dst = Numeric{Float: src, Status: Present} 573 return nil 574 case string: 575 return dst.DecodeText(nil, []byte(src)) 576 case []byte: 577 srcCopy := make([]byte, len(src)) 578 copy(srcCopy, src) 579 return dst.DecodeText(nil, srcCopy) 580 } 581 582 return errors.Errorf("cannot scan %T", src) 583} 584 585// Value implements the database/sql/driver Valuer interface. 586func (src *Numeric) Value() (driver.Value, error) { 587 switch src.Status { 588 case Present: 589 buf, err := src.EncodeText(nil, nil) 590 if err != nil { 591 return nil, err 592 } 593 594 return string(buf), nil 595 case Null: 596 return nil, nil 597 default: 598 return nil, errUndefined 599 } 600} 601