1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest
16import sympy
17
18import cirq
19
20
21def test_periodic_value_equality():
22    eq = cirq.testing.EqualsTester()
23    eq.add_equality_group(
24        cirq.PeriodicValue(1, 2),
25        cirq.PeriodicValue(1, 2),
26        cirq.PeriodicValue(3, 2),
27        cirq.PeriodicValue(3, 2),
28        cirq.PeriodicValue(5, 2),
29        cirq.PeriodicValue(-1, 2),
30    )
31    eq.add_equality_group(
32        cirq.PeriodicValue(1.5, 2.0),
33        cirq.PeriodicValue(1.5, 2.0),
34    )
35    eq.add_equality_group(cirq.PeriodicValue(0, 2))
36    eq.add_equality_group(cirq.PeriodicValue(1, 3))
37    eq.add_equality_group(cirq.PeriodicValue(2, 4))
38
39
40def test_periodic_value_approx_eq_basic():
41    assert cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.0, 2.0), atol=0.1)
42    assert cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.2, 2.0), atol=0.3)
43    assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.2, 2.0), atol=0.1)
44    assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.0, 2.2), atol=0.3)
45    assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.0, 2.2), atol=0.1)
46    assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.2, 2.2), atol=0.3)
47    assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.2, 2.2), atol=0.1)
48
49
50def test_periodic_value_approx_eq_normalized():
51    assert cirq.approx_eq(cirq.PeriodicValue(1.0, 3.0), cirq.PeriodicValue(4.1, 3.0), atol=0.2)
52    assert cirq.approx_eq(cirq.PeriodicValue(1.0, 3.0), cirq.PeriodicValue(-2.1, 3.0), atol=0.2)
53
54
55def test_periodic_value_approx_eq_boundary():
56    assert cirq.approx_eq(cirq.PeriodicValue(0.0, 2.0), cirq.PeriodicValue(1.9, 2.0), atol=0.2)
57    assert cirq.approx_eq(cirq.PeriodicValue(0.1, 2.0), cirq.PeriodicValue(1.9, 2.0), atol=0.3)
58    assert cirq.approx_eq(cirq.PeriodicValue(1.9, 2.0), cirq.PeriodicValue(0.1, 2.0), atol=0.3)
59    assert not cirq.approx_eq(cirq.PeriodicValue(0.1, 2.0), cirq.PeriodicValue(1.9, 2.0), atol=0.1)
60    assert cirq.approx_eq(cirq.PeriodicValue(0, 1.0), cirq.PeriodicValue(0.5, 1.0), atol=0.6)
61    assert not cirq.approx_eq(cirq.PeriodicValue(0, 1.0), cirq.PeriodicValue(0.5, 1.0), atol=0.1)
62    assert cirq.approx_eq(cirq.PeriodicValue(0.4, 1.0), cirq.PeriodicValue(0.6, 1.0), atol=0.3)
63
64
65def test_periodic_value_types_mismatch():
66    assert not cirq.approx_eq(cirq.PeriodicValue(0.0, 2.0), 0.0, atol=0.2)
67    assert not cirq.approx_eq(0.0, cirq.PeriodicValue(0.0, 2.0), atol=0.2)
68
69
70@pytest.mark.parametrize(
71    'value, is_parameterized, parameter_names',
72    [
73        (cirq.PeriodicValue(1.0, 3.0), False, set()),
74        (cirq.PeriodicValue(0.0, sympy.Symbol('p')), True, {'p'}),
75        (cirq.PeriodicValue(sympy.Symbol('v'), 3.0), True, {'v'}),
76        (cirq.PeriodicValue(sympy.Symbol('v'), sympy.Symbol('p')), True, {'p', 'v'}),
77    ],
78)
79@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
80def test_periodic_value_is_parameterized(value, is_parameterized, parameter_names, resolve_fn):
81    assert cirq.is_parameterized(value) == is_parameterized
82    assert cirq.parameter_names(value) == parameter_names
83    resolved = resolve_fn(value, {p: 1 for p in parameter_names})
84    assert not cirq.is_parameterized(resolved)
85
86
87@pytest.mark.parametrize(
88    'val',
89    [
90        cirq.PeriodicValue(0.4, 1.0),
91        cirq.PeriodicValue(0.0, 2.0),
92        cirq.PeriodicValue(1.0, 3),
93        cirq.PeriodicValue(-2.1, 3.0),
94        cirq.PeriodicValue(sympy.Symbol('v'), sympy.Symbol('p')),
95        cirq.PeriodicValue(2.0, sympy.Symbol('p')),
96        cirq.PeriodicValue(sympy.Symbol('v'), 3),
97    ],
98)
99def test_periodic_value_repr(val):
100    cirq.testing.assert_equivalent_repr(val)
101