1/*
2Copyright 2015 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package errors
18
19import (
20	"errors"
21	"fmt"
22
23	"k8s.io/apimachinery/pkg/util/sets"
24)
25
26// MessageCountMap contains occurrence for each error message.
27type MessageCountMap map[string]int
28
29// Aggregate represents an object that contains multiple errors, but does not
30// necessarily have singular semantic meaning.
31type Aggregate interface {
32	error
33	Errors() []error
34}
35
36// NewAggregate converts a slice of errors into an Aggregate interface, which
37// is itself an implementation of the error interface.  If the slice is empty,
38// this returns nil.
39// It will check if any of the element of input error list is nil, to avoid
40// nil pointer panic when call Error().
41func NewAggregate(errlist []error) Aggregate {
42	if len(errlist) == 0 {
43		return nil
44	}
45	// In case of input error list contains nil
46	var errs []error
47	for _, e := range errlist {
48		if e != nil {
49			errs = append(errs, e)
50		}
51	}
52	if len(errs) == 0 {
53		return nil
54	}
55	return aggregate(errs)
56}
57
58// This helper implements the error and Errors interfaces.  Keeping it private
59// prevents people from making an aggregate of 0 errors, which is not
60// an error, but does satisfy the error interface.
61type aggregate []error
62
63// Error is part of the error interface.
64func (agg aggregate) Error() string {
65	if len(agg) == 0 {
66		// This should never happen, really.
67		return ""
68	}
69	if len(agg) == 1 {
70		return agg[0].Error()
71	}
72	seenerrs := sets.NewString()
73	result := ""
74	agg.visit(func(err error) {
75		msg := err.Error()
76		if seenerrs.Has(msg) {
77			return
78		}
79		seenerrs.Insert(msg)
80		if len(seenerrs) > 1 {
81			result += ", "
82		}
83		result += msg
84	})
85	if len(seenerrs) == 1 {
86		return result
87	}
88	return "[" + result + "]"
89}
90
91func (agg aggregate) visit(f func(err error)) {
92	for _, err := range agg {
93		switch err := err.(type) {
94		case aggregate:
95			err.visit(f)
96		case Aggregate:
97			for _, nestedErr := range err.Errors() {
98				f(nestedErr)
99			}
100		default:
101			f(err)
102		}
103	}
104}
105
106// Errors is part of the Aggregate interface.
107func (agg aggregate) Errors() []error {
108	return []error(agg)
109}
110
111// Matcher is used to match errors.  Returns true if the error matches.
112type Matcher func(error) bool
113
114// FilterOut removes all errors that match any of the matchers from the input
115// error.  If the input is a singular error, only that error is tested.  If the
116// input implements the Aggregate interface, the list of errors will be
117// processed recursively.
118//
119// This can be used, for example, to remove known-OK errors (such as io.EOF or
120// os.PathNotFound) from a list of errors.
121func FilterOut(err error, fns ...Matcher) error {
122	if err == nil {
123		return nil
124	}
125	if agg, ok := err.(Aggregate); ok {
126		return NewAggregate(filterErrors(agg.Errors(), fns...))
127	}
128	if !matchesError(err, fns...) {
129		return err
130	}
131	return nil
132}
133
134// matchesError returns true if any Matcher returns true
135func matchesError(err error, fns ...Matcher) bool {
136	for _, fn := range fns {
137		if fn(err) {
138			return true
139		}
140	}
141	return false
142}
143
144// filterErrors returns any errors (or nested errors, if the list contains
145// nested Errors) for which all fns return false. If no errors
146// remain a nil list is returned. The resulting silec will have all
147// nested slices flattened as a side effect.
148func filterErrors(list []error, fns ...Matcher) []error {
149	result := []error{}
150	for _, err := range list {
151		r := FilterOut(err, fns...)
152		if r != nil {
153			result = append(result, r)
154		}
155	}
156	return result
157}
158
159// Flatten takes an Aggregate, which may hold other Aggregates in arbitrary
160// nesting, and flattens them all into a single Aggregate, recursively.
161func Flatten(agg Aggregate) Aggregate {
162	result := []error{}
163	if agg == nil {
164		return nil
165	}
166	for _, err := range agg.Errors() {
167		if a, ok := err.(Aggregate); ok {
168			r := Flatten(a)
169			if r != nil {
170				result = append(result, r.Errors()...)
171			}
172		} else {
173			if err != nil {
174				result = append(result, err)
175			}
176		}
177	}
178	return NewAggregate(result)
179}
180
181// CreateAggregateFromMessageCountMap converts MessageCountMap Aggregate
182func CreateAggregateFromMessageCountMap(m MessageCountMap) Aggregate {
183	if m == nil {
184		return nil
185	}
186	result := make([]error, 0, len(m))
187	for errStr, count := range m {
188		var countStr string
189		if count > 1 {
190			countStr = fmt.Sprintf(" (repeated %v times)", count)
191		}
192		result = append(result, fmt.Errorf("%v%v", errStr, countStr))
193	}
194	return NewAggregate(result)
195}
196
197// Reduce will return err or, if err is an Aggregate and only has one item,
198// the first item in the aggregate.
199func Reduce(err error) error {
200	if agg, ok := err.(Aggregate); ok && err != nil {
201		switch len(agg.Errors()) {
202		case 1:
203			return agg.Errors()[0]
204		case 0:
205			return nil
206		}
207	}
208	return err
209}
210
211// AggregateGoroutines runs the provided functions in parallel, stuffing all
212// non-nil errors into the returned Aggregate.
213// Returns nil if all the functions complete successfully.
214func AggregateGoroutines(funcs ...func() error) Aggregate {
215	errChan := make(chan error, len(funcs))
216	for _, f := range funcs {
217		go func(f func() error) { errChan <- f() }(f)
218	}
219	errs := make([]error, 0)
220	for i := 0; i < cap(errChan); i++ {
221		if err := <-errChan; err != nil {
222			errs = append(errs, err)
223		}
224	}
225	return NewAggregate(errs)
226}
227
228// ErrPreconditionViolated is returned when the precondition is violated
229var ErrPreconditionViolated = errors.New("precondition is violated")
230