1package cassandra
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"strings"
8	"sync"
9	"time"
10
11	"github.com/mitchellh/mapstructure"
12
13	"github.com/gocql/gocql"
14	"github.com/hashicorp/errwrap"
15	"github.com/hashicorp/vault/sdk/database/dbplugin"
16	"github.com/hashicorp/vault/sdk/database/helper/connutil"
17	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
18	"github.com/hashicorp/vault/sdk/helper/certutil"
19	"github.com/hashicorp/vault/sdk/helper/parseutil"
20	"github.com/hashicorp/vault/sdk/helper/tlsutil"
21)
22
23// cassandraConnectionProducer implements ConnectionProducer and provides an
24// interface for cassandra databases to make connections.
25type cassandraConnectionProducer struct {
26	Hosts              string      `json:"hosts" structs:"hosts" mapstructure:"hosts"`
27	Port               int         `json:"port" structs:"port" mapstructure:"port"`
28	Username           string      `json:"username" structs:"username" mapstructure:"username"`
29	Password           string      `json:"password" structs:"password" mapstructure:"password"`
30	TLS                bool        `json:"tls" structs:"tls" mapstructure:"tls"`
31	InsecureTLS        bool        `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
32	ProtocolVersion    int         `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
33	ConnectTimeoutRaw  interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
34	SocketKeepAliveRaw interface{} `json:"socket_keep_alive" structs:"socket_keep_alive" mapstructure:"socket_keep_alive"`
35	TLSMinVersion      string      `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
36	Consistency        string      `json:"consistency" structs:"consistency" mapstructure:"consistency"`
37	LocalDatacenter    string      `json:"local_datacenter" structs:"local_datacenter" mapstructure:"local_datacenter"`
38	PemBundle          string      `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"`
39	PemJSON            string      `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"`
40	SkipVerification   bool        `json:"skip_verification" structs:"skip_verification" mapstructure:"skip_verification"`
41
42	connectTimeout  time.Duration
43	socketKeepAlive time.Duration
44	certificate     string
45	privateKey      string
46	issuingCA       string
47	rawConfig       map[string]interface{}
48
49	Initialized bool
50	Type        string
51	session     *gocql.Session
52	sync.Mutex
53}
54
55func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
56	_, err := c.Init(ctx, conf, verifyConnection)
57	return err
58}
59
60func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
61	c.Lock()
62	defer c.Unlock()
63
64	c.rawConfig = conf
65
66	err := mapstructure.WeakDecode(conf, c)
67	if err != nil {
68		return nil, err
69	}
70
71	if c.ConnectTimeoutRaw == nil {
72		c.ConnectTimeoutRaw = "0s"
73	}
74	c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
75	if err != nil {
76		return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
77	}
78
79	if c.SocketKeepAliveRaw == nil {
80		c.SocketKeepAliveRaw = "0s"
81	}
82	c.socketKeepAlive, err = parseutil.ParseDurationSecond(c.SocketKeepAliveRaw)
83	if err != nil {
84		return nil, errwrap.Wrapf("invalid socket_keep_alive: {{err}}", err)
85	}
86
87	switch {
88	case len(c.Hosts) == 0:
89		return nil, fmt.Errorf("hosts cannot be empty")
90	case len(c.Username) == 0:
91		return nil, fmt.Errorf("username cannot be empty")
92	case len(c.Password) == 0:
93		return nil, fmt.Errorf("password cannot be empty")
94	}
95
96	var certBundle *certutil.CertBundle
97	var parsedCertBundle *certutil.ParsedCertBundle
98	switch {
99	case len(c.PemJSON) != 0:
100		parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
101		if err != nil {
102			return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
103		}
104		certBundle, err = parsedCertBundle.ToCertBundle()
105		if err != nil {
106			return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
107		}
108		c.certificate = certBundle.Certificate
109		c.privateKey = certBundle.PrivateKey
110		c.issuingCA = certBundle.IssuingCA
111		c.TLS = true
112
113	case len(c.PemBundle) != 0:
114		parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
115		if err != nil {
116			return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
117		}
118		certBundle, err = parsedCertBundle.ToCertBundle()
119		if err != nil {
120			return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
121		}
122		c.certificate = certBundle.Certificate
123		c.privateKey = certBundle.PrivateKey
124		c.issuingCA = certBundle.IssuingCA
125		c.TLS = true
126	}
127
128	// Set initialized to true at this point since all fields are set,
129	// and the connection can be established at a later time.
130	c.Initialized = true
131
132	if verifyConnection {
133		if _, err := c.Connection(ctx); err != nil {
134			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
135		}
136	}
137
138	return conf, nil
139}
140
141func (c *cassandraConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
142	if !c.Initialized {
143		return nil, connutil.ErrNotInitialized
144	}
145
146	// If we already have a DB, return it
147	if c.session != nil && !c.session.Closed() {
148		return c.session, nil
149	}
150
151	session, err := c.createSession(ctx)
152	if err != nil {
153		return nil, err
154	}
155
156	//  Store the session in backend for reuse
157	c.session = session
158
159	return session, nil
160}
161
162func (c *cassandraConnectionProducer) Close() error {
163	// Grab the write lock
164	c.Lock()
165	defer c.Unlock()
166
167	if c.session != nil {
168		c.session.Close()
169	}
170
171	c.session = nil
172
173	return nil
174}
175
176func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql.Session, error) {
177	hosts := strings.Split(c.Hosts, ",")
178	clusterConfig := gocql.NewCluster(hosts...)
179	clusterConfig.Authenticator = gocql.PasswordAuthenticator{
180		Username: c.Username,
181		Password: c.Password,
182	}
183
184	if c.Port != 0 {
185		clusterConfig.Port = c.Port
186	}
187
188	clusterConfig.ProtoVersion = c.ProtocolVersion
189	if clusterConfig.ProtoVersion == 0 {
190		clusterConfig.ProtoVersion = 2
191	}
192
193	clusterConfig.Timeout = c.connectTimeout
194	clusterConfig.SocketKeepalive = c.socketKeepAlive
195	if c.TLS {
196		var tlsConfig *tls.Config
197		if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
198			if len(c.certificate) > 0 && len(c.privateKey) == 0 {
199				return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
200			}
201
202			certBundle := &certutil.CertBundle{}
203			if len(c.certificate) > 0 {
204				certBundle.Certificate = c.certificate
205				certBundle.PrivateKey = c.privateKey
206			}
207			if len(c.issuingCA) > 0 {
208				certBundle.IssuingCA = c.issuingCA
209			}
210
211			parsedCertBundle, err := certBundle.ToParsedCertBundle()
212			if err != nil {
213				return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
214			}
215
216			tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
217			if err != nil || tlsConfig == nil {
218				return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
219			}
220			tlsConfig.InsecureSkipVerify = c.InsecureTLS
221
222			if c.TLSMinVersion != "" {
223				var ok bool
224				tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
225				if !ok {
226					return nil, fmt.Errorf("invalid 'tls_min_version' in config")
227				}
228			} else {
229				// MinVersion was not being set earlier. Reset it to
230				// zero to gracefully handle upgrades.
231				tlsConfig.MinVersion = 0
232			}
233		}
234
235		clusterConfig.SslOpts = &gocql.SslOptions{
236			Config: tlsConfig,
237		}
238	}
239
240	if c.LocalDatacenter != "" {
241		clusterConfig.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(c.LocalDatacenter)
242	}
243
244	session, err := clusterConfig.CreateSession()
245	if err != nil {
246		return nil, errwrap.Wrapf("error creating session: {{err}}", err)
247	}
248
249	// Set consistency
250	if c.Consistency != "" {
251		consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
252		if err != nil {
253			session.Close()
254			return nil, err
255		}
256
257		session.SetConsistency(consistencyValue)
258	}
259
260	// Verify the info
261	if !c.SkipVerification {
262		err = session.Query(`LIST ALL`).WithContext(ctx).Exec()
263		if err != nil && len(c.Username) != 0 && strings.Contains(err.Error(), "not authorized") {
264			rowNum := session.Query(dbutil.QueryHelper(`LIST CREATE ON ALL ROLES OF '{{username}}';`, map[string]string{
265				"username": c.Username,
266			})).Iter().NumRows()
267
268			if rowNum < 1 {
269				session.Close()
270				return nil, errwrap.Wrapf("error validating connection info: No role create permissions found, previous error: {{err}}", err)
271			}
272		} else if err != nil {
273			session.Close()
274			return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
275		}
276	}
277
278	return session, nil
279}
280
281func (c *cassandraConnectionProducer) secretValues() map[string]interface{} {
282	return map[string]interface{}{
283		c.Password:  "[password]",
284		c.PemBundle: "[pem_bundle]",
285		c.PemJSON:   "[pem_json]",
286	}
287}
288
289// SetCredentials uses provided information to set/create a user in the
290// database. Unlike CreateUser, this method requires a username be provided and
291// uses the name given, instead of generating a name. This is used for creating
292// and setting the password of static accounts, as well as rolling back
293// passwords in the database in the event an updated database fails to save in
294// Vault's storage.
295func (c *cassandraConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
296	return "", "", dbutil.Unimplemented()
297}
298