1// Copyright 2018, OpenCensus Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package ochttp
16
17import (
18	"io"
19	"net/http"
20	"net/http/httptrace"
21
22	"go.opencensus.io/plugin/ochttp/propagation/b3"
23	"go.opencensus.io/trace"
24	"go.opencensus.io/trace/propagation"
25)
26
27// TODO(jbd): Add godoc examples.
28
29var defaultFormat propagation.HTTPFormat = &b3.HTTPFormat{}
30
31// Attributes recorded on the span for the requests.
32// Only trace exporters will need them.
33const (
34	HostAttribute       = "http.host"
35	MethodAttribute     = "http.method"
36	PathAttribute       = "http.path"
37	URLAttribute        = "http.url"
38	UserAgentAttribute  = "http.user_agent"
39	StatusCodeAttribute = "http.status_code"
40)
41
42type traceTransport struct {
43	base           http.RoundTripper
44	startOptions   trace.StartOptions
45	format         propagation.HTTPFormat
46	formatSpanName func(*http.Request) string
47	newClientTrace func(*http.Request, *trace.Span) *httptrace.ClientTrace
48}
49
50// TODO(jbd): Add message events for request and response size.
51
52// RoundTrip creates a trace.Span and inserts it into the outgoing request's headers.
53// The created span can follow a parent span, if a parent is presented in
54// the request's context.
55func (t *traceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
56	name := t.formatSpanName(req)
57	// TODO(jbd): Discuss whether we want to prefix
58	// outgoing requests with Sent.
59	ctx, span := trace.StartSpan(req.Context(), name,
60		trace.WithSampler(t.startOptions.Sampler),
61		trace.WithSpanKind(trace.SpanKindClient))
62
63	if t.newClientTrace != nil {
64		req = req.WithContext(httptrace.WithClientTrace(ctx, t.newClientTrace(req, span)))
65	} else {
66		req = req.WithContext(ctx)
67	}
68
69	if t.format != nil {
70		// SpanContextToRequest will modify its Request argument, which is
71		// contrary to the contract for http.RoundTripper, so we need to
72		// pass it a copy of the Request.
73		// However, the Request struct itself was already copied by
74		// the WithContext calls above and so we just need to copy the header.
75		header := make(http.Header)
76		for k, v := range req.Header {
77			header[k] = v
78		}
79		req.Header = header
80		t.format.SpanContextToRequest(span.SpanContext(), req)
81	}
82
83	span.AddAttributes(requestAttrs(req)...)
84	resp, err := t.base.RoundTrip(req)
85	if err != nil {
86		span.SetStatus(trace.Status{Code: trace.StatusCodeUnknown, Message: err.Error()})
87		span.End()
88		return resp, err
89	}
90
91	span.AddAttributes(responseAttrs(resp)...)
92	span.SetStatus(TraceStatus(resp.StatusCode, resp.Status))
93
94	// span.End() will be invoked after
95	// a read from resp.Body returns io.EOF or when
96	// resp.Body.Close() is invoked.
97	bt := &bodyTracker{rc: resp.Body, span: span}
98	resp.Body = wrappedBody(bt, resp.Body)
99	return resp, err
100}
101
102// bodyTracker wraps a response.Body and invokes
103// trace.EndSpan on encountering io.EOF on reading
104// the body of the original response.
105type bodyTracker struct {
106	rc   io.ReadCloser
107	span *trace.Span
108}
109
110var _ io.ReadCloser = (*bodyTracker)(nil)
111
112func (bt *bodyTracker) Read(b []byte) (int, error) {
113	n, err := bt.rc.Read(b)
114
115	switch err {
116	case nil:
117		return n, nil
118	case io.EOF:
119		bt.span.End()
120	default:
121		// For all other errors, set the span status
122		bt.span.SetStatus(trace.Status{
123			// Code 2 is the error code for Internal server error.
124			Code:    2,
125			Message: err.Error(),
126		})
127	}
128	return n, err
129}
130
131func (bt *bodyTracker) Close() error {
132	// Invoking endSpan on Close will help catch the cases
133	// in which a read returned a non-nil error, we set the
134	// span status but didn't end the span.
135	bt.span.End()
136	return bt.rc.Close()
137}
138
139// CancelRequest cancels an in-flight request by closing its connection.
140func (t *traceTransport) CancelRequest(req *http.Request) {
141	type canceler interface {
142		CancelRequest(*http.Request)
143	}
144	if cr, ok := t.base.(canceler); ok {
145		cr.CancelRequest(req)
146	}
147}
148
149func spanNameFromURL(req *http.Request) string {
150	return req.URL.Path
151}
152
153func requestAttrs(r *http.Request) []trace.Attribute {
154	userAgent := r.UserAgent()
155
156	attrs := make([]trace.Attribute, 0, 5)
157	attrs = append(attrs,
158		trace.StringAttribute(PathAttribute, r.URL.Path),
159		trace.StringAttribute(URLAttribute, r.URL.String()),
160		trace.StringAttribute(HostAttribute, r.Host),
161		trace.StringAttribute(MethodAttribute, r.Method),
162	)
163
164	if userAgent != "" {
165		attrs = append(attrs, trace.StringAttribute(UserAgentAttribute, userAgent))
166	}
167
168	return attrs
169}
170
171func responseAttrs(resp *http.Response) []trace.Attribute {
172	return []trace.Attribute{
173		trace.Int64Attribute(StatusCodeAttribute, int64(resp.StatusCode)),
174	}
175}
176
177// TraceStatus is a utility to convert the HTTP status code to a trace.Status that
178// represents the outcome as closely as possible.
179func TraceStatus(httpStatusCode int, statusLine string) trace.Status {
180	var code int32
181	if httpStatusCode < 200 || httpStatusCode >= 400 {
182		code = trace.StatusCodeUnknown
183	}
184	switch httpStatusCode {
185	case 499:
186		code = trace.StatusCodeCancelled
187	case http.StatusBadRequest:
188		code = trace.StatusCodeInvalidArgument
189	case http.StatusUnprocessableEntity:
190		code = trace.StatusCodeInvalidArgument
191	case http.StatusGatewayTimeout:
192		code = trace.StatusCodeDeadlineExceeded
193	case http.StatusNotFound:
194		code = trace.StatusCodeNotFound
195	case http.StatusForbidden:
196		code = trace.StatusCodePermissionDenied
197	case http.StatusUnauthorized: // 401 is actually unauthenticated.
198		code = trace.StatusCodeUnauthenticated
199	case http.StatusTooManyRequests:
200		code = trace.StatusCodeResourceExhausted
201	case http.StatusNotImplemented:
202		code = trace.StatusCodeUnimplemented
203	case http.StatusServiceUnavailable:
204		code = trace.StatusCodeUnavailable
205	case http.StatusOK:
206		code = trace.StatusCodeOK
207	case http.StatusConflict:
208		code = trace.StatusCodeAlreadyExists
209	}
210
211	return trace.Status{Code: code, Message: codeToStr[code]}
212}
213
214var codeToStr = map[int32]string{
215	trace.StatusCodeOK:                 `OK`,
216	trace.StatusCodeCancelled:          `CANCELLED`,
217	trace.StatusCodeUnknown:            `UNKNOWN`,
218	trace.StatusCodeInvalidArgument:    `INVALID_ARGUMENT`,
219	trace.StatusCodeDeadlineExceeded:   `DEADLINE_EXCEEDED`,
220	trace.StatusCodeNotFound:           `NOT_FOUND`,
221	trace.StatusCodeAlreadyExists:      `ALREADY_EXISTS`,
222	trace.StatusCodePermissionDenied:   `PERMISSION_DENIED`,
223	trace.StatusCodeResourceExhausted:  `RESOURCE_EXHAUSTED`,
224	trace.StatusCodeFailedPrecondition: `FAILED_PRECONDITION`,
225	trace.StatusCodeAborted:            `ABORTED`,
226	trace.StatusCodeOutOfRange:         `OUT_OF_RANGE`,
227	trace.StatusCodeUnimplemented:      `UNIMPLEMENTED`,
228	trace.StatusCodeInternal:           `INTERNAL`,
229	trace.StatusCodeUnavailable:        `UNAVAILABLE`,
230	trace.StatusCodeDataLoss:           `DATA_LOSS`,
231	trace.StatusCodeUnauthenticated:    `UNAUTHENTICATED`,
232}
233
234func isHealthEndpoint(path string) bool {
235	// Health checking is pretty frequent and
236	// traces collected for health endpoints
237	// can be extremely noisy and expensive.
238	// Disable canonical health checking endpoints
239	// like /healthz and /_ah/health for now.
240	if path == "/healthz" || path == "/_ah/health" {
241		return true
242	}
243	return false
244}
245