1package cassandra
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"strings"
8	"sync"
9	"time"
10
11	"github.com/mitchellh/mapstructure"
12
13	"github.com/gocql/gocql"
14	"github.com/hashicorp/errwrap"
15	"github.com/hashicorp/vault/helper/certutil"
16	"github.com/hashicorp/vault/helper/parseutil"
17	"github.com/hashicorp/vault/helper/tlsutil"
18	"github.com/hashicorp/vault/plugins/helper/database/connutil"
19	"github.com/hashicorp/vault/plugins/helper/database/dbutil"
20)
21
22// cassandraConnectionProducer implements ConnectionProducer and provides an
23// interface for cassandra databases to make connections.
24type cassandraConnectionProducer struct {
25	Hosts              string      `json:"hosts" structs:"hosts" mapstructure:"hosts"`
26	Port               int         `json:"port" structs:"port" mapstructure:"port"`
27	Username           string      `json:"username" structs:"username" mapstructure:"username"`
28	Password           string      `json:"password" structs:"password" mapstructure:"password"`
29	TLS                bool        `json:"tls" structs:"tls" mapstructure:"tls"`
30	InsecureTLS        bool        `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
31	ProtocolVersion    int         `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
32	ConnectTimeoutRaw  interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
33	SocketKeepAliveRaw interface{} `json:"socket_keep_alive" structs:"socket_keep_alive" mapstructure:"socket_keep_alive"`
34	TLSMinVersion      string      `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
35	Consistency        string      `json:"consistency" structs:"consistency" mapstructure:"consistency"`
36	PemBundle          string      `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"`
37	PemJSON            string      `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"`
38
39	connectTimeout  time.Duration
40	socketKeepAlive time.Duration
41	certificate     string
42	privateKey      string
43	issuingCA       string
44	rawConfig       map[string]interface{}
45
46	Initialized bool
47	Type        string
48	session     *gocql.Session
49	sync.Mutex
50}
51
52func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
53	_, err := c.Init(ctx, conf, verifyConnection)
54	return err
55}
56
57func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
58	c.Lock()
59	defer c.Unlock()
60
61	c.rawConfig = conf
62
63	err := mapstructure.WeakDecode(conf, c)
64	if err != nil {
65		return nil, err
66	}
67
68	if c.ConnectTimeoutRaw == nil {
69		c.ConnectTimeoutRaw = "0s"
70	}
71	c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
72	if err != nil {
73		return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
74	}
75
76	if c.SocketKeepAliveRaw == nil {
77		c.SocketKeepAliveRaw = "0s"
78	}
79	c.socketKeepAlive, err = parseutil.ParseDurationSecond(c.SocketKeepAliveRaw)
80	if err != nil {
81		return nil, errwrap.Wrapf("invalid socket_keep_alive: {{err}}", err)
82	}
83
84	switch {
85	case len(c.Hosts) == 0:
86		return nil, fmt.Errorf("hosts cannot be empty")
87	case len(c.Username) == 0:
88		return nil, fmt.Errorf("username cannot be empty")
89	case len(c.Password) == 0:
90		return nil, fmt.Errorf("password cannot be empty")
91	}
92
93	var certBundle *certutil.CertBundle
94	var parsedCertBundle *certutil.ParsedCertBundle
95	switch {
96	case len(c.PemJSON) != 0:
97		parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
98		if err != nil {
99			return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
100		}
101		certBundle, err = parsedCertBundle.ToCertBundle()
102		if err != nil {
103			return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
104		}
105		c.certificate = certBundle.Certificate
106		c.privateKey = certBundle.PrivateKey
107		c.issuingCA = certBundle.IssuingCA
108		c.TLS = true
109
110	case len(c.PemBundle) != 0:
111		parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
112		if err != nil {
113			return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
114		}
115		certBundle, err = parsedCertBundle.ToCertBundle()
116		if err != nil {
117			return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
118		}
119		c.certificate = certBundle.Certificate
120		c.privateKey = certBundle.PrivateKey
121		c.issuingCA = certBundle.IssuingCA
122		c.TLS = true
123	}
124
125	// Set initialized to true at this point since all fields are set,
126	// and the connection can be established at a later time.
127	c.Initialized = true
128
129	if verifyConnection {
130		if _, err := c.Connection(ctx); err != nil {
131			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
132		}
133	}
134
135	return conf, nil
136}
137
138func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
139	if !c.Initialized {
140		return nil, connutil.ErrNotInitialized
141	}
142
143	// If we already have a DB, return it
144	if c.session != nil && !c.session.Closed() {
145		return c.session, nil
146	}
147
148	session, err := c.createSession()
149	if err != nil {
150		return nil, err
151	}
152
153	//  Store the session in backend for reuse
154	c.session = session
155
156	return session, nil
157}
158
159func (c *cassandraConnectionProducer) Close() error {
160	// Grab the write lock
161	c.Lock()
162	defer c.Unlock()
163
164	if c.session != nil {
165		c.session.Close()
166	}
167
168	c.session = nil
169
170	return nil
171}
172
173func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
174	hosts := strings.Split(c.Hosts, ",")
175	clusterConfig := gocql.NewCluster(hosts...)
176	clusterConfig.Authenticator = gocql.PasswordAuthenticator{
177		Username: c.Username,
178		Password: c.Password,
179	}
180
181	if c.Port != 0 {
182		clusterConfig.Port = c.Port
183	}
184
185	clusterConfig.ProtoVersion = c.ProtocolVersion
186	if clusterConfig.ProtoVersion == 0 {
187		clusterConfig.ProtoVersion = 2
188	}
189
190	clusterConfig.Timeout = c.connectTimeout
191	clusterConfig.SocketKeepalive = c.socketKeepAlive
192	if c.TLS {
193		var tlsConfig *tls.Config
194		if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
195			if len(c.certificate) > 0 && len(c.privateKey) == 0 {
196				return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
197			}
198
199			certBundle := &certutil.CertBundle{}
200			if len(c.certificate) > 0 {
201				certBundle.Certificate = c.certificate
202				certBundle.PrivateKey = c.privateKey
203			}
204			if len(c.issuingCA) > 0 {
205				certBundle.IssuingCA = c.issuingCA
206			}
207
208			parsedCertBundle, err := certBundle.ToParsedCertBundle()
209			if err != nil {
210				return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
211			}
212
213			tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
214			if err != nil || tlsConfig == nil {
215				return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
216			}
217			tlsConfig.InsecureSkipVerify = c.InsecureTLS
218
219			if c.TLSMinVersion != "" {
220				var ok bool
221				tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
222				if !ok {
223					return nil, fmt.Errorf("invalid 'tls_min_version' in config")
224				}
225			} else {
226				// MinVersion was not being set earlier. Reset it to
227				// zero to gracefully handle upgrades.
228				tlsConfig.MinVersion = 0
229			}
230		}
231
232		clusterConfig.SslOpts = &gocql.SslOptions{
233			Config: tlsConfig,
234		}
235	}
236
237	session, err := clusterConfig.CreateSession()
238	if err != nil {
239		return nil, errwrap.Wrapf("error creating session: {{err}}", err)
240	}
241
242	// Set consistency
243	if c.Consistency != "" {
244		consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
245		if err != nil {
246			return nil, err
247		}
248
249		session.SetConsistency(consistencyValue)
250	}
251
252	// Verify the info
253	err = session.Query(`LIST ALL`).Exec()
254	if err != nil && len(c.Username) != 0 && strings.Contains(err.Error(), "not authorized") {
255		rowNum := session.Query(dbutil.QueryHelper(`LIST CREATE ON ALL ROLES OF '{{username}}';`, map[string]string{
256			"username": c.Username,
257		})).Iter().NumRows()
258
259		if rowNum < 1 {
260			return nil, errwrap.Wrapf("error validating connection info: No role create permissions found, previous error: {{err}}", err)
261		}
262	} else if err != nil {
263		return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
264	}
265
266	return session, nil
267}
268
269func (c *cassandraConnectionProducer) secretValues() map[string]interface{} {
270	return map[string]interface{}{
271		c.Password:  "[password]",
272		c.PemBundle: "[pem_bundle]",
273		c.PemJSON:   "[pem_json]",
274	}
275}
276