1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package errors
18
19import (
20	"errors"
21	"fmt"
22
23	"k8s.io/apimachinery/pkg/util/sets"
24)
25
26// MessageCountMap contains occurrence for each error message.
27type MessageCountMap map[string]int
28
29// Aggregate represents an object that contains multiple errors, but does not
30// necessarily have singular semantic meaning.
31// The aggregate can be used with `errors.Is()` to check for the occurrence of
32// a specific error type.
33// Errors.As() is not supported, because the caller presumably cares about a
34// specific error of potentially multiple that match the given type.
35type Aggregate interface {
36	error
37	Errors() []error
38	Is(error) bool
39}
40
41// NewAggregate converts a slice of errors into an Aggregate interface, which
42// is itself an implementation of the error interface.  If the slice is empty,
43// this returns nil.
44// It will check if any of the element of input error list is nil, to avoid
45// nil pointer panic when call Error().
46func NewAggregate(errlist []error) Aggregate {
47	if len(errlist) == 0 {
48		return nil
49	}
50	// In case of input error list contains nil
51	var errs []error
52	for _, e := range errlist {
53		if e != nil {
54			errs = append(errs, e)
55		}
56	}
57	if len(errs) == 0 {
58		return nil
59	}
60	return aggregate(errs)
61}
62
63// This helper implements the error and Errors interfaces.  Keeping it private
64// prevents people from making an aggregate of 0 errors, which is not
65// an error, but does satisfy the error interface.
66type aggregate []error
67
68// Error is part of the error interface.
69func (agg aggregate) Error() string {
70	if len(agg) == 0 {
71		// This should never happen, really.
72		return ""
73	}
74	if len(agg) == 1 {
75		return agg[0].Error()
76	}
77	seenerrs := sets.NewString()
78	result := ""
79	agg.visit(func(err error) bool {
80		msg := err.Error()
81		if seenerrs.Has(msg) {
82			return false
83		}
84		seenerrs.Insert(msg)
85		if len(seenerrs) > 1 {
86			result += ", "
87		}
88		result += msg
89		return false
90	})
91	if len(seenerrs) == 1 {
92		return result
93	}
94	return "[" + result + "]"
95}
96
97func (agg aggregate) Is(target error) bool {
98	return agg.visit(func(err error) bool {
99		return errors.Is(err, target)
100	})
101}
102
103func (agg aggregate) visit(f func(err error) bool) bool {
104	for _, err := range agg {
105		switch err := err.(type) {
106		case aggregate:
107			if match := err.visit(f); match {
108				return match
109			}
110		case Aggregate:
111			for _, nestedErr := range err.Errors() {
112				if match := f(nestedErr); match {
113					return match
114				}
115			}
116		default:
117			if match := f(err); match {
118				return match
119			}
120		}
121	}
122
123	return false
124}
125
126// Errors is part of the Aggregate interface.
127func (agg aggregate) Errors() []error {
128	return []error(agg)
129}
130
131// Matcher is used to match errors.  Returns true if the error matches.
132type Matcher func(error) bool
133
134// FilterOut removes all errors that match any of the matchers from the input
135// error.  If the input is a singular error, only that error is tested.  If the
136// input implements the Aggregate interface, the list of errors will be
137// processed recursively.
138//
139// This can be used, for example, to remove known-OK errors (such as io.EOF or
140// os.PathNotFound) from a list of errors.
141func FilterOut(err error, fns ...Matcher) error {
142	if err == nil {
143		return nil
144	}
145	if agg, ok := err.(Aggregate); ok {
146		return NewAggregate(filterErrors(agg.Errors(), fns...))
147	}
148	if !matchesError(err, fns...) {
149		return err
150	}
151	return nil
152}
153
154// matchesError returns true if any Matcher returns true
155func matchesError(err error, fns ...Matcher) bool {
156	for _, fn := range fns {
157		if fn(err) {
158			return true
159		}
160	}
161	return false
162}
163
164// filterErrors returns any errors (or nested errors, if the list contains
165// nested Errors) for which all fns return false. If no errors
166// remain a nil list is returned. The resulting silec will have all
167// nested slices flattened as a side effect.
168func filterErrors(list []error, fns ...Matcher) []error {
169	result := []error{}
170	for _, err := range list {
171		r := FilterOut(err, fns...)
172		if r != nil {
173			result = append(result, r)
174		}
175	}
176	return result
177}
178
179// Flatten takes an Aggregate, which may hold other Aggregates in arbitrary
180// nesting, and flattens them all into a single Aggregate, recursively.
181func Flatten(agg Aggregate) Aggregate {
182	result := []error{}
183	if agg == nil {
184		return nil
185	}
186	for _, err := range agg.Errors() {
187		if a, ok := err.(Aggregate); ok {
188			r := Flatten(a)
189			if r != nil {
190				result = append(result, r.Errors()...)
191			}
192		} else {
193			if err != nil {
194				result = append(result, err)
195			}
196		}
197	}
198	return NewAggregate(result)
199}
200
201// CreateAggregateFromMessageCountMap converts MessageCountMap Aggregate
202func CreateAggregateFromMessageCountMap(m MessageCountMap) Aggregate {
203	if m == nil {
204		return nil
205	}
206	result := make([]error, 0, len(m))
207	for errStr, count := range m {
208		var countStr string
209		if count > 1 {
210			countStr = fmt.Sprintf(" (repeated %v times)", count)
211		}
212		result = append(result, fmt.Errorf("%v%v", errStr, countStr))
213	}
214	return NewAggregate(result)
215}
216
217// Reduce will return err or, if err is an Aggregate and only has one item,
218// the first item in the aggregate.
219func Reduce(err error) error {
220	if agg, ok := err.(Aggregate); ok && err != nil {
221		switch len(agg.Errors()) {
222		case 1:
223			return agg.Errors()[0]
224		case 0:
225			return nil
226		}
227	}
228	return err
229}
230
231// AggregateGoroutines runs the provided functions in parallel, stuffing all
232// non-nil errors into the returned Aggregate.
233// Returns nil if all the functions complete successfully.
234func AggregateGoroutines(funcs ...func() error) Aggregate {
235	errChan := make(chan error, len(funcs))
236	for _, f := range funcs {
237		go func(f func() error) { errChan <- f() }(f)
238	}
239	errs := make([]error, 0)
240	for i := 0; i < cap(errChan); i++ {
241		if err := <-errChan; err != nil {
242			errs = append(errs, err)
243		}
244	}
245	return NewAggregate(errs)
246}
247
248// ErrPreconditionViolated is returned when the precondition is violated
249var ErrPreconditionViolated = errors.New("precondition is violated")
250