1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import pytest, sympy 16 17import cirq 18from cirq.study import ParamResolver 19 20 21@pytest.mark.parametrize( 22 'resolve_fn', 23 [ 24 cirq.resolve_parameters, 25 cirq.resolve_parameters_once, 26 ], 27) 28def test_resolve_parameters(resolve_fn): 29 class NoMethod: 30 pass 31 32 class ReturnsNotImplemented: 33 def _is_parameterized_(self): 34 return NotImplemented 35 36 def _resolve_parameters_(self, resolver, recursive): 37 return NotImplemented 38 39 class SimpleParameterSwitch: 40 def __init__(self, var): 41 self.parameter = var 42 43 def _is_parameterized_(self) -> bool: 44 return self.parameter == 0 45 46 def _resolve_parameters_(self, resolver: ParamResolver, recursive: bool): 47 self.parameter = resolver.value_of(self.parameter, recursive) 48 return self 49 50 assert not cirq.is_parameterized(NoMethod()) 51 assert not cirq.is_parameterized(ReturnsNotImplemented()) 52 assert not cirq.is_parameterized(SimpleParameterSwitch('a')) 53 assert cirq.is_parameterized(SimpleParameterSwitch(0)) 54 55 ni = ReturnsNotImplemented() 56 d = {'a': 0} 57 r = cirq.ParamResolver(d) 58 no = NoMethod() 59 assert resolve_fn(no, r) == no 60 assert resolve_fn(no, d) == no 61 assert resolve_fn(ni, r) == ni 62 assert resolve_fn(SimpleParameterSwitch(0), r).parameter == 0 63 assert resolve_fn(SimpleParameterSwitch('a'), r).parameter == 0 64 assert resolve_fn(SimpleParameterSwitch('a'), d).parameter == 0 65 assert resolve_fn(sympy.Symbol('a'), r) == 0 66 67 a, b, c = tuple(sympy.Symbol(l) for l in 'abc') 68 x, y, z = 0, 4, 7 69 resolver = {a: x, b: y, c: z} 70 71 assert resolve_fn((a, b, c), resolver) == (x, y, z) 72 assert resolve_fn([a, b, c], resolver) == [x, y, z] 73 assert resolve_fn((x, y, z), resolver) == (x, y, z) 74 assert resolve_fn([x, y, z], resolver) == [x, y, z] 75 assert resolve_fn((), resolver) == () 76 assert resolve_fn([], resolver) == [] 77 assert resolve_fn(1, resolver) == 1 78 assert resolve_fn(1.1, resolver) == 1.1 79 assert resolve_fn(1j, resolver) == 1j 80 81 82def test_is_parameterized(): 83 a, b = tuple(sympy.Symbol(l) for l in 'ab') 84 x, y = 0, 4 85 assert not cirq.is_parameterized((x, y)) 86 assert not cirq.is_parameterized([x, y]) 87 assert cirq.is_parameterized([a, b]) 88 assert cirq.is_parameterized([a, x]) 89 assert cirq.is_parameterized((a, b)) 90 assert cirq.is_parameterized((a, x)) 91 assert not cirq.is_parameterized(()) 92 assert not cirq.is_parameterized([]) 93 assert not cirq.is_parameterized(1) 94 assert not cirq.is_parameterized(1.1) 95 assert not cirq.is_parameterized(1j) 96 97 98def test_parameter_names(): 99 a, b, c = tuple(sympy.Symbol(l) for l in 'abc') 100 x, y, z = 0, 4, 7 101 assert cirq.parameter_names((a, b, c)) == {'a', 'b', 'c'} 102 assert cirq.parameter_names([a, b, c]) == {'a', 'b', 'c'} 103 assert cirq.parameter_names((x, y, z)) == set() 104 assert cirq.parameter_names([x, y, z]) == set() 105 assert cirq.parameter_names(()) == set() 106 assert cirq.parameter_names([]) == set() 107 assert cirq.parameter_names(1) == set() 108 assert cirq.parameter_names(1.1) == set() 109 assert cirq.parameter_names(1j) == set() 110 111 112@pytest.mark.parametrize( 113 'resolve_fn', 114 [ 115 cirq.resolve_parameters, 116 cirq.resolve_parameters_once, 117 ], 118) 119def test_skips_empty_resolution(resolve_fn): 120 class Tester: 121 def _resolve_parameters_(self, resolver, recursive): 122 return 5 123 124 t = Tester() 125 assert resolve_fn(t, {}) is t 126 assert resolve_fn(t, {'x': 2}) == 5 127 128 129def test_recursive_resolve(): 130 a, b, c = [sympy.Symbol(l) for l in 'abc'] 131 resolver = cirq.ParamResolver({a: b + 3, b: c + 2, c: 1}) 132 assert cirq.resolve_parameters_once(a, resolver) == b + 3 133 assert cirq.resolve_parameters(a, resolver) == 6 134 assert cirq.resolve_parameters_once(b, resolver) == c + 2 135 assert cirq.resolve_parameters(b, resolver) == 3 136 assert cirq.resolve_parameters_once(c, resolver) == 1 137 assert cirq.resolve_parameters(c, resolver) == 1 138 139 assert cirq.resolve_parameters_once([a, b], {a: b, b: c}) == [b, c] 140 assert cirq.resolve_parameters_once(a, {}) == a 141 142 resolver = cirq.ParamResolver({a: b, b: a}) 143 assert cirq.resolve_parameters_once(a, resolver) == b 144 with pytest.raises(RecursionError): 145 _ = cirq.resolve_parameters(a, resolver) 146