1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10	"context"
11
12	"go.mongodb.org/mongo-driver/bson"
13	"go.mongodb.org/mongo-driver/mongo/readconcern"
14	"go.mongodb.org/mongo-driver/mongo/readpref"
15	"go.mongodb.org/mongo-driver/mongo/writeconcern"
16	"go.mongodb.org/mongo-driver/x/bsonx"
17	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
18	"go.mongodb.org/mongo-driver/x/network/description"
19	"go.mongodb.org/mongo-driver/x/network/result"
20	"go.mongodb.org/mongo-driver/x/network/wiremessage"
21)
22
23// Aggregate represents the aggregate command.
24//
25// The aggregate command performs an aggregation.
26type Aggregate struct {
27	NS           Namespace
28	Pipeline     bsonx.Arr
29	CursorOpts   []bsonx.Elem
30	Opts         []bsonx.Elem
31	ReadPref     *readpref.ReadPref
32	WriteConcern *writeconcern.WriteConcern
33	ReadConcern  *readconcern.ReadConcern
34	Clock        *session.ClusterClock
35	Session      *session.Client
36
37	result bson.Raw
38	err    error
39}
40
41// Encode will encode this command into a wire message for the given server description.
42func (a *Aggregate) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
43	cmd, err := a.encode(desc)
44	if err != nil {
45		return nil, err
46	}
47
48	return cmd.Encode(desc)
49}
50
51func (a *Aggregate) encode(desc description.SelectedServer) (*Read, error) {
52	if err := a.NS.Validate(); err != nil {
53		return nil, err
54	}
55
56	command := bsonx.Doc{
57		{"aggregate", bsonx.String(a.NS.Collection)},
58		{"pipeline", bsonx.Array(a.Pipeline)},
59	}
60
61	cursor := bsonx.Doc{}
62	hasOutStage := a.HasDollarOut()
63
64	for _, opt := range a.Opts {
65		switch opt.Key {
66		case "batchSize":
67			if opt.Value.Int32() == 0 && hasOutStage {
68				continue
69			}
70			cursor = append(cursor, opt)
71		default:
72			command = append(command, opt)
73		}
74	}
75	command = append(command, bsonx.Elem{"cursor", bsonx.Document(cursor)})
76
77	// add write concern because it won't be added by the Read command's Encode()
78	if desc.WireVersion.Max >= 5 && hasOutStage && a.WriteConcern != nil {
79		t, data, err := a.WriteConcern.MarshalBSONValue()
80		if err != nil {
81			return nil, err
82		}
83		var xval bsonx.Val
84		err = xval.UnmarshalBSONValue(t, data)
85		if err != nil {
86			return nil, err
87		}
88		command = append(command, bsonx.Elem{Key: "writeConcern", Value: xval})
89	}
90
91	return &Read{
92		DB:          a.NS.DB,
93		Command:     command,
94		ReadPref:    a.ReadPref,
95		ReadConcern: a.ReadConcern,
96		Clock:       a.Clock,
97		Session:     a.Session,
98	}, nil
99}
100
101// HasDollarOut returns true if the Pipeline field contains a $out stage.
102func (a *Aggregate) HasDollarOut() bool {
103	if a.Pipeline == nil {
104		return false
105	}
106	if len(a.Pipeline) == 0 {
107		return false
108	}
109
110	val := a.Pipeline[len(a.Pipeline)-1]
111
112	doc, ok := val.DocumentOK()
113	if !ok || len(doc) != 1 {
114		return false
115	}
116	return doc[0].Key == "$out"
117}
118
119// Decode will decode the wire message using the provided server description. Errors during decoding
120// are deferred until either the Result or Err methods are called.
121func (a *Aggregate) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Aggregate {
122	rdr, err := (&Read{}).Decode(desc, wm).Result()
123	if err != nil {
124		a.err = err
125		return a
126	}
127
128	return a.decode(desc, rdr)
129}
130
131func (a *Aggregate) decode(desc description.SelectedServer, rdr bson.Raw) *Aggregate {
132	a.result = rdr
133	if val, err := rdr.LookupErr("writeConcernError"); err == nil {
134		var wce result.WriteConcernError
135		_ = val.Unmarshal(&wce)
136		a.err = wce
137	}
138	return a
139}
140
141// Result returns the result of a decoded wire message and server description.
142func (a *Aggregate) Result() (bson.Raw, error) {
143	if a.err != nil {
144		return nil, a.err
145	}
146	return a.result, nil
147}
148
149// Err returns the error set on this command.
150func (a *Aggregate) Err() error { return a.err }
151
152// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
153func (a *Aggregate) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
154	cmd, err := a.encode(desc)
155	if err != nil {
156		return nil, err
157	}
158
159	rdr, err := cmd.RoundTrip(ctx, desc, rw)
160	if err != nil {
161		return nil, err
162	}
163
164	return a.decode(desc, rdr).Result()
165}
166