1"""
2unyt_array class.
3
4
5
6"""
7
8# -----------------------------------------------------------------------------
9# Copyright (c) 2018, yt Development Team.
10#
11# Distributed under the terms of the Modified BSD License.
12#
13# The full license is in the LICENSE file, distributed with this software.
14# -----------------------------------------------------------------------------
15
16import copy
17
18from functools import lru_cache
19from numbers import Number as numeric_type
20import numpy as np
21from numpy import (
22    add,
23    subtract,
24    multiply,
25    divide,
26    logaddexp,
27    logaddexp2,
28    true_divide,
29    floor_divide,
30    negative,
31    power,
32    remainder,
33    mod,
34    absolute,
35    rint,
36    sign,
37    conj,
38    exp,
39    exp2,
40    log,
41    log2,
42    log10,
43    expm1,
44    log1p,
45    sqrt,
46    square,
47    reciprocal,
48    sin,
49    cos,
50    tan,
51    arcsin,
52    arccos,
53    arctan,
54    arctan2,
55    hypot,
56    sinh,
57    cosh,
58    tanh,
59    arcsinh,
60    arccosh,
61    arctanh,
62    deg2rad,
63    rad2deg,
64    bitwise_and,
65    bitwise_or,
66    bitwise_xor,
67    invert,
68    left_shift,
69    right_shift,
70    greater,
71    greater_equal,
72    less,
73    less_equal,
74    not_equal,
75    equal,
76    logical_and,
77    logical_or,
78    logical_xor,
79    logical_not,
80    maximum,
81    minimum,
82    fmax,
83    fmin,
84    isreal,
85    iscomplex,
86    isfinite,
87    isinf,
88    isnan,
89    signbit,
90    copysign,
91    nextafter,
92    modf,
93    ldexp,
94    frexp,
95    fmod,
96    floor,
97    ceil,
98    trunc,
99    fabs,
100    spacing,
101    positive,
102    divmod as divmod_,
103    isnat,
104    heaviside,
105    ones_like,
106    matmul,
107)
108from numpy.core.umath import _ones_like
109
110try:
111    from numpy.core.umath import clip
112except ImportError:
113    clip = None
114from sympy import Rational
115import warnings
116
117from unyt.dimensions import angle, temperature
118from unyt.exceptions import (
119    IterableUnitCoercionError,
120    InvalidUnitEquivalence,
121    InvalidUnitOperation,
122    MKSCGSConversionError,
123    UnitOperationError,
124    UnitConversionError,
125    UnitsNotReducible,
126    SymbolNotFoundError,
127)
128from unyt.equivalencies import equivalence_registry
129from unyt._on_demand_imports import _astropy, _pint
130from unyt._pint_conversions import convert_pint_units
131from unyt._unit_lookup_table import default_unit_symbol_lut
132from unyt.unit_object import _check_em_conversion, _em_conversion, Unit
133from unyt.unit_registry import (
134    _sanitize_unit_system,
135    UnitRegistry,
136    default_unit_registry,
137    _correct_old_unit_registry,
138)
139
140NULL_UNIT = Unit()
141POWER_SIGN_MAPPING = {multiply: 1, divide: -1}
142
143__doctest_requires__ = {
144    ("unyt_array.from_pint", "unyt_array.to_pint"): ["pint"],
145    ("unyt_array.from_astropy", "unyt_array.to_astropy"): ["astropy"],
146}
147
148
149def _iterable(obj):
150    try:
151        len(obj)
152    except Exception:
153        return False
154    return True
155
156
157@lru_cache(maxsize=128, typed=False)
158def _sqrt_unit(unit):
159    return 1, unit ** 0.5
160
161
162@lru_cache(maxsize=128, typed=False)
163def _multiply_units(unit1, unit2):
164    try:
165        ret = (unit1 * unit2).simplify()
166    except SymbolNotFoundError:
167        # Some operators are not natively commutative when operands are
168        # defined within different unit registries, and conversion
169        # is defined one way but not the other.
170        ret = (unit2 * unit1).simplify()
171    return ret.as_coeff_unit()
172
173
174TEMPERATURE_WARNING = """
175    Ambiguous operation with heterogeneous temperature units.
176    In the future, such operations will generate UnitOperationError.
177    Use delta_degC or delta_degF to avoid the ambiguity.
178"""
179
180
181@lru_cache(maxsize=128, typed=False)
182def _preserve_units(unit1, unit2=None):
183    if unit2 is None or unit1.dimensions is not temperature:
184        return 1, unit1
185    if unit1.base_offset == 0.0 and unit2.base_offset != 0.0:
186        if str(unit1.expr) in ["K", "R"]:
187            warnings.warn(TEMPERATURE_WARNING, FutureWarning, stacklevel=3)
188            return 1, unit1
189        return 1, unit2
190    return 1, unit1
191
192
193@lru_cache(maxsize=128, typed=False)
194def _power_unit(unit, power):
195    return 1, unit ** power
196
197
198@lru_cache(maxsize=128, typed=False)
199def _square_unit(unit):
200    return 1, unit * unit
201
202
203@lru_cache(maxsize=128, typed=False)
204def _divide_units(unit1, unit2):
205    try:
206        ret = (unit1 / unit2).simplify()
207    except SymbolNotFoundError:
208        ret = (1 / (unit2 / unit1).simplify()).units
209    return ret.as_coeff_unit()
210
211
212@lru_cache(maxsize=128, typed=False)
213def _reciprocal_unit(unit):
214    return 1, unit ** -1
215
216
217def _passthrough_unit(unit, unit2=None):
218    return 1, unit
219
220
221def _return_without_unit(unit, unit2=None):
222    return 1, None
223
224
225def _arctan2_unit(unit1, unit2):
226    return 1, NULL_UNIT
227
228
229def _comparison_unit(unit1, unit2=None):
230    return 1, None
231
232
233def _invert_units(unit):
234    raise TypeError("Bit-twiddling operators are not defined for unyt_array instances")
235
236
237def _bitop_units(unit1, unit2):
238    raise TypeError("Bit-twiddling operators are not defined for unyt_array instances")
239
240
241def _coerce_iterable_units(input_object, registry=None):
242    if isinstance(input_object, np.ndarray):
243        return input_object
244    if _iterable(input_object):
245        if any([isinstance(o, unyt_array) for o in input_object]):
246            ff = getattr(input_object[0], "units", NULL_UNIT)
247            if any([ff != getattr(_, "units", NULL_UNIT) for _ in input_object]):
248                raise IterableUnitCoercionError(input_object)
249            # This will create a copy of the data in the iterable.
250            return unyt_array(np.array(input_object), ff, registry=registry)
251    return np.asarray(input_object)
252
253
254def _sanitize_units_convert(possible_units, registry):
255    if isinstance(possible_units, Unit):
256        return possible_units
257
258    # let Unit() try to parse this if it's not already a Unit
259    unit = Unit(possible_units, registry=registry)
260
261    return unit
262
263
264unary_operators = (
265    negative,
266    absolute,
267    rint,
268    sign,
269    conj,
270    exp,
271    exp2,
272    log,
273    log2,
274    log10,
275    expm1,
276    log1p,
277    sqrt,
278    square,
279    reciprocal,
280    sin,
281    cos,
282    tan,
283    arcsin,
284    arccos,
285    arctan,
286    sinh,
287    cosh,
288    tanh,
289    arcsinh,
290    arccosh,
291    arctanh,
292    deg2rad,
293    rad2deg,
294    invert,
295    logical_not,
296    isreal,
297    iscomplex,
298    isfinite,
299    isinf,
300    isnan,
301    signbit,
302    floor,
303    ceil,
304    trunc,
305    modf,
306    frexp,
307    fabs,
308    spacing,
309    positive,
310    isnat,
311    ones_like,
312)
313
314binary_operators = (
315    add,
316    subtract,
317    multiply,
318    divide,
319    logaddexp,
320    logaddexp2,
321    true_divide,
322    power,
323    remainder,
324    mod,
325    arctan2,
326    hypot,
327    bitwise_and,
328    bitwise_or,
329    bitwise_xor,
330    left_shift,
331    right_shift,
332    greater,
333    greater_equal,
334    less,
335    less_equal,
336    not_equal,
337    equal,
338    logical_and,
339    logical_or,
340    logical_xor,
341    maximum,
342    minimum,
343    fmax,
344    fmin,
345    copysign,
346    nextafter,
347    ldexp,
348    fmod,
349    divmod_,
350    heaviside,
351)
352
353trigonometric_operators = (sin, cos, tan)
354
355multiple_output_operators = {modf: 2, frexp: 2, divmod_: 2}
356
357LARGE_INPUT = {4: 16777217, 8: 9007199254740993}
358
359
360class unyt_array(np.ndarray):
361    """
362    An ndarray subclass that attaches a symbolic unit object to the array data.
363
364    Parameters
365    ----------
366
367    input_array : iterable
368        A tuple, list, or array to attach units to
369    input_units : String unit name, unit symbol object, or astropy unit
370        The units of the array. Powers must be specified using python
371        syntax (cm**3, not cm^3).
372    registry : :class:`unyt.unit_registry.UnitRegistry`
373        The registry to create units from. If input_units is already associated
374        with a unit registry and this is specified, this will be used instead
375        of the registry associated with the unit object.
376    dtype : numpy dtype or dtype name
377        The dtype of the array data. Defaults to the dtype of the input data,
378        or, if none is found, uses np.float64
379    bypass_validation : boolean
380        If True, all input validation is skipped. Using this option may produce
381        corrupted, invalid units or array data, but can lead to significant
382        speedups in the input validation logic adds significant overhead. If
383        set, input_units *must* be a valid unit object. Defaults to False.
384    name : string
385        The name of the array. Defaults to None. This attribute does not propagate
386        through mathematical operations, but is preserved under indexing
387        and unit conversions.
388
389    Examples
390    --------
391
392    >>> from unyt import unyt_array
393    >>> a = unyt_array([1, 2, 3], 'cm')
394    >>> b = unyt_array([4, 5, 6], 'm')
395    >>> a + b
396    unyt_array([401., 502., 603.], 'cm')
397    >>> b + a
398    unyt_array([4.01, 5.02, 6.03], 'm')
399
400    NumPy ufuncs will pass through units where appropriate.
401
402    >>> from unyt import g, cm
403    >>> import numpy as np
404    >>> a = (np.arange(8) - 4)*g/cm**3
405    >>> np.abs(a)
406    unyt_array([4, 3, 2, 1, 0, 1, 2, 3], 'g/cm**3')
407
408    and strip them when it would be annoying to deal with them.
409
410    >>> np.log10(np.arange(8)+1)
411    array([0.        , 0.30103   , 0.47712125, 0.60205999, 0.69897   ,
412           0.77815125, 0.84509804, 0.90308999])
413
414    """
415
416    _ufunc_registry = {
417        add: _preserve_units,
418        subtract: _preserve_units,
419        multiply: _multiply_units,
420        divide: _divide_units,
421        logaddexp: _return_without_unit,
422        logaddexp2: _return_without_unit,
423        true_divide: _divide_units,
424        floor_divide: _divide_units,
425        negative: _passthrough_unit,
426        power: _power_unit,
427        remainder: _preserve_units,
428        mod: _preserve_units,
429        fmod: _preserve_units,
430        absolute: _passthrough_unit,
431        fabs: _passthrough_unit,
432        rint: _return_without_unit,
433        sign: _return_without_unit,
434        conj: _passthrough_unit,
435        exp: _return_without_unit,
436        exp2: _return_without_unit,
437        log: _return_without_unit,
438        log2: _return_without_unit,
439        log10: _return_without_unit,
440        expm1: _return_without_unit,
441        log1p: _return_without_unit,
442        sqrt: _sqrt_unit,
443        square: _square_unit,
444        reciprocal: _reciprocal_unit,
445        sin: _return_without_unit,
446        cos: _return_without_unit,
447        tan: _return_without_unit,
448        sinh: _return_without_unit,
449        cosh: _return_without_unit,
450        tanh: _return_without_unit,
451        arcsin: _return_without_unit,
452        arccos: _return_without_unit,
453        arctan: _return_without_unit,
454        arctan2: _arctan2_unit,
455        arcsinh: _return_without_unit,
456        arccosh: _return_without_unit,
457        arctanh: _return_without_unit,
458        hypot: _preserve_units,
459        deg2rad: _return_without_unit,
460        rad2deg: _return_without_unit,
461        bitwise_and: _bitop_units,
462        bitwise_or: _bitop_units,
463        bitwise_xor: _bitop_units,
464        invert: _invert_units,
465        left_shift: _bitop_units,
466        right_shift: _bitop_units,
467        greater: _comparison_unit,
468        greater_equal: _comparison_unit,
469        less: _comparison_unit,
470        less_equal: _comparison_unit,
471        not_equal: _comparison_unit,
472        equal: _comparison_unit,
473        logical_and: _comparison_unit,
474        logical_or: _comparison_unit,
475        logical_xor: _comparison_unit,
476        logical_not: _return_without_unit,
477        maximum: _preserve_units,
478        minimum: _preserve_units,
479        fmax: _preserve_units,
480        fmin: _preserve_units,
481        isreal: _return_without_unit,
482        iscomplex: _return_without_unit,
483        isfinite: _return_without_unit,
484        isinf: _return_without_unit,
485        isnan: _return_without_unit,
486        signbit: _return_without_unit,
487        copysign: _passthrough_unit,
488        nextafter: _preserve_units,
489        modf: _passthrough_unit,
490        ldexp: _bitop_units,
491        frexp: _return_without_unit,
492        floor: _passthrough_unit,
493        ceil: _passthrough_unit,
494        trunc: _passthrough_unit,
495        spacing: _passthrough_unit,
496        positive: _passthrough_unit,
497        divmod_: _passthrough_unit,
498        isnat: _return_without_unit,
499        heaviside: _preserve_units,
500        _ones_like: _preserve_units,
501        matmul: _multiply_units,
502        clip: _passthrough_unit,
503    }
504
505    __array_priority__ = 2.0
506
507    def __new__(
508        cls,
509        input_array,
510        units=None,
511        registry=None,
512        dtype=None,
513        bypass_validation=False,
514        input_units=None,
515        name=None,
516    ):
517        # deprecate input_units in favor of units
518        if input_units is not None:
519            warnings.warn(
520                "input_units has been deprecated, please use units instead",
521                DeprecationWarning,
522                stacklevel=2,
523            )
524        if units is not None:
525            input_units = units
526        if bypass_validation is True:
527            if dtype is None:
528                dtype = input_array.dtype
529            obj = input_array.view(type=cls, dtype=dtype)
530            obj.units = input_units
531            if registry is not None:
532                obj.units.registry = registry
533            obj.name = name
534            return obj
535        if isinstance(input_array, unyt_array):
536            ret = input_array.view(cls)
537            if input_units is None:
538                if registry is None:
539                    ret.units = input_array.units
540                else:
541                    units = Unit(str(input_array.units), registry=registry)
542                    ret.units = units
543            elif isinstance(input_units, Unit):
544                ret.units = input_units
545            else:
546                ret.units = Unit(input_units, registry=registry)
547            ret.name = name
548            return ret
549        elif isinstance(input_array, np.ndarray):
550            pass
551        elif _iterable(input_array) and input_array:
552            if isinstance(input_array[0], unyt_array):
553                return _coerce_iterable_units(input_array, registry)
554
555        # Input array is an already formed ndarray instance
556        # We first cast to be our class type
557
558        obj = np.asarray(input_array, dtype=dtype).view(cls)
559
560        # Check units type
561        if input_units is None:
562            # Nothing provided. Make dimensionless...
563            units = Unit()
564        elif isinstance(input_units, Unit):
565            if registry and registry is not input_units.registry:
566                units = Unit(str(input_units), registry=registry)
567            else:
568                units = input_units
569        else:
570            # units kwarg set, but it's not a Unit object.
571            # don't handle all the cases here, let the Unit class handle if
572            # it's a str.
573            units = Unit(input_units, registry=registry)
574
575        # Attach the units and name
576        obj.units = units
577        obj.name = name
578        return obj
579
580    def __repr__(self):
581        rep = super(unyt_array, self).__repr__()
582        units_repr = self.units.__repr__()
583        if "=" in rep:
584            return rep[:-1] + ", units='" + units_repr + "')"
585        else:
586            return rep[:-1] + ", '" + units_repr + "')"
587
588    def __str__(self):
589        return str(self.view(np.ndarray)) + " " + str(self.units)
590
591    def __format__(self, format_spec):
592        ret = super(unyt_array, self).__format__(format_spec)
593        return ret + " {}".format(self.units)
594
595    #
596    # Start unit conversion methods
597    #
598
599    def convert_to_units(self, units, equivalence=None, **kwargs):
600        """
601        Convert the array to the given units in-place.
602
603        Optionally, an equivalence can be specified to convert to an
604        equivalent quantity which is not in the same dimensions.
605
606        Parameters
607        ----------
608        units : Unit object or string
609            The units you want to convert to.
610        equivalence : string, optional
611            The equivalence you wish to use. To see which equivalencies
612            are supported for this object, try the ``list_equivalencies``
613            method. Default: None
614        kwargs: optional
615            Any additional keyword arguments are supplied to the equivalence
616
617        Raises
618        ------
619        If the provided unit does not have the same dimensions as the array
620        this will raise a UnitConversionError
621
622        Examples
623        --------
624
625        >>> from unyt import cm, km
626        >>> length = [3000, 2000, 1000]*cm
627        >>> length.convert_to_units('m')
628        >>> print(length)
629        [30. 20. 10.] m
630        """
631        units = _sanitize_units_convert(units, self.units.registry)
632        if equivalence is None:
633            conv_data = _check_em_conversion(
634                self.units, units, registry=self.units.registry
635            )
636            if any(conv_data):
637                new_units, (conv_factor, offset) = _em_conversion(
638                    self.units, conv_data, units
639                )
640            else:
641                new_units = units
642                (conv_factor, offset) = self.units.get_conversion_factor(
643                    new_units, self.dtype
644                )
645
646            self.units = new_units
647            values = self.d
648            # if our dtype is an integer do the following somewhat awkward
649            # dance to change the dtype in-place. We can't use astype
650            # directly because that will create a copy and not update self
651            if self.dtype.kind in ("u", "i"):
652                # create a copy of the original data in floating point
653                # form, it's possible this may lose precision for very
654                # large integers
655                dsize = values.dtype.itemsize
656                new_dtype = "f" + str(dsize)
657                large = LARGE_INPUT.get(dsize, 0)
658                if large and np.any(np.abs(values) > large):
659                    warnings.warn(
660                        "Overflow encountered while converting to units '%s'"
661                        % new_units,
662                        RuntimeWarning,
663                        stacklevel=2,
664                    )
665                float_values = values.astype(new_dtype)
666                # change the dtypes in-place, this does not change the
667                # underlying memory buffer
668                values.dtype = new_dtype
669                self.dtype = new_dtype
670                # actually fill in the new float values now that our
671                # dtype is correct
672                np.copyto(values, float_values)
673            values *= conv_factor
674
675            if offset:
676                np.subtract(values, offset, values)
677        else:
678            self.convert_to_equivalent(units, equivalence, **kwargs)
679
680    def convert_to_base(self, unit_system=None, equivalence=None, **kwargs):
681        """
682        Convert the array in-place to the equivalent base units in
683        the specified unit system.
684
685        Optionally, an equivalence can be specified to convert to an
686        equivalent quantity which is not in the same dimensions.
687
688        Parameters
689        ----------
690        unit_system : string, optional
691            The unit system to be used in the conversion. If not specified,
692            the configured base units are used (defaults to MKS).
693        equivalence : string, optional
694            The equivalence you wish to use. To see which equivalencies
695            are supported for this object, try the ``list_equivalencies``
696            method. Default: None
697        kwargs: optional
698            Any additional keyword arguments are supplied to the equivalence
699
700        Raises
701        ------
702        If the provided unit does not have the same dimensions as the array
703        this will raise a UnitConversionError
704
705        Examples
706        --------
707        >>> from unyt import erg, s
708        >>> E = 2.5*erg/s
709        >>> E.convert_to_base("mks")
710        >>> E
711        unyt_quantity(2.5e-07, 'W')
712        """
713        self.convert_to_units(
714            self.units.get_base_equivalent(unit_system),
715            equivalence=equivalence,
716            **kwargs
717        )
718
719    def convert_to_cgs(self, equivalence=None, **kwargs):
720        """
721        Convert the array and in-place to the equivalent cgs units.
722
723        Optionally, an equivalence can be specified to convert to an
724        equivalent quantity which is not in the same dimensions.
725
726        Parameters
727        ----------
728        equivalence : string, optional
729            The equivalence you wish to use. To see which equivalencies
730            are supported for this object, try the ``list_equivalencies``
731            method. Default: None
732        kwargs: optional
733            Any additional keyword arguments are supplied to the equivalence
734
735        Raises
736        ------
737        If the provided unit does not have the same dimensions as the array
738        this will raise a UnitConversionError
739
740        Examples
741        --------
742        >>> from unyt import Newton
743        >>> data = [1., 2., 3.]*Newton
744        >>> data.convert_to_cgs()
745        >>> data
746        unyt_array([100000., 200000., 300000.], 'dyn')
747
748        """
749        self.convert_to_units(
750            self.units.get_cgs_equivalent(), equivalence=equivalence, **kwargs
751        )
752
753    def convert_to_mks(self, equivalence=None, **kwargs):
754        """
755        Convert the array and units to the equivalent mks units.
756
757        Optionally, an equivalence can be specified to convert to an
758        equivalent quantity which is not in the same dimensions.
759
760        Parameters
761        ----------
762        equivalence : string, optional
763            The equivalence you wish to use. To see which equivalencies
764            are supported for this object, try the ``list_equivalencies``
765            method. Default: None
766        kwargs: optional
767            Any additional keyword arguments are supplied to the equivalence
768
769        Raises
770        ------
771        If the provided unit does not have the same dimensions as the array
772        this will raise a UnitConversionError
773
774        Examples
775        --------
776        >>> from unyt import dyne, erg
777        >>> data = [1., 2., 3.]*erg
778        >>> data
779        unyt_array([1., 2., 3.], 'erg')
780        >>> data.convert_to_mks()
781        >>> data
782        unyt_array([1.e-07, 2.e-07, 3.e-07], 'J')
783        """
784        self.convert_to_units(self.units.get_mks_equivalent(), equivalence, **kwargs)
785
786    def in_units(self, units, equivalence=None, **kwargs):
787        """
788        Creates a copy of this array with the data converted to the
789        supplied units, and returns it.
790
791        Optionally, an equivalence can be specified to convert to an
792        equivalent quantity which is not in the same dimensions.
793
794        Parameters
795        ----------
796        units : Unit object or string
797            The units you want to get a new quantity in.
798        equivalence : string, optional
799            The equivalence you wish to use. To see which equivalencies
800            are supported for this object, try the ``list_equivalencies``
801            method. Default: None
802        kwargs: optional
803            Any additional keyword arguments are supplied to the
804            equivalence
805
806        Raises
807        ------
808        If the provided unit does not have the same dimensions as the array
809        this will raise a UnitConversionError
810
811        Examples
812        --------
813        >>> from unyt import c, gram
814        >>> m = 10*gram
815        >>> E = m*c**2
816        >>> print(E.in_units('erg'))
817        8.987551787368176e+21 erg
818        >>> print(E.in_units('J'))
819        898755178736817.6 J
820        """
821        units = _sanitize_units_convert(units, self.units.registry)
822        if equivalence is None:
823            conv_data = _check_em_conversion(
824                self.units, units, registry=self.units.registry
825            )
826            if any(conv_data):
827                new_units, (conversion_factor, offset) = _em_conversion(
828                    self.units, conv_data, units
829                )
830                offset = 0
831            else:
832                new_units = units
833                (conversion_factor, offset) = self.units.get_conversion_factor(
834                    new_units, self.dtype
835                )
836            dsize = self.dtype.itemsize
837            if self.dtype.kind in ("u", "i"):
838                large = LARGE_INPUT.get(dsize, 0)
839                if large and np.any(np.abs(self.d) > large):
840                    warnings.warn(
841                        "Overflow encountered while converting to units '%s'"
842                        % new_units,
843                        RuntimeWarning,
844                        stacklevel=2,
845                    )
846            new_dtype = np.dtype("f" + str(dsize))
847            conversion_factor = new_dtype.type(conversion_factor)
848            ret = np.asarray(self.ndview * conversion_factor, dtype=new_dtype)
849            if offset:
850                np.subtract(ret, offset, ret)
851
852            try:
853                new_array = type(self)(
854                    ret, new_units, bypass_validation=True, name=self.name
855                )
856            except TypeError:
857                # subclasses might not take name as a kwarg
858                new_array = type(self)(ret, new_units, bypass_validation=True)
859
860            return new_array
861        else:
862            return self.to_equivalent(units, equivalence, **kwargs)
863
864    def to(self, units, equivalence=None, **kwargs):
865        """
866        Creates a copy of this array with the data converted to the
867        supplied units, and returns it.
868
869        Optionally, an equivalence can be specified to convert to an
870        equivalent quantity which is not in the same dimensions.
871
872        .. note::
873
874            All additional keyword arguments are passed to the
875            equivalency, which should be used if that particular
876            equivalency requires them.
877
878        Parameters
879        ----------
880        units : Unit object or string
881            The units you want to get a new quantity in.
882        equivalence : string, optional
883            The equivalence you wish to use. To see which
884            equivalencies are supported for this unitful
885            quantity, try the :meth:`list_equivalencies`
886            method. Default: None
887        kwargs: optional
888            Any additional keywoard arguments are supplied to the
889            equivalence
890
891        Raises
892        ------
893        If the provided unit does not have the same dimensions as the array
894        this will raise a UnitConversionError
895
896        Examples
897        --------
898        >>> from unyt import c, gram
899        >>> m = 10*gram
900        >>> E = m*c**2
901        >>> print(E.to('erg'))
902        8.987551787368176e+21 erg
903        >>> print(E.to('J'))
904        898755178736817.6 J
905        """
906        return self.in_units(units, equivalence=equivalence, **kwargs)
907
908    def to_value(self, units=None, equivalence=None, **kwargs):
909        """
910        Creates a copy of this array with the data in the supplied
911        units, and returns it without units. Output is therefore a
912        bare NumPy array.
913
914        Optionally, an equivalence can be specified to convert to an
915        equivalent quantity which is not in the same dimensions.
916
917        .. note::
918
919            All additional keyword arguments are passed to the
920            equivalency, which should be used if that particular
921            equivalency requires them.
922
923        Parameters
924        ----------
925        units : Unit object or string, optional
926            The units you want to get the bare quantity in. If not
927            specified, the value will be returned in the current units.
928
929        equivalence : string, optional
930            The equivalence you wish to use. To see which
931            equivalencies are supported for this unitful
932            quantity, try the :meth:`list_equivalencies`
933            method. Default: None
934
935        Examples
936        --------
937        >>> from unyt import km
938        >>> a = [3, 4, 5]*km
939        >>> print(a.to_value('cm'))
940        [300000. 400000. 500000.]
941        """
942        if units is None:
943            v = self.value
944        else:
945            v = self.in_units(units, equivalence=equivalence, **kwargs).value
946        if isinstance(self, unyt_quantity):
947            return float(v)
948        else:
949            return v
950
951    def in_base(self, unit_system=None):
952        """
953        Creates a copy of this array with the data in the specified unit
954        system, and returns it in that system's base units.
955
956        Parameters
957        ----------
958        unit_system : string, optional
959            The unit system to be used in the conversion. If not specified,
960            the configured default base units of are used (defaults to MKS).
961
962        Examples
963        --------
964        >>> from unyt import erg, s
965        >>> E = 2.5*erg/s
966        >>> print(E.in_base("mks"))
967        2.5e-07 W
968        """
969        us = _sanitize_unit_system(unit_system, self)
970        try:
971            conv_data = _check_em_conversion(
972                self.units, unit_system=us, registry=self.units.registry
973            )
974        except MKSCGSConversionError:
975            raise UnitsNotReducible(self.units, us)
976        if any(conv_data):
977            um = us.units_map
978            u = self.units
979            if u.dimensions in um and u.expr == um[self.units.dimensions]:
980                return self.copy()
981            to_units, (conv, offset) = _em_conversion(u, conv_data, unit_system=us)
982        else:
983            to_units = self.units.get_base_equivalent(unit_system)
984            conv, offset = self.units.get_conversion_factor(to_units, self.dtype)
985        new_dtype = np.dtype("f" + str(self.dtype.itemsize))
986        conv = new_dtype.type(conv)
987        ret = self.v * conv
988        if offset:
989            ret = ret - offset
990        return type(self)(ret, to_units)
991
992    def in_cgs(self):
993        """
994        Creates a copy of this array with the data in the equivalent cgs units,
995        and returns it.
996
997        Returns
998        -------
999        unyt_array object with data in this array converted to cgs units.
1000
1001        Example
1002        -------
1003        >>> from unyt import Newton, km
1004        >>> print((10*Newton/km).in_cgs())
1005        10.0 g/s**2
1006        """
1007        return self.in_base("cgs")
1008
1009    def in_mks(self):
1010        """
1011        Creates a copy of this array with the data in the equivalent mks units,
1012        and returns it.
1013
1014        Returns
1015        -------
1016        unyt_array object with data in this array converted to mks units.
1017
1018        Example
1019        -------
1020        >>> from unyt import mile
1021        >>> print((1.*mile).in_mks())
1022        1609.344 m
1023        """
1024        return self.in_base("mks")
1025
1026    def convert_to_equivalent(self, unit, equivalence, **kwargs):
1027        """
1028        Convert the array in-place to the specified units, assuming
1029        the given equivalency. The dimensions of the specified units and the
1030        dimensions of the original array need not match so long as there is an
1031        appropriate conversion in the specified equivalency.
1032
1033        Parameters
1034        ----------
1035        unit : string
1036            The unit that you wish to convert to.
1037        equivalence : string
1038            The equivalence you wish to use. To see which equivalencies are
1039            supported for this unitful quantity, try the
1040            :meth:`list_equivalencies` method.
1041
1042        Examples
1043        --------
1044        >>> from unyt import K
1045        >>> a = [10, 20, 30]*(1e7*K)
1046        >>> a.convert_to_equivalent("keV", "thermal")
1047        >>> a
1048        unyt_array([ 8.6173324, 17.2346648, 25.8519972], 'keV')
1049        """
1050        conv_unit = Unit(unit, registry=self.units.registry)
1051        if self.units.same_dimensions_as(conv_unit):
1052            self.convert_to_units(conv_unit)
1053            return
1054        this_equiv = equivalence_registry[equivalence](in_place=True)
1055        if self.has_equivalent(equivalence):
1056            this_equiv.convert(self, conv_unit.dimensions, **kwargs)
1057            self.convert_to_units(conv_unit)
1058            # set name to None since the semantic meaning has changed
1059            self.name = None
1060        else:
1061            raise InvalidUnitEquivalence(equivalence, self.units, conv_unit)
1062
1063    def to_equivalent(self, unit, equivalence, **kwargs):
1064        """
1065        Return a copy of the unyt_array in the units specified units, assuming
1066        the given equivalency. The dimensions of the specified units and the
1067        dimensions of the original array need not match so long as there is an
1068        appropriate conversion in the specified equivalency.
1069
1070        Parameters
1071        ----------
1072        unit : string
1073            The unit that you wish to convert to.
1074        equivalence : string
1075            The equivalence you wish to use. To see which equivalencies are
1076            supported for this unitful quantity, try the
1077            :meth:`list_equivalencies` method.
1078
1079        Examples
1080        --------
1081        >>> from unyt import K
1082        >>> a = 1.0e7*K
1083        >>> print(a.to_equivalent("keV", "thermal"))
1084        0.8617332401096504 keV
1085        """
1086        conv_unit = Unit(unit, registry=self.units.registry)
1087        if self.units.same_dimensions_as(conv_unit):
1088            return self.in_units(conv_unit)
1089        this_equiv = equivalence_registry[equivalence]()
1090        if self.has_equivalent(equivalence):
1091            new_arr = this_equiv.convert(self, conv_unit.dimensions, **kwargs)
1092            return new_arr.in_units(conv_unit)
1093        else:
1094            raise InvalidUnitEquivalence(equivalence, self.units, unit)
1095
1096    def list_equivalencies(self):
1097        """
1098        Lists the possible equivalencies associated with this unyt_array or
1099        unyt_quantity.
1100
1101        Example
1102        -------
1103        >>> from unyt import km
1104        >>> (1.0*km).list_equivalencies()
1105        spectral: length <-> spatial_frequency <-> frequency <-> energy
1106        schwarzschild: mass <-> length
1107        compton: mass <-> length
1108        """
1109        self.units.list_equivalencies()
1110
1111    def has_equivalent(self, equivalence):
1112        """
1113        Check to see if this unyt_array or unyt_quantity has an equivalent
1114        unit in *equiv*.
1115
1116        Example
1117        -------
1118        >>> from unyt import km, keV
1119        >>> (1.0*km).has_equivalent('spectral')
1120        True
1121        >>> print((1*km).to_equivalent('MHz', equivalence='spectral'))
1122        0.299792458 MHz
1123        >>> print((1*keV).to_equivalent('angstrom', equivalence='spectral'))
1124        12.39841931521966 Å
1125        """
1126        return self.units.has_equivalent(equivalence)
1127
1128    def ndarray_view(self):
1129        """
1130        Returns a view into the array as a numpy array
1131
1132        Returns
1133        -------
1134        View of this array's data.
1135
1136        Example
1137        -------
1138
1139        >>> from unyt import km
1140        >>> a = [3, 4, 5]*km
1141        >>> a
1142        unyt_array([3, 4, 5], 'km')
1143        >>> a.ndarray_view()
1144        array([3, 4, 5])
1145
1146        This function returns a view that shares the same underlying memory
1147        as the original array.
1148
1149        >>> b = a.ndarray_view()
1150        >>> b.base is a.base
1151        True
1152        >>> b[2] = 4
1153        >>> b
1154        array([3, 4, 4])
1155        >>> a
1156        unyt_array([3, 4, 4], 'km')
1157        """
1158        return self.view(np.ndarray)
1159
1160    def to_ndarray(self):
1161        """
1162        Creates a copy of this array with the unit information stripped
1163
1164        Example
1165        -------
1166        >>> from unyt import km
1167        >>> a = [3, 4, 5]*km
1168        >>> a
1169        unyt_array([3, 4, 5], 'km')
1170        >>> b = a.to_ndarray()
1171        >>> b
1172        array([3, 4, 5])
1173
1174        The returned array will contain a copy of the data contained in
1175        the original array.
1176
1177        >>> a.base is not b.base
1178        True
1179
1180        """
1181        return np.array(self)
1182
1183    def argsort(self, axis=-1, kind="quicksort", order=None):
1184        """
1185        Returns the indices that would sort the array.
1186
1187        See the documentation of ndarray.argsort for details about the keyword
1188        arguments.
1189
1190        Example
1191        -------
1192        >>> from unyt import km
1193        >>> data = [3, 8, 7]*km
1194        >>> print(np.argsort(data))
1195        [0 2 1]
1196        >>> print(data.argsort())
1197        [0 2 1]
1198        """
1199        return self.view(np.ndarray).argsort(axis, kind, order)
1200
1201    @classmethod
1202    def from_astropy(cls, arr, unit_registry=None):
1203        """
1204        Convert an AstroPy "Quantity" to a unyt_array or unyt_quantity.
1205
1206        Parameters
1207        ----------
1208        arr : AstroPy Quantity
1209            The Quantity to convert from.
1210        unit_registry : yt UnitRegistry, optional
1211            A yt unit registry to use in the conversion. If one is not
1212            supplied, the default one will be used.
1213
1214        Example
1215        -------
1216        >>> from astropy.units import km
1217        >>> unyt_quantity.from_astropy(km)
1218        unyt_quantity(1., 'km')
1219        >>> a = [1, 2, 3]*km
1220        >>> a
1221        <Quantity [1., 2., 3.] km>
1222        >>> unyt_array.from_astropy(a)
1223        unyt_array([1., 2., 3.], 'km')
1224        """
1225        # Converting from AstroPy Quantity
1226        try:
1227            u = arr.unit
1228            _arr = arr
1229        except AttributeError:
1230            u = arr
1231            _arr = 1.0 * u
1232        ap_units = []
1233        for base, exponent in zip(u.bases, u.powers):
1234            unit_str = base.to_string()
1235            # we have to do this because AstroPy is silly and defines
1236            # hour as "h"
1237            if unit_str == "h":
1238                unit_str = "hr"
1239            ap_units.append("%s**(%s)" % (unit_str, Rational(exponent)))
1240        ap_units = "*".join(ap_units)
1241        if isinstance(_arr.value, np.ndarray) and _arr.shape != ():
1242            return unyt_array(_arr.value, ap_units, registry=unit_registry)
1243        else:
1244            return unyt_quantity(_arr.value, ap_units, registry=unit_registry)
1245
1246    def to_astropy(self, **kwargs):
1247        """
1248        Creates a new AstroPy quantity with the same unit information.
1249
1250        Example
1251        -------
1252        >>> from unyt import g, cm
1253        >>> data = [3, 4, 5]*g/cm**3
1254        >>> data.to_astropy()
1255        <Quantity [3., 4., 5.] g / cm3>
1256        """
1257        return self.value * _astropy.units.Unit(str(self.units), **kwargs)
1258
1259    @classmethod
1260    def from_pint(cls, arr, unit_registry=None):
1261        """
1262        Convert a Pint "Quantity" to a unyt_array or unyt_quantity.
1263
1264        Parameters
1265        ----------
1266        arr : Pint Quantity
1267            The Quantity to convert from.
1268        unit_registry : yt UnitRegistry, optional
1269            A yt unit registry to use in the conversion. If one is not
1270            supplied, the default one will be used.
1271
1272        Examples
1273        --------
1274        >>> from pint import UnitRegistry
1275        >>> import numpy as np
1276        >>> ureg = UnitRegistry()
1277        >>> a = np.arange(4)
1278        >>> b = ureg.Quantity(a, "erg/cm**3")
1279        >>> b
1280        <Quantity([0 1 2 3], 'erg / centimeter ** 3')>
1281        >>> c = unyt_array.from_pint(b)
1282        >>> c
1283        unyt_array([0, 1, 2, 3], 'erg/cm**3')
1284        """
1285        p_units = []
1286        for base, exponent in arr._units.items():
1287            bs = convert_pint_units(base)
1288            p_units.append("%s**(%s)" % (bs, Rational(exponent)))
1289        p_units = "*".join(p_units)
1290        if isinstance(arr.magnitude, np.ndarray):
1291            return unyt_array(arr.magnitude, p_units, registry=unit_registry)
1292        else:
1293            return unyt_quantity(arr.magnitude, p_units, registry=unit_registry)
1294
1295    def to_pint(self, unit_registry=None):
1296        """
1297        Convert a unyt_array or unyt_quantity to a Pint Quantity.
1298
1299        Parameters
1300        ----------
1301        arr : unyt_array or unyt_quantity
1302            The unitful quantity to convert from.
1303        unit_registry : Pint UnitRegistry, optional
1304            The Pint UnitRegistry to use in the conversion. If one is not
1305            supplied, the default one will be used. NOTE: This is not
1306            the same as a yt UnitRegistry object.
1307
1308        Examples
1309        --------
1310        >>> from unyt import cm, s
1311        >>> a = 4*cm**2/s
1312        >>> print(a)
1313        4 cm**2/s
1314        >>> a.to_pint()
1315        <Quantity(4, 'centimeter ** 2 / second')>
1316        """
1317        if unit_registry is None:
1318            unit_registry = _pint.UnitRegistry()
1319        powers_dict = self.units.expr.as_powers_dict()
1320        units = []
1321        for unit, pow in powers_dict.items():
1322            # we have to do this because Pint doesn't recognize
1323            # "yr" as "year"
1324            if str(unit).endswith("yr") and len(str(unit)) in [2, 3]:
1325                unit = str(unit).replace("yr", "year")
1326            units.append("%s**(%s)" % (unit, Rational(pow)))
1327        units = "*".join(units)
1328        return unit_registry.Quantity(self.value, units)
1329
1330    #
1331    # End unit conversion methods
1332    #
1333
1334    def write_hdf5(self, filename, dataset_name=None, info=None, group_name=None):
1335        r"""Writes a unyt_array to hdf5 file.
1336
1337        Parameters
1338        ----------
1339        filename: string
1340            The filename to create and write a dataset to
1341
1342        dataset_name: string
1343            The name of the dataset to create in the file.
1344
1345        info: dictionary
1346            A dictionary of supplementary info to write to append as attributes
1347            to the dataset.
1348
1349        group_name: string
1350            An optional group to write the arrays to. If not specified, the
1351            arrays are datasets at the top level by default.
1352
1353        Examples
1354        --------
1355        >>> from unyt import cm
1356        >>> a = [1,2,3]*cm
1357        >>> myinfo = {'field':'dinosaurs', 'type':'field_data'}
1358        >>> a.write_hdf5('test_array_data.h5', dataset_name='dinosaurs',
1359        ...              info=myinfo)  # doctest: +SKIP
1360        """
1361        from unyt._on_demand_imports import _h5py as h5py
1362        import pickle
1363
1364        if info is None:
1365            info = {}
1366
1367        info["units"] = str(self.units)
1368        lut = {}
1369        for k, v in self.units.registry.lut.items():
1370            if k not in default_unit_registry.lut:
1371                lut[k] = v
1372        info["unit_registry"] = np.void(pickle.dumps(lut))
1373
1374        if dataset_name is None:
1375            dataset_name = "array_data"
1376
1377        f = h5py.File(filename, "a")
1378        if group_name is not None:
1379            if group_name in f:
1380                g = f[group_name]
1381            else:
1382                g = f.create_group(group_name)
1383        else:
1384            g = f
1385        if dataset_name in g.keys():
1386            d = g[dataset_name]
1387            # Overwrite without deleting if we can get away with it.
1388            if d.shape == self.shape and d.dtype == self.dtype:
1389                d[...] = self
1390                for k in d.attrs.keys():
1391                    del d.attrs[k]
1392            else:
1393                del f[dataset_name]
1394                d = g.create_dataset(dataset_name, data=self)
1395        else:
1396            d = g.create_dataset(dataset_name, data=self)
1397
1398        for k, v in info.items():
1399            d.attrs[k] = v
1400        f.close()
1401
1402    @classmethod
1403    def from_hdf5(cls, filename, dataset_name=None, group_name=None):
1404        r"""Attempts read in and convert a dataset in an hdf5 file into a
1405        unyt_array.
1406
1407        Parameters
1408        ----------
1409        filename: string
1410        The filename to of the hdf5 file.
1411
1412        dataset_name: string
1413            The name of the dataset to read from.  If the dataset has a units
1414            attribute, attempt to infer units as well.
1415
1416        group_name: string
1417            An optional group to read the arrays from. If not specified, the
1418            arrays are datasets at the top level by default.
1419
1420        """
1421        from unyt._on_demand_imports import _h5py as h5py
1422        import pickle
1423
1424        if dataset_name is None:
1425            dataset_name = "array_data"
1426
1427        f = h5py.File(filename, "r")
1428        if group_name is not None:
1429            g = f[group_name]
1430        else:
1431            g = f
1432        dataset = g[dataset_name]
1433        data = dataset[:]
1434        units = dataset.attrs.get("units", "")
1435        unit_lut = default_unit_symbol_lut.copy()
1436        unit_lut_load = pickle.loads(dataset.attrs["unit_registry"].tobytes())
1437        unit_lut.update(unit_lut_load)
1438        f.close()
1439        registry = UnitRegistry(lut=unit_lut, add_default_symbols=False)
1440        return cls(data, units, registry=registry)
1441
1442    #
1443    # Start convenience methods
1444    #
1445
1446    @property
1447    def value(self):
1448        """
1449        Creates a copy of this array with the unit information stripped
1450
1451        Example
1452        -------
1453        >>> from unyt import km
1454        >>> a = [3, 4, 5]*km
1455        >>> a
1456        unyt_array([3, 4, 5], 'km')
1457        >>> b = a.value
1458        >>> b
1459        array([3, 4, 5])
1460
1461        The returned array will contain a copy of the data contained in
1462        the original array.
1463
1464        >>> a.base is not b.base
1465        True
1466
1467        """
1468        return np.array(self)
1469
1470    @property
1471    def v(self):
1472        """
1473        Creates a copy of this array with the unit information stripped
1474
1475        Example
1476        -------
1477        >>> from unyt import km
1478        >>> a = [3, 4, 5]*km
1479        >>> a
1480        unyt_array([3, 4, 5], 'km')
1481        >>> b = a.v
1482        >>> b
1483        array([3, 4, 5])
1484
1485        The returned array will contain a copy of the data contained in
1486        the original array.
1487
1488        >>> a.base is not b.base
1489        True
1490
1491        """
1492        return np.array(self)
1493
1494    @property
1495    def ndview(self):
1496        """
1497        Returns a view into the array as a numpy array
1498
1499        Returns
1500        -------
1501        View of this array's data.
1502
1503        Example
1504        -------
1505
1506        >>> from unyt import km
1507        >>> a = [3, 4, 5]*km
1508        >>> a
1509        unyt_array([3, 4, 5], 'km')
1510        >>> a.ndview
1511        array([3, 4, 5])
1512
1513        This function returns a view that shares the same underlying memory
1514        as the original array.
1515
1516        >>> b = a.ndview
1517        >>> b.base is a.base
1518        True
1519        >>> b[2] = 4
1520        >>> b
1521        array([3, 4, 4])
1522        >>> a
1523        unyt_array([3, 4, 4], 'km')
1524
1525        """
1526        return self.view(np.ndarray)
1527
1528    @property
1529    def d(self):
1530        """
1531        Returns a view into the array as a numpy array
1532
1533        Returns
1534        -------
1535        View of this array's data.
1536
1537        Example
1538        -------
1539
1540        >>> from unyt import km
1541        >>> a = [3, 4, 5]*km
1542        >>> a
1543        unyt_array([3, 4, 5], 'km')
1544        >>> a.d
1545        array([3, 4, 5])
1546
1547        This function returns a view that shares the same underlying memory
1548        as the original array.
1549
1550        >>> b = a.d
1551        >>> b.base is a.base
1552        True
1553        >>> b[2] = 4
1554        >>> b
1555        array([3, 4, 4])
1556        >>> a
1557        unyt_array([3, 4, 4], 'km')
1558        """
1559        return self.view(np.ndarray)
1560
1561    @property
1562    def unit_quantity(self):
1563        """
1564        Return a quantity with a value of 1 and the same units as this array
1565
1566        Example
1567        -------
1568        >>> from unyt import km
1569        >>> a = [4, 5, 6]*km
1570        >>> a.unit_quantity
1571        unyt_quantity(1, 'km')
1572        >>> print(a + 7*a.unit_quantity)
1573        [11 12 13] km
1574        """
1575        return unyt_quantity(1, self.units)
1576
1577    @property
1578    def uq(self):
1579        """
1580        Return a quantity with a value of 1 and the same units as this array
1581
1582        Example
1583        -------
1584        >>> from unyt import km
1585        >>> a = [4, 5, 6]*km
1586        >>> a.uq
1587        unyt_quantity(1, 'km')
1588        >>> print(a + 7*a.uq)
1589        [11 12 13] km
1590        """
1591        return unyt_quantity(1, self.units)
1592
1593    @property
1594    def unit_array(self):
1595        """
1596        Return an array filled with ones with the same units as this array
1597
1598        Example
1599        -------
1600        >>> from unyt import km
1601        >>> a = [4, 5, 6]*km
1602        >>> a.unit_array
1603        unyt_array([1, 1, 1], 'km')
1604        >>> print(a + 7*a.unit_array)
1605        [11 12 13] km
1606        """
1607        return np.ones_like(self)
1608
1609    @property
1610    def ua(self):
1611        """
1612        Return an array filled with ones with the same units as this array
1613
1614        Example
1615        -------
1616        >>> from unyt import km
1617        >>> a = [4, 5, 6]*km
1618        >>> a.unit_array
1619        unyt_array([1, 1, 1], 'km')
1620        >>> print(a + 7*a.unit_array)
1621        [11 12 13] km
1622        """
1623        return np.ones_like(self)
1624
1625    def __getitem__(self, item):
1626        ret = super(unyt_array, self).__getitem__(item)
1627        if ret.shape == ():
1628            return unyt_quantity(
1629                ret, self.units, bypass_validation=True, name=self.name
1630            )
1631        else:
1632            if hasattr(self, "units"):
1633                ret.units = self.units
1634            return ret
1635
1636    #
1637    # Start operation methods
1638    #
1639
1640    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
1641        func = getattr(ufunc, method)
1642        if "out" not in kwargs:
1643            if ufunc in multiple_output_operators:
1644                out = (None,) * multiple_output_operators[ufunc]
1645                out_func = out
1646            else:
1647                out = None
1648                out_func = None
1649        else:
1650            # we need to get both the actual "out" object and a view onto it
1651            # in case we need to do in-place operations
1652            out = kwargs.pop("out")
1653            if ufunc in multiple_output_operators:
1654                out_func = []
1655                for arr in out:
1656                    out_func.append(arr.view(np.ndarray))
1657                out_func = tuple(out_func)
1658            else:
1659                out = out[0]
1660                if out.dtype.kind in ("u", "i"):
1661                    new_dtype = "f" + str(out.dtype.itemsize)
1662                    float_values = out.astype(new_dtype)
1663                    out.dtype = new_dtype
1664                    np.copyto(out, float_values)
1665                out_func = out.view(np.ndarray)
1666        if len(inputs) == 1:
1667            # Unary ufuncs
1668            inp = inputs[0]
1669            u = getattr(inp, "units", None)
1670            if u.dimensions is angle and ufunc in trigonometric_operators:
1671                # ensure np.sin(90*degrees) works as expected
1672                inp = inp.in_units("radian").v
1673            # evaluate the ufunc
1674            out_arr = func(np.asarray(inp), out=out_func, **kwargs)
1675            if ufunc in (multiply, divide) and method == "reduce":
1676                # a reduction of a multiply or divide corresponds to
1677                # a repeated product which we implement as an exponent
1678                mul = 1
1679                power_sign = POWER_SIGN_MAPPING[ufunc]
1680                if "axis" in kwargs and kwargs["axis"] is not None:
1681                    unit = u ** (power_sign * inp.shape[kwargs["axis"]])
1682                else:
1683                    unit = u ** (power_sign * inp.size)
1684            else:
1685                # get unit of result
1686                mul, unit = self._ufunc_registry[ufunc](u)
1687            # use type(self) here so we can support user-defined
1688            # subclasses of unyt_array
1689            ret_class = type(self)
1690        elif len(inputs) == 2:
1691            # binary ufuncs
1692            i0 = inputs[0]
1693            i1 = inputs[1]
1694            # coerce inputs to be ndarrays if they aren't already
1695            inp0 = _coerce_iterable_units(i0)
1696            inp1 = _coerce_iterable_units(i1)
1697            u0 = getattr(i0, "units", None) or getattr(inp0, "units", None)
1698            u1 = getattr(i1, "units", None) or getattr(inp1, "units", None)
1699            ret_class = _get_binary_op_return_class(type(i0), type(i1))
1700            if u0 is None:
1701                u0 = Unit(registry=getattr(u1, "registry", None))
1702            if u1 is None and ufunc is not power:
1703                u1 = Unit(registry=getattr(u0, "registry", None))
1704            elif ufunc is power:
1705                u1 = inp1
1706                if inp0.shape != () and inp1.shape != ():
1707                    raise UnitOperationError(ufunc, u0, u1)
1708                if isinstance(u1, unyt_array):
1709                    if u1.units.is_dimensionless:
1710                        pass
1711                    else:
1712                        raise UnitOperationError(ufunc, u0, u1.units)
1713                if u1.shape == ():
1714                    u1 = float(u1)
1715                else:
1716                    u1 = 1.0
1717            unit_operator = self._ufunc_registry[ufunc]
1718            if unit_operator in (_preserve_units, _comparison_unit, _arctan2_unit):
1719                # check "is" equality first for speed
1720                if u0 is not u1 and u0 != u1:
1721                    # we allow adding, multiplying, comparisons with
1722                    # zero-filled arrays, lists, etc or scalar zero. We
1723                    # do not allow zero-filled unyt_array instances for
1724                    # performance reasons. If we did allow it, every
1725                    # binary operation would need to scan over all the
1726                    # elements of both arrays to check for arrays filled
1727                    # with zeros
1728                    if not isinstance(i0, unyt_array) or not isinstance(i1, unyt_array):
1729                        any_nonzero = [np.count_nonzero(i0), np.count_nonzero(i1)]
1730                        if any_nonzero[0] == 0:
1731                            u0 = u1
1732                        elif any_nonzero[1] == 0:
1733                            u1 = u0
1734                    if not u0.same_dimensions_as(u1):
1735                        if unit_operator is _comparison_unit:
1736                            # we allow comparisons between data with
1737                            # units and dimensionless data
1738                            if u0.is_dimensionless:
1739                                u0 = u1
1740                            elif u1.is_dimensionless:
1741                                u1 = u0
1742                            else:
1743                                # comparison with different units, so need to check if
1744                                # this is == and != which we allow and handle in a
1745                                # special way using an early return from __array_ufunc__
1746                                if ufunc in (equal, not_equal):
1747                                    if ufunc is equal:
1748                                        func = np.zeros_like
1749                                    else:
1750                                        func = np.ones_like
1751                                    ret = func(np.asarray(inp1), dtype=bool)
1752                                    if out is not None:
1753                                        out[:] = ret[:]
1754                                        if isinstance(out, unyt_array):
1755                                            out.units = Unit(
1756                                                "", registry=self.units.registry
1757                                            )
1758                                    if ret.shape == ():
1759                                        ret = bool(ret)
1760                                    return ret
1761                                else:
1762                                    raise UnitOperationError(ufunc, u0, u1)
1763                        else:
1764                            raise UnitOperationError(ufunc, u0, u1)
1765                    conv, offset = u1.get_conversion_factor(u0, inp1.dtype)
1766                    new_dtype = np.dtype("f" + str(inp1.dtype.itemsize))
1767                    conv = new_dtype.type(conv)
1768                    if offset is not None:
1769                        raise InvalidUnitOperation(
1770                            "Quantities with units of Fahrenheit or Celsius "
1771                            "cannot by multiplied, divided, subtracted or "
1772                            "added with data that has different units."
1773                        )
1774                    inp1 = np.asarray(inp1, dtype=new_dtype) * conv
1775            # get the unit of the result
1776            mul, unit = unit_operator(u0, u1)
1777            # actually evaluate the ufunc
1778            out_arr = func(
1779                inp0.view(np.ndarray), inp1.view(np.ndarray), out=out_func, **kwargs
1780            )
1781            if unit_operator in (_multiply_units, _divide_units):
1782                if unit.is_dimensionless and unit.base_value != 1.0:
1783                    if not u0.is_dimensionless:
1784                        if u0.dimensions == u1.dimensions:
1785                            out_arr = np.multiply(
1786                                out_arr.view(np.ndarray), unit.base_value, out=out_func
1787                            )
1788                            unit = Unit(registry=unit.registry)
1789                if (
1790                    u0.base_offset
1791                    and u0.dimensions is temperature
1792                    or u1.base_offset
1793                    and u1.dimensions is temperature
1794                ):
1795                    raise InvalidUnitOperation(
1796                        "Quantities with units of Fahrenheit or Celsius "
1797                        "cannot by multiplied, divide, subtracted or added."
1798                    )
1799        else:
1800            if ufunc is clip:
1801                inp = []
1802                for i in inputs:
1803                    if isinstance(i, unyt_array):
1804                        inp.append(i.to(inputs[0].units).view(np.ndarray))
1805                    else:
1806                        inp.append(i)
1807                if out is not None:
1808                    _out = out.view(np.ndarray)
1809                else:
1810                    _out = None
1811                out_arr = ufunc(*inp, out=_out)
1812                unit = inputs[0].units
1813                ret_class = type(inputs[0])
1814                mul = 1
1815            else:
1816                raise RuntimeError(
1817                    "Support for the %s ufunc with %i inputs has not been "
1818                    "added to unyt_array." % (str(ufunc), len(inputs))
1819                )
1820        if unit is None:
1821            out_arr = np.array(out_arr, copy=False)
1822        elif ufunc in (modf, divmod_):
1823            out_arr = tuple((ret_class(o, unit) for o in out_arr))
1824        elif out_arr.size == 1:
1825            out_arr = unyt_quantity(np.asarray(out_arr), unit)
1826        else:
1827            if ret_class is unyt_quantity:
1828                # This happens if you do ndarray * unyt_quantity.
1829                # Explicitly casting to unyt_array avoids creating a
1830                # unyt_quantity with size > 1
1831                out_arr = unyt_array(out_arr, unit)
1832            else:
1833                out_arr = ret_class(out_arr, unit, bypass_validation=True)
1834        if out is not None:
1835            if mul != 1:
1836                multiply(out, mul, out=out)
1837                if np.shares_memory(out_arr, out):
1838                    mul = 1
1839            if isinstance(out, unyt_array):
1840                try:
1841                    out.units = out_arr.units
1842                except AttributeError:
1843                    # out_arr is an ndarray
1844                    out.units = Unit("", registry=self.units.registry)
1845            elif isinstance(out, tuple):
1846                for o, oa in zip(out, out_arr):
1847                    if o is None:
1848                        continue
1849                    try:
1850                        o.units = oa.units
1851                    except AttributeError:
1852                        o.units = Unit("", registry=self.units.registry)
1853        if mul == 1:
1854            return out_arr
1855        return mul * out_arr
1856
1857    def copy(self, order="C"):
1858        """
1859        Return a copy of the array.
1860
1861        Parameters
1862        ----------
1863        order : {'C', 'F', 'A', 'K'}, optional
1864            Controls the memory layout of the copy. 'C' means C-order,
1865            'F' means F-order, 'A' means 'F' if `a` is Fortran contiguous,
1866            'C' otherwise. 'K' means match the layout of `a` as closely
1867            as possible. (Note that this function and :func:`numpy.copy`
1868            are very similar, but have different default values for their
1869            order= arguments.)
1870
1871        See also
1872        --------
1873        numpy.copy
1874        numpy.copyto
1875
1876        Examples
1877        --------
1878        >>> from unyt import km
1879        >>> x = [[1,2,3],[4,5,6]] * km
1880        >>> y = x.copy()
1881        >>> x.fill(0)
1882        >>> print(x)
1883        [[0 0 0]
1884         [0 0 0]] km
1885
1886        >>> print(y)
1887        [[1 2 3]
1888         [4 5 6]] km
1889
1890        """
1891        name = getattr(self, "name", None)
1892        try:
1893            return type(self)(np.copy(np.asarray(self)), self.units, name=name)
1894        except TypeError:
1895            # subclasses might not take name as a kwarg
1896            return type(self)(np.copy(np.asarray(self)), self.units)
1897
1898    def __array_finalize__(self, obj):
1899        self.units = getattr(obj, "units", NULL_UNIT)
1900        self.name = getattr(obj, "name", None)
1901
1902    def __pos__(self):
1903        """ Posify the data. """
1904        # this needs to be defined for all numpy versions, see
1905        # numpy issue #9081
1906        return type(self)(super(unyt_array, self).__pos__(), self.units)
1907
1908    def dot(self, b, out=None):
1909        """dot product of two arrays.
1910
1911        Refer to `numpy.dot` for full documentation.
1912
1913        See Also
1914        --------
1915        numpy.dot : equivalent function
1916
1917        Examples
1918        --------
1919        >>> from unyt import km, s
1920        >>> a = np.eye(2)*km
1921        >>> b = (np.ones((2, 2)) * 2)*s
1922        >>> print(a.dot(b))
1923        [[2. 2.]
1924         [2. 2.]] km*s
1925
1926        This array method can be conveniently chained:
1927
1928        >>> print(a.dot(b).dot(b))
1929        [[8. 8.]
1930         [8. 8.]] km*s**2
1931        """
1932        res_units = self.units * getattr(b, "units", NULL_UNIT)
1933        ret = self.view(np.ndarray).dot(np.asarray(b), out=out) * res_units
1934        if out is not None:
1935            out.units = res_units
1936        return ret
1937
1938    def __reduce__(self):
1939        """Pickle reduction method
1940
1941        See the documentation for the standard library pickle module:
1942        http://docs.python.org/2/library/pickle.html
1943
1944        Unit metadata is encoded in the zeroth element of third element of the
1945        returned tuple, itself a tuple used to restore the state of the
1946        ndarray. This is always defined for numpy arrays.
1947        """
1948        np_ret = super(unyt_array, self).__reduce__()
1949        obj_state = np_ret[2]
1950        unit_state = (((str(self.units), self.units.registry.lut),) + obj_state[:],)
1951        new_ret = np_ret[:2] + unit_state + np_ret[3:]
1952        return new_ret
1953
1954    def __setstate__(self, state):
1955        """Pickle setstate method
1956
1957        This is called inside pickle.read() and restores the unit data from the
1958        metadata extracted in __reduce__ and then serialized by pickle.
1959        """
1960        super(unyt_array, self).__setstate__(state[1:])
1961        unit, lut = state[0]
1962        lut = _correct_old_unit_registry(lut)
1963        registry = UnitRegistry(lut=lut, add_default_symbols=False)
1964        self.units = Unit(unit, registry=registry)
1965
1966    def __deepcopy__(self, memodict=None):
1967        """copy.deepcopy implementation
1968
1969        This is necessary for stdlib deepcopy of arrays and quantities.
1970        """
1971        ret = super(unyt_array, self).__deepcopy__(memodict)
1972        try:
1973            return type(self)(ret, copy.deepcopy(self.units), name=self.name)
1974        except TypeError:
1975            # subclasses might not take name as a kwarg
1976            return type(self)(ret, copy.deepcopy(self.units))
1977
1978
1979class unyt_quantity(unyt_array):
1980    """
1981    A scalar associated with a unit.
1982
1983    Parameters
1984    ----------
1985
1986    input_scalar : an integer or floating point scalar
1987        The scalar to attach units to
1988    input_units : String unit specification, unit symbol object, or astropy
1989                  units
1990        The units of the quantity. Powers must be specified using python syntax
1991        (cm**3, not cm^3).
1992    registry : A UnitRegistry object
1993        The registry to create units from. If input_units is already associated
1994        with a unit registry and this is specified, this will be used instead
1995        of the registry associated with the unit object.
1996    dtype : data-type
1997        The dtype of the array data.
1998    name : string
1999        The name of the scalar. Defaults to None. This attribute does not propagate
2000        through mathematical operations, but is preserved under indexing
2001        and unit conversions.
2002
2003    Examples
2004    --------
2005
2006    >>> a = unyt_quantity(3., 'cm')
2007    >>> b = unyt_quantity(2., 'm')
2008    >>> print(a + b)
2009    203.0 cm
2010    >>> print(b + a)
2011    2.03 m
2012
2013    NumPy ufuncs will pass through units where appropriate.
2014
2015    >>> import numpy as np
2016    >>> from unyt import g, cm
2017    >>> a = 12*g/cm**3
2018    >>> print(np.abs(a))
2019    12 g/cm**3
2020
2021    and strip them when it would be annoying to deal with them.
2022
2023    >>> print(np.log10(a))
2024    1.0791812460476249
2025
2026    """
2027
2028    def __new__(
2029        cls,
2030        input_scalar,
2031        units=None,
2032        registry=None,
2033        dtype=None,
2034        bypass_validation=False,
2035        input_units=None,
2036        name=None,
2037    ):
2038        if input_units is not None:
2039            warnings.warn(
2040                "input_units has been deprecated, please use units instead",
2041                DeprecationWarning,
2042                stacklevel=2,
2043            )
2044        if units is not None:
2045            input_units = units
2046        if not (
2047            bypass_validation
2048            or isinstance(input_scalar, (numeric_type, np.number, np.ndarray))
2049        ):
2050            raise RuntimeError("unyt_quantity values must be numeric")
2051        if input_units is None:
2052            units = getattr(input_scalar, "units", None)
2053        else:
2054            units = input_units
2055        ret = unyt_array.__new__(
2056            cls,
2057            np.asarray(input_scalar),
2058            units,
2059            registry,
2060            dtype=dtype,
2061            bypass_validation=bypass_validation,
2062            name=name,
2063        )
2064        if ret.size > 1:
2065            raise RuntimeError("unyt_quantity instances must be scalars")
2066        return ret
2067
2068    def __round__(self):
2069        return type(self)(round(float(self)), self.units)
2070
2071
2072def _validate_numpy_wrapper_units(v, arrs):
2073    if not any(isinstance(a, unyt_array) for a in arrs):
2074        return v
2075    if not all(isinstance(a, unyt_array) for a in arrs):
2076        raise RuntimeError("Not all of your arrays are unyt_arrays.")
2077    a1 = arrs[0]
2078    if not all(a.units == a1.units for a in arrs[1:]):
2079        raise RuntimeError("Your arrays must have identical units.")
2080    v.units = a1.units
2081    return v
2082
2083
2084def uconcatenate(arrs, axis=0):
2085    """Concatenate a sequence of arrays.
2086
2087    This wrapper around numpy.concatenate preserves units. All input arrays
2088    must have the same units.  See the documentation of numpy.concatenate for
2089    full details.
2090
2091    Examples
2092    --------
2093    >>> from unyt import cm
2094    >>> A = [1, 2, 3]*cm
2095    >>> B = [2, 3, 4]*cm
2096    >>> uconcatenate((A, B))
2097    unyt_array([1, 2, 3, 2, 3, 4], 'cm')
2098
2099    """
2100    v = np.concatenate(arrs, axis=axis)
2101    v = _validate_numpy_wrapper_units(v, arrs)
2102    return v
2103
2104
2105def ucross(arr1, arr2, registry=None, axisa=-1, axisb=-1, axisc=-1, axis=None):
2106    """Applies the cross product to two YT arrays.
2107
2108    This wrapper around numpy.cross preserves units.
2109    See the documentation of numpy.cross for full
2110    details.
2111    """
2112    v = np.cross(arr1, arr2, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis)
2113    units = arr1.units * arr2.units
2114    arr = unyt_array(v, units, registry=registry)
2115    return arr
2116
2117
2118def uintersect1d(arr1, arr2, assume_unique=False):
2119    """Find the sorted unique elements of the two input arrays.
2120
2121    A wrapper around numpy.intersect1d that preserves units.  All input arrays
2122    must have the same units.  See the documentation of numpy.intersect1d for
2123    full details.
2124
2125    Examples
2126    --------
2127    >>> from unyt import cm
2128    >>> A = [1, 2, 3]*cm
2129    >>> B = [2, 3, 4]*cm
2130    >>> uintersect1d(A, B)
2131    unyt_array([2, 3], 'cm')
2132
2133    """
2134    v = np.intersect1d(arr1, arr2, assume_unique=assume_unique)
2135    v = _validate_numpy_wrapper_units(v, [arr1, arr2])
2136    return v
2137
2138
2139def uunion1d(arr1, arr2):
2140    """Find the union of two arrays.
2141
2142    A wrapper around numpy.intersect1d that preserves units.  All input arrays
2143    must have the same units.  See the documentation of numpy.intersect1d for
2144    full details.
2145
2146    Examples
2147    --------
2148    >>> from unyt import cm
2149    >>> A = [1, 2, 3]*cm
2150    >>> B = [2, 3, 4]*cm
2151    >>> uunion1d(A, B)
2152    unyt_array([1, 2, 3, 4], 'cm')
2153
2154    """
2155    v = np.union1d(arr1, arr2)
2156    v = _validate_numpy_wrapper_units(v, [arr1, arr2])
2157    return v
2158
2159
2160def unorm(data, ord=None, axis=None, keepdims=False):
2161    """Matrix or vector norm that preserves units
2162
2163    This is a wrapper around np.linalg.norm that preserves units. See
2164    the documentation for that function for descriptions of the keyword
2165    arguments.
2166
2167    Examples
2168    --------
2169    >>> from unyt import km
2170    >>> data = [1, 2, 3]*km
2171    >>> print(unorm(data))
2172    3.7416573867739413 km
2173    """
2174    norm = np.linalg.norm(data, ord=ord, axis=axis, keepdims=keepdims)
2175    if norm.shape == ():
2176        return unyt_quantity(norm, data.units)
2177    return unyt_array(norm, data.units)
2178
2179
2180def udot(op1, op2):
2181    """Matrix or vector dot product that preserves units
2182
2183    This is a wrapper around np.dot that preserves units.
2184
2185    Examples
2186    --------
2187    >>> from unyt import km, s
2188    >>> a = np.eye(2)*km
2189    >>> b = (np.ones((2, 2)) * 2)*s
2190    >>> print(udot(a, b))
2191    [[2. 2.]
2192     [2. 2.]] km*s
2193    """
2194    dot = np.dot(op1.d, op2.d)
2195    units = op1.units * op2.units
2196    if dot.shape == ():
2197        return unyt_quantity(dot, units)
2198    return unyt_array(dot, units)
2199
2200
2201def uvstack(arrs):
2202    """Stack arrays in sequence vertically (row wise) while preserving units
2203
2204    This is a wrapper around np.vstack that preserves units.
2205
2206    Examples
2207    --------
2208    >>> from unyt import km
2209    >>> a = [1, 2, 3]*km
2210    >>> b = [2, 3, 4]*km
2211    >>> print(uvstack([a, b]))
2212    [[1 2 3]
2213     [2 3 4]] km
2214    """
2215    v = np.vstack(arrs)
2216    v = _validate_numpy_wrapper_units(v, arrs)
2217    return v
2218
2219
2220def uhstack(arrs):
2221    """Stack arrays in sequence horizontally while preserving units
2222
2223    This is a wrapper around np.hstack that preserves units.
2224
2225    Examples
2226    --------
2227    >>> from unyt import km
2228    >>> a = [1, 2, 3]*km
2229    >>> b = [2, 3, 4]*km
2230    >>> print(uhstack([a, b]))
2231    [1 2 3 2 3 4] km
2232    >>> a = [[1],[2],[3]]*km
2233    >>> b = [[2],[3],[4]]*km
2234    >>> print(uhstack([a, b]))
2235    [[1 2]
2236     [2 3]
2237     [3 4]] km
2238    """
2239    v = np.hstack(arrs)
2240    v = _validate_numpy_wrapper_units(v, arrs)
2241    return v
2242
2243
2244def ustack(arrs, axis=0):
2245    """Join a sequence of arrays along a new axis while preserving units
2246
2247    The axis parameter specifies the index of the new axis in the
2248    dimensions of the result. For example, if ``axis=0`` it will be the
2249    first dimension and if ``axis=-1`` it will be the last dimension.
2250
2251    This is a wrapper around np.stack that preserves units. See the
2252    documentation for np.stack for full details.
2253
2254    Examples
2255    --------
2256    >>> from unyt import km
2257    >>> a = [1, 2, 3]*km
2258    >>> b = [2, 3, 4]*km
2259    >>> print(ustack([a, b]))
2260    [[1 2 3]
2261     [2 3 4]] km
2262    """
2263    v = np.stack(arrs, axis=axis)
2264    v = _validate_numpy_wrapper_units(v, arrs)
2265    return v
2266
2267
2268def _get_binary_op_return_class(cls1, cls2):
2269    if cls1 is cls2:
2270        return cls1
2271    if cls1 in (Unit, np.ndarray, np.matrix, np.ma.masked_array) or issubclass(
2272        cls1, (numeric_type, np.number, list, tuple)
2273    ):
2274        return cls2
2275    if cls2 in (Unit, np.ndarray, np.matrix, np.ma.masked_array) or issubclass(
2276        cls2, (numeric_type, np.number, list, tuple)
2277    ):
2278        return cls1
2279    if issubclass(cls1, unyt_quantity):
2280        return cls2
2281    if issubclass(cls2, unyt_quantity):
2282        return cls1
2283    if issubclass(cls1, cls2):
2284        return cls1
2285    if issubclass(cls2, cls1):
2286        return cls2
2287    else:
2288        raise RuntimeError(
2289            "Undefined operation for a unyt_array subclass. "
2290            "Received operand types (%s) and (%s)" % (cls1, cls2)
2291        )
2292
2293
2294def loadtxt(fname, dtype="float", delimiter="\t", usecols=None, comments="#"):
2295    r"""
2296    Load unyt_arrays with unit information from a text file. Each row in the
2297    text file must have the same number of values.
2298
2299    Parameters
2300    ----------
2301    fname : str
2302        Filename to read.
2303    dtype : data-type, optional
2304        Data-type of the resulting array; default: float.
2305    delimiter : str, optional
2306        The string used to separate values.  By default, this is any
2307        whitespace.
2308    usecols : sequence, optional
2309        Which columns to read, with 0 being the first.  For example,
2310        ``usecols = (1,4,5)`` will extract the 2nd, 5th and 6th columns.
2311        The default, None, results in all columns being read.
2312    comments : str, optional
2313        The character used to indicate the start of a comment;
2314        default: '#'.
2315
2316    Examples
2317    --------
2318    >>> temp, velx = loadtxt(
2319    ...    "sphere.dat", usecols=(1,2), delimiter="\t")  # doctest: +SKIP
2320    """
2321    f = open(fname, "r")
2322    next_one = False
2323    units = []
2324    num_cols = -1
2325    for line in f.readlines():
2326        words = line.strip().split()
2327        if len(words) == 0:
2328            continue
2329        if line[0] == comments:
2330            if next_one:
2331                units = words[1:]
2332            if len(words) == 2 and words[1] == "Units":
2333                next_one = True
2334        else:
2335            # Here we catch the first line of numbers
2336            col_words = line.strip().split(delimiter)
2337            for word in col_words:
2338                float(word)
2339            num_cols = len(col_words)
2340            break
2341    f.close()
2342    if len(units) != num_cols:
2343        units = ["dimensionless"] * num_cols
2344    arrays = np.loadtxt(
2345        fname,
2346        dtype=dtype,
2347        comments=comments,
2348        delimiter=delimiter,
2349        converters=None,
2350        unpack=True,
2351        usecols=usecols,
2352        ndmin=0,
2353    )
2354    if len(arrays.shape) < 2:
2355        arrays = [arrays]
2356    if usecols is not None:
2357        units = [units[col] for col in usecols]
2358    ret = tuple([unyt_array(arr, unit) for arr, unit in zip(arrays, units)])
2359    if len(ret) == 1:
2360        return ret[0]
2361    return ret
2362
2363
2364def savetxt(
2365    fname, arrays, fmt="%.18e", delimiter="\t", header="", footer="", comments="#"
2366):
2367    r"""
2368    Write unyt_arrays with unit information to a text file.
2369
2370    Parameters
2371    ----------
2372    fname : str
2373        The file to write the unyt_arrays to.
2374    arrays : list of unyt_arrays or single unyt_array
2375        The array(s) to write to the file.
2376    fmt : str or sequence of strs, optional
2377        A single format (%10.5f), or a sequence of formats.
2378    delimiter : str, optional
2379        String or character separating columns.
2380    header : str, optional
2381        String that will be written at the beginning of the file, before the
2382        unit header.
2383    footer : str, optional
2384        String that will be written at the end of the file.
2385    comments : str, optional
2386        String that will be prepended to the ``header`` and ``footer`` strings,
2387        to mark them as comments. Default: '# ', as expected by e.g.
2388        ``unyt.loadtxt``.
2389
2390    Examples
2391    --------
2392    >>> import unyt as u
2393    >>> a = [1, 2, 3]*u.cm
2394    >>> b = [8, 10, 12]*u.cm/u.s
2395    >>> c = [2, 85, 9]*u.g
2396    >>> savetxt("sphere.dat", [a,b,c], header='My sphere stuff',
2397    ...          delimiter="\t")  # doctest: +SKIP
2398    """
2399    if not isinstance(arrays, list):
2400        arrays = [arrays]
2401    units = []
2402    for array in arrays:
2403        if hasattr(array, "units"):
2404            units.append(str(array.units))
2405        else:
2406            units.append("dimensionless")
2407    if header != "" and not header.endswith("\n"):
2408        header += "\n"
2409    header += " Units\n " + "\t".join(units)
2410    np.savetxt(
2411        fname,
2412        np.transpose(arrays),
2413        header=header,
2414        fmt=fmt,
2415        delimiter=delimiter,
2416        footer=footer,
2417        newline="\n",
2418        comments=comments,
2419    )
2420
2421
2422def allclose_units(actual, desired, rtol=1e-7, atol=0, **kwargs):
2423    """Returns False if two objects are not equal up to desired tolerance
2424
2425    This is a wrapper for :func:`numpy.allclose` that also
2426    verifies unit consistency
2427
2428    Parameters
2429    ----------
2430    actual : array-like
2431        Array obtained (possibly with attached units)
2432    desired : array-like
2433        Array to compare with (possibly with attached units)
2434    rtol : float, optional
2435        Relative tolerance, defaults to 1e-7
2436    atol : float or quantity, optional
2437        Absolute tolerance. If units are attached, they must be consistent
2438        with the units of ``actual`` and ``desired``. If no units are attached,
2439        assumes the same units as ``desired``. Defaults to zero.
2440
2441    Raises
2442    ------
2443    RuntimeError
2444        If units of ``rtol`` are not dimensionless
2445
2446    See Also
2447    --------
2448    :func:`unyt.testing.assert_allclose_units`
2449
2450    Notes
2451    -----
2452    Also accepts additional keyword arguments accepted by
2453    :func:`numpy.allclose`, see the documentation of that
2454    function for details.
2455
2456    Examples
2457    --------
2458    >>> import unyt as u
2459    >>> actual = [1e-5, 1e-3, 1e-1]*u.m
2460    >>> desired = actual.to("cm")
2461    >>> allclose_units(actual, desired)
2462    True
2463    """
2464    # Create a copy to ensure this function does not alter input arrays
2465    act = unyt_array(actual)
2466    des = unyt_array(desired)
2467
2468    try:
2469        des = des.in_units(act.units)
2470    except (UnitOperationError, UnitConversionError):
2471        return False
2472
2473    rt = unyt_array(rtol)
2474    if not rt.units.is_dimensionless:
2475        raise RuntimeError("Units of rtol (%s) are not " "dimensionless" % rt.units)
2476
2477    if not isinstance(atol, unyt_array):
2478        at = unyt_quantity(atol, des.units)
2479    else:
2480        at = atol
2481
2482    try:
2483        at = at.in_units(act.units)
2484    except (UnitOperationError, UnitConversionError):
2485        return False
2486
2487    # units have been validated, so we strip units before calling numpy
2488    # to avoid spurious errors
2489    act = act.value
2490    des = des.value
2491    rt = rt.value
2492    at = at.value
2493
2494    return np.allclose(act, des, rt, at, **kwargs)
2495