1# Copyright 2018 The ops Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Any, Callable, Iterable, Sequence, Tuple, Union, cast, List
16
17from cirq import circuits, ops, protocols
18
19from cirq.contrib.paulistring.pauli_string_dag import (
20    pauli_string_reorder_pred,
21    pauli_string_dag_from_circuit,
22)
23
24
25def _sorted_best_string_placements(
26    possible_nodes: Iterable[Any],
27    output_ops: Sequence[ops.Operation],
28    key: Callable[[Any], ops.PauliStringPhasor] = lambda node: node.val,
29) -> List[Tuple[ops.PauliStringPhasor, int, circuits.Unique[ops.PauliStringPhasor]]]:
30
31    sort_key = lambda placement: (-len(placement[0].pauli_string), placement[1])
32
33    node_maxes = []
34    for possible_node in possible_nodes:
35        string_op = key(possible_node)
36        # Try moving the Pauli string through, stop at measurements
37        node_max = (string_op, 0, possible_node)
38
39        for i, out_op in enumerate(output_ops):
40            if not set(out_op.qubits) & set(string_op.qubits):
41                # Skip if operations don't share qubits
42                continue
43            if isinstance(out_op, ops.PauliStringPhasor) and protocols.commutes(
44                out_op.pauli_string, string_op.pauli_string
45            ):
46                # Pass through another Pauli string if they commute
47                continue
48            if not (
49                isinstance(out_op, ops.GateOperation)
50                and isinstance(
51                    out_op.gate,
52                    (ops.SingleQubitCliffordGate, ops.PauliInteractionGate, ops.CZPowGate),
53                )
54            ):
55                # This is as far through as this Pauli string can move
56                break
57            string_op = string_op.pass_operations_over([out_op], after_to_before=True)
58            curr = (string_op, i + 1, possible_node)
59            if sort_key(curr) > sort_key(node_max):
60                node_max = curr
61
62        node_maxes.append(node_max)
63
64    return sorted(node_maxes, key=sort_key, reverse=True)
65
66
67def move_pauli_strings_into_circuit(
68    circuit_left: Union[circuits.Circuit, circuits.CircuitDag], circuit_right: circuits.Circuit
69) -> circuits.Circuit:
70    if isinstance(circuit_left, circuits.CircuitDag):
71        string_dag = circuits.CircuitDag(pauli_string_reorder_pred, circuit_left)
72    else:
73        string_dag = pauli_string_dag_from_circuit(cast(circuits.Circuit, circuit_left))
74    output_ops = list(circuit_right.all_operations())
75
76    rightmost_nodes = set(string_dag.nodes()) - set(before for before, _ in string_dag.edges())
77
78    while rightmost_nodes:
79        # Sort the pauli string placements based on paulistring length and
80        # furthest possible distance in circuit_right
81        placements = _sorted_best_string_placements(rightmost_nodes, output_ops)
82        last_index = len(output_ops)
83
84        # Pick the Pauli string that can be moved furthest through
85        # the Clifford circuit
86        for best_string_op, best_index, best_node in placements:
87
88            assert (
89                best_index <= last_index
90            ), "Unexpected insertion index order, {} >= {}, len: {}".format(
91                best_index, last_index, len(output_ops)
92            )
93
94            last_index = best_index
95            output_ops.insert(best_index, best_string_op)
96            # Remove the best one from the dag and update rightmost_nodes
97            rightmost_nodes.remove(best_node)
98            rightmost_nodes.update(
99                pred_node
100                for pred_node in string_dag.predecessors(best_node)
101                if len(string_dag.succ[pred_node]) <= 1
102            )
103            string_dag.remove_node(best_node)
104
105    assert not string_dag.nodes(), 'There was a cycle in the CircuitDag'
106
107    return circuits.Circuit(
108        output_ops, strategy=circuits.InsertStrategy.EARLIEST, device=circuit_right.device
109    )
110