1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package mongo 8 9import ( 10 "context" 11 12 "go.mongodb.org/mongo-driver/bson/bsoncodec" 13 "go.mongodb.org/mongo-driver/mongo/options" 14 "go.mongodb.org/mongo-driver/mongo/writeconcern" 15 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" 16 "go.mongodb.org/mongo-driver/x/mongo/driver" 17 "go.mongodb.org/mongo-driver/x/mongo/driver/description" 18 "go.mongodb.org/mongo-driver/x/mongo/driver/operation" 19 "go.mongodb.org/mongo-driver/x/mongo/driver/session" 20) 21 22type bulkWriteBatch struct { 23 models []WriteModel 24 canRetry bool 25 indexes []int 26} 27 28// bulkWrite perfoms a bulkwrite operation 29type bulkWrite struct { 30 ordered *bool 31 bypassDocumentValidation *bool 32 models []WriteModel 33 session *session.Client 34 collection *Collection 35 selector description.ServerSelector 36 writeConcern *writeconcern.WriteConcern 37 result BulkWriteResult 38} 39 40func (bw *bulkWrite) execute(ctx context.Context) error { 41 ordered := true 42 if bw.ordered != nil { 43 ordered = *bw.ordered 44 } 45 46 batches := createBatches(bw.models, ordered) 47 bw.result = BulkWriteResult{ 48 UpsertedIDs: make(map[int64]interface{}), 49 } 50 51 bwErr := BulkWriteException{ 52 WriteErrors: make([]BulkWriteError, 0), 53 } 54 55 var lastErr error 56 continueOnError := !ordered 57 for _, batch := range batches { 58 if len(batch.models) == 0 { 59 continue 60 } 61 62 bypassDocValidation := bw.bypassDocumentValidation 63 if bypassDocValidation != nil && !*bypassDocValidation { 64 bypassDocValidation = nil 65 } 66 67 batchRes, batchErr, err := bw.runBatch(ctx, batch) 68 69 bw.mergeResults(batchRes) 70 71 bwErr.WriteConcernError = batchErr.WriteConcernError 72 bwErr.Labels = append(bwErr.Labels, batchErr.Labels...) 73 74 bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...) 75 76 commandErrorOccurred := err != nil && err != driver.ErrUnacknowledgedWrite 77 writeErrorOccurred := len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil 78 if !continueOnError && (commandErrorOccurred || writeErrorOccurred) { 79 if err != nil { 80 return err 81 } 82 83 return bwErr 84 } 85 86 if err != nil { 87 lastErr = err 88 } 89 } 90 91 bw.result.MatchedCount -= bw.result.UpsertedCount 92 if lastErr != nil { 93 _, lastErr = processWriteError(lastErr) 94 return lastErr 95 } 96 if len(bwErr.WriteErrors) > 0 || bwErr.WriteConcernError != nil { 97 return bwErr 98 } 99 return nil 100} 101 102func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWriteResult, BulkWriteException, error) { 103 batchRes := BulkWriteResult{ 104 UpsertedIDs: make(map[int64]interface{}), 105 } 106 batchErr := BulkWriteException{} 107 108 var writeErrors []driver.WriteError 109 switch batch.models[0].(type) { 110 case *InsertOneModel: 111 res, err := bw.runInsert(ctx, batch) 112 if err != nil { 113 writeErr, ok := err.(driver.WriteCommandError) 114 if !ok { 115 return BulkWriteResult{}, batchErr, err 116 } 117 writeErrors = writeErr.WriteErrors 118 batchErr.Labels = writeErr.Labels 119 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) 120 } 121 batchRes.InsertedCount = int64(res.N) 122 case *DeleteOneModel, *DeleteManyModel: 123 res, err := bw.runDelete(ctx, batch) 124 if err != nil { 125 writeErr, ok := err.(driver.WriteCommandError) 126 if !ok { 127 return BulkWriteResult{}, batchErr, err 128 } 129 writeErrors = writeErr.WriteErrors 130 batchErr.Labels = writeErr.Labels 131 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) 132 } 133 batchRes.DeletedCount = int64(res.N) 134 case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel: 135 res, err := bw.runUpdate(ctx, batch) 136 if err != nil { 137 writeErr, ok := err.(driver.WriteCommandError) 138 if !ok { 139 return BulkWriteResult{}, batchErr, err 140 } 141 writeErrors = writeErr.WriteErrors 142 batchErr.Labels = writeErr.Labels 143 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) 144 } 145 batchRes.MatchedCount = int64(res.N) 146 batchRes.ModifiedCount = int64(res.NModified) 147 batchRes.UpsertedCount = int64(len(res.Upserted)) 148 for _, upsert := range res.Upserted { 149 batchRes.UpsertedIDs[int64(batch.indexes[upsert.Index])] = upsert.ID 150 } 151 } 152 153 batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors)) 154 convWriteErrors := writeErrorsFromDriverWriteErrors(writeErrors) 155 for _, we := range convWriteErrors { 156 request := batch.models[we.Index] 157 we.Index = batch.indexes[we.Index] 158 batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{ 159 WriteError: we, 160 Request: request, 161 }) 162 } 163 return batchRes, batchErr, nil 164} 165 166func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (operation.InsertResult, error) { 167 docs := make([]bsoncore.Document, len(batch.models)) 168 var i int 169 for _, model := range batch.models { 170 converted := model.(*InsertOneModel) 171 doc, _, err := transformAndEnsureIDv2(bw.collection.registry, converted.Document) 172 if err != nil { 173 return operation.InsertResult{}, err 174 } 175 176 docs[i] = doc 177 i++ 178 } 179 180 op := operation.NewInsert(docs...). 181 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). 182 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). 183 Database(bw.collection.db.name).Collection(bw.collection.name). 184 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt) 185 if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation { 186 op = op.BypassDocumentValidation(*bw.bypassDocumentValidation) 187 } 188 if bw.ordered != nil { 189 op = op.Ordered(*bw.ordered) 190 } 191 192 retry := driver.RetryNone 193 if bw.collection.client.retryWrites && batch.canRetry { 194 retry = driver.RetryOncePerCommand 195 } 196 op = op.Retry(retry) 197 198 err := op.Execute(ctx) 199 200 return op.Result(), err 201} 202 203func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (operation.DeleteResult, error) { 204 docs := make([]bsoncore.Document, len(batch.models)) 205 var i int 206 var hasHint bool 207 208 for _, model := range batch.models { 209 var doc bsoncore.Document 210 var err error 211 212 switch converted := model.(type) { 213 case *DeleteOneModel: 214 doc, err = createDeleteDoc(converted.Filter, converted.Collation, converted.Hint, true, bw.collection.registry) 215 hasHint = hasHint || (converted.Hint != nil) 216 case *DeleteManyModel: 217 doc, err = createDeleteDoc(converted.Filter, converted.Collation, converted.Hint, false, bw.collection.registry) 218 hasHint = hasHint || (converted.Hint != nil) 219 } 220 221 if err != nil { 222 return operation.DeleteResult{}, err 223 } 224 225 docs[i] = doc 226 i++ 227 } 228 229 op := operation.NewDelete(docs...). 230 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). 231 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). 232 Database(bw.collection.db.name).Collection(bw.collection.name). 233 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt).Hint(hasHint) 234 if bw.ordered != nil { 235 op = op.Ordered(*bw.ordered) 236 } 237 retry := driver.RetryNone 238 if bw.collection.client.retryWrites && batch.canRetry { 239 retry = driver.RetryOncePerCommand 240 } 241 op = op.Retry(retry) 242 243 err := op.Execute(ctx) 244 245 return op.Result(), err 246} 247 248func createDeleteDoc(filter interface{}, collation *options.Collation, hint interface{}, deleteOne bool, 249 registry *bsoncodec.Registry) (bsoncore.Document, error) { 250 251 f, err := transformBsoncoreDocument(registry, filter) 252 if err != nil { 253 return nil, err 254 } 255 256 var limit int32 257 if deleteOne { 258 limit = 1 259 } 260 didx, doc := bsoncore.AppendDocumentStart(nil) 261 doc = bsoncore.AppendDocumentElement(doc, "q", f) 262 doc = bsoncore.AppendInt32Element(doc, "limit", limit) 263 if collation != nil { 264 doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument()) 265 } 266 if hint != nil { 267 hintVal, err := transformValue(registry, hint) 268 if err != nil { 269 return nil, err 270 } 271 doc = bsoncore.AppendValueElement(doc, "hint", hintVal) 272 } 273 doc, _ = bsoncore.AppendDocumentEnd(doc, didx) 274 275 return doc, nil 276} 277 278func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (operation.UpdateResult, error) { 279 docs := make([]bsoncore.Document, len(batch.models)) 280 var hasHint bool 281 var hasArrayFilters bool 282 for i, model := range batch.models { 283 var doc bsoncore.Document 284 var err error 285 286 switch converted := model.(type) { 287 case *ReplaceOneModel: 288 doc, err = createUpdateDoc(converted.Filter, converted.Replacement, converted.Hint, nil, converted.Collation, converted.Upsert, false, 289 false, bw.collection.registry) 290 hasHint = hasHint || (converted.Hint != nil) 291 case *UpdateOneModel: 292 doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, false, 293 true, bw.collection.registry) 294 hasHint = hasHint || (converted.Hint != nil) 295 hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil) 296 case *UpdateManyModel: 297 doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, true, 298 true, bw.collection.registry) 299 hasHint = hasHint || (converted.Hint != nil) 300 hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil) 301 } 302 if err != nil { 303 return operation.UpdateResult{}, err 304 } 305 306 docs[i] = doc 307 } 308 309 op := operation.NewUpdate(docs...). 310 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). 311 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). 312 Database(bw.collection.db.name).Collection(bw.collection.name). 313 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt).Hint(hasHint). 314 ArrayFilters(hasArrayFilters) 315 if bw.ordered != nil { 316 op = op.Ordered(*bw.ordered) 317 } 318 if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation { 319 op = op.BypassDocumentValidation(*bw.bypassDocumentValidation) 320 } 321 retry := driver.RetryNone 322 if bw.collection.client.retryWrites && batch.canRetry { 323 retry = driver.RetryOncePerCommand 324 } 325 op = op.Retry(retry) 326 327 err := op.Execute(ctx) 328 329 return op.Result(), err 330} 331func createUpdateDoc( 332 filter interface{}, 333 update interface{}, 334 hint interface{}, 335 arrayFilters *options.ArrayFilters, 336 collation *options.Collation, 337 upsert *bool, 338 multi bool, 339 checkDollarKey bool, 340 registry *bsoncodec.Registry, 341) (bsoncore.Document, error) { 342 f, err := transformBsoncoreDocument(registry, filter) 343 if err != nil { 344 return nil, err 345 } 346 347 uidx, updateDoc := bsoncore.AppendDocumentStart(nil) 348 updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f) 349 350 u, err := transformUpdateValue(registry, update, checkDollarKey) 351 if err != nil { 352 return nil, err 353 } 354 355 updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u) 356 357 if multi { 358 updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi) 359 } 360 361 if arrayFilters != nil { 362 arr, err := arrayFilters.ToArrayDocument() 363 if err != nil { 364 return nil, err 365 } 366 updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr) 367 } 368 369 if collation != nil { 370 updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument())) 371 } 372 373 if upsert != nil { 374 updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert) 375 } 376 377 if hint != nil { 378 hintVal, err := transformValue(registry, hint) 379 if err != nil { 380 return nil, err 381 } 382 updateDoc = bsoncore.AppendValueElement(updateDoc, "hint", hintVal) 383 } 384 385 updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx) 386 387 return updateDoc, nil 388} 389 390func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch { 391 if ordered { 392 return createOrderedBatches(models) 393 } 394 395 batches := make([]bulkWriteBatch, 5) 396 batches[insertCommand].canRetry = true 397 batches[deleteOneCommand].canRetry = true 398 batches[updateOneCommand].canRetry = true 399 400 // TODO(GODRIVER-1157): fix batching once operation retryability is fixed 401 for i, model := range models { 402 switch model.(type) { 403 case *InsertOneModel: 404 batches[insertCommand].models = append(batches[insertCommand].models, model) 405 batches[insertCommand].indexes = append(batches[insertCommand].indexes, i) 406 case *DeleteOneModel: 407 batches[deleteOneCommand].models = append(batches[deleteOneCommand].models, model) 408 batches[deleteOneCommand].indexes = append(batches[deleteOneCommand].indexes, i) 409 case *DeleteManyModel: 410 batches[deleteManyCommand].models = append(batches[deleteManyCommand].models, model) 411 batches[deleteManyCommand].indexes = append(batches[deleteManyCommand].indexes, i) 412 case *ReplaceOneModel, *UpdateOneModel: 413 batches[updateOneCommand].models = append(batches[updateOneCommand].models, model) 414 batches[updateOneCommand].indexes = append(batches[updateOneCommand].indexes, i) 415 case *UpdateManyModel: 416 batches[updateManyCommand].models = append(batches[updateManyCommand].models, model) 417 batches[updateManyCommand].indexes = append(batches[updateManyCommand].indexes, i) 418 } 419 } 420 421 return batches 422} 423 424func createOrderedBatches(models []WriteModel) []bulkWriteBatch { 425 var batches []bulkWriteBatch 426 var prevKind writeCommandKind = -1 427 i := -1 // batch index 428 429 for ind, model := range models { 430 var createNewBatch bool 431 var canRetry bool 432 var newKind writeCommandKind 433 434 // TODO(GODRIVER-1157): fix batching once operation retryability is fixed 435 switch model.(type) { 436 case *InsertOneModel: 437 createNewBatch = prevKind != insertCommand 438 canRetry = true 439 newKind = insertCommand 440 case *DeleteOneModel: 441 createNewBatch = prevKind != deleteOneCommand 442 canRetry = true 443 newKind = deleteOneCommand 444 case *DeleteManyModel: 445 createNewBatch = prevKind != deleteManyCommand 446 newKind = deleteManyCommand 447 case *ReplaceOneModel, *UpdateOneModel: 448 createNewBatch = prevKind != updateOneCommand 449 canRetry = true 450 newKind = updateOneCommand 451 case *UpdateManyModel: 452 createNewBatch = prevKind != updateManyCommand 453 newKind = updateManyCommand 454 } 455 456 if createNewBatch { 457 batches = append(batches, bulkWriteBatch{ 458 models: []WriteModel{model}, 459 canRetry: canRetry, 460 indexes: []int{ind}, 461 }) 462 i++ 463 } else { 464 batches[i].models = append(batches[i].models, model) 465 if !canRetry { 466 batches[i].canRetry = false // don't make it true if it was already false 467 } 468 batches[i].indexes = append(batches[i].indexes, ind) 469 } 470 471 prevKind = newKind 472 } 473 474 return batches 475} 476 477func (bw *bulkWrite) mergeResults(newResult BulkWriteResult) { 478 bw.result.InsertedCount += newResult.InsertedCount 479 bw.result.MatchedCount += newResult.MatchedCount 480 bw.result.ModifiedCount += newResult.ModifiedCount 481 bw.result.DeletedCount += newResult.DeletedCount 482 bw.result.UpsertedCount += newResult.UpsertedCount 483 484 for index, upsertID := range newResult.UpsertedIDs { 485 bw.result.UpsertedIDs[index] = upsertID 486 } 487} 488 489// WriteCommandKind is the type of command represented by a Write 490type writeCommandKind int8 491 492// These constants represent the valid types of write commands. 493const ( 494 insertCommand writeCommandKind = iota 495 updateOneCommand 496 updateManyCommand 497 deleteOneCommand 498 deleteManyCommand 499) 500