1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest
16
17import numpy as np
18
19import cirq
20
21
22class NoMethod:
23    pass
24
25
26class ReturnsNotImplemented:
27    def _mixture_(self):
28        return NotImplemented
29
30    def _has_mixture_(self):
31        return NotImplemented
32
33
34class ReturnsValidTuple(cirq.SupportsMixture):
35    def _mixture_(self):
36        return ((0.4, 'a'), (0.6, 'b'))
37
38    def _has_mixture_(self):
39        return True
40
41
42class ReturnsNonnormalizedTuple:
43    def _mixture_(self):
44        return ((0.4, 'a'), (0.4, 'b'))
45
46
47class ReturnsNegativeProbability:
48    def _mixture_(self):
49        return ((0.4, 'a'), (-0.4, 'b'))
50
51
52class ReturnsGreaterThanUnityProbability:
53    def _mixture_(self):
54        return ((1.2, 'a'), (0.4, 'b'))
55
56
57class ReturnsMixtureButNoHasMixture:
58    def _mixture_(self):
59        return ((0.4, 'a'), (0.6, 'b'))
60
61
62class ReturnsUnitary:
63    def _unitary_(self):
64        return np.ones((2, 2))
65
66    def _has_unitary_(self):
67        return True
68
69
70class ReturnsNotImplementedUnitary:
71    def _unitary_(self):
72        return NotImplemented
73
74    def _has_unitary_(self):
75        return NotImplemented
76
77
78@pytest.mark.parametrize(
79    'val,mixture',
80    (
81        (ReturnsValidTuple(), ((0.4, 'a'), (0.6, 'b'))),
82        (ReturnsNonnormalizedTuple(), ((0.4, 'a'), (0.4, 'b'))),
83        (ReturnsUnitary(), ((1.0, np.ones((2, 2))),)),
84    ),
85)
86def test_objects_with_mixture(val, mixture):
87    expected_keys, expected_values = zip(*mixture)
88    keys, values = zip(*cirq.mixture(val))
89    np.testing.assert_almost_equal(keys, expected_keys)
90    np.testing.assert_equal(values, expected_values)
91
92    keys, values = zip(*cirq.mixture(val, ((0.3, 'a'), (0.7, 'b'))))
93    np.testing.assert_almost_equal(keys, expected_keys)
94    np.testing.assert_equal(values, expected_values)
95
96
97@pytest.mark.parametrize(
98    'val', (NoMethod(), ReturnsNotImplemented(), ReturnsNotImplementedUnitary())
99)
100def test_objects_with_no_mixture(val):
101    with pytest.raises(TypeError, match="mixture"):
102        _ = cirq.mixture(val)
103    assert cirq.mixture(val, None) is None
104    assert cirq.mixture(val, NotImplemented) is NotImplemented
105    default = ((0.4, 'a'), (0.6, 'b'))
106    assert cirq.mixture(val, default) == default
107
108
109def test_has_mixture():
110    assert cirq.has_mixture(ReturnsValidTuple())
111    assert not cirq.has_mixture(ReturnsNotImplemented())
112    assert cirq.has_mixture(ReturnsMixtureButNoHasMixture())
113    assert cirq.has_mixture(ReturnsUnitary())
114    assert not cirq.has_mixture(ReturnsNotImplementedUnitary())
115
116
117def test_valid_mixture():
118    cirq.validate_mixture(ReturnsValidTuple())
119
120
121@pytest.mark.parametrize(
122    'val,message',
123    (
124        (ReturnsNonnormalizedTuple(), '1.0'),
125        (ReturnsNegativeProbability(), 'less than 0'),
126        (ReturnsGreaterThanUnityProbability(), 'greater than 1'),
127    ),
128)
129def test_invalid_mixture(val, message):
130    with pytest.raises(ValueError, match=message):
131        cirq.validate_mixture(val)
132
133
134def test_missing_mixture():
135    with pytest.raises(TypeError, match='_mixture_'):
136        cirq.validate_mixture(NoMethod)
137