1import warnings 2from typing import Sequence, Union, List, Tuple, Dict, Optional 3 4import numpy as np 5import quimb 6import quimb.tensor as qtn 7 8import cirq 9 10 11# coverage: ignore 12def _get_quimb_version(): 13 """Returns the quimb version and parsed (major,minor) numbers if possible. 14 Returns: 15 a tuple of ((major, minor), version string) 16 """ 17 version = quimb.__version__ 18 try: 19 return tuple(int(x) for x in version.split('.')), version 20 except: 21 return (0, 0), version 22 23 24QUIMB_VERSION = _get_quimb_version() 25 26 27# TODO(#3388) Add documentation for Raises. 28# pylint: disable=missing-raises-doc 29def circuit_to_tensors( 30 circuit: cirq.Circuit, 31 qubits: Optional[Sequence[cirq.Qid]] = None, 32 initial_state: Union[int, None] = 0, 33) -> Tuple[List[qtn.Tensor], Dict['cirq.Qid', int], None]: 34 """Given a circuit, construct a tensor network representation. 35 36 Indices are named "i{i}_q{x}" where i is a time index and x is a 37 qubit index. 38 39 Args: 40 circuit: The circuit containing operations that implement the 41 cirq.unitary() protocol. 42 qubits: A list of qubits in the circuit. 43 initial_state: Either `0` corresponding to the |0..0> state, in 44 which case the tensor network will represent the final 45 state vector; or `None` in which case the starting indices 46 will be left open and the tensor network will represent the 47 circuit unitary. 48 Returns: 49 tensors: A list of quimb Tensor objects 50 qubit_frontier: A mapping from qubit to time index at the end of 51 the circuit. This can be used to deduce the names of the free 52 tensor indices. 53 positions: Currently None. May be changed in the future to return 54 a suitable mapping for tn.graph()'s `fix` argument. Currently, 55 `fix=None` will draw the resulting tensor network using a spring 56 layout. 57 """ 58 if qubits is None: 59 qubits = sorted(circuit.all_qubits()) # coverage: ignore 60 61 qubit_frontier = {q: 0 for q in qubits} 62 positions = None 63 tensors: List[qtn.Tensor] = [] 64 65 if initial_state == 0: 66 for q in qubits: 67 tensors += [qtn.Tensor(data=quimb.up().squeeze(), inds=(f'i0_q{q}',), tags={'Q0'})] 68 elif initial_state is None: 69 # no input tensors, return a network representing the unitary 70 pass 71 else: 72 raise ValueError("Right now, only |0> or `None` initial states are supported.") 73 74 for moment in circuit.moments: 75 for op in moment.operations: 76 assert op.gate._has_unitary_() 77 start_inds = [f'i{qubit_frontier[q]}_q{q}' for q in op.qubits] 78 for q in op.qubits: 79 qubit_frontier[q] += 1 80 end_inds = [f'i{qubit_frontier[q]}_q{q}' for q in op.qubits] 81 82 U = cirq.unitary(op).reshape((2,) * 2 * len(op.qubits)) 83 t = qtn.Tensor(data=U, inds=end_inds + start_inds, tags={f'Q{len(op.qubits)}'}) 84 tensors.append(t) 85 86 return tensors, qubit_frontier, positions 87 88 89# pylint: enable=missing-raises-doc 90def tensor_state_vector( 91 circuit: cirq.Circuit, qubits: Optional[Sequence[cirq.Qid]] = None 92) -> np.ndarray: 93 """Given a circuit contract a tensor network into a final state vector.""" 94 if qubits is None: 95 qubits = sorted(circuit.all_qubits()) 96 97 tensors, qubit_frontier, _ = circuit_to_tensors(circuit=circuit, qubits=qubits) 98 tn = qtn.TensorNetwork(tensors) 99 f_inds = tuple(f'i{qubit_frontier[q]}_q{q}' for q in qubits) 100 tn.contract(inplace=True) 101 return tn.to_dense(f_inds) 102 103 104def tensor_unitary( 105 circuit: cirq.Circuit, qubits: Optional[Sequence[cirq.Qid]] = None 106) -> np.ndarray: 107 """Given a circuit contract a tensor network into a dense unitary 108 of the circuit.""" 109 if qubits is None: 110 qubits = sorted(circuit.all_qubits()) 111 112 tensors, qubit_frontier, _ = circuit_to_tensors( 113 circuit=circuit, qubits=qubits, initial_state=None 114 ) 115 tn = qtn.TensorNetwork(tensors) 116 i_inds = tuple(f'i0_q{q}' for q in qubits) 117 f_inds = tuple(f'i{qubit_frontier[q]}_q{q}' for q in qubits) 118 tn.contract(inplace=True) 119 return tn.to_dense(f_inds, i_inds) 120 121 122def circuit_for_expectation_value( 123 circuit: cirq.Circuit, pauli_string: cirq.PauliString 124) -> cirq.Circuit: 125 """Sandwich a PauliString operator between a forwards and backwards 126 copy of a circuit. 127 128 This is a circuit representation of the expectation value of an operator 129 <A> = <psi|A|psi> = <0|U^dag A U|0>. You can either extract the 0..0 130 amplitude of the final state vector (assuming starting from the |0..0> 131 state or extract the [0, 0] entry of the unitary matrix of this combined 132 circuit. 133 """ 134 assert pauli_string.coefficient == 1 135 return cirq.Circuit( 136 [ 137 circuit, 138 cirq.Moment(gate.on(q) for q, gate in pauli_string.items()), 139 cirq.inverse(circuit), 140 ] 141 ) 142 143 144def tensor_expectation_value( 145 circuit: cirq.Circuit, pauli_string: cirq.PauliString, max_ram_gb=16, tol=1e-6 146) -> float: 147 """Compute an expectation value for an operator and a circuit via tensor 148 contraction. 149 150 This will give up if it looks like the computation will take too much RAM. 151 """ 152 circuit_sand = circuit_for_expectation_value(circuit, pauli_string / pauli_string.coefficient) 153 qubits = sorted(circuit_sand.all_qubits()) 154 155 tensors, qubit_frontier, _ = circuit_to_tensors(circuit=circuit_sand, qubits=qubits) 156 end_bras = [ 157 qtn.Tensor( 158 data=quimb.up().squeeze(), inds=(f'i{qubit_frontier[q]}_q{q}',), tags={'Q0', 'bra0'} 159 ) 160 for q in qubits 161 ] 162 tn = qtn.TensorNetwork(tensors + end_bras) 163 if QUIMB_VERSION[0] < (1, 3): 164 # coverage: ignore 165 warnings.warn( 166 f'quimb version {QUIMB_VERSION[1]} detected. Please use ' 167 f'quimb>=1.3 for optimal performance in ' 168 '`tensor_expectation_value`. ' 169 'See https://github.com/quantumlib/Cirq/issues/3263' 170 ) 171 else: 172 tn.rank_simplify(inplace=True) 173 path_info = tn.contract(get='path-info') 174 ram_gb = path_info.largest_intermediate * 128 / 8 / 1024 / 1024 / 1024 175 if ram_gb > max_ram_gb: 176 raise MemoryError(f"We estimate that this contraction will take too much RAM! {ram_gb} GB") 177 e_val = tn.contract(inplace=True) 178 assert e_val.imag < tol 179 assert pauli_string.coefficient.imag < tol 180 return e_val.real * pauli_string.coefficient 181