1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package bsoncodec 8 9import ( 10 "errors" 11 "fmt" 12 "reflect" 13 "sort" 14 "strings" 15 "sync" 16 "time" 17 18 "go.mongodb.org/mongo-driver/bson/bsonoptions" 19 "go.mongodb.org/mongo-driver/bson/bsonrw" 20 "go.mongodb.org/mongo-driver/bson/bsontype" 21) 22 23// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type. 24type DecodeError struct { 25 keys []string 26 wrapped error 27} 28 29// Unwrap returns the underlying error 30func (de *DecodeError) Unwrap() error { 31 return de.wrapped 32} 33 34// Error implements the error interface. 35func (de *DecodeError) Error() string { 36 // The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the 37 // stack of BSON keys, so we call de.Keys(), which reverses them. 38 keyPath := strings.Join(de.Keys(), ".") 39 return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped) 40} 41 42// Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down 43// order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be 44// a string, the keys slice will be ["a", "b", "c"]. 45func (de *DecodeError) Keys() []string { 46 reversedKeys := make([]string, 0, len(de.keys)) 47 for idx := len(de.keys) - 1; idx >= 0; idx-- { 48 reversedKeys = append(reversedKeys, de.keys[idx]) 49 } 50 51 return reversedKeys 52} 53 54// Zeroer allows custom struct types to implement a report of zero 55// state. All struct types that don't implement Zeroer or where IsZero 56// returns false are considered to be not zero. 57type Zeroer interface { 58 IsZero() bool 59} 60 61// StructCodec is the Codec used for struct values. 62type StructCodec struct { 63 cache map[reflect.Type]*structDescription 64 l sync.RWMutex 65 parser StructTagParser 66 DecodeZeroStruct bool 67 DecodeDeepZeroInline bool 68 EncodeOmitDefaultStruct bool 69 AllowUnexportedFields bool 70 OverwriteDuplicatedInlinedFields bool 71} 72 73var _ ValueEncoder = &StructCodec{} 74var _ ValueDecoder = &StructCodec{} 75 76// NewStructCodec returns a StructCodec that uses p for struct tag parsing. 77func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) { 78 if p == nil { 79 return nil, errors.New("a StructTagParser must be provided to NewStructCodec") 80 } 81 82 structOpt := bsonoptions.MergeStructCodecOptions(opts...) 83 84 codec := &StructCodec{ 85 cache: make(map[reflect.Type]*structDescription), 86 parser: p, 87 } 88 89 if structOpt.DecodeZeroStruct != nil { 90 codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct 91 } 92 if structOpt.DecodeDeepZeroInline != nil { 93 codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline 94 } 95 if structOpt.EncodeOmitDefaultStruct != nil { 96 codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct 97 } 98 if structOpt.OverwriteDuplicatedInlinedFields != nil { 99 codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields 100 } 101 if structOpt.AllowUnexportedFields != nil { 102 codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields 103 } 104 105 return codec, nil 106} 107 108// EncodeValue handles encoding generic struct types. 109func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { 110 if !val.IsValid() || val.Kind() != reflect.Struct { 111 return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} 112 } 113 114 sd, err := sc.describeStruct(r.Registry, val.Type()) 115 if err != nil { 116 return err 117 } 118 119 dw, err := vw.WriteDocument() 120 if err != nil { 121 return err 122 } 123 var rv reflect.Value 124 for _, desc := range sd.fl { 125 if desc.inline == nil { 126 rv = val.Field(desc.idx) 127 } else { 128 rv, err = fieldByIndexErr(val, desc.inline) 129 if err != nil { 130 continue 131 } 132 } 133 134 desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv) 135 136 if err != nil && err != errInvalidValue { 137 return err 138 } 139 140 if err == errInvalidValue { 141 if desc.omitEmpty { 142 continue 143 } 144 vw2, err := dw.WriteDocumentElement(desc.name) 145 if err != nil { 146 return err 147 } 148 err = vw2.WriteNull() 149 if err != nil { 150 return err 151 } 152 continue 153 } 154 155 if desc.encoder == nil { 156 return ErrNoEncoder{Type: rv.Type()} 157 } 158 159 encoder := desc.encoder 160 161 var isZero bool 162 rvInterface := rv.Interface() 163 if cz, ok := encoder.(CodecZeroer); ok { 164 isZero = cz.IsTypeZero(rvInterface) 165 } else if rv.Kind() == reflect.Interface { 166 // sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately. 167 isZero = rv.IsNil() 168 } else { 169 isZero = sc.isZero(rvInterface) 170 } 171 if desc.omitEmpty && isZero { 172 continue 173 } 174 175 vw2, err := dw.WriteDocumentElement(desc.name) 176 if err != nil { 177 return err 178 } 179 180 ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize} 181 err = encoder.EncodeValue(ectx, vw2, rv) 182 if err != nil { 183 return err 184 } 185 } 186 187 if sd.inlineMap >= 0 { 188 rv := val.Field(sd.inlineMap) 189 collisionFn := func(key string) bool { 190 _, exists := sd.fm[key] 191 return exists 192 } 193 194 return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn) 195 } 196 197 return dw.WriteDocumentEnd() 198} 199 200func newDecodeError(key string, original error) error { 201 de, ok := original.(*DecodeError) 202 if !ok { 203 return &DecodeError{ 204 keys: []string{key}, 205 wrapped: original, 206 } 207 } 208 209 de.keys = append(de.keys, key) 210 return de 211} 212 213// DecodeValue implements the Codec interface. 214// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. 215// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. 216func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { 217 if !val.CanSet() || val.Kind() != reflect.Struct { 218 return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} 219 } 220 221 switch vrType := vr.Type(); vrType { 222 case bsontype.Type(0), bsontype.EmbeddedDocument: 223 case bsontype.Null: 224 if err := vr.ReadNull(); err != nil { 225 return err 226 } 227 228 val.Set(reflect.Zero(val.Type())) 229 return nil 230 case bsontype.Undefined: 231 if err := vr.ReadUndefined(); err != nil { 232 return err 233 } 234 235 val.Set(reflect.Zero(val.Type())) 236 return nil 237 default: 238 return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) 239 } 240 241 sd, err := sc.describeStruct(r.Registry, val.Type()) 242 if err != nil { 243 return err 244 } 245 246 if sc.DecodeZeroStruct { 247 val.Set(reflect.Zero(val.Type())) 248 } 249 if sc.DecodeDeepZeroInline && sd.inline { 250 val.Set(deepZero(val.Type())) 251 } 252 253 var decoder ValueDecoder 254 var inlineMap reflect.Value 255 if sd.inlineMap >= 0 { 256 inlineMap = val.Field(sd.inlineMap) 257 decoder, err = r.LookupDecoder(inlineMap.Type().Elem()) 258 if err != nil { 259 return err 260 } 261 } 262 263 dr, err := vr.ReadDocument() 264 if err != nil { 265 return err 266 } 267 268 for { 269 name, vr, err := dr.ReadElement() 270 if err == bsonrw.ErrEOD { 271 break 272 } 273 if err != nil { 274 return err 275 } 276 277 fd, exists := sd.fm[name] 278 if !exists { 279 // if the original name isn't found in the struct description, try again with the name in lowercase 280 // this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field 281 // names 282 fd, exists = sd.fm[strings.ToLower(name)] 283 } 284 285 if !exists { 286 if sd.inlineMap < 0 { 287 // The encoding/json package requires a flag to return on error for non-existent fields. 288 // This functionality seems appropriate for the struct codec. 289 err = vr.Skip() 290 if err != nil { 291 return err 292 } 293 continue 294 } 295 296 if inlineMap.IsNil() { 297 inlineMap.Set(reflect.MakeMap(inlineMap.Type())) 298 } 299 300 elem := reflect.New(inlineMap.Type().Elem()).Elem() 301 r.Ancestor = inlineMap.Type() 302 err = decoder.DecodeValue(r, vr, elem) 303 if err != nil { 304 return err 305 } 306 inlineMap.SetMapIndex(reflect.ValueOf(name), elem) 307 continue 308 } 309 310 var field reflect.Value 311 if fd.inline == nil { 312 field = val.Field(fd.idx) 313 } else { 314 field, err = getInlineField(val, fd.inline) 315 if err != nil { 316 return err 317 } 318 } 319 320 if !field.CanSet() { // Being settable is a super set of being addressable. 321 innerErr := fmt.Errorf("field %v is not settable", field) 322 return newDecodeError(fd.name, innerErr) 323 } 324 if field.Kind() == reflect.Ptr && field.IsNil() { 325 field.Set(reflect.New(field.Type().Elem())) 326 } 327 field = field.Addr() 328 329 dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate} 330 if fd.decoder == nil { 331 return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) 332 } 333 334 err = fd.decoder.DecodeValue(dctx, vr, field.Elem()) 335 if err != nil { 336 return newDecodeError(fd.name, err) 337 } 338 } 339 340 return nil 341} 342 343func (sc *StructCodec) isZero(i interface{}) bool { 344 v := reflect.ValueOf(i) 345 346 // check the value validity 347 if !v.IsValid() { 348 return true 349 } 350 351 if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { 352 return z.IsZero() 353 } 354 355 switch v.Kind() { 356 case reflect.Array, reflect.Map, reflect.Slice, reflect.String: 357 return v.Len() == 0 358 case reflect.Bool: 359 return !v.Bool() 360 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 361 return v.Int() == 0 362 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 363 return v.Uint() == 0 364 case reflect.Float32, reflect.Float64: 365 return v.Float() == 0 366 case reflect.Interface, reflect.Ptr: 367 return v.IsNil() 368 case reflect.Struct: 369 if sc.EncodeOmitDefaultStruct { 370 vt := v.Type() 371 if vt == tTime { 372 return v.Interface().(time.Time).IsZero() 373 } 374 for i := 0; i < v.NumField(); i++ { 375 if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous { 376 continue // Private field 377 } 378 fld := v.Field(i) 379 if !sc.isZero(fld.Interface()) { 380 return false 381 } 382 } 383 return true 384 } 385 } 386 387 return false 388} 389 390type structDescription struct { 391 fm map[string]fieldDescription 392 fl []fieldDescription 393 inlineMap int 394 inline bool 395} 396 397type fieldDescription struct { 398 name string // BSON key name 399 fieldName string // struct field name 400 idx int 401 omitEmpty bool 402 minSize bool 403 truncate bool 404 inline []int 405 encoder ValueEncoder 406 decoder ValueDecoder 407} 408 409type byIndex []fieldDescription 410 411func (bi byIndex) Len() int { return len(bi) } 412 413func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] } 414 415func (bi byIndex) Less(i, j int) bool { 416 // If a field is inlined, its index in the top level struct is stored at inline[0] 417 iIdx, jIdx := bi[i].idx, bi[j].idx 418 if len(bi[i].inline) > 0 { 419 iIdx = bi[i].inline[0] 420 } 421 if len(bi[j].inline) > 0 { 422 jIdx = bi[j].inline[0] 423 } 424 if iIdx != jIdx { 425 return iIdx < jIdx 426 } 427 for k, biik := range bi[i].inline { 428 if k >= len(bi[j].inline) { 429 return false 430 } 431 if biik != bi[j].inline[k] { 432 return biik < bi[j].inline[k] 433 } 434 } 435 return len(bi[i].inline) < len(bi[j].inline) 436} 437 438func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) { 439 // We need to analyze the struct, including getting the tags, collecting 440 // information about inlining, and create a map of the field name to the field. 441 sc.l.RLock() 442 ds, exists := sc.cache[t] 443 sc.l.RUnlock() 444 if exists { 445 return ds, nil 446 } 447 448 numFields := t.NumField() 449 sd := &structDescription{ 450 fm: make(map[string]fieldDescription, numFields), 451 fl: make([]fieldDescription, 0, numFields), 452 inlineMap: -1, 453 } 454 455 var fields []fieldDescription 456 for i := 0; i < numFields; i++ { 457 sf := t.Field(i) 458 if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) { 459 // field is private or unexported fields aren't allowed, ignore 460 continue 461 } 462 463 sfType := sf.Type 464 encoder, err := r.LookupEncoder(sfType) 465 if err != nil { 466 encoder = nil 467 } 468 decoder, err := r.LookupDecoder(sfType) 469 if err != nil { 470 decoder = nil 471 } 472 473 description := fieldDescription{ 474 fieldName: sf.Name, 475 idx: i, 476 encoder: encoder, 477 decoder: decoder, 478 } 479 480 stags, err := sc.parser.ParseStructTags(sf) 481 if err != nil { 482 return nil, err 483 } 484 if stags.Skip { 485 continue 486 } 487 description.name = stags.Name 488 description.omitEmpty = stags.OmitEmpty 489 description.minSize = stags.MinSize 490 description.truncate = stags.Truncate 491 492 if stags.Inline { 493 sd.inline = true 494 switch sfType.Kind() { 495 case reflect.Map: 496 if sd.inlineMap >= 0 { 497 return nil, errors.New("(struct " + t.String() + ") multiple inline maps") 498 } 499 if sfType.Key() != tString { 500 return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys") 501 } 502 sd.inlineMap = description.idx 503 case reflect.Ptr: 504 sfType = sfType.Elem() 505 if sfType.Kind() != reflect.Struct { 506 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String()) 507 } 508 fallthrough 509 case reflect.Struct: 510 inlinesf, err := sc.describeStruct(r, sfType) 511 if err != nil { 512 return nil, err 513 } 514 for _, fd := range inlinesf.fl { 515 if fd.inline == nil { 516 fd.inline = []int{i, fd.idx} 517 } else { 518 fd.inline = append([]int{i}, fd.inline...) 519 } 520 fields = append(fields, fd) 521 522 } 523 default: 524 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String()) 525 } 526 continue 527 } 528 fields = append(fields, description) 529 } 530 531 // Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name 532 sort.Slice(fields, func(i, j int) bool { 533 x := fields 534 // sort field by name, breaking ties with depth, then 535 // breaking ties with index sequence. 536 if x[i].name != x[j].name { 537 return x[i].name < x[j].name 538 } 539 if len(x[i].inline) != len(x[j].inline) { 540 return len(x[i].inline) < len(x[j].inline) 541 } 542 return byIndex(x).Less(i, j) 543 }) 544 545 for advance, i := 0, 0; i < len(fields); i += advance { 546 // One iteration per name. 547 // Find the sequence of fields with the name of this first field. 548 fi := fields[i] 549 name := fi.name 550 for advance = 1; i+advance < len(fields); advance++ { 551 fj := fields[i+advance] 552 if fj.name != name { 553 break 554 } 555 } 556 if advance == 1 { // Only one field with this name 557 sd.fl = append(sd.fl, fi) 558 sd.fm[name] = fi 559 continue 560 } 561 dominant, ok := dominantField(fields[i : i+advance]) 562 if !ok || !sc.OverwriteDuplicatedInlinedFields { 563 return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name) 564 } 565 sd.fl = append(sd.fl, dominant) 566 sd.fm[name] = dominant 567 } 568 569 sort.Sort(byIndex(sd.fl)) 570 571 sc.l.Lock() 572 sc.cache[t] = sd 573 sc.l.Unlock() 574 575 return sd, nil 576} 577 578// dominantField looks through the fields, all of which are known to 579// have the same name, to find the single field that dominates the 580// others using Go's inlining rules. If there are multiple top-level 581// fields, the boolean will be false: This condition is an error in Go 582// and we skip all the fields. 583func dominantField(fields []fieldDescription) (fieldDescription, bool) { 584 // The fields are sorted in increasing index-length order, then by presence of tag. 585 // That means that the first field is the dominant one. We need only check 586 // for error cases: two fields at top level. 587 if len(fields) > 1 && 588 len(fields[0].inline) == len(fields[1].inline) { 589 return fieldDescription{}, false 590 } 591 return fields[0], true 592} 593 594func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) { 595 defer func() { 596 if recovered := recover(); recovered != nil { 597 switch r := recovered.(type) { 598 case string: 599 err = fmt.Errorf("%s", r) 600 case error: 601 err = r 602 } 603 } 604 }() 605 606 result = v.FieldByIndex(index) 607 return 608} 609 610func getInlineField(val reflect.Value, index []int) (reflect.Value, error) { 611 field, err := fieldByIndexErr(val, index) 612 if err == nil { 613 return field, nil 614 } 615 616 // if parent of this element doesn't exist, fix its parent 617 inlineParent := index[:len(index)-1] 618 var fParent reflect.Value 619 if fParent, err = fieldByIndexErr(val, inlineParent); err != nil { 620 fParent, err = getInlineField(val, inlineParent) 621 if err != nil { 622 return fParent, err 623 } 624 } 625 fParent.Set(reflect.New(fParent.Type().Elem())) 626 627 return fieldByIndexErr(val, index) 628} 629 630// DeepZero returns recursive zero object 631func deepZero(st reflect.Type) (result reflect.Value) { 632 result = reflect.Indirect(reflect.New(st)) 633 634 if result.Kind() == reflect.Struct { 635 for i := 0; i < result.NumField(); i++ { 636 if f := result.Field(i); f.Kind() == reflect.Ptr { 637 if f.CanInterface() { 638 if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct { 639 result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem()))) 640 } 641 } 642 } 643 } 644 } 645 646 return 647} 648 649// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside 650func recursivePointerTo(v reflect.Value) reflect.Value { 651 v = reflect.Indirect(v) 652 result := reflect.New(v.Type()) 653 if v.Kind() == reflect.Struct { 654 for i := 0; i < v.NumField(); i++ { 655 if f := v.Field(i); f.Kind() == reflect.Ptr { 656 if f.Elem().Kind() == reflect.Struct { 657 result.Elem().Field(i).Set(recursivePointerTo(f)) 658 } 659 } 660 } 661 } 662 663 return result 664} 665