1package cassandra 2 3import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/mitchellh/mapstructure" 12 13 "github.com/gocql/gocql" 14 "github.com/hashicorp/errwrap" 15 "github.com/hashicorp/vault/sdk/database/dbplugin" 16 "github.com/hashicorp/vault/sdk/database/helper/connutil" 17 "github.com/hashicorp/vault/sdk/database/helper/dbutil" 18 "github.com/hashicorp/vault/sdk/helper/certutil" 19 "github.com/hashicorp/vault/sdk/helper/parseutil" 20 "github.com/hashicorp/vault/sdk/helper/tlsutil" 21) 22 23// cassandraConnectionProducer implements ConnectionProducer and provides an 24// interface for cassandra databases to make connections. 25type cassandraConnectionProducer struct { 26 Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` 27 Port int `json:"port" structs:"port" mapstructure:"port"` 28 Username string `json:"username" structs:"username" mapstructure:"username"` 29 Password string `json:"password" structs:"password" mapstructure:"password"` 30 TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` 31 InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` 32 ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` 33 ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` 34 SocketKeepAliveRaw interface{} `json:"socket_keep_alive" structs:"socket_keep_alive" mapstructure:"socket_keep_alive"` 35 TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` 36 Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` 37 LocalDatacenter string `json:"local_datacenter" structs:"local_datacenter" mapstructure:"local_datacenter"` 38 PemBundle string `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"` 39 PemJSON string `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"` 40 41 connectTimeout time.Duration 42 socketKeepAlive time.Duration 43 certificate string 44 privateKey string 45 issuingCA string 46 rawConfig map[string]interface{} 47 48 Initialized bool 49 Type string 50 session *gocql.Session 51 sync.Mutex 52} 53 54func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { 55 _, err := c.Init(ctx, conf, verifyConnection) 56 return err 57} 58 59func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { 60 c.Lock() 61 defer c.Unlock() 62 63 c.rawConfig = conf 64 65 err := mapstructure.WeakDecode(conf, c) 66 if err != nil { 67 return nil, err 68 } 69 70 if c.ConnectTimeoutRaw == nil { 71 c.ConnectTimeoutRaw = "0s" 72 } 73 c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) 74 if err != nil { 75 return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err) 76 } 77 78 if c.SocketKeepAliveRaw == nil { 79 c.SocketKeepAliveRaw = "0s" 80 } 81 c.socketKeepAlive, err = parseutil.ParseDurationSecond(c.SocketKeepAliveRaw) 82 if err != nil { 83 return nil, errwrap.Wrapf("invalid socket_keep_alive: {{err}}", err) 84 } 85 86 switch { 87 case len(c.Hosts) == 0: 88 return nil, fmt.Errorf("hosts cannot be empty") 89 case len(c.Username) == 0: 90 return nil, fmt.Errorf("username cannot be empty") 91 case len(c.Password) == 0: 92 return nil, fmt.Errorf("password cannot be empty") 93 } 94 95 var certBundle *certutil.CertBundle 96 var parsedCertBundle *certutil.ParsedCertBundle 97 switch { 98 case len(c.PemJSON) != 0: 99 parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON)) 100 if err != nil { 101 return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err) 102 } 103 certBundle, err = parsedCertBundle.ToCertBundle() 104 if err != nil { 105 return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err) 106 } 107 c.certificate = certBundle.Certificate 108 c.privateKey = certBundle.PrivateKey 109 c.issuingCA = certBundle.IssuingCA 110 c.TLS = true 111 112 case len(c.PemBundle) != 0: 113 parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle) 114 if err != nil { 115 return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err) 116 } 117 certBundle, err = parsedCertBundle.ToCertBundle() 118 if err != nil { 119 return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err) 120 } 121 c.certificate = certBundle.Certificate 122 c.privateKey = certBundle.PrivateKey 123 c.issuingCA = certBundle.IssuingCA 124 c.TLS = true 125 } 126 127 // Set initialized to true at this point since all fields are set, 128 // and the connection can be established at a later time. 129 c.Initialized = true 130 131 if verifyConnection { 132 if _, err := c.Connection(ctx); err != nil { 133 return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) 134 } 135 } 136 137 return conf, nil 138} 139 140func (c *cassandraConnectionProducer) Connection(ctx context.Context) (interface{}, error) { 141 if !c.Initialized { 142 return nil, connutil.ErrNotInitialized 143 } 144 145 // If we already have a DB, return it 146 if c.session != nil && !c.session.Closed() { 147 return c.session, nil 148 } 149 150 session, err := c.createSession(ctx) 151 if err != nil { 152 return nil, err 153 } 154 155 // Store the session in backend for reuse 156 c.session = session 157 158 return session, nil 159} 160 161func (c *cassandraConnectionProducer) Close() error { 162 // Grab the write lock 163 c.Lock() 164 defer c.Unlock() 165 166 if c.session != nil { 167 c.session.Close() 168 } 169 170 c.session = nil 171 172 return nil 173} 174 175func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql.Session, error) { 176 hosts := strings.Split(c.Hosts, ",") 177 clusterConfig := gocql.NewCluster(hosts...) 178 clusterConfig.Authenticator = gocql.PasswordAuthenticator{ 179 Username: c.Username, 180 Password: c.Password, 181 } 182 183 if c.Port != 0 { 184 clusterConfig.Port = c.Port 185 } 186 187 clusterConfig.ProtoVersion = c.ProtocolVersion 188 if clusterConfig.ProtoVersion == 0 { 189 clusterConfig.ProtoVersion = 2 190 } 191 192 clusterConfig.Timeout = c.connectTimeout 193 clusterConfig.SocketKeepalive = c.socketKeepAlive 194 if c.TLS { 195 var tlsConfig *tls.Config 196 if len(c.certificate) > 0 || len(c.issuingCA) > 0 { 197 if len(c.certificate) > 0 && len(c.privateKey) == 0 { 198 return nil, fmt.Errorf("found certificate for TLS authentication but no private key") 199 } 200 201 certBundle := &certutil.CertBundle{} 202 if len(c.certificate) > 0 { 203 certBundle.Certificate = c.certificate 204 certBundle.PrivateKey = c.privateKey 205 } 206 if len(c.issuingCA) > 0 { 207 certBundle.IssuingCA = c.issuingCA 208 } 209 210 parsedCertBundle, err := certBundle.ToParsedCertBundle() 211 if err != nil { 212 return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err) 213 } 214 215 tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) 216 if err != nil || tlsConfig == nil { 217 return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err) 218 } 219 tlsConfig.InsecureSkipVerify = c.InsecureTLS 220 221 if c.TLSMinVersion != "" { 222 var ok bool 223 tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] 224 if !ok { 225 return nil, fmt.Errorf("invalid 'tls_min_version' in config") 226 } 227 } else { 228 // MinVersion was not being set earlier. Reset it to 229 // zero to gracefully handle upgrades. 230 tlsConfig.MinVersion = 0 231 } 232 } 233 234 clusterConfig.SslOpts = &gocql.SslOptions{ 235 Config: tlsConfig, 236 } 237 } 238 239 if c.LocalDatacenter != "" { 240 clusterConfig.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(c.LocalDatacenter) 241 } 242 243 session, err := clusterConfig.CreateSession() 244 if err != nil { 245 return nil, errwrap.Wrapf("error creating session: {{err}}", err) 246 } 247 248 // Set consistency 249 if c.Consistency != "" { 250 consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) 251 if err != nil { 252 return nil, err 253 } 254 255 session.SetConsistency(consistencyValue) 256 } 257 258 // Verify the info 259 err = session.Query(`LIST ALL`).WithContext(ctx).Exec() 260 if err != nil && len(c.Username) != 0 && strings.Contains(err.Error(), "not authorized") { 261 rowNum := session.Query(dbutil.QueryHelper(`LIST CREATE ON ALL ROLES OF '{{username}}';`, map[string]string{ 262 "username": c.Username, 263 })).Iter().NumRows() 264 265 if rowNum < 1 { 266 return nil, errwrap.Wrapf("error validating connection info: No role create permissions found, previous error: {{err}}", err) 267 } 268 } else if err != nil { 269 return nil, errwrap.Wrapf("error validating connection info: {{err}}", err) 270 } 271 272 return session, nil 273} 274 275func (c *cassandraConnectionProducer) secretValues() map[string]interface{} { 276 return map[string]interface{}{ 277 c.Password: "[password]", 278 c.PemBundle: "[pem_bundle]", 279 c.PemJSON: "[pem_json]", 280 } 281} 282 283// SetCredentials uses provided information to set/create a user in the 284// database. Unlike CreateUser, this method requires a username be provided and 285// uses the name given, instead of generating a name. This is used for creating 286// and setting the password of static accounts, as well as rolling back 287// passwords in the database in the event an updated database fails to save in 288// Vault's storage. 289func (c *cassandraConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { 290 return "", "", dbutil.Unimplemented() 291} 292