1package memberlist
2
3import (
4	"bytes"
5	"crypto/aes"
6	"crypto/cipher"
7	"crypto/rand"
8	"fmt"
9	"io"
10)
11
12/*
13
14Encrypted messages are prefixed with an encryptionVersion byte
15that is used for us to be able to properly encode/decode. We
16currently support the following versions:
17
18 0 - AES-GCM 128, using PKCS7 padding
19 1 - AES-GCM 128, no padding. Padding not needed, caused bloat.
20
21*/
22type encryptionVersion uint8
23
24const (
25	minEncryptionVersion encryptionVersion = 0
26	maxEncryptionVersion encryptionVersion = 1
27)
28
29const (
30	versionSize    = 1
31	nonceSize      = 12
32	tagSize        = 16
33	maxPadOverhead = 16
34	blockSize      = aes.BlockSize
35)
36
37// pkcs7encode is used to pad a byte buffer to a specific block size using
38// the PKCS7 algorithm. "Ignores" some bytes to compensate for IV
39func pkcs7encode(buf *bytes.Buffer, ignore, blockSize int) {
40	n := buf.Len() - ignore
41	more := blockSize - (n % blockSize)
42	for i := 0; i < more; i++ {
43		buf.WriteByte(byte(more))
44	}
45}
46
47// pkcs7decode is used to decode a buffer that has been padded
48func pkcs7decode(buf []byte, blockSize int) []byte {
49	if len(buf) == 0 {
50		panic("Cannot decode a PKCS7 buffer of zero length")
51	}
52	n := len(buf)
53	last := buf[n-1]
54	n -= int(last)
55	return buf[:n]
56}
57
58// encryptOverhead returns the maximum possible overhead of encryption by version
59func encryptOverhead(vsn encryptionVersion) int {
60	switch vsn {
61	case 0:
62		return 45 // Version: 1, IV: 12, Padding: 16, Tag: 16
63	case 1:
64		return 29 // Version: 1, IV: 12, Tag: 16
65	default:
66		panic("unsupported version")
67	}
68}
69
70// encryptedLength is used to compute the buffer size needed
71// for a message of given length
72func encryptedLength(vsn encryptionVersion, inp int) int {
73	// If we are on version 1, there is no padding
74	if vsn >= 1 {
75		return versionSize + nonceSize + inp + tagSize
76	}
77
78	// Determine the padding size
79	padding := blockSize - (inp % blockSize)
80
81	// Sum the extra parts to get total size
82	return versionSize + nonceSize + inp + padding + tagSize
83}
84
85// encryptPayload is used to encrypt a message with a given key.
86// We make use of AES-128 in GCM mode. New byte buffer is the version,
87// nonce, ciphertext and tag
88func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, dst *bytes.Buffer) error {
89	// Get the AES block cipher
90	aesBlock, err := aes.NewCipher(key)
91	if err != nil {
92		return err
93	}
94
95	// Get the GCM cipher mode
96	gcm, err := cipher.NewGCM(aesBlock)
97	if err != nil {
98		return err
99	}
100
101	// Grow the buffer to make room for everything
102	offset := dst.Len()
103	dst.Grow(encryptedLength(vsn, len(msg)))
104
105	// Write the encryption version
106	dst.WriteByte(byte(vsn))
107
108	// Add a random nonce
109	_, err = io.CopyN(dst, rand.Reader, nonceSize)
110	if err != nil {
111		return err
112	}
113	afterNonce := dst.Len()
114
115	// Ensure we are correctly padded (only version 0)
116	if vsn == 0 {
117		io.Copy(dst, bytes.NewReader(msg))
118		pkcs7encode(dst, offset+versionSize+nonceSize, aes.BlockSize)
119	}
120
121	// Encrypt message using GCM
122	slice := dst.Bytes()[offset:]
123	nonce := slice[versionSize : versionSize+nonceSize]
124
125	// Message source depends on the encryption version.
126	// Version 0 uses padding, version 1 does not
127	var src []byte
128	if vsn == 0 {
129		src = slice[versionSize+nonceSize:]
130	} else {
131		src = msg
132	}
133	out := gcm.Seal(nil, nonce, src, data)
134
135	// Truncate the plaintext, and write the cipher text
136	dst.Truncate(afterNonce)
137	dst.Write(out)
138	return nil
139}
140
141// decryptMessage performs the actual decryption of ciphertext. This is in its
142// own function to allow it to be called on all keys easily.
143func decryptMessage(key, msg []byte, data []byte) ([]byte, error) {
144	// Get the AES block cipher
145	aesBlock, err := aes.NewCipher(key)
146	if err != nil {
147		return nil, err
148	}
149
150	// Get the GCM cipher mode
151	gcm, err := cipher.NewGCM(aesBlock)
152	if err != nil {
153		return nil, err
154	}
155
156	// Decrypt the message
157	nonce := msg[versionSize : versionSize+nonceSize]
158	ciphertext := msg[versionSize+nonceSize:]
159	plain, err := gcm.Open(nil, nonce, ciphertext, data)
160	if err != nil {
161		return nil, err
162	}
163
164	// Success!
165	return plain, nil
166}
167
168// decryptPayload is used to decrypt a message with a given key,
169// and verify it's contents. Any padding will be removed, and a
170// slice to the plaintext is returned. Decryption is done IN PLACE!
171func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) {
172	// Ensure we have at least one byte
173	if len(msg) == 0 {
174		return nil, fmt.Errorf("Cannot decrypt empty payload")
175	}
176
177	// Verify the version
178	vsn := encryptionVersion(msg[0])
179	if vsn > maxEncryptionVersion {
180		return nil, fmt.Errorf("Unsupported encryption version %d", msg[0])
181	}
182
183	// Ensure the length is sane
184	if len(msg) < encryptedLength(vsn, 0) {
185		return nil, fmt.Errorf("Payload is too small to decrypt: %d", len(msg))
186	}
187
188	for _, key := range keys {
189		plain, err := decryptMessage(key, msg, data)
190		if err == nil {
191			// Remove the PKCS7 padding for vsn 0
192			if vsn == 0 {
193				return pkcs7decode(plain, aes.BlockSize), nil
194			} else {
195				return plain, nil
196			}
197		}
198	}
199
200	return nil, fmt.Errorf("No installed keys could decrypt the message")
201}
202