1package hostpool
2
3import (
4	"log"
5	"math/rand"
6	"time"
7)
8
9type epsilonHostPoolResponse struct {
10	standardHostPoolResponse
11	started time.Time
12	ended   time.Time
13}
14
15func (r *epsilonHostPoolResponse) Mark(err error) {
16	r.Do(func() {
17		r.ended = time.Now()
18		doMark(err, r)
19	})
20}
21
22type epsilonGreedyHostPool struct {
23	standardHostPool               // TODO - would be nifty if we could embed HostPool and Locker interfaces
24	epsilon                float32 // this is our exploration factor
25	decayDuration          time.Duration
26	EpsilonValueCalculator // embed the epsilonValueCalculator
27	timer
28	quit chan bool
29}
30
31// Construct an Epsilon Greedy HostPool
32//
33// Epsilon Greedy is an algorithm that allows HostPool not only to track failure state,
34// but also to learn about "better" options in terms of speed, and to pick from available hosts
35// based on how well they perform. This gives a weighted request rate to better
36// performing hosts, while still distributing requests to all hosts (proportionate to their performance).
37// The interface is the same as the standard HostPool, but be sure to mark the HostResponse immediately
38// after executing the request to the host, as that will stop the implicitly running request timer.
39//
40// A good overview of Epsilon Greedy is here http://stevehanov.ca/blog/index.php?id=132
41//
42// To compute the weighting scores, we perform a weighted average of recent response times, over the course of
43// `decayDuration`. decayDuration may be set to 0 to use the default value of 5 minutes
44// We then use the supplied EpsilonValueCalculator to calculate a score from that weighted average response time.
45func NewEpsilonGreedy(hosts []string, decayDuration time.Duration, calc EpsilonValueCalculator) HostPool {
46
47	if decayDuration <= 0 {
48		decayDuration = defaultDecayDuration
49	}
50	stdHP := New(hosts).(*standardHostPool)
51	p := &epsilonGreedyHostPool{
52		standardHostPool:       *stdHP,
53		epsilon:                float32(initialEpsilon),
54		decayDuration:          decayDuration,
55		EpsilonValueCalculator: calc,
56		timer: &realTimer{},
57		quit:  make(chan bool),
58	}
59
60	// allocate structures
61	for _, h := range p.hostList {
62		h.epsilonCounts = make([]int64, epsilonBuckets)
63		h.epsilonValues = make([]int64, epsilonBuckets)
64	}
65	go p.epsilonGreedyDecay()
66	return p
67}
68
69func (p *epsilonGreedyHostPool) Close() {
70	// No need to do p.quit <- true as close(p.quit) does the trick.
71	close(p.quit)
72}
73
74func (p *epsilonGreedyHostPool) SetEpsilon(newEpsilon float32) {
75	p.Lock()
76	defer p.Unlock()
77	p.epsilon = newEpsilon
78}
79
80func (p *epsilonGreedyHostPool) SetHosts(hosts []string) {
81	p.Lock()
82	defer p.Unlock()
83	p.standardHostPool.setHosts(hosts)
84	for _, h := range p.hostList {
85		h.epsilonCounts = make([]int64, epsilonBuckets)
86		h.epsilonValues = make([]int64, epsilonBuckets)
87	}
88}
89
90func (p *epsilonGreedyHostPool) epsilonGreedyDecay() {
91	durationPerBucket := p.decayDuration / epsilonBuckets
92	ticker := time.NewTicker(durationPerBucket)
93	for {
94		select {
95		case <-p.quit:
96			ticker.Stop()
97			return
98		case <-ticker.C:
99			p.performEpsilonGreedyDecay()
100		}
101	}
102}
103func (p *epsilonGreedyHostPool) performEpsilonGreedyDecay() {
104	p.Lock()
105	for _, h := range p.hostList {
106		h.epsilonIndex += 1
107		h.epsilonIndex = h.epsilonIndex % epsilonBuckets
108		h.epsilonCounts[h.epsilonIndex] = 0
109		h.epsilonValues[h.epsilonIndex] = 0
110	}
111	p.Unlock()
112}
113
114func (p *epsilonGreedyHostPool) Get() HostPoolResponse {
115	p.Lock()
116	defer p.Unlock()
117	host := p.getEpsilonGreedy()
118	if host == "" {
119		return nil
120	}
121
122	started := time.Now()
123	return &epsilonHostPoolResponse{
124		standardHostPoolResponse: standardHostPoolResponse{host: host, pool: p},
125		started:                  started,
126	}
127}
128
129func (p *epsilonGreedyHostPool) getEpsilonGreedy() string {
130	var hostToUse *hostEntry
131
132	// this is our exploration phase
133	if rand.Float32() < p.epsilon {
134		p.epsilon = p.epsilon * epsilonDecay
135		if p.epsilon < minEpsilon {
136			p.epsilon = minEpsilon
137		}
138		return p.getRoundRobin()
139	}
140
141	// calculate values for each host in the 0..1 range (but not ormalized)
142	var possibleHosts []*hostEntry
143	now := time.Now()
144	var sumValues float64
145	for _, h := range p.hostList {
146		if h.canTryHost(now) {
147			v := h.getWeightedAverageResponseTime()
148			if v > 0 {
149				ev := p.CalcValueFromAvgResponseTime(v)
150				h.epsilonValue = ev
151				sumValues += ev
152				possibleHosts = append(possibleHosts, h)
153			}
154		}
155	}
156
157	if len(possibleHosts) != 0 {
158		// now normalize to the 0..1 range to get a percentage
159		for _, h := range possibleHosts {
160			h.epsilonPercentage = h.epsilonValue / sumValues
161		}
162
163		// do a weighted random choice among hosts
164		ceiling := 0.0
165		pickPercentage := rand.Float64()
166		for _, h := range possibleHosts {
167			ceiling += h.epsilonPercentage
168			if pickPercentage <= ceiling {
169				hostToUse = h
170				break
171			}
172		}
173	}
174
175	if hostToUse == nil {
176		if len(possibleHosts) != 0 {
177			log.Println("Failed to randomly choose a host, Dan loses")
178		}
179
180		return p.getRoundRobin()
181	}
182
183	if hostToUse.dead {
184		hostToUse.willRetryHost(p.maxRetryInterval)
185	}
186	return hostToUse.host
187}
188
189func (p *epsilonGreedyHostPool) markSuccess(hostR HostPoolResponse) {
190	// first do the base markSuccess - a little redundant with host lookup but cleaner than repeating logic
191	p.standardHostPool.markSuccess(hostR)
192	eHostR, ok := hostR.(*epsilonHostPoolResponse)
193	if !ok {
194		log.Printf("Incorrect type in eps markSuccess!") // TODO reflection to print out offending type
195		return
196	}
197	host := eHostR.host
198	duration := p.between(eHostR.started, eHostR.ended)
199
200	p.Lock()
201	defer p.Unlock()
202	h, ok := p.hosts[host]
203	if !ok {
204		log.Fatalf("host %s not in HostPool %v", host, p.Hosts())
205	}
206	h.epsilonCounts[h.epsilonIndex]++
207	h.epsilonValues[h.epsilonIndex] += int64(duration.Seconds() * 1000)
208}
209
210// --- timer: this just exists for testing
211
212type timer interface {
213	between(time.Time, time.Time) time.Duration
214}
215
216type realTimer struct{}
217
218func (rt *realTimer) between(start time.Time, end time.Time) time.Duration {
219	return end.Sub(start)
220}
221