1package dependency
2
3import (
4	"crypto/tls"
5	"fmt"
6	"log"
7	"net"
8	"net/http"
9	"sync"
10	"time"
11
12	consulapi "github.com/hashicorp/consul/api"
13	rootcerts "github.com/hashicorp/go-rootcerts"
14	vaultapi "github.com/hashicorp/vault/api"
15)
16
17// ClientSet is a collection of clients that dependencies use to communicate
18// with remote services like Consul or Vault.
19type ClientSet struct {
20	sync.RWMutex
21
22	vault  *vaultClient
23	consul *consulClient
24}
25
26// consulClient is a wrapper around a real Consul API client.
27type consulClient struct {
28	client    *consulapi.Client
29	transport *http.Transport
30}
31
32// vaultClient is a wrapper around a real Vault API client.
33type vaultClient struct {
34	client     *vaultapi.Client
35	httpClient *http.Client
36}
37
38// CreateConsulClientInput is used as input to the CreateConsulClient function.
39type CreateConsulClientInput struct {
40	Address      string
41	Namespace    string
42	Token        string
43	AuthEnabled  bool
44	AuthUsername string
45	AuthPassword string
46	SSLEnabled   bool
47	SSLVerify    bool
48	SSLCert      string
49	SSLKey       string
50	SSLCACert    string
51	SSLCAPath    string
52	ServerName   string
53
54	TransportDialKeepAlive       time.Duration
55	TransportDialTimeout         time.Duration
56	TransportDisableKeepAlives   bool
57	TransportIdleConnTimeout     time.Duration
58	TransportMaxIdleConns        int
59	TransportMaxIdleConnsPerHost int
60	TransportTLSHandshakeTimeout time.Duration
61}
62
63// CreateVaultClientInput is used as input to the CreateVaultClient function.
64type CreateVaultClientInput struct {
65	Address     string
66	Namespace   string
67	Token       string
68	UnwrapToken bool
69	SSLEnabled  bool
70	SSLVerify   bool
71	SSLCert     string
72	SSLKey      string
73	SSLCACert   string
74	SSLCAPath   string
75	ServerName  string
76
77	TransportDialKeepAlive       time.Duration
78	TransportDialTimeout         time.Duration
79	TransportDisableKeepAlives   bool
80	TransportIdleConnTimeout     time.Duration
81	TransportMaxIdleConns        int
82	TransportMaxIdleConnsPerHost int
83	TransportTLSHandshakeTimeout time.Duration
84}
85
86// NewClientSet creates a new client set that is ready to accept clients.
87func NewClientSet() *ClientSet {
88	return &ClientSet{}
89}
90
91// CreateConsulClient creates a new Consul API client from the given input.
92func (c *ClientSet) CreateConsulClient(i *CreateConsulClientInput) error {
93	consulConfig := consulapi.DefaultConfig()
94
95	if i.Address != "" {
96		consulConfig.Address = i.Address
97	}
98
99	if i.Namespace != "" {
100		consulConfig.Namespace = i.Namespace
101	}
102
103	if i.Token != "" {
104		consulConfig.Token = i.Token
105	}
106
107	if i.AuthEnabled {
108		consulConfig.HttpAuth = &consulapi.HttpBasicAuth{
109			Username: i.AuthUsername,
110			Password: i.AuthPassword,
111		}
112	}
113
114	// This transport will attempt to keep connections open to the Consul server.
115	transport := &http.Transport{
116		Proxy: http.ProxyFromEnvironment,
117		Dial: (&net.Dialer{
118			Timeout:   i.TransportDialTimeout,
119			KeepAlive: i.TransportDialKeepAlive,
120		}).Dial,
121		DisableKeepAlives:   i.TransportDisableKeepAlives,
122		MaxIdleConns:        i.TransportMaxIdleConns,
123		IdleConnTimeout:     i.TransportIdleConnTimeout,
124		MaxIdleConnsPerHost: i.TransportMaxIdleConnsPerHost,
125		TLSHandshakeTimeout: i.TransportTLSHandshakeTimeout,
126	}
127
128	// Configure SSL
129	if i.SSLEnabled {
130		consulConfig.Scheme = "https"
131
132		var tlsConfig tls.Config
133
134		// Custom certificate or certificate and key
135		if i.SSLCert != "" && i.SSLKey != "" {
136			cert, err := tls.LoadX509KeyPair(i.SSLCert, i.SSLKey)
137			if err != nil {
138				return fmt.Errorf("client set: consul: %s", err)
139			}
140			tlsConfig.Certificates = []tls.Certificate{cert}
141		} else if i.SSLCert != "" {
142			cert, err := tls.LoadX509KeyPair(i.SSLCert, i.SSLCert)
143			if err != nil {
144				return fmt.Errorf("client set: consul: %s", err)
145			}
146			tlsConfig.Certificates = []tls.Certificate{cert}
147		}
148
149		// Custom CA certificate
150		if i.SSLCACert != "" || i.SSLCAPath != "" {
151			rootConfig := &rootcerts.Config{
152				CAFile: i.SSLCACert,
153				CAPath: i.SSLCAPath,
154			}
155			if err := rootcerts.ConfigureTLS(&tlsConfig, rootConfig); err != nil {
156				return fmt.Errorf("client set: consul configuring TLS failed: %s", err)
157			}
158		}
159
160		// Construct all the certificates now
161		tlsConfig.BuildNameToCertificate()
162
163		// SSL verification
164		if i.ServerName != "" {
165			tlsConfig.ServerName = i.ServerName
166			tlsConfig.InsecureSkipVerify = false
167		}
168		if !i.SSLVerify {
169			log.Printf("[WARN] (clients) disabling consul SSL verification")
170			tlsConfig.InsecureSkipVerify = true
171		}
172
173		// Save the TLS config on our transport
174		transport.TLSClientConfig = &tlsConfig
175	}
176
177	// Setup the new transport
178	consulConfig.Transport = transport
179
180	// Create the API client
181	client, err := consulapi.NewClient(consulConfig)
182	if err != nil {
183		return fmt.Errorf("client set: consul: %s", err)
184	}
185
186	// Save the data on ourselves
187	c.Lock()
188	c.consul = &consulClient{
189		client:    client,
190		transport: transport,
191	}
192	c.Unlock()
193
194	return nil
195}
196
197func (c *ClientSet) CreateVaultClient(i *CreateVaultClientInput) error {
198	vaultConfig := vaultapi.DefaultConfig()
199
200	if i.Address != "" {
201		vaultConfig.Address = i.Address
202	}
203
204	// This transport will attempt to keep connections open to the Vault server.
205	transport := &http.Transport{
206		Proxy: http.ProxyFromEnvironment,
207		Dial: (&net.Dialer{
208			Timeout:   i.TransportDialTimeout,
209			KeepAlive: i.TransportDialKeepAlive,
210		}).Dial,
211		DisableKeepAlives:   i.TransportDisableKeepAlives,
212		MaxIdleConns:        i.TransportMaxIdleConns,
213		IdleConnTimeout:     i.TransportIdleConnTimeout,
214		MaxIdleConnsPerHost: i.TransportMaxIdleConnsPerHost,
215		TLSHandshakeTimeout: i.TransportTLSHandshakeTimeout,
216	}
217
218	// Configure SSL
219	if i.SSLEnabled {
220		var tlsConfig tls.Config
221
222		// Custom certificate or certificate and key
223		if i.SSLCert != "" && i.SSLKey != "" {
224			cert, err := tls.LoadX509KeyPair(i.SSLCert, i.SSLKey)
225			if err != nil {
226				return fmt.Errorf("client set: vault: %s", err)
227			}
228			tlsConfig.Certificates = []tls.Certificate{cert}
229		} else if i.SSLCert != "" {
230			cert, err := tls.LoadX509KeyPair(i.SSLCert, i.SSLCert)
231			if err != nil {
232				return fmt.Errorf("client set: vault: %s", err)
233			}
234			tlsConfig.Certificates = []tls.Certificate{cert}
235		}
236
237		// Custom CA certificate
238		if i.SSLCACert != "" || i.SSLCAPath != "" {
239			rootConfig := &rootcerts.Config{
240				CAFile: i.SSLCACert,
241				CAPath: i.SSLCAPath,
242			}
243			if err := rootcerts.ConfigureTLS(&tlsConfig, rootConfig); err != nil {
244				return fmt.Errorf("client set: vault configuring TLS failed: %s", err)
245			}
246		}
247
248		// Construct all the certificates now
249		tlsConfig.BuildNameToCertificate()
250
251		// SSL verification
252		if i.ServerName != "" {
253			tlsConfig.ServerName = i.ServerName
254			tlsConfig.InsecureSkipVerify = false
255		}
256		if !i.SSLVerify {
257			log.Printf("[WARN] (clients) disabling vault SSL verification")
258			tlsConfig.InsecureSkipVerify = true
259		}
260
261		// Save the TLS config on our transport
262		transport.TLSClientConfig = &tlsConfig
263	}
264
265	// Setup the new transport
266	vaultConfig.HttpClient.Transport = transport
267
268	// Create the client
269	client, err := vaultapi.NewClient(vaultConfig)
270	if err != nil {
271		return fmt.Errorf("client set: vault: %s", err)
272	}
273
274	// Set the namespace if given.
275	if i.Namespace != "" {
276		client.SetNamespace(i.Namespace)
277	}
278
279	// Set the token if given
280	if i.Token != "" {
281		client.SetToken(i.Token)
282	}
283
284	// Check if we are unwrapping
285	if i.UnwrapToken {
286		secret, err := client.Logical().Unwrap(i.Token)
287		if err != nil {
288			return fmt.Errorf("client set: vault unwrap: %s", err)
289		}
290
291		if secret == nil {
292			return fmt.Errorf("client set: vault unwrap: no secret")
293		}
294
295		if secret.Auth == nil {
296			return fmt.Errorf("client set: vault unwrap: no secret auth")
297		}
298
299		if secret.Auth.ClientToken == "" {
300			return fmt.Errorf("client set: vault unwrap: no token returned")
301		}
302
303		client.SetToken(secret.Auth.ClientToken)
304	}
305
306	// Save the data on ourselves
307	c.Lock()
308	c.vault = &vaultClient{
309		client:     client,
310		httpClient: vaultConfig.HttpClient,
311	}
312	c.Unlock()
313
314	return nil
315}
316
317// Consul returns the Consul client for this set.
318func (c *ClientSet) Consul() *consulapi.Client {
319	c.RLock()
320	defer c.RUnlock()
321	return c.consul.client
322}
323
324// Vault returns the Vault client for this set.
325func (c *ClientSet) Vault() *vaultapi.Client {
326	c.RLock()
327	defer c.RUnlock()
328	return c.vault.client
329}
330
331// Stop closes all idle connections for any attached clients.
332func (c *ClientSet) Stop() {
333	c.Lock()
334	defer c.Unlock()
335
336	if c.consul != nil {
337		c.consul.transport.CloseIdleConnections()
338	}
339
340	if c.vault != nil {
341		c.vault.httpClient.Transport.(*http.Transport).CloseIdleConnections()
342	}
343}
344