1package copier 2 3import ( 4 "database/sql" 5 "database/sql/driver" 6 "fmt" 7 "reflect" 8 "strings" 9) 10 11// These flags define options for tag handling 12const ( 13 // Denotes that a destination field must be copied to. If copying fails then a panic will ensue. 14 tagMust uint8 = 1 << iota 15 16 // Denotes that the program should not panic when the must flag is on and 17 // value is not copied. The program will return an error instead. 18 tagNoPanic 19 20 // Ignore a destination field from being copied to. 21 tagIgnore 22 23 // Denotes that the value as been copied 24 hasCopied 25) 26 27// Option sets copy options 28type Option struct { 29 // setting this value to true will ignore copying zero values of all the fields, including bools, as well as a 30 // struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go) 31 IgnoreEmpty bool 32 DeepCopy bool 33} 34 35// Copy copy things 36func Copy(toValue interface{}, fromValue interface{}) (err error) { 37 return copier(toValue, fromValue, Option{}) 38} 39 40// CopyWithOption copy with option 41func CopyWithOption(toValue interface{}, fromValue interface{}, opt Option) (err error) { 42 return copier(toValue, fromValue, opt) 43} 44 45func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) { 46 var ( 47 isSlice bool 48 amount = 1 49 from = indirect(reflect.ValueOf(fromValue)) 50 to = indirect(reflect.ValueOf(toValue)) 51 ) 52 53 if !to.CanAddr() { 54 return ErrInvalidCopyDestination 55 } 56 57 // Return is from value is invalid 58 if !from.IsValid() { 59 return ErrInvalidCopyFrom 60 } 61 62 fromType, isPtrFrom := indirectType(from.Type()) 63 toType, _ := indirectType(to.Type()) 64 65 if fromType.Kind() == reflect.Interface { 66 fromType = reflect.TypeOf(from.Interface()) 67 } 68 69 if toType.Kind() == reflect.Interface { 70 toType, _ = indirectType(reflect.TypeOf(to.Interface())) 71 oldTo := to 72 to = reflect.New(reflect.TypeOf(to.Interface())).Elem() 73 defer func() { 74 oldTo.Set(to) 75 }() 76 } 77 78 // Just set it if possible to assign for normal types 79 if from.Kind() != reflect.Slice && from.Kind() != reflect.Struct && from.Kind() != reflect.Map && (from.Type().AssignableTo(to.Type()) || from.Type().ConvertibleTo(to.Type())) { 80 if !isPtrFrom || !opt.DeepCopy { 81 to.Set(from.Convert(to.Type())) 82 } else { 83 fromCopy := reflect.New(from.Type()) 84 fromCopy.Set(from.Elem()) 85 to.Set(fromCopy.Convert(to.Type())) 86 } 87 return 88 } 89 90 if fromType.Kind() == reflect.Map && toType.Kind() == reflect.Map { 91 if !fromType.Key().ConvertibleTo(toType.Key()) { 92 return ErrMapKeyNotMatch 93 } 94 95 if to.IsNil() { 96 to.Set(reflect.MakeMapWithSize(toType, from.Len())) 97 } 98 99 for _, k := range from.MapKeys() { 100 toKey := indirect(reflect.New(toType.Key())) 101 if !set(toKey, k, opt.DeepCopy) { 102 return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key()) 103 } 104 105 elemType, _ := indirectType(toType.Elem()) 106 toValue := indirect(reflect.New(elemType)) 107 if !set(toValue, from.MapIndex(k), opt.DeepCopy) { 108 if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil { 109 return err 110 } 111 } 112 113 for { 114 if elemType == toType.Elem() { 115 to.SetMapIndex(toKey, toValue) 116 break 117 } 118 elemType = reflect.PtrTo(elemType) 119 toValue = toValue.Addr() 120 } 121 } 122 return 123 } 124 125 if from.Kind() == reflect.Slice && to.Kind() == reflect.Slice && fromType.ConvertibleTo(toType) { 126 if to.IsNil() { 127 slice := reflect.MakeSlice(reflect.SliceOf(to.Type().Elem()), from.Len(), from.Cap()) 128 to.Set(slice) 129 } 130 131 for i := 0; i < from.Len(); i++ { 132 if to.Len() < i+1 { 133 to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem())) 134 } 135 136 if !set(to.Index(i), from.Index(i), opt.DeepCopy) { 137 err = CopyWithOption(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt) 138 if err != nil { 139 continue 140 } 141 } 142 } 143 return 144 } 145 146 if fromType.Kind() != reflect.Struct || toType.Kind() != reflect.Struct { 147 // skip not supported type 148 return 149 } 150 151 if to.Kind() == reflect.Slice { 152 isSlice = true 153 if from.Kind() == reflect.Slice { 154 amount = from.Len() 155 } 156 } 157 158 for i := 0; i < amount; i++ { 159 var dest, source reflect.Value 160 161 if isSlice { 162 // source 163 if from.Kind() == reflect.Slice { 164 source = indirect(from.Index(i)) 165 } else { 166 source = indirect(from) 167 } 168 // dest 169 dest = indirect(reflect.New(toType).Elem()) 170 } else { 171 source = indirect(from) 172 dest = indirect(to) 173 } 174 175 destKind := dest.Kind() 176 initDest := false 177 if destKind == reflect.Interface { 178 initDest = true 179 dest = indirect(reflect.New(toType)) 180 } 181 182 // Get tag options 183 tagBitFlags := map[string]uint8{} 184 if dest.IsValid() { 185 tagBitFlags = getBitFlags(toType) 186 } 187 188 // check source 189 if source.IsValid() { 190 // Copy from source field to dest field or method 191 fromTypeFields := deepFields(fromType) 192 for _, field := range fromTypeFields { 193 name := field.Name 194 195 // Get bit flags for field 196 fieldFlags, _ := tagBitFlags[name] 197 198 // Check if we should ignore copying 199 if (fieldFlags & tagIgnore) != 0 { 200 continue 201 } 202 203 if fromField := source.FieldByName(name); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) { 204 // process for nested anonymous field 205 destFieldNotSet := false 206 if f, ok := dest.Type().FieldByName(name); ok { 207 for idx := range f.Index { 208 destField := dest.FieldByIndex(f.Index[:idx+1]) 209 210 if destField.Kind() != reflect.Ptr { 211 continue 212 } 213 214 if !destField.IsNil() { 215 continue 216 } 217 if !destField.CanSet() { 218 destFieldNotSet = true 219 break 220 } 221 222 // destField is a nil pointer that can be set 223 newValue := reflect.New(destField.Type().Elem()) 224 destField.Set(newValue) 225 } 226 } 227 228 if destFieldNotSet { 229 break 230 } 231 232 toField := dest.FieldByName(name) 233 if toField.IsValid() { 234 if toField.CanSet() { 235 if !set(toField, fromField, opt.DeepCopy) { 236 if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil { 237 return err 238 } 239 } 240 if fieldFlags != 0 { 241 // Note that a copy was made 242 tagBitFlags[name] = fieldFlags | hasCopied 243 } 244 } 245 } else { 246 // try to set to method 247 var toMethod reflect.Value 248 if dest.CanAddr() { 249 toMethod = dest.Addr().MethodByName(name) 250 } else { 251 toMethod = dest.MethodByName(name) 252 } 253 254 if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && fromField.Type().AssignableTo(toMethod.Type().In(0)) { 255 toMethod.Call([]reflect.Value{fromField}) 256 } 257 } 258 } 259 } 260 261 // Copy from from method to dest field 262 for _, field := range deepFields(toType) { 263 name := field.Name 264 265 var fromMethod reflect.Value 266 if source.CanAddr() { 267 fromMethod = source.Addr().MethodByName(name) 268 } else { 269 fromMethod = source.MethodByName(name) 270 } 271 272 if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 && !shouldIgnore(fromMethod, opt.IgnoreEmpty) { 273 if toField := dest.FieldByName(name); toField.IsValid() && toField.CanSet() { 274 values := fromMethod.Call([]reflect.Value{}) 275 if len(values) >= 1 { 276 set(toField, values[0], opt.DeepCopy) 277 } 278 } 279 } 280 } 281 } 282 283 if isSlice { 284 if dest.Addr().Type().AssignableTo(to.Type().Elem()) { 285 to.Set(reflect.Append(to, dest.Addr())) 286 } else if dest.Type().AssignableTo(to.Type().Elem()) { 287 to.Set(reflect.Append(to, dest)) 288 } 289 } else if initDest { 290 to.Set(dest) 291 } 292 293 err = checkBitFlags(tagBitFlags) 294 } 295 296 return 297} 298 299func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool { 300 if !ignoreEmpty { 301 return false 302 } 303 304 return v.IsZero() 305} 306 307func deepFields(reflectType reflect.Type) []reflect.StructField { 308 if reflectType, _ = indirectType(reflectType); reflectType.Kind() == reflect.Struct { 309 fields := make([]reflect.StructField, 0, reflectType.NumField()) 310 311 for i := 0; i < reflectType.NumField(); i++ { 312 v := reflectType.Field(i) 313 if v.Anonymous { 314 fields = append(fields, deepFields(v.Type)...) 315 } else { 316 fields = append(fields, v) 317 } 318 } 319 320 return fields 321 } 322 323 return nil 324} 325 326func indirect(reflectValue reflect.Value) reflect.Value { 327 for reflectValue.Kind() == reflect.Ptr { 328 reflectValue = reflectValue.Elem() 329 } 330 return reflectValue 331} 332 333func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) { 334 for reflectType.Kind() == reflect.Ptr || reflectType.Kind() == reflect.Slice { 335 reflectType = reflectType.Elem() 336 isPtr = true 337 } 338 return reflectType, isPtr 339} 340 341func set(to, from reflect.Value, deepCopy bool) bool { 342 if from.IsValid() { 343 if to.Kind() == reflect.Ptr { 344 // set `to` to nil if from is nil 345 if from.Kind() == reflect.Ptr && from.IsNil() { 346 to.Set(reflect.Zero(to.Type())) 347 return true 348 } else if to.IsNil() { 349 // `from` -> `to` 350 // sql.NullString -> *string 351 if fromValuer, ok := driverValuer(from); ok { 352 v, err := fromValuer.Value() 353 if err != nil { 354 return false 355 } 356 // if `from` is not valid do nothing with `to` 357 if v == nil { 358 return true 359 } 360 } 361 // allocate new `to` variable with default value (eg. *string -> new(string)) 362 to.Set(reflect.New(to.Type().Elem())) 363 } 364 // depointer `to` 365 to = to.Elem() 366 } 367 368 if deepCopy { 369 toKind := to.Kind() 370 if toKind == reflect.Interface && to.IsNil() { 371 to.Set(reflect.New(reflect.TypeOf(from.Interface())).Elem()) 372 toKind = reflect.TypeOf(to.Interface()).Kind() 373 } 374 if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice { 375 return false 376 } 377 } 378 379 if from.Type().ConvertibleTo(to.Type()) { 380 to.Set(from.Convert(to.Type())) 381 } else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok { 382 // `from` -> `to` 383 // *string -> sql.NullString 384 if from.Kind() == reflect.Ptr { 385 // if `from` is nil do nothing with `to` 386 if from.IsNil() { 387 return true 388 } 389 // depointer `from` 390 from = indirect(from) 391 } 392 // `from` -> `to` 393 // string -> sql.NullString 394 // set `to` by invoking method Scan(`from`) 395 err := toScanner.Scan(from.Interface()) 396 if err != nil { 397 return false 398 } 399 } else if fromValuer, ok := driverValuer(from); ok { 400 // `from` -> `to` 401 // sql.NullString -> string 402 v, err := fromValuer.Value() 403 if err != nil { 404 return false 405 } 406 // if `from` is not valid do nothing with `to` 407 if v == nil { 408 return true 409 } 410 rv := reflect.ValueOf(v) 411 if rv.Type().AssignableTo(to.Type()) { 412 to.Set(rv) 413 } 414 } else if from.Kind() == reflect.Ptr { 415 return set(to, from.Elem(), deepCopy) 416 } else { 417 return false 418 } 419 } 420 421 return true 422} 423 424// parseTags Parses struct tags and returns uint8 bit flags. 425func parseTags(tag string) (flags uint8) { 426 for _, t := range strings.Split(tag, ",") { 427 switch t { 428 case "-": 429 flags = tagIgnore 430 return 431 case "must": 432 flags = flags | tagMust 433 case "nopanic": 434 flags = flags | tagNoPanic 435 } 436 } 437 return 438} 439 440// getBitFlags Parses struct tags for bit flags. 441func getBitFlags(toType reflect.Type) map[string]uint8 { 442 flags := map[string]uint8{} 443 toTypeFields := deepFields(toType) 444 445 // Get a list dest of tags 446 for _, field := range toTypeFields { 447 tags := field.Tag.Get("copier") 448 if tags != "" { 449 flags[field.Name] = parseTags(tags) 450 } 451 } 452 return flags 453} 454 455// checkBitFlags Checks flags for error or panic conditions. 456func checkBitFlags(flagsList map[string]uint8) (err error) { 457 // Check flag conditions were met 458 for name, flags := range flagsList { 459 if flags&hasCopied == 0 { 460 switch { 461 case flags&tagMust != 0 && flags&tagNoPanic != 0: 462 err = fmt.Errorf("field %s has must tag but was not copied", name) 463 return 464 case flags&(tagMust) != 0: 465 panic(fmt.Sprintf("Field %s has must tag but was not copied", name)) 466 } 467 } 468 } 469 return 470} 471 472func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) { 473 474 if !v.CanAddr() { 475 i, ok = v.Interface().(driver.Valuer) 476 return 477 } 478 479 i, ok = v.Addr().Interface().(driver.Valuer) 480 return 481} 482