1import functools
2import hashlib
3import importlib
4import itertools as it
5import os
6import pickle
7import shutil
8import tempfile
9import unittest
11import matplotlib
12import numpy as np
13from more_itertools import always_iterable
14from numpy.random import RandomState
15from unyt.exceptions import UnitOperationError
17from yt.config import ytcfg
18from yt.funcs import is_sequence
19from yt.loaders import load
20from yt.units.yt_array import YTArray, YTQuantity
22# we import this in a weird way from numpy.testing to avoid triggering
23# flake8 errors from the unused imports. These test functions are imported
24# elsewhere in yt from here so we want them to be imported here.
25from numpy.testing import assert_array_equal, assert_almost_equal  # NOQA isort:skip
26from numpy.testing import assert_equal, assert_array_less  # NOQA isort:skip
27from numpy.testing import assert_string_equal  # NOQA isort:skip
28from numpy.testing import assert_array_almost_equal_nulp  # NOQA isort:skip
29from numpy.testing import assert_allclose, assert_raises  # NOQA isort:skip
30from numpy.testing import assert_approx_equal  # NOQA isort:skip
31from numpy.testing import assert_array_almost_equal  # NOQA isort:skip
33ANSWER_TEST_TAG = "answer_test"
34# Expose assert_true and assert_less_equal from unittest.TestCase
35# this is adopted from nose. Doing this here allows us to avoid importing
36# nose at the top level.
37class _Dummy(unittest.TestCase):
38    def nop():
39        pass
42_t = _Dummy("nop")
44assert_true = getattr(_t, "assertTrue")  # noqa: B009
45assert_less_equal = getattr(_t, "assertLessEqual")  # noqa: B009
48def assert_rel_equal(a1, a2, decimals, err_msg="", verbose=True):
49    # We have nan checks in here because occasionally we have fields that get
50    # weighted without non-zero weights.  I'm looking at you, particle fields!
51    if isinstance(a1, np.ndarray):
52        assert a1.size == a2.size
53        # Mask out NaNs
54        assert (np.isnan(a1) == np.isnan(a2)).all()
55        a1[np.isnan(a1)] = 1.0
56        a2[np.isnan(a2)] = 1.0
57        # Mask out 0
58        ind1 = np.array(np.abs(a1) < np.finfo(a1.dtype).eps)
59        ind2 = np.array(np.abs(a2) < np.finfo(a2.dtype).eps)
60        assert (ind1 == ind2).all()
61        a1[ind1] = 1.0
62        a2[ind2] = 1.0
63    elif np.any(np.isnan(a1)) and np.any(np.isnan(a2)):
64        return True
65    if not isinstance(a1, np.ndarray) and a1 == a2 == 0.0:
66        # NANS!
67        a1 = a2 = 1.0
68    return assert_almost_equal(
69        np.array(a1) / np.array(a2), 1.0, decimals, err_msg=err_msg, verbose=verbose
70    )
73def amrspace(extent, levels=7, cells=8):
74    """Creates two numpy arrays representing the left and right bounds of
75    an AMR grid as well as an array for the AMR level of each cell.
77    Parameters
78    ----------
79    extent : array-like
80        This a sequence of length 2*ndims that is the bounds of each dimension.
81        For example, the 2D unit square would be given by [0.0, 1.0, 0.0, 1.0].
82        A 3D cylindrical grid may look like [0.0, 2.0, -1.0, 1.0, 0.0, 2*np.pi].
83    levels : int or sequence of ints, optional
84        This is the number of AMR refinement levels.  If given as a sequence (of
85        length ndims), then each dimension will be refined down to this level.
86        All values in this array must be the same or zero.  A zero valued dimension
87        indicates that this dim should not be refined.  Taking the 3D cylindrical
88        example above if we don't want refine theta but want r and z at 5 we would
89        set levels=(5, 5, 0).
90    cells : int, optional
91        This is the number of cells per refinement level.
93    Returns
94    -------
95    left : float ndarray, shape=(npoints, ndims)
96        The left AMR grid points.
97    right : float ndarray, shape=(npoints, ndims)
98        The right AMR grid points.
99    level : int ndarray, shape=(npoints,)
100        The AMR level for each point.
102    Examples
103    --------
104    >>> l, r, lvl = amrspace([0.0, 2.0, 1.0, 2.0, 0.0, 3.14], levels=(3, 3, 0), cells=2)
105    >>> print(l)
106    [[ 0.     1.     0.   ]
107     [ 0.25   1.     0.   ]
108     [ 0.     1.125  0.   ]
109     [ 0.25   1.125  0.   ]
110     [ 0.5    1.     0.   ]
111     [ 0.     1.25   0.   ]
112     [ 0.5    1.25   0.   ]
113     [ 1.     1.     0.   ]
114     [ 0.     1.5    0.   ]
115     [ 1.     1.5    0.   ]]
117    """
118    extent = np.asarray(extent, dtype="f8")
119    dextent = extent[1::2] - extent[::2]
120    ndims = len(dextent)
122    if isinstance(levels, int):
123        minlvl = maxlvl = levels
124        levels = np.array([levels] * ndims, dtype="int32")
125    else:
126        levels = np.asarray(levels, dtype="int32")
127        minlvl = levels.min()
128        maxlvl = levels.max()
129        if minlvl != maxlvl and (minlvl != 0 or {minlvl, maxlvl} != set(levels)):
130            raise ValueError("all levels must have the same value or zero.")
131    dims_zero = levels == 0
132    dims_nonzero = ~dims_zero
133    ndims_nonzero = dims_nonzero.sum()
135    npoints = (cells ** ndims_nonzero - 1) * maxlvl + 1
136    left = np.empty((npoints, ndims), dtype="float64")
137    right = np.empty((npoints, ndims), dtype="float64")
138    level = np.empty(npoints, dtype="int32")
140    # fill zero dims
141    left[:, dims_zero] = extent[::2][dims_zero]
142    right[:, dims_zero] = extent[1::2][dims_zero]
144    # fill non-zero dims
145    dcell = 1.0 / cells
146    left_slice = tuple(
147        slice(extent[2 * n], extent[2 * n + 1], extent[2 * n + 1])
148        if dims_zero[n]
149        else slice(0.0, 1.0, dcell)
150        for n in range(ndims)
151    )
152    right_slice = tuple(
153        slice(extent[2 * n + 1], extent[2 * n], -extent[2 * n + 1])
154        if dims_zero[n]
155        else slice(dcell, 1.0 + dcell, dcell)
156        for n in range(ndims)
157    )
158    left_norm_grid = np.reshape(np.mgrid[left_slice].T.flat[ndims:], (-1, ndims))
159    lng_zero = left_norm_grid[:, dims_zero]
160    lng_nonzero = left_norm_grid[:, dims_nonzero]
162    right_norm_grid = np.reshape(np.mgrid[right_slice].T.flat[ndims:], (-1, ndims))
163    rng_zero = right_norm_grid[:, dims_zero]
164    rng_nonzero = right_norm_grid[:, dims_nonzero]
166    level[0] = maxlvl
167    left[0, :] = extent[::2]
168    right[0, dims_zero] = extent[1::2][dims_zero]
169    right[0, dims_nonzero] = (dcell ** maxlvl) * dextent[dims_nonzero] + extent[::2][
170        dims_nonzero
171    ]
172    for i, lvl in enumerate(range(maxlvl, 0, -1)):
173        start = (cells ** ndims_nonzero - 1) * i + 1
174        stop = (cells ** ndims_nonzero - 1) * (i + 1) + 1
175        dsize = dcell ** (lvl - 1) * dextent[dims_nonzero]
176        level[start:stop] = lvl
177        left[start:stop, dims_zero] = lng_zero
178        left[start:stop, dims_nonzero] = lng_nonzero * dsize + extent[::2][dims_nonzero]
179        right[start:stop, dims_zero] = rng_zero
180        right[start:stop, dims_nonzero] = (
181            rng_nonzero * dsize + extent[::2][dims_nonzero]
182        )
184    return left, right, level
187def _check_field_unit_args_helper(args: dict, default_args: dict):
188    values = list(args.values())
189    keys = list(args.keys())
190    if all(v is None for v in values):
191        for key in keys:
192            args[key] = default_args[key]
193    elif None in values:
194        raise ValueError(
195            "Error in creating a fake dataset:"
196            f" either all or none of the following arguments need to specified: {keys}."
197        )
198    elif any(len(v) != len(values[0]) for v in values):
199        raise ValueError(
200            "Error in creating a fake dataset:"
201            f" all the following arguments must have the same length: {keys}."
202        )
203    return list(args.values())
206_fake_random_ds_default_fields = ("density", "velocity_x", "velocity_y", "velocity_z")
207_fake_random_ds_default_units = ("g/cm**3", "cm/s", "cm/s", "cm/s")
208_fake_random_ds_default_negative = (False, False, False, False)
211def fake_random_ds(
212    ndims,
213    peak_value=1.0,
214    fields=None,
215    units=None,
216    particle_fields=None,
217    particle_field_units=None,
218    negative=False,
219    nprocs=1,
220    particles=0,
221    length_unit=1.0,
222    unit_system="cgs",
223    bbox=None,
224    default_species_fields=None,
226    from yt.loaders import load_uniform_grid
228    prng = RandomState(0x4D3D3D3)
229    if not is_sequence(ndims):
230        ndims = [ndims, ndims, ndims]
231    else:
232        assert len(ndims) == 3
233    if not is_sequence(negative):
234        if fields:
235            negative = [negative for f in fields]
236        else:
237            negative = None
239    fields, units, negative = _check_field_unit_args_helper(
240        {
241            "fields": fields,
242            "units": units,
243            "negative": negative,
244        },
245        {
246            "fields": _fake_random_ds_default_fields,
247            "units": _fake_random_ds_default_units,
248            "negative": _fake_random_ds_default_negative,
249        },
250    )
252    offsets = []
253    for n in negative:
254        if n:
255            offsets.append(0.5)
256        else:
257            offsets.append(0.0)
258    data = {}
259    for field, offset, u in zip(fields, offsets, units):
260        v = (prng.random_sample(ndims) - offset) * peak_value
261        if field[0] == "all":
262            v = v.ravel()
263        data[field] = (v, u)
264    if particles:
265        if particle_fields is not None:
266            for field, unit in zip(particle_fields, particle_field_units):
267                if field in ("particle_position", "particle_velocity"):
268                    data["io", field] = (prng.random_sample((int(particles), 3)), unit)
269                else:
270                    data["io", field] = (prng.random_sample(size=int(particles)), unit)
271        else:
272            for f in (f"particle_position_{ax}" for ax in "xyz"):
273                data["io", f] = (prng.random_sample(size=particles), "code_length")
274            for f in (f"particle_velocity_{ax}" for ax in "xyz"):
275                data["io", f] = (prng.random_sample(size=particles) - 0.5, "cm/s")
276            data["io", "particle_mass"] = (prng.random_sample(particles), "g")
277    ug = load_uniform_grid(
278        data,
279        ndims,
280        length_unit=length_unit,
281        nprocs=nprocs,
282        unit_system=unit_system,
283        bbox=bbox,
284        default_species_fields=default_species_fields,
285    )
286    return ug
289_geom_transforms = {
290    # These are the bounds we want.  Cartesian we just assume goes 0 .. 1.
291    "cartesian": ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)),
292    "spherical": ((0.0, 0.0, 0.0), (1.0, np.pi, 2 * np.pi)),
293    "cylindrical": ((0.0, 0.0, 0.0), (1.0, 1.0, 2.0 * np.pi)),  # rzt
294    "polar": ((0.0, 0.0, 0.0), (1.0, 2.0 * np.pi, 1.0)),  # rtz
295    "geographic": ((-90.0, -180.0, 0.0), (90.0, 180.0, 1000.0)),  # latlonalt
296    "internal_geographic": ((-90.0, -180.0, 0.0), (90.0, 180.0, 1000.0)),  # latlondep
300_fake_amr_ds_default_fields = ("Density",)
301_fake_amr_ds_default_units = ("g/cm**3",)
304def fake_amr_ds(
305    fields=None, units=None, geometry="cartesian", particles=0, length_unit=None
307    from yt.loaders import load_amr_grids
309    fields, units = _check_field_unit_args_helper(
310        {
311            "fields": fields,
312            "units": units,
313        },
314        {
315            "fields": _fake_amr_ds_default_fields,
316            "units": _fake_amr_ds_default_units,
317        },
318    )
320    prng = RandomState(0x4D3D3D3)
321    LE, RE = _geom_transforms[geometry]
322    LE = np.array(LE)
323    RE = np.array(RE)
324    data = []
325    for gspec in _amr_grid_index:
326        level, left_edge, right_edge, dims = gspec
327        left_edge = left_edge * (RE - LE) + LE
328        right_edge = right_edge * (RE - LE) + LE
329        gdata = dict(
330            level=level, left_edge=left_edge, right_edge=right_edge, dimensions=dims
331        )
332        for f, u in zip(fields, units):
333            gdata[f] = (prng.random_sample(dims), u)
334        if particles:
335            for i, f in enumerate(f"particle_position_{ax}" for ax in "xyz"):
336                pdata = prng.random_sample(particles)
337                pdata /= right_edge[i] - left_edge[i]
338                pdata += left_edge[i]
339                gdata["io", f] = (pdata, "code_length")
340            for f in (f"particle_velocity_{ax}" for ax in "xyz"):
341                gdata["io", f] = (prng.random_sample(particles) - 0.5, "cm/s")
342            gdata["io", "particle_mass"] = (prng.random_sample(particles), "g")
343        data.append(gdata)
344    bbox = np.array([LE, RE]).T
345    return load_amr_grids(
346        data, [32, 32, 32], geometry=geometry, bbox=bbox, length_unit=length_unit
347    )
350_fake_particle_ds_default_fields = (
351    "particle_position_x",
352    "particle_position_y",
353    "particle_position_z",
354    "particle_mass",
355    "particle_velocity_x",
356    "particle_velocity_y",
357    "particle_velocity_z",
359_fake_particle_ds_default_units = ("cm", "cm", "cm", "g", "cm/s", "cm/s", "cm/s")
360_fake_particle_ds_default_negative = (False, False, False, False, True, True, True)
363def fake_particle_ds(
364    fields=None,
365    units=None,
366    negative=None,
367    npart=16 ** 3,
368    length_unit=1.0,
369    data=None,
371    from yt.loaders import load_particles
373    prng = RandomState(0x4D3D3D3)
374    if negative is not None and not is_sequence(negative):
375        negative = [negative for f in fields]
377    fields, units, negative = _check_field_unit_args_helper(
378        {
379            "fields": fields,
380            "units": units,
381            "negative": negative,
382        },
383        {
384            "fields": _fake_particle_ds_default_fields,
385            "units": _fake_particle_ds_default_units,
386            "negative": _fake_particle_ds_default_negative,
387        },
388    )
390    offsets = []
391    for n in negative:
392        if n:
393            offsets.append(0.5)
394        else:
395            offsets.append(0.0)
396    data = data if data else {}
397    for field, offset, u in zip(fields, offsets, units):
398        if field in data:
399            v = data[field]
400            continue
401        if "position" in field:
402            v = prng.normal(loc=0.5, scale=0.25, size=npart)
403            np.clip(v, 0.0, 1.0, v)
404        v = prng.random_sample(npart) - offset
405        data[field] = (v, u)
406    bbox = np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]])
407    ds = load_particles(data, 1.0, bbox=bbox)
408    return ds
411def fake_tetrahedral_ds():
412    from yt.frontends.stream.sample_data.tetrahedral_mesh import (
413        _connectivity,
414        _coordinates,
415    )
416    from yt.loaders import load_unstructured_mesh
418    prng = RandomState(0x4D3D3D3)
420    # the distance from the origin
421    node_data = {}
422    dist = np.sum(_coordinates ** 2, 1)
423    node_data[("connect1", "test")] = dist[_connectivity]
425    # each element gets a random number
426    elem_data = {}
427    elem_data[("connect1", "elem")] = prng.rand(_connectivity.shape[0])
429    ds = load_unstructured_mesh(
430        _connectivity, _coordinates, node_data=node_data, elem_data=elem_data
431    )
432    return ds
435def fake_hexahedral_ds(fields=None):
436    from yt.frontends.stream.sample_data.hexahedral_mesh import (
437        _connectivity,
438        _coordinates,
439    )
440    from yt.loaders import load_unstructured_mesh
442    prng = RandomState(0x4D3D3D3)
443    # the distance from the origin
444    node_data = {}
445    dist = np.sum(_coordinates ** 2, 1)
446    node_data[("connect1", "test")] = dist[_connectivity - 1]
448    for field in always_iterable(fields):
449        node_data[("connect1", field)] = dist[_connectivity - 1]
451    # each element gets a random number
452    elem_data = {}
453    elem_data[("connect1", "elem")] = prng.rand(_connectivity.shape[0])
455    ds = load_unstructured_mesh(
456        _connectivity - 1, _coordinates, node_data=node_data, elem_data=elem_data
457    )
458    return ds
461def small_fake_hexahedral_ds():
462    from yt.loaders import load_unstructured_mesh
464    _coordinates = np.array(
465        [
466            [-1.0, -1.0, -1.0],
467            [0.0, -1.0, -1.0],
468            [-0.0, 0.0, -1.0],
469            [-1.0, -0.0, -1.0],
470            [-1.0, -1.0, 0.0],
471            [-0.0, -1.0, 0.0],
472            [-0.0, 0.0, -0.0],
473            [-1.0, 0.0, -0.0],
474        ]
475    )
476    _connectivity = np.array([[1, 2, 3, 4, 5, 6, 7, 8]])
478    # the distance from the origin
479    node_data = {}
480    dist = np.sum(_coordinates ** 2, 1)
481    node_data[("connect1", "test")] = dist[_connectivity - 1]
483    ds = load_unstructured_mesh(_connectivity - 1, _coordinates, node_data=node_data)
484    return ds
487def fake_vr_orientation_test_ds(N=96, scale=1):
488    """
489    create a toy dataset that puts a sphere at (0,0,0), a single cube
490    on +x, two cubes on +y, and three cubes on +z in a domain from
491    [-1*scale,1*scale]**3.  The lower planes
492    (x = -1*scale, y = -1*scale, z = -1*scale) are also given non-zero
493    values.
495    This dataset allows you to easily explore orientations and
496    handiness in VR and other renderings
498    Parameters
499    ----------
501    N : integer
502       The number of cells along each direction
504    scale : float
505       A spatial scale, the domain boundaries will be multiplied by scale to
506       test datasets that have spatial different scales (e.g. data in CGS units)
508    """
509    from yt.loaders import load_uniform_grid
511    xmin = ymin = zmin = -1.0 * scale
512    xmax = ymax = zmax = 1.0 * scale
514    dcoord = (xmax - xmin) / N
516    arr = np.zeros((N, N, N), dtype=np.float64)
517    arr[:, :, :] = 1.0e-4
519    bbox = np.array([[xmin, xmax], [ymin, ymax], [zmin, zmax]])
521    # coordinates -- in the notation data[i, j, k]
522    x = (np.arange(N) + 0.5) * dcoord + xmin
523    y = (np.arange(N) + 0.5) * dcoord + ymin
524    z = (np.arange(N) + 0.5) * dcoord + zmin
526    x3d, y3d, z3d = np.meshgrid(x, y, z, indexing="ij")
528    # sphere at the origin
529    c = np.array([0.5 * (xmin + xmax), 0.5 * (ymin + ymax), 0.5 * (zmin + zmax)])
530    r = np.sqrt((x3d - c[0]) ** 2 + (y3d - c[1]) ** 2 + (z3d - c[2]) ** 2)
531    arr[r < 0.05] = 1.0
533    arr[abs(x3d - xmin) < 2 * dcoord] = 0.3
534    arr[abs(y3d - ymin) < 2 * dcoord] = 0.3
535    arr[abs(z3d - zmin) < 2 * dcoord] = 0.3
537    # single cube on +x
538    xc = 0.75 * scale
539    dx = 0.05 * scale
540    idx = np.logical_and(
541        np.logical_and(x3d > xc - dx, x3d < xc + dx),
542        np.logical_and(
543            np.logical_and(y3d > -dx, y3d < dx), np.logical_and(z3d > -dx, z3d < dx)
544        ),
545    )
546    arr[idx] = 1.0
548    # two cubes on +y
549    dy = 0.05 * scale
550    for yc in [0.65 * scale, 0.85 * scale]:
551        idx = np.logical_and(
552            np.logical_and(y3d > yc - dy, y3d < yc + dy),
553            np.logical_and(
554                np.logical_and(x3d > -dy, x3d < dy), np.logical_and(z3d > -dy, z3d < dy)
555            ),
556        )
557        arr[idx] = 0.8
559    # three cubes on +z
560    dz = 0.05 * scale
561    for zc in [0.5 * scale, 0.7 * scale, 0.9 * scale]:
562        idx = np.logical_and(
563            np.logical_and(z3d > zc - dz, z3d < zc + dz),
564            np.logical_and(
565                np.logical_and(x3d > -dz, x3d < dz), np.logical_and(y3d > -dz, y3d < dz)
566            ),
567        )
568        arr[idx] = 0.6
570    data = dict(density=(arr, "g/cm**3"))
571    ds = load_uniform_grid(data, arr.shape, bbox=bbox)
572    return ds
575def fake_sph_orientation_ds():
576    """Returns an in-memory SPH dataset useful for testing
578    This dataset should have one particle at the origin, one more particle
579    along the x axis, two along y, and three along z. All particles will
580    have non-overlapping smoothing regions with a radius of 0.25, masses of 1,
581    and densities of 1, and zero velocity.
582    """
583    from yt import load_particles
585    npart = 7
587    # one particle at the origin, one particle along x-axis, two along y,
588    # three along z
589    data = {
590        "particle_position_x": (np.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]), "cm"),
591        "particle_position_y": (np.array([0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0]), "cm"),
592        "particle_position_z": (np.array([0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0]), "cm"),
593        "particle_mass": (np.ones(npart), "g"),
594        "particle_velocity_x": (np.zeros(npart), "cm/s"),
595        "particle_velocity_y": (np.zeros(npart), "cm/s"),
596        "particle_velocity_z": (np.zeros(npart), "cm/s"),
597        "smoothing_length": (0.25 * np.ones(npart), "cm"),
598        "density": (np.ones(npart), "g/cm**3"),
599        "temperature": (np.ones(npart), "K"),
600    }
602    bbox = np.array([[-4, 4], [-4, 4], [-4, 4]])
604    return load_particles(data=data, length_unit=1.0, bbox=bbox)
607def fake_sph_grid_ds(hsml_factor=1.0):
608    """Returns an in-memory SPH dataset useful for testing
610    This dataset should have 27 particles with the particles arranged uniformly
611    on a 3D grid. The bottom left corner is (0.5,0.5,0.5) and the top right
612    corner is (2.5,2.5,2.5). All particles will have non-overlapping smoothing
613    regions with a radius of 0.05, masses of 1, and densities of 1, and zero
614    velocity.
615    """
616    from yt import load_particles
618    npart = 27
620    x = np.empty(npart)
621    y = np.empty(npart)
622    z = np.empty(npart)
624    tot = 0
625    for i in range(0, 3):
626        for j in range(0, 3):
627            for k in range(0, 3):
628                x[tot] = i + 0.5
629                y[tot] = j + 0.5
630                z[tot] = k + 0.5
631                tot += 1
633    data = {
634        "particle_position_x": (x, "cm"),
635        "particle_position_y": (y, "cm"),
636        "particle_position_z": (z, "cm"),
637        "particle_mass": (np.ones(npart), "g"),
638        "particle_velocity_x": (np.zeros(npart), "cm/s"),
639        "particle_velocity_y": (np.zeros(npart), "cm/s"),
640        "particle_velocity_z": (np.zeros(npart), "cm/s"),
641        "smoothing_length": (0.05 * np.ones(npart) * hsml_factor, "cm"),
642        "density": (np.ones(npart), "g/cm**3"),
643        "temperature": (np.ones(npart), "K"),
644    }
646    bbox = np.array([[0, 3], [0, 3], [0, 3]])
648    return load_particles(data=data, length_unit=1.0, bbox=bbox)
651def construct_octree_mask(prng=RandomState(0x1D3D3D3), refined=None):  # noqa B008
652    # Implementation taken from url:
653    # http://docs.hyperion-rt.org/en/stable/advanced/indepth_oct.html
655    if refined in (None, True):
656        refined = [True]
657    if not refined:
658        refined = [False]
659        return refined
661    # Loop over subcells
662    for _ in range(8):
663        # Insert criterion for whether cell should be sub-divided. Here we
664        # just use a random number to demonstrate.
665        divide = prng.random_sample() < 0.12
667        # Append boolean to overall list
668        refined.append(divide)
670        # If the cell is sub-divided, recursively divide it further
671        if divide:
672            construct_octree_mask(prng, refined)
673    return refined
676def fake_octree_ds(
677    prng=RandomState(0x4D3D3D3),  # noqa B008
678    refined=None,
679    quantities=None,
680    bbox=None,
681    sim_time=0.0,
682    length_unit=None,
683    mass_unit=None,
684    time_unit=None,
685    velocity_unit=None,
686    magnetic_unit=None,
687    periodicity=(True, True, True),
688    over_refine_factor=1,
689    partial_coverage=1,
690    unit_system="cgs",
692    from yt.loaders import load_octree
694    octree_mask = np.asarray(
695        construct_octree_mask(prng=prng, refined=refined), dtype=np.uint8
696    )
697    particles = np.sum(np.invert(octree_mask))
699    if quantities is None:
700        quantities = {}
701        quantities[("gas", "density")] = prng.random_sample((particles, 1))
702        quantities[("gas", "velocity_x")] = prng.random_sample((particles, 1))
703        quantities[("gas", "velocity_y")] = prng.random_sample((particles, 1))
704        quantities[("gas", "velocity_z")] = prng.random_sample((particles, 1))
706    ds = load_octree(
707        octree_mask=octree_mask,
708        data=quantities,
709        bbox=bbox,
710        sim_time=sim_time,
711        length_unit=length_unit,
712        mass_unit=mass_unit,
713        time_unit=time_unit,
714        velocity_unit=velocity_unit,
715        magnetic_unit=magnetic_unit,
716        periodicity=periodicity,
717        partial_coverage=partial_coverage,
718        over_refine_factor=over_refine_factor,
719        unit_system=unit_system,
720    )
721    return ds
724def add_noise_fields(ds):
725    """Add 4 classes of noise fields to a dataset"""
726    prng = RandomState(0x4D3D3D3)
728    def _binary_noise(field, data):
729        """random binary data"""
730        return prng.randint(low=0, high=2, size=data.size).astype("float64")
732    def _positive_noise(field, data):
733        """random strictly positive data"""
734        return prng.random_sample(data.size) + 1e-16
736    def _negative_noise(field, data):
737        """random negative data"""
738        return -prng.random_sample(data.size)
740    def _even_noise(field, data):
741        """random data with mixed signs"""
742        return 2 * prng.random_sample(data.size) - 1
744    ds.add_field(("gas", "noise0"), _binary_noise, sampling_type="cell")
745    ds.add_field(("gas", "noise1"), _positive_noise, sampling_type="cell")
746    ds.add_field(("gas", "noise2"), _negative_noise, sampling_type="cell")
747    ds.add_field(("gas", "noise3"), _even_noise, sampling_type="cell")
750def expand_keywords(keywords, full=False):
751    """
752    expand_keywords is a means for testing all possible keyword
753    arguments in the nosetests.  Simply pass it a dictionary of all the
754    keyword arguments and all of the values for these arguments that you
755    want to test.
757    It will return a list of kwargs dicts containing combinations of
758    the various kwarg values you passed it.  These can then be passed
759    to the appropriate function in nosetests.
761    If full=True, then every possible combination of keywords is produced,
762    otherwise, every keyword option is included at least once in the output
763    list.  Be careful, by using full=True, you may be in for an exponentially
764    larger number of tests!
766    Parameters
767    ----------
769    keywords : dict
770        a dictionary where the keys are the keywords for the function,
771        and the values of each key are the possible values that this key
772        can take in the function
774    full : bool
775        if set to True, every possible combination of given keywords is
776        returned
778    Returns
779    -------
781    array of dicts
782        An array of dictionaries to be individually passed to the appropriate
783        function matching these kwargs.
785    Examples
786    --------
788    >>> keywords = {}
789    >>> keywords["dpi"] = (50, 100, 200)
790    >>> keywords["cmap"] = ("arbre", "kelp")
791    >>> list_of_kwargs = expand_keywords(keywords)
792    >>> print(list_of_kwargs)
794    array([{'cmap': 'arbre', 'dpi': 50},
795           {'cmap': 'kelp', 'dpi': 100},
796           {'cmap': 'arbre', 'dpi': 200}], dtype=object)
798    >>> list_of_kwargs = expand_keywords(keywords, full=True)
799    >>> print(list_of_kwargs)
801    array([{'cmap': 'arbre', 'dpi': 50},
802           {'cmap': 'arbre', 'dpi': 100},
803           {'cmap': 'arbre', 'dpi': 200},
804           {'cmap': 'kelp', 'dpi': 50},
805           {'cmap': 'kelp', 'dpi': 100},
806           {'cmap': 'kelp', 'dpi': 200}], dtype=object)
808    >>> for kwargs in list_of_kwargs:
809    ...     write_projection(*args, **kwargs)
810    """
812    # if we want every possible combination of keywords, use iter magic
813    if full:
814        keys = sorted(keywords)
815        list_of_kwarg_dicts = np.array(
816            [
817                dict(zip(keys, prod))
818                for prod in it.product(*(keywords[key] for key in keys))
819            ]
820        )
822    # if we just want to probe each keyword, but not necessarily every
823    # combination
824    else:
825        # Determine the maximum number of values any of the keywords has
826        num_lists = 0
827        for val in keywords.values():
828            if isinstance(val, str):
829                num_lists = max(1.0, num_lists)
830            else:
831                num_lists = max(len(val), num_lists)
833        # Construct array of kwargs dicts, each element of the list is a different
834        # **kwargs dict.  each kwargs dict gives a different combination of
835        # the possible values of the kwargs
837        # initialize array
838        list_of_kwarg_dicts = np.array([dict() for x in range(num_lists)])
840        # fill in array
841        for i in np.arange(num_lists):
842            list_of_kwarg_dicts[i] = {}
843            for key in keywords.keys():
844                # if it's a string, use it (there's only one)
845                if isinstance(keywords[key], str):
846                    list_of_kwarg_dicts[i][key] = keywords[key]
847                # if there are more options, use the i'th val
848                elif i < len(keywords[key]):
849                    list_of_kwarg_dicts[i][key] = keywords[key][i]
850                # if there are not more options, use the 0'th val
851                else:
852                    list_of_kwarg_dicts[i][key] = keywords[key][0]
854    return list_of_kwarg_dicts
857def requires_module(module):
858    """
859    Decorator that takes a module name as an argument and tries to import it.
860    If the module imports without issue, the function is returned, but if not,
861    a null function is returned. This is so tests that depend on certain modules
862    being imported will not fail if the module is not installed on the testing
863    platform.
864    """
865    from nose import SkipTest
867    def ffalse(func):
868        @functools.wraps(func)
869        def false_wrapper(*args, **kwargs):
870            raise SkipTest
872        return false_wrapper
874    def ftrue(func):
875        @functools.wraps(func)
876        def true_wrapper(*args, **kwargs):
877            return func(*args, **kwargs)
879        return true_wrapper
881    try:
882        importlib.import_module(module)
883    except ImportError:
884        return ffalse
885    else:
886        return ftrue
889def requires_module_pytest(*module_names):
890    """
891    This is a replacement for yt.testing.requires_module that's
892    compatible with pytest, and accepts an arbitrary number of requirements to
893    avoid stacking decorators
895    Important: this is meant to decorate test functions only, it won't work as a
896    decorator to fixture functions.
897    It's meant to be imported as
898    >>> from yt.testing import requires_module_pytest as requires_module
900    So that it can be later renamed to `requires_module`.
901    """
902    import pytest
904    from yt.utilities import on_demand_imports as odi
906    def deco(func):
907        required_modules = {
908            name: getattr(odi, f"_{name}")._module for name in module_names
909        }
910        missing = [
911            name
912            for name, mod in required_modules.items()
913            if isinstance(mod, odi.NotAModule)
914        ]
916        # note that order between these two decorators matters
917        @pytest.mark.skipif(
918            missing,
919            reason=f"missing requirement(s): {', '.join(missing)}",
920        )
921        @functools.wraps(func)
922        def inner_func(*args, **kwargs):
923            return func(*args, **kwargs)
925        return inner_func
927    return deco
930def requires_file(req_file):
931    from nose import SkipTest
933    path = ytcfg.get("yt", "test_data_dir")
935    def ffalse(func):
936        @functools.wraps(func)
937        def false_wrapper(*args, **kwargs):
938            if ytcfg.get("yt", "internals", "strict_requires"):
939                raise FileNotFoundError(req_file)
940            raise SkipTest
942        return false_wrapper
944    def ftrue(func):
945        @functools.wraps(func)
946        def true_wrapper(*args, **kwargs):
947            return func(*args, **kwargs)
949        return true_wrapper
951    if os.path.exists(req_file):
952        return ftrue
953    else:
954        if os.path.exists(os.path.join(path, req_file)):
955            return ftrue
956        else:
957            return ffalse
960def disable_dataset_cache(func):
961    @functools.wraps(func)
962    def newfunc(*args, **kwargs):
963        restore_cfg_state = False
964        if not ytcfg.get("yt", "skip_dataset_cache"):
965            ytcfg["yt", "skip_dataset_cache"] = True
966        rv = func(*args, **kwargs)
967        if restore_cfg_state:
968            ytcfg["yt", "skip_dataset_cache"] = False
969        return rv
971    return newfunc
975def units_override_check(fn):
976    units_list = ["length", "time", "mass", "velocity", "magnetic", "temperature"]
977    ds1 = load(fn)
978    units_override = {}
979    attrs1 = []
980    attrs2 = []
981    for u in units_list:
982        unit_attr = getattr(ds1, f"{u}_unit", None)
983        if unit_attr is not None:
984            attrs1.append(unit_attr)
985            units_override[f"{u}_unit"] = (unit_attr.v, unit_attr.units)
986    del ds1
987    ds2 = load(fn, units_override=units_override)
988    assert len(ds2.units_override) > 0
989    for u in units_list:
990        unit_attr = getattr(ds2, f"{u}_unit", None)
991        if unit_attr is not None:
992            attrs2.append(unit_attr)
993    assert_equal(attrs1, attrs2)
996# This is an export of the 40 grids in IsolatedGalaxy that are of level 4 or
997# lower.  It's just designed to give a sample AMR index to deal with.
998_amr_grid_index = [
999    [0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [32, 32, 32]],
1000    [1, [0.25, 0.21875, 0.25], [0.5, 0.5, 0.5], [16, 18, 16]],
1001    [1, [0.5, 0.21875, 0.25], [0.75, 0.5, 0.5], [16, 18, 16]],
1002    [1, [0.21875, 0.5, 0.25], [0.5, 0.75, 0.5], [18, 16, 16]],
1003    [1, [0.5, 0.5, 0.25], [0.75, 0.75, 0.5], [16, 16, 16]],
1004    [1, [0.25, 0.25, 0.5], [0.5, 0.5, 0.75], [16, 16, 16]],
1005    [1, [0.5, 0.25, 0.5], [0.75, 0.5, 0.75], [16, 16, 16]],
1006    [1, [0.25, 0.5, 0.5], [0.5, 0.75, 0.75], [16, 16, 16]],
1007    [1, [0.5, 0.5, 0.5], [0.75, 0.75, 0.75], [16, 16, 16]],
1008    [2, [0.5, 0.5, 0.5], [0.71875, 0.71875, 0.71875], [28, 28, 28]],
1009    [3, [0.5, 0.5, 0.5], [0.6640625, 0.65625, 0.6796875], [42, 40, 46]],
1010    [4, [0.5, 0.5, 0.5], [0.59765625, 0.6015625, 0.6015625], [50, 52, 52]],
1011    [2, [0.28125, 0.5, 0.5], [0.5, 0.734375, 0.71875], [28, 30, 28]],
1012    [3, [0.3359375, 0.5, 0.5], [0.5, 0.671875, 0.6640625], [42, 44, 42]],
1013    [4, [0.40625, 0.5, 0.5], [0.5, 0.59765625, 0.59765625], [48, 50, 50]],
1014    [2, [0.5, 0.28125, 0.5], [0.71875, 0.5, 0.71875], [28, 28, 28]],
1015    [3, [0.5, 0.3359375, 0.5], [0.671875, 0.5, 0.6640625], [44, 42, 42]],
1016    [4, [0.5, 0.40625, 0.5], [0.6015625, 0.5, 0.59765625], [52, 48, 50]],
1017    [2, [0.28125, 0.28125, 0.5], [0.5, 0.5, 0.71875], [28, 28, 28]],
1018    [3, [0.3359375, 0.3359375, 0.5], [0.5, 0.5, 0.671875], [42, 42, 44]],
1019    [
1020        4,
1021        [0.46484375, 0.37890625, 0.50390625],
1022        [0.4765625, 0.390625, 0.515625],
1023        [6, 6, 6],
1024    ],
1025    [4, [0.40625, 0.40625, 0.5], [0.5, 0.5, 0.59765625], [48, 48, 50]],
1026    [2, [0.5, 0.5, 0.28125], [0.71875, 0.71875, 0.5], [28, 28, 28]],
1027    [3, [0.5, 0.5, 0.3359375], [0.6796875, 0.6953125, 0.5], [46, 50, 42]],
1028    [4, [0.5, 0.5, 0.40234375], [0.59375, 0.6015625, 0.5], [48, 52, 50]],
1029    [2, [0.265625, 0.5, 0.28125], [0.5, 0.71875, 0.5], [30, 28, 28]],
1030    [3, [0.3359375, 0.5, 0.328125], [0.5, 0.65625, 0.5], [42, 40, 44]],
1031    [4, [0.40234375, 0.5, 0.40625], [0.5, 0.60546875, 0.5], [50, 54, 48]],
1032    [2, [0.5, 0.265625, 0.28125], [0.71875, 0.5, 0.5], [28, 30, 28]],
1033    [3, [0.5, 0.3203125, 0.328125], [0.6640625, 0.5, 0.5], [42, 46, 44]],
1034    [4, [0.5, 0.3984375, 0.40625], [0.546875, 0.5, 0.5], [24, 52, 48]],
1035    [4, [0.546875, 0.41796875, 0.4453125], [0.5625, 0.4375, 0.5], [8, 10, 28]],
1036    [4, [0.546875, 0.453125, 0.41796875], [0.5546875, 0.48046875, 0.4375], [4, 14, 10]],
1037    [4, [0.546875, 0.4375, 0.4375], [0.609375, 0.5, 0.5], [32, 32, 32]],
1038    [4, [0.546875, 0.4921875, 0.41796875], [0.56640625, 0.5, 0.4375], [10, 4, 10]],
1039    [
1040        4,
1041        [0.546875, 0.48046875, 0.41796875],
1042        [0.5703125, 0.4921875, 0.4375],
1043        [12, 6, 10],
1044    ],
1045    [4, [0.55859375, 0.46875, 0.43359375], [0.5703125, 0.48046875, 0.4375], [6, 6, 2]],
1046    [2, [0.265625, 0.28125, 0.28125], [0.5, 0.5, 0.5], [30, 28, 28]],
1047    [3, [0.328125, 0.3359375, 0.328125], [0.5, 0.5, 0.5], [44, 42, 44]],
1048    [4, [0.4140625, 0.40625, 0.40625], [0.5, 0.5, 0.5], [44, 48, 48]],
1052def check_results(func):
1053    r"""This is a decorator for a function to verify that the (numpy ndarray)
1054    result of a function is what it should be.
1056    This function is designed to be used for very light answer testing.
1057    Essentially, it wraps around a larger function that returns a numpy array,
1058    and that has results that should not change.  It is not necessarily used
1059    inside the testing scripts themselves, but inside testing scripts written
1060    by developers during the testing of pull requests and new functionality.
1061    If a hash is specified, it "wins" and the others are ignored.  Otherwise,
1062    tolerance is 1e-8 (just above single precision.)
1064    The correct results will be stored if the command line contains
1065    --answer-reference , and otherwise it will compare against the results on
1066    disk.  The filename will be func_results_ref_FUNCNAME.cpkl where FUNCNAME
1067    is the name of the function being tested.
1069    If you would like more control over the name of the pickle file the results
1070    are stored in, you can pass the result_basename keyword argument to the
1071    function you are testing.  The check_results decorator will use the value
1072    of the keyword to construct the filename of the results data file.  If
1073    result_basename is not specified, the name of the testing function is used.
1075    This will raise an exception if the results are not correct.
1077    Examples
1078    --------
1080    >>> @check_results
1081    ... def my_func(ds):
1082    ...     return ds.domain_width
1084    >>> my_func(ds)
1086    >>> @check_results
1087    ... def field_checker(dd, field_name):
1088    ...     return dd[field_name]
1090    >>> field_checker(ds.all_data(), "density", result_basename="density")
1092    """
1094    def compute_results(func):
1095        @functools.wraps(func)
1096        def _func(*args, **kwargs):
1097            name = kwargs.pop("result_basename", func.__name__)
1098            rv = func(*args, **kwargs)
1099            if hasattr(rv, "convert_to_base"):
1100                rv.convert_to_base()
1101                _rv = rv.ndarray_view()
1102            else:
1103                _rv = rv
1104            mi = _rv.min()
1105            ma = _rv.max()
1106            st = _rv.std(dtype="float64")
1107            su = _rv.sum(dtype="float64")
1108            si = _rv.size
1109            ha = hashlib.md5(_rv.tobytes()).hexdigest()
1110            fn = f"func_results_ref_{name}.cpkl"
1111            with open(fn, "wb") as f:
1112                pickle.dump((mi, ma, st, su, si, ha), f)
1113            return rv
1115        return _func
1117    from yt.mods import unparsed_args
1119    if "--answer-reference" in unparsed_args:
1120        return compute_results(func)
1122    def compare_results(func):
1123        @functools.wraps(func)
1124        def _func(*args, **kwargs):
1125            name = kwargs.pop("result_basename", func.__name__)
1126            rv = func(*args, **kwargs)
1127            if hasattr(rv, "convert_to_base"):
1128                rv.convert_to_base()
1129                _rv = rv.ndarray_view()
1130            else:
1131                _rv = rv
1132            vals = (
1133                _rv.min(),
1134                _rv.max(),
1135                _rv.std(dtype="float64"),
1136                _rv.sum(dtype="float64"),
1137                _rv.size,
1138                hashlib.md5(_rv.tobytes()).hexdigest(),
1139            )
1140            fn = f"func_results_ref_{name}.cpkl"
1141            if not os.path.exists(fn):
1142                print("Answers need to be created with --answer-reference .")
1143                return False
1144            with open(fn, "rb") as f:
1145                ref = pickle.load(f)
1146            print(f"Sizes: {vals[4] == ref[4]} ({vals[4]}, {ref[4]})")
1147            assert_allclose(vals[0], ref[0], 1e-8, err_msg="min")
1148            assert_allclose(vals[1], ref[1], 1e-8, err_msg="max")
1149            assert_allclose(vals[2], ref[2], 1e-8, err_msg="std")
1150            assert_allclose(vals[3], ref[3], 1e-8, err_msg="sum")
1151            assert_equal(vals[4], ref[4])
1152            print("Hashes equal: %s" % (vals[-1] == ref[-1]))
1153            return rv
1155        return _func
1157    return compare_results(func)
1160def periodicity_cases(ds):
1161    # This is a generator that yields things near the corners.  It's good for
1162    # getting different places to check periodicity.
1163    yield (ds.domain_left_edge + ds.domain_right_edge) / 2.0
1164    dx = ds.domain_width / ds.domain_dimensions
1165    # We start one dx in, and only go to one in as well.
1166    for i in (1, ds.domain_dimensions[0] - 2):
1167        for j in (1, ds.domain_dimensions[1] - 2):
1168            for k in (1, ds.domain_dimensions[2] - 2):
1169                center = dx * np.array([i, j, k]) + ds.domain_left_edge
1170                yield center
1173def run_nose(
1174    verbose=False,
1175    run_answer_tests=False,
1176    answer_big_data=False,
1177    call_pdb=False,
1178    module=None,
1180    import sys
1182    from yt.utilities.logger import ytLogger as mylog
1183    from yt.utilities.on_demand_imports import _nose
1185    orig_level = mylog.getEffectiveLevel()
1186    mylog.setLevel(50)
1187    nose_argv = sys.argv
1188    nose_argv += ["--exclude=answer_testing", "--detailed-errors", "--exe"]
1189    if call_pdb:
1190        nose_argv += ["--pdb", "--pdb-failures"]
1191    if verbose:
1192        nose_argv.append("-v")
1193    if run_answer_tests:
1194        nose_argv.append("--with-answer-testing")
1195    if answer_big_data:
1196        nose_argv.append("--answer-big-data")
1197    if module:
1198        nose_argv.append(module)
1199    initial_dir = os.getcwd()
1200    yt_file = os.path.abspath(__file__)
1201    yt_dir = os.path.dirname(yt_file)
1202    if os.path.samefile(os.path.dirname(yt_dir), initial_dir):
1203        # Provide a nice error message to work around nose bug
1204        # see https://github.com/nose-devs/nose/issues/701
1205        raise RuntimeError(
1206            """
1207    The yt.run_nose function does not work correctly when invoked in
1208    the same directory as the installed yt package. Try starting
1209    a python session in a different directory before invoking yt.run_nose
1210    again. Alternatively, you can also run the "nosetests" executable in
1211    the current directory like so:
1213        $ nosetests
1214            """
1215        )
1216    os.chdir(yt_dir)
1217    try:
1218        _nose.run(argv=nose_argv)
1219    finally:
1220        os.chdir(initial_dir)
1221        mylog.setLevel(orig_level)
1224def assert_allclose_units(actual, desired, rtol=1e-7, atol=0, **kwargs):
1225    """Raise an error if two objects are not equal up to desired tolerance
1227    This is a wrapper for :func:`numpy.testing.assert_allclose` that also
1228    verifies unit consistency
1230    Parameters
1231    ----------
1232    actual : array-like
1233        Array obtained (possibly with attached units)
1234    desired : array-like
1235        Array to compare with (possibly with attached units)
1236    rtol : float, optional
1237        Relative tolerance, defaults to 1e-7
1238    atol : float or quantity, optional
1239        Absolute tolerance. If units are attached, they must be consistent
1240        with the units of ``actual`` and ``desired``. If no units are attached,
1241        assumes the same units as ``desired``. Defaults to zero.
1243    Notes
1244    -----
1245    Also accepts additional keyword arguments accepted by
1246    :func:`numpy.testing.assert_allclose`, see the documentation of that
1247    function for details.
1249    """
1250    # Create a copy to ensure this function does not alter input arrays
1251    act = YTArray(actual)
1252    des = YTArray(desired)
1254    try:
1255        des = des.in_units(act.units)
1256    except UnitOperationError as e:
1257        raise AssertionError(
1258            "Units of actual (%s) and desired (%s) do not have "
1259            "equivalent dimensions" % (act.units, des.units)
1260        ) from e
1262    rt = YTArray(rtol)
1263    if not rt.units.is_dimensionless:
1264        raise AssertionError(f"Units of rtol ({rt.units}) are not dimensionless")
1266    if not isinstance(atol, YTArray):
1267        at = YTQuantity(atol, des.units)
1269    try:
1270        at = at.in_units(act.units)
1271    except UnitOperationError as e:
1272        raise AssertionError(
1273            "Units of atol (%s) and actual (%s) do not have "
1274            "equivalent dimensions" % (at.units, act.units)
1275        ) from e
1277    # units have been validated, so we strip units before calling numpy
1278    # to avoid spurious errors
1279    act = act.value
1280    des = des.value
1281    rt = rt.value
1282    at = at.value
1284    return assert_allclose(act, des, rt, at, **kwargs)
1287def assert_fname(fname):
1288    """Function that checks file type using libmagic"""
1289    if fname is None:
1290        return
1292    with open(fname, "rb") as fimg:
1293        data = fimg.read()
1294    image_type = ""
1296    # see http://www.w3.org/TR/PNG/#5PNG-file-signature
1297    if data.startswith(b"\211PNG\r\n\032\n"):
1298        image_type = ".png"
1299    # see http://www.mathguide.de/info/tools/media-types/image/jpeg
1300    elif data.startswith(b"\377\330"):
1301        image_type = ".jpeg"
1302    elif data.startswith(b"%!PS-Adobe"):
1303        data_str = data.decode("utf-8", "ignore")
1304        if "EPSF" in data_str[: data_str.index("\n")]:
1305            image_type = ".eps"
1306        else:
1307            image_type = ".ps"
1308    elif data.startswith(b"%PDF"):
1309        image_type = ".pdf"
1311    extension = os.path.splitext(fname)[1]
1313    assert (
1314        image_type == extension
1315    ), "Expected an image of type '{}' but '{}' is an image of type '{}'".format(
1316        extension,
1317        fname,
1318        image_type,
1319    )
1322def requires_backend(backend):
1323    """Decorator to check for a specified matplotlib backend.
1325    This decorator returns the decorated function if the specified `backend`
1326    is same as of `matplotlib.get_backend()`, otherwise returns null function.
1327    It could be used to execute function only when a particular `backend` of
1328    matplotlib is being used.
1330    Parameters
1331    ----------
1332    backend : String
1333        The value which is compared with the current matplotlib backend in use.
1335    Returns
1336    -------
1337    Decorated function or null function
1339    """
1340    import pytest
1342    def ffalse(func):
1343        # returning a lambda : None causes an error when using pytest. Having
1344        # a function (skip) that returns None does work, but pytest marks the
1345        # test as having passed, which seems bad, since it wasn't actually run.
1346        # Using pytest.skip() means that a change to test_requires_backend was
1347        # needed since None is no longer returned, so we check for the skip
1348        # exception in the xfail case for that test
1349        def skip(*args, **kwargs):
1350            msg = f"`{backend}` backend not found, skipping: `{func.__name__}`"
1351            print(msg)
1352            pytest.skip(msg)
1354        if ytcfg.get("yt", "internals", "within_pytest"):
1355            return skip
1356        else:
1357            return lambda: None
1359    def ftrue(func):
1360        return func
1362    if backend.lower() == matplotlib.get_backend().lower():
1363        return ftrue
1364    return ffalse
1367class TempDirTest(unittest.TestCase):
1368    """
1369    A test class that runs in a temporary directory and
1370    removes it afterward.
1371    """
1373    def setUp(self):
1374        self.curdir = os.getcwd()
1375        self.tmpdir = tempfile.mkdtemp()
1376        os.chdir(self.tmpdir)
1378    def tearDown(self):
1379        os.chdir(self.curdir)
1380        shutil.rmtree(self.tmpdir)
1383class ParticleSelectionComparison:
1384    """
1385    This is a test helper class that takes a particle dataset, caches the
1386    particles it has on disk (manually reading them using lower-level IO
1387    routines) and then received a data object that it compares against manually
1388    running the data object's selection routines.  All supplied data objects
1389    must be created from the input dataset.
1390    """
1392    def __init__(self, ds):
1393        self.ds = ds
1394        # Construct an index so that we get all the data_files
1395        ds.index
1396        particles = {}
1397        # hsml is the smoothing length we use for radial selection
1398        hsml = {}
1399        for data_file in ds.index.data_files:
1400            for ptype, pos_arr in ds.index.io._yield_coordinates(data_file):
1401                particles.setdefault(ptype, []).append(pos_arr)
1402                if ptype in getattr(ds, "_sph_ptypes", ()):
1403                    hsml.setdefault(ptype, []).append(
1404                        ds.index.io._get_smoothing_length(
1405                            data_file, pos_arr.dtype, pos_arr.shape
1406                        )
1407                    )
1408        for ptype in particles:
1409            particles[ptype] = np.concatenate(particles[ptype])
1410            if ptype in hsml:
1411                hsml[ptype] = np.concatenate(hsml[ptype])
1412        self.particles = particles
1413        self.hsml = hsml
1415    def compare_dobj_selection(self, dobj):
1416        for ptype in sorted(self.particles):
1417            x, y, z = self.particles[ptype].T
1418            # Set our radii to zero for now, I guess?
1419            radii = self.hsml.get(ptype, 0.0)
1420            sel_index = dobj.selector.select_points(x, y, z, radii)
1421            if sel_index is None:
1422                sel_pos = np.empty((0, 3))
1423            else:
1424                sel_pos = self.particles[ptype][sel_index, :]
1426            obj_results = []
1427            for chunk in dobj.chunks([], "io"):
1428                obj_results.append(chunk[ptype, "particle_position"])
1429            if any(_.size > 0 for _ in obj_results):
1430                obj_results = np.concatenate(obj_results, axis=0)
1431            else:
1432                obj_results = np.empty((0, 3))
1433            # Sometimes we get unitary scaling or other floating point noise. 5
1434            # NULP should be OK.  This is mostly for stuff like Rockstar, where
1435            # the f32->f64 casting happens at different places depending on
1436            # which code path we use.
1437            assert_array_almost_equal_nulp(sel_pos, obj_results, 5)
1439    def run_defaults(self):
1440        """
1441        This runs lots of samples that touch different types of wraparounds.
1443        Specifically, it does:
1445            * sphere in center with radius 0.1 unitary
1446            * sphere in center with radius 0.2 unitary
1447            * sphere in each of the eight corners of the domain with radius 0.1 unitary
1448            * sphere in center with radius 0.5 unitary
1449            * box that covers 0.1 .. 0.9
1450            * box from 0.8 .. 0.85
1451            * box from 0.3..0.6, 0.2..0.8, 0.0..0.1
1452        """
1453        sp1 = self.ds.sphere("c", (0.1, "unitary"))
1454        self.compare_dobj_selection(sp1)
1456        sp2 = self.ds.sphere("c", (0.2, "unitary"))
1457        self.compare_dobj_selection(sp2)
1459        centers = [
1460            [0.04, 0.5, 0.5],
1461            [0.5, 0.04, 0.5],
1462            [0.5, 0.5, 0.04],
1463            [0.04, 0.04, 0.04],
1464            [0.96, 0.5, 0.5],
1465            [0.5, 0.96, 0.5],
1466            [0.5, 0.5, 0.96],
1467            [0.96, 0.96, 0.96],
1468        ]
1469        r = self.ds.quan(0.1, "unitary")
1470        for center in centers:
1471            c = self.ds.arr(center, "unitary") + self.ds.domain_left_edge.in_units(
1472                "unitary"
1473            )
1474            if not all(self.ds.periodicity):
1475                # filter out the periodic bits for non-periodic datasets
1476                if any(c - r < self.ds.domain_left_edge) or any(
1477                    c + r > self.ds.domain_right_edge
1478                ):
1479                    continue
1480            sp = self.ds.sphere(c, (0.1, "unitary"))
1481            self.compare_dobj_selection(sp)
1483        sp = self.ds.sphere("c", (0.5, "unitary"))
1484        self.compare_dobj_selection(sp)
1486        dd = self.ds.all_data()
1487        self.compare_dobj_selection(dd)
1489        # This is in raw numbers, so we can offset for the left edge
1490        LE = self.ds.domain_left_edge.in_units("unitary").d
1492        reg1 = self.ds.r[
1493            (0.1 + LE[0], "unitary") : (0.9 + LE[0], "unitary"),
1494            (0.1 + LE[1], "unitary") : (0.9 + LE[1], "unitary"),
1495            (0.1 + LE[2], "unitary") : (0.9 + LE[2], "unitary"),
1496        ]
1497        self.compare_dobj_selection(reg1)
1499        reg2 = self.ds.r[
1500            (0.8 + LE[0], "unitary") : (0.85 + LE[0], "unitary"),
1501            (0.8 + LE[1], "unitary") : (0.85 + LE[1], "unitary"),
1502            (0.8 + LE[2], "unitary") : (0.85 + LE[2], "unitary"),
1503        ]
1504        self.compare_dobj_selection(reg2)
1506        reg3 = self.ds.r[
1507            (0.3 + LE[0], "unitary") : (0.6 + LE[0], "unitary"),
1508            (0.2 + LE[1], "unitary") : (0.8 + LE[1], "unitary"),
1509            (0.0 + LE[2], "unitary") : (0.1 + LE[2], "unitary"),
1510        ]
1511        self.compare_dobj_selection(reg3)