1// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package errors
16
17import (
18	"encoding/json"
19	"fmt"
20	"net/http"
21	"reflect"
22	"strings"
23)
24
25// DefaultHTTPCode is used when the error Code cannot be used as an HTTP code.
26var DefaultHTTPCode = http.StatusUnprocessableEntity
27
28// Error represents a error interface all swagger framework errors implement
29type Error interface {
30	error
31	Code() int32
32}
33
34type apiError struct {
35	code    int32
36	message string
37}
38
39func (a *apiError) Error() string {
40	return a.message
41}
42
43func (a *apiError) Code() int32 {
44	return a.code
45}
46
47// New creates a new API error with a code and a message
48func New(code int32, message string, args ...interface{}) Error {
49	if len(args) > 0 {
50		return &apiError{code, fmt.Sprintf(message, args...)}
51	}
52	return &apiError{code, message}
53}
54
55// NotFound creates a new not found error
56func NotFound(message string, args ...interface{}) Error {
57	if message == "" {
58		message = "Not found"
59	}
60	return New(http.StatusNotFound, fmt.Sprintf(message, args...))
61}
62
63// NotImplemented creates a new not implemented error
64func NotImplemented(message string) Error {
65	return New(http.StatusNotImplemented, message)
66}
67
68// MethodNotAllowedError represents an error for when the path matches but the method doesn't
69type MethodNotAllowedError struct {
70	code    int32
71	Allowed []string
72	message string
73}
74
75func (m *MethodNotAllowedError) Error() string {
76	return m.message
77}
78
79// Code the error code
80func (m *MethodNotAllowedError) Code() int32 {
81	return m.code
82}
83
84func errorAsJSON(err Error) []byte {
85	b, _ := json.Marshal(struct {
86		Code    int32  `json:"code"`
87		Message string `json:"message"`
88	}{err.Code(), err.Error()})
89	return b
90}
91
92func flattenComposite(errs *CompositeError) *CompositeError {
93	var res []error
94	for _, er := range errs.Errors {
95		switch e := er.(type) {
96		case *CompositeError:
97			if len(e.Errors) > 0 {
98				flat := flattenComposite(e)
99				if len(flat.Errors) > 0 {
100					res = append(res, flat.Errors...)
101				}
102			}
103		default:
104			if e != nil {
105				res = append(res, e)
106			}
107		}
108	}
109	return CompositeValidationError(res...)
110}
111
112// MethodNotAllowed creates a new method not allowed error
113func MethodNotAllowed(requested string, allow []string) Error {
114	msg := fmt.Sprintf("method %s is not allowed, but [%s] are", requested, strings.Join(allow, ","))
115	return &MethodNotAllowedError{code: http.StatusMethodNotAllowed, Allowed: allow, message: msg}
116}
117
118// ServeError the error handler interface implementation
119func ServeError(rw http.ResponseWriter, r *http.Request, err error) {
120	rw.Header().Set("Content-Type", "application/json")
121	switch e := err.(type) {
122	case *CompositeError:
123		er := flattenComposite(e)
124		// strips composite errors to first element only
125		if len(er.Errors) > 0 {
126			ServeError(rw, r, er.Errors[0])
127		} else {
128			// guard against empty CompositeError (invalid construct)
129			ServeError(rw, r, nil)
130		}
131	case *MethodNotAllowedError:
132		rw.Header().Add("Allow", strings.Join(err.(*MethodNotAllowedError).Allowed, ","))
133		rw.WriteHeader(asHTTPCode(int(e.Code())))
134		if r == nil || r.Method != http.MethodHead {
135			_, _ = rw.Write(errorAsJSON(e))
136		}
137	case Error:
138		value := reflect.ValueOf(e)
139		if value.Kind() == reflect.Ptr && value.IsNil() {
140			rw.WriteHeader(http.StatusInternalServerError)
141			_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, "Unknown error")))
142			return
143		}
144		rw.WriteHeader(asHTTPCode(int(e.Code())))
145		if r == nil || r.Method != http.MethodHead {
146			_, _ = rw.Write(errorAsJSON(e))
147		}
148	case nil:
149		rw.WriteHeader(http.StatusInternalServerError)
150		_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, "Unknown error")))
151	default:
152		rw.WriteHeader(http.StatusInternalServerError)
153		if r == nil || r.Method != http.MethodHead {
154			_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, err.Error())))
155		}
156	}
157}
158
159func asHTTPCode(input int) int {
160	if input >= 600 {
161		return DefaultHTTPCode
162	}
163	return input
164}
165