1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"crypto/ecdsa"
21	"crypto/rsa"
22	"errors"
23	"fmt"
24	"reflect"
25
26	"gopkg.in/square/go-jose.v2/json"
27)
28
29// Encrypter represents an encrypter which produces an encrypted JWE object.
30type Encrypter interface {
31	Encrypt(plaintext []byte) (*JSONWebEncryption, error)
32	EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error)
33	Options() EncrypterOptions
34}
35
36// A generic content cipher
37type contentCipher interface {
38	keySize() int
39	encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
40	decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
41}
42
43// A key generator (for generating/getting a CEK)
44type keyGenerator interface {
45	keySize() int
46	genKey() ([]byte, rawHeader, error)
47}
48
49// A generic key encrypter
50type keyEncrypter interface {
51	encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key
52}
53
54// A generic key decrypter
55type keyDecrypter interface {
56	decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key
57}
58
59// A generic encrypter based on the given key encrypter and content cipher.
60type genericEncrypter struct {
61	contentAlg     ContentEncryption
62	compressionAlg CompressionAlgorithm
63	cipher         contentCipher
64	recipients     []recipientKeyInfo
65	keyGenerator   keyGenerator
66	extraHeaders   map[HeaderKey]interface{}
67}
68
69type recipientKeyInfo struct {
70	keyID        string
71	keyAlg       KeyAlgorithm
72	keyEncrypter keyEncrypter
73}
74
75// EncrypterOptions represents options that can be set on new encrypters.
76type EncrypterOptions struct {
77	Compression CompressionAlgorithm
78
79	// Optional map of additional keys to be inserted into the protected header
80	// of a JWS object. Some specifications which make use of JWS like to insert
81	// additional values here. All values must be JSON-serializable.
82	ExtraHeaders map[HeaderKey]interface{}
83}
84
85// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
86// if necessary. It returns itself and so can be used in a fluent style.
87func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
88	if eo.ExtraHeaders == nil {
89		eo.ExtraHeaders = map[HeaderKey]interface{}{}
90	}
91	eo.ExtraHeaders[k] = v
92	return eo
93}
94
95// WithContentType adds a content type ("cty") header and returns the updated
96// EncrypterOptions.
97func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions {
98	return eo.WithHeader(HeaderContentType, contentType)
99}
100
101// WithType adds a type ("typ") header and returns the updated EncrypterOptions.
102func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
103	return eo.WithHeader(HeaderType, typ)
104}
105
106// Recipient represents an algorithm/key to encrypt messages to.
107//
108// PBES2Count and PBES2Salt correspond with the  "p2c" and "p2s" headers used
109// on the password-based encryption algorithms PBES2-HS256+A128KW,
110// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe
111// default of 100000 will be used for the count and a 128-bit random salt will
112// be generated.
113type Recipient struct {
114	Algorithm  KeyAlgorithm
115	Key        interface{}
116	KeyID      string
117	PBES2Count int
118	PBES2Salt  []byte
119}
120
121// NewEncrypter creates an appropriate encrypter based on the key type
122func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) {
123	encrypter := &genericEncrypter{
124		contentAlg: enc,
125		recipients: []recipientKeyInfo{},
126		cipher:     getContentCipher(enc),
127	}
128	if opts != nil {
129		encrypter.compressionAlg = opts.Compression
130		encrypter.extraHeaders = opts.ExtraHeaders
131	}
132
133	if encrypter.cipher == nil {
134		return nil, ErrUnsupportedAlgorithm
135	}
136
137	var keyID string
138	var rawKey interface{}
139	switch encryptionKey := rcpt.Key.(type) {
140	case JSONWebKey:
141		keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
142	case *JSONWebKey:
143		keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
144	case OpaqueKeyEncrypter:
145		keyID, rawKey = encryptionKey.KeyID(), encryptionKey
146	default:
147		rawKey = encryptionKey
148	}
149
150	switch rcpt.Algorithm {
151	case DIRECT:
152		// Direct encryption mode must be treated differently
153		if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
154			return nil, ErrUnsupportedKeyType
155		}
156		if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
157			return nil, ErrInvalidKeySize
158		}
159		encrypter.keyGenerator = staticKeyGenerator{
160			key: rawKey.([]byte),
161		}
162		recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
163		recipientInfo.keyID = keyID
164		if rcpt.KeyID != "" {
165			recipientInfo.keyID = rcpt.KeyID
166		}
167		encrypter.recipients = []recipientKeyInfo{recipientInfo}
168		return encrypter, nil
169	case ECDH_ES:
170		// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
171		typeOf := reflect.TypeOf(rawKey)
172		if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
173			return nil, ErrUnsupportedKeyType
174		}
175		encrypter.keyGenerator = ecKeyGenerator{
176			size:      encrypter.cipher.keySize(),
177			algID:     string(enc),
178			publicKey: rawKey.(*ecdsa.PublicKey),
179		}
180		recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
181		recipientInfo.keyID = keyID
182		if rcpt.KeyID != "" {
183			recipientInfo.keyID = rcpt.KeyID
184		}
185		encrypter.recipients = []recipientKeyInfo{recipientInfo}
186		return encrypter, nil
187	default:
188		// Can just add a standard recipient
189		encrypter.keyGenerator = randomKeyGenerator{
190			size: encrypter.cipher.keySize(),
191		}
192		err := encrypter.addRecipient(rcpt)
193		return encrypter, err
194	}
195}
196
197// NewMultiEncrypter creates a multi-encrypter based on the given parameters
198func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) {
199	cipher := getContentCipher(enc)
200
201	if cipher == nil {
202		return nil, ErrUnsupportedAlgorithm
203	}
204	if rcpts == nil || len(rcpts) == 0 {
205		return nil, fmt.Errorf("square/go-jose: recipients is nil or empty")
206	}
207
208	encrypter := &genericEncrypter{
209		contentAlg: enc,
210		recipients: []recipientKeyInfo{},
211		cipher:     cipher,
212		keyGenerator: randomKeyGenerator{
213			size: cipher.keySize(),
214		},
215	}
216
217	if opts != nil {
218		encrypter.compressionAlg = opts.Compression
219		encrypter.extraHeaders = opts.ExtraHeaders
220	}
221
222	for _, recipient := range rcpts {
223		err := encrypter.addRecipient(recipient)
224		if err != nil {
225			return nil, err
226		}
227	}
228
229	return encrypter, nil
230}
231
232func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) {
233	var recipientInfo recipientKeyInfo
234
235	switch recipient.Algorithm {
236	case DIRECT, ECDH_ES:
237		return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm)
238	}
239
240	recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key)
241	if recipient.KeyID != "" {
242		recipientInfo.keyID = recipient.KeyID
243	}
244
245	switch recipient.Algorithm {
246	case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
247		if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok {
248			sr.p2c = recipient.PBES2Count
249			sr.p2s = recipient.PBES2Salt
250		}
251	}
252
253	if err == nil {
254		ctx.recipients = append(ctx.recipients, recipientInfo)
255	}
256	return err
257}
258
259func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
260	switch encryptionKey := encryptionKey.(type) {
261	case *rsa.PublicKey:
262		return newRSARecipient(alg, encryptionKey)
263	case *ecdsa.PublicKey:
264		return newECDHRecipient(alg, encryptionKey)
265	case []byte:
266		return newSymmetricRecipient(alg, encryptionKey)
267	case string:
268		return newSymmetricRecipient(alg, []byte(encryptionKey))
269	case *JSONWebKey:
270		recipient, err := makeJWERecipient(alg, encryptionKey.Key)
271		recipient.keyID = encryptionKey.KeyID
272		return recipient, err
273	}
274	if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
275		return newOpaqueKeyEncrypter(alg, encrypter)
276	}
277	return recipientKeyInfo{}, ErrUnsupportedKeyType
278}
279
280// newDecrypter creates an appropriate decrypter based on the key type
281func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
282	switch decryptionKey := decryptionKey.(type) {
283	case *rsa.PrivateKey:
284		return &rsaDecrypterSigner{
285			privateKey: decryptionKey,
286		}, nil
287	case *ecdsa.PrivateKey:
288		return &ecDecrypterSigner{
289			privateKey: decryptionKey,
290		}, nil
291	case []byte:
292		return &symmetricKeyCipher{
293			key: decryptionKey,
294		}, nil
295	case string:
296		return &symmetricKeyCipher{
297			key: []byte(decryptionKey),
298		}, nil
299	case JSONWebKey:
300		return newDecrypter(decryptionKey.Key)
301	case *JSONWebKey:
302		return newDecrypter(decryptionKey.Key)
303	}
304	if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
305		return &opaqueKeyDecrypter{decrypter: okd}, nil
306	}
307	return nil, ErrUnsupportedKeyType
308}
309
310// Implementation of encrypt method producing a JWE object.
311func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
312	return ctx.EncryptWithAuthData(plaintext, nil)
313}
314
315// Implementation of encrypt method producing a JWE object.
316func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
317	obj := &JSONWebEncryption{}
318	obj.aad = aad
319
320	obj.protected = &rawHeader{}
321	err := obj.protected.set(headerEncryption, ctx.contentAlg)
322	if err != nil {
323		return nil, err
324	}
325
326	obj.recipients = make([]recipientInfo, len(ctx.recipients))
327
328	if len(ctx.recipients) == 0 {
329		return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to")
330	}
331
332	cek, headers, err := ctx.keyGenerator.genKey()
333	if err != nil {
334		return nil, err
335	}
336
337	obj.protected.merge(&headers)
338
339	for i, info := range ctx.recipients {
340		recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
341		if err != nil {
342			return nil, err
343		}
344
345		err = recipient.header.set(headerAlgorithm, info.keyAlg)
346		if err != nil {
347			return nil, err
348		}
349
350		if info.keyID != "" {
351			err = recipient.header.set(headerKeyID, info.keyID)
352			if err != nil {
353				return nil, err
354			}
355		}
356		obj.recipients[i] = recipient
357	}
358
359	if len(ctx.recipients) == 1 {
360		// Move per-recipient headers into main protected header if there's
361		// only a single recipient.
362		obj.protected.merge(obj.recipients[0].header)
363		obj.recipients[0].header = nil
364	}
365
366	if ctx.compressionAlg != NONE {
367		plaintext, err = compress(ctx.compressionAlg, plaintext)
368		if err != nil {
369			return nil, err
370		}
371
372		err = obj.protected.set(headerCompression, ctx.compressionAlg)
373		if err != nil {
374			return nil, err
375		}
376	}
377
378	for k, v := range ctx.extraHeaders {
379		b, err := json.Marshal(v)
380		if err != nil {
381			return nil, err
382		}
383		(*obj.protected)[k] = makeRawMessage(b)
384	}
385
386	authData := obj.computeAuthData()
387	parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
388	if err != nil {
389		return nil, err
390	}
391
392	obj.iv = parts.iv
393	obj.ciphertext = parts.ciphertext
394	obj.tag = parts.tag
395
396	return obj, nil
397}
398
399func (ctx *genericEncrypter) Options() EncrypterOptions {
400	return EncrypterOptions{
401		Compression:  ctx.compressionAlg,
402		ExtraHeaders: ctx.extraHeaders,
403	}
404}
405
406// Decrypt and validate the object and return the plaintext. Note that this
407// function does not support multi-recipient, if you desire multi-recipient
408// decryption use DecryptMulti instead.
409func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
410	headers := obj.mergedHeaders(nil)
411
412	if len(obj.recipients) > 1 {
413		return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
414	}
415
416	critical, err := headers.getCritical()
417	if err != nil {
418		return nil, fmt.Errorf("square/go-jose: invalid crit header")
419	}
420
421	if len(critical) > 0 {
422		return nil, fmt.Errorf("square/go-jose: unsupported crit header")
423	}
424
425	decrypter, err := newDecrypter(decryptionKey)
426	if err != nil {
427		return nil, err
428	}
429
430	cipher := getContentCipher(headers.getEncryption())
431	if cipher == nil {
432		return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption()))
433	}
434
435	generator := randomKeyGenerator{
436		size: cipher.keySize(),
437	}
438
439	parts := &aeadParts{
440		iv:         obj.iv,
441		ciphertext: obj.ciphertext,
442		tag:        obj.tag,
443	}
444
445	authData := obj.computeAuthData()
446
447	var plaintext []byte
448	recipient := obj.recipients[0]
449	recipientHeaders := obj.mergedHeaders(&recipient)
450
451	cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
452	if err == nil {
453		// Found a valid CEK -- let's try to decrypt.
454		plaintext, err = cipher.decrypt(cek, authData, parts)
455	}
456
457	if plaintext == nil {
458		return nil, ErrCryptoFailure
459	}
460
461	// The "zip" header parameter may only be present in the protected header.
462	if comp := obj.protected.getCompression(); comp != "" {
463		plaintext, err = decompress(comp, plaintext)
464	}
465
466	return plaintext, err
467}
468
469// DecryptMulti decrypts and validates the object and returns the plaintexts,
470// with support for multiple recipients. It returns the index of the recipient
471// for which the decryption was successful, the merged headers for that recipient,
472// and the plaintext.
473func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
474	globalHeaders := obj.mergedHeaders(nil)
475
476	critical, err := globalHeaders.getCritical()
477	if err != nil {
478		return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header")
479	}
480
481	if len(critical) > 0 {
482		return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
483	}
484
485	decrypter, err := newDecrypter(decryptionKey)
486	if err != nil {
487		return -1, Header{}, nil, err
488	}
489
490	encryption := globalHeaders.getEncryption()
491	cipher := getContentCipher(encryption)
492	if cipher == nil {
493		return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption))
494	}
495
496	generator := randomKeyGenerator{
497		size: cipher.keySize(),
498	}
499
500	parts := &aeadParts{
501		iv:         obj.iv,
502		ciphertext: obj.ciphertext,
503		tag:        obj.tag,
504	}
505
506	authData := obj.computeAuthData()
507
508	index := -1
509	var plaintext []byte
510	var headers rawHeader
511
512	for i, recipient := range obj.recipients {
513		recipientHeaders := obj.mergedHeaders(&recipient)
514
515		cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
516		if err == nil {
517			// Found a valid CEK -- let's try to decrypt.
518			plaintext, err = cipher.decrypt(cek, authData, parts)
519			if err == nil {
520				index = i
521				headers = recipientHeaders
522				break
523			}
524		}
525	}
526
527	if plaintext == nil || err != nil {
528		return -1, Header{}, nil, ErrCryptoFailure
529	}
530
531	// The "zip" header parameter may only be present in the protected header.
532	if comp := obj.protected.getCompression(); comp != "" {
533		plaintext, err = decompress(comp, plaintext)
534	}
535
536	sanitized, err := headers.sanitized()
537	if err != nil {
538		return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err)
539	}
540
541	return index, sanitized, plaintext, err
542}
543