1###############################################################################
2#
3# The MIT License (MIT)
4#
5# Copyright (c) Crossbar.io Technologies GmbH
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23# THE SOFTWARE.
24#
25###############################################################################
26
27from __future__ import absolute_import, division, print_function
28
29import os
30import sys
31import weakref
32import inspect
33
34from functools import partial
35
36from twisted.python.failure import Failure
37from twisted.internet.defer import maybeDeferred, Deferred, DeferredList
38from twisted.internet.defer import succeed, fail
39from twisted.internet.interfaces import IReactorTime
40
41from zope.interface import provider
42
43from txaio.interfaces import IFailedFuture, ILogger, log_levels
44from txaio._iotype import guess_stream_needs_encoding
45from txaio import _Config
46from txaio._common import _BatchedTimer
47
48import six
49
50PY3_CORO = False
51if six.PY3:
52    try:
53        from twisted.internet.defer import ensureDeferred
54        from asyncio import iscoroutinefunction
55        PY3_CORO = True
56    except ImportError:
57        pass
58
59using_twisted = True
60using_asyncio = False
61
62config = _Config()
63_stderr, _stdout = sys.stderr, sys.stdout
64
65# some book-keeping variables here. _observer is used as a global by
66# the "backwards compatible" (Twisted < 15) loggers. The _loggers object
67# is a weak-ref set; we add Logger instances to this *until* such
68# time as start_logging is called (with the desired log-level) and
69# then we call _set_log_level on each instance. After that,
70# Logger's ctor uses _log_level directly.
71_observer = None     # for Twisted legacy logging support; see below
72_loggers = weakref.WeakSet()  # weak-references of each logger we've created
73_log_level = 'info'  # global log level; possibly changed in start_logging()
74_started_logging = False
75
76_categories = {}
77
78IFailedFuture.register(Failure)
79
80_NEW_LOGGER = False
81try:
82    # Twisted 15+
83    from twisted.logger import Logger as _Logger, formatEvent, ILogObserver
84    from twisted.logger import globalLogBeginner, formatTime, LogLevel
85    ILogger.register(_Logger)
86    _NEW_LOGGER = True
87
88except ImportError:
89    # we still support Twisted 12 and 13, which doesn't have new-logger
90    from zope.interface import Interface
91    from datetime import datetime
92    import time
93
94    # provide our own simple versions of what Twisted new-logger does
95
96    class ILogObserver(Interface):
97        pass
98
99    def formatTime(t):  # noqa
100        dt = datetime.fromtimestamp(t)
101        return six.u(dt.strftime("%Y-%m-%dT%H:%M:%S%z"))
102
103    def formatEvent(event):  # noqa
104        msg = event['log_format']
105        return msg.format(**event)
106
107    class LogLevel:
108        critical = 'critical'
109        error = 'error'
110        warn = 'warn'
111        info = 'info'
112        debug = 'debug'
113        trace = 'trace'
114
115        @classmethod
116        def lookupByName(cls, name):  # noqa
117            return getattr(cls, name)
118
119    class _Logger(ILogger):
120        def __init__(self, **kwargs):
121            self.namespace = kwargs.get('namespace', None)
122
123        def emit(self, level, format='', **kwargs):
124            kwargs['log_time'] = time.time()
125            kwargs['log_level'] = level
126            kwargs['log_format'] = format
127            kwargs['log_namespace'] = self.namespace
128            # NOTE: the other loggers are ignoring any log messages
129            # before start_logging() as well
130            if _observer:
131                _observer(kwargs)
132
133
134def _no_op(*args, **kwargs):
135    pass
136
137
138def add_log_categories(categories):
139    _categories.update(categories)
140
141
142def with_config(loop=None):
143    global config
144    if loop is not None:
145        if config.loop is not None and config.loop is not loop:
146            raise RuntimeError(
147                "Twisted has only a single, global reactor. You passed in "
148                "a reactor different from the one already configured "
149                "in txaio.config.loop"
150            )
151    return _TxApi(config)
152
153
154# NOTE: beware that twisted.logger._logger.Logger copies itself via an
155# overriden __get__ method when used as recommended as a class
156# descriptor.  So, we override __get__ to just return ``self`` which
157# means ``log_source`` will be wrong, but we don't document that as a
158# key that you can depend on anyway :/
159class Logger(object):
160
161    def __init__(self, level=None, logger=None, namespace=None, observer=None):
162
163        assert logger, "Should not be instantiated directly."
164
165        self._logger = logger(observer=observer, namespace=namespace)
166        self._log_level_set_explicitly = False
167
168        if level:
169            self.set_log_level(level)
170        else:
171            self._set_log_level(_log_level)
172
173        _loggers.add(self)
174
175    def __get__(self, oself, type=None):
176        # this causes the Logger to lie about the "source=", but
177        # otherwise we create a new Logger instance every time we do
178        # "self.log.info()" if we use it like:
179        # class Foo:
180        #     log = make_logger
181        return self
182
183    def _log(self, level, *args, **kwargs):
184
185        # Look for a log_category, switch it in if we have it
186        if "log_category" in kwargs and kwargs["log_category"] in _categories:
187            args = tuple()
188            kwargs["format"] = _categories.get(kwargs["log_category"])
189
190        self._logger.emit(level, *args, **kwargs)
191
192    def emit(self, level, *args, **kwargs):
193
194        if log_levels.index(self._log_level) < log_levels.index(level):
195            return
196
197        if level == "trace":
198            return self._trace(*args, **kwargs)
199
200        level = LogLevel.lookupByName(level)
201        return self._log(level, *args, **kwargs)
202
203    def set_log_level(self, level, keep=True):
204        """
205        Set the log level. If keep is True, then it will not change along with
206        global log changes.
207        """
208        self._set_log_level(level)
209        self._log_level_set_explicitly = keep
210
211    def _set_log_level(self, level):
212        # up to the desired level, we don't do anything, as we're a
213        # "real" Twisted new-logger; for methods *after* the desired
214        # level, we bind to the no_op method
215        desired_index = log_levels.index(level)
216
217        for (idx, name) in enumerate(log_levels):
218            if name == 'none':
219                continue
220
221            if idx > desired_index:
222                current = getattr(self, name, None)
223                if not current == _no_op or current is None:
224                    setattr(self, name, _no_op)
225                if name == 'error':
226                    setattr(self, 'failure', _no_op)
227
228            else:
229                if getattr(self, name, None) in (_no_op, None):
230
231                    if name == 'trace':
232                        setattr(self, "trace", self._trace)
233                    else:
234                        setattr(self, name,
235                                partial(self._log, LogLevel.lookupByName(name)))
236
237                    if name == 'error':
238                        setattr(self, "failure", self._failure)
239
240        self._log_level = level
241
242    def _failure(self, format=None, *args, **kw):
243        return self._logger.failure(format, *args, **kw)
244
245    def _trace(self, *args, **kw):
246        # there is no "trace" level in Twisted -- but this whole
247        # method will be no-op'd unless we are at the 'trace' level.
248        self.debug(*args, txaio_trace=True, **kw)
249
250
251def make_logger(level=None, logger=_Logger, observer=None):
252    # we want the namespace to be the calling context of "make_logger"
253    # -- so we *have* to pass namespace kwarg to Logger (or else it
254    # will always say the context is "make_logger")
255    cf = inspect.currentframe().f_back
256    if "self" in cf.f_locals:
257        # We're probably in a class init or method
258        cls = cf.f_locals["self"].__class__
259        namespace = '{0}.{1}'.format(cls.__module__, cls.__name__)
260    else:
261        namespace = cf.f_globals["__name__"]
262        if cf.f_code.co_name != "<module>":
263            # If it's not the module, and not in a class instance, add the code
264            # object's name.
265            namespace = namespace + "." + cf.f_code.co_name
266    logger = Logger(level=level, namespace=namespace, logger=logger,
267                    observer=observer)
268    return logger
269
270
271@provider(ILogObserver)
272class _LogObserver(object):
273    """
274    Internal helper.
275
276    An observer which formats events to a given file.
277    """
278    to_tx = {
279        'critical': LogLevel.critical,
280        'error': LogLevel.error,
281        'warn': LogLevel.warn,
282        'info': LogLevel.info,
283        'debug': LogLevel.debug,
284        'trace': LogLevel.debug,
285    }
286
287    def __init__(self, out):
288        self._file = out
289        self._encode = guess_stream_needs_encoding(out)
290
291        self._levels = None
292
293    def _acceptable_level(self, level):
294        if self._levels is None:
295            target_level = log_levels.index(_log_level)
296            self._levels = [
297                self.to_tx[lvl]
298                for lvl in log_levels
299                if log_levels.index(lvl) <= target_level and lvl != "none"
300            ]
301        return level in self._levels
302
303    def __call__(self, event):
304        # it seems if a twisted.logger.Logger() has .failure() called
305        # on it, the log_format will be None for the traceback after
306        # "Unhandled error in Deferred" -- perhaps this is a Twisted
307        # bug?
308        if event['log_format'] is None:
309            msg = u'{0} {1}{2}'.format(
310                formatTime(event["log_time"]),
311                failure_format_traceback(event['log_failure']),
312                os.linesep,
313            )
314            if self._encode:
315                msg = msg.encode('utf8')
316            self._file.write(msg)
317        else:
318            # although Logger will already have filtered out unwanted
319            # levels, bare Logger instances from Twisted code won't have.
320            if 'log_level' in event and self._acceptable_level(event['log_level']):
321                msg = u'{0} {1}{2}'.format(
322                    formatTime(event["log_time"]),
323                    formatEvent(event),
324                    os.linesep,
325                )
326                if self._encode:
327                    msg = msg.encode('utf8')
328
329                self._file.write(msg)
330
331
332def start_logging(out=_stdout, level='info'):
333    """
334    Start logging to the file-like object in ``out``. By default, this
335    is stdout.
336    """
337    global _loggers, _observer, _log_level, _started_logging
338
339    if level not in log_levels:
340        raise RuntimeError(
341            "Invalid log level '{0}'; valid are: {1}".format(
342                level, ', '.join(log_levels)
343            )
344        )
345
346    if _started_logging:
347        return
348
349    _started_logging = True
350
351    _log_level = level
352    set_global_log_level(_log_level)
353
354    if out:
355        _observer = _LogObserver(out)
356
357    if _NEW_LOGGER:
358        _observers = []
359        if _observer:
360            _observers.append(_observer)
361        globalLogBeginner.beginLoggingTo(_observers)
362    else:
363        assert out, "out needs to be given a value if using Twisteds before 15.2"
364        from twisted.python import log
365        log.startLogging(out)
366
367
368_unspecified = object()
369
370
371class _TxApi(object):
372
373    def __init__(self, config):
374        self._config = config
375
376    def failure_message(self, fail):
377        """
378        :param fail: must be an IFailedFuture
379        returns a unicode error-message
380        """
381        try:
382            return u'{0}: {1}'.format(
383                fail.value.__class__.__name__,
384                fail.getErrorMessage(),
385            )
386        except Exception:
387            return 'Failed to produce failure message for "{0}"'.format(fail)
388
389    def failure_traceback(self, fail):
390        """
391        :param fail: must be an IFailedFuture
392        returns a traceback instance
393        """
394        return fail.tb
395
396    def failure_format_traceback(self, fail):
397        """
398        :param fail: must be an IFailedFuture
399        returns a string
400        """
401        try:
402            f = six.StringIO()
403            fail.printTraceback(file=f)
404            return f.getvalue()
405        except Exception:
406            return u"Failed to format failure traceback for '{0}'".format(fail)
407
408    def create_future(self, result=_unspecified, error=_unspecified, canceller=None):
409        if result is not _unspecified and error is not _unspecified:
410            raise ValueError("Cannot have both result and error.")
411
412        f = Deferred(canceller=canceller)
413        if result is not _unspecified:
414            resolve(f, result)
415        elif error is not _unspecified:
416            reject(f, error)
417        return f
418
419    def create_future_success(self, result):
420        return succeed(result)
421
422    def create_future_error(self, error=None):
423        return fail(create_failure(error))
424
425    def as_future(self, fun, *args, **kwargs):
426        # Twisted doesn't automagically deal with coroutines on Py3
427        if PY3_CORO and iscoroutinefunction(fun):
428            return ensureDeferred(fun(*args, **kwargs))
429        return maybeDeferred(fun, *args, **kwargs)
430
431    def is_future(self, obj):
432        return isinstance(obj, Deferred)
433
434    def call_later(self, delay, fun, *args, **kwargs):
435        return IReactorTime(self._get_loop()).callLater(delay, fun, *args, **kwargs)
436
437    def make_batched_timer(self, bucket_seconds, chunk_size=100):
438        """
439        Creates and returns an object implementing
440        :class:`txaio.IBatchedTimer`.
441
442        :param bucket_seconds: the number of seconds in each bucket. That
443            is, a value of 5 means that any timeout within a 5 second
444            window will be in the same bucket, and get notified at the
445            same time. This is only accurate to "milliseconds".
446
447        :param chunk_size: when "doing" the callbacks in a particular
448            bucket, this controls how many we do at once before yielding to
449            the reactor.
450        """
451
452        def get_seconds():
453            return self._get_loop().seconds()
454
455        def create_delayed_call(delay, fun, *args, **kwargs):
456            return self._get_loop().callLater(delay, fun, *args, **kwargs)
457
458        return _BatchedTimer(
459            bucket_seconds * 1000.0, chunk_size,
460            seconds_provider=get_seconds,
461            delayed_call_creator=create_delayed_call,
462        )
463
464    def is_called(self, future):
465        return future.called
466
467    def resolve(self, future, result=None):
468        future.callback(result)
469
470    def reject(self, future, error=None):
471        if error is None:
472            error = create_failure()
473        elif isinstance(error, Exception):
474            error = Failure(error)
475        else:
476            if not isinstance(error, Failure):
477                raise RuntimeError("reject requires a Failure or Exception")
478        future.errback(error)
479
480    def cancel(self, future):
481        future.cancel()
482
483    def create_failure(self, exception=None):
484        """
485        Create a Failure instance.
486
487        if ``exception`` is None (the default), we MUST be inside an
488        "except" block. This encapsulates the exception into an object
489        that implements IFailedFuture
490        """
491        if exception:
492            return Failure(exception)
493        return Failure()
494
495    def add_callbacks(self, future, callback, errback):
496        """
497        callback or errback may be None, but at least one must be
498        non-None.
499        """
500        assert future is not None
501        if callback is None:
502            assert errback is not None
503            future.addErrback(errback)
504        else:
505            # Twisted allows errback to be None here
506            future.addCallbacks(callback, errback)
507        return future
508
509    def gather(self, futures, consume_exceptions=True):
510        def completed(res):
511            rtn = []
512            for (ok, value) in res:
513                rtn.append(value)
514                if not ok and not consume_exceptions:
515                    value.raiseException()
516            return rtn
517
518        # XXX if consume_exceptions is False in asyncio.gather(), it will
519        # abort on the first raised exception -- should we set
520        # fireOnOneErrback=True (if consume_exceptions=False?) -- but then
521        # we'll have to wrap the errback() to extract the "real" failure
522        # from the FirstError that gets thrown if you set that ...
523
524        dl = DeferredList(list(futures), consumeErrors=consume_exceptions)
525        # we unpack the (ok, value) tuples into just a list of values, so
526        # that the callback() gets the same value in asyncio and Twisted.
527        add_callbacks(dl, completed, None)
528        return dl
529
530    def sleep(self, delay):
531        """
532        Inline sleep for use in co-routines.
533
534        :param delay: Time to sleep in seconds.
535        :type delay: float
536        """
537        d = Deferred()
538        self._get_loop().callLater(delay, d.callback, None)
539        return d
540
541    def _get_loop(self):
542        """
543        internal helper
544        """
545        # we import and assign the default here (and not, e.g., when
546        # making Config) so as to delay importing reactor as long as
547        # possible in case someone is installing a custom one.
548        if self._config.loop is None:
549            from twisted.internet import reactor
550            self._config.loop = reactor
551        return self._config.loop
552
553
554def set_global_log_level(level):
555    """
556    Set the global log level on all loggers instantiated by txaio.
557    """
558    for item in _loggers:
559        if not item._log_level_set_explicitly:
560            item._set_log_level(level)
561    global _log_level
562    _log_level = level
563
564
565def get_global_log_level():
566    return _log_level
567
568
569_default_api = _TxApi(config)
570
571
572failure_message = _default_api.failure_message
573failure_traceback = _default_api.failure_traceback
574failure_format_traceback = _default_api.failure_format_traceback
575create_future = _default_api.create_future
576create_future_success = _default_api.create_future_success
577create_future_error = _default_api.create_future_error
578as_future = _default_api.as_future
579is_future = _default_api.is_future
580call_later = _default_api.call_later
581make_batched_timer = _default_api.make_batched_timer
582is_called = _default_api.is_called
583resolve = _default_api.resolve
584reject = _default_api.reject
585cancel = _default_api.cancel
586create_failure = _default_api.create_failure
587add_callbacks = _default_api.add_callbacks
588gather = _default_api.gather
589sleep = _default_api.sleep
590