1package crypto_test
2
3import (
4	"bytes"
5	"crypto"
6	"crypto/ecdsa"
7	"crypto/ed25519"
8	"crypto/elliptic"
9	"crypto/rand"
10	"crypto/rsa"
11	"crypto/x509"
12	"fmt"
13	"reflect"
14	"testing"
15
16	btcec "github.com/btcsuite/btcd/btcec"
17	. "github.com/libp2p/go-libp2p-core/crypto"
18	pb "github.com/libp2p/go-libp2p-core/crypto/pb"
19	"github.com/libp2p/go-libp2p-core/test"
20	sha256 "github.com/minio/sha256-simd"
21)
22
23func TestKeys(t *testing.T) {
24	for _, typ := range KeyTypes {
25		testKeyType(typ, t)
26	}
27}
28
29func TestKeyPairFromKey(t *testing.T) {
30	var (
31		data   = []byte(`hello world`)
32		hashed = sha256.Sum256(data)
33	)
34
35	privk, err := btcec.NewPrivateKey(btcec.S256())
36	if err != nil {
37		t.Fatalf("err generating btcec priv key:\n%v", err)
38	}
39	sigK, err := privk.Sign(hashed[:])
40	if err != nil {
41		t.Fatalf("err generating btcec sig:\n%v", err)
42	}
43
44	eKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
45	if err != nil {
46		t.Fatalf("err generating ecdsa priv key:\n%v", err)
47	}
48	sigE, err := eKey.Sign(rand.Reader, hashed[:], crypto.SHA256)
49	if err != nil {
50		t.Fatalf("err generating ecdsa sig:\n%v", err)
51	}
52
53	rKey, err := rsa.GenerateKey(rand.Reader, 2048)
54	if err != nil {
55		t.Fatalf("err generating rsa priv key:\n%v", err)
56	}
57	sigR, err := rKey.Sign(rand.Reader, hashed[:], crypto.SHA256)
58	if err != nil {
59		t.Fatalf("err generating rsa sig:\n%v", err)
60	}
61
62	_, edKey, err := ed25519.GenerateKey(rand.Reader)
63	sigEd := ed25519.Sign(edKey, data[:])
64	if err != nil {
65		t.Fatalf("err generating ed25519 sig:\n%v", err)
66	}
67
68	for i, tt := range []struct {
69		in  crypto.PrivateKey
70		typ pb.KeyType
71		sig []byte
72	}{
73		{
74			eKey,
75			ECDSA,
76			sigE,
77		},
78		{
79			privk,
80			Secp256k1,
81			sigK.Serialize(),
82		},
83		{
84			rKey,
85			RSA,
86			sigR,
87		},
88		{
89			&edKey,
90			Ed25519,
91			sigEd,
92		},
93	} {
94		t.Run(fmt.Sprintf("%v", i), func(t *testing.T) {
95			priv, pub, err := KeyPairFromStdKey(tt.in)
96			if err != nil {
97				t.Fatal(err)
98			}
99
100			if priv == nil || pub == nil {
101				t.Errorf("received nil private key or public key: %v, %v", priv, pub)
102			}
103
104			if priv == nil || priv.Type() != tt.typ {
105				t.Errorf("want %v; got %v", tt.typ, priv.Type())
106			}
107
108			v, err := pub.Verify(data[:], tt.sig)
109			if err != nil {
110				t.Error(err)
111			}
112
113			if !v {
114				t.Error("signature was not verified")
115			}
116
117			stdPub, err := PubKeyToStdKey(pub)
118			if stdPub == nil {
119				t.Errorf("err getting std public key from key: %v", err)
120			}
121
122			var stdPubBytes []byte
123
124			switch p := stdPub.(type) {
125			case *Secp256k1PublicKey:
126				stdPubBytes, err = p.Raw()
127			case ed25519.PublicKey:
128				stdPubBytes = []byte(p)
129			default:
130				stdPubBytes, err = x509.MarshalPKIXPublicKey(stdPub)
131			}
132
133			if err != nil {
134				t.Errorf("Error while marshaling %v key: %v", reflect.TypeOf(stdPub), err)
135			}
136
137			pubBytes, err := pub.Raw()
138			if err != nil {
139				t.Errorf("err getting raw bytes for %v key: %v", reflect.TypeOf(pub), err)
140			}
141			if !bytes.Equal(stdPubBytes, pubBytes) {
142				t.Errorf("err roundtripping %v key", reflect.TypeOf(pub))
143			}
144
145			stdPriv, err := PrivKeyToStdKey(priv)
146			if stdPub == nil {
147				t.Errorf("err getting std private key from key: %v", err)
148			}
149
150			var stdPrivBytes []byte
151
152			switch p := stdPriv.(type) {
153			case *Secp256k1PrivateKey:
154				stdPrivBytes, err = p.Raw()
155			case *ecdsa.PrivateKey:
156				stdPrivBytes, err = x509.MarshalECPrivateKey(p)
157			case *ed25519.PrivateKey:
158				stdPrivBytes = *p
159			case *rsa.PrivateKey:
160				stdPrivBytes = x509.MarshalPKCS1PrivateKey(p)
161			}
162
163			if err != nil {
164				t.Errorf("err marshaling %v key: %v", reflect.TypeOf(stdPriv), err)
165			}
166
167			privBytes, err := priv.Raw()
168			if err != nil {
169				t.Errorf("err getting raw bytes for %v key: %v", reflect.TypeOf(priv), err)
170			}
171
172			if !bytes.Equal(stdPrivBytes, privBytes) {
173				t.Errorf("err roundtripping %v key", reflect.TypeOf(priv))
174			}
175		})
176	}
177}
178
179func testKeyType(typ int, t *testing.T) {
180	bits := 512
181	if typ == RSA {
182		bits = 2048
183	}
184	sk, pk, err := test.RandTestKeyPair(typ, bits)
185	if err != nil {
186		t.Fatal(err)
187	}
188
189	testKeySignature(t, sk)
190	testKeyEncoding(t, sk)
191	testKeyEquals(t, sk)
192	testKeyEquals(t, pk)
193}
194
195func testKeySignature(t *testing.T, sk PrivKey) {
196	pk := sk.GetPublic()
197
198	text := make([]byte, 16)
199	if _, err := rand.Read(text); err != nil {
200		t.Fatal(err)
201	}
202
203	sig, err := sk.Sign(text)
204	if err != nil {
205		t.Fatal(err)
206	}
207
208	valid, err := pk.Verify(text, sig)
209	if err != nil {
210		t.Fatal(err)
211	}
212
213	if !valid {
214		t.Fatal("Invalid signature.")
215	}
216}
217
218func testKeyEncoding(t *testing.T, sk PrivKey) {
219	skbm, err := MarshalPrivateKey(sk)
220	if err != nil {
221		t.Fatal(err)
222	}
223
224	sk2, err := UnmarshalPrivateKey(skbm)
225	if err != nil {
226		t.Fatal(err)
227	}
228
229	if !sk.Equals(sk2) {
230		t.Error("Unmarshaled private key didn't match original.\n")
231	}
232
233	skbm2, err := MarshalPrivateKey(sk2)
234	if err != nil {
235		t.Fatal(err)
236	}
237
238	if !bytes.Equal(skbm, skbm2) {
239		t.Error("skb -> marshal -> unmarshal -> skb failed.\n", skbm, "\n", skbm2)
240	}
241
242	pk := sk.GetPublic()
243	pkbm, err := MarshalPublicKey(pk)
244	if err != nil {
245		t.Fatal(err)
246	}
247
248	pk2, err := UnmarshalPublicKey(pkbm)
249	if err != nil {
250		t.Fatal(err)
251	}
252
253	if !pk.Equals(pk2) {
254		t.Error("Unmarshaled public key didn't match original.\n")
255	}
256
257	pkbm2, err := MarshalPublicKey(pk)
258	if err != nil {
259		t.Fatal(err)
260	}
261
262	if !bytes.Equal(pkbm, pkbm2) {
263		t.Error("skb -> marshal -> unmarshal -> skb failed.\n", pkbm, "\n", pkbm2)
264	}
265}
266
267func testKeyEquals(t *testing.T, k Key) {
268	// kb, err := k.Raw()
269	// if err != nil {
270	// 	t.Fatal(err)
271	// }
272
273	if !KeyEqual(k, k) {
274		t.Fatal("Key not equal to itself.")
275	}
276
277	// bad test, relies on deep internals..
278	// if !KeyEqual(k, testkey(kb)) {
279	// 	t.Fatal("Key not equal to key with same bytes.")
280	// }
281
282	sk, pk, err := test.RandTestKeyPair(RSA, 2048)
283	if err != nil {
284		t.Fatal(err)
285	}
286
287	if KeyEqual(k, sk) {
288		t.Fatal("Keys should not equal.")
289	}
290
291	if KeyEqual(k, pk) {
292		t.Fatal("Keys should not equal.")
293	}
294}
295
296type testkey []byte
297
298func (pk testkey) Bytes() ([]byte, error) {
299	return pk, nil
300}
301
302func (pk testkey) Type() pb.KeyType {
303	return pb.KeyType_RSA
304}
305
306func (pk testkey) Raw() ([]byte, error) {
307	return pk, nil
308}
309
310func (pk testkey) Equals(k Key) bool {
311	if pk.Type() != k.Type() {
312		return false
313	}
314	a, err := pk.Raw()
315	if err != nil {
316		return false
317	}
318
319	b, err := k.Raw()
320	if err != nil {
321		return false
322	}
323
324	return bytes.Equal(a, b)
325}
326
327func TestUnknownCurveErrors(t *testing.T) {
328	_, _, err := GenerateEKeyPair("P-256")
329	if err != nil {
330		t.Fatal(err)
331	}
332
333	_, _, err = GenerateEKeyPair("error-please")
334	if err == nil {
335		t.Fatal("expected invalid key type to error")
336	}
337}
338
339func TestPanicOnUnknownCipherType(t *testing.T) {
340	passed := false
341	defer func() {
342		if !passed {
343			t.Fatal("expected known cipher and hash to succeed")
344		}
345		err := recover()
346		errStr, ok := err.(string)
347		if !ok {
348			t.Fatal("expected string in panic")
349		}
350		if errStr != "Unrecognized cipher, programmer error?" {
351			t.Fatal("expected \"Unrecognized cipher, programmer error?\"")
352		}
353	}()
354	KeyStretcher("AES-256", "SHA1", []byte("foo"))
355	passed = true
356	KeyStretcher("Fooba", "SHA1", []byte("foo"))
357}
358