1package pgtype
2
3import (
4	"math"
5	"reflect"
6	"time"
7
8	errors "golang.org/x/xerrors"
9)
10
11const maxUint = ^uint(0)
12const maxInt = int(maxUint >> 1)
13const minInt = -maxInt - 1
14
15// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8
16func underlyingNumberType(val interface{}) (interface{}, bool) {
17	refVal := reflect.ValueOf(val)
18
19	switch refVal.Kind() {
20	case reflect.Ptr:
21		if refVal.IsNil() {
22			return nil, false
23		}
24		convVal := refVal.Elem().Interface()
25		return convVal, true
26	case reflect.Int:
27		convVal := int(refVal.Int())
28		return convVal, reflect.TypeOf(convVal) != refVal.Type()
29	case reflect.Int8:
30		convVal := int8(refVal.Int())
31		return convVal, reflect.TypeOf(convVal) != refVal.Type()
32	case reflect.Int16:
33		convVal := int16(refVal.Int())
34		return convVal, reflect.TypeOf(convVal) != refVal.Type()
35	case reflect.Int32:
36		convVal := int32(refVal.Int())
37		return convVal, reflect.TypeOf(convVal) != refVal.Type()
38	case reflect.Int64:
39		convVal := int64(refVal.Int())
40		return convVal, reflect.TypeOf(convVal) != refVal.Type()
41	case reflect.Uint:
42		convVal := uint(refVal.Uint())
43		return convVal, reflect.TypeOf(convVal) != refVal.Type()
44	case reflect.Uint8:
45		convVal := uint8(refVal.Uint())
46		return convVal, reflect.TypeOf(convVal) != refVal.Type()
47	case reflect.Uint16:
48		convVal := uint16(refVal.Uint())
49		return convVal, reflect.TypeOf(convVal) != refVal.Type()
50	case reflect.Uint32:
51		convVal := uint32(refVal.Uint())
52		return convVal, reflect.TypeOf(convVal) != refVal.Type()
53	case reflect.Uint64:
54		convVal := uint64(refVal.Uint())
55		return convVal, reflect.TypeOf(convVal) != refVal.Type()
56	case reflect.Float32:
57		convVal := float32(refVal.Float())
58		return convVal, reflect.TypeOf(convVal) != refVal.Type()
59	case reflect.Float64:
60		convVal := refVal.Float()
61		return convVal, reflect.TypeOf(convVal) != refVal.Type()
62	case reflect.String:
63		convVal := refVal.String()
64		return convVal, reflect.TypeOf(convVal) != refVal.Type()
65	}
66
67	return nil, false
68}
69
70// underlyingBoolType gets the underlying type that can be converted to Bool
71func underlyingBoolType(val interface{}) (interface{}, bool) {
72	refVal := reflect.ValueOf(val)
73
74	switch refVal.Kind() {
75	case reflect.Ptr:
76		if refVal.IsNil() {
77			return nil, false
78		}
79		convVal := refVal.Elem().Interface()
80		return convVal, true
81	case reflect.Bool:
82		convVal := refVal.Bool()
83		return convVal, reflect.TypeOf(convVal) != refVal.Type()
84	}
85
86	return nil, false
87}
88
89// underlyingBytesType gets the underlying type that can be converted to []byte
90func underlyingBytesType(val interface{}) (interface{}, bool) {
91	refVal := reflect.ValueOf(val)
92
93	switch refVal.Kind() {
94	case reflect.Ptr:
95		if refVal.IsNil() {
96			return nil, false
97		}
98		convVal := refVal.Elem().Interface()
99		return convVal, true
100	case reflect.Slice:
101		if refVal.Type().Elem().Kind() == reflect.Uint8 {
102			convVal := refVal.Bytes()
103			return convVal, reflect.TypeOf(convVal) != refVal.Type()
104		}
105	}
106
107	return nil, false
108}
109
110// underlyingStringType gets the underlying type that can be converted to String
111func underlyingStringType(val interface{}) (interface{}, bool) {
112	refVal := reflect.ValueOf(val)
113
114	switch refVal.Kind() {
115	case reflect.Ptr:
116		if refVal.IsNil() {
117			return nil, false
118		}
119		convVal := refVal.Elem().Interface()
120		return convVal, true
121	case reflect.String:
122		convVal := refVal.String()
123		return convVal, reflect.TypeOf(convVal) != refVal.Type()
124	}
125
126	return nil, false
127}
128
129// underlyingPtrType dereferences a pointer
130func underlyingPtrType(val interface{}) (interface{}, bool) {
131	refVal := reflect.ValueOf(val)
132
133	switch refVal.Kind() {
134	case reflect.Ptr:
135		if refVal.IsNil() {
136			return nil, false
137		}
138		convVal := refVal.Elem().Interface()
139		return convVal, true
140	}
141
142	return nil, false
143}
144
145// underlyingTimeType gets the underlying type that can be converted to time.Time
146func underlyingTimeType(val interface{}) (interface{}, bool) {
147	refVal := reflect.ValueOf(val)
148
149	switch refVal.Kind() {
150	case reflect.Ptr:
151		if refVal.IsNil() {
152			return nil, false
153		}
154		convVal := refVal.Elem().Interface()
155		return convVal, true
156	}
157
158	timeType := reflect.TypeOf(time.Time{})
159	if refVal.Type().ConvertibleTo(timeType) {
160		return refVal.Convert(timeType).Interface(), true
161	}
162
163	return nil, false
164}
165
166// underlyingUUIDType gets the underlying type that can be converted to [16]byte
167func underlyingUUIDType(val interface{}) (interface{}, bool) {
168	refVal := reflect.ValueOf(val)
169
170	switch refVal.Kind() {
171	case reflect.Ptr:
172		if refVal.IsNil() {
173			return time.Time{}, false
174		}
175		convVal := refVal.Elem().Interface()
176		return convVal, true
177	}
178
179	uuidType := reflect.TypeOf([16]byte{})
180	if refVal.Type().ConvertibleTo(uuidType) {
181		return refVal.Convert(uuidType).Interface(), true
182	}
183
184	return nil, false
185}
186
187// underlyingSliceType gets the underlying slice type
188func underlyingSliceType(val interface{}) (interface{}, bool) {
189	refVal := reflect.ValueOf(val)
190
191	switch refVal.Kind() {
192	case reflect.Ptr:
193		if refVal.IsNil() {
194			return nil, false
195		}
196		convVal := refVal.Elem().Interface()
197		return convVal, true
198	case reflect.Slice:
199		baseSliceType := reflect.SliceOf(refVal.Type().Elem())
200		if refVal.Type().ConvertibleTo(baseSliceType) {
201			convVal := refVal.Convert(baseSliceType)
202			return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
203		}
204	}
205
206	return nil, false
207}
208
209func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error {
210	if srcStatus == Present {
211		switch v := dst.(type) {
212		case *int:
213			if srcVal < int64(minInt) {
214				return errors.Errorf("%d is less than minimum value for int", srcVal)
215			} else if srcVal > int64(maxInt) {
216				return errors.Errorf("%d is greater than maximum value for int", srcVal)
217			}
218			*v = int(srcVal)
219		case *int8:
220			if srcVal < math.MinInt8 {
221				return errors.Errorf("%d is less than minimum value for int8", srcVal)
222			} else if srcVal > math.MaxInt8 {
223				return errors.Errorf("%d is greater than maximum value for int8", srcVal)
224			}
225			*v = int8(srcVal)
226		case *int16:
227			if srcVal < math.MinInt16 {
228				return errors.Errorf("%d is less than minimum value for int16", srcVal)
229			} else if srcVal > math.MaxInt16 {
230				return errors.Errorf("%d is greater than maximum value for int16", srcVal)
231			}
232			*v = int16(srcVal)
233		case *int32:
234			if srcVal < math.MinInt32 {
235				return errors.Errorf("%d is less than minimum value for int32", srcVal)
236			} else if srcVal > math.MaxInt32 {
237				return errors.Errorf("%d is greater than maximum value for int32", srcVal)
238			}
239			*v = int32(srcVal)
240		case *int64:
241			if srcVal < math.MinInt64 {
242				return errors.Errorf("%d is less than minimum value for int64", srcVal)
243			} else if srcVal > math.MaxInt64 {
244				return errors.Errorf("%d is greater than maximum value for int64", srcVal)
245			}
246			*v = int64(srcVal)
247		case *uint:
248			if srcVal < 0 {
249				return errors.Errorf("%d is less than zero for uint", srcVal)
250			} else if uint64(srcVal) > uint64(maxUint) {
251				return errors.Errorf("%d is greater than maximum value for uint", srcVal)
252			}
253			*v = uint(srcVal)
254		case *uint8:
255			if srcVal < 0 {
256				return errors.Errorf("%d is less than zero for uint8", srcVal)
257			} else if srcVal > math.MaxUint8 {
258				return errors.Errorf("%d is greater than maximum value for uint8", srcVal)
259			}
260			*v = uint8(srcVal)
261		case *uint16:
262			if srcVal < 0 {
263				return errors.Errorf("%d is less than zero for uint32", srcVal)
264			} else if srcVal > math.MaxUint16 {
265				return errors.Errorf("%d is greater than maximum value for uint16", srcVal)
266			}
267			*v = uint16(srcVal)
268		case *uint32:
269			if srcVal < 0 {
270				return errors.Errorf("%d is less than zero for uint32", srcVal)
271			} else if srcVal > math.MaxUint32 {
272				return errors.Errorf("%d is greater than maximum value for uint32", srcVal)
273			}
274			*v = uint32(srcVal)
275		case *uint64:
276			if srcVal < 0 {
277				return errors.Errorf("%d is less than zero for uint64", srcVal)
278			}
279			*v = uint64(srcVal)
280		default:
281			if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
282				el := v.Elem()
283				switch el.Kind() {
284				// if dst is a pointer to pointer, strip the pointer and try again
285				case reflect.Ptr:
286					if el.IsNil() {
287						// allocate destination
288						el.Set(reflect.New(el.Type().Elem()))
289					}
290					return int64AssignTo(srcVal, srcStatus, el.Interface())
291				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
292					if el.OverflowInt(int64(srcVal)) {
293						return errors.Errorf("cannot put %d into %T", srcVal, dst)
294					}
295					el.SetInt(int64(srcVal))
296					return nil
297				case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
298					if srcVal < 0 {
299						return errors.Errorf("%d is less than zero for %T", srcVal, dst)
300					}
301					if el.OverflowUint(uint64(srcVal)) {
302						return errors.Errorf("cannot put %d into %T", srcVal, dst)
303					}
304					el.SetUint(uint64(srcVal))
305					return nil
306				}
307			}
308			return errors.Errorf("cannot assign %v into %T", srcVal, dst)
309		}
310		return nil
311	}
312
313	// if dst is a pointer to pointer and srcStatus is not Present, nil it out
314	if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
315		el := v.Elem()
316		if el.Kind() == reflect.Ptr {
317			el.Set(reflect.Zero(el.Type()))
318			return nil
319		}
320	}
321
322	return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
323}
324
325func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error {
326	if srcStatus == Present {
327		switch v := dst.(type) {
328		case *float32:
329			*v = float32(srcVal)
330		case *float64:
331			*v = srcVal
332		default:
333			if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
334				el := v.Elem()
335				switch el.Kind() {
336				// if dst is a pointer to pointer, strip the pointer and try again
337				case reflect.Ptr:
338					if el.IsNil() {
339						// allocate destination
340						el.Set(reflect.New(el.Type().Elem()))
341					}
342					return float64AssignTo(srcVal, srcStatus, el.Interface())
343				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
344					i64 := int64(srcVal)
345					if float64(i64) == srcVal {
346						return int64AssignTo(i64, srcStatus, dst)
347					}
348				}
349			}
350			return errors.Errorf("cannot assign %v into %T", srcVal, dst)
351		}
352		return nil
353	}
354
355	// if dst is a pointer to pointer and srcStatus is not Present, nil it out
356	if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
357		el := v.Elem()
358		if el.Kind() == reflect.Ptr {
359			el.Set(reflect.Zero(el.Type()))
360			return nil
361		}
362	}
363
364	return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
365}
366
367func NullAssignTo(dst interface{}) error {
368	dstPtr := reflect.ValueOf(dst)
369
370	// AssignTo dst must always be a pointer
371	if dstPtr.Kind() != reflect.Ptr {
372		return &nullAssignmentError{dst: dst}
373	}
374
375	dstVal := dstPtr.Elem()
376
377	switch dstVal.Kind() {
378	case reflect.Ptr, reflect.Slice, reflect.Map:
379		dstVal.Set(reflect.Zero(dstVal.Type()))
380		return nil
381	}
382
383	return &nullAssignmentError{dst: dst}
384}
385
386var kindTypes map[reflect.Kind]reflect.Type
387
388// GetAssignToDstType attempts to convert dst to something AssignTo can assign
389// to. If dst is a pointer to pointer it allocates a value and returns the
390// dereferences pointer. If dst is a named type such as *Foo where Foo is type
391// Foo int16, it converts dst to *int16.
392//
393// GetAssignToDstType returns the converted dst and a bool representing if any
394// change was made.
395func GetAssignToDstType(dst interface{}) (interface{}, bool) {
396	dstPtr := reflect.ValueOf(dst)
397
398	// AssignTo dst must always be a pointer
399	if dstPtr.Kind() != reflect.Ptr {
400		return nil, false
401	}
402
403	dstVal := dstPtr.Elem()
404
405	// if dst is a pointer to pointer, allocate space try again with the dereferenced pointer
406	if dstVal.Kind() == reflect.Ptr {
407		dstVal.Set(reflect.New(dstVal.Type().Elem()))
408		return dstVal.Interface(), true
409	}
410
411	// if dst is pointer to a base type that has been renamed
412	if baseValType, ok := kindTypes[dstVal.Kind()]; ok {
413		nextDst := dstPtr.Convert(reflect.PtrTo(baseValType))
414		return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
415	}
416
417	if dstVal.Kind() == reflect.Slice {
418		if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
419			baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType))
420			nextDst := dstPtr.Convert(baseSliceType)
421			return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
422		}
423	}
424
425	if dstVal.Kind() == reflect.Array {
426		if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
427			baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))
428			nextDst := dstPtr.Convert(baseArrayType)
429			return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
430		}
431	}
432
433	return nil, false
434}
435
436func init() {
437	kindTypes = map[reflect.Kind]reflect.Type{
438		reflect.Bool:    reflect.TypeOf(false),
439		reflect.Float32: reflect.TypeOf(float32(0)),
440		reflect.Float64: reflect.TypeOf(float64(0)),
441		reflect.Int:     reflect.TypeOf(int(0)),
442		reflect.Int8:    reflect.TypeOf(int8(0)),
443		reflect.Int16:   reflect.TypeOf(int16(0)),
444		reflect.Int32:   reflect.TypeOf(int32(0)),
445		reflect.Int64:   reflect.TypeOf(int64(0)),
446		reflect.Uint:    reflect.TypeOf(uint(0)),
447		reflect.Uint8:   reflect.TypeOf(uint8(0)),
448		reflect.Uint16:  reflect.TypeOf(uint16(0)),
449		reflect.Uint32:  reflect.TypeOf(uint32(0)),
450		reflect.Uint64:  reflect.TypeOf(uint64(0)),
451		reflect.String:  reflect.TypeOf(""),
452	}
453}
454