1package cassandra 2 3import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/mitchellh/mapstructure" 12 13 "github.com/gocql/gocql" 14 "github.com/hashicorp/errwrap" 15 "github.com/hashicorp/vault/sdk/database/dbplugin" 16 "github.com/hashicorp/vault/sdk/database/helper/connutil" 17 "github.com/hashicorp/vault/sdk/database/helper/dbutil" 18 "github.com/hashicorp/vault/sdk/helper/certutil" 19 "github.com/hashicorp/vault/sdk/helper/parseutil" 20 "github.com/hashicorp/vault/sdk/helper/tlsutil" 21) 22 23// cassandraConnectionProducer implements ConnectionProducer and provides an 24// interface for cassandra databases to make connections. 25type cassandraConnectionProducer struct { 26 Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` 27 Port int `json:"port" structs:"port" mapstructure:"port"` 28 Username string `json:"username" structs:"username" mapstructure:"username"` 29 Password string `json:"password" structs:"password" mapstructure:"password"` 30 TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` 31 InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` 32 ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` 33 ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` 34 SocketKeepAliveRaw interface{} `json:"socket_keep_alive" structs:"socket_keep_alive" mapstructure:"socket_keep_alive"` 35 TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` 36 Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` 37 LocalDatacenter string `json:"local_datacenter" structs:"local_datacenter" mapstructure:"local_datacenter"` 38 PemBundle string `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"` 39 PemJSON string `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"` 40 SkipVerification bool `json:"skip_verification" structs:"skip_verification" mapstructure:"skip_verification"` 41 42 connectTimeout time.Duration 43 socketKeepAlive time.Duration 44 certificate string 45 privateKey string 46 issuingCA string 47 rawConfig map[string]interface{} 48 49 Initialized bool 50 Type string 51 session *gocql.Session 52 sync.Mutex 53} 54 55func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { 56 _, err := c.Init(ctx, conf, verifyConnection) 57 return err 58} 59 60func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { 61 c.Lock() 62 defer c.Unlock() 63 64 c.rawConfig = conf 65 66 err := mapstructure.WeakDecode(conf, c) 67 if err != nil { 68 return nil, err 69 } 70 71 if c.ConnectTimeoutRaw == nil { 72 c.ConnectTimeoutRaw = "0s" 73 } 74 c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) 75 if err != nil { 76 return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err) 77 } 78 79 if c.SocketKeepAliveRaw == nil { 80 c.SocketKeepAliveRaw = "0s" 81 } 82 c.socketKeepAlive, err = parseutil.ParseDurationSecond(c.SocketKeepAliveRaw) 83 if err != nil { 84 return nil, errwrap.Wrapf("invalid socket_keep_alive: {{err}}", err) 85 } 86 87 switch { 88 case len(c.Hosts) == 0: 89 return nil, fmt.Errorf("hosts cannot be empty") 90 case len(c.Username) == 0: 91 return nil, fmt.Errorf("username cannot be empty") 92 case len(c.Password) == 0: 93 return nil, fmt.Errorf("password cannot be empty") 94 } 95 96 var certBundle *certutil.CertBundle 97 var parsedCertBundle *certutil.ParsedCertBundle 98 switch { 99 case len(c.PemJSON) != 0: 100 parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON)) 101 if err != nil { 102 return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err) 103 } 104 certBundle, err = parsedCertBundle.ToCertBundle() 105 if err != nil { 106 return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err) 107 } 108 c.certificate = certBundle.Certificate 109 c.privateKey = certBundle.PrivateKey 110 c.issuingCA = certBundle.IssuingCA 111 c.TLS = true 112 113 case len(c.PemBundle) != 0: 114 parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle) 115 if err != nil { 116 return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err) 117 } 118 certBundle, err = parsedCertBundle.ToCertBundle() 119 if err != nil { 120 return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err) 121 } 122 c.certificate = certBundle.Certificate 123 c.privateKey = certBundle.PrivateKey 124 c.issuingCA = certBundle.IssuingCA 125 c.TLS = true 126 } 127 128 // Set initialized to true at this point since all fields are set, 129 // and the connection can be established at a later time. 130 c.Initialized = true 131 132 if verifyConnection { 133 if _, err := c.Connection(ctx); err != nil { 134 return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) 135 } 136 } 137 138 return conf, nil 139} 140 141func (c *cassandraConnectionProducer) Connection(ctx context.Context) (interface{}, error) { 142 if !c.Initialized { 143 return nil, connutil.ErrNotInitialized 144 } 145 146 // If we already have a DB, return it 147 if c.session != nil && !c.session.Closed() { 148 return c.session, nil 149 } 150 151 session, err := c.createSession(ctx) 152 if err != nil { 153 return nil, err 154 } 155 156 // Store the session in backend for reuse 157 c.session = session 158 159 return session, nil 160} 161 162func (c *cassandraConnectionProducer) Close() error { 163 // Grab the write lock 164 c.Lock() 165 defer c.Unlock() 166 167 if c.session != nil { 168 c.session.Close() 169 } 170 171 c.session = nil 172 173 return nil 174} 175 176func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql.Session, error) { 177 hosts := strings.Split(c.Hosts, ",") 178 clusterConfig := gocql.NewCluster(hosts...) 179 clusterConfig.Authenticator = gocql.PasswordAuthenticator{ 180 Username: c.Username, 181 Password: c.Password, 182 } 183 184 if c.Port != 0 { 185 clusterConfig.Port = c.Port 186 } 187 188 clusterConfig.ProtoVersion = c.ProtocolVersion 189 if clusterConfig.ProtoVersion == 0 { 190 clusterConfig.ProtoVersion = 2 191 } 192 193 clusterConfig.Timeout = c.connectTimeout 194 clusterConfig.SocketKeepalive = c.socketKeepAlive 195 if c.TLS { 196 var tlsConfig *tls.Config 197 if len(c.certificate) > 0 || len(c.issuingCA) > 0 { 198 if len(c.certificate) > 0 && len(c.privateKey) == 0 { 199 return nil, fmt.Errorf("found certificate for TLS authentication but no private key") 200 } 201 202 certBundle := &certutil.CertBundle{} 203 if len(c.certificate) > 0 { 204 certBundle.Certificate = c.certificate 205 certBundle.PrivateKey = c.privateKey 206 } 207 if len(c.issuingCA) > 0 { 208 certBundle.IssuingCA = c.issuingCA 209 } 210 211 parsedCertBundle, err := certBundle.ToParsedCertBundle() 212 if err != nil { 213 return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err) 214 } 215 216 tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) 217 if err != nil || tlsConfig == nil { 218 return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err) 219 } 220 tlsConfig.InsecureSkipVerify = c.InsecureTLS 221 222 if c.TLSMinVersion != "" { 223 var ok bool 224 tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] 225 if !ok { 226 return nil, fmt.Errorf("invalid 'tls_min_version' in config") 227 } 228 } else { 229 // MinVersion was not being set earlier. Reset it to 230 // zero to gracefully handle upgrades. 231 tlsConfig.MinVersion = 0 232 } 233 } 234 235 clusterConfig.SslOpts = &gocql.SslOptions{ 236 Config: tlsConfig, 237 } 238 } 239 240 if c.LocalDatacenter != "" { 241 clusterConfig.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(c.LocalDatacenter) 242 } 243 244 session, err := clusterConfig.CreateSession() 245 if err != nil { 246 return nil, errwrap.Wrapf("error creating session: {{err}}", err) 247 } 248 249 // Set consistency 250 if c.Consistency != "" { 251 consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) 252 if err != nil { 253 session.Close() 254 return nil, err 255 } 256 257 session.SetConsistency(consistencyValue) 258 } 259 260 // Verify the info 261 if !c.SkipVerification { 262 err = session.Query(`LIST ALL`).WithContext(ctx).Exec() 263 if err != nil && len(c.Username) != 0 && strings.Contains(err.Error(), "not authorized") { 264 rowNum := session.Query(dbutil.QueryHelper(`LIST CREATE ON ALL ROLES OF '{{username}}';`, map[string]string{ 265 "username": c.Username, 266 })).Iter().NumRows() 267 268 if rowNum < 1 { 269 session.Close() 270 return nil, errwrap.Wrapf("error validating connection info: No role create permissions found, previous error: {{err}}", err) 271 } 272 } else if err != nil { 273 session.Close() 274 return nil, errwrap.Wrapf("error validating connection info: {{err}}", err) 275 } 276 } 277 278 return session, nil 279} 280 281func (c *cassandraConnectionProducer) secretValues() map[string]interface{} { 282 return map[string]interface{}{ 283 c.Password: "[password]", 284 c.PemBundle: "[pem_bundle]", 285 c.PemJSON: "[pem_json]", 286 } 287} 288 289// SetCredentials uses provided information to set/create a user in the 290// database. Unlike CreateUser, this method requires a username be provided and 291// uses the name given, instead of generating a name. This is used for creating 292// and setting the password of static accounts, as well as rolling back 293// passwords in the database in the event an updated database fails to save in 294// Vault's storage. 295func (c *cassandraConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { 296 return "", "", dbutil.Unimplemented() 297} 298