1package pgtype
2
3import (
4	"database/sql"
5	"fmt"
6	"math"
7	"reflect"
8	"time"
9)
10
11const (
12	maxUint = ^uint(0)
13	maxInt  = int(maxUint >> 1)
14	minInt  = -maxInt - 1
15)
16
17// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8
18func underlyingNumberType(val interface{}) (interface{}, bool) {
19	refVal := reflect.ValueOf(val)
20
21	switch refVal.Kind() {
22	case reflect.Ptr:
23		if refVal.IsNil() {
24			return nil, false
25		}
26		convVal := refVal.Elem().Interface()
27		return convVal, true
28	case reflect.Int:
29		convVal := int(refVal.Int())
30		return convVal, reflect.TypeOf(convVal) != refVal.Type()
31	case reflect.Int8:
32		convVal := int8(refVal.Int())
33		return convVal, reflect.TypeOf(convVal) != refVal.Type()
34	case reflect.Int16:
35		convVal := int16(refVal.Int())
36		return convVal, reflect.TypeOf(convVal) != refVal.Type()
37	case reflect.Int32:
38		convVal := int32(refVal.Int())
39		return convVal, reflect.TypeOf(convVal) != refVal.Type()
40	case reflect.Int64:
41		convVal := int64(refVal.Int())
42		return convVal, reflect.TypeOf(convVal) != refVal.Type()
43	case reflect.Uint:
44		convVal := uint(refVal.Uint())
45		return convVal, reflect.TypeOf(convVal) != refVal.Type()
46	case reflect.Uint8:
47		convVal := uint8(refVal.Uint())
48		return convVal, reflect.TypeOf(convVal) != refVal.Type()
49	case reflect.Uint16:
50		convVal := uint16(refVal.Uint())
51		return convVal, reflect.TypeOf(convVal) != refVal.Type()
52	case reflect.Uint32:
53		convVal := uint32(refVal.Uint())
54		return convVal, reflect.TypeOf(convVal) != refVal.Type()
55	case reflect.Uint64:
56		convVal := uint64(refVal.Uint())
57		return convVal, reflect.TypeOf(convVal) != refVal.Type()
58	case reflect.Float32:
59		convVal := float32(refVal.Float())
60		return convVal, reflect.TypeOf(convVal) != refVal.Type()
61	case reflect.Float64:
62		convVal := refVal.Float()
63		return convVal, reflect.TypeOf(convVal) != refVal.Type()
64	case reflect.String:
65		convVal := refVal.String()
66		return convVal, reflect.TypeOf(convVal) != refVal.Type()
67	}
68
69	return nil, false
70}
71
72// underlyingBoolType gets the underlying type that can be converted to Bool
73func underlyingBoolType(val interface{}) (interface{}, bool) {
74	refVal := reflect.ValueOf(val)
75
76	switch refVal.Kind() {
77	case reflect.Ptr:
78		if refVal.IsNil() {
79			return nil, false
80		}
81		convVal := refVal.Elem().Interface()
82		return convVal, true
83	case reflect.Bool:
84		convVal := refVal.Bool()
85		return convVal, reflect.TypeOf(convVal) != refVal.Type()
86	}
87
88	return nil, false
89}
90
91// underlyingBytesType gets the underlying type that can be converted to []byte
92func underlyingBytesType(val interface{}) (interface{}, bool) {
93	refVal := reflect.ValueOf(val)
94
95	switch refVal.Kind() {
96	case reflect.Ptr:
97		if refVal.IsNil() {
98			return nil, false
99		}
100		convVal := refVal.Elem().Interface()
101		return convVal, true
102	case reflect.Slice:
103		if refVal.Type().Elem().Kind() == reflect.Uint8 {
104			convVal := refVal.Bytes()
105			return convVal, reflect.TypeOf(convVal) != refVal.Type()
106		}
107	}
108
109	return nil, false
110}
111
112// underlyingStringType gets the underlying type that can be converted to String
113func underlyingStringType(val interface{}) (interface{}, bool) {
114	refVal := reflect.ValueOf(val)
115
116	switch refVal.Kind() {
117	case reflect.Ptr:
118		if refVal.IsNil() {
119			return nil, false
120		}
121		convVal := refVal.Elem().Interface()
122		return convVal, true
123	case reflect.String:
124		convVal := refVal.String()
125		return convVal, reflect.TypeOf(convVal) != refVal.Type()
126	}
127
128	return nil, false
129}
130
131// underlyingPtrType dereferences a pointer
132func underlyingPtrType(val interface{}) (interface{}, bool) {
133	refVal := reflect.ValueOf(val)
134
135	switch refVal.Kind() {
136	case reflect.Ptr:
137		if refVal.IsNil() {
138			return nil, false
139		}
140		convVal := refVal.Elem().Interface()
141		return convVal, true
142	}
143
144	return nil, false
145}
146
147// underlyingTimeType gets the underlying type that can be converted to time.Time
148func underlyingTimeType(val interface{}) (interface{}, bool) {
149	refVal := reflect.ValueOf(val)
150
151	switch refVal.Kind() {
152	case reflect.Ptr:
153		if refVal.IsNil() {
154			return nil, false
155		}
156		convVal := refVal.Elem().Interface()
157		return convVal, true
158	}
159
160	timeType := reflect.TypeOf(time.Time{})
161	if refVal.Type().ConvertibleTo(timeType) {
162		return refVal.Convert(timeType).Interface(), true
163	}
164
165	return nil, false
166}
167
168// underlyingUUIDType gets the underlying type that can be converted to [16]byte
169func underlyingUUIDType(val interface{}) (interface{}, bool) {
170	refVal := reflect.ValueOf(val)
171
172	switch refVal.Kind() {
173	case reflect.Ptr:
174		if refVal.IsNil() {
175			return time.Time{}, false
176		}
177		convVal := refVal.Elem().Interface()
178		return convVal, true
179	}
180
181	uuidType := reflect.TypeOf([16]byte{})
182	if refVal.Type().ConvertibleTo(uuidType) {
183		return refVal.Convert(uuidType).Interface(), true
184	}
185
186	return nil, false
187}
188
189// underlyingSliceType gets the underlying slice type
190func underlyingSliceType(val interface{}) (interface{}, bool) {
191	refVal := reflect.ValueOf(val)
192
193	switch refVal.Kind() {
194	case reflect.Ptr:
195		if refVal.IsNil() {
196			return nil, false
197		}
198		convVal := refVal.Elem().Interface()
199		return convVal, true
200	case reflect.Slice:
201		baseSliceType := reflect.SliceOf(refVal.Type().Elem())
202		if refVal.Type().ConvertibleTo(baseSliceType) {
203			convVal := refVal.Convert(baseSliceType)
204			return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
205		}
206	}
207
208	return nil, false
209}
210
211func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error {
212	if srcStatus == Present {
213		switch v := dst.(type) {
214		case *int:
215			if srcVal < int64(minInt) {
216				return fmt.Errorf("%d is less than minimum value for int", srcVal)
217			} else if srcVal > int64(maxInt) {
218				return fmt.Errorf("%d is greater than maximum value for int", srcVal)
219			}
220			*v = int(srcVal)
221		case *int8:
222			if srcVal < math.MinInt8 {
223				return fmt.Errorf("%d is less than minimum value for int8", srcVal)
224			} else if srcVal > math.MaxInt8 {
225				return fmt.Errorf("%d is greater than maximum value for int8", srcVal)
226			}
227			*v = int8(srcVal)
228		case *int16:
229			if srcVal < math.MinInt16 {
230				return fmt.Errorf("%d is less than minimum value for int16", srcVal)
231			} else if srcVal > math.MaxInt16 {
232				return fmt.Errorf("%d is greater than maximum value for int16", srcVal)
233			}
234			*v = int16(srcVal)
235		case *int32:
236			if srcVal < math.MinInt32 {
237				return fmt.Errorf("%d is less than minimum value for int32", srcVal)
238			} else if srcVal > math.MaxInt32 {
239				return fmt.Errorf("%d is greater than maximum value for int32", srcVal)
240			}
241			*v = int32(srcVal)
242		case *int64:
243			if srcVal < math.MinInt64 {
244				return fmt.Errorf("%d is less than minimum value for int64", srcVal)
245			} else if srcVal > math.MaxInt64 {
246				return fmt.Errorf("%d is greater than maximum value for int64", srcVal)
247			}
248			*v = int64(srcVal)
249		case *uint:
250			if srcVal < 0 {
251				return fmt.Errorf("%d is less than zero for uint", srcVal)
252			} else if uint64(srcVal) > uint64(maxUint) {
253				return fmt.Errorf("%d is greater than maximum value for uint", srcVal)
254			}
255			*v = uint(srcVal)
256		case *uint8:
257			if srcVal < 0 {
258				return fmt.Errorf("%d is less than zero for uint8", srcVal)
259			} else if srcVal > math.MaxUint8 {
260				return fmt.Errorf("%d is greater than maximum value for uint8", srcVal)
261			}
262			*v = uint8(srcVal)
263		case *uint16:
264			if srcVal < 0 {
265				return fmt.Errorf("%d is less than zero for uint32", srcVal)
266			} else if srcVal > math.MaxUint16 {
267				return fmt.Errorf("%d is greater than maximum value for uint16", srcVal)
268			}
269			*v = uint16(srcVal)
270		case *uint32:
271			if srcVal < 0 {
272				return fmt.Errorf("%d is less than zero for uint32", srcVal)
273			} else if srcVal > math.MaxUint32 {
274				return fmt.Errorf("%d is greater than maximum value for uint32", srcVal)
275			}
276			*v = uint32(srcVal)
277		case *uint64:
278			if srcVal < 0 {
279				return fmt.Errorf("%d is less than zero for uint64", srcVal)
280			}
281			*v = uint64(srcVal)
282		case sql.Scanner:
283			return v.Scan(srcVal)
284		default:
285			if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
286				el := v.Elem()
287				switch el.Kind() {
288				// if dst is a pointer to pointer, strip the pointer and try again
289				case reflect.Ptr:
290					if el.IsNil() {
291						// allocate destination
292						el.Set(reflect.New(el.Type().Elem()))
293					}
294					return int64AssignTo(srcVal, srcStatus, el.Interface())
295				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
296					if el.OverflowInt(int64(srcVal)) {
297						return fmt.Errorf("cannot put %d into %T", srcVal, dst)
298					}
299					el.SetInt(int64(srcVal))
300					return nil
301				case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
302					if srcVal < 0 {
303						return fmt.Errorf("%d is less than zero for %T", srcVal, dst)
304					}
305					if el.OverflowUint(uint64(srcVal)) {
306						return fmt.Errorf("cannot put %d into %T", srcVal, dst)
307					}
308					el.SetUint(uint64(srcVal))
309					return nil
310				}
311			}
312			return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
313		}
314		return nil
315	}
316
317	// if dst is a pointer to pointer and srcStatus is not Present, nil it out
318	if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
319		el := v.Elem()
320		if el.Kind() == reflect.Ptr {
321			el.Set(reflect.Zero(el.Type()))
322			return nil
323		}
324	}
325
326	return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
327}
328
329func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error {
330	if srcStatus == Present {
331		switch v := dst.(type) {
332		case *float32:
333			*v = float32(srcVal)
334		case *float64:
335			*v = srcVal
336		default:
337			if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
338				el := v.Elem()
339				switch el.Kind() {
340				// if dst is a pointer to pointer, strip the pointer and try again
341				case reflect.Ptr:
342					if el.IsNil() {
343						// allocate destination
344						el.Set(reflect.New(el.Type().Elem()))
345					}
346					return float64AssignTo(srcVal, srcStatus, el.Interface())
347				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
348					i64 := int64(srcVal)
349					if float64(i64) == srcVal {
350						return int64AssignTo(i64, srcStatus, dst)
351					}
352				}
353			}
354			return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
355		}
356		return nil
357	}
358
359	// if dst is a pointer to pointer and srcStatus is not Present, nil it out
360	if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
361		el := v.Elem()
362		if el.Kind() == reflect.Ptr {
363			el.Set(reflect.Zero(el.Type()))
364			return nil
365		}
366	}
367
368	return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
369}
370
371func NullAssignTo(dst interface{}) error {
372	dstPtr := reflect.ValueOf(dst)
373
374	// AssignTo dst must always be a pointer
375	if dstPtr.Kind() != reflect.Ptr {
376		return &nullAssignmentError{dst: dst}
377	}
378
379	dstVal := dstPtr.Elem()
380
381	switch dstVal.Kind() {
382	case reflect.Ptr, reflect.Slice, reflect.Map:
383		dstVal.Set(reflect.Zero(dstVal.Type()))
384		return nil
385	}
386
387	return &nullAssignmentError{dst: dst}
388}
389
390var kindTypes map[reflect.Kind]reflect.Type
391
392func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) {
393	nextDst := dst.Convert(t)
394	return nextDst.Interface(), dst.Type() != nextDst.Type()
395}
396
397// GetAssignToDstType attempts to convert dst to something AssignTo can assign
398// to. If dst is a pointer to pointer it allocates a value and returns the
399// dereferences pointer. If dst is a named type such as *Foo where Foo is type
400// Foo int16, it converts dst to *int16.
401//
402// GetAssignToDstType returns the converted dst and a bool representing if any
403// change was made.
404func GetAssignToDstType(dst interface{}) (interface{}, bool) {
405	dstPtr := reflect.ValueOf(dst)
406
407	// AssignTo dst must always be a pointer
408	if dstPtr.Kind() != reflect.Ptr {
409		return nil, false
410	}
411
412	dstVal := dstPtr.Elem()
413
414	// if dst is a pointer to pointer, allocate space try again with the dereferenced pointer
415	if dstVal.Kind() == reflect.Ptr {
416		dstVal.Set(reflect.New(dstVal.Type().Elem()))
417		return dstVal.Interface(), true
418	}
419
420	// if dst is pointer to a base type that has been renamed
421	if baseValType, ok := kindTypes[dstVal.Kind()]; ok {
422		return toInterface(dstPtr, reflect.PtrTo(baseValType))
423	}
424
425	if dstVal.Kind() == reflect.Slice {
426		if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
427			return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType)))
428		}
429	}
430
431	if dstVal.Kind() == reflect.Array {
432		if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
433			return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)))
434		}
435	}
436
437	if dstVal.Kind() == reflect.Struct {
438		if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous {
439			dstPtr = dstVal.Field(0).Addr()
440			nested := dstVal.Type().Field(0).Type
441			if nested.Kind() == reflect.Array {
442				if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok {
443					return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType)))
444				}
445			}
446			if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() {
447				return dstPtr.Interface(), true
448			}
449		}
450	}
451
452	return nil, false
453}
454
455func init() {
456	kindTypes = map[reflect.Kind]reflect.Type{
457		reflect.Bool:    reflect.TypeOf(false),
458		reflect.Float32: reflect.TypeOf(float32(0)),
459		reflect.Float64: reflect.TypeOf(float64(0)),
460		reflect.Int:     reflect.TypeOf(int(0)),
461		reflect.Int8:    reflect.TypeOf(int8(0)),
462		reflect.Int16:   reflect.TypeOf(int16(0)),
463		reflect.Int32:   reflect.TypeOf(int32(0)),
464		reflect.Int64:   reflect.TypeOf(int64(0)),
465		reflect.Uint:    reflect.TypeOf(uint(0)),
466		reflect.Uint8:   reflect.TypeOf(uint8(0)),
467		reflect.Uint16:  reflect.TypeOf(uint16(0)),
468		reflect.Uint32:  reflect.TypeOf(uint32(0)),
469		reflect.Uint64:  reflect.TypeOf(uint64(0)),
470		reflect.String:  reflect.TypeOf(""),
471	}
472}
473