1// Copyright (C) 2020 Storj Labs, Inc. 2// See LICENSE for copying information. 3 4package usedserials 5 6import ( 7 "encoding/binary" 8 "math/rand" 9 "sort" 10 "sync" 11 "time" 12 13 "github.com/spacemonkeygo/monkit/v3" 14 "github.com/zeebo/errs" 15 16 "storj.io/common/memory" 17 "storj.io/common/storj" 18) 19 20var ( 21 // ErrSerials defines the usedserials store error class. 22 ErrSerials = errs.Class("usedserials") 23 // ErrSerialAlreadyExists defines an error class for duplicate usedserials. 24 ErrSerialAlreadyExists = errs.Class("used serial already exists in store") 25 26 mon = monkit.Package() 27) 28 29const ( 30 // PartialSize is the size of a partial serial number. 31 PartialSize = memory.Size(len(Partial{})) 32 // FullSize is the size of a full serial number. 33 FullSize = memory.Size(len(storj.SerialNumber{})) 34) 35 36// Partial represents the last 8 bytes of a serial number. It is used when the first 8 are based on the expiration date. 37type Partial [8]byte 38 39// Less returns true if partial serial a is less than partial serial b and false otherwise. 40func (a Partial) Less(b Partial) bool { 41 return binary.BigEndian.Uint64(a[:]) < binary.BigEndian.Uint64(b[:]) 42} 43 44// Full is a copy of the SerialNumber type. It is necessary so we can define a Less function on it. 45type Full storj.SerialNumber 46 47// Less returns true if partial serial a is less than partial serial b and false otherwise. 48func (a Full) Less(b Full) bool { 49 return binary.BigEndian.Uint64(a[:]) < binary.BigEndian.Uint64(b[:]) 50} 51 52// serialsList is a structure that contains a list of partial serials and a list of full serials. 53// 54// For serials where expiration time is the first 8 bytes, it uses partialSerials. 55// It uses fullSerials otherwise. 56type serialsList struct { 57 partialSerials []Partial 58 fullSerials []storj.SerialNumber 59} 60 61// Table is an in-memory store for serial numbers. 62type Table struct { 63 mu sync.Mutex 64 65 // key 1: satellite ID, key 2: expiration hour (in unix time), value: a list of serial numbers 66 serials map[storj.NodeID]map[int64]serialsList 67 68 maxMemory memory.Size 69 memoryUsed memory.Size 70} 71 72// NewTable creates and returns a new usedserials in-memory store. 73func NewTable(maxMemory memory.Size) *Table { 74 if maxMemory <= 0 { 75 panic("max memory for usedserials store is 0") 76 } 77 return &Table{ 78 serials: make(map[storj.NodeID]map[int64]serialsList), 79 maxMemory: maxMemory, 80 } 81} 82 83// Add adds a serial to the store, or returns an error if the serial number was already added. 84// It randomly deletes items from the store if the set maxMemory is exceeded. 85func (table *Table) Add(satelliteID storj.NodeID, serialNumber storj.SerialNumber, expiration time.Time) error { 86 table.mu.Lock() 87 defer table.mu.Unlock() 88 89 satMap, ok := table.serials[satelliteID] 90 if !ok { 91 satMap = make(map[int64]serialsList) 92 table.serials[satelliteID] = satMap 93 } 94 95 expirationHour := ceilExpirationHour(expiration) 96 list, ok := satMap[expirationHour] 97 if !ok { 98 list = serialsList{} 99 satMap[expirationHour] = list 100 } 101 102 // determine whether we can use a partial serial number 103 partialSerial, usePartial := tryTruncate(serialNumber, expiration) 104 105 if usePartial { 106 partialList := list.partialSerials 107 partialList, err := insertPartial(partialList, partialSerial) 108 if err != nil { 109 return err 110 } 111 112 list.partialSerials = partialList 113 table.serials[satelliteID][expirationHour] = list 114 table.memoryUsed += PartialSize 115 } else { 116 fullList := list.fullSerials 117 fullList, err := insertSerial(fullList, serialNumber) 118 if err != nil { 119 return err 120 } 121 122 list.fullSerials = fullList 123 table.serials[satelliteID][expirationHour] = list 124 table.memoryUsed += FullSize 125 } 126 127 // Check to see if the structure exceeds the max allowed size. 128 // If so, delete random items until there is enough space. 129 for table.memoryUsed > table.maxMemory { 130 err := table.deleteRandomSerial() 131 if err != nil { 132 return err 133 } 134 } 135 136 return nil 137} 138 139// DeleteExpired deletes expired serial numbers if their expiration hour has passed. 140func (table *Table) DeleteExpired(now time.Time) { 141 table.mu.Lock() 142 defer table.mu.Unlock() 143 144 partialToDelete := 0 145 fullToDelete := 0 146 for _, satMap := range table.serials { 147 for expirationHour, list := range satMap { 148 if expirationHour < now.Unix() { 149 partialToDelete += len(list.partialSerials) 150 fullToDelete += len(list.fullSerials) 151 152 delete(satMap, expirationHour) 153 } 154 } 155 } 156 157 table.memoryUsed -= memory.Size(partialToDelete) * PartialSize 158 table.memoryUsed -= memory.Size(fullToDelete) * FullSize 159} 160 161// Exists determines whether a serial number exists in the table. 162func (table *Table) Exists(satelliteID storj.NodeID, serialNumber storj.SerialNumber, expiration time.Time) bool { 163 table.mu.Lock() 164 defer table.mu.Unlock() 165 166 expirationHour := ceilExpirationHour(expiration) 167 serialsList := table.serials[satelliteID][expirationHour] 168 169 partial, usePartial := tryTruncate(serialNumber, expiration) 170 if usePartial { 171 for _, serial := range serialsList.partialSerials { 172 if serial == partial { 173 return true 174 } 175 } 176 } else { 177 for _, serial := range serialsList.fullSerials { 178 if serial == serialNumber { 179 return true 180 } 181 } 182 } 183 return false 184} 185 186// Count iterates over all the items in the table and returns the number. 187func (table *Table) Count() int { 188 table.mu.Lock() 189 defer table.mu.Unlock() 190 191 count := 0 192 for _, satMap := range table.serials { 193 for _, serialsList := range satMap { 194 count += len(serialsList.fullSerials) 195 count += len(serialsList.partialSerials) 196 } 197 } 198 199 return count 200} 201 202// deleteRandomSerial deletes a random item. 203// It expects the mutex to be locked before being called. 204func (table *Table) deleteRandomSerial() error { 205 mon.Meter("delete_random_serial").Mark(1) //mon:locked 206 for _, satMap := range table.serials { 207 for expirationHour, serialList := range satMap { 208 if len(serialList.partialSerials) > 0 { 209 i := rand.Intn(len(serialList.partialSerials)) 210 // shift all elements after i once, to overwrite i 211 copy(serialList.partialSerials[i:], serialList.partialSerials[i+1:]) 212 // truncate to get rid of last item 213 serialList.partialSerials = serialList.partialSerials[:len(serialList.partialSerials)-1] 214 satMap[expirationHour] = serialList 215 table.memoryUsed -= PartialSize 216 return nil 217 } else if len(serialList.fullSerials) > 0 { 218 i := rand.Intn(len(serialList.fullSerials)) 219 // shift all elements after i once, to overwrite i 220 copy(serialList.fullSerials[i:], serialList.fullSerials[i+1:]) 221 // truncate to get rid of last item 222 serialList.fullSerials = serialList.fullSerials[:len(serialList.fullSerials)-1] 223 satMap[expirationHour] = serialList 224 table.memoryUsed -= FullSize 225 return nil 226 } 227 } 228 } 229 // we should never get to this path unless config.MaxTableSize is 0 230 return ErrSerials.New("could not delete a random item") 231} 232 233// insertPartial inserts a partial serial in the correct position in a sorted list, 234// or returns an error if it is already in the list. 235func insertPartial(list []Partial, serial Partial) ([]Partial, error) { 236 i := sort.Search(len(list), func(h int) bool { 237 return serial.Less(list[h]) 238 }) 239 // if serial is already in the list, it will be at index i-1 240 if i > 0 && list[i-1] == serial { 241 return nil, ErrSerialAlreadyExists.New("") 242 } 243 244 // insert new serial at index i and shift everything up 245 // 1. grow the slice by one element. 246 list = append(list, Partial{}) 247 // 2. move the upper part of the slice out of the way and open a hole. 248 copy(list[i+1:], list[i:]) 249 // 3. store the new value. 250 list[i] = serial 251 252 return list, nil 253} 254 255// insertSerial inserts a serial in the correct position in a sorted list, 256// or returns an error if it is already in the list. 257func insertSerial(list []storj.SerialNumber, serial storj.SerialNumber) ([]storj.SerialNumber, error) { 258 i := sort.Search(len(list), func(h int) bool { 259 return serial.Less(list[h]) 260 }) 261 // if serial is already in the list, it will be at index i-1 262 if i > 0 && list[i-1] == serial { 263 return nil, ErrSerialAlreadyExists.New("") 264 } 265 266 // insert new serial at index i and shift everything up 267 // 1. grow the slice by one element. 268 list = append(list, storj.SerialNumber{}) 269 // 2. move the upper part of the slice out of the way and open a hole. 270 copy(list[i+1:], list[i:]) 271 // 3. store the new value. 272 list[i] = serial 273 274 return list, nil 275} 276 277func tryTruncate(serial storj.SerialNumber, expiration time.Time) (partial Partial, succeeded bool) { 278 // If the first 8 bytes of the serial number are based on the expiration date 279 // then we can use a partial serial number with the last 8 bytes. 280 // Otherwise, we need to use the full serial number. 281 // see satellite/orders/service.go, createSerial() for how expiration date is used in the serial number. 282 if binary.BigEndian.Uint64(serial[0:8]) == uint64(expiration.Unix()) { 283 partialSerial := Partial{} 284 copy(partialSerial[:], serial[8:]) 285 return partialSerial, true 286 } 287 288 return Partial{}, false 289} 290 291func ceilExpirationHour(expiration time.Time) int64 { 292 // time.Truncate rounds down; adding (Hour-Nanosecond) ensures that we round down to the actual expiration hour 293 return expiration.Add(time.Hour - time.Nanosecond).Truncate(time.Hour).Unix() 294} 295