1package mongodb
2
3import (
4	"context"
5	"crypto/tls"
6	"crypto/x509"
7	"encoding/base64"
8	"encoding/json"
9	"fmt"
10	"sync"
11	"time"
12
13	"github.com/hashicorp/vault/sdk/database/helper/connutil"
14	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
15	"github.com/mitchellh/mapstructure"
16	"go.mongodb.org/mongo-driver/mongo"
17	"go.mongodb.org/mongo-driver/mongo/options"
18	"go.mongodb.org/mongo-driver/mongo/readpref"
19	"go.mongodb.org/mongo-driver/mongo/writeconcern"
20)
21
22// mongoDBConnectionProducer implements ConnectionProducer and provides an
23// interface for databases to make connections.
24type mongoDBConnectionProducer struct {
25	ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
26	WriteConcern  string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"`
27
28	Username string `json:"username" structs:"username" mapstructure:"username"`
29	Password string `json:"password" structs:"password" mapstructure:"password"`
30
31	TLSCertificateKeyData []byte `json:"tls_certificate_key" structs:"-" mapstructure:"tls_certificate_key"`
32	TLSCAData             []byte `json:"tls_ca"              structs:"-" mapstructure:"tls_ca"`
33
34	SocketTimeout          time.Duration `json:"socket_timeout"           structs:"-" mapstructure:"socket_timeout"`
35	ConnectTimeout         time.Duration `json:"connect_timeout"          structs:"-" mapstructure:"connect_timeout"`
36	ServerSelectionTimeout time.Duration `json:"server_selection_timeout" structs:"-" mapstructure:"server_selection_timeout"`
37
38	Initialized   bool
39	RawConfig     map[string]interface{}
40	Type          string
41	clientOptions *options.ClientOptions
42	client        *mongo.Client
43	sync.Mutex
44}
45
46// writeConcern defines the write concern options
47type writeConcern struct {
48	W        int    // Min # of servers to ack before success
49	WMode    string // Write mode for MongoDB 2.0+ (e.g. "majority")
50	WTimeout int    // Milliseconds to wait for W before timing out
51	FSync    bool   // DEPRECATED: Is now handled by J. See: https://jira.mongodb.org/browse/CXX-910
52	J        bool   // Sync via the journal if present
53}
54
55func (c *mongoDBConnectionProducer) loadConfig(cfg map[string]interface{}) error {
56	err := mapstructure.WeakDecode(cfg, c)
57	if err != nil {
58		return err
59	}
60
61	if len(c.ConnectionURL) == 0 {
62		return fmt.Errorf("connection_url cannot be empty")
63	}
64
65	if c.SocketTimeout < 0 {
66		return fmt.Errorf("socket_timeout must be >= 0")
67	}
68	if c.ConnectTimeout < 0 {
69		return fmt.Errorf("connect_timeout must be >= 0")
70	}
71	if c.ServerSelectionTimeout < 0 {
72		return fmt.Errorf("server_selection_timeout must be >= 0")
73	}
74
75	opts, err := c.makeClientOpts()
76	if err != nil {
77		return err
78	}
79
80	c.clientOptions = opts
81
82	return nil
83}
84
85// Connection creates or returns an existing a database connection. If the session fails
86// on a ping check, the session will be closed and then re-created.
87// This method does locks the mutex on its own.
88func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (*mongo.Client, error) {
89	if !c.Initialized {
90		return nil, connutil.ErrNotInitialized
91	}
92
93	c.Mutex.Lock()
94	defer c.Mutex.Unlock()
95
96	if c.client != nil {
97		if err := c.client.Ping(ctx, readpref.Primary()); err == nil {
98			return c.client, nil
99		}
100		// Ignore error on purpose since we want to re-create a session
101		_ = c.client.Disconnect(ctx)
102	}
103
104	client, err := c.createClient(ctx)
105	if err != nil {
106		return nil, err
107	}
108	c.client = client
109	return c.client, nil
110}
111
112func (c *mongoDBConnectionProducer) createClient(ctx context.Context) (client *mongo.Client, err error) {
113	if !c.Initialized {
114		return nil, fmt.Errorf("failed to create client: connection producer is not initialized")
115	}
116	if c.clientOptions == nil {
117		return nil, fmt.Errorf("missing client options")
118	}
119	client, err = mongo.Connect(ctx, options.MergeClientOptions(options.Client().ApplyURI(c.getConnectionURL()), c.clientOptions))
120	if err != nil {
121		return nil, err
122	}
123	return client, nil
124}
125
126// Close terminates the database connection.
127func (c *mongoDBConnectionProducer) Close() error {
128	c.Lock()
129	defer c.Unlock()
130
131	if c.client != nil {
132		ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
133		defer cancel()
134		if err := c.client.Disconnect(ctx); err != nil {
135			return err
136		}
137	}
138
139	c.client = nil
140
141	return nil
142}
143
144func (c *mongoDBConnectionProducer) secretValues() map[string]string {
145	return map[string]string{
146		c.Password: "[password]",
147	}
148}
149
150func (c *mongoDBConnectionProducer) getConnectionURL() (connURL string) {
151	connURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
152		"username": c.Username,
153		"password": c.Password,
154	})
155	return connURL
156}
157
158func (c *mongoDBConnectionProducer) makeClientOpts() (*options.ClientOptions, error) {
159	writeOpts, err := c.getWriteConcern()
160	if err != nil {
161		return nil, err
162	}
163
164	authOpts, err := c.getTLSAuth()
165	if err != nil {
166		return nil, err
167	}
168
169	timeoutOpts, err := c.timeoutOpts()
170	if err != nil {
171		return nil, err
172	}
173
174	opts := options.MergeClientOptions(writeOpts, authOpts, timeoutOpts)
175	return opts, nil
176}
177
178func (c *mongoDBConnectionProducer) getWriteConcern() (opts *options.ClientOptions, err error) {
179	if c.WriteConcern == "" {
180		return nil, nil
181	}
182
183	input := c.WriteConcern
184
185	// Try to base64 decode the input. If successful, consider the decoded
186	// value as input.
187	inputBytes, err := base64.StdEncoding.DecodeString(input)
188	if err == nil {
189		input = string(inputBytes)
190	}
191
192	concern := &writeConcern{}
193	err = json.Unmarshal([]byte(input), concern)
194	if err != nil {
195		return nil, fmt.Errorf("error unmarshalling write_concern: %w", err)
196	}
197
198	// Translate write concern to mongo options
199	var w writeconcern.Option
200	switch {
201	case concern.W != 0:
202		w = writeconcern.W(concern.W)
203	case concern.WMode != "":
204		w = writeconcern.WTagSet(concern.WMode)
205	default:
206		w = writeconcern.WMajority()
207	}
208
209	var j writeconcern.Option
210	switch {
211	case concern.FSync:
212		j = writeconcern.J(concern.FSync)
213	case concern.J:
214		j = writeconcern.J(concern.J)
215	default:
216		j = writeconcern.J(false)
217	}
218
219	writeConcern := writeconcern.New(
220		w,
221		j,
222		writeconcern.WTimeout(time.Duration(concern.WTimeout)*time.Millisecond))
223
224	opts = options.Client()
225	opts.SetWriteConcern(writeConcern)
226	return opts, nil
227}
228
229func (c *mongoDBConnectionProducer) getTLSAuth() (opts *options.ClientOptions, err error) {
230	if len(c.TLSCAData) == 0 && len(c.TLSCertificateKeyData) == 0 {
231		return nil, nil
232	}
233
234	opts = options.Client()
235
236	tlsConfig := &tls.Config{}
237
238	if len(c.TLSCAData) > 0 {
239		tlsConfig.RootCAs = x509.NewCertPool()
240
241		ok := tlsConfig.RootCAs.AppendCertsFromPEM(c.TLSCAData)
242		if !ok {
243			return nil, fmt.Errorf("failed to append CA to client options")
244		}
245	}
246
247	if len(c.TLSCertificateKeyData) > 0 {
248		certificate, err := tls.X509KeyPair(c.TLSCertificateKeyData, c.TLSCertificateKeyData)
249		if err != nil {
250			return nil, fmt.Errorf("unable to load tls_certificate_key_data: %w", err)
251		}
252
253		opts.SetAuth(options.Credential{
254			AuthMechanism: "MONGODB-X509",
255			Username:      c.Username,
256		})
257
258		tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
259	}
260
261	opts.SetTLSConfig(tlsConfig)
262	return opts, nil
263}
264
265func (c *mongoDBConnectionProducer) timeoutOpts() (opts *options.ClientOptions, err error) {
266	opts = options.Client()
267
268	if c.SocketTimeout < 0 {
269		return nil, fmt.Errorf("socket_timeout must be >= 0")
270	}
271
272	if c.SocketTimeout == 0 {
273		opts.SetSocketTimeout(1 * time.Minute)
274	} else {
275		opts.SetSocketTimeout(c.SocketTimeout)
276	}
277
278	if c.ConnectTimeout == 0 {
279		opts.SetConnectTimeout(1 * time.Minute)
280	} else {
281		opts.SetConnectTimeout(c.ConnectTimeout)
282	}
283
284	if c.ServerSelectionTimeout != 0 {
285		opts.SetServerSelectionTimeout(c.ServerSelectionTimeout)
286	}
287
288	return opts, nil
289}
290