1package pgtype 2 3import ( 4 "database/sql/driver" 5 "encoding/binary" 6 "math" 7 "math/big" 8 "strconv" 9 "strings" 10 11 "github.com/jackc/pgio" 12 errors "golang.org/x/xerrors" 13) 14 15// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 16const nbase = 10000 17 18var big0 *big.Int = big.NewInt(0) 19var big1 *big.Int = big.NewInt(1) 20var big10 *big.Int = big.NewInt(10) 21var big100 *big.Int = big.NewInt(100) 22var big1000 *big.Int = big.NewInt(1000) 23 24var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) 25var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) 26var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) 27var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) 28var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) 29var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) 30var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) 31var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) 32var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) 33var bigMinInt *big.Int = big.NewInt(int64(minInt)) 34 35var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) 36var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) 37var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) 38var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) 39var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) 40 41var bigNBase *big.Int = big.NewInt(nbase) 42var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) 43var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) 44var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) 45 46type Numeric struct { 47 Int *big.Int 48 Exp int32 49 Status Status 50} 51 52func (dst *Numeric) Set(src interface{}) error { 53 if src == nil { 54 *dst = Numeric{Status: Null} 55 return nil 56 } 57 58 if value, ok := src.(interface{ Get() interface{} }); ok { 59 value2 := value.Get() 60 if value2 != value { 61 return dst.Set(value2) 62 } 63 } 64 65 switch value := src.(type) { 66 case float32: 67 num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) 68 if err != nil { 69 return err 70 } 71 *dst = Numeric{Int: num, Exp: exp, Status: Present} 72 case float64: 73 num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) 74 if err != nil { 75 return err 76 } 77 *dst = Numeric{Int: num, Exp: exp, Status: Present} 78 case int8: 79 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 80 case uint8: 81 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 82 case int16: 83 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 84 case uint16: 85 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 86 case int32: 87 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 88 case uint32: 89 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 90 case int64: 91 *dst = Numeric{Int: big.NewInt(value), Status: Present} 92 case uint64: 93 *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} 94 case int: 95 *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} 96 case uint: 97 *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} 98 case string: 99 num, exp, err := parseNumericString(value) 100 if err != nil { 101 return err 102 } 103 *dst = Numeric{Int: num, Exp: exp, Status: Present} 104 default: 105 if originalSrc, ok := underlyingNumberType(src); ok { 106 return dst.Set(originalSrc) 107 } 108 return errors.Errorf("cannot convert %v to Numeric", value) 109 } 110 111 return nil 112} 113 114func (dst Numeric) Get() interface{} { 115 switch dst.Status { 116 case Present: 117 return dst 118 case Null: 119 return nil 120 default: 121 return dst.Status 122 } 123} 124 125func (src *Numeric) AssignTo(dst interface{}) error { 126 switch src.Status { 127 case Present: 128 switch v := dst.(type) { 129 case *float32: 130 f, err := src.toFloat64() 131 if err != nil { 132 return err 133 } 134 return float64AssignTo(f, src.Status, dst) 135 case *float64: 136 f, err := src.toFloat64() 137 if err != nil { 138 return err 139 } 140 return float64AssignTo(f, src.Status, dst) 141 case *int: 142 normalizedInt, err := src.toBigInt() 143 if err != nil { 144 return err 145 } 146 if normalizedInt.Cmp(bigMaxInt) > 0 { 147 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 148 } 149 if normalizedInt.Cmp(bigMinInt) < 0 { 150 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 151 } 152 *v = int(normalizedInt.Int64()) 153 case *int8: 154 normalizedInt, err := src.toBigInt() 155 if err != nil { 156 return err 157 } 158 if normalizedInt.Cmp(bigMaxInt8) > 0 { 159 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 160 } 161 if normalizedInt.Cmp(bigMinInt8) < 0 { 162 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 163 } 164 *v = int8(normalizedInt.Int64()) 165 case *int16: 166 normalizedInt, err := src.toBigInt() 167 if err != nil { 168 return err 169 } 170 if normalizedInt.Cmp(bigMaxInt16) > 0 { 171 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 172 } 173 if normalizedInt.Cmp(bigMinInt16) < 0 { 174 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 175 } 176 *v = int16(normalizedInt.Int64()) 177 case *int32: 178 normalizedInt, err := src.toBigInt() 179 if err != nil { 180 return err 181 } 182 if normalizedInt.Cmp(bigMaxInt32) > 0 { 183 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 184 } 185 if normalizedInt.Cmp(bigMinInt32) < 0 { 186 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 187 } 188 *v = int32(normalizedInt.Int64()) 189 case *int64: 190 normalizedInt, err := src.toBigInt() 191 if err != nil { 192 return err 193 } 194 if normalizedInt.Cmp(bigMaxInt64) > 0 { 195 return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) 196 } 197 if normalizedInt.Cmp(bigMinInt64) < 0 { 198 return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) 199 } 200 *v = normalizedInt.Int64() 201 case *uint: 202 normalizedInt, err := src.toBigInt() 203 if err != nil { 204 return err 205 } 206 if normalizedInt.Cmp(big0) < 0 { 207 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 208 } else if normalizedInt.Cmp(bigMaxUint) > 0 { 209 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 210 } 211 *v = uint(normalizedInt.Uint64()) 212 case *uint8: 213 normalizedInt, err := src.toBigInt() 214 if err != nil { 215 return err 216 } 217 if normalizedInt.Cmp(big0) < 0 { 218 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 219 } else if normalizedInt.Cmp(bigMaxUint8) > 0 { 220 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 221 } 222 *v = uint8(normalizedInt.Uint64()) 223 case *uint16: 224 normalizedInt, err := src.toBigInt() 225 if err != nil { 226 return err 227 } 228 if normalizedInt.Cmp(big0) < 0 { 229 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 230 } else if normalizedInt.Cmp(bigMaxUint16) > 0 { 231 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 232 } 233 *v = uint16(normalizedInt.Uint64()) 234 case *uint32: 235 normalizedInt, err := src.toBigInt() 236 if err != nil { 237 return err 238 } 239 if normalizedInt.Cmp(big0) < 0 { 240 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 241 } else if normalizedInt.Cmp(bigMaxUint32) > 0 { 242 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 243 } 244 *v = uint32(normalizedInt.Uint64()) 245 case *uint64: 246 normalizedInt, err := src.toBigInt() 247 if err != nil { 248 return err 249 } 250 if normalizedInt.Cmp(big0) < 0 { 251 return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) 252 } else if normalizedInt.Cmp(bigMaxUint64) > 0 { 253 return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) 254 } 255 *v = normalizedInt.Uint64() 256 default: 257 if nextDst, retry := GetAssignToDstType(dst); retry { 258 return src.AssignTo(nextDst) 259 } 260 return errors.Errorf("unable to assign to %T", dst) 261 } 262 case Null: 263 return NullAssignTo(dst) 264 } 265 266 return nil 267} 268 269func (dst *Numeric) toBigInt() (*big.Int, error) { 270 if dst.Exp == 0 { 271 return dst.Int, nil 272 } 273 274 num := &big.Int{} 275 num.Set(dst.Int) 276 if dst.Exp > 0 { 277 mul := &big.Int{} 278 mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) 279 num.Mul(num, mul) 280 return num, nil 281 } 282 283 div := &big.Int{} 284 div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) 285 remainder := &big.Int{} 286 num.DivMod(num, div, remainder) 287 if remainder.Cmp(big0) != 0 { 288 return nil, errors.Errorf("cannot convert %v to integer", dst) 289 } 290 return num, nil 291} 292 293func (src *Numeric) toFloat64() (float64, error) { 294 f, err := strconv.ParseFloat(src.Int.String(), 64) 295 if err != nil { 296 return 0, err 297 } 298 if src.Exp > 0 { 299 for i := 0; i < int(src.Exp); i++ { 300 f *= 10 301 } 302 } else if src.Exp < 0 { 303 for i := 0; i > int(src.Exp); i-- { 304 f /= 10 305 } 306 } 307 return f, nil 308} 309 310func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { 311 if src == nil { 312 *dst = Numeric{Status: Null} 313 return nil 314 } 315 316 num, exp, err := parseNumericString(string(src)) 317 if err != nil { 318 return err 319 } 320 321 *dst = Numeric{Int: num, Exp: exp, Status: Present} 322 return nil 323} 324 325func parseNumericString(str string) (n *big.Int, exp int32, err error) { 326 parts := strings.SplitN(str, ".", 2) 327 digits := strings.Join(parts, "") 328 329 if len(parts) > 1 { 330 exp = int32(-len(parts[1])) 331 } else { 332 for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { 333 digits = digits[:len(digits)-1] 334 exp++ 335 } 336 } 337 338 accum := &big.Int{} 339 if _, ok := accum.SetString(digits, 10); !ok { 340 return nil, 0, errors.Errorf("%s is not a number", str) 341 } 342 343 return accum, exp, nil 344} 345 346func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { 347 if src == nil { 348 *dst = Numeric{Status: Null} 349 return nil 350 } 351 352 if len(src) < 8 { 353 return errors.Errorf("numeric incomplete %v", src) 354 } 355 356 rp := 0 357 ndigits := int16(binary.BigEndian.Uint16(src[rp:])) 358 rp += 2 359 360 if ndigits == 0 { 361 *dst = Numeric{Int: big.NewInt(0), Status: Present} 362 return nil 363 } 364 365 weight := int16(binary.BigEndian.Uint16(src[rp:])) 366 rp += 2 367 sign := int16(binary.BigEndian.Uint16(src[rp:])) 368 rp += 2 369 dscale := int16(binary.BigEndian.Uint16(src[rp:])) 370 rp += 2 371 372 if len(src[rp:]) < int(ndigits)*2 { 373 return errors.Errorf("numeric incomplete %v", src) 374 } 375 376 accum := &big.Int{} 377 378 for i := 0; i < int(ndigits+3)/4; i++ { 379 int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) 380 rp += bytesRead 381 382 if i > 0 { 383 var mul *big.Int 384 switch digitsRead { 385 case 1: 386 mul = bigNBase 387 case 2: 388 mul = bigNBaseX2 389 case 3: 390 mul = bigNBaseX3 391 case 4: 392 mul = bigNBaseX4 393 default: 394 return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) 395 } 396 accum.Mul(accum, mul) 397 } 398 399 accum.Add(accum, big.NewInt(int64accum)) 400 } 401 402 exp := (int32(weight) - int32(ndigits) + 1) * 4 403 404 if dscale > 0 { 405 fracNBaseDigits := ndigits - weight - 1 406 fracDecimalDigits := fracNBaseDigits * 4 407 408 if dscale > fracDecimalDigits { 409 multCount := int(dscale - fracDecimalDigits) 410 for i := 0; i < multCount; i++ { 411 accum.Mul(accum, big10) 412 exp-- 413 } 414 } else if dscale < fracDecimalDigits { 415 divCount := int(fracDecimalDigits - dscale) 416 for i := 0; i < divCount; i++ { 417 accum.Div(accum, big10) 418 exp++ 419 } 420 } 421 } 422 423 reduced := &big.Int{} 424 remainder := &big.Int{} 425 if exp >= 0 { 426 for { 427 reduced.DivMod(accum, big10, remainder) 428 if remainder.Cmp(big0) != 0 { 429 break 430 } 431 accum.Set(reduced) 432 exp++ 433 } 434 } 435 436 if sign != 0 { 437 accum.Neg(accum) 438 } 439 440 *dst = Numeric{Int: accum, Exp: exp, Status: Present} 441 442 return nil 443 444} 445 446func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { 447 digits := len(src) / 2 448 if digits > 4 { 449 digits = 4 450 } 451 452 rp := 0 453 454 for i := 0; i < digits; i++ { 455 if i > 0 { 456 accum *= nbase 457 } 458 accum += int64(binary.BigEndian.Uint16(src[rp:])) 459 rp += 2 460 } 461 462 return accum, rp, digits 463} 464 465func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { 466 switch src.Status { 467 case Null: 468 return nil, nil 469 case Undefined: 470 return nil, errUndefined 471 } 472 473 buf = append(buf, src.Int.String()...) 474 buf = append(buf, 'e') 475 buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) 476 return buf, nil 477} 478 479func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { 480 switch src.Status { 481 case Null: 482 return nil, nil 483 case Undefined: 484 return nil, errUndefined 485 } 486 487 var sign int16 488 if src.Int.Cmp(big0) < 0 { 489 sign = 16384 490 } 491 492 absInt := &big.Int{} 493 wholePart := &big.Int{} 494 fracPart := &big.Int{} 495 remainder := &big.Int{} 496 absInt.Abs(src.Int) 497 498 // Normalize absInt and exp to where exp is always a multiple of 4. This makes 499 // converting to 16-bit base 10,000 digits easier. 500 var exp int32 501 switch src.Exp % 4 { 502 case 1, -3: 503 exp = src.Exp - 1 504 absInt.Mul(absInt, big10) 505 case 2, -2: 506 exp = src.Exp - 2 507 absInt.Mul(absInt, big100) 508 case 3, -1: 509 exp = src.Exp - 3 510 absInt.Mul(absInt, big1000) 511 default: 512 exp = src.Exp 513 } 514 515 if exp < 0 { 516 divisor := &big.Int{} 517 divisor.Exp(big10, big.NewInt(int64(-exp)), nil) 518 wholePart.DivMod(absInt, divisor, fracPart) 519 fracPart.Add(fracPart, divisor) 520 } else { 521 wholePart = absInt 522 } 523 524 var wholeDigits, fracDigits []int16 525 526 for wholePart.Cmp(big0) != 0 { 527 wholePart.DivMod(wholePart, bigNBase, remainder) 528 wholeDigits = append(wholeDigits, int16(remainder.Int64())) 529 } 530 531 if fracPart.Cmp(big0) != 0 { 532 for fracPart.Cmp(big1) != 0 { 533 fracPart.DivMod(fracPart, bigNBase, remainder) 534 fracDigits = append(fracDigits, int16(remainder.Int64())) 535 } 536 } 537 538 buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) 539 540 var weight int16 541 if len(wholeDigits) > 0 { 542 weight = int16(len(wholeDigits) - 1) 543 if exp > 0 { 544 weight += int16(exp / 4) 545 } 546 } else { 547 weight = int16(exp/4) - 1 + int16(len(fracDigits)) 548 } 549 buf = pgio.AppendInt16(buf, weight) 550 551 buf = pgio.AppendInt16(buf, sign) 552 553 var dscale int16 554 if src.Exp < 0 { 555 dscale = int16(-src.Exp) 556 } 557 buf = pgio.AppendInt16(buf, dscale) 558 559 for i := len(wholeDigits) - 1; i >= 0; i-- { 560 buf = pgio.AppendInt16(buf, wholeDigits[i]) 561 } 562 563 for i := len(fracDigits) - 1; i >= 0; i-- { 564 buf = pgio.AppendInt16(buf, fracDigits[i]) 565 } 566 567 return buf, nil 568} 569 570// Scan implements the database/sql Scanner interface. 571func (dst *Numeric) Scan(src interface{}) error { 572 if src == nil { 573 *dst = Numeric{Status: Null} 574 return nil 575 } 576 577 switch src := src.(type) { 578 case string: 579 return dst.DecodeText(nil, []byte(src)) 580 case []byte: 581 srcCopy := make([]byte, len(src)) 582 copy(srcCopy, src) 583 return dst.DecodeText(nil, srcCopy) 584 } 585 586 return errors.Errorf("cannot scan %T", src) 587} 588 589// Value implements the database/sql/driver Valuer interface. 590func (src Numeric) Value() (driver.Value, error) { 591 switch src.Status { 592 case Present: 593 buf, err := src.EncodeText(nil, nil) 594 if err != nil { 595 return nil, err 596 } 597 598 return string(buf), nil 599 case Null: 600 return nil, nil 601 default: 602 return nil, errUndefined 603 } 604} 605