1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command // import "go.mongodb.org/mongo-driver/x/network/command"
8
9import (
10	"errors"
11
12	"context"
13
14	"fmt"
15
16	"go.mongodb.org/mongo-driver/bson"
17	"go.mongodb.org/mongo-driver/bson/bsontype"
18	"go.mongodb.org/mongo-driver/bson/primitive"
19	"go.mongodb.org/mongo-driver/mongo/readconcern"
20	"go.mongodb.org/mongo-driver/mongo/writeconcern"
21	"go.mongodb.org/mongo-driver/x/bsonx"
22	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
23	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
24	"go.mongodb.org/mongo-driver/x/network/description"
25	"go.mongodb.org/mongo-driver/x/network/result"
26	"go.mongodb.org/mongo-driver/x/network/wiremessage"
27)
28
29// WriteBatch represents a single batch for a write operation.
30type WriteBatch struct {
31	*Write
32	numDocs int
33}
34
35// DecodeError attempts to decode the wiremessage as an error
36func DecodeError(wm wiremessage.WireMessage) error {
37	var rdr bson.Raw
38	switch msg := wm.(type) {
39	case wiremessage.Msg:
40		for _, section := range msg.Sections {
41			switch converted := section.(type) {
42			case wiremessage.SectionBody:
43				rdr = converted.Document
44			}
45		}
46	case wiremessage.Reply:
47		if msg.ResponseFlags&wiremessage.QueryFailure != wiremessage.QueryFailure {
48			return nil
49		}
50		rdr = msg.Documents[0]
51	}
52
53	err := rdr.Validate()
54	if err != nil {
55		return nil
56	}
57
58	extractedError := extractError(rdr)
59
60	// If parsed successfully return the error
61	if _, ok := extractedError.(Error); ok {
62		return err
63	}
64
65	return nil
66}
67
68// helper method to extract an error from a reader if there is one; first returned item is the
69// error if it exists, the second holds parsing errors
70func extractError(rdr bson.Raw) error {
71	var errmsg, codeName string
72	var code int32
73	var labels []string
74	elems, err := rdr.Elements()
75	if err != nil {
76		return err
77	}
78
79	for _, elem := range elems {
80		switch elem.Key() {
81		case "ok":
82			switch elem.Value().Type {
83			case bson.TypeInt32:
84				if elem.Value().Int32() == 1 {
85					return nil
86				}
87			case bson.TypeInt64:
88				if elem.Value().Int64() == 1 {
89					return nil
90				}
91			case bson.TypeDouble:
92				if elem.Value().Double() == 1 {
93					return nil
94				}
95			}
96		case "errmsg":
97			if str, okay := elem.Value().StringValueOK(); okay {
98				errmsg = str
99			}
100		case "codeName":
101			if str, okay := elem.Value().StringValueOK(); okay {
102				codeName = str
103			}
104		case "code":
105			if c, okay := elem.Value().Int32OK(); okay {
106				code = c
107			}
108		case "errorLabels":
109			if arr, okay := elem.Value().ArrayOK(); okay {
110				elems, err := arr.Elements()
111				if err != nil {
112					continue
113				}
114				for _, elem := range elems {
115					if str, ok := elem.Value().StringValueOK(); ok {
116						labels = append(labels, str)
117					}
118				}
119
120			}
121		}
122	}
123
124	if errmsg == "" {
125		errmsg = "command failed"
126	}
127
128	return Error{
129		Code:    code,
130		Message: errmsg,
131		Name:    codeName,
132		Labels:  labels,
133	}
134}
135
136func responseClusterTime(response bson.Raw) bson.Raw {
137	clusterTime, err := response.LookupErr("$clusterTime")
138	if err != nil {
139		// $clusterTime not included by the server
140		return nil
141	}
142	idx, doc := bsoncore.AppendDocumentStart(nil)
143	doc = bsoncore.AppendHeader(doc, clusterTime.Type, "$clusterTime")
144	doc = append(doc, clusterTime.Value...)
145	doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
146	return doc
147}
148
149func updateClusterTimes(sess *session.Client, clock *session.ClusterClock, response bson.Raw) error {
150	clusterTime := responseClusterTime(response)
151	if clusterTime == nil {
152		return nil
153	}
154
155	if sess != nil {
156		err := sess.AdvanceClusterTime(clusterTime)
157		if err != nil {
158			return err
159		}
160	}
161
162	if clock != nil {
163		clock.AdvanceClusterTime(clusterTime)
164	}
165
166	return nil
167}
168
169func updateOperationTime(sess *session.Client, response bson.Raw) error {
170	if sess == nil {
171		return nil
172	}
173
174	opTimeElem, err := response.LookupErr("operationTime")
175	if err != nil {
176		// operationTime not included by the server
177		return nil
178	}
179
180	t, i := opTimeElem.Timestamp()
181	return sess.AdvanceOperationTime(&primitive.Timestamp{
182		T: t,
183		I: i,
184	})
185}
186
187func marshalCommand(cmd bsonx.Doc) (bson.Raw, error) {
188	if cmd == nil {
189		return bson.Raw{5, 0, 0, 0, 0}, nil
190	}
191
192	return cmd.MarshalBSON()
193}
194
195// adds session related fields to a BSON doc representing a command
196func addSessionFields(cmd bsonx.Doc, desc description.SelectedServer, client *session.Client) (bsonx.Doc, error) {
197	if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 {
198		return cmd, nil
199	}
200
201	if client.Terminated {
202		return cmd, session.ErrSessionEnded
203	}
204
205	if _, err := cmd.LookupElementErr("lsid"); err != nil {
206		cmd = cmd.Delete("lsid")
207	}
208
209	cmd = append(cmd, bsonx.Elem{"lsid", bsonx.Document(client.SessionID)})
210
211	if client.TransactionRunning() ||
212		client.RetryingCommit {
213		cmd = addTransaction(cmd, client)
214	}
215
216	client.ApplyCommand() // advance the state machine based on a command executing
217
218	return cmd, nil
219}
220
221// if in a transaction, add the transaction fields
222func addTransaction(cmd bsonx.Doc, client *session.Client) bsonx.Doc {
223	cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(client.TxnNumber)})
224	if client.TransactionStarting() {
225		// When starting transaction, always transition to the next state, even on error
226		cmd = append(cmd, bsonx.Elem{"startTransaction", bsonx.Boolean(true)})
227	}
228	return append(cmd, bsonx.Elem{"autocommit", bsonx.Boolean(false)})
229}
230
231func addClusterTime(cmd bsonx.Doc, desc description.SelectedServer, sess *session.Client, clock *session.ClusterClock) bsonx.Doc {
232	if (clock == nil && sess == nil) || !description.SessionsSupported(desc.WireVersion) {
233		return cmd
234	}
235
236	var clusterTime bson.Raw
237	if clock != nil {
238		clusterTime = clock.GetClusterTime()
239	}
240
241	if sess != nil {
242		if clusterTime == nil {
243			clusterTime = sess.ClusterTime
244		} else {
245			clusterTime = session.MaxClusterTime(clusterTime, sess.ClusterTime)
246		}
247	}
248
249	if clusterTime == nil {
250		return cmd
251	}
252
253	d, err := bsonx.ReadDoc(clusterTime)
254	if err != nil {
255		return cmd // broken clusterTime
256	}
257
258	cmd = cmd.Delete("$clusterTime")
259
260	return append(cmd, d...)
261}
262
263// add a read concern to a BSON doc representing a command
264func addReadConcern(cmd bsonx.Doc, desc description.SelectedServer, rc *readconcern.ReadConcern, sess *session.Client) (bsonx.Doc, error) {
265	// Starting transaction's read concern overrides all others
266	if sess != nil && sess.TransactionStarting() && sess.CurrentRc != nil {
267		rc = sess.CurrentRc
268	}
269
270	// start transaction must append afterclustertime IF causally consistent and operation time exists
271	if rc == nil && sess != nil && sess.TransactionStarting() && sess.Consistent && sess.OperationTime != nil {
272		rc = readconcern.New()
273	}
274
275	if rc == nil {
276		return cmd, nil
277	}
278
279	t, data, err := rc.MarshalBSONValue()
280	if err != nil {
281		return cmd, err
282	}
283
284	var rcDoc bsonx.Doc
285	err = rcDoc.UnmarshalBSONValue(t, data)
286	if err != nil {
287		return cmd, err
288	}
289	if description.SessionsSupported(desc.WireVersion) && sess != nil && sess.Consistent && sess.OperationTime != nil {
290		rcDoc = append(rcDoc, bsonx.Elem{"afterClusterTime", bsonx.Timestamp(sess.OperationTime.T, sess.OperationTime.I)})
291	}
292
293	cmd = cmd.Delete("readConcern")
294
295	if len(rcDoc) != 0 {
296		cmd = append(cmd, bsonx.Elem{"readConcern", bsonx.Document(rcDoc)})
297	}
298	return cmd, nil
299}
300
301// add a write concern to a BSON doc representing a command
302func addWriteConcern(cmd bsonx.Doc, wc *writeconcern.WriteConcern) (bsonx.Doc, error) {
303	if wc == nil {
304		return cmd, nil
305	}
306
307	t, data, err := wc.MarshalBSONValue()
308	if err != nil {
309		if err == writeconcern.ErrEmptyWriteConcern {
310			return cmd, nil
311		}
312		return cmd, err
313	}
314
315	var xval bsonx.Val
316	err = xval.UnmarshalBSONValue(t, data)
317	if err != nil {
318		return cmd, err
319	}
320
321	// delete if doc already has write concern
322	cmd = cmd.Delete("writeConcern")
323
324	return append(cmd, bsonx.Elem{Key: "writeConcern", Value: xval}), nil
325}
326
327// Get the error labels from a command response
328func getErrorLabels(rdr *bson.Raw) ([]string, error) {
329	var labels []string
330	labelsElem, err := rdr.LookupErr("errorLabels")
331	if err != bsoncore.ErrElementNotFound {
332		return nil, err
333	}
334	if labelsElem.Type == bsontype.Array {
335		labelsIt, err := labelsElem.Array().Elements()
336		if err != nil {
337			return nil, err
338		}
339		for _, elem := range labelsIt {
340			labels = append(labels, elem.Value().StringValue())
341		}
342	}
343	return labels, nil
344}
345
346// Remove command arguments for insert, update, and delete commands from the BSON document so they can be encoded
347// as a Section 1 payload in OP_MSG
348func opmsgRemoveArray(cmd bsonx.Doc) (bsonx.Doc, bsonx.Arr, string) {
349	var array bsonx.Arr
350	var id string
351
352	keys := []string{"documents", "updates", "deletes"}
353
354	for _, key := range keys {
355		val, err := cmd.LookupErr(key)
356		if err != nil {
357			continue
358		}
359
360		array = val.Array()
361		cmd = cmd.Delete(key)
362		id = key
363		break
364	}
365
366	return cmd, array, id
367}
368
369// Add the $db and $readPreference keys to the command
370// If the command has no read preference, pass nil for rpDoc
371func opmsgAddGlobals(cmd bsonx.Doc, dbName string, rpDoc bsonx.Doc) (bson.Raw, error) {
372	cmd = append(cmd, bsonx.Elem{"$db", bsonx.String(dbName)})
373	if rpDoc != nil {
374		cmd = append(cmd, bsonx.Elem{"$readPreference", bsonx.Document(rpDoc)})
375	}
376
377	return cmd.MarshalBSON() // bsonx.Doc.MarshalBSON never returns an error.
378}
379
380func opmsgCreateDocSequence(arr bsonx.Arr, identifier string) (wiremessage.SectionDocumentSequence, error) {
381	docSequence := wiremessage.SectionDocumentSequence{
382		PayloadType: wiremessage.DocumentSequence,
383		Identifier:  identifier,
384		Documents:   make([]bson.Raw, 0, len(arr)),
385	}
386
387	for _, val := range arr {
388		d, _ := val.Document().MarshalBSON()
389		docSequence.Documents = append(docSequence.Documents, d)
390	}
391
392	docSequence.Size = int32(docSequence.PayloadLen())
393	return docSequence, nil
394}
395
396func splitBatches(docs []bsonx.Doc, maxCount, targetBatchSize int) ([][]bsonx.Doc, error) {
397	batches := [][]bsonx.Doc{}
398
399	if targetBatchSize > reservedCommandBufferBytes {
400		targetBatchSize -= reservedCommandBufferBytes
401	}
402
403	if maxCount <= 0 {
404		maxCount = 1
405	}
406
407	startAt := 0
408splitInserts:
409	for {
410		size := 0
411		batch := []bsonx.Doc{}
412	assembleBatch:
413		for idx := startAt; idx < len(docs); idx++ {
414			raw, _ := docs[idx].MarshalBSON()
415
416			if len(raw) > targetBatchSize {
417				return nil, ErrDocumentTooLarge
418			}
419			if size+len(raw) > targetBatchSize {
420				break assembleBatch
421			}
422
423			size += len(raw)
424			batch = append(batch, docs[idx])
425			startAt++
426			if len(batch) == maxCount {
427				break assembleBatch
428			}
429		}
430		batches = append(batches, batch)
431		if startAt == len(docs) {
432			break splitInserts
433		}
434	}
435
436	return batches, nil
437}
438
439func encodeBatch(
440	docs []bsonx.Doc,
441	opts []bsonx.Elem,
442	cmdKind WriteCommandKind,
443	collName string,
444) (bsonx.Doc, error) {
445	var cmdName string
446	var docString string
447
448	switch cmdKind {
449	case InsertCommand:
450		cmdName = "insert"
451		docString = "documents"
452	case UpdateCommand:
453		cmdName = "update"
454		docString = "updates"
455	case DeleteCommand:
456		cmdName = "delete"
457		docString = "deletes"
458	}
459
460	cmd := bsonx.Doc{{cmdName, bsonx.String(collName)}}
461
462	vals := make(bsonx.Arr, 0, len(docs))
463	for _, doc := range docs {
464		vals = append(vals, bsonx.Document(doc))
465	}
466	cmd = append(cmd, bsonx.Elem{docString, bsonx.Array(vals)})
467	cmd = append(cmd, opts...)
468
469	return cmd, nil
470}
471
472// converts batches of Write Commands to wire messages
473func batchesToWireMessage(batches []*WriteBatch, desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
474	wms := make([]wiremessage.WireMessage, len(batches))
475	for _, cmd := range batches {
476		wm, err := cmd.Encode(desc)
477		if err != nil {
478			return nil, err
479		}
480
481		wms = append(wms, wm)
482	}
483
484	return wms, nil
485}
486
487// Roundtrips the write batches, returning the result structs (as interface),
488// the write batches that weren't round tripped and any errors
489func roundTripBatches(
490	ctx context.Context,
491	desc description.SelectedServer,
492	rw wiremessage.ReadWriter,
493	batches []*WriteBatch,
494	continueOnError bool,
495	sess *session.Client,
496	cmdKind WriteCommandKind,
497) (interface{}, []*WriteBatch, error) {
498	var res interface{}
499	var upsertIndex int64 // the operation index for the upserted IDs map
500
501	// hold onto txnNumber, reset it when loop exits to ensure reuse of same
502	// transaction number if retry is needed
503	var txnNumber int64
504	if sess != nil && sess.RetryWrite {
505		txnNumber = sess.TxnNumber
506	}
507	for j, cmd := range batches {
508		rdr, err := cmd.RoundTrip(ctx, desc, rw)
509		if err != nil {
510			if sess != nil && sess.RetryWrite {
511				sess.TxnNumber = txnNumber + int64(j)
512			}
513			return res, batches, err
514		}
515
516		// TODO can probably DRY up this code
517		switch cmdKind {
518		case InsertCommand:
519			if res == nil {
520				res = result.Insert{}
521			}
522
523			conv, _ := res.(result.Insert)
524			insertCmd := &Insert{}
525			r, err := insertCmd.decode(desc, rdr).Result()
526			if err != nil {
527				return res, batches, err
528			}
529
530			conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
531
532			if r.WriteConcernError != nil {
533				conv.WriteConcernError = r.WriteConcernError
534				if sess != nil && sess.RetryWrite {
535					sess.TxnNumber = txnNumber
536					return conv, batches, nil // report writeconcernerror for retry
537				}
538			}
539
540			conv.N += r.N
541
542			if !continueOnError && len(conv.WriteErrors) > 0 {
543				return conv, batches, nil
544			}
545
546			res = conv
547		case UpdateCommand:
548			if res == nil {
549				res = result.Update{}
550			}
551
552			conv, _ := res.(result.Update)
553			updateCmd := &Update{}
554			r, err := updateCmd.decode(desc, rdr).Result()
555			if err != nil {
556				return conv, batches, err
557			}
558
559			conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
560
561			if r.WriteConcernError != nil {
562				conv.WriteConcernError = r.WriteConcernError
563				if sess != nil && sess.RetryWrite {
564					sess.TxnNumber = txnNumber
565					return conv, batches, nil // report writeconcernerror for retry
566				}
567			}
568
569			conv.MatchedCount += r.MatchedCount
570			conv.ModifiedCount += r.ModifiedCount
571			for _, upsert := range r.Upserted {
572				conv.Upserted = append(conv.Upserted, result.Upsert{
573					Index: upsert.Index + upsertIndex,
574					ID:    upsert.ID,
575				})
576			}
577
578			if !continueOnError && len(conv.WriteErrors) > 0 {
579				return conv, batches, nil
580			}
581
582			res = conv
583			upsertIndex += int64(cmd.numDocs)
584		case DeleteCommand:
585			if res == nil {
586				res = result.Delete{}
587			}
588
589			conv, _ := res.(result.Delete)
590			deleteCmd := &Delete{}
591			r, err := deleteCmd.decode(desc, rdr).Result()
592			if err != nil {
593				return conv, batches, err
594			}
595
596			conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
597
598			if r.WriteConcernError != nil {
599				conv.WriteConcernError = r.WriteConcernError
600				if sess != nil && sess.RetryWrite {
601					sess.TxnNumber = txnNumber
602					return conv, batches, nil // report writeconcernerror for retry
603				}
604			}
605
606			conv.N += r.N
607
608			if !continueOnError && len(conv.WriteErrors) > 0 {
609				return conv, batches, nil
610			}
611
612			res = conv
613		}
614
615		// Increment txnNumber for each batch
616		if sess != nil && sess.RetryWrite {
617			sess.IncrementTxnNumber()
618			batches = batches[1:] // if batch encoded successfully, remove it from the slice
619		}
620	}
621
622	if sess != nil && sess.RetryWrite {
623		// if retryable write succeeded, transaction number will be incremented one extra time,
624		// so we decrement it here
625		sess.TxnNumber--
626	}
627
628	return res, batches, nil
629}
630
631// get the firstBatch, cursor ID, and namespace from a bson.Raw
632func getCursorValues(result bson.Raw) ([]bson.RawValue, Namespace, int64, error) {
633	cur, err := result.LookupErr("cursor")
634	if err != nil {
635		return nil, Namespace{}, 0, err
636	}
637	if cur.Type != bson.TypeEmbeddedDocument {
638		return nil, Namespace{}, 0, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type)
639	}
640
641	elems, err := cur.Document().Elements()
642	if err != nil {
643		return nil, Namespace{}, 0, err
644	}
645
646	var ok bool
647	var arr bson.Raw
648	var namespace Namespace
649	var cursorID int64
650
651	for _, elem := range elems {
652		switch elem.Key() {
653		case "firstBatch":
654			arr, ok = elem.Value().ArrayOK()
655			if !ok {
656				return nil, Namespace{}, 0, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type)
657			}
658			if err != nil {
659				return nil, Namespace{}, 0, err
660			}
661		case "ns":
662			if elem.Value().Type != bson.TypeString {
663				return nil, Namespace{}, 0, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type)
664			}
665			namespace = ParseNamespace(elem.Value().StringValue())
666			err = namespace.Validate()
667			if err != nil {
668				return nil, Namespace{}, 0, err
669			}
670		case "id":
671			cursorID, ok = elem.Value().Int64OK()
672			if !ok {
673				return nil, Namespace{}, 0, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
674			}
675		}
676	}
677
678	vals, err := arr.Values()
679	if err != nil {
680		return nil, Namespace{}, 0, err
681	}
682
683	return vals, namespace, cursorID, nil
684}
685
686func getBatchSize(opts []bsonx.Elem) int32 {
687	for _, opt := range opts {
688		if opt.Key == "batchSize" {
689			return opt.Value.Int32()
690		}
691	}
692
693	return 0
694}
695
696// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
697// write concern.
698var ErrUnacknowledgedWrite = errors.New("unacknowledged write")
699
700// WriteCommandKind is the type of command represented by a Write
701type WriteCommandKind int8
702
703// These constants represent the valid types of write commands.
704const (
705	InsertCommand WriteCommandKind = iota
706	UpdateCommand
707	DeleteCommand
708)
709