1import base64 2import builtins 3import contextlib 4import copy 5import errno 6import getpass 7import glob 8import inspect 9import itertools 10import os 11import pdb 12import re 13import struct 14import subprocess 15import sys 16import time 17import traceback 18import urllib.parse 19import urllib.request 20import warnings 21from functools import lru_cache, wraps 22from numbers import Number as numeric_type 23from typing import Any, Callable, Type 24 25import matplotlib 26import numpy as np 27from more_itertools import always_iterable, collapse, first 28from packaging.version import parse as parse_version 29from tqdm import tqdm 30 31from yt.units import YTArray, YTQuantity 32from yt.utilities.exceptions import YTInvalidWidthError 33from yt.utilities.logger import ytLogger as mylog 34from yt.utilities.on_demand_imports import _requests as requests 35 36# Some functions for handling sequences and other types 37 38 39def is_sequence(obj): 40 """ 41 Grabbed from Python Cookbook / matplotlib.cbook. Returns true/false for 42 43 Parameters 44 ---------- 45 obj : iterable 46 """ 47 try: 48 len(obj) 49 return True 50 except TypeError: 51 return False 52 53 54def iter_fields(field_or_fields): 55 """ 56 Create an iterator for field names, specified as single strings or tuples(fname, 57 ftype) alike. 58 This can safely be used in places where we accept a single field or a list as input. 59 60 Parameters 61 ---------- 62 field_or_fields: str, tuple(str, str), or any iterable of the previous types. 63 64 Examples 65 -------- 66 67 >>> fields = ("gas", "density") 68 >>> for field in iter_fields(fields): 69 ... print(field) 70 density 71 72 >>> fields = ("gas", "density") 73 >>> for field in iter_fields(fields): 74 ... print(field) 75 ('gas', 'density') 76 77 >>> fields = [("gas", "density"), ("gas", "temperature"), ("index", "dx")] 78 >>> for field in iter_fields(fields): 79 ... print(field) 80 density 81 temperature 82 ('index', 'dx') 83 """ 84 return always_iterable(field_or_fields, base_type=(tuple, str, bytes)) 85 86 87def ensure_numpy_array(obj): 88 """ 89 This function ensures that *obj* is a numpy array. Typically used to 90 convert scalar, list or tuple argument passed to functions using Cython. 91 """ 92 if isinstance(obj, np.ndarray): 93 if obj.shape == (): 94 return np.array([obj]) 95 # We cast to ndarray to catch ndarray subclasses 96 return np.array(obj) 97 elif isinstance(obj, (list, tuple)): 98 return np.asarray(obj) 99 else: 100 return np.asarray([obj]) 101 102 103def read_struct(f, fmt): 104 """ 105 This reads a struct, and only that struct, from an open file. 106 """ 107 s = f.read(struct.calcsize(fmt)) 108 return struct.unpack(fmt, s) 109 110 111def just_one(obj): 112 # If we have an iterable, sometimes we only want one item 113 return first(collapse(obj)) 114 115 116def compare_dicts(dict1, dict2): 117 if not set(dict1) <= set(dict2): 118 return False 119 for key in dict1.keys(): 120 if dict1[key] is not None and dict2[key] is not None: 121 if isinstance(dict1[key], dict): 122 if compare_dicts(dict1[key], dict2[key]): 123 continue 124 else: 125 return False 126 try: 127 comparison = np.array_equal(dict1[key], dict2[key]) 128 except TypeError: 129 comparison = dict1[key] == dict2[key] 130 if not comparison: 131 return False 132 return True 133 134 135# Taken from 136# http://www.goldb.org/goldblog/2008/02/06/PythonConvertSecsIntoHumanReadableTimeStringHHMMSS.aspx 137def humanize_time(secs): 138 """ 139 Takes *secs* and returns a nicely formatted string 140 """ 141 mins, secs = divmod(secs, 60) 142 hours, mins = divmod(mins, 60) 143 return "%02d:%02d:%02d" % (hours, mins, secs) 144 145 146# 147# Some function wrappers that come in handy once in a while 148# 149 150# we use the resource module to get the memory page size 151 152try: 153 import resource 154except ImportError: 155 pass 156 157 158def get_memory_usage(subtract_share=False): 159 """ 160 Returning resident size in megabytes 161 """ 162 pid = os.getpid() 163 try: 164 pagesize = resource.getpagesize() 165 except NameError: 166 return -1024 167 status_file = f"/proc/{pid}/statm" 168 if not os.path.isfile(status_file): 169 return -1024 170 line = open(status_file).read() 171 size, resident, share, text, library, data, dt = (int(i) for i in line.split()) 172 if subtract_share: 173 resident -= share 174 return resident * pagesize / (1024 * 1024) # return in megs 175 176 177def time_execution(func): 178 r""" 179 Decorator for seeing how long a given function takes, depending on whether 180 or not the global 'yt.time_functions' config parameter is set. 181 """ 182 183 @wraps(func) 184 def wrapper(*arg, **kw): 185 t1 = time.time() 186 res = func(*arg, **kw) 187 t2 = time.time() 188 mylog.debug("%s took %0.3f s", func.__name__, (t2 - t1)) 189 return res 190 191 from yt.config import ytcfg 192 193 if ytcfg.get("yt", "time_functions"): 194 return wrapper 195 else: 196 return func 197 198 199def print_tb(func): 200 """ 201 This function is used as a decorate on a function to have the calling stack 202 printed whenever that function is entered. 203 204 This can be used like so: 205 206 >>> @print_tb 207 ... def some_deeply_nested_function(*args, **kwargs): 208 ... ... 209 210 """ 211 212 @wraps(func) 213 def run_func(*args, **kwargs): 214 traceback.print_stack() 215 return func(*args, **kwargs) 216 217 return run_func 218 219 220def rootonly(func): 221 """ 222 This is a decorator that, when used, will only call the function on the 223 root processor and then broadcast the results of the function to all other 224 processors. 225 226 This can be used like so: 227 228 .. code-block:: python 229 230 @rootonly 231 def some_root_only_function(*args, **kwargs): 232 ... 233 """ 234 from yt.config import ytcfg 235 236 @wraps(func) 237 def check_parallel_rank(*args, **kwargs): 238 if ytcfg.get("yt", "internals", "topcomm_parallel_rank") > 0: 239 return 240 return func(*args, **kwargs) 241 242 return check_parallel_rank 243 244 245def pdb_run(func): 246 """ 247 This decorator inserts a pdb session on top of the call-stack into a 248 function. 249 250 This can be used like so: 251 252 >>> @pdb_run 253 ... def some_function_to_debug(*args, **kwargs): 254 ... ... 255 256 """ 257 258 @wraps(func) 259 def wrapper(*args, **kw): 260 pdb.runcall(func, *args, **kw) 261 262 return wrapper 263 264 265__header = """ 266== Welcome to the embedded IPython Shell == 267 268 You are currently inside the function: 269 %(fname)s 270 271 Defined in: 272 %(filename)s:%(lineno)s 273""" 274 275 276def insert_ipython(num_up=1): 277 """ 278 Placed inside a function, this will insert an IPython interpreter at that 279 current location. This will enabled detailed inspection of the current 280 execution environment, as well as (optional) modification of that environment. 281 *num_up* refers to how many frames of the stack get stripped off, and 282 defaults to 1 so that this function itself is stripped off. 283 """ 284 import IPython 285 from IPython.terminal.embed import InteractiveShellEmbed 286 287 try: 288 from traitlets.config.loader import Config 289 except ImportError: 290 from IPython.config.loader import Config 291 292 frame = inspect.stack()[num_up] 293 loc = frame[0].f_locals.copy() 294 glo = frame[0].f_globals 295 dd = dict(fname=frame[3], filename=frame[1], lineno=frame[2]) 296 cfg = Config() 297 cfg.InteractiveShellEmbed.local_ns = loc 298 cfg.InteractiveShellEmbed.global_ns = glo 299 IPython.embed(config=cfg, banner2=__header % dd) 300 ipshell = InteractiveShellEmbed(config=cfg) 301 302 del ipshell 303 304 305# 306# Our progress bar types and how to get one 307# 308 309 310class TqdmProgressBar: 311 # This is a drop in replacement for pbar 312 # called tqdm 313 def __init__(self, title, maxval): 314 self._pbar = tqdm(leave=True, total=maxval, desc=title) 315 self.i = 0 316 317 def update(self, i=None): 318 if i is None: 319 i = self.i + 1 320 n = i - self.i 321 self.i = i 322 self._pbar.update(n) 323 324 def finish(self): 325 self._pbar.close() 326 327 328class DummyProgressBar: 329 # This progressbar gets handed if we don't 330 # want ANY output 331 def __init__(self, *args, **kwargs): 332 return 333 334 def update(self, *args, **kwargs): 335 return 336 337 def finish(self, *args, **kwargs): 338 return 339 340 341def get_pbar(title, maxval): 342 """ 343 This returns a progressbar of the most appropriate type, given a *title* 344 and a *maxval*. 345 """ 346 maxval = max(maxval, 1) 347 from yt.config import ytcfg 348 349 if ( 350 ytcfg.get("yt", "suppress_stream_logging") 351 or ytcfg.get("yt", "internals", "within_testing") 352 or maxval == 1 353 or not is_root() 354 ): 355 return DummyProgressBar() 356 return TqdmProgressBar(title, maxval) 357 358 359def only_on_root(func, *args, **kwargs): 360 """ 361 This function accepts a *func*, a set of *args* and *kwargs* and then only 362 on the root processor calls the function. All other processors get "None" 363 handed back. 364 """ 365 from yt.config import ytcfg 366 367 if kwargs.pop("global_rootonly", False): 368 cfg_option = "global_parallel_rank" 369 else: 370 cfg_option = "topcomm_parallel_rank" 371 if not ytcfg.get("yt", "internals", "parallel"): 372 return func(*args, **kwargs) 373 if ytcfg.get("yt", "internals", cfg_option) > 0: 374 return 375 return func(*args, **kwargs) 376 377 378def is_root(): 379 """ 380 This function returns True if it is on the root processor of the 381 topcomm and False otherwise. 382 """ 383 from yt.config import ytcfg 384 385 if not ytcfg.get("yt", "internals", "parallel"): 386 return True 387 return ytcfg.get("yt", "internals", "topcomm_parallel_rank") == 0 388 389 390# 391# Our signal and traceback handling functions 392# 393 394 395def signal_print_traceback(signo, frame): 396 print(traceback.print_stack(frame)) 397 398 399def signal_problem(signo, frame): 400 raise RuntimeError() 401 402 403def signal_ipython(signo, frame): 404 insert_ipython(2) 405 406 407def paste_traceback(exc_type, exc, tb): 408 """ 409 This is a traceback handler that knows how to paste to the pastebin. 410 Should only be used in sys.excepthook. 411 """ 412 sys.__excepthook__(exc_type, exc, tb) 413 import xmlrpc.client 414 from io import StringIO 415 416 p = xmlrpc.client.ServerProxy( 417 "http://paste.yt-project.org/xmlrpc/", allow_none=True 418 ) 419 s = StringIO() 420 traceback.print_exception(exc_type, exc, tb, file=s) 421 s = s.getvalue() 422 ret = p.pastes.newPaste("pytb", s, None, "", "", True) 423 print() 424 print(f"Traceback pasted to http://paste.yt-project.org/show/{ret}") 425 print() 426 427 428def paste_traceback_detailed(exc_type, exc, tb): 429 """ 430 This is a traceback handler that knows how to paste to the pastebin. 431 Should only be used in sys.excepthook. 432 """ 433 import cgitb 434 import xmlrpc.client 435 from io import StringIO 436 437 s = StringIO() 438 handler = cgitb.Hook(format="text", file=s) 439 handler(exc_type, exc, tb) 440 s = s.getvalue() 441 print(s) 442 p = xmlrpc.client.ServerProxy( 443 "http://paste.yt-project.org/xmlrpc/", allow_none=True 444 ) 445 ret = p.pastes.newPaste("text", s, None, "", "", True) 446 print() 447 print(f"Traceback pasted to http://paste.yt-project.org/show/{ret}") 448 print() 449 450 451_ss = "fURbBUUBE0cLXgETJnZgJRMXVhVGUQpQAUBuehQMUhJWRFFRAV1ERAtBXw1dAxMLXT4zXBFfABNN\nC0ZEXw1YUURHCxMXVlFERwxWCQw=\n" 452 453 454def _rdbeta(key): 455 enc_s = base64.decodestring(_ss) 456 dec_s = "".join(chr(ord(a) ^ ord(b)) for a, b in zip(enc_s, itertools.cycle(key))) 457 print(dec_s) 458 459 460# 461# Some exceptions 462# 463 464 465class NoCUDAException(Exception): 466 pass 467 468 469class YTEmptyClass: 470 pass 471 472 473def update_git(path): 474 try: 475 import git 476 except ImportError: 477 print("Updating and precise version information requires ") 478 print("gitpython to be installed.") 479 print("Try: python -m pip install gitpython") 480 return -1 481 with open(os.path.join(path, "yt_updater.log"), "a") as f: 482 repo = git.Repo(path) 483 if repo.is_dirty(untracked_files=True): 484 print("Changes have been made to the yt source code so I won't ") 485 print("update the code. You will have to do this yourself.") 486 print("Here's a set of sample commands:") 487 print("") 488 print(f" $ cd {path}") 489 print(" $ git stash") 490 print(" $ git checkout main") 491 print(" $ git pull") 492 print(" $ git stash pop") 493 print(f" $ {sys.executable} setup.py develop") 494 print("") 495 return 1 496 if repo.active_branch.name != "main": 497 print("yt repository is not tracking the main branch so I won't ") 498 print("update the code. You will have to do this yourself.") 499 print("Here's a set of sample commands:") 500 print("") 501 print(f" $ cd {path}") 502 print(" $ git checkout main") 503 print(" $ git pull") 504 print(f" $ {sys.executable} setup.py develop") 505 print("") 506 return 1 507 print("Updating the repository") 508 f.write("Updating the repository\n\n") 509 old_version = repo.git.rev_parse("HEAD", short=12) 510 try: 511 remote = repo.remotes.yt_upstream 512 except AttributeError: 513 remote = repo.create_remote( 514 "yt_upstream", url="https://github.com/yt-project/yt" 515 ) 516 remote.fetch() 517 main = repo.heads.main 518 main.set_tracking_branch(remote.refs.main) 519 main.checkout() 520 remote.pull() 521 new_version = repo.git.rev_parse("HEAD", short=12) 522 f.write(f"Updated from {old_version} to {new_version}\n\n") 523 rebuild_modules(path, f) 524 print("Updated successfully") 525 526 527def rebuild_modules(path, f): 528 f.write("Rebuilding modules\n\n") 529 p = subprocess.Popen( 530 [sys.executable, "setup.py", "build_ext", "-i"], 531 cwd=path, 532 stdout=subprocess.PIPE, 533 stderr=subprocess.STDOUT, 534 ) 535 stdout, stderr = p.communicate() 536 f.write(stdout.decode("utf-8")) 537 f.write("\n\n") 538 if p.returncode: 539 print(f"BROKEN: See {os.path.join(path, 'yt_updater.log')}") 540 sys.exit(1) 541 f.write("Successful!\n") 542 543 544def get_git_version(path): 545 try: 546 import git 547 except ImportError: 548 print("Updating and precise version information requires ") 549 print("gitpython to be installed.") 550 print("Try: python -m pip install gitpython") 551 return None 552 try: 553 repo = git.Repo(path) 554 return repo.git.rev_parse("HEAD", short=12) 555 except git.InvalidGitRepositoryError: 556 # path is not a git repository 557 return None 558 559 560def get_yt_version(): 561 import pkg_resources 562 563 yt_provider = pkg_resources.get_provider("yt") 564 path = os.path.dirname(yt_provider.module_path) 565 version = get_git_version(path) 566 if version is None: 567 return version 568 else: 569 v_str = version[:12].strip() 570 if hasattr(v_str, "decode"): 571 v_str = v_str.decode("utf-8") 572 return v_str 573 574 575def get_version_stack(): 576 version_info = {} 577 version_info["yt"] = get_yt_version() 578 version_info["numpy"] = np.version.version 579 version_info["matplotlib"] = matplotlib.__version__ 580 return version_info 581 582 583def get_script_contents(): 584 top_frame = inspect.stack()[-1] 585 finfo = inspect.getframeinfo(top_frame[0]) 586 if finfo[2] != "<module>": 587 return None 588 if not os.path.exists(finfo[0]): 589 return None 590 try: 591 contents = open(finfo[0]).read() 592 except Exception: 593 contents = None 594 return contents 595 596 597def download_file(url, filename): 598 try: 599 return fancy_download_file(url, filename, requests) 600 except ImportError: 601 # fancy_download_file requires requests 602 return simple_download_file(url, filename) 603 604 605def fancy_download_file(url, filename, requests=None): 606 response = requests.get(url, stream=True) 607 total_length = response.headers.get("content-length") 608 609 with open(filename, "wb") as fh: 610 if total_length is None: 611 fh.write(response.content) 612 else: 613 blocksize = 4 * 1024 ** 2 614 iterations = int(float(total_length) / float(blocksize)) 615 616 pbar = get_pbar( 617 "Downloading %s to %s " % os.path.split(filename)[::-1], iterations 618 ) 619 iteration = 0 620 for chunk in response.iter_content(chunk_size=blocksize): 621 fh.write(chunk) 622 iteration += 1 623 pbar.update(iteration) 624 pbar.finish() 625 return filename 626 627 628def simple_download_file(url, filename): 629 class MyURLopener(urllib.request.FancyURLopener): 630 def http_error_default(self, url, fp, errcode, errmsg, headers): 631 raise RuntimeError( 632 "Attempt to download file from %s failed with error %s: %s." 633 % (url, errcode, errmsg) 634 ) 635 636 fn, h = MyURLopener().retrieve(url, filename) 637 return fn 638 639 640# This code snippet is modified from Georg Brandl 641def bb_apicall(endpoint, data, use_pass=True): 642 uri = f"https://api.bitbucket.org/1.0/{endpoint}/" 643 # since bitbucket doesn't return the required WWW-Authenticate header when 644 # making a request without Authorization, we cannot use the standard urllib2 645 # auth handlers; we have to add the requisite header from the start 646 if data is not None: 647 data = urllib.parse.urlencode(data) 648 req = urllib.request.Request(uri, data) 649 if use_pass: 650 username = input("Bitbucket Username? ") 651 password = getpass.getpass() 652 upw = f"{username}:{password}" 653 req.add_header("Authorization", f"Basic {base64.b64encode(upw).strip()}") 654 return urllib.request.urlopen(req).read() 655 656 657def fix_length(length, ds): 658 registry = ds.unit_registry 659 if isinstance(length, YTArray): 660 if registry is not None: 661 length.units.registry = registry 662 return length.in_units("code_length") 663 if isinstance(length, numeric_type): 664 return YTArray(length, "code_length", registry=registry) 665 length_valid_tuple = isinstance(length, (list, tuple)) and len(length) == 2 666 unit_is_string = isinstance(length[1], str) 667 length_is_number = isinstance(length[0], numeric_type) and not isinstance( 668 length[0], YTArray 669 ) 670 if length_valid_tuple and unit_is_string and length_is_number: 671 return YTArray(*length, registry=registry) 672 else: 673 raise RuntimeError(f"Length {str(length)} is invalid") 674 675 676@contextlib.contextmanager 677def parallel_profile(prefix): 678 r"""A context manager for profiling parallel code execution using cProfile 679 680 This is a simple context manager that automatically profiles the execution 681 of a snippet of code. 682 683 Parameters 684 ---------- 685 prefix : string 686 A string name to prefix outputs with. 687 688 Examples 689 -------- 690 691 >>> from yt import PhasePlot 692 >>> from yt.testing import fake_random_ds 693 >>> fields = ("density", "temperature", "cell_mass") 694 >>> units = ("g/cm**3", "K", "g") 695 >>> ds = fake_random_ds(16, fields=fields, units=units) 696 >>> with parallel_profile("my_profile"): 697 ... plot = PhasePlot(ds.all_data(), *fields) 698 """ 699 import cProfile 700 701 from yt.config import ytcfg 702 703 fn = "%s_%04i_%04i.cprof" % ( 704 prefix, 705 ytcfg.get("yt", "internals", "topcomm_parallel_size"), 706 ytcfg.get("yt", "internals", "topcomm_parallel_rank"), 707 ) 708 p = cProfile.Profile() 709 p.enable() 710 yield fn 711 p.disable() 712 p.dump_stats(fn) 713 714 715def get_num_threads(): 716 from .config import ytcfg 717 718 nt = ytcfg.get("yt", "num_threads") 719 if nt < 0: 720 return os.environ.get("OMP_NUM_THREADS", 0) 721 return nt 722 723 724def fix_axis(axis, ds): 725 return ds.coordinates.axis_id.get(axis, axis) 726 727 728def get_output_filename(name, keyword, suffix): 729 r"""Return an appropriate filename for output. 730 731 With a name provided by the user, this will decide how to appropriately name the 732 output file by the following rules: 733 734 1. if name is None, the filename will be the keyword plus the suffix. 735 2. if name ends with "/" (resp "\" on Windows), assume name is a directory and the 736 file will be named name/(keyword+suffix). If the directory does not exist, first 737 try to create it and raise an exception if an error occurs. 738 3. if name does not end in the suffix, add the suffix. 739 740 Parameters 741 ---------- 742 name : str 743 A filename given by the user. 744 keyword : str 745 A default filename prefix if name is None. 746 suffix : str 747 Suffix that must appear at end of the filename. 748 This will be added if not present. 749 750 Examples 751 -------- 752 753 >>> get_output_filename(None, "Projection_x", ".png") 754 'Projection_x.png' 755 >>> get_output_filename("my_file", "Projection_x", ".png") 756 'my_file.png' 757 >>> get_output_filename("my_dir/", "Projection_x", ".png") 758 'my_dir/Projection_x.png' 759 760 """ 761 if name is None: 762 name = keyword 763 name = os.path.expanduser(name) 764 if name.endswith(os.sep) and not os.path.isdir(name): 765 ensure_dir(name) 766 if os.path.isdir(name): 767 name = os.path.join(name, keyword) 768 if not name.endswith(suffix): 769 name += suffix 770 return name 771 772 773def ensure_dir_exists(path): 774 r"""Create all directories in path recursively in a parallel safe manner""" 775 my_dir = os.path.dirname(path) 776 # If path is a file in the current directory, like "test.txt", then my_dir 777 # would be an empty string, resulting in FileNotFoundError when passed to 778 # ensure_dir. Let's avoid that. 779 if my_dir: 780 ensure_dir(my_dir) 781 782 783def ensure_dir(path): 784 r"""Parallel safe directory maker.""" 785 if os.path.exists(path): 786 return path 787 788 try: 789 os.makedirs(path) 790 except OSError as e: 791 if e.errno == errno.EEXIST: 792 pass 793 else: 794 raise 795 return path 796 797 798def validate_width_tuple(width): 799 if not is_sequence(width) or len(width) != 2: 800 raise YTInvalidWidthError(f"width ({width}) is not a two element tuple") 801 is_numeric = isinstance(width[0], numeric_type) 802 length_has_units = isinstance(width[0], YTArray) 803 unit_is_string = isinstance(width[1], str) 804 if not is_numeric or length_has_units and unit_is_string: 805 msg = f"width ({str(width)}) is invalid. " 806 msg += "Valid widths look like this: (12, 'au')" 807 raise YTInvalidWidthError(msg) 808 809 810_first_cap_re = re.compile("(.)([A-Z][a-z]+)") 811_all_cap_re = re.compile("([a-z0-9])([A-Z])") 812 813 814@lru_cache(maxsize=128, typed=False) 815def camelcase_to_underscore(name): 816 s1 = _first_cap_re.sub(r"\1_\2", name) 817 return _all_cap_re.sub(r"\1_\2", s1).lower() 818 819 820def set_intersection(some_list): 821 if len(some_list) == 0: 822 return set() 823 # This accepts a list of iterables, which we get the intersection of. 824 s = set(some_list[0]) 825 for l in some_list[1:]: 826 s.intersection_update(l) 827 return s 828 829 830@contextlib.contextmanager 831def memory_checker(interval=15, dest=None): 832 r"""This is a context manager that monitors memory usage. 833 834 Parameters 835 ---------- 836 interval : int 837 The number of seconds between printing the current memory usage in 838 gigabytes of the current Python interpreter. 839 840 Examples 841 -------- 842 843 >>> with memory_checker(10): 844 ... arr = np.zeros(1024 * 1024 * 1024, dtype="float64") 845 ... time.sleep(15) 846 ... del arr 847 MEMORY: -1.000e+00 gb 848 """ 849 import threading 850 851 if dest is None: 852 dest = sys.stdout 853 854 class MemoryChecker(threading.Thread): 855 def __init__(self, event, interval): 856 self.event = event 857 self.interval = interval 858 threading.Thread.__init__(self) 859 860 def run(self): 861 while not self.event.wait(self.interval): 862 print(f"MEMORY: {get_memory_usage() / 1024.0:0.3e} gb", file=dest) 863 864 e = threading.Event() 865 mem_check = MemoryChecker(e, interval) 866 mem_check.start() 867 try: 868 yield 869 finally: 870 e.set() 871 872 873def enable_plugins(plugin_filename=None): 874 """Forces a plugin file to be parsed. 875 876 A plugin file is a means of creating custom fields, quantities, 877 data objects, colormaps, and other code classes and objects to be used 878 in yt scripts without modifying the yt source directly. 879 880 If ``plugin_filename`` is omitted, this function will look for a plugin file at 881 ``$HOME/.config/yt/my_plugins.py``, which is the prefered behaviour for a 882 system-level configuration. 883 884 Warning: a script using this function will only be reproducible if your plugin 885 file is shared with it. 886 """ 887 import yt 888 from yt.config import config_dir, old_config_dir, ytcfg 889 from yt.fields.my_plugin_fields import my_plugins_fields 890 891 if plugin_filename is not None: 892 _fn = plugin_filename 893 if not os.path.isfile(_fn): 894 raise FileNotFoundError(_fn) 895 else: 896 # Determine global plugin location. By decreasing priority order: 897 # - absolute path 898 # - CONFIG_DIR 899 # - obsolete config dir. 900 my_plugin_name = ytcfg.get("yt", "plugin_filename") 901 for base_prefix in ("", config_dir(), old_config_dir()): 902 if os.path.isfile(os.path.join(base_prefix, my_plugin_name)): 903 _fn = os.path.join(base_prefix, my_plugin_name) 904 break 905 else: 906 raise FileNotFoundError("Could not find a global system plugin file.") 907 908 if _fn.startswith(old_config_dir()): 909 mylog.warning( 910 "Your plugin file is located in a deprecated directory. " 911 "Please move it from %s to %s", 912 os.path.join(old_config_dir(), my_plugin_name), 913 os.path.join(config_dir(), my_plugin_name), 914 ) 915 916 mylog.info("Loading plugins from %s", _fn) 917 ytdict = yt.__dict__ 918 execdict = ytdict.copy() 919 execdict["add_field"] = my_plugins_fields.add_field 920 with open(_fn) as f: 921 code = compile(f.read(), _fn, "exec") 922 exec(code, execdict, execdict) 923 ytnamespace = list(ytdict.keys()) 924 for k in execdict.keys(): 925 if k not in ytnamespace: 926 if callable(execdict[k]): 927 setattr(yt, k, execdict[k]) 928 929 930def subchunk_count(n_total, chunk_size): 931 handled = 0 932 while handled < n_total: 933 tr = min(n_total - handled, chunk_size) 934 yield tr 935 handled += tr 936 937 938def fix_unitary(u): 939 if u == "1": 940 return "unitary" 941 else: 942 return u 943 944 945def get_hash(infile, algorithm="md5", BLOCKSIZE=65536): 946 """Generate file hash without reading in the entire file at once. 947 948 Original code licensed under MIT. Source: 949 https://www.pythoncentral.io/hashing-files-with-python/ 950 951 Parameters 952 ---------- 953 infile : str 954 File of interest (including the path). 955 algorithm : str (optional) 956 Hash algorithm of choice. Defaults to 'md5'. 957 BLOCKSIZE : int (optional) 958 How much data in bytes to read in at once. 959 960 Returns 961 ------- 962 hash : str 963 The hash of the file. 964 965 Examples 966 -------- 967 >>> from tempfile import NamedTemporaryFile 968 >>> with NamedTemporaryFile() as file: 969 ... get_hash(file.name) 970 'd41d8cd98f00b204e9800998ecf8427e' 971 """ 972 import hashlib 973 974 try: 975 hasher = getattr(hashlib, algorithm)() 976 except AttributeError as e: 977 raise NotImplementedError( 978 f"'{algorithm}' not available! Available algorithms: {hashlib.algorithms}" 979 ) from e 980 981 filesize = os.path.getsize(infile) 982 iterations = int(float(filesize) / float(BLOCKSIZE)) 983 984 pbar = get_pbar(f"Generating {algorithm} hash", iterations) 985 986 iter = 0 987 with open(infile, "rb") as f: 988 buf = f.read(BLOCKSIZE) 989 while len(buf) > 0: 990 hasher.update(buf) 991 buf = f.read(BLOCKSIZE) 992 iter += 1 993 pbar.update(iter) 994 pbar.finish() 995 996 return hasher.hexdigest() 997 998 999def get_brewer_cmap(cmap): 1000 """Returns a colorbrewer colormap from palettable""" 1001 try: 1002 import brewer2mpl 1003 except ImportError: 1004 brewer2mpl = None 1005 try: 1006 import palettable 1007 except ImportError: 1008 palettable = None 1009 if palettable is not None: 1010 bmap = palettable.colorbrewer.get_map(*cmap) 1011 elif brewer2mpl is not None: 1012 warnings.warn( 1013 "Using brewer2mpl colormaps is deprecated. " 1014 "Please install the successor to brewer2mpl, " 1015 "palettable, with `pip install palettable`. " 1016 "Colormap tuple names remain unchanged." 1017 ) 1018 bmap = brewer2mpl.get_map(*cmap) 1019 else: 1020 raise RuntimeError("Please install palettable to use colorbrewer colormaps") 1021 return bmap.get_mpl_colormap(N=cmap[2]) 1022 1023 1024@contextlib.contextmanager 1025def dummy_context_manager(*args, **kwargs): 1026 yield 1027 1028 1029def matplotlib_style_context(style_name=None, after_reset=False): 1030 """Returns a context manager for controlling matplotlib style. 1031 1032 Arguments are passed to matplotlib.style.context() if specified. Defaults 1033 to setting "classic" style, after resetting to the default config parameters. 1034 1035 On older matplotlib versions (<=1.5.0) where matplotlib.style isn't 1036 available, returns a dummy context manager. 1037 """ 1038 if style_name is None: 1039 import matplotlib 1040 1041 style_name = {"mathtext.fontset": "cm"} 1042 if parse_version(matplotlib.__version__) >= parse_version("3.3.0"): 1043 style_name["mathtext.fallback"] = "cm" 1044 else: 1045 style_name["mathtext.fallback_to_cm"] = True 1046 try: 1047 import matplotlib.style 1048 1049 return matplotlib.style.context(style_name, after_reset=after_reset) 1050 except ImportError: 1051 pass 1052 return dummy_context_manager() 1053 1054 1055interactivity = False 1056 1057"""Sets the condition that interactive backends can be used.""" 1058 1059 1060def toggle_interactivity(): 1061 global interactivity 1062 interactivity = not interactivity 1063 if interactivity: 1064 if "__IPYTHON__" in dir(builtins): 1065 import IPython 1066 1067 shell = IPython.get_ipython() 1068 shell.magic("matplotlib") 1069 else: 1070 import matplotlib 1071 1072 matplotlib.interactive(True) 1073 1074 1075def get_interactivity(): 1076 return interactivity 1077 1078 1079def setdefaultattr(obj, name, value): 1080 """Set attribute with *name* on *obj* with *value* if it doesn't exist yet 1081 1082 Analogous to dict.setdefault 1083 """ 1084 if not hasattr(obj, name): 1085 setattr(obj, name, value) 1086 return getattr(obj, name) 1087 1088 1089def parse_h5_attr(f, attr): 1090 """A Python3-safe function for getting hdf5 attributes. 1091 1092 If an attribute is supposed to be a string, this will return it as such. 1093 """ 1094 val = f.attrs.get(attr, None) 1095 if isinstance(val, bytes): 1096 return val.decode("utf8") 1097 else: 1098 return val 1099 1100 1101def obj_length(v): 1102 if is_sequence(v): 1103 return len(v) 1104 else: 1105 # If something isn't iterable, we return 0 1106 # to signify zero length (aka a scalar). 1107 return 0 1108 1109 1110def array_like_field(data, x, field): 1111 field = data._determine_fields(field)[0] 1112 if isinstance(field, tuple): 1113 finfo = data.ds._get_field_info(field[0], field[1]) 1114 else: 1115 finfo = data.ds._get_field_info(field) 1116 if finfo.sampling_type == "particle": 1117 units = finfo.output_units 1118 else: 1119 units = finfo.units 1120 if isinstance(x, YTArray): 1121 arr = copy.deepcopy(x) 1122 arr.convert_to_units(units) 1123 return arr 1124 if isinstance(x, np.ndarray): 1125 return data.ds.arr(x, units) 1126 else: 1127 return data.ds.quan(x, units) 1128 1129 1130def validate_3d_array(obj): 1131 if not is_sequence(obj) or len(obj) != 3: 1132 raise TypeError( 1133 "Expected an array of size (3,), received '%s' of " 1134 "length %s" % (str(type(obj)).split("'")[1], len(obj)) 1135 ) 1136 1137 1138def validate_float(obj): 1139 """Validates if the passed argument is a float value. 1140 1141 Raises an exception if `obj` is a single float value 1142 or a YTQuantity of size 1. 1143 1144 Parameters 1145 ---------- 1146 obj : Any 1147 Any argument which needs to be checked for a single float value. 1148 1149 Raises 1150 ------ 1151 TypeError 1152 Raised if `obj` is not a single float value or YTQunatity 1153 1154 Examples 1155 -------- 1156 >>> validate_float(1) 1157 >>> validate_float(1.50) 1158 >>> validate_float(YTQuantity(1, "cm")) 1159 >>> validate_float((1, "cm")) 1160 >>> validate_float([1, 1, 1]) 1161 Traceback (most recent call last): 1162 ... 1163 TypeError: Expected a numeric value (or size-1 array), received 'list' of length 3 1164 1165 >>> validate_float([YTQuantity(1, "cm"), YTQuantity(2, "cm")]) 1166 Traceback (most recent call last): 1167 ... 1168 TypeError: Expected a numeric value (or size-1 array), received 'list' of length 2 1169 """ 1170 if isinstance(obj, tuple): 1171 if ( 1172 len(obj) != 2 1173 or not isinstance(obj[0], numeric_type) 1174 or not isinstance(obj[1], str) 1175 ): 1176 raise TypeError( 1177 "Expected a numeric value (or tuple of format " 1178 "(float, String)), received an inconsistent tuple " 1179 "'%s'." % str(obj) 1180 ) 1181 else: 1182 return 1183 if is_sequence(obj) and (len(obj) != 1 or not isinstance(obj[0], numeric_type)): 1184 raise TypeError( 1185 "Expected a numeric value (or size-1 array), " 1186 "received '%s' of length %s" % (str(type(obj)).split("'")[1], len(obj)) 1187 ) 1188 1189 1190def validate_sequence(obj): 1191 if obj is not None and not is_sequence(obj): 1192 raise TypeError( 1193 "Expected an iterable object," 1194 " received '%s'" % str(type(obj)).split("'")[1] 1195 ) 1196 1197 1198def validate_object(obj, data_type): 1199 if obj is not None and not isinstance(obj, data_type): 1200 raise TypeError( 1201 "Expected an object of '%s' type, received '%s'" 1202 % (str(data_type).split("'")[1], str(type(obj)).split("'")[1]) 1203 ) 1204 1205 1206def validate_axis(ds, axis): 1207 if ds is not None: 1208 valid_axis = ds.coordinates.axis_name.keys() 1209 else: 1210 valid_axis = [0, 1, 2, "x", "y", "z", "X", "Y", "Z"] 1211 if axis not in valid_axis: 1212 raise TypeError( 1213 "Expected axis of int or char type (can be %s), " 1214 "received '%s'." % (list(valid_axis), axis) 1215 ) 1216 1217 1218def validate_center(center): 1219 if isinstance(center, str): 1220 c = center.lower() 1221 if ( 1222 c not in ["c", "center", "m", "max", "min"] 1223 and not c.startswith("max_") 1224 and not c.startswith("min_") 1225 ): 1226 raise TypeError( 1227 "Expected 'center' to be in ['c', 'center', " 1228 "'m', 'max', 'min'] or the prefix to be " 1229 "'max_'/'min_', received '%s'." % center 1230 ) 1231 elif not isinstance(center, (numeric_type, YTQuantity)) and not is_sequence(center): 1232 raise TypeError( 1233 "Expected 'center' to be a numeric object of type " 1234 "list/tuple/np.ndarray/YTArray/YTQuantity, " 1235 "received '%s'." % str(type(center)).split("'")[1] 1236 ) 1237 1238 1239def sglob(pattern): 1240 """ 1241 Return the results of a glob through the sorted() function. 1242 """ 1243 return sorted(glob.glob(pattern)) 1244 1245 1246def dictWithFactory(factory: Callable[[Any], Any]) -> Type: 1247 """ 1248 Create a dictionary class with a default factory function. 1249 Contrary to `collections.defaultdict`, the factory takes 1250 the missing key as input parameter. 1251 1252 Parameters 1253 ---------- 1254 factory : callable(key) -> value 1255 The factory to call when hitting a missing key 1256 1257 Returns 1258 ------- 1259 DictWithFactory class 1260 A class to create new dictionaries handling missing keys. 1261 """ 1262 1263 class DictWithFactory(dict): 1264 def __init__(self, *args, **kwargs): 1265 self.factory = factory 1266 super().__init__(*args, **kwargs) 1267 1268 def __missing__(self, key): 1269 val = self.factory(key) 1270 self[key] = val 1271 return val 1272 1273 return DictWithFactory 1274 1275 1276def levenshtein_distance(seq1, seq2, max_dist=None): 1277 """ 1278 Compute the levenshtein distance between seq1 and seq2. 1279 From https://stackabuse.com/levenshtein-distance-and-text-similarity-in-python/ 1280 1281 Parameters 1282 ---------- 1283 seq1 : str 1284 seq2 : str 1285 The strings to compute the distance between 1286 max_dist : integer 1287 If not None, maximum distance returned (see notes). 1288 1289 Returns 1290 ------- 1291 The Levenshtein distance as an integer. 1292 1293 Notes 1294 ----- 1295 This computes the Levenshtein distance, i.e. the number of edits to change 1296 seq1 into seq2. If a maximum distance is passed, the algorithm will stop as soon 1297 as the number of edits goes above the value. This allows for an earlier break 1298 and speeds calculations up. 1299 """ 1300 size_x = len(seq1) + 1 1301 size_y = len(seq2) + 1 1302 if max_dist is None: 1303 max_dist = max(size_x, size_y) 1304 1305 if abs(size_x - size_y) > max_dist: 1306 return max_dist + 1 1307 matrix = np.zeros((size_x, size_y), dtype=int) 1308 for x in range(size_x): 1309 matrix[x, 0] = x 1310 for y in range(size_y): 1311 matrix[0, y] = y 1312 1313 for x in range(1, size_x): 1314 for y in range(1, size_y): 1315 if seq1[x - 1] == seq2[y - 1]: 1316 matrix[x, y] = min( 1317 matrix[x - 1, y] + 1, matrix[x - 1, y - 1], matrix[x, y - 1] + 1 1318 ) 1319 else: 1320 matrix[x, y] = min( 1321 matrix[x - 1, y] + 1, matrix[x - 1, y - 1] + 1, matrix[x, y - 1] + 1 1322 ) 1323 1324 # Early break: the minimum distance is already larger than 1325 # maximum allow value, can return safely. 1326 if matrix[x].min() > max_dist: 1327 return max_dist + 1 1328 return matrix[size_x - 1, size_y - 1] 1329