1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest
16
17import numpy as np
18
19import cirq
20from cirq.testing.circuit_compare import (
21    _assert_apply_unitary_works_when_axes_transposed,
22)
23
24
25def test_sensitive_to_phase():
26    q = cirq.NamedQubit('q')
27
28    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
29        cirq.Circuit([cirq.Moment([])]), cirq.Circuit(), atol=0
30    )
31
32    with pytest.raises(AssertionError):
33        cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
34            cirq.Circuit([cirq.Moment([cirq.Z(q) ** 0.0001])]), cirq.Circuit(), atol=0
35        )
36
37    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
38        cirq.Circuit([cirq.Moment([cirq.Z(q) ** 0.0001])]), cirq.Circuit(), atol=0.01
39    )
40
41
42def test_sensitive_to_measurement_but_not_measured_phase():
43    q = cirq.NamedQubit('q')
44
45    with pytest.raises(AssertionError):
46        cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
47            cirq.Circuit([cirq.Moment([cirq.measure(q)])]), cirq.Circuit(), atol=1e-8
48        )
49
50    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
51        cirq.Circuit([cirq.Moment([cirq.measure(q)])]),
52        cirq.Circuit(
53            [
54                cirq.Moment([cirq.Z(q)]),
55                cirq.Moment([cirq.measure(q)]),
56            ]
57        ),
58        atol=1e-8,
59    )
60
61    a, b = cirq.LineQubit.range(2)
62
63    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
64        cirq.Circuit([cirq.Moment([cirq.measure(a, b)])]),
65        cirq.Circuit(
66            [
67                cirq.Moment([cirq.Z(a)]),
68                cirq.Moment([cirq.measure(a, b)]),
69            ]
70        ),
71        atol=1e-8,
72    )
73
74    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
75        cirq.Circuit([cirq.Moment([cirq.measure(a)])]),
76        cirq.Circuit(
77            [
78                cirq.Moment([cirq.Z(a)]),
79                cirq.Moment([cirq.measure(a)]),
80            ]
81        ),
82        atol=1e-8,
83    )
84
85    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
86        cirq.Circuit([cirq.Moment([cirq.measure(a, b)])]),
87        cirq.Circuit(
88            [
89                cirq.Moment([cirq.T(a), cirq.S(b)]),
90                cirq.Moment([cirq.measure(a, b)]),
91            ]
92        ),
93        atol=1e-8,
94    )
95
96    with pytest.raises(AssertionError):
97        cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
98            cirq.Circuit([cirq.Moment([cirq.measure(a)])]),
99            cirq.Circuit(
100                [
101                    cirq.Moment([cirq.T(a), cirq.S(b)]),
102                    cirq.Moment([cirq.measure(a)]),
103                ]
104            ),
105            atol=1e-8,
106        )
107
108    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
109        cirq.Circuit([cirq.Moment([cirq.measure(a, b)])]),
110        cirq.Circuit(
111            [
112                cirq.Moment([cirq.CZ(a, b)]),
113                cirq.Moment([cirq.measure(a, b)]),
114            ]
115        ),
116        atol=1e-8,
117    )
118
119
120def test_sensitive_to_measurement_toggle():
121    q = cirq.NamedQubit('q')
122
123    with pytest.raises(AssertionError):
124        cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
125            cirq.Circuit([cirq.Moment([cirq.measure(q)])]),
126            cirq.Circuit(
127                [
128                    cirq.Moment([cirq.X(q)]),
129                    cirq.Moment([cirq.measure(q)]),
130                ]
131            ),
132            atol=1e-8,
133        )
134
135    with pytest.raises(AssertionError):
136        cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
137            cirq.Circuit([cirq.Moment([cirq.measure(q)])]),
138            cirq.Circuit(
139                [
140                    cirq.Moment([cirq.measure(q, invert_mask=(True,))]),
141                ]
142            ),
143            atol=1e-8,
144        )
145
146    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
147        cirq.Circuit([cirq.Moment([cirq.measure(q)])]),
148        cirq.Circuit(
149            [
150                cirq.Moment([cirq.X(q)]),
151                cirq.Moment([cirq.measure(q, invert_mask=(True,))]),
152            ]
153        ),
154        atol=1e-8,
155    )
156
157
158def test_measuring_qubits():
159    a, b = cirq.LineQubit.range(2)
160
161    with pytest.raises(AssertionError):
162        cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
163            cirq.Circuit([cirq.Moment([cirq.measure(a)])]),
164            cirq.Circuit([cirq.Moment([cirq.measure(b)])]),
165            atol=1e-8,
166        )
167
168    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
169        cirq.Circuit([cirq.Moment([cirq.measure(a, b, invert_mask=(True,))])]),
170        cirq.Circuit([cirq.Moment([cirq.measure(b, a, invert_mask=(False, True))])]),
171        atol=1e-8,
172    )
173
174    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
175        cirq.Circuit(
176            [
177                cirq.Moment([cirq.measure(a)]),
178                cirq.Moment([cirq.measure(b)]),
179            ]
180        ),
181        cirq.Circuit([cirq.Moment([cirq.measure(a, b)])]),
182        atol=1e-8,
183    )
184
185
186@pytest.mark.parametrize(
187    'circuit', [cirq.testing.random_circuit(cirq.LineQubit.range(2), 4, 0.5) for _ in range(5)]
188)
189def test_random_same_matrix(circuit):
190    a, b = cirq.LineQubit.range(2)
191    same = cirq.Circuit(
192        cirq.MatrixGate(circuit.unitary(qubits_that_should_be_present=[a, b])).on(a, b)
193    )
194
195    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, same, atol=1e-8)
196
197    circuit.append(cirq.measure(a))
198    same.append(cirq.measure(a))
199    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, same, atol=1e-8)
200
201
202def test_correct_qubit_ordering():
203    a, b = cirq.LineQubit.range(2)
204    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
205        cirq.Circuit(cirq.Z(a), cirq.Z(b), cirq.measure(b)),
206        cirq.Circuit(cirq.Z(a), cirq.measure(b)),
207        atol=1e-8,
208    )
209
210    with pytest.raises(AssertionError):
211        cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
212            cirq.Circuit(cirq.Z(a), cirq.Z(b), cirq.measure(b)),
213            cirq.Circuit(cirq.Z(b), cirq.measure(b)),
214            atol=1e-8,
215        )
216
217
218def test_known_old_failure():
219    a, b = cirq.LineQubit.range(2)
220    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
221        actual=cirq.Circuit(
222            cirq.PhasedXPowGate(exponent=0.61351656, phase_exponent=0.8034575038876517).on(b),
223            cirq.measure(a, b),
224        ),
225        reference=cirq.Circuit(
226            cirq.PhasedXPowGate(exponent=0.61351656, phase_exponent=0.8034575038876517).on(b),
227            cirq.Z(a) ** 0.5,
228            cirq.Z(b) ** 0.1,
229            cirq.measure(a, b),
230        ),
231        atol=1e-8,
232    )
233
234
235def test_assert_same_circuits():
236    a, b = cirq.LineQubit.range(2)
237
238    cirq.testing.assert_same_circuits(
239        cirq.Circuit(cirq.H(a)),
240        cirq.Circuit(cirq.H(a)),
241    )
242
243    with pytest.raises(AssertionError) as exc_info:
244        cirq.testing.assert_same_circuits(
245            cirq.Circuit(cirq.H(a)),
246            cirq.Circuit(),
247        )
248    assert 'differing moment:\n0\n' in exc_info.value.args[0]
249
250    with pytest.raises(AssertionError) as exc_info:
251        cirq.testing.assert_same_circuits(
252            cirq.Circuit(cirq.H(a), cirq.H(a)),
253            cirq.Circuit(cirq.H(a), cirq.CZ(a, b)),
254        )
255    assert 'differing moment:\n1\n' in exc_info.value.args[0]
256
257    with pytest.raises(AssertionError):
258        cirq.testing.assert_same_circuits(
259            cirq.Circuit(cirq.CNOT(a, b)),
260            cirq.Circuit(cirq.ControlledGate(cirq.X).on(a, b)),
261        )
262
263
264def test_assert_has_diagram():
265    a, b = cirq.LineQubit.range(2)
266    circuit = cirq.Circuit(cirq.CNOT(a, b))
267    cirq.testing.assert_has_diagram(
268        circuit,
269        """
2700: ───@───
2712721: ───X───
273""",
274    )
275
276    expected_error = """Circuit's text diagram differs from the desired diagram.
277
278Diagram of actual circuit:
2790: ───@───
2802811: ───X───
282
283Desired text diagram:
2840: ───@───
2852861: ───Z───
287
288Highlighted differences:
2890: ───@───
2902911: ───█───
292
293"""
294
295    with pytest.raises(AssertionError) as ex_info:
296        cirq.testing.assert_has_diagram(
297            circuit,
298            """
2990: ───@───
3003011: ───Z───
302""",
303        )
304    assert expected_error in ex_info.value.args[0]
305
306
307def test_assert_has_consistent_apply_unitary():
308    class IdentityReturningUnalteredWorkspace:
309        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
310            return args.available_buffer
311
312        def _unitary_(self):
313            return np.eye(2)
314
315        def _num_qubits_(self):
316            return 1
317
318    with pytest.raises(AssertionError):
319        cirq.testing.assert_has_consistent_apply_unitary(IdentityReturningUnalteredWorkspace())
320
321    class DifferentEffect:
322        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
323            o = args.subspace_index(0)
324            i = args.subspace_index(1)
325            args.available_buffer[o] = args.target_tensor[i]
326            args.available_buffer[i] = args.target_tensor[o]
327            return args.available_buffer
328
329        def _unitary_(self):
330            return np.eye(2, dtype=np.complex128)
331
332        def _num_qubits_(self):
333            return 1
334
335    with pytest.raises(AssertionError):
336        cirq.testing.assert_has_consistent_apply_unitary(DifferentEffect())
337
338    class IgnoreAxisEffect:
339        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
340            if args.target_tensor.shape[0] > 1:
341                args.available_buffer[0] = args.target_tensor[1]
342                args.available_buffer[1] = args.target_tensor[0]
343            return args.available_buffer
344
345        def _unitary_(self):
346            return np.array([[0, 1], [1, 0]])
347
348        def _num_qubits_(self):
349            return 1
350
351    with pytest.raises(AssertionError, match='Not equal|acted differently'):
352        cirq.testing.assert_has_consistent_apply_unitary(IgnoreAxisEffect())
353
354    class SameEffect:
355        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
356            o = args.subspace_index(0)
357            i = args.subspace_index(1)
358            args.available_buffer[o] = args.target_tensor[i]
359            args.available_buffer[i] = args.target_tensor[o]
360            return args.available_buffer
361
362        def _unitary_(self):
363            return np.array([[0, 1], [1, 0]])
364
365        def _num_qubits_(self):
366            return 1
367
368    cirq.testing.assert_has_consistent_apply_unitary(SameEffect())
369
370    class SameQuditEffect:
371        def _qid_shape_(self):
372            return (3,)
373
374        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
375            args.available_buffer[..., 0] = args.target_tensor[..., 2]
376            args.available_buffer[..., 1] = args.target_tensor[..., 0]
377            args.available_buffer[..., 2] = args.target_tensor[..., 1]
378            return args.available_buffer
379
380        def _unitary_(self):
381            return np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
382
383    cirq.testing.assert_has_consistent_apply_unitary(SameQuditEffect())
384
385    class BadExponent:
386        def __init__(self, power):
387            self.power = power
388
389        def __pow__(self, power):
390            return BadExponent(self.power * power)
391
392        def _num_qubits_(self):
393            return 1
394
395        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
396            i = args.subspace_index(1)
397            args.target_tensor[i] *= self.power * 2
398            return args.target_tensor
399
400        def _unitary_(self):
401            return np.array([[1, 0], [0, 2]])
402
403    cirq.testing.assert_has_consistent_apply_unitary(BadExponent(1))
404
405    with pytest.raises(AssertionError):
406        cirq.testing.assert_has_consistent_apply_unitary_for_various_exponents(
407            BadExponent(1), exponents=[1, 2]
408        )
409
410    class EffectWithoutUnitary:
411        def _num_qubits_(self):
412            return 1
413
414        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
415            return args.target_tensor
416
417    cirq.testing.assert_has_consistent_apply_unitary(EffectWithoutUnitary())
418
419    class NoEffect:
420        def _num_qubits_(self):
421            return 1
422
423        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
424            return NotImplemented
425
426    cirq.testing.assert_has_consistent_apply_unitary(NoEffect())
427
428    class UnknownCountEffect:
429        pass
430
431    with pytest.raises(TypeError, match="no _num_qubits_ or _qid_shape_"):
432        cirq.testing.assert_has_consistent_apply_unitary(UnknownCountEffect())
433
434    cirq.testing.assert_has_consistent_apply_unitary(cirq.X)
435
436    cirq.testing.assert_has_consistent_apply_unitary(cirq.X.on(cirq.NamedQubit('q')))
437
438
439def test_assert_has_consistent_qid_shape():
440    class ConsistentGate(cirq.Gate):
441        def _num_qubits_(self):
442            return 4
443
444        def _qid_shape_(self):
445            return 1, 2, 3, 4
446
447    class InconsistentGate(cirq.Gate):
448        def _num_qubits_(self):
449            return 2
450
451        def _qid_shape_(self):
452            return 1, 2, 3, 4
453
454    class BadShapeGate(cirq.Gate):
455        def _num_qubits_(self):
456            return 4
457
458        def _qid_shape_(self):
459            return 1, 2, 0, 4
460
461    class ConsistentOp(cirq.Operation):
462        def with_qubits(self, *qubits):
463            raise NotImplementedError  # coverage: ignore
464
465        @property
466        def qubits(self):
467            return cirq.LineQubit.range(4)
468
469        def _num_qubits_(self):
470            return 4
471
472        def _qid_shape_(self):
473            return (1, 2, 3, 4)
474
475    # The 'coverage: ignore' comments in the InconsistentOp classes is needed
476    # because test_assert_has_consistent_qid_shape may only need to check two of
477    # the three methods before finding an inconsistency and throwing an error.
478    class InconsistentOp1(cirq.Operation):
479        def with_qubits(self, *qubits):
480            raise NotImplementedError  # coverage: ignore
481
482        @property
483        def qubits(self):
484            return cirq.LineQubit.range(2)
485
486        def _num_qubits_(self):
487            return 4  # coverage: ignore
488
489        def _qid_shape_(self):
490            return (1, 2, 3, 4)  # coverage: ignore
491
492    class InconsistentOp2(cirq.Operation):
493        def with_qubits(self, *qubits):
494            raise NotImplementedError  # coverage: ignore
495
496        @property
497        def qubits(self):
498            return cirq.LineQubit.range(4)  # coverage: ignore
499
500        def _num_qubits_(self):
501            return 2
502
503        def _qid_shape_(self):
504            return (1, 2, 3, 4)  # coverage: ignore
505
506    class InconsistentOp3(cirq.Operation):
507        def with_qubits(self, *qubits):
508            raise NotImplementedError  # coverage: ignore
509
510        @property
511        def qubits(self):
512            return cirq.LineQubit.range(4)  # coverage: ignore
513
514        def _num_qubits_(self):
515            return 4  # coverage: ignore
516
517        def _qid_shape_(self):
518            return 1, 2
519
520    class NoProtocol:
521        pass
522
523    cirq.testing.assert_has_consistent_qid_shape(ConsistentGate())
524    with pytest.raises(AssertionError, match='disagree'):
525        cirq.testing.assert_has_consistent_qid_shape(InconsistentGate())
526    with pytest.raises(AssertionError, match='positive'):
527        cirq.testing.assert_has_consistent_qid_shape(BadShapeGate())
528    cirq.testing.assert_has_consistent_qid_shape(ConsistentOp())
529    with pytest.raises(AssertionError, match='disagree'):
530        cirq.testing.assert_has_consistent_qid_shape(InconsistentOp1())
531    with pytest.raises(AssertionError, match='disagree'):
532        cirq.testing.assert_has_consistent_qid_shape(InconsistentOp2())
533    with pytest.raises(AssertionError, match='disagree'):
534        cirq.testing.assert_has_consistent_qid_shape(InconsistentOp3())
535    cirq.testing.assert_has_consistent_qid_shape(NoProtocol())
536
537
538def test_assert_apply_unitary_works_when_axes_transposed_failure():
539    class BadOp:
540        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):
541            # Get a more convenient view of the data.
542            a, b = args.axes
543            rest = list(range(len(args.target_tensor.shape)))
544            rest.remove(a)
545            rest.remove(b)
546            size = args.target_tensor.size
547            view = args.target_tensor.transpose([a, b, *rest])
548            view = view.reshape((4, size // 4))  # Oops. Reshape might copy.
549
550            # Apply phase gradient.
551            view[1, ...] *= 1j
552            view[2, ...] *= -1
553            view[3, ...] *= -1j
554            return args.target_tensor
555
556        def _num_qubits_(self):
557            return 2
558
559    bad_op = BadOp()
560    assert cirq.has_unitary(bad_op)
561
562    # Appears to work.
563    np.testing.assert_allclose(cirq.unitary(bad_op), np.diag([1, 1j, -1, -1j]))
564    # But fails the more discerning test.
565    with pytest.raises(AssertionError, match='acted differently on out-of-order axes'):
566        for _ in range(100):  # Axis orders chosen at random. Brute force a hit.
567            _assert_apply_unitary_works_when_axes_transposed(bad_op)
568