1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"crypto/ecdsa"
21	"crypto/rsa"
22	"errors"
23	"fmt"
24	"reflect"
25
26	"gopkg.in/square/go-jose.v2/json"
27)
28
29// Encrypter represents an encrypter which produces an encrypted JWE object.
30type Encrypter interface {
31	Encrypt(plaintext []byte) (*JSONWebEncryption, error)
32	EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error)
33	Options() EncrypterOptions
34}
35
36// A generic content cipher
37type contentCipher interface {
38	keySize() int
39	encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
40	decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
41}
42
43// A key generator (for generating/getting a CEK)
44type keyGenerator interface {
45	keySize() int
46	genKey() ([]byte, rawHeader, error)
47}
48
49// A generic key encrypter
50type keyEncrypter interface {
51	encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key
52}
53
54// A generic key decrypter
55type keyDecrypter interface {
56	decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key
57}
58
59// A generic encrypter based on the given key encrypter and content cipher.
60type genericEncrypter struct {
61	contentAlg     ContentEncryption
62	compressionAlg CompressionAlgorithm
63	cipher         contentCipher
64	recipients     []recipientKeyInfo
65	keyGenerator   keyGenerator
66	extraHeaders   map[HeaderKey]interface{}
67}
68
69type recipientKeyInfo struct {
70	keyID        string
71	keyAlg       KeyAlgorithm
72	keyEncrypter keyEncrypter
73}
74
75// EncrypterOptions represents options that can be set on new encrypters.
76type EncrypterOptions struct {
77	Compression CompressionAlgorithm
78
79	// Optional map of additional keys to be inserted into the protected header
80	// of a JWS object. Some specifications which make use of JWS like to insert
81	// additional values here. All values must be JSON-serializable.
82	ExtraHeaders map[HeaderKey]interface{}
83}
84
85// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
86// if necessary. It returns itself and so can be used in a fluent style.
87func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
88	if eo.ExtraHeaders == nil {
89		eo.ExtraHeaders = map[HeaderKey]interface{}{}
90	}
91	eo.ExtraHeaders[k] = v
92	return eo
93}
94
95// WithContentType adds a content type ("cty") header and returns the updated
96// EncrypterOptions.
97func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions {
98	return eo.WithHeader(HeaderContentType, contentType)
99}
100
101// WithType adds a type ("typ") header and returns the updated EncrypterOptions.
102func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
103	return eo.WithHeader(HeaderType, typ)
104}
105
106// Recipient represents an algorithm/key to encrypt messages to.
107//
108// PBES2Count and PBES2Salt correspond with the  "p2c" and "p2s" headers used
109// on the password-based encryption algorithms PBES2-HS256+A128KW,
110// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe
111// default of 100000 will be used for the count and a 128-bit random salt will
112// be generated.
113type Recipient struct {
114	Algorithm  KeyAlgorithm
115	Key        interface{}
116	KeyID      string
117	PBES2Count int
118	PBES2Salt  []byte
119}
120
121// NewEncrypter creates an appropriate encrypter based on the key type
122func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) {
123	encrypter := &genericEncrypter{
124		contentAlg: enc,
125		recipients: []recipientKeyInfo{},
126		cipher:     getContentCipher(enc),
127	}
128	if opts != nil {
129		encrypter.compressionAlg = opts.Compression
130		encrypter.extraHeaders = opts.ExtraHeaders
131	}
132
133	if encrypter.cipher == nil {
134		return nil, ErrUnsupportedAlgorithm
135	}
136
137	var keyID string
138	var rawKey interface{}
139	switch encryptionKey := rcpt.Key.(type) {
140	case JSONWebKey:
141		keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
142	case *JSONWebKey:
143		keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
144	case OpaqueKeyEncrypter:
145		keyID, rawKey = encryptionKey.KeyID(), encryptionKey
146	default:
147		rawKey = encryptionKey
148	}
149
150	switch rcpt.Algorithm {
151	case DIRECT:
152		// Direct encryption mode must be treated differently
153		if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
154			return nil, ErrUnsupportedKeyType
155		}
156		if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
157			return nil, ErrInvalidKeySize
158		}
159		encrypter.keyGenerator = staticKeyGenerator{
160			key: rawKey.([]byte),
161		}
162		recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
163		recipientInfo.keyID = keyID
164		if rcpt.KeyID != "" {
165			recipientInfo.keyID = rcpt.KeyID
166		}
167		encrypter.recipients = []recipientKeyInfo{recipientInfo}
168		return encrypter, nil
169	case ECDH_ES:
170		// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
171		typeOf := reflect.TypeOf(rawKey)
172		if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
173			return nil, ErrUnsupportedKeyType
174		}
175		encrypter.keyGenerator = ecKeyGenerator{
176			size:      encrypter.cipher.keySize(),
177			algID:     string(enc),
178			publicKey: rawKey.(*ecdsa.PublicKey),
179		}
180		recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
181		recipientInfo.keyID = keyID
182		if rcpt.KeyID != "" {
183			recipientInfo.keyID = rcpt.KeyID
184		}
185		encrypter.recipients = []recipientKeyInfo{recipientInfo}
186		return encrypter, nil
187	default:
188		// Can just add a standard recipient
189		encrypter.keyGenerator = randomKeyGenerator{
190			size: encrypter.cipher.keySize(),
191		}
192		err := encrypter.addRecipient(rcpt)
193		return encrypter, err
194	}
195}
196
197// NewMultiEncrypter creates a multi-encrypter based on the given parameters
198func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) {
199	cipher := getContentCipher(enc)
200
201	if cipher == nil {
202		return nil, ErrUnsupportedAlgorithm
203	}
204	if rcpts == nil || len(rcpts) == 0 {
205		return nil, fmt.Errorf("square/go-jose: recipients is nil or empty")
206	}
207
208	encrypter := &genericEncrypter{
209		contentAlg: enc,
210		recipients: []recipientKeyInfo{},
211		cipher:     cipher,
212		keyGenerator: randomKeyGenerator{
213			size: cipher.keySize(),
214		},
215	}
216
217	if opts != nil {
218		encrypter.compressionAlg = opts.Compression
219	}
220
221	for _, recipient := range rcpts {
222		err := encrypter.addRecipient(recipient)
223		if err != nil {
224			return nil, err
225		}
226	}
227
228	return encrypter, nil
229}
230
231func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) {
232	var recipientInfo recipientKeyInfo
233
234	switch recipient.Algorithm {
235	case DIRECT, ECDH_ES:
236		return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm)
237	}
238
239	recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key)
240	if recipient.KeyID != "" {
241		recipientInfo.keyID = recipient.KeyID
242	}
243
244	switch recipient.Algorithm {
245	case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
246		if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok {
247			sr.p2c = recipient.PBES2Count
248			sr.p2s = recipient.PBES2Salt
249		}
250	}
251
252	if err == nil {
253		ctx.recipients = append(ctx.recipients, recipientInfo)
254	}
255	return err
256}
257
258func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
259	switch encryptionKey := encryptionKey.(type) {
260	case *rsa.PublicKey:
261		return newRSARecipient(alg, encryptionKey)
262	case *ecdsa.PublicKey:
263		return newECDHRecipient(alg, encryptionKey)
264	case []byte:
265		return newSymmetricRecipient(alg, encryptionKey)
266	case string:
267		return newSymmetricRecipient(alg, []byte(encryptionKey))
268	case *JSONWebKey:
269		recipient, err := makeJWERecipient(alg, encryptionKey.Key)
270		recipient.keyID = encryptionKey.KeyID
271		return recipient, err
272	}
273	if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
274		return newOpaqueKeyEncrypter(alg, encrypter)
275	}
276	return recipientKeyInfo{}, ErrUnsupportedKeyType
277}
278
279// newDecrypter creates an appropriate decrypter based on the key type
280func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
281	switch decryptionKey := decryptionKey.(type) {
282	case *rsa.PrivateKey:
283		return &rsaDecrypterSigner{
284			privateKey: decryptionKey,
285		}, nil
286	case *ecdsa.PrivateKey:
287		return &ecDecrypterSigner{
288			privateKey: decryptionKey,
289		}, nil
290	case []byte:
291		return &symmetricKeyCipher{
292			key: decryptionKey,
293		}, nil
294	case string:
295		return &symmetricKeyCipher{
296			key: []byte(decryptionKey),
297		}, nil
298	case JSONWebKey:
299		return newDecrypter(decryptionKey.Key)
300	case *JSONWebKey:
301		return newDecrypter(decryptionKey.Key)
302	}
303	if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
304		return &opaqueKeyDecrypter{decrypter: okd}, nil
305	}
306	return nil, ErrUnsupportedKeyType
307}
308
309// Implementation of encrypt method producing a JWE object.
310func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
311	return ctx.EncryptWithAuthData(plaintext, nil)
312}
313
314// Implementation of encrypt method producing a JWE object.
315func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
316	obj := &JSONWebEncryption{}
317	obj.aad = aad
318
319	obj.protected = &rawHeader{}
320	err := obj.protected.set(headerEncryption, ctx.contentAlg)
321	if err != nil {
322		return nil, err
323	}
324
325	obj.recipients = make([]recipientInfo, len(ctx.recipients))
326
327	if len(ctx.recipients) == 0 {
328		return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to")
329	}
330
331	cek, headers, err := ctx.keyGenerator.genKey()
332	if err != nil {
333		return nil, err
334	}
335
336	obj.protected.merge(&headers)
337
338	for i, info := range ctx.recipients {
339		recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
340		if err != nil {
341			return nil, err
342		}
343
344		err = recipient.header.set(headerAlgorithm, info.keyAlg)
345		if err != nil {
346			return nil, err
347		}
348
349		if info.keyID != "" {
350			err = recipient.header.set(headerKeyID, info.keyID)
351			if err != nil {
352				return nil, err
353			}
354		}
355		obj.recipients[i] = recipient
356	}
357
358	if len(ctx.recipients) == 1 {
359		// Move per-recipient headers into main protected header if there's
360		// only a single recipient.
361		obj.protected.merge(obj.recipients[0].header)
362		obj.recipients[0].header = nil
363	}
364
365	if ctx.compressionAlg != NONE {
366		plaintext, err = compress(ctx.compressionAlg, plaintext)
367		if err != nil {
368			return nil, err
369		}
370
371		err = obj.protected.set(headerCompression, ctx.compressionAlg)
372		if err != nil {
373			return nil, err
374		}
375	}
376
377	for k, v := range ctx.extraHeaders {
378		b, err := json.Marshal(v)
379		if err != nil {
380			return nil, err
381		}
382		(*obj.protected)[k] = makeRawMessage(b)
383	}
384
385	authData := obj.computeAuthData()
386	parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
387	if err != nil {
388		return nil, err
389	}
390
391	obj.iv = parts.iv
392	obj.ciphertext = parts.ciphertext
393	obj.tag = parts.tag
394
395	return obj, nil
396}
397
398func (ctx *genericEncrypter) Options() EncrypterOptions {
399	return EncrypterOptions{
400		Compression:  ctx.compressionAlg,
401		ExtraHeaders: ctx.extraHeaders,
402	}
403}
404
405// Decrypt and validate the object and return the plaintext. Note that this
406// function does not support multi-recipient, if you desire multi-recipient
407// decryption use DecryptMulti instead.
408func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
409	headers := obj.mergedHeaders(nil)
410
411	if len(obj.recipients) > 1 {
412		return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
413	}
414
415	critical, err := headers.getCritical()
416	if err != nil {
417		return nil, fmt.Errorf("square/go-jose: invalid crit header")
418	}
419
420	if len(critical) > 0 {
421		return nil, fmt.Errorf("square/go-jose: unsupported crit header")
422	}
423
424	decrypter, err := newDecrypter(decryptionKey)
425	if err != nil {
426		return nil, err
427	}
428
429	cipher := getContentCipher(headers.getEncryption())
430	if cipher == nil {
431		return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption()))
432	}
433
434	generator := randomKeyGenerator{
435		size: cipher.keySize(),
436	}
437
438	parts := &aeadParts{
439		iv:         obj.iv,
440		ciphertext: obj.ciphertext,
441		tag:        obj.tag,
442	}
443
444	authData := obj.computeAuthData()
445
446	var plaintext []byte
447	recipient := obj.recipients[0]
448	recipientHeaders := obj.mergedHeaders(&recipient)
449
450	cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
451	if err == nil {
452		// Found a valid CEK -- let's try to decrypt.
453		plaintext, err = cipher.decrypt(cek, authData, parts)
454	}
455
456	if plaintext == nil {
457		return nil, ErrCryptoFailure
458	}
459
460	// The "zip" header parameter may only be present in the protected header.
461	if comp := obj.protected.getCompression(); comp != "" {
462		plaintext, err = decompress(comp, plaintext)
463	}
464
465	return plaintext, err
466}
467
468// DecryptMulti decrypts and validates the object and returns the plaintexts,
469// with support for multiple recipients. It returns the index of the recipient
470// for which the decryption was successful, the merged headers for that recipient,
471// and the plaintext.
472func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
473	globalHeaders := obj.mergedHeaders(nil)
474
475	critical, err := globalHeaders.getCritical()
476	if err != nil {
477		return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header")
478	}
479
480	if len(critical) > 0 {
481		return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
482	}
483
484	decrypter, err := newDecrypter(decryptionKey)
485	if err != nil {
486		return -1, Header{}, nil, err
487	}
488
489	encryption := globalHeaders.getEncryption()
490	cipher := getContentCipher(encryption)
491	if cipher == nil {
492		return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption))
493	}
494
495	generator := randomKeyGenerator{
496		size: cipher.keySize(),
497	}
498
499	parts := &aeadParts{
500		iv:         obj.iv,
501		ciphertext: obj.ciphertext,
502		tag:        obj.tag,
503	}
504
505	authData := obj.computeAuthData()
506
507	index := -1
508	var plaintext []byte
509	var headers rawHeader
510
511	for i, recipient := range obj.recipients {
512		recipientHeaders := obj.mergedHeaders(&recipient)
513
514		cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
515		if err == nil {
516			// Found a valid CEK -- let's try to decrypt.
517			plaintext, err = cipher.decrypt(cek, authData, parts)
518			if err == nil {
519				index = i
520				headers = recipientHeaders
521				break
522			}
523		}
524	}
525
526	if plaintext == nil || err != nil {
527		return -1, Header{}, nil, ErrCryptoFailure
528	}
529
530	// The "zip" header parameter may only be present in the protected header.
531	if comp := obj.protected.getCompression(); comp != "" {
532		plaintext, err = decompress(comp, plaintext)
533	}
534
535	sanitized, err := headers.sanitized()
536	if err != nil {
537		return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err)
538	}
539
540	return index, sanitized, plaintext, err
541}
542