1package defaults
2
3import (
4	"encoding/json"
5	"errors"
6	"reflect"
7	"strconv"
8	"time"
9)
10
11var (
12	errInvalidType = errors.New("not a struct pointer")
13)
14
15const (
16	fieldName = "default"
17)
18
19// Set initializes members in a struct referenced by a pointer.
20// Maps and slices are initialized by `make` and other primitive types are set with default values.
21// `ptr` should be a struct pointer
22func Set(ptr interface{}) error {
23	if reflect.TypeOf(ptr).Kind() != reflect.Ptr {
24		return errInvalidType
25	}
26
27	v := reflect.ValueOf(ptr).Elem()
28	t := v.Type()
29
30	if t.Kind() != reflect.Struct {
31		return errInvalidType
32	}
33
34	for i := 0; i < t.NumField(); i++ {
35		if defaultVal := t.Field(i).Tag.Get(fieldName); defaultVal != "-" {
36			if err := setField(v.Field(i), defaultVal); err != nil {
37				return err
38			}
39		}
40	}
41	callSetter(ptr)
42	return nil
43}
44
45// MustSet function is a wrapper of Set function
46// It will call Set and panic if err not equals nil.
47func MustSet(ptr interface{}) {
48	if err := Set(ptr); err != nil {
49		panic(err)
50	}
51}
52
53func setField(field reflect.Value, defaultVal string) error {
54	if !field.CanSet() {
55		return nil
56	}
57
58	if !shouldInitializeField(field, defaultVal) {
59		return nil
60	}
61
62	isInitial := isInitialValue(field)
63	if isInitial {
64		switch field.Kind() {
65		case reflect.Bool:
66			if val, err := strconv.ParseBool(defaultVal); err == nil {
67				field.Set(reflect.ValueOf(val).Convert(field.Type()))
68			}
69		case reflect.Int:
70			if val, err := strconv.ParseInt(defaultVal, 0, strconv.IntSize); err == nil {
71				field.Set(reflect.ValueOf(int(val)).Convert(field.Type()))
72			}
73		case reflect.Int8:
74			if val, err := strconv.ParseInt(defaultVal, 0, 8); err == nil {
75				field.Set(reflect.ValueOf(int8(val)).Convert(field.Type()))
76			}
77		case reflect.Int16:
78			if val, err := strconv.ParseInt(defaultVal, 0, 16); err == nil {
79				field.Set(reflect.ValueOf(int16(val)).Convert(field.Type()))
80			}
81		case reflect.Int32:
82			if val, err := strconv.ParseInt(defaultVal, 0, 32); err == nil {
83				field.Set(reflect.ValueOf(int32(val)).Convert(field.Type()))
84			}
85		case reflect.Int64:
86			if val, err := time.ParseDuration(defaultVal); err == nil {
87				field.Set(reflect.ValueOf(val).Convert(field.Type()))
88			} else if val, err := strconv.ParseInt(defaultVal, 0, 64); err == nil {
89				field.Set(reflect.ValueOf(val).Convert(field.Type()))
90			}
91		case reflect.Uint:
92			if val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize); err == nil {
93				field.Set(reflect.ValueOf(uint(val)).Convert(field.Type()))
94			}
95		case reflect.Uint8:
96			if val, err := strconv.ParseUint(defaultVal, 0, 8); err == nil {
97				field.Set(reflect.ValueOf(uint8(val)).Convert(field.Type()))
98			}
99		case reflect.Uint16:
100			if val, err := strconv.ParseUint(defaultVal, 0, 16); err == nil {
101				field.Set(reflect.ValueOf(uint16(val)).Convert(field.Type()))
102			}
103		case reflect.Uint32:
104			if val, err := strconv.ParseUint(defaultVal, 0, 32); err == nil {
105				field.Set(reflect.ValueOf(uint32(val)).Convert(field.Type()))
106			}
107		case reflect.Uint64:
108			if val, err := strconv.ParseUint(defaultVal, 0, 64); err == nil {
109				field.Set(reflect.ValueOf(val).Convert(field.Type()))
110			}
111		case reflect.Uintptr:
112			if val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize); err == nil {
113				field.Set(reflect.ValueOf(uintptr(val)).Convert(field.Type()))
114			}
115		case reflect.Float32:
116			if val, err := strconv.ParseFloat(defaultVal, 32); err == nil {
117				field.Set(reflect.ValueOf(float32(val)).Convert(field.Type()))
118			}
119		case reflect.Float64:
120			if val, err := strconv.ParseFloat(defaultVal, 64); err == nil {
121				field.Set(reflect.ValueOf(val).Convert(field.Type()))
122			}
123		case reflect.String:
124			field.Set(reflect.ValueOf(defaultVal).Convert(field.Type()))
125
126		case reflect.Slice:
127			ref := reflect.New(field.Type())
128			ref.Elem().Set(reflect.MakeSlice(field.Type(), 0, 0))
129			if defaultVal != "" && defaultVal != "[]" {
130				if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil {
131					return err
132				}
133			}
134			field.Set(ref.Elem().Convert(field.Type()))
135		case reflect.Map:
136			ref := reflect.New(field.Type())
137			ref.Elem().Set(reflect.MakeMap(field.Type()))
138			if defaultVal != "" && defaultVal != "{}" {
139				if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil {
140					return err
141				}
142			}
143			field.Set(ref.Elem().Convert(field.Type()))
144		case reflect.Struct:
145			if defaultVal != "" && defaultVal != "{}" {
146				if err := json.Unmarshal([]byte(defaultVal), field.Addr().Interface()); err != nil {
147					return err
148				}
149			}
150		case reflect.Ptr:
151			field.Set(reflect.New(field.Type().Elem()))
152		}
153	}
154
155	switch field.Kind() {
156	case reflect.Ptr:
157		if isInitial || field.Elem().Kind() == reflect.Struct {
158			setField(field.Elem(), defaultVal)
159			callSetter(field.Interface())
160		}
161	case reflect.Struct:
162		if err := Set(field.Addr().Interface()); err != nil {
163			return err
164		}
165	case reflect.Slice:
166		for j := 0; j < field.Len(); j++ {
167			if err := setField(field.Index(j), defaultVal); err != nil {
168				return err
169			}
170		}
171	}
172
173	return nil
174}
175
176func isInitialValue(field reflect.Value) bool {
177	return reflect.DeepEqual(reflect.Zero(field.Type()).Interface(), field.Interface())
178}
179
180func shouldInitializeField(field reflect.Value, tag string) bool {
181	switch field.Kind() {
182	case reflect.Struct:
183		return true
184	case reflect.Ptr:
185		if !field.IsNil() && field.Elem().Kind() == reflect.Struct {
186			return true
187		}
188	case reflect.Slice:
189		return field.Len() > 0 || tag != ""
190	}
191
192	return tag != ""
193}
194
195// CanUpdate returns true when the given value is an initial value of its type
196func CanUpdate(v interface{}) bool {
197	return isInitialValue(reflect.ValueOf(v))
198}
199