1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"bytes"
11	"errors"
12	"fmt"
13
14	"go.mongodb.org/mongo-driver/bson"
15	"go.mongodb.org/mongo-driver/x/mongo/driver"
16	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
17	"go.mongodb.org/mongo-driver/x/network/command"
18	"go.mongodb.org/mongo-driver/x/network/result"
19)
20
21// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
22// write concern.
23var ErrUnacknowledgedWrite = errors.New("unacknowledged write")
24
25// ErrClientDisconnected is returned when a user attempts to call a method on a
26// disconnected client
27var ErrClientDisconnected = errors.New("client is disconnected")
28
29// ErrNilDocument is returned when a user attempts to pass a nil document or filter
30// to a function where the field is required.
31var ErrNilDocument = errors.New("document is nil")
32
33// ErrEmptySlice is returned when a user attempts to pass an empty slice as input
34// to a function wehere the field is required.
35var ErrEmptySlice = errors.New("must provide at least one element in input slice")
36
37func replaceErrors(err error) error {
38	if err == topology.ErrTopologyClosed {
39		return ErrClientDisconnected
40	}
41	if ce, ok := err.(command.Error); ok {
42		return CommandError{Code: ce.Code, Message: ce.Message, Labels: ce.Labels, Name: ce.Name}
43	}
44	if conv, ok := err.(driver.BulkWriteException); ok {
45		return BulkWriteException{
46			WriteConcernError: convertWriteConcernError(conv.WriteConcernError),
47			WriteErrors:       convertBulkWriteErrors(conv.WriteErrors),
48		}
49	}
50
51	return err
52}
53
54// CommandError represents an error in execution of a command against the database.
55type CommandError struct {
56	Code    int32
57	Message string
58	Labels  []string
59	Name    string
60}
61
62// Error implements the error interface.
63func (e CommandError) Error() string {
64	if e.Name != "" {
65		return fmt.Sprintf("(%v) %v", e.Name, e.Message)
66	}
67	return e.Message
68}
69
70// HasErrorLabel returns true if the error contains the specified label.
71func (e CommandError) HasErrorLabel(label string) bool {
72	if e.Labels != nil {
73		for _, l := range e.Labels {
74			if l == label {
75				return true
76			}
77		}
78	}
79	return false
80}
81
82// WriteError is a non-write concern failure that occurred as a result of a write
83// operation.
84type WriteError struct {
85	Index   int
86	Code    int
87	Message string
88}
89
90func (we WriteError) Error() string { return we.Message }
91
92// WriteErrors is a group of non-write concern failures that occurred as a result
93// of a write operation.
94type WriteErrors []WriteError
95
96func (we WriteErrors) Error() string {
97	var buf bytes.Buffer
98	fmt.Fprint(&buf, "write errors: [")
99	for idx, err := range we {
100		if idx != 0 {
101			fmt.Fprintf(&buf, ", ")
102		}
103		fmt.Fprintf(&buf, "{%s}", err)
104	}
105	fmt.Fprint(&buf, "]")
106	return buf.String()
107}
108
109func writeErrorsFromResult(rwes []result.WriteError) WriteErrors {
110	wes := make(WriteErrors, 0, len(rwes))
111	for _, err := range rwes {
112		wes = append(wes, WriteError{Index: err.Index, Code: err.Code, Message: err.ErrMsg})
113	}
114	return wes
115}
116
117// WriteConcernError is a write concern failure that occurred as a result of a
118// write operation.
119type WriteConcernError struct {
120	Code    int
121	Message string
122	Details bson.Raw
123}
124
125func (wce WriteConcernError) Error() string { return wce.Message }
126
127// WriteException is an error for a non-bulk write operation.
128type WriteException struct {
129	WriteConcernError *WriteConcernError
130	WriteErrors       WriteErrors
131}
132
133func (mwe WriteException) Error() string {
134	var buf bytes.Buffer
135	fmt.Fprint(&buf, "multiple write errors: [")
136	fmt.Fprintf(&buf, "{%s}, ", mwe.WriteErrors)
137	fmt.Fprintf(&buf, "{%s}]", mwe.WriteConcernError)
138	return buf.String()
139}
140
141func convertBulkWriteErrors(errors []driver.BulkWriteError) []BulkWriteError {
142	bwErrors := make([]BulkWriteError, 0, len(errors))
143	for _, err := range errors {
144		bwErrors = append(bwErrors, BulkWriteError{
145			WriteError{
146				Index:   err.Index,
147				Code:    err.Code,
148				Message: err.ErrMsg,
149			},
150			dispatchToMongoModel(err.Model),
151		})
152	}
153
154	return bwErrors
155}
156
157func convertWriteConcernError(wce *result.WriteConcernError) *WriteConcernError {
158	if wce == nil {
159		return nil
160	}
161
162	return &WriteConcernError{Code: wce.Code, Message: wce.ErrMsg, Details: wce.ErrInfo}
163}
164
165// BulkWriteError is an error for one operation in a bulk write.
166type BulkWriteError struct {
167	WriteError
168	Request WriteModel
169}
170
171func (bwe BulkWriteError) Error() string {
172	var buf bytes.Buffer
173	fmt.Fprintf(&buf, "{%s}", bwe.WriteError)
174	return buf.String()
175}
176
177// BulkWriteException is an error for a bulk write operation.
178type BulkWriteException struct {
179	WriteConcernError *WriteConcernError
180	WriteErrors       []BulkWriteError
181}
182
183func (bwe BulkWriteException) Error() string {
184	var buf bytes.Buffer
185	fmt.Fprint(&buf, "bulk write error: [")
186	fmt.Fprintf(&buf, "{%s}, ", bwe.WriteErrors)
187	fmt.Fprintf(&buf, "{%s}]", bwe.WriteConcernError)
188	return buf.String()
189}
190
191// returnResult is used to determine if a function calling processWriteError should return
192// the result or return nil. Since the processWriteError function is used by many different
193// methods, both *One and *Many, we need a way to differentiate if the method should return
194// the result and the error.
195type returnResult int
196
197const (
198	rrNone returnResult = 1 << iota // None means do not return the result ever.
199	rrOne                           // One means return the result if this was called by a *One method.
200	rrMany                          // Many means return the result is this was called by a *Many method.
201
202	rrAll returnResult = rrOne | rrMany // All means always return the result.
203)
204
205// processWriteError handles processing the result of a write operation. If the retrunResult matches
206// the calling method's type, it should return the result object in addition to the error.
207// This function will wrap the errors from other packages and return them as errors from this package.
208//
209// WriteConcernError will be returned over WriteErrors if both are present.
210func processWriteError(wce *result.WriteConcernError, wes []result.WriteError, err error) (returnResult, error) {
211	switch {
212	case err == command.ErrUnacknowledgedWrite:
213		return rrAll, ErrUnacknowledgedWrite
214	case err != nil:
215		return rrNone, replaceErrors(err)
216	case wce != nil || len(wes) > 0:
217		return rrMany, WriteException{
218			WriteConcernError: convertWriteConcernError(wce),
219			WriteErrors:       writeErrorsFromResult(wes),
220		}
221	default:
222		return rrAll, nil
223	}
224}
225