1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import AbstractSet, Iterator, Any
16
17import pytest
18import numpy as np
19import sympy
20
21import cirq
22
23
24class ValidQubit(cirq.Qid):
25    def __init__(self, name):
26        self._name = name
27
28    @property
29    def dimension(self):
30        return 2
31
32    def _comparison_key(self):
33        return self._name
34
35    def __repr__(self):
36        return f'ValidQubit({self._name!r})'
37
38    def __str__(self):
39        return f'TQ_{self._name!s}'
40
41
42class ValidQid(cirq.Qid):
43    def __init__(self, name, dimension):
44        self._name = name
45        self._dimension = dimension
46        self.validate_dimension(dimension)
47
48    @property
49    def dimension(self):
50        return self._dimension
51
52    def with_dimension(self, dimension):
53        return ValidQid(self._name, dimension)
54
55    def _comparison_key(self):
56        return self._name
57
58
59def test_wrapped_qid():
60    assert type(ValidQubit('a').with_dimension(3)) is not ValidQubit
61    assert type(ValidQubit('a').with_dimension(2)) is ValidQubit
62    assert type(ValidQubit('a').with_dimension(5).with_dimension(2)) is ValidQubit
63    assert ValidQubit('a').with_dimension(3).with_dimension(4) == ValidQubit('a').with_dimension(4)
64    assert ValidQubit('a').with_dimension(3).qubit == ValidQubit('a')
65    assert ValidQubit('a').with_dimension(3) == ValidQubit('a').with_dimension(3)
66    assert ValidQubit('a').with_dimension(3) < ValidQubit('a').with_dimension(4)
67    assert ValidQubit('a').with_dimension(3) < ValidQubit('b').with_dimension(3)
68    assert ValidQubit('a').with_dimension(4) < ValidQubit('b').with_dimension(3)
69
70    cirq.testing.assert_equivalent_repr(
71        ValidQubit('a').with_dimension(3), global_vals={'ValidQubit': ValidQubit}
72    )
73    assert str(ValidQubit('a').with_dimension(3)) == 'TQ_a (d=3)'
74
75    assert ValidQubit('zz').with_dimension(3)._json_dict_() == {
76        'cirq_type': '_QubitAsQid',
77        'qubit': ValidQubit('zz'),
78        'dimension': 3,
79    }
80
81
82def test_qid_dimension():
83    assert ValidQubit('a').dimension == 2
84    assert ValidQubit('a').with_dimension(3).dimension == 3
85    with pytest.raises(ValueError, match='Wrong qid dimension'):
86        _ = ValidQubit('a').with_dimension(0)
87    with pytest.raises(ValueError, match='Wrong qid dimension'):
88        _ = ValidQubit('a').with_dimension(-3)
89
90    assert ValidQid('a', 3).dimension == 3
91    assert ValidQid('a', 3).with_dimension(2).dimension == 2
92    assert ValidQid('a', 3).with_dimension(4) == ValidQid('a', 4)
93    with pytest.raises(ValueError, match='Wrong qid dimension'):
94        _ = ValidQid('a', 3).with_dimension(0)
95    with pytest.raises(ValueError, match='Wrong qid dimension'):
96        _ = ValidQid('a', 3).with_dimension(-3)
97
98
99class ValiGate(cirq.Gate):
100    def _num_qubits_(self):
101        return 2
102
103    def validate_args(self, qubits):
104        if len(qubits) == 1:
105            return  # Bypass check for some tests
106        super().validate_args(qubits)
107
108
109def test_gate():
110    a, b, c = cirq.LineQubit.range(3)
111
112    g = ValiGate()
113    assert cirq.num_qubits(g) == 2
114
115    _ = g.on(a, c)
116    with pytest.raises(ValueError, match='Wrong number'):
117        _ = g.on(a, c, b)
118
119    _ = g(a)  # Bypassing validate_args
120    _ = g(a, c)
121    with pytest.raises(ValueError, match='Wrong number'):
122        _ = g(c, b, a)
123    with pytest.raises(ValueError, match='Wrong shape'):
124        _ = g(a, b.with_dimension(3))
125
126    assert g.controlled(0) is g
127
128
129def test_op():
130    a, b, c, d = cirq.LineQubit.range(4)
131    g = ValiGate()
132    op = g(a, b)
133    assert op.controlled_by() is op
134    controlled_op = op.controlled_by(c, d)
135    assert controlled_op.sub_operation == op
136    assert controlled_op.controls == (c, d)
137
138
139def test_op_validate():
140    op = cirq.X(cirq.LineQid(0, 2))
141    op2 = cirq.CNOT(*cirq.LineQid.range(2, dimension=2))
142    op.validate_args([cirq.LineQid(1, 2)])  # Valid
143    op2.validate_args(cirq.LineQid.range(1, 3, dimension=2))  # Valid
144    with pytest.raises(ValueError, match='Wrong shape'):
145        op.validate_args([cirq.LineQid(1, 9)])
146    with pytest.raises(ValueError, match='Wrong number'):
147        op.validate_args([cirq.LineQid(1, 2), cirq.LineQid(2, 2)])
148    with pytest.raises(ValueError, match='Duplicate'):
149        op2.validate_args([cirq.LineQid(1, 2), cirq.LineQid(1, 2)])
150
151
152def test_default_validation_and_inverse():
153    class TestGate(cirq.Gate):
154        def _num_qubits_(self):
155            return 2
156
157        def _decompose_(self, qubits):
158            a, b = qubits
159            yield cirq.Z(a)
160            yield cirq.S(b)
161            yield cirq.X(a)
162
163        def __eq__(self, other):
164            return isinstance(other, TestGate)
165
166        def __repr__(self):
167            return 'TestGate()'
168
169    a, b = cirq.LineQubit.range(2)
170
171    with pytest.raises(ValueError, match='number of qubits'):
172        TestGate().on(a)
173
174    t = TestGate().on(a, b)
175    i = t ** -1
176    assert i ** -1 == t
177    assert t ** -1 == i
178    assert cirq.decompose(i) == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)]
179    cirq.testing.assert_allclose_up_to_global_phase(
180        cirq.unitary(i), cirq.unitary(t).conj().T, atol=1e-8
181    )
182
183    cirq.testing.assert_implements_consistent_protocols(i, local_vals={'TestGate': TestGate})
184
185
186def test_default_inverse():
187    class TestGate(cirq.Gate):
188        def _num_qubits_(self):
189            return 3
190
191        def _decompose_(self, qubits):
192            return (cirq.X ** 0.1).on_each(*qubits)
193
194    assert cirq.inverse(TestGate(), None) is not None
195    cirq.testing.assert_has_consistent_qid_shape(cirq.inverse(TestGate()))
196    cirq.testing.assert_has_consistent_qid_shape(
197        cirq.inverse(TestGate().on(*cirq.LineQubit.range(3)))
198    )
199
200
201def test_no_inverse_if_not_unitary():
202    class TestGate(cirq.Gate):
203        def _num_qubits_(self):
204            return 1
205
206        def _decompose_(self, qubits):
207            return cirq.amplitude_damp(0.5).on(qubits[0])
208
209    assert cirq.inverse(TestGate(), None) is None
210
211
212def test_default_qudit_inverse():
213    class TestGate(cirq.Gate):
214        def _qid_shape_(self):
215            return (1, 2, 3)
216
217        def _decompose_(self, qubits):
218            return (cirq.X ** 0.1).on(qubits[1])
219
220    assert cirq.qid_shape(cirq.inverse(TestGate(), None)) == (1, 2, 3)
221    cirq.testing.assert_has_consistent_qid_shape(cirq.inverse(TestGate()))
222
223
224@pytest.mark.parametrize(
225    'expression, expected_result',
226    (
227        (cirq.X * 2, 2 * cirq.X),
228        (cirq.Y * 2, cirq.Y + cirq.Y),
229        (cirq.Z - cirq.Z + cirq.Z, cirq.Z.wrap_in_linear_combination()),
230        (1j * cirq.S * 1j, -cirq.S),
231        (cirq.CZ * 1, cirq.CZ / 1),
232        (-cirq.CSWAP * 1j, cirq.CSWAP / 1j),
233        (cirq.TOFFOLI * 0.5, cirq.TOFFOLI / 2),
234    ),
235)
236def test_gate_algebra(expression, expected_result):
237    assert expression == expected_result
238
239
240def test_gate_shape():
241    class ShapeGate(cirq.Gate):
242        def _qid_shape_(self):
243            return (1, 2, 3, 4)
244
245    class QubitGate(cirq.Gate):
246        def _num_qubits_(self):
247            return 3
248
249    class DeprecatedGate(cirq.Gate):
250        def num_qubits(self):
251            return 3
252
253    shape_gate = ShapeGate()
254    assert cirq.qid_shape(shape_gate) == (1, 2, 3, 4)
255    assert cirq.num_qubits(shape_gate) == 4
256    assert shape_gate.num_qubits() == 4
257
258    qubit_gate = QubitGate()
259    assert cirq.qid_shape(qubit_gate) == (2, 2, 2)
260    assert cirq.num_qubits(qubit_gate) == 3
261    assert qubit_gate.num_qubits() == 3
262
263    dep_gate = DeprecatedGate()
264    assert cirq.qid_shape(dep_gate) == (2, 2, 2)
265    assert cirq.num_qubits(dep_gate) == 3
266    assert dep_gate.num_qubits() == 3
267
268
269def test_gate_shape_protocol():
270    """This test is only needed while the `_num_qubits_` and `_qid_shape_`
271    methods are implemented as alternatives.  This can be removed once the
272    deprecated `num_qubits` method is removed."""
273
274    class NotImplementedGate1(cirq.Gate):
275        def _num_qubits_(self):
276            return NotImplemented
277
278        def _qid_shape_(self):
279            return NotImplemented
280
281    class NotImplementedGate2(cirq.Gate):
282        def _num_qubits_(self):
283            return NotImplemented
284
285    class NotImplementedGate3(cirq.Gate):
286        def _qid_shape_(self):
287            return NotImplemented
288
289    class ShapeGate(cirq.Gate):
290        def _num_qubits_(self):
291            return NotImplemented
292
293        def _qid_shape_(self):
294            return (1, 2, 3)
295
296    class QubitGate(cirq.Gate):
297        def _num_qubits_(self):
298            return 2
299
300        def _qid_shape_(self):
301            return NotImplemented
302
303    with pytest.raises(TypeError, match='returned NotImplemented'):
304        cirq.qid_shape(NotImplementedGate1())
305    with pytest.raises(TypeError, match='returned NotImplemented'):
306        cirq.num_qubits(NotImplementedGate1())
307    with pytest.raises(TypeError, match='returned NotImplemented'):
308        _ = NotImplementedGate1().num_qubits()  # Deprecated
309    with pytest.raises(TypeError, match='returned NotImplemented'):
310        cirq.qid_shape(NotImplementedGate2())
311    with pytest.raises(TypeError, match='returned NotImplemented'):
312        cirq.num_qubits(NotImplementedGate2())
313    with pytest.raises(TypeError, match='returned NotImplemented'):
314        _ = NotImplementedGate2().num_qubits()  # Deprecated
315    with pytest.raises(TypeError, match='returned NotImplemented'):
316        cirq.qid_shape(NotImplementedGate3())
317    with pytest.raises(TypeError, match='returned NotImplemented'):
318        cirq.num_qubits(NotImplementedGate3())
319    with pytest.raises(TypeError, match='returned NotImplemented'):
320        _ = NotImplementedGate3().num_qubits()  # Deprecated
321    assert cirq.qid_shape(ShapeGate()) == (1, 2, 3)
322    assert cirq.num_qubits(ShapeGate()) == 3
323    assert ShapeGate().num_qubits() == 3  # Deprecated
324    assert cirq.qid_shape(QubitGate()) == (2, 2)
325    assert cirq.num_qubits(QubitGate()) == 2
326    assert QubitGate().num_qubits() == 2  # Deprecated
327
328
329def test_operation_shape():
330    class FixedQids(cirq.Operation):
331        def with_qubits(self, *new_qids):
332            raise NotImplementedError  # coverage: ignore
333
334    class QubitOp(FixedQids):
335        @property
336        def qubits(self):
337            return cirq.LineQubit.range(2)
338
339    class NumQubitOp(FixedQids):
340        @property
341        def qubits(self):
342            return cirq.LineQubit.range(3)
343
344        def _num_qubits_(self):
345            return 3
346
347    class ShapeOp(FixedQids):
348        @property
349        def qubits(self):
350            return cirq.LineQubit.range(4)
351
352        def _qid_shape_(self):
353            return (1, 2, 3, 4)
354
355    qubit_op = QubitOp()
356    assert len(qubit_op.qubits) == 2
357    assert cirq.qid_shape(qubit_op) == (2, 2)
358    assert cirq.num_qubits(qubit_op) == 2
359
360    num_qubit_op = NumQubitOp()
361    assert len(num_qubit_op.qubits) == 3
362    assert cirq.qid_shape(num_qubit_op) == (2, 2, 2)
363    assert cirq.num_qubits(num_qubit_op) == 3
364
365    shape_op = ShapeOp()
366    assert len(shape_op.qubits) == 4
367    assert cirq.qid_shape(shape_op) == (1, 2, 3, 4)
368    assert cirq.num_qubits(shape_op) == 4
369
370
371def test_gate_json_dict():
372    g = cirq.CSWAP  # not an eigen gate (which has its own _json_dict_)
373    assert g._json_dict_() == {
374        'cirq_type': 'CSwapGate',
375    }
376
377
378def test_inverse_composite_diagram_info():
379    class Gate(cirq.Gate):
380        def _decompose_(self, qubits):
381            return cirq.S.on(qubits[0])
382
383        def num_qubits(self) -> int:
384            return 1
385
386    c = cirq.inverse(Gate())
387    assert cirq.circuit_diagram_info(c, default=None) is None
388
389    class Gate2(cirq.Gate):
390        def _decompose_(self, qubits):
391            return cirq.S.on(qubits[0])
392
393        def num_qubits(self) -> int:
394            return 1
395
396        def _circuit_diagram_info_(self, args):
397            return 's!'
398
399    c = cirq.inverse(Gate2())
400    assert cirq.circuit_diagram_info(c) == cirq.CircuitDiagramInfo(
401        wire_symbols=('s!',), exponent=-1
402    )
403
404
405def test_tagged_operation_equality():
406    eq = cirq.testing.EqualsTester()
407    q1 = cirq.GridQubit(1, 1)
408    op = cirq.X(q1)
409    op2 = cirq.Y(q1)
410
411    eq.add_equality_group(op)
412    eq.add_equality_group(op.with_tags('tag1'), cirq.TaggedOperation(op, 'tag1'))
413    eq.add_equality_group(op2.with_tags('tag1'), cirq.TaggedOperation(op2, 'tag1'))
414    eq.add_equality_group(op.with_tags('tag2'), cirq.TaggedOperation(op, 'tag2'))
415    eq.add_equality_group(
416        op.with_tags('tag1', 'tag2'),
417        op.with_tags('tag1').with_tags('tag2'),
418        cirq.TaggedOperation(op, 'tag1', 'tag2'),
419    )
420
421
422def test_tagged_operation():
423    q1 = cirq.GridQubit(1, 1)
424    q2 = cirq.GridQubit(2, 2)
425    op = cirq.X(q1).with_tags('tag1')
426    op_repr = "cirq.X(cirq.GridQubit(1, 1))"
427    assert repr(op) == f"cirq.TaggedOperation({op_repr}, 'tag1')"
428
429    assert op.qubits == (q1,)
430    assert op.tags == ('tag1',)
431    assert op.gate == cirq.X
432    assert op.with_qubits(q2) == cirq.X(q2).with_tags('tag1')
433    assert op.with_qubits(q2).qubits == (q2,)
434    assert not cirq.is_measurement(op)
435
436
437def test_with_tags_returns_same_instance_if_possible():
438    untagged = cirq.X(cirq.GridQubit(1, 1))
439    assert untagged.with_tags() is untagged
440
441    tagged = untagged.with_tags('foo')
442    assert tagged.with_tags() is tagged
443
444
445def test_tagged_measurement():
446    assert not cirq.is_measurement(cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0'))
447
448    a = cirq.LineQubit(0)
449    op = cirq.measure(a, key='m').with_tags('tag')
450    assert cirq.is_measurement(op)
451
452    remap_op = cirq.with_measurement_key_mapping(op, {'m': 'k'})
453    assert remap_op.tags == ('tag',)
454    assert cirq.is_measurement(remap_op)
455    assert cirq.measurement_key_names(remap_op) == {'k'}
456    assert cirq.with_measurement_key_mapping(op, {'x': 'k'}) == op
457
458
459def test_cannot_remap_non_measurement_gate():
460    a = cirq.LineQubit(0)
461    op = cirq.X(a).with_tags('tag')
462
463    assert cirq.with_measurement_key_mapping(op, {'m': 'k'}) is NotImplemented
464
465
466def test_circuit_diagram():
467    class TaggyTag:
468        """Tag with a custom repr function to test circuit diagrams."""
469
470        def __repr__(self):
471            return 'TaggyTag()'
472
473    h = cirq.H(cirq.GridQubit(1, 1))
474    tagged_h = h.with_tags('tag1')
475    non_string_tag_h = h.with_tags(TaggyTag())
476
477    expected = cirq.CircuitDiagramInfo(
478        wire_symbols=("H['tag1']",),
479        exponent=1.0,
480        connected=True,
481        exponent_qubit_index=None,
482        auto_exponent_parens=True,
483    )
484    args = cirq.CircuitDiagramInfoArgs(None, None, None, None, None, False)
485    assert cirq.circuit_diagram_info(tagged_h) == expected
486    assert cirq.circuit_diagram_info(tagged_h, args) == cirq.circuit_diagram_info(h)
487
488    c = cirq.Circuit(tagged_h)
489    diagram_with_tags = "(1, 1): ───H['tag1']───"
490    diagram_without_tags = "(1, 1): ───H───"
491    assert str(cirq.Circuit(tagged_h)) == diagram_with_tags
492    assert c.to_text_diagram() == diagram_with_tags
493    assert c.to_text_diagram(include_tags=False) == diagram_without_tags
494
495    c = cirq.Circuit(non_string_tag_h)
496    diagram_with_non_string_tag = "(1, 1): ───H[TaggyTag()]───"
497    assert c.to_text_diagram() == diagram_with_non_string_tag
498    assert c.to_text_diagram(include_tags=False) == diagram_without_tags
499
500
501def test_circuit_diagram_tagged_global_phase():
502    # Tests global phase operation
503    q = cirq.NamedQubit('a')
504    global_phase = cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0')
505
506    # Just global phase in a circuit
507    assert cirq.circuit_diagram_info(global_phase, default='default') == 'default'
508    cirq.testing.assert_has_diagram(
509        cirq.Circuit(global_phase), "\n\nglobal phase:   π['tag0']", use_unicode_characters=True
510    )
511    cirq.testing.assert_has_diagram(
512        cirq.Circuit(global_phase),
513        "\n\nglobal phase:   π",
514        use_unicode_characters=True,
515        include_tags=False,
516    )
517
518    expected = cirq.CircuitDiagramInfo(
519        wire_symbols=(),
520        exponent=1.0,
521        connected=True,
522        exponent_qubit_index=None,
523        auto_exponent_parens=True,
524    )
525
526    # Operation with no qubits and returns diagram info with no wire symbols
527    class NoWireSymbols(cirq.GlobalPhaseOperation):
528        def _circuit_diagram_info_(
529            self, args: 'cirq.CircuitDiagramInfoArgs'
530        ) -> 'cirq.CircuitDiagramInfo':
531            return expected
532
533    no_wire_symbol_op = NoWireSymbols(coefficient=-1.0).with_tags('tag0')
534    assert cirq.circuit_diagram_info(no_wire_symbol_op, default='default') == expected
535    cirq.testing.assert_has_diagram(
536        cirq.Circuit(no_wire_symbol_op),
537        "\n\nglobal phase:   π['tag0']",
538        use_unicode_characters=True,
539    )
540
541    # Two global phases in one moment
542    tag1 = cirq.GlobalPhaseOperation(coefficient=1j).with_tags('tag1')
543    tag2 = cirq.GlobalPhaseOperation(coefficient=1j).with_tags('tag2')
544    c = cirq.Circuit([cirq.X(q), tag1, tag2])
545    cirq.testing.assert_has_diagram(
546        c,
547        """\
548a: ─────────────X───────────────────
549
550global phase:   π['tag1', 'tag2']""",
551        use_unicode_characters=True,
552        precision=2,
553    )
554
555    # Two moments with global phase, one with another tagged gate
556    c = cirq.Circuit([cirq.X(q).with_tags('x_tag'), tag1])
557    c.append(cirq.Moment([cirq.X(q), tag2]))
558    cirq.testing.assert_has_diagram(
559        c,
560        """\
561a: ─────────────X['x_tag']─────X──────────────
562
563global phase:   0.5π['tag1']   0.5π['tag2']
564""",
565        use_unicode_characters=True,
566        include_tags=True,
567    )
568
569
570def test_circuit_diagram_no_circuit_diagram():
571    class NoCircuitDiagram(cirq.Gate):
572        def num_qubits(self) -> int:
573            return 1
574
575        def __repr__(self):
576            return 'guess-i-will-repr'
577
578    q = cirq.GridQubit(1, 1)
579    expected = "(1, 1): ───guess-i-will-repr───"
580    assert cirq.Circuit(NoCircuitDiagram()(q)).to_text_diagram() == expected
581    expected = "(1, 1): ───guess-i-will-repr['taggy']───"
582    assert cirq.Circuit(NoCircuitDiagram()(q).with_tags('taggy')).to_text_diagram() == expected
583
584
585def test_tagged_operation_forwards_protocols():
586    """The results of all protocols applied to an operation with a tag should
587    be equivalent to the result without tags.
588    """
589    q1 = cirq.GridQubit(1, 1)
590    q2 = cirq.GridQubit(1, 2)
591    h = cirq.H(q1)
592    tag = 'tag1'
593    tagged_h = cirq.H(q1).with_tags(tag)
594
595    np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h))
596    assert cirq.has_unitary(tagged_h)
597    assert cirq.decompose(tagged_h) == cirq.decompose(h)
598    assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
599    assert cirq.equal_up_to_global_phase(h, tagged_h)
600    assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all()
601
602    assert cirq.measurement_key_name(cirq.measure(q1, key='blah').with_tags(tag)) == 'blah'
603    assert cirq.measurement_key_obj(
604        cirq.measure(q1, key='blah').with_tags(tag)
605    ) == cirq.MeasurementKey('blah')
606
607    parameterized_op = cirq.XPowGate(exponent=sympy.Symbol('t'))(q1).with_tags(tag)
608    assert cirq.is_parameterized(parameterized_op)
609    resolver = cirq.study.ParamResolver({'t': 0.25})
610    assert cirq.resolve_parameters(parameterized_op, resolver) == cirq.XPowGate(exponent=0.25)(
611        q1
612    ).with_tags(tag)
613    assert cirq.resolve_parameters_once(parameterized_op, resolver) == cirq.XPowGate(exponent=0.25)(
614        q1
615    ).with_tags(tag)
616
617    y = cirq.Y(q1)
618    tagged_y = cirq.Y(q1).with_tags(tag)
619    assert tagged_y ** 0.5 == cirq.YPowGate(exponent=0.5)(q1)
620    assert tagged_y * 2 == (y * 2)
621    assert 3 * tagged_y == (3 * y)
622    assert cirq.phase_by(y, 0.125, 0) == cirq.phase_by(tagged_y, 0.125, 0)
623    controlled_y = tagged_y.controlled_by(q2)
624    assert controlled_y.qubits == (
625        q2,
626        q1,
627    )
628    assert isinstance(controlled_y, cirq.Operation)
629    assert not isinstance(controlled_y, cirq.TaggedOperation)
630
631    clifford_x = cirq.SingleQubitCliffordGate.X(q1)
632    tagged_x = cirq.SingleQubitCliffordGate.X(q1).with_tags(tag)
633    assert cirq.commutes(clifford_x, clifford_x)
634    assert cirq.commutes(tagged_x, clifford_x)
635    assert cirq.commutes(clifford_x, tagged_x)
636    assert cirq.commutes(tagged_x, tagged_x)
637
638    assert cirq.trace_distance_bound(y ** 0.001) == cirq.trace_distance_bound(
639        (y ** 0.001).with_tags(tag)
640    )
641
642    flip = cirq.bit_flip(0.5)(q1)
643    tagged_flip = cirq.bit_flip(0.5)(q1).with_tags(tag)
644    assert cirq.has_mixture(tagged_flip)
645    assert cirq.has_kraus(tagged_flip)
646
647    flip_mixture = cirq.mixture(flip)
648    tagged_mixture = cirq.mixture(tagged_flip)
649    assert len(tagged_mixture) == 2
650    assert len(tagged_mixture[0]) == 2
651    assert len(tagged_mixture[1]) == 2
652    assert tagged_mixture[0][0] == flip_mixture[0][0]
653    assert np.isclose(tagged_mixture[0][1], flip_mixture[0][1]).all()
654    assert tagged_mixture[1][0] == flip_mixture[1][0]
655    assert np.isclose(tagged_mixture[1][1], flip_mixture[1][1]).all()
656
657    qubit_map = {q1: 'q1'}
658    qasm_args = cirq.QasmArgs(qubit_id_map=qubit_map)
659    assert cirq.qasm(h, args=qasm_args) == cirq.qasm(tagged_h, args=qasm_args)
660
661    cirq.testing.assert_has_consistent_apply_unitary(tagged_h)
662
663
664class ParameterizableTag:
665    def __init__(self, value):
666        self.value = value
667
668    def __eq__(self, other):
669        return self.value == other.value
670
671    def _is_parameterized_(self) -> bool:
672        return cirq.is_parameterized(self.value)
673
674    def _parameter_names_(self) -> AbstractSet[str]:
675        return cirq.parameter_names(self.value)
676
677    def _resolve_parameters_(
678        self, resolver: 'cirq.ParamResolver', recursive: bool
679    ) -> 'ParameterizableTag':
680        return ParameterizableTag(cirq.resolve_parameters(self.value, resolver, recursive))
681
682
683@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
684def test_tagged_operation_resolves_parameterized_tags(resolve_fn):
685    q = cirq.GridQubit(0, 0)
686    tag = ParameterizableTag(sympy.Symbol('t'))
687    assert cirq.is_parameterized(tag)
688    assert cirq.parameter_names(tag) == {'t'}
689    op = cirq.Z(q).with_tags(tag)
690    assert cirq.is_parameterized(op)
691    assert cirq.parameter_names(op) == {'t'}
692    resolved_op = resolve_fn(op, {'t': 10})
693    assert resolved_op == cirq.Z(q).with_tags(ParameterizableTag(10))
694    assert not cirq.is_parameterized(resolved_op)
695    assert cirq.parameter_names(resolved_op) == set()
696
697
698def test_inverse_composite_standards():
699    @cirq.value_equality
700    class Gate(cirq.Gate):
701        def _decompose_(self, qubits):
702            return cirq.S.on(qubits[0])
703
704        def num_qubits(self) -> int:
705            return 1
706
707        def _has_unitary_(self):
708            return True
709
710        def _value_equality_values_(self):
711            return ()
712
713        def __repr__(self):
714            return 'C()'
715
716    cirq.testing.assert_implements_consistent_protocols(
717        cirq.inverse(Gate()), global_vals={'C': Gate}
718    )
719
720
721def test_tagged_act_on():
722    class YesActOn(cirq.Gate):
723        def _num_qubits_(self) -> int:
724            return 1
725
726        def _act_on_(self, args, qubits):
727            return True
728
729    class NoActOn(cirq.Gate):
730        def _num_qubits_(self) -> int:
731            return 1
732
733        def _act_on_(self, args, qubits):
734            return NotImplemented
735
736    class MissingActOn(cirq.Operation):
737        def with_qubits(self, *new_qubits):
738            raise NotImplementedError()
739
740        @property
741        def qubits(self):
742            pass
743
744    q = cirq.LineQubit(1)
745    from cirq.protocols.act_on_protocol_test import DummyActOnArgs
746
747    args = DummyActOnArgs()
748    cirq.act_on(YesActOn()(q).with_tags("test"), args)
749    with pytest.raises(TypeError, match="Failed to act"):
750        cirq.act_on(NoActOn()(q).with_tags("test"), args)
751    with pytest.raises(TypeError, match="Failed to act"):
752        cirq.act_on(MissingActOn().with_tags("test"), args)
753
754
755def test_single_qubit_gate_validates_on_each():
756    class Dummy(cirq.SingleQubitGate):
757        def matrix(self):
758            pass
759
760    g = Dummy()
761    assert g.num_qubits() == 1
762
763    test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)]
764
765    _ = g.on_each(*test_qubits)
766    _ = g.on_each(test_qubits)
767
768    test_non_qubits = [str(i) for i in range(3)]
769    with pytest.raises(ValueError):
770        _ = g.on_each(*test_non_qubits)
771    with pytest.raises(ValueError):
772        _ = g.on_each(*test_non_qubits)
773
774
775def test_on_each():
776    class CustomGate(cirq.SingleQubitGate):
777        pass
778
779    a = cirq.NamedQubit('a')
780    b = cirq.NamedQubit('b')
781    c = CustomGate()
782
783    assert c.on_each() == []
784    assert c.on_each(a) == [c(a)]
785    assert c.on_each(a, b) == [c(a), c(b)]
786    assert c.on_each(b, a) == [c(b), c(a)]
787
788    assert c.on_each([]) == []
789    assert c.on_each([a]) == [c(a)]
790    assert c.on_each([a, b]) == [c(a), c(b)]
791    assert c.on_each([b, a]) == [c(b), c(a)]
792    assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)]
793
794    with pytest.raises(ValueError):
795        c.on_each('abcd')
796    with pytest.raises(ValueError):
797        c.on_each(['abcd'])
798    with pytest.raises(ValueError):
799        c.on_each([a, 'abcd'])
800
801    qubit_iterator = (q for q in [a, b, a, b])
802    assert isinstance(qubit_iterator, Iterator)
803    assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)]
804
805
806def test_on_each_two_qubits():
807    a = cirq.NamedQubit('a')
808    b = cirq.NamedQubit('b')
809    g = cirq.testing.TwoQubitGate()
810
811    assert g.on_each([]) == []
812    assert g.on_each([(a, b)]) == [g(a, b)]
813    assert g.on_each([[a, b]]) == [g(a, b)]
814    assert g.on_each([(b, a)]) == [g(b, a)]
815    assert g.on_each([(a, b), (b, a)]) == [g(a, b), g(b, a)]
816    assert g.on_each(zip([a, b], [b, a])) == [g(a, b), g(b, a)]
817    assert g.on_each() == []
818    assert g.on_each((b, a)) == [g(b, a)]
819    assert g.on_each((a, b), (a, b)) == [g(a, b), g(a, b)]
820    assert g.on_each(*zip([a, b], [b, a])) == [g(a, b), g(b, a)]
821    with pytest.raises(TypeError, match='object is not iterable'):
822        g.on_each(a)
823    with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
824        g.on_each(a, b)
825    with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
826        g.on_each([12])
827    with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
828        g.on_each([(a, b), 12])
829    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
830        g.on_each([(a, b), [(a, b)]])
831    with pytest.raises(ValueError, match='Expected 2 qubits'):
832        g.on_each([()])
833    with pytest.raises(ValueError, match='Expected 2 qubits'):
834        g.on_each([(a,)])
835    with pytest.raises(ValueError, match='Expected 2 qubits'):
836        g.on_each([(a, b, a)])
837    with pytest.raises(ValueError, match='Expected 2 qubits'):
838        g.on_each(zip([a, a]))
839    with pytest.raises(ValueError, match='Expected 2 qubits'):
840        g.on_each(zip([a, a], [b, b], [a, a]))
841    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
842        g.on_each('ab')
843    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
844        g.on_each(('ab',))
845    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
846        g.on_each([('ab',)])
847    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
848        g.on_each([(a, 'ab')])
849    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
850        g.on_each([(a, 'b')])
851
852    qubit_iterator = (qs for qs in [[a, b], [a, b]])
853    assert isinstance(qubit_iterator, Iterator)
854    assert g.on_each(qubit_iterator) == [g(a, b), g(a, b)]
855
856
857def test_on_each_three_qubits():
858    a = cirq.NamedQubit('a')
859    b = cirq.NamedQubit('b')
860    c = cirq.NamedQubit('c')
861    g = cirq.testing.ThreeQubitGate()
862
863    assert g.on_each([]) == []
864    assert g.on_each([(a, b, c)]) == [g(a, b, c)]
865    assert g.on_each([[a, b, c]]) == [g(a, b, c)]
866    assert g.on_each([(c, b, a)]) == [g(c, b, a)]
867    assert g.on_each([(a, b, c), (c, b, a)]) == [g(a, b, c), g(c, b, a)]
868    assert g.on_each(zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)]
869    assert g.on_each() == []
870    assert g.on_each((c, b, a)) == [g(c, b, a)]
871    assert g.on_each((a, b, c), (c, b, a)) == [g(a, b, c), g(c, b, a)]
872    assert g.on_each(*zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)]
873    with pytest.raises(TypeError, match='object is not iterable'):
874        g.on_each(a)
875    with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
876        g.on_each(a, b, c)
877    with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
878        g.on_each([12])
879    with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
880        g.on_each([(a, b, c), 12])
881    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
882        g.on_each([(a, b, c), [(a, b, c)]])
883    with pytest.raises(ValueError, match='Expected 3 qubits'):
884        g.on_each([(a,)])
885    with pytest.raises(ValueError, match='Expected 3 qubits'):
886        g.on_each([(a, b)])
887    with pytest.raises(ValueError, match='Expected 3 qubits'):
888        g.on_each([(a, b, c, a)])
889    with pytest.raises(ValueError, match='Expected 3 qubits'):
890        g.on_each(zip([a, a], [b, b]))
891    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
892        g.on_each('abc')
893    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
894        g.on_each(('abc',))
895    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
896        g.on_each([('abc',)])
897    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
898        g.on_each([(a, 'abc')])
899    with pytest.raises(ValueError, match='All values in sequence should be Qids'):
900        g.on_each([(a, 'bc')])
901
902    qubit_iterator = (qs for qs in [[a, b, c], [a, b, c]])
903    assert isinstance(qubit_iterator, Iterator)
904    assert g.on_each(qubit_iterator) == [g(a, b, c), g(a, b, c)]
905
906
907def test_on_each_iterable_qid():
908    class QidIter(cirq.Qid):
909        @property
910        def dimension(self) -> int:
911            return 2
912
913        def _comparison_key(self) -> Any:
914            return 1
915
916        def __iter__(self):
917            raise NotImplementedError()
918
919    assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter())
920