1// Copyright 2018, OpenCensus Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package ochttp
16
17import (
18	"context"
19	"io"
20	"net/http"
21	"strconv"
22	"sync"
23	"time"
24
25	"go.opencensus.io/stats"
26	"go.opencensus.io/tag"
27)
28
29// statsTransport is an http.RoundTripper that collects stats for the outgoing requests.
30type statsTransport struct {
31	base http.RoundTripper
32}
33
34// RoundTrip implements http.RoundTripper, delegating to Base and recording stats for the request.
35func (t statsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
36	ctx, _ := tag.New(req.Context(),
37		tag.Upsert(KeyClientHost, req.Host),
38		tag.Upsert(Host, req.Host),
39		tag.Upsert(KeyClientPath, req.URL.Path),
40		tag.Upsert(Path, req.URL.Path),
41		tag.Upsert(KeyClientMethod, req.Method),
42		tag.Upsert(Method, req.Method))
43	req = req.WithContext(ctx)
44	track := &tracker{
45		start: time.Now(),
46		ctx:   ctx,
47	}
48	if req.Body == nil {
49		// TODO: Handle cases where ContentLength is not set.
50		track.reqSize = -1
51	} else if req.ContentLength > 0 {
52		track.reqSize = req.ContentLength
53	}
54	stats.Record(ctx, ClientRequestCount.M(1))
55
56	// Perform request.
57	resp, err := t.base.RoundTrip(req)
58
59	if err != nil {
60		track.statusCode = http.StatusInternalServerError
61		track.end()
62	} else {
63		track.statusCode = resp.StatusCode
64		if req.Method != "HEAD" {
65			track.respContentLength = resp.ContentLength
66		}
67		if resp.Body == nil {
68			track.end()
69		} else {
70			track.body = resp.Body
71			resp.Body = wrappedBody(track, resp.Body)
72		}
73	}
74	return resp, err
75}
76
77// CancelRequest cancels an in-flight request by closing its connection.
78func (t statsTransport) CancelRequest(req *http.Request) {
79	type canceler interface {
80		CancelRequest(*http.Request)
81	}
82	if cr, ok := t.base.(canceler); ok {
83		cr.CancelRequest(req)
84	}
85}
86
87type tracker struct {
88	ctx               context.Context
89	respSize          int64
90	respContentLength int64
91	reqSize           int64
92	start             time.Time
93	body              io.ReadCloser
94	statusCode        int
95	endOnce           sync.Once
96}
97
98var _ io.ReadCloser = (*tracker)(nil)
99
100func (t *tracker) end() {
101	t.endOnce.Do(func() {
102		latencyMs := float64(time.Since(t.start)) / float64(time.Millisecond)
103		respSize := t.respSize
104		if t.respSize == 0 && t.respContentLength > 0 {
105			respSize = t.respContentLength
106		}
107		m := []stats.Measurement{
108			ClientSentBytes.M(t.reqSize),
109			ClientReceivedBytes.M(respSize),
110			ClientRoundtripLatency.M(latencyMs),
111			ClientLatency.M(latencyMs),
112			ClientResponseBytes.M(t.respSize),
113		}
114		if t.reqSize >= 0 {
115			m = append(m, ClientRequestBytes.M(t.reqSize))
116		}
117
118		stats.RecordWithTags(t.ctx, []tag.Mutator{
119			tag.Upsert(StatusCode, strconv.Itoa(t.statusCode)),
120			tag.Upsert(KeyClientStatus, strconv.Itoa(t.statusCode)),
121		}, m...)
122	})
123}
124
125func (t *tracker) Read(b []byte) (int, error) {
126	n, err := t.body.Read(b)
127	t.respSize += int64(n)
128	switch err {
129	case nil:
130		return n, nil
131	case io.EOF:
132		t.end()
133	}
134	return n, err
135}
136
137func (t *tracker) Close() error {
138	// Invoking endSpan on Close will help catch the cases
139	// in which a read returned a non-nil error, we set the
140	// span status but didn't end the span.
141	t.end()
142	return t.body.Close()
143}
144