1package pgtype 2 3import ( 4 "math" 5 "reflect" 6 "time" 7 8 errors "golang.org/x/xerrors" 9) 10 11const maxUint = ^uint(0) 12const maxInt = int(maxUint >> 1) 13const minInt = -maxInt - 1 14 15// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 16func underlyingNumberType(val interface{}) (interface{}, bool) { 17 refVal := reflect.ValueOf(val) 18 19 switch refVal.Kind() { 20 case reflect.Ptr: 21 if refVal.IsNil() { 22 return nil, false 23 } 24 convVal := refVal.Elem().Interface() 25 return convVal, true 26 case reflect.Int: 27 convVal := int(refVal.Int()) 28 return convVal, reflect.TypeOf(convVal) != refVal.Type() 29 case reflect.Int8: 30 convVal := int8(refVal.Int()) 31 return convVal, reflect.TypeOf(convVal) != refVal.Type() 32 case reflect.Int16: 33 convVal := int16(refVal.Int()) 34 return convVal, reflect.TypeOf(convVal) != refVal.Type() 35 case reflect.Int32: 36 convVal := int32(refVal.Int()) 37 return convVal, reflect.TypeOf(convVal) != refVal.Type() 38 case reflect.Int64: 39 convVal := int64(refVal.Int()) 40 return convVal, reflect.TypeOf(convVal) != refVal.Type() 41 case reflect.Uint: 42 convVal := uint(refVal.Uint()) 43 return convVal, reflect.TypeOf(convVal) != refVal.Type() 44 case reflect.Uint8: 45 convVal := uint8(refVal.Uint()) 46 return convVal, reflect.TypeOf(convVal) != refVal.Type() 47 case reflect.Uint16: 48 convVal := uint16(refVal.Uint()) 49 return convVal, reflect.TypeOf(convVal) != refVal.Type() 50 case reflect.Uint32: 51 convVal := uint32(refVal.Uint()) 52 return convVal, reflect.TypeOf(convVal) != refVal.Type() 53 case reflect.Uint64: 54 convVal := uint64(refVal.Uint()) 55 return convVal, reflect.TypeOf(convVal) != refVal.Type() 56 case reflect.Float32: 57 convVal := float32(refVal.Float()) 58 return convVal, reflect.TypeOf(convVal) != refVal.Type() 59 case reflect.Float64: 60 convVal := refVal.Float() 61 return convVal, reflect.TypeOf(convVal) != refVal.Type() 62 case reflect.String: 63 convVal := refVal.String() 64 return convVal, reflect.TypeOf(convVal) != refVal.Type() 65 } 66 67 return nil, false 68} 69 70// underlyingBoolType gets the underlying type that can be converted to Bool 71func underlyingBoolType(val interface{}) (interface{}, bool) { 72 refVal := reflect.ValueOf(val) 73 74 switch refVal.Kind() { 75 case reflect.Ptr: 76 if refVal.IsNil() { 77 return nil, false 78 } 79 convVal := refVal.Elem().Interface() 80 return convVal, true 81 case reflect.Bool: 82 convVal := refVal.Bool() 83 return convVal, reflect.TypeOf(convVal) != refVal.Type() 84 } 85 86 return nil, false 87} 88 89// underlyingBytesType gets the underlying type that can be converted to []byte 90func underlyingBytesType(val interface{}) (interface{}, bool) { 91 refVal := reflect.ValueOf(val) 92 93 switch refVal.Kind() { 94 case reflect.Ptr: 95 if refVal.IsNil() { 96 return nil, false 97 } 98 convVal := refVal.Elem().Interface() 99 return convVal, true 100 case reflect.Slice: 101 if refVal.Type().Elem().Kind() == reflect.Uint8 { 102 convVal := refVal.Bytes() 103 return convVal, reflect.TypeOf(convVal) != refVal.Type() 104 } 105 } 106 107 return nil, false 108} 109 110// underlyingStringType gets the underlying type that can be converted to String 111func underlyingStringType(val interface{}) (interface{}, bool) { 112 refVal := reflect.ValueOf(val) 113 114 switch refVal.Kind() { 115 case reflect.Ptr: 116 if refVal.IsNil() { 117 return nil, false 118 } 119 convVal := refVal.Elem().Interface() 120 return convVal, true 121 case reflect.String: 122 convVal := refVal.String() 123 return convVal, reflect.TypeOf(convVal) != refVal.Type() 124 } 125 126 return nil, false 127} 128 129// underlyingPtrType dereferences a pointer 130func underlyingPtrType(val interface{}) (interface{}, bool) { 131 refVal := reflect.ValueOf(val) 132 133 switch refVal.Kind() { 134 case reflect.Ptr: 135 if refVal.IsNil() { 136 return nil, false 137 } 138 convVal := refVal.Elem().Interface() 139 return convVal, true 140 } 141 142 return nil, false 143} 144 145// underlyingTimeType gets the underlying type that can be converted to time.Time 146func underlyingTimeType(val interface{}) (interface{}, bool) { 147 refVal := reflect.ValueOf(val) 148 149 switch refVal.Kind() { 150 case reflect.Ptr: 151 if refVal.IsNil() { 152 return nil, false 153 } 154 convVal := refVal.Elem().Interface() 155 return convVal, true 156 } 157 158 timeType := reflect.TypeOf(time.Time{}) 159 if refVal.Type().ConvertibleTo(timeType) { 160 return refVal.Convert(timeType).Interface(), true 161 } 162 163 return nil, false 164} 165 166// underlyingUUIDType gets the underlying type that can be converted to [16]byte 167func underlyingUUIDType(val interface{}) (interface{}, bool) { 168 refVal := reflect.ValueOf(val) 169 170 switch refVal.Kind() { 171 case reflect.Ptr: 172 if refVal.IsNil() { 173 return time.Time{}, false 174 } 175 convVal := refVal.Elem().Interface() 176 return convVal, true 177 } 178 179 uuidType := reflect.TypeOf([16]byte{}) 180 if refVal.Type().ConvertibleTo(uuidType) { 181 return refVal.Convert(uuidType).Interface(), true 182 } 183 184 return nil, false 185} 186 187// underlyingSliceType gets the underlying slice type 188func underlyingSliceType(val interface{}) (interface{}, bool) { 189 refVal := reflect.ValueOf(val) 190 191 switch refVal.Kind() { 192 case reflect.Ptr: 193 if refVal.IsNil() { 194 return nil, false 195 } 196 convVal := refVal.Elem().Interface() 197 return convVal, true 198 case reflect.Slice: 199 baseSliceType := reflect.SliceOf(refVal.Type().Elem()) 200 if refVal.Type().ConvertibleTo(baseSliceType) { 201 convVal := refVal.Convert(baseSliceType) 202 return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() 203 } 204 } 205 206 return nil, false 207} 208 209func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { 210 if srcStatus == Present { 211 switch v := dst.(type) { 212 case *int: 213 if srcVal < int64(minInt) { 214 return errors.Errorf("%d is less than minimum value for int", srcVal) 215 } else if srcVal > int64(maxInt) { 216 return errors.Errorf("%d is greater than maximum value for int", srcVal) 217 } 218 *v = int(srcVal) 219 case *int8: 220 if srcVal < math.MinInt8 { 221 return errors.Errorf("%d is less than minimum value for int8", srcVal) 222 } else if srcVal > math.MaxInt8 { 223 return errors.Errorf("%d is greater than maximum value for int8", srcVal) 224 } 225 *v = int8(srcVal) 226 case *int16: 227 if srcVal < math.MinInt16 { 228 return errors.Errorf("%d is less than minimum value for int16", srcVal) 229 } else if srcVal > math.MaxInt16 { 230 return errors.Errorf("%d is greater than maximum value for int16", srcVal) 231 } 232 *v = int16(srcVal) 233 case *int32: 234 if srcVal < math.MinInt32 { 235 return errors.Errorf("%d is less than minimum value for int32", srcVal) 236 } else if srcVal > math.MaxInt32 { 237 return errors.Errorf("%d is greater than maximum value for int32", srcVal) 238 } 239 *v = int32(srcVal) 240 case *int64: 241 if srcVal < math.MinInt64 { 242 return errors.Errorf("%d is less than minimum value for int64", srcVal) 243 } else if srcVal > math.MaxInt64 { 244 return errors.Errorf("%d is greater than maximum value for int64", srcVal) 245 } 246 *v = int64(srcVal) 247 case *uint: 248 if srcVal < 0 { 249 return errors.Errorf("%d is less than zero for uint", srcVal) 250 } else if uint64(srcVal) > uint64(maxUint) { 251 return errors.Errorf("%d is greater than maximum value for uint", srcVal) 252 } 253 *v = uint(srcVal) 254 case *uint8: 255 if srcVal < 0 { 256 return errors.Errorf("%d is less than zero for uint8", srcVal) 257 } else if srcVal > math.MaxUint8 { 258 return errors.Errorf("%d is greater than maximum value for uint8", srcVal) 259 } 260 *v = uint8(srcVal) 261 case *uint16: 262 if srcVal < 0 { 263 return errors.Errorf("%d is less than zero for uint32", srcVal) 264 } else if srcVal > math.MaxUint16 { 265 return errors.Errorf("%d is greater than maximum value for uint16", srcVal) 266 } 267 *v = uint16(srcVal) 268 case *uint32: 269 if srcVal < 0 { 270 return errors.Errorf("%d is less than zero for uint32", srcVal) 271 } else if srcVal > math.MaxUint32 { 272 return errors.Errorf("%d is greater than maximum value for uint32", srcVal) 273 } 274 *v = uint32(srcVal) 275 case *uint64: 276 if srcVal < 0 { 277 return errors.Errorf("%d is less than zero for uint64", srcVal) 278 } 279 *v = uint64(srcVal) 280 default: 281 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { 282 el := v.Elem() 283 switch el.Kind() { 284 // if dst is a pointer to pointer, strip the pointer and try again 285 case reflect.Ptr: 286 if el.IsNil() { 287 // allocate destination 288 el.Set(reflect.New(el.Type().Elem())) 289 } 290 return int64AssignTo(srcVal, srcStatus, el.Interface()) 291 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 292 if el.OverflowInt(int64(srcVal)) { 293 return errors.Errorf("cannot put %d into %T", srcVal, dst) 294 } 295 el.SetInt(int64(srcVal)) 296 return nil 297 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 298 if srcVal < 0 { 299 return errors.Errorf("%d is less than zero for %T", srcVal, dst) 300 } 301 if el.OverflowUint(uint64(srcVal)) { 302 return errors.Errorf("cannot put %d into %T", srcVal, dst) 303 } 304 el.SetUint(uint64(srcVal)) 305 return nil 306 } 307 } 308 return errors.Errorf("cannot assign %v into %T", srcVal, dst) 309 } 310 return nil 311 } 312 313 // if dst is a pointer to pointer and srcStatus is not Present, nil it out 314 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { 315 el := v.Elem() 316 if el.Kind() == reflect.Ptr { 317 el.Set(reflect.Zero(el.Type())) 318 return nil 319 } 320 } 321 322 return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) 323} 324 325func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { 326 if srcStatus == Present { 327 switch v := dst.(type) { 328 case *float32: 329 *v = float32(srcVal) 330 case *float64: 331 *v = srcVal 332 default: 333 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { 334 el := v.Elem() 335 switch el.Kind() { 336 // if dst is a pointer to pointer, strip the pointer and try again 337 case reflect.Ptr: 338 if el.IsNil() { 339 // allocate destination 340 el.Set(reflect.New(el.Type().Elem())) 341 } 342 return float64AssignTo(srcVal, srcStatus, el.Interface()) 343 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 344 i64 := int64(srcVal) 345 if float64(i64) == srcVal { 346 return int64AssignTo(i64, srcStatus, dst) 347 } 348 } 349 } 350 return errors.Errorf("cannot assign %v into %T", srcVal, dst) 351 } 352 return nil 353 } 354 355 // if dst is a pointer to pointer and srcStatus is not Present, nil it out 356 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { 357 el := v.Elem() 358 if el.Kind() == reflect.Ptr { 359 el.Set(reflect.Zero(el.Type())) 360 return nil 361 } 362 } 363 364 return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) 365} 366 367func NullAssignTo(dst interface{}) error { 368 dstPtr := reflect.ValueOf(dst) 369 370 // AssignTo dst must always be a pointer 371 if dstPtr.Kind() != reflect.Ptr { 372 return &nullAssignmentError{dst: dst} 373 } 374 375 dstVal := dstPtr.Elem() 376 377 switch dstVal.Kind() { 378 case reflect.Ptr, reflect.Slice, reflect.Map: 379 dstVal.Set(reflect.Zero(dstVal.Type())) 380 return nil 381 } 382 383 return &nullAssignmentError{dst: dst} 384} 385 386var kindTypes map[reflect.Kind]reflect.Type 387 388// GetAssignToDstType attempts to convert dst to something AssignTo can assign 389// to. If dst is a pointer to pointer it allocates a value and returns the 390// dereferences pointer. If dst is a named type such as *Foo where Foo is type 391// Foo int16, it converts dst to *int16. 392// 393// GetAssignToDstType returns the converted dst and a bool representing if any 394// change was made. 395func GetAssignToDstType(dst interface{}) (interface{}, bool) { 396 dstPtr := reflect.ValueOf(dst) 397 398 // AssignTo dst must always be a pointer 399 if dstPtr.Kind() != reflect.Ptr { 400 return nil, false 401 } 402 403 dstVal := dstPtr.Elem() 404 405 // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer 406 if dstVal.Kind() == reflect.Ptr { 407 dstVal.Set(reflect.New(dstVal.Type().Elem())) 408 return dstVal.Interface(), true 409 } 410 411 // if dst is pointer to a base type that has been renamed 412 if baseValType, ok := kindTypes[dstVal.Kind()]; ok { 413 nextDst := dstPtr.Convert(reflect.PtrTo(baseValType)) 414 return nextDst.Interface(), dstPtr.Type() != nextDst.Type() 415 } 416 417 if dstVal.Kind() == reflect.Slice { 418 if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { 419 baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType)) 420 nextDst := dstPtr.Convert(baseSliceType) 421 return nextDst.Interface(), dstPtr.Type() != nextDst.Type() 422 } 423 } 424 425 if dstVal.Kind() == reflect.Array { 426 if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { 427 baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)) 428 nextDst := dstPtr.Convert(baseArrayType) 429 return nextDst.Interface(), dstPtr.Type() != nextDst.Type() 430 } 431 } 432 433 return nil, false 434} 435 436func init() { 437 kindTypes = map[reflect.Kind]reflect.Type{ 438 reflect.Bool: reflect.TypeOf(false), 439 reflect.Float32: reflect.TypeOf(float32(0)), 440 reflect.Float64: reflect.TypeOf(float64(0)), 441 reflect.Int: reflect.TypeOf(int(0)), 442 reflect.Int8: reflect.TypeOf(int8(0)), 443 reflect.Int16: reflect.TypeOf(int16(0)), 444 reflect.Int32: reflect.TypeOf(int32(0)), 445 reflect.Int64: reflect.TypeOf(int64(0)), 446 reflect.Uint: reflect.TypeOf(uint(0)), 447 reflect.Uint8: reflect.TypeOf(uint8(0)), 448 reflect.Uint16: reflect.TypeOf(uint16(0)), 449 reflect.Uint32: reflect.TypeOf(uint32(0)), 450 reflect.Uint64: reflect.TypeOf(uint64(0)), 451 reflect.String: reflect.TypeOf(""), 452 } 453} 454