1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import cirq
17
18
19def test_trace_distance_bound():
20    class NoMethod:
21        pass
22
23    class ReturnsNotImplemented:
24        def _trace_distance_bound_(self):
25            return NotImplemented
26
27    class ReturnsTwo:
28        def _trace_distance_bound_(self) -> float:
29            return 2.0
30
31    class ReturnsConstant:
32        def __init__(self, bound):
33            self.bound = bound
34
35        def _trace_distance_bound_(self) -> float:
36            return self.bound
37
38    x = cirq.MatrixGate(cirq.unitary(cirq.X))
39    cx = cirq.MatrixGate(cirq.unitary(cirq.CX))
40    cxh = cirq.MatrixGate(cirq.unitary(cirq.CX ** 0.5))
41
42    assert np.isclose(cirq.trace_distance_bound(x), cirq.trace_distance_bound(cirq.X))
43    assert np.isclose(cirq.trace_distance_bound(cx), cirq.trace_distance_bound(cirq.CX))
44    assert np.isclose(cirq.trace_distance_bound(cxh), cirq.trace_distance_bound(cirq.CX ** 0.5))
45    assert cirq.trace_distance_bound(NoMethod()) == 1.0
46    assert cirq.trace_distance_bound(ReturnsNotImplemented()) == 1.0
47    assert cirq.trace_distance_bound(ReturnsTwo()) == 1.0
48    assert cirq.trace_distance_bound(ReturnsConstant(0.1)) == 0.1
49    assert cirq.trace_distance_bound(ReturnsConstant(0.5)) == 0.5
50    assert cirq.trace_distance_bound(ReturnsConstant(1.0)) == 1.0
51    assert cirq.trace_distance_bound(ReturnsConstant(2.0)) == 1.0
52