1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5"""
6This module implements a FloatWithUnit, which is a subclass of float. It
7also defines supported units for some commonly used units for energy, length,
8temperature, time and charge. FloatWithUnit also support conversion to one
9another, and additions and subtractions perform automatic conversion if
10units are detected. An ArrayWithUnit is also implemented, which is a subclass
11of numpy's ndarray with similar unit features.
12"""
13
14import collections
15import numbers
16from functools import partial
17
18import numpy as np
19import scipy.constants as const
20
21__author__ = "Shyue Ping Ong, Matteo Giantomassi"
22__copyright__ = "Copyright 2011, The Materials Project"
23__version__ = "1.0"
24__maintainer__ = "Shyue Ping Ong, Matteo Giantomassi"
25__status__ = "Production"
26__date__ = "Aug 30, 2013"
27
28"""
29Some conversion factors
30"""
31Ha_to_eV = 1 / const.physical_constants["electron volt-hartree relationship"][0]
32eV_to_Ha = 1 / Ha_to_eV
33Ry_to_eV = Ha_to_eV / 2
34amu_to_kg = const.physical_constants["atomic mass unit-kilogram relationship"][0]
35mile_to_meters = const.mile
36bohr_to_angstrom = const.physical_constants["Bohr radius"][0] * 1e10
37bohr_to_ang = bohr_to_angstrom
38ang_to_bohr = 1 / bohr_to_ang
39kCal_to_kJ = const.calorie
40kb = const.physical_constants["Boltzmann constant in eV/K"][0]
41
42"""
43Definitions of supported units. Values below are essentially scaling and
44conversion factors. What matters is the relative values, not the absolute.
45The SI units must have factor 1.
46"""
47BASE_UNITS = {
48    "length": {
49        "m": 1,
50        "km": 1000,
51        "mile": mile_to_meters,
52        "ang": 1e-10,
53        "cm": 1e-2,
54        "pm": 1e-12,
55        "bohr": bohr_to_angstrom * 1e-10,
56    },
57    "mass": {
58        "kg": 1,
59        "g": 1e-3,
60        "amu": amu_to_kg,
61    },
62    "time": {
63        "s": 1,
64        "min": 60,
65        "h": 3600,
66        "d": 3600 * 24,
67    },
68    "current": {"A": 1},
69    "temperature": {
70        "K": 1,
71    },
72    "amount": {"mol": 1, "atom": 1 / const.N_A},
73    "intensity": {"cd": 1},
74    "memory": {
75        "byte": 1,
76        "Kb": 1024,
77        "Mb": 1024 ** 2,
78        "Gb": 1024 ** 3,
79        "Tb": 1024 ** 4,
80    },
81}
82
83# Accept kb, mb, gb ... as well.
84BASE_UNITS["memory"].update({k.lower(): v for k, v in BASE_UNITS["memory"].items()})
85
86# This current list are supported derived units defined in terms of powers of
87# SI base units and constants.
88DERIVED_UNITS = {
89    "energy": {
90        "eV": {"kg": 1, "m": 2, "s": -2, const.e: 1},
91        "meV": {"kg": 1, "m": 2, "s": -2, const.e * 1e-3: 1},
92        "Ha": {"kg": 1, "m": 2, "s": -2, const.e * Ha_to_eV: 1},
93        "Ry": {"kg": 1, "m": 2, "s": -2, const.e * Ry_to_eV: 1},
94        "J": {"kg": 1, "m": 2, "s": -2},
95        "kJ": {"kg": 1, "m": 2, "s": -2, 1000: 1},
96        "kCal": {"kg": 1, "m": 2, "s": -2, 1000: 1, kCal_to_kJ: 1},
97    },
98    "charge": {
99        "C": {"A": 1, "s": 1},
100        "e": {"A": 1, "s": 1, const.e: 1},
101    },
102    "force": {
103        "N": {"kg": 1, "m": 1, "s": -2},
104        "KN": {"kg": 1, "m": 1, "s": -2, 1000: 1},
105        "MN": {"kg": 1, "m": 1, "s": -2, 1e6: 1},
106        "GN": {"kg": 1, "m": 1, "s": -2, 1e9: 1},
107    },
108    "frequency": {
109        "Hz": {"s": -1},
110        "KHz": {"s": -1, 1000: 1},
111        "MHz": {"s": -1, 1e6: 1},
112        "GHz": {"s": -1, 1e9: 1},
113        "THz": {"s": -1, 1e12: 1},
114    },
115    "pressure": {
116        "Pa": {"kg": 1, "m": -1, "s": -2},
117        "KPa": {"kg": 1, "m": -1, "s": -2, 1000: 1},
118        "MPa": {"kg": 1, "m": -1, "s": -2, 1e6: 1},
119        "GPa": {"kg": 1, "m": -1, "s": -2, 1e9: 1},
120    },
121    "power": {
122        "W": {"m": 2, "kg": 1, "s": -3},
123        "KW": {"m": 2, "kg": 1, "s": -3, 1000: 1},
124        "MW": {"m": 2, "kg": 1, "s": -3, 1e6: 1},
125        "GW": {"m": 2, "kg": 1, "s": -3, 1e9: 1},
126    },
127    "emf": {"V": {"m": 2, "kg": 1, "s": -3, "A": -1}},
128    "capacitance": {"F": {"m": -2, "kg": -1, "s": 4, "A": 2}},
129    "resistance": {"ohm": {"m": 2, "kg": 1, "s": -3, "A": -2}},
130    "conductance": {"S": {"m": -2, "kg": -1, "s": 3, "A": 2}},
131    "magnetic_flux": {"Wb": {"m": 2, "kg": 1, "s": -2, "A": -1}},
132    "cross_section": {"barn": {"m": 2, 1e-28: 1}, "mbarn": {"m": 2, 1e-31: 1}},
133}
134
135ALL_UNITS = dict(list(BASE_UNITS.items()) + list(DERIVED_UNITS.items()))  # type: ignore
136SUPPORTED_UNIT_NAMES = tuple(i for d in ALL_UNITS.values() for i in d.keys())
137
138# Mapping unit name --> unit type (unit names must be unique).
139_UNAME2UTYPE = {}  # type: ignore
140for utype, d in ALL_UNITS.items():
141    assert not set(d.keys()).intersection(_UNAME2UTYPE.keys())
142    _UNAME2UTYPE.update({uname: utype for uname in d})
143del utype, d
144
145
146def _get_si_unit(unit):
147    unit_type = _UNAME2UTYPE[unit]
148    si_unit = filter(lambda k: BASE_UNITS[unit_type][k] == 1, BASE_UNITS[unit_type].keys())
149    return list(si_unit)[0], BASE_UNITS[unit_type][unit]
150
151
152class UnitError(BaseException):
153    """
154    Exception class for unit errors.
155    """
156
157
158def _check_mappings(u):
159    for v in DERIVED_UNITS.values():
160        for k2, v2 in v.items():
161            if all(v2.get(ku, 0) == vu for ku, vu in u.items()) and all(
162                u.get(kv2, 0) == vv2 for kv2, vv2 in v2.items()
163            ):
164                return {k2: 1}
165    return u
166
167
168class Unit(collections.abc.Mapping):
169    """
170    Represents a unit, e.g., "m" for meters, etc. Supports compound units.
171    Only integer powers are supported for units.
172    """
173
174    Error = UnitError
175
176    def __init__(self, unit_def):
177        """
178        Constructs a unit.
179
180        Args:
181            unit_def: A definition for the unit. Either a mapping of unit to
182                powers, e.g., {"m": 2, "s": -1} represents "m^2 s^-1",
183                or simply as a string "kg m^2 s^-1". Note that the supported
184                format uses "^" as the power operator and all units must be
185                space-separated.
186        """
187
188        if isinstance(unit_def, str):
189            unit = collections.defaultdict(int)
190            import re
191
192            for m in re.finditer(r"([A-Za-z]+)\s*\^*\s*([\-0-9]*)", unit_def):
193                p = m.group(2)
194                p = 1 if not p else int(p)
195                k = m.group(1)
196                unit[k] += p
197        else:
198            unit = {k: v for k, v in dict(unit_def).items() if v != 0}
199        self._unit = _check_mappings(unit)
200
201    def __mul__(self, other):
202        new_units = collections.defaultdict(int)
203        for k, v in self.items():
204            new_units[k] += v
205        for k, v in other.items():
206            new_units[k] += v
207        return Unit(new_units)
208
209    def __rmul__(self, other):
210        return self.__mul__(other)
211
212    def __div__(self, other):
213        new_units = collections.defaultdict(int)
214        for k, v in self.items():
215            new_units[k] += v
216        for k, v in other.items():
217            new_units[k] -= v
218        return Unit(new_units)
219
220    def __truediv__(self, other):
221        return self.__div__(other)
222
223    def __pow__(self, i):
224        return Unit({k: v * i for k, v in self.items()})
225
226    def __iter__(self):
227        return self._unit.__iter__()
228
229    def __getitem__(self, i):
230        return self._unit[i]
231
232    def __len__(self):
233        return len(self._unit)
234
235    def __repr__(self):
236        sorted_keys = sorted(self._unit.keys(), key=lambda k: (-self._unit[k], k))
237        return " ".join(
238            ["{}^{}".format(k, self._unit[k]) if self._unit[k] != 1 else k for k in sorted_keys if self._unit[k] != 0]
239        )
240
241    def __str__(self):
242        return self.__repr__()
243
244    @property
245    def as_base_units(self):
246        """
247        Converts all units to base SI units, including derived units.
248
249        Returns:
250            (base_units_dict, scaling factor). base_units_dict will not
251            contain any constants, which are gathered in the scaling factor.
252        """
253        b = collections.defaultdict(int)
254        factor = 1
255        for k, v in self.items():
256            derived = False
257            for d in DERIVED_UNITS.values():
258                if k in d:
259                    for k2, v2 in d[k].items():
260                        if isinstance(k2, numbers.Number):
261                            factor *= k2 ** (v2 * v)
262                        else:
263                            b[k2] += v2 * v
264                    derived = True
265                    break
266            if not derived:
267                si, f = _get_si_unit(k)
268                b[si] += v
269                factor *= f ** v
270        return {k: v for k, v in b.items() if v != 0}, factor
271
272    def get_conversion_factor(self, new_unit):
273        """
274        Returns a conversion factor between this unit and a new unit.
275        Compound units are supported, but must have the same powers in each
276        unit type.
277
278        Args:
279            new_unit: The new unit.
280        """
281        uo_base, ofactor = self.as_base_units
282        un_base, nfactor = Unit(new_unit).as_base_units
283        units_new = sorted(un_base.items(), key=lambda d: _UNAME2UTYPE[d[0]])
284        units_old = sorted(uo_base.items(), key=lambda d: _UNAME2UTYPE[d[0]])
285        factor = ofactor / nfactor
286        for uo, un in zip(units_old, units_new):
287            if uo[1] != un[1]:
288                raise UnitError("Units %s and %s are not compatible!" % (uo, un))
289            c = ALL_UNITS[_UNAME2UTYPE[uo[0]]]
290            factor *= (c[uo[0]] / c[un[0]]) ** uo[1]
291        return factor
292
293
294class FloatWithUnit(float):
295    """
296    Subclasses float to attach a unit type. Typically, you should use the
297    pre-defined unit type subclasses such as Energy, Length, etc. instead of
298    using FloatWithUnit directly.
299
300    Supports conversion, addition and subtraction of the same unit type. E.g.,
301    1 m + 20 cm will be automatically converted to 1.2 m (units follow the
302    leftmost quantity). Note that FloatWithUnit does not override the eq
303    method for float, i.e., units are not checked when testing for equality.
304    The reason is to allow this class to be used transparently wherever floats
305    are expected.
306
307    >>> e = Energy(1.1, "Ha")
308    >>> a = Energy(1.1, "Ha")
309    >>> b = Energy(3, "eV")
310    >>> c = a + b
311    >>> print(c)
312    1.2102479761938871 Ha
313    >>> c.to("eV")
314    32.932522246000005 eV
315    """
316
317    Error = UnitError
318
319    @classmethod
320    def from_string(cls, s):
321        """
322        Initialize a FloatWithUnit from a string. Example Memory.from_string("1. Mb")
323        """
324        # Extract num and unit string.
325        s = s.strip()
326        for i, char in enumerate(s):
327            if char.isalpha() or char.isspace():
328                break
329        else:
330            raise Exception("Unit is missing in string %s" % s)
331        num, unit = float(s[:i]), s[i:]
332
333        # Find unit type (set it to None if it cannot be detected)
334        for unit_type, d in BASE_UNITS.items():
335            if unit in d:
336                break
337        else:
338            unit_type = None
339
340        return cls(num, unit, unit_type=unit_type)
341
342    def __new__(cls, val, unit, unit_type=None):
343        """Overrides __new__ since we are subclassing a Python primitive/"""
344        new = float.__new__(cls, val)
345        new._unit = Unit(unit)
346        new._unit_type = unit_type
347        return new
348
349    def __init__(self, val, unit, unit_type=None):
350        """
351        Initializes a float with unit.
352
353        Args:
354            val (float): Value
355            unit (Unit): A unit. E.g., "C".
356            unit_type (str): A type of unit. E.g., "charge"
357        """
358        if unit_type is not None and str(unit) not in ALL_UNITS[unit_type]:
359            raise UnitError("{} is not a supported unit for {}".format(unit, unit_type))
360        self._unit = Unit(unit)
361        self._unit_type = unit_type
362
363    def __repr__(self):
364        return super().__repr__()
365
366    def __str__(self):
367        s = super().__str__()
368        return "{} {}".format(s, self._unit)
369
370    def __add__(self, other):
371        if not hasattr(other, "unit_type"):
372            return super().__add__(other)
373        if other.unit_type != self._unit_type:
374            raise UnitError("Adding different types of units is not allowed")
375        val = other
376        if other.unit != self._unit:
377            val = other.to(self._unit)
378        return FloatWithUnit(float(self) + val, unit_type=self._unit_type, unit=self._unit)
379
380    def __sub__(self, other):
381        if not hasattr(other, "unit_type"):
382            return super().__sub__(other)
383        if other.unit_type != self._unit_type:
384            raise UnitError("Subtracting different units is not allowed")
385        val = other
386        if other.unit != self._unit:
387            val = other.to(self._unit)
388        return FloatWithUnit(float(self) - val, unit_type=self._unit_type, unit=self._unit)
389
390    def __mul__(self, other):
391        if not isinstance(other, FloatWithUnit):
392            return FloatWithUnit(float(self) * other, unit_type=self._unit_type, unit=self._unit)
393        return FloatWithUnit(float(self) * other, unit_type=None, unit=self._unit * other._unit)
394
395    def __rmul__(self, other):
396        if not isinstance(other, FloatWithUnit):
397            return FloatWithUnit(float(self) * other, unit_type=self._unit_type, unit=self._unit)
398        return FloatWithUnit(float(self) * other, unit_type=None, unit=self._unit * other._unit)
399
400    def __pow__(self, i):
401        return FloatWithUnit(float(self) ** i, unit_type=None, unit=self._unit ** i)
402
403    def __truediv__(self, other):
404        val = super().__truediv__(other)
405        if not isinstance(other, FloatWithUnit):
406            return FloatWithUnit(val, unit_type=self._unit_type, unit=self._unit)
407        return FloatWithUnit(val, unit_type=None, unit=self._unit / other._unit)
408
409    def __neg__(self):
410        return FloatWithUnit(super().__neg__(), unit_type=self._unit_type, unit=self._unit)
411
412    def __getnewargs__(self):
413        """Function used by pickle to recreate object."""
414        # print(self.__dict__)
415        # FIXME
416        # There's a problem with _unit_type if we try to unpickle objects from file.
417        # since self._unit_type might not be defined. I think this is due to
418        # the use of decorators (property and unitized). In particular I have problems with "amu"
419        # likely due to weight in core.composition
420        if hasattr(self, "_unit_type"):
421            args = float(self), self._unit, self._unit_type
422        else:
423            args = float(self), self._unit, None
424
425        return args
426
427    def __getstate__(self):
428        state = self.__dict__.copy()
429        state["val"] = float(self)
430        # print("in getstate %s" % state)
431        return state
432
433    def __setstate__(self, state):
434        # print("in setstate %s" % state)
435        self._unit = state["_unit"]
436
437    @property
438    def unit_type(self) -> str:
439        """
440        :return: The type of unit. Energy, Charge, etc.
441        """
442        return self._unit_type
443
444    @property
445    def unit(self) -> str:
446        """
447        :return: The unit, e.g., "eV".
448        """
449        return self._unit
450
451    def to(self, new_unit):
452        """
453        Conversion to a new_unit. Right now, only supports 1 to 1 mapping of
454        units of each type.
455
456        Args:
457            new_unit: New unit type.
458
459        Returns:
460            A FloatWithUnit object in the new units.
461
462        Example usage:
463        >>> e = Energy(1.1, "eV")
464        >>> e = Energy(1.1, "Ha")
465        >>> e.to("eV")
466        29.932522246 eV
467        """
468        return FloatWithUnit(
469            self * self.unit.get_conversion_factor(new_unit),
470            unit_type=self._unit_type,
471            unit=new_unit,
472        )
473
474    @property
475    def as_base_units(self):
476        """
477        Returns this FloatWithUnit in base SI units, including derived units.
478
479        Returns:
480            A FloatWithUnit object in base SI units
481        """
482        return self.to(self.unit.as_base_units[0])
483
484    @property
485    def supported_units(self):
486        """
487        Supported units for specific unit type.
488        """
489        return tuple(ALL_UNITS[self._unit_type].keys())
490
491
492class ArrayWithUnit(np.ndarray):
493    """
494    Subclasses `numpy.ndarray` to attach a unit type. Typically, you should
495    use the pre-defined unit type subclasses such as EnergyArray,
496    LengthArray, etc. instead of using ArrayWithFloatWithUnit directly.
497
498    Supports conversion, addition and subtraction of the same unit type. E.g.,
499    1 m + 20 cm will be automatically converted to 1.2 m (units follow the
500    leftmost quantity).
501
502    >>> a = EnergyArray([1, 2], "Ha")
503    >>> b = EnergyArray([1, 2], "eV")
504    >>> c = a + b
505    >>> print(c)
506    [ 1.03674933  2.07349865] Ha
507    >>> c.to("eV")
508    array([ 28.21138386,  56.42276772]) eV
509    """
510
511    Error = UnitError
512
513    def __new__(cls, input_array, unit, unit_type=None):
514        """
515        Override __new__.
516        """
517        # Input array is an already formed ndarray instance
518        # We first cast to be our class type
519        obj = np.asarray(input_array).view(cls)
520        # add the new attributes to the created instance
521        obj._unit = Unit(unit)
522        obj._unit_type = unit_type
523        return obj
524
525    def __array_finalize__(self, obj):
526        """
527        See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html for
528        comments.
529        """
530        if obj is None:
531            return
532        self._unit = getattr(obj, "_unit", None)
533        self._unit_type = getattr(obj, "_unit_type", None)
534
535    @property
536    def unit_type(self) -> str:
537        """
538        :return: The type of unit. Energy, Charge, etc.
539        """
540        return self._unit_type
541
542    @property
543    def unit(self) -> str:
544        """
545        :return: The unit, e.g., "eV".
546        """
547        return self._unit
548
549    def __reduce__(self):
550        # print("in reduce")
551        reduce = list(super().__reduce__())
552        # print("unit",self._unit)
553        # print(reduce[2])
554        reduce[2] = {"np_state": reduce[2], "_unit": self._unit}
555        return tuple(reduce)
556
557    def __setstate__(self, state):
558        # pylint: disable=E1101
559        super().__setstate__(state["np_state"])
560        self._unit = state["_unit"]
561
562    def __repr__(self):
563        return "{} {}".format(np.array(self).__repr__(), self.unit)
564
565    def __str__(self):
566        return "{} {}".format(np.array(self).__str__(), self.unit)
567
568    def __add__(self, other):
569        if hasattr(other, "unit_type"):
570            if other.unit_type != self.unit_type:
571                raise UnitError("Adding different types of units is" " not allowed")
572
573            if other.unit != self.unit:
574                other = other.to(self.unit)
575
576        return self.__class__(np.array(self) + np.array(other), unit_type=self.unit_type, unit=self.unit)
577
578    def __sub__(self, other):
579        if hasattr(other, "unit_type"):
580            if other.unit_type != self.unit_type:
581                raise UnitError("Subtracting different units is not allowed")
582
583            if other.unit != self.unit:
584                other = other.to(self.unit)
585
586        return self.__class__(np.array(self) - np.array(other), unit_type=self.unit_type, unit=self.unit)
587
588    def __mul__(self, other):
589        # FIXME
590        # Here we have the most important difference between FloatWithUnit and
591        # ArrayWithFloatWithUnit:
592        # If other does not have units, I return an object with the same units
593        # as self.
594        # if other *has* units, I return an object *without* units since
595        # taking into account all the possible derived quantities would be
596        # too difficult.
597        # Moreover Energy(1.0) * Time(1.0, "s") returns 1.0 Ha that is a
598        # bit misleading.
599        # Same protocol for __div__
600        if not hasattr(other, "unit_type"):
601            return self.__class__(
602                np.array(self).__mul__(np.array(other)),
603                unit_type=self._unit_type,
604                unit=self._unit,
605            )
606        # Cannot use super since it returns an instance of self.__class__
607        # while here we want a bare numpy array.
608        return self.__class__(np.array(self).__mul__(np.array(other)), unit=self.unit * other.unit)
609
610    def __rmul__(self, other):
611        # pylint: disable=E1101
612        if not hasattr(other, "unit_type"):
613            return self.__class__(
614                np.array(self).__rmul__(np.array(other)),
615                unit_type=self._unit_type,
616                unit=self._unit,
617            )
618        return self.__class__(np.array(self).__rmul__(np.array(other)), unit=self.unit * other.unit)
619
620    def __div__(self, other):
621        # pylint: disable=E1101
622        if not hasattr(other, "unit_type"):
623            return self.__class__(
624                np.array(self).__div__(np.array(other)),
625                unit_type=self._unit_type,
626                unit=self._unit,
627            )
628        return self.__class__(np.array(self).__div__(np.array(other)), unit=self.unit / other.unit)
629
630    def __truediv__(self, other):
631        # pylint: disable=E1101
632        if not hasattr(other, "unit_type"):
633            return self.__class__(
634                np.array(self).__truediv__(np.array(other)),
635                unit_type=self._unit_type,
636                unit=self._unit,
637            )
638        return self.__class__(np.array(self).__truediv__(np.array(other)), unit=self.unit / other.unit)
639
640    def __neg__(self):
641        return self.__class__(np.array(self).__neg__(), unit_type=self.unit_type, unit=self.unit)
642
643    def to(self, new_unit):
644        """
645        Conversion to a new_unit.
646
647        Args:
648            new_unit:
649                New unit type.
650
651        Returns:
652            A ArrayWithFloatWithUnit object in the new units.
653
654        Example usage:
655        >>> e = EnergyArray([1, 1.1], "Ha")
656        >>> e.to("eV")
657        array([ 27.21138386,  29.93252225]) eV
658        """
659        return self.__class__(
660            np.array(self) * self.unit.get_conversion_factor(new_unit),
661            unit_type=self.unit_type,
662            unit=new_unit,
663        )
664
665    @property
666    def as_base_units(self):
667        """
668        Returns this ArrayWithUnit in base SI units, including derived units.
669
670        Returns:
671            An ArrayWithUnit object in base SI units
672        """
673        return self.to(self.unit.as_base_units[0])
674
675    # TODO abstract base class property?
676    @property
677    def supported_units(self):
678        """
679        Supported units for specific unit type.
680        """
681        return ALL_UNITS[self.unit_type]
682
683    # TODO abstract base class method?
684    def conversions(self):
685        """
686        Returns a string showing the available conversions.
687        Useful tool in interactive mode.
688        """
689        return "\n".join(str(self.to(unit)) for unit in self.supported_units)
690
691
692def _my_partial(func, *args, **kwargs):
693    """
694    Partial returns a partial object and therefore we cannot inherit class
695    methods defined in FloatWithUnit. This function calls partial and patches
696    the new class before returning.
697    """
698    newobj = partial(func, *args, **kwargs)
699    # monkey patch
700    newobj.from_string = FloatWithUnit.from_string
701    return newobj
702
703
704Energy = partial(FloatWithUnit, unit_type="energy")
705"""
706A float with an energy unit.
707
708Args:
709    val (float): Value
710    unit (Unit): E.g., eV, kJ, etc. Must be valid unit or UnitError is raised.
711"""
712EnergyArray = partial(ArrayWithUnit, unit_type="energy")
713
714Length = partial(FloatWithUnit, unit_type="length")
715"""
716A float with a length unit.
717
718Args:
719    val (float): Value
720    unit (Unit): E.g., m, ang, bohr, etc. Must be valid unit or UnitError is
721        raised.
722"""
723LengthArray = partial(ArrayWithUnit, unit_type="length")
724
725Mass = partial(FloatWithUnit, unit_type="mass")
726"""
727A float with a mass unit.
728
729Args:
730    val (float): Value
731    unit (Unit): E.g., amu, kg, etc. Must be valid unit or UnitError is
732        raised.
733"""
734MassArray = partial(ArrayWithUnit, unit_type="mass")
735
736Temp = partial(FloatWithUnit, unit_type="temperature")
737"""
738A float with a temperature unit.
739
740Args:
741    val (float): Value
742    unit (Unit): E.g., K. Only K (kelvin) is supported.
743"""
744TempArray = partial(ArrayWithUnit, unit_type="temperature")
745
746Time = partial(FloatWithUnit, unit_type="time")
747"""
748A float with a time unit.
749
750Args:
751    val (float): Value
752    unit (Unit): E.g., s, min, h. Must be valid unit or UnitError is
753        raised.
754"""
755TimeArray = partial(ArrayWithUnit, unit_type="time")
756
757Charge = partial(FloatWithUnit, unit_type="charge")
758"""
759A float with a charge unit.
760
761Args:
762    val (float): Value
763    unit (Unit): E.g., C, e (electron charge). Must be valid unit or UnitError
764        is raised.
765"""
766ChargeArray = partial(ArrayWithUnit, unit_type="charge")
767
768Memory = _my_partial(FloatWithUnit, unit_type="memory")
769"""
770A float with a memory unit.
771
772Args:
773    val (float): Value
774    unit (Unit): E.g., Kb, Mb, Gb, Tb. Must be valid unit or UnitError
775        is raised.
776"""
777
778
779def obj_with_unit(obj, unit):
780    """
781    Returns a `FloatWithUnit` instance if obj is scalar, a dictionary of
782    objects with units if obj is a dict, else an instance of
783    `ArrayWithFloatWithUnit`.
784
785    Args:
786        unit: Specific units (eV, Ha, m, ang, etc.).
787    """
788    unit_type = _UNAME2UTYPE[unit]
789
790    if isinstance(obj, numbers.Number):
791        return FloatWithUnit(obj, unit=unit, unit_type=unit_type)
792    if isinstance(obj, collections.Mapping):
793        return {k: obj_with_unit(v, unit) for k, v in obj.items()}
794    return ArrayWithUnit(obj, unit=unit, unit_type=unit_type)
795
796
797def unitized(unit):
798    """
799    Useful decorator to assign units to the output of a function. You can also
800    use it to standardize the output units of a function that already returns
801    a FloatWithUnit or ArrayWithUnit. For sequences, all values in the sequences
802    are assigned the same unit. It works with Python sequences only. The creation
803    of numpy arrays loses all unit information. For mapping types, the values
804    are assigned units.
805
806    Args:
807        unit: Specific unit (eV, Ha, m, ang, etc.).
808
809    Example usage::
810
811        @unitized(unit="kg")
812        def get_mass():
813            return 123.45
814
815    """
816
817    def wrap(f):
818        def wrapped_f(*args, **kwargs):
819            val = f(*args, **kwargs)
820            unit_type = _UNAME2UTYPE[unit]
821
822            if isinstance(val, (FloatWithUnit, ArrayWithUnit)):
823                return val.to(unit)
824
825            if isinstance(val, collections.abc.Sequence):
826                # TODO: why don't we return a ArrayWithUnit?
827                # This complicated way is to ensure the sequence type is
828                # preserved (list or tuple).
829                return val.__class__([FloatWithUnit(i, unit_type=unit_type, unit=unit) for i in val])
830            if isinstance(val, collections.abc.Mapping):
831                for k, v in val.items():
832                    val[k] = FloatWithUnit(v, unit_type=unit_type, unit=unit)
833            elif isinstance(val, numbers.Number):
834                return FloatWithUnit(val, unit_type=unit_type, unit=unit)
835            elif val is None:
836                pass
837            else:
838                raise TypeError("Don't know how to assign units to %s" % str(val))
839            return val
840
841        return wrapped_f
842
843    return wrap
844
845
846if __name__ == "__main__":
847    import doctest
848
849    doctest.testmod()
850