1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import pytest 16 17import numpy as np 18 19import cirq 20 21 22class NoMethod: 23 pass 24 25 26class ReturnsNotImplemented: 27 def _mixture_(self): 28 return NotImplemented 29 30 def _has_mixture_(self): 31 return NotImplemented 32 33 34class ReturnsValidTuple(cirq.SupportsMixture): 35 def _mixture_(self): 36 return ((0.4, 'a'), (0.6, 'b')) 37 38 def _has_mixture_(self): 39 return True 40 41 42class ReturnsNonnormalizedTuple: 43 def _mixture_(self): 44 return ((0.4, 'a'), (0.4, 'b')) 45 46 47class ReturnsNegativeProbability: 48 def _mixture_(self): 49 return ((0.4, 'a'), (-0.4, 'b')) 50 51 52class ReturnsGreaterThanUnityProbability: 53 def _mixture_(self): 54 return ((1.2, 'a'), (0.4, 'b')) 55 56 57class ReturnsMixtureButNoHasMixture: 58 def _mixture_(self): 59 return ((0.4, 'a'), (0.6, 'b')) 60 61 62class ReturnsUnitary: 63 def _unitary_(self): 64 return np.ones((2, 2)) 65 66 def _has_unitary_(self): 67 return True 68 69 70class ReturnsNotImplementedUnitary: 71 def _unitary_(self): 72 return NotImplemented 73 74 def _has_unitary_(self): 75 return NotImplemented 76 77 78@pytest.mark.parametrize( 79 'val,mixture', 80 ( 81 (ReturnsValidTuple(), ((0.4, 'a'), (0.6, 'b'))), 82 (ReturnsNonnormalizedTuple(), ((0.4, 'a'), (0.4, 'b'))), 83 (ReturnsUnitary(), ((1.0, np.ones((2, 2))),)), 84 ), 85) 86def test_objects_with_mixture(val, mixture): 87 expected_keys, expected_values = zip(*mixture) 88 keys, values = zip(*cirq.mixture(val)) 89 np.testing.assert_almost_equal(keys, expected_keys) 90 np.testing.assert_equal(values, expected_values) 91 92 keys, values = zip(*cirq.mixture(val, ((0.3, 'a'), (0.7, 'b')))) 93 np.testing.assert_almost_equal(keys, expected_keys) 94 np.testing.assert_equal(values, expected_values) 95 96 97@pytest.mark.parametrize( 98 'val', (NoMethod(), ReturnsNotImplemented(), ReturnsNotImplementedUnitary()) 99) 100def test_objects_with_no_mixture(val): 101 with pytest.raises(TypeError, match="mixture"): 102 _ = cirq.mixture(val) 103 assert cirq.mixture(val, None) is None 104 assert cirq.mixture(val, NotImplemented) is NotImplemented 105 default = ((0.4, 'a'), (0.6, 'b')) 106 assert cirq.mixture(val, default) == default 107 108 109def test_has_mixture(): 110 assert cirq.has_mixture(ReturnsValidTuple()) 111 assert not cirq.has_mixture(ReturnsNotImplemented()) 112 assert cirq.has_mixture(ReturnsMixtureButNoHasMixture()) 113 assert cirq.has_mixture(ReturnsUnitary()) 114 assert not cirq.has_mixture(ReturnsNotImplementedUnitary()) 115 116 117def test_valid_mixture(): 118 cirq.validate_mixture(ReturnsValidTuple()) 119 120 121@pytest.mark.parametrize( 122 'val,message', 123 ( 124 (ReturnsNonnormalizedTuple(), '1.0'), 125 (ReturnsNegativeProbability(), 'less than 0'), 126 (ReturnsGreaterThanUnityProbability(), 'greater than 1'), 127 ), 128) 129def test_invalid_mixture(val, message): 130 with pytest.raises(ValueError, match=message): 131 cirq.validate_mixture(val) 132 133 134def test_missing_mixture(): 135 with pytest.raises(TypeError, match='_mixture_'): 136 cirq.validate_mixture(NoMethod) 137