1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import pytest
15import numpy as np
16
17import cirq
18
19
20def test_dot():
21    assert cirq.dot(2) == 2
22    assert cirq.dot(2.5, 2.5) == 6.25
23
24    a = np.array([[1, 2], [3, 4]])
25    b = np.array([[5, 6], [7, 8]])
26    assert cirq.dot(a) is not a
27    np.testing.assert_allclose(cirq.dot(a), a, atol=1e-8)
28    np.testing.assert_allclose(cirq.dot(a, b), np.dot(a, b), atol=1e-8)
29    np.testing.assert_allclose(cirq.dot(a, b, a), np.dot(np.dot(a, b), a), atol=1e-8)
30
31    # Invalid use
32    with pytest.raises(ValueError):
33        cirq.dot()
34
35
36def test_kron_multiplies_sizes():
37    assert cirq.kron(np.array([1, 2])).shape == (1, 2)
38    assert cirq.kron(np.array([1, 2]), shape_len=1).shape == (2,)
39    assert cirq.kron(np.array([1, 2]), np.array([3, 4, 5]), shape_len=1).shape == (6,)
40    assert cirq.kron(shape_len=0).shape == ()
41    assert cirq.kron(shape_len=1).shape == (1,)
42    assert cirq.kron(shape_len=2).shape == (1, 1)
43
44    assert np.allclose(cirq.kron(1j, np.array([2, 3])), np.array([2j, 3j]))
45    assert np.allclose(cirq.kron(), np.eye(1))
46    assert np.allclose(cirq.kron(np.eye(1)), np.eye(1))
47    assert np.allclose(cirq.kron(np.eye(2)), np.eye(2))
48    assert np.allclose(cirq.kron(np.eye(1), np.eye(1)), np.eye(1))
49    assert np.allclose(cirq.kron(np.eye(1), np.eye(2)), np.eye(2))
50    assert np.allclose(cirq.kron(np.eye(2), np.eye(3)), np.eye(6))
51    assert np.allclose(cirq.kron(np.eye(2), np.eye(3), np.eye(4)), np.eye(24))
52
53
54def test_kron_spreads_values():
55    u = np.array([[2, 3], [5, 7]])
56
57    assert np.allclose(
58        cirq.kron(np.eye(2), u), np.array([[2, 3, 0, 0], [5, 7, 0, 0], [0, 0, 2, 3], [0, 0, 5, 7]])
59    )
60
61    assert np.allclose(
62        cirq.kron(u, np.eye(2)), np.array([[2, 0, 3, 0], [0, 2, 0, 3], [5, 0, 7, 0], [0, 5, 0, 7]])
63    )
64
65    assert np.allclose(
66        cirq.kron(u, u),
67        np.array([[4, 6, 6, 9], [10, 14, 15, 21], [10, 15, 14, 21], [25, 35, 35, 49]]),
68    )
69
70
71def test_acts_like_kron_multiplies_sizes():
72    assert np.allclose(cirq.kron_with_controls(), np.eye(1))
73    assert np.allclose(cirq.kron_with_controls(np.eye(2), np.eye(3), np.eye(4)), np.eye(24))
74
75    u = np.array([[2, 3], [5, 7]])
76    assert np.allclose(
77        cirq.kron_with_controls(u, u),
78        np.array([[4, 6, 6, 9], [10, 14, 15, 21], [10, 15, 14, 21], [25, 35, 35, 49]]),
79    )
80
81
82def test_supports_controls():
83    u = np.array([[2, 3], [5, 7]])
84    assert np.allclose(cirq.kron_with_controls(cirq.CONTROL_TAG), np.array([[1, 0], [0, 1]]))
85    assert np.allclose(
86        cirq.kron_with_controls(cirq.CONTROL_TAG, u),
87        np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 3], [0, 0, 5, 7]]),
88    )
89    assert np.allclose(
90        cirq.kron_with_controls(u, cirq.CONTROL_TAG),
91        np.array([[1, 0, 0, 0], [0, 2, 0, 3], [0, 0, 1, 0], [0, 5, 0, 7]]),
92    )
93
94
95def test_block_diag():
96    assert np.allclose(cirq.block_diag(), np.zeros((0, 0)))
97
98    assert np.allclose(cirq.block_diag(np.array([[1, 2], [3, 4]])), np.array([[1, 2], [3, 4]]))
99
100    assert np.allclose(
101        cirq.block_diag(np.array([[1, 2], [3, 4]]), np.array([[4, 5, 6], [7, 8, 9], [10, 11, 12]])),
102        np.array(
103            [[1, 2, 0, 0, 0], [3, 4, 0, 0, 0], [0, 0, 4, 5, 6], [0, 0, 7, 8, 9], [0, 0, 10, 11, 12]]
104        ),
105    )
106
107
108def test_block_diag_dtype():
109    assert cirq.block_diag().dtype == np.complex128
110
111    assert cirq.block_diag(np.array([[1]], dtype=np.int8)).dtype == np.int8
112
113    assert (
114        cirq.block_diag(np.array([[1]], dtype=np.float32), np.array([[2]], dtype=np.float32)).dtype
115        == np.float32
116    )
117
118    assert (
119        cirq.block_diag(np.array([[1]], dtype=np.float64), np.array([[2]], dtype=np.float64)).dtype
120        == np.float64
121    )
122
123    assert (
124        cirq.block_diag(np.array([[1]], dtype=np.float32), np.array([[2]], dtype=np.float64)).dtype
125        == np.float64
126    )
127
128    assert (
129        cirq.block_diag(
130            np.array([[1]], dtype=np.float32), np.array([[2]], dtype=np.complex64)
131        ).dtype
132        == np.complex64
133    )
134
135    assert (
136        cirq.block_diag(np.array([[1]], dtype=int), np.array([[2]], dtype=np.complex128)).dtype
137        == np.complex128
138    )
139