1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import pytest 15import numpy as np 16 17import cirq 18 19 20def test_dot(): 21 assert cirq.dot(2) == 2 22 assert cirq.dot(2.5, 2.5) == 6.25 23 24 a = np.array([[1, 2], [3, 4]]) 25 b = np.array([[5, 6], [7, 8]]) 26 assert cirq.dot(a) is not a 27 np.testing.assert_allclose(cirq.dot(a), a, atol=1e-8) 28 np.testing.assert_allclose(cirq.dot(a, b), np.dot(a, b), atol=1e-8) 29 np.testing.assert_allclose(cirq.dot(a, b, a), np.dot(np.dot(a, b), a), atol=1e-8) 30 31 # Invalid use 32 with pytest.raises(ValueError): 33 cirq.dot() 34 35 36def test_kron_multiplies_sizes(): 37 assert cirq.kron(np.array([1, 2])).shape == (1, 2) 38 assert cirq.kron(np.array([1, 2]), shape_len=1).shape == (2,) 39 assert cirq.kron(np.array([1, 2]), np.array([3, 4, 5]), shape_len=1).shape == (6,) 40 assert cirq.kron(shape_len=0).shape == () 41 assert cirq.kron(shape_len=1).shape == (1,) 42 assert cirq.kron(shape_len=2).shape == (1, 1) 43 44 assert np.allclose(cirq.kron(1j, np.array([2, 3])), np.array([2j, 3j])) 45 assert np.allclose(cirq.kron(), np.eye(1)) 46 assert np.allclose(cirq.kron(np.eye(1)), np.eye(1)) 47 assert np.allclose(cirq.kron(np.eye(2)), np.eye(2)) 48 assert np.allclose(cirq.kron(np.eye(1), np.eye(1)), np.eye(1)) 49 assert np.allclose(cirq.kron(np.eye(1), np.eye(2)), np.eye(2)) 50 assert np.allclose(cirq.kron(np.eye(2), np.eye(3)), np.eye(6)) 51 assert np.allclose(cirq.kron(np.eye(2), np.eye(3), np.eye(4)), np.eye(24)) 52 53 54def test_kron_spreads_values(): 55 u = np.array([[2, 3], [5, 7]]) 56 57 assert np.allclose( 58 cirq.kron(np.eye(2), u), np.array([[2, 3, 0, 0], [5, 7, 0, 0], [0, 0, 2, 3], [0, 0, 5, 7]]) 59 ) 60 61 assert np.allclose( 62 cirq.kron(u, np.eye(2)), np.array([[2, 0, 3, 0], [0, 2, 0, 3], [5, 0, 7, 0], [0, 5, 0, 7]]) 63 ) 64 65 assert np.allclose( 66 cirq.kron(u, u), 67 np.array([[4, 6, 6, 9], [10, 14, 15, 21], [10, 15, 14, 21], [25, 35, 35, 49]]), 68 ) 69 70 71def test_acts_like_kron_multiplies_sizes(): 72 assert np.allclose(cirq.kron_with_controls(), np.eye(1)) 73 assert np.allclose(cirq.kron_with_controls(np.eye(2), np.eye(3), np.eye(4)), np.eye(24)) 74 75 u = np.array([[2, 3], [5, 7]]) 76 assert np.allclose( 77 cirq.kron_with_controls(u, u), 78 np.array([[4, 6, 6, 9], [10, 14, 15, 21], [10, 15, 14, 21], [25, 35, 35, 49]]), 79 ) 80 81 82def test_supports_controls(): 83 u = np.array([[2, 3], [5, 7]]) 84 assert np.allclose(cirq.kron_with_controls(cirq.CONTROL_TAG), np.array([[1, 0], [0, 1]])) 85 assert np.allclose( 86 cirq.kron_with_controls(cirq.CONTROL_TAG, u), 87 np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 3], [0, 0, 5, 7]]), 88 ) 89 assert np.allclose( 90 cirq.kron_with_controls(u, cirq.CONTROL_TAG), 91 np.array([[1, 0, 0, 0], [0, 2, 0, 3], [0, 0, 1, 0], [0, 5, 0, 7]]), 92 ) 93 94 95def test_block_diag(): 96 assert np.allclose(cirq.block_diag(), np.zeros((0, 0))) 97 98 assert np.allclose(cirq.block_diag(np.array([[1, 2], [3, 4]])), np.array([[1, 2], [3, 4]])) 99 100 assert np.allclose( 101 cirq.block_diag(np.array([[1, 2], [3, 4]]), np.array([[4, 5, 6], [7, 8, 9], [10, 11, 12]])), 102 np.array( 103 [[1, 2, 0, 0, 0], [3, 4, 0, 0, 0], [0, 0, 4, 5, 6], [0, 0, 7, 8, 9], [0, 0, 10, 11, 12]] 104 ), 105 ) 106 107 108def test_block_diag_dtype(): 109 assert cirq.block_diag().dtype == np.complex128 110 111 assert cirq.block_diag(np.array([[1]], dtype=np.int8)).dtype == np.int8 112 113 assert ( 114 cirq.block_diag(np.array([[1]], dtype=np.float32), np.array([[2]], dtype=np.float32)).dtype 115 == np.float32 116 ) 117 118 assert ( 119 cirq.block_diag(np.array([[1]], dtype=np.float64), np.array([[2]], dtype=np.float64)).dtype 120 == np.float64 121 ) 122 123 assert ( 124 cirq.block_diag(np.array([[1]], dtype=np.float32), np.array([[2]], dtype=np.float64)).dtype 125 == np.float64 126 ) 127 128 assert ( 129 cirq.block_diag( 130 np.array([[1]], dtype=np.float32), np.array([[2]], dtype=np.complex64) 131 ).dtype 132 == np.complex64 133 ) 134 135 assert ( 136 cirq.block_diag(np.array([[1]], dtype=int), np.array([[2]], dtype=np.complex128)).dtype 137 == np.complex128 138 ) 139