1package pgtype 2 3import ( 4 "database/sql" 5 "fmt" 6 "math" 7 "reflect" 8 "time" 9) 10 11const ( 12 maxUint = ^uint(0) 13 maxInt = int(maxUint >> 1) 14 minInt = -maxInt - 1 15) 16 17// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 18func underlyingNumberType(val interface{}) (interface{}, bool) { 19 refVal := reflect.ValueOf(val) 20 21 switch refVal.Kind() { 22 case reflect.Ptr: 23 if refVal.IsNil() { 24 return nil, false 25 } 26 convVal := refVal.Elem().Interface() 27 return convVal, true 28 case reflect.Int: 29 convVal := int(refVal.Int()) 30 return convVal, reflect.TypeOf(convVal) != refVal.Type() 31 case reflect.Int8: 32 convVal := int8(refVal.Int()) 33 return convVal, reflect.TypeOf(convVal) != refVal.Type() 34 case reflect.Int16: 35 convVal := int16(refVal.Int()) 36 return convVal, reflect.TypeOf(convVal) != refVal.Type() 37 case reflect.Int32: 38 convVal := int32(refVal.Int()) 39 return convVal, reflect.TypeOf(convVal) != refVal.Type() 40 case reflect.Int64: 41 convVal := int64(refVal.Int()) 42 return convVal, reflect.TypeOf(convVal) != refVal.Type() 43 case reflect.Uint: 44 convVal := uint(refVal.Uint()) 45 return convVal, reflect.TypeOf(convVal) != refVal.Type() 46 case reflect.Uint8: 47 convVal := uint8(refVal.Uint()) 48 return convVal, reflect.TypeOf(convVal) != refVal.Type() 49 case reflect.Uint16: 50 convVal := uint16(refVal.Uint()) 51 return convVal, reflect.TypeOf(convVal) != refVal.Type() 52 case reflect.Uint32: 53 convVal := uint32(refVal.Uint()) 54 return convVal, reflect.TypeOf(convVal) != refVal.Type() 55 case reflect.Uint64: 56 convVal := uint64(refVal.Uint()) 57 return convVal, reflect.TypeOf(convVal) != refVal.Type() 58 case reflect.Float32: 59 convVal := float32(refVal.Float()) 60 return convVal, reflect.TypeOf(convVal) != refVal.Type() 61 case reflect.Float64: 62 convVal := refVal.Float() 63 return convVal, reflect.TypeOf(convVal) != refVal.Type() 64 case reflect.String: 65 convVal := refVal.String() 66 return convVal, reflect.TypeOf(convVal) != refVal.Type() 67 } 68 69 return nil, false 70} 71 72// underlyingBoolType gets the underlying type that can be converted to Bool 73func underlyingBoolType(val interface{}) (interface{}, bool) { 74 refVal := reflect.ValueOf(val) 75 76 switch refVal.Kind() { 77 case reflect.Ptr: 78 if refVal.IsNil() { 79 return nil, false 80 } 81 convVal := refVal.Elem().Interface() 82 return convVal, true 83 case reflect.Bool: 84 convVal := refVal.Bool() 85 return convVal, reflect.TypeOf(convVal) != refVal.Type() 86 } 87 88 return nil, false 89} 90 91// underlyingBytesType gets the underlying type that can be converted to []byte 92func underlyingBytesType(val interface{}) (interface{}, bool) { 93 refVal := reflect.ValueOf(val) 94 95 switch refVal.Kind() { 96 case reflect.Ptr: 97 if refVal.IsNil() { 98 return nil, false 99 } 100 convVal := refVal.Elem().Interface() 101 return convVal, true 102 case reflect.Slice: 103 if refVal.Type().Elem().Kind() == reflect.Uint8 { 104 convVal := refVal.Bytes() 105 return convVal, reflect.TypeOf(convVal) != refVal.Type() 106 } 107 } 108 109 return nil, false 110} 111 112// underlyingStringType gets the underlying type that can be converted to String 113func underlyingStringType(val interface{}) (interface{}, bool) { 114 refVal := reflect.ValueOf(val) 115 116 switch refVal.Kind() { 117 case reflect.Ptr: 118 if refVal.IsNil() { 119 return nil, false 120 } 121 convVal := refVal.Elem().Interface() 122 return convVal, true 123 case reflect.String: 124 convVal := refVal.String() 125 return convVal, reflect.TypeOf(convVal) != refVal.Type() 126 } 127 128 return nil, false 129} 130 131// underlyingPtrType dereferences a pointer 132func underlyingPtrType(val interface{}) (interface{}, bool) { 133 refVal := reflect.ValueOf(val) 134 135 switch refVal.Kind() { 136 case reflect.Ptr: 137 if refVal.IsNil() { 138 return nil, false 139 } 140 convVal := refVal.Elem().Interface() 141 return convVal, true 142 } 143 144 return nil, false 145} 146 147// underlyingTimeType gets the underlying type that can be converted to time.Time 148func underlyingTimeType(val interface{}) (interface{}, bool) { 149 refVal := reflect.ValueOf(val) 150 151 switch refVal.Kind() { 152 case reflect.Ptr: 153 if refVal.IsNil() { 154 return nil, false 155 } 156 convVal := refVal.Elem().Interface() 157 return convVal, true 158 } 159 160 timeType := reflect.TypeOf(time.Time{}) 161 if refVal.Type().ConvertibleTo(timeType) { 162 return refVal.Convert(timeType).Interface(), true 163 } 164 165 return nil, false 166} 167 168// underlyingUUIDType gets the underlying type that can be converted to [16]byte 169func underlyingUUIDType(val interface{}) (interface{}, bool) { 170 refVal := reflect.ValueOf(val) 171 172 switch refVal.Kind() { 173 case reflect.Ptr: 174 if refVal.IsNil() { 175 return time.Time{}, false 176 } 177 convVal := refVal.Elem().Interface() 178 return convVal, true 179 } 180 181 uuidType := reflect.TypeOf([16]byte{}) 182 if refVal.Type().ConvertibleTo(uuidType) { 183 return refVal.Convert(uuidType).Interface(), true 184 } 185 186 return nil, false 187} 188 189// underlyingSliceType gets the underlying slice type 190func underlyingSliceType(val interface{}) (interface{}, bool) { 191 refVal := reflect.ValueOf(val) 192 193 switch refVal.Kind() { 194 case reflect.Ptr: 195 if refVal.IsNil() { 196 return nil, false 197 } 198 convVal := refVal.Elem().Interface() 199 return convVal, true 200 case reflect.Slice: 201 baseSliceType := reflect.SliceOf(refVal.Type().Elem()) 202 if refVal.Type().ConvertibleTo(baseSliceType) { 203 convVal := refVal.Convert(baseSliceType) 204 return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() 205 } 206 } 207 208 return nil, false 209} 210 211func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { 212 if srcStatus == Present { 213 switch v := dst.(type) { 214 case *int: 215 if srcVal < int64(minInt) { 216 return fmt.Errorf("%d is less than minimum value for int", srcVal) 217 } else if srcVal > int64(maxInt) { 218 return fmt.Errorf("%d is greater than maximum value for int", srcVal) 219 } 220 *v = int(srcVal) 221 case *int8: 222 if srcVal < math.MinInt8 { 223 return fmt.Errorf("%d is less than minimum value for int8", srcVal) 224 } else if srcVal > math.MaxInt8 { 225 return fmt.Errorf("%d is greater than maximum value for int8", srcVal) 226 } 227 *v = int8(srcVal) 228 case *int16: 229 if srcVal < math.MinInt16 { 230 return fmt.Errorf("%d is less than minimum value for int16", srcVal) 231 } else if srcVal > math.MaxInt16 { 232 return fmt.Errorf("%d is greater than maximum value for int16", srcVal) 233 } 234 *v = int16(srcVal) 235 case *int32: 236 if srcVal < math.MinInt32 { 237 return fmt.Errorf("%d is less than minimum value for int32", srcVal) 238 } else if srcVal > math.MaxInt32 { 239 return fmt.Errorf("%d is greater than maximum value for int32", srcVal) 240 } 241 *v = int32(srcVal) 242 case *int64: 243 if srcVal < math.MinInt64 { 244 return fmt.Errorf("%d is less than minimum value for int64", srcVal) 245 } else if srcVal > math.MaxInt64 { 246 return fmt.Errorf("%d is greater than maximum value for int64", srcVal) 247 } 248 *v = int64(srcVal) 249 case *uint: 250 if srcVal < 0 { 251 return fmt.Errorf("%d is less than zero for uint", srcVal) 252 } else if uint64(srcVal) > uint64(maxUint) { 253 return fmt.Errorf("%d is greater than maximum value for uint", srcVal) 254 } 255 *v = uint(srcVal) 256 case *uint8: 257 if srcVal < 0 { 258 return fmt.Errorf("%d is less than zero for uint8", srcVal) 259 } else if srcVal > math.MaxUint8 { 260 return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) 261 } 262 *v = uint8(srcVal) 263 case *uint16: 264 if srcVal < 0 { 265 return fmt.Errorf("%d is less than zero for uint32", srcVal) 266 } else if srcVal > math.MaxUint16 { 267 return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) 268 } 269 *v = uint16(srcVal) 270 case *uint32: 271 if srcVal < 0 { 272 return fmt.Errorf("%d is less than zero for uint32", srcVal) 273 } else if srcVal > math.MaxUint32 { 274 return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) 275 } 276 *v = uint32(srcVal) 277 case *uint64: 278 if srcVal < 0 { 279 return fmt.Errorf("%d is less than zero for uint64", srcVal) 280 } 281 *v = uint64(srcVal) 282 case sql.Scanner: 283 return v.Scan(srcVal) 284 default: 285 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { 286 el := v.Elem() 287 switch el.Kind() { 288 // if dst is a pointer to pointer, strip the pointer and try again 289 case reflect.Ptr: 290 if el.IsNil() { 291 // allocate destination 292 el.Set(reflect.New(el.Type().Elem())) 293 } 294 return int64AssignTo(srcVal, srcStatus, el.Interface()) 295 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 296 if el.OverflowInt(int64(srcVal)) { 297 return fmt.Errorf("cannot put %d into %T", srcVal, dst) 298 } 299 el.SetInt(int64(srcVal)) 300 return nil 301 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 302 if srcVal < 0 { 303 return fmt.Errorf("%d is less than zero for %T", srcVal, dst) 304 } 305 if el.OverflowUint(uint64(srcVal)) { 306 return fmt.Errorf("cannot put %d into %T", srcVal, dst) 307 } 308 el.SetUint(uint64(srcVal)) 309 return nil 310 } 311 } 312 return fmt.Errorf("cannot assign %v into %T", srcVal, dst) 313 } 314 return nil 315 } 316 317 // if dst is a pointer to pointer and srcStatus is not Present, nil it out 318 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { 319 el := v.Elem() 320 if el.Kind() == reflect.Ptr { 321 el.Set(reflect.Zero(el.Type())) 322 return nil 323 } 324 } 325 326 return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) 327} 328 329func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { 330 if srcStatus == Present { 331 switch v := dst.(type) { 332 case *float32: 333 *v = float32(srcVal) 334 case *float64: 335 *v = srcVal 336 default: 337 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { 338 el := v.Elem() 339 switch el.Kind() { 340 // if dst is a pointer to pointer, strip the pointer and try again 341 case reflect.Ptr: 342 if el.IsNil() { 343 // allocate destination 344 el.Set(reflect.New(el.Type().Elem())) 345 } 346 return float64AssignTo(srcVal, srcStatus, el.Interface()) 347 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 348 i64 := int64(srcVal) 349 if float64(i64) == srcVal { 350 return int64AssignTo(i64, srcStatus, dst) 351 } 352 } 353 } 354 return fmt.Errorf("cannot assign %v into %T", srcVal, dst) 355 } 356 return nil 357 } 358 359 // if dst is a pointer to pointer and srcStatus is not Present, nil it out 360 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { 361 el := v.Elem() 362 if el.Kind() == reflect.Ptr { 363 el.Set(reflect.Zero(el.Type())) 364 return nil 365 } 366 } 367 368 return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) 369} 370 371func NullAssignTo(dst interface{}) error { 372 dstPtr := reflect.ValueOf(dst) 373 374 // AssignTo dst must always be a pointer 375 if dstPtr.Kind() != reflect.Ptr { 376 return &nullAssignmentError{dst: dst} 377 } 378 379 dstVal := dstPtr.Elem() 380 381 switch dstVal.Kind() { 382 case reflect.Ptr, reflect.Slice, reflect.Map: 383 dstVal.Set(reflect.Zero(dstVal.Type())) 384 return nil 385 } 386 387 return &nullAssignmentError{dst: dst} 388} 389 390var kindTypes map[reflect.Kind]reflect.Type 391 392func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) { 393 nextDst := dst.Convert(t) 394 return nextDst.Interface(), dst.Type() != nextDst.Type() 395} 396 397// GetAssignToDstType attempts to convert dst to something AssignTo can assign 398// to. If dst is a pointer to pointer it allocates a value and returns the 399// dereferences pointer. If dst is a named type such as *Foo where Foo is type 400// Foo int16, it converts dst to *int16. 401// 402// GetAssignToDstType returns the converted dst and a bool representing if any 403// change was made. 404func GetAssignToDstType(dst interface{}) (interface{}, bool) { 405 dstPtr := reflect.ValueOf(dst) 406 407 // AssignTo dst must always be a pointer 408 if dstPtr.Kind() != reflect.Ptr { 409 return nil, false 410 } 411 412 dstVal := dstPtr.Elem() 413 414 // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer 415 if dstVal.Kind() == reflect.Ptr { 416 dstVal.Set(reflect.New(dstVal.Type().Elem())) 417 return dstVal.Interface(), true 418 } 419 420 // if dst is pointer to a base type that has been renamed 421 if baseValType, ok := kindTypes[dstVal.Kind()]; ok { 422 return toInterface(dstPtr, reflect.PtrTo(baseValType)) 423 } 424 425 if dstVal.Kind() == reflect.Slice { 426 if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { 427 return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType))) 428 } 429 } 430 431 if dstVal.Kind() == reflect.Array { 432 if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { 433 return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))) 434 } 435 } 436 437 if dstVal.Kind() == reflect.Struct { 438 if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous { 439 dstPtr = dstVal.Field(0).Addr() 440 nested := dstVal.Type().Field(0).Type 441 if nested.Kind() == reflect.Array { 442 if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok { 443 return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType))) 444 } 445 } 446 if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() { 447 return dstPtr.Interface(), true 448 } 449 } 450 } 451 452 return nil, false 453} 454 455func init() { 456 kindTypes = map[reflect.Kind]reflect.Type{ 457 reflect.Bool: reflect.TypeOf(false), 458 reflect.Float32: reflect.TypeOf(float32(0)), 459 reflect.Float64: reflect.TypeOf(float64(0)), 460 reflect.Int: reflect.TypeOf(int(0)), 461 reflect.Int8: reflect.TypeOf(int8(0)), 462 reflect.Int16: reflect.TypeOf(int16(0)), 463 reflect.Int32: reflect.TypeOf(int32(0)), 464 reflect.Int64: reflect.TypeOf(int64(0)), 465 reflect.Uint: reflect.TypeOf(uint(0)), 466 reflect.Uint8: reflect.TypeOf(uint8(0)), 467 reflect.Uint16: reflect.TypeOf(uint16(0)), 468 reflect.Uint32: reflect.TypeOf(uint32(0)), 469 reflect.Uint64: reflect.TypeOf(uint64(0)), 470 reflect.String: reflect.TypeOf(""), 471 } 472} 473