1# Copyright (C) 2015-2021 Regents of the University of California
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import json
15import logging
16import os
17import time
18from collections import defaultdict
19from itertools import islice
20
21from toil.batchSystems.abstractBatchSystem import (AbstractScalableBatchSystem,
22                                                   NodeInfo)
23from toil.common import defaultTargetTime
24from toil.job import ServiceJobDescription
25from toil.lib.retry import old_retry
26from toil.lib.threading import ExceptionalThread
27from toil.lib.throttle import throttle
28from toil.provisioners.abstractProvisioner import Shape
29
30logger = logging.getLogger(__name__)
31
32
33class BinPackedFit(object):
34    """
35    If jobShapes is a set of tasks with run requirements (mem/disk/cpu), and nodeShapes is a sorted
36    list of available computers to run these jobs on, this function attempts to return a dictionary
37    representing the minimum set of computerNode computers needed to run the tasks in jobShapes.
38
39    Uses a first fit decreasing (FFD) bin packing like algorithm to calculate an approximate minimum
40    number of nodes that will fit the given list of jobs.  BinPackingFit assumes the ordered list,
41    nodeShapes, is ordered for "node preference" outside of BinPackingFit beforehand. So when
42    virtually "creating" nodes, the first node within nodeShapes that fits the job is the one
43    that's added.
44
45    :param list nodeShapes: The properties of an atomic node allocation, in terms of wall-time,
46                            memory, cores, disk, and whether it is preemptable or not.
47    :param targetTime: The time before which all jobs should at least be started.
48
49    :returns: The minimum number of minimal node allocations estimated to be required to run all
50              the jobs in jobShapes.
51    """
52    def __init__(self, nodeShapes, targetTime=defaultTargetTime):
53        self.nodeShapes = sorted(nodeShapes)
54        self.targetTime = targetTime
55        self.nodeReservations = {nodeShape:[] for nodeShape in nodeShapes}
56
57    def binPack(self, jobShapes):
58        """Pack a list of jobShapes into the fewest nodes reasonable. Can be run multiple times."""
59        # TODO: Check for redundancy with batchsystems.mesos.JobQueue() sorting
60        logger.debug('Running bin packing for node shapes %s and %s job(s).',
61                     self.nodeShapes, len(jobShapes))
62        # Sort in descending order from largest to smallest. The FFD like-strategy will pack the
63        # jobs in order from longest to shortest.
64        jobShapes.sort()
65        jobShapes.reverse()
66        assert len(jobShapes) == 0 or jobShapes[0] >= jobShapes[-1]
67        for jS in jobShapes:
68            self.addJobShape(jS)
69
70    def addJobShape(self, jobShape):
71        """
72        Function adds the job to the first node reservation in which it will fit (this is the
73        bin-packing aspect).
74        """
75        chosenNodeShape = None
76        for nodeShape in self.nodeShapes:
77            if NodeReservation(nodeShape).fits(jobShape):
78                # This node shape is the first that fits this jobShape
79                chosenNodeShape = nodeShape
80                break
81
82        if chosenNodeShape is None:
83            logger.warning("Couldn't fit job with requirements %r into any nodes in the nodeTypes "
84                           "list." % jobShape)
85            return
86
87        # grab current list of job objects appended to this instance type
88        nodeReservations = self.nodeReservations[chosenNodeShape]
89        for nodeReservation in nodeReservations:
90            if nodeReservation.attemptToAddJob(jobShape, chosenNodeShape, self.targetTime):
91                # We succeeded adding the job to this node reservation. Now we're done.
92                return
93
94        reservation = NodeReservation(chosenNodeShape)
95        currentTimeAllocated = chosenNodeShape.wallTime
96        adjustEndingReservationForJob(reservation, jobShape, 0)
97        self.nodeReservations[chosenNodeShape].append(reservation)
98
99        # Extend the reservation if necessary to cover the job's entire runtime.
100        while currentTimeAllocated < jobShape.wallTime:
101            extendThisReservation = NodeReservation(reservation.shape)
102            currentTimeAllocated += chosenNodeShape.wallTime
103            reservation.nReservation = extendThisReservation
104            reservation = extendThisReservation
105
106    def getRequiredNodes(self):
107        """
108        Returns a dict from node shape to number of nodes required to run the packed jobs.
109        """
110        return {nodeShape:len(self.nodeReservations[nodeShape]) for nodeShape in self.nodeShapes}
111
112class NodeReservation(object):
113    """
114    Represents a node "reservation": the amount of resources that we
115    expect to be available on a given node at each point in time. To
116    represent the resources available in a reservation, we represent a
117    reservation as a linked list of NodeReservations, each giving the
118    resources free within a single timeslice.
119    """
120    def __init__(self, shape):
121        # The wall-time of this slice and resources available in this timeslice
122        self.shape = shape
123        # The next portion of the reservation (None if this is the end)
124        self.nReservation = None
125
126    def __str__(self):
127        return "-------------------\n" \
128               "Current Reservation\n" \
129               "-------------------\n" \
130               "Shape wallTime: %s\n" \
131               "Shape memory: %s\n" \
132               "Shape cores: %s\n" \
133               "Shape disk: %s\n" \
134               "Shape preempt: %s\n" \
135               "\n" \
136               "nReserv wallTime: %s\n" \
137               "nReserv memory: %s\n" \
138               "nReserv cores: %s\n" \
139               "nReserv disk: %s\n" \
140               "nReserv preempt: %s\n" \
141               "\n" \
142               "Time slices: %s\n" \
143               "\n" % \
144               (self.shape.wallTime,
145                self.shape.memory,
146                self.shape.cores,
147                self.shape.disk,
148                self.shape.preemptable,
149                self.nReservation.shape.wallTime if self.nReservation is not None else str(None),
150                self.nReservation.shape.memory if self.nReservation is not None else str(None),
151                self.nReservation.shape.cores if self.nReservation is not None else str(None),
152                self.nReservation.shape.disk if self.nReservation is not None else str(None),
153                self.nReservation.shape.preemptable if self.nReservation is not None else str(None),
154                str(len(self.shapes())))
155
156    def fits(self, jobShape):
157        """Check if a job shape's resource requirements will fit within this allocation."""
158        return jobShape.memory <= self.shape.memory and \
159               jobShape.cores <= self.shape.cores and \
160               jobShape.disk <= self.shape.disk and \
161               (jobShape.preemptable or not self.shape.preemptable)
162
163    def shapes(self):
164        """Get all time-slice shapes, in order, from this reservation on."""
165        shapes = []
166        curRes = self
167        while curRes is not None:
168            shapes.append(curRes.shape)
169            curRes = curRes.nReservation
170        return shapes
171
172    def subtract(self, jobShape):
173        """
174        Subtracts the resources necessary to run a jobShape from the reservation.
175        """
176        self.shape = Shape(self.shape.wallTime,
177                           self.shape.memory - jobShape.memory,
178                           self.shape.cores - jobShape.cores,
179                           self.shape.disk - jobShape.disk,
180                           self.shape.preemptable)
181
182    def attemptToAddJob(self, jobShape, nodeShape, targetTime):
183        """
184        Attempt to pack a job into this reservation timeslice and/or the reservations after it.
185
186        jobShape is the Shape of the job requirements, nodeShape is the Shape of the node this
187        is a reservation for, and targetTime is the maximum time to wait before starting this job.
188        """
189        # starting slice of time that we can fit in so far
190        startingReservation = self
191        # current end of the slices we can fit in so far
192        endingReservation = startingReservation
193        # the amount of runtime of the job currently covered by slices
194        availableTime = 0
195        # total time from when the instance started up to startingReservation
196        startingReservationTime = 0
197
198        while True:
199            # True == can run the job (resources & preemptable only; NO time)
200            if endingReservation.fits(jobShape):
201                # add the time left available on the reservation
202                availableTime += endingReservation.shape.wallTime
203                # does the job time fit in the reservation's remaining time?
204                if availableTime >= jobShape.wallTime:
205                    timeSlice = 0
206                    while startingReservation != endingReservation:
207                        # removes resources only (NO time) from startingReservation
208                        startingReservation.subtract(jobShape)
209                        # set aside the timeSlice
210                        timeSlice += startingReservation.shape.wallTime
211                        startingReservation = startingReservation.nReservation
212                    assert jobShape.wallTime - timeSlice <= startingReservation.shape.wallTime
213                    adjustEndingReservationForJob(endingReservation, jobShape, timeSlice)
214                    # Packed the job.
215                    return True
216
217                # If the job would fit, but is longer than the total node allocation
218                # extend the node allocation
219                elif endingReservation.nReservation == None and startingReservation == self:
220                    # Extend the node reservation to accommodate jobShape
221                    endingReservation.nReservation = NodeReservation(nodeShape)
222            # can't run the job with the current resources
223            else:
224                if startingReservationTime + availableTime + endingReservation.shape.wallTime <= targetTime:
225                    startingReservation = endingReservation.nReservation
226                    startingReservationTime += availableTime + endingReservation.shape.wallTime
227                    availableTime = 0
228                else:
229                    break
230
231            endingReservation = endingReservation.nReservation
232            if endingReservation is None:
233                # Reached the end of the reservation without success so stop trying to
234                # add to reservation
235                break
236        # Couldn't pack the job.
237        return False
238
239def adjustEndingReservationForJob(reservation, jobShape, wallTime):
240    """
241    Add a job to an ending reservation that ends at wallTime, splitting
242    the reservation if the job doesn't fill the entire timeslice.
243    """
244    if jobShape.wallTime - wallTime < reservation.shape.wallTime:
245        # This job only partially fills one of the slices. Create a new slice.
246        reservation.shape, nS = split(reservation.shape, jobShape, jobShape.wallTime - wallTime)
247        nS.nReservation = reservation.nReservation
248        reservation.nReservation = nS
249    else:
250        # This job perfectly fits within the boundaries of the slices.
251        reservation.subtract(jobShape)
252
253def split(nodeShape, jobShape, wallTime):
254    """
255    Partition a node allocation into two to fit the job, returning the
256    modified shape of the node and a new node reservation for
257    the extra time that the job didn't fill.
258    """
259    return (Shape(wallTime,
260                  nodeShape.memory - jobShape.memory,
261                  nodeShape.cores - jobShape.cores,
262                  nodeShape.disk - jobShape.disk,
263                  nodeShape.preemptable),
264            NodeReservation(Shape(nodeShape.wallTime - wallTime,
265                                  nodeShape.memory,
266                                  nodeShape.cores,
267                                  nodeShape.disk,
268                                  nodeShape.preemptable)))
269
270def binPacking(nodeShapes, jobShapes, goalTime):
271    bpf = BinPackedFit(nodeShapes, goalTime)
272    bpf.binPack(jobShapes)
273    return bpf.getRequiredNodes()
274
275class ClusterScaler(object):
276    def __init__(self, provisioner, leader, config):
277        """
278        Class manages automatically scaling the number of worker nodes.
279
280        :param AbstractProvisioner provisioner: Provisioner instance to scale.
281        :param toil.leader.Leader leader:
282        :param Config config: Config object from which to draw parameters.
283        """
284        self.provisioner = provisioner
285        self.leader = leader
286        self.config = config
287        self.static = {}
288
289        # Dictionary of job names to their average runtime, used to estimate wall time of queued
290        # jobs for bin-packing
291        self.jobNameToAvgRuntime = {}
292        self.jobNameToNumCompleted = {}
293        self.totalAvgRuntime = 0.0
294        self.totalJobsCompleted = 0
295
296        self.targetTime = config.targetTime
297        if self.targetTime <= 0:
298            raise RuntimeError('targetTime (%s) must be a positive integer!' % self.targetTime)
299        self.betaInertia = config.betaInertia
300        if not 0.0 <= self.betaInertia <= 0.9:
301            raise RuntimeError('betaInertia (%f) must be between 0.0 and 0.9!' % self.betaInertia)
302
303
304        # Pull scaling information from the provisioner.
305        self.nodeShapeToType = provisioner.getAutoscaledInstanceShapes()
306        self.instance_types = list(self.nodeShapeToType.values())
307        self.nodeShapes = list(self.nodeShapeToType.keys())
308
309        self.ignoredNodes = set()
310
311        # A *deficit* exists when we have more jobs that can run on preemptable
312        # nodes than we have preemptable nodes. In order to not block these jobs,
313        # we want to increase the number of non-preemptable nodes that we have and
314        # need for just non-preemptable jobs. However, we may still
315        # prefer waiting for preemptable instances to come available.
316        # To accommodate this, we set the delta to the difference between the number
317        # of provisioned preemptable nodes and the number of nodes that were requested.
318        # Then, when provisioning non-preemptable nodes of the same type, we attempt to
319        # make up the deficit.
320        self.preemptableNodeDeficit = {instance_type: 0 for instance_type in self.instance_types}
321
322        # Keeps track of the last raw (i.e. float, not limited by
323        # max/min nodes) estimates of the number of nodes needed for
324        # each node shape. NB: we start with an estimate of 0, so
325        # scaling up is smoothed as well.
326        self.previousWeightedEstimate = {nodeShape:0.0 for nodeShape in self.nodeShapes}
327
328        assert len(self.nodeShapes) > 0
329
330        # Minimum/maximum number of either preemptable or non-preemptable nodes in the cluster
331        minNodes = config.minNodes
332        if minNodes is None:
333            minNodes = [0 for node in self.instance_types]
334        maxNodes = config.maxNodes
335        while len(maxNodes) < len(self.instance_types):
336            # Pad out the max node counts if we didn't get one per type.
337            maxNodes.append(maxNodes[0])
338        while len(minNodes) < len(self.instance_types):
339            # Pad out the min node counts with 0s, so we can have fewer than
340            # the node types without crashing.
341            minNodes.append(0)
342        self.minNodes = dict(zip(self.nodeShapes, minNodes))
343        self.maxNodes = dict(zip(self.nodeShapes, maxNodes))
344
345        self.nodeShapes.sort()
346
347        #Node shape to number of currently provisioned nodes
348        totalNodes = defaultdict(int)
349        if isinstance(leader.batchSystem, AbstractScalableBatchSystem):
350            for preemptable in (True, False):
351                nodes = []
352                for nodeShape, instance_type in self.nodeShapeToType.items():
353                    nodes_thisType = leader.provisioner.getProvisionedWorkers(instance_type=instance_type,
354                                                                              preemptable=preemptable)
355                    totalNodes[nodeShape] += len(nodes_thisType)
356                    nodes.extend(nodes_thisType)
357
358                self.setStaticNodes(nodes, preemptable)
359
360        logger.debug('Starting with the following nodes in the cluster: %s' % totalNodes)
361
362        if not sum(config.maxNodes) > 0:
363            raise RuntimeError('Not configured to create nodes of any type.')
364
365    def _round(self, number):
366        """
367        Helper function for rounding-as-taught-in-school (X.5 rounds to X+1 if positive).
368        Python 3 now rounds 0.5 to whichever side is even (i.e. 2.5 rounds to 2).
369
370        :param int number: a float to round.
371        :return: closest integer to number, rounding ties away from 0.
372        """
373
374        sign = 1 if number >= 0 else -1
375
376        rounded = int(round(number))
377        nextRounded = int(round(number + 1 * sign))
378
379        if nextRounded == rounded:
380            # We rounded X.5 to even, and it was also away from 0.
381            return rounded
382        elif nextRounded == rounded + 1 * sign:
383            # We rounded normally (we are in Python 2)
384            return rounded
385        elif nextRounded == rounded + 2 * sign:
386            # We rounded X.5 to even, but it was towards 0.
387            # Go away from 0 instead.
388            return rounded + 1 * sign
389        else:
390            # If we get here, something has gone wrong.
391            raise RuntimeError("Could not round {}".format(number))
392
393    def getAverageRuntime(self, jobName, service=False):
394        if service:
395            # We short-circuit service jobs and assume that they will
396            # take a very long time, because if they are assumed to
397            # take a short time, we may try to pack multiple services
398            # into the same core/memory/disk "reservation", one after
399            # the other. That could easily lead to underprovisioning
400            # and a deadlock, because often multiple services need to
401            # be running at once for any actual work to get done.
402            return self.targetTime * 24 + 3600
403        if jobName in self.jobNameToAvgRuntime:
404            #Have seen jobs of this type before, so estimate
405            #the runtime based on average of previous jobs of this type
406            return self.jobNameToAvgRuntime[jobName]
407        elif self.totalAvgRuntime > 0:
408            #Haven't seen this job yet, so estimate its runtime as
409            #the average runtime of all completed jobs
410            return self.totalAvgRuntime
411        else:
412            #Have no information whatsoever
413            return 1.0
414
415    def addCompletedJob(self, job, wallTime):
416        """
417        Adds the shape of a completed job to the queue, allowing the scalar to use the last N
418        completed jobs in factoring how many nodes are required in the cluster.
419        :param toil.job.JobDescription job: The description of the completed job
420        :param int wallTime: The wall-time taken to complete the job in seconds.
421        """
422
423        #Adjust average runtimes to include this job.
424        if job.jobName in self.jobNameToAvgRuntime:
425            prevAvg = self.jobNameToAvgRuntime[job.jobName]
426            prevNum = self.jobNameToNumCompleted[job.jobName]
427            self.jobNameToAvgRuntime[job.jobName] = float(prevAvg*prevNum + wallTime)/(prevNum + 1)
428            self.jobNameToNumCompleted[job.jobName] += 1
429        else:
430            self.jobNameToAvgRuntime[job.jobName] = wallTime
431            self.jobNameToNumCompleted[job.jobName] = 1
432
433        self.totalJobsCompleted += 1
434        self.totalAvgRuntime = float(self.totalAvgRuntime * (self.totalJobsCompleted - 1) + \
435                                     wallTime)/self.totalJobsCompleted
436
437    def setStaticNodes(self, nodes, preemptable):
438        """
439        Used to track statically provisioned nodes. This method must be called
440        before any auto-scaled nodes are provisioned.
441
442        These nodes are treated differently than auto-scaled nodes in that they should
443        not be automatically terminated.
444
445        :param nodes: list of Node objects
446        """
447        prefix = 'non-' if not preemptable else ''
448        logger.debug("Adding %s to %spreemptable static nodes", nodes, prefix)
449        if nodes is not None:
450            self.static[preemptable] = {node.privateIP : node for node in nodes}
451
452    def getStaticNodes(self, preemptable):
453        """
454        Returns nodes set in setStaticNodes().
455
456        :param preemptable:
457        :return: Statically provisioned nodes.
458        """
459        return self.static[preemptable]
460
461    def smoothEstimate(self, nodeShape, estimatedNodeCount):
462        """
463        Smooth out fluctuations in the estimate for this node compared to
464        previous runs. Returns an integer.
465        """
466        weightedEstimate = (1 - self.betaInertia) * estimatedNodeCount + \
467                           self.betaInertia * self.previousWeightedEstimate[nodeShape]
468        self.previousWeightedEstimate[nodeShape] = weightedEstimate
469        return self._round(weightedEstimate)
470
471    def getEstimatedNodeCounts(self, queuedJobShapes, currentNodeCounts):
472        """
473        Given the resource requirements of queued jobs and the current size of the cluster, returns
474        a dict mapping from nodeShape to the number of nodes we want in the cluster right now.
475        """
476        nodesToRunQueuedJobs = binPacking(jobShapes=queuedJobShapes,
477                                          nodeShapes=self.nodeShapes,
478                                          goalTime=self.targetTime)
479        estimatedNodeCounts = {}
480        for nodeShape in self.nodeShapes:
481            instance_type = self.nodeShapeToType[nodeShape]
482
483            logger.debug("Nodes of type %s to run queued jobs = "
484                        "%s" % (instance_type, nodesToRunQueuedJobs[nodeShape]))
485            # Actual calculation of the estimated number of nodes required
486            estimatedNodeCount = 0 if nodesToRunQueuedJobs[nodeShape] == 0 \
487                else max(1, self._round(nodesToRunQueuedJobs[nodeShape]))
488            logger.debug("Estimating %i nodes of shape %s" % (estimatedNodeCount, nodeShape))
489
490            # Use inertia parameter to smooth out fluctuations according to an exponentially
491            # weighted moving average.
492            estimatedNodeCount = self.smoothEstimate(nodeShape, estimatedNodeCount)
493
494            # If we're scaling a non-preemptable node type, we need to see if we have a
495            # deficit of preemptable nodes of this type that we should compensate for.
496            if not nodeShape.preemptable:
497                compensation = self.config.preemptableCompensation
498                assert 0.0 <= compensation <= 1.0
499                # The number of nodes we provision as compensation for missing preemptable
500                # nodes is the product of the deficit (the number of preemptable nodes we did
501                # _not_ allocate) and configuration preference.
502                compensationNodes = self._round(self.preemptableNodeDeficit[instance_type] * compensation)
503                if compensationNodes > 0:
504                    logger.debug('Adding %d non-preemptable nodes of type %s to compensate for a '
505                                'deficit of %d preemptable ones.', compensationNodes,
506                                instance_type,
507                                self.preemptableNodeDeficit[instance_type])
508                estimatedNodeCount += compensationNodes
509
510            logger.debug("Currently %i nodes of type %s in cluster" % (currentNodeCounts[nodeShape],
511                                                                      instance_type))
512            if self.leader.toilMetrics:
513                self.leader.toilMetrics.logClusterSize(instance_type=instance_type,
514                                                       currentSize=currentNodeCounts[nodeShape],
515                                                       desiredSize=estimatedNodeCount)
516
517            # Bound number using the max and min node parameters
518            if estimatedNodeCount > self.maxNodes[nodeShape]:
519                logger.debug('Limiting the estimated number of necessary %s (%s) to the '
520                             'configured maximum (%s).', instance_type,
521                             estimatedNodeCount,
522                             self.maxNodes[nodeShape])
523                estimatedNodeCount = self.maxNodes[nodeShape]
524            elif estimatedNodeCount < self.minNodes[nodeShape]:
525                logger.debug('Raising the estimated number of necessary %s (%s) to the '
526                            'configured minimum (%s).', instance_type,
527                            estimatedNodeCount,
528                            self.minNodes[nodeShape])
529                estimatedNodeCount = self.minNodes[nodeShape]
530            estimatedNodeCounts[nodeShape] = estimatedNodeCount
531        return estimatedNodeCounts
532
533    def updateClusterSize(self, estimatedNodeCounts):
534        """
535        Given the desired and current size of the cluster, attempts to launch/remove instances to
536        get to the desired size. Also attempts to remove ignored nodes that were marked for graceful
537        removal.
538
539        Returns the new size of the cluster.
540        """
541        newNodeCounts = defaultdict(int)
542        for nodeShape, estimatedNodeCount in estimatedNodeCounts.items():
543            instance_type = self.nodeShapeToType[nodeShape]
544
545            newNodeCount = self.setNodeCount(instance_type, estimatedNodeCount, preemptable=nodeShape.preemptable)
546            # If we were scaling up a preemptable node type and failed to meet
547            # our target, we will attempt to compensate for the deficit while scaling
548            # non-preemptable nodes of this type.
549            if nodeShape.preemptable:
550                if newNodeCount < estimatedNodeCount:
551                    deficit = estimatedNodeCount - newNodeCount
552                    logger.debug('Preemptable scaler detected deficit of %d nodes of type %s.' % (deficit, instance_type))
553                    self.preemptableNodeDeficit[instance_type] = deficit
554                else:
555                    self.preemptableNodeDeficit[instance_type] = 0
556            newNodeCounts[nodeShape] = newNodeCount
557
558        #Attempt to terminate any nodes that we previously designated for
559        #termination, but which still had workers running.
560        self._terminateIgnoredNodes()
561        return newNodeCounts
562
563    def setNodeCount(self, instance_type, numNodes, preemptable=False, force=False):
564        """
565        Attempt to grow or shrink the number of preemptable or non-preemptable worker nodes in
566        the cluster to the given value, or as close a value as possible, and, after performing
567        the necessary additions or removals of worker nodes, return the resulting number of
568        preemptable or non-preemptable nodes currently in the cluster.
569
570        :param str instance_type: The instance type to add or remove.
571
572        :param int numNodes: Desired size of the cluster
573
574        :param bool preemptable: whether the added nodes will be preemptable, i.e. whether they
575               may be removed spontaneously by the underlying platform at any time.
576
577        :param bool force: If False, the provisioner is allowed to deviate from the given number
578               of nodes. For example, when downsizing a cluster, a provisioner might leave nodes
579               running if they have active jobs running on them.
580
581        :rtype: int :return: the number of worker nodes in the cluster after making the necessary
582                adjustments. This value should be, but is not guaranteed to be, close or equal to
583                the `numNodes` argument. It represents the closest possible approximation of the
584                actual cluster size at the time this method returns.
585        """
586        for attempt in old_retry(predicate=self.provisioner.retryPredicate):
587            with attempt:
588                workerInstances = self.getNodes(preemptable=preemptable)
589                logger.debug("Cluster contains %i instances" % len(workerInstances))
590                # Reduce to nodes of the correct type
591                workerInstances = {node:workerInstances[node] for node in workerInstances if node.nodeType == instance_type}
592                ignoredNodes = [node for node in workerInstances if node.privateIP in self.ignoredNodes]
593                numIgnoredNodes = len(ignoredNodes)
594                numCurrentNodes = len(workerInstances)
595                logger.debug("Cluster contains %i instances of type %s (%i ignored and draining jobs until "
596                            "they can be safely terminated)" % (numCurrentNodes, instance_type, numIgnoredNodes))
597                if not force:
598                    delta = numNodes - (numCurrentNodes - numIgnoredNodes)
599                else:
600                    delta = numNodes - numCurrentNodes
601                if delta > 0 and numIgnoredNodes > 0:
602                        # We can un-ignore a few nodes to compensate for the additional nodes we want.
603                        numNodesToUnignore = min(delta, numIgnoredNodes)
604                        logger.debug('Unignoring %i nodes because we want to scale back up again.' % numNodesToUnignore)
605                        delta -= numNodesToUnignore
606                        for node in ignoredNodes[:numNodesToUnignore]:
607                            self.ignoredNodes.remove(node.privateIP)
608                            self.leader.batchSystem.unignoreNode(node.privateIP)
609                if delta > 0:
610                    logger.info('Adding %i %s nodes to get to desired cluster size of %i.',
611                                delta,
612                                'preemptable' if preemptable else 'non-preemptable',
613                                numNodes)
614                    numNodes = numCurrentNodes + self._addNodes(instance_type, numNodes=delta,
615                                                                preemptable=preemptable)
616                elif delta < 0:
617                    logger.info('Removing %i %s nodes to get to desired cluster size of %i.', -delta, 'preemptable' if preemptable else 'non-preemptable', numNodes)
618                    numNodes = numCurrentNodes - self._removeNodes(workerInstances,
619                                                                   instance_type = instance_type,
620                                                                   numNodes=-delta,
621                                                                   preemptable=preemptable,
622                                                                   force=force)
623                else:
624                    if not force:
625                        logger.debug('Cluster (minus ignored nodes) already at desired size of %i. Nothing to do.', numNodes)
626                    else:
627                        logger.debug('Cluster already at desired size of %i. Nothing to do.', numNodes)
628        return numNodes
629
630    def _addNodes(self, instance_type, numNodes, preemptable):
631        return self.provisioner.addNodes(nodeTypes={instance_type}, numNodes=numNodes, preemptable=preemptable)
632
633    def _removeNodes(self, nodeToNodeInfo, instance_type, numNodes, preemptable=False, force=False):
634        # If the batch system is scalable, we can use the number of currently running workers on
635        # each node as the primary criterion to select which nodes to terminate.
636        if isinstance(self.leader.batchSystem, AbstractScalableBatchSystem):
637            # Unless forced, exclude nodes with running workers. Note that it is possible for
638            # the batch system to report stale nodes for which the corresponding instance was
639            # terminated already. There can also be instances that the batch system doesn't have
640            # nodes for yet. We'll ignore those, too, unless forced.
641            nodeToNodeInfo = self.getNodes(preemptable)
642            #Filter down to nodes of the correct node type
643            nodeToNodeInfo = {node:nodeToNodeInfo[node] for node in nodeToNodeInfo if node.nodeType == instance_type}
644
645            nodesToTerminate = self.chooseNodes(nodeToNodeInfo, force, preemptable=preemptable)
646
647            nodesToTerminate = nodesToTerminate[:numNodes]
648
649            # Join nodes and instances on private IP address.
650            logger.debug('Nodes considered to terminate: %s', ' '.join(map(str, nodeToNodeInfo)))
651
652            #Tell the batch system to stop sending jobs to these nodes
653            for (node, nodeInfo) in nodesToTerminate:
654                self.ignoredNodes.add(node.privateIP)
655                self.leader.batchSystem.ignoreNode(node.privateIP)
656
657            if not force:
658                # Filter out nodes with jobs still running. These
659                # will be terminated in _removeIgnoredNodes later on
660                # once all jobs have finished, but they will be ignored by
661                # the batch system and cluster scaler from now on
662                nodesToTerminate = [(node,nodeInfo) for (node,nodeInfo) in nodesToTerminate if nodeInfo is not None and nodeInfo.workers < 1]
663            nodesToTerminate = {node:nodeInfo for (node, nodeInfo) in nodesToTerminate}
664            nodeToNodeInfo = nodesToTerminate
665        else:
666            # Without load info all we can do is sort instances by time left in billing cycle.
667            nodeToNodeInfo = sorted(nodeToNodeInfo, key=lambda x: x.remainingBillingInterval())
668            nodeToNodeInfo = [instance for instance in islice(nodeToNodeInfo, numNodes)]
669        logger.debug('Terminating %i instance(s).', len(nodeToNodeInfo))
670        if nodeToNodeInfo:
671            for node in nodeToNodeInfo:
672                if node.privateIP in self.ignoredNodes:
673                    self.ignoredNodes.remove(node.privateIP)
674                    self.leader.batchSystem.unignoreNode(node.privateIP)
675            self.provisioner.terminateNodes(nodeToNodeInfo)
676        return len(nodeToNodeInfo)
677
678    def _terminateIgnoredNodes(self):
679        #Try to terminate any straggling nodes that we designated for
680        #termination, but which still has workers running
681        nodeToNodeInfo = self.getNodes(preemptable=None)
682
683        #Remove any nodes that have already been terminated from the list
684        # of ignored nodes
685        allNodeIPs = [node.privateIP for node in nodeToNodeInfo]
686        terminatedIPs = set([ip for ip in self.ignoredNodes if ip not in allNodeIPs])
687        for ip in terminatedIPs:
688            self.ignoredNodes.remove(ip)
689            self.leader.batchSystem.unignoreNode(ip)
690
691        logger.debug("There are %i nodes being ignored by the batch system, "
692                    "checking if they can be terminated" % len(self.ignoredNodes))
693        nodeToNodeInfo = {node:nodeToNodeInfo[node] for node in nodeToNodeInfo
694                          if node.privateIP in self.ignoredNodes}
695        nodeToNodeInfo = {node:nodeToNodeInfo[node] for node in nodeToNodeInfo
696                          if nodeToNodeInfo[node] is not None and nodeToNodeInfo[node].workers < 1}
697
698        for node in nodeToNodeInfo:
699            self.ignoredNodes.remove(node.privateIP)
700            self.leader.batchSystem.unignoreNode(node.privateIP)
701        if len(nodeToNodeInfo) > 0:
702            logger.debug("Terminating %i nodes that were being ignored by the batch system."
703                        "" % len(nodeToNodeInfo))
704            self.provisioner.terminateNodes(nodeToNodeInfo)
705
706    def chooseNodes(self, nodeToNodeInfo, force=False, preemptable=False):
707        nodesToTerminate = []
708        for node, nodeInfo in list(nodeToNodeInfo.items()):
709            if node is None:
710                logger.debug("Node with info %s was not found in our node list", nodeInfo)
711                continue
712            staticNodes = self.getStaticNodes(preemptable)
713            prefix = 'non-' if not preemptable else ''
714            if node.privateIP in staticNodes:
715                # we don't want to automatically terminate any statically
716                # provisioned nodes
717                logger.debug("Found %s in %spreemptable static nodes", node.privateIP, prefix)
718                continue
719            else:
720                logger.debug("Did not find %s in %spreemptable static nodes", node.privateIP, prefix)
721            nodesToTerminate.append((node, nodeInfo))
722        # Sort nodes by number of workers and time left in billing cycle
723        nodesToTerminate.sort(key=lambda node_nodeInfo: (
724            node_nodeInfo[1].workers if node_nodeInfo[1] else 1, node_nodeInfo[0].remainingBillingInterval()))
725        return nodesToTerminate
726
727    def getNodes(self, preemptable):
728        """
729        Returns a dictionary mapping node identifiers of preemptable or non-preemptable nodes to
730        NodeInfo objects, one for each node.
731
732        This method is the definitive source on nodes in cluster, & is responsible for consolidating
733        cluster state between the provisioner & batch system.
734
735        :param bool preemptable: If True (False) only (non-)preemptable nodes will be returned.
736               If None, all nodes will be returned.
737
738        :rtype: dict[Node, NodeInfo]
739        """
740        def _getInfo(allMesosNodes, ip):
741            info = None
742            try:
743                info = allMesosNodes[ip]
744            except KeyError:
745                # never seen by mesos - 1 of 3 possibilities:
746                # 1) node is still launching mesos & will come online soon
747                # 2) no jobs have been assigned to this worker. This means the executor was never
748                #    launched, so we don't even get an executorInfo back indicating 0 workers running
749                # 3) mesos crashed before launching, worker will never come online
750                # In all 3 situations it's safe to fake executor info with 0 workers, since in all
751                # cases there are no workers running.
752                info = NodeInfo(coresTotal=1, coresUsed=0, requestedCores=0,
753                                memoryTotal=1, memoryUsed=0, requestedMemory=0,
754                                workers=0)
755            else:
756                # Node was tracked but we haven't seen this in the last 10 minutes
757                inUse = self.leader.batchSystem.nodeInUse(ip)
758                if not inUse and info:
759                    # The node hasn't reported in the last 10 minutes & last we know
760                    # there weren't any tasks running. We will fake executorInfo with no
761                    # worker to reflect this, since otherwise this node will never
762                    # be considered for termination
763                    info.workers = 0
764                else:
765                    pass
766                    # despite the node not reporting to mesos jobs may still be running
767                    # so we can't terminate the node
768            return info
769
770        allMesosNodes = self.leader.batchSystem.getNodes(preemptable, timeout=None)
771        recentMesosNodes = self.leader.batchSystem.getNodes(preemptable)
772        provisionerNodes = self.provisioner.getProvisionedWorkers(preemptable=preemptable)
773
774        if len(recentMesosNodes) != len(provisionerNodes):
775            logger.debug("Consolidating state between mesos and provisioner")
776        nodeToInfo = {}
777        # fixme: what happens if awsFilterImpairedNodes is used?
778        # if this assertion is false it means that user-managed nodes are being
779        # used that are outside the provisioner's control
780        # this would violate many basic assumptions in autoscaling so it currently not allowed
781        for node, ip in ((node, node.privateIP) for node in provisionerNodes):
782            info = None
783            if ip not in recentMesosNodes:
784                logger.debug("Worker node at %s is not reporting executor information", ip)
785                # we don't have up to date information about the node
786                info = _getInfo(allMesosNodes, ip)
787            else:
788                # mesos knows about the ip & we have up to date information - easy!
789                info = recentMesosNodes[ip]
790            # add info to dict to return
791            nodeToInfo[node] = info
792        return nodeToInfo
793
794    def shutDown(self):
795        logger.debug('Forcing provisioner to reduce cluster size to zero.')
796        for nodeShape in self.nodeShapes:
797            preemptable = nodeShape.preemptable
798            instance_type = self.nodeShapeToType[nodeShape]
799            self.setNodeCount(instance_type=instance_type, numNodes=0, preemptable=preemptable, force=True)
800
801class ScalerThread(ExceptionalThread):
802    """
803    A thread that automatically scales the number of either preemptable or non-preemptable worker
804    nodes according to the resource requirements of the queued jobs.
805    The scaling calculation is essentially as follows: start with 0 estimated worker nodes. For
806    each queued job, check if we expect it can be scheduled into a worker node before a certain time
807    (currently one hour). Otherwise, attempt to add a single new node of the smallest type that
808    can fit that job.
809    At each scaling decision point a comparison between the current, C, and newly estimated
810    number of nodes is made. If the absolute difference is less than beta * C then no change
811    is made, else the size of the cluster is adapted. The beta factor is an inertia parameter
812    that prevents continual fluctuations in the number of nodes.
813    """
814    def __init__(self, provisioner, leader, config):
815        """
816        :param ClusterScaler scaler: the parent class
817        """
818        super(ScalerThread, self).__init__(name='scaler')
819        self.scaler = ClusterScaler(provisioner, leader, config)
820
821        # Indicates that the scaling thread should shutdown
822        self.stop = False
823
824        self.stats = None
825        if config.clusterStats:
826            logger.debug("Starting up cluster statistics...")
827            self.stats = ClusterStats(leader.config.clusterStats,
828                                      leader.batchSystem,
829                                      provisioner.clusterName)
830            for preemptable in [True, False]:
831                self.stats.startStats(preemptable=preemptable)
832            logger.debug("...Cluster stats started.")
833
834    def check(self):
835        """
836        Attempt to join any existing scaler threads that may have died or finished. This insures
837        any exceptions raised in the threads are propagated in a timely fashion.
838        """
839        try:
840            self.join(timeout=0)
841        except Exception as e:
842            logger.exception(e)
843            raise
844
845    def shutdown(self):
846        """
847        Shutdown the cluster.
848        """
849        self.stop = True
850        if self.stats:
851            self.stats.shutDownStats()
852        self.join()
853
854    def addCompletedJob(self, job, wallTime):
855        self.scaler.addCompletedJob(job, wallTime)
856
857    def tryRun(self):
858        while not self.stop:
859            with throttle(self.scaler.config.scaleInterval):
860                try:
861                    queuedJobs = self.scaler.leader.getJobs()
862                    queuedJobShapes = [
863                        Shape(wallTime=self.scaler.getAverageRuntime(
864                            jobName=job.jobName,
865                            service=isinstance(job, ServiceJobDescription)),
866                            memory=job.memory,
867                            cores=job.cores,
868                            disk=job.disk,
869                            preemptable=job.preemptable) for job in queuedJobs]
870                    currentNodeCounts = {}
871                    for nodeShape in self.scaler.nodeShapes:
872                        instance_type = self.scaler.nodeShapeToType[nodeShape]
873                        currentNodeCounts[nodeShape] = len(
874                            self.scaler.leader.provisioner.getProvisionedWorkers(instance_type=instance_type,
875                                                                                 preemptable=nodeShape.preemptable))
876                    estimatedNodeCounts = self.scaler.getEstimatedNodeCounts(queuedJobShapes,
877                                                                             currentNodeCounts)
878                    self.scaler.updateClusterSize(estimatedNodeCounts)
879                    if self.stats:
880                        self.stats.checkStats()
881                except:
882                    logger.exception("Exception encountered in scaler thread. Making a best-effort "
883                                     "attempt to keep going, but things may go wrong from now on.")
884        self.scaler.shutDown()
885
886class ClusterStats(object):
887    def __init__(self, path, batchSystem, clusterName):
888        logger.debug("Initializing cluster statistics")
889        self.stats = {}
890        self.statsThreads = []
891        self.statsPath = path
892        self.stop = False
893        self.clusterName = clusterName
894        self.batchSystem = batchSystem
895        self.scaleable = isinstance(self.batchSystem, AbstractScalableBatchSystem) \
896            if batchSystem else False
897
898    def shutDownStats(self):
899        if self.stop:
900            return
901        def getFileName():
902            extension = '.json'
903            file = '%s-stats' % self.clusterName
904            counter = 0
905            while True:
906                suffix = str(counter).zfill(3) + extension
907                fullName = os.path.join(self.statsPath, file + suffix)
908                if not os.path.exists(fullName):
909                    return fullName
910                counter += 1
911        if self.statsPath and self.scaleable:
912            self.stop = True
913            for thread in self.statsThreads:
914                thread.join()
915            fileName = getFileName()
916            with open(fileName, 'w') as f:
917                json.dump(self.stats, f)
918
919    def startStats(self, preemptable):
920        thread = ExceptionalThread(target=self._gatherStats, args=[preemptable])
921        thread.start()
922        self.statsThreads.append(thread)
923
924    def checkStats(self):
925        for thread in self.statsThreads:
926            # propagate any errors raised in the threads execution
927            thread.join(timeout=0)
928
929    def _gatherStats(self, preemptable):
930        def toDict(nodeInfo):
931            # convert NodeInfo object to dict to improve JSON output
932            return dict(memory=nodeInfo.memoryUsed,
933                        cores=nodeInfo.coresUsed,
934                        memoryTotal=nodeInfo.memoryTotal,
935                        coresTotal=nodeInfo.coresTotal,
936                        requestedCores=nodeInfo.requestedCores,
937                        requestedMemory=nodeInfo.requestedMemory,
938                        workers=nodeInfo.workers,
939                        time=time.time()  # add time stamp
940                        )
941        if self.scaleable:
942            logger.debug("Starting to gather statistics")
943            stats = {}
944            try:
945                while not self.stop:
946                    nodeInfo = self.batchSystem.getNodes(preemptable)
947                    for nodeIP in list(nodeInfo.keys()):
948                        nodeStats = nodeInfo[nodeIP]
949                        if nodeStats is not None:
950                            nodeStats = toDict(nodeStats)
951                            try:
952                                # if the node is already registered update the dictionary with the
953                                # newly reported stats
954                                stats[nodeIP].append(nodeStats)
955                            except KeyError:
956                                # create a new entry for the node
957                                stats[nodeIP] = [nodeStats]
958                    time.sleep(60)
959            finally:
960                threadName = 'Preemptable' if preemptable else 'Non-preemptable'
961                logger.debug('%s provisioner stats thread shut down successfully.', threadName)
962                self.stats[threadName] = stats
963