1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Utility methods for transforming matrices or vectors."""
16
17from typing import Tuple, Optional, Sequence, List, Union
18
19import numpy as np
20
21from cirq import protocols
22from cirq.linalg import predicates
23
24# This is a special indicator value used by the `sub_state_vector` method to
25# determine whether or not the caller provided a 'default' argument. It must be
26# of type np.ndarray to ensure the method has the correct type signature in that
27# case. It is checked for using `is`, so it won't have a false positive if the
28# user provides a different np.array([]) value.
29RaiseValueErrorIfNotProvided: np.ndarray = np.array([])
30
31
32def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float):
33    """Raises a matrix with two opposing eigenvalues to a power.
34
35    Args:
36        reflection_matrix: The matrix to raise to a power.
37        exponent: The power to raise the matrix to.
38
39    Returns:
40        The given matrix raised to the given power.
41    """
42
43    # The eigenvalues are x and -x for some complex unit x. Determine x.
44    squared_phase = np.dot(reflection_matrix[:, 0], reflection_matrix[0, :])
45    phase = complex(np.sqrt(squared_phase))
46
47    # Extract +x and -x eigencomponents of the matrix.
48    i = np.eye(reflection_matrix.shape[0]) * phase
49    pos_part = (i + reflection_matrix) * 0.5
50    neg_part = (i - reflection_matrix) * 0.5
51
52    # Raise the matrix to a power by raising its eigencomponents to that power.
53    pos_factor = phase ** (exponent - 1)
54    neg_factor = pos_factor * complex(-1) ** exponent
55    pos_part_raised = pos_factor * pos_part
56    neg_part_raised = neg_part * neg_factor
57    return pos_part_raised + neg_part_raised
58
59
60def match_global_phase(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
61    """Phases the given matrices so that they agree on the phase of one entry.
62
63    To maximize precision, the position with the largest entry from one of the
64    matrices is used when attempting to compute the phase difference between
65    the two matrices.
66
67    Args:
68        a: A numpy array.
69        b: Another numpy array.
70
71    Returns:
72        A tuple (a', b') where a' == b' implies a == b*exp(i t) for some t.
73    """
74
75    # Not much point when they have different shapes.
76    if a.shape != b.shape or a.size == 0:
77        return np.copy(a), np.copy(b)
78
79    # Find the entry with the largest magnitude in one of the matrices.
80    k = max(np.ndindex(*a.shape), key=lambda t: abs(b[t]))
81
82    def dephase(v):
83        r = np.real(v)
84        i = np.imag(v)
85
86        # Avoid introducing floating point error when axis-aligned.
87        if i == 0:
88            return -1 if r < 0 else 1
89        if r == 0:
90            return 1j if i < 0 else -1j
91
92        return np.exp(-1j * np.arctan2(i, r))
93
94    # Zero the phase at this entry in both matrices.
95    return a * dephase(a[k]), b * dephase(b[k])
96
97
98# TODO(#3388) Add documentation for Raises.
99# pylint: disable=missing-raises-doc
100def targeted_left_multiply(
101    left_matrix: np.ndarray,
102    right_target: np.ndarray,
103    target_axes: Sequence[int],
104    out: Optional[np.ndarray] = None,
105) -> np.ndarray:
106    """Left-multiplies the given axes of the target tensor by the given matrix.
107
108    Note that the matrix must have a compatible tensor structure.
109
110    For example, if you have an 6-qubit state vector `input_state` with shape
111    (2, 2, 2, 2, 2, 2), and a 2-qubit unitary operation `op` with shape
112    (2, 2, 2, 2), and you want to apply `op` to the 5'th and 3'rd qubits
113    within `input_state`, then the output state vector is computed as follows:
114
115        output_state = cirq.targeted_left_multiply(op, input_state, [5, 3])
116
117    This method also works when the right hand side is a matrix instead of a
118    vector. If a unitary circuit's matrix is `old_effect`, and you append
119    a CNOT(q1, q4) operation onto the circuit, where the control q1 is the qubit
120    at offset 1 and the target q4 is the qubit at offset 4, then the appended
121    circuit's unitary matrix is computed as follows:
122
123        new_effect = cirq.targeted_left_multiply(
124            left_matrix=cirq.unitary(cirq.CNOT).reshape((2, 2, 2, 2)),
125            right_target=old_effect,
126            target_axes=[1, 4])
127
128    Args:
129        left_matrix: What to left-multiply the target tensor by.
130        right_target: A tensor to carefully broadcast a left-multiply over.
131        target_axes: Which axes of the target are being operated on.
132        out: The buffer to store the results in. If not specified or None, a new
133            buffer is used. Must have the same shape as right_target.
134
135    Returns:
136        The output tensor.
137    """
138    if out is right_target or out is left_matrix:
139        raise ValueError('out is right_target or out is left_matrix')
140
141    k = len(target_axes)
142    d = len(right_target.shape)
143    work_indices = tuple(range(k))
144    data_indices = tuple(range(k, k + d))
145    used_data_indices = tuple(data_indices[q] for q in target_axes)
146    input_indices = work_indices + used_data_indices
147    output_indices = list(data_indices)
148    for w, t in zip(work_indices, target_axes):
149        output_indices[t] = w
150
151    all_indices = set(input_indices + data_indices + tuple(output_indices))
152
153    return np.einsum(
154        left_matrix,
155        input_indices,
156        right_target,
157        data_indices,
158        output_indices,
159        # We would prefer to omit 'optimize=' (it's faster),
160        # but this is a workaround for a bug in numpy:
161        #     https://github.com/numpy/numpy/issues/10926
162        optimize=len(all_indices) >= 26,
163        # And this is workaround for *another* bug!
164        # Supposed to be able to just say 'old=old'.
165        **({'out': out} if out is not None else {}),
166    )
167
168
169# pylint: enable=missing-raises-doc
170def targeted_conjugate_about(
171    tensor: np.ndarray,
172    target: np.ndarray,
173    indices: Sequence[int],
174    conj_indices: Sequence[int] = None,
175    buffer: Optional[np.ndarray] = None,
176    out: Optional[np.ndarray] = None,
177) -> np.ndarray:
178    r"""Conjugates the given tensor about the target tensor.
179
180    This method computes a target tensor conjugated by another tensor.
181    Here conjugate is used in the sense of conjugating by a matrix, i.a.
182    A conjugated about B is $A B A^\dagger$ where $\dagger$ represents the
183    conjugate transpose.
184
185    Abstractly this compute $A \cdot B \cdot A^\dagger$ where A and B are
186    multi-dimensional arrays, and instead of matrix multiplication $\cdot$
187    is a contraction between the given indices (indices for first $\cdot$,
188    conj_indices for second $\cdot$).
189
190    More specifically this computes
191        $\sum tensor_{i_0,...,i_{r-1},j_0,...,j_{r-1}} *
192        target_{k_0,...,k_{r-1},l_0,...,l_{r-1}} *
193        tensor_{m_0,...,m_{r-1},n_0,...,n_{r-1}}^*$
194
195    where the sum is over indices where $j_s$ = $k_s$ and $s$ is in `indices`
196    and $l_s$ = $m_s$ and s is in `conj_indices`.
197
198    Args:
199        tensor: The tensor that will be conjugated about the target tensor.
200        target: The tensor that will receive the conjugation.
201        indices: The indices which will be contracted between the tensor and
202            target.
203        conj_indices: The indices which will be contracted between the
204            complex conjugate of the tensor and the target. If this is None,
205            then these will be the values in indices plus half the number
206            of dimensions of the target (`ndim`). This is the most common case
207            and corresponds to the case where the target is an operator on
208            a n-dimensional tensor product space (here `n` would be `ndim`).
209        buffer: A buffer to store partial results in.  If not specified or None,
210            a new buffer is used.
211        out: The buffer to store the results in. If not specified or None, a new
212            buffer is used. Must have the same shape as target.
213
214    Returns:
215        The result the conjugation.
216    """
217    conj_indices = conj_indices or [i + target.ndim // 2 for i in indices]
218    first_multiply = targeted_left_multiply(tensor, target, indices, out=buffer)
219    return targeted_left_multiply(np.conjugate(tensor), first_multiply, conj_indices, out=out)
220
221
222_TSliceAtom = Union[int, slice, 'ellipsis']
223_TSlice = Union[_TSliceAtom, Sequence[_TSliceAtom]]
224
225
226# TODO(#3388) Add documentation for Raises.
227# pylint: disable=missing-raises-doc
228def apply_matrix_to_slices(
229    target: np.ndarray,
230    matrix: np.ndarray,
231    slices: Sequence[_TSlice],
232    *,
233    out: Optional[np.ndarray] = None,
234) -> np.ndarray:
235    """Left-multiplies an NxN matrix onto N slices of a numpy array.
236
237    Example:
238        The 4x4 matrix of a fractional SWAP gate can be expressed as
239
240           [ 1       ]
241           [   X**t  ]
242           [       1 ]
243
244        Where X is the 2x2 Pauli X gate and t is the power of the swap with t=1
245        being a full swap. X**t is a power of the Pauli X gate's matrix.
246        Applying the fractional swap is equivalent to applying a fractional X
247        within the inner 2x2 subspace; the rest of the matrix is identity. This
248        can be expressed using `apply_matrix_to_slices` as follows:
249
250            def fractional_swap(target):
251                assert target.shape == (4,)
252                return apply_matrix_to_slices(
253                    target=target,
254                    matrix=cirq.unitary(cirq.X**t),
255                    slices=[1, 2]
256                )
257
258    Args:
259        target: The input array with slices that need to be left-multiplied.
260        matrix: The linear operation to apply to the subspace defined by the
261            slices.
262        slices: The parts of the tensor that correspond to the "vector entries"
263            that the matrix should operate on. May be integers or complicated
264            multi-dimensional slices into a tensor. The slices must refer to
265            non-overlapping sections of the input all with the same shape.
266        out: Where to write the output. If not specified, a new numpy array is
267            created, with the same shape and dtype as the target, to store the
268            output.
269
270    Returns:
271        The transformed array.
272    """
273    # Validate arguments.
274    if out is target:
275        raise ValueError("Can't write output over the input.")
276    if matrix.shape != (len(slices), len(slices)):
277        raise ValueError("matrix.shape != (len(slices), len(slices))")
278
279    # Fill in default values and prepare space.
280    if out is None:
281        out = np.copy(target)
282    else:
283        out[...] = target[...]
284
285    # Apply operation.
286    for i, s_i in enumerate(slices):
287        out[s_i] *= matrix[i, i]
288        for j, s_j in enumerate(slices):
289            if i != j:
290                out[s_i] += target[s_j] * matrix[i, j]
291
292    return out
293
294
295# pylint: enable=missing-raises-doc
296def partial_trace(tensor: np.ndarray, keep_indices: Sequence[int]) -> np.ndarray:
297    """Takes the partial trace of a given tensor.
298
299    The input tensor must have shape `(d_0, ..., d_{k-1}, d_0, ..., d_{k-1})`.
300    The trace is done over all indices that are not in keep_indices. The
301    resulting tensor has shape `(d_{i_0}, ..., d_{i_r}, d_{i_0}, ..., d_{i_r})`
302    where `i_j` is the `j`th element of `keep_indices`.
303
304    Args:
305        tensor: The tensor to sum over. This tensor must have a shape
306            `(d_0, ..., d_{k-1}, d_0, ..., d_{k-1})`.
307        keep_indices: Which indices to not sum over. These are only the indices
308            of the first half of the tensors indices (i.e. all elements must
309            be between `0` and `tensor.ndims / 2 - 1` inclusive).
310
311    Raises:
312        ValueError: if the tensor is not of the correct shape or the indices
313            are not from the first half of valid indices for the tensor.
314    """
315    ndim = tensor.ndim // 2
316    if not all(tensor.shape[i] == tensor.shape[i + ndim] for i in range(ndim)):
317        raise ValueError(
318            'Tensors must have shape (d_0,...,d_{{k-1}},d_0,...,'
319            'd_{{k-1}}) but had shape ({}).'.format(tensor.shape)
320        )
321    if not all(i < ndim for i in keep_indices):
322        raise ValueError(
323            'keep_indices were {} but must be in first half, '
324            'i.e. have index less that {}.'.format(keep_indices, ndim)
325        )
326    keep_set = set(keep_indices)
327    keep_map = dict(zip(keep_indices, sorted(keep_indices)))
328    left_indices = [keep_map[i] if i in keep_set else i for i in range(ndim)]
329    right_indices = [ndim + i if i in keep_set else i for i in left_indices]
330    return np.einsum(tensor, left_indices + right_indices)
331
332
333class EntangledStateError(ValueError):
334    """Raised when a product state is expected, but an entangled state is provided."""
335
336
337def partial_trace_of_state_vector_as_mixture(
338    state_vector: np.ndarray, keep_indices: List[int], *, atol: Union[int, float] = 1e-8
339) -> Tuple[Tuple[float, np.ndarray], ...]:
340    """Returns a mixture representing a state vector with only some qubits kept.
341
342    The input state vector must have shape `(2,) * n` or `(2 ** n)` where
343    `state_vector` is expressed over n qubits. States in the output mixture will
344    retain the same type of shape as the input state vector, either `(2 ** k)`
345    or `(2,) * k` where k is the number of qubits kept.
346
347    If the state vector cannot be factored into a pure state over `keep_indices`
348    then eigendecomposition is used and the output mixture will not be unique.
349
350    Args:
351        state_vector: The state vector to take the partial trace over.
352        keep_indices: Which indices to take the partial trace of the
353            state_vector on.
354        atol: The tolerance for determining that a factored state is pure.
355
356    Returns:
357        A single-component mixture in which the factored state vector has
358        probability '1' if the partially traced state is pure, or else a
359        mixture of the default eigendecomposition of the mixed state's
360        partial trace.
361
362    Raises:
363        ValueError: if the input `state_vector` is not an array of length
364        `(2 ** n)` or a tensor with a shape of `(2,) * n`
365    """
366
367    # Attempt to do efficient state factoring.
368    try:
369        state = sub_state_vector(
370            state_vector, keep_indices, default=RaiseValueErrorIfNotProvided, atol=atol
371        )
372        return ((1.0, state),)
373    except EntangledStateError:
374        pass
375
376    # Fall back to a (non-unique) mixture representation.
377    keep_dims = 1 << len(keep_indices)
378    ret_shape: Union[Tuple[int], Tuple[int, ...]]
379    if state_vector.shape == (state_vector.size,):
380        ret_shape = (keep_dims,)
381    elif all(e == 2 for e in state_vector.shape):
382        ret_shape = tuple(2 for _ in range(len(keep_indices)))
383
384    rho = np.kron(np.conj(state_vector.reshape(-1, 1)).T, state_vector.reshape(-1, 1)).reshape(
385        (2, 2) * int(np.log2(state_vector.size))
386    )
387    keep_rho = partial_trace(rho, keep_indices).reshape((keep_dims,) * 2)
388    eigvals, eigvecs = np.linalg.eigh(keep_rho)
389    mixture = tuple(zip(eigvals, [vec.reshape(ret_shape) for vec in eigvecs.T]))
390    return tuple([(float(p[0]), p[1]) for p in mixture if not protocols.approx_eq(p[0], 0.0)])
391
392
393def sub_state_vector(
394    state_vector: np.ndarray,
395    keep_indices: List[int],
396    *,
397    default: np.ndarray = RaiseValueErrorIfNotProvided,
398    atol: Union[int, float] = 1e-8,
399) -> np.ndarray:
400    r"""Attempts to factor a state vector into two parts and return one of them.
401
402    The input `state_vector` must have shape ``(2,) * n`` or ``(2 ** n)`` where
403    `state_vector` is expressed over n qubits. The returned array will retain
404    the same type of shape as the input state vector, either ``(2 ** k)`` or
405    ``(2,) * k`` where k is the number of qubits kept.
406
407    If a state vector $|\psi\rangle$ defined on n qubits is an outer product
408    of kets like  $|\psi\rangle$ = $|x\rangle \otimes |y\rangle$, and
409    $|x\rangle$ is defined over the subset ``keep_indices`` of k qubits, then
410    this method will factor $|\psi\rangle$ into $|x\rangle$ and $|y\rangle$ and
411    return $|x\rangle$. Note that $|x\rangle$ is not unique, because scalar
412    multiplication may be absorbed by any factor of a tensor product,
413    $e^{i \theta} |y\rangle \otimes |x\rangle =
414    |y\rangle \otimes e^{i \theta} |x\rangle$
415
416    This method randomizes the global phase of $|x\rangle$ in order to avoid
417    accidental reliance on the global phase being some specific value.
418
419    If the provided `state_vector` cannot be factored into a pure state over
420    `keep_indices`, the method will fall back to return `default`. If `default`
421    is not provided, the method will fail and raise `ValueError`.
422
423    Args:
424        state_vector: The target state_vector.
425        keep_indices: Which indices to attempt to get the separable part of the
426            `state_vector` on.
427        default: Determines the fallback behavior when `state_vector` doesn't
428            have a pure state factorization. If the factored state is not pure
429            and `default` is not set, a ValueError is raised. If default is set
430            to a value, that value is returned.
431        atol: The minimum tolerance for comparing the output state's coherence
432            measure to 1.
433
434    Returns:
435        The state vector expressed over the desired subset of qubits.
436
437    Raises:
438        ValueError: if the `state_vector` is not of the correct shape or the
439            indices are not a valid subset of the input `state_vector`'s indices
440        EntangledStateError: If the result of factoring is not a pure state and
441            `default` is not provided.
442
443    """
444
445    if not np.log2(state_vector.size).is_integer():
446        raise ValueError(
447            "Input state_vector of size {} does not represent a "
448            "state over qubits.".format(state_vector.size)
449        )
450
451    n_qubits = int(np.log2(state_vector.size))
452    keep_dims = 1 << len(keep_indices)
453    ret_shape: Union[Tuple[int], Tuple[int, ...]]
454    if state_vector.shape == (state_vector.size,):
455        ret_shape = (keep_dims,)
456        state_vector = state_vector.reshape((2,) * n_qubits)
457    elif state_vector.shape == (2,) * n_qubits:
458        ret_shape = tuple(2 for _ in range(len(keep_indices)))
459    else:
460        raise ValueError("Input state_vector must be shaped like (2 ** n,) or (2,) * n")
461
462    keep_dims = 1 << len(keep_indices)
463    if not np.isclose(np.linalg.norm(state_vector), 1):
464        raise ValueError("Input state must be normalized.")
465    if len(set(keep_indices)) != len(keep_indices):
466        raise ValueError(f"keep_indices were {keep_indices} but must be unique.")
467    if any([ind >= n_qubits for ind in keep_indices]):
468        raise ValueError("keep_indices {} are an invalid subset of the input state vector.")
469
470    other_qubits = sorted(set(range(n_qubits)) - set(keep_indices))
471    candidates = [
472        state_vector[predicates.slice_for_qubits_equal_to(other_qubits, k)].reshape(keep_dims)
473        for k in range(1 << len(other_qubits))
474    ]
475    # The coherence measure is computed using unnormalized candidates.
476    best_candidate = max(candidates, key=lambda c: np.linalg.norm(c, 2))
477    best_candidate = best_candidate / np.linalg.norm(best_candidate)
478    left = np.conj(best_candidate.reshape((keep_dims,))).T
479    coherence_measure = sum([abs(np.dot(left, c.reshape((keep_dims,)))) ** 2 for c in candidates])
480
481    if protocols.approx_eq(coherence_measure, 1, atol=atol):
482        return np.exp(2j * np.pi * np.random.random()) * best_candidate.reshape(ret_shape)
483
484    # Method did not yield a pure state. Fall back to `default` argument.
485    if default is not RaiseValueErrorIfNotProvided:
486        return default
487
488    raise EntangledStateError(
489        "Input state vector could not be factored into pure state over "
490        "indices {}".format(keep_indices)
491    )
492
493
494def to_special(u: np.ndarray) -> np.ndarray:
495    """Converts a unitary matrix to a special unitary matrix.
496
497    All unitary matrices u have |det(u)| = 1.
498    Also for all d dimensional unitary matrix u, and scalar s:
499        det(u * s) = det(u) * s^(d)
500    To find a special unitary matrix from u:
501        u * det(u)^{-1/d}
502
503    Args:
504        u: the unitary matrix
505    Returns:
506        the special unitary matrix
507    """
508    return u * (np.linalg.det(u) ** (-1 / len(u)))
509
510
511def state_vector_kronecker_product(
512    t1: np.ndarray,
513    t2: np.ndarray,
514) -> np.ndarray:
515    """Merges two state vectors into a single unified state vector.
516
517    The resulting vector's shape will be `t1.shape + t2.shape`.
518
519    Args:
520        t1: The first state vector.
521        t2: The second state vector.
522    Returns:
523        A new state vector representing the unified state.
524    """
525    return np.outer(t1, t2).reshape(t1.shape + t2.shape)
526
527
528def density_matrix_kronecker_product(
529    t1: np.ndarray,
530    t2: np.ndarray,
531) -> np.ndarray:
532    """Merges two density matrices into a single unified density matrix.
533
534    The resulting matrix's shape will be `(t1.shape/2 + t2.shape/2) * 2`. In
535    other words, if t1 has shape [A,B,C,A,B,C] and t2 has shape [X,Y,Z,X,Y,Z],
536    the resulting matrix will have shape [A,B,C,X,Y,Z,A,B,C,X,Y,Z].
537
538    Args:
539        t1: The first density matrix.
540        t2: The second density matrix.
541    Returns:
542        A density matrix representing the unified state.
543    """
544    t = state_vector_kronecker_product(t1, t2)
545    t1_len = len(t1.shape)
546    t1_dim = int(t1_len / 2)
547    t2_len = len(t2.shape)
548    t2_dim = int(t2_len / 2)
549    shape = t1.shape[:t1_dim] + t2.shape[:t2_dim]
550    return np.moveaxis(t, range(t1_len, t1_len + t2_dim), range(t1_dim, t1_dim + t2_dim)).reshape(
551        shape * 2
552    )
553
554
555# TODO(#3388) Add documentation for Raises.
556# pylint: disable=missing-raises-doc
557def factor_state_vector(
558    t: np.ndarray,
559    axes: Sequence[int],
560    *,
561    validate=True,
562    atol=1e-07,
563) -> Tuple[np.ndarray, np.ndarray]:
564    """Factors a state vector into two independent state vectors.
565
566    This function should only be called on state vectors that are known to be
567    separable, such as immediately after a measurement or reset operation. It
568    does not verify that the provided state vector is indeed separable, and
569    will return nonsense results for vectors representing entangled states.
570
571    Args:
572        t: The state vector to factor.
573        axes: The axes to factor out.
574        validate: Perform a validation that the density matrix factors cleanly.
575        atol: The absolute tolerance for the validation.
576
577    Returns:
578        A tuple with the `(extracted, remainder)` state vectors, where
579        `extracted` means the sub-state vector which corresponds to the axes
580        requested, and with the axes in the requested order, and where
581        `remainder` means the sub-state vector on the remaining axes, in the
582        same order as the original state vector.
583    """
584    n_axes = len(axes)
585    t1 = np.moveaxis(t, axes, range(n_axes))
586    pivot = np.unravel_index(np.abs(t1).argmax(), t1.shape)
587    slices1 = (slice(None),) * n_axes + pivot[n_axes:]
588    slices2 = pivot[:n_axes] + (slice(None),) * (t1.ndim - n_axes)
589    extracted = t1[slices1]
590    extracted = extracted / np.sum(abs(extracted) ** 2) ** 0.5
591    remainder = t1[slices2]
592    remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5
593    if validate:
594        t2 = state_vector_kronecker_product(extracted, remainder)
595        axes2 = list(axes) + [i for i in range(t1.ndim) if i not in axes]
596        t3 = transpose_state_vector_to_axis_order(t2, axes2)
597        if not np.allclose(t3, t, atol=atol):
598            raise ValueError('The tensor cannot be factored by the requested axes')
599    return extracted, remainder
600
601
602# TODO(#3388) Add documentation for Raises.
603def factor_density_matrix(
604    t: np.ndarray,
605    axes: Sequence[int],
606    *,
607    validate=True,
608    atol=1e-07,
609) -> Tuple[np.ndarray, np.ndarray]:
610    """Factors a density matrix into two independent density matrices.
611
612    This function should only be called on density matrices that are known to
613    be separable, such as immediately after a measurement or reset operation.
614    It does not verify that the provided density matrix is indeed separable,
615    and will return nonsense results for matrices representing entangled
616    states.
617
618    Args:
619        t: The density matrix to factor.
620        axes: The axes to factor out. Only the left axes should be provided.
621            For example, to extract [C,A] from density matrix of shape
622            [A,B,C,D,A,B,C,D], `axes` should be [2,0], and the return value
623            will be two density matrices ([C,A,C,A], [B,D,B,D]).
624        validate: Perform a validation that the density matrix factors cleanly.
625        atol: The absolute tolerance for the validation.
626
627    Returns:
628        A tuple with the `(extracted, remainder)` density matrices, where
629        `extracted` means the sub-matrix which corresponds to the axes
630        requested, and with the axes in the requested order, and where
631        `remainder` means the sub-matrix on the remaining axes, in the same
632        order as the original density matrix.
633    """
634    extracted = partial_trace(t, axes)
635    remaining_axes = [i for i in range(t.ndim // 2) if i not in axes]
636    remainder = partial_trace(t, remaining_axes)
637    if validate:
638        t1 = density_matrix_kronecker_product(extracted, remainder)
639        product_axes = list(axes) + remaining_axes
640        t2 = transpose_density_matrix_to_axis_order(t1, product_axes)
641        if not np.allclose(t2, t, atol=atol):
642            raise ValueError('The tensor cannot be factored by the requested axes')
643    return extracted, remainder
644
645
646# pylint: enable=missing-raises-doc
647def transpose_state_vector_to_axis_order(t: np.ndarray, axes: Sequence[int]):
648    """Transposes the axes of a state vector to a specified order.
649
650    Args:
651        t: The state vector to transpose.
652        axes: The desired axis order.
653    Returns:
654        The transposed state vector.
655    """
656    assert set(axes) == set(range(int(t.ndim))), "All axes must be provided."
657    return np.moveaxis(t, axes, range(len(axes)))
658
659
660def transpose_density_matrix_to_axis_order(t: np.ndarray, axes: Sequence[int]):
661    """Transposes the axes of a density matrix to a specified order.
662
663    Args:
664        t: The density matrix to transpose.
665        axes: The desired axis order. Only the left axes should be provided.
666            For example, to transpose [A,B,C,A,B,C] to [C,B,A,C,B,A], `axes`
667            should be [2,1,0].
668    Returns:
669        The transposed density matrix.
670    """
671    axes = list(axes) + [i + len(axes) for i in axes]
672    return transpose_state_vector_to_axis_order(t, axes)
673