1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import Any, Union, Iterable 16from fractions import Fraction 17from decimal import Decimal 18 19import numbers 20import numpy as np 21import sympy 22 23from typing_extensions import Protocol 24 25from cirq._doc import doc_private 26 27 28class SupportsApproximateEquality(Protocol): 29 """Object which can be compared approximately.""" 30 31 @doc_private 32 def _approx_eq_(self, other: Any, *, atol: Union[int, float]) -> bool: 33 """Approximate comparator. 34 35 Types implementing this protocol define their own logic for approximate 36 comparison with other types. 37 38 Args: 39 other: Target object for approximate comparison. 40 atol: The minimum absolute tolerance. See np.isclose() documentation 41 for details. 42 43 Returns: 44 True if objects are approximately equal, False otherwise. Returns 45 NotImplemented when approximate equality is not implemented for 46 given types. 47 """ 48 49 50# TODO(#3388) Add documentation for Raises. 51# pylint: disable=missing-raises-doc 52def approx_eq(val: Any, other: Any, *, atol: Union[int, float] = 1e-8) -> bool: 53 """Approximately compares two objects. 54 55 If `val` implements SupportsApproxEquality protocol then it is invoked and 56 takes precedence over all other checks: 57 - For primitive numeric types `int` and `float` approximate equality is 58 delegated to math.isclose(). 59 - For complex primitive type the real and imaginary parts are treated 60 independently and compared using math.isclose(). 61 - For `val` and `other` both iterable of the same length, consecutive 62 elements are compared recursively. Types of `val` and `other` does not 63 necessarily needs to match each other. They just need to be iterable and 64 have the same structure. 65 66 Args: 67 val: Source object for approximate comparison. 68 other: Target object for approximate comparison. 69 atol: The minimum absolute tolerance. See np.isclose() documentation for 70 details. Defaults to 1e-8 which matches np.isclose() default 71 absolute tolerance. 72 73 Returns: 74 True if objects are approximately equal, False otherwise. 75 """ 76 77 # Check if val defines approximate equality via _approx_eq_. This takes 78 # precedence over all other overloads. 79 approx_eq_getter = getattr(val, '_approx_eq_', None) 80 if approx_eq_getter is not None: 81 result = approx_eq_getter(other, atol) 82 if result is not NotImplemented: 83 return result 84 85 # The same for other to make approx_eq symmetric. 86 other_approx_eq_getter = getattr(other, '_approx_eq_', None) 87 if other_approx_eq_getter is not None: 88 result = other_approx_eq_getter(val, atol) 89 if result is not NotImplemented: 90 return result 91 92 # Compare primitive types directly. 93 if isinstance(val, numbers.Number): 94 if not isinstance(other, numbers.Number): 95 return False 96 result = _isclose(val, other, atol=atol) 97 if result is not NotImplemented: 98 return result 99 100 if isinstance(val, str): 101 return val == other 102 103 if isinstance(val, sympy.Basic) or isinstance(other, sympy.Basic): 104 delta = sympy.Abs(other - val).simplify() 105 if not delta.is_number: 106 raise AttributeError( 107 'Insufficient information to decide whether ' 108 'expressions are approximately equal ' 109 f'[{val}] vs [{other}]' 110 ) 111 return sympy.LessThan(delta, atol) == sympy.true 112 113 # If the values are iterable, try comparing recursively on items. 114 if isinstance(val, Iterable) and isinstance(other, Iterable): 115 return _approx_eq_iterables(val, other, atol=atol) 116 117 # Last resort: exact equality. 118 return val == other 119 120 121# pylint: enable=missing-raises-doc 122def _approx_eq_iterables(val: Iterable, other: Iterable, *, atol: Union[int, float]) -> bool: 123 """Iterates over arguments and calls approx_eq recursively. 124 125 Types of `val` and `other` does not necessarily needs to match each other. 126 They just need to be iterable of the same length and have the same 127 structure, approx_eq() will be called on each consecutive element of `val` 128 and `other`. 129 130 Args: 131 val: Source for approximate comparison. 132 other: Target for approximate comparison. 133 atol: The minimum absolute tolerance. See np.isclose() documentation for 134 details. 135 136 Returns: 137 True if objects are approximately equal, False otherwise. Returns 138 NotImplemented when approximate equality is not implemented for given 139 types. 140 """ 141 142 iter1 = iter(val) 143 iter2 = iter(other) 144 done = object() 145 cur_item1 = None 146 147 while cur_item1 is not done: 148 try: 149 cur_item1 = next(iter1) 150 except StopIteration: 151 cur_item1 = done 152 try: 153 cur_item2 = next(iter2) 154 except StopIteration: 155 cur_item2 = done 156 157 if not approx_eq(cur_item1, cur_item2, atol=atol): 158 return False 159 160 return True 161 162 163def _isclose(a: Any, b: Any, *, atol: Union[int, float]) -> bool: 164 """Convenience wrapper around np.isclose.""" 165 166 # support casting some standard numeric types 167 x1 = np.asarray([a]) 168 if isinstance(a, (Fraction, Decimal)): 169 x1 = x1.astype(np.float64) 170 x2 = np.asarray([b]) 171 if isinstance(b, (Fraction, Decimal)): 172 x2 = x2.astype(np.float64) 173 174 # workaround np.isfinite type limitations. Cast to bool to avoid np.bool_ 175 try: 176 result = bool(np.isclose(x1, x2, atol=atol, rtol=0.0)[0]) 177 except TypeError: 178 return NotImplemented 179 180 return result 181