1// Package kms provides Key Management Services support
2package kms
3
4import (
5	"encoding/json"
6	"errors"
7	"os"
8	"strings"
9	"sync"
10
11	"github.com/drakkan/sftpgo/v2/logger"
12	"github.com/drakkan/sftpgo/v2/util"
13)
14
15// SecretProvider defines the interface for a KMS secrets provider
16type SecretProvider interface {
17	Name() string
18	Encrypt() error
19	Decrypt() error
20	IsEncrypted() bool
21	GetStatus() SecretStatus
22	GetPayload() string
23	GetKey() string
24	GetAdditionalData() string
25	GetMode() int
26	SetKey(string)
27	SetAdditionalData(string)
28	SetStatus(SecretStatus)
29	Clone() SecretProvider
30}
31
32const (
33	logSender = "kms"
34)
35
36// SecretStatus defines the statuses of a Secret object
37type SecretStatus = string
38
39const (
40	// SecretStatusPlain means the secret is in plain text and must be encrypted
41	SecretStatusPlain SecretStatus = "Plain"
42	// SecretStatusAES256GCM means the secret is encrypted using AES-256-GCM
43	SecretStatusAES256GCM SecretStatus = "AES-256-GCM"
44	// SecretStatusSecretBox means the secret is encrypted using a locally provided symmetric key
45	SecretStatusSecretBox SecretStatus = "Secretbox"
46	// SecretStatusGCP means we use keys from Google Cloud Platform’s Key Management Service
47	// (GCP KMS) to keep information secret
48	SecretStatusGCP SecretStatus = "GCP"
49	// SecretStatusAWS means we use customer master keys from Amazon Web Service’s
50	// Key Management Service (AWS KMS) to keep information secret
51	SecretStatusAWS SecretStatus = "AWS"
52	// SecretStatusVaultTransit means we use the transit secrets engine in Vault
53	// to keep information secret
54	SecretStatusVaultTransit SecretStatus = "VaultTransit"
55	// SecretStatusAzureKeyVault means we use Azure KeyVault to keep information secret
56	SecretStatusAzureKeyVault SecretStatus = "AzureKeyVault"
57	// SecretStatusRedacted means the secret is redacted
58	SecretStatusRedacted SecretStatus = "Redacted"
59)
60
61// Scheme defines the supported URL scheme
62type Scheme = string
63
64// supported URL schemes
65const (
66	SchemeLocal         Scheme = "local"
67	SchemeBuiltin       Scheme = "builtin"
68	SchemeAWS           Scheme = "awskms"
69	SchemeGCP           Scheme = "gcpkms"
70	SchemeVaultTransit  Scheme = "hashivault"
71	SchemeAzureKeyVault Scheme = "azurekeyvault"
72)
73
74// Configuration defines the KMS configuration
75type Configuration struct {
76	Secrets Secrets `json:"secrets" mapstructure:"secrets"`
77}
78
79// Secrets define the KMS configuration for encryption/decryption
80type Secrets struct {
81	URL             string `json:"url" mapstructure:"url"`
82	MasterKeyPath   string `json:"master_key_path" mapstructure:"master_key_path"`
83	MasterKeyString string `json:"master_key" mapstructure:"master_key"`
84	masterKey       string
85}
86
87type registeredSecretProvider struct {
88	encryptedStatus SecretStatus
89	newFn           func(base BaseSecret, url, masterKey string) SecretProvider
90}
91
92var (
93	// ErrWrongSecretStatus defines the error to return if the secret status is not appropriate
94	// for the request operation
95	ErrWrongSecretStatus = errors.New("wrong secret status")
96	// ErrInvalidSecret defines the error to return if a secret is not valid
97	ErrInvalidSecret       = errors.New("invalid secret")
98	errMalformedCiphertext = errors.New("malformed ciphertext")
99	validSecretStatuses    = []string{SecretStatusPlain, SecretStatusAES256GCM, SecretStatusSecretBox,
100		SecretStatusVaultTransit, SecretStatusAWS, SecretStatusGCP, SecretStatusRedacted}
101	config          Configuration
102	secretProviders = make(map[string]registeredSecretProvider)
103)
104
105// RegisterSecretProvider register a new secret provider
106func RegisterSecretProvider(scheme string, encryptedStatus SecretStatus, fn func(base BaseSecret, url, masterKey string) SecretProvider) {
107	secretProviders[scheme] = registeredSecretProvider{
108		encryptedStatus: encryptedStatus,
109		newFn:           fn,
110	}
111}
112
113// NewSecret builds a new Secret using the provided arguments
114func NewSecret(status SecretStatus, payload, key, data string) *Secret {
115	return config.newSecret(status, payload, key, data)
116}
117
118// NewEmptySecret returns an empty secret
119func NewEmptySecret() *Secret {
120	return NewSecret("", "", "", "")
121}
122
123// NewPlainSecret stores the give payload in a plain text secret
124func NewPlainSecret(payload string) *Secret {
125	return NewSecret(SecretStatusPlain, payload, "", "")
126}
127
128// GetSecretFromCompatString returns a secret from the previous format
129func GetSecretFromCompatString(secret string) (*Secret, error) {
130	plain, err := util.DecryptData(secret)
131	if err != nil {
132		return &Secret{}, errMalformedCiphertext
133	}
134	return NewSecret(SecretStatusPlain, plain, "", ""), nil
135}
136
137// Initialize configures the KMS support
138func (c *Configuration) Initialize() error {
139	if c.Secrets.MasterKeyString != "" {
140		c.Secrets.masterKey = c.Secrets.MasterKeyString
141	}
142	if c.Secrets.masterKey == "" && c.Secrets.MasterKeyPath != "" {
143		mKey, err := os.ReadFile(c.Secrets.MasterKeyPath)
144		if err != nil {
145			return err
146		}
147		c.Secrets.masterKey = strings.TrimSpace(string(mKey))
148	}
149	config = *c
150	if config.Secrets.URL == "" {
151		config.Secrets.URL = SchemeLocal + "://"
152	}
153	for k, v := range secretProviders {
154		logger.Debug(logSender, "", "secret provider registered for scheme: %#v, encrypted status: %#v",
155			k, v.encryptedStatus)
156	}
157	return nil
158}
159
160func (c *Configuration) newSecret(status SecretStatus, payload, key, data string) *Secret {
161	base := BaseSecret{
162		Status:         status,
163		Key:            key,
164		Payload:        payload,
165		AdditionalData: data,
166	}
167	return &Secret{
168		provider: c.getSecretProvider(base),
169	}
170}
171
172func (c *Configuration) getSecretProvider(base BaseSecret) SecretProvider {
173	for k, v := range secretProviders {
174		if strings.HasPrefix(c.Secrets.URL, k) {
175			return v.newFn(base, c.Secrets.URL, c.Secrets.masterKey)
176		}
177	}
178	logger.Warn(logSender, "", "no secret provider registered for URL %v, fallback to local provider", c.Secrets.URL)
179	return NewLocalSecret(base, c.Secrets.URL, c.Secrets.masterKey)
180}
181
182// Secret defines the struct used to store confidential data
183type Secret struct {
184	sync.RWMutex
185	provider SecretProvider
186}
187
188// MarshalJSON return the JSON encoding of the Secret object
189func (s *Secret) MarshalJSON() ([]byte, error) {
190	s.RLock()
191	defer s.RUnlock()
192
193	return json.Marshal(&BaseSecret{
194		Status:         s.provider.GetStatus(),
195		Payload:        s.provider.GetPayload(),
196		Key:            s.provider.GetKey(),
197		AdditionalData: s.provider.GetAdditionalData(),
198		Mode:           s.provider.GetMode(),
199	})
200}
201
202// UnmarshalJSON parses the JSON-encoded data and stores the result
203// in the Secret object
204func (s *Secret) UnmarshalJSON(data []byte) error {
205	s.Lock()
206	defer s.Unlock()
207
208	baseSecret := BaseSecret{}
209	err := json.Unmarshal(data, &baseSecret)
210	if err != nil {
211		return err
212	}
213	if baseSecret.isEmpty() {
214		s.provider = config.getSecretProvider(baseSecret)
215		return nil
216	}
217
218	if baseSecret.Status == SecretStatusPlain || baseSecret.Status == SecretStatusRedacted {
219		s.provider = config.getSecretProvider(baseSecret)
220		return nil
221	}
222
223	for _, v := range secretProviders {
224		if v.encryptedStatus == baseSecret.Status {
225			s.provider = v.newFn(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
226			return nil
227		}
228	}
229	logger.Debug(logSender, "", "no provider registered for status %#v", baseSecret.Status)
230
231	return ErrInvalidSecret
232}
233
234// IsEqual returns true if all the secrets fields are equal
235func (s *Secret) IsEqual(other *Secret) bool {
236	if s.GetStatus() != other.GetStatus() {
237		return false
238	}
239	if s.GetPayload() != other.GetPayload() {
240		return false
241	}
242	if s.GetKey() != other.GetKey() {
243		return false
244	}
245	if s.GetAdditionalData() != other.GetAdditionalData() {
246		return false
247	}
248	if s.GetMode() != other.GetMode() {
249		return false
250	}
251	return true
252}
253
254// Clone returns a copy of the secret object
255func (s *Secret) Clone() *Secret {
256	s.RLock()
257	defer s.RUnlock()
258
259	return &Secret{
260		provider: s.provider.Clone(),
261	}
262}
263
264// IsEncrypted returns true if the secret is encrypted
265// This isn't a pointer receiver because we don't want to pass
266// a pointer to html template
267func (s *Secret) IsEncrypted() bool {
268	s.RLock()
269	defer s.RUnlock()
270
271	return s.provider.IsEncrypted()
272}
273
274// IsPlain returns true if the secret is in plain text
275func (s *Secret) IsPlain() bool {
276	s.RLock()
277	defer s.RUnlock()
278
279	return s.provider.GetStatus() == SecretStatusPlain
280}
281
282// IsNotPlainAndNotEmpty returns true if the secret is not plain and not empty.
283// This is an utility method, we update the secret for an existing user
284// if it is empty or plain
285func (s *Secret) IsNotPlainAndNotEmpty() bool {
286	s.RLock()
287	defer s.RUnlock()
288
289	return !s.IsPlain() && !s.IsEmpty()
290}
291
292// IsRedacted returns true if the secret is redacted
293func (s *Secret) IsRedacted() bool {
294	s.RLock()
295	defer s.RUnlock()
296
297	return s.provider.GetStatus() == SecretStatusRedacted
298}
299
300// GetPayload returns the secret payload
301func (s *Secret) GetPayload() string {
302	s.RLock()
303	defer s.RUnlock()
304
305	return s.provider.GetPayload()
306}
307
308// GetAdditionalData returns the secret additional data
309func (s *Secret) GetAdditionalData() string {
310	s.RLock()
311	defer s.RUnlock()
312
313	return s.provider.GetAdditionalData()
314}
315
316// GetStatus returns the secret status
317func (s *Secret) GetStatus() SecretStatus {
318	s.RLock()
319	defer s.RUnlock()
320
321	return s.provider.GetStatus()
322}
323
324// GetKey returns the secret key
325func (s *Secret) GetKey() string {
326	s.RLock()
327	defer s.RUnlock()
328
329	return s.provider.GetKey()
330}
331
332// GetMode returns the secret mode
333func (s *Secret) GetMode() int {
334	s.RLock()
335	defer s.RUnlock()
336
337	return s.provider.GetMode()
338}
339
340// SetAdditionalData sets the given additional data
341func (s *Secret) SetAdditionalData(value string) {
342	s.Lock()
343	defer s.Unlock()
344
345	s.provider.SetAdditionalData(value)
346}
347
348// SetStatus sets the status for this secret
349func (s *Secret) SetStatus(value SecretStatus) {
350	s.Lock()
351	defer s.Unlock()
352
353	s.provider.SetStatus(value)
354}
355
356// SetKey sets the key for this secret
357func (s *Secret) SetKey(value string) {
358	s.Lock()
359	defer s.Unlock()
360
361	s.provider.SetKey(value)
362}
363
364// IsEmpty returns true if all fields are empty
365func (s *Secret) IsEmpty() bool {
366	s.RLock()
367	defer s.RUnlock()
368
369	if s.provider.GetStatus() != "" {
370		return false
371	}
372	if s.provider.GetPayload() != "" {
373		return false
374	}
375	if s.provider.GetKey() != "" {
376		return false
377	}
378	if s.provider.GetAdditionalData() != "" {
379		return false
380	}
381	return true
382}
383
384// IsValid returns true if the secret is not empty and valid
385func (s *Secret) IsValid() bool {
386	s.RLock()
387	defer s.RUnlock()
388
389	if !s.IsValidInput() {
390		return false
391	}
392	switch s.provider.GetStatus() {
393	case SecretStatusAES256GCM, SecretStatusSecretBox:
394		if len(s.provider.GetKey()) != 64 {
395			return false
396		}
397	case SecretStatusAWS, SecretStatusGCP, SecretStatusVaultTransit:
398		key := s.provider.GetKey()
399		if key != "" && len(key) != 64 {
400			return false
401		}
402	}
403	return true
404}
405
406// IsValidInput returns true if the secret is a valid user input
407func (s *Secret) IsValidInput() bool {
408	s.RLock()
409	defer s.RUnlock()
410
411	if !util.IsStringInSlice(s.provider.GetStatus(), validSecretStatuses) {
412		return false
413	}
414	if s.provider.GetPayload() == "" {
415		return false
416	}
417	return true
418}
419
420// Hide hides info to decrypt data
421func (s *Secret) Hide() {
422	s.Lock()
423	defer s.Unlock()
424
425	s.provider.SetKey("")
426	s.provider.SetAdditionalData("")
427}
428
429// Encrypt encrypts a plain text Secret object
430func (s *Secret) Encrypt() error {
431	s.Lock()
432	defer s.Unlock()
433
434	return s.provider.Encrypt()
435}
436
437// Decrypt decrypts a Secret object
438func (s *Secret) Decrypt() error {
439	s.Lock()
440	defer s.Unlock()
441
442	return s.provider.Decrypt()
443}
444
445// TryDecrypt decrypts a Secret object if encrypted.
446// It returns a nil error if the object is not encrypted
447func (s *Secret) TryDecrypt() error {
448	s.Lock()
449	defer s.Unlock()
450
451	if s.provider.IsEncrypted() {
452		return s.provider.Decrypt()
453	}
454	return nil
455}
456