1import numpy as np
2import pytest
3
4import cirq
5
6
7def test_name():
8    names = [str(state) for state in cirq.PAULI_STATES]
9    assert names == [
10        '+X',
11        '-X',
12        '+Y',
13        '-Y',
14        '+Z',
15        '-Z',
16    ]
17
18
19def test_repr():
20    for o in cirq.PAULI_STATES:
21        assert o == eval(repr(o))
22
23
24def test_equality():
25    assert cirq.KET_PLUS == cirq.KET_PLUS
26    assert cirq.KET_PLUS != cirq.KET_MINUS
27    assert cirq.KET_PLUS != cirq.KET_ZERO
28
29    assert hash(cirq.KET_PLUS) == hash(cirq.KET_PLUS)
30
31
32def test_basis_construction():
33    states = []
34    for gate in [cirq.X, cirq.Y, cirq.Z]:
35        for e_val in [+1, -1]:
36            states.append(gate.basis[e_val])
37
38    assert states == cirq.PAULI_STATES
39
40
41def test_stabilized():
42    for state in cirq.PAULI_STATES:
43        val, gate = state.stabilized_by()
44        matrix = cirq.unitary(gate)
45        vec = state.state_vector()
46
47        np.testing.assert_allclose(matrix @ vec, val * vec)
48
49
50def test_projector():
51    np.testing.assert_equal(cirq.KET_ZERO.projector(), [[1, 0], [0, 0]])
52    np.testing.assert_equal(cirq.KET_ONE.projector(), [[0, 0], [0, 1]])
53    np.testing.assert_allclose(cirq.KET_PLUS.projector(), np.array([[1, 1], [1, 1]]) / 2)
54    np.testing.assert_allclose(cirq.KET_MINUS.projector(), np.array([[1, -1], [-1, 1]]) / 2)
55
56
57def test_projector_2():
58    for gate in [cirq.X, cirq.Y, cirq.Z]:
59        for eigen_index in [0, 1]:
60            eigenvalue = {0: +1, 1: -1}[eigen_index]
61            np.testing.assert_allclose(
62                gate.basis[eigenvalue].projector(), gate._eigen_components()[eigen_index][1]
63            )
64
65
66def test_oneq_state():
67    q0, q1 = cirq.LineQubit.range(2)
68    st0 = cirq.KET_PLUS(q0)
69    assert str(st0) == '+X(0)'
70
71    st1 = cirq.KET_PLUS(q1)
72    assert st0 != st1
73
74    assert st0 == cirq.KET_PLUS.on(q0)
75
76
77def test_product_state():
78    q0, q1, q2 = cirq.LineQubit.range(3)
79
80    plus0 = cirq.KET_PLUS(q0)
81    plus1 = cirq.KET_PLUS(q1)
82
83    ps = plus0 * plus1
84    assert str(plus0) == "+X(0)"
85    assert str(plus1) == "+X(1)"
86    assert str(ps) == "+X(0) * +X(1)"
87
88    ps *= cirq.KET_ONE(q2)
89    assert str(ps) == "+X(0) * +X(1) * -Z(2)"
90
91    with pytest.raises(ValueError) as e:
92        # Re-use q2
93        ps *= cirq.KET_PLUS(q2)
94    assert e.match(r'.*both contain factors for these qubits: ' r'\[cirq.LineQubit\(2\)\]')
95
96    ps2 = eval(repr(ps))
97    assert ps == ps2
98
99
100def test_product_state_2():
101    q0, q1 = cirq.LineQubit.range(2)
102
103    with pytest.raises(ValueError):
104        # No coefficient
105        _ = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * -1
106    with pytest.raises(ValueError):
107        # Not a state
108        _ = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * cirq.KET_ZERO
109
110
111def test_product_qubits():
112    q0, q1, q2 = cirq.LineQubit.range(3)
113    ps = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * cirq.KET_ZERO(q2)
114    assert ps.qubits == [q0, q1, q2]
115    assert ps[q0] == cirq.KET_PLUS
116
117
118def test_product_iter():
119    q0, q1, q2 = cirq.LineQubit.range(3)
120    ps = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * cirq.KET_ZERO(q2)
121
122    should_be = [
123        (q0, cirq.KET_PLUS),
124        (q1, cirq.KET_PLUS),
125        (q2, cirq.KET_ZERO),
126    ]
127    assert list(ps) == should_be
128    assert len(ps) == 3
129
130
131def test_product_state_equality():
132    q0, q1, q2 = cirq.LineQubit.range(3)
133
134    assert cirq.KET_PLUS(q0) == cirq.KET_PLUS(q0)
135    assert cirq.KET_PLUS(q0) != cirq.KET_PLUS(q1)
136    assert cirq.KET_PLUS(q0) != cirq.KET_MINUS(q0)
137
138    assert cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1) == cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1)
139    assert cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1) != cirq.KET_PLUS(q0) * cirq.KET_MINUS(q2)
140
141    assert hash(cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1)) == hash(
142        cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1)
143    )
144    assert hash(cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1)) != hash(
145        cirq.KET_PLUS(q0) * cirq.KET_MINUS(q2)
146    )
147    assert cirq.KET_PLUS(q0) != '+X(0)'
148
149
150def test_tp_state_vector():
151    q0, q1 = cirq.LineQubit.range(2)
152    s00 = cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1)
153    np.testing.assert_equal(s00.state_vector(), [1, 0, 0, 0])
154    np.testing.assert_equal(s00.state_vector(qubit_order=(q1, q0)), [1, 0, 0, 0])
155
156    s01 = cirq.KET_ZERO(q0) * cirq.KET_ONE(q1)
157    np.testing.assert_equal(s01.state_vector(), [0, 1, 0, 0])
158    np.testing.assert_equal(s01.state_vector(qubit_order=(q1, q0)), [0, 0, 1, 0])
159
160
161def test_tp_initial_state():
162    q0, q1 = cirq.LineQubit.range(2)
163    psi1 = cirq.final_state_vector(cirq.Circuit([cirq.I.on_each(q0, q1), cirq.X(q1)]))
164
165    s01 = cirq.KET_ZERO(q0) * cirq.KET_ONE(q1)
166    psi2 = cirq.final_state_vector(cirq.Circuit([cirq.I.on_each(q0, q1)]), initial_state=s01)
167
168    np.testing.assert_allclose(psi1, psi2)
169
170
171def test_tp_projector():
172    q0, q1 = cirq.LineQubit.range(2)
173    p00 = (cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1)).projector()
174    rho = cirq.final_density_matrix(cirq.Circuit(cirq.I.on_each(q0, q1)))
175    np.testing.assert_allclose(rho, p00)
176
177    p01 = (cirq.KET_ZERO(q0) * cirq.KET_ONE(q1)).projector()
178    rho = cirq.final_density_matrix(cirq.Circuit([cirq.I.on_each(q0, q1), cirq.X(q1)]))
179    np.testing.assert_allclose(rho, p01)
180
181    ppp = (cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1)).projector()
182    rho = cirq.final_density_matrix(
183        cirq.Circuit(
184            [
185                cirq.H.on_each(q0, q1),
186            ]
187        )
188    )
189    np.testing.assert_allclose(rho, ppp, atol=1e-7)
190
191    ppm = (cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1)).projector()
192    rho = cirq.final_density_matrix(cirq.Circuit([cirq.H.on_each(q0, q1), cirq.Z(q1)]))
193    np.testing.assert_allclose(rho, ppm, atol=1e-7)
194
195    pii = (cirq.KET_IMAG(q0) * cirq.KET_IMAG(q1)).projector()
196    rho = cirq.final_density_matrix(cirq.Circuit(cirq.rx(-np.pi / 2).on_each(q0, q1)))
197    np.testing.assert_allclose(rho, pii, atol=1e-7)
198
199    pij = (cirq.KET_IMAG(q0) * cirq.KET_MINUS_IMAG(q1)).projector()
200    rho = cirq.final_density_matrix(cirq.Circuit(cirq.rx(-np.pi / 2)(q0), cirq.rx(np.pi / 2)(q1)))
201    np.testing.assert_allclose(rho, pij, atol=1e-7)
202