1package influxdb
2
3import (
4	"context"
5	"fmt"
6	"strings"
7
8	multierror "github.com/hashicorp/go-multierror"
9	dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
10	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
11	"github.com/hashicorp/vault/sdk/helper/strutil"
12	"github.com/hashicorp/vault/sdk/helper/template"
13	influx "github.com/influxdata/influxdb/client/v2"
14)
15
16const (
17	defaultUserCreationIFQL           = `CREATE USER "{{username}}" WITH PASSWORD '{{password}}';`
18	defaultUserDeletionIFQL           = `DROP USER "{{username}}";`
19	defaultRootCredentialRotationIFQL = `SET PASSWORD FOR "{{username}}" = '{{password}}';`
20	influxdbTypeName                  = "influxdb"
21
22	defaultUserNameTemplate = `{{ printf "v_%s_%s_%s_%s" (.DisplayName | truncate 15) (.RoleName | truncate 15) (random 20) (unix_time) | truncate 100 | replace "-" "_" | lowercase }}`
23)
24
25var _ dbplugin.Database = &Influxdb{}
26
27// Influxdb is an implementation of Database interface
28type Influxdb struct {
29	*influxdbConnectionProducer
30
31	usernameProducer template.StringTemplate
32}
33
34// New returns a new Cassandra instance
35func New() (interface{}, error) {
36	db := new()
37	dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
38
39	return dbType, nil
40}
41
42func new() *Influxdb {
43	connProducer := &influxdbConnectionProducer{}
44	connProducer.Type = influxdbTypeName
45
46	return &Influxdb{
47		influxdbConnectionProducer: connProducer,
48	}
49}
50
51// Type returns the TypeName for this backend
52func (i *Influxdb) Type() (string, error) {
53	return influxdbTypeName, nil
54}
55
56func (i *Influxdb) getConnection(ctx context.Context) (influx.Client, error) {
57	cli, err := i.Connection(ctx)
58	if err != nil {
59		return nil, err
60	}
61
62	return cli.(influx.Client), nil
63}
64
65func (i *Influxdb) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (resp dbplugin.InitializeResponse, err error) {
66	usernameTemplate, err := strutil.GetString(req.Config, "username_template")
67	if err != nil {
68		return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve username_template: %w", err)
69	}
70	if usernameTemplate == "" {
71		usernameTemplate = defaultUserNameTemplate
72	}
73
74	up, err := template.NewTemplate(template.Template(usernameTemplate))
75	if err != nil {
76		return dbplugin.InitializeResponse{}, fmt.Errorf("unable to initialize username template: %w", err)
77	}
78	i.usernameProducer = up
79
80	_, err = i.usernameProducer.Generate(dbplugin.UsernameMetadata{})
81	if err != nil {
82		return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err)
83	}
84
85	return i.influxdbConnectionProducer.Initialize(ctx, req)
86}
87
88// NewUser generates the username/password on the underlying Influxdb secret backend as instructed by
89// the statements provided.
90func (i *Influxdb) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (resp dbplugin.NewUserResponse, err error) {
91	i.Lock()
92	defer i.Unlock()
93
94	cli, err := i.getConnection(ctx)
95	if err != nil {
96		return dbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
97	}
98
99	creationIFQL := req.Statements.Commands
100	if len(creationIFQL) == 0 {
101		creationIFQL = []string{defaultUserCreationIFQL}
102	}
103
104	rollbackIFQL := req.RollbackStatements.Commands
105	if len(rollbackIFQL) == 0 {
106		rollbackIFQL = []string{defaultUserDeletionIFQL}
107	}
108
109	username, err := i.usernameProducer.Generate(req.UsernameConfig)
110	if err != nil {
111		return dbplugin.NewUserResponse{}, err
112	}
113
114	for _, stmt := range creationIFQL {
115		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
116			query = strings.TrimSpace(query)
117			if len(query) == 0 {
118				continue
119			}
120
121			m := map[string]string{
122				"username": username,
123				"password": req.Password,
124			}
125			qry := influx.NewQuery(dbutil.QueryHelper(query, m), "", "")
126			response, err := cli.Query(qry)
127			// err can be nil with response.Error() being not nil, so both need to be handled
128			merr := multierror.Append(err, response.Error())
129			if merr.ErrorOrNil() != nil {
130				// Attempt rollback only when the response has an error
131				if response != nil && response.Error() != nil {
132					attemptRollback(cli, username, rollbackIFQL)
133				}
134
135				return dbplugin.NewUserResponse{}, fmt.Errorf("failed to run query in InfluxDB: %w", merr)
136			}
137		}
138	}
139	resp = dbplugin.NewUserResponse{
140		Username: username,
141	}
142	return resp, nil
143}
144
145// attemptRollback will attempt to roll back user creation if an error occurs in
146// CreateUser
147func attemptRollback(cli influx.Client, username string, rollbackStatements []string) error {
148	for _, stmt := range rollbackStatements {
149		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
150			query = strings.TrimSpace(query)
151
152			if len(query) == 0 {
153				continue
154			}
155			q := influx.NewQuery(dbutil.QueryHelper(query, map[string]string{
156				"username": username,
157			}), "", "")
158
159			response, err := cli.Query(q)
160			// err can be nil with response.Error() being not nil, so both need to be handled
161			merr := multierror.Append(err, response.Error())
162			if merr.ErrorOrNil() != nil {
163				return merr
164			}
165		}
166	}
167	return nil
168}
169
170func (i *Influxdb) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
171	i.Lock()
172	defer i.Unlock()
173
174	cli, err := i.getConnection(ctx)
175	if err != nil {
176		return dbplugin.DeleteUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
177	}
178
179	revocationIFQL := req.Statements.Commands
180	if len(revocationIFQL) == 0 {
181		revocationIFQL = []string{defaultUserDeletionIFQL}
182	}
183
184	var result *multierror.Error
185	for _, stmt := range revocationIFQL {
186		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
187			query = strings.TrimSpace(query)
188			if len(query) == 0 {
189				continue
190			}
191			m := map[string]string{
192				"username": req.Username,
193			}
194			q := influx.NewQuery(dbutil.QueryHelper(query, m), "", "")
195			response, err := cli.Query(q)
196			result = multierror.Append(result, err)
197			if response != nil {
198				result = multierror.Append(result, response.Error())
199			}
200		}
201	}
202	if result.ErrorOrNil() != nil {
203		return dbplugin.DeleteUserResponse{}, fmt.Errorf("failed to delete user cleanly: %w", result.ErrorOrNil())
204	}
205	return dbplugin.DeleteUserResponse{}, nil
206}
207
208func (i *Influxdb) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) {
209	if req.Password == nil && req.Expiration == nil {
210		return dbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
211	}
212
213	i.Lock()
214	defer i.Unlock()
215
216	if req.Password != nil {
217		err := i.changeUserPassword(ctx, req.Username, req.Password)
218		if err != nil {
219			return dbplugin.UpdateUserResponse{}, fmt.Errorf("failed to change %q password: %w", req.Username, err)
220		}
221	}
222	// Expiration is a no-op
223	return dbplugin.UpdateUserResponse{}, nil
224}
225
226func (i *Influxdb) changeUserPassword(ctx context.Context, username string, changePassword *dbplugin.ChangePassword) error {
227	cli, err := i.getConnection(ctx)
228	if err != nil {
229		return fmt.Errorf("unable to get connection: %w", err)
230	}
231
232	rotateIFQL := changePassword.Statements.Commands
233	if len(rotateIFQL) == 0 {
234		rotateIFQL = []string{defaultRootCredentialRotationIFQL}
235	}
236
237	var result *multierror.Error
238	for _, stmt := range rotateIFQL {
239		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
240			query = strings.TrimSpace(query)
241			if len(query) == 0 {
242				continue
243			}
244			m := map[string]string{
245				"username": username,
246				"password": changePassword.NewPassword,
247			}
248			q := influx.NewQuery(dbutil.QueryHelper(query, m), "", "")
249			response, err := cli.Query(q)
250			result = multierror.Append(result, err)
251			if response != nil {
252				result = multierror.Append(result, response.Error())
253			}
254		}
255	}
256
257	err = result.ErrorOrNil()
258	if err != nil {
259		return fmt.Errorf("failed to execute rotation queries: %w", err)
260	}
261
262	return nil
263}
264