1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"errors"
11	"fmt"
12	"testing"
13
14	"go.mongodb.org/mongo-driver/bson"
15	"go.mongodb.org/mongo-driver/bson/bsoncodec"
16	"go.mongodb.org/mongo-driver/bson/bsontype"
17	"go.mongodb.org/mongo-driver/bson/primitive"
18	"go.mongodb.org/mongo-driver/internal/testutil/assert"
19	"go.mongodb.org/mongo-driver/x/bsonx"
20	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
21)
22
23func TestMongoHelpers(t *testing.T) {
24	t.Run("transform document", func(t *testing.T) {
25		testCases := []struct {
26			name     string
27			document interface{}
28			want     bsonx.Doc
29			err      error
30		}{
31			{
32				"bson.Marshaler",
33				bMarsh{bsonx.Doc{{"foo", bsonx.String("bar")}}},
34				bsonx.Doc{{"foo", bsonx.String("bar")}},
35				nil,
36			},
37			{
38				"reflection",
39				reflectStruct{Foo: "bar"},
40				bsonx.Doc{{"foo", bsonx.String("bar")}},
41				nil,
42			},
43			{
44				"reflection pointer",
45				&reflectStruct{Foo: "bar"},
46				bsonx.Doc{{"foo", bsonx.String("bar")}},
47				nil,
48			},
49			{
50				"unsupported type",
51				[]string{"foo", "bar"},
52				nil,
53				MarshalError{
54					Value: []string{"foo", "bar"},
55					Err:   errors.New("WriteArray can only write a Array while positioned on a Element or Value but is positioned on a TopLevel")},
56			},
57			{
58				"nil",
59				nil,
60				nil,
61				ErrNilDocument,
62			},
63		}
64		for _, tc := range testCases {
65			t.Run(tc.name, func(t *testing.T) {
66				got, err := transformDocument(bson.NewRegistryBuilder().Build(), tc.document)
67				assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
68				assert.Equal(t, tc.want, got, "expected document %v, got %v", tc.want, got)
69			})
70		}
71	})
72	t.Run("transform and ensure ID", func(t *testing.T) {
73		t.Run("newly added _id should be first element", func(t *testing.T) {
74			doc := bson.D{{"foo", "bar"}, {"baz", "qux"}, {"hello", "world"}}
75			want := bsonx.Doc{
76				{"_id", bsonx.Null()}, {"foo", bsonx.String("bar")},
77				{"baz", bsonx.String("qux")}, {"hello", bsonx.String("world")},
78			}
79			got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc)
80			assert.Nil(t, err, "transformAndEnsureID error: %v", err)
81			oid, ok := id.(primitive.ObjectID)
82			assert.True(t, ok, "expected returned id type %T, got %T", primitive.ObjectID{}, id)
83			want[0] = bsonx.Elem{"_id", bsonx.ObjectID(oid)}
84			assert.Equal(t, got, want, "expected document %v, got %v", got, want)
85		})
86		t.Run("existing _id should be first element", func(t *testing.T) {
87			doc := bson.D{{"foo", "bar"}, {"baz", "qux"}, {"_id", 3.14159}, {"hello", "world"}}
88			want := bsonx.Doc{
89				{"_id", bsonx.Double(3.14159)}, {"foo", bsonx.String("bar")},
90				{"baz", bsonx.String("qux")}, {"hello", bsonx.String("world")},
91			}
92			got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc)
93			assert.Nil(t, err, "transformAndEnsureID error: %v", err)
94			_, ok := id.(float64)
95			assert.True(t, ok, "expected returned id type %T, got %T", float64(0), id)
96			assert.Equal(t, got, want, "expected document %v, got %v", got, want)
97		})
98		t.Run("existing _id as first element should remain first element", func(t *testing.T) {
99			doc := bson.D{{"_id", 3.14159}, {"foo", "bar"}, {"baz", "qux"}, {"hello", "world"}}
100			want := bsonx.Doc{
101				{"_id", bsonx.Double(3.14159)}, {"foo", bsonx.String("bar")},
102				{"baz", bsonx.String("qux")}, {"hello", bsonx.String("world")},
103			}
104			got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc)
105			assert.Nil(t, err, "transformAndEnsureID error: %v", err)
106			_, ok := id.(float64)
107			assert.True(t, ok, "expected returned id type %T, got %T", float64(0), id)
108			assert.Equal(t, got, want, "expected document %v, got %v", got, want)
109		})
110		t.Run("existing _id should not overwrite a first binary field", func(t *testing.T) {
111			doc := bson.D{{"bin", []byte{0, 0, 0}}, {"_id", "LongEnoughIdentifier"}}
112			want := bsonx.Doc{
113				{"_id", bsonx.String("LongEnoughIdentifier")},
114				{"bin", bsonx.Binary(0x00, []byte{0x00, 0x00, 0x00})},
115			}
116			got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc)
117			assert.Nil(t, err, "transformAndEnsureID error: %v", err)
118			_, ok := id.(string)
119			assert.True(t, ok, "expected returned id type %T, got %T", string(0), id)
120			assert.Equal(t, got, want, "expected document %v, got %v", got, want)
121		})
122	})
123	t.Run("transform aggregate pipeline", func(t *testing.T) {
124		index, arr := bsoncore.AppendArrayStart(nil)
125		dindex, arr := bsoncore.AppendDocumentElementStart(arr, "0")
126		arr = bsoncore.AppendInt32Element(arr, "$limit", 12345)
127		arr, _ = bsoncore.AppendDocumentEnd(arr, dindex)
128		arr, _ = bsoncore.AppendArrayEnd(arr, index)
129
130		testCases := []struct {
131			name     string
132			pipeline interface{}
133			arr      bsonx.Arr
134			err      error
135		}{
136			{
137				"Pipeline/error",
138				Pipeline{{{"hello", func() {}}}}, bsonx.Arr{},
139				MarshalError{Value: primitive.D{}, Err: errors.New("no encoder found for func()")},
140			},
141			{
142				"Pipeline/success",
143				Pipeline{{{"hello", "world"}}, {{"pi", 3.14159}}},
144				bsonx.Arr{
145					bsonx.Document(bsonx.Doc{{"hello", bsonx.String("world")}}),
146					bsonx.Document(bsonx.Doc{{"pi", bsonx.Double(3.14159)}}),
147				},
148				nil,
149			},
150			{
151				"bsonx.Arr",
152				bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}})},
153				bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}})},
154				nil,
155			},
156			{
157				"[]bsonx.Doc",
158				[]bsonx.Doc{{{"$limit", bsonx.Int32(12345)}}},
159				bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}})},
160				nil,
161			},
162			{
163				"primitive.A/error",
164				primitive.A{"5"},
165				bsonx.Arr{},
166				MarshalError{Value: string(""), Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")},
167			},
168			{
169				"primitive.A/success",
170				primitive.A{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}},
171				bsonx.Arr{
172					bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}}),
173					bsonx.Document(bsonx.Doc{{"$count", bsonx.String("foobar")}}),
174				},
175				nil,
176			},
177			{
178				"bson.A/error",
179				bson.A{"5"},
180				bsonx.Arr{},
181				MarshalError{Value: string(""), Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")},
182			},
183			{
184				"bson.A/success",
185				bson.A{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}},
186				bsonx.Arr{
187					bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}}),
188					bsonx.Document(bsonx.Doc{{"$count", bsonx.String("foobar")}}),
189				},
190				nil,
191			},
192			{
193				"[]interface{}/error",
194				[]interface{}{"5"},
195				bsonx.Arr{},
196				MarshalError{Value: string(""), Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")},
197			},
198			{
199				"[]interface{}/success",
200				[]interface{}{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}},
201				bsonx.Arr{
202					bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}}),
203					bsonx.Document(bsonx.Doc{{"$count", bsonx.String("foobar")}}),
204				},
205				nil,
206			},
207			{
208				"bsoncodec.ValueMarshaler/MarshalBSONValue error",
209				bvMarsh{err: errors.New("MarshalBSONValue error")},
210				bsonx.Arr{},
211				errors.New("MarshalBSONValue error"),
212			},
213			{
214				"bsoncodec.ValueMarshaler/not array",
215				bvMarsh{t: bsontype.String},
216				bsonx.Arr{},
217				fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", bsontype.String, bsontype.Array),
218			},
219			{
220				"bsoncodec.ValueMarshaler/UnmarshalBSONValue error",
221				bvMarsh{t: bsontype.Array},
222				bsonx.Arr{},
223				bsoncore.NewInsufficientBytesError(nil, nil),
224			},
225			{
226				"bsoncodec.ValueMarshaler/success",
227				bvMarsh{t: bsontype.Array, data: arr},
228				bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int32(12345)}})},
229				nil,
230			},
231			{
232				"nil",
233				nil,
234				bsonx.Arr{},
235				errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid"),
236			},
237			{
238				"not array or slice",
239				int64(42),
240				bsonx.Arr{},
241				errors.New("can only transform slices and arrays into aggregation pipelines, but got int64"),
242			},
243			{
244				"array/error",
245				[1]interface{}{int64(42)},
246				bsonx.Arr{},
247				MarshalError{Value: int64(0), Err: errors.New("WriteInt64 can only write while positioned on a Element or Value but is positioned on a TopLevel")},
248			},
249			{
250				"array/success",
251				[1]interface{}{primitive.D{{"$limit", int64(12345)}}},
252				bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int64(12345)}})},
253				nil,
254			},
255			{
256				"slice/error",
257				[]interface{}{int64(42)},
258				bsonx.Arr{},
259				MarshalError{Value: int64(0), Err: errors.New("WriteInt64 can only write while positioned on a Element or Value but is positioned on a TopLevel")},
260			},
261			{
262				"slice/success",
263				[]interface{}{primitive.D{{"$limit", int64(12345)}}},
264				bsonx.Arr{bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int64(12345)}})},
265				nil,
266			},
267		}
268
269		for _, tc := range testCases {
270			t.Run(tc.name, func(t *testing.T) {
271				arr, err := transformAggregatePipeline(bson.NewRegistryBuilder().Build(), tc.pipeline)
272				assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
273				assert.Equal(t, tc.arr, arr, "expected array %v, got %v", tc.arr, arr)
274			})
275		}
276	})
277}
278
279var _ bson.Marshaler = bMarsh{}
280
281type bMarsh struct {
282	bsonx.Doc
283}
284
285func (b bMarsh) MarshalBSON() ([]byte, error) {
286	return b.Doc.MarshalBSON()
287}
288
289type reflectStruct struct {
290	Foo string
291}
292
293var _ bsoncodec.ValueMarshaler = bvMarsh{}
294
295type bvMarsh struct {
296	t    bsontype.Type
297	data []byte
298	err  error
299}
300
301func (b bvMarsh) MarshalBSONValue() (bsontype.Type, []byte, error) {
302	return b.t, b.data, b.err
303}
304