1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build go1.7
6
7package http2
8
9import (
10	"context"
11	"net"
12	"net/http"
13	"net/http/httptrace"
14	"time"
15)
16
17type contextContext interface {
18	context.Context
19}
20
21func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
22	ctx, cancel = context.WithCancel(context.Background())
23	ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr())
24	if hs := opts.baseConfig(); hs != nil {
25		ctx = context.WithValue(ctx, http.ServerContextKey, hs)
26	}
27	return
28}
29
30func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) {
31	return context.WithCancel(ctx)
32}
33
34func requestWithContext(req *http.Request, ctx contextContext) *http.Request {
35	return req.WithContext(ctx)
36}
37
38type clientTrace httptrace.ClientTrace
39
40func reqContext(r *http.Request) context.Context { return r.Context() }
41
42func (t *Transport) idleConnTimeout() time.Duration {
43	if t.t1 != nil {
44		return t.t1.IdleConnTimeout
45	}
46	return 0
47}
48
49func setResponseUncompressed(res *http.Response) { res.Uncompressed = true }
50
51func traceGotConn(req *http.Request, cc *ClientConn) {
52	trace := httptrace.ContextClientTrace(req.Context())
53	if trace == nil || trace.GotConn == nil {
54		return
55	}
56	ci := httptrace.GotConnInfo{Conn: cc.tconn}
57	cc.mu.Lock()
58	ci.Reused = cc.nextStreamID > 1
59	ci.WasIdle = len(cc.streams) == 0 && ci.Reused
60	if ci.WasIdle && !cc.lastActive.IsZero() {
61		ci.IdleTime = time.Now().Sub(cc.lastActive)
62	}
63	cc.mu.Unlock()
64
65	trace.GotConn(ci)
66}
67
68func traceWroteHeaders(trace *clientTrace) {
69	if trace != nil && trace.WroteHeaders != nil {
70		trace.WroteHeaders()
71	}
72}
73
74func traceGot100Continue(trace *clientTrace) {
75	if trace != nil && trace.Got100Continue != nil {
76		trace.Got100Continue()
77	}
78}
79
80func traceWait100Continue(trace *clientTrace) {
81	if trace != nil && trace.Wait100Continue != nil {
82		trace.Wait100Continue()
83	}
84}
85
86func traceWroteRequest(trace *clientTrace, err error) {
87	if trace != nil && trace.WroteRequest != nil {
88		trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
89	}
90}
91
92func traceFirstResponseByte(trace *clientTrace) {
93	if trace != nil && trace.GotFirstResponseByte != nil {
94		trace.GotFirstResponseByte()
95	}
96}
97
98func requestTrace(req *http.Request) *clientTrace {
99	trace := httptrace.ContextClientTrace(req.Context())
100	return (*clientTrace)(trace)
101}
102
103// Ping sends a PING frame to the server and waits for the ack.
104func (cc *ClientConn) Ping(ctx context.Context) error {
105	return cc.ping(ctx)
106}
107