1package mysql
2
3import (
4	"context"
5	"database/sql"
6	"errors"
7	"strings"
8	"time"
9
10	stdmysql "github.com/go-sql-driver/mysql"
11	"github.com/hashicorp/vault/api"
12	"github.com/hashicorp/vault/sdk/database/dbplugin"
13	"github.com/hashicorp/vault/sdk/database/helper/connutil"
14	"github.com/hashicorp/vault/sdk/database/helper/credsutil"
15	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
16	"github.com/hashicorp/vault/sdk/helper/strutil"
17)
18
19const (
20	defaultMysqlRevocationStmts = `
21		REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
22		DROP USER '{{name}}'@'%'
23	`
24
25	defaultMySQLRotateCredentialsSQL = `
26		ALTER USER '{{username}}'@'%' IDENTIFIED BY '{{password}}';
27	`
28
29	mySQLTypeName = "mysql"
30)
31
32var (
33	MetadataLen       int = 10
34	LegacyMetadataLen int = 4
35	UsernameLen       int = 32
36	LegacyUsernameLen int = 16
37)
38
39var _ dbplugin.Database = (*MySQL)(nil)
40
41type MySQL struct {
42	*connutil.SQLConnectionProducer
43	credsutil.CredentialsProducer
44}
45
46// New implements builtinplugins.BuiltinFactory
47func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, error) {
48	return func() (interface{}, error) {
49		db := new(displayNameLen, roleNameLen, usernameLen)
50		// Wrap the plugin with middleware to sanitize errors
51		dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
52
53		return dbType, nil
54	}
55}
56
57func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
58	connProducer := &connutil.SQLConnectionProducer{}
59	connProducer.Type = mySQLTypeName
60
61	credsProducer := &credsutil.SQLCredentialsProducer{
62		DisplayNameLen: displayNameLen,
63		RoleNameLen:    roleNameLen,
64		UsernameLen:    usernameLen,
65		Separator:      "-",
66	}
67
68	return &MySQL{
69		SQLConnectionProducer: connProducer,
70		CredentialsProducer:   credsProducer,
71	}
72}
73
74// Run instantiates a MySQL object, and runs the RPC server for the plugin
75func Run(apiTLSConfig *api.TLSConfig) error {
76	return runCommon(false, apiTLSConfig)
77}
78
79// Run instantiates a MySQL object, and runs the RPC server for the plugin
80func RunLegacy(apiTLSConfig *api.TLSConfig) error {
81	return runCommon(true, apiTLSConfig)
82}
83
84func runCommon(legacy bool, apiTLSConfig *api.TLSConfig) error {
85	var f func() (interface{}, error)
86	if legacy {
87		f = New(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen)
88	} else {
89		f = New(MetadataLen, MetadataLen, UsernameLen)
90	}
91	dbType, err := f()
92	if err != nil {
93		return err
94	}
95
96	dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
97
98	return nil
99}
100
101func (m *MySQL) Type() (string, error) {
102	return mySQLTypeName, nil
103}
104
105func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) {
106	db, err := m.Connection(ctx)
107	if err != nil {
108		return nil, err
109	}
110
111	return db.(*sql.DB), nil
112}
113
114func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
115	statements = dbutil.StatementCompatibilityHelper(statements)
116
117	if len(statements.Creation) == 0 {
118		return "", "", dbutil.ErrEmptyCreationStatement
119	}
120
121	username, err = m.GenerateUsername(usernameConfig)
122	if err != nil {
123		return "", "", err
124	}
125
126	password, err = m.GeneratePassword()
127	if err != nil {
128		return "", "", err
129	}
130
131	expirationStr, err := m.GenerateExpiration(expiration)
132	if err != nil {
133		return "", "", err
134	}
135
136	queryMap := map[string]string{
137		"name":       username,
138		"password":   password,
139		"expiration": expirationStr,
140	}
141
142	if err := m.executePreparedStatmentsWithMap(ctx, statements.Creation, queryMap); err != nil {
143		return "", "", err
144	}
145	return username, password, nil
146}
147
148// NOOP
149func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
150	return nil
151}
152
153func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
154	// Grab the read lock
155	m.Lock()
156	defer m.Unlock()
157
158	statements = dbutil.StatementCompatibilityHelper(statements)
159
160	// Get the connection
161	db, err := m.getConnection(ctx)
162	if err != nil {
163		return err
164	}
165
166	revocationStmts := statements.Revocation
167	// Use a default SQL statement for revocation if one cannot be fetched from the role
168	if len(revocationStmts) == 0 {
169		revocationStmts = []string{defaultMysqlRevocationStmts}
170	}
171
172	// Start a transaction
173	tx, err := db.BeginTx(ctx, nil)
174	if err != nil {
175		return err
176	}
177	defer tx.Rollback()
178
179	for _, stmt := range revocationStmts {
180		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
181			query = strings.TrimSpace(query)
182			if len(query) == 0 {
183				continue
184			}
185
186			// This is not a prepared statement because not all commands are supported
187			// 1295: This command is not supported in the prepared statement protocol yet
188			// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
189			query = strings.Replace(query, "{{name}}", username, -1)
190			_, err = tx.ExecContext(ctx, query)
191			if err != nil {
192				return err
193			}
194		}
195	}
196
197	// Commit the transaction
198	if err := tx.Commit(); err != nil {
199		return err
200	}
201
202	return nil
203}
204
205func (m *MySQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
206	m.Lock()
207	defer m.Unlock()
208
209	if len(m.Username) == 0 || len(m.Password) == 0 {
210		return nil, errors.New("username and password are required to rotate")
211	}
212
213	rotateStatements := statements
214	if len(rotateStatements) == 0 {
215		rotateStatements = []string{defaultMySQLRotateCredentialsSQL}
216	}
217
218	db, err := m.getConnection(ctx)
219	if err != nil {
220		return nil, err
221	}
222
223	tx, err := db.BeginTx(ctx, nil)
224	if err != nil {
225		return nil, err
226	}
227	defer func() {
228		tx.Rollback()
229	}()
230
231	password, err := m.GeneratePassword()
232	if err != nil {
233		return nil, err
234	}
235
236	for _, stmt := range rotateStatements {
237		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
238			query = strings.TrimSpace(query)
239			if len(query) == 0 {
240				continue
241			}
242
243			// This is not a prepared statement because not all commands are supported
244			// 1295: This command is not supported in the prepared statement protocol yet
245			// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
246			query = strings.Replace(query, "{{username}}", m.Username, -1)
247			query = strings.Replace(query, "{{password}}", password, -1)
248
249			if _, err := tx.ExecContext(ctx, query); err != nil {
250				return nil, err
251			}
252		}
253	}
254
255	if err := tx.Commit(); err != nil {
256		return nil, err
257	}
258
259	if err := db.Close(); err != nil {
260		return nil, err
261	}
262
263	m.RawConfig["password"] = password
264	return m.RawConfig, nil
265}
266
267// SetCredentials uses provided information to set the password to a user in the
268// database. Unlike CreateUser, this method requires a username be provided and
269// uses the name given, instead of generating a name. This is used for setting
270// the password of static accounts, as well as rolling back passwords in the
271// database in the event an updated database fails to save in Vault's storage.
272func (m *MySQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
273	rotateStatements := statements.Rotation
274	if len(rotateStatements) == 0 {
275		rotateStatements = []string{defaultMySQLRotateCredentialsSQL}
276	}
277
278	username = staticUser.Username
279	password = staticUser.Password
280	if username == "" || password == "" {
281		return "", "", errors.New("must provide both username and password")
282	}
283
284	queryMap := map[string]string{
285		"name":     username,
286		"password": password,
287	}
288
289	if err := m.executePreparedStatmentsWithMap(ctx, statements.Rotation, queryMap); err != nil {
290		return "", "", err
291	}
292	return username, password, nil
293}
294
295// executePreparedStatmentsWithMap loops through the given templated SQL statements and
296// applies the a map to them, interpolating values into the templates,returning
297// tthe resulting username and password
298func (m *MySQL) executePreparedStatmentsWithMap(ctx context.Context, statements []string, queryMap map[string]string) error {
299	// Grab the lock
300	m.Lock()
301	defer m.Unlock()
302
303	// Get the connection
304	db, err := m.getConnection(ctx)
305	if err != nil {
306		return err
307	}
308	// Start a transaction
309	tx, err := db.BeginTx(ctx, nil)
310	if err != nil {
311		return err
312	}
313	defer func() {
314		_ = tx.Rollback()
315	}()
316
317	// Execute each query
318	for _, stmt := range statements {
319		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
320			query = strings.TrimSpace(query)
321			if len(query) == 0 {
322				continue
323			}
324
325			query = dbutil.QueryHelper(query, queryMap)
326
327			stmt, err := tx.PrepareContext(ctx, query)
328			if err != nil {
329				// If the error code we get back is Error 1295: This command is not
330				// supported in the prepared statement protocol yet, we will execute
331				// the statement without preparing it. This allows the caller to
332				// manually prepare statements, as well as run other not yet
333				// prepare supported commands. If there is no error when running we
334				// will continue to the next statement.
335				if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
336					_, err = tx.ExecContext(ctx, query)
337					if err != nil {
338						stmt.Close()
339						return err
340					}
341					continue
342				}
343
344				return err
345			}
346			if _, err := stmt.ExecContext(ctx); err != nil {
347				stmt.Close()
348				return err
349			}
350			stmt.Close()
351		}
352	}
353
354	// Commit the transaction
355	if err := tx.Commit(); err != nil {
356		return err
357	}
358	return nil
359}
360