1package logical
2
3import (
4	"encoding/json"
5	"errors"
6	"fmt"
7	"net/http"
8
9	"github.com/hashicorp/errwrap"
10	multierror "github.com/hashicorp/go-multierror"
11	"github.com/hashicorp/vault/sdk/helper/consts"
12)
13
14// RespondErrorCommon pulls most of the functionality from http's
15// respondErrorCommon and some of http's handleLogical and makes it available
16// to both the http package and elsewhere.
17func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
18	if err == nil && (resp == nil || !resp.IsError()) {
19		switch {
20		case req.Operation == ReadOperation:
21			if resp == nil {
22				return http.StatusNotFound, nil
23			}
24
25		// Basically: if we have empty "keys" or no keys at all, 404. This
26		// provides consistency with GET.
27		case req.Operation == ListOperation && (resp == nil || resp.WrapInfo == nil):
28			if resp == nil {
29				return http.StatusNotFound, nil
30			}
31			if len(resp.Data) == 0 {
32				if len(resp.Warnings) > 0 {
33					return 0, nil
34				}
35				return http.StatusNotFound, nil
36			}
37			keysRaw, ok := resp.Data["keys"]
38			if !ok || keysRaw == nil {
39				// If we don't have keys but have other data, return as-is
40				if len(resp.Data) > 0 || len(resp.Warnings) > 0 {
41					return 0, nil
42				}
43				return http.StatusNotFound, nil
44			}
45
46			var keys []string
47			switch keysRaw.(type) {
48			case []interface{}:
49				keys = make([]string, len(keysRaw.([]interface{})))
50				for i, el := range keysRaw.([]interface{}) {
51					s, ok := el.(string)
52					if !ok {
53						return http.StatusInternalServerError, nil
54					}
55					keys[i] = s
56				}
57
58			case []string:
59				keys = keysRaw.([]string)
60			default:
61				return http.StatusInternalServerError, nil
62			}
63
64			if len(keys) == 0 {
65				return http.StatusNotFound, nil
66			}
67		}
68
69		return 0, nil
70	}
71
72	if errwrap.ContainsType(err, new(ReplicationCodedError)) {
73		var allErrors error
74		var codedErr *ReplicationCodedError
75		errwrap.Walk(err, func(inErr error) {
76			newErr, ok := inErr.(*ReplicationCodedError)
77			if ok {
78				codedErr = newErr
79			} else {
80				allErrors = multierror.Append(allErrors, inErr)
81			}
82		})
83		if allErrors != nil {
84			return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors)
85		}
86		return codedErr.Code, errors.New(codedErr.Msg)
87	}
88
89	// Start out with internal server error since in most of these cases there
90	// won't be a response so this won't be overridden
91	statusCode := http.StatusInternalServerError
92	// If we actually have a response, start out with bad request
93	if resp != nil {
94		statusCode = http.StatusBadRequest
95	}
96
97	// Now, check the error itself; if it has a specific logical error, set the
98	// appropriate code
99	if err != nil {
100		switch {
101		case errwrap.ContainsType(err, new(StatusBadRequest)):
102			statusCode = http.StatusBadRequest
103		case errwrap.Contains(err, ErrPermissionDenied.Error()):
104			statusCode = http.StatusForbidden
105		case errwrap.Contains(err, ErrUnsupportedOperation.Error()):
106			statusCode = http.StatusMethodNotAllowed
107		case errwrap.Contains(err, ErrUnsupportedPath.Error()):
108			statusCode = http.StatusNotFound
109		case errwrap.Contains(err, ErrInvalidRequest.Error()):
110			statusCode = http.StatusBadRequest
111		case errwrap.Contains(err, ErrUpstreamRateLimited.Error()):
112			statusCode = http.StatusBadGateway
113		}
114	}
115
116	if resp != nil && resp.IsError() {
117		err = fmt.Errorf("%s", resp.Data["error"].(string))
118	}
119
120	return statusCode, err
121}
122
123// AdjustErrorStatusCode adjusts the status that will be sent in error
124// conditions in a way that can be shared across http's respondError and other
125// locations.
126func AdjustErrorStatusCode(status *int, err error) {
127	// Handle nested errors
128	if t, ok := err.(*multierror.Error); ok {
129		for _, e := range t.Errors {
130			AdjustErrorStatusCode(status, e)
131		}
132	}
133
134	// Adjust status code when sealed
135	if errwrap.Contains(err, consts.ErrSealed.Error()) {
136		*status = http.StatusServiceUnavailable
137	}
138
139	// Adjust status code on
140	if errwrap.Contains(err, "http: request body too large") {
141		*status = http.StatusRequestEntityTooLarge
142	}
143
144	// Allow HTTPCoded error passthrough to specify a code
145	if t, ok := err.(HTTPCodedError); ok {
146		*status = t.Code()
147	}
148}
149
150func RespondError(w http.ResponseWriter, status int, err error) {
151	AdjustErrorStatusCode(&status, err)
152
153	w.Header().Set("Content-Type", "application/json")
154	w.WriteHeader(status)
155
156	type ErrorResponse struct {
157		Errors []string `json:"errors"`
158	}
159	resp := &ErrorResponse{Errors: make([]string, 0, 1)}
160	if err != nil {
161		resp.Errors = append(resp.Errors, err.Error())
162	}
163
164	enc := json.NewEncoder(w)
165	enc.Encode(resp)
166}
167