1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"context"
11
12	"go.mongodb.org/mongo-driver/bson/bsoncodec"
13	"go.mongodb.org/mongo-driver/mongo/options"
14	"go.mongodb.org/mongo-driver/mongo/writeconcern"
15	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
16	"go.mongodb.org/mongo-driver/x/mongo/driver"
17	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
18	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
19	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
20)
21
22type bulkWriteBatch struct {
23	models   []WriteModel
24	canRetry bool
25}
26
27// bulkWrite perfoms a bulkwrite operation
28type bulkWrite struct {
29	ordered                  *bool
30	bypassDocumentValidation *bool
31	models                   []WriteModel
32	session                  *session.Client
33	collection               *Collection
34	selector                 description.ServerSelector
35	writeConcern             *writeconcern.WriteConcern
36	result                   BulkWriteResult
37}
38
39func (bw *bulkWrite) execute(ctx context.Context) error {
40	ordered := true
41	if bw.ordered != nil {
42		ordered = *bw.ordered
43	}
44
45	batches := createBatches(bw.models, ordered)
46	bw.result = BulkWriteResult{
47		UpsertedIDs: make(map[int64]interface{}),
48	}
49
50	bwErr := BulkWriteException{
51		WriteErrors: make([]BulkWriteError, 0),
52	}
53
54	var lastErr error
55	var opIndex int64 // the operation index for the upsertedIDs map
56	continueOnError := !ordered
57	for _, batch := range batches {
58		if len(batch.models) == 0 {
59			continue
60		}
61
62		bypassDocValidation := bw.bypassDocumentValidation
63		if bypassDocValidation != nil && !*bypassDocValidation {
64			bypassDocValidation = nil
65		}
66
67		batchRes, batchErr, err := bw.runBatch(ctx, batch)
68
69		bw.mergeResults(batchRes, opIndex)
70
71		bwErr.WriteConcernError = batchErr.WriteConcernError
72		for i := range batchErr.WriteErrors {
73			batchErr.WriteErrors[i].Index = batchErr.WriteErrors[i].Index + int(opIndex)
74		}
75
76		bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...)
77
78		if !continueOnError && (err != nil || len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil) {
79			if err != nil {
80				return err
81			}
82
83			return bwErr
84		}
85
86		if err != nil {
87			lastErr = err
88		}
89
90		opIndex += int64(len(batch.models))
91	}
92
93	bw.result.MatchedCount -= bw.result.UpsertedCount
94	if lastErr != nil {
95		return lastErr
96	}
97	if len(bwErr.WriteErrors) > 0 || bwErr.WriteConcernError != nil {
98		return bwErr
99	}
100	return nil
101}
102
103func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWriteResult, BulkWriteException, error) {
104	batchRes := BulkWriteResult{
105		UpsertedIDs: make(map[int64]interface{}),
106	}
107	batchErr := BulkWriteException{}
108
109	var writeErrors []driver.WriteError
110	switch batch.models[0].(type) {
111	case *InsertOneModel:
112		res, err := bw.runInsert(ctx, batch)
113		if err != nil {
114			writeErr, ok := err.(driver.WriteCommandError)
115			if !ok {
116				return BulkWriteResult{}, batchErr, err
117			}
118			writeErrors = writeErr.WriteErrors
119			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
120		}
121		batchRes.InsertedCount = int64(res.N)
122	case *DeleteOneModel, *DeleteManyModel:
123		res, err := bw.runDelete(ctx, batch)
124		if err != nil {
125			writeErr, ok := err.(driver.WriteCommandError)
126			if !ok {
127				return BulkWriteResult{}, batchErr, err
128			}
129			writeErrors = writeErr.WriteErrors
130			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
131		}
132		batchRes.DeletedCount = int64(res.N)
133	case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel:
134		res, err := bw.runUpdate(ctx, batch)
135		if err != nil {
136			writeErr, ok := err.(driver.WriteCommandError)
137			if !ok {
138				return BulkWriteResult{}, batchErr, err
139			}
140			writeErrors = writeErr.WriteErrors
141			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
142		}
143		batchRes.MatchedCount = int64(res.N)
144		batchRes.ModifiedCount = int64(res.NModified)
145		batchRes.UpsertedCount = int64(len(res.Upserted))
146		for _, upsert := range res.Upserted {
147			batchRes.UpsertedIDs[upsert.Index] = upsert.ID
148		}
149	}
150
151	batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors))
152	convWriteErrors := writeErrorsFromDriverWriteErrors(writeErrors)
153	for _, we := range convWriteErrors {
154		batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{
155			WriteError: we,
156			Request:    batch.models[we.Index],
157		})
158	}
159	return batchRes, batchErr, nil
160}
161
162func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (operation.InsertResult, error) {
163	docs := make([]bsoncore.Document, len(batch.models))
164	var i int
165	for _, model := range batch.models {
166		converted := model.(*InsertOneModel)
167		doc, _, err := transformAndEnsureIDv2(bw.collection.registry, converted.Document)
168		if err != nil {
169			return operation.InsertResult{}, err
170		}
171
172		docs[i] = doc
173		i++
174	}
175
176	op := operation.NewInsert(docs...).
177		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
178		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
179		Database(bw.collection.db.name).Collection(bw.collection.name).
180		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt)
181	if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
182		op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
183	}
184	if bw.ordered != nil {
185		op = op.Ordered(*bw.ordered)
186	}
187
188	retry := driver.RetryNone
189	if bw.collection.client.retryWrites && batch.canRetry {
190		retry = driver.RetryOncePerCommand
191	}
192	op = op.Retry(retry)
193
194	err := op.Execute(ctx)
195
196	return op.Result(), err
197}
198
199func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (operation.DeleteResult, error) {
200	docs := make([]bsoncore.Document, len(batch.models))
201	var i int
202
203	for _, model := range batch.models {
204		var doc bsoncore.Document
205		var err error
206
207		switch converted := model.(type) {
208		case *DeleteOneModel:
209			doc, err = createDeleteDoc(converted.Filter, converted.Collation, true, bw.collection.registry)
210		case *DeleteManyModel:
211			doc, err = createDeleteDoc(converted.Filter, converted.Collation, false, bw.collection.registry)
212		}
213
214		if err != nil {
215			return operation.DeleteResult{}, err
216		}
217
218		docs[i] = doc
219		i++
220	}
221
222	op := operation.NewDelete(docs...).
223		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
224		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
225		Database(bw.collection.db.name).Collection(bw.collection.name).
226		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt)
227	if bw.ordered != nil {
228		op = op.Ordered(*bw.ordered)
229	}
230	retry := driver.RetryNone
231	if bw.collection.client.retryWrites && batch.canRetry {
232		retry = driver.RetryOncePerCommand
233	}
234	op = op.Retry(retry)
235
236	err := op.Execute(ctx)
237
238	return op.Result(), err
239}
240
241func createDeleteDoc(filter interface{}, collation *options.Collation, deleteOne bool, registry *bsoncodec.Registry) (bsoncore.Document, error) {
242	f, err := transformBsoncoreDocument(registry, filter)
243	if err != nil {
244		return nil, err
245	}
246
247	var limit int32
248	if deleteOne {
249		limit = 1
250	}
251	didx, doc := bsoncore.AppendDocumentStart(nil)
252	doc = bsoncore.AppendDocumentElement(doc, "q", f)
253	doc = bsoncore.AppendInt32Element(doc, "limit", limit)
254	if collation != nil {
255		doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument())
256	}
257	doc, _ = bsoncore.AppendDocumentEnd(doc, didx)
258
259	return doc, nil
260}
261
262func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (operation.UpdateResult, error) {
263	docs := make([]bsoncore.Document, len(batch.models))
264	for i, model := range batch.models {
265		var doc bsoncore.Document
266		var err error
267
268		switch converted := model.(type) {
269		case *ReplaceOneModel:
270			doc, err = createUpdateDoc(converted.Filter, converted.Replacement, nil, converted.Collation, converted.Upsert, false,
271				bw.collection.registry)
272		case *UpdateOneModel:
273			doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.ArrayFilters, converted.Collation, converted.Upsert, false,
274				bw.collection.registry)
275		case *UpdateManyModel:
276			doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.ArrayFilters, converted.Collation, converted.Upsert, true,
277				bw.collection.registry)
278		}
279		if err != nil {
280			return operation.UpdateResult{}, err
281		}
282
283		docs[i] = doc
284	}
285
286	op := operation.NewUpdate(docs...).
287		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
288		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
289		Database(bw.collection.db.name).Collection(bw.collection.name).
290		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt)
291	if bw.ordered != nil {
292		op = op.Ordered(*bw.ordered)
293	}
294	if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
295		op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
296	}
297	retry := driver.RetryNone
298	if bw.collection.client.retryWrites && batch.canRetry {
299		retry = driver.RetryOncePerCommand
300	}
301	op = op.Retry(retry)
302
303	err := op.Execute(ctx)
304
305	return op.Result(), err
306}
307func createUpdateDoc(
308	filter interface{},
309	update interface{},
310	arrayFilters *options.ArrayFilters,
311	collation *options.Collation,
312	upsert *bool,
313	multi bool,
314	registry *bsoncodec.Registry,
315) (bsoncore.Document, error) {
316	f, err := transformBsoncoreDocument(registry, filter)
317	if err != nil {
318		return nil, err
319	}
320
321	uidx, updateDoc := bsoncore.AppendDocumentStart(nil)
322	updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f)
323
324	u, err := transformUpdateValue(registry, update, false)
325	if err != nil {
326		return nil, err
327	}
328	updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)
329
330	updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi)
331
332	if arrayFilters != nil {
333		arr, err := arrayFilters.ToArrayDocument()
334		if err != nil {
335			return nil, err
336		}
337		updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr)
338	}
339
340	if collation != nil {
341		updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument()))
342	}
343
344	if upsert != nil {
345		updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert)
346	}
347	updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx)
348
349	return updateDoc, nil
350}
351
352func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch {
353	if ordered {
354		return createOrderedBatches(models)
355	}
356
357	batches := make([]bulkWriteBatch, 5)
358	batches[insertCommand].canRetry = true
359	batches[deleteOneCommand].canRetry = true
360	batches[updateOneCommand].canRetry = true
361
362	// TODO(GODRIVER-1157): fix batching once operation retryability is fixed
363	for _, model := range models {
364		switch model.(type) {
365		case *InsertOneModel:
366			batches[insertCommand].models = append(batches[insertCommand].models, model)
367		case *DeleteOneModel:
368			batches[deleteOneCommand].models = append(batches[deleteOneCommand].models, model)
369		case *DeleteManyModel:
370			batches[deleteManyCommand].models = append(batches[deleteManyCommand].models, model)
371		case *ReplaceOneModel, *UpdateOneModel:
372			batches[updateOneCommand].models = append(batches[updateOneCommand].models, model)
373		case *UpdateManyModel:
374			batches[updateManyCommand].models = append(batches[updateManyCommand].models, model)
375		}
376	}
377
378	return batches
379}
380
381func createOrderedBatches(models []WriteModel) []bulkWriteBatch {
382	var batches []bulkWriteBatch
383	var prevKind writeCommandKind = -1
384	i := -1 // batch index
385
386	for _, model := range models {
387		var createNewBatch bool
388		var canRetry bool
389		var newKind writeCommandKind
390
391		// TODO(GODRIVER-1157): fix batching once operation retryability is fixed
392		switch model.(type) {
393		case *InsertOneModel:
394			createNewBatch = prevKind != insertCommand
395			canRetry = true
396			newKind = insertCommand
397		case *DeleteOneModel:
398			createNewBatch = prevKind != deleteOneCommand
399			canRetry = true
400			newKind = deleteOneCommand
401		case *DeleteManyModel:
402			createNewBatch = prevKind != deleteManyCommand
403			newKind = deleteManyCommand
404		case *ReplaceOneModel, *UpdateOneModel:
405			createNewBatch = prevKind != updateOneCommand
406			canRetry = true
407			newKind = updateOneCommand
408		case *UpdateManyModel:
409			createNewBatch = prevKind != updateManyCommand
410			newKind = updateManyCommand
411		}
412
413		if createNewBatch {
414			batches = append(batches, bulkWriteBatch{
415				models:   []WriteModel{model},
416				canRetry: canRetry,
417			})
418			i++
419		} else {
420			batches[i].models = append(batches[i].models, model)
421			if !canRetry {
422				batches[i].canRetry = false // don't make it true if it was already false
423			}
424		}
425
426		prevKind = newKind
427	}
428
429	return batches
430}
431
432func (bw *bulkWrite) mergeResults(newResult BulkWriteResult, opIndex int64) {
433	bw.result.InsertedCount += newResult.InsertedCount
434	bw.result.MatchedCount += newResult.MatchedCount
435	bw.result.ModifiedCount += newResult.ModifiedCount
436	bw.result.DeletedCount += newResult.DeletedCount
437	bw.result.UpsertedCount += newResult.UpsertedCount
438
439	for index, upsertID := range newResult.UpsertedIDs {
440		bw.result.UpsertedIDs[index+opIndex] = upsertID
441	}
442}
443
444// WriteCommandKind is the type of command represented by a Write
445type writeCommandKind int8
446
447// These constants represent the valid types of write commands.
448const (
449	insertCommand writeCommandKind = iota
450	updateOneCommand
451	updateManyCommand
452	deleteOneCommand
453	deleteManyCommand
454)
455