1from __future__ import absolute_import, unicode_literals
2import logging
3import os
4import socket
5import subprocess
6import sys
7import time
8import traceback
9try:
10    from urllib2 import (
11        Request,
12        urlopen,
13    )
14except ImportError:
15    from urllib.request import (
16        Request,
17        urlopen,
18    )
19from collections import (
20    deque,
21    OrderedDict,
22)
23try:
24    from collections.abc import Iterable
25except ImportError:
26    from collections import Iterable
27from difflib import (
28    Match,
29    SequenceMatcher,
30)
31from functools import wraps
32from itertools import chain
33try:
34    from itertools import izip as zip
35    from Queue import (
36        Empty,
37        Queue,
38    )
39except ImportError:
40    from queue import (
41        Empty,
42        Queue,
43    )
44from threading import Thread
45from weakref import WeakKeyDictionary
46
47from .exceptions import Abort
48
49
50class StreamHandler(logging.StreamHandler):
51    def __init__(self):
52        super(StreamHandler, self).__init__()
53        self._start_time = time.time()
54
55    def emit(self, record):
56        record.timestamp = record.created - self._start_time
57        super(StreamHandler, self).emit(record)
58
59
60class Formatter(logging.Formatter):
61    def __init__(self):
62        super(Formatter, self).__init__(
63            '\r%(timestamp).3f %(levelname)s [%(name)s] %(message)s')
64        self._root_formatter = logging.Formatter('\r%(levelname)s %(message)s')
65        self._no_timestamp_formatter = logging.Formatter(
66            '\r%(levelname)s [%(name)s] %(message)s')
67
68    def format(self, record):
69        if record.name == 'root':
70            return self._root_formatter.format(record)
71        if record.levelno >= logging.WARNING:
72            return self._no_timestamp_formatter.format(record)
73        return super(Formatter, self).format(record)
74
75
76def init_logging():
77    # Initialize logging from the GIT_CINNABAR_LOG environment variable
78    # or the cinnabar.log configuration, the former taking precedence.
79    # Still read the configuration to force the git config cache being
80    # filled before logging is setup, so that the output of
81    # `git config -l` is never logged.
82    from .git import Git
83    logger = logging.getLogger()
84    handler = StreamHandler()
85    handler.setFormatter(Formatter())
86    logger.addHandler(handler)
87    log_conf = Git.config('cinnabar.log') or b''
88    if not log_conf and not check_enabled('memory') and \
89            not check_enabled('cpu'):
90        return
91    for assignment in log_conf.split(b','):
92        try:
93            name, value = assignment.split(b':', 1)
94            value = int(value)
95            name = name.decode('ascii')
96            if name == '*':
97                name = ''
98            logging.getLogger(name).setLevel(
99                max(logging.DEBUG, logging.FATAL - value * 10))
100        except Exception:
101            pass
102
103
104class ConfigSetFunc(object):
105    def __init__(self, key, values, extra_values=(), default='', remote=None):
106        self._config = None
107        self._key = key
108        self._values = values
109        self._extra_values = extra_values
110        self._default = default.encode('ascii')
111        self._remote = remote
112
113    def __call__(self, name):
114        if self._config is None:
115            from .git import Git
116            if self._remote:
117                config = Git.config(self._key, self._remote) or self._default
118            else:
119                config = Git.config(self._key) or self._default
120            if config:
121                config = config.decode('ascii').split(',')
122            self._config = set()
123            for c in config:
124                if c in ('true', 'all'):
125                    self._config |= set(self._values)
126                elif c.startswith('-'):
127                    c = c[1:]
128                    try:
129                        self._config.remove(c.decode('ascii'))
130                    except KeyError:
131                        logging.getLogger('config').warn(
132                            '%s: %s is not one of (%s)',
133                            self._key, c, ', '.join(self._config))
134                elif c in self._values or c in self._extra_values:
135                    self._config.add(c)
136                else:
137                    logging.getLogger('config').warn(
138                        '%s: %s is not one of (%s)',
139                        self._key, c, ', '.join(self._values))
140        return name in self._config
141
142
143check_enabled = ConfigSetFunc(
144    'cinnabar.check',
145    ('nodeid', 'manifests', 'helper'),
146    ('bundle', 'files', 'memory', 'cpu', 'time', 'traceback', 'no-mercurial',
147     'no-bundle2', 'cinnabarclone', 'clonebundles', 'no-version-check'),
148)
149
150experiment = ConfigSetFunc(
151    'cinnabar.experiments',
152    ('wire', 'merge', 'store'),
153    (),
154)
155
156
157def interval_expired(config_key, interval, globl=False):
158    from .git import Git
159    config_key = 'cinnabar.{}'.format(config_key)
160    try:
161        last = int(Git.config(config_key))
162    except (ValueError, TypeError):
163        last = None
164    now = time.time()
165    if last:
166        if last + interval > now:
167            return False
168    # cinnabar.fsck used to be global and is now local.
169    # Remove the global value.
170    if globl is not True and config_key == 'cinnabar.fsck':
171        Git.run('config', '--global', '--unset', config_key)
172    Git.run('config', '--global' if globl else '--local',
173            config_key, str(int(now)))
174    return bool(last)
175
176
177progress = True
178
179
180class Progress(object):
181    def __init__(self, fmt):
182        self._count = 0
183        self._start = self._t0 = time.time()
184        self._fmt = fmt
185
186    def progress(self, count=None):
187        if not progress:
188            return
189        if count is None:
190            count = self._count + 1
191        t1 = time.time()
192        if t1 - self._t0 > 0.1:
193            self._print_count(count, t1)
194        self._count = count
195
196    def _print_count(self, count, t1=None):
197        if not isinstance(count, tuple):
198            count = (count,)
199        timed = ''
200        if check_enabled('time'):
201            t1 = t1 or time.time()
202            timed = ' in %.1fs' % (t1 - self._start)
203        sys.stderr.write('\r' + self._fmt.format(*count) + timed)
204        sys.stderr.flush()
205        self._t0 = t1
206
207    def finish(self, count=None):
208        if not progress:
209            return
210        self._print_count(count or self._count)
211        sys.stderr.write('\n')
212        sys.stderr.flush()
213
214
215def progress_iter(fmt, iter):
216    return progress_enum(fmt, enumerate(iter, start=1))
217
218
219def progress_enum(fmt, enum_iter):
220    count = 0
221    progress = Progress(fmt)
222    try:
223        for count, item in enum_iter:
224            progress.progress(count)
225            yield item
226    finally:
227        if count:
228            progress.finish(count)
229
230
231class IOLogger(object):
232    def __init__(self, logger, reader, writer=None, prefix=''):
233        self._reader = reader
234        self._writer = writer or reader
235        self._logger = logger
236        self._prefix = (prefix + ' ') if prefix else ''
237
238    def read(self, length=None, level=logging.INFO):
239        if length is None:
240            ret = self._reader.read()
241        else:
242            ret = self._reader.read(length)
243        if not isinstance(self._reader, IOLogger):
244            self._logger.log(level, '%s<= %r', self._prefix, ret)
245        return ret
246
247    def readline(self, level=logging.INFO):
248        ret = self._reader.readline()
249        if not isinstance(self._reader, IOLogger):
250            self._logger.log(level, '%s<= %r', self._prefix, ret)
251        return ret
252
253    def write(self, data, level=logging.INFO):
254        if not isinstance(self._writer, IOLogger):
255            self._logger.log(level, '%s=> %r', self._prefix, data)
256        return self._writer.write(data)
257
258    def flush(self):
259        self._writer.flush()
260
261    def __iter__(self):
262        while self._reader:
263            line = self.readline()
264            if not line:
265                break
266            yield line
267
268
269def one(iterable):
270    lst = list(iterable)
271    if lst:
272        assert len(lst) == 1
273        return lst[0]
274    return None
275
276
277def strip_suffix(s, suffix):
278    if s.endswith(suffix):
279        return s[:-len(suffix)]
280    return s
281
282
283class OrderedDefaultDict(OrderedDict):
284    def __init__(self, default_factory, *args, **kwargs):
285        OrderedDict.__init__(self, *args, **kwargs)
286        self._default_factory = default_factory
287
288    def __missing__(self, key):
289        value = self[key] = self._default_factory()
290        return value
291
292
293class VersionedDict(object):
294    def __init__(self, content=None, **kwargs):
295        if content:
296            if kwargs:
297                self._previous = VersionedDict(content)
298                for k, v in kwargs.items():
299                    self._previous[k] = v
300            elif isinstance(content, (VersionedDict, dict)):
301                self._previous = content
302            else:
303                self._previous = dict(content)
304        else:
305            self._previous = dict(**kwargs)
306        self._dict = {}
307        self._deleted = set()
308
309    def update(self, content=None, **kwargs):
310        if content:
311            if isinstance(content, (VersionedDict, dict)):
312                content = content.items()
313            for k, v in content:
314                self[k] = v
315        for k, v in kwargs.items():
316            self[k] = v
317
318    def __getitem__(self, key):
319        if key in self._dict:
320            return self._dict[key]
321        return self._previous[key]
322
323    def get(self, key, default=None):
324        try:
325            return self[key]
326        except KeyError:
327            return default
328
329    def __contains__(self, key):
330        if key in self._deleted:
331            return False
332        if key in self._dict:
333            return True
334        return key in self._previous
335
336    def __delitem__(self, key):
337        self._deleted.add(key)
338        if key in self._dict:
339            del self._dict[key]
340        elif key not in self._previous:
341            raise KeyError(key)
342
343    def __setitem__(self, key, value):
344        if key in self._deleted:
345            self._deleted.remove(key)
346        self._dict[key] = value
347
348    def __len__(self):
349        return len(self._dict) + sum(1 for k in self._previous
350                                     if k not in self._deleted and
351                                     k not in self._dict)
352
353    def keys(self):
354        if self._previous:
355            return list(self)
356        return self._dict.keys()
357
358    def values(self):
359        if self._previous:
360            return list(chain(
361                self._dict.values(),
362                (v for k, v in self._previous.items()
363                 if k not in self._deleted and k not in self._dict)))
364        return self._dict.values()
365
366    def __iter__(self):
367        if self._previous:
368            return chain(self._dict,
369                         (k for k in self._previous
370                          if k not in self._deleted and k not in self._dict))
371        return iter(self._dict)
372
373    def items(self):
374        return self.iteritems()
375
376    def iteritems(self):
377        if self._previous:
378            return chain(
379                self._dict.items(),
380                ((k, v) for k, v in self._previous.items()
381                 if k not in self._deleted and k not in self._dict))
382        return self._dict.items()
383
384    CREATED = 1
385    MODIFIED = 2
386    REMOVED = 3
387
388    def iterchanges(self):
389        for k, v in self._dict.items():
390            if k in self._previous:
391                if self._previous[k] == v:
392                    continue
393                status = self.MODIFIED
394            else:
395                status = self.CREATED
396            yield status, k, v
397        for k in self._deleted:
398            if k in self._previous:
399                yield self.REMOVED, k, self._previous[k]
400
401    def flattened(self):
402        previous = self
403        changes = []
404        while isinstance(previous, VersionedDict):
405            changes.append(previous.iterchanges())
406            previous = previous._previous
407
408        ret = VersionedDict(previous)
409
410        # This can probably be optimized, but it shouldn't matter much that
411        # it's not.
412        for c in reversed(changes):
413            for status, k, v in c:
414                if status == self.REMOVED:
415                    del ret[k]
416                else:
417                    ret[k] = v
418        return ret
419
420
421def _iter_diff_blocks(a, b):
422    m = SequenceMatcher(a=a, b=b, autojunk=False).get_matching_blocks()
423    for start, end in zip(chain((Match(0, 0, 0),), m), m):
424        if start.a + start.size != end.a or start.b + start.size != end.b:
425            yield start.a + start.size, end.a, start.b + start.size, end.b
426
427
428def byte_diff(a, b):
429    '''Given two strings, returns the diff between them, at the byte level.
430
431    Yields start offset in a, end offset in a and replacement string for
432    each difference. Far from optimal results, but works well enough.'''
433    a = tuple(a.splitlines(True))
434    b = tuple(b.splitlines(True))
435    offset = 0
436    last = 0
437    for start_a, end_a, start_b, end_b in _iter_diff_blocks(a, b):
438        a2 = b''.join(a[start_a:end_a])
439        b2 = b''.join(b[start_b:end_b])
440        offset += sum(len(i) for i in a[last:start_a])
441        last = start_a
442        for start2_a, end2_a, start2_b, end2_b in _iter_diff_blocks(a2, b2):
443            yield offset + start2_a, offset + end2_a, b2[start2_b:end2_b]
444
445
446def sorted_merge(iter_a, iter_b, key=lambda i: i[0], non_key=lambda i: i[1:]):
447    iter_a = iter(iter_a)
448    iter_b = iter(iter_b)
449    item_a = next(iter_a, None)
450    item_b = next(iter_b, None)
451    while item_a is not None or item_b is not None:
452        while item_a and (item_b and key(item_a) < key(item_b) or
453                          item_b is None):
454            yield key(item_a), non_key(item_a), ()
455            item_a = next(iter_a, None)
456        while item_b and (item_a and key(item_b) < key(item_a) or
457                          item_a is None):
458            yield key(item_b), (), non_key(item_b)
459            item_b = next(iter_b, None)
460        if item_a is None or item_b is None:
461            continue
462        key_a = key(item_a)
463        if key_a == key(item_b):
464            yield key_a, non_key(item_a), non_key(item_b)
465            item_a = next(iter_a, None)
466            item_b = next(iter_b, None)
467
468
469class lrucache(object):
470    class node(object):
471        __slots__ = ('next', 'prev', 'key', 'value')
472
473        def __init__(self):
474            self.next = self.prev = None
475
476        def insert(self, after):
477            if self.next and self.prev:
478                self.next.prev = self.prev
479                self.prev.next = self.next
480            self.next = after.next
481            self.prev = after
482            after.next = self
483            self.next.prev = self
484
485        def detach(self):
486            assert self.next
487            assert self.prev
488            self.prev.next = self.next
489            self.next.prev = self.prev
490            self.next = self.prev = None
491
492    def __init__(self, size):
493        self._size = max(size, 2)
494        self._cache = {}
495        self._top = self.node()
496        self._top.next = self._top
497        self._top.prev = self._top
498
499    def __call__(self, func):
500        @wraps(func)
501        def wrapper(*args):
502            try:
503                return self[args]
504            except KeyError:
505                result = func(*args)
506                self[args] = result
507                return result
508        wrapper.invalidate = self.invalidate
509        return wrapper
510
511    def invalidate(self, *args):
512        if len(args) == 0:
513            keys = list(self._cache)
514            for k in keys:
515                del self[k]
516            return
517        try:
518            del self[args]
519        except KeyError:
520            pass
521
522    def __getitem__(self, key):
523        node = self._cache[key]
524        node.insert(self._top)
525        return node.value
526
527    def __setitem__(self, key, value):
528        if key in self._cache:
529            node = self._cache[key]
530        else:
531            node = self.node()
532            node.key = key
533
534        node.value = value
535        node.insert(self._top)
536
537        self._cache[key] = node
538        while len(self._cache) > self._size:
539            node = self._top.prev
540            node.detach()
541            del self._cache[node.key]
542
543    def __delitem__(self, key):
544        node = self._cache.pop(key)
545        node.detach()
546
547    def __len__(self):
548        node = self._top.next
549        count = 0
550        while node is not self._top:
551            count += 1
552            node = node.next
553        assert count == len(self._cache)
554        return len(self._cache)
555
556
557# The following class was copied from mercurial.
558#  Copyright 2005 K. Thananchayan <thananck@yahoo.com>
559#  Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
560#  Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
561class chunkbuffer(object):
562    """Allow arbitrary sized chunks of data to be efficiently read from an
563    iterator over chunks of arbitrary size."""
564
565    def __init__(self, in_iter):
566        """in_iter is the iterator that's iterating over the input chunks."""
567        def splitbig(chunks):
568            for chunk in chunks:
569                if len(chunk) > 2 ** 20:
570                    pos = 0
571                    while pos < len(chunk):
572                        end = pos + 2 ** 18
573                        yield chunk[pos:end]
574                        pos = end
575                else:
576                    yield chunk
577        self.iter = splitbig(in_iter)
578        self._queue = deque()
579        self._chunkoffset = 0
580
581    def read(self, l=None):
582        """Read L bytes of data from the iterator of chunks of data.
583        Returns less than L bytes if the iterator runs dry.
584
585        If size parameter is omitted, read everything"""
586        if l is None:
587            return b''.join(self.iter)
588
589        left = l
590        buf = []
591        queue = self._queue
592        while left > 0:
593            # refill the queue
594            if not queue:
595                target = 2 ** 18
596                for chunk in self.iter:
597                    queue.append(chunk)
598                    target -= len(chunk)
599                    if target <= 0:
600                        break
601                if not queue:
602                    break
603
604            # The easy way to do this would be to queue.popleft(), modify the
605            # chunk (if necessary), then queue.appendleft(). However, for cases
606            # where we read partial chunk content, this incurs 2 dequeue
607            # mutations and creates a new str for the remaining chunk in the
608            # queue. Our code below avoids this overhead.
609
610            chunk = queue[0]
611            chunkl = len(chunk)
612            offset = self._chunkoffset
613
614            # Use full chunk.
615            if offset == 0 and left >= chunkl:
616                left -= chunkl
617                queue.popleft()
618                buf.append(chunk)
619                # self._chunkoffset remains at 0.
620                continue
621
622            chunkremaining = chunkl - offset
623
624            # Use all of unconsumed part of chunk.
625            if left >= chunkremaining:
626                left -= chunkremaining
627                queue.popleft()
628                # offset == 0 is enabled by block above, so this won't merely
629                # copy via ``chunk[0:]``.
630                buf.append(chunk[offset:])
631                self._chunkoffset = 0
632
633            # Partial chunk needed.
634            else:
635                buf.append(chunk[offset:offset + left])
636                self._chunkoffset += left
637                left -= chunkremaining
638
639        return b''.join(buf)
640
641
642class HTTPReader(object):
643    def __init__(self, url):
644        url = fsdecode(url)
645        self.fh = urlopen(url)
646        # If the url was redirected, get the final url for possible future
647        # range requests.
648        self.url = self.fh.geturl()
649        try:
650            length = self.fh.headers['content-length']
651            self.length = None if length is None else int(length)
652        except (ValueError, KeyError):
653            self.length = None
654        self.can_recover = \
655            self.fh.headers.get('Accept-Ranges') == 'bytes'
656        self.backoff_period = 0
657        self.offset = 0
658        self.closed = False
659
660    def read(self, size):
661        result = []
662        length = 0
663        while length < size:
664            try:
665                buf = self.fh.read(size - length)
666            except socket.error:
667                buf = b''
668            if not buf:
669                # When self.length is None, self.offset < self.length is always
670                # false.
671                if self.can_recover and self.offset < self.length:
672                    while True:
673                        # Linear backoff.
674                        self.backoff_period += 1
675                        time.sleep(self.backoff_period)
676                        try:
677                            self.fh = self._reopen()
678                            break
679                        except Exception:
680                            if self.backoff_period >= 10:
681                                raise
682                    if self.fh:
683                        continue
684                break
685            length += len(buf)
686            self.offset += len(buf)
687            result.append(buf)
688        return b''.join(result)
689
690    def _reopen(self):
691        # This reopens the network connection with a HTTP Range request
692        # starting from self.offset.
693        req = Request(self.url)
694        req.add_header('Range', 'bytes=%d-' % self.offset)
695        fh = urlopen(req)
696        if fh.getcode() != 206:
697            return None
698        range = fh.headers.get('Content-Range') or ''
699        unit, _, range = range.partition(' ')
700        if unit != 'bytes':
701            return None
702        start, _, end = range.lstrip().partition('-')
703        try:
704            start = int(start)
705        except (TypeError, ValueError):
706            start = 0
707        if start > self.offset:
708            return None
709        logging.getLogger('httpreader').debug('Retrying from offset %d', start)
710        while start < self.offset:
711            l = len(fh.read(self.offset - start))
712            if not l:
713                return None
714            start += l
715        return fh
716
717    def readable(self):
718        return True
719
720    def readinto(self, b):
721        buf = self.read(len(b))
722        b[:len(buf)] = buf
723        return len(buf)
724
725
726# Transforms a File object without seek() or tell() into one that has.
727# This only implements enough to make GzipFile happy. It wants to seek to
728# the end of the file and back ; it also rewinds 8 bytes for the CRC.
729class Seekable(object):
730    def __init__(self, reader, length):
731        self._reader = reader
732        self._length = length
733        self._read = 0
734        self._pos = 0
735        self._buf = b''
736
737    def read(self, length):
738        if self._pos < self._read:
739            assert self._read - self._pos <= 8
740            assert length <= len(self._buf)
741            data = self._buf[:length]
742            self._buf = self._buf[length:]
743            self._pos += length
744        else:
745            assert self._read == self._pos
746            data = self._reader.read(length)
747            self._read += len(data)
748            self._pos = self._read
749            # Keep the last 8 bytes we read for GzipFile
750            self._buf = data[-8:]
751        return data
752
753    def tell(self):
754        return self._pos
755
756    def seek(self, pos, how=os.SEEK_SET):
757        if how == os.SEEK_END:
758            if pos:
759                raise NotImplementedError()
760            self._pos = self._length
761        elif how == os.SEEK_SET:
762            self._pos = pos
763        elif how == os.SEEK_CUR:
764            self._pos += pos
765        else:
766            raise NotImplementedError()
767        return self._pos
768
769
770class Process(object):
771    def __init__(self, *args, **kwargs):
772        stdin = kwargs.pop('stdin', None)
773        stdout = kwargs.pop('stdout', subprocess.PIPE)
774        stderr = kwargs.pop('stderr', None)
775        logger = kwargs.pop('logger', args[0])
776        env = kwargs.pop('env', {})
777        cwd = kwargs.pop('cwd', None)
778        assert not kwargs
779        if isinstance(stdin, (str, Iterable)):
780            proc_stdin = subprocess.PIPE
781        else:
782            proc_stdin = stdin
783
784        full_env = VersionedDict(environ())
785        if env:
786            full_env.update(env)
787
788        self._proc = self._popen(args, stdin=proc_stdin, stdout=stdout,
789                                 stderr=stderr, env=full_env, cwd=cwd)
790
791        logger = logging.getLogger(logger)
792        if logger.isEnabledFor(logging.INFO):
793            self._stdin = IOLogger(logger, self._proc.stdout, self._proc.stdin,
794                                   prefix='[%d]' % self.pid)
795        else:
796            self._stdin = self._proc.stdin
797
798        if logger.isEnabledFor(logging.DEBUG):
799            self._stdout = self._stdin
800        else:
801            self._stdout = self._proc.stdout
802
803        if proc_stdin == subprocess.PIPE:
804            if isinstance(stdin, str):
805                self._stdin.write(stdin)
806            elif isinstance(stdin, Iterable):
807                for line in stdin:
808                    self._stdin.write(b'%s\n' % line)
809            if proc_stdin != stdin:
810                self._proc.stdin.close()
811
812    def _env_strings(self, env):
813        for k, v in sorted((k, v) for s, k, v in env.iterchanges()
814                           if s != env.REMOVED):
815            yield '%s=%s' % (k, v)
816
817    def _popen(self, cmd, env, **kwargs):
818        assert isinstance(env, VersionedDict)
819        logger = logging.getLogger('process')
820        if logger.isEnabledFor(logging.INFO):
821            full_cmd = ' '.join(chain(self._env_strings(env), cmd))
822        if not getattr(os, 'supports_bytes_environ', True):
823            env = {
824                fsdecode(k): fsdecode(v) for k, v in iteritems(env)
825            }
826        proc = subprocess.Popen(cmd, env=env, **kwargs)
827        if logger.isEnabledFor(logging.INFO):
828            logger.info('[%d] %s', proc.pid, full_cmd)
829        return proc
830
831    def wait(self):
832        for fh in (self._proc.stdin, self._proc.stdout, self._proc.stderr):
833            if fh:
834                fh.close()
835        pid = self._proc.pid
836        retcode = self._proc.wait()
837        logger = logging.getLogger('process')
838        if logger.isEnabledFor(logging.INFO):
839            logger.info('[%d] Exited with code %d', pid, retcode)
840        return retcode
841
842    @property
843    def pid(self):
844        return self._proc.pid
845
846    @property
847    def stdin(self):
848        return self._stdin
849
850    @property
851    def stdout(self):
852        return self._stdout
853
854    @property
855    def stderr(self):
856        return self._proc.stderr
857
858
859class TypedProperty(object):
860    def __init__(self, cls):
861        self.cls = cls
862        self.values = WeakKeyDictionary()
863
864    def __get__(self, obj, cls=None):
865        return self.values.get(obj)
866
867    def __set__(self, obj, value):
868        # If the class has a "from_obj" static or class method, use it.
869        # Otherwise, just use cls(value)
870        self.values[obj] = getattr(self.cls, 'from_obj', self.cls)(value)
871
872
873class MemoryCPUReporter(Thread):
874    def __init__(self, memory=False, cpu=False):
875        assert memory or cpu
876        super(MemoryCPUReporter, self).__init__()
877        self._queue = Queue(1)
878        self._logger = logging.getLogger('report')
879        self._logger.setLevel(logging.INFO)
880        self._format = '[%s(%d)] %r'
881        if memory and cpu:
882            self._format += ' %r'
883            self._info = lambda p: (p.memory_info(), p.cpu_times())
884        elif memory:
885            self._info = lambda p: (p.memory_info(),)
886        elif cpu:
887            self._info = lambda p: (p.cpu_times(),)
888        self.start()
889
890    def _report(self, proc):
891        self._logger.info(
892            self._format, proc.name(), proc.pid, *self._info(proc))
893
894    def run(self):
895        import psutil
896        proc = psutil.Process()
897        while True:
898            try:
899                self._queue.get(True, 1)
900                break
901            except Empty:
902                pass
903            except Exception:
904                break
905            finally:
906                children = proc.children(recursive=True)
907                self._report(proc)
908                for p in children:
909                    self._report(p)
910
911    def shutdown(self):
912        self._queue.put(None)
913        self.join()
914
915
916class VersionCheck(Thread):
917    def __init__(self):
918        super(VersionCheck, self).__init__()
919        self.message = None
920        self.start()
921
922    def run(self):
923        from cinnabar import VERSION
924        from cinnabar.git import Git, GitProcess
925        from distutils.version import StrictVersion
926        parent_dir = os.path.dirname(os.path.dirname(__file__))
927        if not os.path.exists(os.path.join(parent_dir, '.git')) or \
928                check_enabled('no-version-check') or \
929                not interval_expired('version-check', 86400, globl=True):
930            return
931        REPO = 'https://github.com/glandium/git-cinnabar'
932        devnull = open(os.devnull, 'wb')
933        if VERSION.endswith('a'):
934            _, _, extra = StrictVersion(VERSION[:-1]).version
935            ref = 'refs/heads/next' if extra == 0 else 'refs/heads/master'
936            for line in Git.iter('ls-remote', REPO, ref, stderr=devnull):
937                sha1, head = line.split()
938                if fsdecode(head) != ref:
939                    continue
940                proc = GitProcess(
941                    '-C', parent_dir, 'merge-base', '--is-ancestor', sha1,
942                    'HEAD', stdout=devnull, stderr=devnull)
943                if proc.wait() != 0:
944                    self.message = (
945                        'The `{}` branch of git-cinnabar was updated. '
946                        'Please update your copy.\n'
947                        'You can switch to the `release` branch if you want '
948                        'to reduce these update notifications.'
949                        .format(ref.partition('refs/heads/')[-1]))
950                    break
951        else:
952            version = StrictVersion(VERSION)
953            newer_version = version
954            for line in Git.iter('ls-remote', REPO, 'refs/tags/*',
955                                 stderr=devnull):
956                sha1, tag = line.split()
957                tag = fsdecode(tag).partition('refs/tags/')[-1]
958                try:
959                    v = StrictVersion(tag)
960                except ValueError:
961                    continue
962                if v > newer_version:
963                    newer_version = v
964            if newer_version != version:
965                self.message = (
966                    'New git-cinnabar version available: {} '
967                    '(current version: {})'
968                    .format(newer_version, version))
969
970    def join(self):
971        super(VersionCheck, self).join()
972        if self.message:
973            sys.stderr.write('\n' + self.message + '\n')
974
975
976def run(func, args):
977    reexec = None
978    if os.environ.pop('GIT_CINNABAR_COVERAGE', None):
979        if not reexec:
980            reexec = [sys.executable]
981        reexec.extend(['-m', 'coverage', 'run', '--append'])
982    init_logging()
983    if reexec:
984        reexec.append(os.path.abspath(sys.argv[0]))
985        reexec.extend(sys.argv[1:])
986        os.execlp(reexec[0], *reexec)
987        assert False
988    if check_enabled('memory') or check_enabled('cpu'):
989        reporter = MemoryCPUReporter(memory=check_enabled('memory'),
990                                     cpu=check_enabled('cpu'))
991
992    version_check = VersionCheck()
993    try:
994        from cinnabar.git import Git
995        objectformat = Git.config('extensions.objectformat') or 'sha1'
996        if objectformat != 'sha1':
997            sys.stderr.write(
998                'Git repository uses unsupported %s object format\n'
999                % objectformat)
1000            retcode = 65  # Data format error
1001        else:
1002            retcode = func(args)
1003    except Abort as e:
1004        # These exceptions are normal abort and require no traceback
1005        retcode = 1
1006        logging.error(str(e))
1007    except Exception as e:
1008        # Catch all exceptions and provide a nice message
1009        retcode = 70  # Internal software error
1010        message = getattr(e, 'message', None) or getattr(e, 'reason', None)
1011        message = message or str(e)
1012        if check_enabled('traceback') or not message:
1013            traceback.print_exc()
1014        else:
1015            logging.error(message)
1016
1017            sys.stderr.write(
1018                'Run the command again with '
1019                '`git -c cinnabar.check=traceback <command>` to see the '
1020                'full traceback.\n')
1021    finally:
1022        if check_enabled('memory') or check_enabled('cpu'):
1023            reporter.shutdown()
1024        version_check.join()
1025    if check_enabled('no-mercurial'):
1026        if any(k.startswith('mercurial.') or k == 'mercurial'
1027               for k in sys.modules):
1028            sys.stderr.write('Mercurial libraries were loaded!')
1029            retcode = 70
1030    sys.exit(retcode)
1031
1032
1033# Python3 compat
1034if sys.version_info[0] == 3:
1035    def iteritems(d):
1036        return iter(d.items())
1037
1038    def itervalues(d):
1039        return iter(d.values())
1040
1041    fsencode = os.fsencode
1042    fsdecode = os.fsdecode
1043
1044    def environ(k=None):
1045        if os.supports_bytes_environ:
1046            if k is None:
1047                return os.environb
1048            return os.environb.get(k)
1049
1050        if k is None:
1051            return {
1052                fsencode(k): fsencode(v)
1053                for k, v in iteritems(os.environ)
1054            }
1055        v = os.environ.get(fsdecode(k))
1056        if v is None:
1057            return None
1058        return fsencode(v)
1059else:
1060    def iteritems(d):
1061        return d.iteritems()
1062
1063    def itervalues(d):
1064        return d.itervalues()
1065
1066    def fsencode(s):
1067        return s
1068
1069    def fsdecode(s):
1070        return s
1071
1072    def environ(k=None):
1073        if k is None:
1074            return os.environ
1075        return os.environ.get(k)
1076
1077if hasattr(sys.stdout, 'buffer'):
1078    bytes_stdout = sys.stdout.buffer
1079    bytes_stdin = sys.stdin.buffer
1080else:
1081    bytes_stdout = sys.stdout
1082    bytes_stdin = sys.stdin
1083