1#  _________________________________________________________________________
2#
3#  PyUtilib: A Python utility library.
4#  Copyright (c) 2008 Sandia Corporation.
5#  This software is distributed under the BSD License.
6#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
7#  the U.S. Government retains certain rights in this software.
8#  _________________________________________________________________________
9
10__all__ = ['Dispatcher', 'DispatcherServer']
11
12import os
13import sys
14import uuid
15from collections import defaultdict
16
17from pyutilib.pyro.util import get_nameserver, using_pyro3, using_pyro4
18from pyutilib.pyro.util import Pyro as _pyro
19from pyutilib.pyro.util import set_maxconnections, get_dispatchers
20
21if sys.version_info >= (3, 0):
22    import queue as Queue
23else:
24    import Queue
25
26from six import iteritems
27
28if using_pyro3:
29    base = _pyro.core.ObjBase
30    oneway = lambda method: method
31    expose = lambda obj: obj
32elif using_pyro4:
33    base = object
34    oneway = _pyro.oneway
35    expose = _pyro.expose
36else:
37    base = object
38    oneway = lambda method: method
39    expose = lambda obj: obj
40
41
42def _clear_queue_threadsafe(q):
43    while not q.empty():
44        try:
45            q.get(False)
46        except Queue.Empty:
47            continue
48        q.task_done()
49
50
51class Dispatcher(base):
52
53    def __init__(self, **kwds):
54        if _pyro is None:
55            raise ImportError("Pyro or Pyro4 is not available")
56        if using_pyro3:
57            _pyro.core.ObjBase.__init__(self)
58        self._task_queue = defaultdict(Queue.Queue)
59        self._result_queue = defaultdict(Queue.Queue)
60        self._verbose = kwds.pop("verbose", False)
61        self._registered_workers = set()
62        self._acquired_workers = set()
63        self._worker_limit = kwds.pop("worker_limit", None)
64        if self._verbose:
65            print("Verbose output enabled...")
66
67    #
68    # One-way methods (Pyro4 only)
69    #
70
71    @oneway
72    def release_acquired_workers(self, names):
73        names = set(names)
74        if not names.issubset(self._registered_workers):
75            raise ValueError("List contains one or more worker names that "
76                             "were not registered")
77        self._acquired_workers.difference_update(names)
78
79    @oneway
80    def unregister_worker(self, name):
81        if name not in self._registered_workers:
82            raise ValueError("Worker name '%s' has not been registered")
83        if self._verbose:
84            print("Unregistering worker with name: %s" % (name))
85        self._registered_workers.remove(name)
86
87    @oneway
88    def shutdown(self):
89        print("Dispatcher received request to shut down - initiating...")
90        if using_pyro3:
91            self.getDaemon().shutdown()
92        else:
93            self._pyroDaemon.shutdown()
94
95    @oneway
96    def add_task(self, task, type=None):
97        if self._verbose:
98            print("Received request to add task=<Task id=" + str(task['id']) +
99                  ">; queue type=" + str(type))
100        self._task_queue[type].put(task)
101
102    # process a set of tasks in one shot - the input
103    # is a dictionary from queue type (including None)
104    # to a list of tasks to be added to that queue.
105    @oneway
106    def add_tasks(self, tasks):
107        if self._verbose:
108            print("Received request to add bulk task set. Task ids=%s" % (dict(
109                (task_type, [task['id'] for task in tasks[task_type]])
110                for task_type in tasks)))
111        for task_type in tasks:
112            task_queue = self._task_queue[task_type]
113            for task in tasks[task_type]:
114                task_queue.put(task)
115
116    @oneway
117    def add_result(self, result, type=None):
118        if self._verbose:
119            print("Received request to add result with "
120                  "result=" + str(result) + "; queue type=" + str(type))
121        self._result_queue[type].put(result)
122
123    # process a set of results in one shot - the input
124    # is a dictionary from queue type (including None)
125    # to a list of results to be added to that queue.
126    @oneway
127    def add_results(self, results):
128        if self._verbose:
129            print("Received request to add bulk result set for task ids=%s" %
130                  (dict((result_type, [result['id']
131                                       for result in results[result_type]])
132                        for result_type in results)))
133        for result_type in results:
134            result_queue = self._result_queue[result_type]
135            for result in results[result_type]:
136                result_queue.put(result)
137
138    #
139    # Methods that do not return anything but are
140    # not marked oneway for Pyro4 to avoid race conditions
141    #
142
143    def clear_queue(self, type=None):
144        if self._verbose:
145            print("Received request to clear task and result "
146                  "queues for queue type=" + str(type))
147
148        try:
149            _clear_queue_threadsafe(self._task_queue[type])
150        except KeyError:
151            pass
152        try:
153            _clear_queue_threadsafe(self._result_queue[type])
154        except KeyError:
155            pass
156
157    def clear_queues(self, types):
158        for type in types:
159            self.clear_queue(type=type)
160
161    def clear_all_queues(self):
162        self._task_queue = defaultdict(Queue.Queue)
163        self._result_queue = defaultdict(Queue.Queue)
164
165    def clear_task_queue(self, type=None):
166        if self._verbose:
167            print("Received request to clear task "
168                  "queue for queue type=" + str(type))
169        try:
170            _clear_queue_threadsafe(self._task_queue[type])
171        except KeyError:
172            pass
173
174    def clear_task_queues(self, types):
175        for type in types:
176            self.clear_task_queue(type=type)
177
178    def clear_all_task_queues(self):
179        self._task_queue = defaultdict(Queue.Queue)
180
181    def clear_result_queue(self, type=None):
182        if self._verbose:
183            print("Received request to clear result "
184                  "queue for queue type=" + str(type))
185        try:
186            _clear_queue_threadsafe(self._result_queue[type])
187        except KeyError:
188            pass
189
190    def clear_result_queues(self, types):
191        for type in types:
192            self.clear_result_queue(type=type)
193
194    def clear_all_result_queues(self):
195        self._result_queue = defaultdict(Queue.Queue)
196
197    #
198    # Methods that do return something, so can't
199    # be marked as oneway for Pyro4
200    #
201
202    def acquire_available_workers(self):
203        worker_names = self._registered_workers - self._acquired_workers
204        self._acquired_workers.update(worker_names)
205        return worker_names
206
207    def register_worker(self, name):
208        if name in self._registered_workers:
209            raise ValueError("Worker name '%s' has already been registered")
210        if (self._worker_limit is None) or \
211           (len(self._registered_workers) < self._worker_limit):
212            self._registered_workers.add(name)
213            if self._verbose:
214                print("Registering worker %s with name: %s" %
215                      (len(self._registered_workers), name))
216            return True
217        return False
218
219    def get_task(self, type=None, block=True, timeout=5):
220        if self._verbose:
221            print("Received request to get a task from "
222                  "queue type=" + str(type) + "; block=" + str(block) +
223                  "; timeout=" + str(timeout) + " seconds")
224        try:
225            task = self._task_queue[type].get(block=block, timeout=timeout)
226            return task
227        except Queue.Empty:
228            return None
229
230    def get_tasks(self, type_block_timeout_list):
231        if self._verbose:
232            print("Received request to get tasks in bulk. "
233                  "Queue request types=" + str(type_block_timeout_list))
234
235        ret = {}
236        for type, block, timeout in type_block_timeout_list:
237            task_list = []
238            try:
239                task_list.append(self._task_queue[type].get(block=block,
240                                                            timeout=timeout))
241            except Queue.Empty:
242                pass
243            else:
244                while self._task_queue[type].qsize():
245                    try:
246                        task_list.append(self._task_queue[type].get(
247                            block=block, timeout=timeout))
248                    except Queue.Empty:
249                        pass
250            if len(task_list) > 0:
251                ret.setdefault(type, []).extend(task_list)
252
253        return ret
254
255    def get_result(self, type=None, block=True, timeout=5):
256        if self._verbose:
257            print("Received request to get a result from "
258                  "queue type=" + str(type) + "; block=" + str(block) +
259                  "; timeout=" + str(timeout))
260        try:
261            return self._result_queue[type].get(block=block, timeout=timeout)
262        except Queue.Empty:
263            return None
264
265    def get_results(self, type_block_timeout_list):
266        if self._verbose:
267            print("Received request to get results in bulk. "
268                  "Queue request types=" + str(type_block_timeout_list))
269
270        ret = {}
271        for type_name, block, timeout in type_block_timeout_list:
272            result_list = []
273            try:
274                result_list.append(self._result_queue[type_name].get(
275                    block=block, timeout=timeout))
276            except Queue.Empty:
277                pass
278            else:
279                while self._result_queue[type_name].qsize():
280                    try:
281                        result_list.append(self._result_queue[type_name].get(
282                            block=block, timeout=timeout))
283                    except Queue.Empty:
284                        pass
285            if len(result_list) > 0:
286                ret.setdefault(type_name, []).extend(result_list)
287
288        return ret
289
290    def num_tasks(self, type=None):
291        if self._verbose:
292            print("Received request for number of tasks in "
293                  "queue with type=" + str(type))
294        return self._task_queue[type].qsize()
295
296    def num_results(self, type=None):
297        if self._verbose:
298            print("Received request for number of results in "
299                  "queue with type=" + str(type))
300        return self._result_queue[type].qsize()
301
302    def queues_with_results(self):
303        if self._verbose:
304            print("Received request for the set of queues with results")
305
306        results = []
307        #
308        # Iterate over a copy of the contents of the queue, since
309        # the queue may change while iterating.
310        #
311        for queue_name, result_queue in list(self._result_queue.items()):
312            if result_queue.qsize() > 0:
313                results.append(queue_name)
314
315        return results
316
317    def get_results_all_queues(self):
318
319        if self._verbose:
320            print("Received request to obtain all available "
321                  "results from all queues")
322
323        results = []
324        #
325        # Iterate over a copy of the contents of the queue, since
326        # the queue may change while iterating.
327        #
328        for queue_name, result_queue in list(self._result_queue.items()):
329            while result_queue.qsize() > 0:
330                try:
331                    results.append(result_queue.get(block=False, timeout=0))
332                except Queue.Empty:
333                    pass
334        return results
335
336
337Dispatcher = expose(Dispatcher)
338
339
340def DispatcherServer(group=":PyUtilibServer",
341                     daemon_host=None,
342                     daemon_port=0,
343                     nameserver_host=None,
344                     nameserver_port=None,
345                     verbose=False,
346                     max_allowed_connections=None,
347                     worker_limit=None,
348                     clear_group=True):
349
350    set_maxconnections(max_allowed_connections=max_allowed_connections)
351
352    #
353    # main program
354    #
355    ns = get_nameserver(
356        host=nameserver_host, port=nameserver_port, caller_name="Dispatcher")
357
358    if clear_group:
359        for name, uri in get_dispatchers(group=group, ns=ns):
360            print("Multiple dispatchers not allowed.")
361            print("dispatch_srvr is shutting down...")
362            return 1
363
364    if using_pyro3:
365        daemon = _pyro.core.Daemon(host=daemon_host, port=daemon_port)
366        daemon.useNameServer(ns)
367    else:
368        daemon = _pyro.Daemon(host=daemon_host, port=daemon_port)
369
370    if using_pyro3:
371        try:
372            ns.createGroup(group)
373        except _pyro.errors.NamingError:
374            pass
375        try:
376            ns.createGroup(group + ".dispatcher")
377        except _pyro.errors.NamingError:
378            pass
379        if clear_group:
380            try:
381                ns.unregister(group + ".dispatcher")
382            except _pyro.errors.NamingError:
383                pass
384    else:
385        if clear_group:
386            try:
387                ns.remove(group + ".dispatcher")
388            except _pyro.errors.NamingError:
389                pass
390
391    disp = Dispatcher(verbose=verbose, worker_limit=worker_limit)
392    proxy_name = group + ".dispatcher." + str(uuid.uuid4())
393    if using_pyro3:
394        uri = daemon.connect(disp, proxy_name)
395    else:
396        uri = daemon.register(disp, proxy_name)
397        ns.register(proxy_name, uri)
398
399    # There is no need to retain the proxy connection to the
400    # nameserver, so free up resources on the nameserver thread
401    if using_pyro4:
402        ns._pyroRelease()
403    else:
404        ns._release()
405
406    print("Dispatcher is ready.")
407    return daemon.requestLoop()
408