1package fasthttp
2
3import (
4	"sync"
5	"sync/atomic"
6	"time"
7)
8
9// BalancingClient is the interface for clients, which may be passed
10// to LBClient.Clients.
11type BalancingClient interface {
12	DoDeadline(req *Request, resp *Response, deadline time.Time) error
13	PendingRequests() int
14}
15
16// LBClient balances requests among available LBClient.Clients.
17//
18// It has the following features:
19//
20//   - Balances load among available clients using 'least loaded' + 'round robin'
21//     hybrid technique.
22//   - Dynamically decreases load on unhealthy clients.
23//
24// It is forbidden copying LBClient instances. Create new instances instead.
25//
26// It is safe calling LBClient methods from concurrently running goroutines.
27type LBClient struct {
28	noCopy noCopy
29
30	// Clients must contain non-zero clients list.
31	// Incoming requests are balanced among these clients.
32	Clients []BalancingClient
33
34	// HealthCheck is a callback called after each request.
35	//
36	// The request, response and the error returned by the client
37	// is passed to HealthCheck, so the callback may determine whether
38	// the client is healthy.
39	//
40	// Load on the current client is decreased if HealthCheck returns false.
41	//
42	// By default HealthCheck returns false if err != nil.
43	HealthCheck func(req *Request, resp *Response, err error) bool
44
45	// Timeout is the request timeout used when calling LBClient.Do.
46	//
47	// DefaultLBClientTimeout is used by default.
48	Timeout time.Duration
49
50	cs []*lbClient
51
52	// nextIdx is for spreading requests among equally loaded clients
53	// in a round-robin fashion.
54	nextIdx uint32
55
56	once sync.Once
57}
58
59// DefaultLBClientTimeout is the default request timeout used by LBClient
60// when calling LBClient.Do.
61//
62// The timeout may be overridden via LBClient.Timeout.
63const DefaultLBClientTimeout = time.Second
64
65// DoDeadline calls DoDeadline on the least loaded client
66func (cc *LBClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
67	return cc.get().DoDeadline(req, resp, deadline)
68}
69
70// DoTimeout calculates deadline and calls DoDeadline on the least loaded client
71func (cc *LBClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
72	deadline := time.Now().Add(timeout)
73	return cc.get().DoDeadline(req, resp, deadline)
74}
75
76// Do calls calculates deadline using LBClient.Timeout and calls DoDeadline
77// on the least loaded client.
78func (cc *LBClient) Do(req *Request, resp *Response) error {
79	timeout := cc.Timeout
80	if timeout <= 0 {
81		timeout = DefaultLBClientTimeout
82	}
83	return cc.DoTimeout(req, resp, timeout)
84}
85
86func (cc *LBClient) init() {
87	if len(cc.Clients) == 0 {
88		panic("BUG: LBClient.Clients cannot be empty")
89	}
90	for _, c := range cc.Clients {
91		cc.cs = append(cc.cs, &lbClient{
92			c:           c,
93			healthCheck: cc.HealthCheck,
94		})
95	}
96
97	// Randomize nextIdx in order to prevent initial servers'
98	// hammering from a cluster of identical LBClients.
99	cc.nextIdx = uint32(time.Now().UnixNano())
100}
101
102func (cc *LBClient) get() *lbClient {
103	cc.once.Do(cc.init)
104
105	cs := cc.cs
106	idx := atomic.AddUint32(&cc.nextIdx, 1)
107	idx %= uint32(len(cs))
108
109	minC := cs[idx]
110	minN := minC.PendingRequests()
111	if minN == 0 {
112		return minC
113	}
114	for _, c := range cs[idx+1:] {
115		n := c.PendingRequests()
116		if n == 0 {
117			return c
118		}
119		if n < minN {
120			minC = c
121			minN = n
122		}
123	}
124	for _, c := range cs[:idx] {
125		n := c.PendingRequests()
126		if n == 0 {
127			return c
128		}
129		if n < minN {
130			minC = c
131			minN = n
132		}
133	}
134	return minC
135}
136
137type lbClient struct {
138	c           BalancingClient
139	healthCheck func(req *Request, resp *Response, err error) bool
140	penalty     uint32
141}
142
143func (c *lbClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
144	err := c.c.DoDeadline(req, resp, deadline)
145	if !c.isHealthy(req, resp, err) && c.incPenalty() {
146		// Penalize the client returning error, so the next requests
147		// are routed to another clients.
148		time.AfterFunc(penaltyDuration, c.decPenalty)
149	}
150	return err
151}
152
153func (c *lbClient) PendingRequests() int {
154	n := c.c.PendingRequests()
155	m := atomic.LoadUint32(&c.penalty)
156	return n + int(m)
157}
158
159func (c *lbClient) isHealthy(req *Request, resp *Response, err error) bool {
160	if c.healthCheck == nil {
161		return err == nil
162	}
163	return c.healthCheck(req, resp, err)
164}
165
166func (c *lbClient) incPenalty() bool {
167	m := atomic.AddUint32(&c.penalty, 1)
168	if m > maxPenalty {
169		c.decPenalty()
170		return false
171	}
172	return true
173}
174
175func (c *lbClient) decPenalty() {
176	atomic.AddUint32(&c.penalty, ^uint32(0))
177}
178
179const (
180	maxPenalty = 300
181
182	penaltyDuration = 3 * time.Second
183)
184