1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import pytest 16import sympy 17 18import cirq 19 20 21def test_periodic_value_equality(): 22 eq = cirq.testing.EqualsTester() 23 eq.add_equality_group( 24 cirq.PeriodicValue(1, 2), 25 cirq.PeriodicValue(1, 2), 26 cirq.PeriodicValue(3, 2), 27 cirq.PeriodicValue(3, 2), 28 cirq.PeriodicValue(5, 2), 29 cirq.PeriodicValue(-1, 2), 30 ) 31 eq.add_equality_group( 32 cirq.PeriodicValue(1.5, 2.0), 33 cirq.PeriodicValue(1.5, 2.0), 34 ) 35 eq.add_equality_group(cirq.PeriodicValue(0, 2)) 36 eq.add_equality_group(cirq.PeriodicValue(1, 3)) 37 eq.add_equality_group(cirq.PeriodicValue(2, 4)) 38 39 40def test_periodic_value_approx_eq_basic(): 41 assert cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.0, 2.0), atol=0.1) 42 assert cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.2, 2.0), atol=0.3) 43 assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.2, 2.0), atol=0.1) 44 assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.0, 2.2), atol=0.3) 45 assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.0, 2.2), atol=0.1) 46 assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.2, 2.2), atol=0.3) 47 assert not cirq.approx_eq(cirq.PeriodicValue(1.0, 2.0), cirq.PeriodicValue(1.2, 2.2), atol=0.1) 48 49 50def test_periodic_value_approx_eq_normalized(): 51 assert cirq.approx_eq(cirq.PeriodicValue(1.0, 3.0), cirq.PeriodicValue(4.1, 3.0), atol=0.2) 52 assert cirq.approx_eq(cirq.PeriodicValue(1.0, 3.0), cirq.PeriodicValue(-2.1, 3.0), atol=0.2) 53 54 55def test_periodic_value_approx_eq_boundary(): 56 assert cirq.approx_eq(cirq.PeriodicValue(0.0, 2.0), cirq.PeriodicValue(1.9, 2.0), atol=0.2) 57 assert cirq.approx_eq(cirq.PeriodicValue(0.1, 2.0), cirq.PeriodicValue(1.9, 2.0), atol=0.3) 58 assert cirq.approx_eq(cirq.PeriodicValue(1.9, 2.0), cirq.PeriodicValue(0.1, 2.0), atol=0.3) 59 assert not cirq.approx_eq(cirq.PeriodicValue(0.1, 2.0), cirq.PeriodicValue(1.9, 2.0), atol=0.1) 60 assert cirq.approx_eq(cirq.PeriodicValue(0, 1.0), cirq.PeriodicValue(0.5, 1.0), atol=0.6) 61 assert not cirq.approx_eq(cirq.PeriodicValue(0, 1.0), cirq.PeriodicValue(0.5, 1.0), atol=0.1) 62 assert cirq.approx_eq(cirq.PeriodicValue(0.4, 1.0), cirq.PeriodicValue(0.6, 1.0), atol=0.3) 63 64 65def test_periodic_value_types_mismatch(): 66 assert not cirq.approx_eq(cirq.PeriodicValue(0.0, 2.0), 0.0, atol=0.2) 67 assert not cirq.approx_eq(0.0, cirq.PeriodicValue(0.0, 2.0), atol=0.2) 68 69 70@pytest.mark.parametrize( 71 'value, is_parameterized, parameter_names', 72 [ 73 (cirq.PeriodicValue(1.0, 3.0), False, set()), 74 (cirq.PeriodicValue(0.0, sympy.Symbol('p')), True, {'p'}), 75 (cirq.PeriodicValue(sympy.Symbol('v'), 3.0), True, {'v'}), 76 (cirq.PeriodicValue(sympy.Symbol('v'), sympy.Symbol('p')), True, {'p', 'v'}), 77 ], 78) 79@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) 80def test_periodic_value_is_parameterized(value, is_parameterized, parameter_names, resolve_fn): 81 assert cirq.is_parameterized(value) == is_parameterized 82 assert cirq.parameter_names(value) == parameter_names 83 resolved = resolve_fn(value, {p: 1 for p in parameter_names}) 84 assert not cirq.is_parameterized(resolved) 85 86 87@pytest.mark.parametrize( 88 'val', 89 [ 90 cirq.PeriodicValue(0.4, 1.0), 91 cirq.PeriodicValue(0.0, 2.0), 92 cirq.PeriodicValue(1.0, 3), 93 cirq.PeriodicValue(-2.1, 3.0), 94 cirq.PeriodicValue(sympy.Symbol('v'), sympy.Symbol('p')), 95 cirq.PeriodicValue(2.0, sympy.Symbol('p')), 96 cirq.PeriodicValue(sympy.Symbol('v'), 3), 97 ], 98) 99def test_periodic_value_repr(val): 100 cirq.testing.assert_equivalent_repr(val) 101