1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package bsoncodec 8 9import ( 10 "errors" 11 "fmt" 12 "reflect" 13 "sort" 14 "strings" 15 "sync" 16 "time" 17 18 "go.mongodb.org/mongo-driver/bson/bsonoptions" 19 "go.mongodb.org/mongo-driver/bson/bsonrw" 20 "go.mongodb.org/mongo-driver/bson/bsontype" 21) 22 23// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type. 24type DecodeError struct { 25 keys []string 26 wrapped error 27} 28 29// Unwrap returns the underlying error 30func (de *DecodeError) Unwrap() error { 31 return de.wrapped 32} 33 34// Error implements the error interface. 35func (de *DecodeError) Error() string { 36 // The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the 37 // stack of BSON keys, so we call de.Keys(), which reverses them. 38 keyPath := strings.Join(de.Keys(), ".") 39 return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped) 40} 41 42// Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down 43// order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be 44// a string, the keys slice will be ["a", "b", "c"]. 45func (de *DecodeError) Keys() []string { 46 reversedKeys := make([]string, 0, len(de.keys)) 47 for idx := len(de.keys) - 1; idx >= 0; idx-- { 48 reversedKeys = append(reversedKeys, de.keys[idx]) 49 } 50 51 return reversedKeys 52} 53 54// Zeroer allows custom struct types to implement a report of zero 55// state. All struct types that don't implement Zeroer or where IsZero 56// returns false are considered to be not zero. 57type Zeroer interface { 58 IsZero() bool 59} 60 61// StructCodec is the Codec used for struct values. 62type StructCodec struct { 63 cache map[reflect.Type]*structDescription 64 l sync.RWMutex 65 parser StructTagParser 66 DecodeZeroStruct bool 67 DecodeDeepZeroInline bool 68 EncodeOmitDefaultStruct bool 69 AllowUnexportedFields bool 70 OverwriteDuplicatedInlinedFields bool 71} 72 73var _ ValueEncoder = &StructCodec{} 74var _ ValueDecoder = &StructCodec{} 75 76// NewStructCodec returns a StructCodec that uses p for struct tag parsing. 77func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) { 78 if p == nil { 79 return nil, errors.New("a StructTagParser must be provided to NewStructCodec") 80 } 81 82 structOpt := bsonoptions.MergeStructCodecOptions(opts...) 83 84 codec := &StructCodec{ 85 cache: make(map[reflect.Type]*structDescription), 86 parser: p, 87 } 88 89 if structOpt.DecodeZeroStruct != nil { 90 codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct 91 } 92 if structOpt.DecodeDeepZeroInline != nil { 93 codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline 94 } 95 if structOpt.EncodeOmitDefaultStruct != nil { 96 codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct 97 } 98 if structOpt.OverwriteDuplicatedInlinedFields != nil { 99 codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields 100 } 101 if structOpt.AllowUnexportedFields != nil { 102 codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields 103 } 104 105 return codec, nil 106} 107 108// EncodeValue handles encoding generic struct types. 109func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { 110 if !val.IsValid() || val.Kind() != reflect.Struct { 111 return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} 112 } 113 114 sd, err := sc.describeStruct(r.Registry, val.Type()) 115 if err != nil { 116 return err 117 } 118 119 dw, err := vw.WriteDocument() 120 if err != nil { 121 return err 122 } 123 var rv reflect.Value 124 for _, desc := range sd.fl { 125 if desc.inline == nil { 126 rv = val.Field(desc.idx) 127 } else { 128 rv, err = fieldByIndexErr(val, desc.inline) 129 if err != nil { 130 continue 131 } 132 } 133 134 desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv) 135 136 if err != nil && err != errInvalidValue { 137 return err 138 } 139 140 if err == errInvalidValue { 141 if desc.omitEmpty { 142 continue 143 } 144 vw2, err := dw.WriteDocumentElement(desc.name) 145 if err != nil { 146 return err 147 } 148 err = vw2.WriteNull() 149 if err != nil { 150 return err 151 } 152 continue 153 } 154 155 if desc.encoder == nil { 156 return ErrNoEncoder{Type: rv.Type()} 157 } 158 159 encoder := desc.encoder 160 161 var isZero bool 162 rvInterface := rv.Interface() 163 if cz, ok := encoder.(CodecZeroer); ok { 164 isZero = cz.IsTypeZero(rvInterface) 165 } else if rv.Kind() == reflect.Interface { 166 // sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately. 167 isZero = rv.IsNil() 168 } else { 169 isZero = sc.isZero(rvInterface) 170 } 171 if desc.omitEmpty && isZero { 172 continue 173 } 174 175 vw2, err := dw.WriteDocumentElement(desc.name) 176 if err != nil { 177 return err 178 } 179 180 ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize} 181 err = encoder.EncodeValue(ectx, vw2, rv) 182 if err != nil { 183 return err 184 } 185 } 186 187 if sd.inlineMap >= 0 { 188 rv := val.Field(sd.inlineMap) 189 collisionFn := func(key string) bool { 190 _, exists := sd.fm[key] 191 return exists 192 } 193 194 return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn) 195 } 196 197 return dw.WriteDocumentEnd() 198} 199 200func newDecodeError(key string, original error) error { 201 de, ok := original.(*DecodeError) 202 if !ok { 203 return &DecodeError{ 204 keys: []string{key}, 205 wrapped: original, 206 } 207 } 208 209 de.keys = append(de.keys, key) 210 return de 211} 212 213// DecodeValue implements the Codec interface. 214// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. 215// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. 216func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { 217 if !val.CanSet() || val.Kind() != reflect.Struct { 218 return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} 219 } 220 221 switch vrType := vr.Type(); vrType { 222 case bsontype.Type(0), bsontype.EmbeddedDocument: 223 case bsontype.Null: 224 if err := vr.ReadNull(); err != nil { 225 return err 226 } 227 228 val.Set(reflect.Zero(val.Type())) 229 return nil 230 case bsontype.Undefined: 231 if err := vr.ReadUndefined(); err != nil { 232 return err 233 } 234 235 val.Set(reflect.Zero(val.Type())) 236 return nil 237 default: 238 return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) 239 } 240 241 sd, err := sc.describeStruct(r.Registry, val.Type()) 242 if err != nil { 243 return err 244 } 245 246 if sc.DecodeZeroStruct { 247 val.Set(reflect.Zero(val.Type())) 248 } 249 if sc.DecodeDeepZeroInline && sd.inline { 250 val.Set(deepZero(val.Type())) 251 } 252 253 var decoder ValueDecoder 254 var inlineMap reflect.Value 255 if sd.inlineMap >= 0 { 256 inlineMap = val.Field(sd.inlineMap) 257 decoder, err = r.LookupDecoder(inlineMap.Type().Elem()) 258 if err != nil { 259 return err 260 } 261 } 262 263 dr, err := vr.ReadDocument() 264 if err != nil { 265 return err 266 } 267 268 for { 269 name, vr, err := dr.ReadElement() 270 if err == bsonrw.ErrEOD { 271 break 272 } 273 if err != nil { 274 return err 275 } 276 277 fd, exists := sd.fm[name] 278 if !exists { 279 // if the original name isn't found in the struct description, try again with the name in lowercase 280 // this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field 281 // names 282 fd, exists = sd.fm[strings.ToLower(name)] 283 } 284 285 if !exists { 286 if sd.inlineMap < 0 { 287 // The encoding/json package requires a flag to return on error for non-existent fields. 288 // This functionality seems appropriate for the struct codec. 289 err = vr.Skip() 290 if err != nil { 291 return err 292 } 293 continue 294 } 295 296 if inlineMap.IsNil() { 297 inlineMap.Set(reflect.MakeMap(inlineMap.Type())) 298 } 299 300 elem := reflect.New(inlineMap.Type().Elem()).Elem() 301 r.Ancestor = inlineMap.Type() 302 err = decoder.DecodeValue(r, vr, elem) 303 if err != nil { 304 return err 305 } 306 inlineMap.SetMapIndex(reflect.ValueOf(name), elem) 307 continue 308 } 309 310 var field reflect.Value 311 if fd.inline == nil { 312 field = val.Field(fd.idx) 313 } else { 314 field, err = getInlineField(val, fd.inline) 315 if err != nil { 316 return err 317 } 318 } 319 320 if !field.CanSet() { // Being settable is a super set of being addressable. 321 innerErr := fmt.Errorf("field %v is not settable", field) 322 return newDecodeError(fd.name, innerErr) 323 } 324 if field.Kind() == reflect.Ptr && field.IsNil() { 325 field.Set(reflect.New(field.Type().Elem())) 326 } 327 field = field.Addr() 328 329 dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate} 330 if fd.decoder == nil { 331 return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) 332 } 333 334 if decoder, ok := fd.decoder.(ValueDecoder); ok { 335 err = decoder.DecodeValue(dctx, vr, field.Elem()) 336 if err != nil { 337 return newDecodeError(fd.name, err) 338 } 339 continue 340 } 341 err = fd.decoder.DecodeValue(dctx, vr, field) 342 if err != nil { 343 return newDecodeError(fd.name, err) 344 } 345 } 346 347 return nil 348} 349 350func (sc *StructCodec) isZero(i interface{}) bool { 351 v := reflect.ValueOf(i) 352 353 // check the value validity 354 if !v.IsValid() { 355 return true 356 } 357 358 if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { 359 return z.IsZero() 360 } 361 362 switch v.Kind() { 363 case reflect.Array, reflect.Map, reflect.Slice, reflect.String: 364 return v.Len() == 0 365 case reflect.Bool: 366 return !v.Bool() 367 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 368 return v.Int() == 0 369 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 370 return v.Uint() == 0 371 case reflect.Float32, reflect.Float64: 372 return v.Float() == 0 373 case reflect.Interface, reflect.Ptr: 374 return v.IsNil() 375 case reflect.Struct: 376 if sc.EncodeOmitDefaultStruct { 377 vt := v.Type() 378 if vt == tTime { 379 return v.Interface().(time.Time).IsZero() 380 } 381 for i := 0; i < v.NumField(); i++ { 382 if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous { 383 continue // Private field 384 } 385 fld := v.Field(i) 386 if !sc.isZero(fld.Interface()) { 387 return false 388 } 389 } 390 return true 391 } 392 } 393 394 return false 395} 396 397type structDescription struct { 398 fm map[string]fieldDescription 399 fl []fieldDescription 400 inlineMap int 401 inline bool 402} 403 404type fieldDescription struct { 405 name string // BSON key name 406 fieldName string // struct field name 407 idx int 408 omitEmpty bool 409 minSize bool 410 truncate bool 411 inline []int 412 encoder ValueEncoder 413 decoder ValueDecoder 414} 415 416type byIndex []fieldDescription 417 418func (bi byIndex) Len() int { return len(bi) } 419 420func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] } 421 422func (bi byIndex) Less(i, j int) bool { 423 // If a field is inlined, its index in the top level struct is stored at inline[0] 424 iIdx, jIdx := bi[i].idx, bi[j].idx 425 if len(bi[i].inline) > 0 { 426 iIdx = bi[i].inline[0] 427 } 428 if len(bi[j].inline) > 0 { 429 jIdx = bi[j].inline[0] 430 } 431 if iIdx != jIdx { 432 return iIdx < jIdx 433 } 434 for k, biik := range bi[i].inline { 435 if k >= len(bi[j].inline) { 436 return false 437 } 438 if biik != bi[j].inline[k] { 439 return biik < bi[j].inline[k] 440 } 441 } 442 return len(bi[i].inline) < len(bi[j].inline) 443} 444 445func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) { 446 // We need to analyze the struct, including getting the tags, collecting 447 // information about inlining, and create a map of the field name to the field. 448 sc.l.RLock() 449 ds, exists := sc.cache[t] 450 sc.l.RUnlock() 451 if exists { 452 return ds, nil 453 } 454 455 numFields := t.NumField() 456 sd := &structDescription{ 457 fm: make(map[string]fieldDescription, numFields), 458 fl: make([]fieldDescription, 0, numFields), 459 inlineMap: -1, 460 } 461 462 var fields []fieldDescription 463 for i := 0; i < numFields; i++ { 464 sf := t.Field(i) 465 if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) { 466 // field is private or unexported fields aren't allowed, ignore 467 continue 468 } 469 470 sfType := sf.Type 471 encoder, err := r.LookupEncoder(sfType) 472 if err != nil { 473 encoder = nil 474 } 475 decoder, err := r.LookupDecoder(sfType) 476 if err != nil { 477 decoder = nil 478 } 479 480 description := fieldDescription{ 481 fieldName: sf.Name, 482 idx: i, 483 encoder: encoder, 484 decoder: decoder, 485 } 486 487 stags, err := sc.parser.ParseStructTags(sf) 488 if err != nil { 489 return nil, err 490 } 491 if stags.Skip { 492 continue 493 } 494 description.name = stags.Name 495 description.omitEmpty = stags.OmitEmpty 496 description.minSize = stags.MinSize 497 description.truncate = stags.Truncate 498 499 if stags.Inline { 500 sd.inline = true 501 switch sfType.Kind() { 502 case reflect.Map: 503 if sd.inlineMap >= 0 { 504 return nil, errors.New("(struct " + t.String() + ") multiple inline maps") 505 } 506 if sfType.Key() != tString { 507 return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys") 508 } 509 sd.inlineMap = description.idx 510 case reflect.Ptr: 511 sfType = sfType.Elem() 512 if sfType.Kind() != reflect.Struct { 513 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String()) 514 } 515 fallthrough 516 case reflect.Struct: 517 inlinesf, err := sc.describeStruct(r, sfType) 518 if err != nil { 519 return nil, err 520 } 521 for _, fd := range inlinesf.fl { 522 if fd.inline == nil { 523 fd.inline = []int{i, fd.idx} 524 } else { 525 fd.inline = append([]int{i}, fd.inline...) 526 } 527 fields = append(fields, fd) 528 529 } 530 default: 531 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String()) 532 } 533 continue 534 } 535 fields = append(fields, description) 536 } 537 538 // Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name 539 sort.Slice(fields, func(i, j int) bool { 540 x := fields 541 // sort field by name, breaking ties with depth, then 542 // breaking ties with index sequence. 543 if x[i].name != x[j].name { 544 return x[i].name < x[j].name 545 } 546 if len(x[i].inline) != len(x[j].inline) { 547 return len(x[i].inline) < len(x[j].inline) 548 } 549 return byIndex(x).Less(i, j) 550 }) 551 552 for advance, i := 0, 0; i < len(fields); i += advance { 553 // One iteration per name. 554 // Find the sequence of fields with the name of this first field. 555 fi := fields[i] 556 name := fi.name 557 for advance = 1; i+advance < len(fields); advance++ { 558 fj := fields[i+advance] 559 if fj.name != name { 560 break 561 } 562 } 563 if advance == 1 { // Only one field with this name 564 sd.fl = append(sd.fl, fi) 565 sd.fm[name] = fi 566 continue 567 } 568 dominant, ok := dominantField(fields[i : i+advance]) 569 if !ok || !sc.OverwriteDuplicatedInlinedFields { 570 return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), name) 571 } 572 sd.fl = append(sd.fl, dominant) 573 sd.fm[name] = dominant 574 } 575 576 sort.Sort(byIndex(sd.fl)) 577 578 sc.l.Lock() 579 sc.cache[t] = sd 580 sc.l.Unlock() 581 582 return sd, nil 583} 584 585// dominantField looks through the fields, all of which are known to 586// have the same name, to find the single field that dominates the 587// others using Go's inlining rules. If there are multiple top-level 588// fields, the boolean will be false: This condition is an error in Go 589// and we skip all the fields. 590func dominantField(fields []fieldDescription) (fieldDescription, bool) { 591 // The fields are sorted in increasing index-length order, then by presence of tag. 592 // That means that the first field is the dominant one. We need only check 593 // for error cases: two fields at top level. 594 if len(fields) > 1 && 595 len(fields[0].inline) == len(fields[1].inline) { 596 return fieldDescription{}, false 597 } 598 return fields[0], true 599} 600 601func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) { 602 defer func() { 603 if recovered := recover(); recovered != nil { 604 switch r := recovered.(type) { 605 case string: 606 err = fmt.Errorf("%s", r) 607 case error: 608 err = r 609 } 610 } 611 }() 612 613 result = v.FieldByIndex(index) 614 return 615} 616 617func getInlineField(val reflect.Value, index []int) (reflect.Value, error) { 618 field, err := fieldByIndexErr(val, index) 619 if err == nil { 620 return field, nil 621 } 622 623 // if parent of this element doesn't exist, fix its parent 624 inlineParent := index[:len(index)-1] 625 var fParent reflect.Value 626 if fParent, err = fieldByIndexErr(val, inlineParent); err != nil { 627 fParent, err = getInlineField(val, inlineParent) 628 if err != nil { 629 return fParent, err 630 } 631 } 632 fParent.Set(reflect.New(fParent.Type().Elem())) 633 634 return fieldByIndexErr(val, index) 635} 636 637// DeepZero returns recursive zero object 638func deepZero(st reflect.Type) (result reflect.Value) { 639 result = reflect.Indirect(reflect.New(st)) 640 641 if result.Kind() == reflect.Struct { 642 for i := 0; i < result.NumField(); i++ { 643 if f := result.Field(i); f.Kind() == reflect.Ptr { 644 if f.CanInterface() { 645 if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct { 646 result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem()))) 647 } 648 } 649 } 650 } 651 } 652 653 return 654} 655 656// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside 657func recursivePointerTo(v reflect.Value) reflect.Value { 658 v = reflect.Indirect(v) 659 result := reflect.New(v.Type()) 660 if v.Kind() == reflect.Struct { 661 for i := 0; i < v.NumField(); i++ { 662 if f := v.Field(i); f.Kind() == reflect.Ptr { 663 if f.Elem().Kind() == reflect.Struct { 664 result.Elem().Field(i).Set(recursivePointerTo(f)) 665 } 666 } 667 } 668 } 669 670 return result 671} 672