1package structure 2 3// references: https://github.com/mitchellh/mapstructure 4 5import ( 6 "fmt" 7 "reflect" 8 "strconv" 9 "strings" 10) 11 12// Option is the configuration that is used to create a new decoder 13type Option struct { 14 TagName string 15 WeaklyTypedInput bool 16} 17 18// Decoder is the core of structure 19type Decoder struct { 20 option *Option 21} 22 23// NewDecoder return a Decoder by Option 24func NewDecoder(option Option) *Decoder { 25 if option.TagName == "" { 26 option.TagName = "structure" 27 } 28 return &Decoder{option: &option} 29} 30 31// Decode transform a map[string]interface{} to a struct 32func (d *Decoder) Decode(src map[string]interface{}, dst interface{}) error { 33 if reflect.TypeOf(dst).Kind() != reflect.Ptr { 34 return fmt.Errorf("Decode must recive a ptr struct") 35 } 36 t := reflect.TypeOf(dst).Elem() 37 v := reflect.ValueOf(dst).Elem() 38 for idx := 0; idx < v.NumField(); idx++ { 39 field := t.Field(idx) 40 41 tag := field.Tag.Get(d.option.TagName) 42 str := strings.SplitN(tag, ",", 2) 43 key := str[0] 44 omitempty := false 45 if len(str) > 1 { 46 omitempty = str[1] == "omitempty" 47 } 48 49 value, ok := src[key] 50 if !ok || value == nil { 51 if omitempty { 52 continue 53 } 54 return fmt.Errorf("key '%s' missing", key) 55 } 56 57 err := d.decode(key, value, v.Field(idx)) 58 if err != nil { 59 return err 60 } 61 } 62 return nil 63} 64 65func (d *Decoder) decode(name string, data interface{}, val reflect.Value) error { 66 switch val.Kind() { 67 case reflect.Int: 68 return d.decodeInt(name, data, val) 69 case reflect.String: 70 return d.decodeString(name, data, val) 71 case reflect.Bool: 72 return d.decodeBool(name, data, val) 73 case reflect.Slice: 74 return d.decodeSlice(name, data, val) 75 case reflect.Map: 76 return d.decodeMap(name, data, val) 77 case reflect.Interface: 78 return d.setInterface(name, data, val) 79 case reflect.Struct: 80 return d.decodeStruct(name, data, val) 81 default: 82 return fmt.Errorf("type %s not support", val.Kind().String()) 83 } 84} 85 86func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) (err error) { 87 dataVal := reflect.ValueOf(data) 88 kind := dataVal.Kind() 89 switch { 90 case kind == reflect.Int: 91 val.SetInt(dataVal.Int()) 92 case kind == reflect.String && d.option.WeaklyTypedInput: 93 var i int64 94 i, err = strconv.ParseInt(dataVal.String(), 0, val.Type().Bits()) 95 if err == nil { 96 val.SetInt(i) 97 } else { 98 err = fmt.Errorf("cannot parse '%s' as int: %s", name, err) 99 } 100 default: 101 err = fmt.Errorf( 102 "'%s' expected type '%s', got unconvertible type '%s'", 103 name, val.Type(), dataVal.Type(), 104 ) 105 } 106 return err 107} 108 109func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) (err error) { 110 dataVal := reflect.ValueOf(data) 111 kind := dataVal.Kind() 112 switch { 113 case kind == reflect.String: 114 val.SetString(dataVal.String()) 115 case kind == reflect.Int && d.option.WeaklyTypedInput: 116 val.SetString(strconv.FormatInt(dataVal.Int(), 10)) 117 default: 118 err = fmt.Errorf( 119 "'%s' expected type '%s', got unconvertible type '%s'", 120 name, val.Type(), dataVal.Type(), 121 ) 122 } 123 return err 124} 125 126func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) (err error) { 127 dataVal := reflect.ValueOf(data) 128 kind := dataVal.Kind() 129 switch { 130 case kind == reflect.Bool: 131 val.SetBool(dataVal.Bool()) 132 case kind == reflect.Int && d.option.WeaklyTypedInput: 133 val.SetBool(dataVal.Int() != 0) 134 default: 135 err = fmt.Errorf( 136 "'%s' expected type '%s', got unconvertible type '%s'", 137 name, val.Type(), dataVal.Type(), 138 ) 139 } 140 return err 141} 142 143func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error { 144 dataVal := reflect.Indirect(reflect.ValueOf(data)) 145 valType := val.Type() 146 valElemType := valType.Elem() 147 148 if dataVal.Kind() != reflect.Slice { 149 return fmt.Errorf("'%s' is not a slice", name) 150 } 151 152 valSlice := val 153 for i := 0; i < dataVal.Len(); i++ { 154 currentData := dataVal.Index(i).Interface() 155 for valSlice.Len() <= i { 156 valSlice = reflect.Append(valSlice, reflect.Zero(valElemType)) 157 } 158 currentField := valSlice.Index(i) 159 160 fieldName := fmt.Sprintf("%s[%d]", name, i) 161 if err := d.decode(fieldName, currentData, currentField); err != nil { 162 return err 163 } 164 } 165 166 val.Set(valSlice) 167 return nil 168} 169 170func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error { 171 valType := val.Type() 172 valKeyType := valType.Key() 173 valElemType := valType.Elem() 174 175 valMap := val 176 177 if valMap.IsNil() { 178 mapType := reflect.MapOf(valKeyType, valElemType) 179 valMap = reflect.MakeMap(mapType) 180 } 181 182 dataVal := reflect.Indirect(reflect.ValueOf(data)) 183 if dataVal.Kind() != reflect.Map { 184 return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) 185 } 186 187 return d.decodeMapFromMap(name, dataVal, val, valMap) 188} 189 190func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { 191 valType := val.Type() 192 valKeyType := valType.Key() 193 valElemType := valType.Elem() 194 195 errors := make([]string, 0) 196 197 if dataVal.Len() == 0 { 198 if dataVal.IsNil() { 199 if !val.IsNil() { 200 val.Set(dataVal) 201 } 202 } else { 203 val.Set(valMap) 204 } 205 206 return nil 207 } 208 209 for _, k := range dataVal.MapKeys() { 210 fieldName := fmt.Sprintf("%s[%s]", name, k) 211 212 currentKey := reflect.Indirect(reflect.New(valKeyType)) 213 if err := d.decode(fieldName, k.Interface(), currentKey); err != nil { 214 errors = append(errors, err.Error()) 215 continue 216 } 217 218 v := dataVal.MapIndex(k).Interface() 219 if v == nil { 220 errors = append(errors, fmt.Sprintf("filed %s invalid", fieldName)) 221 continue 222 } 223 224 currentVal := reflect.Indirect(reflect.New(valElemType)) 225 if err := d.decode(fieldName, v, currentVal); err != nil { 226 errors = append(errors, err.Error()) 227 continue 228 } 229 230 valMap.SetMapIndex(currentKey, currentVal) 231 } 232 233 val.Set(valMap) 234 235 if len(errors) > 0 { 236 return fmt.Errorf(strings.Join(errors, ",")) 237 } 238 239 return nil 240} 241 242func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error { 243 dataVal := reflect.Indirect(reflect.ValueOf(data)) 244 245 // If the type of the value to write to and the data match directly, 246 // then we just set it directly instead of recursing into the structure. 247 if dataVal.Type() == val.Type() { 248 val.Set(dataVal) 249 return nil 250 } 251 252 dataValKind := dataVal.Kind() 253 switch dataValKind { 254 case reflect.Map: 255 return d.decodeStructFromMap(name, dataVal, val) 256 default: 257 return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind()) 258 } 259} 260 261func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) error { 262 dataValType := dataVal.Type() 263 if kind := dataValType.Key().Kind(); kind != reflect.String && kind != reflect.Interface { 264 return fmt.Errorf( 265 "'%s' needs a map with string keys, has '%s' keys", 266 name, dataValType.Key().Kind()) 267 } 268 269 dataValKeys := make(map[reflect.Value]struct{}) 270 dataValKeysUnused := make(map[interface{}]struct{}) 271 for _, dataValKey := range dataVal.MapKeys() { 272 dataValKeys[dataValKey] = struct{}{} 273 dataValKeysUnused[dataValKey.Interface()] = struct{}{} 274 } 275 276 errors := make([]string, 0) 277 278 // This slice will keep track of all the structs we'll be decoding. 279 // There can be more than one struct if there are embedded structs 280 // that are squashed. 281 structs := make([]reflect.Value, 1, 5) 282 structs[0] = val 283 284 // Compile the list of all the fields that we're going to be decoding 285 // from all the structs. 286 type field struct { 287 field reflect.StructField 288 val reflect.Value 289 } 290 fields := []field{} 291 for len(structs) > 0 { 292 structVal := structs[0] 293 structs = structs[1:] 294 295 structType := structVal.Type() 296 297 for i := 0; i < structType.NumField(); i++ { 298 fieldType := structType.Field(i) 299 fieldKind := fieldType.Type.Kind() 300 301 // If "squash" is specified in the tag, we squash the field down. 302 squash := false 303 tagParts := strings.Split(fieldType.Tag.Get(d.option.TagName), ",") 304 for _, tag := range tagParts[1:] { 305 if tag == "squash" { 306 squash = true 307 break 308 } 309 } 310 311 if squash { 312 if fieldKind != reflect.Struct { 313 errors = append(errors, 314 fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldKind).Error()) 315 } else { 316 structs = append(structs, structVal.FieldByName(fieldType.Name)) 317 } 318 continue 319 } 320 321 // Normal struct field, store it away 322 fields = append(fields, field{fieldType, structVal.Field(i)}) 323 } 324 } 325 326 // for fieldType, field := range fields { 327 for _, f := range fields { 328 field, fieldValue := f.field, f.val 329 fieldName := field.Name 330 331 tagValue := field.Tag.Get(d.option.TagName) 332 tagValue = strings.SplitN(tagValue, ",", 2)[0] 333 if tagValue != "" { 334 fieldName = tagValue 335 } 336 337 rawMapKey := reflect.ValueOf(fieldName) 338 rawMapVal := dataVal.MapIndex(rawMapKey) 339 if !rawMapVal.IsValid() { 340 // Do a slower search by iterating over each key and 341 // doing case-insensitive search. 342 for dataValKey := range dataValKeys { 343 mK, ok := dataValKey.Interface().(string) 344 if !ok { 345 // Not a string key 346 continue 347 } 348 349 if strings.EqualFold(mK, fieldName) { 350 rawMapKey = dataValKey 351 rawMapVal = dataVal.MapIndex(dataValKey) 352 break 353 } 354 } 355 356 if !rawMapVal.IsValid() { 357 // There was no matching key in the map for the value in 358 // the struct. Just ignore. 359 continue 360 } 361 } 362 363 // Delete the key we're using from the unused map so we stop tracking 364 delete(dataValKeysUnused, rawMapKey.Interface()) 365 366 if !fieldValue.IsValid() { 367 // This should never happen 368 panic("field is not valid") 369 } 370 371 // If we can't set the field, then it is unexported or something, 372 // and we just continue onwards. 373 if !fieldValue.CanSet() { 374 continue 375 } 376 377 // If the name is empty string, then we're at the root, and we 378 // don't dot-join the fields. 379 if name != "" { 380 fieldName = fmt.Sprintf("%s.%s", name, fieldName) 381 } 382 383 if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil { 384 errors = append(errors, err.Error()) 385 } 386 } 387 388 if len(errors) > 0 { 389 return fmt.Errorf(strings.Join(errors, ",")) 390 } 391 392 return nil 393} 394 395func (d *Decoder) setInterface(name string, data interface{}, val reflect.Value) (err error) { 396 dataVal := reflect.ValueOf(data) 397 val.Set(dataVal) 398 return nil 399} 400