1from __future__ import division 2from __future__ import print_function 3 4import argparse 5import json 6import netrc 7import ntpath 8import os 9import platform 10import re 11import subprocess 12import sys 13import types 14from datetime import datetime 15from decimal import Decimal 16from functools import partial 17from os.path import basename 18from os.path import dirname 19from os.path import exists 20from os.path import join 21from os.path import split 22 23import genericpath 24 25from .compat import PY3 26from .compat import PY38 27 28# This is here (in the utils module) because it might be used by 29# various other modules. 30try: 31 from pathlib2 import Path # noqa: F401 32except ImportError: 33 from pathlib import Path # noqa: F401 34 35try: 36 from urllib.parse import parse_qs 37 from urllib.parse import urlparse 38except ImportError: 39 from urlparse import parse_qs 40 from urlparse import urlparse 41 42try: 43 from subprocess import CalledProcessError 44 from subprocess import check_output 45except ImportError: 46 class CalledProcessError(subprocess.CalledProcessError): 47 def __init__(self, returncode, cmd, output=None): 48 super(CalledProcessError, self).__init__(returncode, cmd) 49 self.output = output 50 51 def check_output(*popenargs, **kwargs): 52 if 'stdout' in kwargs: 53 raise ValueError('stdout argument not allowed, it will be overridden.') 54 process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) 55 output, unused_err = process.communicate() 56 retcode = process.poll() 57 if retcode: 58 cmd = kwargs.get("args") 59 if cmd is None: 60 cmd = popenargs[0] 61 raise CalledProcessError(retcode, cmd, output) 62 return output 63 64TIME_UNITS = { 65 "": "Seconds", 66 "m": "Milliseconds (ms)", 67 "u": "Microseconds (us)", 68 "n": "Nanoseconds (ns)" 69} 70ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"] 71 72 73class SecondsDecimal(Decimal): 74 def __float__(self): 75 return float(super(SecondsDecimal, self).__str__()) 76 77 def __str__(self): 78 return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__()))) 79 80 @property 81 def as_string(self): 82 return super(SecondsDecimal, self).__str__() 83 84 85class NameWrapper(object): 86 def __init__(self, target): 87 self.target = target 88 89 def __str__(self): 90 name = self.target.__module__ + "." if hasattr(self.target, '__module__') else "" 91 name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target) 92 return name 93 94 def __repr__(self): 95 return "NameWrapper(%s)" % repr(self.target) 96 97 98def get_tag(project_name=None): 99 info = get_commit_info(project_name) 100 parts = [info['id'], get_current_time()] 101 if info['dirty']: 102 parts.append("uncommited-changes") 103 return "_".join(parts) 104 105 106def get_machine_id(): 107 return "%s-%s-%s-%s" % ( 108 platform.system(), 109 platform.python_implementation(), 110 ".".join(platform.python_version_tuple()[:2]), 111 platform.architecture()[0] 112 ) 113 114 115class Fallback(object): 116 def __init__(self, fallback, exceptions): 117 self.fallback = fallback 118 self.functions = [] 119 self.exceptions = exceptions 120 121 def __call__(self, *args, **kwargs): 122 for func in self.functions: 123 try: 124 value = func(*args, **kwargs) 125 except self.exceptions: 126 continue 127 else: 128 if value: 129 return value 130 else: 131 return self.fallback(*args, **kwargs) 132 133 def register(self, other): 134 self.functions.append(other) 135 return self 136 137 138@partial(Fallback, exceptions=(IndexError, CalledProcessError, OSError)) 139def get_project_name(): 140 return basename(os.getcwd()) 141 142 143@get_project_name.register 144def get_project_name_git(): 145 is_git = check_output(['git', 'rev-parse', '--git-dir'], stderr=subprocess.STDOUT) 146 if is_git: 147 project_address = check_output(['git', 'config', '--local', 'remote.origin.url']) 148 if isinstance(project_address, bytes) and str != bytes: 149 project_address = project_address.decode() 150 project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1] 151 return project_name.strip() 152 153 154@get_project_name.register 155def get_project_name_hg(): 156 with open(os.devnull, 'w') as devnull: 157 project_address = check_output(['hg', 'path', 'default'], stderr=devnull) 158 project_address = project_address.decode() 159 project_name = project_address.split("/")[-1] 160 return project_name.strip() 161 162 163def in_any_parent(name, path=None): 164 prev = None 165 if not path: 166 path = os.getcwd() 167 while path and prev != path and not exists(join(path, name)): 168 prev = path 169 path = dirname(path) 170 return exists(join(path, name)) 171 172 173def subprocess_output(cmd): 174 return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip() 175 176 177def get_commit_info(project_name=None): 178 dirty = False 179 commit = 'unversioned' 180 commit_time = None 181 author_time = None 182 project_name = project_name or get_project_name() 183 branch = '(unknown)' 184 try: 185 if in_any_parent('.git'): 186 desc = subprocess_output('git describe --dirty --always --long --abbrev=40') 187 desc = desc.split('-') 188 if desc[-1].strip() == 'dirty': 189 dirty = True 190 desc.pop() 191 commit = desc[-1].strip('g') 192 commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"') 193 author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"') 194 branch = subprocess_output('git rev-parse --abbrev-ref HEAD') 195 if branch == 'HEAD': 196 branch = '(detached head)' 197 elif in_any_parent('.hg'): 198 desc = subprocess_output('hg id --id --debug') 199 if desc[-1] == '+': 200 dirty = True 201 commit = desc.strip('+') 202 commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"') 203 branch = subprocess_output('hg branch') 204 return { 205 'id': commit, 206 'time': commit_time, 207 'author_time': author_time, 208 'dirty': dirty, 209 'project': project_name, 210 'branch': branch, 211 } 212 except Exception as exc: 213 return { 214 'id': 'unknown', 215 'time': None, 216 'author_time': None, 217 'dirty': dirty, 218 'error': 'CalledProcessError({0.returncode}, {0.output!r})'.format(exc) 219 if isinstance(exc, CalledProcessError) else repr(exc), 220 'project': project_name, 221 'branch': branch, 222 } 223 224 225def get_current_time(): 226 return datetime.utcnow().strftime("%Y%m%d_%H%M%S") 227 228 229def first_or_value(obj, value): 230 if obj: 231 value, = obj 232 233 return value 234 235 236def short_filename(path, machine_id=None): 237 parts = [] 238 try: 239 last = len(path.parts) - 1 240 except AttributeError: 241 return str(path) 242 for pos, part in enumerate(path.parts): 243 if not pos and part == machine_id: 244 continue 245 if pos == last: 246 part = part.rsplit('.', 1)[0] 247 # if len(part) > 16: 248 # part = "%.13s..." % part 249 parts.append(part) 250 return '/'.join(parts) 251 252 253def load_timer(string): 254 if "." not in string: 255 raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.") 256 mod, attr = string.rsplit(".", 1) 257 if mod == 'pep418': 258 if PY3: 259 import time 260 return NameWrapper(getattr(time, attr)) 261 else: 262 from . import pep418 263 return NameWrapper(getattr(pep418, attr)) 264 else: 265 __import__(mod) 266 mod = sys.modules[mod] 267 return NameWrapper(getattr(mod, attr)) 268 269 270class RegressionCheck(object): 271 def __init__(self, field, threshold): 272 self.field = field 273 self.threshold = threshold 274 275 def fails(self, current, compared): 276 val = self.compute(current, compared) 277 if val > self.threshold: 278 return "Field %r has failed %s: %.9f > %.9f" % ( 279 self.field, self.__class__.__name__, val, self.threshold 280 ) 281 282 283class PercentageRegressionCheck(RegressionCheck): 284 def compute(self, current, compared): 285 val = compared[self.field] 286 if not val: 287 return float("inf") 288 return current[self.field] / val * 100 - 100 289 290 291class DifferenceRegressionCheck(RegressionCheck): 292 def compute(self, current, compared): 293 return current[self.field] - compared[self.field] 294 295 296def parse_compare_fail(string, 297 rex=re.compile(r'^(?P<field>min|max|mean|median|stddev|iqr):' 298 r'((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?[' 299 r'0-9]+)?))$')): 300 m = rex.match(string) 301 if m: 302 g = m.groupdict() 303 if g['percentage']: 304 return PercentageRegressionCheck(g['field'], int(g['percentage'])) 305 elif g['difference']: 306 return DifferenceRegressionCheck(g['field'], float(g['difference'])) 307 308 raise argparse.ArgumentTypeError("Could not parse value: %r." % string) 309 310 311def parse_warmup(string): 312 string = string.lower().strip() 313 if string == "auto": 314 return platform.python_implementation() == "PyPy" 315 elif string in ["off", "false", "no"]: 316 return False 317 elif string in ["on", "true", "yes", ""]: 318 return True 319 else: 320 raise argparse.ArgumentTypeError("Could not parse value: %r." % string) 321 322 323def name_formatter_short(bench): 324 name = bench["name"] 325 if bench["source"]: 326 name = "%s (%.4s)" % (name, split(bench["source"])[-1]) 327 if name.startswith("test_"): 328 name = name[5:] 329 return name 330 331 332def name_formatter_normal(bench): 333 name = bench["name"] 334 if bench["source"]: 335 parts = bench["source"].split('/') 336 parts[-1] = parts[-1][:12] 337 name = "%s (%s)" % (name, '/'.join(parts)) 338 return name 339 340 341def name_formatter_long(bench): 342 if bench["source"]: 343 return "%(fullname)s (%(source)s)" % bench 344 else: 345 return bench["fullname"] 346 347 348def name_formatter_trial(bench): 349 if bench["source"]: 350 return "%.4s" % split(bench["source"])[-1] 351 else: 352 return '????' 353 354 355NAME_FORMATTERS = { 356 "short": name_formatter_short, 357 "normal": name_formatter_normal, 358 "long": name_formatter_long, 359 "trial": name_formatter_trial, 360} 361 362 363def parse_name_format(string): 364 string = string.lower().strip() 365 if string in NAME_FORMATTERS: 366 return string 367 else: 368 raise argparse.ArgumentTypeError("Could not parse value: %r." % string) 369 370 371def parse_timer(string): 372 return str(load_timer(string)) 373 374 375def parse_sort(string): 376 string = string.lower().strip() 377 if string not in ("min", "max", "mean", "stddev", "name", "fullname"): 378 raise argparse.ArgumentTypeError( 379 "Unacceptable value: %r. " 380 "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', " 381 "'stddev', 'name', 'fullname'." % string) 382 return string 383 384 385def parse_columns(string): 386 columns = [str.strip(s) for s in string.lower().split(',')] 387 invalid = set(columns) - set(ALLOWED_COLUMNS) 388 if invalid: 389 # there are extra items in columns! 390 msg = "Invalid column name(s): %s. " % ', '.join(invalid) 391 msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS) 392 raise argparse.ArgumentTypeError(msg) 393 return columns 394 395 396def parse_rounds(string): 397 try: 398 value = int(string) 399 except ValueError as exc: 400 raise argparse.ArgumentTypeError(exc) 401 else: 402 if value < 1: 403 raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.") 404 return value 405 406 407def parse_seconds(string): 408 try: 409 return SecondsDecimal(string).as_string 410 except Exception as exc: 411 raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc)) 412 413 414def parse_save(string): 415 if not string: 416 raise argparse.ArgumentTypeError("Can't be empty.") 417 illegal = ''.join(c for c in r"\/:*?<>|" if c in string) 418 if illegal: 419 raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal) 420 return string 421 422 423def _parse_hosts(storage_url, netrc_file): 424 425 # load creds from netrc file 426 path = os.path.expanduser(netrc_file) 427 creds = None 428 if netrc_file and os.path.isfile(path): 429 creds = netrc.netrc(path) 430 431 # add creds to urls 432 urls = [] 433 for netloc in storage_url.netloc.split(','): 434 auth = "" 435 if creds and '@' not in netloc: 436 host = netloc.split(':').pop(0) 437 res = creds.authenticators(host) 438 if res: 439 user, _, secret = res 440 auth = "{user}:{secret}@".format(user=user, secret=secret) 441 url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme, 442 netloc=netloc, auth=auth) 443 urls.append(url) 444 return urls 445 446 447def parse_elasticsearch_storage(string, default_index="benchmark", 448 default_doctype="benchmark", netrc_file=''): 449 storage_url = urlparse(string) 450 hosts = _parse_hosts(storage_url, netrc_file) 451 index = default_index 452 doctype = default_doctype 453 if storage_url.path and storage_url.path != "/": 454 splitted = storage_url.path.strip("/").split("/") 455 index = splitted[0] 456 if len(splitted) >= 2: 457 doctype = splitted[1] 458 query = parse_qs(storage_url.query) 459 try: 460 project_name = query["project_name"][0] 461 except KeyError: 462 project_name = get_project_name() 463 return hosts, index, doctype, project_name 464 465 466def load_storage(storage, **kwargs): 467 if "://" not in storage: 468 storage = "file://" + storage 469 netrc_file = kwargs.pop('netrc') # only used by elasticsearch storage 470 if storage.startswith("file://"): 471 from .storage.file import FileStorage 472 return FileStorage(storage[len("file://"):], **kwargs) 473 elif storage.startswith("elasticsearch+"): 474 from .storage.elasticsearch import ElasticsearchStorage 475 476 # TODO update benchmark_autosave 477 args = parse_elasticsearch_storage(storage[len("elasticsearch+"):], 478 netrc_file=netrc_file) 479 return ElasticsearchStorage(*args, **kwargs) 480 else: 481 raise argparse.ArgumentTypeError("Storage must be in form of file://path or " 482 "elasticsearch+http[s]://host1,host2/index/doctype") 483 484 485def time_unit(value): 486 if value < 1e-6: 487 return "n", 1e9 488 elif value < 1e-3: 489 return "u", 1e6 490 elif value < 1: 491 return "m", 1e3 492 else: 493 return "", 1. 494 495 496def operations_unit(value): 497 if value > 1e+6: 498 return "M", 1e-6 499 if value > 1e+3: 500 return "K", 1e-3 501 return "", 1. 502 503 504def format_time(value): 505 unit, adjustment = time_unit(value) 506 return "{0:.2f}{1:s}".format(value * adjustment, unit) 507 508 509class cached_property(object): 510 def __init__(self, func): 511 self.__doc__ = getattr(func, '__doc__') 512 self.func = func 513 514 def __get__(self, obj, cls): 515 if obj is None: 516 return self 517 value = obj.__dict__[self.func.__name__] = self.func(obj) 518 return value 519 520 521def funcname(f): 522 try: 523 if isinstance(f, partial): 524 return f.func.__name__ 525 else: 526 return f.__name__ 527 except AttributeError: 528 return str(f) 529 530 531# from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default 532def clonefunc(f): 533 """Deep clone the given function to create a new one. 534 535 By default, the PyPy JIT specializes the assembler based on f.__code__: 536 clonefunc makes sure that you will get a new function with a **different** 537 __code__, so that PyPy will produce independent assembler. This is useful 538 e.g. for benchmarks and microbenchmarks, so you can make sure to compare 539 apples to apples. 540 541 Use it with caution: if abused, this might easily produce an explosion of 542 produced assembler. 543 """ 544 # first of all, we clone the code object 545 if not hasattr(f, '__code__'): 546 return f 547 co = f.__code__ 548 args = [co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code, 549 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name, 550 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars] 551 if PY38: 552 args.insert(1, co.co_posonlyargcount) 553 if PY3: 554 args.insert(1, co.co_kwonlyargcount) 555 co2 = types.CodeType(*args) 556 # 557 # then, we clone the function itself, using the new co2 558 f2 = types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__) 559 return f2 560 561 562def format_dict(obj): 563 return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items())) 564 565 566class SafeJSONEncoder(json.JSONEncoder): 567 def default(self, o): 568 return "UNSERIALIZABLE[%r]" % o 569 570 571def safe_dumps(obj, **kwargs): 572 return json.dumps(obj, cls=SafeJSONEncoder, **kwargs) 573 574 575def report_progress(iterable, terminal_reporter, format_string, **kwargs): 576 total = len(iterable) 577 578 def progress_reporting_wrapper(): 579 for pos, item in enumerate(iterable): 580 string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs) 581 terminal_reporter.rewrite(string, black=True, bold=True) 582 yield string, item 583 584 return progress_reporting_wrapper() 585 586 587def report_noprogress(iterable, *args, **kwargs): 588 for pos, item in enumerate(iterable): 589 yield "", item 590 591 592def report_online_progress(progress_reporter, tr, line): 593 next(progress_reporter([line], tr, "{value}")) 594 595 596def slugify(name): 597 for c in r"\/:*?<>| ": 598 name = name.replace(c, '_').replace('__', '_') 599 return name 600 601 602def commonpath(paths): 603 """Given a sequence of path names, returns the longest common sub-path.""" 604 605 if not paths: 606 raise ValueError('commonpath() arg is an empty sequence') 607 608 if isinstance(paths[0], bytes): 609 sep = b'\\' 610 altsep = b'/' 611 curdir = b'.' 612 else: 613 sep = '\\' 614 altsep = '/' 615 curdir = '.' 616 617 try: 618 drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths] 619 split_paths = [p.split(sep) for d, p in drivesplits] 620 621 try: 622 isabs, = set(p[:1] == sep for d, p in drivesplits) 623 except ValueError: 624 raise ValueError("Can't mix absolute and relative paths") 625 626 # Check that all drive letters or UNC paths match. The check is made only 627 # now otherwise type errors for mixing strings and bytes would not be 628 # caught. 629 if len(set(d for d, p in drivesplits)) != 1: 630 raise ValueError("Paths don't have the same drive") 631 632 drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep)) 633 common = path.split(sep) 634 common = [c for c in common if c and c != curdir] 635 636 split_paths = [[c for c in s if c and c != curdir] for s in split_paths] 637 s1 = min(split_paths) 638 s2 = max(split_paths) 639 for i, c in enumerate(s1): 640 if c != s2[i]: 641 common = common[:i] 642 break 643 else: 644 common = common[:len(s1)] 645 646 prefix = drive + sep if isabs else drive 647 return prefix + sep.join(common) 648 except (TypeError, AttributeError): 649 genericpath._check_arg_types('commonpath', *paths) 650 raise 651 652 653def get_cprofile_functions(stats): 654 """ 655 Convert pstats structure to list of sorted dicts about each function. 656 """ 657 result = [] 658 # this assumes that you run py.test from project root dir 659 project_dir_parent = dirname(os.getcwd()) 660 661 for function_info, run_info in stats.stats.items(): 662 file_path = function_info[0] 663 if file_path.startswith(project_dir_parent): 664 file_path = file_path[len(project_dir_parent):].lstrip('/') 665 function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2]) 666 667 # if the function is recursive write number of 'total calls/primitive calls' 668 if run_info[0] == run_info[1]: 669 calls = str(run_info[0]) 670 else: 671 calls = '{1}/{0}'.format(run_info[0], run_info[1]) 672 673 result.append(dict(ncalls_recursion=calls, 674 ncalls=run_info[1], 675 tottime=run_info[2], 676 tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0, 677 cumtime=run_info[3], 678 cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0, 679 function_name=function_name)) 680 681 return result 682