1package cassandra
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"strings"
8	"sync"
9	"time"
10
11	"github.com/mitchellh/mapstructure"
12
13	"github.com/gocql/gocql"
14	"github.com/hashicorp/errwrap"
15	"github.com/hashicorp/vault/sdk/database/dbplugin"
16	"github.com/hashicorp/vault/sdk/database/helper/connutil"
17	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
18	"github.com/hashicorp/vault/sdk/helper/certutil"
19	"github.com/hashicorp/vault/sdk/helper/parseutil"
20	"github.com/hashicorp/vault/sdk/helper/tlsutil"
21)
22
23// cassandraConnectionProducer implements ConnectionProducer and provides an
24// interface for cassandra databases to make connections.
25type cassandraConnectionProducer struct {
26	Hosts              string      `json:"hosts" structs:"hosts" mapstructure:"hosts"`
27	Port               int         `json:"port" structs:"port" mapstructure:"port"`
28	Username           string      `json:"username" structs:"username" mapstructure:"username"`
29	Password           string      `json:"password" structs:"password" mapstructure:"password"`
30	TLS                bool        `json:"tls" structs:"tls" mapstructure:"tls"`
31	InsecureTLS        bool        `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
32	ProtocolVersion    int         `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
33	ConnectTimeoutRaw  interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
34	SocketKeepAliveRaw interface{} `json:"socket_keep_alive" structs:"socket_keep_alive" mapstructure:"socket_keep_alive"`
35	TLSMinVersion      string      `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
36	Consistency        string      `json:"consistency" structs:"consistency" mapstructure:"consistency"`
37	LocalDatacenter    string      `json:"local_datacenter" structs:"local_datacenter" mapstructure:"local_datacenter"`
38	PemBundle          string      `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"`
39	PemJSON            string      `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"`
40
41	connectTimeout  time.Duration
42	socketKeepAlive time.Duration
43	certificate     string
44	privateKey      string
45	issuingCA       string
46	rawConfig       map[string]interface{}
47
48	Initialized bool
49	Type        string
50	session     *gocql.Session
51	sync.Mutex
52}
53
54func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
55	_, err := c.Init(ctx, conf, verifyConnection)
56	return err
57}
58
59func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
60	c.Lock()
61	defer c.Unlock()
62
63	c.rawConfig = conf
64
65	err := mapstructure.WeakDecode(conf, c)
66	if err != nil {
67		return nil, err
68	}
69
70	if c.ConnectTimeoutRaw == nil {
71		c.ConnectTimeoutRaw = "0s"
72	}
73	c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
74	if err != nil {
75		return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
76	}
77
78	if c.SocketKeepAliveRaw == nil {
79		c.SocketKeepAliveRaw = "0s"
80	}
81	c.socketKeepAlive, err = parseutil.ParseDurationSecond(c.SocketKeepAliveRaw)
82	if err != nil {
83		return nil, errwrap.Wrapf("invalid socket_keep_alive: {{err}}", err)
84	}
85
86	switch {
87	case len(c.Hosts) == 0:
88		return nil, fmt.Errorf("hosts cannot be empty")
89	case len(c.Username) == 0:
90		return nil, fmt.Errorf("username cannot be empty")
91	case len(c.Password) == 0:
92		return nil, fmt.Errorf("password cannot be empty")
93	}
94
95	var certBundle *certutil.CertBundle
96	var parsedCertBundle *certutil.ParsedCertBundle
97	switch {
98	case len(c.PemJSON) != 0:
99		parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
100		if err != nil {
101			return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
102		}
103		certBundle, err = parsedCertBundle.ToCertBundle()
104		if err != nil {
105			return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
106		}
107		c.certificate = certBundle.Certificate
108		c.privateKey = certBundle.PrivateKey
109		c.issuingCA = certBundle.IssuingCA
110		c.TLS = true
111
112	case len(c.PemBundle) != 0:
113		parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
114		if err != nil {
115			return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
116		}
117		certBundle, err = parsedCertBundle.ToCertBundle()
118		if err != nil {
119			return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
120		}
121		c.certificate = certBundle.Certificate
122		c.privateKey = certBundle.PrivateKey
123		c.issuingCA = certBundle.IssuingCA
124		c.TLS = true
125	}
126
127	// Set initialized to true at this point since all fields are set,
128	// and the connection can be established at a later time.
129	c.Initialized = true
130
131	if verifyConnection {
132		if _, err := c.Connection(ctx); err != nil {
133			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
134		}
135	}
136
137	return conf, nil
138}
139
140func (c *cassandraConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
141	if !c.Initialized {
142		return nil, connutil.ErrNotInitialized
143	}
144
145	// If we already have a DB, return it
146	if c.session != nil && !c.session.Closed() {
147		return c.session, nil
148	}
149
150	session, err := c.createSession(ctx)
151	if err != nil {
152		return nil, err
153	}
154
155	//  Store the session in backend for reuse
156	c.session = session
157
158	return session, nil
159}
160
161func (c *cassandraConnectionProducer) Close() error {
162	// Grab the write lock
163	c.Lock()
164	defer c.Unlock()
165
166	if c.session != nil {
167		c.session.Close()
168	}
169
170	c.session = nil
171
172	return nil
173}
174
175func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql.Session, error) {
176	hosts := strings.Split(c.Hosts, ",")
177	clusterConfig := gocql.NewCluster(hosts...)
178	clusterConfig.Authenticator = gocql.PasswordAuthenticator{
179		Username: c.Username,
180		Password: c.Password,
181	}
182
183	if c.Port != 0 {
184		clusterConfig.Port = c.Port
185	}
186
187	clusterConfig.ProtoVersion = c.ProtocolVersion
188	if clusterConfig.ProtoVersion == 0 {
189		clusterConfig.ProtoVersion = 2
190	}
191
192	clusterConfig.Timeout = c.connectTimeout
193	clusterConfig.SocketKeepalive = c.socketKeepAlive
194	if c.TLS {
195		var tlsConfig *tls.Config
196		if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
197			if len(c.certificate) > 0 && len(c.privateKey) == 0 {
198				return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
199			}
200
201			certBundle := &certutil.CertBundle{}
202			if len(c.certificate) > 0 {
203				certBundle.Certificate = c.certificate
204				certBundle.PrivateKey = c.privateKey
205			}
206			if len(c.issuingCA) > 0 {
207				certBundle.IssuingCA = c.issuingCA
208			}
209
210			parsedCertBundle, err := certBundle.ToParsedCertBundle()
211			if err != nil {
212				return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
213			}
214
215			tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
216			if err != nil || tlsConfig == nil {
217				return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
218			}
219			tlsConfig.InsecureSkipVerify = c.InsecureTLS
220
221			if c.TLSMinVersion != "" {
222				var ok bool
223				tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
224				if !ok {
225					return nil, fmt.Errorf("invalid 'tls_min_version' in config")
226				}
227			} else {
228				// MinVersion was not being set earlier. Reset it to
229				// zero to gracefully handle upgrades.
230				tlsConfig.MinVersion = 0
231			}
232		}
233
234		clusterConfig.SslOpts = &gocql.SslOptions{
235			Config: tlsConfig,
236		}
237	}
238
239	if c.LocalDatacenter != "" {
240		clusterConfig.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(c.LocalDatacenter)
241	}
242
243	session, err := clusterConfig.CreateSession()
244	if err != nil {
245		return nil, errwrap.Wrapf("error creating session: {{err}}", err)
246	}
247
248	// Set consistency
249	if c.Consistency != "" {
250		consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
251		if err != nil {
252			return nil, err
253		}
254
255		session.SetConsistency(consistencyValue)
256	}
257
258	// Verify the info
259	err = session.Query(`LIST ALL`).WithContext(ctx).Exec()
260	if err != nil && len(c.Username) != 0 && strings.Contains(err.Error(), "not authorized") {
261		rowNum := session.Query(dbutil.QueryHelper(`LIST CREATE ON ALL ROLES OF '{{username}}';`, map[string]string{
262			"username": c.Username,
263		})).Iter().NumRows()
264
265		if rowNum < 1 {
266			return nil, errwrap.Wrapf("error validating connection info: No role create permissions found, previous error: {{err}}", err)
267		}
268	} else if err != nil {
269		return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
270	}
271
272	return session, nil
273}
274
275func (c *cassandraConnectionProducer) secretValues() map[string]interface{} {
276	return map[string]interface{}{
277		c.Password:  "[password]",
278		c.PemBundle: "[pem_bundle]",
279		c.PemJSON:   "[pem_json]",
280	}
281}
282
283// SetCredentials uses provided information to set/create a user in the
284// database. Unlike CreateUser, this method requires a username be provided and
285// uses the name given, instead of generating a name. This is used for creating
286// and setting the password of static accounts, as well as rolling back
287// passwords in the database in the event an updated database fails to save in
288// Vault's storage.
289func (c *cassandraConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
290	return "", "", dbutil.Unimplemented()
291}
292