1package pgtype 2 3import ( 4 "database/sql/driver" 5 "encoding/binary" 6 "math" 7 "math/big" 8 "strconv" 9 "strings" 10 11 "github.com/jackc/pgio" 12 errors "golang.org/x/xerrors" 13) 14 15// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 16const nbase = 10000 17 18const ( 19 pgNumericNaN = 0x000000000c000000 20 pgNumericNaNSign = 0x0c00 21) 22 23var big0 *big.Int = big.NewInt(0) 24var big1 *big.Int = big.NewInt(1) 25var big10 *big.Int = big.NewInt(10) 26var big100 *big.Int = big.NewInt(100) 27var big1000 *big.Int = big.NewInt(1000) 28 29var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) 30var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) 31var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) 32var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) 33var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) 34var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) 35var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) 36var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) 37var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) 38var bigMinInt *big.Int = big.NewInt(int64(minInt)) 39 40var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) 41var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) 42var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) 43var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) 44var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) 45 46var bigNBase *big.Int = big.NewInt(nbase) 47var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) 48var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) 49var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) 50 51type Numeric struct { 52 Int *big.Int 53 Exp int32 54 Status Status 55 NaN bool 56} 57 58func (dst *Numeric) Set(src interface{}) error { 59 if src == nil { 60 *dst = Numeric{Status: Null} 61 return nil 62 } 63 64 if value, ok := src.(interface{ Get() interface{} }); ok { 65 value2 := value.Get() 66 if value2 != value { 67 return dst.Set(value2) 68 } 69 } 70 71 switch value := src.(type) { 72 case float32: 73 if math.IsNaN(float64(value)) { 74 *dst = Numeric{Status: Present, NaN: true} 75 return nil 76 } 77 num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) 78 if err != nil { 79 return err 80 } 81 *dst = Numeric{Int: num, Exp: exp, Status: Present} 82 case float64: 83 if math.IsNaN(value) { 84 *dst = Numeric{Status: Present, NaN: true} 85 return nil 86 } 87 num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) 88 if err != nil { 89 return err 90 } 91 *dst = Numeric{Int: num, Exp: exp, Status: Present} 92 case int8: 93 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 94 case uint8: 95 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 96 case int16: 97 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 98 case uint16: 99 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 100 case int32: 101 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 102 case uint32: 103 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 104 case int64: 105 *dst = Numeric{Int: big.NewInt(value), Status: Present} 106 case uint64: 107 *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} 108 case int: 109 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 110 case uint: 111 *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} 112 case string: 113 num, exp, err := parseNumericString(value) 114 if err != nil { 115 return err 116 } 117 *dst = Numeric{Int: num, Exp: exp, Status: Present} 118 case *float64: 119 if value == nil { 120 *dst = Numeric{Status: Null} 121 } else { 122 return dst.Set(*value) 123 } 124 case *float32: 125 if value == nil { 126 *dst = Numeric{Status: Null} 127 } else { 128 return dst.Set(*value) 129 } 130 case *int8: 131 if value == nil { 132 *dst = Numeric{Status: Null} 133 } else { 134 return dst.Set(*value) 135 } 136 case *uint8: 137 if value == nil { 138 *dst = Numeric{Status: Null} 139 } else { 140 return dst.Set(*value) 141 } 142 case *int16: 143 if value == nil { 144 *dst = Numeric{Status: Null} 145 } else { 146 return dst.Set(*value) 147 } 148 case *uint16: 149 if value == nil { 150 *dst = Numeric{Status: Null} 151 } else { 152 return dst.Set(*value) 153 } 154 case *int32: 155 if value == nil { 156 *dst = Numeric{Status: Null} 157 } else { 158 return dst.Set(*value) 159 } 160 case *uint32: 161 if value == nil { 162 *dst = Numeric{Status: Null} 163 } else { 164 return dst.Set(*value) 165 } 166 case *int64: 167 if value == nil { 168 *dst = Numeric{Status: Null} 169 } else { 170 return dst.Set(*value) 171 } 172 case *uint64: 173 if value == nil { 174 *dst = Numeric{Status: Null} 175 } else { 176 return dst.Set(*value) 177 } 178 case *int: 179 if value == nil { 180 *dst = Numeric{Status: Null} 181 } else { 182 return dst.Set(*value) 183 } 184 case *uint: 185 if value == nil { 186 *dst = Numeric{Status: Null} 187 } else { 188 return dst.Set(*value) 189 } 190 case *string: 191 if value == nil { 192 *dst = Numeric{Status: Null} 193 } else { 194 return dst.Set(*value) 195 } 196 default: 197 if originalSrc, ok := underlyingNumberType(src); ok { 198 return dst.Set(originalSrc) 199 } 200 return errors.Errorf("cannot convert %v to Numeric", value) 201 } 202 203 return nil 204} 205 206func (dst Numeric) Get() interface{} { 207 switch dst.Status { 208 case Present: 209 return dst 210 case Null: 211 return nil 212 default: 213 return dst.Status 214 } 215} 216 217func (src *Numeric) AssignTo(dst interface{}) error { 218 switch src.Status { 219 case Present: 220 switch v := dst.(type) { 221 case *float32: 222 f, err := src.toFloat64() 223 if err != nil { 224 return err 225 } 226 return float64AssignTo(f, src.Status, dst) 227 case *float64: 228 f, err := src.toFloat64() 229 if err != nil { 230 return err 231 } 232 return float64AssignTo(f, src.Status, dst) 233 case *int: 234 normalizedInt, err := src.toBigInt() 235 if err != nil { 236 return err 237 } 238 if normalizedInt.Cmp(bigMaxInt) > 0 { 239 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 240 } 241 if normalizedInt.Cmp(bigMinInt) < 0 { 242 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 243 } 244 *v = int(normalizedInt.Int64()) 245 case *int8: 246 normalizedInt, err := src.toBigInt() 247 if err != nil { 248 return err 249 } 250 if normalizedInt.Cmp(bigMaxInt8) > 0 { 251 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 252 } 253 if normalizedInt.Cmp(bigMinInt8) < 0 { 254 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 255 } 256 *v = int8(normalizedInt.Int64()) 257 case *int16: 258 normalizedInt, err := src.toBigInt() 259 if err != nil { 260 return err 261 } 262 if normalizedInt.Cmp(bigMaxInt16) > 0 { 263 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 264 } 265 if normalizedInt.Cmp(bigMinInt16) < 0 { 266 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 267 } 268 *v = int16(normalizedInt.Int64()) 269 case *int32: 270 normalizedInt, err := src.toBigInt() 271 if err != nil { 272 return err 273 } 274 if normalizedInt.Cmp(bigMaxInt32) > 0 { 275 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 276 } 277 if normalizedInt.Cmp(bigMinInt32) < 0 { 278 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 279 } 280 *v = int32(normalizedInt.Int64()) 281 case *int64: 282 normalizedInt, err := src.toBigInt() 283 if err != nil { 284 return err 285 } 286 if normalizedInt.Cmp(bigMaxInt64) > 0 { 287 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 288 } 289 if normalizedInt.Cmp(bigMinInt64) < 0 { 290 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 291 } 292 *v = normalizedInt.Int64() 293 case *uint: 294 normalizedInt, err := src.toBigInt() 295 if err != nil { 296 return err 297 } 298 if normalizedInt.Cmp(big0) < 0 { 299 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 300 } else if normalizedInt.Cmp(bigMaxUint) > 0 { 301 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 302 } 303 *v = uint(normalizedInt.Uint64()) 304 case *uint8: 305 normalizedInt, err := src.toBigInt() 306 if err != nil { 307 return err 308 } 309 if normalizedInt.Cmp(big0) < 0 { 310 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 311 } else if normalizedInt.Cmp(bigMaxUint8) > 0 { 312 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 313 } 314 *v = uint8(normalizedInt.Uint64()) 315 case *uint16: 316 normalizedInt, err := src.toBigInt() 317 if err != nil { 318 return err 319 } 320 if normalizedInt.Cmp(big0) < 0 { 321 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 322 } else if normalizedInt.Cmp(bigMaxUint16) > 0 { 323 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 324 } 325 *v = uint16(normalizedInt.Uint64()) 326 case *uint32: 327 normalizedInt, err := src.toBigInt() 328 if err != nil { 329 return err 330 } 331 if normalizedInt.Cmp(big0) < 0 { 332 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 333 } else if normalizedInt.Cmp(bigMaxUint32) > 0 { 334 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 335 } 336 *v = uint32(normalizedInt.Uint64()) 337 case *uint64: 338 normalizedInt, err := src.toBigInt() 339 if err != nil { 340 return err 341 } 342 if normalizedInt.Cmp(big0) < 0 { 343 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 344 } else if normalizedInt.Cmp(bigMaxUint64) > 0 { 345 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 346 } 347 *v = normalizedInt.Uint64() 348 default: 349 if nextDst, retry := GetAssignToDstType(dst); retry { 350 return src.AssignTo(nextDst) 351 } 352 return errors.Errorf("unable to assign to %T", dst) 353 } 354 case Null: 355 return NullAssignTo(dst) 356 } 357 358 return nil 359} 360 361func (dst *Numeric) toBigInt() (*big.Int, error) { 362 if dst.Exp == 0 { 363 return dst.Int, nil 364 } 365 366 num := &big.Int{} 367 num.Set(dst.Int) 368 if dst.Exp > 0 { 369 mul := &big.Int{} 370 mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) 371 num.Mul(num, mul) 372 return num, nil 373 } 374 375 div := &big.Int{} 376 div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) 377 remainder := &big.Int{} 378 num.DivMod(num, div, remainder) 379 if remainder.Cmp(big0) != 0 { 380 return nil, errors.Errorf("cannot convert %v to integer", dst) 381 } 382 return num, nil 383} 384 385func (src *Numeric) toFloat64() (float64, error) { 386 if src.NaN { 387 return math.NaN(), nil 388 } 389 390 buf := make([]byte, 0, 32) 391 392 buf = append(buf, src.Int.String()...) 393 buf = append(buf, 'e') 394 buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) 395 396 f, err := strconv.ParseFloat(string(buf), 64) 397 if err != nil { 398 return 0, err 399 } 400 return f, nil 401} 402 403func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { 404 if src == nil { 405 *dst = Numeric{Status: Null} 406 return nil 407 } 408 409 if string(src) == "'NaN'" { // includes single quotes, see EncodeText for details. 410 *dst = Numeric{Status: Present, NaN: true} 411 return nil 412 } 413 414 num, exp, err := parseNumericString(string(src)) 415 if err != nil { 416 return err 417 } 418 419 *dst = Numeric{Int: num, Exp: exp, Status: Present} 420 return nil 421} 422 423func parseNumericString(str string) (n *big.Int, exp int32, err error) { 424 parts := strings.SplitN(str, ".", 2) 425 digits := strings.Join(parts, "") 426 427 if len(parts) > 1 { 428 exp = int32(-len(parts[1])) 429 } else { 430 for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { 431 digits = digits[:len(digits)-1] 432 exp++ 433 } 434 } 435 436 accum := &big.Int{} 437 if _, ok := accum.SetString(digits, 10); !ok { 438 return nil, 0, errors.Errorf("%s is not a number", str) 439 } 440 441 return accum, exp, nil 442} 443 444func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { 445 if src == nil { 446 *dst = Numeric{Status: Null} 447 return nil 448 } 449 450 if len(src) < 8 { 451 return errors.Errorf("numeric incomplete %v", src) 452 } 453 454 rp := 0 455 ndigits := int16(binary.BigEndian.Uint16(src[rp:])) 456 rp += 2 457 weight := int16(binary.BigEndian.Uint16(src[rp:])) 458 rp += 2 459 sign := int16(binary.BigEndian.Uint16(src[rp:])) 460 rp += 2 461 dscale := int16(binary.BigEndian.Uint16(src[rp:])) 462 rp += 2 463 464 if sign == pgNumericNaNSign { 465 *dst = Numeric{Status: Present, NaN: true} 466 return nil 467 } 468 469 if ndigits == 0 { 470 *dst = Numeric{Int: big.NewInt(0), Status: Present} 471 return nil 472 } 473 474 if len(src[rp:]) < int(ndigits)*2 { 475 return errors.Errorf("numeric incomplete %v", src) 476 } 477 478 accum := &big.Int{} 479 480 for i := 0; i < int(ndigits+3)/4; i++ { 481 int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) 482 rp += bytesRead 483 484 if i > 0 { 485 var mul *big.Int 486 switch digitsRead { 487 case 1: 488 mul = bigNBase 489 case 2: 490 mul = bigNBaseX2 491 case 3: 492 mul = bigNBaseX3 493 case 4: 494 mul = bigNBaseX4 495 default: 496 return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) 497 } 498 accum.Mul(accum, mul) 499 } 500 501 accum.Add(accum, big.NewInt(int64accum)) 502 } 503 504 exp := (int32(weight) - int32(ndigits) + 1) * 4 505 506 if dscale > 0 { 507 fracNBaseDigits := ndigits - weight - 1 508 fracDecimalDigits := fracNBaseDigits * 4 509 510 if dscale > fracDecimalDigits { 511 multCount := int(dscale - fracDecimalDigits) 512 for i := 0; i < multCount; i++ { 513 accum.Mul(accum, big10) 514 exp-- 515 } 516 } else if dscale < fracDecimalDigits { 517 divCount := int(fracDecimalDigits - dscale) 518 for i := 0; i < divCount; i++ { 519 accum.Div(accum, big10) 520 exp++ 521 } 522 } 523 } 524 525 reduced := &big.Int{} 526 remainder := &big.Int{} 527 if exp >= 0 { 528 for { 529 reduced.DivMod(accum, big10, remainder) 530 if remainder.Cmp(big0) != 0 { 531 break 532 } 533 accum.Set(reduced) 534 exp++ 535 } 536 } 537 538 if sign != 0 { 539 accum.Neg(accum) 540 } 541 542 *dst = Numeric{Int: accum, Exp: exp, Status: Present} 543 544 return nil 545 546} 547 548func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { 549 digits := len(src) / 2 550 if digits > 4 { 551 digits = 4 552 } 553 554 rp := 0 555 556 for i := 0; i < digits; i++ { 557 if i > 0 { 558 accum *= nbase 559 } 560 accum += int64(binary.BigEndian.Uint16(src[rp:])) 561 rp += 2 562 } 563 564 return accum, rp, digits 565} 566 567func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { 568 switch src.Status { 569 case Null: 570 return nil, nil 571 case Undefined: 572 return nil, errUndefined 573 } 574 575 if src.NaN { 576 // encode as 'NaN' including single quotes, 577 // "When writing this value [NaN] as a constant in an SQL command, 578 // you must put quotes around it, for example UPDATE table SET x = 'NaN'" 579 // https://www.postgresql.org/docs/9.3/datatype-numeric.html 580 buf = append(buf, "'NaN'"...) 581 return buf, nil 582 } 583 584 buf = append(buf, src.Int.String()...) 585 buf = append(buf, 'e') 586 buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) 587 return buf, nil 588} 589 590func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { 591 switch src.Status { 592 case Null: 593 return nil, nil 594 case Undefined: 595 return nil, errUndefined 596 } 597 598 if src.NaN { 599 buf = pgio.AppendUint64(buf, pgNumericNaN) 600 return buf, nil 601 } 602 603 var sign int16 604 if src.Int.Cmp(big0) < 0 { 605 sign = 16384 606 } 607 608 absInt := &big.Int{} 609 wholePart := &big.Int{} 610 fracPart := &big.Int{} 611 remainder := &big.Int{} 612 absInt.Abs(src.Int) 613 614 // Normalize absInt and exp to where exp is always a multiple of 4. This makes 615 // converting to 16-bit base 10,000 digits easier. 616 var exp int32 617 switch src.Exp % 4 { 618 case 1, -3: 619 exp = src.Exp - 1 620 absInt.Mul(absInt, big10) 621 case 2, -2: 622 exp = src.Exp - 2 623 absInt.Mul(absInt, big100) 624 case 3, -1: 625 exp = src.Exp - 3 626 absInt.Mul(absInt, big1000) 627 default: 628 exp = src.Exp 629 } 630 631 if exp < 0 { 632 divisor := &big.Int{} 633 divisor.Exp(big10, big.NewInt(int64(-exp)), nil) 634 wholePart.DivMod(absInt, divisor, fracPart) 635 fracPart.Add(fracPart, divisor) 636 } else { 637 wholePart = absInt 638 } 639 640 var wholeDigits, fracDigits []int16 641 642 for wholePart.Cmp(big0) != 0 { 643 wholePart.DivMod(wholePart, bigNBase, remainder) 644 wholeDigits = append(wholeDigits, int16(remainder.Int64())) 645 } 646 647 if fracPart.Cmp(big0) != 0 { 648 for fracPart.Cmp(big1) != 0 { 649 fracPart.DivMod(fracPart, bigNBase, remainder) 650 fracDigits = append(fracDigits, int16(remainder.Int64())) 651 } 652 } 653 654 buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) 655 656 var weight int16 657 if len(wholeDigits) > 0 { 658 weight = int16(len(wholeDigits) - 1) 659 if exp > 0 { 660 weight += int16(exp / 4) 661 } 662 } else { 663 weight = int16(exp/4) - 1 + int16(len(fracDigits)) 664 } 665 buf = pgio.AppendInt16(buf, weight) 666 667 buf = pgio.AppendInt16(buf, sign) 668 669 var dscale int16 670 if src.Exp < 0 { 671 dscale = int16(-src.Exp) 672 } 673 buf = pgio.AppendInt16(buf, dscale) 674 675 for i := len(wholeDigits) - 1; i >= 0; i-- { 676 buf = pgio.AppendInt16(buf, wholeDigits[i]) 677 } 678 679 for i := len(fracDigits) - 1; i >= 0; i-- { 680 buf = pgio.AppendInt16(buf, fracDigits[i]) 681 } 682 683 return buf, nil 684} 685 686// Scan implements the database/sql Scanner interface. 687func (dst *Numeric) Scan(src interface{}) error { 688 if src == nil { 689 *dst = Numeric{Status: Null} 690 return nil 691 } 692 693 switch src := src.(type) { 694 case string: 695 return dst.DecodeText(nil, []byte(src)) 696 case []byte: 697 srcCopy := make([]byte, len(src)) 698 copy(srcCopy, src) 699 return dst.DecodeText(nil, srcCopy) 700 } 701 702 return errors.Errorf("cannot scan %T", src) 703} 704 705// Value implements the database/sql/driver Valuer interface. 706func (src Numeric) Value() (driver.Value, error) { 707 switch src.Status { 708 case Present: 709 buf, err := src.EncodeText(nil, nil) 710 if err != nil { 711 return nil, err 712 } 713 714 return string(buf), nil 715 case Null: 716 return nil, nil 717 default: 718 return nil, errUndefined 719 } 720} 721