1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest, sympy
16
17import cirq
18from cirq.study import ParamResolver
19
20
21@pytest.mark.parametrize(
22    'resolve_fn',
23    [
24        cirq.resolve_parameters,
25        cirq.resolve_parameters_once,
26    ],
27)
28def test_resolve_parameters(resolve_fn):
29    class NoMethod:
30        pass
31
32    class ReturnsNotImplemented:
33        def _is_parameterized_(self):
34            return NotImplemented
35
36        def _resolve_parameters_(self, resolver, recursive):
37            return NotImplemented
38
39    class SimpleParameterSwitch:
40        def __init__(self, var):
41            self.parameter = var
42
43        def _is_parameterized_(self) -> bool:
44            return self.parameter == 0
45
46        def _resolve_parameters_(self, resolver: ParamResolver, recursive: bool):
47            self.parameter = resolver.value_of(self.parameter, recursive)
48            return self
49
50    assert not cirq.is_parameterized(NoMethod())
51    assert not cirq.is_parameterized(ReturnsNotImplemented())
52    assert not cirq.is_parameterized(SimpleParameterSwitch('a'))
53    assert cirq.is_parameterized(SimpleParameterSwitch(0))
54
55    ni = ReturnsNotImplemented()
56    d = {'a': 0}
57    r = cirq.ParamResolver(d)
58    no = NoMethod()
59    assert resolve_fn(no, r) == no
60    assert resolve_fn(no, d) == no
61    assert resolve_fn(ni, r) == ni
62    assert resolve_fn(SimpleParameterSwitch(0), r).parameter == 0
63    assert resolve_fn(SimpleParameterSwitch('a'), r).parameter == 0
64    assert resolve_fn(SimpleParameterSwitch('a'), d).parameter == 0
65    assert resolve_fn(sympy.Symbol('a'), r) == 0
66
67    a, b, c = tuple(sympy.Symbol(l) for l in 'abc')
68    x, y, z = 0, 4, 7
69    resolver = {a: x, b: y, c: z}
70
71    assert resolve_fn((a, b, c), resolver) == (x, y, z)
72    assert resolve_fn([a, b, c], resolver) == [x, y, z]
73    assert resolve_fn((x, y, z), resolver) == (x, y, z)
74    assert resolve_fn([x, y, z], resolver) == [x, y, z]
75    assert resolve_fn((), resolver) == ()
76    assert resolve_fn([], resolver) == []
77    assert resolve_fn(1, resolver) == 1
78    assert resolve_fn(1.1, resolver) == 1.1
79    assert resolve_fn(1j, resolver) == 1j
80
81
82def test_is_parameterized():
83    a, b = tuple(sympy.Symbol(l) for l in 'ab')
84    x, y = 0, 4
85    assert not cirq.is_parameterized((x, y))
86    assert not cirq.is_parameterized([x, y])
87    assert cirq.is_parameterized([a, b])
88    assert cirq.is_parameterized([a, x])
89    assert cirq.is_parameterized((a, b))
90    assert cirq.is_parameterized((a, x))
91    assert not cirq.is_parameterized(())
92    assert not cirq.is_parameterized([])
93    assert not cirq.is_parameterized(1)
94    assert not cirq.is_parameterized(1.1)
95    assert not cirq.is_parameterized(1j)
96
97
98def test_parameter_names():
99    a, b, c = tuple(sympy.Symbol(l) for l in 'abc')
100    x, y, z = 0, 4, 7
101    assert cirq.parameter_names((a, b, c)) == {'a', 'b', 'c'}
102    assert cirq.parameter_names([a, b, c]) == {'a', 'b', 'c'}
103    assert cirq.parameter_names((x, y, z)) == set()
104    assert cirq.parameter_names([x, y, z]) == set()
105    assert cirq.parameter_names(()) == set()
106    assert cirq.parameter_names([]) == set()
107    assert cirq.parameter_names(1) == set()
108    assert cirq.parameter_names(1.1) == set()
109    assert cirq.parameter_names(1j) == set()
110
111
112@pytest.mark.parametrize(
113    'resolve_fn',
114    [
115        cirq.resolve_parameters,
116        cirq.resolve_parameters_once,
117    ],
118)
119def test_skips_empty_resolution(resolve_fn):
120    class Tester:
121        def _resolve_parameters_(self, resolver, recursive):
122            return 5
123
124    t = Tester()
125    assert resolve_fn(t, {}) is t
126    assert resolve_fn(t, {'x': 2}) == 5
127
128
129def test_recursive_resolve():
130    a, b, c = [sympy.Symbol(l) for l in 'abc']
131    resolver = cirq.ParamResolver({a: b + 3, b: c + 2, c: 1})
132    assert cirq.resolve_parameters_once(a, resolver) == b + 3
133    assert cirq.resolve_parameters(a, resolver) == 6
134    assert cirq.resolve_parameters_once(b, resolver) == c + 2
135    assert cirq.resolve_parameters(b, resolver) == 3
136    assert cirq.resolve_parameters_once(c, resolver) == 1
137    assert cirq.resolve_parameters(c, resolver) == 1
138
139    assert cirq.resolve_parameters_once([a, b], {a: b, b: c}) == [b, c]
140    assert cirq.resolve_parameters_once(a, {}) == a
141
142    resolver = cirq.ParamResolver({a: b, b: a})
143    assert cirq.resolve_parameters_once(a, resolver) == b
144    with pytest.raises(RecursionError):
145        _ = cirq.resolve_parameters(a, resolver)
146