1// Copyright (c) 2019 Uber Technologies, Inc. 2// 3// Permission is hereby granted, free of charge, to any person obtaining a copy 4// of this software and associated documentation files (the "Software"), to deal 5// in the Software without restriction, including without limitation the rights 6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7// copies of the Software, and to permit persons to whom the Software is 8// furnished to do so, subject to the following conditions: 9// 10// The above copyright notice and this permission notice shall be included in 11// all copies or substantial portions of the Software. 12// 13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19// THE SOFTWARE. 20 21package multierr 22 23import ( 24 "errors" 25 "fmt" 26 "io" 27 "sync" 28 "testing" 29 30 "github.com/stretchr/testify/assert" 31 "github.com/stretchr/testify/require" 32) 33 34// richFormatError is an error that prints a different output depending on 35// whether %v or %+v was used. 36type richFormatError struct{} 37 38func (r richFormatError) Error() string { 39 return fmt.Sprint(r) 40} 41 42func (richFormatError) Format(f fmt.State, c rune) { 43 if c == 'v' && f.Flag('+') { 44 io.WriteString(f, "multiline\nmessage\nwith plus") 45 } else { 46 io.WriteString(f, "without plus") 47 } 48} 49 50func appendN(initial, err error, n int) error { 51 errs := initial 52 for i := 0; i < n; i++ { 53 errs = Append(errs, err) 54 } 55 return errs 56} 57 58func newMultiErr(errors ...error) error { 59 return &multiError{errors: errors} 60} 61 62func TestCombine(t *testing.T) { 63 tests := []struct { 64 // Input 65 giveErrors []error 66 67 // Resulting error 68 wantError error 69 70 // %+v and %v string representations 71 wantMultiline string 72 wantSingleline string 73 }{ 74 { 75 giveErrors: nil, 76 wantError: nil, 77 }, 78 { 79 giveErrors: []error{}, 80 wantError: nil, 81 }, 82 { 83 giveErrors: []error{ 84 errors.New("foo"), 85 nil, 86 newMultiErr( 87 errors.New("bar"), 88 ), 89 nil, 90 }, 91 wantError: newMultiErr( 92 errors.New("foo"), 93 errors.New("bar"), 94 ), 95 wantMultiline: "the following errors occurred:\n" + 96 " - foo\n" + 97 " - bar", 98 wantSingleline: "foo; bar", 99 }, 100 { 101 giveErrors: []error{ 102 errors.New("foo"), 103 newMultiErr( 104 errors.New("bar"), 105 ), 106 }, 107 wantError: newMultiErr( 108 errors.New("foo"), 109 errors.New("bar"), 110 ), 111 wantMultiline: "the following errors occurred:\n" + 112 " - foo\n" + 113 " - bar", 114 wantSingleline: "foo; bar", 115 }, 116 { 117 giveErrors: []error{errors.New("great sadness")}, 118 wantError: errors.New("great sadness"), 119 wantMultiline: "great sadness", 120 wantSingleline: "great sadness", 121 }, 122 { 123 giveErrors: []error{ 124 errors.New("foo"), 125 errors.New("bar"), 126 }, 127 wantError: newMultiErr( 128 errors.New("foo"), 129 errors.New("bar"), 130 ), 131 wantMultiline: "the following errors occurred:\n" + 132 " - foo\n" + 133 " - bar", 134 wantSingleline: "foo; bar", 135 }, 136 { 137 giveErrors: []error{ 138 errors.New("great sadness"), 139 errors.New("multi\n line\nerror message"), 140 errors.New("single line error message"), 141 }, 142 wantError: newMultiErr( 143 errors.New("great sadness"), 144 errors.New("multi\n line\nerror message"), 145 errors.New("single line error message"), 146 ), 147 wantMultiline: "the following errors occurred:\n" + 148 " - great sadness\n" + 149 " - multi\n" + 150 " line\n" + 151 " error message\n" + 152 " - single line error message", 153 wantSingleline: "great sadness; " + 154 "multi\n line\nerror message; " + 155 "single line error message", 156 }, 157 { 158 giveErrors: []error{ 159 errors.New("foo"), 160 newMultiErr( 161 errors.New("bar"), 162 errors.New("baz"), 163 ), 164 errors.New("qux"), 165 }, 166 wantError: newMultiErr( 167 errors.New("foo"), 168 errors.New("bar"), 169 errors.New("baz"), 170 errors.New("qux"), 171 ), 172 wantMultiline: "the following errors occurred:\n" + 173 " - foo\n" + 174 " - bar\n" + 175 " - baz\n" + 176 " - qux", 177 wantSingleline: "foo; bar; baz; qux", 178 }, 179 { 180 giveErrors: []error{ 181 errors.New("foo"), 182 nil, 183 newMultiErr( 184 errors.New("bar"), 185 ), 186 nil, 187 }, 188 wantError: newMultiErr( 189 errors.New("foo"), 190 errors.New("bar"), 191 ), 192 wantMultiline: "the following errors occurred:\n" + 193 " - foo\n" + 194 " - bar", 195 wantSingleline: "foo; bar", 196 }, 197 { 198 giveErrors: []error{ 199 errors.New("foo"), 200 newMultiErr( 201 errors.New("bar"), 202 ), 203 }, 204 wantError: newMultiErr( 205 errors.New("foo"), 206 errors.New("bar"), 207 ), 208 wantMultiline: "the following errors occurred:\n" + 209 " - foo\n" + 210 " - bar", 211 wantSingleline: "foo; bar", 212 }, 213 { 214 giveErrors: []error{ 215 errors.New("foo"), 216 richFormatError{}, 217 errors.New("bar"), 218 }, 219 wantError: newMultiErr( 220 errors.New("foo"), 221 richFormatError{}, 222 errors.New("bar"), 223 ), 224 wantMultiline: "the following errors occurred:\n" + 225 " - foo\n" + 226 " - multiline\n" + 227 " message\n" + 228 " with plus\n" + 229 " - bar", 230 wantSingleline: "foo; without plus; bar", 231 }, 232 } 233 234 for i, tt := range tests { 235 t.Run(fmt.Sprint(i), func(t *testing.T) { 236 err := Combine(tt.giveErrors...) 237 require.Equal(t, tt.wantError, err) 238 239 if tt.wantMultiline != "" { 240 t.Run("Sprintf/multiline", func(t *testing.T) { 241 assert.Equal(t, tt.wantMultiline, fmt.Sprintf("%+v", err)) 242 }) 243 } 244 245 if tt.wantSingleline != "" { 246 t.Run("Sprintf/singleline", func(t *testing.T) { 247 assert.Equal(t, tt.wantSingleline, fmt.Sprintf("%v", err)) 248 }) 249 250 t.Run("Error()", func(t *testing.T) { 251 assert.Equal(t, tt.wantSingleline, err.Error()) 252 }) 253 254 if s, ok := err.(fmt.Stringer); ok { 255 t.Run("String()", func(t *testing.T) { 256 assert.Equal(t, tt.wantSingleline, s.String()) 257 }) 258 } 259 } 260 }) 261 } 262} 263 264func TestCombineDoesNotModifySlice(t *testing.T) { 265 errors := []error{ 266 errors.New("foo"), 267 nil, 268 errors.New("bar"), 269 } 270 271 assert.NotNil(t, Combine(errors...)) 272 assert.Len(t, errors, 3) 273 assert.Nil(t, errors[1], 3) 274} 275 276func TestAppend(t *testing.T) { 277 tests := []struct { 278 left error 279 right error 280 want error 281 }{ 282 { 283 left: nil, 284 right: nil, 285 want: nil, 286 }, 287 { 288 left: nil, 289 right: errors.New("great sadness"), 290 want: errors.New("great sadness"), 291 }, 292 { 293 left: errors.New("great sadness"), 294 right: nil, 295 want: errors.New("great sadness"), 296 }, 297 { 298 left: errors.New("foo"), 299 right: errors.New("bar"), 300 want: newMultiErr( 301 errors.New("foo"), 302 errors.New("bar"), 303 ), 304 }, 305 { 306 left: newMultiErr( 307 errors.New("foo"), 308 errors.New("bar"), 309 ), 310 right: errors.New("baz"), 311 want: newMultiErr( 312 errors.New("foo"), 313 errors.New("bar"), 314 errors.New("baz"), 315 ), 316 }, 317 { 318 left: errors.New("baz"), 319 right: newMultiErr( 320 errors.New("foo"), 321 errors.New("bar"), 322 ), 323 want: newMultiErr( 324 errors.New("baz"), 325 errors.New("foo"), 326 errors.New("bar"), 327 ), 328 }, 329 { 330 left: newMultiErr( 331 errors.New("foo"), 332 ), 333 right: newMultiErr( 334 errors.New("bar"), 335 ), 336 want: newMultiErr( 337 errors.New("foo"), 338 errors.New("bar"), 339 ), 340 }, 341 } 342 343 for _, tt := range tests { 344 assert.Equal(t, tt.want, Append(tt.left, tt.right)) 345 } 346} 347 348type notMultiErr struct{} 349 350var _ errorGroup = notMultiErr{} 351 352func (notMultiErr) Error() string { 353 return "great sadness" 354} 355 356func (notMultiErr) Errors() []error { 357 return []error{errors.New("great sadness")} 358} 359 360func TestErrors(t *testing.T) { 361 tests := []struct { 362 give error 363 want []error 364 365 // Don't attempt to cast to errorGroup or *multiError 366 dontCast bool 367 }{ 368 {dontCast: true}, // nil 369 { 370 give: errors.New("hi"), 371 want: []error{errors.New("hi")}, 372 dontCast: true, 373 }, 374 { 375 // We don't yet support non-multierr errors. 376 give: notMultiErr{}, 377 want: []error{notMultiErr{}}, 378 dontCast: true, 379 }, 380 { 381 give: Combine( 382 errors.New("foo"), 383 errors.New("bar"), 384 ), 385 want: []error{ 386 errors.New("foo"), 387 errors.New("bar"), 388 }, 389 }, 390 { 391 give: Append( 392 errors.New("foo"), 393 errors.New("bar"), 394 ), 395 want: []error{ 396 errors.New("foo"), 397 errors.New("bar"), 398 }, 399 }, 400 { 401 give: Append( 402 errors.New("foo"), 403 Combine( 404 errors.New("bar"), 405 ), 406 ), 407 want: []error{ 408 errors.New("foo"), 409 errors.New("bar"), 410 }, 411 }, 412 { 413 give: Combine( 414 errors.New("foo"), 415 Append( 416 errors.New("bar"), 417 errors.New("baz"), 418 ), 419 errors.New("qux"), 420 ), 421 want: []error{ 422 errors.New("foo"), 423 errors.New("bar"), 424 errors.New("baz"), 425 errors.New("qux"), 426 }, 427 }, 428 } 429 430 for i, tt := range tests { 431 t.Run(fmt.Sprint(i), func(t *testing.T) { 432 t.Run("Errors()", func(t *testing.T) { 433 require.Equal(t, tt.want, Errors(tt.give)) 434 }) 435 436 if tt.dontCast { 437 return 438 } 439 440 t.Run("multiError", func(t *testing.T) { 441 require.Equal(t, tt.want, tt.give.(*multiError).Errors()) 442 }) 443 444 t.Run("errorGroup", func(t *testing.T) { 445 require.Equal(t, tt.want, tt.give.(errorGroup).Errors()) 446 }) 447 }) 448 } 449} 450 451func createMultiErrWithCapacity() error { 452 // Create a multiError that has capacity for more errors so Append will 453 // modify the underlying array that may be shared. 454 return appendN(nil, errors.New("append"), 50) 455} 456 457func TestAppendDoesNotModify(t *testing.T) { 458 initial := createMultiErrWithCapacity() 459 err1 := Append(initial, errors.New("err1")) 460 err2 := Append(initial, errors.New("err2")) 461 462 // Make sure the error messages match, since we do modify the copyNeeded 463 // atomic, the values cannot be compared. 464 assert.EqualError(t, initial, createMultiErrWithCapacity().Error(), "Initial should not be modified") 465 466 assert.EqualError(t, err1, Append(createMultiErrWithCapacity(), errors.New("err1")).Error()) 467 assert.EqualError(t, err2, Append(createMultiErrWithCapacity(), errors.New("err2")).Error()) 468} 469 470func TestAppendRace(t *testing.T) { 471 initial := createMultiErrWithCapacity() 472 473 var wg sync.WaitGroup 474 for i := 0; i < 10; i++ { 475 wg.Add(1) 476 go func() { 477 defer wg.Done() 478 479 err := initial 480 for j := 0; j < 10; j++ { 481 err = Append(err, errors.New("err")) 482 } 483 }() 484 } 485 486 wg.Wait() 487} 488 489func TestErrorsSliceIsImmutable(t *testing.T) { 490 err1 := errors.New("err1") 491 err2 := errors.New("err2") 492 493 err := Append(err1, err2) 494 gotErrors := Errors(err) 495 require.Equal(t, []error{err1, err2}, gotErrors, "errors must match") 496 497 gotErrors[0] = nil 498 gotErrors[1] = errors.New("err3") 499 500 require.Equal(t, []error{err1, err2}, Errors(err), 501 "errors must match after modification") 502} 503 504func TestNilMultierror(t *testing.T) { 505 // For safety, all operations on multiError should be safe even if it is 506 // nil. 507 var err *multiError 508 509 require.Empty(t, err.Error()) 510 require.Empty(t, err.Errors()) 511} 512 513func TestAppendInto(t *testing.T) { 514 tests := []struct { 515 desc string 516 into *error 517 give error 518 want error 519 }{ 520 { 521 desc: "append into empty", 522 into: new(error), 523 give: errors.New("foo"), 524 want: errors.New("foo"), 525 }, 526 { 527 desc: "append into non-empty, non-multierr", 528 into: errorPtr(errors.New("foo")), 529 give: errors.New("bar"), 530 want: Combine( 531 errors.New("foo"), 532 errors.New("bar"), 533 ), 534 }, 535 { 536 desc: "append into non-empty multierr", 537 into: errorPtr(Combine( 538 errors.New("foo"), 539 errors.New("bar"), 540 )), 541 give: errors.New("baz"), 542 want: Combine( 543 errors.New("foo"), 544 errors.New("bar"), 545 errors.New("baz"), 546 ), 547 }, 548 } 549 550 for _, tt := range tests { 551 t.Run(tt.desc, func(t *testing.T) { 552 assert.True(t, AppendInto(tt.into, tt.give)) 553 assert.Equal(t, tt.want, *tt.into) 554 }) 555 } 556} 557 558func TestAppendIntoNil(t *testing.T) { 559 t.Run("nil pointer panics", func(t *testing.T) { 560 assert.Panics(t, func() { 561 AppendInto(nil, errors.New("foo")) 562 }) 563 }) 564 565 t.Run("nil error is no-op", func(t *testing.T) { 566 t.Run("empty left", func(t *testing.T) { 567 var err error 568 assert.False(t, AppendInto(&err, nil)) 569 assert.Nil(t, err) 570 }) 571 572 t.Run("non-empty left", func(t *testing.T) { 573 err := errors.New("foo") 574 assert.False(t, AppendInto(&err, nil)) 575 assert.Equal(t, errors.New("foo"), err) 576 }) 577 }) 578} 579 580func errorPtr(err error) *error { 581 return &err 582} 583