1"""
2Helpers/utils for working with tornado asynchronous stuff
3"""
4
5
6import contextlib
7import logging
8import sys
9import threading
10
11import salt.ext.tornado.concurrent
12import salt.ext.tornado.ioloop
13
14log = logging.getLogger(__name__)
15
16
17@contextlib.contextmanager
18def current_ioloop(io_loop):
19    """
20    A context manager that will set the current ioloop to io_loop for the context
21    """
22    orig_loop = salt.ext.tornado.ioloop.IOLoop.current()
23    io_loop.make_current()
24    try:
25        yield
26    finally:
27        orig_loop.make_current()
28
29
30class SyncWrapper:
31    """
32    A wrapper to make Async classes synchronous
33
34    This is uses as a simple wrapper, for example:
35
36    asynchronous = AsyncClass()
37    # this method would reguarly return a future
38    future = asynchronous.async_method()
39
40    sync = SyncWrapper(async_factory_method, (arg1, arg2), {'kwarg1': 'val'})
41    # the sync wrapper will automatically wait on the future
42    ret = sync.async_method()
43    """
44
45    def __init__(
46        self,
47        cls,
48        args=None,
49        kwargs=None,
50        async_methods=None,
51        close_methods=None,
52        loop_kwarg=None,
53    ):
54        self.io_loop = salt.ext.tornado.ioloop.IOLoop()
55        if args is None:
56            args = []
57        if kwargs is None:
58            kwargs = {}
59        if async_methods is None:
60            async_methods = []
61        if close_methods is None:
62            close_methods = []
63        self.loop_kwarg = loop_kwarg
64        self.cls = cls
65        if loop_kwarg:
66            kwargs[self.loop_kwarg] = self.io_loop
67        self.obj = cls(*args, **kwargs)
68        self._async_methods = list(
69            set(async_methods + getattr(self.obj, "async_methods", []))
70        )
71        self._close_methods = list(
72            set(close_methods + getattr(self.obj, "close_methods", []))
73        )
74
75    def _populate_async_methods(self):
76        """
77        We need the '_coroutines' attribute on classes until we can depricate
78        tornado<4.5. After that 'is_coroutine_fuction' will always be
79        available.
80        """
81        if hasattr(self.obj, "_coroutines"):
82            self._async_methods += self.obj._coroutines
83
84    def __repr__(self):
85        return "<SyncWrapper(cls={})".format(self.cls)
86
87    def close(self):
88        for method in self._close_methods:
89            if method in self._async_methods:
90                method = self._wrap(method)
91            else:
92                try:
93                    method = getattr(self.obj, method)
94                except AttributeError:
95                    log.error("No sync method %s on object %r", method, self.obj)
96                    continue
97            try:
98                method()
99            except AttributeError:
100                log.error("No async method %s on object %r", method, self.obj)
101            except Exception:  # pylint: disable=broad-except
102                log.exception("Exception encountered while running stop method")
103        io_loop = self.io_loop
104        io_loop.stop()
105        io_loop.close(all_fds=True)
106
107    def __getattr__(self, key):
108        if key in self._async_methods:
109            return self._wrap(key)
110        return getattr(self.obj, key)
111
112    def _wrap(self, key):
113        def wrap(*args, **kwargs):
114            results = []
115            thread = threading.Thread(
116                target=self._target,
117                args=(key, args, kwargs, results, self.io_loop),
118            )
119            thread.start()
120            thread.join()
121            if results[0]:
122                return results[1]
123            else:
124                exc_info = results[1]
125                raise exc_info[1].with_traceback(exc_info[2])
126
127        return wrap
128
129    def _target(self, key, args, kwargs, results, io_loop):
130        try:
131            result = io_loop.run_sync(lambda: getattr(self.obj, key)(*args, **kwargs))
132            results.append(True)
133            results.append(result)
134        except Exception as exc:  # pylint: disable=broad-except
135            results.append(False)
136            results.append(sys.exc_info())
137
138    def __enter__(self):
139        return self
140
141    def __exit__(self, exc_type, exc_val, tb):
142        self.close()
143