1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for state_vector.py""" 15 16import itertools 17import pytest 18 19import numpy as np 20 21import cirq 22import cirq.testing 23 24 25def test_state_mixin(): 26 class TestClass(cirq.StateVectorMixin): 27 def state_vector(self) -> np.ndarray: 28 return np.array([0, 0, 1, 0]) 29 30 qubits = cirq.LineQubit.range(2) 31 test = TestClass(qubit_map={qubits[i]: i for i in range(2)}) 32 assert test.dirac_notation() == '|10⟩' 33 np.testing.assert_almost_equal(test.bloch_vector_of(qubits[0]), np.array([0, 0, -1])) 34 np.testing.assert_almost_equal(test.density_matrix_of(qubits[0:1]), np.array([[0, 0], [0, 1]])) 35 36 assert cirq.qid_shape(TestClass({qubits[i]: 1 - i for i in range(2)})) == (2, 2) 37 assert cirq.qid_shape(TestClass({cirq.LineQid(i, i + 1): 2 - i for i in range(3)})) == (3, 2, 1) 38 assert cirq.qid_shape(TestClass(), 'no shape') == 'no shape' 39 40 with pytest.raises(ValueError, match='Qubit index out of bounds'): 41 _ = TestClass({qubits[0]: 1}) 42 with pytest.raises(ValueError, match='Duplicate qubit index'): 43 _ = TestClass({qubits[0]: 0, qubits[1]: 0}) 44 with pytest.raises(ValueError, match='Duplicate qubit index'): 45 _ = TestClass({qubits[0]: 1, qubits[1]: 1}) 46 with pytest.raises(ValueError, match='Duplicate qubit index'): 47 _ = TestClass({qubits[0]: -1, qubits[1]: 1}) 48 49 50def test_sample_state_big_endian(): 51 results = [] 52 for x in range(8): 53 state = cirq.to_valid_state_vector(x, 3) 54 sample = cirq.sample_state_vector(state, [2, 1, 0]) 55 results.append(sample) 56 expecteds = [[list(reversed(x))] for x in list(itertools.product([False, True], repeat=3))] 57 for result, expected in zip(results, expecteds): 58 np.testing.assert_equal(result, expected) 59 60 61def test_sample_state_partial_indices(): 62 for index in range(3): 63 for x in range(8): 64 state = cirq.to_valid_state_vector(x, 3) 65 np.testing.assert_equal( 66 cirq.sample_state_vector(state, [index]), [[bool(1 & (x >> (2 - index)))]] 67 ) 68 69 70def test_sample_state_partial_indices_oder(): 71 for x in range(8): 72 state = cirq.to_valid_state_vector(x, 3) 73 expected = [[bool(1 & (x >> 0)), bool(1 & (x >> 1))]] 74 np.testing.assert_equal(cirq.sample_state_vector(state, [2, 1]), expected) 75 76 77def test_sample_state_partial_indices_all_orders(): 78 for perm in itertools.permutations([0, 1, 2]): 79 for x in range(8): 80 state = cirq.to_valid_state_vector(x, 3) 81 expected = [[bool(1 & (x >> (2 - p))) for p in perm]] 82 np.testing.assert_equal(cirq.sample_state_vector(state, perm), expected) 83 84 85def test_sample_state(): 86 state = np.zeros(8, dtype=np.complex64) 87 state[0] = 1 / np.sqrt(2) 88 state[2] = 1 / np.sqrt(2) 89 for _ in range(10): 90 sample = cirq.sample_state_vector(state, [2, 1, 0]) 91 assert np.array_equal(sample, [[False, False, False]]) or np.array_equal( 92 sample, [[False, True, False]] 93 ) 94 # Partial sample is correct. 95 for _ in range(10): 96 np.testing.assert_equal(cirq.sample_state_vector(state, [2]), [[False]]) 97 np.testing.assert_equal(cirq.sample_state_vector(state, [0]), [[False]]) 98 99 100def test_sample_empty_state(): 101 state = np.array([1.0]) 102 np.testing.assert_almost_equal(cirq.sample_state_vector(state, []), np.zeros(shape=(1, 0))) 103 104 105def test_sample_no_repetitions(): 106 state = cirq.to_valid_state_vector(0, 3) 107 np.testing.assert_almost_equal( 108 cirq.sample_state_vector(state, [1], repetitions=0), np.zeros(shape=(0, 1)) 109 ) 110 np.testing.assert_almost_equal( 111 cirq.sample_state_vector(state, [1, 2], repetitions=0), np.zeros(shape=(0, 2)) 112 ) 113 114 115def test_sample_state_repetitions(): 116 for perm in itertools.permutations([0, 1, 2]): 117 for x in range(8): 118 state = cirq.to_valid_state_vector(x, 3) 119 expected = [[bool(1 & (x >> (2 - p))) for p in perm]] * 3 120 121 result = cirq.sample_state_vector(state, perm, repetitions=3) 122 np.testing.assert_equal(result, expected) 123 124 125def test_sample_state_seed(): 126 state = np.ones(2) / np.sqrt(2) 127 128 samples = cirq.sample_state_vector(state, [0], repetitions=10, seed=1234) 129 assert np.array_equal( 130 samples, 131 [[False], [True], [False], [True], [True], [False], [False], [True], [True], [True]], 132 ) 133 134 samples = cirq.sample_state_vector(state, [0], repetitions=10, seed=np.random.RandomState(1234)) 135 assert np.array_equal( 136 samples, 137 [[False], [True], [False], [True], [True], [False], [False], [True], [True], [True]], 138 ) 139 140 141def test_sample_state_negative_repetitions(): 142 state = cirq.to_valid_state_vector(0, 3) 143 with pytest.raises(ValueError, match='-1'): 144 cirq.sample_state_vector(state, [1], repetitions=-1) 145 146 147def test_sample_state_not_power_of_two(): 148 with pytest.raises(ValueError, match='3'): 149 cirq.sample_state_vector(np.array([1, 0, 0]), [1]) 150 with pytest.raises(ValueError, match='5'): 151 cirq.sample_state_vector(np.array([0, 1, 0, 0, 0]), [1]) 152 153 154def test_sample_state_index_out_of_range(): 155 state = cirq.to_valid_state_vector(0, 3) 156 with pytest.raises(IndexError, match='-2'): 157 cirq.sample_state_vector(state, [-2]) 158 with pytest.raises(IndexError, match='3'): 159 cirq.sample_state_vector(state, [3]) 160 161 162def test_sample_no_indices(): 163 state = cirq.to_valid_state_vector(0, 3) 164 np.testing.assert_almost_equal(cirq.sample_state_vector(state, []), np.zeros(shape=(1, 0))) 165 166 167def test_sample_no_indices_repetitions(): 168 state = cirq.to_valid_state_vector(0, 3) 169 np.testing.assert_almost_equal( 170 cirq.sample_state_vector(state, [], repetitions=2), np.zeros(shape=(2, 0)) 171 ) 172 173 174def test_measure_state_computational_basis(): 175 results = [] 176 for x in range(8): 177 initial_state = cirq.to_valid_state_vector(x, 3) 178 bits, state = cirq.measure_state_vector(initial_state, [2, 1, 0]) 179 results.append(bits) 180 np.testing.assert_almost_equal(state, initial_state) 181 expected = [list(reversed(x)) for x in list(itertools.product([False, True], repeat=3))] 182 assert results == expected 183 184 185def test_measure_state_reshape(): 186 results = [] 187 for x in range(8): 188 initial_state = np.reshape(cirq.to_valid_state_vector(x, 3), [2] * 3) 189 bits, state = cirq.measure_state_vector(initial_state, [2, 1, 0]) 190 results.append(bits) 191 np.testing.assert_almost_equal(state, initial_state) 192 expected = [list(reversed(x)) for x in list(itertools.product([False, True], repeat=3))] 193 assert results == expected 194 195 196def test_measure_state_partial_indices(): 197 for index in range(3): 198 for x in range(8): 199 initial_state = cirq.to_valid_state_vector(x, 3) 200 bits, state = cirq.measure_state_vector(initial_state, [index]) 201 np.testing.assert_almost_equal(state, initial_state) 202 assert bits == [bool(1 & (x >> (2 - index)))] 203 204 205def test_measure_state_partial_indices_order(): 206 for x in range(8): 207 initial_state = cirq.to_valid_state_vector(x, 3) 208 bits, state = cirq.measure_state_vector(initial_state, [2, 1]) 209 np.testing.assert_almost_equal(state, initial_state) 210 assert bits == [bool(1 & (x >> 0)), bool(1 & (x >> 1))] 211 212 213def test_measure_state_partial_indices_all_orders(): 214 for perm in itertools.permutations([0, 1, 2]): 215 for x in range(8): 216 initial_state = cirq.to_valid_state_vector(x, 3) 217 bits, state = cirq.measure_state_vector(initial_state, perm) 218 np.testing.assert_almost_equal(state, initial_state) 219 assert bits == [bool(1 & (x >> (2 - p))) for p in perm] 220 221 222def test_measure_state_collapse(): 223 initial_state = np.zeros(8, dtype=np.complex64) 224 initial_state[0] = 1 / np.sqrt(2) 225 initial_state[2] = 1 / np.sqrt(2) 226 for _ in range(10): 227 bits, state = cirq.measure_state_vector(initial_state, [2, 1, 0]) 228 assert bits in [[False, False, False], [False, True, False]] 229 expected = np.zeros(8, dtype=np.complex64) 230 expected[2 if bits[1] else 0] = 1.0 231 np.testing.assert_almost_equal(state, expected) 232 assert state is not initial_state 233 234 # Partial sample is correct. 235 for _ in range(10): 236 bits, state = cirq.measure_state_vector(initial_state, [2]) 237 np.testing.assert_almost_equal(state, initial_state) 238 assert bits == [False] 239 240 bits, state = cirq.measure_state_vector(initial_state, [0]) 241 np.testing.assert_almost_equal(state, initial_state) 242 assert bits == [False] 243 244 245def test_measure_state_seed(): 246 n = 10 247 initial_state = np.ones(2 ** n) / 2 ** (n / 2) 248 249 bits, state1 = cirq.measure_state_vector(initial_state, range(n), seed=1234) 250 np.testing.assert_equal( 251 bits, [False, False, True, True, False, False, False, True, False, False] 252 ) 253 254 bits, state2 = cirq.measure_state_vector( 255 initial_state, range(n), seed=np.random.RandomState(1234) 256 ) 257 np.testing.assert_equal( 258 bits, [False, False, True, True, False, False, False, True, False, False] 259 ) 260 261 np.testing.assert_allclose(state1, state2) 262 263 264def test_measure_state_out_is_state(): 265 initial_state = np.zeros(8, dtype=np.complex64) 266 initial_state[0] = 1 / np.sqrt(2) 267 initial_state[2] = 1 / np.sqrt(2) 268 bits, state = cirq.measure_state_vector(initial_state, [2, 1, 0], out=initial_state) 269 expected = np.zeros(8, dtype=np.complex64) 270 expected[2 if bits[1] else 0] = 1.0 271 np.testing.assert_array_almost_equal(initial_state, expected) 272 assert state is initial_state 273 274 275def test_measure_state_out_is_not_state(): 276 initial_state = np.zeros(8, dtype=np.complex64) 277 initial_state[0] = 1 / np.sqrt(2) 278 initial_state[2] = 1 / np.sqrt(2) 279 out = np.zeros_like(initial_state) 280 _, state = cirq.measure_state_vector(initial_state, [2, 1, 0], out=out) 281 assert out is not initial_state 282 assert out is state 283 284 285def test_measure_state_not_power_of_two(): 286 with pytest.raises(ValueError, match='3'): 287 _, _ = cirq.measure_state_vector(np.array([1, 0, 0]), [1]) 288 with pytest.raises(ValueError, match='5'): 289 cirq.measure_state_vector(np.array([0, 1, 0, 0, 0]), [1]) 290 291 292def test_measure_state_index_out_of_range(): 293 state = cirq.to_valid_state_vector(0, 3) 294 with pytest.raises(IndexError, match='-2'): 295 cirq.measure_state_vector(state, [-2]) 296 with pytest.raises(IndexError, match='3'): 297 cirq.measure_state_vector(state, [3]) 298 299 300def test_measure_state_no_indices(): 301 initial_state = cirq.to_valid_state_vector(0, 3) 302 bits, state = cirq.measure_state_vector(initial_state, []) 303 assert [] == bits 304 np.testing.assert_almost_equal(state, initial_state) 305 306 307def test_measure_state_no_indices_out_is_state(): 308 initial_state = cirq.to_valid_state_vector(0, 3) 309 bits, state = cirq.measure_state_vector(initial_state, [], out=initial_state) 310 assert [] == bits 311 np.testing.assert_almost_equal(state, initial_state) 312 assert state is initial_state 313 314 315def test_measure_state_no_indices_out_is_not_state(): 316 initial_state = cirq.to_valid_state_vector(0, 3) 317 out = np.zeros_like(initial_state) 318 bits, state = cirq.measure_state_vector(initial_state, [], out=out) 319 assert [] == bits 320 np.testing.assert_almost_equal(state, initial_state) 321 assert state is out 322 assert out is not initial_state 323 324 325def test_measure_state_empty_state(): 326 initial_state = np.array([1.0]) 327 bits, state = cirq.measure_state_vector(initial_state, []) 328 assert [] == bits 329 np.testing.assert_almost_equal(state, initial_state) 330 331 332class BasicStateVector(cirq.StateVectorMixin): 333 def state_vector(self) -> np.ndarray: 334 return np.array([0, 1, 0, 0]) 335 336 337def test_step_result_pretty_state(): 338 step_result = BasicStateVector() 339 assert step_result.dirac_notation() == '|01⟩' 340 341 342def test_step_result_density_matrix(): 343 q0, q1 = cirq.LineQubit.range(2) 344 345 step_result = BasicStateVector({q0: 0, q1: 1}) 346 rho = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) 347 np.testing.assert_array_almost_equal(rho, step_result.density_matrix_of([q0, q1])) 348 349 np.testing.assert_array_almost_equal(rho, step_result.density_matrix_of()) 350 351 rho_ind_rev = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]]) 352 np.testing.assert_array_almost_equal(rho_ind_rev, step_result.density_matrix_of([q1, q0])) 353 354 single_rho = np.array([[0, 0], [0, 1]]) 355 np.testing.assert_array_almost_equal(single_rho, step_result.density_matrix_of([q1])) 356 357 358def test_step_result_density_matrix_invalid(): 359 q0, q1 = cirq.LineQubit.range(2) 360 361 step_result = BasicStateVector({q0: 0}) 362 363 with pytest.raises(KeyError): 364 step_result.density_matrix_of([q1]) 365 with pytest.raises(KeyError): 366 step_result.density_matrix_of('junk') 367 with pytest.raises(TypeError): 368 step_result.density_matrix_of(0) 369 370 371def test_step_result_bloch_vector(): 372 q0, q1 = cirq.LineQubit.range(2) 373 step_result = BasicStateVector({q0: 0, q1: 1}) 374 bloch1 = np.array([0, 0, -1]) 375 bloch0 = np.array([0, 0, 1]) 376 np.testing.assert_array_almost_equal(bloch1, step_result.bloch_vector_of(q1)) 377 np.testing.assert_array_almost_equal(bloch0, step_result.bloch_vector_of(q0)) 378 379 380def test_factor_validation(): 381 args = cirq.Simulator()._create_act_on_args(0, qubits=cirq.LineQubit.range(2)) 382 args.apply_operation(cirq.H(cirq.LineQubit(0))) 383 t = args.create_merged_state().target_tensor 384 cirq.linalg.transformations.factor_state_vector(t, [0]) 385 cirq.linalg.transformations.factor_state_vector(t, [1], atol=1e-2) 386 args.apply_operation(cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1))) 387 t = args.create_merged_state().target_tensor 388 with pytest.raises(ValueError, match='factor'): 389 cirq.linalg.transformations.factor_state_vector(t, [0]) 390 with pytest.raises(ValueError, match='factor'): 391 cirq.linalg.transformations.factor_state_vector(t, [1]) 392