1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Transport code's client connection pooling.
6
7package http2
8
9import (
10	"crypto/tls"
11	"net/http"
12	"sync"
13)
14
15// ClientConnPool manages a pool of HTTP/2 client connections.
16type ClientConnPool interface {
17	GetClientConn(req *http.Request, addr string) (*ClientConn, error)
18	MarkDead(*ClientConn)
19}
20
21// TODO: use singleflight for dialing and addConnCalls?
22type clientConnPool struct {
23	t *Transport
24
25	mu sync.Mutex // TODO: maybe switch to RWMutex
26	// TODO: add support for sharing conns based on cert names
27	// (e.g. share conn for googleapis.com and appspot.com)
28	conns        map[string][]*ClientConn // key is host:port
29	dialing      map[string]*dialCall     // currently in-flight dials
30	keys         map[*ClientConn][]string
31	addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls
32}
33
34func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
35	return p.getClientConn(req, addr, dialOnMiss)
36}
37
38const (
39	dialOnMiss   = true
40	noDialOnMiss = false
41)
42
43func (p *clientConnPool) getClientConn(_ *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
44	p.mu.Lock()
45	for _, cc := range p.conns[addr] {
46		if cc.CanTakeNewRequest() {
47			p.mu.Unlock()
48			return cc, nil
49		}
50	}
51	if !dialOnMiss {
52		p.mu.Unlock()
53		return nil, ErrNoCachedConn
54	}
55	call := p.getStartDialLocked(addr)
56	p.mu.Unlock()
57	<-call.done
58	return call.res, call.err
59}
60
61// dialCall is an in-flight Transport dial call to a host.
62type dialCall struct {
63	p    *clientConnPool
64	done chan struct{} // closed when done
65	res  *ClientConn   // valid after done is closed
66	err  error         // valid after done is closed
67}
68
69// requires p.mu is held.
70func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
71	if call, ok := p.dialing[addr]; ok {
72		// A dial is already in-flight. Don't start another.
73		return call
74	}
75	call := &dialCall{p: p, done: make(chan struct{})}
76	if p.dialing == nil {
77		p.dialing = make(map[string]*dialCall)
78	}
79	p.dialing[addr] = call
80	go call.dial(addr)
81	return call
82}
83
84// run in its own goroutine.
85func (c *dialCall) dial(addr string) {
86	c.res, c.err = c.p.t.dialClientConn(addr)
87	close(c.done)
88
89	c.p.mu.Lock()
90	delete(c.p.dialing, addr)
91	if c.err == nil {
92		c.p.addConnLocked(addr, c.res)
93	}
94	c.p.mu.Unlock()
95}
96
97// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
98// already exist. It coalesces concurrent calls with the same key.
99// This is used by the http1 Transport code when it creates a new connection. Because
100// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
101// the protocol), it can get into a situation where it has multiple TLS connections.
102// This code decides which ones live or die.
103// The return value used is whether c was used.
104// c is never closed.
105func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) {
106	p.mu.Lock()
107	for _, cc := range p.conns[key] {
108		if cc.CanTakeNewRequest() {
109			p.mu.Unlock()
110			return false, nil
111		}
112	}
113	call, dup := p.addConnCalls[key]
114	if !dup {
115		if p.addConnCalls == nil {
116			p.addConnCalls = make(map[string]*addConnCall)
117		}
118		call = &addConnCall{
119			p:    p,
120			done: make(chan struct{}),
121		}
122		p.addConnCalls[key] = call
123		go call.run(t, key, c)
124	}
125	p.mu.Unlock()
126
127	<-call.done
128	if call.err != nil {
129		return false, call.err
130	}
131	return !dup, nil
132}
133
134type addConnCall struct {
135	p    *clientConnPool
136	done chan struct{} // closed when done
137	err  error
138}
139
140func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
141	cc, err := t.NewClientConn(tc)
142
143	p := c.p
144	p.mu.Lock()
145	if err != nil {
146		c.err = err
147	} else {
148		p.addConnLocked(key, cc)
149	}
150	delete(p.addConnCalls, key)
151	p.mu.Unlock()
152	close(c.done)
153}
154
155func (p *clientConnPool) addConn(key string, cc *ClientConn) {
156	p.mu.Lock()
157	p.addConnLocked(key, cc)
158	p.mu.Unlock()
159}
160
161// p.mu must be held
162func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
163	for _, v := range p.conns[key] {
164		if v == cc {
165			return
166		}
167	}
168	if p.conns == nil {
169		p.conns = make(map[string][]*ClientConn)
170	}
171	if p.keys == nil {
172		p.keys = make(map[*ClientConn][]string)
173	}
174	p.conns[key] = append(p.conns[key], cc)
175	p.keys[cc] = append(p.keys[cc], key)
176}
177
178func (p *clientConnPool) MarkDead(cc *ClientConn) {
179	p.mu.Lock()
180	defer p.mu.Unlock()
181	for _, key := range p.keys[cc] {
182		vv, ok := p.conns[key]
183		if !ok {
184			continue
185		}
186		newList := filterOutClientConn(vv, cc)
187		if len(newList) > 0 {
188			p.conns[key] = newList
189		} else {
190			delete(p.conns, key)
191		}
192	}
193	delete(p.keys, cc)
194}
195
196func (p *clientConnPool) closeIdleConnections() {
197	p.mu.Lock()
198	defer p.mu.Unlock()
199	// TODO: don't close a cc if it was just added to the pool
200	// milliseconds ago and has never been used. There's currently
201	// a small race window with the HTTP/1 Transport's integration
202	// where it can add an idle conn just before using it, and
203	// somebody else can concurrently call CloseIdleConns and
204	// break some caller's RoundTrip.
205	for _, vv := range p.conns {
206		for _, cc := range vv {
207			cc.closeIfIdle()
208		}
209	}
210}
211
212func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
213	out := in[:0]
214	for _, v := range in {
215		if v != exclude {
216			out = append(out, v)
217		}
218	}
219	// If we filtered it out, zero out the last item to prevent
220	// the GC from seeing it.
221	if len(in) != len(out) {
222		in[len(in)-1] = nil
223	}
224	return out
225}
226