1package clockwork
2
3import (
4	"sync"
5	"time"
6)
7
8// Clock provides an interface that packages can use instead of directly
9// using the time module, so that chronology-related behavior can be tested
10type Clock interface {
11	After(d time.Duration) <-chan time.Time
12	Sleep(d time.Duration)
13	Now() time.Time
14	Since(t time.Time) time.Duration
15
16	// AfterTime sends `now` on the returned channel after the given
17	// time t has been surpassed. This is an addition to the standard
18	// Clock interface, to alleviate race conditions with the fake clock.
19	AfterTime(t time.Time) <-chan time.Time
20}
21
22// FakeClock provides an interface for a clock which can be
23// manually advanced through time
24type FakeClock interface {
25	Clock
26	// Advance advances the FakeClock to a new point in time, ensuring any existing
27	// sleepers are notified appropriately before returning
28	Advance(d time.Duration)
29	// BlockUntil will block until the FakeClock has the given number of
30	// sleepers (callers of Sleep or After)
31	BlockUntil(n int)
32}
33
34// NewRealClock returns a Clock which simply delegates calls to the actual time
35// package; it should be used by packages in production.
36func NewRealClock() Clock {
37	return &realClock{}
38}
39
40// NewFakeClock returns a FakeClock implementation which can be
41// manually advanced through time for testing. The initial time of the
42// FakeClock will be an arbitrary non-zero time.
43func NewFakeClock() FakeClock {
44	// use a fixture that does not fulfill Time.IsZero()
45	return NewFakeClockAt(time.Date(1984, time.April, 4, 0, 0, 0, 0, time.UTC))
46}
47
48// NewFakeClockAt returns a FakeClock initialised at the given time.Time.
49func NewFakeClockAt(t time.Time) FakeClock {
50	return &fakeClock{
51		time: t,
52	}
53}
54
55type realClock struct{}
56
57func (rc *realClock) After(d time.Duration) <-chan time.Time {
58	return time.After(d)
59}
60
61func (rc *realClock) Sleep(d time.Duration) {
62	time.Sleep(d)
63}
64
65func (rc *realClock) Now() time.Time {
66	return time.Now()
67}
68
69func (rc *realClock) Since(t time.Time) time.Duration {
70	return rc.Now().Sub(t)
71}
72
73// AfterTime sends the current time after the time t has been surpassed.
74// For the real clock, just computes `(t - Now())` and sleeps that long
75// via a call to `After`.
76func (rc *realClock) AfterTime(t time.Time) <-chan time.Time {
77	now := rc.Now()
78	var dur time.Duration
79	if t.After(now) {
80		dur = t.Sub(now)
81	}
82	return rc.After(dur)
83}
84
85type fakeClock struct {
86	sleepers []*sleeper
87	blockers []*blocker
88	time     time.Time
89
90	l sync.RWMutex
91}
92
93// sleeper represents a caller of After or Sleep
94type sleeper struct {
95	until time.Time
96	done  chan time.Time
97}
98
99// blocker represents a caller of BlockUntil
100type blocker struct {
101	count int
102	ch    chan struct{}
103}
104
105// After mimics time.After; it waits for the given duration to elapse on the
106// fakeClock, then sends the current time on the returned channel.
107func (fc *fakeClock) After(d time.Duration) <-chan time.Time {
108	fc.l.Lock()
109	defer fc.l.Unlock()
110	now := fc.time
111	done := make(chan time.Time, 1)
112	if d.Nanoseconds() == 0 {
113		// special case - trigger immediately
114		done <- now
115	} else {
116		// otherwise, add to the set of sleepers
117		s := &sleeper{
118			until: now.Add(d),
119			done:  done,
120		}
121		fc.sleepers = append(fc.sleepers, s)
122		// and notify any blockers
123		fc.blockers = notifyBlockers(fc.blockers, len(fc.sleepers))
124	}
125	return done
126}
127
128// AfterTime, when called with time `t`, is similar to calling `After(t.Sub(Now()))`.
129// Crucially, it is race-free, and therefore can be called from one go-routine
130// while another is calling `Advance`.
131func (fc *fakeClock) AfterTime(t time.Time) <-chan time.Time {
132	fc.l.Lock()
133	defer fc.l.Unlock()
134	done := make(chan time.Time, 1)
135	if now := fc.time; !t.After(now) {
136		done <- now
137	} else {
138		s := &sleeper{
139			until: t,
140			done:  done,
141		}
142		fc.sleepers = append(fc.sleepers, s)
143		// and notify any blockers
144		fc.blockers = notifyBlockers(fc.blockers, len(fc.sleepers))
145	}
146	return done
147}
148
149// notifyBlockers notifies all the blockers waiting until the
150// given number of sleepers are waiting on the fakeClock. It
151// returns an updated slice of blockers (i.e. those still waiting)
152func notifyBlockers(blockers []*blocker, count int) (newBlockers []*blocker) {
153	for _, b := range blockers {
154		if b.count == count {
155			close(b.ch)
156		} else {
157			newBlockers = append(newBlockers, b)
158		}
159	}
160	return
161}
162
163// Sleep blocks until the given duration has passed on the fakeClock
164func (fc *fakeClock) Sleep(d time.Duration) {
165	<-fc.After(d)
166}
167
168// Time returns the current time of the fakeClock
169func (fc *fakeClock) Now() time.Time {
170	fc.l.RLock()
171	t := fc.time
172	fc.l.RUnlock()
173	return t
174}
175
176// Since returns the duration that has passed since the given time on the fakeClock
177func (fc *fakeClock) Since(t time.Time) time.Duration {
178	return fc.Now().Sub(t)
179}
180
181// Advance advances fakeClock to a new point in time, ensuring channels from any
182// previous invocations of After are notified appropriately before returning
183func (fc *fakeClock) Advance(d time.Duration) {
184	fc.l.Lock()
185	defer fc.l.Unlock()
186	end := fc.time.Add(d)
187	var newSleepers []*sleeper
188	for _, s := range fc.sleepers {
189		if end.Sub(s.until) >= 0 {
190			s.done <- end
191		} else {
192			newSleepers = append(newSleepers, s)
193		}
194	}
195	fc.sleepers = newSleepers
196	fc.blockers = notifyBlockers(fc.blockers, len(fc.sleepers))
197	fc.time = end
198}
199
200// BlockUntil will block until the fakeClock has the given number of sleepers
201// (callers of Sleep or After)
202func (fc *fakeClock) BlockUntil(n int) {
203	fc.l.Lock()
204	// Fast path: current number of sleepers is what we're looking for
205	if len(fc.sleepers) == n {
206		fc.l.Unlock()
207		return
208	}
209	// Otherwise, set up a new blocker
210	b := &blocker{
211		count: n,
212		ch:    make(chan struct{}),
213	}
214	fc.blockers = append(fc.blockers, b)
215	fc.l.Unlock()
216	<-b.ch
217}
218