1package mssql 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "sort" 8 "strconv" 9 "strings" 10 "time" 11 12 metrics "github.com/armon/go-metrics" 13 _ "github.com/denisenkom/go-mssqldb" 14 "github.com/hashicorp/errwrap" 15 log "github.com/hashicorp/go-hclog" 16 "github.com/hashicorp/vault/sdk/helper/strutil" 17 "github.com/hashicorp/vault/sdk/physical" 18) 19 20// Verify MSSQLBackend satisfies the correct interfaces 21var _ physical.Backend = (*MSSQLBackend)(nil) 22 23type MSSQLBackend struct { 24 dbTable string 25 client *sql.DB 26 statements map[string]*sql.Stmt 27 logger log.Logger 28 permitPool *physical.PermitPool 29} 30 31func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { 32 username, ok := conf["username"] 33 if !ok { 34 username = "" 35 } 36 37 password, ok := conf["password"] 38 if !ok { 39 password = "" 40 } 41 42 server, ok := conf["server"] 43 if !ok || server == "" { 44 return nil, fmt.Errorf("missing server") 45 } 46 47 port, ok := conf["port"] 48 if !ok { 49 port = "" 50 } 51 52 maxParStr, ok := conf["max_parallel"] 53 var maxParInt int 54 var err error 55 if ok { 56 maxParInt, err = strconv.Atoi(maxParStr) 57 if err != nil { 58 return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) 59 } 60 if logger.IsDebug() { 61 logger.Debug("max_parallel set", "max_parallel", maxParInt) 62 } 63 } else { 64 maxParInt = physical.DefaultParallelOperations 65 } 66 67 database, ok := conf["database"] 68 if !ok { 69 database = "Vault" 70 } 71 72 table, ok := conf["table"] 73 if !ok { 74 table = "Vault" 75 } 76 77 appname, ok := conf["appname"] 78 if !ok { 79 appname = "Vault" 80 } 81 82 connectionTimeout, ok := conf["connectiontimeout"] 83 if !ok { 84 connectionTimeout = "30" 85 } 86 87 logLevel, ok := conf["loglevel"] 88 if !ok { 89 logLevel = "0" 90 } 91 92 schema, ok := conf["schema"] 93 if !ok || schema == "" { 94 schema = "dbo" 95 } 96 97 connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel) 98 if username != "" { 99 connectionString += ";user id=" + username 100 } 101 102 if password != "" { 103 connectionString += ";password=" + password 104 } 105 106 if port != "" { 107 connectionString += ";port=" + port 108 } 109 110 db, err := sql.Open("mssql", connectionString) 111 if err != nil { 112 return nil, errwrap.Wrapf("failed to connect to mssql: {{err}}", err) 113 } 114 115 db.SetMaxOpenConns(maxParInt) 116 117 if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database); err != nil { 118 return nil, errwrap.Wrapf("failed to create mssql database: {{err}}", err) 119 } 120 121 dbTable := database + "." + schema + "." + table 122 createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema + 123 "') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))" 124 125 if schema != "dbo" { 126 127 var num int 128 err = db.QueryRow("SELECT 1 FROM " + database + ".sys.schemas WHERE name = '" + schema + "'").Scan(&num) 129 130 switch { 131 case err == sql.ErrNoRows: 132 if _, err := db.Exec("USE " + database + "; EXEC ('CREATE SCHEMA " + schema + "')"); err != nil { 133 return nil, errwrap.Wrapf("failed to create mssql schema: {{err}}", err) 134 } 135 136 case err != nil: 137 return nil, errwrap.Wrapf("failed to check if mssql schema exists: {{err}}", err) 138 } 139 } 140 141 if _, err := db.Exec(createQuery); err != nil { 142 return nil, errwrap.Wrapf("failed to create mssql table: {{err}}", err) 143 } 144 145 m := &MSSQLBackend{ 146 dbTable: dbTable, 147 client: db, 148 statements: make(map[string]*sql.Stmt), 149 logger: logger, 150 permitPool: physical.NewPermitPool(maxParInt), 151 } 152 153 statements := map[string]string{ 154 "put": "IF EXISTS(SELECT 1 FROM " + dbTable + " WHERE Path = ?) UPDATE " + dbTable + " SET Value = ? WHERE Path = ?" + 155 " ELSE INSERT INTO " + dbTable + " VALUES(?, ?)", 156 "get": "SELECT Value FROM " + dbTable + " WHERE Path = ?", 157 "delete": "DELETE FROM " + dbTable + " WHERE Path = ?", 158 "list": "SELECT Path FROM " + dbTable + " WHERE Path LIKE ?", 159 } 160 161 for name, query := range statements { 162 if err := m.prepare(name, query); err != nil { 163 return nil, err 164 } 165 } 166 167 return m, nil 168} 169 170func (m *MSSQLBackend) prepare(name, query string) error { 171 stmt, err := m.client.Prepare(query) 172 if err != nil { 173 return errwrap.Wrapf(fmt.Sprintf("failed to prepare %q: {{err}}", name), err) 174 } 175 176 m.statements[name] = stmt 177 178 return nil 179} 180 181func (m *MSSQLBackend) Put(ctx context.Context, entry *physical.Entry) error { 182 defer metrics.MeasureSince([]string{"mssql", "put"}, time.Now()) 183 184 m.permitPool.Acquire() 185 defer m.permitPool.Release() 186 187 _, err := m.statements["put"].Exec(entry.Key, entry.Value, entry.Key, entry.Key, entry.Value) 188 if err != nil { 189 return err 190 } 191 192 return nil 193} 194 195func (m *MSSQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { 196 defer metrics.MeasureSince([]string{"mssql", "get"}, time.Now()) 197 198 m.permitPool.Acquire() 199 defer m.permitPool.Release() 200 201 var result []byte 202 err := m.statements["get"].QueryRow(key).Scan(&result) 203 if err == sql.ErrNoRows { 204 return nil, nil 205 } 206 207 if err != nil { 208 return nil, err 209 } 210 211 ent := &physical.Entry{ 212 Key: key, 213 Value: result, 214 } 215 216 return ent, nil 217} 218 219func (m *MSSQLBackend) Delete(ctx context.Context, key string) error { 220 defer metrics.MeasureSince([]string{"mssql", "delete"}, time.Now()) 221 222 m.permitPool.Acquire() 223 defer m.permitPool.Release() 224 225 _, err := m.statements["delete"].Exec(key) 226 if err != nil { 227 return err 228 } 229 230 return nil 231} 232 233func (m *MSSQLBackend) List(ctx context.Context, prefix string) ([]string, error) { 234 defer metrics.MeasureSince([]string{"mssql", "list"}, time.Now()) 235 236 m.permitPool.Acquire() 237 defer m.permitPool.Release() 238 239 likePrefix := prefix + "%" 240 rows, err := m.statements["list"].Query(likePrefix) 241 if err != nil { 242 return nil, err 243 } 244 var keys []string 245 for rows.Next() { 246 var key string 247 err = rows.Scan(&key) 248 if err != nil { 249 return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err) 250 } 251 252 key = strings.TrimPrefix(key, prefix) 253 if i := strings.Index(key, "/"); i == -1 { 254 keys = append(keys, key) 255 } else if i != -1 { 256 keys = strutil.AppendIfMissing(keys, string(key[:i+1])) 257 } 258 } 259 260 sort.Strings(keys) 261 262 return keys, nil 263} 264