1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsoncodec
8
9import (
10	"errors"
11	"fmt"
12	"reflect"
13	"strings"
14	"sync"
15	"time"
16
17	"go.mongodb.org/mongo-driver/bson/bsonoptions"
18	"go.mongodb.org/mongo-driver/bson/bsonrw"
19	"go.mongodb.org/mongo-driver/bson/bsontype"
20)
21
22var defaultStructCodec = &StructCodec{
23	cache:  make(map[reflect.Type]*structDescription),
24	parser: DefaultStructTagParser,
25}
26
27// Zeroer allows custom struct types to implement a report of zero
28// state. All struct types that don't implement Zeroer or where IsZero
29// returns false are considered to be not zero.
30type Zeroer interface {
31	IsZero() bool
32}
33
34// StructCodec is the Codec used for struct values.
35type StructCodec struct {
36	cache                   map[reflect.Type]*structDescription
37	l                       sync.RWMutex
38	parser                  StructTagParser
39	DecodeZeroStruct        bool
40	DecodeDeepZeroInline    bool
41	EncodeOmitDefaultStruct bool
42	AllowUnexportedFields   bool
43}
44
45var _ ValueEncoder = &StructCodec{}
46var _ ValueDecoder = &StructCodec{}
47
48// NewStructCodec returns a StructCodec that uses p for struct tag parsing.
49func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
50	if p == nil {
51		return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
52	}
53
54	structOpt := bsonoptions.MergeStructCodecOptions(opts...)
55
56	codec := &StructCodec{
57		cache:  make(map[reflect.Type]*structDescription),
58		parser: p,
59	}
60
61	if structOpt.DecodeZeroStruct != nil {
62		codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
63	}
64	if structOpt.DecodeDeepZeroInline != nil {
65		codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
66	}
67	if structOpt.EncodeOmitDefaultStruct != nil {
68		codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
69	}
70	if structOpt.AllowUnexportedFields != nil {
71		codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
72	}
73
74	return codec, nil
75}
76
77// EncodeValue handles encoding generic struct types.
78func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
79	if !val.IsValid() || val.Kind() != reflect.Struct {
80		return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
81	}
82
83	sd, err := sc.describeStruct(r.Registry, val.Type())
84	if err != nil {
85		return err
86	}
87
88	dw, err := vw.WriteDocument()
89	if err != nil {
90		return err
91	}
92	var rv reflect.Value
93	for _, desc := range sd.fl {
94		if desc.inline == nil {
95			rv = val.Field(desc.idx)
96		} else {
97			rv, err = fieldByIndexErr(val, desc.inline)
98			if err != nil {
99				continue
100			}
101		}
102
103		desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv)
104
105		if err != nil && err != errInvalidValue {
106			return err
107		}
108
109		if err == errInvalidValue {
110			if desc.omitEmpty {
111				continue
112			}
113			vw2, err := dw.WriteDocumentElement(desc.name)
114			if err != nil {
115				return err
116			}
117			err = vw2.WriteNull()
118			if err != nil {
119				return err
120			}
121			continue
122		}
123
124		if desc.encoder == nil {
125			return ErrNoEncoder{Type: rv.Type()}
126		}
127
128		encoder := desc.encoder
129
130		var isZero bool
131		rvInterface := rv.Interface()
132		if cz, ok := encoder.(CodecZeroer); ok {
133			isZero = cz.IsTypeZero(rvInterface)
134		} else if rv.Kind() == reflect.Interface {
135			// sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately.
136			isZero = rv.IsNil()
137		} else {
138			isZero = sc.isZero(rvInterface)
139		}
140		if desc.omitEmpty && isZero {
141			continue
142		}
143
144		vw2, err := dw.WriteDocumentElement(desc.name)
145		if err != nil {
146			return err
147		}
148
149		ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
150		err = encoder.EncodeValue(ectx, vw2, rv)
151		if err != nil {
152			return err
153		}
154	}
155
156	if sd.inlineMap >= 0 {
157		rv := val.Field(sd.inlineMap)
158		collisionFn := func(key string) bool {
159			_, exists := sd.fm[key]
160			return exists
161		}
162
163		return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn)
164	}
165
166	return dw.WriteDocumentEnd()
167}
168
169// DecodeValue implements the Codec interface.
170// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
171// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
172func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
173	if !val.CanSet() || val.Kind() != reflect.Struct {
174		return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
175	}
176
177	switch vr.Type() {
178	case bsontype.Type(0), bsontype.EmbeddedDocument:
179	default:
180		return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type())
181	}
182
183	sd, err := sc.describeStruct(r.Registry, val.Type())
184	if err != nil {
185		return err
186	}
187
188	if sc.DecodeZeroStruct {
189		val.Set(reflect.Zero(val.Type()))
190	}
191	if sc.DecodeDeepZeroInline && sd.inline {
192		val.Set(deepZero(val.Type()))
193	}
194
195	var decoder ValueDecoder
196	var inlineMap reflect.Value
197	if sd.inlineMap >= 0 {
198		inlineMap = val.Field(sd.inlineMap)
199		decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
200		if err != nil {
201			return err
202		}
203	}
204
205	dr, err := vr.ReadDocument()
206	if err != nil {
207		return err
208	}
209
210	for {
211		name, vr, err := dr.ReadElement()
212		if err == bsonrw.ErrEOD {
213			break
214		}
215		if err != nil {
216			return err
217		}
218
219		fd, exists := sd.fm[name]
220		if !exists {
221			// if the original name isn't found in the struct description, try again with the name in lowercase
222			// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
223			// names
224			fd, exists = sd.fm[strings.ToLower(name)]
225		}
226
227		if !exists {
228			if sd.inlineMap < 0 {
229				// The encoding/json package requires a flag to return on error for non-existent fields.
230				// This functionality seems appropriate for the struct codec.
231				err = vr.Skip()
232				if err != nil {
233					return err
234				}
235				continue
236			}
237
238			if inlineMap.IsNil() {
239				inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
240			}
241
242			elem := reflect.New(inlineMap.Type().Elem()).Elem()
243			err = decoder.DecodeValue(r, vr, elem)
244			if err != nil {
245				return err
246			}
247			inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
248			continue
249		}
250
251		var field reflect.Value
252		if fd.inline == nil {
253			field = val.Field(fd.idx)
254		} else {
255			field, err = getInlineField(val, fd.inline)
256			if err != nil {
257				return err
258			}
259		}
260
261		if !field.CanSet() { // Being settable is a super set of being addressable.
262			return fmt.Errorf("cannot decode element '%s' into field %v; it is not settable", name, field)
263		}
264		if field.Kind() == reflect.Ptr && field.IsNil() {
265			field.Set(reflect.New(field.Type().Elem()))
266		}
267		field = field.Addr()
268
269		dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate}
270		if fd.decoder == nil {
271			return ErrNoDecoder{Type: field.Elem().Type()}
272		}
273
274		if decoder, ok := fd.decoder.(ValueDecoder); ok {
275			err = decoder.DecodeValue(dctx, vr, field.Elem())
276			if err != nil {
277				return err
278			}
279			continue
280		}
281		err = fd.decoder.DecodeValue(dctx, vr, field)
282		if err != nil {
283			return err
284		}
285	}
286
287	return nil
288}
289
290func (sc *StructCodec) isZero(i interface{}) bool {
291	v := reflect.ValueOf(i)
292
293	// check the value validity
294	if !v.IsValid() {
295		return true
296	}
297
298	if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
299		return z.IsZero()
300	}
301
302	switch v.Kind() {
303	case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
304		return v.Len() == 0
305	case reflect.Bool:
306		return !v.Bool()
307	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
308		return v.Int() == 0
309	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
310		return v.Uint() == 0
311	case reflect.Float32, reflect.Float64:
312		return v.Float() == 0
313	case reflect.Interface, reflect.Ptr:
314		return v.IsNil()
315	case reflect.Struct:
316		if sc.EncodeOmitDefaultStruct {
317			vt := v.Type()
318			if vt == tTime {
319				return v.Interface().(time.Time).IsZero()
320			}
321			for i := 0; i < v.NumField(); i++ {
322				if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
323					continue // Private field
324				}
325				fld := v.Field(i)
326				if !sc.isZero(fld.Interface()) {
327					return false
328				}
329			}
330			return true
331		}
332	}
333
334	return false
335}
336
337type structDescription struct {
338	fm        map[string]fieldDescription
339	fl        []fieldDescription
340	inlineMap int
341	inline    bool
342}
343
344type fieldDescription struct {
345	name      string
346	idx       int
347	omitEmpty bool
348	minSize   bool
349	truncate  bool
350	inline    []int
351	encoder   ValueEncoder
352	decoder   ValueDecoder
353}
354
355func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
356	// We need to analyze the struct, including getting the tags, collecting
357	// information about inlining, and create a map of the field name to the field.
358	sc.l.RLock()
359	ds, exists := sc.cache[t]
360	sc.l.RUnlock()
361	if exists {
362		return ds, nil
363	}
364
365	numFields := t.NumField()
366	sd := &structDescription{
367		fm:        make(map[string]fieldDescription, numFields),
368		fl:        make([]fieldDescription, 0, numFields),
369		inlineMap: -1,
370	}
371
372	for i := 0; i < numFields; i++ {
373		sf := t.Field(i)
374		if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
375			// field is private or unexported fields aren't allowed, ignore
376			continue
377		}
378
379		sfType := sf.Type
380		encoder, err := r.LookupEncoder(sfType)
381		if err != nil {
382			encoder = nil
383		}
384		decoder, err := r.LookupDecoder(sfType)
385		if err != nil {
386			decoder = nil
387		}
388
389		description := fieldDescription{idx: i, encoder: encoder, decoder: decoder}
390
391		stags, err := sc.parser.ParseStructTags(sf)
392		if err != nil {
393			return nil, err
394		}
395		if stags.Skip {
396			continue
397		}
398		description.name = stags.Name
399		description.omitEmpty = stags.OmitEmpty
400		description.minSize = stags.MinSize
401		description.truncate = stags.Truncate
402
403		if stags.Inline {
404			sd.inline = true
405			switch sfType.Kind() {
406			case reflect.Map:
407				if sd.inlineMap >= 0 {
408					return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
409				}
410				if sfType.Key() != tString {
411					return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
412				}
413				sd.inlineMap = description.idx
414			case reflect.Ptr:
415				sfType = sfType.Elem()
416				if sfType.Kind() != reflect.Struct {
417					return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
418				}
419				fallthrough
420			case reflect.Struct:
421				inlinesf, err := sc.describeStruct(r, sfType)
422				if err != nil {
423					return nil, err
424				}
425				for _, fd := range inlinesf.fl {
426					if _, exists := sd.fm[fd.name]; exists {
427						return nil, fmt.Errorf("(struct %s) duplicated key %s", t.String(), fd.name)
428					}
429					if fd.inline == nil {
430						fd.inline = []int{i, fd.idx}
431					} else {
432						fd.inline = append([]int{i}, fd.inline...)
433					}
434					sd.fm[fd.name] = fd
435					sd.fl = append(sd.fl, fd)
436				}
437			default:
438				return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
439			}
440			continue
441		}
442
443		if _, exists := sd.fm[description.name]; exists {
444			return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), description.name)
445		}
446
447		sd.fm[description.name] = description
448		sd.fl = append(sd.fl, description)
449	}
450
451	sc.l.Lock()
452	sc.cache[t] = sd
453	sc.l.Unlock()
454
455	return sd, nil
456}
457
458func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
459	defer func() {
460		if recovered := recover(); recovered != nil {
461			switch r := recovered.(type) {
462			case string:
463				err = fmt.Errorf("%s", r)
464			case error:
465				err = r
466			}
467		}
468	}()
469
470	result = v.FieldByIndex(index)
471	return
472}
473
474func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
475	field, err := fieldByIndexErr(val, index)
476	if err == nil {
477		return field, nil
478	}
479
480	// if parent of this element doesn't exist, fix its parent
481	inlineParent := index[:len(index)-1]
482	var fParent reflect.Value
483	if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
484		fParent, err = getInlineField(val, inlineParent)
485		if err != nil {
486			return fParent, err
487		}
488	}
489	fParent.Set(reflect.New(fParent.Type().Elem()))
490
491	return fieldByIndexErr(val, index)
492}
493
494// DeepZero returns recursive zero object
495func deepZero(st reflect.Type) (result reflect.Value) {
496	result = reflect.Indirect(reflect.New(st))
497
498	if result.Kind() == reflect.Struct {
499		for i := 0; i < result.NumField(); i++ {
500			if f := result.Field(i); f.Kind() == reflect.Ptr {
501				if f.CanInterface() {
502					if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
503						result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
504					}
505				}
506			}
507		}
508	}
509
510	return
511}
512
513// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
514func recursivePointerTo(v reflect.Value) reflect.Value {
515	v = reflect.Indirect(v)
516	result := reflect.New(v.Type())
517	if v.Kind() == reflect.Struct {
518		for i := 0; i < v.NumField(); i++ {
519			if f := v.Field(i); f.Kind() == reflect.Ptr {
520				if f.Elem().Kind() == reflect.Struct {
521					result.Elem().Field(i).Set(recursivePointerTo(f))
522				}
523			}
524		}
525	}
526
527	return result
528}
529