1package fairshare
2
3import (
4	"container/list"
5	"fmt"
6	"io/ioutil"
7	"math"
8	"sync"
9	"time"
10
11	"github.com/armon/go-metrics"
12	log "github.com/hashicorp/go-hclog"
13	uuid "github.com/hashicorp/go-uuid"
14	"github.com/hashicorp/vault/helper/metricsutil"
15	"github.com/hashicorp/vault/sdk/helper/logging"
16)
17
18type JobManager struct {
19	name   string
20	queues map[string]*list.List
21
22	quit    chan struct{}
23	newWork chan struct{} // must be buffered
24
25	workerPool  *dispatcher
26	workerCount map[string]int
27
28	onceStart sync.Once
29	onceStop  sync.Once
30
31	logger log.Logger
32
33	totalJobs  int
34	metricSink *metricsutil.ClusterMetricSink
35
36	// waitgroup for testing stop functionality
37	wg sync.WaitGroup
38
39	// protects `queues`, `workerCount`, `queuesIndex`, `lastQueueAccessed`
40	l sync.RWMutex
41
42	// track queues by index for round robin worker assignment
43	queuesIndex       []string
44	lastQueueAccessed int
45}
46
47// NewJobManager creates a job manager, with an optional name
48func NewJobManager(name string, numWorkers int, l log.Logger, metricSink *metricsutil.ClusterMetricSink) *JobManager {
49	if l == nil {
50		l = logging.NewVaultLoggerWithWriter(ioutil.Discard, log.NoLevel)
51	}
52	if name == "" {
53		guid, err := uuid.GenerateUUID()
54		if err != nil {
55			l.Warn("uuid generator failed, using 'no-uuid'", "err", err)
56			guid = "no-uuid"
57		}
58
59		name = fmt.Sprintf("jobmanager-%s", guid)
60	}
61
62	wp := newDispatcher(fmt.Sprintf("%s-dispatcher", name), numWorkers, l)
63
64	j := JobManager{
65		name:              name,
66		queues:            make(map[string]*list.List),
67		quit:              make(chan struct{}),
68		newWork:           make(chan struct{}, 1),
69		workerPool:        wp,
70		workerCount:       make(map[string]int),
71		logger:            l,
72		metricSink:        metricSink,
73		queuesIndex:       make([]string, 0),
74		lastQueueAccessed: -1,
75	}
76
77	j.logger.Trace("created job manager", "name", name, "pool_size", numWorkers)
78	return &j
79}
80
81// Start starts the job manager
82// note: a given job manager cannot be restarted after it has been stopped
83func (j *JobManager) Start() {
84	j.onceStart.Do(func() {
85		j.logger.Trace("starting job manager", "name", j.name)
86		j.workerPool.start()
87		j.assignWork()
88	})
89}
90
91// Stop stops the job manager asynchronously
92func (j *JobManager) Stop() {
93	j.onceStop.Do(func() {
94		j.logger.Trace("terminating job manager...")
95		close(j.quit)
96		j.workerPool.stop()
97	})
98}
99
100// AddJob adds a job to the given queue, creating the queue if it doesn't exist
101func (j *JobManager) AddJob(job Job, queueID string) {
102	j.l.Lock()
103	if len(j.queues) == 0 {
104		defer func() {
105			// newWork must be buffered to avoid deadlocks if work is added
106			// before the job manager is started
107			j.newWork <- struct{}{}
108		}()
109	}
110	defer j.l.Unlock()
111
112	if _, ok := j.queues[queueID]; !ok {
113		j.addQueue(queueID)
114	}
115
116	j.queues[queueID].PushBack(job)
117	j.totalJobs++
118
119	if j.metricSink != nil {
120		j.metricSink.AddSampleWithLabels([]string{j.name, "job_manager", "queue_length"}, float32(j.queues[queueID].Len()), []metrics.Label{{"queue_id", queueID}})
121		j.metricSink.AddSample([]string{j.name, "job_manager", "total_jobs"}, float32(j.totalJobs))
122	}
123}
124
125// GetCurrentJobCount returns the total number of pending jobs in the job manager
126func (j *JobManager) GetPendingJobCount() int {
127	j.l.RLock()
128	defer j.l.RUnlock()
129
130	cnt := 0
131	for _, q := range j.queues {
132		cnt += q.Len()
133	}
134
135	return cnt
136}
137
138// GetWorkerCounts() returns a map of queue ID to number of active workers
139func (j *JobManager) GetWorkerCounts() map[string]int {
140	j.l.RLock()
141	defer j.l.RUnlock()
142	return j.workerCount
143}
144
145// GetWorkQueueLengths() returns a map of queue ID to number of jobs in the queue
146func (j *JobManager) GetWorkQueueLengths() map[string]int {
147	out := make(map[string]int)
148
149	j.l.RLock()
150	defer j.l.RUnlock()
151
152	for k, v := range j.queues {
153		out[k] = v.Len()
154	}
155
156	return out
157}
158
159// getNextJob pops the next job to be processed and prunes empty queues
160// it also returns the ID of the queue the job is associated with
161func (j *JobManager) getNextJob() (Job, string) {
162	j.l.Lock()
163	defer j.l.Unlock()
164
165	if len(j.queues) == 0 {
166		return nil, ""
167	}
168
169	queueID, canAssignWorker := j.getNextQueue()
170	if !canAssignWorker {
171		return nil, ""
172	}
173
174	jobElement := j.queues[queueID].Front()
175	jobRaw := j.queues[queueID].Remove(jobElement)
176
177	j.totalJobs--
178
179	if j.metricSink != nil {
180		j.metricSink.AddSampleWithLabels([]string{j.name, "job_manager", "queue_length"}, float32(j.queues[queueID].Len()), []metrics.Label{{"queue_id", queueID}})
181		j.metricSink.AddSample([]string{j.name, "job_manager", "total_jobs"}, float32(j.totalJobs))
182	}
183
184	if j.queues[queueID].Len() == 0 {
185		// we remove the empty queue, but we don't remove the worker count
186		// in case we are still working on previous jobs from this queue.
187		// worker count cleanup is handled in j.decrementWorkerCount
188		j.removeLastQueueAccessed()
189	}
190
191	return jobRaw.(Job), queueID
192}
193
194// returns the next queue to assign work from, and a bool if there is a queue
195// that can have a worker assigned. if there is work to be assigned,
196// j.lastQueueAccessed will be updated to that queue.
197// note: this must be called with j.l held
198func (j *JobManager) getNextQueue() (string, bool) {
199	var nextQueue string
200	var canAssignWorker bool
201
202	// ensure we loop through all existing queues until we find an eligible
203	// queue, if one exists.
204	queueIdx := j.nextQueueIndex(j.lastQueueAccessed)
205	for i := 0; i < len(j.queuesIndex); i++ {
206		potentialQueueID := j.queuesIndex[queueIdx]
207
208		if !j.queueWorkersSaturated(potentialQueueID) {
209			nextQueue = potentialQueueID
210			canAssignWorker = true
211			j.lastQueueAccessed = queueIdx
212			break
213		}
214
215		queueIdx = j.nextQueueIndex(queueIdx)
216	}
217
218	return nextQueue, canAssignWorker
219}
220
221// get the index of the next queue in round-robin order
222// note: this must be called with j.l held
223func (j *JobManager) nextQueueIndex(currentIdx int) int {
224	return (currentIdx + 1) % len(j.queuesIndex)
225}
226
227// returns true if there are already too many workers on this queue
228// note: this must be called with j.l held (at least for read).
229// note: we may want to eventually factor in queue length relative to num queues
230func (j *JobManager) queueWorkersSaturated(queueID string) bool {
231	numActiveQueues := float64(len(j.queues))
232	numTotalWorkers := float64(j.workerPool.numWorkers)
233	maxWorkersPerQueue := math.Ceil(0.9 * numTotalWorkers / numActiveQueues)
234
235	numWorkersPerQueue := j.workerCount
236
237	return numWorkersPerQueue[queueID] >= int(maxWorkersPerQueue)
238}
239
240// increment the worker count for this queue
241func (j *JobManager) incrementWorkerCount(queueID string) {
242	j.l.Lock()
243	defer j.l.Unlock()
244
245	j.workerCount[queueID]++
246}
247
248// decrement the worker count for this queue
249// this also removes worker tracking for this queue if needed
250func (j *JobManager) decrementWorkerCount(queueID string) {
251	j.l.Lock()
252	defer j.l.Unlock()
253
254	j.workerCount[queueID]--
255
256	_, queueExists := j.queues[queueID]
257	if !queueExists && j.workerCount[queueID] < 1 {
258		delete(j.workerCount, queueID)
259	}
260}
261
262// assignWork continually loops checks for new jobs and dispatches them to the
263// worker pool
264func (j *JobManager) assignWork() {
265	j.wg.Add(1)
266
267	go func() {
268		for {
269			for {
270				// assign work while there are jobs to distribute
271				select {
272				case <-j.quit:
273					j.wg.Done()
274					return
275				case <-j.newWork:
276					// keep the channel empty since we're already processing work
277				default:
278				}
279
280				job, queueID := j.getNextJob()
281				if job != nil {
282					j.workerPool.dispatch(job,
283						func() {
284							j.incrementWorkerCount(queueID)
285						},
286						func() {
287							j.decrementWorkerCount(queueID)
288						})
289				} else {
290					break
291				}
292			}
293
294			select {
295			case <-j.quit:
296				j.wg.Done()
297				return
298			case <-j.newWork:
299				// listen for wake-up when an empty job manager has been given work
300			case <-time.After(50 * time.Millisecond):
301				// periodically check if new workers can be assigned. with the
302				// fairsharing worker distribution it can be the case that there
303				// is work waiting, but no queues are eligible for another worker
304			}
305		}
306	}()
307}
308
309// addQueue generates a new queue if a queue for `queueID` doesn't exist
310// it also starts tracking workers on that queue, if not already tracked
311// note: this must be called with j.l held for write
312func (j *JobManager) addQueue(queueID string) {
313	if _, ok := j.queues[queueID]; !ok {
314		j.queues[queueID] = list.New()
315		j.queuesIndex = append(j.queuesIndex, queueID)
316	}
317
318	// it's possible the queue ran out of work and was pruned, but there were
319	// still workers operating on data formerly in that queue, which were still
320	// being tracked. if that is the case, we don't want to wipe out that worker
321	// count when the queue is re-initialized.
322	if _, ok := j.workerCount[queueID]; !ok {
323		j.workerCount[queueID] = 0
324	}
325}
326
327// removes the queue and index tracker for the last queue accessed.
328// it is to be used when the last queue accessed has emptied.
329// note: this must be called with j.l held.
330func (j *JobManager) removeLastQueueAccessed() {
331	if j.lastQueueAccessed == -1 || j.lastQueueAccessed > len(j.queuesIndex)-1 {
332		j.logger.Warn("call to remove queue out of bounds", "idx", j.lastQueueAccessed)
333		return
334	}
335
336	queueID := j.queuesIndex[j.lastQueueAccessed]
337
338	// remove the queue
339	delete(j.queues, queueID)
340
341	// remove the index for the queue
342	j.queuesIndex = append(j.queuesIndex[:j.lastQueueAccessed], j.queuesIndex[j.lastQueueAccessed+1:]...)
343
344	// correct the last queue accessed for round robining
345	if j.lastQueueAccessed > 0 {
346		j.lastQueueAccessed--
347	} else {
348		j.lastQueueAccessed = len(j.queuesIndex) - 1
349	}
350}
351