1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package command 8 9import ( 10 "context" 11 12 "go.mongodb.org/mongo-driver/bson" 13 "go.mongodb.org/mongo-driver/mongo/writeconcern" 14 "go.mongodb.org/mongo-driver/x/bsonx" 15 "go.mongodb.org/mongo-driver/x/mongo/driver/session" 16 "go.mongodb.org/mongo-driver/x/network/description" 17 "go.mongodb.org/mongo-driver/x/network/result" 18 "go.mongodb.org/mongo-driver/x/network/wiremessage" 19) 20 21// Update represents the update command. 22// 23// The update command updates a set of documents with the database. 24type Update struct { 25 ContinueOnError bool 26 Clock *session.ClusterClock 27 NS Namespace 28 Docs []bsonx.Doc 29 Opts []bsonx.Elem 30 WriteConcern *writeconcern.WriteConcern 31 Session *session.Client 32 33 batches []*WriteBatch 34 result result.Update 35 err error 36} 37 38// Encode will encode this command into a wire message for the given server description. 39func (u *Update) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) { 40 err := u.encode(desc) 41 if err != nil { 42 return nil, err 43 } 44 45 return batchesToWireMessage(u.batches, desc) 46} 47 48func (u *Update) encode(desc description.SelectedServer) error { 49 batches, err := splitBatches(u.Docs, int(desc.MaxBatchCount), int(desc.MaxDocumentSize)) 50 if err != nil { 51 return err 52 } 53 54 for _, docs := range batches { 55 cmd, err := u.encodeBatch(docs, desc) 56 if err != nil { 57 return err 58 } 59 60 u.batches = append(u.batches, cmd) 61 } 62 63 return nil 64} 65 66func (u *Update) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) { 67 copyDocs := make([]bsonx.Doc, 0, len(docs)) // copy of all the documents 68 for _, doc := range docs { 69 newDoc := doc.Copy() 70 copyDocs = append(copyDocs, newDoc) 71 } 72 73 var options []bsonx.Elem 74 for _, opt := range u.Opts { 75 switch opt.Key { 76 case "upsert", "collation", "arrayFilters": 77 // options that are encoded on each individual document 78 for idx := range copyDocs { 79 copyDocs[idx] = append(copyDocs[idx], opt) 80 } 81 default: 82 options = append(options, opt) 83 } 84 } 85 86 command, err := encodeBatch(copyDocs, options, UpdateCommand, u.NS.Collection) 87 if err != nil { 88 return nil, err 89 } 90 91 return &WriteBatch{ 92 &Write{ 93 Clock: u.Clock, 94 DB: u.NS.DB, 95 Command: command, 96 WriteConcern: u.WriteConcern, 97 Session: u.Session, 98 }, 99 len(docs), 100 }, nil 101} 102 103// Decode will decode the wire message using the provided server description. Errors during decoding 104// are deferred until either the Result or Err methods are called. 105func (u *Update) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Update { 106 rdr, err := (&Write{}).Decode(desc, wm).Result() 107 if err != nil { 108 u.err = err 109 return u 110 } 111 return u.decode(desc, rdr) 112} 113 114func (u *Update) decode(desc description.SelectedServer, rdr bson.Raw) *Update { 115 u.err = bson.Unmarshal(rdr, &u.result) 116 return u 117} 118 119// Result returns the result of a decoded wire message and server description. 120func (u *Update) Result() (result.Update, error) { 121 if u.err != nil { 122 return result.Update{}, u.err 123 } 124 return u.result, nil 125} 126 127// Err returns the error set on this command. 128func (u *Update) Err() error { return u.err } 129 130// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter. 131func (u *Update) RoundTrip( 132 ctx context.Context, 133 desc description.SelectedServer, 134 rw wiremessage.ReadWriter, 135) (result.Update, error) { 136 if u.batches == nil { 137 err := u.encode(desc) 138 if err != nil { 139 return result.Update{}, err 140 } 141 } 142 143 r, batches, err := roundTripBatches( 144 ctx, desc, rw, 145 u.batches, 146 u.ContinueOnError, 147 u.Session, 148 UpdateCommand, 149 ) 150 151 // if there are leftover batches, save them for retry 152 if batches != nil { 153 u.batches = batches 154 } 155 156 if err != nil { 157 return result.Update{}, err 158 } 159 160 return r.(result.Update), nil 161} 162