1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17
18import cirq
19
20
21class NoMethod:
22    pass
23
24
25class ReturnsNotImplemented:
26    def _pauli_expansion_(self):
27        return NotImplemented
28
29
30class ReturnsExpansion:
31    def __init__(self, expansion: cirq.LinearDict[str]) -> None:
32        self._expansion = expansion
33
34    def _pauli_expansion_(self) -> cirq.LinearDict[str]:
35        return self._expansion
36
37
38class HasUnitary:
39    def __init__(self, unitary: np.ndarray):
40        self._unitary = unitary
41
42    def _unitary_(self) -> np.ndarray:
43        return self._unitary
44
45
46class HasQuditUnitary:
47    def _qid_shape_(self):
48        return (3,)
49
50    def _unitary_(self) -> np.ndarray:
51        raise NotImplementedError
52
53
54@pytest.mark.parametrize(
55    'val',
56    (
57        NoMethod(),
58        ReturnsNotImplemented(),
59        HasQuditUnitary(),
60        123,
61        np.eye(2),
62        object(),
63        cirq,
64    ),
65)
66def test_raises_no_pauli_expansion(val):
67    assert cirq.pauli_expansion(val, default=None) is None
68    with pytest.raises(TypeError, match='No Pauli expansion'):
69        cirq.pauli_expansion(val)
70
71
72@pytest.mark.parametrize(
73    'val, expected_expansion',
74    (
75        (
76            ReturnsExpansion(cirq.LinearDict({'X': 1, 'Y': 2, 'Z': 3})),
77            cirq.LinearDict({'X': 1, 'Y': 2, 'Z': 3}),
78        ),
79        (HasUnitary(np.eye(2)), cirq.LinearDict({'I': 1})),
80        (HasUnitary(np.array([[1, -1j], [1j, -1]])), cirq.LinearDict({'Y': 1, 'Z': 1})),
81        (HasUnitary(np.array([[0.0, 1.0], [0.0, 0.0]])), cirq.LinearDict({'X': 0.5, 'Y': 0.5j})),
82        (HasUnitary(np.eye(16)), cirq.LinearDict({'IIII': 1.0})),
83        (cirq.H, cirq.LinearDict({'X': np.sqrt(0.5), 'Z': np.sqrt(0.5)})),
84        (
85            cirq.ry(np.pi / 2),
86            cirq.LinearDict({'I': np.cos(np.pi / 4), 'Y': -1j * np.sin(np.pi / 4)}),
87        ),
88    ),
89)
90def test_pauli_expansion(val, expected_expansion):
91    actual_expansion = cirq.pauli_expansion(val)
92    assert cirq.approx_eq(actual_expansion, expected_expansion, atol=1e-12)
93    assert set(actual_expansion.keys()) == set(expected_expansion.keys())
94    for name in actual_expansion.keys():
95        assert np.abs(actual_expansion[name] - expected_expansion[name]) < 1e-12
96