1package credsutil
2
3import (
4	"context"
5	"fmt"
6	"time"
7
8	"github.com/hashicorp/vault/sdk/database/dbplugin"
9)
10
11const (
12	NoneLength int = -1
13)
14
15// SQLCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types.
16type SQLCredentialsProducer struct {
17	DisplayNameLen int
18	RoleNameLen    int
19	UsernameLen    int
20	Separator      string
21}
22
23func (scp *SQLCredentialsProducer) GenerateCredentials(ctx context.Context) (string, error) {
24	password, err := scp.GeneratePassword()
25	if err != nil {
26		return "", err
27	}
28	return password, nil
29}
30
31func (scp *SQLCredentialsProducer) GenerateUsername(config dbplugin.UsernameConfig) (string, error) {
32	username := "v"
33
34	displayName := config.DisplayName
35	if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen {
36		displayName = displayName[:scp.DisplayNameLen]
37	} else if scp.DisplayNameLen == NoneLength {
38		displayName = ""
39	}
40
41	if len(displayName) > 0 {
42		username = fmt.Sprintf("%s%s%s", username, scp.Separator, displayName)
43	}
44
45	roleName := config.RoleName
46	if scp.RoleNameLen > 0 && len(roleName) > scp.RoleNameLen {
47		roleName = roleName[:scp.RoleNameLen]
48	} else if scp.RoleNameLen == NoneLength {
49		roleName = ""
50	}
51
52	if len(roleName) > 0 {
53		username = fmt.Sprintf("%s%s%s", username, scp.Separator, roleName)
54	}
55
56	userUUID, err := RandomAlphaNumeric(20, false)
57	if err != nil {
58		return "", err
59	}
60
61	username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID)
62	username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().Unix()))
63	if scp.UsernameLen > 0 && len(username) > scp.UsernameLen {
64		username = username[:scp.UsernameLen]
65	}
66
67	return username, nil
68}
69
70func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) {
71	password, err := RandomAlphaNumeric(20, true)
72	if err != nil {
73		return "", err
74	}
75
76	return password, nil
77}
78
79func (scp *SQLCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
80	return ttl.Format("2006-01-02 15:04:05-0700"), nil
81}
82