1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package command 8 9import ( 10 "context" 11 12 "go.mongodb.org/mongo-driver/bson" 13 "go.mongodb.org/mongo-driver/mongo/readconcern" 14 "go.mongodb.org/mongo-driver/mongo/readpref" 15 "go.mongodb.org/mongo-driver/mongo/writeconcern" 16 "go.mongodb.org/mongo-driver/x/bsonx" 17 "go.mongodb.org/mongo-driver/x/mongo/driver/session" 18 "go.mongodb.org/mongo-driver/x/network/description" 19 "go.mongodb.org/mongo-driver/x/network/result" 20 "go.mongodb.org/mongo-driver/x/network/wiremessage" 21) 22 23// Aggregate represents the aggregate command. 24// 25// The aggregate command performs an aggregation. 26type Aggregate struct { 27 NS Namespace 28 Pipeline bsonx.Arr 29 CursorOpts []bsonx.Elem 30 Opts []bsonx.Elem 31 ReadPref *readpref.ReadPref 32 WriteConcern *writeconcern.WriteConcern 33 ReadConcern *readconcern.ReadConcern 34 Clock *session.ClusterClock 35 Session *session.Client 36 37 result bson.Raw 38 err error 39} 40 41// Encode will encode this command into a wire message for the given server description. 42func (a *Aggregate) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) { 43 cmd, err := a.encode(desc) 44 if err != nil { 45 return nil, err 46 } 47 48 return cmd.Encode(desc) 49} 50 51func (a *Aggregate) encode(desc description.SelectedServer) (*Read, error) { 52 if err := a.NS.Validate(); err != nil { 53 return nil, err 54 } 55 56 command := bsonx.Doc{ 57 {"aggregate", bsonx.String(a.NS.Collection)}, 58 {"pipeline", bsonx.Array(a.Pipeline)}, 59 } 60 61 cursor := bsonx.Doc{} 62 hasOutStage := a.HasDollarOut() 63 64 for _, opt := range a.Opts { 65 switch opt.Key { 66 case "batchSize": 67 if opt.Value.Int32() == 0 && hasOutStage { 68 continue 69 } 70 cursor = append(cursor, opt) 71 default: 72 command = append(command, opt) 73 } 74 } 75 command = append(command, bsonx.Elem{"cursor", bsonx.Document(cursor)}) 76 77 // add write concern because it won't be added by the Read command's Encode() 78 if desc.WireVersion.Max >= 5 && hasOutStage && a.WriteConcern != nil { 79 t, data, err := a.WriteConcern.MarshalBSONValue() 80 if err != nil { 81 return nil, err 82 } 83 var xval bsonx.Val 84 err = xval.UnmarshalBSONValue(t, data) 85 if err != nil { 86 return nil, err 87 } 88 command = append(command, bsonx.Elem{Key: "writeConcern", Value: xval}) 89 } 90 91 return &Read{ 92 DB: a.NS.DB, 93 Command: command, 94 ReadPref: a.ReadPref, 95 ReadConcern: a.ReadConcern, 96 Clock: a.Clock, 97 Session: a.Session, 98 }, nil 99} 100 101// HasDollarOut returns true if the Pipeline field contains a $out stage. 102func (a *Aggregate) HasDollarOut() bool { 103 if a.Pipeline == nil { 104 return false 105 } 106 if len(a.Pipeline) == 0 { 107 return false 108 } 109 110 val := a.Pipeline[len(a.Pipeline)-1] 111 112 doc, ok := val.DocumentOK() 113 if !ok || len(doc) != 1 { 114 return false 115 } 116 return doc[0].Key == "$out" 117} 118 119// Decode will decode the wire message using the provided server description. Errors during decoding 120// are deferred until either the Result or Err methods are called. 121func (a *Aggregate) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Aggregate { 122 rdr, err := (&Read{}).Decode(desc, wm).Result() 123 if err != nil { 124 a.err = err 125 return a 126 } 127 128 return a.decode(desc, rdr) 129} 130 131func (a *Aggregate) decode(desc description.SelectedServer, rdr bson.Raw) *Aggregate { 132 a.result = rdr 133 if val, err := rdr.LookupErr("writeConcernError"); err == nil { 134 var wce result.WriteConcernError 135 _ = val.Unmarshal(&wce) 136 a.err = wce 137 } 138 return a 139} 140 141// Result returns the result of a decoded wire message and server description. 142func (a *Aggregate) Result() (bson.Raw, error) { 143 if a.err != nil { 144 return nil, a.err 145 } 146 return a.result, nil 147} 148 149// Err returns the error set on this command. 150func (a *Aggregate) Err() error { return a.err } 151 152// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter. 153func (a *Aggregate) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) { 154 cmd, err := a.encode(desc) 155 if err != nil { 156 return nil, err 157 } 158 159 rdr, err := cmd.RoundTrip(ctx, desc, rw) 160 if err != nil { 161 return nil, err 162 } 163 164 return a.decode(desc, rdr).Result() 165} 166