1package salt
2
3import (
4	"context"
5	"crypto/hmac"
6	"crypto/sha1"
7	"crypto/sha256"
8	"encoding/hex"
9	"fmt"
10	"hash"
11
12	"github.com/hashicorp/errwrap"
13	uuid "github.com/hashicorp/go-uuid"
14	"github.com/hashicorp/vault/sdk/logical"
15)
16
17const (
18	// DefaultLocation is the path in the view we store our key salt
19	// if no other path is provided.
20	DefaultLocation = "salt"
21)
22
23// Salt is used to manage a persistent salt key which is used to
24// hash values. This allows keys to be generated and recovered
25// using the global salt. Primarily, this allows paths in the storage
26// backend to be obfuscated if they may contain sensitive information.
27type Salt struct {
28	config    *Config
29	salt      string
30	generated bool
31}
32
33type HashFunc func([]byte) []byte
34
35// Config is used to parameterize the Salt
36type Config struct {
37	// Location is the path in the storage backend for the
38	// salt. Uses DefaultLocation if not specified.
39	Location string
40
41	// HashFunc is the hashing function to use for salting.
42	// Defaults to SHA1 if not provided.
43	HashFunc HashFunc
44
45	// HMAC allows specification of a hash function to use for
46	// the HMAC helpers
47	HMAC func() hash.Hash
48
49	// String prepended to HMAC strings for identification.
50	// Required if using HMAC
51	HMACType string
52}
53
54// NewSalt creates a new salt based on the configuration
55func NewSalt(ctx context.Context, view logical.Storage, config *Config) (*Salt, error) {
56	// Setup the configuration
57	if config == nil {
58		config = &Config{}
59	}
60	if config.Location == "" {
61		config.Location = DefaultLocation
62	}
63	if config.HashFunc == nil {
64		config.HashFunc = SHA256Hash
65	}
66	if config.HMAC == nil {
67		config.HMAC = sha256.New
68		config.HMACType = "hmac-sha256"
69	}
70
71	// Create the salt
72	s := &Salt{
73		config: config,
74	}
75
76	// Look for the salt
77	var raw *logical.StorageEntry
78	var err error
79	if view != nil {
80		raw, err = view.Get(ctx, config.Location)
81		if err != nil {
82			return nil, errwrap.Wrapf("failed to read salt: {{err}}", err)
83		}
84	}
85
86	// Restore the salt if it exists
87	if raw != nil {
88		s.salt = string(raw.Value)
89	}
90
91	// Generate a new salt if necessary
92	if s.salt == "" {
93		s.salt, err = uuid.GenerateUUID()
94		if err != nil {
95			return nil, errwrap.Wrapf("failed to generate uuid: {{err}}", err)
96		}
97		s.generated = true
98		if view != nil {
99			raw := &logical.StorageEntry{
100				Key:   config.Location,
101				Value: []byte(s.salt),
102			}
103			if err := view.Put(ctx, raw); err != nil {
104				return nil, errwrap.Wrapf("failed to persist salt: {{err}}", err)
105			}
106		}
107	}
108
109	if config.HMAC != nil {
110		if len(config.HMACType) == 0 {
111			return nil, fmt.Errorf("HMACType must be defined")
112		}
113	}
114
115	return s, nil
116}
117
118// SaltID is used to apply a salt and hash function to an ID to make sure
119// it is not reversible
120func (s *Salt) SaltID(id string) string {
121	return SaltID(s.salt, id, s.config.HashFunc)
122}
123
124// GetHMAC is used to apply a salt and hash function to data to make sure it is
125// not reversible, with an additional HMAC
126func (s *Salt) GetHMAC(data string) string {
127	hm := hmac.New(s.config.HMAC, []byte(s.salt))
128	hm.Write([]byte(data))
129	return hex.EncodeToString(hm.Sum(nil))
130}
131
132// GetIdentifiedHMAC is used to apply a salt and hash function to data to make
133// sure it is not reversible, with an additional HMAC, and ID prepended
134func (s *Salt) GetIdentifiedHMAC(data string) string {
135	return s.config.HMACType + ":" + s.GetHMAC(data)
136}
137
138// DidGenerate returns if the underlying salt value was generated
139// on initialization or if an existing salt value was loaded
140func (s *Salt) DidGenerate() bool {
141	return s.generated
142}
143
144// SaltIDHashFunc uses the supplied hash function instead of the configured
145// hash func in the salt.
146func (s *Salt) SaltIDHashFunc(id string, hashFunc HashFunc) string {
147	return SaltID(s.salt, id, hashFunc)
148}
149
150// SaltID is used to apply a salt and hash function to an ID to make sure
151// it is not reversible
152func SaltID(salt, id string, hash HashFunc) string {
153	comb := salt + id
154	hashVal := hash([]byte(comb))
155	return hex.EncodeToString(hashVal)
156}
157
158func HMACValue(salt, val string, hashFunc func() hash.Hash) string {
159	hm := hmac.New(hashFunc, []byte(salt))
160	hm.Write([]byte(val))
161	return hex.EncodeToString(hm.Sum(nil))
162}
163
164func HMACIdentifiedValue(salt, val, hmacType string, hashFunc func() hash.Hash) string {
165	return hmacType + ":" + HMACValue(salt, val, hashFunc)
166}
167
168// SHA1Hash returns the SHA1 of the input
169func SHA1Hash(inp []byte) []byte {
170	hashed := sha1.Sum(inp)
171	return hashed[:]
172}
173
174// SHA256Hash returns the SHA256 of the input
175func SHA256Hash(inp []byte) []byte {
176	hashed := sha256.Sum256(inp)
177	return hashed[:]
178}
179