1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest
16
17import cirq
18
19
20class NoMethod:
21    pass
22
23
24class ReturnsNotImplemented:
25    def __pow__(self, exponent):
26        return NotImplemented
27
28
29class ReturnsFive:
30    def __pow__(self, exponent) -> int:
31        return 5
32
33
34class SelfInverse:
35    def __pow__(self, exponent) -> 'SelfInverse':
36        return self
37
38
39class ImplementsReversible:
40    def __pow__(self, exponent):
41        return 6 if exponent == -1 else NotImplemented
42
43
44class IsIterable:
45    def __iter__(self):
46        yield 1
47        yield 2
48
49
50@pytest.mark.parametrize(
51    'val',
52    (
53        NoMethod(),
54        'text',
55        object(),
56        ReturnsNotImplemented(),
57        [NoMethod(), 5],
58    ),
59)
60def test_objects_with_no_inverse(val):
61    with pytest.raises(TypeError, match="isn't invertible"):
62        _ = cirq.inverse(val)
63    assert cirq.inverse(val, None) is None
64    assert cirq.inverse(val, NotImplemented) is NotImplemented
65    assert cirq.inverse(val, 5) == 5
66
67
68@pytest.mark.parametrize(
69    'val,inv',
70    (
71        (ReturnsFive(), 5),
72        (ImplementsReversible(), 6),
73        (SelfInverse(),) * 2,
74        (1, 1),
75        (2, 0.5),
76        (1j, -1j),
77        ((), ()),
78        ([], ()),
79        ((2,), (0.5,)),
80        ((1, 2), (0.5, 1)),
81        ((2, (4, 8)), ((0.125, 0.25), 0.5)),
82        ((2, [4, 8]), ((0.125, 0.25), 0.5)),
83        (IsIterable(), (0.5, 1)),
84    ),
85)
86def test_objects_with_inverse(val, inv):
87    assert cirq.inverse(val) == inv
88    assert cirq.inverse(val, 0) == inv
89