1//go:generate ./gen.sh
2
3// Package jwk implements JWK as described in https://tools.ietf.org/html/rfc7517
4package jwk
5
6import (
7	"bytes"
8	"context"
9	"crypto"
10	"crypto/ecdsa"
11	"crypto/ed25519"
12	"crypto/rsa"
13	"crypto/x509"
14	"encoding/pem"
15	"io"
16	"io/ioutil"
17	"math/big"
18	"net/http"
19
20	"github.com/lestrrat-go/backoff/v2"
21	"github.com/lestrrat-go/jwx/internal/base64"
22	"github.com/lestrrat-go/jwx/internal/json"
23	"github.com/lestrrat-go/jwx/jwa"
24	"github.com/lestrrat-go/jwx/x25519"
25	"github.com/pkg/errors"
26)
27
28var registry = json.NewRegistry()
29
30func bigIntToBytes(n *big.Int) ([]byte, error) {
31	if n == nil {
32		return nil, errors.New(`invalid *big.Int value`)
33	}
34	return n.Bytes(), nil
35}
36
37// New creates a jwk.Key from the given key (RSA/ECDSA/symmetric keys).
38//
39// The constructor auto-detects the type of key to be instantiated
40// based on the input type:
41//
42//   * "crypto/rsa".PrivateKey and "crypto/rsa".PublicKey creates an RSA based key
43//   * "crypto/ecdsa".PrivateKey and "crypto/ecdsa".PublicKey creates an EC based key
44//   * "crypto/ed25519".PrivateKey and "crypto/ed25519".PublicKey creates an OKP based key
45//   * []byte creates a symmetric key
46func New(key interface{}) (Key, error) {
47	if key == nil {
48		return nil, errors.New(`jwk.New requires a non-nil key`)
49	}
50
51	var ptr interface{}
52	switch v := key.(type) {
53	case rsa.PrivateKey:
54		ptr = &v
55	case rsa.PublicKey:
56		ptr = &v
57	case ecdsa.PrivateKey:
58		ptr = &v
59	case ecdsa.PublicKey:
60		ptr = &v
61	default:
62		ptr = v
63	}
64
65	switch rawKey := ptr.(type) {
66	case *rsa.PrivateKey:
67		k := NewRSAPrivateKey()
68		if err := k.FromRaw(rawKey); err != nil {
69			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
70		}
71		return k, nil
72	case *rsa.PublicKey:
73		k := NewRSAPublicKey()
74		if err := k.FromRaw(rawKey); err != nil {
75			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
76		}
77		return k, nil
78	case *ecdsa.PrivateKey:
79		k := NewECDSAPrivateKey()
80		if err := k.FromRaw(rawKey); err != nil {
81			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
82		}
83		return k, nil
84	case *ecdsa.PublicKey:
85		k := NewECDSAPublicKey()
86		if err := k.FromRaw(rawKey); err != nil {
87			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
88		}
89		return k, nil
90	case ed25519.PrivateKey:
91		k := NewOKPPrivateKey()
92		if err := k.FromRaw(rawKey); err != nil {
93			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
94		}
95		return k, nil
96	case ed25519.PublicKey:
97		k := NewOKPPublicKey()
98		if err := k.FromRaw(rawKey); err != nil {
99			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
100		}
101		return k, nil
102	case x25519.PrivateKey:
103		k := NewOKPPrivateKey()
104		if err := k.FromRaw(rawKey); err != nil {
105			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
106		}
107		return k, nil
108	case x25519.PublicKey:
109		k := NewOKPPublicKey()
110		if err := k.FromRaw(rawKey); err != nil {
111			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
112		}
113		return k, nil
114	case []byte:
115		k := NewSymmetricKey()
116		if err := k.FromRaw(rawKey); err != nil {
117			return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
118		}
119		return k, nil
120	default:
121		return nil, errors.Errorf(`invalid key type '%T' for jwk.New`, key)
122	}
123}
124
125// PublicSetOf returns a new jwk.Set consisting of
126// public keys of the keys contained in the set.
127//
128// This is useful when you are generating a set of private keys, and
129// you want to generate the corresponding public versions for the
130// users to verify with.
131//
132// Be aware that all fields will be copied onto the new public key. It is the caller's
133// responsibility to remove any fields, if necessary.
134func PublicSetOf(v Set) (Set, error) {
135	newSet := NewSet()
136
137	n := v.Len()
138	for i := 0; i < n; i++ {
139		k, ok := v.Get(i)
140		if !ok {
141			return nil, errors.New("key not found")
142		}
143		pubKey, err := PublicKeyOf(k)
144		if err != nil {
145			return nil, errors.Wrapf(err, `failed to get public key of %T`, k)
146		}
147		newSet.Add(pubKey)
148	}
149
150	return newSet, nil
151}
152
153// PublicKeyOf returns the corresponding public version of the jwk.Key.
154// If `v` is a SymmetricKey, then the same value is returned.
155// If `v` is already a public key, the key itself is returned.
156//
157// If `v` is a private key type that has a `PublicKey()` method, be aware
158// that all fields will be copied onto the new public key. It is the caller's
159// responsibility to remove any fields, if necessary
160//
161// If `v` is a raw key, the key is first converted to a `jwk.Key`
162func PublicKeyOf(v interface{}) (Key, error) {
163	if pk, ok := v.(PublicKeyer); ok {
164		return pk.PublicKey()
165	}
166
167	jk, err := New(v)
168	if err != nil {
169		return nil, errors.Wrapf(err, `failed to convert key into JWK`)
170	}
171
172	return jk.PublicKey()
173}
174
175// PublicRawKeyOf returns the corresponding public key of the given
176// value `v` (e.g. given *rsa.PrivateKey, *rsa.PublicKey is returned)
177// If `v` is already a public key, the key itself is returned.
178//
179// The returned value will always be a pointer to the public key,
180// except when a []byte (e.g. symmetric key, ed25519 key) is passed to `v`.
181// In this case, the same []byte value is returned.
182func PublicRawKeyOf(v interface{}) (interface{}, error) {
183	if pk, ok := v.(PublicKeyer); ok {
184		pubk, err := pk.PublicKey()
185		if err != nil {
186			return nil, errors.Wrapf(err, `failed to obtain public key from %T`, v)
187		}
188
189		var raw interface{}
190		if err := pubk.Raw(&raw); err != nil {
191			return nil, errors.Wrapf(err, `failed to obtain raw key from %T`, pubk)
192		}
193		return raw, nil
194	}
195
196	// This may be a silly idea, but if the user gave us a non-pointer value...
197	var ptr interface{}
198	switch v := v.(type) {
199	case rsa.PrivateKey:
200		ptr = &v
201	case rsa.PublicKey:
202		ptr = &v
203	case ecdsa.PrivateKey:
204		ptr = &v
205	case ecdsa.PublicKey:
206		ptr = &v
207	default:
208		ptr = v
209	}
210
211	switch x := ptr.(type) {
212	case *rsa.PrivateKey:
213		return &x.PublicKey, nil
214	case *rsa.PublicKey:
215		return x, nil
216	case *ecdsa.PrivateKey:
217		return &x.PublicKey, nil
218	case *ecdsa.PublicKey:
219		return x, nil
220	case ed25519.PrivateKey:
221		return x.Public(), nil
222	case ed25519.PublicKey:
223		return x, nil
224	case x25519.PrivateKey:
225		return x.Public(), nil
226	case x25519.PublicKey:
227		return x, nil
228	case []byte:
229		return x, nil
230	default:
231		return nil, errors.Errorf(`invalid key type passed to PublicKeyOf (%T)`, v)
232	}
233}
234
235// Fetch fetches a JWK resource specified by a URL. The url must be
236// pointing to a resource that is supported by `net/http`.
237//
238// If you are using the same `jwk.Set` for long periods of time during
239// the lifecycle of your program, and would like to periodically refresh the
240// contents of the object with the data at the remote resource,
241// consider using `jwk.AutoRefresh`, which automatically refreshes
242// jwk.Set objects asynchronously.
243func Fetch(ctx context.Context, urlstring string, options ...FetchOption) (Set, error) {
244	res, err := fetch(ctx, urlstring, options...)
245	if err != nil {
246		return nil, err
247	}
248
249	defer res.Body.Close()
250	keyset, err := ParseReader(res.Body)
251	if err != nil {
252		return nil, errors.Wrap(err, `failed to parse JWK set`)
253	}
254	return keyset, nil
255}
256
257func fetch(ctx context.Context, urlstring string, options ...FetchOption) (*http.Response, error) {
258	var httpcl HTTPClient = http.DefaultClient
259	bo := backoff.Null()
260	for _, option := range options {
261		//nolint:forcetypeassert
262		switch option.Ident() {
263		case identHTTPClient{}:
264			httpcl = option.Value().(HTTPClient)
265		case identFetchBackoff{}:
266			bo = option.Value().(backoff.Policy)
267		}
268	}
269
270	req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlstring, nil)
271	if err != nil {
272		return nil, errors.Wrap(err, "failed to new request to remote JWK")
273	}
274
275	b := bo.Start(ctx)
276	var lastError error
277	for backoff.Continue(b) {
278		res, err := httpcl.Do(req)
279		if err != nil {
280			lastError = errors.Wrap(err, "failed to fetch remote JWK")
281			continue
282		}
283
284		if res.StatusCode != http.StatusOK {
285			lastError = errors.Errorf("failed to fetch remote JWK (status = %d)", res.StatusCode)
286			continue
287		}
288		return res, nil
289	}
290
291	// It's possible for us to get here without populating lastError.
292	// e.g. what if we bailed out of `for backoff.Contineu(b)` without making
293	// a single request? or, <-ctx.Done() returned?
294	if lastError == nil {
295		lastError = errors.New(`fetching remote JWK did not complete`)
296	}
297	return nil, lastError
298}
299
300// ParseRawKey is a combination of ParseKey and Raw. It parses a single JWK key,
301// and assigns the "raw" key to the given parameter. The key must either be
302// a pointer to an empty interface, or a pointer to the actual raw key type
303// such as *rsa.PrivateKey, *ecdsa.PublicKey, *[]byte, etc.
304func ParseRawKey(data []byte, rawkey interface{}) error {
305	key, err := ParseKey(data)
306	if err != nil {
307		return errors.Wrap(err, `failed to parse key`)
308	}
309
310	if err := key.Raw(rawkey); err != nil {
311		return errors.Wrap(err, `failed to assign to raw key variable`)
312	}
313
314	return nil
315}
316
317// parsePEMEncodedRawKey parses a key in PEM encoded ASN.1 DER format. It tires its
318// best to determine the key type, but when it just can't, it will return
319// an error
320func parsePEMEncodedRawKey(src []byte) (interface{}, []byte, error) {
321	block, rest := pem.Decode(src)
322	if block == nil {
323		return nil, nil, errors.New(`failed to decode PEM data`)
324	}
325
326	switch block.Type {
327	// Handle the semi-obvious cases
328	case "RSA PRIVATE KEY":
329		key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
330		if err != nil {
331			return nil, nil, errors.Wrap(err, `failed to parse PKCS1 private key`)
332		}
333		return key, rest, nil
334	case "RSA PUBLIC KEY":
335		key, err := x509.ParsePKCS1PublicKey(block.Bytes)
336		if err != nil {
337			return nil, nil, errors.Wrap(err, `failed to parse PKCS1 public key`)
338		}
339		return key, rest, nil
340	case "EC PRIVATE KEY":
341		key, err := x509.ParseECPrivateKey(block.Bytes)
342		if err != nil {
343			return nil, nil, errors.Wrap(err, `failed to parse EC private key`)
344		}
345		return key, rest, nil
346	case "PUBLIC KEY":
347		// XXX *could* return dsa.PublicKey
348		key, err := x509.ParsePKIXPublicKey(block.Bytes)
349		if err != nil {
350			return nil, nil, errors.Wrap(err, `failed to parse PKIX public key`)
351		}
352		return key, rest, nil
353	case "PRIVATE KEY":
354		key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
355		if err != nil {
356			return nil, nil, errors.Wrap(err, `failed to parse PKCS8 private key`)
357		}
358		return key, rest, nil
359	case "CERTIFICATE":
360		cert, err := x509.ParseCertificate(block.Bytes)
361		if err != nil {
362			return nil, nil, errors.Wrap(err, `failed to parse certificate`)
363		}
364		return cert.PublicKey, rest, nil
365	default:
366		return nil, nil, errors.Errorf(`invalid PEM block type %s`, block.Type)
367	}
368}
369
370// ParseKey parses a single key JWK. Unlike `jwk.Parse` this method will
371// report failure if you attempt to pass a JWK set. Only use this function
372// when you know that the data is a single JWK.
373//
374// Given a WithPEM(true) option, this function assumes that the given input
375// is PEM encoded ASN.1 DER format key.
376//
377// Note that a successful parsing of any type of key does NOT necessarily
378// guarantee a valid key. For example, no checks against expiration dates
379// are performed for certificate expiration, no checks against missing
380// parameters are performed, etc.
381func ParseKey(data []byte, options ...ParseOption) (Key, error) {
382	var parsePEM bool
383	var localReg *json.Registry
384	for _, option := range options {
385		//nolint:forcetypeassert
386		switch option.Ident() {
387		case identPEM{}:
388			parsePEM = option.Value().(bool)
389		case identLocalRegistry{}:
390			// in reality you can only pass either withLocalRegistry or
391			// WithTypedField, but since withLocalRegistry is used only by us,
392			// we skip checking
393			localReg = option.Value().(*json.Registry)
394		case identTypedField{}:
395			pair := option.Value().(typedFieldPair)
396			if localReg == nil {
397				localReg = json.NewRegistry()
398			}
399			localReg.Register(pair.Name, pair.Value)
400		}
401	}
402
403	if parsePEM {
404		raw, _, err := parsePEMEncodedRawKey(data)
405		if err != nil {
406			return nil, errors.Wrap(err, `failed to parse PEM encoded key`)
407		}
408		return New(raw)
409	}
410
411	var hint struct {
412		Kty string          `json:"kty"`
413		D   json.RawMessage `json:"d"`
414	}
415
416	if err := json.Unmarshal(data, &hint); err != nil {
417		return nil, errors.Wrap(err, `failed to unmarshal JSON into key hint`)
418	}
419
420	var key Key
421	switch jwa.KeyType(hint.Kty) {
422	case jwa.RSA:
423		if len(hint.D) > 0 {
424			key = newRSAPrivateKey()
425		} else {
426			key = newRSAPublicKey()
427		}
428	case jwa.EC:
429		if len(hint.D) > 0 {
430			key = newECDSAPrivateKey()
431		} else {
432			key = newECDSAPublicKey()
433		}
434	case jwa.OctetSeq:
435		key = newSymmetricKey()
436	case jwa.OKP:
437		if len(hint.D) > 0 {
438			key = newOKPPrivateKey()
439		} else {
440			key = newOKPPublicKey()
441		}
442	default:
443		return nil, errors.Errorf(`invalid key type from JSON (%s)`, hint.Kty)
444	}
445
446	if localReg != nil {
447		dcKey, ok := key.(KeyWithDecodeCtx)
448		if !ok {
449			return nil, errors.Errorf(`typed field was requested, but the key (%T) does not support DecodeCtx`, key)
450		}
451		dc := json.NewDecodeCtx(localReg)
452		dcKey.SetDecodeCtx(dc)
453		defer func() { dcKey.SetDecodeCtx(nil) }()
454	}
455
456	if err := json.Unmarshal(data, key); err != nil {
457		return nil, errors.Wrapf(err, `failed to unmarshal JSON into key (%T)`, key)
458	}
459
460	return key, nil
461}
462
463// Parse parses JWK from the incoming []byte.
464//
465// For JWK sets, this is a convenience function. You could just as well
466// call `json.Unmarshal` against an empty set created by `jwk.NewSet()`
467// to parse a JSON buffer into a `jwk.Set`.
468//
469// This method exists because many times the user does not know before hand
470// if a JWK(s) resource at a remote location contains a single JWK key or
471// a JWK set, and `jwk.Parse()` can handle either case, returning a JWK Set
472// even if the data only contains a single JWK key
473//
474// If you are looking for more information on how JWKs are parsed, or if
475// you know for sure that you have a single key, please see the documentation
476// for `jwk.ParseKey()`.
477func Parse(src []byte, options ...ParseOption) (Set, error) {
478	var parsePEM bool
479	var localReg *json.Registry
480	for _, option := range options {
481		//nolint:forcetypeassert
482		switch option.Ident() {
483		case identPEM{}:
484			parsePEM = option.Value().(bool)
485		case identTypedField{}:
486			pair := option.Value().(typedFieldPair)
487			if localReg == nil {
488				localReg = json.NewRegistry()
489			}
490			localReg.Register(pair.Name, pair.Value)
491		}
492	}
493
494	s := NewSet()
495
496	if parsePEM {
497		src = bytes.TrimSpace(src)
498		for len(src) > 0 {
499			raw, rest, err := parsePEMEncodedRawKey(src)
500			if err != nil {
501				return nil, errors.Wrap(err, `failed to parse PEM encoded key`)
502			}
503			key, err := New(raw)
504			if err != nil {
505				return nil, errors.Wrapf(err, `failed to create jwk.Key from %T`, raw)
506			}
507			s.Add(key)
508			src = bytes.TrimSpace(rest)
509		}
510		return s, nil
511	}
512
513	if localReg != nil {
514		dcKs, ok := s.(KeyWithDecodeCtx)
515		if !ok {
516			return nil, errors.Errorf(`typed field was requested, but the key set (%T) does not support DecodeCtx`, s)
517		}
518		dc := json.NewDecodeCtx(localReg)
519		dcKs.SetDecodeCtx(dc)
520		defer func() { dcKs.SetDecodeCtx(nil) }()
521	}
522
523	if err := json.Unmarshal(src, s); err != nil {
524		return nil, errors.Wrap(err, "failed to unmarshal JWK set")
525	}
526	return s, nil
527}
528
529// ParseReader parses a JWK set from the incoming byte buffer.
530func ParseReader(src io.Reader, options ...ParseOption) (Set, error) {
531	// meh, there's no way to tell if a stream has "ended" a single
532	// JWKs except when we encounter an EOF, so just... ReadAll
533	buf, err := ioutil.ReadAll(src)
534	if err != nil {
535		return nil, errors.Wrap(err, `failed to read from io.Reader`)
536	}
537
538	return Parse(buf, options...)
539}
540
541// ParseString parses a JWK set from the incoming string.
542func ParseString(s string, options ...ParseOption) (Set, error) {
543	return Parse([]byte(s), options...)
544}
545
546// AssignKeyID is a convenience function to automatically assign the "kid"
547// section of the key, if it already doesn't have one. It uses Key.Thumbprint
548// method with crypto.SHA256 as the default hashing algorithm
549func AssignKeyID(key Key, options ...Option) error {
550	if _, ok := key.Get(KeyIDKey); ok {
551		return nil
552	}
553
554	hash := crypto.SHA256
555	for _, option := range options {
556		//nolint:forcetypeassert
557		switch option.Ident() {
558		case identThumbprintHash{}:
559			hash = option.Value().(crypto.Hash)
560		}
561	}
562
563	h, err := key.Thumbprint(hash)
564	if err != nil {
565		return errors.Wrap(err, `failed to generate thumbprint`)
566	}
567
568	if err := key.Set(KeyIDKey, base64.EncodeToString(h)); err != nil {
569		return errors.Wrap(err, `failed to set "kid"`)
570	}
571
572	return nil
573}
574
575func cloneKey(src Key) (Key, error) {
576	var dst Key
577	switch src.(type) {
578	case RSAPrivateKey:
579		dst = NewRSAPrivateKey()
580	case RSAPublicKey:
581		dst = NewRSAPublicKey()
582	case ECDSAPrivateKey:
583		dst = NewECDSAPrivateKey()
584	case ECDSAPublicKey:
585		dst = NewECDSAPublicKey()
586	case OKPPrivateKey:
587		dst = NewOKPPrivateKey()
588	case OKPPublicKey:
589		dst = NewOKPPublicKey()
590	case SymmetricKey:
591		dst = NewSymmetricKey()
592	default:
593		return nil, errors.Errorf(`unknown key type %T`, src)
594	}
595
596	for _, pair := range src.makePairs() {
597		if err := dst.Set(pair.Key.(string), pair.Value); err != nil {
598			return nil, errors.Wrapf(err, `failed to set %s`, pair.Key.(string))
599		}
600	}
601	return dst, nil
602}
603
604// Pem serializes the given jwk.Key in PEM encoded ASN.1 DER format,
605// using either PKCS8 for private keys and PKIX for public keys.
606// If you need to encode using PKCS1 or SEC1, you must do it yourself.
607//
608// Argument must be of type jwk.Key or jwk.Set
609//
610// Currently only EC (including Ed25519) and RSA keys (and jwk.Set
611// comprised of these key types) are supported.
612func Pem(v interface{}) ([]byte, error) {
613	var set Set
614	switch v := v.(type) {
615	case Key:
616		set = NewSet()
617		set.Add(v)
618	case Set:
619		set = v
620	default:
621		return nil, errors.Errorf(`argument to Pem must be either jwk.Key or jwk.Set: %T`, v)
622	}
623
624	var ret []byte
625	for i := 0; i < set.Len(); i++ {
626		key, _ := set.Get(i)
627		typ, buf, err := asnEncode(key)
628		if err != nil {
629			return nil, errors.Wrapf(err, `failed to encode content for key #%d`, i)
630		}
631
632		var block pem.Block
633		block.Type = typ
634		block.Bytes = buf
635		ret = append(ret, pem.EncodeToMemory(&block)...)
636	}
637	return ret, nil
638}
639
640func asnEncode(key Key) (string, []byte, error) {
641	switch key := key.(type) {
642	case RSAPrivateKey, ECDSAPrivateKey, OKPPrivateKey:
643		var rawkey interface{}
644		if err := key.Raw(&rawkey); err != nil {
645			return "", nil, errors.Wrap(err, `failed to get raw key from jwk.Key`)
646		}
647		buf, err := x509.MarshalPKCS8PrivateKey(rawkey)
648		if err != nil {
649			return "", nil, errors.Wrap(err, `failed to marshal PKCS8`)
650		}
651		return "PRIVATE KEY", buf, nil
652	case RSAPublicKey, ECDSAPublicKey, OKPPublicKey:
653		var rawkey interface{}
654		if err := key.Raw(&rawkey); err != nil {
655			return "", nil, errors.Wrap(err, `failed to get raw key from jwk.Key`)
656		}
657		buf, err := x509.MarshalPKIXPublicKey(rawkey)
658		if err != nil {
659			return "", nil, errors.Wrap(err, `failed to marshal PKIX`)
660		}
661		return "PUBLIC KEY", buf, nil
662	default:
663		return "", nil, errors.Errorf(`unsupported key type %T`, key)
664	}
665}
666
667// RegisterCustomField allows users to specify that a private field
668// be decoded as an instance of the specified type. This option has
669// a global effect.
670//
671// For example, suppose you have a custom field `x-birthday`, which
672// you want to represent as a string formatted in RFC3339 in JSON,
673// but want it back as `time.Time`.
674//
675// In that case you would register a custom field as follows
676//
677//   jwk.RegisterCustomField(`x-birthday`, timeT)
678//
679// Then `key.Get("x-birthday")` will still return an `interface{}`,
680// but you can convert its type to `time.Time`
681//
682//   bdayif, _ := key.Get(`x-birthday`)
683//   bday := bdayif.(time.Time)
684//
685func RegisterCustomField(name string, object interface{}) {
686	registry.Register(name, object)
687}
688