1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest
16
17import numpy as np
18
19import cirq
20
21
22class GoodPhaser:
23    def __init__(self, e):
24        self.e = e
25
26    def _unitary_(self):
27        return np.array([[0, 1j ** -self.e], [1j ** self.e, 0]])
28
29    def _phase_by_(self, phase_turns: float, qubit_index: int):
30        return GoodPhaser(self.e + phase_turns * 4)
31
32    def _resolve_parameters_(self, resolver, recursive):
33        return GoodPhaser(resolver.value_of(self.e, recursive))
34
35
36class GoodQuditPhaser:
37    def __init__(self, e):
38        self.e = e
39
40    def _qid_shape_(self):
41        return (3,)
42
43    def _unitary_(self):
44        return np.array(
45            [
46                [0, 1j ** -self.e, 0],
47                [0, 0, 1j ** self.e],
48                [1, 0, 0],
49            ]
50        )
51
52    def _phase_by_(self, phase_turns: float, qubit_index: int):
53        return GoodQuditPhaser(self.e + phase_turns * 4)
54
55    def _resolve_parameters_(self, resolver, recursive):
56        return GoodQuditPhaser(resolver.value_of(self.e, recursive))
57
58
59class BadPhaser:
60    def __init__(self, e):
61        self.e = e
62
63    def _unitary_(self):
64        return np.array([[0, 1j ** -(self.e * 2)], [1j ** self.e, 0]])
65
66    def _phase_by_(self, phase_turns: float, qubit_index: int):
67        return BadPhaser(self.e + phase_turns * 4)
68
69    def _resolve_parameters_(self, resolver, recursive):
70        return BadPhaser(resolver.value_of(self.e, recursive))
71
72
73class NotPhaser:
74    def _unitary_(self):
75        return np.array([[0, 1], [1, 0]])
76
77    def _phase_by_(self, phase_turns: float, qubit_index: int):
78        return NotImplemented
79
80
81class SemiBadPhaser:
82    def __init__(self, e):
83        self.e = e
84
85    def _unitary_(self):
86        a1 = cirq.unitary(GoodPhaser(self.e[0]))
87        a2 = cirq.unitary(BadPhaser(self.e[1]))
88        return np.kron(a1, a2)
89
90    def _phase_by_(self, phase_turns: float, qubit_index: int):
91        r = list(self.e)
92        r[qubit_index] += phase_turns * 4
93        return SemiBadPhaser(r)
94
95    def _resolve_parameters_(self, resolver, recursive):
96        return SemiBadPhaser([resolver.value_of(val, recursive) for val in self.e])
97
98
99def test_assert_phase_by_is_consistent_with_unitary():
100    cirq.testing.assert_phase_by_is_consistent_with_unitary(GoodPhaser(0.5))
101
102    cirq.testing.assert_phase_by_is_consistent_with_unitary(GoodQuditPhaser(0.5))
103
104    with pytest.raises(AssertionError, match='Phased unitary was incorrect for index #0'):
105        cirq.testing.assert_phase_by_is_consistent_with_unitary(BadPhaser(0.5))
106
107    with pytest.raises(AssertionError, match='Phased unitary was incorrect for index #1'):
108        cirq.testing.assert_phase_by_is_consistent_with_unitary(SemiBadPhaser([0.5, 0.25]))
109
110    # Vacuous success.
111    cirq.testing.assert_phase_by_is_consistent_with_unitary(NotPhaser())
112