1import functools
2import hashlib
3import importlib
4import itertools as it
5import os
6import pickle
7import shutil
8import tempfile
9import unittest
10
11import matplotlib
12import numpy as np
13from more_itertools import always_iterable
14from numpy.random import RandomState
15from unyt.exceptions import UnitOperationError
16
17from yt.config import ytcfg
18from yt.funcs import is_sequence
19from yt.loaders import load
20from yt.units.yt_array import YTArray, YTQuantity
21
22# we import this in a weird way from numpy.testing to avoid triggering
23# flake8 errors from the unused imports. These test functions are imported
24# elsewhere in yt from here so we want them to be imported here.
25from numpy.testing import assert_array_equal, assert_almost_equal  # NOQA isort:skip
26from numpy.testing import assert_equal, assert_array_less  # NOQA isort:skip
27from numpy.testing import assert_string_equal  # NOQA isort:skip
28from numpy.testing import assert_array_almost_equal_nulp  # NOQA isort:skip
29from numpy.testing import assert_allclose, assert_raises  # NOQA isort:skip
30from numpy.testing import assert_approx_equal  # NOQA isort:skip
31from numpy.testing import assert_array_almost_equal  # NOQA isort:skip
32
33ANSWER_TEST_TAG = "answer_test"
34# Expose assert_true and assert_less_equal from unittest.TestCase
35# this is adopted from nose. Doing this here allows us to avoid importing
36# nose at the top level.
37class _Dummy(unittest.TestCase):
38    def nop():
39        pass
40
41
42_t = _Dummy("nop")
43
44assert_true = getattr(_t, "assertTrue")  # noqa: B009
45assert_less_equal = getattr(_t, "assertLessEqual")  # noqa: B009
46
47
48def assert_rel_equal(a1, a2, decimals, err_msg="", verbose=True):
49    # We have nan checks in here because occasionally we have fields that get
50    # weighted without non-zero weights.  I'm looking at you, particle fields!
51    if isinstance(a1, np.ndarray):
52        assert a1.size == a2.size
53        # Mask out NaNs
54        assert (np.isnan(a1) == np.isnan(a2)).all()
55        a1[np.isnan(a1)] = 1.0
56        a2[np.isnan(a2)] = 1.0
57        # Mask out 0
58        ind1 = np.array(np.abs(a1) < np.finfo(a1.dtype).eps)
59        ind2 = np.array(np.abs(a2) < np.finfo(a2.dtype).eps)
60        assert (ind1 == ind2).all()
61        a1[ind1] = 1.0
62        a2[ind2] = 1.0
63    elif np.any(np.isnan(a1)) and np.any(np.isnan(a2)):
64        return True
65    if not isinstance(a1, np.ndarray) and a1 == a2 == 0.0:
66        # NANS!
67        a1 = a2 = 1.0
68    return assert_almost_equal(
69        np.array(a1) / np.array(a2), 1.0, decimals, err_msg=err_msg, verbose=verbose
70    )
71
72
73def amrspace(extent, levels=7, cells=8):
74    """Creates two numpy arrays representing the left and right bounds of
75    an AMR grid as well as an array for the AMR level of each cell.
76
77    Parameters
78    ----------
79    extent : array-like
80        This a sequence of length 2*ndims that is the bounds of each dimension.
81        For example, the 2D unit square would be given by [0.0, 1.0, 0.0, 1.0].
82        A 3D cylindrical grid may look like [0.0, 2.0, -1.0, 1.0, 0.0, 2*np.pi].
83    levels : int or sequence of ints, optional
84        This is the number of AMR refinement levels.  If given as a sequence (of
85        length ndims), then each dimension will be refined down to this level.
86        All values in this array must be the same or zero.  A zero valued dimension
87        indicates that this dim should not be refined.  Taking the 3D cylindrical
88        example above if we don't want refine theta but want r and z at 5 we would
89        set levels=(5, 5, 0).
90    cells : int, optional
91        This is the number of cells per refinement level.
92
93    Returns
94    -------
95    left : float ndarray, shape=(npoints, ndims)
96        The left AMR grid points.
97    right : float ndarray, shape=(npoints, ndims)
98        The right AMR grid points.
99    level : int ndarray, shape=(npoints,)
100        The AMR level for each point.
101
102    Examples
103    --------
104    >>> l, r, lvl = amrspace([0.0, 2.0, 1.0, 2.0, 0.0, 3.14], levels=(3, 3, 0), cells=2)
105    >>> print(l)
106    [[ 0.     1.     0.   ]
107     [ 0.25   1.     0.   ]
108     [ 0.     1.125  0.   ]
109     [ 0.25   1.125  0.   ]
110     [ 0.5    1.     0.   ]
111     [ 0.     1.25   0.   ]
112     [ 0.5    1.25   0.   ]
113     [ 1.     1.     0.   ]
114     [ 0.     1.5    0.   ]
115     [ 1.     1.5    0.   ]]
116
117    """
118    extent = np.asarray(extent, dtype="f8")
119    dextent = extent[1::2] - extent[::2]
120    ndims = len(dextent)
121
122    if isinstance(levels, int):
123        minlvl = maxlvl = levels
124        levels = np.array([levels] * ndims, dtype="int32")
125    else:
126        levels = np.asarray(levels, dtype="int32")
127        minlvl = levels.min()
128        maxlvl = levels.max()
129        if minlvl != maxlvl and (minlvl != 0 or {minlvl, maxlvl} != set(levels)):
130            raise ValueError("all levels must have the same value or zero.")
131    dims_zero = levels == 0
132    dims_nonzero = ~dims_zero
133    ndims_nonzero = dims_nonzero.sum()
134
135    npoints = (cells ** ndims_nonzero - 1) * maxlvl + 1
136    left = np.empty((npoints, ndims), dtype="float64")
137    right = np.empty((npoints, ndims), dtype="float64")
138    level = np.empty(npoints, dtype="int32")
139
140    # fill zero dims
141    left[:, dims_zero] = extent[::2][dims_zero]
142    right[:, dims_zero] = extent[1::2][dims_zero]
143
144    # fill non-zero dims
145    dcell = 1.0 / cells
146    left_slice = tuple(
147        slice(extent[2 * n], extent[2 * n + 1], extent[2 * n + 1])
148        if dims_zero[n]
149        else slice(0.0, 1.0, dcell)
150        for n in range(ndims)
151    )
152    right_slice = tuple(
153        slice(extent[2 * n + 1], extent[2 * n], -extent[2 * n + 1])
154        if dims_zero[n]
155        else slice(dcell, 1.0 + dcell, dcell)
156        for n in range(ndims)
157    )
158    left_norm_grid = np.reshape(np.mgrid[left_slice].T.flat[ndims:], (-1, ndims))
159    lng_zero = left_norm_grid[:, dims_zero]
160    lng_nonzero = left_norm_grid[:, dims_nonzero]
161
162    right_norm_grid = np.reshape(np.mgrid[right_slice].T.flat[ndims:], (-1, ndims))
163    rng_zero = right_norm_grid[:, dims_zero]
164    rng_nonzero = right_norm_grid[:, dims_nonzero]
165
166    level[0] = maxlvl
167    left[0, :] = extent[::2]
168    right[0, dims_zero] = extent[1::2][dims_zero]
169    right[0, dims_nonzero] = (dcell ** maxlvl) * dextent[dims_nonzero] + extent[::2][
170        dims_nonzero
171    ]
172    for i, lvl in enumerate(range(maxlvl, 0, -1)):
173        start = (cells ** ndims_nonzero - 1) * i + 1
174        stop = (cells ** ndims_nonzero - 1) * (i + 1) + 1
175        dsize = dcell ** (lvl - 1) * dextent[dims_nonzero]
176        level[start:stop] = lvl
177        left[start:stop, dims_zero] = lng_zero
178        left[start:stop, dims_nonzero] = lng_nonzero * dsize + extent[::2][dims_nonzero]
179        right[start:stop, dims_zero] = rng_zero
180        right[start:stop, dims_nonzero] = (
181            rng_nonzero * dsize + extent[::2][dims_nonzero]
182        )
183
184    return left, right, level
185
186
187def _check_field_unit_args_helper(args: dict, default_args: dict):
188    values = list(args.values())
189    keys = list(args.keys())
190    if all(v is None for v in values):
191        for key in keys:
192            args[key] = default_args[key]
193    elif None in values:
194        raise ValueError(
195            "Error in creating a fake dataset:"
196            f" either all or none of the following arguments need to specified: {keys}."
197        )
198    elif any(len(v) != len(values[0]) for v in values):
199        raise ValueError(
200            "Error in creating a fake dataset:"
201            f" all the following arguments must have the same length: {keys}."
202        )
203    return list(args.values())
204
205
206_fake_random_ds_default_fields = ("density", "velocity_x", "velocity_y", "velocity_z")
207_fake_random_ds_default_units = ("g/cm**3", "cm/s", "cm/s", "cm/s")
208_fake_random_ds_default_negative = (False, False, False, False)
209
210
211def fake_random_ds(
212    ndims,
213    peak_value=1.0,
214    fields=None,
215    units=None,
216    particle_fields=None,
217    particle_field_units=None,
218    negative=False,
219    nprocs=1,
220    particles=0,
221    length_unit=1.0,
222    unit_system="cgs",
223    bbox=None,
224    default_species_fields=None,
225):
226    from yt.loaders import load_uniform_grid
227
228    prng = RandomState(0x4D3D3D3)
229    if not is_sequence(ndims):
230        ndims = [ndims, ndims, ndims]
231    else:
232        assert len(ndims) == 3
233    if not is_sequence(negative):
234        if fields:
235            negative = [negative for f in fields]
236        else:
237            negative = None
238
239    fields, units, negative = _check_field_unit_args_helper(
240        {
241            "fields": fields,
242            "units": units,
243            "negative": negative,
244        },
245        {
246            "fields": _fake_random_ds_default_fields,
247            "units": _fake_random_ds_default_units,
248            "negative": _fake_random_ds_default_negative,
249        },
250    )
251
252    offsets = []
253    for n in negative:
254        if n:
255            offsets.append(0.5)
256        else:
257            offsets.append(0.0)
258    data = {}
259    for field, offset, u in zip(fields, offsets, units):
260        v = (prng.random_sample(ndims) - offset) * peak_value
261        if field[0] == "all":
262            v = v.ravel()
263        data[field] = (v, u)
264    if particles:
265        if particle_fields is not None:
266            for field, unit in zip(particle_fields, particle_field_units):
267                if field in ("particle_position", "particle_velocity"):
268                    data["io", field] = (prng.random_sample((int(particles), 3)), unit)
269                else:
270                    data["io", field] = (prng.random_sample(size=int(particles)), unit)
271        else:
272            for f in (f"particle_position_{ax}" for ax in "xyz"):
273                data["io", f] = (prng.random_sample(size=particles), "code_length")
274            for f in (f"particle_velocity_{ax}" for ax in "xyz"):
275                data["io", f] = (prng.random_sample(size=particles) - 0.5, "cm/s")
276            data["io", "particle_mass"] = (prng.random_sample(particles), "g")
277    ug = load_uniform_grid(
278        data,
279        ndims,
280        length_unit=length_unit,
281        nprocs=nprocs,
282        unit_system=unit_system,
283        bbox=bbox,
284        default_species_fields=default_species_fields,
285    )
286    return ug
287
288
289_geom_transforms = {
290    # These are the bounds we want.  Cartesian we just assume goes 0 .. 1.
291    "cartesian": ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)),
292    "spherical": ((0.0, 0.0, 0.0), (1.0, np.pi, 2 * np.pi)),
293    "cylindrical": ((0.0, 0.0, 0.0), (1.0, 1.0, 2.0 * np.pi)),  # rzt
294    "polar": ((0.0, 0.0, 0.0), (1.0, 2.0 * np.pi, 1.0)),  # rtz
295    "geographic": ((-90.0, -180.0, 0.0), (90.0, 180.0, 1000.0)),  # latlonalt
296    "internal_geographic": ((-90.0, -180.0, 0.0), (90.0, 180.0, 1000.0)),  # latlondep
297}
298
299
300_fake_amr_ds_default_fields = ("Density",)
301_fake_amr_ds_default_units = ("g/cm**3",)
302
303
304def fake_amr_ds(
305    fields=None, units=None, geometry="cartesian", particles=0, length_unit=None
306):
307    from yt.loaders import load_amr_grids
308
309    fields, units = _check_field_unit_args_helper(
310        {
311            "fields": fields,
312            "units": units,
313        },
314        {
315            "fields": _fake_amr_ds_default_fields,
316            "units": _fake_amr_ds_default_units,
317        },
318    )
319
320    prng = RandomState(0x4D3D3D3)
321    LE, RE = _geom_transforms[geometry]
322    LE = np.array(LE)
323    RE = np.array(RE)
324    data = []
325    for gspec in _amr_grid_index:
326        level, left_edge, right_edge, dims = gspec
327        left_edge = left_edge * (RE - LE) + LE
328        right_edge = right_edge * (RE - LE) + LE
329        gdata = dict(
330            level=level, left_edge=left_edge, right_edge=right_edge, dimensions=dims
331        )
332        for f, u in zip(fields, units):
333            gdata[f] = (prng.random_sample(dims), u)
334        if particles:
335            for i, f in enumerate(f"particle_position_{ax}" for ax in "xyz"):
336                pdata = prng.random_sample(particles)
337                pdata /= right_edge[i] - left_edge[i]
338                pdata += left_edge[i]
339                gdata["io", f] = (pdata, "code_length")
340            for f in (f"particle_velocity_{ax}" for ax in "xyz"):
341                gdata["io", f] = (prng.random_sample(particles) - 0.5, "cm/s")
342            gdata["io", "particle_mass"] = (prng.random_sample(particles), "g")
343        data.append(gdata)
344    bbox = np.array([LE, RE]).T
345    return load_amr_grids(
346        data, [32, 32, 32], geometry=geometry, bbox=bbox, length_unit=length_unit
347    )
348
349
350_fake_particle_ds_default_fields = (
351    "particle_position_x",
352    "particle_position_y",
353    "particle_position_z",
354    "particle_mass",
355    "particle_velocity_x",
356    "particle_velocity_y",
357    "particle_velocity_z",
358)
359_fake_particle_ds_default_units = ("cm", "cm", "cm", "g", "cm/s", "cm/s", "cm/s")
360_fake_particle_ds_default_negative = (False, False, False, False, True, True, True)
361
362
363def fake_particle_ds(
364    fields=None,
365    units=None,
366    negative=None,
367    npart=16 ** 3,
368    length_unit=1.0,
369    data=None,
370):
371    from yt.loaders import load_particles
372
373    prng = RandomState(0x4D3D3D3)
374    if negative is not None and not is_sequence(negative):
375        negative = [negative for f in fields]
376
377    fields, units, negative = _check_field_unit_args_helper(
378        {
379            "fields": fields,
380            "units": units,
381            "negative": negative,
382        },
383        {
384            "fields": _fake_particle_ds_default_fields,
385            "units": _fake_particle_ds_default_units,
386            "negative": _fake_particle_ds_default_negative,
387        },
388    )
389
390    offsets = []
391    for n in negative:
392        if n:
393            offsets.append(0.5)
394        else:
395            offsets.append(0.0)
396    data = data if data else {}
397    for field, offset, u in zip(fields, offsets, units):
398        if field in data:
399            v = data[field]
400            continue
401        if "position" in field:
402            v = prng.normal(loc=0.5, scale=0.25, size=npart)
403            np.clip(v, 0.0, 1.0, v)
404        v = prng.random_sample(npart) - offset
405        data[field] = (v, u)
406    bbox = np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]])
407    ds = load_particles(data, 1.0, bbox=bbox)
408    return ds
409
410
411def fake_tetrahedral_ds():
412    from yt.frontends.stream.sample_data.tetrahedral_mesh import (
413        _connectivity,
414        _coordinates,
415    )
416    from yt.loaders import load_unstructured_mesh
417
418    prng = RandomState(0x4D3D3D3)
419
420    # the distance from the origin
421    node_data = {}
422    dist = np.sum(_coordinates ** 2, 1)
423    node_data[("connect1", "test")] = dist[_connectivity]
424
425    # each element gets a random number
426    elem_data = {}
427    elem_data[("connect1", "elem")] = prng.rand(_connectivity.shape[0])
428
429    ds = load_unstructured_mesh(
430        _connectivity, _coordinates, node_data=node_data, elem_data=elem_data
431    )
432    return ds
433
434
435def fake_hexahedral_ds(fields=None):
436    from yt.frontends.stream.sample_data.hexahedral_mesh import (
437        _connectivity,
438        _coordinates,
439    )
440    from yt.loaders import load_unstructured_mesh
441
442    prng = RandomState(0x4D3D3D3)
443    # the distance from the origin
444    node_data = {}
445    dist = np.sum(_coordinates ** 2, 1)
446    node_data[("connect1", "test")] = dist[_connectivity - 1]
447
448    for field in always_iterable(fields):
449        node_data[("connect1", field)] = dist[_connectivity - 1]
450
451    # each element gets a random number
452    elem_data = {}
453    elem_data[("connect1", "elem")] = prng.rand(_connectivity.shape[0])
454
455    ds = load_unstructured_mesh(
456        _connectivity - 1, _coordinates, node_data=node_data, elem_data=elem_data
457    )
458    return ds
459
460
461def small_fake_hexahedral_ds():
462    from yt.loaders import load_unstructured_mesh
463
464    _coordinates = np.array(
465        [
466            [-1.0, -1.0, -1.0],
467            [0.0, -1.0, -1.0],
468            [-0.0, 0.0, -1.0],
469            [-1.0, -0.0, -1.0],
470            [-1.0, -1.0, 0.0],
471            [-0.0, -1.0, 0.0],
472            [-0.0, 0.0, -0.0],
473            [-1.0, 0.0, -0.0],
474        ]
475    )
476    _connectivity = np.array([[1, 2, 3, 4, 5, 6, 7, 8]])
477
478    # the distance from the origin
479    node_data = {}
480    dist = np.sum(_coordinates ** 2, 1)
481    node_data[("connect1", "test")] = dist[_connectivity - 1]
482
483    ds = load_unstructured_mesh(_connectivity - 1, _coordinates, node_data=node_data)
484    return ds
485
486
487def fake_vr_orientation_test_ds(N=96, scale=1):
488    """
489    create a toy dataset that puts a sphere at (0,0,0), a single cube
490    on +x, two cubes on +y, and three cubes on +z in a domain from
491    [-1*scale,1*scale]**3.  The lower planes
492    (x = -1*scale, y = -1*scale, z = -1*scale) are also given non-zero
493    values.
494
495    This dataset allows you to easily explore orientations and
496    handiness in VR and other renderings
497
498    Parameters
499    ----------
500
501    N : integer
502       The number of cells along each direction
503
504    scale : float
505       A spatial scale, the domain boundaries will be multiplied by scale to
506       test datasets that have spatial different scales (e.g. data in CGS units)
507
508    """
509    from yt.loaders import load_uniform_grid
510
511    xmin = ymin = zmin = -1.0 * scale
512    xmax = ymax = zmax = 1.0 * scale
513
514    dcoord = (xmax - xmin) / N
515
516    arr = np.zeros((N, N, N), dtype=np.float64)
517    arr[:, :, :] = 1.0e-4
518
519    bbox = np.array([[xmin, xmax], [ymin, ymax], [zmin, zmax]])
520
521    # coordinates -- in the notation data[i, j, k]
522    x = (np.arange(N) + 0.5) * dcoord + xmin
523    y = (np.arange(N) + 0.5) * dcoord + ymin
524    z = (np.arange(N) + 0.5) * dcoord + zmin
525
526    x3d, y3d, z3d = np.meshgrid(x, y, z, indexing="ij")
527
528    # sphere at the origin
529    c = np.array([0.5 * (xmin + xmax), 0.5 * (ymin + ymax), 0.5 * (zmin + zmax)])
530    r = np.sqrt((x3d - c[0]) ** 2 + (y3d - c[1]) ** 2 + (z3d - c[2]) ** 2)
531    arr[r < 0.05] = 1.0
532
533    arr[abs(x3d - xmin) < 2 * dcoord] = 0.3
534    arr[abs(y3d - ymin) < 2 * dcoord] = 0.3
535    arr[abs(z3d - zmin) < 2 * dcoord] = 0.3
536
537    # single cube on +x
538    xc = 0.75 * scale
539    dx = 0.05 * scale
540    idx = np.logical_and(
541        np.logical_and(x3d > xc - dx, x3d < xc + dx),
542        np.logical_and(
543            np.logical_and(y3d > -dx, y3d < dx), np.logical_and(z3d > -dx, z3d < dx)
544        ),
545    )
546    arr[idx] = 1.0
547
548    # two cubes on +y
549    dy = 0.05 * scale
550    for yc in [0.65 * scale, 0.85 * scale]:
551        idx = np.logical_and(
552            np.logical_and(y3d > yc - dy, y3d < yc + dy),
553            np.logical_and(
554                np.logical_and(x3d > -dy, x3d < dy), np.logical_and(z3d > -dy, z3d < dy)
555            ),
556        )
557        arr[idx] = 0.8
558
559    # three cubes on +z
560    dz = 0.05 * scale
561    for zc in [0.5 * scale, 0.7 * scale, 0.9 * scale]:
562        idx = np.logical_and(
563            np.logical_and(z3d > zc - dz, z3d < zc + dz),
564            np.logical_and(
565                np.logical_and(x3d > -dz, x3d < dz), np.logical_and(y3d > -dz, y3d < dz)
566            ),
567        )
568        arr[idx] = 0.6
569
570    data = dict(density=(arr, "g/cm**3"))
571    ds = load_uniform_grid(data, arr.shape, bbox=bbox)
572    return ds
573
574
575def fake_sph_orientation_ds():
576    """Returns an in-memory SPH dataset useful for testing
577
578    This dataset should have one particle at the origin, one more particle
579    along the x axis, two along y, and three along z. All particles will
580    have non-overlapping smoothing regions with a radius of 0.25, masses of 1,
581    and densities of 1, and zero velocity.
582    """
583    from yt import load_particles
584
585    npart = 7
586
587    # one particle at the origin, one particle along x-axis, two along y,
588    # three along z
589    data = {
590        "particle_position_x": (np.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]), "cm"),
591        "particle_position_y": (np.array([0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0]), "cm"),
592        "particle_position_z": (np.array([0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0]), "cm"),
593        "particle_mass": (np.ones(npart), "g"),
594        "particle_velocity_x": (np.zeros(npart), "cm/s"),
595        "particle_velocity_y": (np.zeros(npart), "cm/s"),
596        "particle_velocity_z": (np.zeros(npart), "cm/s"),
597        "smoothing_length": (0.25 * np.ones(npart), "cm"),
598        "density": (np.ones(npart), "g/cm**3"),
599        "temperature": (np.ones(npart), "K"),
600    }
601
602    bbox = np.array([[-4, 4], [-4, 4], [-4, 4]])
603
604    return load_particles(data=data, length_unit=1.0, bbox=bbox)
605
606
607def fake_sph_grid_ds(hsml_factor=1.0):
608    """Returns an in-memory SPH dataset useful for testing
609
610    This dataset should have 27 particles with the particles arranged uniformly
611    on a 3D grid. The bottom left corner is (0.5,0.5,0.5) and the top right
612    corner is (2.5,2.5,2.5). All particles will have non-overlapping smoothing
613    regions with a radius of 0.05, masses of 1, and densities of 1, and zero
614    velocity.
615    """
616    from yt import load_particles
617
618    npart = 27
619
620    x = np.empty(npart)
621    y = np.empty(npart)
622    z = np.empty(npart)
623
624    tot = 0
625    for i in range(0, 3):
626        for j in range(0, 3):
627            for k in range(0, 3):
628                x[tot] = i + 0.5
629                y[tot] = j + 0.5
630                z[tot] = k + 0.5
631                tot += 1
632
633    data = {
634        "particle_position_x": (x, "cm"),
635        "particle_position_y": (y, "cm"),
636        "particle_position_z": (z, "cm"),
637        "particle_mass": (np.ones(npart), "g"),
638        "particle_velocity_x": (np.zeros(npart), "cm/s"),
639        "particle_velocity_y": (np.zeros(npart), "cm/s"),
640        "particle_velocity_z": (np.zeros(npart), "cm/s"),
641        "smoothing_length": (0.05 * np.ones(npart) * hsml_factor, "cm"),
642        "density": (np.ones(npart), "g/cm**3"),
643        "temperature": (np.ones(npart), "K"),
644    }
645
646    bbox = np.array([[0, 3], [0, 3], [0, 3]])
647
648    return load_particles(data=data, length_unit=1.0, bbox=bbox)
649
650
651def construct_octree_mask(prng=RandomState(0x1D3D3D3), refined=None):  # noqa B008
652    # Implementation taken from url:
653    # http://docs.hyperion-rt.org/en/stable/advanced/indepth_oct.html
654
655    if refined in (None, True):
656        refined = [True]
657    if not refined:
658        refined = [False]
659        return refined
660
661    # Loop over subcells
662    for _ in range(8):
663        # Insert criterion for whether cell should be sub-divided. Here we
664        # just use a random number to demonstrate.
665        divide = prng.random_sample() < 0.12
666
667        # Append boolean to overall list
668        refined.append(divide)
669
670        # If the cell is sub-divided, recursively divide it further
671        if divide:
672            construct_octree_mask(prng, refined)
673    return refined
674
675
676def fake_octree_ds(
677    prng=RandomState(0x4D3D3D3),  # noqa B008
678    refined=None,
679    quantities=None,
680    bbox=None,
681    sim_time=0.0,
682    length_unit=None,
683    mass_unit=None,
684    time_unit=None,
685    velocity_unit=None,
686    magnetic_unit=None,
687    periodicity=(True, True, True),
688    over_refine_factor=1,
689    partial_coverage=1,
690    unit_system="cgs",
691):
692    from yt.loaders import load_octree
693
694    octree_mask = np.asarray(
695        construct_octree_mask(prng=prng, refined=refined), dtype=np.uint8
696    )
697    particles = np.sum(np.invert(octree_mask))
698
699    if quantities is None:
700        quantities = {}
701        quantities[("gas", "density")] = prng.random_sample((particles, 1))
702        quantities[("gas", "velocity_x")] = prng.random_sample((particles, 1))
703        quantities[("gas", "velocity_y")] = prng.random_sample((particles, 1))
704        quantities[("gas", "velocity_z")] = prng.random_sample((particles, 1))
705
706    ds = load_octree(
707        octree_mask=octree_mask,
708        data=quantities,
709        bbox=bbox,
710        sim_time=sim_time,
711        length_unit=length_unit,
712        mass_unit=mass_unit,
713        time_unit=time_unit,
714        velocity_unit=velocity_unit,
715        magnetic_unit=magnetic_unit,
716        periodicity=periodicity,
717        partial_coverage=partial_coverage,
718        over_refine_factor=over_refine_factor,
719        unit_system=unit_system,
720    )
721    return ds
722
723
724def add_noise_fields(ds):
725    """Add 4 classes of noise fields to a dataset"""
726    prng = RandomState(0x4D3D3D3)
727
728    def _binary_noise(field, data):
729        """random binary data"""
730        return prng.randint(low=0, high=2, size=data.size).astype("float64")
731
732    def _positive_noise(field, data):
733        """random strictly positive data"""
734        return prng.random_sample(data.size) + 1e-16
735
736    def _negative_noise(field, data):
737        """random negative data"""
738        return -prng.random_sample(data.size)
739
740    def _even_noise(field, data):
741        """random data with mixed signs"""
742        return 2 * prng.random_sample(data.size) - 1
743
744    ds.add_field(("gas", "noise0"), _binary_noise, sampling_type="cell")
745    ds.add_field(("gas", "noise1"), _positive_noise, sampling_type="cell")
746    ds.add_field(("gas", "noise2"), _negative_noise, sampling_type="cell")
747    ds.add_field(("gas", "noise3"), _even_noise, sampling_type="cell")
748
749
750def expand_keywords(keywords, full=False):
751    """
752    expand_keywords is a means for testing all possible keyword
753    arguments in the nosetests.  Simply pass it a dictionary of all the
754    keyword arguments and all of the values for these arguments that you
755    want to test.
756
757    It will return a list of kwargs dicts containing combinations of
758    the various kwarg values you passed it.  These can then be passed
759    to the appropriate function in nosetests.
760
761    If full=True, then every possible combination of keywords is produced,
762    otherwise, every keyword option is included at least once in the output
763    list.  Be careful, by using full=True, you may be in for an exponentially
764    larger number of tests!
765
766    Parameters
767    ----------
768
769    keywords : dict
770        a dictionary where the keys are the keywords for the function,
771        and the values of each key are the possible values that this key
772        can take in the function
773
774    full : bool
775        if set to True, every possible combination of given keywords is
776        returned
777
778    Returns
779    -------
780
781    array of dicts
782        An array of dictionaries to be individually passed to the appropriate
783        function matching these kwargs.
784
785    Examples
786    --------
787
788    >>> keywords = {}
789    >>> keywords["dpi"] = (50, 100, 200)
790    >>> keywords["cmap"] = ("arbre", "kelp")
791    >>> list_of_kwargs = expand_keywords(keywords)
792    >>> print(list_of_kwargs)
793
794    array([{'cmap': 'arbre', 'dpi': 50},
795           {'cmap': 'kelp', 'dpi': 100},
796           {'cmap': 'arbre', 'dpi': 200}], dtype=object)
797
798    >>> list_of_kwargs = expand_keywords(keywords, full=True)
799    >>> print(list_of_kwargs)
800
801    array([{'cmap': 'arbre', 'dpi': 50},
802           {'cmap': 'arbre', 'dpi': 100},
803           {'cmap': 'arbre', 'dpi': 200},
804           {'cmap': 'kelp', 'dpi': 50},
805           {'cmap': 'kelp', 'dpi': 100},
806           {'cmap': 'kelp', 'dpi': 200}], dtype=object)
807
808    >>> for kwargs in list_of_kwargs:
809    ...     write_projection(*args, **kwargs)
810    """
811
812    # if we want every possible combination of keywords, use iter magic
813    if full:
814        keys = sorted(keywords)
815        list_of_kwarg_dicts = np.array(
816            [
817                dict(zip(keys, prod))
818                for prod in it.product(*(keywords[key] for key in keys))
819            ]
820        )
821
822    # if we just want to probe each keyword, but not necessarily every
823    # combination
824    else:
825        # Determine the maximum number of values any of the keywords has
826        num_lists = 0
827        for val in keywords.values():
828            if isinstance(val, str):
829                num_lists = max(1.0, num_lists)
830            else:
831                num_lists = max(len(val), num_lists)
832
833        # Construct array of kwargs dicts, each element of the list is a different
834        # **kwargs dict.  each kwargs dict gives a different combination of
835        # the possible values of the kwargs
836
837        # initialize array
838        list_of_kwarg_dicts = np.array([dict() for x in range(num_lists)])
839
840        # fill in array
841        for i in np.arange(num_lists):
842            list_of_kwarg_dicts[i] = {}
843            for key in keywords.keys():
844                # if it's a string, use it (there's only one)
845                if isinstance(keywords[key], str):
846                    list_of_kwarg_dicts[i][key] = keywords[key]
847                # if there are more options, use the i'th val
848                elif i < len(keywords[key]):
849                    list_of_kwarg_dicts[i][key] = keywords[key][i]
850                # if there are not more options, use the 0'th val
851                else:
852                    list_of_kwarg_dicts[i][key] = keywords[key][0]
853
854    return list_of_kwarg_dicts
855
856
857def requires_module(module):
858    """
859    Decorator that takes a module name as an argument and tries to import it.
860    If the module imports without issue, the function is returned, but if not,
861    a null function is returned. This is so tests that depend on certain modules
862    being imported will not fail if the module is not installed on the testing
863    platform.
864    """
865    from nose import SkipTest
866
867    def ffalse(func):
868        @functools.wraps(func)
869        def false_wrapper(*args, **kwargs):
870            raise SkipTest
871
872        return false_wrapper
873
874    def ftrue(func):
875        @functools.wraps(func)
876        def true_wrapper(*args, **kwargs):
877            return func(*args, **kwargs)
878
879        return true_wrapper
880
881    try:
882        importlib.import_module(module)
883    except ImportError:
884        return ffalse
885    else:
886        return ftrue
887
888
889def requires_module_pytest(*module_names):
890    """
891    This is a replacement for yt.testing.requires_module that's
892    compatible with pytest, and accepts an arbitrary number of requirements to
893    avoid stacking decorators
894
895    Important: this is meant to decorate test functions only, it won't work as a
896    decorator to fixture functions.
897    It's meant to be imported as
898    >>> from yt.testing import requires_module_pytest as requires_module
899
900    So that it can be later renamed to `requires_module`.
901    """
902    import pytest
903
904    from yt.utilities import on_demand_imports as odi
905
906    def deco(func):
907        required_modules = {
908            name: getattr(odi, f"_{name}")._module for name in module_names
909        }
910        missing = [
911            name
912            for name, mod in required_modules.items()
913            if isinstance(mod, odi.NotAModule)
914        ]
915
916        # note that order between these two decorators matters
917        @pytest.mark.skipif(
918            missing,
919            reason=f"missing requirement(s): {', '.join(missing)}",
920        )
921        @functools.wraps(func)
922        def inner_func(*args, **kwargs):
923            return func(*args, **kwargs)
924
925        return inner_func
926
927    return deco
928
929
930def requires_file(req_file):
931    from nose import SkipTest
932
933    path = ytcfg.get("yt", "test_data_dir")
934
935    def ffalse(func):
936        @functools.wraps(func)
937        def false_wrapper(*args, **kwargs):
938            if ytcfg.get("yt", "internals", "strict_requires"):
939                raise FileNotFoundError(req_file)
940            raise SkipTest
941
942        return false_wrapper
943
944    def ftrue(func):
945        @functools.wraps(func)
946        def true_wrapper(*args, **kwargs):
947            return func(*args, **kwargs)
948
949        return true_wrapper
950
951    if os.path.exists(req_file):
952        return ftrue
953    else:
954        if os.path.exists(os.path.join(path, req_file)):
955            return ftrue
956        else:
957            return ffalse
958
959
960def disable_dataset_cache(func):
961    @functools.wraps(func)
962    def newfunc(*args, **kwargs):
963        restore_cfg_state = False
964        if not ytcfg.get("yt", "skip_dataset_cache"):
965            ytcfg["yt", "skip_dataset_cache"] = True
966        rv = func(*args, **kwargs)
967        if restore_cfg_state:
968            ytcfg["yt", "skip_dataset_cache"] = False
969        return rv
970
971    return newfunc
972
973
974@disable_dataset_cache
975def units_override_check(fn):
976    units_list = ["length", "time", "mass", "velocity", "magnetic", "temperature"]
977    ds1 = load(fn)
978    units_override = {}
979    attrs1 = []
980    attrs2 = []
981    for u in units_list:
982        unit_attr = getattr(ds1, f"{u}_unit", None)
983        if unit_attr is not None:
984            attrs1.append(unit_attr)
985            units_override[f"{u}_unit"] = (unit_attr.v, unit_attr.units)
986    del ds1
987    ds2 = load(fn, units_override=units_override)
988    assert len(ds2.units_override) > 0
989    for u in units_list:
990        unit_attr = getattr(ds2, f"{u}_unit", None)
991        if unit_attr is not None:
992            attrs2.append(unit_attr)
993    assert_equal(attrs1, attrs2)
994
995
996# This is an export of the 40 grids in IsolatedGalaxy that are of level 4 or
997# lower.  It's just designed to give a sample AMR index to deal with.
998_amr_grid_index = [
999    [0, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [32, 32, 32]],
1000    [1, [0.25, 0.21875, 0.25], [0.5, 0.5, 0.5], [16, 18, 16]],
1001    [1, [0.5, 0.21875, 0.25], [0.75, 0.5, 0.5], [16, 18, 16]],
1002    [1, [0.21875, 0.5, 0.25], [0.5, 0.75, 0.5], [18, 16, 16]],
1003    [1, [0.5, 0.5, 0.25], [0.75, 0.75, 0.5], [16, 16, 16]],
1004    [1, [0.25, 0.25, 0.5], [0.5, 0.5, 0.75], [16, 16, 16]],
1005    [1, [0.5, 0.25, 0.5], [0.75, 0.5, 0.75], [16, 16, 16]],
1006    [1, [0.25, 0.5, 0.5], [0.5, 0.75, 0.75], [16, 16, 16]],
1007    [1, [0.5, 0.5, 0.5], [0.75, 0.75, 0.75], [16, 16, 16]],
1008    [2, [0.5, 0.5, 0.5], [0.71875, 0.71875, 0.71875], [28, 28, 28]],
1009    [3, [0.5, 0.5, 0.5], [0.6640625, 0.65625, 0.6796875], [42, 40, 46]],
1010    [4, [0.5, 0.5, 0.5], [0.59765625, 0.6015625, 0.6015625], [50, 52, 52]],
1011    [2, [0.28125, 0.5, 0.5], [0.5, 0.734375, 0.71875], [28, 30, 28]],
1012    [3, [0.3359375, 0.5, 0.5], [0.5, 0.671875, 0.6640625], [42, 44, 42]],
1013    [4, [0.40625, 0.5, 0.5], [0.5, 0.59765625, 0.59765625], [48, 50, 50]],
1014    [2, [0.5, 0.28125, 0.5], [0.71875, 0.5, 0.71875], [28, 28, 28]],
1015    [3, [0.5, 0.3359375, 0.5], [0.671875, 0.5, 0.6640625], [44, 42, 42]],
1016    [4, [0.5, 0.40625, 0.5], [0.6015625, 0.5, 0.59765625], [52, 48, 50]],
1017    [2, [0.28125, 0.28125, 0.5], [0.5, 0.5, 0.71875], [28, 28, 28]],
1018    [3, [0.3359375, 0.3359375, 0.5], [0.5, 0.5, 0.671875], [42, 42, 44]],
1019    [
1020        4,
1021        [0.46484375, 0.37890625, 0.50390625],
1022        [0.4765625, 0.390625, 0.515625],
1023        [6, 6, 6],
1024    ],
1025    [4, [0.40625, 0.40625, 0.5], [0.5, 0.5, 0.59765625], [48, 48, 50]],
1026    [2, [0.5, 0.5, 0.28125], [0.71875, 0.71875, 0.5], [28, 28, 28]],
1027    [3, [0.5, 0.5, 0.3359375], [0.6796875, 0.6953125, 0.5], [46, 50, 42]],
1028    [4, [0.5, 0.5, 0.40234375], [0.59375, 0.6015625, 0.5], [48, 52, 50]],
1029    [2, [0.265625, 0.5, 0.28125], [0.5, 0.71875, 0.5], [30, 28, 28]],
1030    [3, [0.3359375, 0.5, 0.328125], [0.5, 0.65625, 0.5], [42, 40, 44]],
1031    [4, [0.40234375, 0.5, 0.40625], [0.5, 0.60546875, 0.5], [50, 54, 48]],
1032    [2, [0.5, 0.265625, 0.28125], [0.71875, 0.5, 0.5], [28, 30, 28]],
1033    [3, [0.5, 0.3203125, 0.328125], [0.6640625, 0.5, 0.5], [42, 46, 44]],
1034    [4, [0.5, 0.3984375, 0.40625], [0.546875, 0.5, 0.5], [24, 52, 48]],
1035    [4, [0.546875, 0.41796875, 0.4453125], [0.5625, 0.4375, 0.5], [8, 10, 28]],
1036    [4, [0.546875, 0.453125, 0.41796875], [0.5546875, 0.48046875, 0.4375], [4, 14, 10]],
1037    [4, [0.546875, 0.4375, 0.4375], [0.609375, 0.5, 0.5], [32, 32, 32]],
1038    [4, [0.546875, 0.4921875, 0.41796875], [0.56640625, 0.5, 0.4375], [10, 4, 10]],
1039    [
1040        4,
1041        [0.546875, 0.48046875, 0.41796875],
1042        [0.5703125, 0.4921875, 0.4375],
1043        [12, 6, 10],
1044    ],
1045    [4, [0.55859375, 0.46875, 0.43359375], [0.5703125, 0.48046875, 0.4375], [6, 6, 2]],
1046    [2, [0.265625, 0.28125, 0.28125], [0.5, 0.5, 0.5], [30, 28, 28]],
1047    [3, [0.328125, 0.3359375, 0.328125], [0.5, 0.5, 0.5], [44, 42, 44]],
1048    [4, [0.4140625, 0.40625, 0.40625], [0.5, 0.5, 0.5], [44, 48, 48]],
1049]
1050
1051
1052def check_results(func):
1053    r"""This is a decorator for a function to verify that the (numpy ndarray)
1054    result of a function is what it should be.
1055
1056    This function is designed to be used for very light answer testing.
1057    Essentially, it wraps around a larger function that returns a numpy array,
1058    and that has results that should not change.  It is not necessarily used
1059    inside the testing scripts themselves, but inside testing scripts written
1060    by developers during the testing of pull requests and new functionality.
1061    If a hash is specified, it "wins" and the others are ignored.  Otherwise,
1062    tolerance is 1e-8 (just above single precision.)
1063
1064    The correct results will be stored if the command line contains
1065    --answer-reference , and otherwise it will compare against the results on
1066    disk.  The filename will be func_results_ref_FUNCNAME.cpkl where FUNCNAME
1067    is the name of the function being tested.
1068
1069    If you would like more control over the name of the pickle file the results
1070    are stored in, you can pass the result_basename keyword argument to the
1071    function you are testing.  The check_results decorator will use the value
1072    of the keyword to construct the filename of the results data file.  If
1073    result_basename is not specified, the name of the testing function is used.
1074
1075    This will raise an exception if the results are not correct.
1076
1077    Examples
1078    --------
1079
1080    >>> @check_results
1081    ... def my_func(ds):
1082    ...     return ds.domain_width
1083
1084    >>> my_func(ds)
1085
1086    >>> @check_results
1087    ... def field_checker(dd, field_name):
1088    ...     return dd[field_name]
1089
1090    >>> field_checker(ds.all_data(), "density", result_basename="density")
1091
1092    """
1093
1094    def compute_results(func):
1095        @functools.wraps(func)
1096        def _func(*args, **kwargs):
1097            name = kwargs.pop("result_basename", func.__name__)
1098            rv = func(*args, **kwargs)
1099            if hasattr(rv, "convert_to_base"):
1100                rv.convert_to_base()
1101                _rv = rv.ndarray_view()
1102            else:
1103                _rv = rv
1104            mi = _rv.min()
1105            ma = _rv.max()
1106            st = _rv.std(dtype="float64")
1107            su = _rv.sum(dtype="float64")
1108            si = _rv.size
1109            ha = hashlib.md5(_rv.tobytes()).hexdigest()
1110            fn = f"func_results_ref_{name}.cpkl"
1111            with open(fn, "wb") as f:
1112                pickle.dump((mi, ma, st, su, si, ha), f)
1113            return rv
1114
1115        return _func
1116
1117    from yt.mods import unparsed_args
1118
1119    if "--answer-reference" in unparsed_args:
1120        return compute_results(func)
1121
1122    def compare_results(func):
1123        @functools.wraps(func)
1124        def _func(*args, **kwargs):
1125            name = kwargs.pop("result_basename", func.__name__)
1126            rv = func(*args, **kwargs)
1127            if hasattr(rv, "convert_to_base"):
1128                rv.convert_to_base()
1129                _rv = rv.ndarray_view()
1130            else:
1131                _rv = rv
1132            vals = (
1133                _rv.min(),
1134                _rv.max(),
1135                _rv.std(dtype="float64"),
1136                _rv.sum(dtype="float64"),
1137                _rv.size,
1138                hashlib.md5(_rv.tobytes()).hexdigest(),
1139            )
1140            fn = f"func_results_ref_{name}.cpkl"
1141            if not os.path.exists(fn):
1142                print("Answers need to be created with --answer-reference .")
1143                return False
1144            with open(fn, "rb") as f:
1145                ref = pickle.load(f)
1146            print(f"Sizes: {vals[4] == ref[4]} ({vals[4]}, {ref[4]})")
1147            assert_allclose(vals[0], ref[0], 1e-8, err_msg="min")
1148            assert_allclose(vals[1], ref[1], 1e-8, err_msg="max")
1149            assert_allclose(vals[2], ref[2], 1e-8, err_msg="std")
1150            assert_allclose(vals[3], ref[3], 1e-8, err_msg="sum")
1151            assert_equal(vals[4], ref[4])
1152            print("Hashes equal: %s" % (vals[-1] == ref[-1]))
1153            return rv
1154
1155        return _func
1156
1157    return compare_results(func)
1158
1159
1160def periodicity_cases(ds):
1161    # This is a generator that yields things near the corners.  It's good for
1162    # getting different places to check periodicity.
1163    yield (ds.domain_left_edge + ds.domain_right_edge) / 2.0
1164    dx = ds.domain_width / ds.domain_dimensions
1165    # We start one dx in, and only go to one in as well.
1166    for i in (1, ds.domain_dimensions[0] - 2):
1167        for j in (1, ds.domain_dimensions[1] - 2):
1168            for k in (1, ds.domain_dimensions[2] - 2):
1169                center = dx * np.array([i, j, k]) + ds.domain_left_edge
1170                yield center
1171
1172
1173def run_nose(
1174    verbose=False,
1175    run_answer_tests=False,
1176    answer_big_data=False,
1177    call_pdb=False,
1178    module=None,
1179):
1180    import sys
1181
1182    from yt.utilities.logger import ytLogger as mylog
1183    from yt.utilities.on_demand_imports import _nose
1184
1185    orig_level = mylog.getEffectiveLevel()
1186    mylog.setLevel(50)
1187    nose_argv = sys.argv
1188    nose_argv += ["--exclude=answer_testing", "--detailed-errors", "--exe"]
1189    if call_pdb:
1190        nose_argv += ["--pdb", "--pdb-failures"]
1191    if verbose:
1192        nose_argv.append("-v")
1193    if run_answer_tests:
1194        nose_argv.append("--with-answer-testing")
1195    if answer_big_data:
1196        nose_argv.append("--answer-big-data")
1197    if module:
1198        nose_argv.append(module)
1199    initial_dir = os.getcwd()
1200    yt_file = os.path.abspath(__file__)
1201    yt_dir = os.path.dirname(yt_file)
1202    if os.path.samefile(os.path.dirname(yt_dir), initial_dir):
1203        # Provide a nice error message to work around nose bug
1204        # see https://github.com/nose-devs/nose/issues/701
1205        raise RuntimeError(
1206            """
1207    The yt.run_nose function does not work correctly when invoked in
1208    the same directory as the installed yt package. Try starting
1209    a python session in a different directory before invoking yt.run_nose
1210    again. Alternatively, you can also run the "nosetests" executable in
1211    the current directory like so:
1212
1213        $ nosetests
1214            """
1215        )
1216    os.chdir(yt_dir)
1217    try:
1218        _nose.run(argv=nose_argv)
1219    finally:
1220        os.chdir(initial_dir)
1221        mylog.setLevel(orig_level)
1222
1223
1224def assert_allclose_units(actual, desired, rtol=1e-7, atol=0, **kwargs):
1225    """Raise an error if two objects are not equal up to desired tolerance
1226
1227    This is a wrapper for :func:`numpy.testing.assert_allclose` that also
1228    verifies unit consistency
1229
1230    Parameters
1231    ----------
1232    actual : array-like
1233        Array obtained (possibly with attached units)
1234    desired : array-like
1235        Array to compare with (possibly with attached units)
1236    rtol : float, optional
1237        Relative tolerance, defaults to 1e-7
1238    atol : float or quantity, optional
1239        Absolute tolerance. If units are attached, they must be consistent
1240        with the units of ``actual`` and ``desired``. If no units are attached,
1241        assumes the same units as ``desired``. Defaults to zero.
1242
1243    Notes
1244    -----
1245    Also accepts additional keyword arguments accepted by
1246    :func:`numpy.testing.assert_allclose`, see the documentation of that
1247    function for details.
1248
1249    """
1250    # Create a copy to ensure this function does not alter input arrays
1251    act = YTArray(actual)
1252    des = YTArray(desired)
1253
1254    try:
1255        des = des.in_units(act.units)
1256    except UnitOperationError as e:
1257        raise AssertionError(
1258            "Units of actual (%s) and desired (%s) do not have "
1259            "equivalent dimensions" % (act.units, des.units)
1260        ) from e
1261
1262    rt = YTArray(rtol)
1263    if not rt.units.is_dimensionless:
1264        raise AssertionError(f"Units of rtol ({rt.units}) are not dimensionless")
1265
1266    if not isinstance(atol, YTArray):
1267        at = YTQuantity(atol, des.units)
1268
1269    try:
1270        at = at.in_units(act.units)
1271    except UnitOperationError as e:
1272        raise AssertionError(
1273            "Units of atol (%s) and actual (%s) do not have "
1274            "equivalent dimensions" % (at.units, act.units)
1275        ) from e
1276
1277    # units have been validated, so we strip units before calling numpy
1278    # to avoid spurious errors
1279    act = act.value
1280    des = des.value
1281    rt = rt.value
1282    at = at.value
1283
1284    return assert_allclose(act, des, rt, at, **kwargs)
1285
1286
1287def assert_fname(fname):
1288    """Function that checks file type using libmagic"""
1289    if fname is None:
1290        return
1291
1292    with open(fname, "rb") as fimg:
1293        data = fimg.read()
1294    image_type = ""
1295
1296    # see http://www.w3.org/TR/PNG/#5PNG-file-signature
1297    if data.startswith(b"\211PNG\r\n\032\n"):
1298        image_type = ".png"
1299    # see http://www.mathguide.de/info/tools/media-types/image/jpeg
1300    elif data.startswith(b"\377\330"):
1301        image_type = ".jpeg"
1302    elif data.startswith(b"%!PS-Adobe"):
1303        data_str = data.decode("utf-8", "ignore")
1304        if "EPSF" in data_str[: data_str.index("\n")]:
1305            image_type = ".eps"
1306        else:
1307            image_type = ".ps"
1308    elif data.startswith(b"%PDF"):
1309        image_type = ".pdf"
1310
1311    extension = os.path.splitext(fname)[1]
1312
1313    assert (
1314        image_type == extension
1315    ), "Expected an image of type '{}' but '{}' is an image of type '{}'".format(
1316        extension,
1317        fname,
1318        image_type,
1319    )
1320
1321
1322def requires_backend(backend):
1323    """Decorator to check for a specified matplotlib backend.
1324
1325    This decorator returns the decorated function if the specified `backend`
1326    is same as of `matplotlib.get_backend()`, otherwise returns null function.
1327    It could be used to execute function only when a particular `backend` of
1328    matplotlib is being used.
1329
1330    Parameters
1331    ----------
1332    backend : String
1333        The value which is compared with the current matplotlib backend in use.
1334
1335    Returns
1336    -------
1337    Decorated function or null function
1338
1339    """
1340    import pytest
1341
1342    def ffalse(func):
1343        # returning a lambda : None causes an error when using pytest. Having
1344        # a function (skip) that returns None does work, but pytest marks the
1345        # test as having passed, which seems bad, since it wasn't actually run.
1346        # Using pytest.skip() means that a change to test_requires_backend was
1347        # needed since None is no longer returned, so we check for the skip
1348        # exception in the xfail case for that test
1349        def skip(*args, **kwargs):
1350            msg = f"`{backend}` backend not found, skipping: `{func.__name__}`"
1351            print(msg)
1352            pytest.skip(msg)
1353
1354        if ytcfg.get("yt", "internals", "within_pytest"):
1355            return skip
1356        else:
1357            return lambda: None
1358
1359    def ftrue(func):
1360        return func
1361
1362    if backend.lower() == matplotlib.get_backend().lower():
1363        return ftrue
1364    return ffalse
1365
1366
1367class TempDirTest(unittest.TestCase):
1368    """
1369    A test class that runs in a temporary directory and
1370    removes it afterward.
1371    """
1372
1373    def setUp(self):
1374        self.curdir = os.getcwd()
1375        self.tmpdir = tempfile.mkdtemp()
1376        os.chdir(self.tmpdir)
1377
1378    def tearDown(self):
1379        os.chdir(self.curdir)
1380        shutil.rmtree(self.tmpdir)
1381
1382
1383class ParticleSelectionComparison:
1384    """
1385    This is a test helper class that takes a particle dataset, caches the
1386    particles it has on disk (manually reading them using lower-level IO
1387    routines) and then received a data object that it compares against manually
1388    running the data object's selection routines.  All supplied data objects
1389    must be created from the input dataset.
1390    """
1391
1392    def __init__(self, ds):
1393        self.ds = ds
1394        # Construct an index so that we get all the data_files
1395        ds.index
1396        particles = {}
1397        # hsml is the smoothing length we use for radial selection
1398        hsml = {}
1399        for data_file in ds.index.data_files:
1400            for ptype, pos_arr in ds.index.io._yield_coordinates(data_file):
1401                particles.setdefault(ptype, []).append(pos_arr)
1402                if ptype in getattr(ds, "_sph_ptypes", ()):
1403                    hsml.setdefault(ptype, []).append(
1404                        ds.index.io._get_smoothing_length(
1405                            data_file, pos_arr.dtype, pos_arr.shape
1406                        )
1407                    )
1408        for ptype in particles:
1409            particles[ptype] = np.concatenate(particles[ptype])
1410            if ptype in hsml:
1411                hsml[ptype] = np.concatenate(hsml[ptype])
1412        self.particles = particles
1413        self.hsml = hsml
1414
1415    def compare_dobj_selection(self, dobj):
1416        for ptype in sorted(self.particles):
1417            x, y, z = self.particles[ptype].T
1418            # Set our radii to zero for now, I guess?
1419            radii = self.hsml.get(ptype, 0.0)
1420            sel_index = dobj.selector.select_points(x, y, z, radii)
1421            if sel_index is None:
1422                sel_pos = np.empty((0, 3))
1423            else:
1424                sel_pos = self.particles[ptype][sel_index, :]
1425
1426            obj_results = []
1427            for chunk in dobj.chunks([], "io"):
1428                obj_results.append(chunk[ptype, "particle_position"])
1429            if any(_.size > 0 for _ in obj_results):
1430                obj_results = np.concatenate(obj_results, axis=0)
1431            else:
1432                obj_results = np.empty((0, 3))
1433            # Sometimes we get unitary scaling or other floating point noise. 5
1434            # NULP should be OK.  This is mostly for stuff like Rockstar, where
1435            # the f32->f64 casting happens at different places depending on
1436            # which code path we use.
1437            assert_array_almost_equal_nulp(sel_pos, obj_results, 5)
1438
1439    def run_defaults(self):
1440        """
1441        This runs lots of samples that touch different types of wraparounds.
1442
1443        Specifically, it does:
1444
1445            * sphere in center with radius 0.1 unitary
1446            * sphere in center with radius 0.2 unitary
1447            * sphere in each of the eight corners of the domain with radius 0.1 unitary
1448            * sphere in center with radius 0.5 unitary
1449            * box that covers 0.1 .. 0.9
1450            * box from 0.8 .. 0.85
1451            * box from 0.3..0.6, 0.2..0.8, 0.0..0.1
1452        """
1453        sp1 = self.ds.sphere("c", (0.1, "unitary"))
1454        self.compare_dobj_selection(sp1)
1455
1456        sp2 = self.ds.sphere("c", (0.2, "unitary"))
1457        self.compare_dobj_selection(sp2)
1458
1459        centers = [
1460            [0.04, 0.5, 0.5],
1461            [0.5, 0.04, 0.5],
1462            [0.5, 0.5, 0.04],
1463            [0.04, 0.04, 0.04],
1464            [0.96, 0.5, 0.5],
1465            [0.5, 0.96, 0.5],
1466            [0.5, 0.5, 0.96],
1467            [0.96, 0.96, 0.96],
1468        ]
1469        r = self.ds.quan(0.1, "unitary")
1470        for center in centers:
1471            c = self.ds.arr(center, "unitary") + self.ds.domain_left_edge.in_units(
1472                "unitary"
1473            )
1474            if not all(self.ds.periodicity):
1475                # filter out the periodic bits for non-periodic datasets
1476                if any(c - r < self.ds.domain_left_edge) or any(
1477                    c + r > self.ds.domain_right_edge
1478                ):
1479                    continue
1480            sp = self.ds.sphere(c, (0.1, "unitary"))
1481            self.compare_dobj_selection(sp)
1482
1483        sp = self.ds.sphere("c", (0.5, "unitary"))
1484        self.compare_dobj_selection(sp)
1485
1486        dd = self.ds.all_data()
1487        self.compare_dobj_selection(dd)
1488
1489        # This is in raw numbers, so we can offset for the left edge
1490        LE = self.ds.domain_left_edge.in_units("unitary").d
1491
1492        reg1 = self.ds.r[
1493            (0.1 + LE[0], "unitary") : (0.9 + LE[0], "unitary"),
1494            (0.1 + LE[1], "unitary") : (0.9 + LE[1], "unitary"),
1495            (0.1 + LE[2], "unitary") : (0.9 + LE[2], "unitary"),
1496        ]
1497        self.compare_dobj_selection(reg1)
1498
1499        reg2 = self.ds.r[
1500            (0.8 + LE[0], "unitary") : (0.85 + LE[0], "unitary"),
1501            (0.8 + LE[1], "unitary") : (0.85 + LE[1], "unitary"),
1502            (0.8 + LE[2], "unitary") : (0.85 + LE[2], "unitary"),
1503        ]
1504        self.compare_dobj_selection(reg2)
1505
1506        reg3 = self.ds.r[
1507            (0.3 + LE[0], "unitary") : (0.6 + LE[0], "unitary"),
1508            (0.2 + LE[1], "unitary") : (0.8 + LE[1], "unitary"),
1509            (0.0 + LE[2], "unitary") : (0.1 + LE[2], "unitary"),
1510        ]
1511        self.compare_dobj_selection(reg3)
1512