1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import AbstractSet, Any, TYPE_CHECKING, Union
16
17import sympy
18
19from cirq._compat import proper_repr
20
21
22if TYPE_CHECKING:
23    import cirq
24
25
26class PeriodicValue:
27    """Wrapper for periodic numerical values.
28
29    Wrapper for periodic numerical types which implements `__eq__`, `__ne__`,
30    `__hash__` and `_approx_eq_` so that values which are in the same
31    equivalence class are treated as equal.
32
33    Internally the `value` passed to `__init__` is normalized to the interval
34    [0, `period`) and stored as that. Specialized version of `_approx_eq_` is
35    provided to cover values which end up at the opposite edges of this
36    interval.
37    """
38
39    def __init__(self, value: Union[int, float], period: Union[int, float]):
40        """Initializes the equivalence class.
41
42        Args:
43            value: numerical value to wrap.
44            period: periodicity of the numerical value.
45        """
46        self.value = value % period
47        self.period = period
48
49    def __eq__(self, other: Any) -> bool:
50        if not isinstance(other, type(self)):
51            return NotImplemented
52        return (self.value, self.period) == (other.value, other.period)
53
54    def __ne__(self, other: Any) -> bool:
55        return not self == other
56
57    def __hash__(self) -> int:
58        return hash((type(self), self.value, self.period))
59
60    def _approx_eq_(self, other: Any, atol: float) -> bool:
61        """Implementation of `SupportsApproximateEquality` protocol."""
62        # HACK: Avoids circular dependencies.
63        from cirq.protocols import approx_eq
64
65        if not isinstance(other, type(self)):
66            return NotImplemented
67
68        # self.value = value % period in __init__() creates a Mod
69        if isinstance(other.value, sympy.Mod):
70            return self.value == other.value
71        # Periods must be exactly equal to avoid drift of normalized value when
72        # original value increases.
73        if self.period != other.period:
74            return False
75
76        low = min(self.value, other.value)
77        high = max(self.value, other.value)
78
79        # Shift lower value outside of normalization interval in case low and
80        # high values are at the opposite borders of normalization interval.
81        if high - low > self.period / 2:
82            low += self.period
83
84        return approx_eq(low, high, atol=atol)
85
86    def __repr__(self) -> str:
87        v = proper_repr(self.value)
88        p = proper_repr(self.period)
89        return f'cirq.PeriodicValue({v}, {p})'
90
91    def _is_parameterized_(self) -> bool:
92        # HACK: Avoids circular dependencies.
93        from cirq.protocols import is_parameterized
94
95        return is_parameterized(self.value) or is_parameterized(self.period)
96
97    def _parameter_names_(self) -> AbstractSet[str]:
98        # HACK: Avoids circular dependencies.
99        from cirq.protocols import parameter_names
100
101        return parameter_names(self.value) | parameter_names(self.period)
102
103    def _resolve_parameters_(
104        self, resolver: 'cirq.ParamResolver', recursive: bool
105    ) -> 'PeriodicValue':
106        # HACK: Avoids circular dependencies.
107        from cirq.protocols import resolve_parameters
108
109        return PeriodicValue(
110            value=resolve_parameters(self.value, resolver, recursive),
111            period=resolve_parameters(self.period, resolver, recursive),
112        )
113