1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo // import "go.mongodb.org/mongo-driver/mongo"
8
9import (
10	"context"
11	"errors"
12	"fmt"
13	"net"
14	"reflect"
15	"strconv"
16	"strings"
17
18	"go.mongodb.org/mongo-driver/mongo/options"
19	"go.mongodb.org/mongo-driver/x/bsonx"
20	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
21
22	"go.mongodb.org/mongo-driver/bson"
23	"go.mongodb.org/mongo-driver/bson/bsoncodec"
24	"go.mongodb.org/mongo-driver/bson/bsontype"
25	"go.mongodb.org/mongo-driver/bson/primitive"
26)
27
28// Dialer is used to make network connections.
29type Dialer interface {
30	DialContext(ctx context.Context, network, address string) (net.Conn, error)
31}
32
33// BSONAppender is an interface implemented by types that can marshal a
34// provided type into BSON bytes and append those bytes to the provided []byte.
35// The AppendBSON can return a non-nil error and non-nil []byte. The AppendBSON
36// method may also write incomplete BSON to the []byte.
37type BSONAppender interface {
38	AppendBSON([]byte, interface{}) ([]byte, error)
39}
40
41// BSONAppenderFunc is an adapter function that allows any function that
42// satisfies the AppendBSON method signature to be used where a BSONAppender is
43// used.
44type BSONAppenderFunc func([]byte, interface{}) ([]byte, error)
45
46// AppendBSON implements the BSONAppender interface
47func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, error) {
48	return baf(dst, val)
49}
50
51// MarshalError is returned when attempting to transform a value into a document
52// results in an error.
53type MarshalError struct {
54	Value interface{}
55	Err   error
56}
57
58// Error implements the error interface.
59func (me MarshalError) Error() string {
60	return fmt.Sprintf("cannot transform type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err)
61}
62
63// Pipeline is a type that makes creating aggregation pipelines easier. It is a
64// helper and is intended for serializing to BSON.
65//
66// Example usage:
67//
68//		mongo.Pipeline{
69//			{{"$group", bson.D{{"_id", "$state"}, {"totalPop", bson.D{{"$sum", "$pop"}}}}}},
70//			{{"$match", bson.D{{"totalPop", bson.D{{"$gte", 10*1000*1000}}}}}},
71//		}
72//
73type Pipeline []bson.D
74
75// transformAndEnsureID is a hack that makes it easy to get a RawValue as the _id value. This will
76// be removed when we switch from using bsonx to bsoncore for the driver package.
77func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, interface{}, error) {
78	// TODO: performance is going to be pretty bad for bsonx.Doc here since we turn it into a []byte
79	// only to turn it back into a bsonx.Doc. We can fix this post beta1 when we refactor the driver
80	// package to use bsoncore.Document instead of bsonx.Doc.
81	if registry == nil {
82		registry = bson.NewRegistryBuilder().Build()
83	}
84	switch tt := val.(type) {
85	case nil:
86		return nil, nil, ErrNilDocument
87	case bsonx.Doc:
88		val = tt.Copy()
89	case []byte:
90		// Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
91		val = bson.Raw(tt)
92	}
93
94	// TODO(skriptble): Use a pool of these instead.
95	buf := make([]byte, 0, 256)
96	b, err := bson.MarshalAppendWithRegistry(registry, buf, val)
97	if err != nil {
98		return nil, nil, MarshalError{Value: val, Err: err}
99	}
100
101	d, err := bsonx.ReadDoc(b)
102	if err != nil {
103		return nil, nil, err
104	}
105
106	var id interface{}
107
108	idx := d.IndexOf("_id")
109	var idElem bsonx.Elem
110	switch idx {
111	case -1:
112		idElem = bsonx.Elem{"_id", bsonx.ObjectID(primitive.NewObjectID())}
113		d = append(d, bsonx.Elem{})
114		copy(d[1:], d)
115		d[0] = idElem
116	default:
117		idElem = d[idx]
118		copy(d[1:idx+1], d[0:idx])
119		d[0] = idElem
120	}
121
122	idBuf := make([]byte, 0, 256)
123	t, data, err := idElem.Value.MarshalAppendBSONValue(idBuf[:0])
124	if err != nil {
125		return nil, nil, err
126	}
127
128	err = bson.RawValue{Type: t, Value: data}.UnmarshalWithRegistry(registry, &id)
129	if err != nil {
130		return nil, nil, err
131	}
132
133	return d, id, nil
134}
135
136// transformAndEnsureIDv2 is a hack that makes it easy to get a RawValue as the _id value. This will
137// be removed when we switch from using bsonx to bsoncore for the driver package.
138func transformAndEnsureIDv2(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, interface{}, error) {
139	if registry == nil {
140		registry = bson.NewRegistryBuilder().Build()
141	}
142	switch tt := val.(type) {
143	case nil:
144		return nil, nil, ErrNilDocument
145	case bsonx.Doc:
146		val = tt.Copy()
147	case []byte:
148		// Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
149		val = bson.Raw(tt)
150	}
151
152	// TODO(skriptble): Use a pool of these instead.
153	doc := make(bsoncore.Document, 0, 256)
154	doc, err := bson.MarshalAppendWithRegistry(registry, doc, val)
155	if err != nil {
156		return nil, nil, MarshalError{Value: val, Err: err}
157	}
158
159	var id interface{}
160
161	value := doc.Lookup("_id")
162	switch value.Type {
163	case bsontype.Type(0):
164		value = bsoncore.Value{Type: bsontype.ObjectID, Data: bsoncore.AppendObjectID(nil, primitive.NewObjectID())}
165		olddoc := doc
166		doc = make(bsoncore.Document, 0, len(olddoc)+17) // type byte + _id + null byte + object ID
167		_, doc = bsoncore.ReserveLength(doc)
168		doc = bsoncore.AppendValueElement(doc, "_id", value)
169		doc = append(doc, olddoc[4:]...) // remove the length
170		doc = bsoncore.UpdateLength(doc, 0, int32(len(doc)))
171	default:
172		// We copy the bytes here to ensure that any bytes returned to the user aren't modified
173		// later.
174		buf := make([]byte, len(value.Data))
175		copy(buf, value.Data)
176		value.Data = buf
177	}
178
179	err = bson.RawValue{Type: value.Type, Value: value.Data}.UnmarshalWithRegistry(registry, &id)
180	if err != nil {
181		return nil, nil, err
182	}
183
184	return doc, id, nil
185}
186
187func transformDocument(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, error) {
188	if doc, ok := val.(bsonx.Doc); ok {
189		return doc.Copy(), nil
190	}
191	b, err := transformBsoncoreDocument(registry, val)
192	if err != nil {
193		return nil, err
194	}
195	return bsonx.ReadDoc(b)
196}
197
198func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, error) {
199	if registry == nil {
200		registry = bson.DefaultRegistry
201	}
202	if val == nil {
203		return nil, ErrNilDocument
204	}
205	if bs, ok := val.([]byte); ok {
206		// Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
207		val = bson.Raw(bs)
208	}
209
210	// TODO(skriptble): Use a pool of these instead.
211	buf := make([]byte, 0, 256)
212	b, err := bson.MarshalAppendWithRegistry(registry, buf[:0], val)
213	if err != nil {
214		return nil, MarshalError{Value: val, Err: err}
215	}
216	return b, nil
217}
218
219func ensureID(d bsonx.Doc) (bsonx.Doc, interface{}) {
220	var id interface{}
221
222	elem, err := d.LookupElementErr("_id")
223	switch err.(type) {
224	case nil:
225		id = elem
226	default:
227		oid := primitive.NewObjectID()
228		d = append(d, bsonx.Elem{"_id", bsonx.ObjectID(oid)})
229		id = oid
230	}
231	return d, id
232}
233
234func ensureDollarKey(doc bsonx.Doc) error {
235	if len(doc) == 0 {
236		return errors.New("update document must have at least one element")
237	}
238	if !strings.HasPrefix(doc[0].Key, "$") {
239		return errors.New("update document must contain key beginning with '$'")
240	}
241	return nil
242}
243
244func ensureDollarKeyv2(doc bsoncore.Document) error {
245	firstElem, err := doc.IndexErr(0)
246	if err != nil {
247		return errors.New("update document must have at least one element")
248	}
249
250	if !strings.HasPrefix(firstElem.Key(), "$") {
251		return errors.New("update document must contain key beginning with '$'")
252	}
253	return nil
254}
255
256func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsonx.Arr, error) {
257	pipelineArr := bsonx.Arr{}
258	switch t := pipeline.(type) {
259	case bsoncodec.ValueMarshaler:
260		btype, val, err := t.MarshalBSONValue()
261		if err != nil {
262			return nil, err
263		}
264		if btype != bsontype.Array {
265			return nil, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
266		}
267		err = pipelineArr.UnmarshalBSONValue(btype, val)
268		if err != nil {
269			return nil, err
270		}
271	default:
272		val := reflect.ValueOf(t)
273		if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
274			return nil, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind())
275		}
276		for idx := 0; idx < val.Len(); idx++ {
277			elem, err := transformDocument(registry, val.Index(idx).Interface())
278			if err != nil {
279				return nil, err
280			}
281			pipelineArr = append(pipelineArr, bsonx.Document(elem))
282		}
283	}
284
285	return pipelineArr, nil
286}
287
288func transformAggregatePipelinev2(registry *bsoncodec.Registry, pipeline interface{}) (bsoncore.Document, bool, error) {
289	switch t := pipeline.(type) {
290	case bsoncodec.ValueMarshaler:
291		btype, val, err := t.MarshalBSONValue()
292		if err != nil {
293			return nil, false, err
294		}
295		if btype != bsontype.Array {
296			return nil, false, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
297		}
298
299		var hasOutputStage bool
300		pipelineDoc := bsoncore.Document(val)
301		if _, err := pipelineDoc.LookupErr("$out"); err == nil {
302			hasOutputStage = true
303		}
304		if _, err := pipelineDoc.LookupErr("$merge"); err == nil {
305			hasOutputStage = true
306		}
307
308		return pipelineDoc, hasOutputStage, nil
309	default:
310		val := reflect.ValueOf(t)
311		if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
312			return nil, false, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind())
313		}
314
315		aidx, arr := bsoncore.AppendArrayStart(nil)
316		var hasOutputStage bool
317		valLen := val.Len()
318		for idx := 0; idx < valLen; idx++ {
319			doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface())
320			if err != nil {
321				return nil, false, err
322			}
323
324			if idx == valLen-1 {
325				if elem, err := doc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
326					hasOutputStage = true
327				}
328			}
329			arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc)
330		}
331		arr, _ = bsoncore.AppendArrayEnd(arr, aidx)
332		return arr, hasOutputStage, nil
333	}
334}
335
336func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, checkDocDollarKey bool) (bsoncore.Value, error) {
337	var u bsoncore.Value
338	var err error
339	switch t := update.(type) {
340	case nil:
341		return u, ErrNilDocument
342	case primitive.D, bsonx.Doc:
343		u.Type = bsontype.EmbeddedDocument
344		u.Data, err = transformBsoncoreDocument(registry, update)
345		if err != nil {
346			return u, err
347		}
348
349		if checkDocDollarKey {
350			err = ensureDollarKeyv2(u.Data)
351		}
352		return u, err
353	case bson.Raw:
354		u.Type = bsontype.EmbeddedDocument
355		u.Data = t
356		if checkDocDollarKey {
357			err = ensureDollarKeyv2(u.Data)
358		}
359		return u, err
360	case bsoncore.Document:
361		u.Type = bsontype.EmbeddedDocument
362		u.Data = t
363		if checkDocDollarKey {
364			err = ensureDollarKeyv2(u.Data)
365		}
366		return u, err
367	case []byte:
368		u.Type = bsontype.EmbeddedDocument
369		u.Data = t
370		if checkDocDollarKey {
371			err = ensureDollarKeyv2(u.Data)
372		}
373		return u, err
374	case bsoncodec.Marshaler:
375		u.Type = bsontype.EmbeddedDocument
376		u.Data, err = t.MarshalBSON()
377		if err != nil {
378			return u, err
379		}
380
381		if checkDocDollarKey {
382			err = ensureDollarKeyv2(u.Data)
383		}
384		return u, err
385	case bsoncodec.ValueMarshaler:
386		u.Type, u.Data, err = t.MarshalBSONValue()
387		if err != nil {
388			return u, err
389		}
390		if u.Type != bsontype.Array && u.Type != bsontype.EmbeddedDocument {
391			return u, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v or %v", u.Type, bsontype.Array, bsontype.EmbeddedDocument)
392		}
393		return u, err
394	default:
395		val := reflect.ValueOf(t)
396		if !val.IsValid() {
397			return u, fmt.Errorf("can only transform slices and arrays into update pipelines, but got %v", val.Kind())
398		}
399		if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
400			u.Type = bsontype.EmbeddedDocument
401			u.Data, err = transformBsoncoreDocument(registry, update)
402			if err != nil {
403				return u, err
404			}
405
406			if checkDocDollarKey {
407				err = ensureDollarKeyv2(u.Data)
408			}
409			return u, err
410		}
411
412		u.Type = bsontype.Array
413		aidx, arr := bsoncore.AppendArrayStart(nil)
414		valLen := val.Len()
415		for idx := 0; idx < valLen; idx++ {
416			doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface())
417			if err != nil {
418				return u, err
419			}
420
421			if err := ensureDollarKeyv2(doc); err != nil {
422				return u, err
423			}
424
425			arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc)
426		}
427		u.Data, _ = bsoncore.AppendArrayEnd(arr, aidx)
428		return u, err
429	}
430}
431
432func transformValue(registry *bsoncodec.Registry, val interface{}) (bsoncore.Value, error) {
433	switch conv := val.(type) {
434	case string:
435		return bsoncore.Value{Type: bsontype.String, Data: bsoncore.AppendString(nil, conv)}, nil
436	default:
437		doc, err := transformBsoncoreDocument(registry, val)
438		if err != nil {
439			return bsoncore.Value{}, err
440		}
441
442		return bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: doc}, nil
443	}
444}
445
446// Build the aggregation pipeline for the CountDocument command.
447func countDocumentsAggregatePipeline(registry *bsoncodec.Registry, filter interface{}, opts *options.CountOptions) (bsoncore.Document, error) {
448	filterDoc, err := transformBsoncoreDocument(registry, filter)
449	if err != nil {
450		return nil, err
451	}
452
453	aidx, arr := bsoncore.AppendArrayStart(nil)
454	didx, arr := bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(0))
455	arr = bsoncore.AppendDocumentElement(arr, "$match", filterDoc)
456	arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
457
458	index := 1
459	if opts != nil {
460		if opts.Skip != nil {
461			didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
462			arr = bsoncore.AppendInt64Element(arr, "$skip", *opts.Skip)
463			arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
464			index++
465		}
466		if opts.Limit != nil {
467			didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
468			arr = bsoncore.AppendInt64Element(arr, "$limit", *opts.Limit)
469			arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
470			index++
471		}
472	}
473
474	didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
475	iidx, arr := bsoncore.AppendDocumentElementStart(arr, "$group")
476	arr = bsoncore.AppendInt32Element(arr, "_id", 1)
477	iiidx, arr := bsoncore.AppendDocumentElementStart(arr, "n")
478	arr = bsoncore.AppendInt32Element(arr, "$sum", 1)
479	arr, _ = bsoncore.AppendDocumentEnd(arr, iiidx)
480	arr, _ = bsoncore.AppendDocumentEnd(arr, iidx)
481	arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
482
483	return bsoncore.AppendArrayEnd(arr, aidx)
484}
485