1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"context"
11
12	"go.mongodb.org/mongo-driver/bson/bsoncodec"
13	"go.mongodb.org/mongo-driver/mongo/options"
14	"go.mongodb.org/mongo-driver/mongo/writeconcern"
15	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
16	"go.mongodb.org/mongo-driver/x/mongo/driver"
17	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
18	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
19	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
20)
21
22type bulkWriteBatch struct {
23	models   []WriteModel
24	canRetry bool
25	indexes  []int
26}
27
28// bulkWrite perfoms a bulkwrite operation
29type bulkWrite struct {
30	ordered                  *bool
31	bypassDocumentValidation *bool
32	models                   []WriteModel
33	session                  *session.Client
34	collection               *Collection
35	selector                 description.ServerSelector
36	writeConcern             *writeconcern.WriteConcern
37	result                   BulkWriteResult
38}
39
40func (bw *bulkWrite) execute(ctx context.Context) error {
41	ordered := true
42	if bw.ordered != nil {
43		ordered = *bw.ordered
44	}
45
46	batches := createBatches(bw.models, ordered)
47	bw.result = BulkWriteResult{
48		UpsertedIDs: make(map[int64]interface{}),
49	}
50
51	bwErr := BulkWriteException{
52		WriteErrors: make([]BulkWriteError, 0),
53	}
54
55	var lastErr error
56	continueOnError := !ordered
57	for _, batch := range batches {
58		if len(batch.models) == 0 {
59			continue
60		}
61
62		bypassDocValidation := bw.bypassDocumentValidation
63		if bypassDocValidation != nil && !*bypassDocValidation {
64			bypassDocValidation = nil
65		}
66
67		batchRes, batchErr, err := bw.runBatch(ctx, batch)
68
69		bw.mergeResults(batchRes)
70
71		bwErr.WriteConcernError = batchErr.WriteConcernError
72		bwErr.Labels = append(bwErr.Labels, batchErr.Labels...)
73
74		bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...)
75
76		commandErrorOccurred := err != nil && err != driver.ErrUnacknowledgedWrite
77		writeErrorOccurred := len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil
78		if !continueOnError && (commandErrorOccurred || writeErrorOccurred) {
79			if err != nil {
80				return err
81			}
82
83			return bwErr
84		}
85
86		if err != nil {
87			lastErr = err
88		}
89	}
90
91	bw.result.MatchedCount -= bw.result.UpsertedCount
92	if lastErr != nil {
93		_, lastErr = processWriteError(lastErr)
94		return lastErr
95	}
96	if len(bwErr.WriteErrors) > 0 || bwErr.WriteConcernError != nil {
97		return bwErr
98	}
99	return nil
100}
101
102func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWriteResult, BulkWriteException, error) {
103	batchRes := BulkWriteResult{
104		UpsertedIDs: make(map[int64]interface{}),
105	}
106	batchErr := BulkWriteException{}
107
108	var writeErrors []driver.WriteError
109	switch batch.models[0].(type) {
110	case *InsertOneModel:
111		res, err := bw.runInsert(ctx, batch)
112		if err != nil {
113			writeErr, ok := err.(driver.WriteCommandError)
114			if !ok {
115				return BulkWriteResult{}, batchErr, err
116			}
117			writeErrors = writeErr.WriteErrors
118			batchErr.Labels = writeErr.Labels
119			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
120		}
121		batchRes.InsertedCount = int64(res.N)
122	case *DeleteOneModel, *DeleteManyModel:
123		res, err := bw.runDelete(ctx, batch)
124		if err != nil {
125			writeErr, ok := err.(driver.WriteCommandError)
126			if !ok {
127				return BulkWriteResult{}, batchErr, err
128			}
129			writeErrors = writeErr.WriteErrors
130			batchErr.Labels = writeErr.Labels
131			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
132		}
133		batchRes.DeletedCount = int64(res.N)
134	case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel:
135		res, err := bw.runUpdate(ctx, batch)
136		if err != nil {
137			writeErr, ok := err.(driver.WriteCommandError)
138			if !ok {
139				return BulkWriteResult{}, batchErr, err
140			}
141			writeErrors = writeErr.WriteErrors
142			batchErr.Labels = writeErr.Labels
143			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
144		}
145		batchRes.MatchedCount = int64(res.N)
146		batchRes.ModifiedCount = int64(res.NModified)
147		batchRes.UpsertedCount = int64(len(res.Upserted))
148		for _, upsert := range res.Upserted {
149			batchRes.UpsertedIDs[int64(batch.indexes[upsert.Index])] = upsert.ID
150		}
151	}
152
153	batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors))
154	convWriteErrors := writeErrorsFromDriverWriteErrors(writeErrors)
155	for _, we := range convWriteErrors {
156		request := batch.models[we.Index]
157		we.Index = batch.indexes[we.Index]
158		batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{
159			WriteError: we,
160			Request:    request,
161		})
162	}
163	return batchRes, batchErr, nil
164}
165
166func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (operation.InsertResult, error) {
167	docs := make([]bsoncore.Document, len(batch.models))
168	var i int
169	for _, model := range batch.models {
170		converted := model.(*InsertOneModel)
171		doc, _, err := transformAndEnsureIDv2(bw.collection.registry, converted.Document)
172		if err != nil {
173			return operation.InsertResult{}, err
174		}
175
176		docs[i] = doc
177		i++
178	}
179
180	op := operation.NewInsert(docs...).
181		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
182		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
183		Database(bw.collection.db.name).Collection(bw.collection.name).
184		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt)
185	if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
186		op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
187	}
188	if bw.ordered != nil {
189		op = op.Ordered(*bw.ordered)
190	}
191
192	retry := driver.RetryNone
193	if bw.collection.client.retryWrites && batch.canRetry {
194		retry = driver.RetryOncePerCommand
195	}
196	op = op.Retry(retry)
197
198	err := op.Execute(ctx)
199
200	return op.Result(), err
201}
202
203func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (operation.DeleteResult, error) {
204	docs := make([]bsoncore.Document, len(batch.models))
205	var i int
206	var hasHint bool
207
208	for _, model := range batch.models {
209		var doc bsoncore.Document
210		var err error
211
212		switch converted := model.(type) {
213		case *DeleteOneModel:
214			doc, err = createDeleteDoc(converted.Filter, converted.Collation, converted.Hint, true, bw.collection.registry)
215			hasHint = hasHint || (converted.Hint != nil)
216		case *DeleteManyModel:
217			doc, err = createDeleteDoc(converted.Filter, converted.Collation, converted.Hint, false, bw.collection.registry)
218			hasHint = hasHint || (converted.Hint != nil)
219		}
220
221		if err != nil {
222			return operation.DeleteResult{}, err
223		}
224
225		docs[i] = doc
226		i++
227	}
228
229	op := operation.NewDelete(docs...).
230		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
231		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
232		Database(bw.collection.db.name).Collection(bw.collection.name).
233		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt).Hint(hasHint)
234	if bw.ordered != nil {
235		op = op.Ordered(*bw.ordered)
236	}
237	retry := driver.RetryNone
238	if bw.collection.client.retryWrites && batch.canRetry {
239		retry = driver.RetryOncePerCommand
240	}
241	op = op.Retry(retry)
242
243	err := op.Execute(ctx)
244
245	return op.Result(), err
246}
247
248func createDeleteDoc(filter interface{}, collation *options.Collation, hint interface{}, deleteOne bool,
249	registry *bsoncodec.Registry) (bsoncore.Document, error) {
250
251	f, err := transformBsoncoreDocument(registry, filter)
252	if err != nil {
253		return nil, err
254	}
255
256	var limit int32
257	if deleteOne {
258		limit = 1
259	}
260	didx, doc := bsoncore.AppendDocumentStart(nil)
261	doc = bsoncore.AppendDocumentElement(doc, "q", f)
262	doc = bsoncore.AppendInt32Element(doc, "limit", limit)
263	if collation != nil {
264		doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument())
265	}
266	if hint != nil {
267		hintVal, err := transformValue(registry, hint)
268		if err != nil {
269			return nil, err
270		}
271		doc = bsoncore.AppendValueElement(doc, "hint", hintVal)
272	}
273	doc, _ = bsoncore.AppendDocumentEnd(doc, didx)
274
275	return doc, nil
276}
277
278func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (operation.UpdateResult, error) {
279	docs := make([]bsoncore.Document, len(batch.models))
280	var hasHint bool
281	var hasArrayFilters bool
282	for i, model := range batch.models {
283		var doc bsoncore.Document
284		var err error
285
286		switch converted := model.(type) {
287		case *ReplaceOneModel:
288			doc, err = createUpdateDoc(converted.Filter, converted.Replacement, converted.Hint, nil, converted.Collation, converted.Upsert, false,
289				false, bw.collection.registry)
290			hasHint = hasHint || (converted.Hint != nil)
291		case *UpdateOneModel:
292			doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, false,
293				true, bw.collection.registry)
294			hasHint = hasHint || (converted.Hint != nil)
295			hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
296		case *UpdateManyModel:
297			doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, true,
298				true, bw.collection.registry)
299			hasHint = hasHint || (converted.Hint != nil)
300			hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
301		}
302		if err != nil {
303			return operation.UpdateResult{}, err
304		}
305
306		docs[i] = doc
307	}
308
309	op := operation.NewUpdate(docs...).
310		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
311		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
312		Database(bw.collection.db.name).Collection(bw.collection.name).
313		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt).Hint(hasHint).
314		ArrayFilters(hasArrayFilters)
315	if bw.ordered != nil {
316		op = op.Ordered(*bw.ordered)
317	}
318	if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
319		op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
320	}
321	retry := driver.RetryNone
322	if bw.collection.client.retryWrites && batch.canRetry {
323		retry = driver.RetryOncePerCommand
324	}
325	op = op.Retry(retry)
326
327	err := op.Execute(ctx)
328
329	return op.Result(), err
330}
331func createUpdateDoc(
332	filter interface{},
333	update interface{},
334	hint interface{},
335	arrayFilters *options.ArrayFilters,
336	collation *options.Collation,
337	upsert *bool,
338	multi bool,
339	checkDollarKey bool,
340	registry *bsoncodec.Registry,
341) (bsoncore.Document, error) {
342	f, err := transformBsoncoreDocument(registry, filter)
343	if err != nil {
344		return nil, err
345	}
346
347	uidx, updateDoc := bsoncore.AppendDocumentStart(nil)
348	updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f)
349
350	u, err := transformUpdateValue(registry, update, checkDollarKey)
351	if err != nil {
352		return nil, err
353	}
354
355	updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)
356
357	if multi {
358		updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi)
359	}
360
361	if arrayFilters != nil {
362		arr, err := arrayFilters.ToArrayDocument()
363		if err != nil {
364			return nil, err
365		}
366		updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr)
367	}
368
369	if collation != nil {
370		updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument()))
371	}
372
373	if upsert != nil {
374		updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert)
375	}
376
377	if hint != nil {
378		hintVal, err := transformValue(registry, hint)
379		if err != nil {
380			return nil, err
381		}
382		updateDoc = bsoncore.AppendValueElement(updateDoc, "hint", hintVal)
383	}
384
385	updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx)
386
387	return updateDoc, nil
388}
389
390func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch {
391	if ordered {
392		return createOrderedBatches(models)
393	}
394
395	batches := make([]bulkWriteBatch, 5)
396	batches[insertCommand].canRetry = true
397	batches[deleteOneCommand].canRetry = true
398	batches[updateOneCommand].canRetry = true
399
400	// TODO(GODRIVER-1157): fix batching once operation retryability is fixed
401	for i, model := range models {
402		switch model.(type) {
403		case *InsertOneModel:
404			batches[insertCommand].models = append(batches[insertCommand].models, model)
405			batches[insertCommand].indexes = append(batches[insertCommand].indexes, i)
406		case *DeleteOneModel:
407			batches[deleteOneCommand].models = append(batches[deleteOneCommand].models, model)
408			batches[deleteOneCommand].indexes = append(batches[deleteOneCommand].indexes, i)
409		case *DeleteManyModel:
410			batches[deleteManyCommand].models = append(batches[deleteManyCommand].models, model)
411			batches[deleteManyCommand].indexes = append(batches[deleteManyCommand].indexes, i)
412		case *ReplaceOneModel, *UpdateOneModel:
413			batches[updateOneCommand].models = append(batches[updateOneCommand].models, model)
414			batches[updateOneCommand].indexes = append(batches[updateOneCommand].indexes, i)
415		case *UpdateManyModel:
416			batches[updateManyCommand].models = append(batches[updateManyCommand].models, model)
417			batches[updateManyCommand].indexes = append(batches[updateManyCommand].indexes, i)
418		}
419	}
420
421	return batches
422}
423
424func createOrderedBatches(models []WriteModel) []bulkWriteBatch {
425	var batches []bulkWriteBatch
426	var prevKind writeCommandKind = -1
427	i := -1 // batch index
428
429	for ind, model := range models {
430		var createNewBatch bool
431		var canRetry bool
432		var newKind writeCommandKind
433
434		// TODO(GODRIVER-1157): fix batching once operation retryability is fixed
435		switch model.(type) {
436		case *InsertOneModel:
437			createNewBatch = prevKind != insertCommand
438			canRetry = true
439			newKind = insertCommand
440		case *DeleteOneModel:
441			createNewBatch = prevKind != deleteOneCommand
442			canRetry = true
443			newKind = deleteOneCommand
444		case *DeleteManyModel:
445			createNewBatch = prevKind != deleteManyCommand
446			newKind = deleteManyCommand
447		case *ReplaceOneModel, *UpdateOneModel:
448			createNewBatch = prevKind != updateOneCommand
449			canRetry = true
450			newKind = updateOneCommand
451		case *UpdateManyModel:
452			createNewBatch = prevKind != updateManyCommand
453			newKind = updateManyCommand
454		}
455
456		if createNewBatch {
457			batches = append(batches, bulkWriteBatch{
458				models:   []WriteModel{model},
459				canRetry: canRetry,
460				indexes:  []int{ind},
461			})
462			i++
463		} else {
464			batches[i].models = append(batches[i].models, model)
465			if !canRetry {
466				batches[i].canRetry = false // don't make it true if it was already false
467			}
468			batches[i].indexes = append(batches[i].indexes, ind)
469		}
470
471		prevKind = newKind
472	}
473
474	return batches
475}
476
477func (bw *bulkWrite) mergeResults(newResult BulkWriteResult) {
478	bw.result.InsertedCount += newResult.InsertedCount
479	bw.result.MatchedCount += newResult.MatchedCount
480	bw.result.ModifiedCount += newResult.ModifiedCount
481	bw.result.DeletedCount += newResult.DeletedCount
482	bw.result.UpsertedCount += newResult.UpsertedCount
483
484	for index, upsertID := range newResult.UpsertedIDs {
485		bw.result.UpsertedIDs[index] = upsertID
486	}
487}
488
489// WriteCommandKind is the type of command represented by a Write
490type writeCommandKind int8
491
492// These constants represent the valid types of write commands.
493const (
494	insertCommand writeCommandKind = iota
495	updateOneCommand
496	updateManyCommand
497	deleteOneCommand
498	deleteManyCommand
499)
500