1// Copyright (c) 2019 Uber Technologies, Inc.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19// THE SOFTWARE.
20
21package multierr
22
23import (
24	"errors"
25	"fmt"
26	"io"
27	"sync"
28	"testing"
29
30	"github.com/stretchr/testify/assert"
31	"github.com/stretchr/testify/require"
32)
33
34// richFormatError is an error that prints a different output depending on
35// whether %v or %+v was used.
36type richFormatError struct{}
37
38func (r richFormatError) Error() string {
39	return fmt.Sprint(r)
40}
41
42func (richFormatError) Format(f fmt.State, c rune) {
43	if c == 'v' && f.Flag('+') {
44		io.WriteString(f, "multiline\nmessage\nwith plus")
45	} else {
46		io.WriteString(f, "without plus")
47	}
48}
49
50func appendN(initial, err error, n int) error {
51	errs := initial
52	for i := 0; i < n; i++ {
53		errs = Append(errs, err)
54	}
55	return errs
56}
57
58func newMultiErr(errors ...error) error {
59	return &multiError{errors: errors}
60}
61
62func TestCombine(t *testing.T) {
63	tests := []struct {
64		// Input
65		giveErrors []error
66
67		// Resulting error
68		wantError error
69
70		// %+v and %v string representations
71		wantMultiline  string
72		wantSingleline string
73	}{
74		{
75			giveErrors: nil,
76			wantError:  nil,
77		},
78		{
79			giveErrors: []error{},
80			wantError:  nil,
81		},
82		{
83			giveErrors: []error{
84				errors.New("foo"),
85				nil,
86				newMultiErr(
87					errors.New("bar"),
88				),
89				nil,
90			},
91			wantError: newMultiErr(
92				errors.New("foo"),
93				errors.New("bar"),
94			),
95			wantMultiline: "the following errors occurred:\n" +
96				" -  foo\n" +
97				" -  bar",
98			wantSingleline: "foo; bar",
99		},
100		{
101			giveErrors: []error{
102				errors.New("foo"),
103				newMultiErr(
104					errors.New("bar"),
105				),
106			},
107			wantError: newMultiErr(
108				errors.New("foo"),
109				errors.New("bar"),
110			),
111			wantMultiline: "the following errors occurred:\n" +
112				" -  foo\n" +
113				" -  bar",
114			wantSingleline: "foo; bar",
115		},
116		{
117			giveErrors:     []error{errors.New("great sadness")},
118			wantError:      errors.New("great sadness"),
119			wantMultiline:  "great sadness",
120			wantSingleline: "great sadness",
121		},
122		{
123			giveErrors: []error{
124				errors.New("foo"),
125				errors.New("bar"),
126			},
127			wantError: newMultiErr(
128				errors.New("foo"),
129				errors.New("bar"),
130			),
131			wantMultiline: "the following errors occurred:\n" +
132				" -  foo\n" +
133				" -  bar",
134			wantSingleline: "foo; bar",
135		},
136		{
137			giveErrors: []error{
138				errors.New("great sadness"),
139				errors.New("multi\n  line\nerror message"),
140				errors.New("single line error message"),
141			},
142			wantError: newMultiErr(
143				errors.New("great sadness"),
144				errors.New("multi\n  line\nerror message"),
145				errors.New("single line error message"),
146			),
147			wantMultiline: "the following errors occurred:\n" +
148				" -  great sadness\n" +
149				" -  multi\n" +
150				"      line\n" +
151				"    error message\n" +
152				" -  single line error message",
153			wantSingleline: "great sadness; " +
154				"multi\n  line\nerror message; " +
155				"single line error message",
156		},
157		{
158			giveErrors: []error{
159				errors.New("foo"),
160				newMultiErr(
161					errors.New("bar"),
162					errors.New("baz"),
163				),
164				errors.New("qux"),
165			},
166			wantError: newMultiErr(
167				errors.New("foo"),
168				errors.New("bar"),
169				errors.New("baz"),
170				errors.New("qux"),
171			),
172			wantMultiline: "the following errors occurred:\n" +
173				" -  foo\n" +
174				" -  bar\n" +
175				" -  baz\n" +
176				" -  qux",
177			wantSingleline: "foo; bar; baz; qux",
178		},
179		{
180			giveErrors: []error{
181				errors.New("foo"),
182				nil,
183				newMultiErr(
184					errors.New("bar"),
185				),
186				nil,
187			},
188			wantError: newMultiErr(
189				errors.New("foo"),
190				errors.New("bar"),
191			),
192			wantMultiline: "the following errors occurred:\n" +
193				" -  foo\n" +
194				" -  bar",
195			wantSingleline: "foo; bar",
196		},
197		{
198			giveErrors: []error{
199				errors.New("foo"),
200				newMultiErr(
201					errors.New("bar"),
202				),
203			},
204			wantError: newMultiErr(
205				errors.New("foo"),
206				errors.New("bar"),
207			),
208			wantMultiline: "the following errors occurred:\n" +
209				" -  foo\n" +
210				" -  bar",
211			wantSingleline: "foo; bar",
212		},
213		{
214			giveErrors: []error{
215				errors.New("foo"),
216				richFormatError{},
217				errors.New("bar"),
218			},
219			wantError: newMultiErr(
220				errors.New("foo"),
221				richFormatError{},
222				errors.New("bar"),
223			),
224			wantMultiline: "the following errors occurred:\n" +
225				" -  foo\n" +
226				" -  multiline\n" +
227				"    message\n" +
228				"    with plus\n" +
229				" -  bar",
230			wantSingleline: "foo; without plus; bar",
231		},
232	}
233
234	for i, tt := range tests {
235		t.Run(fmt.Sprint(i), func(t *testing.T) {
236			err := Combine(tt.giveErrors...)
237			require.Equal(t, tt.wantError, err)
238
239			if tt.wantMultiline != "" {
240				t.Run("Sprintf/multiline", func(t *testing.T) {
241					assert.Equal(t, tt.wantMultiline, fmt.Sprintf("%+v", err))
242				})
243			}
244
245			if tt.wantSingleline != "" {
246				t.Run("Sprintf/singleline", func(t *testing.T) {
247					assert.Equal(t, tt.wantSingleline, fmt.Sprintf("%v", err))
248				})
249
250				t.Run("Error()", func(t *testing.T) {
251					assert.Equal(t, tt.wantSingleline, err.Error())
252				})
253
254				if s, ok := err.(fmt.Stringer); ok {
255					t.Run("String()", func(t *testing.T) {
256						assert.Equal(t, tt.wantSingleline, s.String())
257					})
258				}
259			}
260		})
261	}
262}
263
264func TestCombineDoesNotModifySlice(t *testing.T) {
265	errors := []error{
266		errors.New("foo"),
267		nil,
268		errors.New("bar"),
269	}
270
271	assert.NotNil(t, Combine(errors...))
272	assert.Len(t, errors, 3)
273	assert.Nil(t, errors[1], 3)
274}
275
276func TestAppend(t *testing.T) {
277	tests := []struct {
278		left  error
279		right error
280		want  error
281	}{
282		{
283			left:  nil,
284			right: nil,
285			want:  nil,
286		},
287		{
288			left:  nil,
289			right: errors.New("great sadness"),
290			want:  errors.New("great sadness"),
291		},
292		{
293			left:  errors.New("great sadness"),
294			right: nil,
295			want:  errors.New("great sadness"),
296		},
297		{
298			left:  errors.New("foo"),
299			right: errors.New("bar"),
300			want: newMultiErr(
301				errors.New("foo"),
302				errors.New("bar"),
303			),
304		},
305		{
306			left: newMultiErr(
307				errors.New("foo"),
308				errors.New("bar"),
309			),
310			right: errors.New("baz"),
311			want: newMultiErr(
312				errors.New("foo"),
313				errors.New("bar"),
314				errors.New("baz"),
315			),
316		},
317		{
318			left: errors.New("baz"),
319			right: newMultiErr(
320				errors.New("foo"),
321				errors.New("bar"),
322			),
323			want: newMultiErr(
324				errors.New("baz"),
325				errors.New("foo"),
326				errors.New("bar"),
327			),
328		},
329		{
330			left: newMultiErr(
331				errors.New("foo"),
332			),
333			right: newMultiErr(
334				errors.New("bar"),
335			),
336			want: newMultiErr(
337				errors.New("foo"),
338				errors.New("bar"),
339			),
340		},
341	}
342
343	for _, tt := range tests {
344		assert.Equal(t, tt.want, Append(tt.left, tt.right))
345	}
346}
347
348type notMultiErr struct{}
349
350var _ errorGroup = notMultiErr{}
351
352func (notMultiErr) Error() string {
353	return "great sadness"
354}
355
356func (notMultiErr) Errors() []error {
357	return []error{errors.New("great sadness")}
358}
359
360func TestErrors(t *testing.T) {
361	tests := []struct {
362		give error
363		want []error
364
365		// Don't attempt to cast to errorGroup or *multiError
366		dontCast bool
367	}{
368		{dontCast: true}, // nil
369		{
370			give:     errors.New("hi"),
371			want:     []error{errors.New("hi")},
372			dontCast: true,
373		},
374		{
375			// We don't yet support non-multierr errors.
376			give:     notMultiErr{},
377			want:     []error{notMultiErr{}},
378			dontCast: true,
379		},
380		{
381			give: Combine(
382				errors.New("foo"),
383				errors.New("bar"),
384			),
385			want: []error{
386				errors.New("foo"),
387				errors.New("bar"),
388			},
389		},
390		{
391			give: Append(
392				errors.New("foo"),
393				errors.New("bar"),
394			),
395			want: []error{
396				errors.New("foo"),
397				errors.New("bar"),
398			},
399		},
400		{
401			give: Append(
402				errors.New("foo"),
403				Combine(
404					errors.New("bar"),
405				),
406			),
407			want: []error{
408				errors.New("foo"),
409				errors.New("bar"),
410			},
411		},
412		{
413			give: Combine(
414				errors.New("foo"),
415				Append(
416					errors.New("bar"),
417					errors.New("baz"),
418				),
419				errors.New("qux"),
420			),
421			want: []error{
422				errors.New("foo"),
423				errors.New("bar"),
424				errors.New("baz"),
425				errors.New("qux"),
426			},
427		},
428	}
429
430	for i, tt := range tests {
431		t.Run(fmt.Sprint(i), func(t *testing.T) {
432			t.Run("Errors()", func(t *testing.T) {
433				require.Equal(t, tt.want, Errors(tt.give))
434			})
435
436			if tt.dontCast {
437				return
438			}
439
440			t.Run("multiError", func(t *testing.T) {
441				require.Equal(t, tt.want, tt.give.(*multiError).Errors())
442			})
443
444			t.Run("errorGroup", func(t *testing.T) {
445				require.Equal(t, tt.want, tt.give.(errorGroup).Errors())
446			})
447		})
448	}
449}
450
451func createMultiErrWithCapacity() error {
452	// Create a multiError that has capacity for more errors so Append will
453	// modify the underlying array that may be shared.
454	return appendN(nil, errors.New("append"), 50)
455}
456
457func TestAppendDoesNotModify(t *testing.T) {
458	initial := createMultiErrWithCapacity()
459	err1 := Append(initial, errors.New("err1"))
460	err2 := Append(initial, errors.New("err2"))
461
462	// Make sure the error messages match, since we do modify the copyNeeded
463	// atomic, the values cannot be compared.
464	assert.EqualError(t, initial, createMultiErrWithCapacity().Error(), "Initial should not be modified")
465
466	assert.EqualError(t, err1, Append(createMultiErrWithCapacity(), errors.New("err1")).Error())
467	assert.EqualError(t, err2, Append(createMultiErrWithCapacity(), errors.New("err2")).Error())
468}
469
470func TestAppendRace(t *testing.T) {
471	initial := createMultiErrWithCapacity()
472
473	var wg sync.WaitGroup
474	for i := 0; i < 10; i++ {
475		wg.Add(1)
476		go func() {
477			defer wg.Done()
478
479			err := initial
480			for j := 0; j < 10; j++ {
481				err = Append(err, errors.New("err"))
482			}
483		}()
484	}
485
486	wg.Wait()
487}
488
489func TestErrorsSliceIsImmutable(t *testing.T) {
490	err1 := errors.New("err1")
491	err2 := errors.New("err2")
492
493	err := Append(err1, err2)
494	gotErrors := Errors(err)
495	require.Equal(t, []error{err1, err2}, gotErrors, "errors must match")
496
497	gotErrors[0] = nil
498	gotErrors[1] = errors.New("err3")
499
500	require.Equal(t, []error{err1, err2}, Errors(err),
501		"errors must match after modification")
502}
503
504func TestNilMultierror(t *testing.T) {
505	// For safety, all operations on multiError should be safe even if it is
506	// nil.
507	var err *multiError
508
509	require.Empty(t, err.Error())
510	require.Empty(t, err.Errors())
511}
512
513func TestAppendInto(t *testing.T) {
514	tests := []struct {
515		desc string
516		into *error
517		give error
518		want error
519	}{
520		{
521			desc: "append into empty",
522			into: new(error),
523			give: errors.New("foo"),
524			want: errors.New("foo"),
525		},
526		{
527			desc: "append into non-empty, non-multierr",
528			into: errorPtr(errors.New("foo")),
529			give: errors.New("bar"),
530			want: Combine(
531				errors.New("foo"),
532				errors.New("bar"),
533			),
534		},
535		{
536			desc: "append into non-empty multierr",
537			into: errorPtr(Combine(
538				errors.New("foo"),
539				errors.New("bar"),
540			)),
541			give: errors.New("baz"),
542			want: Combine(
543				errors.New("foo"),
544				errors.New("bar"),
545				errors.New("baz"),
546			),
547		},
548	}
549
550	for _, tt := range tests {
551		t.Run(tt.desc, func(t *testing.T) {
552			assert.True(t, AppendInto(tt.into, tt.give))
553			assert.Equal(t, tt.want, *tt.into)
554		})
555	}
556}
557
558func TestAppendIntoNil(t *testing.T) {
559	t.Run("nil pointer panics", func(t *testing.T) {
560		assert.Panics(t, func() {
561			AppendInto(nil, errors.New("foo"))
562		})
563	})
564
565	t.Run("nil error is no-op", func(t *testing.T) {
566		t.Run("empty left", func(t *testing.T) {
567			var err error
568			assert.False(t, AppendInto(&err, nil))
569			assert.Nil(t, err)
570		})
571
572		t.Run("non-empty left", func(t *testing.T) {
573			err := errors.New("foo")
574			assert.False(t, AppendInto(&err, nil))
575			assert.Equal(t, errors.New("foo"), err)
576		})
577	})
578}
579
580func errorPtr(err error) *error {
581	return &err
582}
583