1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Utility methods for diagonalizing matrices."""
16
17from typing import Tuple, Callable, List
18
19import numpy as np
20
21from cirq.linalg import combinators, predicates, tolerance
22
23
24def diagonalize_real_symmetric_matrix(
25    matrix: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8, check_preconditions: bool = True
26) -> np.ndarray:
27    """Returns an orthogonal matrix that diagonalizes the given matrix.
28
29    Args:
30        matrix: A real symmetric matrix to diagonalize.
31        rtol: Relative error tolerance.
32        atol: Absolute error tolerance.
33        check_preconditions: If set, verifies that the input matrix is real and
34            symmetric.
35
36    Returns:
37        An orthogonal matrix P such that P.T @ matrix @ P is diagonal.
38
39    Raises:
40        ValueError: Matrix isn't real symmetric.
41    """
42
43    if check_preconditions and (
44        np.any(np.imag(matrix) != 0) or not predicates.is_hermitian(matrix, rtol=rtol, atol=atol)
45    ):
46        raise ValueError('Input must be real and symmetric.')
47
48    _, result = np.linalg.eigh(matrix)
49
50    return result
51
52
53def _contiguous_groups(
54    length: int, comparator: Callable[[int, int], bool]
55) -> List[Tuple[int, int]]:
56    """Splits range(length) into approximate equivalence classes.
57
58    Args:
59        length: The length of the range to split.
60        comparator: Determines if two indices have approximately equal items.
61
62    Returns:
63        A list of (inclusive_start, exclusive_end) range endpoints. Each
64        corresponds to a run of approximately-equivalent items.
65    """
66    result = []
67    start = 0
68    while start < length:
69        past = start + 1
70        while past < length and comparator(start, past):
71            past += 1
72        result.append((start, past))
73        start = past
74    return result
75
76
77def diagonalize_real_symmetric_and_sorted_diagonal_matrices(
78    symmetric_matrix: np.ndarray,
79    diagonal_matrix: np.ndarray,
80    *,
81    rtol: float = 1e-5,
82    atol: float = 1e-8,
83    check_preconditions: bool = True,
84) -> np.ndarray:
85    """Returns an orthogonal matrix that diagonalizes both given matrices.
86
87    The given matrices must commute.
88    Guarantees that the sorted diagonal matrix is not permuted by the
89    diagonalization (except for nearly-equal values).
90
91    Args:
92        symmetric_matrix: A real symmetric matrix.
93        diagonal_matrix: A real diagonal matrix with entries along the diagonal
94            sorted into descending order.
95        rtol: Relative numeric error threshold.
96        atol: Absolute numeric error threshold.
97        check_preconditions: If set, verifies that the input matrices commute
98            and are respectively symmetric and diagonal descending.
99
100    Returns:
101        An orthogonal matrix P such that P.T @ symmetric_matrix @ P is diagonal
102        and P.T @ diagonal_matrix @ P = diagonal_matrix (up to tolerance).
103
104    Raises:
105        ValueError: Matrices don't meet preconditions (e.g. not symmetric).
106    """
107
108    # Verify preconditions.
109    if check_preconditions:
110        if np.any(np.imag(symmetric_matrix)) or not predicates.is_hermitian(
111            symmetric_matrix, rtol=rtol, atol=atol
112        ):
113            raise ValueError('symmetric_matrix must be real symmetric.')
114        if (
115            not predicates.is_diagonal(diagonal_matrix, atol=atol)
116            or np.any(np.imag(diagonal_matrix))
117            or np.any(diagonal_matrix[:-1, :-1] < diagonal_matrix[1:, 1:])
118        ):
119            raise ValueError('diagonal_matrix must be real diagonal descending.')
120        if not predicates.matrix_commutes(diagonal_matrix, symmetric_matrix, rtol=rtol, atol=atol):
121            raise ValueError('Given matrices must commute.')
122
123    def similar_singular(i, j):
124        return np.allclose(diagonal_matrix[i, i], diagonal_matrix[j, j], rtol=rtol)
125
126    # Because the symmetric matrix commutes with the diagonal singulars matrix,
127    # the symmetric matrix should be block-diagonal with a block boundary
128    # wherever the singular values happen change. So we can use the singular
129    # values to extract blocks that can be independently diagonalized.
130    ranges = _contiguous_groups(diagonal_matrix.shape[0], similar_singular)
131
132    # Build the overall diagonalization by diagonalizing each block.
133    p = np.zeros(symmetric_matrix.shape, dtype=np.float64)
134    for start, end in ranges:
135        block = symmetric_matrix[start:end, start:end]
136        p[start:end, start:end] = diagonalize_real_symmetric_matrix(
137            block, rtol=rtol, atol=atol, check_preconditions=False
138        )
139
140    return p
141
142
143def _svd_handling_empty(mat):
144    if not mat.shape[0] * mat.shape[1]:
145        z = np.zeros((0, 0), dtype=mat.dtype)
146        return z, np.array([]), z
147
148    return np.linalg.svd(mat)
149
150
151def bidiagonalize_real_matrix_pair_with_symmetric_products(
152    mat1: np.ndarray,
153    mat2: np.ndarray,
154    *,
155    rtol: float = 1e-5,
156    atol: float = 1e-8,
157    check_preconditions: bool = True,
158) -> Tuple[np.ndarray, np.ndarray]:
159    """Finds orthogonal matrices that diagonalize both mat1 and mat2.
160
161    Requires mat1 and mat2 to be real.
162    Requires mat1.T @ mat2 to be symmetric.
163    Requires mat1 @ mat2.T to be symmetric.
164
165    Args:
166        mat1: One of the real matrices.
167        mat2: The other real matrix.
168        rtol: Relative numeric error threshold.
169        atol: Absolute numeric error threshold.
170        check_preconditions: If set, verifies that the inputs are real, and that
171            mat1.T @ mat2 and mat1 @ mat2.T are both symmetric. Defaults to set.
172
173    Returns:
174        A tuple (L, R) of two orthogonal matrices, such that both L @ mat1 @ R
175        and L @ mat2 @ R are diagonal matrices.
176
177    Raises:
178        ValueError: Matrices don't meet preconditions (e.g. not real).
179    """
180
181    if check_preconditions:
182        if np.any(np.imag(mat1) != 0):
183            raise ValueError('mat1 must be real.')
184        if np.any(np.imag(mat2) != 0):
185            raise ValueError('mat2 must be real.')
186        if not predicates.is_hermitian(np.dot(mat1, mat2.T), rtol=rtol, atol=atol):
187            raise ValueError('mat1 @ mat2.T must be symmetric.')
188        if not predicates.is_hermitian(np.dot(mat1.T, mat2), rtol=rtol, atol=atol):
189            raise ValueError('mat1.T @ mat2 must be symmetric.')
190
191    # Use SVD to bi-diagonalize the first matrix.
192    base_left, base_diag, base_right = _svd_handling_empty(np.real(mat1))
193    base_diag = np.diag(base_diag)
194
195    # Determine where we switch between diagonalization-fixup strategies.
196    dim = base_diag.shape[0]
197    rank = dim
198    while rank > 0 and tolerance.all_near_zero(base_diag[rank - 1, rank - 1], atol=atol):
199        rank -= 1
200    base_diag = base_diag[:rank, :rank]
201
202    # Try diagonalizing the second matrix with the same factors as the first.
203    semi_corrected = combinators.dot(base_left.T, np.real(mat2), base_right.T)
204
205    # Fix up the part of the second matrix's diagonalization that's matched
206    # against non-zero diagonal entries in the first matrix's diagonalization
207    # by performing simultaneous diagonalization.
208    overlap = semi_corrected[:rank, :rank]
209    overlap_adjust = diagonalize_real_symmetric_and_sorted_diagonal_matrices(
210        overlap, base_diag, rtol=rtol, atol=atol, check_preconditions=check_preconditions
211    )
212
213    # Fix up the part of the second matrix's diagonalization that's matched
214    # against zeros in the first matrix's diagonalization by performing an SVD.
215    extra = semi_corrected[rank:, rank:]
216    extra_left_adjust, _, extra_right_adjust = _svd_handling_empty(extra)
217
218    # Merge the fixup factors into the initial diagonalization.
219    left_adjust = combinators.block_diag(overlap_adjust, extra_left_adjust)
220    right_adjust = combinators.block_diag(overlap_adjust.T, extra_right_adjust)
221    left = np.dot(left_adjust.T, base_left.T)
222    right = np.dot(base_right.T, right_adjust.T)
223
224    return left, right
225
226
227def bidiagonalize_unitary_with_special_orthogonals(
228    mat: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8, check_preconditions: bool = True
229) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
230    """Finds orthogonal matrices L, R such that L @ matrix @ R is diagonal.
231
232    Args:
233        mat: A unitary matrix.
234        rtol: Relative numeric error threshold.
235        atol: Absolute numeric error threshold.
236        check_preconditions: If set, verifies that the input is a unitary matrix
237            (to the given tolerances). Defaults to set.
238
239    Returns:
240        A triplet (L, d, R) such that L @ mat @ R = diag(d). Both L and R will
241        be orthogonal matrices with determinant equal to 1.
242
243    Raises:
244        ValueError: Matrices don't meet preconditions (e.g. not real).
245    """
246
247    if check_preconditions:
248        if not predicates.is_unitary(mat, rtol=rtol, atol=atol):
249            raise ValueError('matrix must be unitary.')
250
251    # Note: Because mat is unitary, setting A = real(mat) and B = imag(mat)
252    # guarantees that both A @ B.T and A.T @ B are Hermitian.
253    left, right = bidiagonalize_real_matrix_pair_with_symmetric_products(
254        np.real(mat), np.imag(mat), rtol=rtol, atol=atol, check_preconditions=check_preconditions
255    )
256
257    # Convert to special orthogonal w/o breaking diagonalization.
258    if np.linalg.det(left) < 0:
259        left[0, :] *= -1
260    if np.linalg.det(right) < 0:
261        right[:, 0] *= -1
262
263    diag = combinators.dot(left, mat, right)
264
265    return left, np.diag(diag), right
266