1package memberlist 2 3import ( 4 "bytes" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "fmt" 9 "io" 10) 11 12/* 13 14Encrypted messages are prefixed with an encryptionVersion byte 15that is used for us to be able to properly encode/decode. We 16currently support the following versions: 17 18 0 - AES-GCM 128, using PKCS7 padding 19 1 - AES-GCM 128, no padding. Padding not needed, caused bloat. 20 21*/ 22type encryptionVersion uint8 23 24const ( 25 minEncryptionVersion encryptionVersion = 0 26 maxEncryptionVersion encryptionVersion = 1 27) 28 29const ( 30 versionSize = 1 31 nonceSize = 12 32 tagSize = 16 33 maxPadOverhead = 16 34 blockSize = aes.BlockSize 35) 36 37// pkcs7encode is used to pad a byte buffer to a specific block size using 38// the PKCS7 algorithm. "Ignores" some bytes to compensate for IV 39func pkcs7encode(buf *bytes.Buffer, ignore, blockSize int) { 40 n := buf.Len() - ignore 41 more := blockSize - (n % blockSize) 42 for i := 0; i < more; i++ { 43 buf.WriteByte(byte(more)) 44 } 45} 46 47// pkcs7decode is used to decode a buffer that has been padded 48func pkcs7decode(buf []byte, blockSize int) []byte { 49 if len(buf) == 0 { 50 panic("Cannot decode a PKCS7 buffer of zero length") 51 } 52 n := len(buf) 53 last := buf[n-1] 54 n -= int(last) 55 return buf[:n] 56} 57 58// encryptOverhead returns the maximum possible overhead of encryption by version 59func encryptOverhead(vsn encryptionVersion) int { 60 switch vsn { 61 case 0: 62 return 45 // Version: 1, IV: 12, Padding: 16, Tag: 16 63 case 1: 64 return 29 // Version: 1, IV: 12, Tag: 16 65 default: 66 panic("unsupported version") 67 } 68} 69 70// encryptedLength is used to compute the buffer size needed 71// for a message of given length 72func encryptedLength(vsn encryptionVersion, inp int) int { 73 // If we are on version 1, there is no padding 74 if vsn >= 1 { 75 return versionSize + nonceSize + inp + tagSize 76 } 77 78 // Determine the padding size 79 padding := blockSize - (inp % blockSize) 80 81 // Sum the extra parts to get total size 82 return versionSize + nonceSize + inp + padding + tagSize 83} 84 85// encryptPayload is used to encrypt a message with a given key. 86// We make use of AES-128 in GCM mode. New byte buffer is the version, 87// nonce, ciphertext and tag 88func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, dst *bytes.Buffer) error { 89 // Get the AES block cipher 90 aesBlock, err := aes.NewCipher(key) 91 if err != nil { 92 return err 93 } 94 95 // Get the GCM cipher mode 96 gcm, err := cipher.NewGCM(aesBlock) 97 if err != nil { 98 return err 99 } 100 101 // Grow the buffer to make room for everything 102 offset := dst.Len() 103 dst.Grow(encryptedLength(vsn, len(msg))) 104 105 // Write the encryption version 106 dst.WriteByte(byte(vsn)) 107 108 // Add a random nonce 109 _, err = io.CopyN(dst, rand.Reader, nonceSize) 110 if err != nil { 111 return err 112 } 113 afterNonce := dst.Len() 114 115 // Ensure we are correctly padded (only version 0) 116 if vsn == 0 { 117 io.Copy(dst, bytes.NewReader(msg)) 118 pkcs7encode(dst, offset+versionSize+nonceSize, aes.BlockSize) 119 } 120 121 // Encrypt message using GCM 122 slice := dst.Bytes()[offset:] 123 nonce := slice[versionSize : versionSize+nonceSize] 124 125 // Message source depends on the encryption version. 126 // Version 0 uses padding, version 1 does not 127 var src []byte 128 if vsn == 0 { 129 src = slice[versionSize+nonceSize:] 130 } else { 131 src = msg 132 } 133 out := gcm.Seal(nil, nonce, src, data) 134 135 // Truncate the plaintext, and write the cipher text 136 dst.Truncate(afterNonce) 137 dst.Write(out) 138 return nil 139} 140 141// decryptMessage performs the actual decryption of ciphertext. This is in its 142// own function to allow it to be called on all keys easily. 143func decryptMessage(key, msg []byte, data []byte) ([]byte, error) { 144 // Get the AES block cipher 145 aesBlock, err := aes.NewCipher(key) 146 if err != nil { 147 return nil, err 148 } 149 150 // Get the GCM cipher mode 151 gcm, err := cipher.NewGCM(aesBlock) 152 if err != nil { 153 return nil, err 154 } 155 156 // Decrypt the message 157 nonce := msg[versionSize : versionSize+nonceSize] 158 ciphertext := msg[versionSize+nonceSize:] 159 plain, err := gcm.Open(nil, nonce, ciphertext, data) 160 if err != nil { 161 return nil, err 162 } 163 164 // Success! 165 return plain, nil 166} 167 168// decryptPayload is used to decrypt a message with a given key, 169// and verify it's contents. Any padding will be removed, and a 170// slice to the plaintext is returned. Decryption is done IN PLACE! 171func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) { 172 // Ensure we have at least one byte 173 if len(msg) == 0 { 174 return nil, fmt.Errorf("Cannot decrypt empty payload") 175 } 176 177 // Verify the version 178 vsn := encryptionVersion(msg[0]) 179 if vsn > maxEncryptionVersion { 180 return nil, fmt.Errorf("Unsupported encryption version %d", msg[0]) 181 } 182 183 // Ensure the length is sane 184 if len(msg) < encryptedLength(vsn, 0) { 185 return nil, fmt.Errorf("Payload is too small to decrypt: %d", len(msg)) 186 } 187 188 for _, key := range keys { 189 plain, err := decryptMessage(key, msg, data) 190 if err == nil { 191 // Remove the PKCS7 padding for vsn 0 192 if vsn == 0 { 193 return pkcs7decode(plain, aes.BlockSize), nil 194 } else { 195 return plain, nil 196 } 197 } 198 } 199 200 return nil, fmt.Errorf("No installed keys could decrypt the message") 201} 202