1package scheduler 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "sync" 8 "time" 9 10 dcontext "github.com/docker/distribution/context" 11 "github.com/docker/distribution/reference" 12 "github.com/docker/distribution/registry/storage/driver" 13) 14 15// onTTLExpiryFunc is called when a repository's TTL expires 16type expiryFunc func(reference.Reference) error 17 18const ( 19 entryTypeBlob = iota 20 entryTypeManifest 21 indexSaveFrequency = 5 * time.Second 22) 23 24// schedulerEntry represents an entry in the scheduler 25// fields are exported for serialization 26type schedulerEntry struct { 27 Key string `json:"Key"` 28 Expiry time.Time `json:"ExpiryData"` 29 EntryType int `json:"EntryType"` 30 31 timer *time.Timer 32} 33 34// New returns a new instance of the scheduler 35func New(ctx context.Context, driver driver.StorageDriver, path string) *TTLExpirationScheduler { 36 return &TTLExpirationScheduler{ 37 entries: make(map[string]*schedulerEntry), 38 driver: driver, 39 pathToStateFile: path, 40 ctx: ctx, 41 stopped: true, 42 doneChan: make(chan struct{}), 43 saveTimer: time.NewTicker(indexSaveFrequency), 44 } 45} 46 47// TTLExpirationScheduler is a scheduler used to perform actions 48// when TTLs expire 49type TTLExpirationScheduler struct { 50 sync.Mutex 51 52 entries map[string]*schedulerEntry 53 54 driver driver.StorageDriver 55 ctx context.Context 56 pathToStateFile string 57 58 stopped bool 59 60 onBlobExpire expiryFunc 61 onManifestExpire expiryFunc 62 63 indexDirty bool 64 saveTimer *time.Ticker 65 doneChan chan struct{} 66} 67 68// OnBlobExpire is called when a scheduled blob's TTL expires 69func (ttles *TTLExpirationScheduler) OnBlobExpire(f expiryFunc) { 70 ttles.Lock() 71 defer ttles.Unlock() 72 73 ttles.onBlobExpire = f 74} 75 76// OnManifestExpire is called when a scheduled manifest's TTL expires 77func (ttles *TTLExpirationScheduler) OnManifestExpire(f expiryFunc) { 78 ttles.Lock() 79 defer ttles.Unlock() 80 81 ttles.onManifestExpire = f 82} 83 84// AddBlob schedules a blob cleanup after ttl expires 85func (ttles *TTLExpirationScheduler) AddBlob(blobRef reference.Canonical, ttl time.Duration) error { 86 ttles.Lock() 87 defer ttles.Unlock() 88 89 if ttles.stopped { 90 return fmt.Errorf("scheduler not started") 91 } 92 93 ttles.add(blobRef, ttl, entryTypeBlob) 94 return nil 95} 96 97// AddManifest schedules a manifest cleanup after ttl expires 98func (ttles *TTLExpirationScheduler) AddManifest(manifestRef reference.Canonical, ttl time.Duration) error { 99 ttles.Lock() 100 defer ttles.Unlock() 101 102 if ttles.stopped { 103 return fmt.Errorf("scheduler not started") 104 } 105 106 ttles.add(manifestRef, ttl, entryTypeManifest) 107 return nil 108} 109 110// Start starts the scheduler 111func (ttles *TTLExpirationScheduler) Start() error { 112 ttles.Lock() 113 defer ttles.Unlock() 114 115 err := ttles.readState() 116 if err != nil { 117 return err 118 } 119 120 if !ttles.stopped { 121 return fmt.Errorf("Scheduler already started") 122 } 123 124 dcontext.GetLogger(ttles.ctx).Infof("Starting cached object TTL expiration scheduler...") 125 ttles.stopped = false 126 127 // Start timer for each deserialized entry 128 for _, entry := range ttles.entries { 129 entry.timer = ttles.startTimer(entry, entry.Expiry.Sub(time.Now())) 130 } 131 132 // Start a ticker to periodically save the entries index 133 134 go func() { 135 for { 136 select { 137 case <-ttles.saveTimer.C: 138 ttles.Lock() 139 if !ttles.indexDirty { 140 ttles.Unlock() 141 continue 142 } 143 144 err := ttles.writeState() 145 if err != nil { 146 dcontext.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err) 147 } else { 148 ttles.indexDirty = false 149 } 150 ttles.Unlock() 151 152 case <-ttles.doneChan: 153 return 154 } 155 } 156 }() 157 158 return nil 159} 160 161func (ttles *TTLExpirationScheduler) add(r reference.Reference, ttl time.Duration, eType int) { 162 entry := &schedulerEntry{ 163 Key: r.String(), 164 Expiry: time.Now().Add(ttl), 165 EntryType: eType, 166 } 167 dcontext.GetLogger(ttles.ctx).Infof("Adding new scheduler entry for %s with ttl=%s", entry.Key, entry.Expiry.Sub(time.Now())) 168 if oldEntry, present := ttles.entries[entry.Key]; present && oldEntry.timer != nil { 169 oldEntry.timer.Stop() 170 } 171 ttles.entries[entry.Key] = entry 172 entry.timer = ttles.startTimer(entry, ttl) 173 ttles.indexDirty = true 174} 175 176func (ttles *TTLExpirationScheduler) startTimer(entry *schedulerEntry, ttl time.Duration) *time.Timer { 177 return time.AfterFunc(ttl, func() { 178 ttles.Lock() 179 defer ttles.Unlock() 180 181 var f expiryFunc 182 183 switch entry.EntryType { 184 case entryTypeBlob: 185 f = ttles.onBlobExpire 186 case entryTypeManifest: 187 f = ttles.onManifestExpire 188 default: 189 f = func(reference.Reference) error { 190 return fmt.Errorf("scheduler entry type") 191 } 192 } 193 194 ref, err := reference.Parse(entry.Key) 195 if err == nil { 196 if err := f(ref); err != nil { 197 dcontext.GetLogger(ttles.ctx).Errorf("Scheduler error returned from OnExpire(%s): %s", entry.Key, err) 198 } 199 } else { 200 dcontext.GetLogger(ttles.ctx).Errorf("Error unpacking reference: %s", err) 201 } 202 203 delete(ttles.entries, entry.Key) 204 ttles.indexDirty = true 205 }) 206} 207 208// Stop stops the scheduler. 209func (ttles *TTLExpirationScheduler) Stop() { 210 ttles.Lock() 211 defer ttles.Unlock() 212 213 if err := ttles.writeState(); err != nil { 214 dcontext.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err) 215 } 216 217 for _, entry := range ttles.entries { 218 entry.timer.Stop() 219 } 220 221 close(ttles.doneChan) 222 ttles.saveTimer.Stop() 223 ttles.stopped = true 224} 225 226func (ttles *TTLExpirationScheduler) writeState() error { 227 jsonBytes, err := json.Marshal(ttles.entries) 228 if err != nil { 229 return err 230 } 231 232 err = ttles.driver.PutContent(ttles.ctx, ttles.pathToStateFile, jsonBytes) 233 if err != nil { 234 return err 235 } 236 237 return nil 238} 239 240func (ttles *TTLExpirationScheduler) readState() error { 241 if _, err := ttles.driver.Stat(ttles.ctx, ttles.pathToStateFile); err != nil { 242 switch err := err.(type) { 243 case driver.PathNotFoundError: 244 return nil 245 default: 246 return err 247 } 248 } 249 250 bytes, err := ttles.driver.GetContent(ttles.ctx, ttles.pathToStateFile) 251 if err != nil { 252 return err 253 } 254 255 err = json.Unmarshal(bytes, &ttles.entries) 256 if err != nil { 257 return err 258 } 259 return nil 260} 261