1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsoncodec
8
9import (
10	"errors"
11	"fmt"
12	"reflect"
13	"sort"
14	"strings"
15	"sync"
16	"time"
17
18	"go.mongodb.org/mongo-driver/bson/bsonoptions"
19	"go.mongodb.org/mongo-driver/bson/bsonrw"
20	"go.mongodb.org/mongo-driver/bson/bsontype"
21)
22
23// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
24type DecodeError struct {
25	keys    []string
26	wrapped error
27}
28
29// Unwrap returns the underlying error
30func (de *DecodeError) Unwrap() error {
31	return de.wrapped
32}
33
34// Error implements the error interface.
35func (de *DecodeError) Error() string {
36	// The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the
37	// stack of BSON keys, so we call de.Keys(), which reverses them.
38	keyPath := strings.Join(de.Keys(), ".")
39	return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
40}
41
42// Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down
43// order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be
44// a string, the keys slice will be ["a", "b", "c"].
45func (de *DecodeError) Keys() []string {
46	reversedKeys := make([]string, 0, len(de.keys))
47	for idx := len(de.keys) - 1; idx >= 0; idx-- {
48		reversedKeys = append(reversedKeys, de.keys[idx])
49	}
50
51	return reversedKeys
52}
53
54// Zeroer allows custom struct types to implement a report of zero
55// state. All struct types that don't implement Zeroer or where IsZero
56// returns false are considered to be not zero.
57type Zeroer interface {
58	IsZero() bool
59}
60
61// StructCodec is the Codec used for struct values.
62type StructCodec struct {
63	cache                            map[reflect.Type]*structDescription
64	l                                sync.RWMutex
65	parser                           StructTagParser
66	DecodeZeroStruct                 bool
67	DecodeDeepZeroInline             bool
68	EncodeOmitDefaultStruct          bool
69	AllowUnexportedFields            bool
70	OverwriteDuplicatedInlinedFields bool
71}
72
73var _ ValueEncoder = &StructCodec{}
74var _ ValueDecoder = &StructCodec{}
75
76// NewStructCodec returns a StructCodec that uses p for struct tag parsing.
77func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
78	if p == nil {
79		return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
80	}
81
82	structOpt := bsonoptions.MergeStructCodecOptions(opts...)
83
84	codec := &StructCodec{
85		cache:  make(map[reflect.Type]*structDescription),
86		parser: p,
87	}
88
89	if structOpt.DecodeZeroStruct != nil {
90		codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
91	}
92	if structOpt.DecodeDeepZeroInline != nil {
93		codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
94	}
95	if structOpt.EncodeOmitDefaultStruct != nil {
96		codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
97	}
98	if structOpt.OverwriteDuplicatedInlinedFields != nil {
99		codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields
100	}
101	if structOpt.AllowUnexportedFields != nil {
102		codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
103	}
104
105	return codec, nil
106}
107
108// EncodeValue handles encoding generic struct types.
109func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
110	if !val.IsValid() || val.Kind() != reflect.Struct {
111		return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
112	}
113
114	sd, err := sc.describeStruct(r.Registry, val.Type())
115	if err != nil {
116		return err
117	}
118
119	dw, err := vw.WriteDocument()
120	if err != nil {
121		return err
122	}
123	var rv reflect.Value
124	for _, desc := range sd.fl {
125		if desc.inline == nil {
126			rv = val.Field(desc.idx)
127		} else {
128			rv, err = fieldByIndexErr(val, desc.inline)
129			if err != nil {
130				continue
131			}
132		}
133
134		desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv)
135
136		if err != nil && err != errInvalidValue {
137			return err
138		}
139
140		if err == errInvalidValue {
141			if desc.omitEmpty {
142				continue
143			}
144			vw2, err := dw.WriteDocumentElement(desc.name)
145			if err != nil {
146				return err
147			}
148			err = vw2.WriteNull()
149			if err != nil {
150				return err
151			}
152			continue
153		}
154
155		if desc.encoder == nil {
156			return ErrNoEncoder{Type: rv.Type()}
157		}
158
159		encoder := desc.encoder
160
161		var isZero bool
162		rvInterface := rv.Interface()
163		if cz, ok := encoder.(CodecZeroer); ok {
164			isZero = cz.IsTypeZero(rvInterface)
165		} else if rv.Kind() == reflect.Interface {
166			// sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately.
167			isZero = rv.IsNil()
168		} else {
169			isZero = sc.isZero(rvInterface)
170		}
171		if desc.omitEmpty && isZero {
172			continue
173		}
174
175		vw2, err := dw.WriteDocumentElement(desc.name)
176		if err != nil {
177			return err
178		}
179
180		ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
181		err = encoder.EncodeValue(ectx, vw2, rv)
182		if err != nil {
183			return err
184		}
185	}
186
187	if sd.inlineMap >= 0 {
188		rv := val.Field(sd.inlineMap)
189		collisionFn := func(key string) bool {
190			_, exists := sd.fm[key]
191			return exists
192		}
193
194		return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn)
195	}
196
197	return dw.WriteDocumentEnd()
198}
199
200func newDecodeError(key string, original error) error {
201	de, ok := original.(*DecodeError)
202	if !ok {
203		return &DecodeError{
204			keys:    []string{key},
205			wrapped: original,
206		}
207	}
208
209	de.keys = append(de.keys, key)
210	return de
211}
212
213// DecodeValue implements the Codec interface.
214// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
215// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
216func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
217	if !val.CanSet() || val.Kind() != reflect.Struct {
218		return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
219	}
220
221	switch vrType := vr.Type(); vrType {
222	case bsontype.Type(0), bsontype.EmbeddedDocument:
223	case bsontype.Null:
224		if err := vr.ReadNull(); err != nil {
225			return err
226		}
227
228		val.Set(reflect.Zero(val.Type()))
229		return nil
230	case bsontype.Undefined:
231		if err := vr.ReadUndefined(); err != nil {
232			return err
233		}
234
235		val.Set(reflect.Zero(val.Type()))
236		return nil
237	default:
238		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
239	}
240
241	sd, err := sc.describeStruct(r.Registry, val.Type())
242	if err != nil {
243		return err
244	}
245
246	if sc.DecodeZeroStruct {
247		val.Set(reflect.Zero(val.Type()))
248	}
249	if sc.DecodeDeepZeroInline && sd.inline {
250		val.Set(deepZero(val.Type()))
251	}
252
253	var decoder ValueDecoder
254	var inlineMap reflect.Value
255	if sd.inlineMap >= 0 {
256		inlineMap = val.Field(sd.inlineMap)
257		decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
258		if err != nil {
259			return err
260		}
261	}
262
263	dr, err := vr.ReadDocument()
264	if err != nil {
265		return err
266	}
267
268	for {
269		name, vr, err := dr.ReadElement()
270		if err == bsonrw.ErrEOD {
271			break
272		}
273		if err != nil {
274			return err
275		}
276
277		fd, exists := sd.fm[name]
278		if !exists {
279			// if the original name isn't found in the struct description, try again with the name in lowercase
280			// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
281			// names
282			fd, exists = sd.fm[strings.ToLower(name)]
283		}
284
285		if !exists {
286			if sd.inlineMap < 0 {
287				// The encoding/json package requires a flag to return on error for non-existent fields.
288				// This functionality seems appropriate for the struct codec.
289				err = vr.Skip()
290				if err != nil {
291					return err
292				}
293				continue
294			}
295
296			if inlineMap.IsNil() {
297				inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
298			}
299
300			elem := reflect.New(inlineMap.Type().Elem()).Elem()
301			r.Ancestor = inlineMap.Type()
302			err = decoder.DecodeValue(r, vr, elem)
303			if err != nil {
304				return err
305			}
306			inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
307			continue
308		}
309
310		var field reflect.Value
311		if fd.inline == nil {
312			field = val.Field(fd.idx)
313		} else {
314			field, err = getInlineField(val, fd.inline)
315			if err != nil {
316				return err
317			}
318		}
319
320		if !field.CanSet() { // Being settable is a super set of being addressable.
321			innerErr := fmt.Errorf("field %v is not settable", field)
322			return newDecodeError(fd.name, innerErr)
323		}
324		if field.Kind() == reflect.Ptr && field.IsNil() {
325			field.Set(reflect.New(field.Type().Elem()))
326		}
327		field = field.Addr()
328
329		dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate}
330		if fd.decoder == nil {
331			return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()})
332		}
333
334		err = fd.decoder.DecodeValue(dctx, vr, field.Elem())
335		if err != nil {
336			return newDecodeError(fd.name, err)
337		}
338	}
339
340	return nil
341}
342
343func (sc *StructCodec) isZero(i interface{}) bool {
344	v := reflect.ValueOf(i)
345
346	// check the value validity
347	if !v.IsValid() {
348		return true
349	}
350
351	if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
352		return z.IsZero()
353	}
354
355	switch v.Kind() {
356	case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
357		return v.Len() == 0
358	case reflect.Bool:
359		return !v.Bool()
360	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
361		return v.Int() == 0
362	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
363		return v.Uint() == 0
364	case reflect.Float32, reflect.Float64:
365		return v.Float() == 0
366	case reflect.Interface, reflect.Ptr:
367		return v.IsNil()
368	case reflect.Struct:
369		if sc.EncodeOmitDefaultStruct {
370			vt := v.Type()
371			if vt == tTime {
372				return v.Interface().(time.Time).IsZero()
373			}
374			for i := 0; i < v.NumField(); i++ {
375				if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
376					continue // Private field
377				}
378				fld := v.Field(i)
379				if !sc.isZero(fld.Interface()) {
380					return false
381				}
382			}
383			return true
384		}
385	}
386
387	return false
388}
389
390type structDescription struct {
391	fm        map[string]fieldDescription
392	fl        []fieldDescription
393	inlineMap int
394	inline    bool
395}
396
397type fieldDescription struct {
398	name      string // BSON key name
399	fieldName string // struct field name
400	idx       int
401	omitEmpty bool
402	minSize   bool
403	truncate  bool
404	inline    []int
405	encoder   ValueEncoder
406	decoder   ValueDecoder
407}
408
409type byIndex []fieldDescription
410
411func (bi byIndex) Len() int { return len(bi) }
412
413func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
414
415func (bi byIndex) Less(i, j int) bool {
416	// If a field is inlined, its index in the top level struct is stored at inline[0]
417	iIdx, jIdx := bi[i].idx, bi[j].idx
418	if len(bi[i].inline) > 0 {
419		iIdx = bi[i].inline[0]
420	}
421	if len(bi[j].inline) > 0 {
422		jIdx = bi[j].inline[0]
423	}
424	if iIdx != jIdx {
425		return iIdx < jIdx
426	}
427	for k, biik := range bi[i].inline {
428		if k >= len(bi[j].inline) {
429			return false
430		}
431		if biik != bi[j].inline[k] {
432			return biik < bi[j].inline[k]
433		}
434	}
435	return len(bi[i].inline) < len(bi[j].inline)
436}
437
438func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
439	// We need to analyze the struct, including getting the tags, collecting
440	// information about inlining, and create a map of the field name to the field.
441	sc.l.RLock()
442	ds, exists := sc.cache[t]
443	sc.l.RUnlock()
444	if exists {
445		return ds, nil
446	}
447
448	numFields := t.NumField()
449	sd := &structDescription{
450		fm:        make(map[string]fieldDescription, numFields),
451		fl:        make([]fieldDescription, 0, numFields),
452		inlineMap: -1,
453	}
454
455	var fields []fieldDescription
456	for i := 0; i < numFields; i++ {
457		sf := t.Field(i)
458		if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
459			// field is private or unexported fields aren't allowed, ignore
460			continue
461		}
462
463		sfType := sf.Type
464		encoder, err := r.LookupEncoder(sfType)
465		if err != nil {
466			encoder = nil
467		}
468		decoder, err := r.LookupDecoder(sfType)
469		if err != nil {
470			decoder = nil
471		}
472
473		description := fieldDescription{
474			fieldName: sf.Name,
475			idx:       i,
476			encoder:   encoder,
477			decoder:   decoder,
478		}
479
480		stags, err := sc.parser.ParseStructTags(sf)
481		if err != nil {
482			return nil, err
483		}
484		if stags.Skip {
485			continue
486		}
487		description.name = stags.Name
488		description.omitEmpty = stags.OmitEmpty
489		description.minSize = stags.MinSize
490		description.truncate = stags.Truncate
491
492		if stags.Inline {
493			sd.inline = true
494			switch sfType.Kind() {
495			case reflect.Map:
496				if sd.inlineMap >= 0 {
497					return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
498				}
499				if sfType.Key() != tString {
500					return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
501				}
502				sd.inlineMap = description.idx
503			case reflect.Ptr:
504				sfType = sfType.Elem()
505				if sfType.Kind() != reflect.Struct {
506					return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
507				}
508				fallthrough
509			case reflect.Struct:
510				inlinesf, err := sc.describeStruct(r, sfType)
511				if err != nil {
512					return nil, err
513				}
514				for _, fd := range inlinesf.fl {
515					if fd.inline == nil {
516						fd.inline = []int{i, fd.idx}
517					} else {
518						fd.inline = append([]int{i}, fd.inline...)
519					}
520					fields = append(fields, fd)
521
522				}
523			default:
524				return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
525			}
526			continue
527		}
528		fields = append(fields, description)
529	}
530
531	// Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name
532	sort.Slice(fields, func(i, j int) bool {
533		x := fields
534		// sort field by name, breaking ties with depth, then
535		// breaking ties with index sequence.
536		if x[i].name != x[j].name {
537			return x[i].name < x[j].name
538		}
539		if len(x[i].inline) != len(x[j].inline) {
540			return len(x[i].inline) < len(x[j].inline)
541		}
542		return byIndex(x).Less(i, j)
543	})
544
545	for advance, i := 0, 0; i < len(fields); i += advance {
546		// One iteration per name.
547		// Find the sequence of fields with the name of this first field.
548		fi := fields[i]
549		name := fi.name
550		for advance = 1; i+advance < len(fields); advance++ {
551			fj := fields[i+advance]
552			if fj.name != name {
553				break
554			}
555		}
556		if advance == 1 { // Only one field with this name
557			sd.fl = append(sd.fl, fi)
558			sd.fm[name] = fi
559			continue
560		}
561		dominant, ok := dominantField(fields[i : i+advance])
562		if !ok || !sc.OverwriteDuplicatedInlinedFields {
563			return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name)
564		}
565		sd.fl = append(sd.fl, dominant)
566		sd.fm[name] = dominant
567	}
568
569	sort.Sort(byIndex(sd.fl))
570
571	sc.l.Lock()
572	sc.cache[t] = sd
573	sc.l.Unlock()
574
575	return sd, nil
576}
577
578// dominantField looks through the fields, all of which are known to
579// have the same name, to find the single field that dominates the
580// others using Go's inlining rules. If there are multiple top-level
581// fields, the boolean will be false: This condition is an error in Go
582// and we skip all the fields.
583func dominantField(fields []fieldDescription) (fieldDescription, bool) {
584	// The fields are sorted in increasing index-length order, then by presence of tag.
585	// That means that the first field is the dominant one. We need only check
586	// for error cases: two fields at top level.
587	if len(fields) > 1 &&
588		len(fields[0].inline) == len(fields[1].inline) {
589		return fieldDescription{}, false
590	}
591	return fields[0], true
592}
593
594func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
595	defer func() {
596		if recovered := recover(); recovered != nil {
597			switch r := recovered.(type) {
598			case string:
599				err = fmt.Errorf("%s", r)
600			case error:
601				err = r
602			}
603		}
604	}()
605
606	result = v.FieldByIndex(index)
607	return
608}
609
610func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
611	field, err := fieldByIndexErr(val, index)
612	if err == nil {
613		return field, nil
614	}
615
616	// if parent of this element doesn't exist, fix its parent
617	inlineParent := index[:len(index)-1]
618	var fParent reflect.Value
619	if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
620		fParent, err = getInlineField(val, inlineParent)
621		if err != nil {
622			return fParent, err
623		}
624	}
625	fParent.Set(reflect.New(fParent.Type().Elem()))
626
627	return fieldByIndexErr(val, index)
628}
629
630// DeepZero returns recursive zero object
631func deepZero(st reflect.Type) (result reflect.Value) {
632	result = reflect.Indirect(reflect.New(st))
633
634	if result.Kind() == reflect.Struct {
635		for i := 0; i < result.NumField(); i++ {
636			if f := result.Field(i); f.Kind() == reflect.Ptr {
637				if f.CanInterface() {
638					if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
639						result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
640					}
641				}
642			}
643		}
644	}
645
646	return
647}
648
649// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
650func recursivePointerTo(v reflect.Value) reflect.Value {
651	v = reflect.Indirect(v)
652	result := reflect.New(v.Type())
653	if v.Kind() == reflect.Struct {
654		for i := 0; i < v.NumField(); i++ {
655			if f := v.Field(i); f.Kind() == reflect.Ptr {
656				if f.Elem().Kind() == reflect.Struct {
657					result.Elem().Field(i).Set(recursivePointerTo(f))
658				}
659			}
660		}
661	}
662
663	return result
664}
665