1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from itertools import combinations
16from string import ascii_lowercase
17from typing import Sequence, Dict, Tuple
18
19import numpy as np
20import pytest
21
22import cirq
23import cirq.testing as ct
24import cirq.contrib.acquaintance as cca
25
26
27class ExampleGate(cirq.Gate):
28    def __init__(self, wire_symbols: Sequence[str]) -> None:
29        self._num_qubits = len(wire_symbols)
30        self._wire_symbols = tuple(wire_symbols)
31
32    def num_qubits(self) -> int:
33        return self._num_qubits
34
35    def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs):
36        return self._wire_symbols
37
38
39def test_executor_explicit():
40    num_qubits = 8
41    qubits = cirq.LineQubit.range(num_qubits)
42    circuit = cca.complete_acquaintance_strategy(qubits, 2)
43
44    gates = {
45        (i, j): ExampleGate([str(k) for k in ij])
46        for ij in combinations(range(num_qubits), 2)
47        for i, j in (ij, ij[::-1])
48    }
49    initial_mapping = {q: i for i, q in enumerate(sorted(qubits))}
50    execution_strategy = cca.GreedyExecutionStrategy(gates, initial_mapping)
51    executor = cca.StrategyExecutor(execution_strategy)
52
53    with pytest.raises(NotImplementedError):
54        bad_gates = {(0,): ExampleGate(['0']), (0, 1): ExampleGate(['0', '1'])}
55        cca.GreedyExecutionStrategy(bad_gates, initial_mapping)
56
57    with pytest.raises(TypeError):
58        executor(cirq.Circuit())
59
60    with pytest.raises(TypeError):
61        bad_strategy = cirq.Circuit(cirq.X(qubits[0]))
62        executor(bad_strategy)
63
64    with pytest.raises(TypeError):
65        op = cirq.X(qubits[0])
66        bad_strategy = cirq.Circuit(op)
67        executor.optimization_at(bad_strategy, 0, op)
68
69    executor(circuit)
70    expected_text_diagram = """
710: ───0───1───╲0╱─────────────────1───3───╲0╱─────────────────3───5───╲0╱─────────────────5───7───╲0╱─────────────────
72      │   │   │                   │   │   │                   │   │   │                   │   │   │
731: ───1───0───╱1╲───0───3───╲0╱───3───1───╱1╲───1───5───╲0╱───5───3───╱1╲───3───7───╲0╱───7───5───╱1╲───5───6───╲0╱───
74                    │   │   │                   │   │   │                   │   │   │                   │   │   │
752: ───2───3───╲0╱───3───0───╱1╲───0───5───╲0╱───5───1───╱1╲───1───7───╲0╱───7───3───╱1╲───3───6───╲0╱───6───5───╱1╲───
76      │   │   │                   │   │   │                   │   │   │                   │   │   │
773: ───3───2───╱1╲───2───5───╲0╱───5───0───╱1╲───0───7───╲0╱───7───1───╱1╲───1───6───╲0╱───6───3───╱1╲───3───4───╲0╱───
78                    │   │   │                   │   │   │                   │   │   │                   │   │   │
794: ───4───5───╲0╱───5───2───╱1╲───2───7───╲0╱───7───0───╱1╲───0───6───╲0╱───6───1───╱1╲───1───4───╲0╱───4───3───╱1╲───
80      │   │   │                   │   │   │                   │   │   │                   │   │   │
815: ───5───4───╱1╲───4───7───╲0╱───7───2───╱1╲───2───6───╲0╱───6───0───╱1╲───0───4───╲0╱───4───1───╱1╲───1───2───╲0╱───
82                    │   │   │                   │   │   │                   │   │   │                   │   │   │
836: ───6───7───╲0╱───7───4───╱1╲───4───6───╲0╱───6───2───╱1╲───2───4───╲0╱───4───0───╱1╲───0───2───╲0╱───2───1───╱1╲───
84      │   │   │                   │   │   │                   │   │   │                   │   │   │
857: ───7───6───╱1╲─────────────────6───4───╱1╲─────────────────4───2───╱1╲─────────────────2───0───╱1╲─────────────────
86    """.strip()
87    ct.assert_has_diagram(circuit, expected_text_diagram)
88
89
90def random_diagonal_gates(
91    num_qubits: int, acquaintance_size: int
92) -> Dict[Tuple[cirq.Qid, ...], cirq.Gate]:
93
94    return {
95        Q: cirq.DiagonalGate(np.random.random(2 ** acquaintance_size))
96        for Q in combinations(cirq.LineQubit.range(num_qubits), acquaintance_size)
97    }
98
99
100@pytest.mark.parametrize(
101    'num_qubits, acquaintance_size, gates',
102    [
103        (num_qubits, acquaintance_size, random_diagonal_gates(num_qubits, acquaintance_size))
104        for acquaintance_size, num_qubits in (
105            [(2, n) for n in range(2, 9)] + [(3, n) for n in range(3, 8)] + [(4, 4), (4, 6), (5, 5)]
106        )
107        for _ in range(2)
108    ],
109)
110def test_executor_random(
111    num_qubits: int, acquaintance_size: int, gates: Dict[Tuple[cirq.Qid, ...], cirq.Gate]
112):
113    qubits = cirq.LineQubit.range(num_qubits)
114    circuit = cca.complete_acquaintance_strategy(qubits, acquaintance_size)
115
116    logical_circuit = cirq.Circuit([g(*Q) for Q, g in gates.items()])
117    expected_unitary = logical_circuit.unitary()
118
119    initial_mapping = {q: q for q in qubits}
120    final_mapping = cca.GreedyExecutionStrategy(gates, initial_mapping)(circuit)
121    permutation = {q.x: qq.x for q, qq in final_mapping.items()}
122    circuit.append(cca.LinearPermutationGate(num_qubits, permutation)(*qubits))
123    actual_unitary = circuit.unitary()
124
125    np.testing.assert_allclose(actual=actual_unitary, desired=expected_unitary, verbose=True)
126
127
128def test_acquaintance_operation():
129    n = 5
130    physical_qubits = tuple(cirq.LineQubit.range(n))
131    logical_qubits = tuple(cirq.NamedQubit(s) for s in ascii_lowercase[:n])
132    int_indices = tuple(range(n))
133    with pytest.raises(ValueError):
134        cca.AcquaintanceOperation(physical_qubits[:3], int_indices[:4])
135    for logical_indices in (logical_qubits, int_indices):
136        op = cca.AcquaintanceOperation(physical_qubits, logical_indices)
137        assert op.logical_indices == logical_indices
138        assert op.qubits == physical_qubits
139        wire_symbols = tuple(f'({i})' for i in logical_indices)
140        assert cirq.circuit_diagram_info(op) == cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
141