1import numpy as np 2import pytest 3 4import cirq 5 6 7def test_name(): 8 names = [str(state) for state in cirq.PAULI_STATES] 9 assert names == [ 10 '+X', 11 '-X', 12 '+Y', 13 '-Y', 14 '+Z', 15 '-Z', 16 ] 17 18 19def test_repr(): 20 for o in cirq.PAULI_STATES: 21 assert o == eval(repr(o)) 22 23 24def test_equality(): 25 assert cirq.KET_PLUS == cirq.KET_PLUS 26 assert cirq.KET_PLUS != cirq.KET_MINUS 27 assert cirq.KET_PLUS != cirq.KET_ZERO 28 29 assert hash(cirq.KET_PLUS) == hash(cirq.KET_PLUS) 30 31 32def test_basis_construction(): 33 states = [] 34 for gate in [cirq.X, cirq.Y, cirq.Z]: 35 for e_val in [+1, -1]: 36 states.append(gate.basis[e_val]) 37 38 assert states == cirq.PAULI_STATES 39 40 41def test_stabilized(): 42 for state in cirq.PAULI_STATES: 43 val, gate = state.stabilized_by() 44 matrix = cirq.unitary(gate) 45 vec = state.state_vector() 46 47 np.testing.assert_allclose(matrix @ vec, val * vec) 48 49 50def test_projector(): 51 np.testing.assert_equal(cirq.KET_ZERO.projector(), [[1, 0], [0, 0]]) 52 np.testing.assert_equal(cirq.KET_ONE.projector(), [[0, 0], [0, 1]]) 53 np.testing.assert_allclose(cirq.KET_PLUS.projector(), np.array([[1, 1], [1, 1]]) / 2) 54 np.testing.assert_allclose(cirq.KET_MINUS.projector(), np.array([[1, -1], [-1, 1]]) / 2) 55 56 57def test_projector_2(): 58 for gate in [cirq.X, cirq.Y, cirq.Z]: 59 for eigen_index in [0, 1]: 60 eigenvalue = {0: +1, 1: -1}[eigen_index] 61 np.testing.assert_allclose( 62 gate.basis[eigenvalue].projector(), gate._eigen_components()[eigen_index][1] 63 ) 64 65 66def test_oneq_state(): 67 q0, q1 = cirq.LineQubit.range(2) 68 st0 = cirq.KET_PLUS(q0) 69 assert str(st0) == '+X(0)' 70 71 st1 = cirq.KET_PLUS(q1) 72 assert st0 != st1 73 74 assert st0 == cirq.KET_PLUS.on(q0) 75 76 77def test_product_state(): 78 q0, q1, q2 = cirq.LineQubit.range(3) 79 80 plus0 = cirq.KET_PLUS(q0) 81 plus1 = cirq.KET_PLUS(q1) 82 83 ps = plus0 * plus1 84 assert str(plus0) == "+X(0)" 85 assert str(plus1) == "+X(1)" 86 assert str(ps) == "+X(0) * +X(1)" 87 88 ps *= cirq.KET_ONE(q2) 89 assert str(ps) == "+X(0) * +X(1) * -Z(2)" 90 91 with pytest.raises(ValueError) as e: 92 # Re-use q2 93 ps *= cirq.KET_PLUS(q2) 94 assert e.match(r'.*both contain factors for these qubits: ' r'\[cirq.LineQubit\(2\)\]') 95 96 ps2 = eval(repr(ps)) 97 assert ps == ps2 98 99 100def test_product_state_2(): 101 q0, q1 = cirq.LineQubit.range(2) 102 103 with pytest.raises(ValueError): 104 # No coefficient 105 _ = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * -1 106 with pytest.raises(ValueError): 107 # Not a state 108 _ = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * cirq.KET_ZERO 109 110 111def test_product_qubits(): 112 q0, q1, q2 = cirq.LineQubit.range(3) 113 ps = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * cirq.KET_ZERO(q2) 114 assert ps.qubits == [q0, q1, q2] 115 assert ps[q0] == cirq.KET_PLUS 116 117 118def test_product_iter(): 119 q0, q1, q2 = cirq.LineQubit.range(3) 120 ps = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * cirq.KET_ZERO(q2) 121 122 should_be = [ 123 (q0, cirq.KET_PLUS), 124 (q1, cirq.KET_PLUS), 125 (q2, cirq.KET_ZERO), 126 ] 127 assert list(ps) == should_be 128 assert len(ps) == 3 129 130 131def test_product_state_equality(): 132 q0, q1, q2 = cirq.LineQubit.range(3) 133 134 assert cirq.KET_PLUS(q0) == cirq.KET_PLUS(q0) 135 assert cirq.KET_PLUS(q0) != cirq.KET_PLUS(q1) 136 assert cirq.KET_PLUS(q0) != cirq.KET_MINUS(q0) 137 138 assert cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1) == cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1) 139 assert cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1) != cirq.KET_PLUS(q0) * cirq.KET_MINUS(q2) 140 141 assert hash(cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1)) == hash( 142 cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1) 143 ) 144 assert hash(cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1)) != hash( 145 cirq.KET_PLUS(q0) * cirq.KET_MINUS(q2) 146 ) 147 assert cirq.KET_PLUS(q0) != '+X(0)' 148 149 150def test_tp_state_vector(): 151 q0, q1 = cirq.LineQubit.range(2) 152 s00 = cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1) 153 np.testing.assert_equal(s00.state_vector(), [1, 0, 0, 0]) 154 np.testing.assert_equal(s00.state_vector(qubit_order=(q1, q0)), [1, 0, 0, 0]) 155 156 s01 = cirq.KET_ZERO(q0) * cirq.KET_ONE(q1) 157 np.testing.assert_equal(s01.state_vector(), [0, 1, 0, 0]) 158 np.testing.assert_equal(s01.state_vector(qubit_order=(q1, q0)), [0, 0, 1, 0]) 159 160 161def test_tp_initial_state(): 162 q0, q1 = cirq.LineQubit.range(2) 163 psi1 = cirq.final_state_vector(cirq.Circuit([cirq.I.on_each(q0, q1), cirq.X(q1)])) 164 165 s01 = cirq.KET_ZERO(q0) * cirq.KET_ONE(q1) 166 psi2 = cirq.final_state_vector(cirq.Circuit([cirq.I.on_each(q0, q1)]), initial_state=s01) 167 168 np.testing.assert_allclose(psi1, psi2) 169 170 171def test_tp_projector(): 172 q0, q1 = cirq.LineQubit.range(2) 173 p00 = (cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1)).projector() 174 rho = cirq.final_density_matrix(cirq.Circuit(cirq.I.on_each(q0, q1))) 175 np.testing.assert_allclose(rho, p00) 176 177 p01 = (cirq.KET_ZERO(q0) * cirq.KET_ONE(q1)).projector() 178 rho = cirq.final_density_matrix(cirq.Circuit([cirq.I.on_each(q0, q1), cirq.X(q1)])) 179 np.testing.assert_allclose(rho, p01) 180 181 ppp = (cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1)).projector() 182 rho = cirq.final_density_matrix( 183 cirq.Circuit( 184 [ 185 cirq.H.on_each(q0, q1), 186 ] 187 ) 188 ) 189 np.testing.assert_allclose(rho, ppp, atol=1e-7) 190 191 ppm = (cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1)).projector() 192 rho = cirq.final_density_matrix(cirq.Circuit([cirq.H.on_each(q0, q1), cirq.Z(q1)])) 193 np.testing.assert_allclose(rho, ppm, atol=1e-7) 194 195 pii = (cirq.KET_IMAG(q0) * cirq.KET_IMAG(q1)).projector() 196 rho = cirq.final_density_matrix(cirq.Circuit(cirq.rx(-np.pi / 2).on_each(q0, q1))) 197 np.testing.assert_allclose(rho, pii, atol=1e-7) 198 199 pij = (cirq.KET_IMAG(q0) * cirq.KET_MINUS_IMAG(q1)).projector() 200 rho = cirq.final_density_matrix(cirq.Circuit(cirq.rx(-np.pi / 2)(q0), cirq.rx(np.pi / 2)(q1))) 201 np.testing.assert_allclose(rho, pij, atol=1e-7) 202