1"""QR decomposition functions.""" 2import numpy 3 4# Local imports 5from .lapack import get_lapack_funcs 6from .misc import _datacopied 7 8__all__ = ['qr', 'qr_multiply', 'rq'] 9 10 11def safecall(f, name, *args, **kwargs): 12 """Call a LAPACK routine, determining lwork automatically and handling 13 error return values""" 14 lwork = kwargs.get("lwork", None) 15 if lwork in (None, -1): 16 kwargs['lwork'] = -1 17 ret = f(*args, **kwargs) 18 kwargs['lwork'] = ret[-2][0].real.astype(numpy.int_) 19 ret = f(*args, **kwargs) 20 if ret[-1] < 0: 21 raise ValueError("illegal value in %dth argument of internal %s" 22 % (-ret[-1], name)) 23 return ret[:-2] 24 25 26def qr(a, overwrite_a=False, lwork=None, mode='full', pivoting=False, 27 check_finite=True): 28 """ 29 Compute QR decomposition of a matrix. 30 31 Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal 32 and R upper triangular. 33 34 Parameters 35 ---------- 36 a : (M, N) array_like 37 Matrix to be decomposed 38 overwrite_a : bool, optional 39 Whether data in `a` is overwritten (may improve performance if 40 `overwrite_a` is set to True by reusing the existing input data 41 structure rather than creating a new one.) 42 lwork : int, optional 43 Work array size, lwork >= a.shape[1]. If None or -1, an optimal size 44 is computed. 45 mode : {'full', 'r', 'economic', 'raw'}, optional 46 Determines what information is to be returned: either both Q and R 47 ('full', default), only R ('r') or both Q and R but computed in 48 economy-size ('economic', see Notes). The final option 'raw' 49 (added in SciPy 0.11) makes the function return two matrices 50 (Q, TAU) in the internal format used by LAPACK. 51 pivoting : bool, optional 52 Whether or not factorization should include pivoting for rank-revealing 53 qr decomposition. If pivoting, compute the decomposition 54 ``A P = Q R`` as above, but where P is chosen such that the diagonal 55 of R is non-increasing. 56 check_finite : bool, optional 57 Whether to check that the input matrix contains only finite numbers. 58 Disabling may give a performance gain, but may result in problems 59 (crashes, non-termination) if the inputs do contain infinities or NaNs. 60 61 Returns 62 ------- 63 Q : float or complex ndarray 64 Of shape (M, M), or (M, K) for ``mode='economic'``. Not returned 65 if ``mode='r'``. 66 R : float or complex ndarray 67 Of shape (M, N), or (K, N) for ``mode='economic'``. ``K = min(M, N)``. 68 P : int ndarray 69 Of shape (N,) for ``pivoting=True``. Not returned if 70 ``pivoting=False``. 71 72 Raises 73 ------ 74 LinAlgError 75 Raised if decomposition fails 76 77 Notes 78 ----- 79 This is an interface to the LAPACK routines dgeqrf, zgeqrf, 80 dorgqr, zungqr, dgeqp3, and zgeqp3. 81 82 If ``mode=economic``, the shapes of Q and R are (M, K) and (K, N) instead 83 of (M,M) and (M,N), with ``K=min(M,N)``. 84 85 Examples 86 -------- 87 >>> from scipy import linalg 88 >>> rng = np.random.default_rng() 89 >>> a = rng.standard_normal((9, 6)) 90 91 >>> q, r = linalg.qr(a) 92 >>> np.allclose(a, np.dot(q, r)) 93 True 94 >>> q.shape, r.shape 95 ((9, 9), (9, 6)) 96 97 >>> r2 = linalg.qr(a, mode='r') 98 >>> np.allclose(r, r2) 99 True 100 101 >>> q3, r3 = linalg.qr(a, mode='economic') 102 >>> q3.shape, r3.shape 103 ((9, 6), (6, 6)) 104 105 >>> q4, r4, p4 = linalg.qr(a, pivoting=True) 106 >>> d = np.abs(np.diag(r4)) 107 >>> np.all(d[1:] <= d[:-1]) 108 True 109 >>> np.allclose(a[:, p4], np.dot(q4, r4)) 110 True 111 >>> q4.shape, r4.shape, p4.shape 112 ((9, 9), (9, 6), (6,)) 113 114 >>> q5, r5, p5 = linalg.qr(a, mode='economic', pivoting=True) 115 >>> q5.shape, r5.shape, p5.shape 116 ((9, 6), (6, 6), (6,)) 117 118 """ 119 # 'qr' was the old default, equivalent to 'full'. Neither 'full' nor 120 # 'qr' are used below. 121 # 'raw' is used internally by qr_multiply 122 if mode not in ['full', 'qr', 'r', 'economic', 'raw']: 123 raise ValueError("Mode argument should be one of ['full', 'r'," 124 "'economic', 'raw']") 125 126 if check_finite: 127 a1 = numpy.asarray_chkfinite(a) 128 else: 129 a1 = numpy.asarray(a) 130 if len(a1.shape) != 2: 131 raise ValueError("expected a 2-D array") 132 M, N = a1.shape 133 overwrite_a = overwrite_a or (_datacopied(a1, a)) 134 135 if pivoting: 136 geqp3, = get_lapack_funcs(('geqp3',), (a1,)) 137 qr, jpvt, tau = safecall(geqp3, "geqp3", a1, overwrite_a=overwrite_a) 138 jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1 139 else: 140 geqrf, = get_lapack_funcs(('geqrf',), (a1,)) 141 qr, tau = safecall(geqrf, "geqrf", a1, lwork=lwork, 142 overwrite_a=overwrite_a) 143 144 if mode not in ['economic', 'raw'] or M < N: 145 R = numpy.triu(qr) 146 else: 147 R = numpy.triu(qr[:N, :]) 148 149 if pivoting: 150 Rj = R, jpvt 151 else: 152 Rj = R, 153 154 if mode == 'r': 155 return Rj 156 elif mode == 'raw': 157 return ((qr, tau),) + Rj 158 159 gor_un_gqr, = get_lapack_funcs(('orgqr',), (qr,)) 160 161 if M < N: 162 Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr[:, :M], tau, 163 lwork=lwork, overwrite_a=1) 164 elif mode == 'economic': 165 Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr, tau, lwork=lwork, 166 overwrite_a=1) 167 else: 168 t = qr.dtype.char 169 qqr = numpy.empty((M, M), dtype=t) 170 qqr[:, :N] = qr 171 Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qqr, tau, lwork=lwork, 172 overwrite_a=1) 173 174 return (Q,) + Rj 175 176 177def qr_multiply(a, c, mode='right', pivoting=False, conjugate=False, 178 overwrite_a=False, overwrite_c=False): 179 """ 180 Calculate the QR decomposition and multiply Q with a matrix. 181 182 Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal 183 and R upper triangular. Multiply Q with a vector or a matrix c. 184 185 Parameters 186 ---------- 187 a : (M, N), array_like 188 Input array 189 c : array_like 190 Input array to be multiplied by ``q``. 191 mode : {'left', 'right'}, optional 192 ``Q @ c`` is returned if mode is 'left', ``c @ Q`` is returned if 193 mode is 'right'. 194 The shape of c must be appropriate for the matrix multiplications, 195 if mode is 'left', ``min(a.shape) == c.shape[0]``, 196 if mode is 'right', ``a.shape[0] == c.shape[1]``. 197 pivoting : bool, optional 198 Whether or not factorization should include pivoting for rank-revealing 199 qr decomposition, see the documentation of qr. 200 conjugate : bool, optional 201 Whether Q should be complex-conjugated. This might be faster 202 than explicit conjugation. 203 overwrite_a : bool, optional 204 Whether data in a is overwritten (may improve performance) 205 overwrite_c : bool, optional 206 Whether data in c is overwritten (may improve performance). 207 If this is used, c must be big enough to keep the result, 208 i.e. ``c.shape[0]`` = ``a.shape[0]`` if mode is 'left'. 209 210 Returns 211 ------- 212 CQ : ndarray 213 The product of ``Q`` and ``c``. 214 R : (K, N), ndarray 215 R array of the resulting QR factorization where ``K = min(M, N)``. 216 P : (N,) ndarray 217 Integer pivot array. Only returned when ``pivoting=True``. 218 219 Raises 220 ------ 221 LinAlgError 222 Raised if QR decomposition fails. 223 224 Notes 225 ----- 226 This is an interface to the LAPACK routines ``?GEQRF``, ``?ORMQR``, 227 ``?UNMQR``, and ``?GEQP3``. 228 229 .. versionadded:: 0.11.0 230 231 Examples 232 -------- 233 >>> from scipy.linalg import qr_multiply, qr 234 >>> A = np.array([[1, 3, 3], [2, 3, 2], [2, 3, 3], [1, 3, 2]]) 235 >>> qc, r1, piv1 = qr_multiply(A, 2*np.eye(4), pivoting=1) 236 >>> qc 237 array([[-1., 1., -1.], 238 [-1., -1., 1.], 239 [-1., -1., -1.], 240 [-1., 1., 1.]]) 241 >>> r1 242 array([[-6., -3., -5. ], 243 [ 0., -1., -1.11022302e-16], 244 [ 0., 0., -1. ]]) 245 >>> piv1 246 array([1, 0, 2], dtype=int32) 247 >>> q2, r2, piv2 = qr(A, mode='economic', pivoting=1) 248 >>> np.allclose(2*q2 - qc, np.zeros((4, 3))) 249 True 250 251 """ 252 if mode not in ['left', 'right']: 253 raise ValueError("Mode argument can only be 'left' or 'right' but " 254 "not '{}'".format(mode)) 255 c = numpy.asarray_chkfinite(c) 256 if c.ndim < 2: 257 onedim = True 258 c = numpy.atleast_2d(c) 259 if mode == "left": 260 c = c.T 261 else: 262 onedim = False 263 264 a = numpy.atleast_2d(numpy.asarray(a)) # chkfinite done in qr 265 M, N = a.shape 266 267 if mode == 'left': 268 if c.shape[0] != min(M, N + overwrite_c*(M-N)): 269 raise ValueError('Array shapes are not compatible for Q @ c' 270 ' operation: {} vs {}'.format(a.shape, c.shape)) 271 else: 272 if M != c.shape[1]: 273 raise ValueError('Array shapes are not compatible for c @ Q' 274 ' operation: {} vs {}'.format(c.shape, a.shape)) 275 276 raw = qr(a, overwrite_a, None, "raw", pivoting) 277 Q, tau = raw[0] 278 279 gor_un_mqr, = get_lapack_funcs(('ormqr',), (Q,)) 280 if gor_un_mqr.typecode in ('s', 'd'): 281 trans = "T" 282 else: 283 trans = "C" 284 285 Q = Q[:, :min(M, N)] 286 if M > N and mode == "left" and not overwrite_c: 287 if conjugate: 288 cc = numpy.zeros((c.shape[1], M), dtype=c.dtype, order="F") 289 cc[:, :N] = c.T 290 else: 291 cc = numpy.zeros((M, c.shape[1]), dtype=c.dtype, order="F") 292 cc[:N, :] = c 293 trans = "N" 294 if conjugate: 295 lr = "R" 296 else: 297 lr = "L" 298 overwrite_c = True 299 elif c.flags["C_CONTIGUOUS"] and trans == "T" or conjugate: 300 cc = c.T 301 if mode == "left": 302 lr = "R" 303 else: 304 lr = "L" 305 else: 306 trans = "N" 307 cc = c 308 if mode == "left": 309 lr = "L" 310 else: 311 lr = "R" 312 cQ, = safecall(gor_un_mqr, "gormqr/gunmqr", lr, trans, Q, tau, cc, 313 overwrite_c=overwrite_c) 314 if trans != "N": 315 cQ = cQ.T 316 if mode == "right": 317 cQ = cQ[:, :min(M, N)] 318 if onedim: 319 cQ = cQ.ravel() 320 321 return (cQ,) + raw[1:] 322 323 324def rq(a, overwrite_a=False, lwork=None, mode='full', check_finite=True): 325 """ 326 Compute RQ decomposition of a matrix. 327 328 Calculate the decomposition ``A = R Q`` where Q is unitary/orthogonal 329 and R upper triangular. 330 331 Parameters 332 ---------- 333 a : (M, N) array_like 334 Matrix to be decomposed 335 overwrite_a : bool, optional 336 Whether data in a is overwritten (may improve performance) 337 lwork : int, optional 338 Work array size, lwork >= a.shape[1]. If None or -1, an optimal size 339 is computed. 340 mode : {'full', 'r', 'economic'}, optional 341 Determines what information is to be returned: either both Q and R 342 ('full', default), only R ('r') or both Q and R but computed in 343 economy-size ('economic', see Notes). 344 check_finite : bool, optional 345 Whether to check that the input matrix contains only finite numbers. 346 Disabling may give a performance gain, but may result in problems 347 (crashes, non-termination) if the inputs do contain infinities or NaNs. 348 349 Returns 350 ------- 351 R : float or complex ndarray 352 Of shape (M, N) or (M, K) for ``mode='economic'``. ``K = min(M, N)``. 353 Q : float or complex ndarray 354 Of shape (N, N) or (K, N) for ``mode='economic'``. Not returned 355 if ``mode='r'``. 356 357 Raises 358 ------ 359 LinAlgError 360 If decomposition fails. 361 362 Notes 363 ----- 364 This is an interface to the LAPACK routines sgerqf, dgerqf, cgerqf, zgerqf, 365 sorgrq, dorgrq, cungrq and zungrq. 366 367 If ``mode=economic``, the shapes of Q and R are (K, N) and (M, K) instead 368 of (N,N) and (M,N), with ``K=min(M,N)``. 369 370 Examples 371 -------- 372 >>> from scipy import linalg 373 >>> rng = np.random.default_rng() 374 >>> a = rng.standard_normal((6, 9)) 375 >>> r, q = linalg.rq(a) 376 >>> np.allclose(a, r @ q) 377 True 378 >>> r.shape, q.shape 379 ((6, 9), (9, 9)) 380 >>> r2 = linalg.rq(a, mode='r') 381 >>> np.allclose(r, r2) 382 True 383 >>> r3, q3 = linalg.rq(a, mode='economic') 384 >>> r3.shape, q3.shape 385 ((6, 6), (6, 9)) 386 387 """ 388 if mode not in ['full', 'r', 'economic']: 389 raise ValueError( 390 "Mode argument should be one of ['full', 'r', 'economic']") 391 392 if check_finite: 393 a1 = numpy.asarray_chkfinite(a) 394 else: 395 a1 = numpy.asarray(a) 396 if len(a1.shape) != 2: 397 raise ValueError('expected matrix') 398 M, N = a1.shape 399 overwrite_a = overwrite_a or (_datacopied(a1, a)) 400 401 gerqf, = get_lapack_funcs(('gerqf',), (a1,)) 402 rq, tau = safecall(gerqf, 'gerqf', a1, lwork=lwork, 403 overwrite_a=overwrite_a) 404 if not mode == 'economic' or N < M: 405 R = numpy.triu(rq, N-M) 406 else: 407 R = numpy.triu(rq[-M:, -M:]) 408 409 if mode == 'r': 410 return R 411 412 gor_un_grq, = get_lapack_funcs(('orgrq',), (rq,)) 413 414 if N < M: 415 Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq[-N:], tau, lwork=lwork, 416 overwrite_a=1) 417 elif mode == 'economic': 418 Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq, tau, lwork=lwork, 419 overwrite_a=1) 420 else: 421 rq1 = numpy.empty((N, N), dtype=rq.dtype) 422 rq1[-M:] = rq 423 Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq1, tau, lwork=lwork, 424 overwrite_a=1) 425 426 return R, Q 427