1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"crypto/elliptic"
21	"crypto/x509"
22	"encoding/base64"
23	"errors"
24	"fmt"
25
26	"gopkg.in/square/go-jose.v2/json"
27)
28
29// KeyAlgorithm represents a key management algorithm.
30type KeyAlgorithm string
31
32// SignatureAlgorithm represents a signature (or MAC) algorithm.
33type SignatureAlgorithm string
34
35// ContentEncryption represents a content encryption algorithm.
36type ContentEncryption string
37
38// CompressionAlgorithm represents an algorithm used for plaintext compression.
39type CompressionAlgorithm string
40
41// ContentType represents type of the contained data.
42type ContentType string
43
44var (
45	// ErrCryptoFailure represents an error in cryptographic primitive. This
46	// occurs when, for example, a message had an invalid authentication tag or
47	// could not be decrypted.
48	ErrCryptoFailure = errors.New("square/go-jose: error in cryptographic primitive")
49
50	// ErrUnsupportedAlgorithm indicates that a selected algorithm is not
51	// supported. This occurs when trying to instantiate an encrypter for an
52	// algorithm that is not yet implemented.
53	ErrUnsupportedAlgorithm = errors.New("square/go-jose: unknown/unsupported algorithm")
54
55	// ErrUnsupportedKeyType indicates that the given key type/format is not
56	// supported. This occurs when trying to instantiate an encrypter and passing
57	// it a key of an unrecognized type or with unsupported parameters, such as
58	// an RSA private key with more than two primes.
59	ErrUnsupportedKeyType = errors.New("square/go-jose: unsupported key type/format")
60
61	// ErrInvalidKeySize indicates that the given key is not the correct size
62	// for the selected algorithm. This can occur, for example, when trying to
63	// encrypt with AES-256 but passing only a 128-bit key as input.
64	ErrInvalidKeySize = errors.New("square/go-jose: invalid key size for algorithm")
65
66	// ErrNotSupported serialization of object is not supported. This occurs when
67	// trying to compact-serialize an object which can't be represented in
68	// compact form.
69	ErrNotSupported = errors.New("square/go-jose: compact serialization not supported for object")
70
71	// ErrUnprotectedNonce indicates that while parsing a JWS or JWE object, a
72	// nonce header parameter was included in an unprotected header object.
73	ErrUnprotectedNonce = errors.New("square/go-jose: Nonce parameter included in unprotected header")
74)
75
76// Key management algorithms
77const (
78	ED25519            = KeyAlgorithm("ED25519")
79	RSA1_5             = KeyAlgorithm("RSA1_5")             // RSA-PKCS1v1.5
80	RSA_OAEP           = KeyAlgorithm("RSA-OAEP")           // RSA-OAEP-SHA1
81	RSA_OAEP_256       = KeyAlgorithm("RSA-OAEP-256")       // RSA-OAEP-SHA256
82	A128KW             = KeyAlgorithm("A128KW")             // AES key wrap (128)
83	A192KW             = KeyAlgorithm("A192KW")             // AES key wrap (192)
84	A256KW             = KeyAlgorithm("A256KW")             // AES key wrap (256)
85	DIRECT             = KeyAlgorithm("dir")                // Direct encryption
86	ECDH_ES            = KeyAlgorithm("ECDH-ES")            // ECDH-ES
87	ECDH_ES_A128KW     = KeyAlgorithm("ECDH-ES+A128KW")     // ECDH-ES + AES key wrap (128)
88	ECDH_ES_A192KW     = KeyAlgorithm("ECDH-ES+A192KW")     // ECDH-ES + AES key wrap (192)
89	ECDH_ES_A256KW     = KeyAlgorithm("ECDH-ES+A256KW")     // ECDH-ES + AES key wrap (256)
90	A128GCMKW          = KeyAlgorithm("A128GCMKW")          // AES-GCM key wrap (128)
91	A192GCMKW          = KeyAlgorithm("A192GCMKW")          // AES-GCM key wrap (192)
92	A256GCMKW          = KeyAlgorithm("A256GCMKW")          // AES-GCM key wrap (256)
93	PBES2_HS256_A128KW = KeyAlgorithm("PBES2-HS256+A128KW") // PBES2 + HMAC-SHA256 + AES key wrap (128)
94	PBES2_HS384_A192KW = KeyAlgorithm("PBES2-HS384+A192KW") // PBES2 + HMAC-SHA384 + AES key wrap (192)
95	PBES2_HS512_A256KW = KeyAlgorithm("PBES2-HS512+A256KW") // PBES2 + HMAC-SHA512 + AES key wrap (256)
96)
97
98// Signature algorithms
99const (
100	EdDSA = SignatureAlgorithm("EdDSA")
101	HS256 = SignatureAlgorithm("HS256") // HMAC using SHA-256
102	HS384 = SignatureAlgorithm("HS384") // HMAC using SHA-384
103	HS512 = SignatureAlgorithm("HS512") // HMAC using SHA-512
104	RS256 = SignatureAlgorithm("RS256") // RSASSA-PKCS-v1.5 using SHA-256
105	RS384 = SignatureAlgorithm("RS384") // RSASSA-PKCS-v1.5 using SHA-384
106	RS512 = SignatureAlgorithm("RS512") // RSASSA-PKCS-v1.5 using SHA-512
107	ES256 = SignatureAlgorithm("ES256") // ECDSA using P-256 and SHA-256
108	ES384 = SignatureAlgorithm("ES384") // ECDSA using P-384 and SHA-384
109	ES512 = SignatureAlgorithm("ES512") // ECDSA using P-521 and SHA-512
110	PS256 = SignatureAlgorithm("PS256") // RSASSA-PSS using SHA256 and MGF1-SHA256
111	PS384 = SignatureAlgorithm("PS384") // RSASSA-PSS using SHA384 and MGF1-SHA384
112	PS512 = SignatureAlgorithm("PS512") // RSASSA-PSS using SHA512 and MGF1-SHA512
113)
114
115// Content encryption algorithms
116const (
117	A128CBC_HS256 = ContentEncryption("A128CBC-HS256") // AES-CBC + HMAC-SHA256 (128)
118	A192CBC_HS384 = ContentEncryption("A192CBC-HS384") // AES-CBC + HMAC-SHA384 (192)
119	A256CBC_HS512 = ContentEncryption("A256CBC-HS512") // AES-CBC + HMAC-SHA512 (256)
120	A128GCM       = ContentEncryption("A128GCM")       // AES-GCM (128)
121	A192GCM       = ContentEncryption("A192GCM")       // AES-GCM (192)
122	A256GCM       = ContentEncryption("A256GCM")       // AES-GCM (256)
123)
124
125// Compression algorithms
126const (
127	NONE    = CompressionAlgorithm("")    // No compression
128	DEFLATE = CompressionAlgorithm("DEF") // DEFLATE (RFC 1951)
129)
130
131// A key in the protected header of a JWS object. Use of the Header...
132// constants is preferred to enhance type safety.
133type HeaderKey string
134
135const (
136	HeaderType        HeaderKey = "typ" // string
137	HeaderContentType           = "cty" // string
138
139	// These are set by go-jose and shouldn't need to be set by consumers of the
140	// library.
141	headerAlgorithm   = "alg"  // string
142	headerEncryption  = "enc"  // ContentEncryption
143	headerCompression = "zip"  // CompressionAlgorithm
144	headerCritical    = "crit" // []string
145
146	headerAPU = "apu" // *byteBuffer
147	headerAPV = "apv" // *byteBuffer
148	headerEPK = "epk" // *JSONWebKey
149	headerIV  = "iv"  // *byteBuffer
150	headerTag = "tag" // *byteBuffer
151	headerX5c = "x5c" // []*x509.Certificate
152
153	headerJWK   = "jwk"   // *JSONWebKey
154	headerKeyID = "kid"   // string
155	headerNonce = "nonce" // string
156	headerB64   = "b64"   // bool
157
158	headerP2C = "p2c" // *byteBuffer (int)
159	headerP2S = "p2s" // *byteBuffer ([]byte)
160
161)
162
163// supportedCritical is the set of supported extensions that are understood and processed.
164var supportedCritical = map[string]bool{
165	headerB64: true,
166}
167
168// rawHeader represents the JOSE header for JWE/JWS objects (used for parsing).
169//
170// The decoding of the constituent items is deferred because we want to marshal
171// some members into particular structs rather than generic maps, but at the
172// same time we need to receive any extra fields unhandled by this library to
173// pass through to consuming code in case it wants to examine them.
174type rawHeader map[HeaderKey]*json.RawMessage
175
176// Header represents the read-only JOSE header for JWE/JWS objects.
177type Header struct {
178	KeyID      string
179	JSONWebKey *JSONWebKey
180	Algorithm  string
181	Nonce      string
182
183	// Unverified certificate chain parsed from x5c header.
184	certificates []*x509.Certificate
185
186	// Any headers not recognised above get unmarshalled
187	// from JSON in a generic manner and placed in this map.
188	ExtraHeaders map[HeaderKey]interface{}
189}
190
191// Certificates verifies & returns the certificate chain present
192// in the x5c header field of a message, if one was present. Returns
193// an error if there was no x5c header present or the chain could
194// not be validated with the given verify options.
195func (h Header) Certificates(opts x509.VerifyOptions) ([][]*x509.Certificate, error) {
196	if len(h.certificates) == 0 {
197		return nil, errors.New("square/go-jose: no x5c header present in message")
198	}
199
200	leaf := h.certificates[0]
201	if opts.Intermediates == nil {
202		opts.Intermediates = x509.NewCertPool()
203		for _, intermediate := range h.certificates[1:] {
204			opts.Intermediates.AddCert(intermediate)
205		}
206	}
207
208	return leaf.Verify(opts)
209}
210
211func (parsed rawHeader) set(k HeaderKey, v interface{}) error {
212	b, err := json.Marshal(v)
213	if err != nil {
214		return err
215	}
216
217	parsed[k] = makeRawMessage(b)
218	return nil
219}
220
221// getString gets a string from the raw JSON, defaulting to "".
222func (parsed rawHeader) getString(k HeaderKey) string {
223	v, ok := parsed[k]
224	if !ok || v == nil {
225		return ""
226	}
227	var s string
228	err := json.Unmarshal(*v, &s)
229	if err != nil {
230		return ""
231	}
232	return s
233}
234
235// getByteBuffer gets a byte buffer from the raw JSON. Returns (nil, nil) if
236// not specified.
237func (parsed rawHeader) getByteBuffer(k HeaderKey) (*byteBuffer, error) {
238	v := parsed[k]
239	if v == nil {
240		return nil, nil
241	}
242	var bb *byteBuffer
243	err := json.Unmarshal(*v, &bb)
244	if err != nil {
245		return nil, err
246	}
247	return bb, nil
248}
249
250// getAlgorithm extracts parsed "alg" from the raw JSON as a KeyAlgorithm.
251func (parsed rawHeader) getAlgorithm() KeyAlgorithm {
252	return KeyAlgorithm(parsed.getString(headerAlgorithm))
253}
254
255// getSignatureAlgorithm extracts parsed "alg" from the raw JSON as a SignatureAlgorithm.
256func (parsed rawHeader) getSignatureAlgorithm() SignatureAlgorithm {
257	return SignatureAlgorithm(parsed.getString(headerAlgorithm))
258}
259
260// getEncryption extracts parsed "enc" from the raw JSON.
261func (parsed rawHeader) getEncryption() ContentEncryption {
262	return ContentEncryption(parsed.getString(headerEncryption))
263}
264
265// getCompression extracts parsed "zip" from the raw JSON.
266func (parsed rawHeader) getCompression() CompressionAlgorithm {
267	return CompressionAlgorithm(parsed.getString(headerCompression))
268}
269
270func (parsed rawHeader) getNonce() string {
271	return parsed.getString(headerNonce)
272}
273
274// getEPK extracts parsed "epk" from the raw JSON.
275func (parsed rawHeader) getEPK() (*JSONWebKey, error) {
276	v := parsed[headerEPK]
277	if v == nil {
278		return nil, nil
279	}
280	var epk *JSONWebKey
281	err := json.Unmarshal(*v, &epk)
282	if err != nil {
283		return nil, err
284	}
285	return epk, nil
286}
287
288// getAPU extracts parsed "apu" from the raw JSON.
289func (parsed rawHeader) getAPU() (*byteBuffer, error) {
290	return parsed.getByteBuffer(headerAPU)
291}
292
293// getAPV extracts parsed "apv" from the raw JSON.
294func (parsed rawHeader) getAPV() (*byteBuffer, error) {
295	return parsed.getByteBuffer(headerAPV)
296}
297
298// getIV extracts parsed "iv" from the raw JSON.
299func (parsed rawHeader) getIV() (*byteBuffer, error) {
300	return parsed.getByteBuffer(headerIV)
301}
302
303// getTag extracts parsed "tag" from the raw JSON.
304func (parsed rawHeader) getTag() (*byteBuffer, error) {
305	return parsed.getByteBuffer(headerTag)
306}
307
308// getJWK extracts parsed "jwk" from the raw JSON.
309func (parsed rawHeader) getJWK() (*JSONWebKey, error) {
310	v := parsed[headerJWK]
311	if v == nil {
312		return nil, nil
313	}
314	var jwk *JSONWebKey
315	err := json.Unmarshal(*v, &jwk)
316	if err != nil {
317		return nil, err
318	}
319	return jwk, nil
320}
321
322// getCritical extracts parsed "crit" from the raw JSON. If omitted, it
323// returns an empty slice.
324func (parsed rawHeader) getCritical() ([]string, error) {
325	v := parsed[headerCritical]
326	if v == nil {
327		return nil, nil
328	}
329
330	var q []string
331	err := json.Unmarshal(*v, &q)
332	if err != nil {
333		return nil, err
334	}
335	return q, nil
336}
337
338// getS2C extracts parsed "p2c" from the raw JSON.
339func (parsed rawHeader) getP2C() (int, error) {
340	v := parsed[headerP2C]
341	if v == nil {
342		return 0, nil
343	}
344
345	var p2c int
346	err := json.Unmarshal(*v, &p2c)
347	if err != nil {
348		return 0, err
349	}
350	return p2c, nil
351}
352
353// getS2S extracts parsed "p2s" from the raw JSON.
354func (parsed rawHeader) getP2S() (*byteBuffer, error) {
355	return parsed.getByteBuffer(headerP2S)
356}
357
358// getB64 extracts parsed "b64" from the raw JSON, defaulting to true.
359func (parsed rawHeader) getB64() (bool, error) {
360	v := parsed[headerB64]
361	if v == nil {
362		return true, nil
363	}
364
365	var b64 bool
366	err := json.Unmarshal(*v, &b64)
367	if err != nil {
368		return true, err
369	}
370	return b64, nil
371}
372
373// sanitized produces a cleaned-up header object from the raw JSON.
374func (parsed rawHeader) sanitized() (h Header, err error) {
375	for k, v := range parsed {
376		if v == nil {
377			continue
378		}
379		switch k {
380		case headerJWK:
381			var jwk *JSONWebKey
382			err = json.Unmarshal(*v, &jwk)
383			if err != nil {
384				err = fmt.Errorf("failed to unmarshal JWK: %v: %#v", err, string(*v))
385				return
386			}
387			h.JSONWebKey = jwk
388		case headerKeyID:
389			var s string
390			err = json.Unmarshal(*v, &s)
391			if err != nil {
392				err = fmt.Errorf("failed to unmarshal key ID: %v: %#v", err, string(*v))
393				return
394			}
395			h.KeyID = s
396		case headerAlgorithm:
397			var s string
398			err = json.Unmarshal(*v, &s)
399			if err != nil {
400				err = fmt.Errorf("failed to unmarshal algorithm: %v: %#v", err, string(*v))
401				return
402			}
403			h.Algorithm = s
404		case headerNonce:
405			var s string
406			err = json.Unmarshal(*v, &s)
407			if err != nil {
408				err = fmt.Errorf("failed to unmarshal nonce: %v: %#v", err, string(*v))
409				return
410			}
411			h.Nonce = s
412		case headerX5c:
413			c := []string{}
414			err = json.Unmarshal(*v, &c)
415			if err != nil {
416				err = fmt.Errorf("failed to unmarshal x5c header: %v: %#v", err, string(*v))
417				return
418			}
419			h.certificates, err = parseCertificateChain(c)
420			if err != nil {
421				err = fmt.Errorf("failed to unmarshal x5c header: %v: %#v", err, string(*v))
422				return
423			}
424		default:
425			if h.ExtraHeaders == nil {
426				h.ExtraHeaders = map[HeaderKey]interface{}{}
427			}
428			var v2 interface{}
429			err = json.Unmarshal(*v, &v2)
430			if err != nil {
431				err = fmt.Errorf("failed to unmarshal value: %v: %#v", err, string(*v))
432				return
433			}
434			h.ExtraHeaders[k] = v2
435		}
436	}
437	return
438}
439
440func parseCertificateChain(chain []string) ([]*x509.Certificate, error) {
441	out := make([]*x509.Certificate, len(chain))
442	for i, cert := range chain {
443		raw, err := base64.StdEncoding.DecodeString(cert)
444		if err != nil {
445			return nil, err
446		}
447		out[i], err = x509.ParseCertificate(raw)
448		if err != nil {
449			return nil, err
450		}
451	}
452	return out, nil
453}
454
455func (dst rawHeader) isSet(k HeaderKey) bool {
456	dvr := dst[k]
457	if dvr == nil {
458		return false
459	}
460
461	var dv interface{}
462	err := json.Unmarshal(*dvr, &dv)
463	if err != nil {
464		return true
465	}
466
467	if dvStr, ok := dv.(string); ok {
468		return dvStr != ""
469	}
470
471	return true
472}
473
474// Merge headers from src into dst, giving precedence to headers from l.
475func (dst rawHeader) merge(src *rawHeader) {
476	if src == nil {
477		return
478	}
479
480	for k, v := range *src {
481		if dst.isSet(k) {
482			continue
483		}
484
485		dst[k] = v
486	}
487}
488
489// Get JOSE name of curve
490func curveName(crv elliptic.Curve) (string, error) {
491	switch crv {
492	case elliptic.P256():
493		return "P-256", nil
494	case elliptic.P384():
495		return "P-384", nil
496	case elliptic.P521():
497		return "P-521", nil
498	default:
499		return "", fmt.Errorf("square/go-jose: unsupported/unknown elliptic curve")
500	}
501}
502
503// Get size of curve in bytes
504func curveSize(crv elliptic.Curve) int {
505	bits := crv.Params().BitSize
506
507	div := bits / 8
508	mod := bits % 8
509
510	if mod == 0 {
511		return div
512	}
513
514	return div + 1
515}
516
517func makeRawMessage(b []byte) *json.RawMessage {
518	rm := json.RawMessage(b)
519	return &rm
520}
521