1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package mongo // import "go.mongodb.org/mongo-driver/mongo" 8 9import ( 10 "context" 11 "errors" 12 "fmt" 13 "net" 14 "reflect" 15 "strconv" 16 "strings" 17 18 "go.mongodb.org/mongo-driver/mongo/options" 19 "go.mongodb.org/mongo-driver/x/bsonx" 20 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" 21 22 "go.mongodb.org/mongo-driver/bson" 23 "go.mongodb.org/mongo-driver/bson/bsoncodec" 24 "go.mongodb.org/mongo-driver/bson/bsontype" 25 "go.mongodb.org/mongo-driver/bson/primitive" 26) 27 28// Dialer is used to make network connections. 29type Dialer interface { 30 DialContext(ctx context.Context, network, address string) (net.Conn, error) 31} 32 33// BSONAppender is an interface implemented by types that can marshal a 34// provided type into BSON bytes and append those bytes to the provided []byte. 35// The AppendBSON can return a non-nil error and non-nil []byte. The AppendBSON 36// method may also write incomplete BSON to the []byte. 37type BSONAppender interface { 38 AppendBSON([]byte, interface{}) ([]byte, error) 39} 40 41// BSONAppenderFunc is an adapter function that allows any function that 42// satisfies the AppendBSON method signature to be used where a BSONAppender is 43// used. 44type BSONAppenderFunc func([]byte, interface{}) ([]byte, error) 45 46// AppendBSON implements the BSONAppender interface 47func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, error) { 48 return baf(dst, val) 49} 50 51// MarshalError is returned when attempting to transform a value into a document 52// results in an error. 53type MarshalError struct { 54 Value interface{} 55 Err error 56} 57 58// Error implements the error interface. 59func (me MarshalError) Error() string { 60 return fmt.Sprintf("cannot transform type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err) 61} 62 63// Pipeline is a type that makes creating aggregation pipelines easier. It is a 64// helper and is intended for serializing to BSON. 65// 66// Example usage: 67// 68// mongo.Pipeline{ 69// {{"$group", bson.D{{"_id", "$state"}, {"totalPop", bson.D{{"$sum", "$pop"}}}}}}, 70// {{"$match", bson.D{{"totalPop", bson.D{{"$gte", 10*1000*1000}}}}}}, 71// } 72// 73type Pipeline []bson.D 74 75// transformAndEnsureID is a hack that makes it easy to get a RawValue as the _id value. This will 76// be removed when we switch from using bsonx to bsoncore for the driver package. 77func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, interface{}, error) { 78 // TODO: performance is going to be pretty bad for bsonx.Doc here since we turn it into a []byte 79 // only to turn it back into a bsonx.Doc. We can fix this post beta1 when we refactor the driver 80 // package to use bsoncore.Document instead of bsonx.Doc. 81 if registry == nil { 82 registry = bson.NewRegistryBuilder().Build() 83 } 84 switch tt := val.(type) { 85 case nil: 86 return nil, nil, ErrNilDocument 87 case bsonx.Doc: 88 val = tt.Copy() 89 case []byte: 90 // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. 91 val = bson.Raw(tt) 92 } 93 94 // TODO(skriptble): Use a pool of these instead. 95 buf := make([]byte, 0, 256) 96 b, err := bson.MarshalAppendWithRegistry(registry, buf, val) 97 if err != nil { 98 return nil, nil, MarshalError{Value: val, Err: err} 99 } 100 101 d, err := bsonx.ReadDoc(b) 102 if err != nil { 103 return nil, nil, err 104 } 105 106 var id interface{} 107 108 idx := d.IndexOf("_id") 109 var idElem bsonx.Elem 110 switch idx { 111 case -1: 112 idElem = bsonx.Elem{"_id", bsonx.ObjectID(primitive.NewObjectID())} 113 d = append(d, bsonx.Elem{}) 114 copy(d[1:], d) 115 d[0] = idElem 116 default: 117 idElem = d[idx] 118 copy(d[1:idx+1], d[0:idx]) 119 d[0] = idElem 120 } 121 122 idBuf := make([]byte, 0, 256) 123 t, data, err := idElem.Value.MarshalAppendBSONValue(idBuf[:0]) 124 if err != nil { 125 return nil, nil, err 126 } 127 128 err = bson.RawValue{Type: t, Value: data}.UnmarshalWithRegistry(registry, &id) 129 if err != nil { 130 return nil, nil, err 131 } 132 133 return d, id, nil 134} 135 136// transformAndEnsureIDv2 is a hack that makes it easy to get a RawValue as the _id value. This will 137// be removed when we switch from using bsonx to bsoncore for the driver package. 138func transformAndEnsureIDv2(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, interface{}, error) { 139 if registry == nil { 140 registry = bson.NewRegistryBuilder().Build() 141 } 142 switch tt := val.(type) { 143 case nil: 144 return nil, nil, ErrNilDocument 145 case bsonx.Doc: 146 val = tt.Copy() 147 case []byte: 148 // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. 149 val = bson.Raw(tt) 150 } 151 152 // TODO(skriptble): Use a pool of these instead. 153 doc := make(bsoncore.Document, 0, 256) 154 doc, err := bson.MarshalAppendWithRegistry(registry, doc, val) 155 if err != nil { 156 return nil, nil, MarshalError{Value: val, Err: err} 157 } 158 159 var id interface{} 160 161 value := doc.Lookup("_id") 162 switch value.Type { 163 case bsontype.Type(0): 164 value = bsoncore.Value{Type: bsontype.ObjectID, Data: bsoncore.AppendObjectID(nil, primitive.NewObjectID())} 165 olddoc := doc 166 doc = make(bsoncore.Document, 0, len(olddoc)+17) // type byte + _id + null byte + object ID 167 _, doc = bsoncore.ReserveLength(doc) 168 doc = bsoncore.AppendValueElement(doc, "_id", value) 169 doc = append(doc, olddoc[4:]...) // remove the length 170 doc = bsoncore.UpdateLength(doc, 0, int32(len(doc))) 171 default: 172 // We copy the bytes here to ensure that any bytes returned to the user aren't modified 173 // later. 174 buf := make([]byte, len(value.Data)) 175 copy(buf, value.Data) 176 value.Data = buf 177 } 178 179 err = bson.RawValue{Type: value.Type, Value: value.Data}.UnmarshalWithRegistry(registry, &id) 180 if err != nil { 181 return nil, nil, err 182 } 183 184 return doc, id, nil 185} 186 187func transformDocument(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, error) { 188 if doc, ok := val.(bsonx.Doc); ok { 189 return doc.Copy(), nil 190 } 191 b, err := transformBsoncoreDocument(registry, val) 192 if err != nil { 193 return nil, err 194 } 195 return bsonx.ReadDoc(b) 196} 197 198func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, error) { 199 if registry == nil { 200 registry = bson.DefaultRegistry 201 } 202 if val == nil { 203 return nil, ErrNilDocument 204 } 205 if bs, ok := val.([]byte); ok { 206 // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. 207 val = bson.Raw(bs) 208 } 209 210 // TODO(skriptble): Use a pool of these instead. 211 buf := make([]byte, 0, 256) 212 b, err := bson.MarshalAppendWithRegistry(registry, buf[:0], val) 213 if err != nil { 214 return nil, MarshalError{Value: val, Err: err} 215 } 216 return b, nil 217} 218 219func ensureID(d bsonx.Doc) (bsonx.Doc, interface{}) { 220 var id interface{} 221 222 elem, err := d.LookupElementErr("_id") 223 switch err.(type) { 224 case nil: 225 id = elem 226 default: 227 oid := primitive.NewObjectID() 228 d = append(d, bsonx.Elem{"_id", bsonx.ObjectID(oid)}) 229 id = oid 230 } 231 return d, id 232} 233 234func ensureDollarKey(doc bsonx.Doc) error { 235 if len(doc) == 0 { 236 return errors.New("update document must have at least one element") 237 } 238 if !strings.HasPrefix(doc[0].Key, "$") { 239 return errors.New("update document must contain key beginning with '$'") 240 } 241 return nil 242} 243 244func ensureDollarKeyv2(doc bsoncore.Document) error { 245 firstElem, err := doc.IndexErr(0) 246 if err != nil { 247 return errors.New("update document must have at least one element") 248 } 249 250 if !strings.HasPrefix(firstElem.Key(), "$") { 251 return errors.New("update document must contain key beginning with '$'") 252 } 253 return nil 254} 255 256func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsonx.Arr, error) { 257 pipelineArr := bsonx.Arr{} 258 switch t := pipeline.(type) { 259 case bsoncodec.ValueMarshaler: 260 btype, val, err := t.MarshalBSONValue() 261 if err != nil { 262 return nil, err 263 } 264 if btype != bsontype.Array { 265 return nil, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array) 266 } 267 err = pipelineArr.UnmarshalBSONValue(btype, val) 268 if err != nil { 269 return nil, err 270 } 271 default: 272 val := reflect.ValueOf(t) 273 if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) { 274 return nil, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind()) 275 } 276 for idx := 0; idx < val.Len(); idx++ { 277 elem, err := transformDocument(registry, val.Index(idx).Interface()) 278 if err != nil { 279 return nil, err 280 } 281 pipelineArr = append(pipelineArr, bsonx.Document(elem)) 282 } 283 } 284 285 return pipelineArr, nil 286} 287 288func transformAggregatePipelinev2(registry *bsoncodec.Registry, pipeline interface{}) (bsoncore.Document, bool, error) { 289 switch t := pipeline.(type) { 290 case bsoncodec.ValueMarshaler: 291 btype, val, err := t.MarshalBSONValue() 292 if err != nil { 293 return nil, false, err 294 } 295 if btype != bsontype.Array { 296 return nil, false, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array) 297 } 298 299 var hasOutputStage bool 300 pipelineDoc := bsoncore.Document(val) 301 if _, err := pipelineDoc.LookupErr("$out"); err == nil { 302 hasOutputStage = true 303 } 304 if _, err := pipelineDoc.LookupErr("$merge"); err == nil { 305 hasOutputStage = true 306 } 307 308 return pipelineDoc, hasOutputStage, nil 309 default: 310 val := reflect.ValueOf(t) 311 if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) { 312 return nil, false, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind()) 313 } 314 315 aidx, arr := bsoncore.AppendArrayStart(nil) 316 var hasOutputStage bool 317 valLen := val.Len() 318 for idx := 0; idx < valLen; idx++ { 319 doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface()) 320 if err != nil { 321 return nil, false, err 322 } 323 324 if idx == valLen-1 { 325 if elem, err := doc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") { 326 hasOutputStage = true 327 } 328 } 329 arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc) 330 } 331 arr, _ = bsoncore.AppendArrayEnd(arr, aidx) 332 return arr, hasOutputStage, nil 333 } 334} 335 336func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, checkDocDollarKey bool) (bsoncore.Value, error) { 337 var u bsoncore.Value 338 var err error 339 switch t := update.(type) { 340 case nil: 341 return u, ErrNilDocument 342 case primitive.D, bsonx.Doc: 343 u.Type = bsontype.EmbeddedDocument 344 u.Data, err = transformBsoncoreDocument(registry, update) 345 if err != nil { 346 return u, err 347 } 348 349 if checkDocDollarKey { 350 err = ensureDollarKeyv2(u.Data) 351 } 352 return u, err 353 case bson.Raw: 354 u.Type = bsontype.EmbeddedDocument 355 u.Data = t 356 if checkDocDollarKey { 357 err = ensureDollarKeyv2(u.Data) 358 } 359 return u, err 360 case bsoncore.Document: 361 u.Type = bsontype.EmbeddedDocument 362 u.Data = t 363 if checkDocDollarKey { 364 err = ensureDollarKeyv2(u.Data) 365 } 366 return u, err 367 case []byte: 368 u.Type = bsontype.EmbeddedDocument 369 u.Data = t 370 if checkDocDollarKey { 371 err = ensureDollarKeyv2(u.Data) 372 } 373 return u, err 374 case bsoncodec.Marshaler: 375 u.Type = bsontype.EmbeddedDocument 376 u.Data, err = t.MarshalBSON() 377 if err != nil { 378 return u, err 379 } 380 381 if checkDocDollarKey { 382 err = ensureDollarKeyv2(u.Data) 383 } 384 return u, err 385 case bsoncodec.ValueMarshaler: 386 u.Type, u.Data, err = t.MarshalBSONValue() 387 if err != nil { 388 return u, err 389 } 390 if u.Type != bsontype.Array && u.Type != bsontype.EmbeddedDocument { 391 return u, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v or %v", u.Type, bsontype.Array, bsontype.EmbeddedDocument) 392 } 393 return u, err 394 default: 395 val := reflect.ValueOf(t) 396 if !val.IsValid() { 397 return u, fmt.Errorf("can only transform slices and arrays into update pipelines, but got %v", val.Kind()) 398 } 399 if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { 400 u.Type = bsontype.EmbeddedDocument 401 u.Data, err = transformBsoncoreDocument(registry, update) 402 if err != nil { 403 return u, err 404 } 405 406 if checkDocDollarKey { 407 err = ensureDollarKeyv2(u.Data) 408 } 409 return u, err 410 } 411 412 u.Type = bsontype.Array 413 aidx, arr := bsoncore.AppendArrayStart(nil) 414 valLen := val.Len() 415 for idx := 0; idx < valLen; idx++ { 416 doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface()) 417 if err != nil { 418 return u, err 419 } 420 421 if err := ensureDollarKeyv2(doc); err != nil { 422 return u, err 423 } 424 425 arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc) 426 } 427 u.Data, _ = bsoncore.AppendArrayEnd(arr, aidx) 428 return u, err 429 } 430} 431 432func transformValue(registry *bsoncodec.Registry, val interface{}) (bsoncore.Value, error) { 433 switch conv := val.(type) { 434 case string: 435 return bsoncore.Value{Type: bsontype.String, Data: bsoncore.AppendString(nil, conv)}, nil 436 default: 437 doc, err := transformBsoncoreDocument(registry, val) 438 if err != nil { 439 return bsoncore.Value{}, err 440 } 441 442 return bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: doc}, nil 443 } 444} 445 446// Build the aggregation pipeline for the CountDocument command. 447func countDocumentsAggregatePipeline(registry *bsoncodec.Registry, filter interface{}, opts *options.CountOptions) (bsoncore.Document, error) { 448 filterDoc, err := transformBsoncoreDocument(registry, filter) 449 if err != nil { 450 return nil, err 451 } 452 453 aidx, arr := bsoncore.AppendArrayStart(nil) 454 didx, arr := bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(0)) 455 arr = bsoncore.AppendDocumentElement(arr, "$match", filterDoc) 456 arr, _ = bsoncore.AppendDocumentEnd(arr, didx) 457 458 index := 1 459 if opts != nil { 460 if opts.Skip != nil { 461 didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index)) 462 arr = bsoncore.AppendInt64Element(arr, "$skip", *opts.Skip) 463 arr, _ = bsoncore.AppendDocumentEnd(arr, didx) 464 index++ 465 } 466 if opts.Limit != nil { 467 didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index)) 468 arr = bsoncore.AppendInt64Element(arr, "$limit", *opts.Limit) 469 arr, _ = bsoncore.AppendDocumentEnd(arr, didx) 470 index++ 471 } 472 } 473 474 didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index)) 475 iidx, arr := bsoncore.AppendDocumentElementStart(arr, "$group") 476 arr = bsoncore.AppendInt32Element(arr, "_id", 1) 477 iiidx, arr := bsoncore.AppendDocumentElementStart(arr, "n") 478 arr = bsoncore.AppendInt32Element(arr, "$sum", 1) 479 arr, _ = bsoncore.AppendDocumentEnd(arr, iiidx) 480 arr, _ = bsoncore.AppendDocumentEnd(arr, iidx) 481 arr, _ = bsoncore.AppendDocumentEnd(arr, didx) 482 483 return bsoncore.AppendArrayEnd(arr, aidx) 484} 485