1package queue
2
3import (
4	"context"
5	"sync"
6	"time"
7
8	"github.com/grafana/dskit/services"
9	"github.com/pkg/errors"
10	"github.com/prometheus/client_golang/prometheus"
11	"go.uber.org/atomic"
12)
13
14const (
15	// How frequently to check for disconnected queriers that should be forgotten.
16	forgetCheckPeriod = 5 * time.Second
17)
18
19var (
20	ErrTooManyRequests = errors.New("too many outstanding requests")
21	ErrStopped         = errors.New("queue is stopped")
22)
23
24// UserIndex is opaque type that allows to resume iteration over users between successive calls
25// of RequestQueue.GetNextRequestForQuerier method.
26type UserIndex struct {
27	last int
28}
29
30// Modify index to start iteration on the same user, for which last queue was returned.
31func (ui UserIndex) ReuseLastUser() UserIndex {
32	if ui.last >= 0 {
33		return UserIndex{last: ui.last - 1}
34	}
35	return ui
36}
37
38// FirstUser returns UserIndex that starts iteration over user queues from the very first user.
39func FirstUser() UserIndex {
40	return UserIndex{last: -1}
41}
42
43// Request stored into the queue.
44type Request interface{}
45
46// RequestQueue holds incoming requests in per-user queues. It also assigns each user specified number of queriers,
47// and when querier asks for next request to handle (using GetNextRequestForQuerier), it returns requests
48// in a fair fashion.
49type RequestQueue struct {
50	services.Service
51
52	connectedQuerierWorkers *atomic.Int32
53
54	mtx     sync.Mutex
55	cond    *sync.Cond // Notified when request is enqueued or dequeued, or querier is disconnected.
56	queues  *queues
57	stopped bool
58
59	queueLength       *prometheus.GaugeVec   // Per user and reason.
60	discardedRequests *prometheus.CounterVec // Per user.
61}
62
63func NewRequestQueue(maxOutstandingPerTenant int, forgetDelay time.Duration, queueLength *prometheus.GaugeVec, discardedRequests *prometheus.CounterVec) *RequestQueue {
64	q := &RequestQueue{
65		queues:                  newUserQueues(maxOutstandingPerTenant, forgetDelay),
66		connectedQuerierWorkers: atomic.NewInt32(0),
67		queueLength:             queueLength,
68		discardedRequests:       discardedRequests,
69	}
70
71	q.cond = sync.NewCond(&q.mtx)
72	q.Service = services.NewTimerService(forgetCheckPeriod, nil, q.forgetDisconnectedQueriers, q.stopping).WithName("request queue")
73
74	return q
75}
76
77// EnqueueRequest puts the request into the queue. MaxQueries is user-specific value that specifies how many queriers can
78// this user use (zero or negative = all queriers). It is passed to each EnqueueRequest, because it can change
79// between calls.
80//
81// If request is successfully enqueued, successFn is called with the lock held, before any querier can receive the request.
82func (q *RequestQueue) EnqueueRequest(userID string, req Request, maxQueriers int, successFn func()) error {
83	q.mtx.Lock()
84	defer q.mtx.Unlock()
85
86	if q.stopped {
87		return ErrStopped
88	}
89
90	queue := q.queues.getOrAddQueue(userID, maxQueriers)
91	if queue == nil {
92		// This can only happen if userID is "".
93		return errors.New("no queue found")
94	}
95
96	select {
97	case queue <- req:
98		q.queueLength.WithLabelValues(userID).Inc()
99		q.cond.Broadcast()
100		// Call this function while holding a lock. This guarantees that no querier can fetch the request before function returns.
101		if successFn != nil {
102			successFn()
103		}
104		return nil
105	default:
106		q.discardedRequests.WithLabelValues(userID).Inc()
107		return ErrTooManyRequests
108	}
109}
110
111// GetNextRequestForQuerier find next user queue and takes the next request off of it. Will block if there are no requests.
112// By passing user index from previous call of this method, querier guarantees that it iterates over all users fairly.
113// If querier finds that request from the user is already expired, it can get a request for the same user by using UserIndex.ReuseLastUser.
114func (q *RequestQueue) GetNextRequestForQuerier(ctx context.Context, last UserIndex, querierID string) (Request, UserIndex, error) {
115	q.mtx.Lock()
116	defer q.mtx.Unlock()
117
118	querierWait := false
119
120FindQueue:
121	// We need to wait if there are no users, or no pending requests for given querier.
122	for (q.queues.len() == 0 || querierWait) && ctx.Err() == nil && !q.stopped {
123		querierWait = false
124		q.cond.Wait()
125	}
126
127	if q.stopped {
128		return nil, last, ErrStopped
129	}
130
131	if err := ctx.Err(); err != nil {
132		return nil, last, err
133	}
134
135	for {
136		queue, userID, idx := q.queues.getNextQueueForQuerier(last.last, querierID)
137		last.last = idx
138		if queue == nil {
139			break
140		}
141
142		// Pick next request from the queue.
143		for {
144			request := <-queue
145			if len(queue) == 0 {
146				q.queues.deleteQueue(userID)
147			}
148
149			q.queueLength.WithLabelValues(userID).Dec()
150
151			// Tell close() we've processed a request.
152			q.cond.Broadcast()
153
154			return request, last, nil
155		}
156	}
157
158	// There are no unexpired requests, so we can get back
159	// and wait for more requests.
160	querierWait = true
161	goto FindQueue
162}
163
164func (q *RequestQueue) forgetDisconnectedQueriers(_ context.Context) error {
165	q.mtx.Lock()
166	defer q.mtx.Unlock()
167
168	if q.queues.forgetDisconnectedQueriers(time.Now()) > 0 {
169		// We need to notify goroutines cause having removed some queriers
170		// may have caused a resharding.
171		q.cond.Broadcast()
172	}
173
174	return nil
175}
176
177func (q *RequestQueue) stopping(_ error) error {
178	q.mtx.Lock()
179	defer q.mtx.Unlock()
180
181	for q.queues.len() > 0 && q.connectedQuerierWorkers.Load() > 0 {
182		q.cond.Wait()
183	}
184
185	// Only stop after dispatching enqueued requests.
186	q.stopped = true
187
188	// If there are still goroutines in GetNextRequestForQuerier method, they get notified.
189	q.cond.Broadcast()
190
191	return nil
192}
193
194func (q *RequestQueue) RegisterQuerierConnection(querier string) {
195	q.connectedQuerierWorkers.Inc()
196
197	q.mtx.Lock()
198	defer q.mtx.Unlock()
199	q.queues.addQuerierConnection(querier)
200}
201
202func (q *RequestQueue) UnregisterQuerierConnection(querier string) {
203	q.connectedQuerierWorkers.Dec()
204
205	q.mtx.Lock()
206	defer q.mtx.Unlock()
207	q.queues.removeQuerierConnection(querier, time.Now())
208}
209
210func (q *RequestQueue) NotifyQuerierShutdown(querierID string) {
211	q.mtx.Lock()
212	defer q.mtx.Unlock()
213	q.queues.notifyQuerierShutdown(querierID)
214}
215
216// When querier is waiting for next request, this unblocks the method.
217func (q *RequestQueue) QuerierDisconnecting() {
218	q.cond.Broadcast()
219}
220
221func (q *RequestQueue) GetConnectedQuerierWorkersMetric() float64 {
222	return float64(q.connectedQuerierWorkers.Load())
223}
224