1package dependency
2
3import (
4	"crypto/tls"
5	"fmt"
6	"log"
7	"net"
8	"net/http"
9	"sync"
10	"time"
11
12	consulapi "github.com/hashicorp/consul/api"
13	rootcerts "github.com/hashicorp/go-rootcerts"
14	vaultapi "github.com/hashicorp/vault/api"
15)
16
17// ClientSet is a collection of clients that dependencies use to communicate
18// with remote services like Consul or Vault.
19type ClientSet struct {
20	sync.RWMutex
21
22	vault  *vaultClient
23	consul *consulClient
24}
25
26// consulClient is a wrapper around a real Consul API client.
27type consulClient struct {
28	client    *consulapi.Client
29	transport *http.Transport
30}
31
32// vaultClient is a wrapper around a real Vault API client.
33type vaultClient struct {
34	client     *vaultapi.Client
35	httpClient *http.Client
36}
37
38// CreateConsulClientInput is used as input to the CreateConsulClient function.
39type CreateConsulClientInput struct {
40	Address      string
41	Token        string
42	AuthEnabled  bool
43	AuthUsername string
44	AuthPassword string
45	SSLEnabled   bool
46	SSLVerify    bool
47	SSLCert      string
48	SSLKey       string
49	SSLCACert    string
50	SSLCAPath    string
51	ServerName   string
52
53	TransportDialKeepAlive       time.Duration
54	TransportDialTimeout         time.Duration
55	TransportDisableKeepAlives   bool
56	TransportIdleConnTimeout     time.Duration
57	TransportMaxIdleConns        int
58	TransportMaxIdleConnsPerHost int
59	TransportTLSHandshakeTimeout time.Duration
60}
61
62// CreateVaultClientInput is used as input to the CreateVaultClient function.
63type CreateVaultClientInput struct {
64	Address     string
65	Namespace   string
66	Token       string
67	UnwrapToken bool
68	SSLEnabled  bool
69	SSLVerify   bool
70	SSLCert     string
71	SSLKey      string
72	SSLCACert   string
73	SSLCAPath   string
74	ServerName  string
75
76	TransportDialKeepAlive       time.Duration
77	TransportDialTimeout         time.Duration
78	TransportDisableKeepAlives   bool
79	TransportIdleConnTimeout     time.Duration
80	TransportMaxIdleConns        int
81	TransportMaxIdleConnsPerHost int
82	TransportTLSHandshakeTimeout time.Duration
83}
84
85// NewClientSet creates a new client set that is ready to accept clients.
86func NewClientSet() *ClientSet {
87	return &ClientSet{}
88}
89
90// CreateConsulClient creates a new Consul API client from the given input.
91func (c *ClientSet) CreateConsulClient(i *CreateConsulClientInput) error {
92	consulConfig := consulapi.DefaultConfig()
93
94	if i.Address != "" {
95		consulConfig.Address = i.Address
96	}
97
98	if i.Token != "" {
99		consulConfig.Token = i.Token
100	}
101
102	if i.AuthEnabled {
103		consulConfig.HttpAuth = &consulapi.HttpBasicAuth{
104			Username: i.AuthUsername,
105			Password: i.AuthPassword,
106		}
107	}
108
109	// This transport will attempt to keep connections open to the Consul server.
110	transport := &http.Transport{
111		Proxy: http.ProxyFromEnvironment,
112		Dial: (&net.Dialer{
113			Timeout:   i.TransportDialTimeout,
114			KeepAlive: i.TransportDialKeepAlive,
115		}).Dial,
116		DisableKeepAlives:   i.TransportDisableKeepAlives,
117		MaxIdleConns:        i.TransportMaxIdleConns,
118		IdleConnTimeout:     i.TransportIdleConnTimeout,
119		MaxIdleConnsPerHost: i.TransportMaxIdleConnsPerHost,
120		TLSHandshakeTimeout: i.TransportTLSHandshakeTimeout,
121	}
122
123	// Configure SSL
124	if i.SSLEnabled {
125		consulConfig.Scheme = "https"
126
127		var tlsConfig tls.Config
128
129		// Custom certificate or certificate and key
130		if i.SSLCert != "" && i.SSLKey != "" {
131			cert, err := tls.LoadX509KeyPair(i.SSLCert, i.SSLKey)
132			if err != nil {
133				return fmt.Errorf("client set: consul: %s", err)
134			}
135			tlsConfig.Certificates = []tls.Certificate{cert}
136		} else if i.SSLCert != "" {
137			cert, err := tls.LoadX509KeyPair(i.SSLCert, i.SSLCert)
138			if err != nil {
139				return fmt.Errorf("client set: consul: %s", err)
140			}
141			tlsConfig.Certificates = []tls.Certificate{cert}
142		}
143
144		// Custom CA certificate
145		if i.SSLCACert != "" || i.SSLCAPath != "" {
146			rootConfig := &rootcerts.Config{
147				CAFile: i.SSLCACert,
148				CAPath: i.SSLCAPath,
149			}
150			if err := rootcerts.ConfigureTLS(&tlsConfig, rootConfig); err != nil {
151				return fmt.Errorf("client set: consul configuring TLS failed: %s", err)
152			}
153		}
154
155		// Construct all the certificates now
156		tlsConfig.BuildNameToCertificate()
157
158		// SSL verification
159		if i.ServerName != "" {
160			tlsConfig.ServerName = i.ServerName
161			tlsConfig.InsecureSkipVerify = false
162		}
163		if !i.SSLVerify {
164			log.Printf("[WARN] (clients) disabling consul SSL verification")
165			tlsConfig.InsecureSkipVerify = true
166		}
167
168		// Save the TLS config on our transport
169		transport.TLSClientConfig = &tlsConfig
170	}
171
172	// Setup the new transport
173	consulConfig.Transport = transport
174
175	// Create the API client
176	client, err := consulapi.NewClient(consulConfig)
177	if err != nil {
178		return fmt.Errorf("client set: consul: %s", err)
179	}
180
181	// Save the data on ourselves
182	c.Lock()
183	c.consul = &consulClient{
184		client:    client,
185		transport: transport,
186	}
187	c.Unlock()
188
189	return nil
190}
191
192func (c *ClientSet) CreateVaultClient(i *CreateVaultClientInput) error {
193	vaultConfig := vaultapi.DefaultConfig()
194
195	if i.Address != "" {
196		vaultConfig.Address = i.Address
197	}
198
199	// This transport will attempt to keep connections open to the Vault server.
200	transport := &http.Transport{
201		Proxy: http.ProxyFromEnvironment,
202		Dial: (&net.Dialer{
203			Timeout:   i.TransportDialTimeout,
204			KeepAlive: i.TransportDialKeepAlive,
205		}).Dial,
206		DisableKeepAlives:   i.TransportDisableKeepAlives,
207		MaxIdleConns:        i.TransportMaxIdleConns,
208		IdleConnTimeout:     i.TransportIdleConnTimeout,
209		MaxIdleConnsPerHost: i.TransportMaxIdleConnsPerHost,
210		TLSHandshakeTimeout: i.TransportTLSHandshakeTimeout,
211	}
212
213	// Configure SSL
214	if i.SSLEnabled {
215		var tlsConfig tls.Config
216
217		// Custom certificate or certificate and key
218		if i.SSLCert != "" && i.SSLKey != "" {
219			cert, err := tls.LoadX509KeyPair(i.SSLCert, i.SSLKey)
220			if err != nil {
221				return fmt.Errorf("client set: vault: %s", err)
222			}
223			tlsConfig.Certificates = []tls.Certificate{cert}
224		} else if i.SSLCert != "" {
225			cert, err := tls.LoadX509KeyPair(i.SSLCert, i.SSLCert)
226			if err != nil {
227				return fmt.Errorf("client set: vault: %s", err)
228			}
229			tlsConfig.Certificates = []tls.Certificate{cert}
230		}
231
232		// Custom CA certificate
233		if i.SSLCACert != "" || i.SSLCAPath != "" {
234			rootConfig := &rootcerts.Config{
235				CAFile: i.SSLCACert,
236				CAPath: i.SSLCAPath,
237			}
238			if err := rootcerts.ConfigureTLS(&tlsConfig, rootConfig); err != nil {
239				return fmt.Errorf("client set: vault configuring TLS failed: %s", err)
240			}
241		}
242
243		// Construct all the certificates now
244		tlsConfig.BuildNameToCertificate()
245
246		// SSL verification
247		if i.ServerName != "" {
248			tlsConfig.ServerName = i.ServerName
249			tlsConfig.InsecureSkipVerify = false
250		}
251		if !i.SSLVerify {
252			log.Printf("[WARN] (clients) disabling vault SSL verification")
253			tlsConfig.InsecureSkipVerify = true
254		}
255
256		// Save the TLS config on our transport
257		transport.TLSClientConfig = &tlsConfig
258	}
259
260	// Setup the new transport
261	vaultConfig.HttpClient.Transport = transport
262
263	// Create the client
264	client, err := vaultapi.NewClient(vaultConfig)
265	if err != nil {
266		return fmt.Errorf("client set: vault: %s", err)
267	}
268
269	// Set the namespace if given.
270	if i.Namespace != "" {
271		client.SetNamespace(i.Namespace)
272	}
273
274	// Set the token if given
275	if i.Token != "" {
276		client.SetToken(i.Token)
277	}
278
279	// Check if we are unwrapping
280	if i.UnwrapToken {
281		secret, err := client.Logical().Unwrap(i.Token)
282		if err != nil {
283			return fmt.Errorf("client set: vault unwrap: %s", err)
284		}
285
286		if secret == nil {
287			return fmt.Errorf("client set: vault unwrap: no secret")
288		}
289
290		if secret.Auth == nil {
291			return fmt.Errorf("client set: vault unwrap: no secret auth")
292		}
293
294		if secret.Auth.ClientToken == "" {
295			return fmt.Errorf("client set: vault unwrap: no token returned")
296		}
297
298		client.SetToken(secret.Auth.ClientToken)
299	}
300
301	// Save the data on ourselves
302	c.Lock()
303	c.vault = &vaultClient{
304		client:     client,
305		httpClient: vaultConfig.HttpClient,
306	}
307	c.Unlock()
308
309	return nil
310}
311
312// Consul returns the Consul client for this set.
313func (c *ClientSet) Consul() *consulapi.Client {
314	c.RLock()
315	defer c.RUnlock()
316	return c.consul.client
317}
318
319// Vault returns the Vault client for this set.
320func (c *ClientSet) Vault() *vaultapi.Client {
321	c.RLock()
322	defer c.RUnlock()
323	return c.vault.client
324}
325
326// Stop closes all idle connections for any attached clients.
327func (c *ClientSet) Stop() {
328	c.Lock()
329	defer c.Unlock()
330
331	if c.consul != nil {
332		c.consul.transport.CloseIdleConnections()
333	}
334
335	if c.vault != nil {
336		c.vault.httpClient.Transport.(*http.Transport).CloseIdleConnections()
337	}
338}
339