1package mongodb
2
3import (
4	"context"
5	"crypto/tls"
6	"crypto/x509"
7	"encoding/base64"
8	"encoding/json"
9	"fmt"
10	"sync"
11	"time"
12
13	"github.com/hashicorp/errwrap"
14	"github.com/hashicorp/vault/sdk/database/helper/connutil"
15	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
16	"github.com/mitchellh/mapstructure"
17	"go.mongodb.org/mongo-driver/mongo"
18	"go.mongodb.org/mongo-driver/mongo/options"
19	"go.mongodb.org/mongo-driver/mongo/readpref"
20	"go.mongodb.org/mongo-driver/mongo/writeconcern"
21)
22
23// mongoDBConnectionProducer implements ConnectionProducer and provides an
24// interface for databases to make connections.
25type mongoDBConnectionProducer struct {
26	ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
27	WriteConcern  string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"`
28
29	Username string `json:"username" structs:"username" mapstructure:"username"`
30	Password string `json:"password" structs:"password" mapstructure:"password"`
31
32	TLSCertificateKeyData []byte `json:"tls_certificate_key" structs:"-" mapstructure:"tls_certificate_key"`
33	TLSCAData             []byte `json:"tls_ca"              structs:"-" mapstructure:"tls_ca"`
34
35	Initialized   bool
36	RawConfig     map[string]interface{}
37	Type          string
38	clientOptions *options.ClientOptions
39	client        *mongo.Client
40	sync.Mutex
41}
42
43// writeConcern defines the write concern options
44type writeConcern struct {
45	W        int    // Min # of servers to ack before success
46	WMode    string // Write mode for MongoDB 2.0+ (e.g. "majority")
47	WTimeout int    // Milliseconds to wait for W before timing out
48	FSync    bool   // DEPRECATED: Is now handled by J. See: https://jira.mongodb.org/browse/CXX-910
49	J        bool   // Sync via the journal if present
50}
51
52func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
53	_, err := c.Init(ctx, conf, verifyConnection)
54	return err
55}
56
57// Initialize parses connection configuration.
58func (c *mongoDBConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
59	c.Lock()
60	defer c.Unlock()
61
62	c.RawConfig = conf
63
64	err := mapstructure.WeakDecode(conf, c)
65	if err != nil {
66		return nil, err
67	}
68
69	if len(c.ConnectionURL) == 0 {
70		return nil, fmt.Errorf("connection_url cannot be empty")
71	}
72
73	writeOpts, err := c.getWriteConcern()
74	if err != nil {
75		return nil, err
76	}
77
78	authOpts, err := c.getTLSAuth()
79	if err != nil {
80		return nil, err
81	}
82
83	c.ConnectionURL = c.getConnectionURL()
84	c.clientOptions = options.MergeClientOptions(writeOpts, authOpts)
85
86	// Set initialized to true at this point since all fields are set,
87	// and the connection can be established at a later time.
88	c.Initialized = true
89
90	if verifyConnection {
91		if _, err := c.Connection(ctx); err != nil {
92			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
93		}
94
95		if err := c.client.Ping(ctx, readpref.Primary()); err != nil {
96			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
97		}
98	}
99
100	return conf, nil
101}
102
103// Connection creates or returns an existing a database connection. If the session fails
104// on a ping check, the session will be closed and then re-created.
105// This method does not lock the mutex and it is intended that this is the callers
106// responsibility.
107func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
108	if !c.Initialized {
109		return nil, connutil.ErrNotInitialized
110	}
111
112	if c.client != nil {
113		if err := c.client.Ping(ctx, readpref.Primary()); err == nil {
114			return c.client, nil
115		}
116		// Ignore error on purpose since we want to re-create a session
117		_ = c.client.Disconnect(ctx)
118	}
119
120	if c.clientOptions == nil {
121		c.clientOptions = options.Client()
122	}
123	c.clientOptions.SetSocketTimeout(1 * time.Minute)
124	c.clientOptions.SetConnectTimeout(1 * time.Minute)
125
126	var err error
127	opts := c.clientOptions.ApplyURI(c.ConnectionURL)
128	c.client, err = mongo.Connect(ctx, opts)
129	if err != nil {
130		return nil, err
131	}
132	return c.client, nil
133}
134
135// Close terminates the database connection.
136func (c *mongoDBConnectionProducer) Close() error {
137	c.Lock()
138	defer c.Unlock()
139
140	if c.client != nil {
141		ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
142		defer cancel()
143		if err := c.client.Disconnect(ctx); err != nil {
144			return err
145		}
146	}
147
148	c.client = nil
149
150	return nil
151}
152
153func (c *mongoDBConnectionProducer) secretValues() map[string]interface{} {
154	return map[string]interface{}{
155		c.Password: "[password]",
156	}
157}
158
159func (c *mongoDBConnectionProducer) getConnectionURL() (connURL string) {
160	connURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
161		"username": c.Username,
162		"password": c.Password,
163	})
164	return connURL
165}
166
167func (c *mongoDBConnectionProducer) getWriteConcern() (opts *options.ClientOptions, err error) {
168	if c.WriteConcern == "" {
169		return nil, nil
170	}
171
172	input := c.WriteConcern
173
174	// Try to base64 decode the input. If successful, consider the decoded
175	// value as input.
176	inputBytes, err := base64.StdEncoding.DecodeString(input)
177	if err == nil {
178		input = string(inputBytes)
179	}
180
181	concern := &writeConcern{}
182	err = json.Unmarshal([]byte(input), concern)
183	if err != nil {
184		return nil, errwrap.Wrapf("error unmarshalling write_concern: {{err}}", err)
185	}
186
187	// Translate write concern to mongo options
188	var w writeconcern.Option
189	switch {
190	case concern.W != 0:
191		w = writeconcern.W(concern.W)
192	case concern.WMode != "":
193		w = writeconcern.WTagSet(concern.WMode)
194	default:
195		w = writeconcern.WMajority()
196	}
197
198	var j writeconcern.Option
199	switch {
200	case concern.FSync:
201		j = writeconcern.J(concern.FSync)
202	case concern.J:
203		j = writeconcern.J(concern.J)
204	default:
205		j = writeconcern.J(false)
206	}
207
208	writeConcern := writeconcern.New(
209		w,
210		j,
211		writeconcern.WTimeout(time.Duration(concern.WTimeout)*time.Millisecond))
212
213	opts = options.Client()
214	opts.SetWriteConcern(writeConcern)
215	return opts, nil
216}
217
218func (c *mongoDBConnectionProducer) getTLSAuth() (opts *options.ClientOptions, err error) {
219	if len(c.TLSCAData) == 0 && len(c.TLSCertificateKeyData) == 0 {
220		return nil, nil
221	}
222
223	opts = options.Client()
224
225	tlsConfig := &tls.Config{}
226
227	if len(c.TLSCAData) > 0 {
228		tlsConfig.RootCAs = x509.NewCertPool()
229
230		ok := tlsConfig.RootCAs.AppendCertsFromPEM(c.TLSCAData)
231		if !ok {
232			return nil, fmt.Errorf("failed to append CA to client options")
233		}
234	}
235
236	if len(c.TLSCertificateKeyData) > 0 {
237		certificate, err := tls.X509KeyPair(c.TLSCertificateKeyData, c.TLSCertificateKeyData)
238		if err != nil {
239			return nil, fmt.Errorf("unable to load tls_certificate_key_data: %w", err)
240		}
241
242		opts.SetAuth(options.Credential{
243			AuthMechanism: "MONGODB-X509",
244			Username:      c.Username,
245		})
246
247		tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
248	}
249
250	opts.SetTLSConfig(tlsConfig)
251	return opts, nil
252}
253