1package jwk
2
3import (
4	"crypto"
5	"crypto/ecdsa"
6	"crypto/elliptic"
7	"fmt"
8	"math/big"
9
10	"github.com/lestrrat-go/blackmagic"
11	"github.com/lestrrat-go/jwx/internal/base64"
12	"github.com/lestrrat-go/jwx/internal/ecutil"
13	"github.com/lestrrat-go/jwx/jwa"
14	"github.com/pkg/errors"
15)
16
17func init() {
18	ecutil.RegisterCurve(elliptic.P256(), jwa.P256)
19	ecutil.RegisterCurve(elliptic.P384(), jwa.P384)
20	ecutil.RegisterCurve(elliptic.P521(), jwa.P521)
21}
22
23func (k *ecdsaPublicKey) FromRaw(rawKey *ecdsa.PublicKey) error {
24	k.mu.Lock()
25	defer k.mu.Unlock()
26
27	if rawKey.X == nil {
28		return errors.Errorf(`invalid ecdsa.PublicKey`)
29	}
30
31	if rawKey.Y == nil {
32		return errors.Errorf(`invalid ecdsa.PublicKey`)
33	}
34
35	xbuf := ecutil.AllocECPointBuffer(rawKey.X, rawKey.Curve)
36	ybuf := ecutil.AllocECPointBuffer(rawKey.Y, rawKey.Curve)
37	defer ecutil.ReleaseECPointBuffer(xbuf)
38	defer ecutil.ReleaseECPointBuffer(ybuf)
39
40	k.x = make([]byte, len(xbuf))
41	copy(k.x, xbuf)
42	k.y = make([]byte, len(ybuf))
43	copy(k.y, ybuf)
44
45	var crv jwa.EllipticCurveAlgorithm
46	if tmp, ok := ecutil.AlgorithmForCurve(rawKey.Curve); ok {
47		crv = tmp
48	} else {
49		return errors.Errorf(`invalid elliptic curve %s`, rawKey.Curve)
50	}
51	k.crv = &crv
52
53	return nil
54}
55
56func (k *ecdsaPrivateKey) FromRaw(rawKey *ecdsa.PrivateKey) error {
57	k.mu.Lock()
58	defer k.mu.Unlock()
59
60	if rawKey.PublicKey.X == nil {
61		return errors.Errorf(`invalid ecdsa.PrivateKey`)
62	}
63	if rawKey.PublicKey.Y == nil {
64		return errors.Errorf(`invalid ecdsa.PrivateKey`)
65	}
66	if rawKey.D == nil {
67		return errors.Errorf(`invalid ecdsa.PrivateKey`)
68	}
69
70	xbuf := ecutil.AllocECPointBuffer(rawKey.PublicKey.X, rawKey.Curve)
71	ybuf := ecutil.AllocECPointBuffer(rawKey.PublicKey.Y, rawKey.Curve)
72	dbuf := ecutil.AllocECPointBuffer(rawKey.D, rawKey.Curve)
73	defer ecutil.ReleaseECPointBuffer(xbuf)
74	defer ecutil.ReleaseECPointBuffer(ybuf)
75	defer ecutil.ReleaseECPointBuffer(dbuf)
76
77	k.x = make([]byte, len(xbuf))
78	copy(k.x, xbuf)
79	k.y = make([]byte, len(ybuf))
80	copy(k.y, ybuf)
81	k.d = make([]byte, len(dbuf))
82	copy(k.d, dbuf)
83
84	var crv jwa.EllipticCurveAlgorithm
85	if tmp, ok := ecutil.AlgorithmForCurve(rawKey.Curve); ok {
86		crv = tmp
87	} else {
88		return errors.Errorf(`invalid elliptic curve %s`, rawKey.Curve)
89	}
90	k.crv = &crv
91
92	return nil
93}
94
95func buildECDSAPublicKey(alg jwa.EllipticCurveAlgorithm, xbuf, ybuf []byte) (*ecdsa.PublicKey, error) {
96	var crv elliptic.Curve
97	if tmp, ok := ecutil.CurveForAlgorithm(alg); ok {
98		crv = tmp
99	} else {
100		return nil, errors.Errorf(`invalid curve algorithm %s`, alg)
101	}
102
103	var x, y big.Int
104	x.SetBytes(xbuf)
105	y.SetBytes(ybuf)
106
107	return &ecdsa.PublicKey{Curve: crv, X: &x, Y: &y}, nil
108}
109
110// Raw returns the EC-DSA public key represented by this JWK
111func (k *ecdsaPublicKey) Raw(v interface{}) error {
112	k.mu.RLock()
113	defer k.mu.RUnlock()
114
115	pubk, err := buildECDSAPublicKey(k.Crv(), k.x, k.y)
116	if err != nil {
117		return errors.Wrap(err, `failed to build public key`)
118	}
119
120	return blackmagic.AssignIfCompatible(v, pubk)
121}
122
123func (k *ecdsaPrivateKey) Raw(v interface{}) error {
124	k.mu.RLock()
125	defer k.mu.RUnlock()
126
127	pubk, err := buildECDSAPublicKey(k.Crv(), k.x, k.y)
128	if err != nil {
129		return errors.Wrap(err, `failed to build public key`)
130	}
131
132	var key ecdsa.PrivateKey
133	var d big.Int
134	d.SetBytes(k.d)
135	key.D = &d
136	key.PublicKey = *pubk
137
138	return blackmagic.AssignIfCompatible(v, &key)
139}
140
141func makeECDSAPublicKey(v interface {
142	makePairs() []*HeaderPair
143}) (Key, error) {
144	newKey := NewECDSAPublicKey()
145
146	// Iterate and copy everything except for the bits that should not be in the public key
147	for _, pair := range v.makePairs() {
148		switch pair.Key {
149		case ECDSADKey:
150			continue
151		default:
152			if err := newKey.Set(pair.Key.(string), pair.Value); err != nil {
153				return nil, errors.Wrapf(err, `failed to set field %s`, pair.Key)
154			}
155		}
156	}
157
158	return newKey, nil
159}
160
161func (k *ecdsaPrivateKey) PublicKey() (Key, error) {
162	return makeECDSAPublicKey(k)
163}
164
165func (k *ecdsaPublicKey) PublicKey() (Key, error) {
166	return makeECDSAPublicKey(k)
167}
168
169func ecdsaThumbprint(hash crypto.Hash, crv, x, y string) []byte {
170	h := hash.New()
171	fmt.Fprint(h, `{"crv":"`)
172	fmt.Fprint(h, crv)
173	fmt.Fprint(h, `","kty":"EC","x":"`)
174	fmt.Fprint(h, x)
175	fmt.Fprint(h, `","y":"`)
176	fmt.Fprint(h, y)
177	fmt.Fprint(h, `"}`)
178	return h.Sum(nil)
179}
180
181// Thumbprint returns the JWK thumbprint using the indicated
182// hashing algorithm, according to RFC 7638
183func (k ecdsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
184	k.mu.RLock()
185	defer k.mu.RUnlock()
186
187	var key ecdsa.PublicKey
188	if err := k.Raw(&key); err != nil {
189		return nil, errors.Wrap(err, `failed to materialize ecdsa.PublicKey for thumbprint generation`)
190	}
191
192	xbuf := ecutil.AllocECPointBuffer(key.X, key.Curve)
193	ybuf := ecutil.AllocECPointBuffer(key.Y, key.Curve)
194	defer ecutil.ReleaseECPointBuffer(xbuf)
195	defer ecutil.ReleaseECPointBuffer(ybuf)
196
197	return ecdsaThumbprint(
198		hash,
199		key.Curve.Params().Name,
200		base64.EncodeToString(xbuf),
201		base64.EncodeToString(ybuf),
202	), nil
203}
204
205// Thumbprint returns the JWK thumbprint using the indicated
206// hashing algorithm, according to RFC 7638
207func (k ecdsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
208	k.mu.RLock()
209	defer k.mu.RUnlock()
210
211	var key ecdsa.PrivateKey
212	if err := k.Raw(&key); err != nil {
213		return nil, errors.Wrap(err, `failed to materialize ecdsa.PrivateKey for thumbprint generation`)
214	}
215
216	xbuf := ecutil.AllocECPointBuffer(key.X, key.Curve)
217	ybuf := ecutil.AllocECPointBuffer(key.Y, key.Curve)
218	defer ecutil.ReleaseECPointBuffer(xbuf)
219	defer ecutil.ReleaseECPointBuffer(ybuf)
220
221	return ecdsaThumbprint(
222		hash,
223		key.Curve.Params().Name,
224		base64.EncodeToString(xbuf),
225		base64.EncodeToString(ybuf),
226	), nil
227}
228