1from __future__ import division
2from __future__ import print_function
3
4import argparse
5import json
6import netrc
7import ntpath
8import os
9import platform
10import re
11import subprocess
12import sys
13import types
14from datetime import datetime
15from decimal import Decimal
16from functools import partial
17from os.path import basename
18from os.path import dirname
19from os.path import exists
20from os.path import join
21from os.path import split
22
23import genericpath
24
25from .compat import PY3
26from .compat import PY38
27
28# This is here (in the utils module) because it might be used by
29# various other modules.
30try:
31    from pathlib2 import Path  # noqa: F401
32except ImportError:
33    from pathlib import Path  # noqa: F401
34
35try:
36    from urllib.parse import parse_qs
37    from urllib.parse import urlparse
38except ImportError:
39    from urlparse import parse_qs
40    from urlparse import urlparse
41
42try:
43    from subprocess import CalledProcessError
44    from subprocess import check_output
45except ImportError:
46    class CalledProcessError(subprocess.CalledProcessError):
47        def __init__(self, returncode, cmd, output=None):
48            super(CalledProcessError, self).__init__(returncode, cmd)
49            self.output = output
50
51    def check_output(*popenargs, **kwargs):
52        if 'stdout' in kwargs:
53            raise ValueError('stdout argument not allowed, it will be overridden.')
54        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
55        output, unused_err = process.communicate()
56        retcode = process.poll()
57        if retcode:
58            cmd = kwargs.get("args")
59            if cmd is None:
60                cmd = popenargs[0]
61            raise CalledProcessError(retcode, cmd, output)
62        return output
63
64TIME_UNITS = {
65    "": "Seconds",
66    "m": "Milliseconds (ms)",
67    "u": "Microseconds (us)",
68    "n": "Nanoseconds (ns)"
69}
70ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"]
71
72
73class SecondsDecimal(Decimal):
74    def __float__(self):
75        return float(super(SecondsDecimal, self).__str__())
76
77    def __str__(self):
78        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
79
80    @property
81    def as_string(self):
82        return super(SecondsDecimal, self).__str__()
83
84
85class NameWrapper(object):
86    def __init__(self, target):
87        self.target = target
88
89    def __str__(self):
90        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
91        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
92        return name
93
94    def __repr__(self):
95        return "NameWrapper(%s)" % repr(self.target)
96
97
98def get_tag(project_name=None):
99    info = get_commit_info(project_name)
100    parts = [info['id'], get_current_time()]
101    if info['dirty']:
102        parts.append("uncommited-changes")
103    return "_".join(parts)
104
105
106def get_machine_id():
107    return "%s-%s-%s-%s" % (
108        platform.system(),
109        platform.python_implementation(),
110        ".".join(platform.python_version_tuple()[:2]),
111        platform.architecture()[0]
112    )
113
114
115class Fallback(object):
116    def __init__(self, fallback, exceptions):
117        self.fallback = fallback
118        self.functions = []
119        self.exceptions = exceptions
120
121    def __call__(self, *args, **kwargs):
122        for func in self.functions:
123            try:
124                value = func(*args, **kwargs)
125            except self.exceptions:
126                continue
127            else:
128                if value:
129                    return value
130        else:
131            return self.fallback(*args, **kwargs)
132
133    def register(self, other):
134        self.functions.append(other)
135        return self
136
137
138@partial(Fallback, exceptions=(IndexError, CalledProcessError, OSError))
139def get_project_name():
140    return basename(os.getcwd())
141
142
143@get_project_name.register
144def get_project_name_git():
145    is_git = check_output(['git', 'rev-parse', '--git-dir'], stderr=subprocess.STDOUT)
146    if is_git:
147        project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
148        if isinstance(project_address, bytes) and str != bytes:
149            project_address = project_address.decode()
150        project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1]
151        return project_name.strip()
152
153
154@get_project_name.register
155def get_project_name_hg():
156    with open(os.devnull, 'w') as devnull:
157        project_address = check_output(['hg', 'path', 'default'], stderr=devnull)
158    project_address = project_address.decode()
159    project_name = project_address.split("/")[-1]
160    return project_name.strip()
161
162
163def in_any_parent(name, path=None):
164    prev = None
165    if not path:
166        path = os.getcwd()
167    while path and prev != path and not exists(join(path, name)):
168        prev = path
169        path = dirname(path)
170    return exists(join(path, name))
171
172
173def subprocess_output(cmd):
174    return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip()
175
176
177def get_commit_info(project_name=None):
178    dirty = False
179    commit = 'unversioned'
180    commit_time = None
181    author_time = None
182    project_name = project_name or get_project_name()
183    branch = '(unknown)'
184    try:
185        if in_any_parent('.git'):
186            desc = subprocess_output('git describe --dirty --always --long --abbrev=40')
187            desc = desc.split('-')
188            if desc[-1].strip() == 'dirty':
189                dirty = True
190                desc.pop()
191            commit = desc[-1].strip('g')
192            commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"')
193            author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"')
194            branch = subprocess_output('git rev-parse --abbrev-ref HEAD')
195            if branch == 'HEAD':
196                branch = '(detached head)'
197        elif in_any_parent('.hg'):
198            desc = subprocess_output('hg id --id --debug')
199            if desc[-1] == '+':
200                dirty = True
201            commit = desc.strip('+')
202            commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"')
203            branch = subprocess_output('hg branch')
204        return {
205            'id': commit,
206            'time': commit_time,
207            'author_time': author_time,
208            'dirty': dirty,
209            'project': project_name,
210            'branch': branch,
211        }
212    except Exception as exc:
213        return {
214            'id': 'unknown',
215            'time': None,
216            'author_time': None,
217            'dirty': dirty,
218            'error': 'CalledProcessError({0.returncode}, {0.output!r})'.format(exc)
219                     if isinstance(exc, CalledProcessError) else repr(exc),
220            'project': project_name,
221            'branch': branch,
222        }
223
224
225def get_current_time():
226    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
227
228
229def first_or_value(obj, value):
230    if obj:
231        value, = obj
232
233    return value
234
235
236def short_filename(path, machine_id=None):
237    parts = []
238    try:
239        last = len(path.parts) - 1
240    except AttributeError:
241        return str(path)
242    for pos, part in enumerate(path.parts):
243        if not pos and part == machine_id:
244            continue
245        if pos == last:
246            part = part.rsplit('.', 1)[0]
247            # if len(part) > 16:
248            #     part = "%.13s..." % part
249        parts.append(part)
250    return '/'.join(parts)
251
252
253def load_timer(string):
254    if "." not in string:
255        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
256    mod, attr = string.rsplit(".", 1)
257    if mod == 'pep418':
258        if PY3:
259            import time
260            return NameWrapper(getattr(time, attr))
261        else:
262            from . import pep418
263            return NameWrapper(getattr(pep418, attr))
264    else:
265        __import__(mod)
266        mod = sys.modules[mod]
267        return NameWrapper(getattr(mod, attr))
268
269
270class RegressionCheck(object):
271    def __init__(self, field, threshold):
272        self.field = field
273        self.threshold = threshold
274
275    def fails(self, current, compared):
276        val = self.compute(current, compared)
277        if val > self.threshold:
278            return "Field %r has failed %s: %.9f > %.9f" % (
279                self.field, self.__class__.__name__, val, self.threshold
280            )
281
282
283class PercentageRegressionCheck(RegressionCheck):
284    def compute(self, current, compared):
285        val = compared[self.field]
286        if not val:
287            return float("inf")
288        return current[self.field] / val * 100 - 100
289
290
291class DifferenceRegressionCheck(RegressionCheck):
292    def compute(self, current, compared):
293        return current[self.field] - compared[self.field]
294
295
296def parse_compare_fail(string,
297                       rex=re.compile(r'^(?P<field>min|max|mean|median|stddev|iqr):'
298                                      r'((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
299                                      r'0-9]+)?))$')):
300    m = rex.match(string)
301    if m:
302        g = m.groupdict()
303        if g['percentage']:
304            return PercentageRegressionCheck(g['field'], int(g['percentage']))
305        elif g['difference']:
306            return DifferenceRegressionCheck(g['field'], float(g['difference']))
307
308    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
309
310
311def parse_warmup(string):
312    string = string.lower().strip()
313    if string == "auto":
314        return platform.python_implementation() == "PyPy"
315    elif string in ["off", "false", "no"]:
316        return False
317    elif string in ["on", "true", "yes", ""]:
318        return True
319    else:
320        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
321
322
323def name_formatter_short(bench):
324    name = bench["name"]
325    if bench["source"]:
326        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
327    if name.startswith("test_"):
328        name = name[5:]
329    return name
330
331
332def name_formatter_normal(bench):
333    name = bench["name"]
334    if bench["source"]:
335        parts = bench["source"].split('/')
336        parts[-1] = parts[-1][:12]
337        name = "%s (%s)" % (name, '/'.join(parts))
338    return name
339
340
341def name_formatter_long(bench):
342    if bench["source"]:
343        return "%(fullname)s (%(source)s)" % bench
344    else:
345        return bench["fullname"]
346
347
348def name_formatter_trial(bench):
349    if bench["source"]:
350        return "%.4s" % split(bench["source"])[-1]
351    else:
352        return '????'
353
354
355NAME_FORMATTERS = {
356    "short": name_formatter_short,
357    "normal": name_formatter_normal,
358    "long": name_formatter_long,
359    "trial": name_formatter_trial,
360}
361
362
363def parse_name_format(string):
364    string = string.lower().strip()
365    if string in NAME_FORMATTERS:
366        return string
367    else:
368        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
369
370
371def parse_timer(string):
372    return str(load_timer(string))
373
374
375def parse_sort(string):
376    string = string.lower().strip()
377    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
378        raise argparse.ArgumentTypeError(
379            "Unacceptable value: %r. "
380            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
381            "'stddev', 'name', 'fullname'." % string)
382    return string
383
384
385def parse_columns(string):
386    columns = [str.strip(s) for s in string.lower().split(',')]
387    invalid = set(columns) - set(ALLOWED_COLUMNS)
388    if invalid:
389        # there are extra items in columns!
390        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
391        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
392        raise argparse.ArgumentTypeError(msg)
393    return columns
394
395
396def parse_rounds(string):
397    try:
398        value = int(string)
399    except ValueError as exc:
400        raise argparse.ArgumentTypeError(exc)
401    else:
402        if value < 1:
403            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
404        return value
405
406
407def parse_seconds(string):
408    try:
409        return SecondsDecimal(string).as_string
410    except Exception as exc:
411        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
412
413
414def parse_save(string):
415    if not string:
416        raise argparse.ArgumentTypeError("Can't be empty.")
417    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
418    if illegal:
419        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
420    return string
421
422
423def _parse_hosts(storage_url, netrc_file):
424
425    # load creds from netrc file
426    path = os.path.expanduser(netrc_file)
427    creds = None
428    if netrc_file and os.path.isfile(path):
429        creds = netrc.netrc(path)
430
431    # add creds to urls
432    urls = []
433    for netloc in storage_url.netloc.split(','):
434        auth = ""
435        if creds and '@' not in netloc:
436            host = netloc.split(':').pop(0)
437            res = creds.authenticators(host)
438            if res:
439                user, _, secret = res
440                auth = "{user}:{secret}@".format(user=user, secret=secret)
441        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
442                                                 netloc=netloc, auth=auth)
443        urls.append(url)
444    return urls
445
446
447def parse_elasticsearch_storage(string, default_index="benchmark",
448                                default_doctype="benchmark", netrc_file=''):
449    storage_url = urlparse(string)
450    hosts = _parse_hosts(storage_url, netrc_file)
451    index = default_index
452    doctype = default_doctype
453    if storage_url.path and storage_url.path != "/":
454        splitted = storage_url.path.strip("/").split("/")
455        index = splitted[0]
456        if len(splitted) >= 2:
457            doctype = splitted[1]
458    query = parse_qs(storage_url.query)
459    try:
460        project_name = query["project_name"][0]
461    except KeyError:
462        project_name = get_project_name()
463    return hosts, index, doctype, project_name
464
465
466def load_storage(storage, **kwargs):
467    if "://" not in storage:
468        storage = "file://" + storage
469    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
470    if storage.startswith("file://"):
471        from .storage.file import FileStorage
472        return FileStorage(storage[len("file://"):], **kwargs)
473    elif storage.startswith("elasticsearch+"):
474        from .storage.elasticsearch import ElasticsearchStorage
475
476        # TODO update benchmark_autosave
477        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
478                                           netrc_file=netrc_file)
479        return ElasticsearchStorage(*args, **kwargs)
480    else:
481        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
482                                         "elasticsearch+http[s]://host1,host2/index/doctype")
483
484
485def time_unit(value):
486    if value < 1e-6:
487        return "n", 1e9
488    elif value < 1e-3:
489        return "u", 1e6
490    elif value < 1:
491        return "m", 1e3
492    else:
493        return "", 1.
494
495
496def operations_unit(value):
497    if value > 1e+6:
498        return "M", 1e-6
499    if value > 1e+3:
500        return "K", 1e-3
501    return "", 1.
502
503
504def format_time(value):
505    unit, adjustment = time_unit(value)
506    return "{0:.2f}{1:s}".format(value * adjustment, unit)
507
508
509class cached_property(object):
510    def __init__(self, func):
511        self.__doc__ = getattr(func, '__doc__')
512        self.func = func
513
514    def __get__(self, obj, cls):
515        if obj is None:
516            return self
517        value = obj.__dict__[self.func.__name__] = self.func(obj)
518        return value
519
520
521def funcname(f):
522    try:
523        if isinstance(f, partial):
524            return f.func.__name__
525        else:
526            return f.__name__
527    except AttributeError:
528        return str(f)
529
530
531# from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
532def clonefunc(f):
533    """Deep clone the given function to create a new one.
534
535    By default, the PyPy JIT specializes the assembler based on f.__code__:
536    clonefunc makes sure that you will get a new function with a **different**
537    __code__, so that PyPy will produce independent assembler. This is useful
538    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
539    apples to apples.
540
541    Use it with caution: if abused, this might easily produce an explosion of
542    produced assembler.
543    """
544    # first of all, we clone the code object
545    if not hasattr(f, '__code__'):
546        return f
547    co = f.__code__
548    args = [co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
549            co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
550            co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars]
551    if PY38:
552        args.insert(1, co.co_posonlyargcount)
553    if PY3:
554        args.insert(1, co.co_kwonlyargcount)
555    co2 = types.CodeType(*args)
556    #
557    # then, we clone the function itself, using the new co2
558    f2 = types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
559    return f2
560
561
562def format_dict(obj):
563    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
564
565
566class SafeJSONEncoder(json.JSONEncoder):
567    def default(self, o):
568        return "UNSERIALIZABLE[%r]" % o
569
570
571def safe_dumps(obj, **kwargs):
572    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
573
574
575def report_progress(iterable, terminal_reporter, format_string, **kwargs):
576    total = len(iterable)
577
578    def progress_reporting_wrapper():
579        for pos, item in enumerate(iterable):
580            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
581            terminal_reporter.rewrite(string, black=True, bold=True)
582            yield string, item
583
584    return progress_reporting_wrapper()
585
586
587def report_noprogress(iterable, *args, **kwargs):
588    for pos, item in enumerate(iterable):
589        yield "", item
590
591
592def report_online_progress(progress_reporter, tr, line):
593    next(progress_reporter([line], tr, "{value}"))
594
595
596def slugify(name):
597    for c in r"\/:*?<>| ":
598        name = name.replace(c, '_').replace('__', '_')
599    return name
600
601
602def commonpath(paths):
603    """Given a sequence of path names, returns the longest common sub-path."""
604
605    if not paths:
606        raise ValueError('commonpath() arg is an empty sequence')
607
608    if isinstance(paths[0], bytes):
609        sep = b'\\'
610        altsep = b'/'
611        curdir = b'.'
612    else:
613        sep = '\\'
614        altsep = '/'
615        curdir = '.'
616
617    try:
618        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
619        split_paths = [p.split(sep) for d, p in drivesplits]
620
621        try:
622            isabs, = set(p[:1] == sep for d, p in drivesplits)
623        except ValueError:
624            raise ValueError("Can't mix absolute and relative paths")
625
626        # Check that all drive letters or UNC paths match. The check is made only
627        # now otherwise type errors for mixing strings and bytes would not be
628        # caught.
629        if len(set(d for d, p in drivesplits)) != 1:
630            raise ValueError("Paths don't have the same drive")
631
632        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
633        common = path.split(sep)
634        common = [c for c in common if c and c != curdir]
635
636        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
637        s1 = min(split_paths)
638        s2 = max(split_paths)
639        for i, c in enumerate(s1):
640            if c != s2[i]:
641                common = common[:i]
642                break
643        else:
644            common = common[:len(s1)]
645
646        prefix = drive + sep if isabs else drive
647        return prefix + sep.join(common)
648    except (TypeError, AttributeError):
649        genericpath._check_arg_types('commonpath', *paths)
650        raise
651
652
653def get_cprofile_functions(stats):
654    """
655    Convert pstats structure to list of sorted dicts about each function.
656    """
657    result = []
658    # this assumes that you run py.test from project root dir
659    project_dir_parent = dirname(os.getcwd())
660
661    for function_info, run_info in stats.stats.items():
662        file_path = function_info[0]
663        if file_path.startswith(project_dir_parent):
664            file_path = file_path[len(project_dir_parent):].lstrip('/')
665        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
666
667        # if the function is recursive write number of 'total calls/primitive calls'
668        if run_info[0] == run_info[1]:
669            calls = str(run_info[0])
670        else:
671            calls = '{1}/{0}'.format(run_info[0], run_info[1])
672
673        result.append(dict(ncalls_recursion=calls,
674                           ncalls=run_info[1],
675                           tottime=run_info[2],
676                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
677                           cumtime=run_info[3],
678                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
679                           function_name=function_name))
680
681    return result
682