1package cassandra 2 3import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/mitchellh/mapstructure" 12 13 "github.com/gocql/gocql" 14 "github.com/hashicorp/errwrap" 15 "github.com/hashicorp/vault/helper/certutil" 16 "github.com/hashicorp/vault/helper/parseutil" 17 "github.com/hashicorp/vault/helper/tlsutil" 18 "github.com/hashicorp/vault/plugins/helper/database/connutil" 19 "github.com/hashicorp/vault/plugins/helper/database/dbutil" 20) 21 22// cassandraConnectionProducer implements ConnectionProducer and provides an 23// interface for cassandra databases to make connections. 24type cassandraConnectionProducer struct { 25 Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` 26 Port int `json:"port" structs:"port" mapstructure:"port"` 27 Username string `json:"username" structs:"username" mapstructure:"username"` 28 Password string `json:"password" structs:"password" mapstructure:"password"` 29 TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` 30 InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` 31 ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` 32 ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` 33 SocketKeepAliveRaw interface{} `json:"socket_keep_alive" structs:"socket_keep_alive" mapstructure:"socket_keep_alive"` 34 TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` 35 Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` 36 PemBundle string `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"` 37 PemJSON string `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"` 38 39 connectTimeout time.Duration 40 socketKeepAlive time.Duration 41 certificate string 42 privateKey string 43 issuingCA string 44 rawConfig map[string]interface{} 45 46 Initialized bool 47 Type string 48 session *gocql.Session 49 sync.Mutex 50} 51 52func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { 53 _, err := c.Init(ctx, conf, verifyConnection) 54 return err 55} 56 57func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { 58 c.Lock() 59 defer c.Unlock() 60 61 c.rawConfig = conf 62 63 err := mapstructure.WeakDecode(conf, c) 64 if err != nil { 65 return nil, err 66 } 67 68 if c.ConnectTimeoutRaw == nil { 69 c.ConnectTimeoutRaw = "0s" 70 } 71 c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) 72 if err != nil { 73 return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err) 74 } 75 76 if c.SocketKeepAliveRaw == nil { 77 c.SocketKeepAliveRaw = "0s" 78 } 79 c.socketKeepAlive, err = parseutil.ParseDurationSecond(c.SocketKeepAliveRaw) 80 if err != nil { 81 return nil, errwrap.Wrapf("invalid socket_keep_alive: {{err}}", err) 82 } 83 84 switch { 85 case len(c.Hosts) == 0: 86 return nil, fmt.Errorf("hosts cannot be empty") 87 case len(c.Username) == 0: 88 return nil, fmt.Errorf("username cannot be empty") 89 case len(c.Password) == 0: 90 return nil, fmt.Errorf("password cannot be empty") 91 } 92 93 var certBundle *certutil.CertBundle 94 var parsedCertBundle *certutil.ParsedCertBundle 95 switch { 96 case len(c.PemJSON) != 0: 97 parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON)) 98 if err != nil { 99 return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err) 100 } 101 certBundle, err = parsedCertBundle.ToCertBundle() 102 if err != nil { 103 return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err) 104 } 105 c.certificate = certBundle.Certificate 106 c.privateKey = certBundle.PrivateKey 107 c.issuingCA = certBundle.IssuingCA 108 c.TLS = true 109 110 case len(c.PemBundle) != 0: 111 parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle) 112 if err != nil { 113 return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err) 114 } 115 certBundle, err = parsedCertBundle.ToCertBundle() 116 if err != nil { 117 return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err) 118 } 119 c.certificate = certBundle.Certificate 120 c.privateKey = certBundle.PrivateKey 121 c.issuingCA = certBundle.IssuingCA 122 c.TLS = true 123 } 124 125 // Set initialized to true at this point since all fields are set, 126 // and the connection can be established at a later time. 127 c.Initialized = true 128 129 if verifyConnection { 130 if _, err := c.Connection(ctx); err != nil { 131 return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) 132 } 133 } 134 135 return conf, nil 136} 137 138func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) { 139 if !c.Initialized { 140 return nil, connutil.ErrNotInitialized 141 } 142 143 // If we already have a DB, return it 144 if c.session != nil && !c.session.Closed() { 145 return c.session, nil 146 } 147 148 session, err := c.createSession() 149 if err != nil { 150 return nil, err 151 } 152 153 // Store the session in backend for reuse 154 c.session = session 155 156 return session, nil 157} 158 159func (c *cassandraConnectionProducer) Close() error { 160 // Grab the write lock 161 c.Lock() 162 defer c.Unlock() 163 164 if c.session != nil { 165 c.session.Close() 166 } 167 168 c.session = nil 169 170 return nil 171} 172 173func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { 174 hosts := strings.Split(c.Hosts, ",") 175 clusterConfig := gocql.NewCluster(hosts...) 176 clusterConfig.Authenticator = gocql.PasswordAuthenticator{ 177 Username: c.Username, 178 Password: c.Password, 179 } 180 181 if c.Port != 0 { 182 clusterConfig.Port = c.Port 183 } 184 185 clusterConfig.ProtoVersion = c.ProtocolVersion 186 if clusterConfig.ProtoVersion == 0 { 187 clusterConfig.ProtoVersion = 2 188 } 189 190 clusterConfig.Timeout = c.connectTimeout 191 clusterConfig.SocketKeepalive = c.socketKeepAlive 192 if c.TLS { 193 var tlsConfig *tls.Config 194 if len(c.certificate) > 0 || len(c.issuingCA) > 0 { 195 if len(c.certificate) > 0 && len(c.privateKey) == 0 { 196 return nil, fmt.Errorf("found certificate for TLS authentication but no private key") 197 } 198 199 certBundle := &certutil.CertBundle{} 200 if len(c.certificate) > 0 { 201 certBundle.Certificate = c.certificate 202 certBundle.PrivateKey = c.privateKey 203 } 204 if len(c.issuingCA) > 0 { 205 certBundle.IssuingCA = c.issuingCA 206 } 207 208 parsedCertBundle, err := certBundle.ToParsedCertBundle() 209 if err != nil { 210 return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err) 211 } 212 213 tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) 214 if err != nil || tlsConfig == nil { 215 return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err) 216 } 217 tlsConfig.InsecureSkipVerify = c.InsecureTLS 218 219 if c.TLSMinVersion != "" { 220 var ok bool 221 tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] 222 if !ok { 223 return nil, fmt.Errorf("invalid 'tls_min_version' in config") 224 } 225 } else { 226 // MinVersion was not being set earlier. Reset it to 227 // zero to gracefully handle upgrades. 228 tlsConfig.MinVersion = 0 229 } 230 } 231 232 clusterConfig.SslOpts = &gocql.SslOptions{ 233 Config: tlsConfig, 234 } 235 } 236 237 session, err := clusterConfig.CreateSession() 238 if err != nil { 239 return nil, errwrap.Wrapf("error creating session: {{err}}", err) 240 } 241 242 // Set consistency 243 if c.Consistency != "" { 244 consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) 245 if err != nil { 246 return nil, err 247 } 248 249 session.SetConsistency(consistencyValue) 250 } 251 252 // Verify the info 253 err = session.Query(`LIST ALL`).Exec() 254 if err != nil && len(c.Username) != 0 && strings.Contains(err.Error(), "not authorized") { 255 rowNum := session.Query(dbutil.QueryHelper(`LIST CREATE ON ALL ROLES OF '{{username}}';`, map[string]string{ 256 "username": c.Username, 257 })).Iter().NumRows() 258 259 if rowNum < 1 { 260 return nil, errwrap.Wrapf("error validating connection info: No role create permissions found, previous error: {{err}}", err) 261 } 262 } else if err != nil { 263 return nil, errwrap.Wrapf("error validating connection info: {{err}}", err) 264 } 265 266 return session, nil 267} 268 269func (c *cassandraConnectionProducer) secretValues() map[string]interface{} { 270 return map[string]interface{}{ 271 c.Password: "[password]", 272 c.PemBundle: "[pem_bundle]", 273 c.PemJSON: "[pem_json]", 274 } 275} 276