1###############################################################################
2#
3# The MIT License (MIT)
4#
5# Copyright (c) Crossbar.io Technologies GmbH
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23# THE SOFTWARE.
24#
25###############################################################################
26
27from __future__ import absolute_import, print_function
28
29import os
30import sys
31import time
32import weakref
33import functools
34import traceback
35import logging
36import inspect
37
38from datetime import datetime
39
40from txaio.interfaces import IFailedFuture, ILogger, log_levels
41from txaio._iotype import guess_stream_needs_encoding
42from txaio._common import _BatchedTimer
43from txaio import _Config
44
45import six
46
47try:
48    import asyncio
49    from asyncio import iscoroutine
50    from asyncio import Future
51
52except ImportError:
53    # Trollius >= 0.3 was renamed
54    # noinspection PyUnresolvedReferences
55    import trollius as asyncio
56    from trollius import iscoroutine
57    from trollius import Future
58
59try:
60    from types import AsyncGeneratorType  # python 3.5+
61except ImportError:
62    class AsyncGeneratorType(object):
63        pass
64
65
66def _create_future_of_loop(loop):
67    return loop.create_future()
68
69
70def _create_future_directly(loop=None):
71    return Future(loop=loop)
72
73
74def _create_task_of_loop(res, loop):
75    return loop.create_task(res)
76
77
78def _create_task_directly(res, loop=None):
79    return asyncio.Task(res, loop=loop)
80
81
82if sys.version_info >= (3, 4, 2):
83    _create_task = _create_task_of_loop
84    if sys.version_info >= (3, 5, 2):
85        _create_future = _create_future_of_loop
86    else:
87        _create_future = _create_future_directly
88else:
89    _create_task = _create_task_directly
90    _create_future = _create_future_directly
91
92
93config = _Config()
94
95
96def with_config(loop=None):
97    """
98    :return: an instance of the txaio API with the given
99        configuration. This won't affect anything using the 'gloabl'
100        config nor other instances created using this function.
101
102    If you need to customize txaio configuration separately (e.g. to
103    use multiple event-loops in asyncio), you can take code like this:
104
105        import txaio
106
107
108        class FunTimes(object):
109
110            def something_async(self):
111                return txaio.call_later(1, lambda: 'some result')
112
113    and instead do this:
114
115        import txaio
116
117
118        class FunTimes(object):
119            txaio = txaio
120
121            def something_async(self):
122                # this will run in the local/new event loop created in the constructor
123                return self.txaio.call_later(1, lambda: 'some result')
124
125        fun0 = FunTimes()
126        fun1 = FunTimes()
127        fun1.txaio = txaio.with_config(loop=asyncio.new_event_loop())
128
129    So `fun1` will run its futures on the newly-created event loop,
130    while `fun0` will work just as it did before this `with_config`
131    method was introduced (after 2.6.2).
132    """
133    cfg = _Config()
134    if loop is not None:
135        cfg.loop = loop
136    return _AsyncioApi(cfg)
137
138
139# logging should probably all be folded into _AsyncioApi as well
140_stderr, _stdout = sys.stderr, sys.stdout
141_loggers = weakref.WeakSet()  # weak-ref's of each logger we've created before start_logging()
142_log_level = 'info'  # re-set by start_logging
143_started_logging = False
144_categories = {}
145
146
147def add_log_categories(categories):
148    _categories.update(categories)
149
150
151class FailedFuture(IFailedFuture):
152    """
153    This provides an object with any features from Twisted's Failure
154    that we might need in Autobahn classes that use FutureMixin.
155
156    We need to encapsulate information from exceptions so that
157    errbacks still have access to the traceback (in case they want to
158    print it out) outside of "except" blocks.
159    """
160
161    def __init__(self, type_, value, traceback):
162        """
163        These are the same parameters as returned from ``sys.exc_info()``
164
165        :param type_: exception type
166        :param value: the Exception instance
167        :param traceback: a traceback object
168        """
169        self._type = type_
170        self._value = value
171        self._traceback = traceback
172
173    @property
174    def value(self):
175        return self._value
176
177    def __str__(self):
178        return str(self.value)
179
180
181# logging API methods
182
183
184def _log(logger, level, format=u'', **kwargs):
185
186    # Look for a log_category, switch it in if we have it
187    if "log_category" in kwargs and kwargs["log_category"] in _categories:
188        format = _categories.get(kwargs["log_category"])
189
190    kwargs['log_time'] = time.time()
191    kwargs['log_level'] = level
192    kwargs['log_format'] = format
193    # NOTE: turning kwargs into a single "argument which
194    # is a dict" on purpose, since a LogRecord only keeps
195    # args, not kwargs.
196    if level == 'trace':
197        level = 'debug'
198        kwargs['txaio_trace'] = True
199
200    msg = format.format(**kwargs)
201    getattr(logger._logger, level)(msg)
202
203
204def _no_op(*args, **kw):
205    pass
206
207
208class _TxaioLogWrapper(ILogger):
209    def __init__(self, logger):
210        self._logger = logger
211        self._set_log_level(_log_level)
212
213    def emit(self, level, *args, **kwargs):
214        func = getattr(self, level)
215        return func(*args, **kwargs)
216
217    def _set_log_level(self, level):
218        target_level = log_levels.index(level)
219        # this binds either _log or _no_op above to this instance,
220        # depending on the desired level.
221        for (idx, name) in enumerate(log_levels):
222            if idx <= target_level:
223                log_method = functools.partial(_log, self, name)
224            else:
225                log_method = _no_op
226            setattr(self, name, log_method)
227        self._log_level = level
228
229
230class _TxaioFileHandler(logging.Handler, object):
231    def __init__(self, fileobj, **kw):
232        super(_TxaioFileHandler, self).__init__(**kw)
233        self._file = fileobj
234        self._encode = guess_stream_needs_encoding(fileobj)
235
236    def emit(self, record):
237        if isinstance(record.args, dict):
238            fmt = record.args.get(
239                'log_format',
240                record.args.get('log_message', u'')
241            )
242            message = fmt.format(**record.args)
243            dt = datetime.fromtimestamp(record.args.get('log_time', 0))
244        else:
245            message = record.getMessage()
246            if record.levelno == logging.ERROR and record.exc_info:
247                message += '\n'
248                for line in traceback.format_exception(*record.exc_info):
249                    message = message + line
250            dt = datetime.fromtimestamp(record.created)
251        msg = u'{0} {1}{2}'.format(
252            dt.strftime("%Y-%m-%dT%H:%M:%S%z"),
253            message,
254            os.linesep
255        )
256        if self._encode:
257            msg = msg.encode('utf8')
258        self._file.write(msg)
259
260
261def make_logger():
262    # we want the namespace to be the calling context of "make_logger"
263    # otherwise the root logger will be returned
264    cf = inspect.currentframe().f_back
265    if "self" in cf.f_locals:
266        # We're probably in a class init or method
267        cls = cf.f_locals["self"].__class__
268        namespace = '{0}.{1}'.format(cls.__module__, cls.__name__)
269    else:
270        namespace = cf.f_globals["__name__"]
271        if cf.f_code.co_name != "<module>":
272            # If it's not the module, and not in a class instance, add the code
273            # object's name.
274            namespace = namespace + "." + cf.f_code.co_name
275
276    logger = _TxaioLogWrapper(logging.getLogger(name=namespace))
277    # remember this so we can set their levels properly once
278    # start_logging is actually called
279    _loggers.add(logger)
280    return logger
281
282
283def start_logging(out=_stdout, level='info'):
284    """
285    Begin logging.
286
287    :param out: if provided, a file-like object to log to. By default, this is
288                stdout.
289    :param level: the maximum log-level to emit (a string)
290    """
291    global _log_level, _loggers, _started_logging
292    if level not in log_levels:
293        raise RuntimeError(
294            "Invalid log level '{0}'; valid are: {1}".format(
295                level, ', '.join(log_levels)
296            )
297        )
298
299    if _started_logging:
300        return
301
302    _started_logging = True
303    _log_level = level
304
305    handler = _TxaioFileHandler(out)
306    logging.getLogger().addHandler(handler)
307    # note: Don't need to call basicConfig() or similar, because we've
308    # now added at least one handler to the root logger
309    logging.raiseExceptions = True  # FIXME
310    level_to_stdlib = {
311        'critical': logging.CRITICAL,
312        'error': logging.ERROR,
313        'warn': logging.WARNING,
314        'info': logging.INFO,
315        'debug': logging.DEBUG,
316        'trace': logging.DEBUG,
317    }
318    logging.getLogger().setLevel(level_to_stdlib[level])
319    # make sure any loggers we created before now have their log-level
320    # set (any created after now will get it from _log_level
321    for logger in _loggers:
322        logger._set_log_level(level)
323
324
325def set_global_log_level(level):
326    """
327    Set the global log level on all loggers instantiated by txaio.
328    """
329    for logger in _loggers:
330        logger._set_log_level(level)
331    global _log_level
332    _log_level = level
333
334
335def get_global_log_level():
336    return _log_level
337
338
339# asyncio API methods; the module-level functions are (now, for
340# backwards-compat) exported from a default instance of this class
341
342
343_unspecified = object()
344
345
346class _AsyncioApi(object):
347    using_twisted = False
348    using_asyncio = True
349
350    def __init__(self, config):
351        if config.loop is None:
352            config.loop = asyncio.get_event_loop()
353        self._config = config
354
355    def failure_message(self, fail):
356        """
357        :param fail: must be an IFailedFuture
358        returns a unicode error-message
359        """
360        try:
361            return u'{0}: {1}'.format(
362                fail._value.__class__.__name__,
363                str(fail._value),
364            )
365        except Exception:
366            return u'Failed to produce failure message for "{0}"'.format(fail)
367
368    def failure_traceback(self, fail):
369        """
370        :param fail: must be an IFailedFuture
371        returns a traceback instance
372        """
373        return fail._traceback
374
375    def failure_format_traceback(self, fail):
376        """
377        :param fail: must be an IFailedFuture
378        returns a string
379        """
380        try:
381            f = six.StringIO()
382            traceback.print_exception(
383                fail._type,
384                fail.value,
385                fail._traceback,
386                file=f,
387            )
388            return f.getvalue()
389        except Exception:
390            return u"Failed to format failure traceback for '{0}'".format(fail)
391
392    def create_future(self, result=_unspecified, error=_unspecified, canceller=_unspecified):
393        if result is not _unspecified and error is not _unspecified:
394            raise ValueError("Cannot have both result and error.")
395
396        f = _create_future(loop=self._config.loop)
397        if result is not _unspecified:
398            resolve(f, result)
399        elif error is not _unspecified:
400            reject(f, error)
401
402        # Twisted's only API for cancelling is to pass a
403        # single-argument callable to the Deferred constructor, so
404        # txaio apes that here for asyncio. The argument is the Future
405        # that has been cancelled.
406        if canceller is not _unspecified:
407            def done(f):
408                try:
409                    f.exception()
410                except asyncio.CancelledError:
411                    canceller(f)
412            f.add_done_callback(done)
413
414        return f
415
416    def create_future_success(self, result):
417        return self.create_future(result=result)
418
419    def create_future_error(self, error=None):
420        f = self.create_future()
421        reject(f, error)
422        return f
423
424    def as_future(self, fun, *args, **kwargs):
425        try:
426            res = fun(*args, **kwargs)
427        except Exception:
428            return create_future_error(create_failure())
429        else:
430            if isinstance(res, Future):
431                return res
432            elif iscoroutine(res):
433                return _create_task(res, loop=self._config.loop)
434            elif isinstance(res, AsyncGeneratorType):
435                raise RuntimeError(
436                    "as_future() received an async generator function; does "
437                    "'{}' use 'yield' when you meant 'await'?".format(
438                        str(fun)
439                    )
440                )
441            else:
442                return create_future_success(res)
443
444    def is_future(self, obj):
445        return iscoroutine(obj) or isinstance(obj, Future)
446
447    def call_later(self, delay, fun, *args, **kwargs):
448        # loop.call_later doesn't support kwargs
449        real_call = functools.partial(fun, *args, **kwargs)
450        return self._config.loop.call_later(delay, real_call)
451
452    def make_batched_timer(self, bucket_seconds, chunk_size=100):
453        """
454        Creates and returns an object implementing
455        :class:`txaio.IBatchedTimer`.
456
457        :param bucket_seconds: the number of seconds in each bucket. That
458            is, a value of 5 means that any timeout within a 5 second
459            window will be in the same bucket, and get notified at the
460            same time. This is only accurate to "milliseconds".
461
462        :param chunk_size: when "doing" the callbacks in a particular
463            bucket, this controls how many we do at once before yielding to
464            the reactor.
465        """
466
467        def get_seconds():
468            return self._config.loop.time()
469
470        return _BatchedTimer(
471            bucket_seconds * 1000.0, chunk_size,
472            seconds_provider=get_seconds,
473            delayed_call_creator=self.call_later,
474        )
475
476    def is_called(self, future):
477        return future.done()
478
479    def resolve(self, future, result=None):
480        future.set_result(result)
481
482    def reject(self, future, error=None):
483        if error is None:
484            error = create_failure()  # will be error if we're not in an "except"
485        elif isinstance(error, Exception):
486            error = FailedFuture(type(error), error, None)
487        else:
488            if not isinstance(error, IFailedFuture):
489                raise RuntimeError("reject requires an IFailedFuture or Exception")
490        future.set_exception(error.value)
491
492    def cancel(self, future):
493        future.cancel()
494
495    def create_failure(self, exception=None):
496        """
497        This returns an object implementing IFailedFuture.
498
499        If exception is None (the default) we MUST be called within an
500        "except" block (such that sys.exc_info() returns useful
501        information).
502        """
503        if exception:
504            return FailedFuture(type(exception), exception, None)
505        return FailedFuture(*sys.exc_info())
506
507    def add_callbacks(self, future, callback, errback):
508        """
509        callback or errback may be None, but at least one must be
510        non-None.
511        """
512        def done(f):
513            try:
514                res = f.result()
515                if callback:
516                    callback(res)
517            except Exception:
518                if errback:
519                    errback(create_failure())
520        return future.add_done_callback(done)
521
522    def gather(self, futures, consume_exceptions=True):
523        """
524        This returns a Future that waits for all the Futures in the list
525        ``futures``
526
527        :param futures: a list of Futures (or coroutines?)
528
529        :param consume_exceptions: if True, any errors are eaten and
530        returned in the result list.
531        """
532
533        # from the asyncio docs: "If return_exceptions is True, exceptions
534        # in the tasks are treated the same as successful results, and
535        # gathered in the result list; otherwise, the first raised
536        # exception will be immediately propagated to the returned
537        # future."
538        return asyncio.gather(*futures, return_exceptions=consume_exceptions)
539
540    def sleep(self, delay):
541        """
542        Inline sleep for use in co-routines.
543
544        :param delay: Time to sleep in seconds.
545        :type delay: float
546        """
547        return asyncio.ensure_future(asyncio.sleep(delay))
548
549
550_default_api = _AsyncioApi(config)
551
552
553using_twisted = _default_api.using_twisted
554using_asyncio = _default_api.using_asyncio
555sleep = _default_api.sleep
556failure_message = _default_api.failure_message
557failure_traceback = _default_api.failure_traceback
558failure_format_traceback = _default_api.failure_format_traceback
559create_future = _default_api.create_future
560create_future_success = _default_api.create_future_success
561create_future_error = _default_api.create_future_error
562as_future = _default_api.as_future
563is_future = _default_api.is_future
564call_later = _default_api.call_later
565make_batched_timer = _default_api.make_batched_timer
566is_called = _default_api.is_called
567resolve = _default_api.resolve
568reject = _default_api.reject
569cancel = _default_api.cancel
570create_failure = _default_api.create_failure
571add_callbacks = _default_api.add_callbacks
572gather = _default_api.gather
573sleep = _default_api.sleep
574