1package clockwork
2
3import (
4	"sync"
5	"time"
6)
7
8// Clock provides an interface that packages can use instead of directly
9// using the time module, so that chronology-related behavior can be tested
10type Clock interface {
11	After(d time.Duration) <-chan time.Time
12	Sleep(d time.Duration)
13	Now() time.Time
14}
15
16// FakeClock provides an interface for a clock which can be
17// manually advanced through time
18type FakeClock interface {
19	Clock
20	// Advance advances the FakeClock to a new point in time, ensuring any existing
21	// sleepers are notified appropriately before returning
22	Advance(d time.Duration)
23	// BlockUntil will block until the FakeClock has the given number of
24	// sleepers (callers of Sleep or After)
25	BlockUntil(n int)
26}
27
28// NewRealClock returns a Clock which simply delegates calls to the actual time
29// package; it should be used by packages in production.
30func NewRealClock() Clock {
31	return &realClock{}
32}
33
34// NewFakeClock returns a FakeClock implementation which can be
35// manually advanced through time for testing. The initial time of the
36// FakeClock will be an arbitrary non-zero time.
37func NewFakeClock() FakeClock {
38	// use a fixture that does not fulfill Time.IsZero()
39	return NewFakeClockAt(time.Date(1984, time.April, 4, 0, 0, 0, 0, time.UTC))
40}
41
42// NewFakeClockAt returns a FakeClock initialised at the given time.Time.
43func NewFakeClockAt(t time.Time) FakeClock {
44	return &fakeClock{
45		time: t,
46	}
47}
48
49type realClock struct{}
50
51func (rc *realClock) After(d time.Duration) <-chan time.Time {
52	return time.After(d)
53}
54
55func (rc *realClock) Sleep(d time.Duration) {
56	time.Sleep(d)
57}
58
59func (rc *realClock) Now() time.Time {
60	return time.Now()
61}
62
63type fakeClock struct {
64	sleepers []*sleeper
65	blockers []*blocker
66	time     time.Time
67
68	l sync.RWMutex
69}
70
71// sleeper represents a caller of After or Sleep
72type sleeper struct {
73	until time.Time
74	done  chan time.Time
75}
76
77// blocker represents a caller of BlockUntil
78type blocker struct {
79	count int
80	ch    chan struct{}
81}
82
83// After mimics time.After; it waits for the given duration to elapse on the
84// fakeClock, then sends the current time on the returned channel.
85func (fc *fakeClock) After(d time.Duration) <-chan time.Time {
86	fc.l.Lock()
87	defer fc.l.Unlock()
88	now := fc.time
89	done := make(chan time.Time, 1)
90	if d.Nanoseconds() == 0 {
91		// special case - trigger immediately
92		done <- now
93	} else {
94		// otherwise, add to the set of sleepers
95		s := &sleeper{
96			until: now.Add(d),
97			done:  done,
98		}
99		fc.sleepers = append(fc.sleepers, s)
100		// and notify any blockers
101		fc.blockers = notifyBlockers(fc.blockers, len(fc.sleepers))
102	}
103	return done
104}
105
106// notifyBlockers notifies all the blockers waiting until the
107// given number of sleepers are waiting on the fakeClock. It
108// returns an updated slice of blockers (i.e. those still waiting)
109func notifyBlockers(blockers []*blocker, count int) (newBlockers []*blocker) {
110	for _, b := range blockers {
111		if b.count == count {
112			close(b.ch)
113		} else {
114			newBlockers = append(newBlockers, b)
115		}
116	}
117	return
118}
119
120// Sleep blocks until the given duration has passed on the fakeClock
121func (fc *fakeClock) Sleep(d time.Duration) {
122	<-fc.After(d)
123}
124
125// Time returns the current time of the fakeClock
126func (fc *fakeClock) Now() time.Time {
127	fc.l.RLock()
128	t := fc.time
129	fc.l.RUnlock()
130	return t
131}
132
133// Advance advances fakeClock to a new point in time, ensuring channels from any
134// previous invocations of After are notified appropriately before returning
135func (fc *fakeClock) Advance(d time.Duration) {
136	fc.l.Lock()
137	defer fc.l.Unlock()
138	end := fc.time.Add(d)
139	var newSleepers []*sleeper
140	for _, s := range fc.sleepers {
141		if end.Sub(s.until) >= 0 {
142			s.done <- end
143		} else {
144			newSleepers = append(newSleepers, s)
145		}
146	}
147	fc.sleepers = newSleepers
148	fc.blockers = notifyBlockers(fc.blockers, len(fc.sleepers))
149	fc.time = end
150}
151
152// BlockUntil will block until the fakeClock has the given number of sleepers
153// (callers of Sleep or After)
154func (fc *fakeClock) BlockUntil(n int) {
155	fc.l.Lock()
156	// Fast path: current number of sleepers is what we're looking for
157	if len(fc.sleepers) == n {
158		fc.l.Unlock()
159		return
160	}
161	// Otherwise, set up a new blocker
162	b := &blocker{
163		count: n,
164		ch:    make(chan struct{}),
165	}
166	fc.blockers = append(fc.blockers, b)
167	fc.l.Unlock()
168	<-b.ch
169}
170