1// Copyright (C) 2020 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package usedserials_test
5
6import (
7	"encoding/binary"
8	"testing"
9	"time"
10
11	"github.com/stretchr/testify/require"
12
13	"storj.io/common/identity/testidentity"
14	"storj.io/common/memory"
15	"storj.io/common/storj"
16	"storj.io/common/testrand"
17	"storj.io/storj/storagenode/piecestore/usedserials"
18)
19
20type Serial struct {
21	SatelliteID         storj.NodeID
22	SerialNumber        storj.SerialNumber
23	PartialSerialNumber usedserials.Partial
24	Expiration          time.Time
25}
26
27func TestUsedSerials(t *testing.T) {
28	usedSerials := usedserials.NewTable(memory.MiB)
29
30	node0 := testidentity.MustPregeneratedIdentity(0, storj.LatestIDVersion())
31	node1 := testidentity.MustPregeneratedIdentity(1, storj.LatestIDVersion())
32
33	serial1 := testrand.SerialNumber()
34	serial2 := testrand.SerialNumber()
35	serial3 := testrand.SerialNumber()
36	serial4 := testrand.SerialNumber()
37	serial5 := testrand.SerialNumber()
38
39	var partialSerial1, partialSerial2, partialSerial3, partialSerial4, partialSerial5 usedserials.Partial
40	copy(partialSerial1[:], serial1[8:])
41	copy(partialSerial2[:], serial2[8:])
42	copy(partialSerial3[:], serial3[8:])
43	copy(partialSerial4[:], serial4[8:])
44	copy(partialSerial5[:], serial5[8:])
45
46	now := time.Now()
47
48	// queries on empty table
49	usedSerials.DeleteExpired(now.Add(6 * time.Minute))
50	require.Zero(t, usedSerials.Count())
51
52	// let's start adding data
53	// use different timezones
54	location := time.FixedZone("XYZ", int((8 * time.Hour).Seconds()))
55
56	// the serials with expiration times embedded are based on serial4 and serial5
57	serialWithExp1 := createExpirationSerial(serial4, now.Add(8*time.Hour))
58	serialWithExp2 := createExpirationSerial(serial5, now.Add(time.Hour))
59
60	serialNumbers := []Serial{
61		{node0.ID, serial1, partialSerial1, now.Add(time.Hour)},
62		{node0.ID, serial2, partialSerial2, now.Add(4 * time.Hour)},
63		{node0.ID, serial3, partialSerial3, now.In(location).Add(8 * time.Hour)},
64		{node1.ID, serial1, partialSerial1, now.In(location).Add(time.Hour)},
65		{node1.ID, serial2, partialSerial2, now.Add(4 * time.Hour)},
66		{node1.ID, serial3, partialSerial3, now.Add(8 * time.Hour)},
67
68		{node0.ID, serialWithExp1, partialSerial4, now.Add(8 * time.Hour)},
69		{node0.ID, serialWithExp2, partialSerial5, now.Add(time.Hour)},
70		{node1.ID, serialWithExp1, partialSerial4, now.Add(8 * time.Hour)},
71		{node1.ID, serialWithExp2, partialSerial5, now.Add(time.Hour)},
72	}
73
74	// basic adding
75	for _, serial := range serialNumbers {
76		err := usedSerials.Add(serial.SatelliteID, serial.SerialNumber, serial.Expiration)
77		require.NoError(t, err)
78	}
79
80	// duplicate adds should fail
81	for _, serial := range serialNumbers {
82		err := usedSerials.Add(serial.SatelliteID, serial.SerialNumber, serial.Expiration)
83		require.Error(t, err)
84		require.True(t, usedserials.ErrSerialAlreadyExists.Has(err))
85	}
86
87	// ensure all the serials exist
88	require.Equal(t, len(serialNumbers), usedSerials.Count())
89	for _, serial := range serialNumbers {
90		require.True(t, usedSerials.Exists(serial.SatelliteID, serial.SerialNumber, serial.Expiration))
91	}
92
93	// ensure we can delete expired
94	usedSerials.DeleteExpired(now.Add(6 * time.Hour))
95
96	// check that we have actually deleted things
97	expectedAfterDelete := []Serial{
98		{node0.ID, serial3, partialSerial3, now.Add(8 * time.Hour)},
99		{node1.ID, serial3, partialSerial3, now.Add(8 * time.Hour)},
100		{node0.ID, serialWithExp1, partialSerial4, now.Add(8 * time.Hour)},
101		{node1.ID, serialWithExp1, partialSerial4, now.Add(8 * time.Hour)},
102	}
103
104	require.Equal(t, len(expectedAfterDelete), usedSerials.Count())
105	for _, serial := range expectedAfterDelete {
106		require.True(t, usedSerials.Exists(serial.SatelliteID, serial.SerialNumber, serial.Expiration))
107	}
108}
109
110// TestUsedSerialsMemory ensures that random serials are deleted if the allocated memory size is exceeded.
111func TestUsedSerialsMemory(t *testing.T) {
112	// first, test with partial serial numbers
113	entrySize := usedserials.PartialSize
114
115	// allow for up to three items
116	// add one byte so that we don't remove items at exactly the threshold when adding a duplicate.
117	usedSerials := usedserials.NewTable(3 * entrySize)
118	require.Zero(t, usedSerials.Count())
119
120	for i := 0; i < 10; i++ {
121		newNodeID := testrand.NodeID()
122		expiration := time.Now().Add(time.Hour)
123		newSerial := createExpirationSerial(testrand.SerialNumber(), expiration)
124
125		err := usedSerials.Add(newNodeID, newSerial, expiration)
126		require.NoError(t, err)
127
128		expectedCount := 3
129		if i < 2 {
130			expectedCount = i + 1
131		}
132
133		// expect count to be correct
134		require.EqualValues(t, expectedCount, usedSerials.Count())
135	}
136
137	// now, test with full serial numbers
138	entrySize = usedserials.FullSize
139
140	// allow for up to three items
141	usedSerials = usedserials.NewTable(3 * entrySize)
142	require.Zero(t, usedSerials.Count())
143
144	for i := 0; i < 10; i++ {
145		newNodeID := testrand.NodeID()
146		expiration := time.Now().Add(time.Hour)
147		newSerial := testrand.SerialNumber()
148
149		err := usedSerials.Add(newNodeID, newSerial, expiration)
150		require.NoError(t, err)
151
152		expectedCount := 3
153		if i < 2 {
154			expectedCount = i + 1
155		}
156
157		// expect count to be correct
158		require.EqualValues(t, expectedCount, usedSerials.Count())
159	}
160}
161
162func createExpirationSerial(originalSerial storj.SerialNumber, expiration time.Time) storj.SerialNumber {
163	serialWithExp := storj.SerialNumber{}
164	copy(serialWithExp[:], originalSerial[:])
165	// make first 8 bytes of serial expiration so that it is stored as a partial serial
166	binary.BigEndian.PutUint64(serialWithExp[0:8], uint64(expiration.Unix()))
167
168	return serialWithExp
169}
170