1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package authorization
5
6import (
7	"bytes"
8	"encoding/gob"
9	"net"
10	"testing"
11	"time"
12
13	"github.com/btcsuite/btcutil/base58"
14	"github.com/stretchr/testify/assert"
15	"github.com/stretchr/testify/require"
16
17	"storj.io/common/identity/testidentity"
18	"storj.io/common/peertls/tlsopts"
19	"storj.io/common/rpc"
20	"storj.io/common/storj"
21	"storj.io/common/testcontext"
22	"storj.io/storj/certificate/certificateclient"
23	"storj.io/storj/certificate/certificatepb"
24)
25
26var (
27	t1 = Token{
28		UserID: "user@mail.test",
29		Data:   [tokenDataLength]byte{1, 2, 3},
30	}
31	t2 = Token{
32		UserID: "user2@mail.test",
33		Data:   [tokenDataLength]byte{4, 5, 6},
34	}
35)
36
37func TestNewAuthorization(t *testing.T) {
38	userID := "user@mail.test"
39	auth, err := NewAuthorization(userID)
40	require.NoError(t, err)
41	require.NotNil(t, auth)
42
43	assert.NotZero(t, auth.Token)
44	assert.Equal(t, userID, auth.Token.UserID)
45	assert.NotEmpty(t, auth.Token.Data)
46}
47
48func TestAuthorizations_Marshal(t *testing.T) {
49	expectedAuths := Group{
50		{Token: t1},
51		{Token: t2},
52	}
53
54	authsBytes, err := expectedAuths.Marshal()
55	require.NoError(t, err)
56	require.NotEmpty(t, authsBytes)
57
58	var actualAuths Group
59	decoder := gob.NewDecoder(bytes.NewBuffer(authsBytes))
60	err = decoder.Decode(&actualAuths)
61	assert.NoError(t, err)
62	assert.NotNil(t, actualAuths)
63	assert.Equal(t, expectedAuths, actualAuths)
64}
65
66func TestAuthorizations_Unmarshal(t *testing.T) {
67	expectedAuths := Group{
68		{Token: t1},
69		{Token: t2},
70	}
71
72	authsBytes, err := expectedAuths.Marshal()
73	require.NoError(t, err)
74	require.NotEmpty(t, authsBytes)
75
76	var actualAuths Group
77	err = actualAuths.Unmarshal(authsBytes)
78	assert.NoError(t, err)
79	assert.NotNil(t, actualAuths)
80	assert.Equal(t, expectedAuths, actualAuths)
81}
82
83func TestAuthorizations_Group(t *testing.T) {
84	auths := make(Group, 10)
85	for i := 0; i < 10; i++ {
86		if i%2 == 0 {
87			auths[i] = &Authorization{
88				Token: t1,
89				Claim: &Claim{
90					Timestamp: time.Now().Unix(),
91				},
92			}
93		} else {
94			auths[i] = &Authorization{
95				Token: t2,
96			}
97		}
98	}
99
100	claimed, open := auths.GroupByClaimed()
101	for _, a := range claimed {
102		assert.NotNil(t, a.Claim)
103	}
104	for _, a := range open {
105		assert.Nil(t, a.Claim)
106	}
107}
108
109func TestParseToken_Valid(t *testing.T) {
110	userID := "user@mail.test"
111	data := [tokenDataLength]byte{1, 2, 3}
112
113	cases := []struct {
114		testID string
115		userID string
116	}{
117		{
118			"valid token",
119			userID,
120		},
121		{
122			"multiple delimiters",
123			"us" + tokenDelimiter + "er@mail.test",
124		},
125	}
126
127	for _, c := range cases {
128		testCase := c
129		t.Run(testCase.testID, func(t *testing.T) {
130			b58Data := base58.CheckEncode(data[:], tokenVersion)
131			tokenString := testCase.userID + tokenDelimiter + b58Data
132			token, err := ParseToken(tokenString)
133			require.NoError(t, err)
134			require.NotNil(t, token)
135
136			assert.Equal(t, testCase.userID, token.UserID)
137			assert.Equal(t, data[:], token.Data[:])
138		})
139	}
140}
141
142func TestParseToken_Invalid(t *testing.T) {
143	userID := "user@mail.test"
144	data := [tokenDataLength]byte{1, 2, 3}
145
146	cases := []struct {
147		testID      string
148		tokenString string
149	}{
150		{
151			"no delimiter",
152			userID + base58.CheckEncode(data[:], tokenVersion),
153		},
154		{
155			"missing userID",
156			tokenDelimiter + base58.CheckEncode(data[:], tokenVersion),
157		},
158		{
159			"not enough data",
160			userID + tokenDelimiter + base58.CheckEncode(data[:len(data)-10], tokenVersion),
161		},
162		{
163			"too much data",
164			userID + tokenDelimiter + base58.CheckEncode(append(data[:], []byte{0, 0, 0}...), tokenVersion),
165		},
166		{
167			"data checksum or format error",
168			userID + tokenDelimiter + base58.CheckEncode(data[:], tokenVersion)[:len(base58.CheckEncode(data[:], tokenVersion))-4] + "0000",
169		},
170	}
171
172	for _, c := range cases {
173		testCase := c
174		t.Run(testCase.testID, func(t *testing.T) {
175			token, err := ParseToken(testCase.tokenString)
176			assert.Nil(t, token)
177			assert.True(t, ErrInvalidToken.Has(err))
178		})
179	}
180}
181
182func TestToken_Equal(t *testing.T) {
183	assert.True(t, t1.Equal(&t1))
184	assert.False(t, t1.Equal(&t2))
185}
186
187func TestNewClient(t *testing.T) {
188	t.Skip("needs proper rpc listener to work")
189
190	ctx := testcontext.New(t)
191	defer ctx.Cleanup()
192
193	ident, err := testidentity.PregeneratedIdentity(0, storj.LatestIDVersion())
194	require.NoError(t, err)
195	require.NotNil(t, ident)
196
197	listener, err := net.Listen("tcp", "127.0.0.1:0")
198	require.NoError(t, err)
199	require.NotNil(t, listener)
200
201	defer ctx.Check(listener.Close)
202	ctx.Go(func() error {
203		for {
204			conn, err := listener.Accept()
205			if err != nil {
206				return nil //nolint: nilerr // ignore closing error
207			}
208			if err := conn.Close(); err != nil {
209				return err
210			}
211		}
212	})
213
214	tlsOptions, err := tlsopts.NewOptions(ident, tlsopts.Config{}, nil)
215	require.NoError(t, err)
216
217	dialer := rpc.NewDefaultDialer(tlsOptions)
218
219	t.Run("Basic", func(t *testing.T) {
220		client, err := certificateclient.New(ctx, dialer, listener.Addr().String())
221		assert.NoError(t, err)
222		assert.NotNil(t, client)
223
224		defer ctx.Check(client.Close)
225	})
226
227	t.Run("ClientFrom", func(t *testing.T) {
228		conn, err := dialer.DialAddressInsecure(ctx, listener.Addr().String())
229		require.NoError(t, err)
230		require.NotNil(t, conn)
231
232		defer ctx.Check(conn.Close)
233
234		client := certificateclient.NewClientFrom(certificatepb.NewDRPCCertificatesClient(conn))
235		assert.NoError(t, err)
236		assert.NotNil(t, client)
237
238		defer ctx.Check(client.Close)
239	})
240}
241