1package mongodb 2 3import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/base64" 8 "encoding/json" 9 "fmt" 10 "sync" 11 "time" 12 13 "github.com/hashicorp/vault/sdk/database/helper/connutil" 14 "github.com/hashicorp/vault/sdk/database/helper/dbutil" 15 "github.com/mitchellh/mapstructure" 16 "go.mongodb.org/mongo-driver/mongo" 17 "go.mongodb.org/mongo-driver/mongo/options" 18 "go.mongodb.org/mongo-driver/mongo/readpref" 19 "go.mongodb.org/mongo-driver/mongo/writeconcern" 20) 21 22// mongoDBConnectionProducer implements ConnectionProducer and provides an 23// interface for databases to make connections. 24type mongoDBConnectionProducer struct { 25 ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` 26 WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"` 27 28 Username string `json:"username" structs:"username" mapstructure:"username"` 29 Password string `json:"password" structs:"password" mapstructure:"password"` 30 31 TLSCertificateKeyData []byte `json:"tls_certificate_key" structs:"-" mapstructure:"tls_certificate_key"` 32 TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"` 33 34 SocketTimeout time.Duration `json:"socket_timeout" structs:"-" mapstructure:"socket_timeout"` 35 ConnectTimeout time.Duration `json:"connect_timeout" structs:"-" mapstructure:"connect_timeout"` 36 ServerSelectionTimeout time.Duration `json:"server_selection_timeout" structs:"-" mapstructure:"server_selection_timeout"` 37 38 Initialized bool 39 RawConfig map[string]interface{} 40 Type string 41 clientOptions *options.ClientOptions 42 client *mongo.Client 43 sync.Mutex 44} 45 46// writeConcern defines the write concern options 47type writeConcern struct { 48 W int // Min # of servers to ack before success 49 WMode string // Write mode for MongoDB 2.0+ (e.g. "majority") 50 WTimeout int // Milliseconds to wait for W before timing out 51 FSync bool // DEPRECATED: Is now handled by J. See: https://jira.mongodb.org/browse/CXX-910 52 J bool // Sync via the journal if present 53} 54 55func (c *mongoDBConnectionProducer) loadConfig(cfg map[string]interface{}) error { 56 err := mapstructure.WeakDecode(cfg, c) 57 if err != nil { 58 return err 59 } 60 61 if len(c.ConnectionURL) == 0 { 62 return fmt.Errorf("connection_url cannot be empty") 63 } 64 65 if c.SocketTimeout < 0 { 66 return fmt.Errorf("socket_timeout must be >= 0") 67 } 68 if c.ConnectTimeout < 0 { 69 return fmt.Errorf("connect_timeout must be >= 0") 70 } 71 if c.ServerSelectionTimeout < 0 { 72 return fmt.Errorf("server_selection_timeout must be >= 0") 73 } 74 75 opts, err := c.makeClientOpts() 76 if err != nil { 77 return err 78 } 79 80 c.clientOptions = opts 81 82 return nil 83} 84 85// Connection creates or returns an existing a database connection. If the session fails 86// on a ping check, the session will be closed and then re-created. 87// This method does locks the mutex on its own. 88func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (*mongo.Client, error) { 89 if !c.Initialized { 90 return nil, connutil.ErrNotInitialized 91 } 92 93 c.Mutex.Lock() 94 defer c.Mutex.Unlock() 95 96 if c.client != nil { 97 if err := c.client.Ping(ctx, readpref.Primary()); err == nil { 98 return c.client, nil 99 } 100 // Ignore error on purpose since we want to re-create a session 101 _ = c.client.Disconnect(ctx) 102 } 103 104 client, err := c.createClient(ctx) 105 if err != nil { 106 return nil, err 107 } 108 c.client = client 109 return c.client, nil 110} 111 112func (c *mongoDBConnectionProducer) createClient(ctx context.Context) (client *mongo.Client, err error) { 113 if !c.Initialized { 114 return nil, fmt.Errorf("failed to create client: connection producer is not initialized") 115 } 116 if c.clientOptions == nil { 117 return nil, fmt.Errorf("missing client options") 118 } 119 client, err = mongo.Connect(ctx, options.MergeClientOptions(options.Client().ApplyURI(c.getConnectionURL()), c.clientOptions)) 120 if err != nil { 121 return nil, err 122 } 123 return client, nil 124} 125 126// Close terminates the database connection. 127func (c *mongoDBConnectionProducer) Close() error { 128 c.Lock() 129 defer c.Unlock() 130 131 if c.client != nil { 132 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) 133 defer cancel() 134 if err := c.client.Disconnect(ctx); err != nil { 135 return err 136 } 137 } 138 139 c.client = nil 140 141 return nil 142} 143 144func (c *mongoDBConnectionProducer) secretValues() map[string]string { 145 return map[string]string{ 146 c.Password: "[password]", 147 } 148} 149 150func (c *mongoDBConnectionProducer) getConnectionURL() (connURL string) { 151 connURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{ 152 "username": c.Username, 153 "password": c.Password, 154 }) 155 return connURL 156} 157 158func (c *mongoDBConnectionProducer) makeClientOpts() (*options.ClientOptions, error) { 159 writeOpts, err := c.getWriteConcern() 160 if err != nil { 161 return nil, err 162 } 163 164 authOpts, err := c.getTLSAuth() 165 if err != nil { 166 return nil, err 167 } 168 169 timeoutOpts, err := c.timeoutOpts() 170 if err != nil { 171 return nil, err 172 } 173 174 opts := options.MergeClientOptions(writeOpts, authOpts, timeoutOpts) 175 return opts, nil 176} 177 178func (c *mongoDBConnectionProducer) getWriteConcern() (opts *options.ClientOptions, err error) { 179 if c.WriteConcern == "" { 180 return nil, nil 181 } 182 183 input := c.WriteConcern 184 185 // Try to base64 decode the input. If successful, consider the decoded 186 // value as input. 187 inputBytes, err := base64.StdEncoding.DecodeString(input) 188 if err == nil { 189 input = string(inputBytes) 190 } 191 192 concern := &writeConcern{} 193 err = json.Unmarshal([]byte(input), concern) 194 if err != nil { 195 return nil, fmt.Errorf("error unmarshalling write_concern: %w", err) 196 } 197 198 // Translate write concern to mongo options 199 var w writeconcern.Option 200 switch { 201 case concern.W != 0: 202 w = writeconcern.W(concern.W) 203 case concern.WMode != "": 204 w = writeconcern.WTagSet(concern.WMode) 205 default: 206 w = writeconcern.WMajority() 207 } 208 209 var j writeconcern.Option 210 switch { 211 case concern.FSync: 212 j = writeconcern.J(concern.FSync) 213 case concern.J: 214 j = writeconcern.J(concern.J) 215 default: 216 j = writeconcern.J(false) 217 } 218 219 writeConcern := writeconcern.New( 220 w, 221 j, 222 writeconcern.WTimeout(time.Duration(concern.WTimeout)*time.Millisecond)) 223 224 opts = options.Client() 225 opts.SetWriteConcern(writeConcern) 226 return opts, nil 227} 228 229func (c *mongoDBConnectionProducer) getTLSAuth() (opts *options.ClientOptions, err error) { 230 if len(c.TLSCAData) == 0 && len(c.TLSCertificateKeyData) == 0 { 231 return nil, nil 232 } 233 234 opts = options.Client() 235 236 tlsConfig := &tls.Config{} 237 238 if len(c.TLSCAData) > 0 { 239 tlsConfig.RootCAs = x509.NewCertPool() 240 241 ok := tlsConfig.RootCAs.AppendCertsFromPEM(c.TLSCAData) 242 if !ok { 243 return nil, fmt.Errorf("failed to append CA to client options") 244 } 245 } 246 247 if len(c.TLSCertificateKeyData) > 0 { 248 certificate, err := tls.X509KeyPair(c.TLSCertificateKeyData, c.TLSCertificateKeyData) 249 if err != nil { 250 return nil, fmt.Errorf("unable to load tls_certificate_key_data: %w", err) 251 } 252 253 opts.SetAuth(options.Credential{ 254 AuthMechanism: "MONGODB-X509", 255 Username: c.Username, 256 }) 257 258 tlsConfig.Certificates = append(tlsConfig.Certificates, certificate) 259 } 260 261 opts.SetTLSConfig(tlsConfig) 262 return opts, nil 263} 264 265func (c *mongoDBConnectionProducer) timeoutOpts() (opts *options.ClientOptions, err error) { 266 opts = options.Client() 267 268 if c.SocketTimeout < 0 { 269 return nil, fmt.Errorf("socket_timeout must be >= 0") 270 } 271 272 if c.SocketTimeout == 0 { 273 opts.SetSocketTimeout(1 * time.Minute) 274 } else { 275 opts.SetSocketTimeout(c.SocketTimeout) 276 } 277 278 if c.ConnectTimeout == 0 { 279 opts.SetConnectTimeout(1 * time.Minute) 280 } else { 281 opts.SetConnectTimeout(c.ConnectTimeout) 282 } 283 284 if c.ServerSelectionTimeout != 0 { 285 opts.SetServerSelectionTimeout(c.ServerSelectionTimeout) 286 } 287 288 return opts, nil 289} 290