1package cassandra
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"strings"
8	"sync"
9	"time"
10
11	"github.com/mitchellh/mapstructure"
12
13	"github.com/gocql/gocql"
14	"github.com/hashicorp/errwrap"
15	"github.com/hashicorp/vault/helper/certutil"
16	"github.com/hashicorp/vault/helper/parseutil"
17	"github.com/hashicorp/vault/helper/tlsutil"
18	"github.com/hashicorp/vault/plugins/helper/database/connutil"
19)
20
21// cassandraConnectionProducer implements ConnectionProducer and provides an
22// interface for cassandra databases to make connections.
23type cassandraConnectionProducer struct {
24	Hosts             string      `json:"hosts" structs:"hosts" mapstructure:"hosts"`
25	Port              int         `json:"port" structs:"port" mapstructure:"port"`
26	Username          string      `json:"username" structs:"username" mapstructure:"username"`
27	Password          string      `json:"password" structs:"password" mapstructure:"password"`
28	TLS               bool        `json:"tls" structs:"tls" mapstructure:"tls"`
29	InsecureTLS       bool        `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
30	ProtocolVersion   int         `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
31	ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
32	TLSMinVersion     string      `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
33	Consistency       string      `json:"consistency" structs:"consistency" mapstructure:"consistency"`
34	PemBundle         string      `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"`
35	PemJSON           string      `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"`
36
37	connectTimeout time.Duration
38	certificate    string
39	privateKey     string
40	issuingCA      string
41	rawConfig      map[string]interface{}
42
43	Initialized bool
44	Type        string
45	session     *gocql.Session
46	sync.Mutex
47}
48
49func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
50	_, err := c.Init(ctx, conf, verifyConnection)
51	return err
52}
53
54func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
55	c.Lock()
56	defer c.Unlock()
57
58	c.rawConfig = conf
59
60	err := mapstructure.WeakDecode(conf, c)
61	if err != nil {
62		return nil, err
63	}
64
65	if c.ConnectTimeoutRaw == nil {
66		c.ConnectTimeoutRaw = "0s"
67	}
68	c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
69	if err != nil {
70		return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
71	}
72
73	switch {
74	case len(c.Hosts) == 0:
75		return nil, fmt.Errorf("hosts cannot be empty")
76	case len(c.Username) == 0:
77		return nil, fmt.Errorf("username cannot be empty")
78	case len(c.Password) == 0:
79		return nil, fmt.Errorf("password cannot be empty")
80	}
81
82	var certBundle *certutil.CertBundle
83	var parsedCertBundle *certutil.ParsedCertBundle
84	switch {
85	case len(c.PemJSON) != 0:
86		parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
87		if err != nil {
88			return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
89		}
90		certBundle, err = parsedCertBundle.ToCertBundle()
91		if err != nil {
92			return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
93		}
94		c.certificate = certBundle.Certificate
95		c.privateKey = certBundle.PrivateKey
96		c.issuingCA = certBundle.IssuingCA
97		c.TLS = true
98
99	case len(c.PemBundle) != 0:
100		parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
101		if err != nil {
102			return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
103		}
104		certBundle, err = parsedCertBundle.ToCertBundle()
105		if err != nil {
106			return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
107		}
108		c.certificate = certBundle.Certificate
109		c.privateKey = certBundle.PrivateKey
110		c.issuingCA = certBundle.IssuingCA
111		c.TLS = true
112	}
113
114	// Set initialized to true at this point since all fields are set,
115	// and the connection can be established at a later time.
116	c.Initialized = true
117
118	if verifyConnection {
119		if _, err := c.Connection(ctx); err != nil {
120			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
121		}
122	}
123
124	return conf, nil
125}
126
127func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
128	if !c.Initialized {
129		return nil, connutil.ErrNotInitialized
130	}
131
132	// If we already have a DB, return it
133	if c.session != nil && !c.session.Closed() {
134		return c.session, nil
135	}
136
137	session, err := c.createSession()
138	if err != nil {
139		return nil, err
140	}
141
142	//  Store the session in backend for reuse
143	c.session = session
144
145	return session, nil
146}
147
148func (c *cassandraConnectionProducer) Close() error {
149	// Grab the write lock
150	c.Lock()
151	defer c.Unlock()
152
153	if c.session != nil {
154		c.session.Close()
155	}
156
157	c.session = nil
158
159	return nil
160}
161
162func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
163	hosts := strings.Split(c.Hosts, ",")
164	clusterConfig := gocql.NewCluster(hosts...)
165	clusterConfig.Authenticator = gocql.PasswordAuthenticator{
166		Username: c.Username,
167		Password: c.Password,
168	}
169
170	if c.Port != 0 {
171		clusterConfig.Port = c.Port
172	}
173
174	clusterConfig.ProtoVersion = c.ProtocolVersion
175	if clusterConfig.ProtoVersion == 0 {
176		clusterConfig.ProtoVersion = 2
177	}
178
179	clusterConfig.Timeout = c.connectTimeout
180	if c.TLS {
181		var tlsConfig *tls.Config
182		if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
183			if len(c.certificate) > 0 && len(c.privateKey) == 0 {
184				return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
185			}
186
187			certBundle := &certutil.CertBundle{}
188			if len(c.certificate) > 0 {
189				certBundle.Certificate = c.certificate
190				certBundle.PrivateKey = c.privateKey
191			}
192			if len(c.issuingCA) > 0 {
193				certBundle.IssuingCA = c.issuingCA
194			}
195
196			parsedCertBundle, err := certBundle.ToParsedCertBundle()
197			if err != nil {
198				return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
199			}
200
201			tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
202			if err != nil || tlsConfig == nil {
203				return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
204			}
205			tlsConfig.InsecureSkipVerify = c.InsecureTLS
206
207			if c.TLSMinVersion != "" {
208				var ok bool
209				tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
210				if !ok {
211					return nil, fmt.Errorf("invalid 'tls_min_version' in config")
212				}
213			} else {
214				// MinVersion was not being set earlier. Reset it to
215				// zero to gracefully handle upgrades.
216				tlsConfig.MinVersion = 0
217			}
218		}
219
220		clusterConfig.SslOpts = &gocql.SslOptions{
221			Config: tlsConfig,
222		}
223	}
224
225	session, err := clusterConfig.CreateSession()
226	if err != nil {
227		return nil, errwrap.Wrapf("error creating session: {{err}}", err)
228	}
229
230	// Set consistency
231	if c.Consistency != "" {
232		consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
233		if err != nil {
234			return nil, err
235		}
236
237		session.SetConsistency(consistencyValue)
238	}
239
240	// Verify the info
241	err = session.Query(`LIST ALL`).Exec()
242	if err != nil {
243		return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
244	}
245
246	return session, nil
247}
248
249func (c *cassandraConnectionProducer) secretValues() map[string]interface{} {
250	return map[string]interface{}{
251		c.Password:  "[password]",
252		c.PemBundle: "[pem_bundle]",
253		c.PemJSON:   "[pem_json]",
254	}
255}
256