1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import AbstractSet, Any, TYPE_CHECKING, Union 16 17import sympy 18 19from cirq._compat import proper_repr 20 21 22if TYPE_CHECKING: 23 import cirq 24 25 26class PeriodicValue: 27 """Wrapper for periodic numerical values. 28 29 Wrapper for periodic numerical types which implements `__eq__`, `__ne__`, 30 `__hash__` and `_approx_eq_` so that values which are in the same 31 equivalence class are treated as equal. 32 33 Internally the `value` passed to `__init__` is normalized to the interval 34 [0, `period`) and stored as that. Specialized version of `_approx_eq_` is 35 provided to cover values which end up at the opposite edges of this 36 interval. 37 """ 38 39 def __init__(self, value: Union[int, float], period: Union[int, float]): 40 """Initializes the equivalence class. 41 42 Args: 43 value: numerical value to wrap. 44 period: periodicity of the numerical value. 45 """ 46 self.value = value % period 47 self.period = period 48 49 def __eq__(self, other: Any) -> bool: 50 if not isinstance(other, type(self)): 51 return NotImplemented 52 return (self.value, self.period) == (other.value, other.period) 53 54 def __ne__(self, other: Any) -> bool: 55 return not self == other 56 57 def __hash__(self) -> int: 58 return hash((type(self), self.value, self.period)) 59 60 def _approx_eq_(self, other: Any, atol: float) -> bool: 61 """Implementation of `SupportsApproximateEquality` protocol.""" 62 # HACK: Avoids circular dependencies. 63 from cirq.protocols import approx_eq 64 65 if not isinstance(other, type(self)): 66 return NotImplemented 67 68 # self.value = value % period in __init__() creates a Mod 69 if isinstance(other.value, sympy.Mod): 70 return self.value == other.value 71 # Periods must be exactly equal to avoid drift of normalized value when 72 # original value increases. 73 if self.period != other.period: 74 return False 75 76 low = min(self.value, other.value) 77 high = max(self.value, other.value) 78 79 # Shift lower value outside of normalization interval in case low and 80 # high values are at the opposite borders of normalization interval. 81 if high - low > self.period / 2: 82 low += self.period 83 84 return approx_eq(low, high, atol=atol) 85 86 def __repr__(self) -> str: 87 v = proper_repr(self.value) 88 p = proper_repr(self.period) 89 return f'cirq.PeriodicValue({v}, {p})' 90 91 def _is_parameterized_(self) -> bool: 92 # HACK: Avoids circular dependencies. 93 from cirq.protocols import is_parameterized 94 95 return is_parameterized(self.value) or is_parameterized(self.period) 96 97 def _parameter_names_(self) -> AbstractSet[str]: 98 # HACK: Avoids circular dependencies. 99 from cirq.protocols import parameter_names 100 101 return parameter_names(self.value) | parameter_names(self.period) 102 103 def _resolve_parameters_( 104 self, resolver: 'cirq.ParamResolver', recursive: bool 105 ) -> 'PeriodicValue': 106 # HACK: Avoids circular dependencies. 107 from cirq.protocols import resolve_parameters 108 109 return PeriodicValue( 110 value=resolve_parameters(self.value, resolver, recursive), 111 period=resolve_parameters(self.period, resolver, recursive), 112 ) 113