1package jwxtest
2
3import (
4	"bytes"
5	"context"
6	"crypto/ecdsa"
7	"crypto/ed25519"
8	"crypto/elliptic"
9	"crypto/rand"
10	"crypto/rsa"
11	"encoding/json"
12	"io"
13	"io/ioutil"
14	"os"
15	"strings"
16	"testing"
17
18	"github.com/lestrrat-go/jwx/internal/ecutil"
19	"github.com/lestrrat-go/jwx/jwa"
20	"github.com/lestrrat-go/jwx/jwe"
21	"github.com/lestrrat-go/jwx/jwk"
22	"github.com/lestrrat-go/jwx/jws"
23	"github.com/lestrrat-go/jwx/x25519"
24	"github.com/pkg/errors"
25	"github.com/stretchr/testify/assert"
26)
27
28func GenerateRsaKey() (*rsa.PrivateKey, error) {
29	return rsa.GenerateKey(rand.Reader, 2048)
30}
31
32func GenerateRsaJwk() (jwk.Key, error) {
33	key, err := GenerateRsaKey()
34	if err != nil {
35		return nil, errors.Wrap(err, `failed to generate RSA private key`)
36	}
37
38	k, err := jwk.New(key)
39	if err != nil {
40		return nil, errors.Wrap(err, `failed to generate jwk.RSAPrivateKey`)
41	}
42
43	return k, nil
44}
45
46func GenerateRsaPublicJwk() (jwk.Key, error) {
47	key, err := GenerateRsaJwk()
48	if err != nil {
49		return nil, errors.Wrap(err, `failed to generate jwk.RSAPrivateKey`)
50	}
51
52	return jwk.PublicKeyOf(key)
53}
54
55func GenerateEcdsaKey(alg jwa.EllipticCurveAlgorithm) (*ecdsa.PrivateKey, error) {
56	var crv elliptic.Curve
57	if tmp, ok := ecutil.CurveForAlgorithm(alg); ok {
58		crv = tmp
59	} else {
60		return nil, errors.Errorf(`invalid curve algorithm %s`, alg)
61	}
62
63	return ecdsa.GenerateKey(crv, rand.Reader)
64}
65
66func GenerateEcdsaJwk() (jwk.Key, error) {
67	key, err := GenerateEcdsaKey(jwa.P521)
68	if err != nil {
69		return nil, errors.Wrap(err, `failed to generate ECDSA private key`)
70	}
71
72	k, err := jwk.New(key)
73	if err != nil {
74		return nil, errors.Wrap(err, `failed to generate jwk.ECDSAPrivateKey`)
75	}
76
77	return k, nil
78}
79
80func GenerateEcdsaPublicJwk() (jwk.Key, error) {
81	key, err := GenerateEcdsaJwk()
82	if err != nil {
83		return nil, errors.Wrap(err, `failed to generate jwk.ECDSAPrivateKey`)
84	}
85
86	return jwk.PublicKeyOf(key)
87}
88
89func GenerateSymmetricKey() []byte {
90	sharedKey := make([]byte, 64)
91	//nolint:errcheck
92	rand.Read(sharedKey)
93	return sharedKey
94}
95
96func GenerateSymmetricJwk() (jwk.Key, error) {
97	key, err := jwk.New(GenerateSymmetricKey())
98	if err != nil {
99		return nil, errors.Wrap(err, `failed to generate jwk.SymmetricKey`)
100	}
101
102	return key, nil
103}
104
105func GenerateEd25519Key() (ed25519.PrivateKey, error) {
106	_, priv, err := ed25519.GenerateKey(rand.Reader)
107	return priv, err
108}
109
110func GenerateEd25519Jwk() (jwk.Key, error) {
111	key, err := GenerateEd25519Key()
112	if err != nil {
113		return nil, errors.Wrap(err, `failed to generate Ed25519 private key`)
114	}
115
116	k, err := jwk.New(key)
117	if err != nil {
118		return nil, errors.Wrap(err, `failed to generate jwk.OKPPrivateKey`)
119	}
120
121	return k, nil
122}
123
124func GenerateX25519Key() (x25519.PrivateKey, error) {
125	_, priv, err := x25519.GenerateKey(rand.Reader)
126	return priv, err
127}
128
129func GenerateX25519Jwk() (jwk.Key, error) {
130	key, err := GenerateX25519Key()
131	if err != nil {
132		return nil, errors.Wrap(err, `failed to generate X25519 private key`)
133	}
134
135	k, err := jwk.New(key)
136	if err != nil {
137		return nil, errors.Wrap(err, `failed to generate jwk.OKPPrivateKey`)
138	}
139
140	return k, nil
141}
142
143func WriteFile(template string, src io.Reader) (string, func(), error) {
144	file, cleanup, err := CreateTempFile(template)
145	if err != nil {
146		return "", nil, errors.Wrap(err, `failed to create temporary file`)
147	}
148
149	if _, err := io.Copy(file, src); err != nil {
150		defer cleanup()
151		return "", nil, errors.Wrap(err, `failed to copy content to temporary file`)
152	}
153
154	if err := file.Sync(); err != nil {
155		defer cleanup()
156		return "", nil, errors.Wrap(err, `failed to sync file`)
157	}
158	return file.Name(), cleanup, nil
159}
160
161func WriteJSONFile(template string, v interface{}) (string, func(), error) {
162	var buf bytes.Buffer
163
164	enc := json.NewEncoder(&buf)
165	if err := enc.Encode(v); err != nil {
166		return "", nil, errors.Wrap(err, `failed to encode object to JSON`)
167	}
168	return WriteFile(template, &buf)
169}
170
171func DumpFile(t *testing.T, file string) {
172	buf, err := ioutil.ReadFile(file)
173	if !assert.NoError(t, err, `failed to read file %s for debugging`, file) {
174		return
175	}
176
177	if isHash, isArray := bytes.ContainsRune(buf, '{'), bytes.ContainsRune(buf, '['); isHash || isArray {
178		// Looks like a JSON-like thing. Dump that in a formatted manner, and
179		// be done with it
180
181		var v interface{}
182		if isHash {
183			v = map[string]interface{}{}
184		} else {
185			v = []interface{}{}
186		}
187
188		if !assert.NoError(t, json.Unmarshal(buf, &v), `failed to parse contents as JSON`) {
189			return
190		}
191
192		buf, _ = json.MarshalIndent(v, "", "  ")
193		t.Logf("=== BEGIN %s (formatted JSON) ===", file)
194		t.Logf("%s", buf)
195		t.Logf("=== END   %s (formatted JSON) ===", file)
196		return
197	}
198
199	// If the contents do not look like JSON, then we attempt to parse each content
200	// based on heuristics (from its file name) and do our best
201	t.Logf("=== BEGIN %s (raw) ===", file)
202	t.Logf("%s", buf)
203	t.Logf("=== END   %s (raw) ===", file)
204
205	if strings.HasSuffix(file, ".jwe") {
206		// cross our fingers our jwe implementation works
207		m, err := jwe.Parse(buf)
208		if !assert.NoError(t, err, `failed to parse JWE encrypted message`) {
209			return
210		}
211
212		buf, _ = json.MarshalIndent(m, "", "  ")
213	}
214
215	t.Logf("=== BEGIN %s (formatted JSON) ===", file)
216	t.Logf("%s", buf)
217	t.Logf("=== END   %s (formatted JSON) ===", file)
218}
219
220func CreateTempFile(template string) (*os.File, func(), error) {
221	file, err := ioutil.TempFile("", template)
222	if err != nil {
223		return nil, nil, errors.Wrap(err, "failed to create temporary file")
224	}
225
226	cleanup := func() {
227		file.Close()
228		os.Remove(file.Name())
229	}
230
231	return file, cleanup, nil
232}
233
234func ReadFile(file string) ([]byte, error) {
235	f, err := os.Open(file)
236	if err != nil {
237		return nil, errors.Wrapf(err, `failed to open file %s`, file)
238	}
239	defer f.Close()
240
241	buf, err := ioutil.ReadAll(f)
242	if err != nil {
243		return nil, errors.Wrapf(err, `failed to read from key file %s`, file)
244	}
245
246	return buf, nil
247}
248
249func ParseJwkFile(_ context.Context, file string) (jwk.Key, error) {
250	buf, err := ReadFile(file)
251	if err != nil {
252		return nil, errors.Wrapf(err, `failed to read from key file %s`, file)
253	}
254
255	key, err := jwk.ParseKey(buf)
256	if err != nil {
257		return nil, errors.Wrapf(err, `filed to parse JWK in key file %s`, file)
258	}
259
260	return key, nil
261}
262
263func DecryptJweFile(ctx context.Context, file string, alg jwa.KeyEncryptionAlgorithm, jwkfile string) ([]byte, error) {
264	key, err := ParseJwkFile(ctx, jwkfile)
265	if err != nil {
266		return nil, errors.Wrapf(err, `failed to parse keyfile %s`, file)
267	}
268
269	buf, err := ReadFile(file)
270	if err != nil {
271		return nil, errors.Wrapf(err, `failed to read from encrypted file %s`, file)
272	}
273
274	var rawkey interface{}
275	if err := key.Raw(&rawkey); err != nil {
276		return nil, errors.Wrap(err, `failed to obtain raw key from JWK`)
277	}
278
279	return jwe.Decrypt(buf, alg, rawkey)
280}
281
282func EncryptJweFile(ctx context.Context, payload []byte, keyalg jwa.KeyEncryptionAlgorithm, keyfile string, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm) (string, func(), error) {
283	key, err := ParseJwkFile(ctx, keyfile)
284	if err != nil {
285		return "", nil, errors.Wrapf(err, `failed to parse keyfile %s`, keyfile)
286	}
287
288	var keyif interface{}
289
290	switch keyalg {
291	case jwa.RSA1_5, jwa.RSA_OAEP, jwa.RSA_OAEP_256:
292		var rawkey rsa.PrivateKey
293		if err := key.Raw(&rawkey); err != nil {
294			return "", nil, errors.Wrap(err, `failed to obtain raw key`)
295		}
296		keyif = rawkey.PublicKey
297	case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
298		var rawkey ecdsa.PrivateKey
299		if err := key.Raw(&rawkey); err != nil {
300			return "", nil, errors.Wrap(err, `failed to obtain raw key`)
301		}
302		keyif = rawkey.PublicKey
303	default:
304		var rawkey []byte
305		if err := key.Raw(&rawkey); err != nil {
306			return "", nil, errors.Wrap(err, `failed to obtain raw key`)
307		}
308		keyif = rawkey
309	}
310
311	buf, err := jwe.Encrypt(payload, keyalg, keyif, contentalg, compressalg)
312	if err != nil {
313		return "", nil, errors.Wrap(err, `failed to encrypt payload`)
314	}
315
316	return WriteFile("jwx-test-*.jwe", bytes.NewReader(buf))
317}
318
319func VerifyJwsFile(ctx context.Context, file string, alg jwa.SignatureAlgorithm, jwkfile string) ([]byte, error) {
320	key, err := ParseJwkFile(ctx, jwkfile)
321	if err != nil {
322		return nil, errors.Wrapf(err, `failed to parse keyfile %s`, file)
323	}
324
325	buf, err := ReadFile(file)
326	if err != nil {
327		return nil, errors.Wrapf(err, `failed to read from encrypted file %s`, file)
328	}
329
330	var rawkey, pubkey interface{}
331	if err := key.Raw(&rawkey); err != nil {
332		return nil, errors.Wrap(err, `failed to obtain raw key from JWK`)
333	}
334	pubkey = rawkey
335	switch tkey := rawkey.(type) {
336	case *ecdsa.PrivateKey:
337		pubkey = tkey.PublicKey
338	case *rsa.PrivateKey:
339		pubkey = tkey.PublicKey
340	case *ed25519.PrivateKey:
341		pubkey = tkey.Public()
342	}
343
344	return jws.Verify(buf, alg, pubkey)
345}
346
347func SignJwsFile(ctx context.Context, payload []byte, alg jwa.SignatureAlgorithm, keyfile string) (string, func(), error) {
348	key, err := ParseJwkFile(ctx, keyfile)
349	if err != nil {
350		return "", nil, errors.Wrapf(err, `failed to parse keyfile %s`, keyfile)
351	}
352
353	buf, err := jws.Sign(payload, alg, key)
354	if err != nil {
355		return "", nil, errors.Wrap(err, `failed to sign payload`)
356	}
357
358	return WriteFile("jwx-test-*.jws", bytes.NewReader(buf))
359}
360