1// Copyright (C) 2020 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package usedserials
5
6import (
7	"encoding/binary"
8	"math/rand"
9	"sort"
10	"sync"
11	"time"
12
13	"github.com/spacemonkeygo/monkit/v3"
14	"github.com/zeebo/errs"
15
16	"storj.io/common/memory"
17	"storj.io/common/storj"
18)
19
20var (
21	// ErrSerials defines the usedserials store error class.
22	ErrSerials = errs.Class("usedserials")
23	// ErrSerialAlreadyExists defines an error class for duplicate usedserials.
24	ErrSerialAlreadyExists = errs.Class("used serial already exists in store")
25
26	mon = monkit.Package()
27)
28
29const (
30	// PartialSize is the size of a partial serial number.
31	PartialSize = memory.Size(len(Partial{}))
32	// FullSize is the size of a full serial number.
33	FullSize = memory.Size(len(storj.SerialNumber{}))
34)
35
36// Partial represents the last 8 bytes of a serial number. It is used when the first 8 are based on the expiration date.
37type Partial [8]byte
38
39// Less returns true if partial serial a is less than partial serial b and false otherwise.
40func (a Partial) Less(b Partial) bool {
41	return binary.BigEndian.Uint64(a[:]) < binary.BigEndian.Uint64(b[:])
42}
43
44// Full is a copy of the SerialNumber type. It is necessary so we can define a Less function on it.
45type Full storj.SerialNumber
46
47// Less returns true if partial serial a is less than partial serial b and false otherwise.
48func (a Full) Less(b Full) bool {
49	return binary.BigEndian.Uint64(a[:]) < binary.BigEndian.Uint64(b[:])
50}
51
52// serialsList is a structure that contains a list of partial serials and a list of full serials.
53//
54// For serials where expiration time is the first 8 bytes, it uses partialSerials.
55// It uses fullSerials otherwise.
56type serialsList struct {
57	partialSerials []Partial
58	fullSerials    []storj.SerialNumber
59}
60
61// Table is an in-memory store for serial numbers.
62type Table struct {
63	mu sync.Mutex
64
65	// key 1: satellite ID, key 2: expiration hour (in unix time), value: a list of serial numbers
66	serials map[storj.NodeID]map[int64]serialsList
67
68	maxMemory  memory.Size
69	memoryUsed memory.Size
70}
71
72// NewTable creates and returns a new usedserials in-memory store.
73func NewTable(maxMemory memory.Size) *Table {
74	if maxMemory <= 0 {
75		panic("max memory for usedserials store is 0")
76	}
77	return &Table{
78		serials:   make(map[storj.NodeID]map[int64]serialsList),
79		maxMemory: maxMemory,
80	}
81}
82
83// Add adds a serial to the store, or returns an error if the serial number was already added.
84// It randomly deletes items from the store if the set maxMemory is exceeded.
85func (table *Table) Add(satelliteID storj.NodeID, serialNumber storj.SerialNumber, expiration time.Time) error {
86	table.mu.Lock()
87	defer table.mu.Unlock()
88
89	satMap, ok := table.serials[satelliteID]
90	if !ok {
91		satMap = make(map[int64]serialsList)
92		table.serials[satelliteID] = satMap
93	}
94
95	expirationHour := ceilExpirationHour(expiration)
96	list, ok := satMap[expirationHour]
97	if !ok {
98		list = serialsList{}
99		satMap[expirationHour] = list
100	}
101
102	// determine whether we can use a partial serial number
103	partialSerial, usePartial := tryTruncate(serialNumber, expiration)
104
105	if usePartial {
106		partialList := list.partialSerials
107		partialList, err := insertPartial(partialList, partialSerial)
108		if err != nil {
109			return err
110		}
111
112		list.partialSerials = partialList
113		table.serials[satelliteID][expirationHour] = list
114		table.memoryUsed += PartialSize
115	} else {
116		fullList := list.fullSerials
117		fullList, err := insertSerial(fullList, serialNumber)
118		if err != nil {
119			return err
120		}
121
122		list.fullSerials = fullList
123		table.serials[satelliteID][expirationHour] = list
124		table.memoryUsed += FullSize
125	}
126
127	// Check to see if the structure exceeds the max allowed size.
128	// If so, delete random items until there is enough space.
129	for table.memoryUsed > table.maxMemory {
130		err := table.deleteRandomSerial()
131		if err != nil {
132			return err
133		}
134	}
135
136	return nil
137}
138
139// DeleteExpired deletes expired serial numbers if their expiration hour has passed.
140func (table *Table) DeleteExpired(now time.Time) {
141	table.mu.Lock()
142	defer table.mu.Unlock()
143
144	partialToDelete := 0
145	fullToDelete := 0
146	for _, satMap := range table.serials {
147		for expirationHour, list := range satMap {
148			if expirationHour < now.Unix() {
149				partialToDelete += len(list.partialSerials)
150				fullToDelete += len(list.fullSerials)
151
152				delete(satMap, expirationHour)
153			}
154		}
155	}
156
157	table.memoryUsed -= memory.Size(partialToDelete) * PartialSize
158	table.memoryUsed -= memory.Size(fullToDelete) * FullSize
159}
160
161// Exists determines whether a serial number exists in the table.
162func (table *Table) Exists(satelliteID storj.NodeID, serialNumber storj.SerialNumber, expiration time.Time) bool {
163	table.mu.Lock()
164	defer table.mu.Unlock()
165
166	expirationHour := ceilExpirationHour(expiration)
167	serialsList := table.serials[satelliteID][expirationHour]
168
169	partial, usePartial := tryTruncate(serialNumber, expiration)
170	if usePartial {
171		for _, serial := range serialsList.partialSerials {
172			if serial == partial {
173				return true
174			}
175		}
176	} else {
177		for _, serial := range serialsList.fullSerials {
178			if serial == serialNumber {
179				return true
180			}
181		}
182	}
183	return false
184}
185
186// Count iterates over all the items in the table and returns the number.
187func (table *Table) Count() int {
188	table.mu.Lock()
189	defer table.mu.Unlock()
190
191	count := 0
192	for _, satMap := range table.serials {
193		for _, serialsList := range satMap {
194			count += len(serialsList.fullSerials)
195			count += len(serialsList.partialSerials)
196		}
197	}
198
199	return count
200}
201
202// deleteRandomSerial deletes a random item.
203// It expects the mutex to be locked before being called.
204func (table *Table) deleteRandomSerial() error {
205	mon.Meter("delete_random_serial").Mark(1) //mon:locked
206	for _, satMap := range table.serials {
207		for expirationHour, serialList := range satMap {
208			if len(serialList.partialSerials) > 0 {
209				i := rand.Intn(len(serialList.partialSerials))
210				// shift all elements after i once, to overwrite i
211				copy(serialList.partialSerials[i:], serialList.partialSerials[i+1:])
212				// truncate to get rid of last item
213				serialList.partialSerials = serialList.partialSerials[:len(serialList.partialSerials)-1]
214				satMap[expirationHour] = serialList
215				table.memoryUsed -= PartialSize
216				return nil
217			} else if len(serialList.fullSerials) > 0 {
218				i := rand.Intn(len(serialList.fullSerials))
219				// shift all elements after i once, to overwrite i
220				copy(serialList.fullSerials[i:], serialList.fullSerials[i+1:])
221				// truncate to get rid of last item
222				serialList.fullSerials = serialList.fullSerials[:len(serialList.fullSerials)-1]
223				satMap[expirationHour] = serialList
224				table.memoryUsed -= FullSize
225				return nil
226			}
227		}
228	}
229	// we should never get to this path unless config.MaxTableSize is 0
230	return ErrSerials.New("could not delete a random item")
231}
232
233// insertPartial inserts a partial serial in the correct position in a sorted list,
234// or returns an error if it is already in the list.
235func insertPartial(list []Partial, serial Partial) ([]Partial, error) {
236	i := sort.Search(len(list), func(h int) bool {
237		return serial.Less(list[h])
238	})
239	// if serial is already in the list, it will be at index i-1
240	if i > 0 && list[i-1] == serial {
241		return nil, ErrSerialAlreadyExists.New("")
242	}
243
244	// insert new serial at index i and shift everything up
245	// 1. grow the slice by one element.
246	list = append(list, Partial{})
247	// 2. move the upper part of the slice out of the way and open a hole.
248	copy(list[i+1:], list[i:])
249	// 3. store the new value.
250	list[i] = serial
251
252	return list, nil
253}
254
255// insertSerial inserts a serial in the correct position in a sorted list,
256// or returns an error if it is already in the list.
257func insertSerial(list []storj.SerialNumber, serial storj.SerialNumber) ([]storj.SerialNumber, error) {
258	i := sort.Search(len(list), func(h int) bool {
259		return serial.Less(list[h])
260	})
261	// if serial is already in the list, it will be at index i-1
262	if i > 0 && list[i-1] == serial {
263		return nil, ErrSerialAlreadyExists.New("")
264	}
265
266	// insert new serial at index i and shift everything up
267	// 1. grow the slice by one element.
268	list = append(list, storj.SerialNumber{})
269	// 2. move the upper part of the slice out of the way and open a hole.
270	copy(list[i+1:], list[i:])
271	// 3. store the new value.
272	list[i] = serial
273
274	return list, nil
275}
276
277func tryTruncate(serial storj.SerialNumber, expiration time.Time) (partial Partial, succeeded bool) {
278	// If the first 8 bytes of the serial number are based on the expiration date
279	// then we can use a partial serial number with the last 8 bytes.
280	// Otherwise, we need to use the full serial number.
281	// see satellite/orders/service.go, createSerial() for how expiration date is used in the serial number.
282	if binary.BigEndian.Uint64(serial[0:8]) == uint64(expiration.Unix()) {
283		partialSerial := Partial{}
284		copy(partialSerial[:], serial[8:])
285		return partialSerial, true
286	}
287
288	return Partial{}, false
289}
290
291func ceilExpirationHour(expiration time.Time) int64 {
292	// time.Truncate rounds down; adding (Hour-Nanosecond) ensures that we round down to the actual expiration hour
293	return expiration.Add(time.Hour - time.Nanosecond).Truncate(time.Hour).Unix()
294}
295