1package defaults 2 3import ( 4 "encoding/json" 5 "errors" 6 "reflect" 7 "strconv" 8 "time" 9) 10 11var ( 12 errInvalidType = errors.New("not a struct pointer") 13) 14 15const ( 16 fieldName = "default" 17) 18 19// Set initializes members in a struct referenced by a pointer. 20// Maps and slices are initialized by `make` and other primitive types are set with default values. 21// `ptr` should be a struct pointer 22func Set(ptr interface{}) error { 23 if reflect.TypeOf(ptr).Kind() != reflect.Ptr { 24 return errInvalidType 25 } 26 27 v := reflect.ValueOf(ptr).Elem() 28 t := v.Type() 29 30 if t.Kind() != reflect.Struct { 31 return errInvalidType 32 } 33 34 for i := 0; i < t.NumField(); i++ { 35 if defaultVal := t.Field(i).Tag.Get(fieldName); defaultVal != "-" { 36 if err := setField(v.Field(i), defaultVal); err != nil { 37 return err 38 } 39 } 40 } 41 callSetter(ptr) 42 return nil 43} 44 45// MustSet function is a wrapper of Set function 46// It will call Set and panic if err not equals nil. 47func MustSet(ptr interface{}) { 48 if err := Set(ptr); err != nil { 49 panic(err) 50 } 51} 52 53func setField(field reflect.Value, defaultVal string) error { 54 if !field.CanSet() { 55 return nil 56 } 57 58 if !shouldInitializeField(field, defaultVal) { 59 return nil 60 } 61 62 isInitial := isInitialValue(field) 63 if isInitial { 64 switch field.Kind() { 65 case reflect.Bool: 66 if val, err := strconv.ParseBool(defaultVal); err == nil { 67 field.Set(reflect.ValueOf(val).Convert(field.Type())) 68 } 69 case reflect.Int: 70 if val, err := strconv.ParseInt(defaultVal, 0, strconv.IntSize); err == nil { 71 field.Set(reflect.ValueOf(int(val)).Convert(field.Type())) 72 } 73 case reflect.Int8: 74 if val, err := strconv.ParseInt(defaultVal, 0, 8); err == nil { 75 field.Set(reflect.ValueOf(int8(val)).Convert(field.Type())) 76 } 77 case reflect.Int16: 78 if val, err := strconv.ParseInt(defaultVal, 0, 16); err == nil { 79 field.Set(reflect.ValueOf(int16(val)).Convert(field.Type())) 80 } 81 case reflect.Int32: 82 if val, err := strconv.ParseInt(defaultVal, 0, 32); err == nil { 83 field.Set(reflect.ValueOf(int32(val)).Convert(field.Type())) 84 } 85 case reflect.Int64: 86 if val, err := time.ParseDuration(defaultVal); err == nil { 87 field.Set(reflect.ValueOf(val).Convert(field.Type())) 88 } else if val, err := strconv.ParseInt(defaultVal, 0, 64); err == nil { 89 field.Set(reflect.ValueOf(val).Convert(field.Type())) 90 } 91 case reflect.Uint: 92 if val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize); err == nil { 93 field.Set(reflect.ValueOf(uint(val)).Convert(field.Type())) 94 } 95 case reflect.Uint8: 96 if val, err := strconv.ParseUint(defaultVal, 0, 8); err == nil { 97 field.Set(reflect.ValueOf(uint8(val)).Convert(field.Type())) 98 } 99 case reflect.Uint16: 100 if val, err := strconv.ParseUint(defaultVal, 0, 16); err == nil { 101 field.Set(reflect.ValueOf(uint16(val)).Convert(field.Type())) 102 } 103 case reflect.Uint32: 104 if val, err := strconv.ParseUint(defaultVal, 0, 32); err == nil { 105 field.Set(reflect.ValueOf(uint32(val)).Convert(field.Type())) 106 } 107 case reflect.Uint64: 108 if val, err := strconv.ParseUint(defaultVal, 0, 64); err == nil { 109 field.Set(reflect.ValueOf(val).Convert(field.Type())) 110 } 111 case reflect.Uintptr: 112 if val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize); err == nil { 113 field.Set(reflect.ValueOf(uintptr(val)).Convert(field.Type())) 114 } 115 case reflect.Float32: 116 if val, err := strconv.ParseFloat(defaultVal, 32); err == nil { 117 field.Set(reflect.ValueOf(float32(val)).Convert(field.Type())) 118 } 119 case reflect.Float64: 120 if val, err := strconv.ParseFloat(defaultVal, 64); err == nil { 121 field.Set(reflect.ValueOf(val).Convert(field.Type())) 122 } 123 case reflect.String: 124 field.Set(reflect.ValueOf(defaultVal).Convert(field.Type())) 125 126 case reflect.Slice: 127 ref := reflect.New(field.Type()) 128 ref.Elem().Set(reflect.MakeSlice(field.Type(), 0, 0)) 129 if defaultVal != "" && defaultVal != "[]" { 130 if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil { 131 return err 132 } 133 } 134 field.Set(ref.Elem().Convert(field.Type())) 135 case reflect.Map: 136 ref := reflect.New(field.Type()) 137 ref.Elem().Set(reflect.MakeMap(field.Type())) 138 if defaultVal != "" && defaultVal != "{}" { 139 if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil { 140 return err 141 } 142 } 143 field.Set(ref.Elem().Convert(field.Type())) 144 case reflect.Struct: 145 if defaultVal != "" && defaultVal != "{}" { 146 if err := json.Unmarshal([]byte(defaultVal), field.Addr().Interface()); err != nil { 147 return err 148 } 149 } 150 case reflect.Ptr: 151 field.Set(reflect.New(field.Type().Elem())) 152 } 153 } 154 155 switch field.Kind() { 156 case reflect.Ptr: 157 if isInitial || field.Elem().Kind() == reflect.Struct { 158 setField(field.Elem(), defaultVal) 159 callSetter(field.Interface()) 160 } 161 case reflect.Struct: 162 if err := Set(field.Addr().Interface()); err != nil { 163 return err 164 } 165 case reflect.Slice: 166 for j := 0; j < field.Len(); j++ { 167 if err := setField(field.Index(j), defaultVal); err != nil { 168 return err 169 } 170 } 171 } 172 173 return nil 174} 175 176func isInitialValue(field reflect.Value) bool { 177 return reflect.DeepEqual(reflect.Zero(field.Type()).Interface(), field.Interface()) 178} 179 180func shouldInitializeField(field reflect.Value, tag string) bool { 181 switch field.Kind() { 182 case reflect.Struct: 183 return true 184 case reflect.Ptr: 185 if !field.IsNil() && field.Elem().Kind() == reflect.Struct { 186 return true 187 } 188 case reflect.Slice: 189 return field.Len() > 0 || tag != "" 190 } 191 192 return tag != "" 193} 194 195// CanUpdate returns true when the given value is an initial value of its type 196func CanUpdate(v interface{}) bool { 197 return isInitialValue(reflect.ValueOf(v)) 198} 199