1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"bytes"
21	"crypto/aes"
22	"crypto/cipher"
23	"crypto/hmac"
24	"crypto/rand"
25	"crypto/sha256"
26	"crypto/sha512"
27	"crypto/subtle"
28	"errors"
29	"fmt"
30	"hash"
31	"io"
32
33	"golang.org/x/crypto/pbkdf2"
34	"gopkg.in/square/go-jose.v2/cipher"
35)
36
37// Random reader (stubbed out in tests)
38var RandReader = rand.Reader
39
40const (
41	// RFC7518 recommends a minimum of 1,000 iterations:
42	// https://tools.ietf.org/html/rfc7518#section-4.8.1.2
43	// NIST recommends a minimum of 10,000:
44	// https://pages.nist.gov/800-63-3/sp800-63b.html
45	// 1Password uses 100,000:
46	// https://support.1password.com/pbkdf2/
47	defaultP2C = 100000
48	// Default salt size: 128 bits
49	defaultP2SSize = 16
50)
51
52// Dummy key cipher for shared symmetric key mode
53type symmetricKeyCipher struct {
54	key []byte // Pre-shared content-encryption key
55	p2c int    // PBES2 Count
56	p2s []byte // PBES2 Salt Input
57}
58
59// Signer/verifier for MAC modes
60type symmetricMac struct {
61	key []byte
62}
63
64// Input/output from an AEAD operation
65type aeadParts struct {
66	iv, ciphertext, tag []byte
67}
68
69// A content cipher based on an AEAD construction
70type aeadContentCipher struct {
71	keyBytes     int
72	authtagBytes int
73	getAead      func(key []byte) (cipher.AEAD, error)
74}
75
76// Random key generator
77type randomKeyGenerator struct {
78	size int
79}
80
81// Static key generator
82type staticKeyGenerator struct {
83	key []byte
84}
85
86// Create a new content cipher based on AES-GCM
87func newAESGCM(keySize int) contentCipher {
88	return &aeadContentCipher{
89		keyBytes:     keySize,
90		authtagBytes: 16,
91		getAead: func(key []byte) (cipher.AEAD, error) {
92			aes, err := aes.NewCipher(key)
93			if err != nil {
94				return nil, err
95			}
96
97			return cipher.NewGCM(aes)
98		},
99	}
100}
101
102// Create a new content cipher based on AES-CBC+HMAC
103func newAESCBC(keySize int) contentCipher {
104	return &aeadContentCipher{
105		keyBytes:     keySize * 2,
106		authtagBytes: keySize,
107		getAead: func(key []byte) (cipher.AEAD, error) {
108			return josecipher.NewCBCHMAC(key, aes.NewCipher)
109		},
110	}
111}
112
113// Get an AEAD cipher object for the given content encryption algorithm
114func getContentCipher(alg ContentEncryption) contentCipher {
115	switch alg {
116	case A128GCM:
117		return newAESGCM(16)
118	case A192GCM:
119		return newAESGCM(24)
120	case A256GCM:
121		return newAESGCM(32)
122	case A128CBC_HS256:
123		return newAESCBC(16)
124	case A192CBC_HS384:
125		return newAESCBC(24)
126	case A256CBC_HS512:
127		return newAESCBC(32)
128	default:
129		return nil
130	}
131}
132
133// getPbkdf2Params returns the key length and hash function used in
134// pbkdf2.Key.
135func getPbkdf2Params(alg KeyAlgorithm) (int, func() hash.Hash) {
136	switch alg {
137	case PBES2_HS256_A128KW:
138		return 16, sha256.New
139	case PBES2_HS384_A192KW:
140		return 24, sha512.New384
141	case PBES2_HS512_A256KW:
142		return 32, sha512.New
143	default:
144		panic("invalid algorithm")
145	}
146}
147
148// getRandomSalt generates a new salt of the given size.
149func getRandomSalt(size int) ([]byte, error) {
150	salt := make([]byte, size)
151	_, err := io.ReadFull(RandReader, salt)
152	if err != nil {
153		return nil, err
154	}
155
156	return salt, nil
157}
158
159// newSymmetricRecipient creates a JWE encrypter based on AES-GCM key wrap.
160func newSymmetricRecipient(keyAlg KeyAlgorithm, key []byte) (recipientKeyInfo, error) {
161	switch keyAlg {
162	case DIRECT, A128GCMKW, A192GCMKW, A256GCMKW, A128KW, A192KW, A256KW:
163	case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
164	default:
165		return recipientKeyInfo{}, ErrUnsupportedAlgorithm
166	}
167
168	return recipientKeyInfo{
169		keyAlg: keyAlg,
170		keyEncrypter: &symmetricKeyCipher{
171			key: key,
172		},
173	}, nil
174}
175
176// newSymmetricSigner creates a recipientSigInfo based on the given key.
177func newSymmetricSigner(sigAlg SignatureAlgorithm, key []byte) (recipientSigInfo, error) {
178	// Verify that key management algorithm is supported by this encrypter
179	switch sigAlg {
180	case HS256, HS384, HS512:
181	default:
182		return recipientSigInfo{}, ErrUnsupportedAlgorithm
183	}
184
185	return recipientSigInfo{
186		sigAlg: sigAlg,
187		signer: &symmetricMac{
188			key: key,
189		},
190	}, nil
191}
192
193// Generate a random key for the given content cipher
194func (ctx randomKeyGenerator) genKey() ([]byte, rawHeader, error) {
195	key := make([]byte, ctx.size)
196	_, err := io.ReadFull(RandReader, key)
197	if err != nil {
198		return nil, rawHeader{}, err
199	}
200
201	return key, rawHeader{}, nil
202}
203
204// Key size for random generator
205func (ctx randomKeyGenerator) keySize() int {
206	return ctx.size
207}
208
209// Generate a static key (for direct mode)
210func (ctx staticKeyGenerator) genKey() ([]byte, rawHeader, error) {
211	cek := make([]byte, len(ctx.key))
212	copy(cek, ctx.key)
213	return cek, rawHeader{}, nil
214}
215
216// Key size for static generator
217func (ctx staticKeyGenerator) keySize() int {
218	return len(ctx.key)
219}
220
221// Get key size for this cipher
222func (ctx aeadContentCipher) keySize() int {
223	return ctx.keyBytes
224}
225
226// Encrypt some data
227func (ctx aeadContentCipher) encrypt(key, aad, pt []byte) (*aeadParts, error) {
228	// Get a new AEAD instance
229	aead, err := ctx.getAead(key)
230	if err != nil {
231		return nil, err
232	}
233
234	// Initialize a new nonce
235	iv := make([]byte, aead.NonceSize())
236	_, err = io.ReadFull(RandReader, iv)
237	if err != nil {
238		return nil, err
239	}
240
241	ciphertextAndTag := aead.Seal(nil, iv, pt, aad)
242	offset := len(ciphertextAndTag) - ctx.authtagBytes
243
244	return &aeadParts{
245		iv:         iv,
246		ciphertext: ciphertextAndTag[:offset],
247		tag:        ciphertextAndTag[offset:],
248	}, nil
249}
250
251// Decrypt some data
252func (ctx aeadContentCipher) decrypt(key, aad []byte, parts *aeadParts) ([]byte, error) {
253	aead, err := ctx.getAead(key)
254	if err != nil {
255		return nil, err
256	}
257
258	if len(parts.iv) != aead.NonceSize() || len(parts.tag) < ctx.authtagBytes {
259		return nil, ErrCryptoFailure
260	}
261
262	return aead.Open(nil, parts.iv, append(parts.ciphertext, parts.tag...), aad)
263}
264
265// Encrypt the content encryption key.
266func (ctx *symmetricKeyCipher) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
267	switch alg {
268	case DIRECT:
269		return recipientInfo{
270			header: &rawHeader{},
271		}, nil
272	case A128GCMKW, A192GCMKW, A256GCMKW:
273		aead := newAESGCM(len(ctx.key))
274
275		parts, err := aead.encrypt(ctx.key, []byte{}, cek)
276		if err != nil {
277			return recipientInfo{}, err
278		}
279
280		header := &rawHeader{}
281		header.set(headerIV, newBuffer(parts.iv))
282		header.set(headerTag, newBuffer(parts.tag))
283
284		return recipientInfo{
285			header:       header,
286			encryptedKey: parts.ciphertext,
287		}, nil
288	case A128KW, A192KW, A256KW:
289		block, err := aes.NewCipher(ctx.key)
290		if err != nil {
291			return recipientInfo{}, err
292		}
293
294		jek, err := josecipher.KeyWrap(block, cek)
295		if err != nil {
296			return recipientInfo{}, err
297		}
298
299		return recipientInfo{
300			encryptedKey: jek,
301			header:       &rawHeader{},
302		}, nil
303	case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
304		if len(ctx.p2s) == 0 {
305			salt, err := getRandomSalt(defaultP2SSize)
306			if err != nil {
307				return recipientInfo{}, err
308			}
309			ctx.p2s = salt
310		}
311
312		if ctx.p2c <= 0 {
313			ctx.p2c = defaultP2C
314		}
315
316		// salt is UTF8(Alg) || 0x00 || Salt Input
317		salt := bytes.Join([][]byte{[]byte(alg), ctx.p2s}, []byte{0x00})
318
319		// derive key
320		keyLen, h := getPbkdf2Params(alg)
321		key := pbkdf2.Key(ctx.key, salt, ctx.p2c, keyLen, h)
322
323		// use AES cipher with derived key
324		block, err := aes.NewCipher(key)
325		if err != nil {
326			return recipientInfo{}, err
327		}
328
329		jek, err := josecipher.KeyWrap(block, cek)
330		if err != nil {
331			return recipientInfo{}, err
332		}
333
334		header := &rawHeader{}
335		header.set(headerP2C, ctx.p2c)
336		header.set(headerP2S, newBuffer(ctx.p2s))
337
338		return recipientInfo{
339			encryptedKey: jek,
340			header:       header,
341		}, nil
342	}
343
344	return recipientInfo{}, ErrUnsupportedAlgorithm
345}
346
347// Decrypt the content encryption key.
348func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
349	switch headers.getAlgorithm() {
350	case DIRECT:
351		cek := make([]byte, len(ctx.key))
352		copy(cek, ctx.key)
353		return cek, nil
354	case A128GCMKW, A192GCMKW, A256GCMKW:
355		aead := newAESGCM(len(ctx.key))
356
357		iv, err := headers.getIV()
358		if err != nil {
359			return nil, fmt.Errorf("square/go-jose: invalid IV: %v", err)
360		}
361		tag, err := headers.getTag()
362		if err != nil {
363			return nil, fmt.Errorf("square/go-jose: invalid tag: %v", err)
364		}
365
366		parts := &aeadParts{
367			iv:         iv.bytes(),
368			ciphertext: recipient.encryptedKey,
369			tag:        tag.bytes(),
370		}
371
372		cek, err := aead.decrypt(ctx.key, []byte{}, parts)
373		if err != nil {
374			return nil, err
375		}
376
377		return cek, nil
378	case A128KW, A192KW, A256KW:
379		block, err := aes.NewCipher(ctx.key)
380		if err != nil {
381			return nil, err
382		}
383
384		cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
385		if err != nil {
386			return nil, err
387		}
388		return cek, nil
389	case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
390		p2s, err := headers.getP2S()
391		if err != nil {
392			return nil, fmt.Errorf("square/go-jose: invalid P2S: %v", err)
393		}
394		if p2s == nil || len(p2s.data) == 0 {
395			return nil, fmt.Errorf("square/go-jose: invalid P2S: must be present")
396		}
397
398		p2c, err := headers.getP2C()
399		if err != nil {
400			return nil, fmt.Errorf("square/go-jose: invalid P2C: %v", err)
401		}
402		if p2c <= 0 {
403			return nil, fmt.Errorf("square/go-jose: invalid P2C: must be a positive integer")
404		}
405
406		// salt is UTF8(Alg) || 0x00 || Salt Input
407		alg := headers.getAlgorithm()
408		salt := bytes.Join([][]byte{[]byte(alg), p2s.bytes()}, []byte{0x00})
409
410		// derive key
411		keyLen, h := getPbkdf2Params(alg)
412		key := pbkdf2.Key(ctx.key, salt, p2c, keyLen, h)
413
414		// use AES cipher with derived key
415		block, err := aes.NewCipher(key)
416		if err != nil {
417			return nil, err
418		}
419
420		cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
421		if err != nil {
422			return nil, err
423		}
424		return cek, nil
425	}
426
427	return nil, ErrUnsupportedAlgorithm
428}
429
430// Sign the given payload
431func (ctx symmetricMac) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
432	mac, err := ctx.hmac(payload, alg)
433	if err != nil {
434		return Signature{}, errors.New("square/go-jose: failed to compute hmac")
435	}
436
437	return Signature{
438		Signature: mac,
439		protected: &rawHeader{},
440	}, nil
441}
442
443// Verify the given payload
444func (ctx symmetricMac) verifyPayload(payload []byte, mac []byte, alg SignatureAlgorithm) error {
445	expected, err := ctx.hmac(payload, alg)
446	if err != nil {
447		return errors.New("square/go-jose: failed to compute hmac")
448	}
449
450	if len(mac) != len(expected) {
451		return errors.New("square/go-jose: invalid hmac")
452	}
453
454	match := subtle.ConstantTimeCompare(mac, expected)
455	if match != 1 {
456		return errors.New("square/go-jose: invalid hmac")
457	}
458
459	return nil
460}
461
462// Compute the HMAC based on the given alg value
463func (ctx symmetricMac) hmac(payload []byte, alg SignatureAlgorithm) ([]byte, error) {
464	var hash func() hash.Hash
465
466	switch alg {
467	case HS256:
468		hash = sha256.New
469	case HS384:
470		hash = sha512.New384
471	case HS512:
472		hash = sha512.New
473	default:
474		return nil, ErrUnsupportedAlgorithm
475	}
476
477	hmac := hmac.New(hash, ctx.key)
478
479	// According to documentation, Write() on hash never fails
480	_, _ = hmac.Write(payload)
481	return hmac.Sum(nil), nil
482}
483