1// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package errors
16
17import (
18	"encoding/json"
19	"fmt"
20	"net/http"
21	"reflect"
22	"strings"
23)
24
25// DefaultHTTPCode is used when the error Code cannot be used as an HTTP code.
26var DefaultHTTPCode = http.StatusUnprocessableEntity
27
28// Error represents a error interface all swagger framework errors implement
29type Error interface {
30	error
31	Code() int32
32}
33
34type apiError struct {
35	code    int32
36	message string
37}
38
39func (a *apiError) Error() string {
40	return a.message
41}
42
43func (a *apiError) Code() int32 {
44	return a.code
45}
46
47// MarshalJSON implements the JSON encoding interface
48func (a apiError) MarshalJSON() ([]byte, error) {
49	return json.Marshal(map[string]interface{}{
50		"code":    a.code,
51		"message": a.message,
52	})
53}
54
55// New creates a new API error with a code and a message
56func New(code int32, message string, args ...interface{}) Error {
57	if len(args) > 0 {
58		return &apiError{code, fmt.Sprintf(message, args...)}
59	}
60	return &apiError{code, message}
61}
62
63// NotFound creates a new not found error
64func NotFound(message string, args ...interface{}) Error {
65	if message == "" {
66		message = "Not found"
67	}
68	return New(http.StatusNotFound, fmt.Sprintf(message, args...))
69}
70
71// NotImplemented creates a new not implemented error
72func NotImplemented(message string) Error {
73	return New(http.StatusNotImplemented, message)
74}
75
76// MethodNotAllowedError represents an error for when the path matches but the method doesn't
77type MethodNotAllowedError struct {
78	code    int32
79	Allowed []string
80	message string
81}
82
83func (m *MethodNotAllowedError) Error() string {
84	return m.message
85}
86
87// Code the error code
88func (m *MethodNotAllowedError) Code() int32 {
89	return m.code
90}
91
92// MarshalJSON implements the JSON encoding interface
93func (m MethodNotAllowedError) MarshalJSON() ([]byte, error) {
94	return json.Marshal(map[string]interface{}{
95		"code":    m.code,
96		"message": m.message,
97		"allowed": m.Allowed,
98	})
99}
100
101func errorAsJSON(err Error) []byte {
102	b, _ := json.Marshal(struct {
103		Code    int32  `json:"code"`
104		Message string `json:"message"`
105	}{err.Code(), err.Error()})
106	return b
107}
108
109func flattenComposite(errs *CompositeError) *CompositeError {
110	var res []error
111	for _, er := range errs.Errors {
112		switch e := er.(type) {
113		case *CompositeError:
114			if len(e.Errors) > 0 {
115				flat := flattenComposite(e)
116				if len(flat.Errors) > 0 {
117					res = append(res, flat.Errors...)
118				}
119			}
120		default:
121			if e != nil {
122				res = append(res, e)
123			}
124		}
125	}
126	return CompositeValidationError(res...)
127}
128
129// MethodNotAllowed creates a new method not allowed error
130func MethodNotAllowed(requested string, allow []string) Error {
131	msg := fmt.Sprintf("method %s is not allowed, but [%s] are", requested, strings.Join(allow, ","))
132	return &MethodNotAllowedError{code: http.StatusMethodNotAllowed, Allowed: allow, message: msg}
133}
134
135// ServeError the error handler interface implementation
136func ServeError(rw http.ResponseWriter, r *http.Request, err error) {
137	rw.Header().Set("Content-Type", "application/json")
138	switch e := err.(type) {
139	case *CompositeError:
140		er := flattenComposite(e)
141		// strips composite errors to first element only
142		if len(er.Errors) > 0 {
143			ServeError(rw, r, er.Errors[0])
144		} else {
145			// guard against empty CompositeError (invalid construct)
146			ServeError(rw, r, nil)
147		}
148	case *MethodNotAllowedError:
149		rw.Header().Add("Allow", strings.Join(err.(*MethodNotAllowedError).Allowed, ","))
150		rw.WriteHeader(asHTTPCode(int(e.Code())))
151		if r == nil || r.Method != http.MethodHead {
152			_, _ = rw.Write(errorAsJSON(e))
153		}
154	case Error:
155		value := reflect.ValueOf(e)
156		if value.Kind() == reflect.Ptr && value.IsNil() {
157			rw.WriteHeader(http.StatusInternalServerError)
158			_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, "Unknown error")))
159			return
160		}
161		rw.WriteHeader(asHTTPCode(int(e.Code())))
162		if r == nil || r.Method != http.MethodHead {
163			_, _ = rw.Write(errorAsJSON(e))
164		}
165	case nil:
166		rw.WriteHeader(http.StatusInternalServerError)
167		_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, "Unknown error")))
168	default:
169		rw.WriteHeader(http.StatusInternalServerError)
170		if r == nil || r.Method != http.MethodHead {
171			_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, err.Error())))
172		}
173	}
174}
175
176func asHTTPCode(input int) int {
177	if input >= 600 {
178		return DefaultHTTPCode
179	}
180	return input
181}
182