1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10	"context"
11
12	"go.mongodb.org/mongo-driver/bson"
13	"go.mongodb.org/mongo-driver/mongo/writeconcern"
14	"go.mongodb.org/mongo-driver/x/bsonx"
15	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
16	"go.mongodb.org/mongo-driver/x/network/description"
17	"go.mongodb.org/mongo-driver/x/network/result"
18	"go.mongodb.org/mongo-driver/x/network/wiremessage"
19)
20
21// Update represents the update command.
22//
23// The update command updates a set of documents with the database.
24type Update struct {
25	ContinueOnError bool
26	Clock           *session.ClusterClock
27	NS              Namespace
28	Docs            []bsonx.Doc
29	Opts            []bsonx.Elem
30	WriteConcern    *writeconcern.WriteConcern
31	Session         *session.Client
32
33	batches []*WriteBatch
34	result  result.Update
35	err     error
36}
37
38// Encode will encode this command into a wire message for the given server description.
39func (u *Update) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
40	err := u.encode(desc)
41	if err != nil {
42		return nil, err
43	}
44
45	return batchesToWireMessage(u.batches, desc)
46}
47
48func (u *Update) encode(desc description.SelectedServer) error {
49	batches, err := splitBatches(u.Docs, int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
50	if err != nil {
51		return err
52	}
53
54	for _, docs := range batches {
55		cmd, err := u.encodeBatch(docs, desc)
56		if err != nil {
57			return err
58		}
59
60		u.batches = append(u.batches, cmd)
61	}
62
63	return nil
64}
65
66func (u *Update) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) {
67	copyDocs := make([]bsonx.Doc, 0, len(docs)) // copy of all the documents
68	for _, doc := range docs {
69		newDoc := doc.Copy()
70		copyDocs = append(copyDocs, newDoc)
71	}
72
73	var options []bsonx.Elem
74	for _, opt := range u.Opts {
75		switch opt.Key {
76		case "upsert", "collation", "arrayFilters":
77			// options that are encoded on each individual document
78			for idx := range copyDocs {
79				copyDocs[idx] = append(copyDocs[idx], opt)
80			}
81		default:
82			options = append(options, opt)
83		}
84	}
85
86	command, err := encodeBatch(copyDocs, options, UpdateCommand, u.NS.Collection)
87	if err != nil {
88		return nil, err
89	}
90
91	return &WriteBatch{
92		&Write{
93			Clock:        u.Clock,
94			DB:           u.NS.DB,
95			Command:      command,
96			WriteConcern: u.WriteConcern,
97			Session:      u.Session,
98		},
99		len(docs),
100	}, nil
101}
102
103// Decode will decode the wire message using the provided server description. Errors during decoding
104// are deferred until either the Result or Err methods are called.
105func (u *Update) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Update {
106	rdr, err := (&Write{}).Decode(desc, wm).Result()
107	if err != nil {
108		u.err = err
109		return u
110	}
111	return u.decode(desc, rdr)
112}
113
114func (u *Update) decode(desc description.SelectedServer, rdr bson.Raw) *Update {
115	u.err = bson.Unmarshal(rdr, &u.result)
116	return u
117}
118
119// Result returns the result of a decoded wire message and server description.
120func (u *Update) Result() (result.Update, error) {
121	if u.err != nil {
122		return result.Update{}, u.err
123	}
124	return u.result, nil
125}
126
127// Err returns the error set on this command.
128func (u *Update) Err() error { return u.err }
129
130// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
131func (u *Update) RoundTrip(
132	ctx context.Context,
133	desc description.SelectedServer,
134	rw wiremessage.ReadWriter,
135) (result.Update, error) {
136	if u.batches == nil {
137		err := u.encode(desc)
138		if err != nil {
139			return result.Update{}, err
140		}
141	}
142
143	r, batches, err := roundTripBatches(
144		ctx, desc, rw,
145		u.batches,
146		u.ContinueOnError,
147		u.Session,
148		UpdateCommand,
149	)
150
151	// if there are leftover batches, save them for retry
152	if batches != nil {
153		u.batches = batches
154	}
155
156	if err != nil {
157		return result.Update{}, err
158	}
159
160	return r.(result.Update), nil
161}
162