1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package bsoncodec 8 9import ( 10 "errors" 11 "fmt" 12 "reflect" 13 "strings" 14 "sync" 15 "time" 16 17 "go.mongodb.org/mongo-driver/bson/bsonoptions" 18 "go.mongodb.org/mongo-driver/bson/bsonrw" 19 "go.mongodb.org/mongo-driver/bson/bsontype" 20) 21 22var defaultStructCodec = &StructCodec{ 23 cache: make(map[reflect.Type]*structDescription), 24 parser: DefaultStructTagParser, 25} 26 27// Zeroer allows custom struct types to implement a report of zero 28// state. All struct types that don't implement Zeroer or where IsZero 29// returns false are considered to be not zero. 30type Zeroer interface { 31 IsZero() bool 32} 33 34// StructCodec is the Codec used for struct values. 35type StructCodec struct { 36 cache map[reflect.Type]*structDescription 37 l sync.RWMutex 38 parser StructTagParser 39 DecodeZeroStruct bool 40 DecodeDeepZeroInline bool 41 EncodeOmitDefaultStruct bool 42 AllowUnexportedFields bool 43} 44 45var _ ValueEncoder = &StructCodec{} 46var _ ValueDecoder = &StructCodec{} 47 48// NewStructCodec returns a StructCodec that uses p for struct tag parsing. 49func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) { 50 if p == nil { 51 return nil, errors.New("a StructTagParser must be provided to NewStructCodec") 52 } 53 54 structOpt := bsonoptions.MergeStructCodecOptions(opts...) 55 56 codec := &StructCodec{ 57 cache: make(map[reflect.Type]*structDescription), 58 parser: p, 59 } 60 61 if structOpt.DecodeZeroStruct != nil { 62 codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct 63 } 64 if structOpt.DecodeDeepZeroInline != nil { 65 codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline 66 } 67 if structOpt.EncodeOmitDefaultStruct != nil { 68 codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct 69 } 70 if structOpt.AllowUnexportedFields != nil { 71 codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields 72 } 73 74 return codec, nil 75} 76 77// EncodeValue handles encoding generic struct types. 78func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { 79 if !val.IsValid() || val.Kind() != reflect.Struct { 80 return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} 81 } 82 83 sd, err := sc.describeStruct(r.Registry, val.Type()) 84 if err != nil { 85 return err 86 } 87 88 dw, err := vw.WriteDocument() 89 if err != nil { 90 return err 91 } 92 var rv reflect.Value 93 for _, desc := range sd.fl { 94 if desc.inline == nil { 95 rv = val.Field(desc.idx) 96 } else { 97 rv, err = fieldByIndexErr(val, desc.inline) 98 if err != nil { 99 continue 100 } 101 } 102 103 desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv) 104 105 if err != nil && err != errInvalidValue { 106 return err 107 } 108 109 if err == errInvalidValue { 110 if desc.omitEmpty { 111 continue 112 } 113 vw2, err := dw.WriteDocumentElement(desc.name) 114 if err != nil { 115 return err 116 } 117 err = vw2.WriteNull() 118 if err != nil { 119 return err 120 } 121 continue 122 } 123 124 if desc.encoder == nil { 125 return ErrNoEncoder{Type: rv.Type()} 126 } 127 128 encoder := desc.encoder 129 130 var isZero bool 131 rvInterface := rv.Interface() 132 if cz, ok := encoder.(CodecZeroer); ok { 133 isZero = cz.IsTypeZero(rvInterface) 134 } else if rv.Kind() == reflect.Interface { 135 // sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately. 136 isZero = rv.IsNil() 137 } else { 138 isZero = sc.isZero(rvInterface) 139 } 140 if desc.omitEmpty && isZero { 141 continue 142 } 143 144 vw2, err := dw.WriteDocumentElement(desc.name) 145 if err != nil { 146 return err 147 } 148 149 ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize} 150 err = encoder.EncodeValue(ectx, vw2, rv) 151 if err != nil { 152 return err 153 } 154 } 155 156 if sd.inlineMap >= 0 { 157 rv := val.Field(sd.inlineMap) 158 collisionFn := func(key string) bool { 159 _, exists := sd.fm[key] 160 return exists 161 } 162 163 return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn) 164 } 165 166 return dw.WriteDocumentEnd() 167} 168 169// DecodeValue implements the Codec interface. 170// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. 171// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. 172func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { 173 if !val.CanSet() || val.Kind() != reflect.Struct { 174 return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} 175 } 176 177 switch vr.Type() { 178 case bsontype.Type(0), bsontype.EmbeddedDocument: 179 default: 180 return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type()) 181 } 182 183 sd, err := sc.describeStruct(r.Registry, val.Type()) 184 if err != nil { 185 return err 186 } 187 188 if sc.DecodeZeroStruct { 189 val.Set(reflect.Zero(val.Type())) 190 } 191 if sc.DecodeDeepZeroInline && sd.inline { 192 val.Set(deepZero(val.Type())) 193 } 194 195 var decoder ValueDecoder 196 var inlineMap reflect.Value 197 if sd.inlineMap >= 0 { 198 inlineMap = val.Field(sd.inlineMap) 199 decoder, err = r.LookupDecoder(inlineMap.Type().Elem()) 200 if err != nil { 201 return err 202 } 203 } 204 205 dr, err := vr.ReadDocument() 206 if err != nil { 207 return err 208 } 209 210 for { 211 name, vr, err := dr.ReadElement() 212 if err == bsonrw.ErrEOD { 213 break 214 } 215 if err != nil { 216 return err 217 } 218 219 fd, exists := sd.fm[name] 220 if !exists { 221 // if the original name isn't found in the struct description, try again with the name in lowercase 222 // this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field 223 // names 224 fd, exists = sd.fm[strings.ToLower(name)] 225 } 226 227 if !exists { 228 if sd.inlineMap < 0 { 229 // The encoding/json package requires a flag to return on error for non-existent fields. 230 // This functionality seems appropriate for the struct codec. 231 err = vr.Skip() 232 if err != nil { 233 return err 234 } 235 continue 236 } 237 238 if inlineMap.IsNil() { 239 inlineMap.Set(reflect.MakeMap(inlineMap.Type())) 240 } 241 242 elem := reflect.New(inlineMap.Type().Elem()).Elem() 243 err = decoder.DecodeValue(r, vr, elem) 244 if err != nil { 245 return err 246 } 247 inlineMap.SetMapIndex(reflect.ValueOf(name), elem) 248 continue 249 } 250 251 var field reflect.Value 252 if fd.inline == nil { 253 field = val.Field(fd.idx) 254 } else { 255 field, err = getInlineField(val, fd.inline) 256 if err != nil { 257 return err 258 } 259 } 260 261 if !field.CanSet() { // Being settable is a super set of being addressable. 262 return fmt.Errorf("cannot decode element '%s' into field %v; it is not settable", name, field) 263 } 264 if field.Kind() == reflect.Ptr && field.IsNil() { 265 field.Set(reflect.New(field.Type().Elem())) 266 } 267 field = field.Addr() 268 269 dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate} 270 if fd.decoder == nil { 271 return ErrNoDecoder{Type: field.Elem().Type()} 272 } 273 274 if decoder, ok := fd.decoder.(ValueDecoder); ok { 275 err = decoder.DecodeValue(dctx, vr, field.Elem()) 276 if err != nil { 277 return err 278 } 279 continue 280 } 281 err = fd.decoder.DecodeValue(dctx, vr, field) 282 if err != nil { 283 return err 284 } 285 } 286 287 return nil 288} 289 290func (sc *StructCodec) isZero(i interface{}) bool { 291 v := reflect.ValueOf(i) 292 293 // check the value validity 294 if !v.IsValid() { 295 return true 296 } 297 298 if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { 299 return z.IsZero() 300 } 301 302 switch v.Kind() { 303 case reflect.Array, reflect.Map, reflect.Slice, reflect.String: 304 return v.Len() == 0 305 case reflect.Bool: 306 return !v.Bool() 307 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 308 return v.Int() == 0 309 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 310 return v.Uint() == 0 311 case reflect.Float32, reflect.Float64: 312 return v.Float() == 0 313 case reflect.Interface, reflect.Ptr: 314 return v.IsNil() 315 case reflect.Struct: 316 if sc.EncodeOmitDefaultStruct { 317 vt := v.Type() 318 if vt == tTime { 319 return v.Interface().(time.Time).IsZero() 320 } 321 for i := 0; i < v.NumField(); i++ { 322 if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous { 323 continue // Private field 324 } 325 fld := v.Field(i) 326 if !sc.isZero(fld.Interface()) { 327 return false 328 } 329 } 330 return true 331 } 332 } 333 334 return false 335} 336 337type structDescription struct { 338 fm map[string]fieldDescription 339 fl []fieldDescription 340 inlineMap int 341 inline bool 342} 343 344type fieldDescription struct { 345 name string 346 idx int 347 omitEmpty bool 348 minSize bool 349 truncate bool 350 inline []int 351 encoder ValueEncoder 352 decoder ValueDecoder 353} 354 355func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) { 356 // We need to analyze the struct, including getting the tags, collecting 357 // information about inlining, and create a map of the field name to the field. 358 sc.l.RLock() 359 ds, exists := sc.cache[t] 360 sc.l.RUnlock() 361 if exists { 362 return ds, nil 363 } 364 365 numFields := t.NumField() 366 sd := &structDescription{ 367 fm: make(map[string]fieldDescription, numFields), 368 fl: make([]fieldDescription, 0, numFields), 369 inlineMap: -1, 370 } 371 372 for i := 0; i < numFields; i++ { 373 sf := t.Field(i) 374 if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) { 375 // field is private or unexported fields aren't allowed, ignore 376 continue 377 } 378 379 sfType := sf.Type 380 encoder, err := r.LookupEncoder(sfType) 381 if err != nil { 382 encoder = nil 383 } 384 decoder, err := r.LookupDecoder(sfType) 385 if err != nil { 386 decoder = nil 387 } 388 389 description := fieldDescription{idx: i, encoder: encoder, decoder: decoder} 390 391 stags, err := sc.parser.ParseStructTags(sf) 392 if err != nil { 393 return nil, err 394 } 395 if stags.Skip { 396 continue 397 } 398 description.name = stags.Name 399 description.omitEmpty = stags.OmitEmpty 400 description.minSize = stags.MinSize 401 description.truncate = stags.Truncate 402 403 if stags.Inline { 404 sd.inline = true 405 switch sfType.Kind() { 406 case reflect.Map: 407 if sd.inlineMap >= 0 { 408 return nil, errors.New("(struct " + t.String() + ") multiple inline maps") 409 } 410 if sfType.Key() != tString { 411 return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys") 412 } 413 sd.inlineMap = description.idx 414 case reflect.Ptr: 415 sfType = sfType.Elem() 416 if sfType.Kind() != reflect.Struct { 417 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String()) 418 } 419 fallthrough 420 case reflect.Struct: 421 inlinesf, err := sc.describeStruct(r, sfType) 422 if err != nil { 423 return nil, err 424 } 425 for _, fd := range inlinesf.fl { 426 if _, exists := sd.fm[fd.name]; exists { 427 return nil, fmt.Errorf("(struct %s) duplicated key %s", t.String(), fd.name) 428 } 429 if fd.inline == nil { 430 fd.inline = []int{i, fd.idx} 431 } else { 432 fd.inline = append([]int{i}, fd.inline...) 433 } 434 sd.fm[fd.name] = fd 435 sd.fl = append(sd.fl, fd) 436 } 437 default: 438 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String()) 439 } 440 continue 441 } 442 443 if _, exists := sd.fm[description.name]; exists { 444 return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), description.name) 445 } 446 447 sd.fm[description.name] = description 448 sd.fl = append(sd.fl, description) 449 } 450 451 sc.l.Lock() 452 sc.cache[t] = sd 453 sc.l.Unlock() 454 455 return sd, nil 456} 457 458func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) { 459 defer func() { 460 if recovered := recover(); recovered != nil { 461 switch r := recovered.(type) { 462 case string: 463 err = fmt.Errorf("%s", r) 464 case error: 465 err = r 466 } 467 } 468 }() 469 470 result = v.FieldByIndex(index) 471 return 472} 473 474func getInlineField(val reflect.Value, index []int) (reflect.Value, error) { 475 field, err := fieldByIndexErr(val, index) 476 if err == nil { 477 return field, nil 478 } 479 480 // if parent of this element doesn't exist, fix its parent 481 inlineParent := index[:len(index)-1] 482 var fParent reflect.Value 483 if fParent, err = fieldByIndexErr(val, inlineParent); err != nil { 484 fParent, err = getInlineField(val, inlineParent) 485 if err != nil { 486 return fParent, err 487 } 488 } 489 fParent.Set(reflect.New(fParent.Type().Elem())) 490 491 return fieldByIndexErr(val, index) 492} 493 494// DeepZero returns recursive zero object 495func deepZero(st reflect.Type) (result reflect.Value) { 496 result = reflect.Indirect(reflect.New(st)) 497 498 if result.Kind() == reflect.Struct { 499 for i := 0; i < result.NumField(); i++ { 500 if f := result.Field(i); f.Kind() == reflect.Ptr { 501 if f.CanInterface() { 502 if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct { 503 result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem()))) 504 } 505 } 506 } 507 } 508 } 509 510 return 511} 512 513// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside 514func recursivePointerTo(v reflect.Value) reflect.Value { 515 v = reflect.Indirect(v) 516 result := reflect.New(v.Type()) 517 if v.Kind() == reflect.Struct { 518 for i := 0; i < v.NumField(); i++ { 519 if f := v.Field(i); f.Kind() == reflect.Ptr { 520 if f.Elem().Kind() == reflect.Struct { 521 result.Elem().Field(i).Set(recursivePointerTo(f)) 522 } 523 } 524 } 525 } 526 527 return result 528} 529