1package trace
2
3import (
4	"encoding/json"
5	"fmt"
6	"net/http"
7)
8
9// WriteError sets up HTTP error response and writes it to writer w
10func WriteError(w http.ResponseWriter, err error) {
11	if IsAggregate(err) {
12		for i := 0; i < maxHops; i++ {
13			var aggErr Aggregate
14			var ok bool
15			if aggErr, ok = Unwrap(err).(Aggregate); !ok {
16				break
17			}
18			errors := aggErr.Errors()
19			if len(errors) == 0 {
20				break
21			}
22			err = errors[0]
23		}
24	}
25	replyJSON(w, ErrorToCode(err), err)
26}
27
28// ErrorToCode returns an appropriate HTTP status code based on the provided error type
29func ErrorToCode(err error) int {
30	switch {
31	case IsAggregate(err):
32		return http.StatusGatewayTimeout
33	case IsNotFound(err):
34		return http.StatusNotFound
35	case IsBadParameter(err) || IsOAuth2(err):
36		return http.StatusBadRequest
37	case IsNotImplemented(err):
38		return http.StatusNotImplemented
39	case IsCompareFailed(err):
40		return http.StatusPreconditionFailed
41	case IsAccessDenied(err):
42		return http.StatusForbidden
43	case IsAlreadyExists(err):
44		return http.StatusConflict
45	case IsLimitExceeded(err):
46		return http.StatusTooManyRequests
47	case IsConnectionProblem(err):
48		return http.StatusGatewayTimeout
49	default:
50		return http.StatusInternalServerError
51	}
52}
53
54// ReadError converts http error to internal error type
55// based on HTTP response code and HTTP body contents
56// if status code does not indicate error, it will return nil
57func ReadError(statusCode int, re []byte) error {
58	var e error
59	switch statusCode {
60	case http.StatusNotFound:
61		e = &NotFoundError{Message: string(re)}
62	case http.StatusBadRequest:
63		e = &BadParameterError{Message: string(re)}
64	case http.StatusNotImplemented:
65		e = &NotImplementedError{Message: string(re)}
66	case http.StatusPreconditionFailed:
67		e = &CompareFailedError{Message: string(re)}
68	case http.StatusForbidden:
69		e = &AccessDeniedError{Message: string(re)}
70	case http.StatusConflict:
71		e = &AlreadyExistsError{Message: string(re)}
72	case http.StatusTooManyRequests:
73		e = &LimitExceededError{Message: string(re)}
74	case http.StatusGatewayTimeout:
75		e = &ConnectionProblemError{Message: string(re)}
76	default:
77		if statusCode < 200 || statusCode >= 400 {
78			return Errorf(string(re))
79		}
80		return nil
81	}
82	return unmarshalError(e, re)
83}
84
85func replyJSON(w http.ResponseWriter, code int, err error) {
86	w.Header().Set("Content-Type", "application/json")
87	w.WriteHeader(code)
88
89	var out []byte
90	if IsDebug() {
91		// trace error can marshal itself,
92		// otherwise capture error message and marshal it explicitly
93		var obj interface{} = err
94		if _, ok := err.(*TraceErr); !ok {
95			obj = message{Message: err.Error()}
96		}
97		out, err = json.MarshalIndent(obj, "", "    ")
98		if err != nil {
99			out = []byte(fmt.Sprintf(`{"message": "internal marshal error: %v"}`, err))
100		}
101	} else {
102		innerError := err
103		if terr, ok := err.(Error); ok {
104			innerError = terr.OrigError()
105		}
106		out, err = json.Marshal(message{Message: innerError.Error()})
107	}
108	w.Write(out)
109}
110
111type message struct {
112	Message string `json:"message"`
113}
114
115func unmarshalError(err error, responseBody []byte) error {
116	if len(responseBody) == 0 {
117		return err
118	}
119	var raw RawTrace
120	if err2 := json.Unmarshal(responseBody, &raw); err2 != nil {
121		return err
122	}
123	if len(raw.Traces) != 0 && len(raw.Err) != 0 {
124		// try to capture traces, if there are any
125		err2 := json.Unmarshal(raw.Err, err)
126		if err2 != nil {
127			return err
128		}
129		return &TraceErr{Traces: raw.Traces, Err: err, Message: raw.Message}
130	}
131	json.Unmarshal(responseBody, err)
132	return err
133}
134