1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package identity_test
5
6import (
7	"bytes"
8	"context"
9	"crypto"
10	"crypto/x509"
11	"crypto/x509/pkix"
12	"encoding/asn1"
13	"fmt"
14	"os"
15	"runtime"
16	"testing"
17
18	"github.com/stretchr/testify/assert"
19	"github.com/stretchr/testify/require"
20
21	"storj.io/common/identity"
22	"storj.io/common/identity/testidentity"
23	"storj.io/common/peertls"
24	"storj.io/common/peertls/extensions"
25	"storj.io/common/peertls/testpeertls"
26	"storj.io/common/peertls/tlsopts"
27	"storj.io/common/pkcrypto"
28	"storj.io/common/storj"
29	"storj.io/common/testcontext"
30	"storj.io/common/testrand"
31)
32
33func TestPeerIdentityFromCertChain(t *testing.T) {
34	caKey, err := pkcrypto.GeneratePrivateKey()
35	require.NoError(t, err)
36
37	caTemplate, err := peertls.CATemplate()
38	require.NoError(t, err)
39
40	caCert, err := peertls.CreateSelfSignedCertificate(caKey, caTemplate)
41	require.NoError(t, err)
42
43	leafTemplate, err := peertls.LeafTemplate()
44	require.NoError(t, err)
45
46	leafKey, err := pkcrypto.GeneratePrivateKey()
47	require.NoError(t, err)
48
49	pubKey, err := pkcrypto.PublicKeyFromPrivate(leafKey)
50	require.NoError(t, err)
51	leafCert, err := peertls.CreateCertificate(pubKey, caKey, leafTemplate, caTemplate)
52	require.NoError(t, err)
53
54	peerIdent, err := identity.PeerIdentityFromChain([]*x509.Certificate{leafCert, caCert})
55	require.NoError(t, err)
56	assert.Equal(t, caCert, peerIdent.CA)
57	assert.Equal(t, leafCert, peerIdent.Leaf)
58	assert.NotEmpty(t, peerIdent.ID)
59}
60
61func TestFullIdentityFromPEM(t *testing.T) {
62	caKey, err := pkcrypto.GeneratePrivateKey()
63	require.NoError(t, err)
64
65	caTemplate, err := peertls.CATemplate()
66	require.NoError(t, err)
67
68	caCert, err := peertls.CreateSelfSignedCertificate(caKey, caTemplate)
69	require.NoError(t, err)
70	require.NoError(t, err)
71	require.NotEmpty(t, caCert)
72
73	leafTemplate, err := peertls.LeafTemplate()
74	require.NoError(t, err)
75
76	leafKey, err := pkcrypto.GeneratePrivateKey()
77	require.NoError(t, err)
78
79	pubKey, err := pkcrypto.PublicKeyFromPrivate(leafKey)
80	require.NoError(t, err)
81	leafCert, err := peertls.CreateCertificate(pubKey, caKey, leafTemplate, caTemplate)
82	require.NoError(t, err)
83	require.NotEmpty(t, leafCert)
84
85	chainPEM := bytes.NewBuffer([]byte{})
86	require.NoError(t, pkcrypto.WriteCertPEM(chainPEM, leafCert))
87	require.NoError(t, pkcrypto.WriteCertPEM(chainPEM, caCert))
88
89	keyPEM := bytes.NewBuffer([]byte{})
90	require.NoError(t, pkcrypto.WritePrivateKeyPEM(keyPEM, leafKey))
91
92	fullIdent, err := identity.FullIdentityFromPEM(chainPEM.Bytes(), keyPEM.Bytes())
93	assert.NoError(t, err)
94	assert.Equal(t, leafCert.Raw, fullIdent.Leaf.Raw)
95	assert.Equal(t, caCert.Raw, fullIdent.CA.Raw)
96	assert.Equal(t, leafKey, fullIdent.Key)
97}
98
99func TestConfig_Save_with_extension(t *testing.T) {
100	ctx := testcontext.New(t)
101
102	testidentity.CompleteIdentityVersionsTest(t, func(t *testing.T, version storj.IDVersion, ident *identity.FullIdentity) {
103		identCfg := &identity.Config{
104			CertPath: ctx.File("chain.pem"),
105			KeyPath:  ctx.File("key.pem"),
106		}
107
108		{ // pre-save version assertions
109			assert.Equal(t, version.Number, ident.ID.Version().Number)
110
111			caVersion, err := storj.IDVersionFromCert(ident.CA)
112			require.NoError(t, err)
113			assert.Equal(t, version.Number, caVersion.Number)
114
115			versionExt := tlsopts.NewExtensionsMap(ident.CA)[extensions.IdentityVersionExtID.String()]
116			if ident.ID.Version().Number == 0 {
117				require.NotEmpty(t, versionExt)
118				assert.Equal(t, ident.ID.Version().Number, storj.IDVersionNumber(versionExt.Value[0]))
119			} else {
120				assert.Empty(t, versionExt)
121			}
122		}
123
124		{ // test saving
125			err := identCfg.Save(ident)
126			assert.NoError(t, err)
127
128			certInfo, err := os.Stat(identCfg.CertPath)
129			assert.NoError(t, err)
130
131			keyInfo, err := os.Stat(identCfg.KeyPath)
132			assert.NoError(t, err)
133
134			// TODO (windows): ignoring for windows due to different default permissions
135			if runtime.GOOS != "windows" {
136				assert.Equal(t, os.FileMode(0644), certInfo.Mode())
137				assert.Equal(t, os.FileMode(0600), keyInfo.Mode())
138			}
139		}
140
141		{ // test loading
142			loadedFi, err := identCfg.Load()
143			require.NoError(t, err)
144			assert.Equal(t, ident.Key, loadedFi.Key)
145			assert.Equal(t, ident.Leaf, loadedFi.Leaf)
146			assert.Equal(t, ident.CA, loadedFi.CA)
147			assert.Equal(t, ident.ID, loadedFi.ID)
148
149			versionExt := tlsopts.NewExtensionsMap(ident.CA)[extensions.IdentityVersionExtID.String()]
150			if ident.ID.Version().Number == 0 {
151				require.NotEmpty(t, versionExt)
152				assert.Equal(t, ident.ID.Version().Number, storj.IDVersionNumber(versionExt.Value[0]))
153			} else {
154				assert.Empty(t, versionExt)
155			}
156		}
157	})
158}
159
160func TestConfig_Save(t *testing.T) {
161	ctx := testcontext.New(t)
162
163	testidentity.IdentityVersionsTest(t, func(t *testing.T, version storj.IDVersion, ident *identity.FullIdentity) {
164		identCfg := &identity.Config{
165			CertPath: ctx.File("chain.pem"),
166			KeyPath:  ctx.File("key.pem"),
167		}
168
169		chainPEM := bytes.NewBuffer([]byte{})
170		require.NoError(t, pkcrypto.WriteCertPEM(chainPEM, ident.Leaf))
171		require.NoError(t, pkcrypto.WriteCertPEM(chainPEM, ident.CA))
172
173		privateKey := ident.Key
174		require.NotEmpty(t, privateKey)
175
176		keyPEM := bytes.NewBuffer([]byte{})
177		require.NoError(t, pkcrypto.WritePrivateKeyPEM(keyPEM, privateKey))
178
179		{ // test saving
180			err := identCfg.Save(ident)
181			assert.NoError(t, err)
182
183			certInfo, err := os.Stat(identCfg.CertPath)
184			assert.NoError(t, err)
185
186			keyInfo, err := os.Stat(identCfg.KeyPath)
187			assert.NoError(t, err)
188
189			// TODO (windows): ignoring for windows due to different default permissions
190			if runtime.GOOS != "windows" {
191				assert.Equal(t, os.FileMode(0644), certInfo.Mode())
192				assert.Equal(t, os.FileMode(0600), keyInfo.Mode())
193			}
194		}
195
196		{ // test loading
197			loadedFi, err := identCfg.Load()
198			assert.NoError(t, err)
199			assert.Equal(t, ident.Key, loadedFi.Key)
200			assert.Equal(t, ident.Leaf, loadedFi.Leaf)
201			assert.Equal(t, ident.CA, loadedFi.CA)
202			assert.Equal(t, ident.ID, loadedFi.ID)
203		}
204	})
205}
206
207func TestVersionedNodeIDFromKey(t *testing.T) {
208	_, chain, err := testpeertls.NewCertChain(1, storj.LatestIDVersion().Number)
209	require.NoError(t, err)
210
211	pubKey, ok := chain[peertls.LeafIndex].PublicKey.(crypto.PublicKey)
212	require.True(t, ok)
213
214	for _, v := range storj.IDVersions {
215		version := v
216		t.Run(fmt.Sprintf("IdentityV%d", version.Number), func(t *testing.T) {
217			id, err := identity.NodeIDFromKey(pubKey, version)
218			require.NoError(t, err)
219			assert.Equal(t, version.Number, id.Version().Number)
220		})
221	}
222}
223
224func TestVerifyPeer(t *testing.T) {
225	ca, err := identity.NewCA(context.Background(), identity.NewCAOptions{
226		Difficulty:  12,
227		Concurrency: 4,
228	})
229	require.NoError(t, err)
230	require.NotNil(t, ca)
231
232	fi, err := ca.NewIdentity()
233	require.NoError(t, err)
234	require.NotNil(t, fi)
235
236	err = peertls.VerifyPeerFunc(peertls.VerifyPeerCertChains)([][]byte{fi.Leaf.Raw, fi.CA.Raw}, nil)
237	assert.NoError(t, err)
238}
239
240func TestManageablePeerIdentity_AddExtension(t *testing.T) {
241	ctx := testcontext.New(t)
242
243	manageablePeerIdentity, err := testidentity.NewTestManageablePeerIdentity(ctx)
244	require.NoError(t, err)
245
246	oldLeaf := manageablePeerIdentity.Leaf
247	assert.Len(t, manageablePeerIdentity.CA.Cert.ExtraExtensions, 0)
248
249	randBytes := testrand.Bytes(10)
250	randExt := pkix.Extension{
251		Id:    asn1.ObjectIdentifier{2, 999, int(randBytes[0])},
252		Value: randBytes,
253	}
254
255	err = manageablePeerIdentity.AddExtension(randExt)
256	require.NoError(t, err)
257
258	assert.Len(t, manageablePeerIdentity.Leaf.ExtraExtensions, 0)
259	assert.Len(t, manageablePeerIdentity.Leaf.Extensions, len(oldLeaf.Extensions)+1)
260
261	assert.Equal(t, oldLeaf.SerialNumber, manageablePeerIdentity.Leaf.SerialNumber)
262	assert.Equal(t, oldLeaf.IsCA, manageablePeerIdentity.Leaf.IsCA)
263	assert.Equal(t, oldLeaf.PublicKey, manageablePeerIdentity.Leaf.PublicKey)
264	ext := tlsopts.NewExtensionsMap(manageablePeerIdentity.Leaf)[randExt.Id.String()]
265	assert.Equal(t, randExt, ext)
266
267	assert.Equal(t, randExt, tlsopts.NewExtensionsMap(manageablePeerIdentity.Leaf)[randExt.Id.String()])
268
269	assert.NotEqual(t, oldLeaf.Raw, manageablePeerIdentity.Leaf.Raw)
270	assert.NotEqual(t, oldLeaf.RawTBSCertificate, manageablePeerIdentity.Leaf.RawTBSCertificate)
271	assert.NotEqual(t, oldLeaf.Signature, manageablePeerIdentity.Leaf.Signature)
272}
273
274func TestManageableFullIdentity_Revoke(t *testing.T) {
275	ctx := testcontext.New(t)
276
277	manageableFullIdentity, err := testidentity.NewTestManageableFullIdentity(ctx)
278	require.NoError(t, err)
279
280	oldLeaf := manageableFullIdentity.Leaf
281	assert.Len(t, manageableFullIdentity.CA.Cert.ExtraExtensions, 0)
282
283	err = manageableFullIdentity.Revoke()
284	require.NoError(t, err)
285
286	assert.Len(t, manageableFullIdentity.Leaf.ExtraExtensions, 0)
287	assert.Len(t, manageableFullIdentity.Leaf.Extensions, len(oldLeaf.Extensions)+1)
288
289	assert.Equal(t, oldLeaf.IsCA, manageableFullIdentity.Leaf.IsCA)
290
291	assert.NotEqual(t, oldLeaf.PublicKey, manageableFullIdentity.Leaf.PublicKey)
292	assert.NotEqual(t, oldLeaf.SerialNumber, manageableFullIdentity.Leaf.SerialNumber)
293	assert.NotEqual(t, oldLeaf.Raw, manageableFullIdentity.Leaf.Raw)
294	assert.NotEqual(t, oldLeaf.RawTBSCertificate, manageableFullIdentity.Leaf.RawTBSCertificate)
295	assert.NotEqual(t, oldLeaf.Signature, manageableFullIdentity.Leaf.Signature)
296
297	revocationExt := tlsopts.NewExtensionsMap(manageableFullIdentity.Leaf)[extensions.RevocationExtID.String()]
298	assert.True(t, extensions.RevocationExtID.Equal(revocationExt.Id))
299
300	var rev extensions.Revocation
301	err = rev.Unmarshal(revocationExt.Value)
302	require.NoError(t, err)
303
304	err = rev.Verify(manageableFullIdentity.CA.Cert)
305	require.NoError(t, err)
306}
307
308func TestEncodeDecodePeerIdentity(t *testing.T) {
309	ctx := testcontext.New(t)
310
311	peerID, err := testidentity.NewTestIdentity(ctx)
312	require.NoError(t, err)
313	pi := peerID.PeerIdentity()
314
315	// encode the peer identity
316	encodedPiBytes := identity.EncodePeerIdentity(pi)
317	assert.NotNil(t, encodedPiBytes)
318	// decode the peer identity
319	decodedPi, err := identity.DecodePeerIdentity(ctx, encodedPiBytes)
320	assert.NoError(t, err)
321	// again encode the above decoded peer identity and compare
322	decodedPiBytes := identity.EncodePeerIdentity(decodedPi)
323	assert.Equal(t, encodedPiBytes, decodedPiBytes)
324}
325