1import warnings
2from typing import Sequence, Union, List, Tuple, Dict, Optional
3
4import numpy as np
5import quimb
6import quimb.tensor as qtn
7
8import cirq
9
10
11# coverage: ignore
12def _get_quimb_version():
13    """Returns the quimb version and parsed (major,minor) numbers if possible.
14    Returns:
15        a tuple of ((major, minor), version string)
16    """
17    version = quimb.__version__
18    try:
19        return tuple(int(x) for x in version.split('.')), version
20    except:
21        return (0, 0), version
22
23
24QUIMB_VERSION = _get_quimb_version()
25
26
27# TODO(#3388) Add documentation for Raises.
28# pylint: disable=missing-raises-doc
29def circuit_to_tensors(
30    circuit: cirq.Circuit,
31    qubits: Optional[Sequence[cirq.Qid]] = None,
32    initial_state: Union[int, None] = 0,
33) -> Tuple[List[qtn.Tensor], Dict['cirq.Qid', int], None]:
34    """Given a circuit, construct a tensor network representation.
35
36    Indices are named "i{i}_q{x}" where i is a time index and x is a
37    qubit index.
38
39    Args:
40        circuit: The circuit containing operations that implement the
41            cirq.unitary() protocol.
42        qubits: A list of qubits in the circuit.
43        initial_state: Either `0` corresponding to the |0..0> state, in
44            which case the tensor network will represent the final
45            state vector; or `None` in which case the starting indices
46            will be left open and the tensor network will represent the
47            circuit unitary.
48    Returns:
49        tensors: A list of quimb Tensor objects
50        qubit_frontier: A mapping from qubit to time index at the end of
51            the circuit. This can be used to deduce the names of the free
52            tensor indices.
53        positions: Currently None. May be changed in the future to return
54            a suitable mapping for tn.graph()'s `fix` argument. Currently,
55            `fix=None` will draw the resulting tensor network using a spring
56            layout.
57    """
58    if qubits is None:
59        qubits = sorted(circuit.all_qubits())  # coverage: ignore
60
61    qubit_frontier = {q: 0 for q in qubits}
62    positions = None
63    tensors: List[qtn.Tensor] = []
64
65    if initial_state == 0:
66        for q in qubits:
67            tensors += [qtn.Tensor(data=quimb.up().squeeze(), inds=(f'i0_q{q}',), tags={'Q0'})]
68    elif initial_state is None:
69        # no input tensors, return a network representing the unitary
70        pass
71    else:
72        raise ValueError("Right now, only |0> or `None` initial states are supported.")
73
74    for moment in circuit.moments:
75        for op in moment.operations:
76            assert op.gate._has_unitary_()
77            start_inds = [f'i{qubit_frontier[q]}_q{q}' for q in op.qubits]
78            for q in op.qubits:
79                qubit_frontier[q] += 1
80            end_inds = [f'i{qubit_frontier[q]}_q{q}' for q in op.qubits]
81
82            U = cirq.unitary(op).reshape((2,) * 2 * len(op.qubits))
83            t = qtn.Tensor(data=U, inds=end_inds + start_inds, tags={f'Q{len(op.qubits)}'})
84            tensors.append(t)
85
86    return tensors, qubit_frontier, positions
87
88
89# pylint: enable=missing-raises-doc
90def tensor_state_vector(
91    circuit: cirq.Circuit, qubits: Optional[Sequence[cirq.Qid]] = None
92) -> np.ndarray:
93    """Given a circuit contract a tensor network into a final state vector."""
94    if qubits is None:
95        qubits = sorted(circuit.all_qubits())
96
97    tensors, qubit_frontier, _ = circuit_to_tensors(circuit=circuit, qubits=qubits)
98    tn = qtn.TensorNetwork(tensors)
99    f_inds = tuple(f'i{qubit_frontier[q]}_q{q}' for q in qubits)
100    tn.contract(inplace=True)
101    return tn.to_dense(f_inds)
102
103
104def tensor_unitary(
105    circuit: cirq.Circuit, qubits: Optional[Sequence[cirq.Qid]] = None
106) -> np.ndarray:
107    """Given a circuit contract a tensor network into a dense unitary
108    of the circuit."""
109    if qubits is None:
110        qubits = sorted(circuit.all_qubits())
111
112    tensors, qubit_frontier, _ = circuit_to_tensors(
113        circuit=circuit, qubits=qubits, initial_state=None
114    )
115    tn = qtn.TensorNetwork(tensors)
116    i_inds = tuple(f'i0_q{q}' for q in qubits)
117    f_inds = tuple(f'i{qubit_frontier[q]}_q{q}' for q in qubits)
118    tn.contract(inplace=True)
119    return tn.to_dense(f_inds, i_inds)
120
121
122def circuit_for_expectation_value(
123    circuit: cirq.Circuit, pauli_string: cirq.PauliString
124) -> cirq.Circuit:
125    """Sandwich a PauliString operator between a forwards and backwards
126    copy of a circuit.
127
128    This is a circuit representation of the expectation value of an operator
129    <A> = <psi|A|psi> = <0|U^dag A U|0>. You can either extract the 0..0
130    amplitude of the final state vector (assuming starting from the |0..0>
131    state or extract the [0, 0] entry of the unitary matrix of this combined
132    circuit.
133    """
134    assert pauli_string.coefficient == 1
135    return cirq.Circuit(
136        [
137            circuit,
138            cirq.Moment(gate.on(q) for q, gate in pauli_string.items()),
139            cirq.inverse(circuit),
140        ]
141    )
142
143
144def tensor_expectation_value(
145    circuit: cirq.Circuit, pauli_string: cirq.PauliString, max_ram_gb=16, tol=1e-6
146) -> float:
147    """Compute an expectation value for an operator and a circuit via tensor
148    contraction.
149
150    This will give up if it looks like the computation will take too much RAM.
151    """
152    circuit_sand = circuit_for_expectation_value(circuit, pauli_string / pauli_string.coefficient)
153    qubits = sorted(circuit_sand.all_qubits())
154
155    tensors, qubit_frontier, _ = circuit_to_tensors(circuit=circuit_sand, qubits=qubits)
156    end_bras = [
157        qtn.Tensor(
158            data=quimb.up().squeeze(), inds=(f'i{qubit_frontier[q]}_q{q}',), tags={'Q0', 'bra0'}
159        )
160        for q in qubits
161    ]
162    tn = qtn.TensorNetwork(tensors + end_bras)
163    if QUIMB_VERSION[0] < (1, 3):
164        # coverage: ignore
165        warnings.warn(
166            f'quimb version {QUIMB_VERSION[1]} detected. Please use '
167            f'quimb>=1.3 for optimal performance in '
168            '`tensor_expectation_value`. '
169            'See https://github.com/quantumlib/Cirq/issues/3263'
170        )
171    else:
172        tn.rank_simplify(inplace=True)
173    path_info = tn.contract(get='path-info')
174    ram_gb = path_info.largest_intermediate * 128 / 8 / 1024 / 1024 / 1024
175    if ram_gb > max_ram_gb:
176        raise MemoryError(f"We estimate that this contraction will take too much RAM! {ram_gb} GB")
177    e_val = tn.contract(inplace=True)
178    assert e_val.imag < tol
179    assert pauli_string.coefficient.imag < tol
180    return e_val.real * pauli_string.coefficient
181