1// Copyright 2018, OpenCensus Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package ochttp
16
17import (
18	"io"
19	"net/http"
20	"net/http/httptrace"
21
22	"go.opencensus.io/plugin/ochttp/propagation/b3"
23	"go.opencensus.io/trace"
24	"go.opencensus.io/trace/propagation"
25)
26
27// TODO(jbd): Add godoc examples.
28
29var defaultFormat propagation.HTTPFormat = &b3.HTTPFormat{}
30
31// Attributes recorded on the span for the requests.
32// Only trace exporters will need them.
33const (
34	HostAttribute       = "http.host"
35	MethodAttribute     = "http.method"
36	PathAttribute       = "http.path"
37	UserAgentAttribute  = "http.user_agent"
38	StatusCodeAttribute = "http.status_code"
39)
40
41type traceTransport struct {
42	base           http.RoundTripper
43	startOptions   trace.StartOptions
44	format         propagation.HTTPFormat
45	formatSpanName func(*http.Request) string
46	newClientTrace func(*http.Request, *trace.Span) *httptrace.ClientTrace
47}
48
49// TODO(jbd): Add message events for request and response size.
50
51// RoundTrip creates a trace.Span and inserts it into the outgoing request's headers.
52// The created span can follow a parent span, if a parent is presented in
53// the request's context.
54func (t *traceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
55	name := t.formatSpanName(req)
56	// TODO(jbd): Discuss whether we want to prefix
57	// outgoing requests with Sent.
58	ctx, span := trace.StartSpan(req.Context(), name,
59		trace.WithSampler(t.startOptions.Sampler),
60		trace.WithSpanKind(trace.SpanKindClient))
61
62	if t.newClientTrace != nil {
63		req = req.WithContext(httptrace.WithClientTrace(ctx, t.newClientTrace(req, span)))
64	} else {
65		req = req.WithContext(ctx)
66	}
67
68	if t.format != nil {
69		// SpanContextToRequest will modify its Request argument, which is
70		// contrary to the contract for http.RoundTripper, so we need to
71		// pass it a copy of the Request.
72		// However, the Request struct itself was already copied by
73		// the WithContext calls above and so we just need to copy the header.
74		header := make(http.Header)
75		for k, v := range req.Header {
76			header[k] = v
77		}
78		req.Header = header
79		t.format.SpanContextToRequest(span.SpanContext(), req)
80	}
81
82	span.AddAttributes(requestAttrs(req)...)
83	resp, err := t.base.RoundTrip(req)
84	if err != nil {
85		span.SetStatus(trace.Status{Code: trace.StatusCodeUnknown, Message: err.Error()})
86		span.End()
87		return resp, err
88	}
89
90	span.AddAttributes(responseAttrs(resp)...)
91	span.SetStatus(TraceStatus(resp.StatusCode, resp.Status))
92
93	// span.End() will be invoked after
94	// a read from resp.Body returns io.EOF or when
95	// resp.Body.Close() is invoked.
96	resp.Body = &bodyTracker{rc: resp.Body, span: span}
97	return resp, err
98}
99
100// bodyTracker wraps a response.Body and invokes
101// trace.EndSpan on encountering io.EOF on reading
102// the body of the original response.
103type bodyTracker struct {
104	rc   io.ReadCloser
105	span *trace.Span
106}
107
108var _ io.ReadCloser = (*bodyTracker)(nil)
109
110func (bt *bodyTracker) Read(b []byte) (int, error) {
111	n, err := bt.rc.Read(b)
112
113	switch err {
114	case nil:
115		return n, nil
116	case io.EOF:
117		bt.span.End()
118	default:
119		// For all other errors, set the span status
120		bt.span.SetStatus(trace.Status{
121			// Code 2 is the error code for Internal server error.
122			Code:    2,
123			Message: err.Error(),
124		})
125	}
126	return n, err
127}
128
129func (bt *bodyTracker) Close() error {
130	// Invoking endSpan on Close will help catch the cases
131	// in which a read returned a non-nil error, we set the
132	// span status but didn't end the span.
133	bt.span.End()
134	return bt.rc.Close()
135}
136
137// CancelRequest cancels an in-flight request by closing its connection.
138func (t *traceTransport) CancelRequest(req *http.Request) {
139	type canceler interface {
140		CancelRequest(*http.Request)
141	}
142	if cr, ok := t.base.(canceler); ok {
143		cr.CancelRequest(req)
144	}
145}
146
147func spanNameFromURL(req *http.Request) string {
148	return req.URL.Path
149}
150
151func requestAttrs(r *http.Request) []trace.Attribute {
152	return []trace.Attribute{
153		trace.StringAttribute(PathAttribute, r.URL.Path),
154		trace.StringAttribute(HostAttribute, r.URL.Host),
155		trace.StringAttribute(MethodAttribute, r.Method),
156		trace.StringAttribute(UserAgentAttribute, r.UserAgent()),
157	}
158}
159
160func responseAttrs(resp *http.Response) []trace.Attribute {
161	return []trace.Attribute{
162		trace.Int64Attribute(StatusCodeAttribute, int64(resp.StatusCode)),
163	}
164}
165
166// TraceStatus is a utility to convert the HTTP status code to a trace.Status that
167// represents the outcome as closely as possible.
168func TraceStatus(httpStatusCode int, statusLine string) trace.Status {
169	var code int32
170	if httpStatusCode < 200 || httpStatusCode >= 400 {
171		code = trace.StatusCodeUnknown
172	}
173	switch httpStatusCode {
174	case 499:
175		code = trace.StatusCodeCancelled
176	case http.StatusBadRequest:
177		code = trace.StatusCodeInvalidArgument
178	case http.StatusGatewayTimeout:
179		code = trace.StatusCodeDeadlineExceeded
180	case http.StatusNotFound:
181		code = trace.StatusCodeNotFound
182	case http.StatusForbidden:
183		code = trace.StatusCodePermissionDenied
184	case http.StatusUnauthorized: // 401 is actually unauthenticated.
185		code = trace.StatusCodeUnauthenticated
186	case http.StatusTooManyRequests:
187		code = trace.StatusCodeResourceExhausted
188	case http.StatusNotImplemented:
189		code = trace.StatusCodeUnimplemented
190	case http.StatusServiceUnavailable:
191		code = trace.StatusCodeUnavailable
192	case http.StatusOK:
193		code = trace.StatusCodeOK
194	}
195	return trace.Status{Code: code, Message: codeToStr[code]}
196}
197
198var codeToStr = map[int32]string{
199	trace.StatusCodeOK:                 `OK`,
200	trace.StatusCodeCancelled:          `CANCELLED`,
201	trace.StatusCodeUnknown:            `UNKNOWN`,
202	trace.StatusCodeInvalidArgument:    `INVALID_ARGUMENT`,
203	trace.StatusCodeDeadlineExceeded:   `DEADLINE_EXCEEDED`,
204	trace.StatusCodeNotFound:           `NOT_FOUND`,
205	trace.StatusCodeAlreadyExists:      `ALREADY_EXISTS`,
206	trace.StatusCodePermissionDenied:   `PERMISSION_DENIED`,
207	trace.StatusCodeResourceExhausted:  `RESOURCE_EXHAUSTED`,
208	trace.StatusCodeFailedPrecondition: `FAILED_PRECONDITION`,
209	trace.StatusCodeAborted:            `ABORTED`,
210	trace.StatusCodeOutOfRange:         `OUT_OF_RANGE`,
211	trace.StatusCodeUnimplemented:      `UNIMPLEMENTED`,
212	trace.StatusCodeInternal:           `INTERNAL`,
213	trace.StatusCodeUnavailable:        `UNAVAILABLE`,
214	trace.StatusCodeDataLoss:           `DATA_LOSS`,
215	trace.StatusCodeUnauthenticated:    `UNAUTHENTICATED`,
216}
217
218func isHealthEndpoint(path string) bool {
219	// Health checking is pretty frequent and
220	// traces collected for health endpoints
221	// can be extremely noisy and expensive.
222	// Disable canonical health checking endpoints
223	// like /healthz and /_ah/health for now.
224	if path == "/healthz" || path == "/_ah/health" {
225		return true
226	}
227	return false
228}
229