1package encryption
2
3import (
4	"crypto/aes"
5	"crypto/cipher"
6	"crypto/rand"
7	"encoding/base64"
8	"fmt"
9	"io"
10)
11
12// Cipher provides methods to encrypt and decrypt
13type Cipher interface {
14	Encrypt(value []byte) ([]byte, error)
15	Decrypt(ciphertext []byte) ([]byte, error)
16}
17
18type base64Cipher struct {
19	Cipher Cipher
20}
21
22// NewBase64Cipher returns a new AES Cipher for encrypting cookie values
23// and wrapping them in Base64 -- Supports Legacy encryption scheme
24func NewBase64Cipher(c Cipher) Cipher {
25	return &base64Cipher{Cipher: c}
26}
27
28// Encrypt encrypts a value with the embedded Cipher & Base64 encodes it
29func (c *base64Cipher) Encrypt(value []byte) ([]byte, error) {
30	encrypted, err := c.Cipher.Encrypt(value)
31	if err != nil {
32		return nil, err
33	}
34
35	return []byte(base64.StdEncoding.EncodeToString(encrypted)), nil
36}
37
38// Decrypt Base64 decodes a value & decrypts it with the embedded Cipher
39func (c *base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) {
40	encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext))
41	if err != nil {
42		return nil, fmt.Errorf("failed to base64 decode value %s", err)
43	}
44
45	return c.Cipher.Decrypt(encrypted)
46}
47
48type cfbCipher struct {
49	cipher.Block
50}
51
52// NewCFBCipher returns a new AES CFB Cipher
53func NewCFBCipher(secret []byte) (Cipher, error) {
54	c, err := aes.NewCipher(secret)
55	if err != nil {
56		return nil, err
57	}
58	return &cfbCipher{Block: c}, err
59}
60
61// Encrypt with AES CFB
62func (c *cfbCipher) Encrypt(value []byte) ([]byte, error) {
63	ciphertext := make([]byte, aes.BlockSize+len(value))
64	iv := ciphertext[:aes.BlockSize]
65	if _, err := io.ReadFull(rand.Reader, iv); err != nil {
66		return nil, fmt.Errorf("failed to create initialization vector %s", err)
67	}
68
69	stream := cipher.NewCFBEncrypter(c.Block, iv)
70	stream.XORKeyStream(ciphertext[aes.BlockSize:], value)
71	return ciphertext, nil
72}
73
74// Decrypt an AES CFB ciphertext
75func (c *cfbCipher) Decrypt(ciphertext []byte) ([]byte, error) {
76	if len(ciphertext) < aes.BlockSize {
77		return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext))
78	}
79
80	iv, ciphertext := ciphertext[:aes.BlockSize], ciphertext[aes.BlockSize:]
81	plaintext := make([]byte, len(ciphertext))
82	stream := cipher.NewCFBDecrypter(c.Block, iv)
83	stream.XORKeyStream(plaintext, ciphertext)
84
85	return plaintext, nil
86}
87
88type gcmCipher struct {
89	cipher.Block
90}
91
92// NewGCMCipher returns a new AES GCM Cipher
93func NewGCMCipher(secret []byte) (Cipher, error) {
94	c, err := aes.NewCipher(secret)
95	if err != nil {
96		return nil, err
97	}
98	return &gcmCipher{Block: c}, err
99}
100
101// Encrypt with AES GCM on raw bytes
102func (c *gcmCipher) Encrypt(value []byte) ([]byte, error) {
103	gcm, err := cipher.NewGCM(c.Block)
104	if err != nil {
105		return nil, err
106	}
107	nonce := make([]byte, gcm.NonceSize())
108	if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
109		return nil, err
110	}
111	// Using nonce as Seal's dst argument results in it being the first
112	// chunk of bytes in the ciphertext. Decrypt retrieves the nonce/IV from this.
113	ciphertext := gcm.Seal(nonce, nonce, value, nil)
114	return ciphertext, nil
115}
116
117// Decrypt an AES GCM ciphertext
118func (c *gcmCipher) Decrypt(ciphertext []byte) ([]byte, error) {
119	gcm, err := cipher.NewGCM(c.Block)
120	if err != nil {
121		return nil, err
122	}
123
124	nonceSize := gcm.NonceSize()
125	nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
126
127	plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
128	if err != nil {
129		return nil, err
130	}
131	return plaintext, nil
132}
133