1package main
2
3import (
4	"bytes"
5	"context"
6	"crypto/tls"
7	"fmt"
8	"io"
9	"io/ioutil"
10	"net/http/httptrace"
11	"strings"
12	"time"
13
14	"github.com/aws/aws-sdk-go/aws/request"
15)
16
17// RequestLatency provides latencies for the SDK API request and its attempts.
18type RequestLatency struct {
19	Latency  time.Duration
20	Validate time.Duration
21	Build    time.Duration
22
23	Attempts []RequestAttemptLatency
24}
25
26// RequestAttemptLatency provides latencies for an individual request attempt.
27type RequestAttemptLatency struct {
28	Latency time.Duration
29	Err     error
30
31	Sign time.Duration
32	Send time.Duration
33
34	HTTP HTTPLatency
35
36	Unmarshal      time.Duration
37	UnmarshalError time.Duration
38
39	WillRetry bool
40	Retry     time.Duration
41}
42
43// HTTPLatency provides latencies for an HTTP request.
44type HTTPLatency struct {
45	Latency    time.Duration
46	ConnReused bool
47
48	GetConn time.Duration
49
50	DNS     time.Duration
51	Connect time.Duration
52	TLS     time.Duration
53
54	WriteHeader           time.Duration
55	WriteRequest          time.Duration
56	WaitResponseFirstByte time.Duration
57	ReadHeader            time.Duration
58	ReadBody              time.Duration
59}
60
61// RequestTrace provides the structure to store SDK request attempt latencies.
62// Use TraceRequest as a API operation request option to capture trace metrics
63// for the individual request.
64type RequestTrace struct {
65	Start, Finish time.Time
66
67	ValidateStart, ValidateDone time.Time
68	BuildStart, BuildDone       time.Time
69
70	ReadResponseBody bool
71
72	Attempts []*RequestAttemptTrace
73}
74
75// Latency returns the latencies of the request trace components.
76func (t RequestTrace) Latency() RequestLatency {
77	var attempts []RequestAttemptLatency
78	for _, a := range t.Attempts {
79		attempts = append(attempts, a.Latency())
80	}
81
82	latency := RequestLatency{
83		Latency:  safeTimeDelta(t.Start, t.Finish),
84		Validate: safeTimeDelta(t.ValidateStart, t.ValidateDone),
85		Build:    safeTimeDelta(t.BuildStart, t.BuildDone),
86		Attempts: attempts,
87	}
88
89	return latency
90}
91
92// TraceRequest is a SDK request Option that will add request handlers to an
93// individual request to track request latencies per attempt. Must be used only
94// for a single request call per RequestTrace value.
95func (t *RequestTrace) TraceRequest(r *request.Request) {
96	t.Start = time.Now()
97	r.Handlers.Complete.PushBack(t.onComplete)
98
99	r.Handlers.Validate.PushFront(t.onValidateStart)
100	r.Handlers.Validate.PushBack(t.onValidateDone)
101
102	r.Handlers.Build.PushFront(t.onBuildStart)
103	r.Handlers.Build.PushBack(t.onBuildDone)
104
105	var attempt *RequestAttemptTrace
106
107	// Signing and Start attempt
108	r.Handlers.Sign.PushFront(func(rr *request.Request) {
109		attempt = &RequestAttemptTrace{Start: time.Now()}
110		attempt.SignStart = attempt.Start
111	})
112	r.Handlers.Sign.PushBack(func(rr *request.Request) {
113		attempt.SignDone = time.Now()
114	})
115
116	// Ensure that the http trace added to the request always uses the original
117	// context instead of each following attempt's context to prevent conflict
118	// with previous http traces used.
119	origContext := r.Context()
120
121	// Send
122	r.Handlers.Send.PushFront(func(rr *request.Request) {
123		attempt.SendStart = time.Now()
124		attempt.HTTPTrace = NewHTTPTrace(origContext)
125		rr.SetContext(attempt.HTTPTrace)
126	})
127	r.Handlers.Send.PushBack(func(rr *request.Request) {
128		attempt.SendDone = time.Now()
129		defer func() {
130			attempt.HTTPTrace.Finish = time.Now()
131		}()
132
133		if rr.Error != nil {
134			return
135		}
136
137		attempt.HTTPTrace.ReadHeaderDone = time.Now()
138		if t.ReadResponseBody {
139			attempt.HTTPTrace.ReadBodyStart = time.Now()
140			var w bytes.Buffer
141			if _, err := io.Copy(&w, rr.HTTPResponse.Body); err != nil {
142				rr.Error = err
143				return
144			}
145			rr.HTTPResponse.Body.Close()
146			rr.HTTPResponse.Body = ioutil.NopCloser(&w)
147
148			attempt.HTTPTrace.ReadBodyDone = time.Now()
149		}
150	})
151
152	// Unmarshal
153	r.Handlers.Unmarshal.PushFront(func(rr *request.Request) {
154		attempt.UnmarshalStart = time.Now()
155	})
156	r.Handlers.Unmarshal.PushBack(func(rr *request.Request) {
157		attempt.UnmarshalDone = time.Now()
158	})
159
160	// Unmarshal Error
161	r.Handlers.UnmarshalError.PushFront(func(rr *request.Request) {
162		attempt.UnmarshalErrorStart = time.Now()
163	})
164	r.Handlers.UnmarshalError.PushBack(func(rr *request.Request) {
165		attempt.UnmarshalErrorDone = time.Now()
166	})
167
168	// Retry handling and delay
169	r.Handlers.Retry.PushFront(func(rr *request.Request) {
170		attempt.RetryStart = time.Now()
171		attempt.Err = rr.Error
172	})
173	r.Handlers.AfterRetry.PushBack(func(rr *request.Request) {
174		attempt.RetryDone = time.Now()
175		attempt.WillRetry = rr.WillRetry()
176	})
177
178	// Complete Attempt
179	r.Handlers.CompleteAttempt.PushBack(func(rr *request.Request) {
180		attempt.Finish = time.Now()
181		t.Attempts = append(t.Attempts, attempt)
182	})
183}
184
185func (t *RequestTrace) String() string {
186	var w strings.Builder
187
188	l := t.Latency()
189	writeDurField(&w, "Latency", l.Latency)
190	writeDurField(&w, "Validate", l.Validate)
191	writeDurField(&w, "Build", l.Build)
192	writeField(&w, "Attempts", "%d", len(t.Attempts))
193
194	for i, a := range t.Attempts {
195		fmt.Fprintf(&w, "\n\tAttempt: %d, %s", i, a)
196	}
197
198	return w.String()
199}
200
201func (t *RequestTrace) onComplete(*request.Request) {
202	t.Finish = time.Now()
203}
204func (t *RequestTrace) onValidateStart(*request.Request) { t.ValidateStart = time.Now() }
205func (t *RequestTrace) onValidateDone(*request.Request)  { t.ValidateDone = time.Now() }
206func (t *RequestTrace) onBuildStart(*request.Request)    { t.BuildStart = time.Now() }
207func (t *RequestTrace) onBuildDone(*request.Request)     { t.BuildDone = time.Now() }
208
209// RequestAttemptTrace provides a structure for storing trace information on
210// the SDK's request attempt.
211type RequestAttemptTrace struct {
212	Start, Finish time.Time
213	Err           error
214
215	SignStart, SignDone time.Time
216
217	SendStart, SendDone time.Time
218	HTTPTrace           *HTTPTrace
219
220	UnmarshalStart, UnmarshalDone           time.Time
221	UnmarshalErrorStart, UnmarshalErrorDone time.Time
222
223	WillRetry             bool
224	RetryStart, RetryDone time.Time
225}
226
227// Latency returns the latencies of the request attempt trace components.
228func (t *RequestAttemptTrace) Latency() RequestAttemptLatency {
229	return RequestAttemptLatency{
230		Latency: safeTimeDelta(t.Start, t.Finish),
231		Err:     t.Err,
232
233		Sign: safeTimeDelta(t.SignStart, t.SignDone),
234		Send: safeTimeDelta(t.SendStart, t.SendDone),
235
236		HTTP: t.HTTPTrace.Latency(),
237
238		Unmarshal:      safeTimeDelta(t.UnmarshalStart, t.UnmarshalDone),
239		UnmarshalError: safeTimeDelta(t.UnmarshalErrorStart, t.UnmarshalErrorDone),
240
241		WillRetry: t.WillRetry,
242		Retry:     safeTimeDelta(t.RetryStart, t.RetryDone),
243	}
244}
245
246func (t *RequestAttemptTrace) String() string {
247	var w strings.Builder
248
249	l := t.Latency()
250	writeDurField(&w, "Latency", l.Latency)
251	writeDurField(&w, "Sign", l.Sign)
252	writeDurField(&w, "Send", l.Send)
253
254	writeDurField(&w, "Unmarshal", l.Unmarshal)
255	writeDurField(&w, "UnmarshalError", l.UnmarshalError)
256
257	writeField(&w, "WillRetry", "%t", l.WillRetry)
258	writeDurField(&w, "Retry", l.Retry)
259
260	fmt.Fprintf(&w, "\n\t\tHTTP: %s", t.HTTPTrace)
261	if t.Err != nil {
262		fmt.Fprintf(&w, "\n\t\tError: %v", t.Err)
263	}
264
265	return w.String()
266}
267
268// HTTPTrace provides the trace time stamps of the HTTP request's segments.
269type HTTPTrace struct {
270	context.Context
271
272	Start, Finish time.Time
273
274	GetConnStart, GetConnDone time.Time
275	Reused                    bool
276
277	DNSStart, DNSDone                   time.Time
278	ConnectStart, ConnectDone           time.Time
279	TLSHandshakeStart, TLSHandshakeDone time.Time
280	WriteHeaderDone                     time.Time
281	WriteRequestDone                    time.Time
282	FirstResponseByte                   time.Time
283
284	ReadHeaderStart, ReadHeaderDone time.Time
285	ReadBodyStart, ReadBodyDone     time.Time
286}
287
288// NewHTTPTrace returns a initialized HTTPTrace for an
289// httptrace.ClientTrace, based on the context passed.
290func NewHTTPTrace(ctx context.Context) *HTTPTrace {
291	t := &HTTPTrace{
292		Start: time.Now(),
293	}
294
295	trace := &httptrace.ClientTrace{
296		GetConn:              t.getConn,
297		GotConn:              t.gotConn,
298		PutIdleConn:          t.putIdleConn,
299		GotFirstResponseByte: t.gotFirstResponseByte,
300		Got100Continue:       t.got100Continue,
301		DNSStart:             t.dnsStart,
302		DNSDone:              t.dnsDone,
303		ConnectStart:         t.connectStart,
304		ConnectDone:          t.connectDone,
305		TLSHandshakeStart:    t.tlsHandshakeStart,
306		TLSHandshakeDone:     t.tlsHandshakeDone,
307		WroteHeaders:         t.wroteHeaders,
308		Wait100Continue:      t.wait100Continue,
309		WroteRequest:         t.wroteRequest,
310	}
311
312	t.Context = httptrace.WithClientTrace(ctx, trace)
313
314	return t
315}
316
317// Latency returns the latencies for an HTTP request.
318func (t *HTTPTrace) Latency() HTTPLatency {
319	latency := HTTPLatency{
320		Latency:    safeTimeDelta(t.Start, t.Finish),
321		ConnReused: t.Reused,
322
323		WriteHeader:           safeTimeDelta(t.GetConnDone, t.WriteHeaderDone),
324		WriteRequest:          safeTimeDelta(t.GetConnDone, t.WriteRequestDone),
325		WaitResponseFirstByte: safeTimeDelta(t.WriteRequestDone, t.FirstResponseByte),
326		ReadHeader:            safeTimeDelta(t.ReadHeaderStart, t.ReadHeaderDone),
327		ReadBody:              safeTimeDelta(t.ReadBodyStart, t.ReadBodyDone),
328	}
329
330	if !t.Reused {
331		latency.GetConn = safeTimeDelta(t.GetConnStart, t.GetConnDone)
332		latency.DNS = safeTimeDelta(t.DNSStart, t.DNSDone)
333		latency.Connect = safeTimeDelta(t.ConnectStart, t.ConnectDone)
334		latency.TLS = safeTimeDelta(t.TLSHandshakeStart, t.TLSHandshakeDone)
335	} else {
336		latency.GetConn = safeTimeDelta(t.Start, t.GetConnDone)
337	}
338
339	return latency
340}
341
342func (t *HTTPTrace) String() string {
343	var w strings.Builder
344
345	l := t.Latency()
346	writeDurField(&w, "Latency", l.Latency)
347	writeField(&w, "ConnReused", "%t", l.ConnReused)
348	writeDurField(&w, "GetConn", l.GetConn)
349
350	writeDurField(&w, "WriteHeader", l.WriteHeader)
351	writeDurField(&w, "WriteRequest", l.WriteRequest)
352	writeDurField(&w, "WaitResponseFirstByte", l.WaitResponseFirstByte)
353	writeDurField(&w, "ReadHeader", l.ReadHeader)
354	writeDurField(&w, "ReadBody", l.ReadBody)
355
356	if !t.Reused {
357		fmt.Fprintf(&w, "\n\t\t\tConn: ")
358		writeDurField(&w, "DNS", l.DNS)
359		writeDurField(&w, "Connect", l.Connect)
360		writeDurField(&w, "TLS", l.TLS)
361	}
362
363	return w.String()
364}
365
366func (t *HTTPTrace) getConn(hostPort string) {
367	t.GetConnStart = time.Now()
368}
369func (t *HTTPTrace) gotConn(info httptrace.GotConnInfo) {
370	t.GetConnDone = time.Now()
371	t.Reused = info.Reused
372}
373func (t *HTTPTrace) putIdleConn(err error) {}
374func (t *HTTPTrace) gotFirstResponseByte() {
375	t.FirstResponseByte = time.Now()
376	t.ReadHeaderStart = t.FirstResponseByte
377}
378func (t *HTTPTrace) got100Continue() {}
379func (t *HTTPTrace) dnsStart(info httptrace.DNSStartInfo) {
380	t.DNSStart = time.Now()
381}
382func (t *HTTPTrace) dnsDone(info httptrace.DNSDoneInfo) {
383	t.DNSDone = time.Now()
384}
385func (t *HTTPTrace) connectStart(network, addr string) {
386	t.ConnectStart = time.Now()
387}
388func (t *HTTPTrace) connectDone(network, addr string, err error) {
389	t.ConnectDone = time.Now()
390}
391func (t *HTTPTrace) tlsHandshakeStart() {
392	t.TLSHandshakeStart = time.Now()
393}
394func (t *HTTPTrace) tlsHandshakeDone(state tls.ConnectionState, err error) {
395	t.TLSHandshakeDone = time.Now()
396}
397func (t *HTTPTrace) wroteHeaders() {
398	t.WriteHeaderDone = time.Now()
399}
400func (t *HTTPTrace) wait100Continue() {}
401func (t *HTTPTrace) wroteRequest(info httptrace.WroteRequestInfo) {
402	t.WriteRequestDone = time.Now()
403}
404
405func safeTimeDelta(start, end time.Time) time.Duration {
406	if start.IsZero() || end.IsZero() || start.After(end) {
407		return 0
408	}
409
410	return end.Sub(start)
411}
412
413func writeField(w io.Writer, field string, format string, args ...interface{}) error {
414	_, err := fmt.Fprintf(w, "%s: "+format+", ", append([]interface{}{field}, args...)...)
415	return err
416}
417
418func writeDurField(w io.Writer, field string, dur time.Duration) error {
419	if dur == 0 {
420		return nil
421	}
422
423	_, err := fmt.Fprintf(w, "%s: %s, ", field, dur)
424	return err
425}
426