1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package mongo 8 9import ( 10 "context" 11 12 "go.mongodb.org/mongo-driver/bson/bsoncodec" 13 "go.mongodb.org/mongo-driver/mongo/options" 14 "go.mongodb.org/mongo-driver/mongo/writeconcern" 15 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" 16 "go.mongodb.org/mongo-driver/x/mongo/driver" 17 "go.mongodb.org/mongo-driver/x/mongo/driver/description" 18 "go.mongodb.org/mongo-driver/x/mongo/driver/operation" 19 "go.mongodb.org/mongo-driver/x/mongo/driver/session" 20) 21 22type bulkWriteBatch struct { 23 models []WriteModel 24 canRetry bool 25} 26 27// bulkWrite perfoms a bulkwrite operation 28type bulkWrite struct { 29 ordered *bool 30 bypassDocumentValidation *bool 31 models []WriteModel 32 session *session.Client 33 collection *Collection 34 selector description.ServerSelector 35 writeConcern *writeconcern.WriteConcern 36 result BulkWriteResult 37} 38 39func (bw *bulkWrite) execute(ctx context.Context) error { 40 ordered := true 41 if bw.ordered != nil { 42 ordered = *bw.ordered 43 } 44 45 batches := createBatches(bw.models, ordered) 46 bw.result = BulkWriteResult{ 47 UpsertedIDs: make(map[int64]interface{}), 48 } 49 50 bwErr := BulkWriteException{ 51 WriteErrors: make([]BulkWriteError, 0), 52 } 53 54 var lastErr error 55 var opIndex int64 // the operation index for the upsertedIDs map 56 continueOnError := !ordered 57 for _, batch := range batches { 58 if len(batch.models) == 0 { 59 continue 60 } 61 62 bypassDocValidation := bw.bypassDocumentValidation 63 if bypassDocValidation != nil && !*bypassDocValidation { 64 bypassDocValidation = nil 65 } 66 67 batchRes, batchErr, err := bw.runBatch(ctx, batch) 68 69 bw.mergeResults(batchRes, opIndex) 70 71 bwErr.WriteConcernError = batchErr.WriteConcernError 72 for i := range batchErr.WriteErrors { 73 batchErr.WriteErrors[i].Index = batchErr.WriteErrors[i].Index + int(opIndex) 74 } 75 76 bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...) 77 78 if !continueOnError && (err != nil || len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil) { 79 if err != nil { 80 return err 81 } 82 83 return bwErr 84 } 85 86 if err != nil { 87 lastErr = err 88 } 89 90 opIndex += int64(len(batch.models)) 91 } 92 93 bw.result.MatchedCount -= bw.result.UpsertedCount 94 if lastErr != nil { 95 return lastErr 96 } 97 if len(bwErr.WriteErrors) > 0 || bwErr.WriteConcernError != nil { 98 return bwErr 99 } 100 return nil 101} 102 103func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWriteResult, BulkWriteException, error) { 104 batchRes := BulkWriteResult{ 105 UpsertedIDs: make(map[int64]interface{}), 106 } 107 batchErr := BulkWriteException{} 108 109 var writeErrors []driver.WriteError 110 switch batch.models[0].(type) { 111 case *InsertOneModel: 112 res, err := bw.runInsert(ctx, batch) 113 if err != nil { 114 writeErr, ok := err.(driver.WriteCommandError) 115 if !ok { 116 return BulkWriteResult{}, batchErr, err 117 } 118 writeErrors = writeErr.WriteErrors 119 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) 120 } 121 batchRes.InsertedCount = int64(res.N) 122 case *DeleteOneModel, *DeleteManyModel: 123 res, err := bw.runDelete(ctx, batch) 124 if err != nil { 125 writeErr, ok := err.(driver.WriteCommandError) 126 if !ok { 127 return BulkWriteResult{}, batchErr, err 128 } 129 writeErrors = writeErr.WriteErrors 130 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) 131 } 132 batchRes.DeletedCount = int64(res.N) 133 case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel: 134 res, err := bw.runUpdate(ctx, batch) 135 if err != nil { 136 writeErr, ok := err.(driver.WriteCommandError) 137 if !ok { 138 return BulkWriteResult{}, batchErr, err 139 } 140 writeErrors = writeErr.WriteErrors 141 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) 142 } 143 batchRes.MatchedCount = int64(res.N) 144 batchRes.ModifiedCount = int64(res.NModified) 145 batchRes.UpsertedCount = int64(len(res.Upserted)) 146 for _, upsert := range res.Upserted { 147 batchRes.UpsertedIDs[upsert.Index] = upsert.ID 148 } 149 } 150 151 batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors)) 152 convWriteErrors := writeErrorsFromDriverWriteErrors(writeErrors) 153 for _, we := range convWriteErrors { 154 batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{ 155 WriteError: we, 156 Request: batch.models[we.Index], 157 }) 158 } 159 return batchRes, batchErr, nil 160} 161 162func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (operation.InsertResult, error) { 163 docs := make([]bsoncore.Document, len(batch.models)) 164 var i int 165 for _, model := range batch.models { 166 converted := model.(*InsertOneModel) 167 doc, _, err := transformAndEnsureIDv2(bw.collection.registry, converted.Document) 168 if err != nil { 169 return operation.InsertResult{}, err 170 } 171 172 docs[i] = doc 173 i++ 174 } 175 176 op := operation.NewInsert(docs...). 177 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). 178 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). 179 Database(bw.collection.db.name).Collection(bw.collection.name). 180 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt) 181 if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation { 182 op = op.BypassDocumentValidation(*bw.bypassDocumentValidation) 183 } 184 if bw.ordered != nil { 185 op = op.Ordered(*bw.ordered) 186 } 187 188 retry := driver.RetryNone 189 if bw.collection.client.retryWrites && batch.canRetry { 190 retry = driver.RetryOncePerCommand 191 } 192 op = op.Retry(retry) 193 194 err := op.Execute(ctx) 195 196 return op.Result(), err 197} 198 199func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (operation.DeleteResult, error) { 200 docs := make([]bsoncore.Document, len(batch.models)) 201 var i int 202 203 for _, model := range batch.models { 204 var doc bsoncore.Document 205 var err error 206 207 switch converted := model.(type) { 208 case *DeleteOneModel: 209 doc, err = createDeleteDoc(converted.Filter, converted.Collation, true, bw.collection.registry) 210 case *DeleteManyModel: 211 doc, err = createDeleteDoc(converted.Filter, converted.Collation, false, bw.collection.registry) 212 } 213 214 if err != nil { 215 return operation.DeleteResult{}, err 216 } 217 218 docs[i] = doc 219 i++ 220 } 221 222 op := operation.NewDelete(docs...). 223 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). 224 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). 225 Database(bw.collection.db.name).Collection(bw.collection.name). 226 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt) 227 if bw.ordered != nil { 228 op = op.Ordered(*bw.ordered) 229 } 230 retry := driver.RetryNone 231 if bw.collection.client.retryWrites && batch.canRetry { 232 retry = driver.RetryOncePerCommand 233 } 234 op = op.Retry(retry) 235 236 err := op.Execute(ctx) 237 238 return op.Result(), err 239} 240 241func createDeleteDoc(filter interface{}, collation *options.Collation, deleteOne bool, registry *bsoncodec.Registry) (bsoncore.Document, error) { 242 f, err := transformBsoncoreDocument(registry, filter) 243 if err != nil { 244 return nil, err 245 } 246 247 var limit int32 248 if deleteOne { 249 limit = 1 250 } 251 didx, doc := bsoncore.AppendDocumentStart(nil) 252 doc = bsoncore.AppendDocumentElement(doc, "q", f) 253 doc = bsoncore.AppendInt32Element(doc, "limit", limit) 254 if collation != nil { 255 doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument()) 256 } 257 doc, _ = bsoncore.AppendDocumentEnd(doc, didx) 258 259 return doc, nil 260} 261 262func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (operation.UpdateResult, error) { 263 docs := make([]bsoncore.Document, len(batch.models)) 264 for i, model := range batch.models { 265 var doc bsoncore.Document 266 var err error 267 268 switch converted := model.(type) { 269 case *ReplaceOneModel: 270 doc, err = createUpdateDoc(converted.Filter, converted.Replacement, nil, converted.Collation, converted.Upsert, false, 271 bw.collection.registry) 272 case *UpdateOneModel: 273 doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.ArrayFilters, converted.Collation, converted.Upsert, false, 274 bw.collection.registry) 275 case *UpdateManyModel: 276 doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.ArrayFilters, converted.Collation, converted.Upsert, true, 277 bw.collection.registry) 278 } 279 if err != nil { 280 return operation.UpdateResult{}, err 281 } 282 283 docs[i] = doc 284 } 285 286 op := operation.NewUpdate(docs...). 287 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). 288 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). 289 Database(bw.collection.db.name).Collection(bw.collection.name). 290 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt) 291 if bw.ordered != nil { 292 op = op.Ordered(*bw.ordered) 293 } 294 if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation { 295 op = op.BypassDocumentValidation(*bw.bypassDocumentValidation) 296 } 297 retry := driver.RetryNone 298 if bw.collection.client.retryWrites && batch.canRetry { 299 retry = driver.RetryOncePerCommand 300 } 301 op = op.Retry(retry) 302 303 err := op.Execute(ctx) 304 305 return op.Result(), err 306} 307func createUpdateDoc( 308 filter interface{}, 309 update interface{}, 310 arrayFilters *options.ArrayFilters, 311 collation *options.Collation, 312 upsert *bool, 313 multi bool, 314 registry *bsoncodec.Registry, 315) (bsoncore.Document, error) { 316 f, err := transformBsoncoreDocument(registry, filter) 317 if err != nil { 318 return nil, err 319 } 320 321 uidx, updateDoc := bsoncore.AppendDocumentStart(nil) 322 updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f) 323 324 u, err := transformUpdateValue(registry, update, false) 325 if err != nil { 326 return nil, err 327 } 328 updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u) 329 330 updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi) 331 332 if arrayFilters != nil { 333 arr, err := arrayFilters.ToArrayDocument() 334 if err != nil { 335 return nil, err 336 } 337 updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr) 338 } 339 340 if collation != nil { 341 updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument())) 342 } 343 344 if upsert != nil { 345 updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert) 346 } 347 updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx) 348 349 return updateDoc, nil 350} 351 352func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch { 353 if ordered { 354 return createOrderedBatches(models) 355 } 356 357 batches := make([]bulkWriteBatch, 5) 358 batches[insertCommand].canRetry = true 359 batches[deleteOneCommand].canRetry = true 360 batches[updateOneCommand].canRetry = true 361 362 // TODO(GODRIVER-1157): fix batching once operation retryability is fixed 363 for _, model := range models { 364 switch model.(type) { 365 case *InsertOneModel: 366 batches[insertCommand].models = append(batches[insertCommand].models, model) 367 case *DeleteOneModel: 368 batches[deleteOneCommand].models = append(batches[deleteOneCommand].models, model) 369 case *DeleteManyModel: 370 batches[deleteManyCommand].models = append(batches[deleteManyCommand].models, model) 371 case *ReplaceOneModel, *UpdateOneModel: 372 batches[updateOneCommand].models = append(batches[updateOneCommand].models, model) 373 case *UpdateManyModel: 374 batches[updateManyCommand].models = append(batches[updateManyCommand].models, model) 375 } 376 } 377 378 return batches 379} 380 381func createOrderedBatches(models []WriteModel) []bulkWriteBatch { 382 var batches []bulkWriteBatch 383 var prevKind writeCommandKind = -1 384 i := -1 // batch index 385 386 for _, model := range models { 387 var createNewBatch bool 388 var canRetry bool 389 var newKind writeCommandKind 390 391 // TODO(GODRIVER-1157): fix batching once operation retryability is fixed 392 switch model.(type) { 393 case *InsertOneModel: 394 createNewBatch = prevKind != insertCommand 395 canRetry = true 396 newKind = insertCommand 397 case *DeleteOneModel: 398 createNewBatch = prevKind != deleteOneCommand 399 canRetry = true 400 newKind = deleteOneCommand 401 case *DeleteManyModel: 402 createNewBatch = prevKind != deleteManyCommand 403 newKind = deleteManyCommand 404 case *ReplaceOneModel, *UpdateOneModel: 405 createNewBatch = prevKind != updateOneCommand 406 canRetry = true 407 newKind = updateOneCommand 408 case *UpdateManyModel: 409 createNewBatch = prevKind != updateManyCommand 410 newKind = updateManyCommand 411 } 412 413 if createNewBatch { 414 batches = append(batches, bulkWriteBatch{ 415 models: []WriteModel{model}, 416 canRetry: canRetry, 417 }) 418 i++ 419 } else { 420 batches[i].models = append(batches[i].models, model) 421 if !canRetry { 422 batches[i].canRetry = false // don't make it true if it was already false 423 } 424 } 425 426 prevKind = newKind 427 } 428 429 return batches 430} 431 432func (bw *bulkWrite) mergeResults(newResult BulkWriteResult, opIndex int64) { 433 bw.result.InsertedCount += newResult.InsertedCount 434 bw.result.MatchedCount += newResult.MatchedCount 435 bw.result.ModifiedCount += newResult.ModifiedCount 436 bw.result.DeletedCount += newResult.DeletedCount 437 bw.result.UpsertedCount += newResult.UpsertedCount 438 439 for index, upsertID := range newResult.UpsertedIDs { 440 bw.result.UpsertedIDs[index+opIndex] = upsertID 441 } 442} 443 444// WriteCommandKind is the type of command represented by a Write 445type writeCommandKind int8 446 447// These constants represent the valid types of write commands. 448const ( 449 insertCommand writeCommandKind = iota 450 updateOneCommand 451 updateManyCommand 452 deleteOneCommand 453 deleteManyCommand 454) 455