1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Utility methods for transforming matrices or vectors.""" 16 17from typing import Tuple, Optional, Sequence, List, Union 18 19import numpy as np 20 21from cirq import protocols 22from cirq.linalg import predicates 23 24# This is a special indicator value used by the `sub_state_vector` method to 25# determine whether or not the caller provided a 'default' argument. It must be 26# of type np.ndarray to ensure the method has the correct type signature in that 27# case. It is checked for using `is`, so it won't have a false positive if the 28# user provides a different np.array([]) value. 29RaiseValueErrorIfNotProvided: np.ndarray = np.array([]) 30 31 32def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float): 33 """Raises a matrix with two opposing eigenvalues to a power. 34 35 Args: 36 reflection_matrix: The matrix to raise to a power. 37 exponent: The power to raise the matrix to. 38 39 Returns: 40 The given matrix raised to the given power. 41 """ 42 43 # The eigenvalues are x and -x for some complex unit x. Determine x. 44 squared_phase = np.dot(reflection_matrix[:, 0], reflection_matrix[0, :]) 45 phase = complex(np.sqrt(squared_phase)) 46 47 # Extract +x and -x eigencomponents of the matrix. 48 i = np.eye(reflection_matrix.shape[0]) * phase 49 pos_part = (i + reflection_matrix) * 0.5 50 neg_part = (i - reflection_matrix) * 0.5 51 52 # Raise the matrix to a power by raising its eigencomponents to that power. 53 pos_factor = phase ** (exponent - 1) 54 neg_factor = pos_factor * complex(-1) ** exponent 55 pos_part_raised = pos_factor * pos_part 56 neg_part_raised = neg_part * neg_factor 57 return pos_part_raised + neg_part_raised 58 59 60def match_global_phase(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 61 """Phases the given matrices so that they agree on the phase of one entry. 62 63 To maximize precision, the position with the largest entry from one of the 64 matrices is used when attempting to compute the phase difference between 65 the two matrices. 66 67 Args: 68 a: A numpy array. 69 b: Another numpy array. 70 71 Returns: 72 A tuple (a', b') where a' == b' implies a == b*exp(i t) for some t. 73 """ 74 75 # Not much point when they have different shapes. 76 if a.shape != b.shape or a.size == 0: 77 return np.copy(a), np.copy(b) 78 79 # Find the entry with the largest magnitude in one of the matrices. 80 k = max(np.ndindex(*a.shape), key=lambda t: abs(b[t])) 81 82 def dephase(v): 83 r = np.real(v) 84 i = np.imag(v) 85 86 # Avoid introducing floating point error when axis-aligned. 87 if i == 0: 88 return -1 if r < 0 else 1 89 if r == 0: 90 return 1j if i < 0 else -1j 91 92 return np.exp(-1j * np.arctan2(i, r)) 93 94 # Zero the phase at this entry in both matrices. 95 return a * dephase(a[k]), b * dephase(b[k]) 96 97 98# TODO(#3388) Add documentation for Raises. 99# pylint: disable=missing-raises-doc 100def targeted_left_multiply( 101 left_matrix: np.ndarray, 102 right_target: np.ndarray, 103 target_axes: Sequence[int], 104 out: Optional[np.ndarray] = None, 105) -> np.ndarray: 106 """Left-multiplies the given axes of the target tensor by the given matrix. 107 108 Note that the matrix must have a compatible tensor structure. 109 110 For example, if you have an 6-qubit state vector `input_state` with shape 111 (2, 2, 2, 2, 2, 2), and a 2-qubit unitary operation `op` with shape 112 (2, 2, 2, 2), and you want to apply `op` to the 5'th and 3'rd qubits 113 within `input_state`, then the output state vector is computed as follows: 114 115 output_state = cirq.targeted_left_multiply(op, input_state, [5, 3]) 116 117 This method also works when the right hand side is a matrix instead of a 118 vector. If a unitary circuit's matrix is `old_effect`, and you append 119 a CNOT(q1, q4) operation onto the circuit, where the control q1 is the qubit 120 at offset 1 and the target q4 is the qubit at offset 4, then the appended 121 circuit's unitary matrix is computed as follows: 122 123 new_effect = cirq.targeted_left_multiply( 124 left_matrix=cirq.unitary(cirq.CNOT).reshape((2, 2, 2, 2)), 125 right_target=old_effect, 126 target_axes=[1, 4]) 127 128 Args: 129 left_matrix: What to left-multiply the target tensor by. 130 right_target: A tensor to carefully broadcast a left-multiply over. 131 target_axes: Which axes of the target are being operated on. 132 out: The buffer to store the results in. If not specified or None, a new 133 buffer is used. Must have the same shape as right_target. 134 135 Returns: 136 The output tensor. 137 """ 138 if out is right_target or out is left_matrix: 139 raise ValueError('out is right_target or out is left_matrix') 140 141 k = len(target_axes) 142 d = len(right_target.shape) 143 work_indices = tuple(range(k)) 144 data_indices = tuple(range(k, k + d)) 145 used_data_indices = tuple(data_indices[q] for q in target_axes) 146 input_indices = work_indices + used_data_indices 147 output_indices = list(data_indices) 148 for w, t in zip(work_indices, target_axes): 149 output_indices[t] = w 150 151 all_indices = set(input_indices + data_indices + tuple(output_indices)) 152 153 return np.einsum( 154 left_matrix, 155 input_indices, 156 right_target, 157 data_indices, 158 output_indices, 159 # We would prefer to omit 'optimize=' (it's faster), 160 # but this is a workaround for a bug in numpy: 161 # https://github.com/numpy/numpy/issues/10926 162 optimize=len(all_indices) >= 26, 163 # And this is workaround for *another* bug! 164 # Supposed to be able to just say 'old=old'. 165 **({'out': out} if out is not None else {}), 166 ) 167 168 169# pylint: enable=missing-raises-doc 170def targeted_conjugate_about( 171 tensor: np.ndarray, 172 target: np.ndarray, 173 indices: Sequence[int], 174 conj_indices: Sequence[int] = None, 175 buffer: Optional[np.ndarray] = None, 176 out: Optional[np.ndarray] = None, 177) -> np.ndarray: 178 r"""Conjugates the given tensor about the target tensor. 179 180 This method computes a target tensor conjugated by another tensor. 181 Here conjugate is used in the sense of conjugating by a matrix, i.a. 182 A conjugated about B is $A B A^\dagger$ where $\dagger$ represents the 183 conjugate transpose. 184 185 Abstractly this compute $A \cdot B \cdot A^\dagger$ where A and B are 186 multi-dimensional arrays, and instead of matrix multiplication $\cdot$ 187 is a contraction between the given indices (indices for first $\cdot$, 188 conj_indices for second $\cdot$). 189 190 More specifically this computes 191 $\sum tensor_{i_0,...,i_{r-1},j_0,...,j_{r-1}} * 192 target_{k_0,...,k_{r-1},l_0,...,l_{r-1}} * 193 tensor_{m_0,...,m_{r-1},n_0,...,n_{r-1}}^*$ 194 195 where the sum is over indices where $j_s$ = $k_s$ and $s$ is in `indices` 196 and $l_s$ = $m_s$ and s is in `conj_indices`. 197 198 Args: 199 tensor: The tensor that will be conjugated about the target tensor. 200 target: The tensor that will receive the conjugation. 201 indices: The indices which will be contracted between the tensor and 202 target. 203 conj_indices: The indices which will be contracted between the 204 complex conjugate of the tensor and the target. If this is None, 205 then these will be the values in indices plus half the number 206 of dimensions of the target (`ndim`). This is the most common case 207 and corresponds to the case where the target is an operator on 208 a n-dimensional tensor product space (here `n` would be `ndim`). 209 buffer: A buffer to store partial results in. If not specified or None, 210 a new buffer is used. 211 out: The buffer to store the results in. If not specified or None, a new 212 buffer is used. Must have the same shape as target. 213 214 Returns: 215 The result the conjugation. 216 """ 217 conj_indices = conj_indices or [i + target.ndim // 2 for i in indices] 218 first_multiply = targeted_left_multiply(tensor, target, indices, out=buffer) 219 return targeted_left_multiply(np.conjugate(tensor), first_multiply, conj_indices, out=out) 220 221 222_TSliceAtom = Union[int, slice, 'ellipsis'] 223_TSlice = Union[_TSliceAtom, Sequence[_TSliceAtom]] 224 225 226# TODO(#3388) Add documentation for Raises. 227# pylint: disable=missing-raises-doc 228def apply_matrix_to_slices( 229 target: np.ndarray, 230 matrix: np.ndarray, 231 slices: Sequence[_TSlice], 232 *, 233 out: Optional[np.ndarray] = None, 234) -> np.ndarray: 235 """Left-multiplies an NxN matrix onto N slices of a numpy array. 236 237 Example: 238 The 4x4 matrix of a fractional SWAP gate can be expressed as 239 240 [ 1 ] 241 [ X**t ] 242 [ 1 ] 243 244 Where X is the 2x2 Pauli X gate and t is the power of the swap with t=1 245 being a full swap. X**t is a power of the Pauli X gate's matrix. 246 Applying the fractional swap is equivalent to applying a fractional X 247 within the inner 2x2 subspace; the rest of the matrix is identity. This 248 can be expressed using `apply_matrix_to_slices` as follows: 249 250 def fractional_swap(target): 251 assert target.shape == (4,) 252 return apply_matrix_to_slices( 253 target=target, 254 matrix=cirq.unitary(cirq.X**t), 255 slices=[1, 2] 256 ) 257 258 Args: 259 target: The input array with slices that need to be left-multiplied. 260 matrix: The linear operation to apply to the subspace defined by the 261 slices. 262 slices: The parts of the tensor that correspond to the "vector entries" 263 that the matrix should operate on. May be integers or complicated 264 multi-dimensional slices into a tensor. The slices must refer to 265 non-overlapping sections of the input all with the same shape. 266 out: Where to write the output. If not specified, a new numpy array is 267 created, with the same shape and dtype as the target, to store the 268 output. 269 270 Returns: 271 The transformed array. 272 """ 273 # Validate arguments. 274 if out is target: 275 raise ValueError("Can't write output over the input.") 276 if matrix.shape != (len(slices), len(slices)): 277 raise ValueError("matrix.shape != (len(slices), len(slices))") 278 279 # Fill in default values and prepare space. 280 if out is None: 281 out = np.copy(target) 282 else: 283 out[...] = target[...] 284 285 # Apply operation. 286 for i, s_i in enumerate(slices): 287 out[s_i] *= matrix[i, i] 288 for j, s_j in enumerate(slices): 289 if i != j: 290 out[s_i] += target[s_j] * matrix[i, j] 291 292 return out 293 294 295# pylint: enable=missing-raises-doc 296def partial_trace(tensor: np.ndarray, keep_indices: Sequence[int]) -> np.ndarray: 297 """Takes the partial trace of a given tensor. 298 299 The input tensor must have shape `(d_0, ..., d_{k-1}, d_0, ..., d_{k-1})`. 300 The trace is done over all indices that are not in keep_indices. The 301 resulting tensor has shape `(d_{i_0}, ..., d_{i_r}, d_{i_0}, ..., d_{i_r})` 302 where `i_j` is the `j`th element of `keep_indices`. 303 304 Args: 305 tensor: The tensor to sum over. This tensor must have a shape 306 `(d_0, ..., d_{k-1}, d_0, ..., d_{k-1})`. 307 keep_indices: Which indices to not sum over. These are only the indices 308 of the first half of the tensors indices (i.e. all elements must 309 be between `0` and `tensor.ndims / 2 - 1` inclusive). 310 311 Raises: 312 ValueError: if the tensor is not of the correct shape or the indices 313 are not from the first half of valid indices for the tensor. 314 """ 315 ndim = tensor.ndim // 2 316 if not all(tensor.shape[i] == tensor.shape[i + ndim] for i in range(ndim)): 317 raise ValueError( 318 'Tensors must have shape (d_0,...,d_{{k-1}},d_0,...,' 319 'd_{{k-1}}) but had shape ({}).'.format(tensor.shape) 320 ) 321 if not all(i < ndim for i in keep_indices): 322 raise ValueError( 323 'keep_indices were {} but must be in first half, ' 324 'i.e. have index less that {}.'.format(keep_indices, ndim) 325 ) 326 keep_set = set(keep_indices) 327 keep_map = dict(zip(keep_indices, sorted(keep_indices))) 328 left_indices = [keep_map[i] if i in keep_set else i for i in range(ndim)] 329 right_indices = [ndim + i if i in keep_set else i for i in left_indices] 330 return np.einsum(tensor, left_indices + right_indices) 331 332 333class EntangledStateError(ValueError): 334 """Raised when a product state is expected, but an entangled state is provided.""" 335 336 337def partial_trace_of_state_vector_as_mixture( 338 state_vector: np.ndarray, keep_indices: List[int], *, atol: Union[int, float] = 1e-8 339) -> Tuple[Tuple[float, np.ndarray], ...]: 340 """Returns a mixture representing a state vector with only some qubits kept. 341 342 The input state vector must have shape `(2,) * n` or `(2 ** n)` where 343 `state_vector` is expressed over n qubits. States in the output mixture will 344 retain the same type of shape as the input state vector, either `(2 ** k)` 345 or `(2,) * k` where k is the number of qubits kept. 346 347 If the state vector cannot be factored into a pure state over `keep_indices` 348 then eigendecomposition is used and the output mixture will not be unique. 349 350 Args: 351 state_vector: The state vector to take the partial trace over. 352 keep_indices: Which indices to take the partial trace of the 353 state_vector on. 354 atol: The tolerance for determining that a factored state is pure. 355 356 Returns: 357 A single-component mixture in which the factored state vector has 358 probability '1' if the partially traced state is pure, or else a 359 mixture of the default eigendecomposition of the mixed state's 360 partial trace. 361 362 Raises: 363 ValueError: if the input `state_vector` is not an array of length 364 `(2 ** n)` or a tensor with a shape of `(2,) * n` 365 """ 366 367 # Attempt to do efficient state factoring. 368 try: 369 state = sub_state_vector( 370 state_vector, keep_indices, default=RaiseValueErrorIfNotProvided, atol=atol 371 ) 372 return ((1.0, state),) 373 except EntangledStateError: 374 pass 375 376 # Fall back to a (non-unique) mixture representation. 377 keep_dims = 1 << len(keep_indices) 378 ret_shape: Union[Tuple[int], Tuple[int, ...]] 379 if state_vector.shape == (state_vector.size,): 380 ret_shape = (keep_dims,) 381 elif all(e == 2 for e in state_vector.shape): 382 ret_shape = tuple(2 for _ in range(len(keep_indices))) 383 384 rho = np.kron(np.conj(state_vector.reshape(-1, 1)).T, state_vector.reshape(-1, 1)).reshape( 385 (2, 2) * int(np.log2(state_vector.size)) 386 ) 387 keep_rho = partial_trace(rho, keep_indices).reshape((keep_dims,) * 2) 388 eigvals, eigvecs = np.linalg.eigh(keep_rho) 389 mixture = tuple(zip(eigvals, [vec.reshape(ret_shape) for vec in eigvecs.T])) 390 return tuple([(float(p[0]), p[1]) for p in mixture if not protocols.approx_eq(p[0], 0.0)]) 391 392 393def sub_state_vector( 394 state_vector: np.ndarray, 395 keep_indices: List[int], 396 *, 397 default: np.ndarray = RaiseValueErrorIfNotProvided, 398 atol: Union[int, float] = 1e-8, 399) -> np.ndarray: 400 r"""Attempts to factor a state vector into two parts and return one of them. 401 402 The input `state_vector` must have shape ``(2,) * n`` or ``(2 ** n)`` where 403 `state_vector` is expressed over n qubits. The returned array will retain 404 the same type of shape as the input state vector, either ``(2 ** k)`` or 405 ``(2,) * k`` where k is the number of qubits kept. 406 407 If a state vector $|\psi\rangle$ defined on n qubits is an outer product 408 of kets like $|\psi\rangle$ = $|x\rangle \otimes |y\rangle$, and 409 $|x\rangle$ is defined over the subset ``keep_indices`` of k qubits, then 410 this method will factor $|\psi\rangle$ into $|x\rangle$ and $|y\rangle$ and 411 return $|x\rangle$. Note that $|x\rangle$ is not unique, because scalar 412 multiplication may be absorbed by any factor of a tensor product, 413 $e^{i \theta} |y\rangle \otimes |x\rangle = 414 |y\rangle \otimes e^{i \theta} |x\rangle$ 415 416 This method randomizes the global phase of $|x\rangle$ in order to avoid 417 accidental reliance on the global phase being some specific value. 418 419 If the provided `state_vector` cannot be factored into a pure state over 420 `keep_indices`, the method will fall back to return `default`. If `default` 421 is not provided, the method will fail and raise `ValueError`. 422 423 Args: 424 state_vector: The target state_vector. 425 keep_indices: Which indices to attempt to get the separable part of the 426 `state_vector` on. 427 default: Determines the fallback behavior when `state_vector` doesn't 428 have a pure state factorization. If the factored state is not pure 429 and `default` is not set, a ValueError is raised. If default is set 430 to a value, that value is returned. 431 atol: The minimum tolerance for comparing the output state's coherence 432 measure to 1. 433 434 Returns: 435 The state vector expressed over the desired subset of qubits. 436 437 Raises: 438 ValueError: if the `state_vector` is not of the correct shape or the 439 indices are not a valid subset of the input `state_vector`'s indices 440 EntangledStateError: If the result of factoring is not a pure state and 441 `default` is not provided. 442 443 """ 444 445 if not np.log2(state_vector.size).is_integer(): 446 raise ValueError( 447 "Input state_vector of size {} does not represent a " 448 "state over qubits.".format(state_vector.size) 449 ) 450 451 n_qubits = int(np.log2(state_vector.size)) 452 keep_dims = 1 << len(keep_indices) 453 ret_shape: Union[Tuple[int], Tuple[int, ...]] 454 if state_vector.shape == (state_vector.size,): 455 ret_shape = (keep_dims,) 456 state_vector = state_vector.reshape((2,) * n_qubits) 457 elif state_vector.shape == (2,) * n_qubits: 458 ret_shape = tuple(2 for _ in range(len(keep_indices))) 459 else: 460 raise ValueError("Input state_vector must be shaped like (2 ** n,) or (2,) * n") 461 462 keep_dims = 1 << len(keep_indices) 463 if not np.isclose(np.linalg.norm(state_vector), 1): 464 raise ValueError("Input state must be normalized.") 465 if len(set(keep_indices)) != len(keep_indices): 466 raise ValueError(f"keep_indices were {keep_indices} but must be unique.") 467 if any([ind >= n_qubits for ind in keep_indices]): 468 raise ValueError("keep_indices {} are an invalid subset of the input state vector.") 469 470 other_qubits = sorted(set(range(n_qubits)) - set(keep_indices)) 471 candidates = [ 472 state_vector[predicates.slice_for_qubits_equal_to(other_qubits, k)].reshape(keep_dims) 473 for k in range(1 << len(other_qubits)) 474 ] 475 # The coherence measure is computed using unnormalized candidates. 476 best_candidate = max(candidates, key=lambda c: np.linalg.norm(c, 2)) 477 best_candidate = best_candidate / np.linalg.norm(best_candidate) 478 left = np.conj(best_candidate.reshape((keep_dims,))).T 479 coherence_measure = sum([abs(np.dot(left, c.reshape((keep_dims,)))) ** 2 for c in candidates]) 480 481 if protocols.approx_eq(coherence_measure, 1, atol=atol): 482 return np.exp(2j * np.pi * np.random.random()) * best_candidate.reshape(ret_shape) 483 484 # Method did not yield a pure state. Fall back to `default` argument. 485 if default is not RaiseValueErrorIfNotProvided: 486 return default 487 488 raise EntangledStateError( 489 "Input state vector could not be factored into pure state over " 490 "indices {}".format(keep_indices) 491 ) 492 493 494def to_special(u: np.ndarray) -> np.ndarray: 495 """Converts a unitary matrix to a special unitary matrix. 496 497 All unitary matrices u have |det(u)| = 1. 498 Also for all d dimensional unitary matrix u, and scalar s: 499 det(u * s) = det(u) * s^(d) 500 To find a special unitary matrix from u: 501 u * det(u)^{-1/d} 502 503 Args: 504 u: the unitary matrix 505 Returns: 506 the special unitary matrix 507 """ 508 return u * (np.linalg.det(u) ** (-1 / len(u))) 509 510 511def state_vector_kronecker_product( 512 t1: np.ndarray, 513 t2: np.ndarray, 514) -> np.ndarray: 515 """Merges two state vectors into a single unified state vector. 516 517 The resulting vector's shape will be `t1.shape + t2.shape`. 518 519 Args: 520 t1: The first state vector. 521 t2: The second state vector. 522 Returns: 523 A new state vector representing the unified state. 524 """ 525 return np.outer(t1, t2).reshape(t1.shape + t2.shape) 526 527 528def density_matrix_kronecker_product( 529 t1: np.ndarray, 530 t2: np.ndarray, 531) -> np.ndarray: 532 """Merges two density matrices into a single unified density matrix. 533 534 The resulting matrix's shape will be `(t1.shape/2 + t2.shape/2) * 2`. In 535 other words, if t1 has shape [A,B,C,A,B,C] and t2 has shape [X,Y,Z,X,Y,Z], 536 the resulting matrix will have shape [A,B,C,X,Y,Z,A,B,C,X,Y,Z]. 537 538 Args: 539 t1: The first density matrix. 540 t2: The second density matrix. 541 Returns: 542 A density matrix representing the unified state. 543 """ 544 t = state_vector_kronecker_product(t1, t2) 545 t1_len = len(t1.shape) 546 t1_dim = int(t1_len / 2) 547 t2_len = len(t2.shape) 548 t2_dim = int(t2_len / 2) 549 shape = t1.shape[:t1_dim] + t2.shape[:t2_dim] 550 return np.moveaxis(t, range(t1_len, t1_len + t2_dim), range(t1_dim, t1_dim + t2_dim)).reshape( 551 shape * 2 552 ) 553 554 555# TODO(#3388) Add documentation for Raises. 556# pylint: disable=missing-raises-doc 557def factor_state_vector( 558 t: np.ndarray, 559 axes: Sequence[int], 560 *, 561 validate=True, 562 atol=1e-07, 563) -> Tuple[np.ndarray, np.ndarray]: 564 """Factors a state vector into two independent state vectors. 565 566 This function should only be called on state vectors that are known to be 567 separable, such as immediately after a measurement or reset operation. It 568 does not verify that the provided state vector is indeed separable, and 569 will return nonsense results for vectors representing entangled states. 570 571 Args: 572 t: The state vector to factor. 573 axes: The axes to factor out. 574 validate: Perform a validation that the density matrix factors cleanly. 575 atol: The absolute tolerance for the validation. 576 577 Returns: 578 A tuple with the `(extracted, remainder)` state vectors, where 579 `extracted` means the sub-state vector which corresponds to the axes 580 requested, and with the axes in the requested order, and where 581 `remainder` means the sub-state vector on the remaining axes, in the 582 same order as the original state vector. 583 """ 584 n_axes = len(axes) 585 t1 = np.moveaxis(t, axes, range(n_axes)) 586 pivot = np.unravel_index(np.abs(t1).argmax(), t1.shape) 587 slices1 = (slice(None),) * n_axes + pivot[n_axes:] 588 slices2 = pivot[:n_axes] + (slice(None),) * (t1.ndim - n_axes) 589 extracted = t1[slices1] 590 extracted = extracted / np.sum(abs(extracted) ** 2) ** 0.5 591 remainder = t1[slices2] 592 remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5 593 if validate: 594 t2 = state_vector_kronecker_product(extracted, remainder) 595 axes2 = list(axes) + [i for i in range(t1.ndim) if i not in axes] 596 t3 = transpose_state_vector_to_axis_order(t2, axes2) 597 if not np.allclose(t3, t, atol=atol): 598 raise ValueError('The tensor cannot be factored by the requested axes') 599 return extracted, remainder 600 601 602# TODO(#3388) Add documentation for Raises. 603def factor_density_matrix( 604 t: np.ndarray, 605 axes: Sequence[int], 606 *, 607 validate=True, 608 atol=1e-07, 609) -> Tuple[np.ndarray, np.ndarray]: 610 """Factors a density matrix into two independent density matrices. 611 612 This function should only be called on density matrices that are known to 613 be separable, such as immediately after a measurement or reset operation. 614 It does not verify that the provided density matrix is indeed separable, 615 and will return nonsense results for matrices representing entangled 616 states. 617 618 Args: 619 t: The density matrix to factor. 620 axes: The axes to factor out. Only the left axes should be provided. 621 For example, to extract [C,A] from density matrix of shape 622 [A,B,C,D,A,B,C,D], `axes` should be [2,0], and the return value 623 will be two density matrices ([C,A,C,A], [B,D,B,D]). 624 validate: Perform a validation that the density matrix factors cleanly. 625 atol: The absolute tolerance for the validation. 626 627 Returns: 628 A tuple with the `(extracted, remainder)` density matrices, where 629 `extracted` means the sub-matrix which corresponds to the axes 630 requested, and with the axes in the requested order, and where 631 `remainder` means the sub-matrix on the remaining axes, in the same 632 order as the original density matrix. 633 """ 634 extracted = partial_trace(t, axes) 635 remaining_axes = [i for i in range(t.ndim // 2) if i not in axes] 636 remainder = partial_trace(t, remaining_axes) 637 if validate: 638 t1 = density_matrix_kronecker_product(extracted, remainder) 639 product_axes = list(axes) + remaining_axes 640 t2 = transpose_density_matrix_to_axis_order(t1, product_axes) 641 if not np.allclose(t2, t, atol=atol): 642 raise ValueError('The tensor cannot be factored by the requested axes') 643 return extracted, remainder 644 645 646# pylint: enable=missing-raises-doc 647def transpose_state_vector_to_axis_order(t: np.ndarray, axes: Sequence[int]): 648 """Transposes the axes of a state vector to a specified order. 649 650 Args: 651 t: The state vector to transpose. 652 axes: The desired axis order. 653 Returns: 654 The transposed state vector. 655 """ 656 assert set(axes) == set(range(int(t.ndim))), "All axes must be provided." 657 return np.moveaxis(t, axes, range(len(axes))) 658 659 660def transpose_density_matrix_to_axis_order(t: np.ndarray, axes: Sequence[int]): 661 """Transposes the axes of a density matrix to a specified order. 662 663 Args: 664 t: The density matrix to transpose. 665 axes: The desired axis order. Only the left axes should be provided. 666 For example, to transpose [A,B,C,A,B,C] to [C,B,A,C,B,A], `axes` 667 should be [2,1,0]. 668 Returns: 669 The transposed density matrix. 670 """ 671 axes = list(axes) + [i + len(axes) for i in axes] 672 return transpose_state_vector_to_axis_order(t, axes) 673