1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2# This module implements the Arithmetic mixin to the NDData class.
3
4
5from copy import deepcopy
6
7import numpy as np
8
9from astropy.nddata.nduncertainty import NDUncertainty
10from astropy.units import dimensionless_unscaled
11from astropy.utils import format_doc, sharedmethod
12
13__all__ = ['NDArithmeticMixin']
14
15# Global so it doesn't pollute the class dict unnecessarily:
16
17# Docstring templates for add, subtract, multiply, divide methods.
18_arit_doc = """
19    Performs {name} by evaluating ``self`` {op} ``operand``.
20
21    Parameters
22    ----------
23    operand, operand2 : `NDData`-like instance
24        If ``operand2`` is ``None`` or not given it will perform the operation
25        ``self`` {op} ``operand``.
26        If ``operand2`` is given it will perform ``operand`` {op} ``operand2``.
27        If the method was called on a class rather than on the instance
28        ``operand2`` must be given.
29
30    propagate_uncertainties : `bool` or ``None``, optional
31        If ``None`` the result will have no uncertainty. If ``False`` the
32        result will have a copied version of the first operand that has an
33        uncertainty. If ``True`` the result will have a correctly propagated
34        uncertainty from the uncertainties of the operands but this assumes
35        that the uncertainties are `NDUncertainty`-like. Default is ``True``.
36
37        .. versionchanged:: 1.2
38            This parameter must be given as keyword-parameter. Using it as
39            positional parameter is deprecated.
40            ``None`` was added as valid parameter value.
41
42    handle_mask : callable, ``'first_found'`` or ``None``, optional
43        If ``None`` the result will have no mask. If ``'first_found'`` the
44        result will have a copied version of the first operand that has a
45        mask). If it is a callable then the specified callable must
46        create the results ``mask`` and if necessary provide a copy.
47        Default is `numpy.logical_or`.
48
49        .. versionadded:: 1.2
50
51    handle_meta : callable, ``'first_found'`` or ``None``, optional
52        If ``None`` the result will have no meta. If ``'first_found'`` the
53        result will have a copied version of the first operand that has a
54        (not empty) meta. If it is a callable then the specified callable must
55        create the results ``meta`` and if necessary provide a copy.
56        Default is ``None``.
57
58        .. versionadded:: 1.2
59
60    compare_wcs : callable, ``'first_found'`` or ``None``, optional
61        If ``None`` the result will have no wcs and no comparison between
62        the wcs of the operands is made. If ``'first_found'`` the
63        result will have a copied version of the first operand that has a
64        wcs. If it is a callable then the specified callable must
65        compare the ``wcs``. The resulting ``wcs`` will be like if ``False``
66        was given otherwise it raises a ``ValueError`` if the comparison was
67        not successful. Default is ``'first_found'``.
68
69        .. versionadded:: 1.2
70
71    uncertainty_correlation : number or `~numpy.ndarray`, optional
72        The correlation between the two operands is used for correct error
73        propagation for correlated data as given in:
74        https://en.wikipedia.org/wiki/Propagation_of_uncertainty#Example_formulas
75        Default is 0.
76
77        .. versionadded:: 1.2
78
79
80    kwargs :
81        Any other parameter that should be passed to the callables used.
82
83    Returns
84    -------
85    result : `~astropy.nddata.NDData`-like
86        The resulting dataset
87
88    Notes
89    -----
90    If a ``callable`` is used for ``mask``, ``wcs`` or ``meta`` the
91    callable must accept the corresponding attributes as first two
92    parameters. If the callable also needs additional parameters these can be
93    defined as ``kwargs`` and must start with ``"wcs_"`` (for wcs callable) or
94    ``"meta_"`` (for meta callable). This startstring is removed before the
95    callable is called.
96
97    ``"first_found"`` can also be abbreviated with ``"ff"``.
98    """
99
100
101class NDArithmeticMixin:
102    """
103    Mixin class to add arithmetic to an NDData object.
104
105    When subclassing, be sure to list the superclasses in the correct order
106    so that the subclass sees NDData as the main superclass. See
107    `~astropy.nddata.NDDataArray` for an example.
108
109    Notes
110    -----
111    This class only aims at covering the most common cases so there are certain
112    restrictions on the saved attributes::
113
114        - ``uncertainty`` : has to be something that has a `NDUncertainty`-like
115          interface for uncertainty propagation
116        - ``mask`` : has to be something that can be used by a bitwise ``or``
117          operation.
118        - ``wcs`` : has to implement a way of comparing with ``=`` to allow
119          the operation.
120
121    But there is a workaround that allows to disable handling a specific
122    attribute and to simply set the results attribute to ``None`` or to
123    copy the existing attribute (and neglecting the other).
124    For example for uncertainties not representing an `NDUncertainty`-like
125    interface you can alter the ``propagate_uncertainties`` parameter in
126    :meth:`NDArithmeticMixin.add`. ``None`` means that the result will have no
127    uncertainty, ``False`` means it takes the uncertainty of the first operand
128    (if this does not exist from the second operand) as the result's
129    uncertainty. This behavior is also explained in the docstring for the
130    different arithmetic operations.
131
132    Decomposing the units is not attempted, mainly due to the internal mechanics
133    of `~astropy.units.Quantity`, so the resulting data might have units like
134    ``km/m`` if you divided for example 100km by 5m. So this Mixin has adopted
135    this behavior.
136
137    Examples
138    --------
139    Using this Mixin with `~astropy.nddata.NDData`:
140
141        >>> from astropy.nddata import NDData, NDArithmeticMixin
142        >>> class NDDataWithMath(NDArithmeticMixin, NDData):
143        ...     pass
144
145    Using it with one operand on an instance::
146
147        >>> ndd = NDDataWithMath(100)
148        >>> ndd.add(20)
149        NDDataWithMath(120)
150
151    Using it with two operand on an instance::
152
153        >>> ndd = NDDataWithMath(-4)
154        >>> ndd.divide(1, ndd)
155        NDDataWithMath(-0.25)
156
157    Using it as classmethod requires two operands::
158
159        >>> NDDataWithMath.subtract(5, 4)
160        NDDataWithMath(1)
161
162    """
163
164    def _arithmetic(self, operation, operand,
165                    propagate_uncertainties=True, handle_mask=np.logical_or,
166                    handle_meta=None, uncertainty_correlation=0,
167                    compare_wcs='first_found', **kwds):
168        """
169        Base method which calculates the result of the arithmetic operation.
170
171        This method determines the result of the arithmetic operation on the
172        ``data`` including their units and then forwards to other methods
173        to calculate the other properties for the result (like uncertainty).
174
175        Parameters
176        ----------
177        operation : callable
178            The operation that is performed on the `NDData`. Supported are
179            `numpy.add`, `numpy.subtract`, `numpy.multiply` and
180            `numpy.true_divide`.
181
182        operand : same type (class) as self
183            see :meth:`NDArithmeticMixin.add`
184
185        propagate_uncertainties : `bool` or ``None``, optional
186            see :meth:`NDArithmeticMixin.add`
187
188        handle_mask : callable, ``'first_found'`` or ``None``, optional
189            see :meth:`NDArithmeticMixin.add`
190
191        handle_meta : callable, ``'first_found'`` or ``None``, optional
192            see :meth:`NDArithmeticMixin.add`
193
194        compare_wcs : callable, ``'first_found'`` or ``None``, optional
195            see :meth:`NDArithmeticMixin.add`
196
197        uncertainty_correlation : ``Number`` or `~numpy.ndarray`, optional
198            see :meth:`NDArithmeticMixin.add`
199
200        kwargs :
201            Any other parameter that should be passed to the
202            different :meth:`NDArithmeticMixin._arithmetic_mask` (or wcs, ...)
203            methods.
204
205        Returns
206        -------
207        result : ndarray or `~astropy.units.Quantity`
208            The resulting data as array (in case both operands were without
209            unit) or as quantity if at least one had a unit.
210
211        kwargs : `dict`
212            The kwargs should contain all the other attributes (besides data
213            and unit) needed to create a new instance for the result. Creating
214            the new instance is up to the calling method, for example
215            :meth:`NDArithmeticMixin.add`.
216
217        """
218        # Find the appropriate keywords for the appropriate method (not sure
219        # if data and uncertainty are ever used ...)
220        kwds2 = {'mask': {}, 'meta': {}, 'wcs': {},
221                 'data': {}, 'uncertainty': {}}
222        for i in kwds:
223            splitted = i.split('_', 1)
224            try:
225                kwds2[splitted[0]][splitted[1]] = kwds[i]
226            except KeyError:
227                raise KeyError(f'Unknown prefix {splitted[0]} for parameter {i}')
228
229        kwargs = {}
230
231        # First check that the WCS allows the arithmetic operation
232        if compare_wcs is None:
233            kwargs['wcs'] = None
234        elif compare_wcs in ['ff', 'first_found']:
235            if self.wcs is None:
236                kwargs['wcs'] = deepcopy(operand.wcs)
237            else:
238                kwargs['wcs'] = deepcopy(self.wcs)
239        else:
240            kwargs['wcs'] = self._arithmetic_wcs(operation, operand,
241                                                 compare_wcs, **kwds2['wcs'])
242
243        # Then calculate the resulting data (which can but not needs to be a
244        # quantity)
245        result = self._arithmetic_data(operation, operand, **kwds2['data'])
246
247        # Determine the other properties
248        if propagate_uncertainties is None:
249            kwargs['uncertainty'] = None
250        elif not propagate_uncertainties:
251            if self.uncertainty is None:
252                kwargs['uncertainty'] = deepcopy(operand.uncertainty)
253            else:
254                kwargs['uncertainty'] = deepcopy(self.uncertainty)
255        else:
256            kwargs['uncertainty'] = self._arithmetic_uncertainty(
257                operation, operand, result, uncertainty_correlation,
258                **kwds2['uncertainty'])
259
260        if handle_mask is None:
261            kwargs['mask'] = None
262        elif handle_mask in ['ff', 'first_found']:
263            if self.mask is None:
264                kwargs['mask'] = deepcopy(operand.mask)
265            else:
266                kwargs['mask'] = deepcopy(self.mask)
267        else:
268            kwargs['mask'] = self._arithmetic_mask(operation, operand,
269                                                   handle_mask,
270                                                   **kwds2['mask'])
271
272        if handle_meta is None:
273            kwargs['meta'] = None
274        elif handle_meta in ['ff', 'first_found']:
275            if not self.meta:
276                kwargs['meta'] = deepcopy(operand.meta)
277            else:
278                kwargs['meta'] = deepcopy(self.meta)
279        else:
280            kwargs['meta'] = self._arithmetic_meta(
281                operation, operand, handle_meta, **kwds2['meta'])
282
283        # Wrap the individual results into a new instance of the same class.
284        return result, kwargs
285
286    def _arithmetic_data(self, operation, operand, **kwds):
287        """
288        Calculate the resulting data
289
290        Parameters
291        ----------
292        operation : callable
293            see `NDArithmeticMixin._arithmetic` parameter description.
294
295        operand : `NDData`-like instance
296            The second operand wrapped in an instance of the same class as
297            self.
298
299        kwds :
300            Additional parameters.
301
302        Returns
303        -------
304        result_data : ndarray or `~astropy.units.Quantity`
305            If both operands had no unit the resulting data is a simple numpy
306            array, but if any of the operands had a unit the return is a
307            Quantity.
308        """
309
310        # Do the calculation with or without units
311        if self.unit is None and operand.unit is None:
312            result = operation(self.data, operand.data)
313        elif self.unit is None:
314            result = operation(self.data << dimensionless_unscaled,
315                               operand.data << operand.unit)
316        elif operand.unit is None:
317            result = operation(self.data << self.unit,
318                               operand.data << dimensionless_unscaled)
319        else:
320            result = operation(self.data << self.unit,
321                               operand.data << operand.unit)
322
323        return result
324
325    def _arithmetic_uncertainty(self, operation, operand, result, correlation,
326                                **kwds):
327        """
328        Calculate the resulting uncertainty.
329
330        Parameters
331        ----------
332        operation : callable
333            see :meth:`NDArithmeticMixin._arithmetic` parameter description.
334
335        operand : `NDData`-like instance
336            The second operand wrapped in an instance of the same class as
337            self.
338
339        result : `~astropy.units.Quantity` or `~numpy.ndarray`
340            The result of :meth:`NDArithmeticMixin._arithmetic_data`.
341
342        correlation : number or `~numpy.ndarray`
343            see :meth:`NDArithmeticMixin.add` parameter description.
344
345        kwds :
346            Additional parameters.
347
348        Returns
349        -------
350        result_uncertainty : `NDUncertainty` subclass instance or None
351            The resulting uncertainty already saved in the same `NDUncertainty`
352            subclass that ``self`` had (or ``operand`` if self had no
353            uncertainty). ``None`` only if both had no uncertainty.
354        """
355
356        # Make sure these uncertainties are NDUncertainties so this kind of
357        # propagation is possible.
358        if (self.uncertainty is not None and
359                not isinstance(self.uncertainty, NDUncertainty)):
360            raise TypeError("Uncertainty propagation is only defined for "
361                            "subclasses of NDUncertainty.")
362        if (operand.uncertainty is not None and
363                not isinstance(operand.uncertainty, NDUncertainty)):
364            raise TypeError("Uncertainty propagation is only defined for "
365                            "subclasses of NDUncertainty.")
366
367        # Now do the uncertainty propagation
368        # TODO: There is no enforced requirement that actually forbids the
369        # uncertainty to have negative entries but with correlation the
370        # sign of the uncertainty DOES matter.
371        if self.uncertainty is None and operand.uncertainty is None:
372            # Neither has uncertainties so the result should have none.
373            return None
374        elif self.uncertainty is None:
375            # Create a temporary uncertainty to allow uncertainty propagation
376            # to yield the correct results. (issue #4152)
377            self.uncertainty = operand.uncertainty.__class__(None)
378            result_uncert = self.uncertainty.propagate(operation, operand,
379                                                       result, correlation)
380            # Delete the temporary uncertainty again.
381            self.uncertainty = None
382            return result_uncert
383
384        elif operand.uncertainty is None:
385            # As with self.uncertainty is None but the other way around.
386            operand.uncertainty = self.uncertainty.__class__(None)
387            result_uncert = self.uncertainty.propagate(operation, operand,
388                                                       result, correlation)
389            operand.uncertainty = None
390            return result_uncert
391
392        else:
393            # Both have uncertainties so just propagate.
394            return self.uncertainty.propagate(operation, operand, result,
395                                              correlation)
396
397    def _arithmetic_mask(self, operation, operand, handle_mask, **kwds):
398        """
399        Calculate the resulting mask
400
401        This is implemented as the piecewise ``or`` operation if both have a
402        mask.
403
404        Parameters
405        ----------
406        operation : callable
407            see :meth:`NDArithmeticMixin._arithmetic` parameter description.
408            By default, the ``operation`` will be ignored.
409
410        operand : `NDData`-like instance
411            The second operand wrapped in an instance of the same class as
412            self.
413
414        handle_mask : callable
415            see :meth:`NDArithmeticMixin.add`
416
417        kwds :
418            Additional parameters given to ``handle_mask``.
419
420        Returns
421        -------
422        result_mask : any type
423            If only one mask was present this mask is returned.
424            If neither had a mask ``None`` is returned. Otherwise
425            ``handle_mask`` must create (and copy) the returned mask.
426        """
427
428        # If only one mask is present we need not bother about any type checks
429        if self.mask is None and operand.mask is None:
430            return None
431        elif self.mask is None:
432            # Make a copy so there is no reference in the result.
433            return deepcopy(operand.mask)
434        elif operand.mask is None:
435            return deepcopy(self.mask)
436        else:
437            # Now lets calculate the resulting mask (operation enforces copy)
438            return handle_mask(self.mask, operand.mask, **kwds)
439
440    def _arithmetic_wcs(self, operation, operand, compare_wcs, **kwds):
441        """
442        Calculate the resulting wcs.
443
444        There is actually no calculation involved but it is a good place to
445        compare wcs information of both operands. This is currently not working
446        properly with `~astropy.wcs.WCS` (which is the suggested class for
447        storing as wcs property) but it will not break it neither.
448
449        Parameters
450        ----------
451        operation : callable
452            see :meth:`NDArithmeticMixin._arithmetic` parameter description.
453            By default, the ``operation`` will be ignored.
454
455        operand : `NDData` instance or subclass
456            The second operand wrapped in an instance of the same class as
457            self.
458
459        compare_wcs : callable
460            see :meth:`NDArithmeticMixin.add` parameter description.
461
462        kwds :
463            Additional parameters given to ``compare_wcs``.
464
465        Raises
466        ------
467        ValueError
468            If ``compare_wcs`` returns ``False``.
469
470        Returns
471        -------
472        result_wcs : any type
473            The ``wcs`` of the first operand is returned.
474        """
475
476        # ok, not really arithmetics but we need to check which wcs makes sense
477        # for the result and this is an ideal place to compare the two WCS,
478        # too.
479
480        # I'll assume that the comparison returned None or False in case they
481        # are not equal.
482        if not compare_wcs(self.wcs, operand.wcs, **kwds):
483            raise ValueError("WCS are not equal.")
484
485        return deepcopy(self.wcs)
486
487    def _arithmetic_meta(self, operation, operand, handle_meta, **kwds):
488        """
489        Calculate the resulting meta.
490
491        Parameters
492        ----------
493        operation : callable
494            see :meth:`NDArithmeticMixin._arithmetic` parameter description.
495            By default, the ``operation`` will be ignored.
496
497        operand : `NDData`-like instance
498            The second operand wrapped in an instance of the same class as
499            self.
500
501        handle_meta : callable
502            see :meth:`NDArithmeticMixin.add`
503
504        kwds :
505            Additional parameters given to ``handle_meta``.
506
507        Returns
508        -------
509        result_meta : any type
510            The result of ``handle_meta``.
511        """
512        # Just return what handle_meta does with both of the metas.
513        return handle_meta(self.meta, operand.meta, **kwds)
514
515    @sharedmethod
516    @format_doc(_arit_doc, name='addition', op='+')
517    def add(self, operand, operand2=None, **kwargs):
518        return self._prepare_then_do_arithmetic(np.add, operand, operand2,
519                                                **kwargs)
520
521    @sharedmethod
522    @format_doc(_arit_doc, name='subtraction', op='-')
523    def subtract(self, operand, operand2=None, **kwargs):
524        return self._prepare_then_do_arithmetic(np.subtract, operand, operand2,
525                                                **kwargs)
526
527    @sharedmethod
528    @format_doc(_arit_doc, name="multiplication", op="*")
529    def multiply(self, operand, operand2=None, **kwargs):
530        return self._prepare_then_do_arithmetic(np.multiply, operand, operand2,
531                                                **kwargs)
532
533    @sharedmethod
534    @format_doc(_arit_doc, name="division", op="/")
535    def divide(self, operand, operand2=None, **kwargs):
536        return self._prepare_then_do_arithmetic(np.true_divide, operand,
537                                                operand2, **kwargs)
538
539    @sharedmethod
540    def _prepare_then_do_arithmetic(self_or_cls, operation, operand, operand2,
541                                    **kwargs):
542        """Intermediate method called by public arithmetics (i.e. ``add``)
543        before the processing method (``_arithmetic``) is invoked.
544
545        .. warning::
546            Do not override this method in subclasses.
547
548        This method checks if it was called as instance or as class method and
549        then wraps the operands and the result from ``_arithmetics`` in the
550        appropriate subclass.
551
552        Parameters
553        ----------
554        self_or_cls : instance or class
555            ``sharedmethod`` behaves like a normal method if called on the
556            instance (then this parameter is ``self``) but like a classmethod
557            when called on the class (then this parameter is ``cls``).
558
559        operations : callable
560            The operation (normally a numpy-ufunc) that represents the
561            appropriate action.
562
563        operand, operand2, kwargs :
564            See for example ``add``.
565
566        Result
567        ------
568        result : `~astropy.nddata.NDData`-like
569            Depending how this method was called either ``self_or_cls``
570            (called on class) or ``self_or_cls.__class__`` (called on instance)
571            is the NDData-subclass that is used as wrapper for the result.
572        """
573        # DO NOT OVERRIDE THIS METHOD IN SUBCLASSES.
574
575        if isinstance(self_or_cls, NDArithmeticMixin):
576            # True means it was called on the instance, so self_or_cls is
577            # a reference to self
578            cls = self_or_cls.__class__
579
580            if operand2 is None:
581                # Only one operand was given. Set operand2 to operand and
582                # operand to self so that we call the appropriate method of the
583                # operand.
584                operand2 = operand
585                operand = self_or_cls
586            else:
587                # Convert the first operand to the class of this method.
588                # This is important so that always the correct _arithmetics is
589                # called later that method.
590                operand = cls(operand)
591
592        else:
593            # It was used as classmethod so self_or_cls represents the cls
594            cls = self_or_cls
595
596            # It was called on the class so we expect two operands!
597            if operand2 is None:
598                raise TypeError("operand2 must be given when the method isn't "
599                                "called on an instance.")
600
601            # Convert to this class. See above comment why.
602            operand = cls(operand)
603
604        # At this point operand, operand2, kwargs and cls are determined.
605
606        # Let's try to convert operand2 to the class of operand to allows for
607        # arithmetic operations with numbers, lists, numpy arrays, numpy masked
608        # arrays, astropy quantities, masked quantities and of other subclasses
609        # of NDData.
610        operand2 = cls(operand2)
611
612        # Now call the _arithmetics method to do the arithmetics.
613        result, init_kwds = operand._arithmetic(operation, operand2, **kwargs)
614
615        # Return a new class based on the result
616        return cls(result, **init_kwds)
617