1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bson
8
9import (
10	"reflect"
11	"testing"
12
13	"github.com/google/go-cmp/cmp"
14	"go.mongodb.org/mongo-driver/bson/bsoncodec"
15	"go.mongodb.org/mongo-driver/bson/bsonrw"
16	"go.mongodb.org/mongo-driver/internal/testutil/assert"
17	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
18)
19
20func TestUnmarshal(t *testing.T) {
21	for _, tc := range unmarshalingTestCases {
22		t.Run(tc.name, func(t *testing.T) {
23			if tc.reg != nil {
24				t.Skip() // test requires custom registry
25			}
26			got := reflect.New(tc.sType).Interface()
27			err := Unmarshal(tc.data, got)
28			noerr(t, err)
29			if !cmp.Equal(got, tc.want) {
30				t.Errorf("Did not unmarshal as expected. got %v; want %v", got, tc.want)
31			}
32		})
33	}
34}
35
36func TestUnmarshalWithRegistry(t *testing.T) {
37	for _, tc := range unmarshalingTestCases {
38		t.Run(tc.name, func(t *testing.T) {
39			var reg *bsoncodec.Registry
40			if tc.reg != nil {
41				reg = tc.reg
42			} else {
43				reg = DefaultRegistry
44			}
45			got := reflect.New(tc.sType).Interface()
46			err := UnmarshalWithRegistry(reg, tc.data, got)
47			noerr(t, err)
48			if !cmp.Equal(got, tc.want) {
49				t.Errorf("Did not unmarshal as expected. got %v; want %v", got, tc.want)
50			}
51		})
52	}
53}
54
55func TestUnmarshalWithContext(t *testing.T) {
56	for _, tc := range unmarshalingTestCases {
57		t.Run(tc.name, func(t *testing.T) {
58			var reg *bsoncodec.Registry
59			if tc.reg != nil {
60				reg = tc.reg
61			} else {
62				reg = DefaultRegistry
63			}
64			dc := bsoncodec.DecodeContext{Registry: reg}
65			got := reflect.New(tc.sType).Interface()
66			err := UnmarshalWithContext(dc, tc.data, got)
67			noerr(t, err)
68			if !cmp.Equal(got, tc.want) {
69				t.Errorf("Did not unmarshal as expected. got %v; want %v", got, tc.want)
70			}
71		})
72	}
73}
74
75func TestUnmarshalExtJSONWithRegistry(t *testing.T) {
76	t.Run("UnmarshalExtJSONWithContext", func(t *testing.T) {
77		type teststruct struct{ Foo int }
78		var got teststruct
79		data := []byte("{\"foo\":1}")
80		err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &got)
81		noerr(t, err)
82		want := teststruct{1}
83		if !cmp.Equal(got, want) {
84			t.Errorf("Did not unmarshal as expected. got %v; want %v", got, want)
85		}
86	})
87
88	t.Run("UnmarshalExtJSONInvalidInput", func(t *testing.T) {
89		data := []byte("invalid")
90		err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &M{})
91		if err != bsonrw.ErrInvalidJSON {
92			t.Fatalf("wanted ErrInvalidJSON, got %v", err)
93		}
94	})
95}
96
97func TestUnmarshalExtJSONWithContext(t *testing.T) {
98	t.Run("UnmarshalExtJSONWithContext", func(t *testing.T) {
99		type teststruct struct{ Foo int }
100		var got teststruct
101		data := []byte("{\"foo\":1}")
102		dc := bsoncodec.DecodeContext{Registry: DefaultRegistry}
103		err := UnmarshalExtJSONWithContext(dc, data, true, &got)
104		noerr(t, err)
105		want := teststruct{1}
106		if !cmp.Equal(got, want) {
107			t.Errorf("Did not unmarshal as expected. got %v; want %v", got, want)
108		}
109	})
110}
111
112func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) {
113	// Decoders that have caches for recursive decoder lookup should not be shared across Registry instances. Otherwise,
114	// the first DecodeValue call would cache an decoder and a subsequent call would see that decoder even if a
115	// different Registry is used.
116
117	// Create a custom Registry that negates BSON int32 values when decoding.
118	var decodeInt32 bsoncodec.ValueDecoderFunc = func(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
119		i32, err := vr.ReadInt32()
120		if err != nil {
121			return err
122		}
123
124		val.SetInt(int64(-1 * i32))
125		return nil
126	}
127	customReg := NewRegistryBuilder().
128		RegisterTypeDecoder(tInt32, bsoncodec.ValueDecoderFunc(decodeInt32)).
129		Build()
130
131	docBytes := bsoncore.BuildDocumentFromElements(
132		nil,
133		bsoncore.AppendInt32Element(nil, "x", 1),
134	)
135
136	// For all sub-tests, unmarshal docBytes into a struct and assert that value for "x" is 1 when using the default
137	// registry and -1 when using the custom registry.
138	t.Run("struct", func(t *testing.T) {
139		type Struct struct {
140			X int32
141		}
142
143		var first Struct
144		err := Unmarshal(docBytes, &first)
145		assert.Nil(t, err, "Unmarshal error: %v", err)
146		assert.Equal(t, int32(1), first.X, "expected X value to be 1, got %v", first.X)
147
148		var second Struct
149		err = UnmarshalWithRegistry(customReg, docBytes, &second)
150		assert.Nil(t, err, "Unmarshal error: %v", err)
151		assert.Equal(t, int32(-1), second.X, "expected X value to be -1, got %v", second.X)
152	})
153	t.Run("pointer", func(t *testing.T) {
154		type Struct struct {
155			X *int32
156		}
157
158		var first Struct
159		err := Unmarshal(docBytes, &first)
160		assert.Nil(t, err, "Unmarshal error: %v", err)
161		assert.Equal(t, int32(1), *first.X, "expected X value to be 1, got %v", *first.X)
162
163		var second Struct
164		err = UnmarshalWithRegistry(customReg, docBytes, &second)
165		assert.Nil(t, err, "Unmarshal error: %v", err)
166		assert.Equal(t, int32(-1), *second.X, "expected X value to be -1, got %v", *second.X)
167	})
168}
169
170func TestUnmarshalExtJSONWithUndefinedField(t *testing.T) {
171	// When unmarshalling, fields that are undefined in the destination struct are skipped.
172	// This process must not skip other, defined fields and must not raise errors.
173	type expectedResponse struct {
174		DefinedField string
175	}
176
177	unmarshalExpectedResponse := func(t *testing.T, extJSON string) *expectedResponse {
178		t.Helper()
179		responseDoc := expectedResponse{}
180		err := UnmarshalExtJSON([]byte(extJSON), false, &responseDoc)
181		assert.Nil(t, err, "UnmarshalExtJSON error: %v", err)
182		return &responseDoc
183	}
184
185	testCases := []struct {
186		name     string
187		testJSON string
188	}{
189		{
190			"no array",
191			`{
192				"UndefinedField": {"key": 1},
193				"DefinedField": "value"
194			}`,
195		},
196		{
197			"outer array",
198			`{
199				"UndefinedField": [{"key": 1}],
200				"DefinedField": "value"
201			}`,
202		},
203		{
204			"embedded array",
205			`{
206				"UndefinedField": {"keys": [2]},
207				"DefinedField": "value"
208			}`,
209		},
210		{
211			"outer array and embedded array",
212			`{
213				"UndefinedField": [{"keys": [2]}],
214				"DefinedField": "value"
215			}`,
216		},
217		{
218			"embedded document",
219			`{
220				"UndefinedField": {"key": {"one": "two"}},
221				"DefinedField": "value"
222			}`,
223		},
224		{
225			"doubly embedded document",
226			`{
227				"UndefinedField": {"key": {"one": {"two": "three"}}},
228				"DefinedField": "value"
229			}`,
230		},
231		{
232			"embedded document and embedded array",
233			`{
234				"UndefinedField": {"key": {"one": {"two": [3]}}},
235				"DefinedField": "value"
236			}`,
237		},
238		{
239			"embedded document and embedded array in outer array",
240			`{
241				"UndefinedField": [{"key": {"one": [3]}}],
242				"DefinedField": "value"
243			}`,
244		},
245		{
246			"code with scope",
247			`{
248				"UndefinedField": {"logic": {"$code": "foo", "$scope": {"bar": 1}}},
249				"DefinedField": "value"
250			}`,
251		},
252		{
253			"embedded array of code with scope",
254			`{
255				"UndefinedField": {"logic": [{"$code": "foo", "$scope": {"bar": 1}}]},
256				"DefinedField": "value"
257			}`,
258		},
259		{
260			"type definition embedded document",
261			`{
262				"UndefinedField": {"myDouble": {"$numberDouble": "1.24"}},
263				"DefinedField": "value"
264			}`,
265		},
266		{
267			"empty embedded document",
268			`{
269				"UndefinedField": {"empty": {}},
270				"DefinedField": "value"
271			}`,
272		},
273	}
274	for _, tc := range testCases {
275		t.Run(tc.name, func(t *testing.T) {
276			responseDoc := unmarshalExpectedResponse(t, tc.testJSON)
277			assert.Equal(t, "value", responseDoc.DefinedField, "expected DefinedField to be 'value', got %q", responseDoc.DefinedField)
278		})
279	}
280}
281