1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsoncodec
8
9import (
10	"errors"
11	"fmt"
12	"reflect"
13	"sort"
14	"strings"
15	"sync"
16	"time"
17
18	"go.mongodb.org/mongo-driver/bson/bsonoptions"
19	"go.mongodb.org/mongo-driver/bson/bsonrw"
20	"go.mongodb.org/mongo-driver/bson/bsontype"
21)
22
23// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
24type DecodeError struct {
25	keys    []string
26	wrapped error
27}
28
29// Unwrap returns the underlying error
30func (de *DecodeError) Unwrap() error {
31	return de.wrapped
32}
33
34// Error implements the error interface.
35func (de *DecodeError) Error() string {
36	// The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the
37	// stack of BSON keys, so we call de.Keys(), which reverses them.
38	keyPath := strings.Join(de.Keys(), ".")
39	return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
40}
41
42// Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down
43// order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be
44// a string, the keys slice will be ["a", "b", "c"].
45func (de *DecodeError) Keys() []string {
46	reversedKeys := make([]string, 0, len(de.keys))
47	for idx := len(de.keys) - 1; idx >= 0; idx-- {
48		reversedKeys = append(reversedKeys, de.keys[idx])
49	}
50
51	return reversedKeys
52}
53
54// Zeroer allows custom struct types to implement a report of zero
55// state. All struct types that don't implement Zeroer or where IsZero
56// returns false are considered to be not zero.
57type Zeroer interface {
58	IsZero() bool
59}
60
61// StructCodec is the Codec used for struct values.
62type StructCodec struct {
63	cache                            map[reflect.Type]*structDescription
64	l                                sync.RWMutex
65	parser                           StructTagParser
66	DecodeZeroStruct                 bool
67	DecodeDeepZeroInline             bool
68	EncodeOmitDefaultStruct          bool
69	AllowUnexportedFields            bool
70	OverwriteDuplicatedInlinedFields bool
71}
72
73var _ ValueEncoder = &StructCodec{}
74var _ ValueDecoder = &StructCodec{}
75
76// NewStructCodec returns a StructCodec that uses p for struct tag parsing.
77func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
78	if p == nil {
79		return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
80	}
81
82	structOpt := bsonoptions.MergeStructCodecOptions(opts...)
83
84	codec := &StructCodec{
85		cache:  make(map[reflect.Type]*structDescription),
86		parser: p,
87	}
88
89	if structOpt.DecodeZeroStruct != nil {
90		codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
91	}
92	if structOpt.DecodeDeepZeroInline != nil {
93		codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
94	}
95	if structOpt.EncodeOmitDefaultStruct != nil {
96		codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
97	}
98	if structOpt.OverwriteDuplicatedInlinedFields != nil {
99		codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields
100	}
101	if structOpt.AllowUnexportedFields != nil {
102		codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
103	}
104
105	return codec, nil
106}
107
108// EncodeValue handles encoding generic struct types.
109func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
110	if !val.IsValid() || val.Kind() != reflect.Struct {
111		return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
112	}
113
114	sd, err := sc.describeStruct(r.Registry, val.Type())
115	if err != nil {
116		return err
117	}
118
119	dw, err := vw.WriteDocument()
120	if err != nil {
121		return err
122	}
123	var rv reflect.Value
124	for _, desc := range sd.fl {
125		if desc.inline == nil {
126			rv = val.Field(desc.idx)
127		} else {
128			rv, err = fieldByIndexErr(val, desc.inline)
129			if err != nil {
130				continue
131			}
132		}
133
134		desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv)
135
136		if err != nil && err != errInvalidValue {
137			return err
138		}
139
140		if err == errInvalidValue {
141			if desc.omitEmpty {
142				continue
143			}
144			vw2, err := dw.WriteDocumentElement(desc.name)
145			if err != nil {
146				return err
147			}
148			err = vw2.WriteNull()
149			if err != nil {
150				return err
151			}
152			continue
153		}
154
155		if desc.encoder == nil {
156			return ErrNoEncoder{Type: rv.Type()}
157		}
158
159		encoder := desc.encoder
160
161		var isZero bool
162		rvInterface := rv.Interface()
163		if cz, ok := encoder.(CodecZeroer); ok {
164			isZero = cz.IsTypeZero(rvInterface)
165		} else if rv.Kind() == reflect.Interface {
166			// sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately.
167			isZero = rv.IsNil()
168		} else {
169			isZero = sc.isZero(rvInterface)
170		}
171		if desc.omitEmpty && isZero {
172			continue
173		}
174
175		vw2, err := dw.WriteDocumentElement(desc.name)
176		if err != nil {
177			return err
178		}
179
180		ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
181		err = encoder.EncodeValue(ectx, vw2, rv)
182		if err != nil {
183			return err
184		}
185	}
186
187	if sd.inlineMap >= 0 {
188		rv := val.Field(sd.inlineMap)
189		collisionFn := func(key string) bool {
190			_, exists := sd.fm[key]
191			return exists
192		}
193
194		return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn)
195	}
196
197	return dw.WriteDocumentEnd()
198}
199
200func newDecodeError(key string, original error) error {
201	de, ok := original.(*DecodeError)
202	if !ok {
203		return &DecodeError{
204			keys:    []string{key},
205			wrapped: original,
206		}
207	}
208
209	de.keys = append(de.keys, key)
210	return de
211}
212
213// DecodeValue implements the Codec interface.
214// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
215// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
216func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
217	if !val.CanSet() || val.Kind() != reflect.Struct {
218		return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
219	}
220
221	switch vrType := vr.Type(); vrType {
222	case bsontype.Type(0), bsontype.EmbeddedDocument:
223	case bsontype.Null:
224		if err := vr.ReadNull(); err != nil {
225			return err
226		}
227
228		val.Set(reflect.Zero(val.Type()))
229		return nil
230	case bsontype.Undefined:
231		if err := vr.ReadUndefined(); err != nil {
232			return err
233		}
234
235		val.Set(reflect.Zero(val.Type()))
236		return nil
237	default:
238		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
239	}
240
241	sd, err := sc.describeStruct(r.Registry, val.Type())
242	if err != nil {
243		return err
244	}
245
246	if sc.DecodeZeroStruct {
247		val.Set(reflect.Zero(val.Type()))
248	}
249	if sc.DecodeDeepZeroInline && sd.inline {
250		val.Set(deepZero(val.Type()))
251	}
252
253	var decoder ValueDecoder
254	var inlineMap reflect.Value
255	if sd.inlineMap >= 0 {
256		inlineMap = val.Field(sd.inlineMap)
257		decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
258		if err != nil {
259			return err
260		}
261	}
262
263	dr, err := vr.ReadDocument()
264	if err != nil {
265		return err
266	}
267
268	for {
269		name, vr, err := dr.ReadElement()
270		if err == bsonrw.ErrEOD {
271			break
272		}
273		if err != nil {
274			return err
275		}
276
277		fd, exists := sd.fm[name]
278		if !exists {
279			// if the original name isn't found in the struct description, try again with the name in lowercase
280			// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
281			// names
282			fd, exists = sd.fm[strings.ToLower(name)]
283		}
284
285		if !exists {
286			if sd.inlineMap < 0 {
287				// The encoding/json package requires a flag to return on error for non-existent fields.
288				// This functionality seems appropriate for the struct codec.
289				err = vr.Skip()
290				if err != nil {
291					return err
292				}
293				continue
294			}
295
296			if inlineMap.IsNil() {
297				inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
298			}
299
300			elem := reflect.New(inlineMap.Type().Elem()).Elem()
301			r.Ancestor = inlineMap.Type()
302			err = decoder.DecodeValue(r, vr, elem)
303			if err != nil {
304				return err
305			}
306			inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
307			continue
308		}
309
310		var field reflect.Value
311		if fd.inline == nil {
312			field = val.Field(fd.idx)
313		} else {
314			field, err = getInlineField(val, fd.inline)
315			if err != nil {
316				return err
317			}
318		}
319
320		if !field.CanSet() { // Being settable is a super set of being addressable.
321			innerErr := fmt.Errorf("field %v is not settable", field)
322			return newDecodeError(fd.name, innerErr)
323		}
324		if field.Kind() == reflect.Ptr && field.IsNil() {
325			field.Set(reflect.New(field.Type().Elem()))
326		}
327		field = field.Addr()
328
329		dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate}
330		if fd.decoder == nil {
331			return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()})
332		}
333
334		if decoder, ok := fd.decoder.(ValueDecoder); ok {
335			err = decoder.DecodeValue(dctx, vr, field.Elem())
336			if err != nil {
337				return newDecodeError(fd.name, err)
338			}
339			continue
340		}
341		err = fd.decoder.DecodeValue(dctx, vr, field)
342		if err != nil {
343			return newDecodeError(fd.name, err)
344		}
345	}
346
347	return nil
348}
349
350func (sc *StructCodec) isZero(i interface{}) bool {
351	v := reflect.ValueOf(i)
352
353	// check the value validity
354	if !v.IsValid() {
355		return true
356	}
357
358	if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
359		return z.IsZero()
360	}
361
362	switch v.Kind() {
363	case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
364		return v.Len() == 0
365	case reflect.Bool:
366		return !v.Bool()
367	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
368		return v.Int() == 0
369	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
370		return v.Uint() == 0
371	case reflect.Float32, reflect.Float64:
372		return v.Float() == 0
373	case reflect.Interface, reflect.Ptr:
374		return v.IsNil()
375	case reflect.Struct:
376		if sc.EncodeOmitDefaultStruct {
377			vt := v.Type()
378			if vt == tTime {
379				return v.Interface().(time.Time).IsZero()
380			}
381			for i := 0; i < v.NumField(); i++ {
382				if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
383					continue // Private field
384				}
385				fld := v.Field(i)
386				if !sc.isZero(fld.Interface()) {
387					return false
388				}
389			}
390			return true
391		}
392	}
393
394	return false
395}
396
397type structDescription struct {
398	fm        map[string]fieldDescription
399	fl        []fieldDescription
400	inlineMap int
401	inline    bool
402}
403
404type fieldDescription struct {
405	name      string // BSON key name
406	fieldName string // struct field name
407	idx       int
408	omitEmpty bool
409	minSize   bool
410	truncate  bool
411	inline    []int
412	encoder   ValueEncoder
413	decoder   ValueDecoder
414}
415
416type byIndex []fieldDescription
417
418func (bi byIndex) Len() int { return len(bi) }
419
420func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
421
422func (bi byIndex) Less(i, j int) bool {
423	// If a field is inlined, its index in the top level struct is stored at inline[0]
424	iIdx, jIdx := bi[i].idx, bi[j].idx
425	if len(bi[i].inline) > 0 {
426		iIdx = bi[i].inline[0]
427	}
428	if len(bi[j].inline) > 0 {
429		jIdx = bi[j].inline[0]
430	}
431	if iIdx != jIdx {
432		return iIdx < jIdx
433	}
434	for k, biik := range bi[i].inline {
435		if k >= len(bi[j].inline) {
436			return false
437		}
438		if biik != bi[j].inline[k] {
439			return biik < bi[j].inline[k]
440		}
441	}
442	return len(bi[i].inline) < len(bi[j].inline)
443}
444
445func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
446	// We need to analyze the struct, including getting the tags, collecting
447	// information about inlining, and create a map of the field name to the field.
448	sc.l.RLock()
449	ds, exists := sc.cache[t]
450	sc.l.RUnlock()
451	if exists {
452		return ds, nil
453	}
454
455	numFields := t.NumField()
456	sd := &structDescription{
457		fm:        make(map[string]fieldDescription, numFields),
458		fl:        make([]fieldDescription, 0, numFields),
459		inlineMap: -1,
460	}
461
462	var fields []fieldDescription
463	for i := 0; i < numFields; i++ {
464		sf := t.Field(i)
465		if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
466			// field is private or unexported fields aren't allowed, ignore
467			continue
468		}
469
470		sfType := sf.Type
471		encoder, err := r.LookupEncoder(sfType)
472		if err != nil {
473			encoder = nil
474		}
475		decoder, err := r.LookupDecoder(sfType)
476		if err != nil {
477			decoder = nil
478		}
479
480		description := fieldDescription{
481			fieldName: sf.Name,
482			idx:       i,
483			encoder:   encoder,
484			decoder:   decoder,
485		}
486
487		stags, err := sc.parser.ParseStructTags(sf)
488		if err != nil {
489			return nil, err
490		}
491		if stags.Skip {
492			continue
493		}
494		description.name = stags.Name
495		description.omitEmpty = stags.OmitEmpty
496		description.minSize = stags.MinSize
497		description.truncate = stags.Truncate
498
499		if stags.Inline {
500			sd.inline = true
501			switch sfType.Kind() {
502			case reflect.Map:
503				if sd.inlineMap >= 0 {
504					return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
505				}
506				if sfType.Key() != tString {
507					return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
508				}
509				sd.inlineMap = description.idx
510			case reflect.Ptr:
511				sfType = sfType.Elem()
512				if sfType.Kind() != reflect.Struct {
513					return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
514				}
515				fallthrough
516			case reflect.Struct:
517				inlinesf, err := sc.describeStruct(r, sfType)
518				if err != nil {
519					return nil, err
520				}
521				for _, fd := range inlinesf.fl {
522					if fd.inline == nil {
523						fd.inline = []int{i, fd.idx}
524					} else {
525						fd.inline = append([]int{i}, fd.inline...)
526					}
527					fields = append(fields, fd)
528
529				}
530			default:
531				return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
532			}
533			continue
534		}
535		fields = append(fields, description)
536	}
537
538	// Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name
539	sort.Slice(fields, func(i, j int) bool {
540		x := fields
541		// sort field by name, breaking ties with depth, then
542		// breaking ties with index sequence.
543		if x[i].name != x[j].name {
544			return x[i].name < x[j].name
545		}
546		if len(x[i].inline) != len(x[j].inline) {
547			return len(x[i].inline) < len(x[j].inline)
548		}
549		return byIndex(x).Less(i, j)
550	})
551
552	for advance, i := 0, 0; i < len(fields); i += advance {
553		// One iteration per name.
554		// Find the sequence of fields with the name of this first field.
555		fi := fields[i]
556		name := fi.name
557		for advance = 1; i+advance < len(fields); advance++ {
558			fj := fields[i+advance]
559			if fj.name != name {
560				break
561			}
562		}
563		if advance == 1 { // Only one field with this name
564			sd.fl = append(sd.fl, fi)
565			sd.fm[name] = fi
566			continue
567		}
568		dominant, ok := dominantField(fields[i : i+advance])
569		if !ok || !sc.OverwriteDuplicatedInlinedFields {
570			return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), name)
571		}
572		sd.fl = append(sd.fl, dominant)
573		sd.fm[name] = dominant
574	}
575
576	sort.Sort(byIndex(sd.fl))
577
578	sc.l.Lock()
579	sc.cache[t] = sd
580	sc.l.Unlock()
581
582	return sd, nil
583}
584
585// dominantField looks through the fields, all of which are known to
586// have the same name, to find the single field that dominates the
587// others using Go's inlining rules. If there are multiple top-level
588// fields, the boolean will be false: This condition is an error in Go
589// and we skip all the fields.
590func dominantField(fields []fieldDescription) (fieldDescription, bool) {
591	// The fields are sorted in increasing index-length order, then by presence of tag.
592	// That means that the first field is the dominant one. We need only check
593	// for error cases: two fields at top level.
594	if len(fields) > 1 &&
595		len(fields[0].inline) == len(fields[1].inline) {
596		return fieldDescription{}, false
597	}
598	return fields[0], true
599}
600
601func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
602	defer func() {
603		if recovered := recover(); recovered != nil {
604			switch r := recovered.(type) {
605			case string:
606				err = fmt.Errorf("%s", r)
607			case error:
608				err = r
609			}
610		}
611	}()
612
613	result = v.FieldByIndex(index)
614	return
615}
616
617func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
618	field, err := fieldByIndexErr(val, index)
619	if err == nil {
620		return field, nil
621	}
622
623	// if parent of this element doesn't exist, fix its parent
624	inlineParent := index[:len(index)-1]
625	var fParent reflect.Value
626	if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
627		fParent, err = getInlineField(val, inlineParent)
628		if err != nil {
629			return fParent, err
630		}
631	}
632	fParent.Set(reflect.New(fParent.Type().Elem()))
633
634	return fieldByIndexErr(val, index)
635}
636
637// DeepZero returns recursive zero object
638func deepZero(st reflect.Type) (result reflect.Value) {
639	result = reflect.Indirect(reflect.New(st))
640
641	if result.Kind() == reflect.Struct {
642		for i := 0; i < result.NumField(); i++ {
643			if f := result.Field(i); f.Kind() == reflect.Ptr {
644				if f.CanInterface() {
645					if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
646						result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
647					}
648				}
649			}
650		}
651	}
652
653	return
654}
655
656// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
657func recursivePointerTo(v reflect.Value) reflect.Value {
658	v = reflect.Indirect(v)
659	result := reflect.New(v.Type())
660	if v.Kind() == reflect.Struct {
661		for i := 0; i < v.NumField(); i++ {
662			if f := v.Field(i); f.Kind() == reflect.Ptr {
663				if f.Elem().Kind() == reflect.Struct {
664					result.Elem().Field(i).Set(recursivePointerTo(f))
665				}
666			}
667		}
668	}
669
670	return result
671}
672