1"""QR decomposition functions."""
2import numpy
3
4# Local imports
5from .lapack import get_lapack_funcs
6from .misc import _datacopied
7
8__all__ = ['qr', 'qr_multiply', 'rq']
9
10
11def safecall(f, name, *args, **kwargs):
12    """Call a LAPACK routine, determining lwork automatically and handling
13    error return values"""
14    lwork = kwargs.get("lwork", None)
15    if lwork in (None, -1):
16        kwargs['lwork'] = -1
17        ret = f(*args, **kwargs)
18        kwargs['lwork'] = ret[-2][0].real.astype(numpy.int_)
19    ret = f(*args, **kwargs)
20    if ret[-1] < 0:
21        raise ValueError("illegal value in %dth argument of internal %s"
22                         % (-ret[-1], name))
23    return ret[:-2]
24
25
26def qr(a, overwrite_a=False, lwork=None, mode='full', pivoting=False,
27       check_finite=True):
28    """
29    Compute QR decomposition of a matrix.
30
31    Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal
32    and R upper triangular.
33
34    Parameters
35    ----------
36    a : (M, N) array_like
37        Matrix to be decomposed
38    overwrite_a : bool, optional
39        Whether data in `a` is overwritten (may improve performance if
40        `overwrite_a` is set to True by reusing the existing input data
41        structure rather than creating a new one.)
42    lwork : int, optional
43        Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
44        is computed.
45    mode : {'full', 'r', 'economic', 'raw'}, optional
46        Determines what information is to be returned: either both Q and R
47        ('full', default), only R ('r') or both Q and R but computed in
48        economy-size ('economic', see Notes). The final option 'raw'
49        (added in SciPy 0.11) makes the function return two matrices
50        (Q, TAU) in the internal format used by LAPACK.
51    pivoting : bool, optional
52        Whether or not factorization should include pivoting for rank-revealing
53        qr decomposition. If pivoting, compute the decomposition
54        ``A P = Q R`` as above, but where P is chosen such that the diagonal
55        of R is non-increasing.
56    check_finite : bool, optional
57        Whether to check that the input matrix contains only finite numbers.
58        Disabling may give a performance gain, but may result in problems
59        (crashes, non-termination) if the inputs do contain infinities or NaNs.
60
61    Returns
62    -------
63    Q : float or complex ndarray
64        Of shape (M, M), or (M, K) for ``mode='economic'``. Not returned
65        if ``mode='r'``.
66    R : float or complex ndarray
67        Of shape (M, N), or (K, N) for ``mode='economic'``. ``K = min(M, N)``.
68    P : int ndarray
69        Of shape (N,) for ``pivoting=True``. Not returned if
70        ``pivoting=False``.
71
72    Raises
73    ------
74    LinAlgError
75        Raised if decomposition fails
76
77    Notes
78    -----
79    This is an interface to the LAPACK routines dgeqrf, zgeqrf,
80    dorgqr, zungqr, dgeqp3, and zgeqp3.
81
82    If ``mode=economic``, the shapes of Q and R are (M, K) and (K, N) instead
83    of (M,M) and (M,N), with ``K=min(M,N)``.
84
85    Examples
86    --------
87    >>> from scipy import linalg
88    >>> rng = np.random.default_rng()
89    >>> a = rng.standard_normal((9, 6))
90
91    >>> q, r = linalg.qr(a)
92    >>> np.allclose(a, np.dot(q, r))
93    True
94    >>> q.shape, r.shape
95    ((9, 9), (9, 6))
96
97    >>> r2 = linalg.qr(a, mode='r')
98    >>> np.allclose(r, r2)
99    True
100
101    >>> q3, r3 = linalg.qr(a, mode='economic')
102    >>> q3.shape, r3.shape
103    ((9, 6), (6, 6))
104
105    >>> q4, r4, p4 = linalg.qr(a, pivoting=True)
106    >>> d = np.abs(np.diag(r4))
107    >>> np.all(d[1:] <= d[:-1])
108    True
109    >>> np.allclose(a[:, p4], np.dot(q4, r4))
110    True
111    >>> q4.shape, r4.shape, p4.shape
112    ((9, 9), (9, 6), (6,))
113
114    >>> q5, r5, p5 = linalg.qr(a, mode='economic', pivoting=True)
115    >>> q5.shape, r5.shape, p5.shape
116    ((9, 6), (6, 6), (6,))
117
118    """
119    # 'qr' was the old default, equivalent to 'full'. Neither 'full' nor
120    # 'qr' are used below.
121    # 'raw' is used internally by qr_multiply
122    if mode not in ['full', 'qr', 'r', 'economic', 'raw']:
123        raise ValueError("Mode argument should be one of ['full', 'r',"
124                         "'economic', 'raw']")
125
126    if check_finite:
127        a1 = numpy.asarray_chkfinite(a)
128    else:
129        a1 = numpy.asarray(a)
130    if len(a1.shape) != 2:
131        raise ValueError("expected a 2-D array")
132    M, N = a1.shape
133    overwrite_a = overwrite_a or (_datacopied(a1, a))
134
135    if pivoting:
136        geqp3, = get_lapack_funcs(('geqp3',), (a1,))
137        qr, jpvt, tau = safecall(geqp3, "geqp3", a1, overwrite_a=overwrite_a)
138        jpvt -= 1  # geqp3 returns a 1-based index array, so subtract 1
139    else:
140        geqrf, = get_lapack_funcs(('geqrf',), (a1,))
141        qr, tau = safecall(geqrf, "geqrf", a1, lwork=lwork,
142                           overwrite_a=overwrite_a)
143
144    if mode not in ['economic', 'raw'] or M < N:
145        R = numpy.triu(qr)
146    else:
147        R = numpy.triu(qr[:N, :])
148
149    if pivoting:
150        Rj = R, jpvt
151    else:
152        Rj = R,
153
154    if mode == 'r':
155        return Rj
156    elif mode == 'raw':
157        return ((qr, tau),) + Rj
158
159    gor_un_gqr, = get_lapack_funcs(('orgqr',), (qr,))
160
161    if M < N:
162        Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr[:, :M], tau,
163                      lwork=lwork, overwrite_a=1)
164    elif mode == 'economic':
165        Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr, tau, lwork=lwork,
166                      overwrite_a=1)
167    else:
168        t = qr.dtype.char
169        qqr = numpy.empty((M, M), dtype=t)
170        qqr[:, :N] = qr
171        Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qqr, tau, lwork=lwork,
172                      overwrite_a=1)
173
174    return (Q,) + Rj
175
176
177def qr_multiply(a, c, mode='right', pivoting=False, conjugate=False,
178                overwrite_a=False, overwrite_c=False):
179    """
180    Calculate the QR decomposition and multiply Q with a matrix.
181
182    Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal
183    and R upper triangular. Multiply Q with a vector or a matrix c.
184
185    Parameters
186    ----------
187    a : (M, N), array_like
188        Input array
189    c : array_like
190        Input array to be multiplied by ``q``.
191    mode : {'left', 'right'}, optional
192        ``Q @ c`` is returned if mode is 'left', ``c @ Q`` is returned if
193        mode is 'right'.
194        The shape of c must be appropriate for the matrix multiplications,
195        if mode is 'left', ``min(a.shape) == c.shape[0]``,
196        if mode is 'right', ``a.shape[0] == c.shape[1]``.
197    pivoting : bool, optional
198        Whether or not factorization should include pivoting for rank-revealing
199        qr decomposition, see the documentation of qr.
200    conjugate : bool, optional
201        Whether Q should be complex-conjugated. This might be faster
202        than explicit conjugation.
203    overwrite_a : bool, optional
204        Whether data in a is overwritten (may improve performance)
205    overwrite_c : bool, optional
206        Whether data in c is overwritten (may improve performance).
207        If this is used, c must be big enough to keep the result,
208        i.e. ``c.shape[0]`` = ``a.shape[0]`` if mode is 'left'.
209
210    Returns
211    -------
212    CQ : ndarray
213        The product of ``Q`` and ``c``.
214    R : (K, N), ndarray
215        R array of the resulting QR factorization where ``K = min(M, N)``.
216    P : (N,) ndarray
217        Integer pivot array. Only returned when ``pivoting=True``.
218
219    Raises
220    ------
221    LinAlgError
222        Raised if QR decomposition fails.
223
224    Notes
225    -----
226    This is an interface to the LAPACK routines ``?GEQRF``, ``?ORMQR``,
227    ``?UNMQR``, and ``?GEQP3``.
228
229    .. versionadded:: 0.11.0
230
231    Examples
232    --------
233    >>> from scipy.linalg import qr_multiply, qr
234    >>> A = np.array([[1, 3, 3], [2, 3, 2], [2, 3, 3], [1, 3, 2]])
235    >>> qc, r1, piv1 = qr_multiply(A, 2*np.eye(4), pivoting=1)
236    >>> qc
237    array([[-1.,  1., -1.],
238           [-1., -1.,  1.],
239           [-1., -1., -1.],
240           [-1.,  1.,  1.]])
241    >>> r1
242    array([[-6., -3., -5.            ],
243           [ 0., -1., -1.11022302e-16],
244           [ 0.,  0., -1.            ]])
245    >>> piv1
246    array([1, 0, 2], dtype=int32)
247    >>> q2, r2, piv2 = qr(A, mode='economic', pivoting=1)
248    >>> np.allclose(2*q2 - qc, np.zeros((4, 3)))
249    True
250
251    """
252    if mode not in ['left', 'right']:
253        raise ValueError("Mode argument can only be 'left' or 'right' but "
254                         "not '{}'".format(mode))
255    c = numpy.asarray_chkfinite(c)
256    if c.ndim < 2:
257        onedim = True
258        c = numpy.atleast_2d(c)
259        if mode == "left":
260            c = c.T
261    else:
262        onedim = False
263
264    a = numpy.atleast_2d(numpy.asarray(a))  # chkfinite done in qr
265    M, N = a.shape
266
267    if mode == 'left':
268        if c.shape[0] != min(M, N + overwrite_c*(M-N)):
269            raise ValueError('Array shapes are not compatible for Q @ c'
270                             ' operation: {} vs {}'.format(a.shape, c.shape))
271    else:
272        if M != c.shape[1]:
273            raise ValueError('Array shapes are not compatible for c @ Q'
274                             ' operation: {} vs {}'.format(c.shape, a.shape))
275
276    raw = qr(a, overwrite_a, None, "raw", pivoting)
277    Q, tau = raw[0]
278
279    gor_un_mqr, = get_lapack_funcs(('ormqr',), (Q,))
280    if gor_un_mqr.typecode in ('s', 'd'):
281        trans = "T"
282    else:
283        trans = "C"
284
285    Q = Q[:, :min(M, N)]
286    if M > N and mode == "left" and not overwrite_c:
287        if conjugate:
288            cc = numpy.zeros((c.shape[1], M), dtype=c.dtype, order="F")
289            cc[:, :N] = c.T
290        else:
291            cc = numpy.zeros((M, c.shape[1]), dtype=c.dtype, order="F")
292            cc[:N, :] = c
293            trans = "N"
294        if conjugate:
295            lr = "R"
296        else:
297            lr = "L"
298        overwrite_c = True
299    elif c.flags["C_CONTIGUOUS"] and trans == "T" or conjugate:
300        cc = c.T
301        if mode == "left":
302            lr = "R"
303        else:
304            lr = "L"
305    else:
306        trans = "N"
307        cc = c
308        if mode == "left":
309            lr = "L"
310        else:
311            lr = "R"
312    cQ, = safecall(gor_un_mqr, "gormqr/gunmqr", lr, trans, Q, tau, cc,
313                   overwrite_c=overwrite_c)
314    if trans != "N":
315        cQ = cQ.T
316    if mode == "right":
317        cQ = cQ[:, :min(M, N)]
318    if onedim:
319        cQ = cQ.ravel()
320
321    return (cQ,) + raw[1:]
322
323
324def rq(a, overwrite_a=False, lwork=None, mode='full', check_finite=True):
325    """
326    Compute RQ decomposition of a matrix.
327
328    Calculate the decomposition ``A = R Q`` where Q is unitary/orthogonal
329    and R upper triangular.
330
331    Parameters
332    ----------
333    a : (M, N) array_like
334        Matrix to be decomposed
335    overwrite_a : bool, optional
336        Whether data in a is overwritten (may improve performance)
337    lwork : int, optional
338        Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
339        is computed.
340    mode : {'full', 'r', 'economic'}, optional
341        Determines what information is to be returned: either both Q and R
342        ('full', default), only R ('r') or both Q and R but computed in
343        economy-size ('economic', see Notes).
344    check_finite : bool, optional
345        Whether to check that the input matrix contains only finite numbers.
346        Disabling may give a performance gain, but may result in problems
347        (crashes, non-termination) if the inputs do contain infinities or NaNs.
348
349    Returns
350    -------
351    R : float or complex ndarray
352        Of shape (M, N) or (M, K) for ``mode='economic'``. ``K = min(M, N)``.
353    Q : float or complex ndarray
354        Of shape (N, N) or (K, N) for ``mode='economic'``. Not returned
355        if ``mode='r'``.
356
357    Raises
358    ------
359    LinAlgError
360        If decomposition fails.
361
362    Notes
363    -----
364    This is an interface to the LAPACK routines sgerqf, dgerqf, cgerqf, zgerqf,
365    sorgrq, dorgrq, cungrq and zungrq.
366
367    If ``mode=economic``, the shapes of Q and R are (K, N) and (M, K) instead
368    of (N,N) and (M,N), with ``K=min(M,N)``.
369
370    Examples
371    --------
372    >>> from scipy import linalg
373    >>> rng = np.random.default_rng()
374    >>> a = rng.standard_normal((6, 9))
375    >>> r, q = linalg.rq(a)
376    >>> np.allclose(a, r @ q)
377    True
378    >>> r.shape, q.shape
379    ((6, 9), (9, 9))
380    >>> r2 = linalg.rq(a, mode='r')
381    >>> np.allclose(r, r2)
382    True
383    >>> r3, q3 = linalg.rq(a, mode='economic')
384    >>> r3.shape, q3.shape
385    ((6, 6), (6, 9))
386
387    """
388    if mode not in ['full', 'r', 'economic']:
389        raise ValueError(
390                 "Mode argument should be one of ['full', 'r', 'economic']")
391
392    if check_finite:
393        a1 = numpy.asarray_chkfinite(a)
394    else:
395        a1 = numpy.asarray(a)
396    if len(a1.shape) != 2:
397        raise ValueError('expected matrix')
398    M, N = a1.shape
399    overwrite_a = overwrite_a or (_datacopied(a1, a))
400
401    gerqf, = get_lapack_funcs(('gerqf',), (a1,))
402    rq, tau = safecall(gerqf, 'gerqf', a1, lwork=lwork,
403                       overwrite_a=overwrite_a)
404    if not mode == 'economic' or N < M:
405        R = numpy.triu(rq, N-M)
406    else:
407        R = numpy.triu(rq[-M:, -M:])
408
409    if mode == 'r':
410        return R
411
412    gor_un_grq, = get_lapack_funcs(('orgrq',), (rq,))
413
414    if N < M:
415        Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq[-N:], tau, lwork=lwork,
416                      overwrite_a=1)
417    elif mode == 'economic':
418        Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq, tau, lwork=lwork,
419                      overwrite_a=1)
420    else:
421        rq1 = numpy.empty((N, N), dtype=rq.dtype)
422        rq1[-M:] = rq
423        Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq1, tau, lwork=lwork,
424                      overwrite_a=1)
425
426    return R, Q
427