1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3import numpy as np
4
5from astropy.units import (dimensionless_unscaled, photometric, Unit,
6                           CompositeUnit, UnitsError, UnitTypeError,
7                           UnitConversionError)
8
9from .core import FunctionUnitBase, FunctionQuantity
10from .units import dex, dB, mag
11
12
13__all__ = ['LogUnit', 'MagUnit', 'DexUnit', 'DecibelUnit',
14           'LogQuantity', 'Magnitude', 'Decibel', 'Dex',
15           'STmag', 'ABmag', 'M_bol', 'm_bol']
16
17
18class LogUnit(FunctionUnitBase):
19    """Logarithmic unit containing a physical one
20
21    Usually, logarithmic units are instantiated via specific subclasses
22    such `MagUnit`, `DecibelUnit`, and `DexUnit`.
23
24    Parameters
25    ----------
26    physical_unit : `~astropy.units.Unit` or `string`
27        Unit that is encapsulated within the logarithmic function unit.
28        If not given, dimensionless.
29
30    function_unit :  `~astropy.units.Unit` or `string`
31        By default, the same as the logarithmic unit set by the subclass.
32
33    """
34    # the four essential overrides of FunctionUnitBase
35    @property
36    def _default_function_unit(self):
37        return dex
38
39    @property
40    def _quantity_class(self):
41        return LogQuantity
42
43    def from_physical(self, x):
44        """Transformation from value in physical to value in logarithmic units.
45        Used in equivalency."""
46        return dex.to(self._function_unit, np.log10(x))
47
48    def to_physical(self, x):
49        """Transformation from value in logarithmic to value in physical units.
50        Used in equivalency."""
51        return 10 ** self._function_unit.to(dex, x)
52    # ^^^^ the four essential overrides of FunctionUnitBase
53
54    # add addition and subtraction, which imply multiplication/division of
55    # the underlying physical units
56    def _add_and_adjust_physical_unit(self, other, sign_self, sign_other):
57        """Add/subtract LogUnit to/from another unit, and adjust physical unit.
58
59        self and other are multiplied by sign_self and sign_other, resp.
60
61        We wish to do:   ±lu_1 + ±lu_2  -> lu_f          (lu=logarithmic unit)
62                  and     pu_1^(±1) * pu_2^(±1) -> pu_f  (pu=physical unit)
63
64        Raises
65        ------
66        UnitsError
67            If function units are not equivalent.
68        """
69        # First, insist on compatible logarithmic type. Here, plain u.mag,
70        # u.dex, and u.dB are OK, i.e., other does not have to be LogUnit
71        # (this will indirectly test whether other is a unit at all).
72        try:
73            getattr(other, 'function_unit', other)._to(self._function_unit)
74        except AttributeError:
75            # if other is not a unit (i.e., does not have _to).
76            return NotImplemented
77        except UnitsError:
78            raise UnitsError("Can only add/subtract logarithmic units of"
79                             "of compatible type.")
80
81        other_physical_unit = getattr(other, 'physical_unit',
82                                      dimensionless_unscaled)
83        physical_unit = CompositeUnit(
84            1, [self._physical_unit, other_physical_unit],
85            [sign_self, sign_other])
86
87        return self._copy(physical_unit)
88
89    def __neg__(self):
90        return self._copy(self.physical_unit**(-1))
91
92    def __add__(self, other):
93        # Only know how to add to a logarithmic unit with compatible type,
94        # be it a plain one (u.mag, etc.,) or another LogUnit
95        return self._add_and_adjust_physical_unit(other, +1, +1)
96
97    def __radd__(self, other):
98        return self._add_and_adjust_physical_unit(other, +1, +1)
99
100    def __sub__(self, other):
101        return self._add_and_adjust_physical_unit(other, +1, -1)
102
103    def __rsub__(self, other):
104        # here, in normal usage other cannot be LogUnit; only equivalent one
105        # would be u.mag,u.dB,u.dex.  But might as well use common routine.
106        return self._add_and_adjust_physical_unit(other, -1, +1)
107
108
109class MagUnit(LogUnit):
110    """Logarithmic physical units expressed in magnitudes
111
112    Parameters
113    ----------
114    physical_unit : `~astropy.units.Unit` or `string`
115        Unit that is encapsulated within the magnitude function unit.
116        If not given, dimensionless.
117
118    function_unit :  `~astropy.units.Unit` or `string`
119        By default, this is ``mag``, but this allows one to use an equivalent
120        unit such as ``2 mag``.
121    """
122    @property
123    def _default_function_unit(self):
124        return mag
125
126    @property
127    def _quantity_class(self):
128        return Magnitude
129
130
131class DexUnit(LogUnit):
132    """Logarithmic physical units expressed in magnitudes
133
134    Parameters
135    ----------
136    physical_unit : `~astropy.units.Unit` or `string`
137        Unit that is encapsulated within the magnitude function unit.
138        If not given, dimensionless.
139
140    function_unit :  `~astropy.units.Unit` or `string`
141        By default, this is ``dex`, but this allows one to use an equivalent
142        unit such as ``0.5 dex``.
143    """
144
145    @property
146    def _default_function_unit(self):
147        return dex
148
149    @property
150    def _quantity_class(self):
151        return Dex
152
153    def to_string(self, format='generic'):
154        if format == 'cds':
155            if self.physical_unit == dimensionless_unscaled:
156                return "[-]"  # by default, would get "[---]".
157            else:
158                return f"[{self.physical_unit.to_string(format=format)}]"
159        else:
160            return super(DexUnit, self).to_string()
161
162
163class DecibelUnit(LogUnit):
164    """Logarithmic physical units expressed in dB
165
166    Parameters
167    ----------
168    physical_unit : `~astropy.units.Unit` or `string`
169        Unit that is encapsulated within the decibel function unit.
170        If not given, dimensionless.
171
172    function_unit :  `~astropy.units.Unit` or `string`
173        By default, this is ``dB``, but this allows one to use an equivalent
174        unit such as ``2 dB``.
175    """
176
177    @property
178    def _default_function_unit(self):
179        return dB
180
181    @property
182    def _quantity_class(self):
183        return Decibel
184
185
186class LogQuantity(FunctionQuantity):
187    """A representation of a (scaled) logarithm of a number with a unit
188
189    Parameters
190    ----------
191    value : number, `~astropy.units.Quantity`, `~astropy.units.function.logarithmic.LogQuantity`, or sequence of quantity-like.
192        The numerical value of the logarithmic quantity. If a number or
193        a `~astropy.units.Quantity` with a logarithmic unit, it will be
194        converted to ``unit`` and the physical unit will be inferred from
195        ``unit``.  If a `~astropy.units.Quantity` with just a physical unit,
196        it will converted to the logarithmic unit, after, if necessary,
197        converting it to the physical unit inferred from ``unit``.
198
199    unit : str, `~astropy.units.UnitBase`, or `~astropy.units.function.FunctionUnitBase`, optional
200        For an `~astropy.units.function.FunctionUnitBase` instance, the
201        physical unit will be taken from it; for other input, it will be
202        inferred from ``value``. By default, ``unit`` is set by the subclass.
203
204    dtype : `~numpy.dtype`, optional
205        The ``dtype`` of the resulting Numpy array or scalar that will
206        hold the value.  If not provided, is is determined automatically
207        from the input value.
208
209    copy : bool, optional
210        If `True` (default), then the value is copied.  Otherwise, a copy will
211        only be made if ``__array__`` returns a copy, if value is a nested
212        sequence, or if a copy is needed to satisfy an explicitly given
213        ``dtype``.  (The `False` option is intended mostly for internal use,
214        to speed up initialization where a copy is known to have been made.
215        Use with care.)
216
217    Examples
218    --------
219    Typically, use is made of an `~astropy.units.function.FunctionQuantity`
220    subclass, as in::
221
222        >>> import astropy.units as u
223        >>> u.Magnitude(-2.5)
224        <Magnitude -2.5 mag>
225        >>> u.Magnitude(10.*u.count/u.second)
226        <Magnitude -2.5 mag(ct / s)>
227        >>> u.Decibel(1.*u.W, u.DecibelUnit(u.mW))  # doctest: +FLOAT_CMP
228        <Decibel 30. dB(mW)>
229
230    """
231    # only override of FunctionQuantity
232    _unit_class = LogUnit
233
234    # additions that work just for logarithmic units
235    def __add__(self, other):
236        # Add function units, thus multiplying physical units. If no unit is
237        # given, assume dimensionless_unscaled; this will give the appropriate
238        # exception in LogUnit.__add__.
239        new_unit = self.unit + getattr(other, 'unit', dimensionless_unscaled)
240        # Add actual logarithmic values, rescaling, e.g., dB -> dex.
241        result = self._function_view + getattr(other, '_function_view', other)
242        return self._new_view(result, new_unit)
243
244    def __radd__(self, other):
245        return self.__add__(other)
246
247    def __iadd__(self, other):
248        new_unit = self.unit + getattr(other, 'unit', dimensionless_unscaled)
249        # Do calculation in-place using _function_view of array.
250        function_view = self._function_view
251        function_view += getattr(other, '_function_view', other)
252        self._set_unit(new_unit)
253        return self
254
255    def __sub__(self, other):
256        # Subtract function units, thus dividing physical units.
257        new_unit = self.unit - getattr(other, 'unit', dimensionless_unscaled)
258        # Subtract actual logarithmic values, rescaling, e.g., dB -> dex.
259        result = self._function_view - getattr(other, '_function_view', other)
260        return self._new_view(result, new_unit)
261
262    def __rsub__(self, other):
263        new_unit = self.unit.__rsub__(
264            getattr(other, 'unit', dimensionless_unscaled))
265        result = self._function_view.__rsub__(
266            getattr(other, '_function_view', other))
267        # Ensure the result is in right function unit scale
268        # (with rsub, this does not have to be one's own).
269        result = result.to(new_unit.function_unit)
270        return self._new_view(result, new_unit)
271
272    def __isub__(self, other):
273        new_unit = self.unit - getattr(other, 'unit', dimensionless_unscaled)
274        # Do calculation in-place using _function_view of array.
275        function_view = self._function_view
276        function_view -= getattr(other, '_function_view', other)
277        self._set_unit(new_unit)
278        return self
279
280    def __pow__(self, other):
281        # We check if this power is OK by applying it first to the unit.
282        try:
283            other = float(other)
284        except TypeError:
285            return NotImplemented
286        new_unit = self.unit ** other
287        new_value = self.view(np.ndarray) ** other
288        return self._new_view(new_value, new_unit)
289
290    def __ilshift__(self, other):
291        try:
292            other = Unit(other)
293        except UnitTypeError:
294            return NotImplemented
295
296        if not isinstance(other, self._unit_class):
297            return NotImplemented
298
299        try:
300            factor = self.unit.physical_unit._to(other.physical_unit)
301        except UnitConversionError:
302            # Maybe via equivalencies?  Now we do make a temporary copy.
303            try:
304                value = self._to_value(other)
305            except UnitConversionError:
306                return NotImplemented
307
308            self.view(np.ndarray)[...] = value
309        else:
310            self.view(np.ndarray)[...] += self.unit.from_physical(factor)
311
312        self._set_unit(other)
313        return self
314
315    # Could add __mul__ and __div__ and try interpreting other as a power,
316    # but this seems just too error-prone.
317
318    # Methods that do not work for function units generally but are OK for
319    # logarithmic units as they imply differences and independence of
320    # physical unit.
321    def var(self, axis=None, dtype=None, out=None, ddof=0):
322        return self._wrap_function(np.var, axis, dtype, out=out, ddof=ddof,
323                                   unit=self.unit.function_unit**2)
324
325    def std(self, axis=None, dtype=None, out=None, ddof=0):
326        return self._wrap_function(np.std, axis, dtype, out=out, ddof=ddof,
327                                   unit=self.unit._copy(dimensionless_unscaled))
328
329    def ptp(self, axis=None, out=None):
330        return self._wrap_function(np.ptp, axis, out=out,
331                                   unit=self.unit._copy(dimensionless_unscaled))
332
333    def diff(self, n=1, axis=-1):
334        return self._wrap_function(np.diff, n, axis,
335                                   unit=self.unit._copy(dimensionless_unscaled))
336
337    def ediff1d(self, to_end=None, to_begin=None):
338        return self._wrap_function(np.ediff1d, to_end, to_begin,
339                                   unit=self.unit._copy(dimensionless_unscaled))
340
341    _supported_functions = (FunctionQuantity._supported_functions |
342                            set(getattr(np, function) for function in
343                                ('var', 'std', 'ptp', 'diff', 'ediff1d')))
344
345
346class Dex(LogQuantity):
347    _unit_class = DexUnit
348
349
350class Decibel(LogQuantity):
351    _unit_class = DecibelUnit
352
353
354class Magnitude(LogQuantity):
355    _unit_class = MagUnit
356
357
358dex._function_unit_class = DexUnit
359dB._function_unit_class = DecibelUnit
360mag._function_unit_class = MagUnit
361
362
363STmag = MagUnit(photometric.STflux)
364STmag.__doc__ = "ST magnitude: STmag=-21.1 corresponds to 1 erg/s/cm2/A"
365
366ABmag = MagUnit(photometric.ABflux)
367ABmag.__doc__ = "AB magnitude: ABmag=-48.6 corresponds to 1 erg/s/cm2/Hz"
368
369M_bol = MagUnit(photometric.Bol)
370M_bol.__doc__ = ("Absolute bolometric magnitude: M_bol=0 corresponds to "
371                 "L_bol0={}".format(photometric.Bol.si))
372
373m_bol = MagUnit(photometric.bol)
374m_bol.__doc__ = ("Apparent bolometric magnitude: m_bol=0 corresponds to "
375                 "f_bol0={}".format(photometric.bol.si))
376