1# Copyright 2018 The ops Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Tuple, cast
16
17from cirq import ops, circuits
18from cirq.contrib.paulistring.convert_gate_set import converted_gate_set
19
20
21def clifford_optimized_circuit(circuit: circuits.Circuit, atol: float = 1e-8) -> circuits.Circuit:
22    # Convert to a circuit with SingleQubitCliffordGates,
23    # CZs and other ignored gates
24    c_cliff = converted_gate_set(circuit, no_clifford_gates=False, atol=atol)
25
26    all_ops = list(c_cliff.all_operations())
27
28    def find_merge_point(
29        start_i: int,
30        string_op: ops.PauliStringPhasor,
31        stop_at_cz: bool,
32    ) -> Tuple[int, ops.PauliStringPhasor, int]:
33        STOP = 0
34        CONTINUE = 1
35        SKIP = 2
36
37        def continue_condition(
38            op: ops.Operation, current_string: ops.PauliStringPhasor, is_first: bool
39        ) -> int:
40            if isinstance(op.gate, ops.SingleQubitCliffordGate):
41                return CONTINUE if len(current_string.pauli_string) != 1 else STOP
42            if isinstance(op.gate, ops.CZPowGate):
43                return STOP if stop_at_cz else CONTINUE
44            if (
45                isinstance(op, ops.PauliStringPhasor)
46                and len(op.qubits) == 1
47                and (op.pauli_string[op.qubits[0]] == current_string.pauli_string[op.qubits[0]])
48            ):
49                return SKIP
50            return STOP
51
52        modified_op = string_op
53        furthest_op = string_op
54        furthest_i = start_i + 1
55        num_passed_over = 0
56        for i in range(start_i + 1, len(all_ops)):
57            op = all_ops[i]
58            if not set(op.qubits) & set(modified_op.qubits):
59                # No qubits in common
60                continue
61            cont_cond = continue_condition(op, modified_op, i == start_i + 1)
62            if cont_cond == STOP:
63                if len(modified_op.pauli_string) == 1:
64                    furthest_op = modified_op
65                    furthest_i = i
66                break
67            if cont_cond == CONTINUE:
68                modified_op = modified_op.pass_operations_over([op], after_to_before=True)
69            num_passed_over += 1
70            if len(modified_op.pauli_string) == 1:
71                furthest_op = modified_op
72                furthest_i = i + 1
73
74        return furthest_i, furthest_op, num_passed_over
75
76    def try_merge_clifford(cliff_op: ops.GateOperation, start_i: int) -> bool:
77        (orig_qubit,) = cliff_op.qubits
78        remaining_cliff_gate = ops.SingleQubitCliffordGate.I
79        for pauli, quarter_turns in reversed(
80            cast(ops.SingleQubitCliffordGate, cliff_op.gate).decompose_rotation()
81        ):
82            trans = remaining_cliff_gate.transform(pauli)
83            pauli = trans.to
84            quarter_turns *= -1 if trans.flip else 1
85            string_op = ops.PauliStringPhasor(
86                ops.PauliString(pauli(cliff_op.qubits[0])), exponent_neg=quarter_turns / 2
87            )
88
89            merge_i, merge_op, num_passed = find_merge_point(start_i, string_op, quarter_turns == 2)
90            assert merge_i > start_i
91            assert len(merge_op.pauli_string) == 1, 'PauliString length != 1'
92
93            qubit, pauli = next(iter(merge_op.pauli_string.items()))
94            quarter_turns = round(merge_op.exponent_relative * 2)
95            if merge_op.pauli_string.coefficient not in [1, -1]:
96                # TODO: Add support for more general phases.
97                # Github issue: https://github.com/quantumlib/Cirq/issues/2962
98                # Legacy coverage ignore, we need test code that hits this.
99                # coverage: ignore
100                raise NotImplementedError(
101                    'Only +1/-1 pauli string coefficients currently supported'
102                )
103            quarter_turns *= int(merge_op.pauli_string.coefficient.real)
104            quarter_turns %= 4
105            part_cliff_gate = ops.SingleQubitCliffordGate.from_quarter_turns(pauli, quarter_turns)
106
107            other_op = all_ops[merge_i] if merge_i < len(all_ops) else None
108            if other_op is not None and qubit not in set(other_op.qubits):
109                other_op = None
110
111            if isinstance(other_op, ops.GateOperation) and isinstance(
112                other_op.gate, ops.SingleQubitCliffordGate
113            ):
114                # Merge with another SingleQubitCliffordGate
115                new_op = part_cliff_gate.merged_with(other_op.gate)(qubit)
116                all_ops[merge_i] = new_op
117            elif (
118                isinstance(other_op, ops.GateOperation)
119                and isinstance(other_op.gate, ops.CZPowGate)
120                and other_op.gate.exponent == 1
121                and quarter_turns == 2
122            ):
123                # Pass whole Pauli gate over CZ, possibly adding a Z gate
124                if pauli != ops.pauli_gates.Z:
125                    other_qubit = other_op.qubits[other_op.qubits.index(qubit) - 1]
126                    all_ops.insert(merge_i + 1, ops.SingleQubitCliffordGate.Z(other_qubit))
127                all_ops.insert(merge_i + 1, part_cliff_gate(qubit))
128            elif isinstance(other_op, ops.PauliStringPhasor):
129                # Pass over a non-Clifford gate
130                mod_op = other_op.pass_operations_over([part_cliff_gate(qubit)])
131                all_ops[merge_i] = mod_op
132                all_ops.insert(merge_i + 1, part_cliff_gate(qubit))
133            elif merge_i > start_i + 1 and num_passed > 0:
134                # Moved Clifford through the circuit but nothing to merge
135                all_ops.insert(merge_i, part_cliff_gate(qubit))
136            else:
137                # Couldn't move Clifford
138                remaining_cliff_gate = remaining_cliff_gate.merged_with(part_cliff_gate)
139
140        if remaining_cliff_gate == ops.SingleQubitCliffordGate.I:
141            all_ops.pop(start_i)
142            return True
143        all_ops[start_i] = remaining_cliff_gate(orig_qubit)
144        return False
145
146    def try_merge_cz(cz_op: ops.GateOperation, start_i: int) -> int:
147        """Returns the number of operations removed at or before start_i."""
148        for i in reversed(range(start_i)):
149            op = all_ops[i]
150            if not set(cz_op.qubits) & set(op.qubits):
151                # Don't share qubits
152                # Keep looking
153                continue
154            elif not (
155                isinstance(op, ops.GateOperation)
156                and isinstance(op.gate, ops.CZPowGate)
157                and op.gate.exponent == 1
158            ):
159                # Not a CZ gate
160                return 0
161            elif cz_op == op:
162                # Cancel two CZ gates
163                all_ops.pop(start_i)
164                all_ops.pop(i)
165                return 2
166            else:
167                # Two CZ gates that share one qubit
168                # Pass through and keep looking
169                continue  # coverage: ignore
170                # The above line is covered by test_remove_staggered_czs but the
171                # coverage checker disagrees.
172        return 0
173
174    i = 0
175    while i < len(all_ops):
176        op = all_ops[i]
177        if isinstance(op, ops.GateOperation) and isinstance(op.gate, ops.SingleQubitCliffordGate):
178            if try_merge_clifford(op, i):
179                i -= 1
180        elif (
181            isinstance(op, ops.GateOperation)
182            and isinstance(op.gate, ops.CZPowGate)
183            and op.gate.exponent == 1
184        ):
185            num_rm = try_merge_cz(op, i)
186            i -= num_rm
187        i += 1
188
189    return circuits.Circuit(all_ops, strategy=circuits.InsertStrategy.EARLIEST)
190