1package memberlist
2
3import (
4	"bytes"
5	"crypto/aes"
6	"crypto/cipher"
7	"crypto/rand"
8	"fmt"
9	"io"
10)
11
12/*
13
14Encrypted messages are prefixed with an encryptionVersion byte
15that is used for us to be able to properly encode/decode. We
16currently support the following versions:
17
18 0 - AES-GCM 128, using PKCS7 padding
19 1 - AES-GCM 128, no padding. Padding not needed, caused bloat.
20
21*/
22type encryptionVersion uint8
23
24const (
25	minEncryptionVersion encryptionVersion = 0
26	maxEncryptionVersion encryptionVersion = 1
27)
28
29const (
30	versionSize    = 1
31	nonceSize      = 12
32	tagSize        = 16
33	maxPadOverhead = 16
34	blockSize      = aes.BlockSize
35)
36
37// pkcs7encode is used to pad a byte buffer to a specific block size using
38// the PKCS7 algorithm. "Ignores" some bytes to compensate for IV
39func pkcs7encode(buf *bytes.Buffer, ignore, blockSize int) {
40	n := buf.Len() - ignore
41	more := blockSize - (n % blockSize)
42	for i := 0; i < more; i++ {
43		buf.WriteByte(byte(more))
44	}
45}
46
47// pkcs7decode is used to decode a buffer that has been padded
48func pkcs7decode(buf []byte, blockSize int) []byte {
49	if len(buf) == 0 {
50		panic("Cannot decode a PKCS7 buffer of zero length")
51	}
52	n := len(buf)
53	last := buf[n-1]
54	n -= int(last)
55	return buf[:n]
56}
57
58// encryptOverhead returns the maximum possible overhead of encryption by version
59func encryptOverhead(vsn encryptionVersion) int {
60	switch vsn {
61	case 0:
62		return 45 // Version: 1, IV: 12, Padding: 16, Tag: 16
63	case 1:
64		return 29 // Version: 1, IV: 12, Tag: 16
65	default:
66		panic("unsupported version")
67	}
68}
69
70// encryptedLength is used to compute the buffer size needed
71// for a message of given length
72func encryptedLength(vsn encryptionVersion, inp int) int {
73	// If we are on version 1, there is no padding
74	if vsn >= 1 {
75		return versionSize + nonceSize + inp + tagSize
76	}
77
78	// Determine the padding size
79	padding := blockSize - (inp % blockSize)
80
81	// Sum the extra parts to get total size
82	return versionSize + nonceSize + inp + padding + tagSize
83}
84
85// encryptPayload is used to encrypt a message with a given key.
86// We make use of AES-128 in GCM mode. New byte buffer is the version,
87// nonce, ciphertext and tag
88func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, dst *bytes.Buffer) error {
89	// Get the AES block cipher
90	aesBlock, err := aes.NewCipher(key)
91	if err != nil {
92		return err
93	}
94
95	// Get the GCM cipher mode
96	gcm, err := cipher.NewGCM(aesBlock)
97	if err != nil {
98		return err
99	}
100
101	// Grow the buffer to make room for everything
102	offset := dst.Len()
103	dst.Grow(encryptedLength(vsn, len(msg)))
104
105	// Write the encryption version
106	dst.WriteByte(byte(vsn))
107
108	// Add a random nonce
109	io.CopyN(dst, rand.Reader, nonceSize)
110	afterNonce := dst.Len()
111
112	// Ensure we are correctly padded (only version 0)
113	if vsn == 0 {
114		io.Copy(dst, bytes.NewReader(msg))
115		pkcs7encode(dst, offset+versionSize+nonceSize, aes.BlockSize)
116	}
117
118	// Encrypt message using GCM
119	slice := dst.Bytes()[offset:]
120	nonce := slice[versionSize : versionSize+nonceSize]
121
122	// Message source depends on the encryption version.
123	// Version 0 uses padding, version 1 does not
124	var src []byte
125	if vsn == 0 {
126		src = slice[versionSize+nonceSize:]
127	} else {
128		src = msg
129	}
130	out := gcm.Seal(nil, nonce, src, data)
131
132	// Truncate the plaintext, and write the cipher text
133	dst.Truncate(afterNonce)
134	dst.Write(out)
135	return nil
136}
137
138// decryptMessage performs the actual decryption of ciphertext. This is in its
139// own function to allow it to be called on all keys easily.
140func decryptMessage(key, msg []byte, data []byte) ([]byte, error) {
141	// Get the AES block cipher
142	aesBlock, err := aes.NewCipher(key)
143	if err != nil {
144		return nil, err
145	}
146
147	// Get the GCM cipher mode
148	gcm, err := cipher.NewGCM(aesBlock)
149	if err != nil {
150		return nil, err
151	}
152
153	// Decrypt the message
154	nonce := msg[versionSize : versionSize+nonceSize]
155	ciphertext := msg[versionSize+nonceSize:]
156	plain, err := gcm.Open(nil, nonce, ciphertext, data)
157	if err != nil {
158		return nil, err
159	}
160
161	// Success!
162	return plain, nil
163}
164
165// decryptPayload is used to decrypt a message with a given key,
166// and verify it's contents. Any padding will be removed, and a
167// slice to the plaintext is returned. Decryption is done IN PLACE!
168func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) {
169	// Ensure we have at least one byte
170	if len(msg) == 0 {
171		return nil, fmt.Errorf("Cannot decrypt empty payload")
172	}
173
174	// Verify the version
175	vsn := encryptionVersion(msg[0])
176	if vsn > maxEncryptionVersion {
177		return nil, fmt.Errorf("Unsupported encryption version %d", msg[0])
178	}
179
180	// Ensure the length is sane
181	if len(msg) < encryptedLength(vsn, 0) {
182		return nil, fmt.Errorf("Payload is too small to decrypt: %d", len(msg))
183	}
184
185	for _, key := range keys {
186		plain, err := decryptMessage(key, msg, data)
187		if err == nil {
188			// Remove the PKCS7 padding for vsn 0
189			if vsn == 0 {
190				return pkcs7decode(plain, aes.BlockSize), nil
191			} else {
192				return plain, nil
193			}
194		}
195	}
196
197	return nil, fmt.Errorf("No installed keys could decrypt the message")
198}
199