1import base64
2import builtins
3import contextlib
4import copy
5import errno
6import getpass
7import glob
8import inspect
9import itertools
10import os
11import pdb
12import re
13import struct
14import subprocess
15import sys
16import time
17import traceback
18import urllib.parse
19import urllib.request
20import warnings
21from functools import lru_cache, wraps
22from numbers import Number as numeric_type
23from typing import Any, Callable, Type
24
25import matplotlib
26import numpy as np
27from more_itertools import always_iterable, collapse, first
28from packaging.version import parse as parse_version
29from tqdm import tqdm
30
31from yt.units import YTArray, YTQuantity
32from yt.utilities.exceptions import YTInvalidWidthError
33from yt.utilities.logger import ytLogger as mylog
34from yt.utilities.on_demand_imports import _requests as requests
35
36# Some functions for handling sequences and other types
37
38
39def is_sequence(obj):
40    """
41    Grabbed from Python Cookbook / matplotlib.cbook.  Returns true/false for
42
43    Parameters
44    ----------
45    obj : iterable
46    """
47    try:
48        len(obj)
49        return True
50    except TypeError:
51        return False
52
53
54def iter_fields(field_or_fields):
55    """
56    Create an iterator for field names, specified as single strings or tuples(fname,
57    ftype) alike.
58    This can safely be used in places where we accept a single field or a list as input.
59
60    Parameters
61    ----------
62    field_or_fields: str, tuple(str, str), or any iterable of the previous types.
63
64    Examples
65    --------
66
67    >>> fields = ("gas", "density")
68    >>> for field in iter_fields(fields):
69    ...     print(field)
70    density
71
72    >>> fields = ("gas", "density")
73    >>> for field in iter_fields(fields):
74    ...     print(field)
75    ('gas', 'density')
76
77    >>> fields = [("gas", "density"), ("gas", "temperature"), ("index", "dx")]
78    >>> for field in iter_fields(fields):
79    ...     print(field)
80    density
81    temperature
82    ('index', 'dx')
83    """
84    return always_iterable(field_or_fields, base_type=(tuple, str, bytes))
85
86
87def ensure_numpy_array(obj):
88    """
89    This function ensures that *obj* is a numpy array. Typically used to
90    convert scalar, list or tuple argument passed to functions using Cython.
91    """
92    if isinstance(obj, np.ndarray):
93        if obj.shape == ():
94            return np.array([obj])
95        # We cast to ndarray to catch ndarray subclasses
96        return np.array(obj)
97    elif isinstance(obj, (list, tuple)):
98        return np.asarray(obj)
99    else:
100        return np.asarray([obj])
101
102
103def read_struct(f, fmt):
104    """
105    This reads a struct, and only that struct, from an open file.
106    """
107    s = f.read(struct.calcsize(fmt))
108    return struct.unpack(fmt, s)
109
110
111def just_one(obj):
112    # If we have an iterable, sometimes we only want one item
113    return first(collapse(obj))
114
115
116def compare_dicts(dict1, dict2):
117    if not set(dict1) <= set(dict2):
118        return False
119    for key in dict1.keys():
120        if dict1[key] is not None and dict2[key] is not None:
121            if isinstance(dict1[key], dict):
122                if compare_dicts(dict1[key], dict2[key]):
123                    continue
124                else:
125                    return False
126            try:
127                comparison = np.array_equal(dict1[key], dict2[key])
128            except TypeError:
129                comparison = dict1[key] == dict2[key]
130            if not comparison:
131                return False
132    return True
133
134
135# Taken from
136# http://www.goldb.org/goldblog/2008/02/06/PythonConvertSecsIntoHumanReadableTimeStringHHMMSS.aspx
137def humanize_time(secs):
138    """
139    Takes *secs* and returns a nicely formatted string
140    """
141    mins, secs = divmod(secs, 60)
142    hours, mins = divmod(mins, 60)
143    return "%02d:%02d:%02d" % (hours, mins, secs)
144
145
146#
147# Some function wrappers that come in handy once in a while
148#
149
150# we use the resource module to get the memory page size
151
152try:
153    import resource
154except ImportError:
155    pass
156
157
158def get_memory_usage(subtract_share=False):
159    """
160    Returning resident size in megabytes
161    """
162    pid = os.getpid()
163    try:
164        pagesize = resource.getpagesize()
165    except NameError:
166        return -1024
167    status_file = f"/proc/{pid}/statm"
168    if not os.path.isfile(status_file):
169        return -1024
170    line = open(status_file).read()
171    size, resident, share, text, library, data, dt = (int(i) for i in line.split())
172    if subtract_share:
173        resident -= share
174    return resident * pagesize / (1024 * 1024)  # return in megs
175
176
177def time_execution(func):
178    r"""
179    Decorator for seeing how long a given function takes, depending on whether
180    or not the global 'yt.time_functions' config parameter is set.
181    """
182
183    @wraps(func)
184    def wrapper(*arg, **kw):
185        t1 = time.time()
186        res = func(*arg, **kw)
187        t2 = time.time()
188        mylog.debug("%s took %0.3f s", func.__name__, (t2 - t1))
189        return res
190
191    from yt.config import ytcfg
192
193    if ytcfg.get("yt", "time_functions"):
194        return wrapper
195    else:
196        return func
197
198
199def print_tb(func):
200    """
201    This function is used as a decorate on a function to have the calling stack
202    printed whenever that function is entered.
203
204    This can be used like so:
205
206    >>> @print_tb
207    ... def some_deeply_nested_function(*args, **kwargs):
208    ...     ...
209
210    """
211
212    @wraps(func)
213    def run_func(*args, **kwargs):
214        traceback.print_stack()
215        return func(*args, **kwargs)
216
217    return run_func
218
219
220def rootonly(func):
221    """
222    This is a decorator that, when used, will only call the function on the
223    root processor and then broadcast the results of the function to all other
224    processors.
225
226    This can be used like so:
227
228    .. code-block:: python
229
230       @rootonly
231       def some_root_only_function(*args, **kwargs):
232           ...
233    """
234    from yt.config import ytcfg
235
236    @wraps(func)
237    def check_parallel_rank(*args, **kwargs):
238        if ytcfg.get("yt", "internals", "topcomm_parallel_rank") > 0:
239            return
240        return func(*args, **kwargs)
241
242    return check_parallel_rank
243
244
245def pdb_run(func):
246    """
247    This decorator inserts a pdb session on top of the call-stack into a
248    function.
249
250    This can be used like so:
251
252    >>> @pdb_run
253    ... def some_function_to_debug(*args, **kwargs):
254    ...     ...
255
256    """
257
258    @wraps(func)
259    def wrapper(*args, **kw):
260        pdb.runcall(func, *args, **kw)
261
262    return wrapper
263
264
265__header = """
266== Welcome to the embedded IPython Shell ==
267
268   You are currently inside the function:
269     %(fname)s
270
271   Defined in:
272     %(filename)s:%(lineno)s
273"""
274
275
276def insert_ipython(num_up=1):
277    """
278    Placed inside a function, this will insert an IPython interpreter at that
279    current location.  This will enabled detailed inspection of the current
280    execution environment, as well as (optional) modification of that environment.
281    *num_up* refers to how many frames of the stack get stripped off, and
282    defaults to 1 so that this function itself is stripped off.
283    """
284    import IPython
285    from IPython.terminal.embed import InteractiveShellEmbed
286
287    try:
288        from traitlets.config.loader import Config
289    except ImportError:
290        from IPython.config.loader import Config
291
292    frame = inspect.stack()[num_up]
293    loc = frame[0].f_locals.copy()
294    glo = frame[0].f_globals
295    dd = dict(fname=frame[3], filename=frame[1], lineno=frame[2])
296    cfg = Config()
297    cfg.InteractiveShellEmbed.local_ns = loc
298    cfg.InteractiveShellEmbed.global_ns = glo
299    IPython.embed(config=cfg, banner2=__header % dd)
300    ipshell = InteractiveShellEmbed(config=cfg)
301
302    del ipshell
303
304
305#
306# Our progress bar types and how to get one
307#
308
309
310class TqdmProgressBar:
311    # This is a drop in replacement for pbar
312    # called tqdm
313    def __init__(self, title, maxval):
314        self._pbar = tqdm(leave=True, total=maxval, desc=title)
315        self.i = 0
316
317    def update(self, i=None):
318        if i is None:
319            i = self.i + 1
320        n = i - self.i
321        self.i = i
322        self._pbar.update(n)
323
324    def finish(self):
325        self._pbar.close()
326
327
328class DummyProgressBar:
329    # This progressbar gets handed if we don't
330    # want ANY output
331    def __init__(self, *args, **kwargs):
332        return
333
334    def update(self, *args, **kwargs):
335        return
336
337    def finish(self, *args, **kwargs):
338        return
339
340
341def get_pbar(title, maxval):
342    """
343    This returns a progressbar of the most appropriate type, given a *title*
344    and a *maxval*.
345    """
346    maxval = max(maxval, 1)
347    from yt.config import ytcfg
348
349    if (
350        ytcfg.get("yt", "suppress_stream_logging")
351        or ytcfg.get("yt", "internals", "within_testing")
352        or maxval == 1
353        or not is_root()
354    ):
355        return DummyProgressBar()
356    return TqdmProgressBar(title, maxval)
357
358
359def only_on_root(func, *args, **kwargs):
360    """
361    This function accepts a *func*, a set of *args* and *kwargs* and then only
362    on the root processor calls the function.  All other processors get "None"
363    handed back.
364    """
365    from yt.config import ytcfg
366
367    if kwargs.pop("global_rootonly", False):
368        cfg_option = "global_parallel_rank"
369    else:
370        cfg_option = "topcomm_parallel_rank"
371    if not ytcfg.get("yt", "internals", "parallel"):
372        return func(*args, **kwargs)
373    if ytcfg.get("yt", "internals", cfg_option) > 0:
374        return
375    return func(*args, **kwargs)
376
377
378def is_root():
379    """
380    This function returns True if it is on the root processor of the
381    topcomm and False otherwise.
382    """
383    from yt.config import ytcfg
384
385    if not ytcfg.get("yt", "internals", "parallel"):
386        return True
387    return ytcfg.get("yt", "internals", "topcomm_parallel_rank") == 0
388
389
390#
391# Our signal and traceback handling functions
392#
393
394
395def signal_print_traceback(signo, frame):
396    print(traceback.print_stack(frame))
397
398
399def signal_problem(signo, frame):
400    raise RuntimeError()
401
402
403def signal_ipython(signo, frame):
404    insert_ipython(2)
405
406
407def paste_traceback(exc_type, exc, tb):
408    """
409    This is a traceback handler that knows how to paste to the pastebin.
410    Should only be used in sys.excepthook.
411    """
412    sys.__excepthook__(exc_type, exc, tb)
413    import xmlrpc.client
414    from io import StringIO
415
416    p = xmlrpc.client.ServerProxy(
417        "http://paste.yt-project.org/xmlrpc/", allow_none=True
418    )
419    s = StringIO()
420    traceback.print_exception(exc_type, exc, tb, file=s)
421    s = s.getvalue()
422    ret = p.pastes.newPaste("pytb", s, None, "", "", True)
423    print()
424    print(f"Traceback pasted to http://paste.yt-project.org/show/{ret}")
425    print()
426
427
428def paste_traceback_detailed(exc_type, exc, tb):
429    """
430    This is a traceback handler that knows how to paste to the pastebin.
431    Should only be used in sys.excepthook.
432    """
433    import cgitb
434    import xmlrpc.client
435    from io import StringIO
436
437    s = StringIO()
438    handler = cgitb.Hook(format="text", file=s)
439    handler(exc_type, exc, tb)
440    s = s.getvalue()
441    print(s)
442    p = xmlrpc.client.ServerProxy(
443        "http://paste.yt-project.org/xmlrpc/", allow_none=True
444    )
445    ret = p.pastes.newPaste("text", s, None, "", "", True)
446    print()
447    print(f"Traceback pasted to http://paste.yt-project.org/show/{ret}")
448    print()
449
450
451_ss = "fURbBUUBE0cLXgETJnZgJRMXVhVGUQpQAUBuehQMUhJWRFFRAV1ERAtBXw1dAxMLXT4zXBFfABNN\nC0ZEXw1YUURHCxMXVlFERwxWCQw=\n"
452
453
454def _rdbeta(key):
455    enc_s = base64.decodestring(_ss)
456    dec_s = "".join(chr(ord(a) ^ ord(b)) for a, b in zip(enc_s, itertools.cycle(key)))
457    print(dec_s)
458
459
460#
461# Some exceptions
462#
463
464
465class NoCUDAException(Exception):
466    pass
467
468
469class YTEmptyClass:
470    pass
471
472
473def update_git(path):
474    try:
475        import git
476    except ImportError:
477        print("Updating and precise version information requires ")
478        print("gitpython to be installed.")
479        print("Try: python -m pip install gitpython")
480        return -1
481    with open(os.path.join(path, "yt_updater.log"), "a") as f:
482        repo = git.Repo(path)
483        if repo.is_dirty(untracked_files=True):
484            print("Changes have been made to the yt source code so I won't ")
485            print("update the code. You will have to do this yourself.")
486            print("Here's a set of sample commands:")
487            print("")
488            print(f"    $ cd {path}")
489            print("    $ git stash")
490            print("    $ git checkout main")
491            print("    $ git pull")
492            print("    $ git stash pop")
493            print(f"    $ {sys.executable} setup.py develop")
494            print("")
495            return 1
496        if repo.active_branch.name != "main":
497            print("yt repository is not tracking the main branch so I won't ")
498            print("update the code. You will have to do this yourself.")
499            print("Here's a set of sample commands:")
500            print("")
501            print(f"    $ cd {path}")
502            print("    $ git checkout main")
503            print("    $ git pull")
504            print(f"    $ {sys.executable} setup.py develop")
505            print("")
506            return 1
507        print("Updating the repository")
508        f.write("Updating the repository\n\n")
509        old_version = repo.git.rev_parse("HEAD", short=12)
510        try:
511            remote = repo.remotes.yt_upstream
512        except AttributeError:
513            remote = repo.create_remote(
514                "yt_upstream", url="https://github.com/yt-project/yt"
515            )
516            remote.fetch()
517        main = repo.heads.main
518        main.set_tracking_branch(remote.refs.main)
519        main.checkout()
520        remote.pull()
521        new_version = repo.git.rev_parse("HEAD", short=12)
522        f.write(f"Updated from {old_version} to {new_version}\n\n")
523        rebuild_modules(path, f)
524    print("Updated successfully")
525
526
527def rebuild_modules(path, f):
528    f.write("Rebuilding modules\n\n")
529    p = subprocess.Popen(
530        [sys.executable, "setup.py", "build_ext", "-i"],
531        cwd=path,
532        stdout=subprocess.PIPE,
533        stderr=subprocess.STDOUT,
534    )
535    stdout, stderr = p.communicate()
536    f.write(stdout.decode("utf-8"))
537    f.write("\n\n")
538    if p.returncode:
539        print(f"BROKEN: See {os.path.join(path, 'yt_updater.log')}")
540        sys.exit(1)
541    f.write("Successful!\n")
542
543
544def get_git_version(path):
545    try:
546        import git
547    except ImportError:
548        print("Updating and precise version information requires ")
549        print("gitpython to be installed.")
550        print("Try: python -m pip install gitpython")
551        return None
552    try:
553        repo = git.Repo(path)
554        return repo.git.rev_parse("HEAD", short=12)
555    except git.InvalidGitRepositoryError:
556        # path is not a git repository
557        return None
558
559
560def get_yt_version():
561    import pkg_resources
562
563    yt_provider = pkg_resources.get_provider("yt")
564    path = os.path.dirname(yt_provider.module_path)
565    version = get_git_version(path)
566    if version is None:
567        return version
568    else:
569        v_str = version[:12].strip()
570        if hasattr(v_str, "decode"):
571            v_str = v_str.decode("utf-8")
572        return v_str
573
574
575def get_version_stack():
576    version_info = {}
577    version_info["yt"] = get_yt_version()
578    version_info["numpy"] = np.version.version
579    version_info["matplotlib"] = matplotlib.__version__
580    return version_info
581
582
583def get_script_contents():
584    top_frame = inspect.stack()[-1]
585    finfo = inspect.getframeinfo(top_frame[0])
586    if finfo[2] != "<module>":
587        return None
588    if not os.path.exists(finfo[0]):
589        return None
590    try:
591        contents = open(finfo[0]).read()
592    except Exception:
593        contents = None
594    return contents
595
596
597def download_file(url, filename):
598    try:
599        return fancy_download_file(url, filename, requests)
600    except ImportError:
601        # fancy_download_file requires requests
602        return simple_download_file(url, filename)
603
604
605def fancy_download_file(url, filename, requests=None):
606    response = requests.get(url, stream=True)
607    total_length = response.headers.get("content-length")
608
609    with open(filename, "wb") as fh:
610        if total_length is None:
611            fh.write(response.content)
612        else:
613            blocksize = 4 * 1024 ** 2
614            iterations = int(float(total_length) / float(blocksize))
615
616            pbar = get_pbar(
617                "Downloading %s to %s " % os.path.split(filename)[::-1], iterations
618            )
619            iteration = 0
620            for chunk in response.iter_content(chunk_size=blocksize):
621                fh.write(chunk)
622                iteration += 1
623                pbar.update(iteration)
624            pbar.finish()
625    return filename
626
627
628def simple_download_file(url, filename):
629    class MyURLopener(urllib.request.FancyURLopener):
630        def http_error_default(self, url, fp, errcode, errmsg, headers):
631            raise RuntimeError(
632                "Attempt to download file from %s failed with error %s: %s."
633                % (url, errcode, errmsg)
634            )
635
636    fn, h = MyURLopener().retrieve(url, filename)
637    return fn
638
639
640# This code snippet is modified from Georg Brandl
641def bb_apicall(endpoint, data, use_pass=True):
642    uri = f"https://api.bitbucket.org/1.0/{endpoint}/"
643    # since bitbucket doesn't return the required WWW-Authenticate header when
644    # making a request without Authorization, we cannot use the standard urllib2
645    # auth handlers; we have to add the requisite header from the start
646    if data is not None:
647        data = urllib.parse.urlencode(data)
648    req = urllib.request.Request(uri, data)
649    if use_pass:
650        username = input("Bitbucket Username? ")
651        password = getpass.getpass()
652        upw = f"{username}:{password}"
653        req.add_header("Authorization", f"Basic {base64.b64encode(upw).strip()}")
654    return urllib.request.urlopen(req).read()
655
656
657def fix_length(length, ds):
658    registry = ds.unit_registry
659    if isinstance(length, YTArray):
660        if registry is not None:
661            length.units.registry = registry
662        return length.in_units("code_length")
663    if isinstance(length, numeric_type):
664        return YTArray(length, "code_length", registry=registry)
665    length_valid_tuple = isinstance(length, (list, tuple)) and len(length) == 2
666    unit_is_string = isinstance(length[1], str)
667    length_is_number = isinstance(length[0], numeric_type) and not isinstance(
668        length[0], YTArray
669    )
670    if length_valid_tuple and unit_is_string and length_is_number:
671        return YTArray(*length, registry=registry)
672    else:
673        raise RuntimeError(f"Length {str(length)} is invalid")
674
675
676@contextlib.contextmanager
677def parallel_profile(prefix):
678    r"""A context manager for profiling parallel code execution using cProfile
679
680    This is a simple context manager that automatically profiles the execution
681    of a snippet of code.
682
683    Parameters
684    ----------
685    prefix : string
686        A string name to prefix outputs with.
687
688    Examples
689    --------
690
691    >>> from yt import PhasePlot
692    >>> from yt.testing import fake_random_ds
693    >>> fields = ("density", "temperature", "cell_mass")
694    >>> units = ("g/cm**3", "K", "g")
695    >>> ds = fake_random_ds(16, fields=fields, units=units)
696    >>> with parallel_profile("my_profile"):
697    ...     plot = PhasePlot(ds.all_data(), *fields)
698    """
699    import cProfile
700
701    from yt.config import ytcfg
702
703    fn = "%s_%04i_%04i.cprof" % (
704        prefix,
705        ytcfg.get("yt", "internals", "topcomm_parallel_size"),
706        ytcfg.get("yt", "internals", "topcomm_parallel_rank"),
707    )
708    p = cProfile.Profile()
709    p.enable()
710    yield fn
711    p.disable()
712    p.dump_stats(fn)
713
714
715def get_num_threads():
716    from .config import ytcfg
717
718    nt = ytcfg.get("yt", "num_threads")
719    if nt < 0:
720        return os.environ.get("OMP_NUM_THREADS", 0)
721    return nt
722
723
724def fix_axis(axis, ds):
725    return ds.coordinates.axis_id.get(axis, axis)
726
727
728def get_output_filename(name, keyword, suffix):
729    r"""Return an appropriate filename for output.
730
731    With a name provided by the user, this will decide how to appropriately name the
732    output file by the following rules:
733
734    1. if name is None, the filename will be the keyword plus the suffix.
735    2. if name ends with "/" (resp "\" on Windows), assume name is a directory and the
736       file will be named name/(keyword+suffix).  If the directory does not exist, first
737       try to create it and raise an exception if an error occurs.
738    3. if name does not end in the suffix, add the suffix.
739
740    Parameters
741    ----------
742    name : str
743        A filename given by the user.
744    keyword : str
745        A default filename prefix if name is None.
746    suffix : str
747        Suffix that must appear at end of the filename.
748        This will be added if not present.
749
750    Examples
751    --------
752
753    >>> get_output_filename(None, "Projection_x", ".png")
754    'Projection_x.png'
755    >>> get_output_filename("my_file", "Projection_x", ".png")
756    'my_file.png'
757    >>> get_output_filename("my_dir/", "Projection_x", ".png")
758    'my_dir/Projection_x.png'
759
760    """
761    if name is None:
762        name = keyword
763    name = os.path.expanduser(name)
764    if name.endswith(os.sep) and not os.path.isdir(name):
765        ensure_dir(name)
766    if os.path.isdir(name):
767        name = os.path.join(name, keyword)
768    if not name.endswith(suffix):
769        name += suffix
770    return name
771
772
773def ensure_dir_exists(path):
774    r"""Create all directories in path recursively in a parallel safe manner"""
775    my_dir = os.path.dirname(path)
776    # If path is a file in the current directory, like "test.txt", then my_dir
777    # would be an empty string, resulting in FileNotFoundError when passed to
778    # ensure_dir. Let's avoid that.
779    if my_dir:
780        ensure_dir(my_dir)
781
782
783def ensure_dir(path):
784    r"""Parallel safe directory maker."""
785    if os.path.exists(path):
786        return path
787
788    try:
789        os.makedirs(path)
790    except OSError as e:
791        if e.errno == errno.EEXIST:
792            pass
793        else:
794            raise
795    return path
796
797
798def validate_width_tuple(width):
799    if not is_sequence(width) or len(width) != 2:
800        raise YTInvalidWidthError(f"width ({width}) is not a two element tuple")
801    is_numeric = isinstance(width[0], numeric_type)
802    length_has_units = isinstance(width[0], YTArray)
803    unit_is_string = isinstance(width[1], str)
804    if not is_numeric or length_has_units and unit_is_string:
805        msg = f"width ({str(width)}) is invalid. "
806        msg += "Valid widths look like this: (12, 'au')"
807        raise YTInvalidWidthError(msg)
808
809
810_first_cap_re = re.compile("(.)([A-Z][a-z]+)")
811_all_cap_re = re.compile("([a-z0-9])([A-Z])")
812
813
814@lru_cache(maxsize=128, typed=False)
815def camelcase_to_underscore(name):
816    s1 = _first_cap_re.sub(r"\1_\2", name)
817    return _all_cap_re.sub(r"\1_\2", s1).lower()
818
819
820def set_intersection(some_list):
821    if len(some_list) == 0:
822        return set()
823    # This accepts a list of iterables, which we get the intersection of.
824    s = set(some_list[0])
825    for l in some_list[1:]:
826        s.intersection_update(l)
827    return s
828
829
830@contextlib.contextmanager
831def memory_checker(interval=15, dest=None):
832    r"""This is a context manager that monitors memory usage.
833
834    Parameters
835    ----------
836    interval : int
837        The number of seconds between printing the current memory usage in
838        gigabytes of the current Python interpreter.
839
840    Examples
841    --------
842
843    >>> with memory_checker(10):
844    ...     arr = np.zeros(1024 * 1024 * 1024, dtype="float64")
845    ...     time.sleep(15)
846    ...     del arr
847    MEMORY: -1.000e+00 gb
848    """
849    import threading
850
851    if dest is None:
852        dest = sys.stdout
853
854    class MemoryChecker(threading.Thread):
855        def __init__(self, event, interval):
856            self.event = event
857            self.interval = interval
858            threading.Thread.__init__(self)
859
860        def run(self):
861            while not self.event.wait(self.interval):
862                print(f"MEMORY: {get_memory_usage() / 1024.0:0.3e} gb", file=dest)
863
864    e = threading.Event()
865    mem_check = MemoryChecker(e, interval)
866    mem_check.start()
867    try:
868        yield
869    finally:
870        e.set()
871
872
873def enable_plugins(plugin_filename=None):
874    """Forces a plugin file to be parsed.
875
876    A plugin file is a means of creating custom fields, quantities,
877    data objects, colormaps, and other code classes and objects to be used
878    in yt scripts without modifying the yt source directly.
879
880    If ``plugin_filename`` is omitted, this function will look for a plugin file at
881    ``$HOME/.config/yt/my_plugins.py``, which is the prefered behaviour for a
882    system-level configuration.
883
884    Warning: a script using this function will only be reproducible if your plugin
885    file is shared with it.
886    """
887    import yt
888    from yt.config import config_dir, old_config_dir, ytcfg
889    from yt.fields.my_plugin_fields import my_plugins_fields
890
891    if plugin_filename is not None:
892        _fn = plugin_filename
893        if not os.path.isfile(_fn):
894            raise FileNotFoundError(_fn)
895    else:
896        # Determine global plugin location. By decreasing priority order:
897        # - absolute path
898        # - CONFIG_DIR
899        # - obsolete config dir.
900        my_plugin_name = ytcfg.get("yt", "plugin_filename")
901        for base_prefix in ("", config_dir(), old_config_dir()):
902            if os.path.isfile(os.path.join(base_prefix, my_plugin_name)):
903                _fn = os.path.join(base_prefix, my_plugin_name)
904                break
905        else:
906            raise FileNotFoundError("Could not find a global system plugin file.")
907
908        if _fn.startswith(old_config_dir()):
909            mylog.warning(
910                "Your plugin file is located in a deprecated directory. "
911                "Please move it from %s to %s",
912                os.path.join(old_config_dir(), my_plugin_name),
913                os.path.join(config_dir(), my_plugin_name),
914            )
915
916    mylog.info("Loading plugins from %s", _fn)
917    ytdict = yt.__dict__
918    execdict = ytdict.copy()
919    execdict["add_field"] = my_plugins_fields.add_field
920    with open(_fn) as f:
921        code = compile(f.read(), _fn, "exec")
922        exec(code, execdict, execdict)
923    ytnamespace = list(ytdict.keys())
924    for k in execdict.keys():
925        if k not in ytnamespace:
926            if callable(execdict[k]):
927                setattr(yt, k, execdict[k])
928
929
930def subchunk_count(n_total, chunk_size):
931    handled = 0
932    while handled < n_total:
933        tr = min(n_total - handled, chunk_size)
934        yield tr
935        handled += tr
936
937
938def fix_unitary(u):
939    if u == "1":
940        return "unitary"
941    else:
942        return u
943
944
945def get_hash(infile, algorithm="md5", BLOCKSIZE=65536):
946    """Generate file hash without reading in the entire file at once.
947
948    Original code licensed under MIT.  Source:
949    https://www.pythoncentral.io/hashing-files-with-python/
950
951    Parameters
952    ----------
953    infile : str
954        File of interest (including the path).
955    algorithm : str (optional)
956        Hash algorithm of choice. Defaults to 'md5'.
957    BLOCKSIZE : int (optional)
958        How much data in bytes to read in at once.
959
960    Returns
961    -------
962    hash : str
963        The hash of the file.
964
965    Examples
966    --------
967    >>> from tempfile import NamedTemporaryFile
968    >>> with NamedTemporaryFile() as file:
969    ...     get_hash(file.name)
970    'd41d8cd98f00b204e9800998ecf8427e'
971    """
972    import hashlib
973
974    try:
975        hasher = getattr(hashlib, algorithm)()
976    except AttributeError as e:
977        raise NotImplementedError(
978            f"'{algorithm}' not available!  Available algorithms: {hashlib.algorithms}"
979        ) from e
980
981    filesize = os.path.getsize(infile)
982    iterations = int(float(filesize) / float(BLOCKSIZE))
983
984    pbar = get_pbar(f"Generating {algorithm} hash", iterations)
985
986    iter = 0
987    with open(infile, "rb") as f:
988        buf = f.read(BLOCKSIZE)
989        while len(buf) > 0:
990            hasher.update(buf)
991            buf = f.read(BLOCKSIZE)
992            iter += 1
993            pbar.update(iter)
994        pbar.finish()
995
996    return hasher.hexdigest()
997
998
999def get_brewer_cmap(cmap):
1000    """Returns a colorbrewer colormap from palettable"""
1001    try:
1002        import brewer2mpl
1003    except ImportError:
1004        brewer2mpl = None
1005    try:
1006        import palettable
1007    except ImportError:
1008        palettable = None
1009    if palettable is not None:
1010        bmap = palettable.colorbrewer.get_map(*cmap)
1011    elif brewer2mpl is not None:
1012        warnings.warn(
1013            "Using brewer2mpl colormaps is deprecated. "
1014            "Please install the successor to brewer2mpl, "
1015            "palettable, with `pip install palettable`. "
1016            "Colormap tuple names remain unchanged."
1017        )
1018        bmap = brewer2mpl.get_map(*cmap)
1019    else:
1020        raise RuntimeError("Please install palettable to use colorbrewer colormaps")
1021    return bmap.get_mpl_colormap(N=cmap[2])
1022
1023
1024@contextlib.contextmanager
1025def dummy_context_manager(*args, **kwargs):
1026    yield
1027
1028
1029def matplotlib_style_context(style_name=None, after_reset=False):
1030    """Returns a context manager for controlling matplotlib style.
1031
1032    Arguments are passed to matplotlib.style.context() if specified. Defaults
1033    to setting "classic" style, after resetting to the default config parameters.
1034
1035    On older matplotlib versions (<=1.5.0) where matplotlib.style isn't
1036    available, returns a dummy context manager.
1037    """
1038    if style_name is None:
1039        import matplotlib
1040
1041        style_name = {"mathtext.fontset": "cm"}
1042        if parse_version(matplotlib.__version__) >= parse_version("3.3.0"):
1043            style_name["mathtext.fallback"] = "cm"
1044        else:
1045            style_name["mathtext.fallback_to_cm"] = True
1046    try:
1047        import matplotlib.style
1048
1049        return matplotlib.style.context(style_name, after_reset=after_reset)
1050    except ImportError:
1051        pass
1052    return dummy_context_manager()
1053
1054
1055interactivity = False
1056
1057"""Sets the condition that interactive backends can be used."""
1058
1059
1060def toggle_interactivity():
1061    global interactivity
1062    interactivity = not interactivity
1063    if interactivity:
1064        if "__IPYTHON__" in dir(builtins):
1065            import IPython
1066
1067            shell = IPython.get_ipython()
1068            shell.magic("matplotlib")
1069        else:
1070            import matplotlib
1071
1072            matplotlib.interactive(True)
1073
1074
1075def get_interactivity():
1076    return interactivity
1077
1078
1079def setdefaultattr(obj, name, value):
1080    """Set attribute with *name* on *obj* with *value* if it doesn't exist yet
1081
1082    Analogous to dict.setdefault
1083    """
1084    if not hasattr(obj, name):
1085        setattr(obj, name, value)
1086    return getattr(obj, name)
1087
1088
1089def parse_h5_attr(f, attr):
1090    """A Python3-safe function for getting hdf5 attributes.
1091
1092    If an attribute is supposed to be a string, this will return it as such.
1093    """
1094    val = f.attrs.get(attr, None)
1095    if isinstance(val, bytes):
1096        return val.decode("utf8")
1097    else:
1098        return val
1099
1100
1101def obj_length(v):
1102    if is_sequence(v):
1103        return len(v)
1104    else:
1105        # If something isn't iterable, we return 0
1106        # to signify zero length (aka a scalar).
1107        return 0
1108
1109
1110def array_like_field(data, x, field):
1111    field = data._determine_fields(field)[0]
1112    if isinstance(field, tuple):
1113        finfo = data.ds._get_field_info(field[0], field[1])
1114    else:
1115        finfo = data.ds._get_field_info(field)
1116    if finfo.sampling_type == "particle":
1117        units = finfo.output_units
1118    else:
1119        units = finfo.units
1120    if isinstance(x, YTArray):
1121        arr = copy.deepcopy(x)
1122        arr.convert_to_units(units)
1123        return arr
1124    if isinstance(x, np.ndarray):
1125        return data.ds.arr(x, units)
1126    else:
1127        return data.ds.quan(x, units)
1128
1129
1130def validate_3d_array(obj):
1131    if not is_sequence(obj) or len(obj) != 3:
1132        raise TypeError(
1133            "Expected an array of size (3,), received '%s' of "
1134            "length %s" % (str(type(obj)).split("'")[1], len(obj))
1135        )
1136
1137
1138def validate_float(obj):
1139    """Validates if the passed argument is a float value.
1140
1141    Raises an exception if `obj` is a single float value
1142    or a YTQuantity of size 1.
1143
1144    Parameters
1145    ----------
1146    obj : Any
1147        Any argument which needs to be checked for a single float value.
1148
1149    Raises
1150    ------
1151    TypeError
1152        Raised if `obj` is not a single float value or YTQunatity
1153
1154    Examples
1155    --------
1156    >>> validate_float(1)
1157    >>> validate_float(1.50)
1158    >>> validate_float(YTQuantity(1, "cm"))
1159    >>> validate_float((1, "cm"))
1160    >>> validate_float([1, 1, 1])
1161    Traceback (most recent call last):
1162    ...
1163    TypeError: Expected a numeric value (or size-1 array), received 'list' of length 3
1164
1165    >>> validate_float([YTQuantity(1, "cm"), YTQuantity(2, "cm")])
1166    Traceback (most recent call last):
1167    ...
1168    TypeError: Expected a numeric value (or size-1 array), received 'list' of length 2
1169    """
1170    if isinstance(obj, tuple):
1171        if (
1172            len(obj) != 2
1173            or not isinstance(obj[0], numeric_type)
1174            or not isinstance(obj[1], str)
1175        ):
1176            raise TypeError(
1177                "Expected a numeric value (or tuple of format "
1178                "(float, String)), received an inconsistent tuple "
1179                "'%s'." % str(obj)
1180            )
1181        else:
1182            return
1183    if is_sequence(obj) and (len(obj) != 1 or not isinstance(obj[0], numeric_type)):
1184        raise TypeError(
1185            "Expected a numeric value (or size-1 array), "
1186            "received '%s' of length %s" % (str(type(obj)).split("'")[1], len(obj))
1187        )
1188
1189
1190def validate_sequence(obj):
1191    if obj is not None and not is_sequence(obj):
1192        raise TypeError(
1193            "Expected an iterable object,"
1194            " received '%s'" % str(type(obj)).split("'")[1]
1195        )
1196
1197
1198def validate_object(obj, data_type):
1199    if obj is not None and not isinstance(obj, data_type):
1200        raise TypeError(
1201            "Expected an object of '%s' type, received '%s'"
1202            % (str(data_type).split("'")[1], str(type(obj)).split("'")[1])
1203        )
1204
1205
1206def validate_axis(ds, axis):
1207    if ds is not None:
1208        valid_axis = ds.coordinates.axis_name.keys()
1209    else:
1210        valid_axis = [0, 1, 2, "x", "y", "z", "X", "Y", "Z"]
1211    if axis not in valid_axis:
1212        raise TypeError(
1213            "Expected axis of int or char type (can be %s), "
1214            "received '%s'." % (list(valid_axis), axis)
1215        )
1216
1217
1218def validate_center(center):
1219    if isinstance(center, str):
1220        c = center.lower()
1221        if (
1222            c not in ["c", "center", "m", "max", "min"]
1223            and not c.startswith("max_")
1224            and not c.startswith("min_")
1225        ):
1226            raise TypeError(
1227                "Expected 'center' to be in ['c', 'center', "
1228                "'m', 'max', 'min'] or the prefix to be "
1229                "'max_'/'min_', received '%s'." % center
1230            )
1231    elif not isinstance(center, (numeric_type, YTQuantity)) and not is_sequence(center):
1232        raise TypeError(
1233            "Expected 'center' to be a numeric object of type "
1234            "list/tuple/np.ndarray/YTArray/YTQuantity, "
1235            "received '%s'." % str(type(center)).split("'")[1]
1236        )
1237
1238
1239def sglob(pattern):
1240    """
1241    Return the results of a glob through the sorted() function.
1242    """
1243    return sorted(glob.glob(pattern))
1244
1245
1246def dictWithFactory(factory: Callable[[Any], Any]) -> Type:
1247    """
1248    Create a dictionary class with a default factory function.
1249    Contrary to `collections.defaultdict`, the factory takes
1250    the missing key as input parameter.
1251
1252    Parameters
1253    ----------
1254    factory : callable(key) -> value
1255        The factory to call when hitting a missing key
1256
1257    Returns
1258    -------
1259    DictWithFactory class
1260        A class to create new dictionaries handling missing keys.
1261    """
1262
1263    class DictWithFactory(dict):
1264        def __init__(self, *args, **kwargs):
1265            self.factory = factory
1266            super().__init__(*args, **kwargs)
1267
1268        def __missing__(self, key):
1269            val = self.factory(key)
1270            self[key] = val
1271            return val
1272
1273    return DictWithFactory
1274
1275
1276def levenshtein_distance(seq1, seq2, max_dist=None):
1277    """
1278    Compute the levenshtein distance between seq1 and seq2.
1279    From https://stackabuse.com/levenshtein-distance-and-text-similarity-in-python/
1280
1281    Parameters
1282    ----------
1283    seq1 : str
1284    seq2 : str
1285        The strings to compute the distance between
1286    max_dist : integer
1287        If not None, maximum distance returned (see notes).
1288
1289    Returns
1290    -------
1291    The Levenshtein distance as an integer.
1292
1293    Notes
1294    -----
1295    This computes the Levenshtein distance, i.e. the number of edits to change
1296    seq1 into seq2. If a maximum distance is passed, the algorithm will stop as soon
1297    as the number of edits goes above the value. This allows for an earlier break
1298    and speeds calculations up.
1299    """
1300    size_x = len(seq1) + 1
1301    size_y = len(seq2) + 1
1302    if max_dist is None:
1303        max_dist = max(size_x, size_y)
1304
1305    if abs(size_x - size_y) > max_dist:
1306        return max_dist + 1
1307    matrix = np.zeros((size_x, size_y), dtype=int)
1308    for x in range(size_x):
1309        matrix[x, 0] = x
1310    for y in range(size_y):
1311        matrix[0, y] = y
1312
1313    for x in range(1, size_x):
1314        for y in range(1, size_y):
1315            if seq1[x - 1] == seq2[y - 1]:
1316                matrix[x, y] = min(
1317                    matrix[x - 1, y] + 1, matrix[x - 1, y - 1], matrix[x, y - 1] + 1
1318                )
1319            else:
1320                matrix[x, y] = min(
1321                    matrix[x - 1, y] + 1, matrix[x - 1, y - 1] + 1, matrix[x, y - 1] + 1
1322                )
1323
1324        # Early break: the minimum distance is already larger than
1325        # maximum allow value, can return safely.
1326        if matrix[x].min() > max_dist:
1327            return max_dist + 1
1328    return matrix[size_x - 1, size_y - 1]
1329