1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Transport code's client connection pooling.
6
7package http2
8
9import (
10	"crypto/tls"
11	"net/http"
12	"sync"
13)
14
15// ClientConnPool manages a pool of HTTP/2 client connections.
16type ClientConnPool interface {
17	GetClientConn(req *http.Request, addr string) (*ClientConn, error)
18	MarkDead(*ClientConn)
19}
20
21// clientConnPoolIdleCloser is the interface implemented by ClientConnPool
22// implementations which can close their idle connections.
23type clientConnPoolIdleCloser interface {
24	ClientConnPool
25	closeIdleConnections()
26}
27
28var (
29	_ clientConnPoolIdleCloser = (*clientConnPool)(nil)
30	_ clientConnPoolIdleCloser = noDialClientConnPool{}
31)
32
33// TODO: use singleflight for dialing and addConnCalls?
34type clientConnPool struct {
35	t *Transport
36
37	mu sync.Mutex // TODO: maybe switch to RWMutex
38	// TODO: add support for sharing conns based on cert names
39	// (e.g. share conn for googleapis.com and appspot.com)
40	conns        map[string][]*ClientConn // key is host:port
41	dialing      map[string]*dialCall     // currently in-flight dials
42	keys         map[*ClientConn][]string
43	addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls
44}
45
46func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
47	return p.getClientConn(req, addr, dialOnMiss)
48}
49
50const (
51	dialOnMiss   = true
52	noDialOnMiss = false
53)
54
55// shouldTraceGetConn reports whether getClientConn should call any
56// ClientTrace.GetConn hook associated with the http.Request.
57//
58// This complexity is needed to avoid double calls of the GetConn hook
59// during the back-and-forth between net/http and x/net/http2 (when the
60// net/http.Transport is upgraded to also speak http2), as well as support
61// the case where x/net/http2 is being used directly.
62func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
63	// If our Transport wasn't made via ConfigureTransport, always
64	// trace the GetConn hook if provided, because that means the
65	// http2 package is being used directly and it's the one
66	// dialing, as opposed to net/http.
67	if _, ok := p.t.ConnPool.(noDialClientConnPool); !ok {
68		return true
69	}
70	// Otherwise, only use the GetConn hook if this connection has
71	// been used previously for other requests. For fresh
72	// connections, the net/http package does the dialing.
73	return !st.freshConn
74}
75
76func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
77	if isConnectionCloseRequest(req) && dialOnMiss {
78		// It gets its own connection.
79		traceGetConn(req, addr)
80		const singleUse = true
81		cc, err := p.t.dialClientConn(addr, singleUse)
82		if err != nil {
83			return nil, err
84		}
85		return cc, nil
86	}
87	p.mu.Lock()
88	for _, cc := range p.conns[addr] {
89		if st := cc.idleState(); st.canTakeNewRequest {
90			if p.shouldTraceGetConn(st) {
91				traceGetConn(req, addr)
92			}
93			p.mu.Unlock()
94			return cc, nil
95		}
96	}
97	if !dialOnMiss {
98		p.mu.Unlock()
99		return nil, ErrNoCachedConn
100	}
101	traceGetConn(req, addr)
102	call := p.getStartDialLocked(addr)
103	p.mu.Unlock()
104	<-call.done
105	return call.res, call.err
106}
107
108// dialCall is an in-flight Transport dial call to a host.
109type dialCall struct {
110	_    incomparable
111	p    *clientConnPool
112	done chan struct{} // closed when done
113	res  *ClientConn   // valid after done is closed
114	err  error         // valid after done is closed
115}
116
117// requires p.mu is held.
118func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
119	if call, ok := p.dialing[addr]; ok {
120		// A dial is already in-flight. Don't start another.
121		return call
122	}
123	call := &dialCall{p: p, done: make(chan struct{})}
124	if p.dialing == nil {
125		p.dialing = make(map[string]*dialCall)
126	}
127	p.dialing[addr] = call
128	go call.dial(addr)
129	return call
130}
131
132// run in its own goroutine.
133func (c *dialCall) dial(addr string) {
134	const singleUse = false // shared conn
135	c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
136	close(c.done)
137
138	c.p.mu.Lock()
139	delete(c.p.dialing, addr)
140	if c.err == nil {
141		c.p.addConnLocked(addr, c.res)
142	}
143	c.p.mu.Unlock()
144}
145
146// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
147// already exist. It coalesces concurrent calls with the same key.
148// This is used by the http1 Transport code when it creates a new connection. Because
149// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
150// the protocol), it can get into a situation where it has multiple TLS connections.
151// This code decides which ones live or die.
152// The return value used is whether c was used.
153// c is never closed.
154func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) {
155	p.mu.Lock()
156	for _, cc := range p.conns[key] {
157		if cc.CanTakeNewRequest() {
158			p.mu.Unlock()
159			return false, nil
160		}
161	}
162	call, dup := p.addConnCalls[key]
163	if !dup {
164		if p.addConnCalls == nil {
165			p.addConnCalls = make(map[string]*addConnCall)
166		}
167		call = &addConnCall{
168			p:    p,
169			done: make(chan struct{}),
170		}
171		p.addConnCalls[key] = call
172		go call.run(t, key, c)
173	}
174	p.mu.Unlock()
175
176	<-call.done
177	if call.err != nil {
178		return false, call.err
179	}
180	return !dup, nil
181}
182
183type addConnCall struct {
184	_    incomparable
185	p    *clientConnPool
186	done chan struct{} // closed when done
187	err  error
188}
189
190func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
191	cc, err := t.NewClientConn(tc)
192
193	p := c.p
194	p.mu.Lock()
195	if err != nil {
196		c.err = err
197	} else {
198		p.addConnLocked(key, cc)
199	}
200	delete(p.addConnCalls, key)
201	p.mu.Unlock()
202	close(c.done)
203}
204
205// p.mu must be held
206func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
207	for _, v := range p.conns[key] {
208		if v == cc {
209			return
210		}
211	}
212	if p.conns == nil {
213		p.conns = make(map[string][]*ClientConn)
214	}
215	if p.keys == nil {
216		p.keys = make(map[*ClientConn][]string)
217	}
218	p.conns[key] = append(p.conns[key], cc)
219	p.keys[cc] = append(p.keys[cc], key)
220}
221
222func (p *clientConnPool) MarkDead(cc *ClientConn) {
223	p.mu.Lock()
224	defer p.mu.Unlock()
225	for _, key := range p.keys[cc] {
226		vv, ok := p.conns[key]
227		if !ok {
228			continue
229		}
230		newList := filterOutClientConn(vv, cc)
231		if len(newList) > 0 {
232			p.conns[key] = newList
233		} else {
234			delete(p.conns, key)
235		}
236	}
237	delete(p.keys, cc)
238}
239
240func (p *clientConnPool) closeIdleConnections() {
241	p.mu.Lock()
242	defer p.mu.Unlock()
243	// TODO: don't close a cc if it was just added to the pool
244	// milliseconds ago and has never been used. There's currently
245	// a small race window with the HTTP/1 Transport's integration
246	// where it can add an idle conn just before using it, and
247	// somebody else can concurrently call CloseIdleConns and
248	// break some caller's RoundTrip.
249	for _, vv := range p.conns {
250		for _, cc := range vv {
251			cc.closeIfIdle()
252		}
253	}
254}
255
256func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
257	out := in[:0]
258	for _, v := range in {
259		if v != exclude {
260			out = append(out, v)
261		}
262	}
263	// If we filtered it out, zero out the last item to prevent
264	// the GC from seeing it.
265	if len(in) != len(out) {
266		in[len(in)-1] = nil
267	}
268	return out
269}
270
271// noDialClientConnPool is an implementation of http2.ClientConnPool
272// which never dials. We let the HTTP/1.1 client dial and use its TLS
273// connection instead.
274type noDialClientConnPool struct{ *clientConnPool }
275
276func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
277	return p.getClientConn(req, addr, noDialOnMiss)
278}
279