1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import cirq 17 18 19def test_trace_distance_bound(): 20 class NoMethod: 21 pass 22 23 class ReturnsNotImplemented: 24 def _trace_distance_bound_(self): 25 return NotImplemented 26 27 class ReturnsTwo: 28 def _trace_distance_bound_(self) -> float: 29 return 2.0 30 31 class ReturnsConstant: 32 def __init__(self, bound): 33 self.bound = bound 34 35 def _trace_distance_bound_(self) -> float: 36 return self.bound 37 38 x = cirq.MatrixGate(cirq.unitary(cirq.X)) 39 cx = cirq.MatrixGate(cirq.unitary(cirq.CX)) 40 cxh = cirq.MatrixGate(cirq.unitary(cirq.CX ** 0.5)) 41 42 assert np.isclose(cirq.trace_distance_bound(x), cirq.trace_distance_bound(cirq.X)) 43 assert np.isclose(cirq.trace_distance_bound(cx), cirq.trace_distance_bound(cirq.CX)) 44 assert np.isclose(cirq.trace_distance_bound(cxh), cirq.trace_distance_bound(cirq.CX ** 0.5)) 45 assert cirq.trace_distance_bound(NoMethod()) == 1.0 46 assert cirq.trace_distance_bound(ReturnsNotImplemented()) == 1.0 47 assert cirq.trace_distance_bound(ReturnsTwo()) == 1.0 48 assert cirq.trace_distance_bound(ReturnsConstant(0.1)) == 0.1 49 assert cirq.trace_distance_bound(ReturnsConstant(0.5)) == 0.5 50 assert cirq.trace_distance_bound(ReturnsConstant(1.0)) == 1.0 51 assert cirq.trace_distance_bound(ReturnsConstant(2.0)) == 1.0 52