1package scheduler
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"sync"
8	"time"
9
10	dcontext "github.com/docker/distribution/context"
11	"github.com/docker/distribution/reference"
12	"github.com/docker/distribution/registry/storage/driver"
13)
14
15// onTTLExpiryFunc is called when a repository's TTL expires
16type expiryFunc func(reference.Reference) error
17
18const (
19	entryTypeBlob = iota
20	entryTypeManifest
21	indexSaveFrequency = 5 * time.Second
22)
23
24// schedulerEntry represents an entry in the scheduler
25// fields are exported for serialization
26type schedulerEntry struct {
27	Key       string    `json:"Key"`
28	Expiry    time.Time `json:"ExpiryData"`
29	EntryType int       `json:"EntryType"`
30
31	timer *time.Timer
32}
33
34// New returns a new instance of the scheduler
35func New(ctx context.Context, driver driver.StorageDriver, path string) *TTLExpirationScheduler {
36	return &TTLExpirationScheduler{
37		entries:         make(map[string]*schedulerEntry),
38		driver:          driver,
39		pathToStateFile: path,
40		ctx:             ctx,
41		stopped:         true,
42		doneChan:        make(chan struct{}),
43		saveTimer:       time.NewTicker(indexSaveFrequency),
44	}
45}
46
47// TTLExpirationScheduler is a scheduler used to perform actions
48// when TTLs expire
49type TTLExpirationScheduler struct {
50	sync.Mutex
51
52	entries map[string]*schedulerEntry
53
54	driver          driver.StorageDriver
55	ctx             context.Context
56	pathToStateFile string
57
58	stopped bool
59
60	onBlobExpire     expiryFunc
61	onManifestExpire expiryFunc
62
63	indexDirty bool
64	saveTimer  *time.Ticker
65	doneChan   chan struct{}
66}
67
68// OnBlobExpire is called when a scheduled blob's TTL expires
69func (ttles *TTLExpirationScheduler) OnBlobExpire(f expiryFunc) {
70	ttles.Lock()
71	defer ttles.Unlock()
72
73	ttles.onBlobExpire = f
74}
75
76// OnManifestExpire is called when a scheduled manifest's TTL expires
77func (ttles *TTLExpirationScheduler) OnManifestExpire(f expiryFunc) {
78	ttles.Lock()
79	defer ttles.Unlock()
80
81	ttles.onManifestExpire = f
82}
83
84// AddBlob schedules a blob cleanup after ttl expires
85func (ttles *TTLExpirationScheduler) AddBlob(blobRef reference.Canonical, ttl time.Duration) error {
86	ttles.Lock()
87	defer ttles.Unlock()
88
89	if ttles.stopped {
90		return fmt.Errorf("scheduler not started")
91	}
92
93	ttles.add(blobRef, ttl, entryTypeBlob)
94	return nil
95}
96
97// AddManifest schedules a manifest cleanup after ttl expires
98func (ttles *TTLExpirationScheduler) AddManifest(manifestRef reference.Canonical, ttl time.Duration) error {
99	ttles.Lock()
100	defer ttles.Unlock()
101
102	if ttles.stopped {
103		return fmt.Errorf("scheduler not started")
104	}
105
106	ttles.add(manifestRef, ttl, entryTypeManifest)
107	return nil
108}
109
110// Start starts the scheduler
111func (ttles *TTLExpirationScheduler) Start() error {
112	ttles.Lock()
113	defer ttles.Unlock()
114
115	err := ttles.readState()
116	if err != nil {
117		return err
118	}
119
120	if !ttles.stopped {
121		return fmt.Errorf("Scheduler already started")
122	}
123
124	dcontext.GetLogger(ttles.ctx).Infof("Starting cached object TTL expiration scheduler...")
125	ttles.stopped = false
126
127	// Start timer for each deserialized entry
128	for _, entry := range ttles.entries {
129		entry.timer = ttles.startTimer(entry, entry.Expiry.Sub(time.Now()))
130	}
131
132	// Start a ticker to periodically save the entries index
133
134	go func() {
135		for {
136			select {
137			case <-ttles.saveTimer.C:
138				ttles.Lock()
139				if !ttles.indexDirty {
140					ttles.Unlock()
141					continue
142				}
143
144				err := ttles.writeState()
145				if err != nil {
146					dcontext.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err)
147				} else {
148					ttles.indexDirty = false
149				}
150				ttles.Unlock()
151
152			case <-ttles.doneChan:
153				return
154			}
155		}
156	}()
157
158	return nil
159}
160
161func (ttles *TTLExpirationScheduler) add(r reference.Reference, ttl time.Duration, eType int) {
162	entry := &schedulerEntry{
163		Key:       r.String(),
164		Expiry:    time.Now().Add(ttl),
165		EntryType: eType,
166	}
167	dcontext.GetLogger(ttles.ctx).Infof("Adding new scheduler entry for %s with ttl=%s", entry.Key, entry.Expiry.Sub(time.Now()))
168	if oldEntry, present := ttles.entries[entry.Key]; present && oldEntry.timer != nil {
169		oldEntry.timer.Stop()
170	}
171	ttles.entries[entry.Key] = entry
172	entry.timer = ttles.startTimer(entry, ttl)
173	ttles.indexDirty = true
174}
175
176func (ttles *TTLExpirationScheduler) startTimer(entry *schedulerEntry, ttl time.Duration) *time.Timer {
177	return time.AfterFunc(ttl, func() {
178		ttles.Lock()
179		defer ttles.Unlock()
180
181		var f expiryFunc
182
183		switch entry.EntryType {
184		case entryTypeBlob:
185			f = ttles.onBlobExpire
186		case entryTypeManifest:
187			f = ttles.onManifestExpire
188		default:
189			f = func(reference.Reference) error {
190				return fmt.Errorf("scheduler entry type")
191			}
192		}
193
194		ref, err := reference.Parse(entry.Key)
195		if err == nil {
196			if err := f(ref); err != nil {
197				dcontext.GetLogger(ttles.ctx).Errorf("Scheduler error returned from OnExpire(%s): %s", entry.Key, err)
198			}
199		} else {
200			dcontext.GetLogger(ttles.ctx).Errorf("Error unpacking reference: %s", err)
201		}
202
203		delete(ttles.entries, entry.Key)
204		ttles.indexDirty = true
205	})
206}
207
208// Stop stops the scheduler.
209func (ttles *TTLExpirationScheduler) Stop() {
210	ttles.Lock()
211	defer ttles.Unlock()
212
213	if err := ttles.writeState(); err != nil {
214		dcontext.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err)
215	}
216
217	for _, entry := range ttles.entries {
218		entry.timer.Stop()
219	}
220
221	close(ttles.doneChan)
222	ttles.saveTimer.Stop()
223	ttles.stopped = true
224}
225
226func (ttles *TTLExpirationScheduler) writeState() error {
227	jsonBytes, err := json.Marshal(ttles.entries)
228	if err != nil {
229		return err
230	}
231
232	err = ttles.driver.PutContent(ttles.ctx, ttles.pathToStateFile, jsonBytes)
233	if err != nil {
234		return err
235	}
236
237	return nil
238}
239
240func (ttles *TTLExpirationScheduler) readState() error {
241	if _, err := ttles.driver.Stat(ttles.ctx, ttles.pathToStateFile); err != nil {
242		switch err := err.(type) {
243		case driver.PathNotFoundError:
244			return nil
245		default:
246			return err
247		}
248	}
249
250	bytes, err := ttles.driver.GetContent(ttles.ctx, ttles.pathToStateFile)
251	if err != nil {
252		return err
253	}
254
255	err = json.Unmarshal(bytes, &ttles.entries)
256	if err != nil {
257		return err
258	}
259	return nil
260}
261