1"""The Python scheduler for rich scheduling.
2
3The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4nor does it check msg_id DAG dependencies. For those, a slightly slower
5Python Scheduler exists.
6"""
7
8# Copyright (c) IPython Development Team.
9# Distributed under the terms of the Modified BSD License.
10
11import logging
12import time
13
14from collections import deque
15from random import randint, random
16from types import FunctionType
17
18try:
19    import numpy
20except ImportError:
21    numpy = None
22
23import zmq
24from zmq.eventloop import zmqstream
25
26# local imports
27from decorator import decorator
28from traitlets.config.application import Application
29from traitlets.config.loader import Config
30from traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes, observe
31from ipython_genutils.py3compat import cast_bytes
32
33from ipyparallel import error, util
34from ipyparallel.factory import SessionFactory
35from ipyparallel.util import connect_logger, local_logger, ioloop
36
37from .dependency import Dependency
38
39@decorator
40def logged(f,self,*args,**kwargs):
41    # print ("#--------------------")
42    self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
43    # print ("#--")
44    return f(self,*args, **kwargs)
45
46#----------------------------------------------------------------------
47# Chooser functions
48#----------------------------------------------------------------------
49
50def plainrandom(loads):
51    """Plain random pick."""
52    n = len(loads)
53    return randint(0,n-1)
54
55def lru(loads):
56    """Always pick the front of the line.
57
58    The content of `loads` is ignored.
59
60    Assumes LRU ordering of loads, with oldest first.
61    """
62    return 0
63
64def twobin(loads):
65    """Pick two at random, use the LRU of the two.
66
67    The content of loads is ignored.
68
69    Assumes LRU ordering of loads, with oldest first.
70    """
71    n = len(loads)
72    a = randint(0,n-1)
73    b = randint(0,n-1)
74    return min(a,b)
75
76def weighted(loads):
77    """Pick two at random using inverse load as weight.
78
79    Return the less loaded of the two.
80    """
81    # weight 0 a million times more than 1:
82    weights = 1./(1e-6+numpy.array(loads))
83    sums = weights.cumsum()
84    t = sums[-1]
85    x = random()*t
86    y = random()*t
87    idx = 0
88    idy = 0
89    while sums[idx] < x:
90        idx += 1
91    while sums[idy] < y:
92        idy += 1
93    if weights[idy] > weights[idx]:
94        return idy
95    else:
96        return idx
97
98def leastload(loads):
99    """Always choose the lowest load.
100
101    If the lowest load occurs more than once, the first
102    occurance will be used.  If loads has LRU ordering, this means
103    the LRU of those with the lowest load is chosen.
104    """
105    return loads.index(min(loads))
106
107#---------------------------------------------------------------------
108# Classes
109#---------------------------------------------------------------------
110
111
112# store empty default dependency:
113MET = Dependency([])
114
115
116class Job(object):
117    """Simple container for a job"""
118    def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
119                    targets, after, follow, timeout):
120        self.msg_id = msg_id
121        self.raw_msg = raw_msg
122        self.idents = idents
123        self.msg = msg
124        self.header = header
125        self.metadata = metadata
126        self.targets = targets
127        self.after = after
128        self.follow = follow
129        self.timeout = timeout
130
131        self.removed = False # used for lazy-delete from sorted queue
132        self.timestamp = time.time()
133        self.timeout_id = 0
134        self.blacklist = set()
135
136    def __lt__(self, other):
137        return self.timestamp < other.timestamp
138
139    def __cmp__(self, other):
140        return cmp(self.timestamp, other.timestamp)
141
142    @property
143    def dependents(self):
144        return self.follow.union(self.after)
145
146
147class TaskScheduler(SessionFactory):
148    """Python TaskScheduler object.
149
150    This is the simplest object that supports msg_id based
151    DAG dependencies. *Only* task msg_ids are checked, not
152    msg_ids of jobs submitted via the MUX queue.
153
154    """
155
156    hwm = Integer(1, config=True,
157        help="""specify the High Water Mark (HWM) for the downstream
158        socket in the Task scheduler. This is the maximum number
159        of allowed outstanding tasks on each engine.
160
161        The default (1) means that only one task can be outstanding on each
162        engine.  Setting TaskScheduler.hwm=0 means there is no limit, and the
163        engines continue to be assigned tasks while they are working,
164        effectively hiding network latency behind computation, but can result
165        in an imbalance of work when submitting many heterogenous tasks all at
166        once.  Any positive value greater than one is a compromise between the
167        two.
168
169        """
170    )
171    scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
172        'leastload', config=True,
173help="""select the task scheduler scheme  [default: Python LRU]
174        Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
175    )
176
177    @observe('scheme_name')
178    def _scheme_name_changed(self, change):
179        self.log.debug("Using scheme %r" % change['new'])
180        self.scheme = globals()[change['new']]
181
182    # input arguments:
183    scheme = Instance(FunctionType) # function for determining the destination
184    def _scheme_default(self):
185        return leastload
186    client_stream = Instance(zmqstream.ZMQStream, allow_none=True) # client-facing stream
187    engine_stream = Instance(zmqstream.ZMQStream, allow_none=True) # engine-facing stream
188    notifier_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing sub stream
189    mon_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing pub stream
190    query_stream = Instance(zmqstream.ZMQStream, allow_none=True) # hub-facing DEALER stream
191
192    # internals:
193    queue = Instance(deque) # sorted list of Jobs
194    def _queue_default(self):
195        return deque()
196    queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
197    graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
198    retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
199    # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
200    pending = Dict() # dict by engine_uuid of submitted tasks
201    completed = Dict() # dict by engine_uuid of completed tasks
202    failed = Dict() # dict by engine_uuid of failed tasks
203    destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
204    clients = Dict() # dict by msg_id for who submitted the task
205    targets = List() # list of target IDENTs
206    loads = List() # list of engine loads
207    # full = Set() # set of IDENTs that have HWM outstanding tasks
208    all_completed = Set() # set of all completed tasks
209    all_failed = Set() # set of all failed tasks
210    all_done = Set() # set of all finished tasks=union(completed,failed)
211    all_ids = Set() # set of all submitted task IDs
212
213    ident = CBytes() # ZMQ identity. This should just be self.session.session
214                     # but ensure Bytes
215    def _ident_default(self):
216        return self.session.bsession
217
218    def start(self):
219        self.query_stream.on_recv(self.dispatch_query_reply)
220        self.session.send(self.query_stream, "connection_request", {})
221
222        self.engine_stream.on_recv(self.dispatch_result, copy=False)
223        self.client_stream.on_recv(self.dispatch_submission, copy=False)
224
225        self._notification_handlers = dict(
226            registration_notification = self._register_engine,
227            unregistration_notification = self._unregister_engine
228        )
229        self.notifier_stream.on_recv(self.dispatch_notification)
230        self.log.info("Scheduler started [%s]" % self.scheme_name)
231
232    def resume_receiving(self):
233        """Resume accepting jobs."""
234        self.client_stream.on_recv(self.dispatch_submission, copy=False)
235
236    def stop_receiving(self):
237        """Stop accepting jobs while there are no engines.
238        Leave them in the ZMQ queue."""
239        self.client_stream.on_recv(None)
240
241    #-----------------------------------------------------------------------
242    # [Un]Registration Handling
243    #-----------------------------------------------------------------------
244
245
246    def dispatch_query_reply(self, msg):
247        """handle reply to our initial connection request"""
248        try:
249            idents,msg = self.session.feed_identities(msg)
250        except ValueError:
251            self.log.warn("task::Invalid Message: %r",msg)
252            return
253        try:
254            msg = self.session.deserialize(msg)
255        except ValueError:
256            self.log.warn("task::Unauthorized message from: %r"%idents)
257            return
258
259        content = msg['content']
260        for uuid in content.get('engines', {}).values():
261            self._register_engine(cast_bytes(uuid))
262
263
264    @util.log_errors
265    def dispatch_notification(self, msg):
266        """dispatch register/unregister events."""
267        try:
268            idents,msg = self.session.feed_identities(msg)
269        except ValueError:
270            self.log.warn("task::Invalid Message: %r",msg)
271            return
272        try:
273            msg = self.session.deserialize(msg)
274        except ValueError:
275            self.log.warn("task::Unauthorized message from: %r"%idents)
276            return
277
278        msg_type = msg['header']['msg_type']
279
280        handler = self._notification_handlers.get(msg_type, None)
281        if handler is None:
282            self.log.error("Unhandled message type: %r"%msg_type)
283        else:
284            try:
285                handler(cast_bytes(msg['content']['uuid']))
286            except Exception:
287                self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
288
289    def _register_engine(self, uid):
290        """New engine with ident `uid` became available."""
291        # head of the line:
292        self.targets.insert(0,uid)
293        self.loads.insert(0,0)
294
295        # initialize sets
296        self.completed[uid] = set()
297        self.failed[uid] = set()
298        self.pending[uid] = {}
299
300        # rescan the graph:
301        self.update_graph(None)
302
303    def _unregister_engine(self, uid):
304        """Existing engine with ident `uid` became unavailable."""
305        if len(self.targets) == 1:
306            # this was our only engine
307            pass
308
309        # handle any potentially finished tasks:
310        self.engine_stream.flush()
311
312        # don't pop destinations, because they might be used later
313        # map(self.destinations.pop, self.completed.pop(uid))
314        # map(self.destinations.pop, self.failed.pop(uid))
315
316        # prevent this engine from receiving work
317        idx = self.targets.index(uid)
318        self.targets.pop(idx)
319        self.loads.pop(idx)
320
321        # wait 5 seconds before cleaning up pending jobs, since the results might
322        # still be incoming
323        if self.pending[uid]:
324            self.loop.add_timeout(
325                self.loop.time() + 5,
326                lambda: self.handle_stranded_tasks(uid),
327            )
328        else:
329            self.completed.pop(uid)
330            self.failed.pop(uid)
331
332    def handle_stranded_tasks(self, engine):
333        """Deal with jobs resident in an engine that died."""
334        lost = self.pending[engine]
335        for msg_id in list(lost.keys()):
336            if msg_id not in lost:
337                # prevent double-handling of messages
338                continue
339
340            raw_msg = lost[msg_id].raw_msg
341            idents, msg = self.session.feed_identities(raw_msg, copy=False)
342            parent = self.session.unpack(msg[1].bytes)
343            idents = [engine, idents[0]]
344
345            # build fake error reply
346            try:
347                raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
348            except:
349                content = error.wrap_exception()
350            # build fake metadata
351            md = dict(
352                status=u'error',
353                engine=engine.decode('ascii'),
354                date=util.utcnow(),
355            )
356            msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
357            raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
358            # and dispatch it
359            self.dispatch_result(raw_reply)
360
361        # finally scrub completed/failed lists
362        self.completed.pop(engine)
363        self.failed.pop(engine)
364
365
366    #-----------------------------------------------------------------------
367    # Job Submission
368    #-----------------------------------------------------------------------
369
370
371    @util.log_errors
372    def dispatch_submission(self, raw_msg):
373        """Dispatch job submission to appropriate handlers."""
374        # ensure targets up to date:
375        self.notifier_stream.flush()
376        try:
377            idents, msg = self.session.feed_identities(raw_msg, copy=False)
378            msg = self.session.deserialize(msg, content=False, copy=False)
379        except Exception:
380            self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
381            return
382
383
384        # send to monitor
385        self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
386
387        header = msg['header']
388        md = msg['metadata']
389        msg_id = header['msg_id']
390        self.all_ids.add(msg_id)
391
392        # get targets as a set of bytes objects
393        # from a list of unicode objects
394        targets = md.get('targets', [])
395        targets = set(map(cast_bytes, targets))
396
397        retries = md.get('retries', 0)
398        self.retries[msg_id] = retries
399
400        # time dependencies
401        after = md.get('after', None)
402        if after:
403            after = Dependency(after)
404            if after.all:
405                if after.success:
406                    after = Dependency(after.difference(self.all_completed),
407                                success=after.success,
408                                failure=after.failure,
409                                all=after.all,
410                    )
411                if after.failure:
412                    after = Dependency(after.difference(self.all_failed),
413                                success=after.success,
414                                failure=after.failure,
415                                all=after.all,
416                    )
417            if after.check(self.all_completed, self.all_failed):
418                # recast as empty set, if `after` already met,
419                # to prevent unnecessary set comparisons
420                after = MET
421        else:
422            after = MET
423
424        # location dependencies
425        follow = Dependency(md.get('follow', []))
426
427        timeout = md.get('timeout', None)
428        if timeout:
429            timeout = float(timeout)
430
431        job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
432                 header=header, targets=targets, after=after, follow=follow,
433                 timeout=timeout, metadata=md,
434        )
435        # validate and reduce dependencies:
436        for dep in after,follow:
437            if not dep: # empty dependency
438                continue
439            # check valid:
440            if msg_id in dep or dep.difference(self.all_ids):
441                self.queue_map[msg_id] = job
442                return self.fail_unreachable(msg_id, error.InvalidDependency)
443            # check if unreachable:
444            if dep.unreachable(self.all_completed, self.all_failed):
445                self.queue_map[msg_id] = job
446                return self.fail_unreachable(msg_id)
447
448        if after.check(self.all_completed, self.all_failed):
449            # time deps already met, try to run
450            if not self.maybe_run(job):
451                # can't run yet
452                if msg_id not in self.all_failed:
453                    # could have failed as unreachable
454                    self.save_unmet(job)
455        else:
456            self.save_unmet(job)
457
458    def job_timeout(self, job, timeout_id):
459        """callback for a job's timeout.
460
461        The job may or may not have been run at this point.
462        """
463        if job.timeout_id != timeout_id:
464            # not the most recent call
465            return
466        now = time.time()
467        if job.timeout >= (now + 1):
468            self.log.warn("task %s timeout fired prematurely: %s > %s",
469                job.msg_id, job.timeout, now
470            )
471        if job.msg_id in self.queue_map:
472            # still waiting, but ran out of time
473            self.log.info("task %r timed out", job.msg_id)
474            self.fail_unreachable(job.msg_id, error.TaskTimeout)
475
476    def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
477        """a task has become unreachable, send a reply with an ImpossibleDependency
478        error."""
479        if msg_id not in self.queue_map:
480            self.log.error("task %r already failed!", msg_id)
481            return
482        job = self.queue_map.pop(msg_id)
483        # lazy-delete from the queue
484        job.removed = True
485        for mid in job.dependents:
486            if mid in self.graph:
487                self.graph[mid].remove(msg_id)
488
489        try:
490            raise why()
491        except:
492            content = error.wrap_exception()
493        self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
494
495        self.all_done.add(msg_id)
496        self.all_failed.add(msg_id)
497
498        msg = self.session.send(self.client_stream, 'apply_reply', content,
499                                                parent=job.header, ident=job.idents)
500        self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
501
502        self.update_graph(msg_id, success=False)
503
504    def available_engines(self):
505        """return a list of available engine indices based on HWM"""
506        if not self.hwm:
507            return list(range(len(self.targets)))
508        available = []
509        for idx in range(len(self.targets)):
510            if self.loads[idx] < self.hwm:
511                available.append(idx)
512        return available
513
514    def maybe_run(self, job):
515        """check location dependencies, and run if they are met."""
516        msg_id = job.msg_id
517        self.log.debug("Attempting to assign task %s", msg_id)
518        available = self.available_engines()
519        if not available:
520            # no engines, definitely can't run
521            return False
522
523        if job.follow or job.targets or job.blacklist or self.hwm:
524            # we need a can_run filter
525            def can_run(idx):
526                # check hwm
527                if self.hwm and self.loads[idx] == self.hwm:
528                    return False
529                target = self.targets[idx]
530                # check blacklist
531                if target in job.blacklist:
532                    return False
533                # check targets
534                if job.targets and target not in job.targets:
535                    return False
536                # check follow
537                return job.follow.check(self.completed[target], self.failed[target])
538
539            indices = list(filter(can_run, available))
540
541            if not indices:
542                # couldn't run
543                if job.follow.all:
544                    # check follow for impossibility
545                    dests = set()
546                    relevant = set()
547                    if job.follow.success:
548                        relevant = self.all_completed
549                    if job.follow.failure:
550                        relevant = relevant.union(self.all_failed)
551                    for m in job.follow.intersection(relevant):
552                        dests.add(self.destinations[m])
553                    if len(dests) > 1:
554                        self.queue_map[msg_id] = job
555                        self.fail_unreachable(msg_id)
556                        return False
557                if job.targets:
558                    # check blacklist+targets for impossibility
559                    job.targets.difference_update(job.blacklist)
560                    if not job.targets or not job.targets.intersection(self.targets):
561                        self.queue_map[msg_id] = job
562                        self.fail_unreachable(msg_id)
563                        return False
564                return False
565        else:
566            indices = None
567
568        self.submit_task(job, indices)
569        return True
570
571    def save_unmet(self, job):
572        """Save a message for later submission when its dependencies are met."""
573        msg_id = job.msg_id
574        self.log.debug("Adding task %s to the queue", msg_id)
575        self.queue_map[msg_id] = job
576        self.queue.append(job)
577        # track the ids in follow or after, but not those already finished
578        for dep_id in job.after.union(job.follow).difference(self.all_done):
579            if dep_id not in self.graph:
580                self.graph[dep_id] = set()
581            self.graph[dep_id].add(msg_id)
582
583        # schedule timeout callback
584        if job.timeout:
585            timeout_id = job.timeout_id = job.timeout_id + 1
586            self.loop.add_timeout(time.time() + job.timeout,
587                lambda : self.job_timeout(job, timeout_id)
588            )
589
590
591    def submit_task(self, job, indices=None):
592        """Submit a task to any of a subset of our targets."""
593        if indices:
594            loads = [self.loads[i] for i in indices]
595        else:
596            loads = self.loads
597        idx = self.scheme(loads)
598        if indices:
599            idx = indices[idx]
600        target = self.targets[idx]
601        # print (target, map(str, msg[:3]))
602        # send job to the engine
603        self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
604        self.engine_stream.send_multipart(job.raw_msg, copy=False)
605        # update load
606        self.add_job(idx)
607        self.pending[target][job.msg_id] = job
608        # notify Hub
609        content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
610        self.session.send(self.mon_stream, 'task_destination', content=content,
611                        ident=[b'tracktask',self.ident])
612
613
614    #-----------------------------------------------------------------------
615    # Result Handling
616    #-----------------------------------------------------------------------
617
618
619    @util.log_errors
620    def dispatch_result(self, raw_msg):
621        """dispatch method for result replies"""
622        try:
623            idents,msg = self.session.feed_identities(raw_msg, copy=False)
624            msg = self.session.deserialize(msg, content=False, copy=False)
625            engine = idents[0]
626            try:
627                idx = self.targets.index(engine)
628            except ValueError:
629                pass # skip load-update for dead engines
630            else:
631                self.finish_job(idx)
632        except Exception:
633            self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
634            return
635
636        md = msg['metadata']
637        parent = msg['parent_header']
638        if md.get('dependencies_met', True):
639            success = (md['status'] == 'ok')
640            msg_id = parent['msg_id']
641            retries = self.retries[msg_id]
642            if not success and retries > 0:
643                # failed
644                self.retries[msg_id] = retries - 1
645                self.handle_unmet_dependency(idents, parent)
646            else:
647                del self.retries[msg_id]
648                # relay to client and update graph
649                self.handle_result(idents, parent, raw_msg, success)
650                # send to Hub monitor
651                self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
652        else:
653            self.handle_unmet_dependency(idents, parent)
654
655    def handle_result(self, idents, parent, raw_msg, success=True):
656        """handle a real task result, either success or failure"""
657        # first, relay result to client
658        engine = idents[0]
659        client = idents[1]
660        # swap_ids for ROUTER-ROUTER mirror
661        raw_msg[:2] = [client,engine]
662        # print (map(str, raw_msg[:4]))
663        self.client_stream.send_multipart(raw_msg, copy=False)
664        # now, update our data structures
665        msg_id = parent['msg_id']
666        self.pending[engine].pop(msg_id)
667        if success:
668            self.completed[engine].add(msg_id)
669            self.all_completed.add(msg_id)
670        else:
671            self.failed[engine].add(msg_id)
672            self.all_failed.add(msg_id)
673        self.all_done.add(msg_id)
674        self.destinations[msg_id] = engine
675
676        self.update_graph(msg_id, success)
677
678    def handle_unmet_dependency(self, idents, parent):
679        """handle an unmet dependency"""
680        engine = idents[0]
681        msg_id = parent['msg_id']
682
683        job = self.pending[engine].pop(msg_id)
684        job.blacklist.add(engine)
685
686        if job.blacklist == job.targets:
687            self.queue_map[msg_id] = job
688            self.fail_unreachable(msg_id)
689        elif not self.maybe_run(job):
690            # resubmit failed
691            if msg_id not in self.all_failed:
692                # put it back in our dependency tree
693                self.save_unmet(job)
694
695        if self.hwm:
696            try:
697                idx = self.targets.index(engine)
698            except ValueError:
699                pass # skip load-update for dead engines
700            else:
701                if self.loads[idx] == self.hwm-1:
702                    self.update_graph(None)
703
704    def update_graph(self, dep_id=None, success=True):
705        """dep_id just finished. Update our dependency
706        graph and submit any jobs that just became runnable.
707
708        Called with dep_id=None to update entire graph for hwm, but without finishing a task.
709        """
710        # print ("\n\n***********")
711        # pprint (dep_id)
712        # pprint (self.graph)
713        # pprint (self.queue_map)
714        # pprint (self.all_completed)
715        # pprint (self.all_failed)
716        # print ("\n\n***********\n\n")
717        # update any jobs that depended on the dependency
718        msg_ids = self.graph.pop(dep_id, [])
719
720        # recheck *all* jobs if
721        # a) we have HWM and an engine just become no longer full
722        # or b) dep_id was given as None
723
724        if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
725            jobs = self.queue
726            using_queue = True
727        else:
728            using_queue = False
729            jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
730
731        to_restore = []
732        while jobs:
733            job = jobs.popleft()
734            if job.removed:
735                continue
736            msg_id = job.msg_id
737
738            put_it_back = True
739
740            if job.after.unreachable(self.all_completed, self.all_failed)\
741                    or job.follow.unreachable(self.all_completed, self.all_failed):
742                self.fail_unreachable(msg_id)
743                put_it_back = False
744
745            elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
746                if self.maybe_run(job):
747                    put_it_back = False
748                    self.queue_map.pop(msg_id)
749                    for mid in job.dependents:
750                        if mid in self.graph:
751                            self.graph[mid].remove(msg_id)
752
753                    # abort the loop if we just filled up all of our engines.
754                    # avoids an O(N) operation in situation of full queue,
755                    # where graph update is triggered as soon as an engine becomes
756                    # non-full, and all tasks after the first are checked,
757                    # even though they can't run.
758                    if not self.available_engines():
759                        break
760
761            if using_queue and put_it_back:
762                # popped a job from the queue but it neither ran nor failed,
763                # so we need to put it back when we are done
764                # make sure to_restore preserves the same ordering
765                to_restore.append(job)
766
767        # put back any tasks we popped but didn't run
768        if using_queue:
769            self.queue.extendleft(to_restore)
770
771    #----------------------------------------------------------------------
772    # methods to be overridden by subclasses
773    #----------------------------------------------------------------------
774
775    def add_job(self, idx):
776        """Called after self.targets[idx] just got the job with header.
777        Override with subclasses.  The default ordering is simple LRU.
778        The default loads are the number of outstanding jobs."""
779        self.loads[idx] += 1
780        for lis in (self.targets, self.loads):
781            lis.append(lis.pop(idx))
782
783    def finish_job(self, idx):
784        """Called after self.targets[idx] just finished a job.
785        Override with subclasses."""
786        self.loads[idx] -= 1
787
788
789def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
790                        logname='root', log_url=None, loglevel=logging.DEBUG,
791                        identity=b'task', in_thread=False):
792
793    ZMQStream = zmqstream.ZMQStream
794
795    if config:
796        # unwrap dict back into Config
797        config = Config(config)
798
799    if in_thread:
800        # use instance() to get the same Context/Loop as our parent
801        ctx = zmq.Context.instance()
802        loop = ioloop.IOLoop.current()
803    else:
804        # in a process, don't use instance()
805        # for safety with multiprocessing
806        ctx = zmq.Context()
807        loop = ioloop.IOLoop()
808    ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
809    util.set_hwm(ins, 0)
810    ins.setsockopt(zmq.IDENTITY, identity + b'_in')
811    ins.bind(in_addr)
812
813    outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
814    util.set_hwm(outs, 0)
815    outs.setsockopt(zmq.IDENTITY, identity + b'_out')
816    outs.bind(out_addr)
817    mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
818    util.set_hwm(mons, 0)
819    mons.connect(mon_addr)
820    nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
821    nots.setsockopt(zmq.SUBSCRIBE, b'')
822    nots.connect(not_addr)
823
824    querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
825    querys.connect(reg_addr)
826
827    # setup logging.
828    if in_thread:
829        log = Application.instance().log
830    else:
831        if log_url:
832            log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
833        else:
834            log = local_logger(logname, loglevel)
835
836    scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
837                            mon_stream=mons, notifier_stream=nots,
838                            query_stream=querys,
839                            loop=loop, log=log,
840                            config=config)
841    scheduler.start()
842    if not in_thread:
843        try:
844            loop.start()
845        except KeyboardInterrupt:
846            scheduler.log.critical("Interrupted, exiting...")
847
848