1/*-
2 * Copyright 2014 Square Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package jose
18
19import (
20	"crypto/elliptic"
21	"crypto/x509"
22	"encoding/base64"
23	"errors"
24	"fmt"
25
26	"gopkg.in/square/go-jose.v2/json"
27)
28
29// KeyAlgorithm represents a key management algorithm.
30type KeyAlgorithm string
31
32// SignatureAlgorithm represents a signature (or MAC) algorithm.
33type SignatureAlgorithm string
34
35// ContentEncryption represents a content encryption algorithm.
36type ContentEncryption string
37
38// CompressionAlgorithm represents an algorithm used for plaintext compression.
39type CompressionAlgorithm string
40
41// ContentType represents type of the contained data.
42type ContentType string
43
44var (
45	// ErrCryptoFailure represents an error in cryptographic primitive. This
46	// occurs when, for example, a message had an invalid authentication tag or
47	// could not be decrypted.
48	ErrCryptoFailure = errors.New("square/go-jose: error in cryptographic primitive")
49
50	// ErrUnsupportedAlgorithm indicates that a selected algorithm is not
51	// supported. This occurs when trying to instantiate an encrypter for an
52	// algorithm that is not yet implemented.
53	ErrUnsupportedAlgorithm = errors.New("square/go-jose: unknown/unsupported algorithm")
54
55	// ErrUnsupportedKeyType indicates that the given key type/format is not
56	// supported. This occurs when trying to instantiate an encrypter and passing
57	// it a key of an unrecognized type or with unsupported parameters, such as
58	// an RSA private key with more than two primes.
59	ErrUnsupportedKeyType = errors.New("square/go-jose: unsupported key type/format")
60
61	// ErrInvalidKeySize indicates that the given key is not the correct size
62	// for the selected algorithm. This can occur, for example, when trying to
63	// encrypt with AES-256 but passing only a 128-bit key as input.
64	ErrInvalidKeySize = errors.New("square/go-jose: invalid key size for algorithm")
65
66	// ErrNotSupported serialization of object is not supported. This occurs when
67	// trying to compact-serialize an object which can't be represented in
68	// compact form.
69	ErrNotSupported = errors.New("square/go-jose: compact serialization not supported for object")
70
71	// ErrUnprotectedNonce indicates that while parsing a JWS or JWE object, a
72	// nonce header parameter was included in an unprotected header object.
73	ErrUnprotectedNonce = errors.New("square/go-jose: Nonce parameter included in unprotected header")
74)
75
76// Key management algorithms
77const (
78	ED25519            = KeyAlgorithm("ED25519")
79	RSA1_5             = KeyAlgorithm("RSA1_5")             // RSA-PKCS1v1.5
80	RSA_OAEP           = KeyAlgorithm("RSA-OAEP")           // RSA-OAEP-SHA1
81	RSA_OAEP_256       = KeyAlgorithm("RSA-OAEP-256")       // RSA-OAEP-SHA256
82	A128KW             = KeyAlgorithm("A128KW")             // AES key wrap (128)
83	A192KW             = KeyAlgorithm("A192KW")             // AES key wrap (192)
84	A256KW             = KeyAlgorithm("A256KW")             // AES key wrap (256)
85	DIRECT             = KeyAlgorithm("dir")                // Direct encryption
86	ECDH_ES            = KeyAlgorithm("ECDH-ES")            // ECDH-ES
87	ECDH_ES_A128KW     = KeyAlgorithm("ECDH-ES+A128KW")     // ECDH-ES + AES key wrap (128)
88	ECDH_ES_A192KW     = KeyAlgorithm("ECDH-ES+A192KW")     // ECDH-ES + AES key wrap (192)
89	ECDH_ES_A256KW     = KeyAlgorithm("ECDH-ES+A256KW")     // ECDH-ES + AES key wrap (256)
90	A128GCMKW          = KeyAlgorithm("A128GCMKW")          // AES-GCM key wrap (128)
91	A192GCMKW          = KeyAlgorithm("A192GCMKW")          // AES-GCM key wrap (192)
92	A256GCMKW          = KeyAlgorithm("A256GCMKW")          // AES-GCM key wrap (256)
93	PBES2_HS256_A128KW = KeyAlgorithm("PBES2-HS256+A128KW") // PBES2 + HMAC-SHA256 + AES key wrap (128)
94	PBES2_HS384_A192KW = KeyAlgorithm("PBES2-HS384+A192KW") // PBES2 + HMAC-SHA384 + AES key wrap (192)
95	PBES2_HS512_A256KW = KeyAlgorithm("PBES2-HS512+A256KW") // PBES2 + HMAC-SHA512 + AES key wrap (256)
96)
97
98// Signature algorithms
99const (
100	EdDSA = SignatureAlgorithm("EdDSA")
101	HS256 = SignatureAlgorithm("HS256") // HMAC using SHA-256
102	HS384 = SignatureAlgorithm("HS384") // HMAC using SHA-384
103	HS512 = SignatureAlgorithm("HS512") // HMAC using SHA-512
104	RS256 = SignatureAlgorithm("RS256") // RSASSA-PKCS-v1.5 using SHA-256
105	RS384 = SignatureAlgorithm("RS384") // RSASSA-PKCS-v1.5 using SHA-384
106	RS512 = SignatureAlgorithm("RS512") // RSASSA-PKCS-v1.5 using SHA-512
107	ES256 = SignatureAlgorithm("ES256") // ECDSA using P-256 and SHA-256
108	ES384 = SignatureAlgorithm("ES384") // ECDSA using P-384 and SHA-384
109	ES512 = SignatureAlgorithm("ES512") // ECDSA using P-521 and SHA-512
110	PS256 = SignatureAlgorithm("PS256") // RSASSA-PSS using SHA256 and MGF1-SHA256
111	PS384 = SignatureAlgorithm("PS384") // RSASSA-PSS using SHA384 and MGF1-SHA384
112	PS512 = SignatureAlgorithm("PS512") // RSASSA-PSS using SHA512 and MGF1-SHA512
113)
114
115// Content encryption algorithms
116const (
117	A128CBC_HS256 = ContentEncryption("A128CBC-HS256") // AES-CBC + HMAC-SHA256 (128)
118	A192CBC_HS384 = ContentEncryption("A192CBC-HS384") // AES-CBC + HMAC-SHA384 (192)
119	A256CBC_HS512 = ContentEncryption("A256CBC-HS512") // AES-CBC + HMAC-SHA512 (256)
120	A128GCM       = ContentEncryption("A128GCM")       // AES-GCM (128)
121	A192GCM       = ContentEncryption("A192GCM")       // AES-GCM (192)
122	A256GCM       = ContentEncryption("A256GCM")       // AES-GCM (256)
123)
124
125// Compression algorithms
126const (
127	NONE    = CompressionAlgorithm("")    // No compression
128	DEFLATE = CompressionAlgorithm("DEF") // DEFLATE (RFC 1951)
129)
130
131// A key in the protected header of a JWS object. Use of the Header...
132// constants is preferred to enhance type safety.
133type HeaderKey string
134
135const (
136	HeaderType        HeaderKey = "typ" // string
137	HeaderContentType           = "cty" // string
138
139	// These are set by go-jose and shouldn't need to be set by consumers of the
140	// library.
141	headerAlgorithm   = "alg"  // string
142	headerEncryption  = "enc"  // ContentEncryption
143	headerCompression = "zip"  // CompressionAlgorithm
144	headerCritical    = "crit" // []string
145
146	headerAPU = "apu" // *byteBuffer
147	headerAPV = "apv" // *byteBuffer
148	headerEPK = "epk" // *JSONWebKey
149	headerIV  = "iv"  // *byteBuffer
150	headerTag = "tag" // *byteBuffer
151	headerX5c = "x5c" // []*x509.Certificate
152
153	headerJWK   = "jwk"   // *JSONWebKey
154	headerKeyID = "kid"   // string
155	headerNonce = "nonce" // string
156
157	headerP2C = "p2c" // *byteBuffer (int)
158	headerP2S = "p2s" // *byteBuffer ([]byte)
159
160)
161
162// rawHeader represents the JOSE header for JWE/JWS objects (used for parsing).
163//
164// The decoding of the constituent items is deferred because we want to marshal
165// some members into particular structs rather than generic maps, but at the
166// same time we need to receive any extra fields unhandled by this library to
167// pass through to consuming code in case it wants to examine them.
168type rawHeader map[HeaderKey]*json.RawMessage
169
170// Header represents the read-only JOSE header for JWE/JWS objects.
171type Header struct {
172	KeyID      string
173	JSONWebKey *JSONWebKey
174	Algorithm  string
175	Nonce      string
176
177	// Unverified certificate chain parsed from x5c header.
178	certificates []*x509.Certificate
179
180	// Any headers not recognised above get unmarshaled
181	// from JSON in a generic manner and placed in this map.
182	ExtraHeaders map[HeaderKey]interface{}
183}
184
185// Certificates verifies & returns the certificate chain present
186// in the x5c header field of a message, if one was present. Returns
187// an error if there was no x5c header present or the chain could
188// not be validated with the given verify options.
189func (h Header) Certificates(opts x509.VerifyOptions) ([][]*x509.Certificate, error) {
190	if len(h.certificates) == 0 {
191		return nil, errors.New("square/go-jose: no x5c header present in message")
192	}
193
194	leaf := h.certificates[0]
195	if opts.Intermediates == nil {
196		opts.Intermediates = x509.NewCertPool()
197		for _, intermediate := range h.certificates[1:] {
198			opts.Intermediates.AddCert(intermediate)
199		}
200	}
201
202	return leaf.Verify(opts)
203}
204
205func (parsed rawHeader) set(k HeaderKey, v interface{}) error {
206	b, err := json.Marshal(v)
207	if err != nil {
208		return err
209	}
210
211	parsed[k] = makeRawMessage(b)
212	return nil
213}
214
215// getString gets a string from the raw JSON, defaulting to "".
216func (parsed rawHeader) getString(k HeaderKey) string {
217	v, ok := parsed[k]
218	if !ok || v == nil {
219		return ""
220	}
221	var s string
222	err := json.Unmarshal(*v, &s)
223	if err != nil {
224		return ""
225	}
226	return s
227}
228
229// getByteBuffer gets a byte buffer from the raw JSON. Returns (nil, nil) if
230// not specified.
231func (parsed rawHeader) getByteBuffer(k HeaderKey) (*byteBuffer, error) {
232	v := parsed[k]
233	if v == nil {
234		return nil, nil
235	}
236	var bb *byteBuffer
237	err := json.Unmarshal(*v, &bb)
238	if err != nil {
239		return nil, err
240	}
241	return bb, nil
242}
243
244// getAlgorithm extracts parsed "alg" from the raw JSON as a KeyAlgorithm.
245func (parsed rawHeader) getAlgorithm() KeyAlgorithm {
246	return KeyAlgorithm(parsed.getString(headerAlgorithm))
247}
248
249// getSignatureAlgorithm extracts parsed "alg" from the raw JSON as a SignatureAlgorithm.
250func (parsed rawHeader) getSignatureAlgorithm() SignatureAlgorithm {
251	return SignatureAlgorithm(parsed.getString(headerAlgorithm))
252}
253
254// getEncryption extracts parsed "enc" from the raw JSON.
255func (parsed rawHeader) getEncryption() ContentEncryption {
256	return ContentEncryption(parsed.getString(headerEncryption))
257}
258
259// getCompression extracts parsed "zip" from the raw JSON.
260func (parsed rawHeader) getCompression() CompressionAlgorithm {
261	return CompressionAlgorithm(parsed.getString(headerCompression))
262}
263
264func (parsed rawHeader) getNonce() string {
265	return parsed.getString(headerNonce)
266}
267
268// getEPK extracts parsed "epk" from the raw JSON.
269func (parsed rawHeader) getEPK() (*JSONWebKey, error) {
270	v := parsed[headerEPK]
271	if v == nil {
272		return nil, nil
273	}
274	var epk *JSONWebKey
275	err := json.Unmarshal(*v, &epk)
276	if err != nil {
277		return nil, err
278	}
279	return epk, nil
280}
281
282// getAPU extracts parsed "apu" from the raw JSON.
283func (parsed rawHeader) getAPU() (*byteBuffer, error) {
284	return parsed.getByteBuffer(headerAPU)
285}
286
287// getAPV extracts parsed "apv" from the raw JSON.
288func (parsed rawHeader) getAPV() (*byteBuffer, error) {
289	return parsed.getByteBuffer(headerAPV)
290}
291
292// getIV extracts parsed "iv" frpom the raw JSON.
293func (parsed rawHeader) getIV() (*byteBuffer, error) {
294	return parsed.getByteBuffer(headerIV)
295}
296
297// getTag extracts parsed "tag" frpom the raw JSON.
298func (parsed rawHeader) getTag() (*byteBuffer, error) {
299	return parsed.getByteBuffer(headerTag)
300}
301
302// getJWK extracts parsed "jwk" from the raw JSON.
303func (parsed rawHeader) getJWK() (*JSONWebKey, error) {
304	v := parsed[headerJWK]
305	if v == nil {
306		return nil, nil
307	}
308	var jwk *JSONWebKey
309	err := json.Unmarshal(*v, &jwk)
310	if err != nil {
311		return nil, err
312	}
313	return jwk, nil
314}
315
316// getCritical extracts parsed "crit" from the raw JSON. If omitted, it
317// returns an empty slice.
318func (parsed rawHeader) getCritical() ([]string, error) {
319	v := parsed[headerCritical]
320	if v == nil {
321		return nil, nil
322	}
323
324	var q []string
325	err := json.Unmarshal(*v, &q)
326	if err != nil {
327		return nil, err
328	}
329	return q, nil
330}
331
332// getS2C extracts parsed "p2c" from the raw JSON.
333func (parsed rawHeader) getP2C() (int, error) {
334	v := parsed[headerP2C]
335	if v == nil {
336		return 0, nil
337	}
338
339	var p2c int
340	err := json.Unmarshal(*v, &p2c)
341	if err != nil {
342		return 0, err
343	}
344	return p2c, nil
345}
346
347// getS2S extracts parsed "p2s" from the raw JSON.
348func (parsed rawHeader) getP2S() (*byteBuffer, error) {
349	return parsed.getByteBuffer(headerP2S)
350}
351
352// sanitized produces a cleaned-up header object from the raw JSON.
353func (parsed rawHeader) sanitized() (h Header, err error) {
354	for k, v := range parsed {
355		if v == nil {
356			continue
357		}
358		switch k {
359		case headerJWK:
360			var jwk *JSONWebKey
361			err = json.Unmarshal(*v, &jwk)
362			if err != nil {
363				err = fmt.Errorf("failed to unmarshal JWK: %v: %#v", err, string(*v))
364				return
365			}
366			h.JSONWebKey = jwk
367		case headerKeyID:
368			var s string
369			err = json.Unmarshal(*v, &s)
370			if err != nil {
371				err = fmt.Errorf("failed to unmarshal key ID: %v: %#v", err, string(*v))
372				return
373			}
374			h.KeyID = s
375		case headerAlgorithm:
376			var s string
377			err = json.Unmarshal(*v, &s)
378			if err != nil {
379				err = fmt.Errorf("failed to unmarshal algorithm: %v: %#v", err, string(*v))
380				return
381			}
382			h.Algorithm = s
383		case headerNonce:
384			var s string
385			err = json.Unmarshal(*v, &s)
386			if err != nil {
387				err = fmt.Errorf("failed to unmarshal nonce: %v: %#v", err, string(*v))
388				return
389			}
390			h.Nonce = s
391		case headerX5c:
392			c := []string{}
393			err = json.Unmarshal(*v, &c)
394			if err != nil {
395				err = fmt.Errorf("failed to unmarshal x5c header: %v: %#v", err, string(*v))
396				return
397			}
398			h.certificates, err = parseCertificateChain(c)
399			if err != nil {
400				err = fmt.Errorf("failed to unmarshal x5c header: %v: %#v", err, string(*v))
401				return
402			}
403		default:
404			if h.ExtraHeaders == nil {
405				h.ExtraHeaders = map[HeaderKey]interface{}{}
406			}
407			var v2 interface{}
408			err = json.Unmarshal(*v, &v2)
409			if err != nil {
410				err = fmt.Errorf("failed to unmarshal value: %v: %#v", err, string(*v))
411				return
412			}
413			h.ExtraHeaders[k] = v2
414		}
415	}
416	return
417}
418
419func parseCertificateChain(chain []string) ([]*x509.Certificate, error) {
420	out := make([]*x509.Certificate, len(chain))
421	for i, cert := range chain {
422		raw, err := base64.StdEncoding.DecodeString(cert)
423		if err != nil {
424			return nil, err
425		}
426		out[i], err = x509.ParseCertificate(raw)
427		if err != nil {
428			return nil, err
429		}
430	}
431	return out, nil
432}
433
434func (dst rawHeader) isSet(k HeaderKey) bool {
435	dvr := dst[k]
436	if dvr == nil {
437		return false
438	}
439
440	var dv interface{}
441	err := json.Unmarshal(*dvr, &dv)
442	if err != nil {
443		return true
444	}
445
446	if dvStr, ok := dv.(string); ok {
447		return dvStr != ""
448	}
449
450	return true
451}
452
453// Merge headers from src into dst, giving precedence to headers from l.
454func (dst rawHeader) merge(src *rawHeader) {
455	if src == nil {
456		return
457	}
458
459	for k, v := range *src {
460		if dst.isSet(k) {
461			continue
462		}
463
464		dst[k] = v
465	}
466}
467
468// Get JOSE name of curve
469func curveName(crv elliptic.Curve) (string, error) {
470	switch crv {
471	case elliptic.P256():
472		return "P-256", nil
473	case elliptic.P384():
474		return "P-384", nil
475	case elliptic.P521():
476		return "P-521", nil
477	default:
478		return "", fmt.Errorf("square/go-jose: unsupported/unknown elliptic curve")
479	}
480}
481
482// Get size of curve in bytes
483func curveSize(crv elliptic.Curve) int {
484	bits := crv.Params().BitSize
485
486	div := bits / 8
487	mod := bits % 8
488
489	if mod == 0 {
490		return div
491	}
492
493	return div + 1
494}
495
496func makeRawMessage(b []byte) *json.RawMessage {
497	rm := json.RawMessage(b)
498	return &rm
499}
500