1// Copyright (c) 2015-2016 The btcsuite developers
2// Use of this source code is governed by an ISC
3// license that can be found in the LICENSE file.
4
5package btcec
6
7import (
8	"bytes"
9	"crypto/aes"
10	"crypto/cipher"
11	"crypto/hmac"
12	"crypto/rand"
13	"crypto/sha256"
14	"crypto/sha512"
15	"errors"
16	"io"
17)
18
19var (
20	// ErrInvalidMAC occurs when Message Authentication Check (MAC) fails
21	// during decryption. This happens because of either invalid private key or
22	// corrupt ciphertext.
23	ErrInvalidMAC = errors.New("invalid mac hash")
24
25	// errInputTooShort occurs when the input ciphertext to the Decrypt
26	// function is less than 134 bytes long.
27	errInputTooShort = errors.New("ciphertext too short")
28
29	// errUnsupportedCurve occurs when the first two bytes of the encrypted
30	// text aren't 0x02CA (= 712 = secp256k1, from OpenSSL).
31	errUnsupportedCurve = errors.New("unsupported curve")
32
33	errInvalidXLength = errors.New("invalid X length, must be 32")
34	errInvalidYLength = errors.New("invalid Y length, must be 32")
35	errInvalidPadding = errors.New("invalid PKCS#7 padding")
36
37	// 0x02CA = 714
38	ciphCurveBytes = [2]byte{0x02, 0xCA}
39	// 0x20 = 32
40	ciphCoordLength = [2]byte{0x00, 0x20}
41)
42
43// GenerateSharedSecret generates a shared secret based on a private key and a
44// public key using Diffie-Hellman key exchange (ECDH) (RFC 4753).
45// RFC5903 Section 9 states we should only return x.
46func GenerateSharedSecret(privkey *PrivateKey, pubkey *PublicKey) []byte {
47	x, _ := pubkey.Curve.ScalarMult(pubkey.X, pubkey.Y, privkey.D.Bytes())
48	return x.Bytes()
49}
50
51// Encrypt encrypts data for the target public key using AES-256-CBC. It also
52// generates a private key (the pubkey of which is also in the output). The only
53// supported curve is secp256k1. The `structure' that it encodes everything into
54// is:
55//
56//	struct {
57//		// Initialization Vector used for AES-256-CBC
58//		IV [16]byte
59//		// Public Key: curve(2) + len_of_pubkeyX(2) + pubkeyX +
60//		// len_of_pubkeyY(2) + pubkeyY (curve = 714)
61//		PublicKey [70]byte
62//		// Cipher text
63//		Data []byte
64//		// HMAC-SHA-256 Message Authentication Code
65//		HMAC [32]byte
66//	}
67//
68// The primary aim is to ensure byte compatibility with Pyelliptic.  Also, refer
69// to section 5.8.1 of ANSI X9.63 for rationale on this format.
70func Encrypt(pubkey *PublicKey, in []byte) ([]byte, error) {
71	ephemeral, err := NewPrivateKey(S256())
72	if err != nil {
73		return nil, err
74	}
75	ecdhKey := GenerateSharedSecret(ephemeral, pubkey)
76	derivedKey := sha512.Sum512(ecdhKey)
77	keyE := derivedKey[:32]
78	keyM := derivedKey[32:]
79
80	paddedIn := addPKCSPadding(in)
81	// IV + Curve params/X/Y + padded plaintext/ciphertext + HMAC-256
82	out := make([]byte, aes.BlockSize+70+len(paddedIn)+sha256.Size)
83	iv := out[:aes.BlockSize]
84	if _, err = io.ReadFull(rand.Reader, iv); err != nil {
85		return nil, err
86	}
87	// start writing public key
88	pb := ephemeral.PubKey().SerializeUncompressed()
89	offset := aes.BlockSize
90
91	// curve and X length
92	copy(out[offset:offset+4], append(ciphCurveBytes[:], ciphCoordLength[:]...))
93	offset += 4
94	// X
95	copy(out[offset:offset+32], pb[1:33])
96	offset += 32
97	// Y length
98	copy(out[offset:offset+2], ciphCoordLength[:])
99	offset += 2
100	// Y
101	copy(out[offset:offset+32], pb[33:])
102	offset += 32
103
104	// start encryption
105	block, err := aes.NewCipher(keyE)
106	if err != nil {
107		return nil, err
108	}
109	mode := cipher.NewCBCEncrypter(block, iv)
110	mode.CryptBlocks(out[offset:len(out)-sha256.Size], paddedIn)
111
112	// start HMAC-SHA-256
113	hm := hmac.New(sha256.New, keyM)
114	hm.Write(out[:len(out)-sha256.Size])          // everything is hashed
115	copy(out[len(out)-sha256.Size:], hm.Sum(nil)) // write checksum
116
117	return out, nil
118}
119
120// Decrypt decrypts data that was encrypted using the Encrypt function.
121func Decrypt(priv *PrivateKey, in []byte) ([]byte, error) {
122	// IV + Curve params/X/Y + 1 block + HMAC-256
123	if len(in) < aes.BlockSize+70+aes.BlockSize+sha256.Size {
124		return nil, errInputTooShort
125	}
126
127	// read iv
128	iv := in[:aes.BlockSize]
129	offset := aes.BlockSize
130
131	// start reading pubkey
132	if !bytes.Equal(in[offset:offset+2], ciphCurveBytes[:]) {
133		return nil, errUnsupportedCurve
134	}
135	offset += 2
136
137	if !bytes.Equal(in[offset:offset+2], ciphCoordLength[:]) {
138		return nil, errInvalidXLength
139	}
140	offset += 2
141
142	xBytes := in[offset : offset+32]
143	offset += 32
144
145	if !bytes.Equal(in[offset:offset+2], ciphCoordLength[:]) {
146		return nil, errInvalidYLength
147	}
148	offset += 2
149
150	yBytes := in[offset : offset+32]
151	offset += 32
152
153	pb := make([]byte, 65)
154	pb[0] = byte(0x04) // uncompressed
155	copy(pb[1:33], xBytes)
156	copy(pb[33:], yBytes)
157	// check if (X, Y) lies on the curve and create a Pubkey if it does
158	pubkey, err := ParsePubKey(pb, S256())
159	if err != nil {
160		return nil, err
161	}
162
163	// check for cipher text length
164	if (len(in)-aes.BlockSize-offset-sha256.Size)%aes.BlockSize != 0 {
165		return nil, errInvalidPadding // not padded to 16 bytes
166	}
167
168	// read hmac
169	messageMAC := in[len(in)-sha256.Size:]
170
171	// generate shared secret
172	ecdhKey := GenerateSharedSecret(priv, pubkey)
173	derivedKey := sha512.Sum512(ecdhKey)
174	keyE := derivedKey[:32]
175	keyM := derivedKey[32:]
176
177	// verify mac
178	hm := hmac.New(sha256.New, keyM)
179	hm.Write(in[:len(in)-sha256.Size]) // everything is hashed
180	expectedMAC := hm.Sum(nil)
181	if !hmac.Equal(messageMAC, expectedMAC) {
182		return nil, ErrInvalidMAC
183	}
184
185	// start decryption
186	block, err := aes.NewCipher(keyE)
187	if err != nil {
188		return nil, err
189	}
190	mode := cipher.NewCBCDecrypter(block, iv)
191	// same length as ciphertext
192	plaintext := make([]byte, len(in)-offset-sha256.Size)
193	mode.CryptBlocks(plaintext, in[offset:len(in)-sha256.Size])
194
195	return removePKCSPadding(plaintext)
196}
197
198// Implement PKCS#7 padding with block size of 16 (AES block size).
199
200// addPKCSPadding adds padding to a block of data
201func addPKCSPadding(src []byte) []byte {
202	padding := aes.BlockSize - len(src)%aes.BlockSize
203	padtext := bytes.Repeat([]byte{byte(padding)}, padding)
204	return append(src, padtext...)
205}
206
207// removePKCSPadding removes padding from data that was added with addPKCSPadding
208func removePKCSPadding(src []byte) ([]byte, error) {
209	length := len(src)
210	padLength := int(src[length-1])
211	if padLength > aes.BlockSize || length < aes.BlockSize {
212		return nil, errInvalidPadding
213	}
214
215	return src[:length-padLength], nil
216}
217