1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Any, Union, Iterable
16from fractions import Fraction
17from decimal import Decimal
18
19import numbers
20import numpy as np
21import sympy
22
23from typing_extensions import Protocol
24
25from cirq._doc import doc_private
26
27
28class SupportsApproximateEquality(Protocol):
29    """Object which can be compared approximately."""
30
31    @doc_private
32    def _approx_eq_(self, other: Any, *, atol: Union[int, float]) -> bool:
33        """Approximate comparator.
34
35        Types implementing this protocol define their own logic for approximate
36        comparison with other types.
37
38        Args:
39            other: Target object for approximate comparison.
40            atol: The minimum absolute tolerance. See np.isclose() documentation
41                  for details.
42
43        Returns:
44            True if objects are approximately equal, False otherwise. Returns
45            NotImplemented when approximate equality is not implemented for
46            given types.
47        """
48
49
50# TODO(#3388) Add documentation for Raises.
51# pylint: disable=missing-raises-doc
52def approx_eq(val: Any, other: Any, *, atol: Union[int, float] = 1e-8) -> bool:
53    """Approximately compares two objects.
54
55    If `val` implements SupportsApproxEquality protocol then it is invoked and
56    takes precedence over all other checks:
57     - For primitive numeric types `int` and `float` approximate equality is
58       delegated to math.isclose().
59     - For complex primitive type the real and imaginary parts are treated
60       independently and compared using math.isclose().
61     - For `val` and `other` both iterable of the same length, consecutive
62       elements are compared recursively. Types of `val` and `other` does not
63       necessarily needs to match each other. They just need to be iterable and
64       have the same structure.
65
66    Args:
67        val: Source object for approximate comparison.
68        other: Target object for approximate comparison.
69        atol: The minimum absolute tolerance. See np.isclose() documentation for
70              details. Defaults to 1e-8 which matches np.isclose() default
71              absolute tolerance.
72
73    Returns:
74        True if objects are approximately equal, False otherwise.
75    """
76
77    # Check if val defines approximate equality via _approx_eq_. This takes
78    # precedence over all other overloads.
79    approx_eq_getter = getattr(val, '_approx_eq_', None)
80    if approx_eq_getter is not None:
81        result = approx_eq_getter(other, atol)
82        if result is not NotImplemented:
83            return result
84
85    # The same for other to make approx_eq symmetric.
86    other_approx_eq_getter = getattr(other, '_approx_eq_', None)
87    if other_approx_eq_getter is not None:
88        result = other_approx_eq_getter(val, atol)
89        if result is not NotImplemented:
90            return result
91
92    # Compare primitive types directly.
93    if isinstance(val, numbers.Number):
94        if not isinstance(other, numbers.Number):
95            return False
96        result = _isclose(val, other, atol=atol)
97        if result is not NotImplemented:
98            return result
99
100    if isinstance(val, str):
101        return val == other
102
103    if isinstance(val, sympy.Basic) or isinstance(other, sympy.Basic):
104        delta = sympy.Abs(other - val).simplify()
105        if not delta.is_number:
106            raise AttributeError(
107                'Insufficient information to decide whether '
108                'expressions are approximately equal '
109                f'[{val}] vs [{other}]'
110            )
111        return sympy.LessThan(delta, atol) == sympy.true
112
113    # If the values are iterable, try comparing recursively on items.
114    if isinstance(val, Iterable) and isinstance(other, Iterable):
115        return _approx_eq_iterables(val, other, atol=atol)
116
117    # Last resort: exact equality.
118    return val == other
119
120
121# pylint: enable=missing-raises-doc
122def _approx_eq_iterables(val: Iterable, other: Iterable, *, atol: Union[int, float]) -> bool:
123    """Iterates over arguments and calls approx_eq recursively.
124
125    Types of `val` and `other` does not necessarily needs to match each other.
126    They just need to be iterable of the same length and have the same
127    structure, approx_eq() will be called on each consecutive element of `val`
128    and `other`.
129
130    Args:
131        val: Source for approximate comparison.
132        other: Target for approximate comparison.
133        atol: The minimum absolute tolerance. See np.isclose() documentation for
134              details.
135
136    Returns:
137        True if objects are approximately equal, False otherwise. Returns
138        NotImplemented when approximate equality is not implemented for given
139        types.
140    """
141
142    iter1 = iter(val)
143    iter2 = iter(other)
144    done = object()
145    cur_item1 = None
146
147    while cur_item1 is not done:
148        try:
149            cur_item1 = next(iter1)
150        except StopIteration:
151            cur_item1 = done
152        try:
153            cur_item2 = next(iter2)
154        except StopIteration:
155            cur_item2 = done
156
157        if not approx_eq(cur_item1, cur_item2, atol=atol):
158            return False
159
160    return True
161
162
163def _isclose(a: Any, b: Any, *, atol: Union[int, float]) -> bool:
164    """Convenience wrapper around np.isclose."""
165
166    # support casting some standard numeric types
167    x1 = np.asarray([a])
168    if isinstance(a, (Fraction, Decimal)):
169        x1 = x1.astype(np.float64)
170    x2 = np.asarray([b])
171    if isinstance(b, (Fraction, Decimal)):
172        x2 = x2.astype(np.float64)
173
174    # workaround np.isfinite type limitations. Cast to bool to avoid np.bool_
175    try:
176        result = bool(np.isclose(x1, x2, atol=atol, rtol=0.0)[0])
177    except TypeError:
178        return NotImplemented
179
180    return result
181