1#
2# Author: Pearu Peterson, March 2002
3#
4# w/ additions by Travis Oliphant, March 2002
5#              and Jake Vanderplas, August 2012
6
7from warnings import warn
8import numpy as np
9from numpy import atleast_1d, atleast_2d
10from .flinalg import get_flinalg_funcs
11from .lapack import get_lapack_funcs, _compute_lwork
12from .misc import LinAlgError, _datacopied, LinAlgWarning
13from .decomp import _asarray_validated
14from . import decomp, decomp_svd
15from ._solve_toeplitz import levinson
16
17__all__ = ['solve', 'solve_triangular', 'solveh_banded', 'solve_banded',
18           'solve_toeplitz', 'solve_circulant', 'inv', 'det', 'lstsq',
19           'pinv', 'pinv2', 'pinvh', 'matrix_balance', 'matmul_toeplitz']
20
21
22# Linear equations
23def _solve_check(n, info, lamch=None, rcond=None):
24    """ Check arguments during the different steps of the solution phase """
25    if info < 0:
26        raise ValueError('LAPACK reported an illegal value in {}-th argument'
27                         '.'.format(-info))
28    elif 0 < info:
29        raise LinAlgError('Matrix is singular.')
30
31    if lamch is None:
32        return
33    E = lamch('E')
34    if rcond < E:
35        warn('Ill-conditioned matrix (rcond={:.6g}): '
36             'result may not be accurate.'.format(rcond),
37             LinAlgWarning, stacklevel=3)
38
39
40def solve(a, b, sym_pos=False, lower=False, overwrite_a=False,
41          overwrite_b=False, debug=None, check_finite=True, assume_a='gen',
42          transposed=False):
43    """
44    Solves the linear equation set ``a * x = b`` for the unknown ``x``
45    for square ``a`` matrix.
46
47    If the data matrix is known to be a particular type then supplying the
48    corresponding string to ``assume_a`` key chooses the dedicated solver.
49    The available options are
50
51    ===================  ========
52     generic matrix       'gen'
53     symmetric            'sym'
54     hermitian            'her'
55     positive definite    'pos'
56    ===================  ========
57
58    If omitted, ``'gen'`` is the default structure.
59
60    The datatype of the arrays define which solver is called regardless
61    of the values. In other words, even when the complex array entries have
62    precisely zero imaginary parts, the complex solver will be called based
63    on the data type of the array.
64
65    Parameters
66    ----------
67    a : (N, N) array_like
68        Square input data
69    b : (N, NRHS) array_like
70        Input data for the right hand side.
71    sym_pos : bool, optional
72        Assume `a` is symmetric and positive definite. This key is deprecated
73        and assume_a = 'pos' keyword is recommended instead. The functionality
74        is the same. It will be removed in the future.
75    lower : bool, optional
76        If True, only the data contained in the lower triangle of `a`. Default
77        is to use upper triangle. (ignored for ``'gen'``)
78    overwrite_a : bool, optional
79        Allow overwriting data in `a` (may enhance performance).
80        Default is False.
81    overwrite_b : bool, optional
82        Allow overwriting data in `b` (may enhance performance).
83        Default is False.
84    check_finite : bool, optional
85        Whether to check that the input matrices contain only finite numbers.
86        Disabling may give a performance gain, but may result in problems
87        (crashes, non-termination) if the inputs do contain infinities or NaNs.
88    assume_a : str, optional
89        Valid entries are explained above.
90    transposed: bool, optional
91        If True, ``a^T x = b`` for real matrices, raises `NotImplementedError`
92        for complex matrices (only for True).
93
94    Returns
95    -------
96    x : (N, NRHS) ndarray
97        The solution array.
98
99    Raises
100    ------
101    ValueError
102        If size mismatches detected or input a is not square.
103    LinAlgError
104        If the matrix is singular.
105    LinAlgWarning
106        If an ill-conditioned input a is detected.
107    NotImplementedError
108        If transposed is True and input a is a complex matrix.
109
110    Examples
111    --------
112    Given `a` and `b`, solve for `x`:
113
114    >>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
115    >>> b = np.array([2, 4, -1])
116    >>> from scipy import linalg
117    >>> x = linalg.solve(a, b)
118    >>> x
119    array([ 2., -2.,  9.])
120    >>> np.dot(a, x) == b
121    array([ True,  True,  True], dtype=bool)
122
123    Notes
124    -----
125    If the input b matrix is a 1-D array with N elements, when supplied
126    together with an NxN input a, it is assumed as a valid column vector
127    despite the apparent size mismatch. This is compatible with the
128    numpy.dot() behavior and the returned result is still 1-D array.
129
130    The generic, symmetric, Hermitian and positive definite solutions are
131    obtained via calling ?GESV, ?SYSV, ?HESV, and ?POSV routines of
132    LAPACK respectively.
133    """
134    # Flags for 1-D or N-D right-hand side
135    b_is_1D = False
136
137    a1 = atleast_2d(_asarray_validated(a, check_finite=check_finite))
138    b1 = atleast_1d(_asarray_validated(b, check_finite=check_finite))
139    n = a1.shape[0]
140
141    overwrite_a = overwrite_a or _datacopied(a1, a)
142    overwrite_b = overwrite_b or _datacopied(b1, b)
143
144    if a1.shape[0] != a1.shape[1]:
145        raise ValueError('Input a needs to be a square matrix.')
146
147    if n != b1.shape[0]:
148        # Last chance to catch 1x1 scalar a and 1-D b arrays
149        if not (n == 1 and b1.size != 0):
150            raise ValueError('Input b has to have same number of rows as '
151                             'input a')
152
153    # accommodate empty arrays
154    if b1.size == 0:
155        return np.asfortranarray(b1.copy())
156
157    # regularize 1-D b arrays to 2D
158    if b1.ndim == 1:
159        if n == 1:
160            b1 = b1[None, :]
161        else:
162            b1 = b1[:, None]
163        b_is_1D = True
164
165    # Backwards compatibility - old keyword.
166    if sym_pos:
167        assume_a = 'pos'
168
169    if assume_a not in ('gen', 'sym', 'her', 'pos'):
170        raise ValueError('{} is not a recognized matrix structure'
171                         ''.format(assume_a))
172
173    # for a real matrix, describe it as "symmetric", not "hermitian"
174    # (lapack doesn't know what to do with real hermitian matrices)
175    if assume_a == 'her' and not np.iscomplexobj(a1):
176        assume_a = 'sym'
177
178    # Deprecate keyword "debug"
179    if debug is not None:
180        warn('Use of the "debug" keyword is deprecated '
181             'and this keyword will be removed in future '
182             'versions of SciPy.', DeprecationWarning, stacklevel=2)
183
184    # Get the correct lamch function.
185    # The LAMCH functions only exists for S and D
186    # So for complex values we have to convert to real/double.
187    if a1.dtype.char in 'fF':  # single precision
188        lamch = get_lapack_funcs('lamch', dtype='f')
189    else:
190        lamch = get_lapack_funcs('lamch', dtype='d')
191
192    # Currently we do not have the other forms of the norm calculators
193    #   lansy, lanpo, lanhe.
194    # However, in any case they only reduce computations slightly...
195    lange = get_lapack_funcs('lange', (a1,))
196
197    # Since the I-norm and 1-norm are the same for symmetric matrices
198    # we can collect them all in this one call
199    # Note however, that when issuing 'gen' and form!='none', then
200    # the I-norm should be used
201    if transposed:
202        trans = 1
203        norm = 'I'
204        if np.iscomplexobj(a1):
205            raise NotImplementedError('scipy.linalg.solve can currently '
206                                      'not solve a^T x = b or a^H x = b '
207                                      'for complex matrices.')
208    else:
209        trans = 0
210        norm = '1'
211
212    anorm = lange(norm, a1)
213
214    # Generalized case 'gesv'
215    if assume_a == 'gen':
216        gecon, getrf, getrs = get_lapack_funcs(('gecon', 'getrf', 'getrs'),
217                                               (a1, b1))
218        lu, ipvt, info = getrf(a1, overwrite_a=overwrite_a)
219        _solve_check(n, info)
220        x, info = getrs(lu, ipvt, b1,
221                        trans=trans, overwrite_b=overwrite_b)
222        _solve_check(n, info)
223        rcond, info = gecon(lu, anorm, norm=norm)
224    # Hermitian case 'hesv'
225    elif assume_a == 'her':
226        hecon, hesv, hesv_lw = get_lapack_funcs(('hecon', 'hesv',
227                                                 'hesv_lwork'), (a1, b1))
228        lwork = _compute_lwork(hesv_lw, n, lower)
229        lu, ipvt, x, info = hesv(a1, b1, lwork=lwork,
230                                 lower=lower,
231                                 overwrite_a=overwrite_a,
232                                 overwrite_b=overwrite_b)
233        _solve_check(n, info)
234        rcond, info = hecon(lu, ipvt, anorm)
235    # Symmetric case 'sysv'
236    elif assume_a == 'sym':
237        sycon, sysv, sysv_lw = get_lapack_funcs(('sycon', 'sysv',
238                                                 'sysv_lwork'), (a1, b1))
239        lwork = _compute_lwork(sysv_lw, n, lower)
240        lu, ipvt, x, info = sysv(a1, b1, lwork=lwork,
241                                 lower=lower,
242                                 overwrite_a=overwrite_a,
243                                 overwrite_b=overwrite_b)
244        _solve_check(n, info)
245        rcond, info = sycon(lu, ipvt, anorm)
246    # Positive definite case 'posv'
247    else:
248        pocon, posv = get_lapack_funcs(('pocon', 'posv'),
249                                       (a1, b1))
250        lu, x, info = posv(a1, b1, lower=lower,
251                           overwrite_a=overwrite_a,
252                           overwrite_b=overwrite_b)
253        _solve_check(n, info)
254        rcond, info = pocon(lu, anorm)
255
256    _solve_check(n, info, lamch, rcond)
257
258    if b_is_1D:
259        x = x.ravel()
260
261    return x
262
263
264def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
265                     overwrite_b=False, debug=None, check_finite=True):
266    """
267    Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.
268
269    Parameters
270    ----------
271    a : (M, M) array_like
272        A triangular matrix
273    b : (M,) or (M, N) array_like
274        Right-hand side matrix in `a x = b`
275    lower : bool, optional
276        Use only data contained in the lower triangle of `a`.
277        Default is to use upper triangle.
278    trans : {0, 1, 2, 'N', 'T', 'C'}, optional
279        Type of system to solve:
280
281        ========  =========
282        trans     system
283        ========  =========
284        0 or 'N'  a x  = b
285        1 or 'T'  a^T x = b
286        2 or 'C'  a^H x = b
287        ========  =========
288    unit_diagonal : bool, optional
289        If True, diagonal elements of `a` are assumed to be 1 and
290        will not be referenced.
291    overwrite_b : bool, optional
292        Allow overwriting data in `b` (may enhance performance)
293    check_finite : bool, optional
294        Whether to check that the input matrices contain only finite numbers.
295        Disabling may give a performance gain, but may result in problems
296        (crashes, non-termination) if the inputs do contain infinities or NaNs.
297
298    Returns
299    -------
300    x : (M,) or (M, N) ndarray
301        Solution to the system `a x = b`.  Shape of return matches `b`.
302
303    Raises
304    ------
305    LinAlgError
306        If `a` is singular
307
308    Notes
309    -----
310    .. versionadded:: 0.9.0
311
312    Examples
313    --------
314    Solve the lower triangular system a x = b, where::
315
316             [3  0  0  0]       [4]
317        a =  [2  1  0  0]   b = [2]
318             [1  0  1  0]       [4]
319             [1  1  1  1]       [2]
320
321    >>> from scipy.linalg import solve_triangular
322    >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
323    >>> b = np.array([4, 2, 4, 2])
324    >>> x = solve_triangular(a, b, lower=True)
325    >>> x
326    array([ 1.33333333, -0.66666667,  2.66666667, -1.33333333])
327    >>> a.dot(x)  # Check the result
328    array([ 4.,  2.,  4.,  2.])
329
330    """
331
332    # Deprecate keyword "debug"
333    if debug is not None:
334        warn('Use of the "debug" keyword is deprecated '
335             'and this keyword will be removed in the future '
336             'versions of SciPy.', DeprecationWarning, stacklevel=2)
337
338    a1 = _asarray_validated(a, check_finite=check_finite)
339    b1 = _asarray_validated(b, check_finite=check_finite)
340    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
341        raise ValueError('expected square matrix')
342    if a1.shape[0] != b1.shape[0]:
343        raise ValueError('shapes of a {} and b {} are incompatible'
344                         .format(a1.shape, b1.shape))
345    overwrite_b = overwrite_b or _datacopied(b1, b)
346    if debug:
347        print('solve:overwrite_b=', overwrite_b)
348    trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans)
349    trtrs, = get_lapack_funcs(('trtrs',), (a1, b1))
350    if a1.flags.f_contiguous or trans == 2:
351        x, info = trtrs(a1, b1, overwrite_b=overwrite_b, lower=lower,
352                        trans=trans, unitdiag=unit_diagonal)
353    else:
354        # transposed system is solved since trtrs expects Fortran ordering
355        x, info = trtrs(a1.T, b1, overwrite_b=overwrite_b, lower=not lower,
356                        trans=not trans, unitdiag=unit_diagonal)
357
358    if info == 0:
359        return x
360    if info > 0:
361        raise LinAlgError("singular matrix: resolution failed at diagonal %d" %
362                          (info-1))
363    raise ValueError('illegal value in %dth argument of internal trtrs' %
364                     (-info))
365
366
367def solve_banded(l_and_u, ab, b, overwrite_ab=False, overwrite_b=False,
368                 debug=None, check_finite=True):
369    """
370    Solve the equation a x = b for x, assuming a is banded matrix.
371
372    The matrix a is stored in `ab` using the matrix diagonal ordered form::
373
374        ab[u + i - j, j] == a[i,j]
375
376    Example of `ab` (shape of a is (6,6), `u` =1, `l` =2)::
377
378        *    a01  a12  a23  a34  a45
379        a00  a11  a22  a33  a44  a55
380        a10  a21  a32  a43  a54   *
381        a20  a31  a42  a53   *    *
382
383    Parameters
384    ----------
385    (l, u) : (integer, integer)
386        Number of non-zero lower and upper diagonals
387    ab : (`l` + `u` + 1, M) array_like
388        Banded matrix
389    b : (M,) or (M, K) array_like
390        Right-hand side
391    overwrite_ab : bool, optional
392        Discard data in `ab` (may enhance performance)
393    overwrite_b : bool, optional
394        Discard data in `b` (may enhance performance)
395    check_finite : bool, optional
396        Whether to check that the input matrices contain only finite numbers.
397        Disabling may give a performance gain, but may result in problems
398        (crashes, non-termination) if the inputs do contain infinities or NaNs.
399
400    Returns
401    -------
402    x : (M,) or (M, K) ndarray
403        The solution to the system a x = b. Returned shape depends on the
404        shape of `b`.
405
406    Examples
407    --------
408    Solve the banded system a x = b, where::
409
410            [5  2 -1  0  0]       [0]
411            [1  4  2 -1  0]       [1]
412        a = [0  1  3  2 -1]   b = [2]
413            [0  0  1  2  2]       [2]
414            [0  0  0  1  1]       [3]
415
416    There is one nonzero diagonal below the main diagonal (l = 1), and
417    two above (u = 2). The diagonal banded form of the matrix is::
418
419             [*  * -1 -1 -1]
420        ab = [*  2  2  2  2]
421             [5  4  3  2  1]
422             [1  1  1  1  *]
423
424    >>> from scipy.linalg import solve_banded
425    >>> ab = np.array([[0,  0, -1, -1, -1],
426    ...                [0,  2,  2,  2,  2],
427    ...                [5,  4,  3,  2,  1],
428    ...                [1,  1,  1,  1,  0]])
429    >>> b = np.array([0, 1, 2, 2, 3])
430    >>> x = solve_banded((1, 2), ab, b)
431    >>> x
432    array([-2.37288136,  3.93220339, -4.        ,  4.3559322 , -1.3559322 ])
433
434    """
435
436    # Deprecate keyword "debug"
437    if debug is not None:
438        warn('Use of the "debug" keyword is deprecated '
439             'and this keyword will be removed in the future '
440             'versions of SciPy.', DeprecationWarning, stacklevel=2)
441
442    a1 = _asarray_validated(ab, check_finite=check_finite, as_inexact=True)
443    b1 = _asarray_validated(b, check_finite=check_finite, as_inexact=True)
444    # Validate shapes.
445    if a1.shape[-1] != b1.shape[0]:
446        raise ValueError("shapes of ab and b are not compatible.")
447    (nlower, nupper) = l_and_u
448    if nlower + nupper + 1 != a1.shape[0]:
449        raise ValueError("invalid values for the number of lower and upper "
450                         "diagonals: l+u+1 (%d) does not equal ab.shape[0] "
451                         "(%d)" % (nlower + nupper + 1, ab.shape[0]))
452
453    overwrite_b = overwrite_b or _datacopied(b1, b)
454    if a1.shape[-1] == 1:
455        b2 = np.array(b1, copy=(not overwrite_b))
456        b2 /= a1[1, 0]
457        return b2
458    if nlower == nupper == 1:
459        overwrite_ab = overwrite_ab or _datacopied(a1, ab)
460        gtsv, = get_lapack_funcs(('gtsv',), (a1, b1))
461        du = a1[0, 1:]
462        d = a1[1, :]
463        dl = a1[2, :-1]
464        du2, d, du, x, info = gtsv(dl, d, du, b1, overwrite_ab, overwrite_ab,
465                                   overwrite_ab, overwrite_b)
466    else:
467        gbsv, = get_lapack_funcs(('gbsv',), (a1, b1))
468        a2 = np.zeros((2*nlower + nupper + 1, a1.shape[1]), dtype=gbsv.dtype)
469        a2[nlower:, :] = a1
470        lu, piv, x, info = gbsv(nlower, nupper, a2, b1, overwrite_ab=True,
471                                overwrite_b=overwrite_b)
472    if info == 0:
473        return x
474    if info > 0:
475        raise LinAlgError("singular matrix")
476    raise ValueError('illegal value in %d-th argument of internal '
477                     'gbsv/gtsv' % -info)
478
479
480def solveh_banded(ab, b, overwrite_ab=False, overwrite_b=False, lower=False,
481                  check_finite=True):
482    """
483    Solve equation a x = b. a is Hermitian positive-definite banded matrix.
484
485    The matrix a is stored in `ab` either in lower diagonal or upper
486    diagonal ordered form:
487
488        ab[u + i - j, j] == a[i,j]        (if upper form; i <= j)
489        ab[    i - j, j] == a[i,j]        (if lower form; i >= j)
490
491    Example of `ab` (shape of a is (6, 6), `u` =2)::
492
493        upper form:
494        *   *   a02 a13 a24 a35
495        *   a01 a12 a23 a34 a45
496        a00 a11 a22 a33 a44 a55
497
498        lower form:
499        a00 a11 a22 a33 a44 a55
500        a10 a21 a32 a43 a54 *
501        a20 a31 a42 a53 *   *
502
503    Cells marked with * are not used.
504
505    Parameters
506    ----------
507    ab : (`u` + 1, M) array_like
508        Banded matrix
509    b : (M,) or (M, K) array_like
510        Right-hand side
511    overwrite_ab : bool, optional
512        Discard data in `ab` (may enhance performance)
513    overwrite_b : bool, optional
514        Discard data in `b` (may enhance performance)
515    lower : bool, optional
516        Is the matrix in the lower form. (Default is upper form)
517    check_finite : bool, optional
518        Whether to check that the input matrices contain only finite numbers.
519        Disabling may give a performance gain, but may result in problems
520        (crashes, non-termination) if the inputs do contain infinities or NaNs.
521
522    Returns
523    -------
524    x : (M,) or (M, K) ndarray
525        The solution to the system a x = b. Shape of return matches shape
526        of `b`.
527
528    Examples
529    --------
530    Solve the banded system A x = b, where::
531
532            [ 4  2 -1  0  0  0]       [1]
533            [ 2  5  2 -1  0  0]       [2]
534        A = [-1  2  6  2 -1  0]   b = [2]
535            [ 0 -1  2  7  2 -1]       [3]
536            [ 0  0 -1  2  8  2]       [3]
537            [ 0  0  0 -1  2  9]       [3]
538
539    >>> from scipy.linalg import solveh_banded
540
541    `ab` contains the main diagonal and the nonzero diagonals below the
542    main diagonal. That is, we use the lower form:
543
544    >>> ab = np.array([[ 4,  5,  6,  7, 8, 9],
545    ...                [ 2,  2,  2,  2, 2, 0],
546    ...                [-1, -1, -1, -1, 0, 0]])
547    >>> b = np.array([1, 2, 2, 3, 3, 3])
548    >>> x = solveh_banded(ab, b, lower=True)
549    >>> x
550    array([ 0.03431373,  0.45938375,  0.05602241,  0.47759104,  0.17577031,
551            0.34733894])
552
553
554    Solve the Hermitian banded system H x = b, where::
555
556            [ 8   2-1j   0     0  ]        [ 1  ]
557        H = [2+1j  5     1j    0  ]    b = [1+1j]
558            [ 0   -1j    9   -2-1j]        [1-2j]
559            [ 0    0   -2+1j   6  ]        [ 0  ]
560
561    In this example, we put the upper diagonals in the array `hb`:
562
563    >>> hb = np.array([[0, 2-1j, 1j, -2-1j],
564    ...                [8,  5,    9,   6  ]])
565    >>> b = np.array([1, 1+1j, 1-2j, 0])
566    >>> x = solveh_banded(hb, b)
567    >>> x
568    array([ 0.07318536-0.02939412j,  0.11877624+0.17696461j,
569            0.10077984-0.23035393j, -0.00479904-0.09358128j])
570
571    """
572    a1 = _asarray_validated(ab, check_finite=check_finite)
573    b1 = _asarray_validated(b, check_finite=check_finite)
574    # Validate shapes.
575    if a1.shape[-1] != b1.shape[0]:
576        raise ValueError("shapes of ab and b are not compatible.")
577
578    overwrite_b = overwrite_b or _datacopied(b1, b)
579    overwrite_ab = overwrite_ab or _datacopied(a1, ab)
580
581    if a1.shape[0] == 2:
582        ptsv, = get_lapack_funcs(('ptsv',), (a1, b1))
583        if lower:
584            d = a1[0, :].real
585            e = a1[1, :-1]
586        else:
587            d = a1[1, :].real
588            e = a1[0, 1:].conj()
589        d, du, x, info = ptsv(d, e, b1, overwrite_ab, overwrite_ab,
590                              overwrite_b)
591    else:
592        pbsv, = get_lapack_funcs(('pbsv',), (a1, b1))
593        c, x, info = pbsv(a1, b1, lower=lower, overwrite_ab=overwrite_ab,
594                          overwrite_b=overwrite_b)
595    if info > 0:
596        raise LinAlgError("%dth leading minor not positive definite" % info)
597    if info < 0:
598        raise ValueError('illegal value in %dth argument of internal '
599                         'pbsv' % -info)
600    return x
601
602
603def solve_toeplitz(c_or_cr, b, check_finite=True):
604    """Solve a Toeplitz system using Levinson Recursion
605
606    The Toeplitz matrix has constant diagonals, with c as its first column
607    and r as its first row. If r is not given, ``r == conjugate(c)`` is
608    assumed.
609
610    Parameters
611    ----------
612    c_or_cr : array_like or tuple of (array_like, array_like)
613        The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
614        actual shape of ``c``, it will be converted to a 1-D array. If not
615        supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
616        real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
617        of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
618        of ``r``, it will be converted to a 1-D array.
619    b : (M,) or (M, K) array_like
620        Right-hand side in ``T x = b``.
621    check_finite : bool, optional
622        Whether to check that the input matrices contain only finite numbers.
623        Disabling may give a performance gain, but may result in problems
624        (result entirely NaNs) if the inputs do contain infinities or NaNs.
625
626    Returns
627    -------
628    x : (M,) or (M, K) ndarray
629        The solution to the system ``T x = b``. Shape of return matches shape
630        of `b`.
631
632    See Also
633    --------
634    toeplitz : Toeplitz matrix
635
636    Notes
637    -----
638    The solution is computed using Levinson-Durbin recursion, which is faster
639    than generic least-squares methods, but can be less numerically stable.
640
641    Examples
642    --------
643    Solve the Toeplitz system T x = b, where::
644
645            [ 1 -1 -2 -3]       [1]
646        T = [ 3  1 -1 -2]   b = [2]
647            [ 6  3  1 -1]       [2]
648            [10  6  3  1]       [5]
649
650    To specify the Toeplitz matrix, only the first column and the first
651    row are needed.
652
653    >>> c = np.array([1, 3, 6, 10])    # First column of T
654    >>> r = np.array([1, -1, -2, -3])  # First row of T
655    >>> b = np.array([1, 2, 2, 5])
656
657    >>> from scipy.linalg import solve_toeplitz, toeplitz
658    >>> x = solve_toeplitz((c, r), b)
659    >>> x
660    array([ 1.66666667, -1.        , -2.66666667,  2.33333333])
661
662    Check the result by creating the full Toeplitz matrix and
663    multiplying it by `x`.  We should get `b`.
664
665    >>> T = toeplitz(c, r)
666    >>> T.dot(x)
667    array([ 1.,  2.,  2.,  5.])
668
669    """
670    # If numerical stability of this algorithm is a problem, a future
671    # developer might consider implementing other O(N^2) Toeplitz solvers,
672    # such as GKO (https://www.jstor.org/stable/2153371) or Bareiss.
673
674    r, c, b, dtype, b_shape = _validate_args_for_toeplitz_ops(
675        c_or_cr, b, check_finite, keep_b_shape=True)
676
677    # Form a 1-D array of values to be used in the matrix, containing a
678    # reversed copy of r[1:], followed by c.
679    vals = np.concatenate((r[-1:0:-1], c))
680    if b is None:
681        raise ValueError('illegal value, `b` is a required argument')
682
683    if b.ndim == 1:
684        x, _ = levinson(vals, np.ascontiguousarray(b))
685    else:
686        x = np.column_stack([levinson(vals, np.ascontiguousarray(b[:, i]))[0]
687                             for i in range(b.shape[1])])
688        x = x.reshape(*b_shape)
689
690    return x
691
692
693def _get_axis_len(aname, a, axis):
694    ax = axis
695    if ax < 0:
696        ax += a.ndim
697    if 0 <= ax < a.ndim:
698        return a.shape[ax]
699    raise ValueError("'%saxis' entry is out of bounds" % (aname,))
700
701
702def solve_circulant(c, b, singular='raise', tol=None,
703                    caxis=-1, baxis=0, outaxis=0):
704    """Solve C x = b for x, where C is a circulant matrix.
705
706    `C` is the circulant matrix associated with the vector `c`.
707
708    The system is solved by doing division in Fourier space. The
709    calculation is::
710
711        x = ifft(fft(b) / fft(c))
712
713    where `fft` and `ifft` are the fast Fourier transform and its inverse,
714    respectively. For a large vector `c`, this is *much* faster than
715    solving the system with the full circulant matrix.
716
717    Parameters
718    ----------
719    c : array_like
720        The coefficients of the circulant matrix.
721    b : array_like
722        Right-hand side matrix in ``a x = b``.
723    singular : str, optional
724        This argument controls how a near singular circulant matrix is
725        handled.  If `singular` is "raise" and the circulant matrix is
726        near singular, a `LinAlgError` is raised. If `singular` is
727        "lstsq", the least squares solution is returned. Default is "raise".
728    tol : float, optional
729        If any eigenvalue of the circulant matrix has an absolute value
730        that is less than or equal to `tol`, the matrix is considered to be
731        near singular. If not given, `tol` is set to::
732
733            tol = abs_eigs.max() * abs_eigs.size * np.finfo(np.float64).eps
734
735        where `abs_eigs` is the array of absolute values of the eigenvalues
736        of the circulant matrix.
737    caxis : int
738        When `c` has dimension greater than 1, it is viewed as a collection
739        of circulant vectors. In this case, `caxis` is the axis of `c` that
740        holds the vectors of circulant coefficients.
741    baxis : int
742        When `b` has dimension greater than 1, it is viewed as a collection
743        of vectors. In this case, `baxis` is the axis of `b` that holds the
744        right-hand side vectors.
745    outaxis : int
746        When `c` or `b` are multidimensional, the value returned by
747        `solve_circulant` is multidimensional. In this case, `outaxis` is
748        the axis of the result that holds the solution vectors.
749
750    Returns
751    -------
752    x : ndarray
753        Solution to the system ``C x = b``.
754
755    Raises
756    ------
757    LinAlgError
758        If the circulant matrix associated with `c` is near singular.
759
760    See Also
761    --------
762    circulant : circulant matrix
763
764    Notes
765    -----
766    For a 1-D vector `c` with length `m`, and an array `b`
767    with shape ``(m, ...)``,
768
769        solve_circulant(c, b)
770
771    returns the same result as
772
773        solve(circulant(c), b)
774
775    where `solve` and `circulant` are from `scipy.linalg`.
776
777    .. versionadded:: 0.16.0
778
779    Examples
780    --------
781    >>> from scipy.linalg import solve_circulant, solve, circulant, lstsq
782
783    >>> c = np.array([2, 2, 4])
784    >>> b = np.array([1, 2, 3])
785    >>> solve_circulant(c, b)
786    array([ 0.75, -0.25,  0.25])
787
788    Compare that result to solving the system with `scipy.linalg.solve`:
789
790    >>> solve(circulant(c), b)
791    array([ 0.75, -0.25,  0.25])
792
793    A singular example:
794
795    >>> c = np.array([1, 1, 0, 0])
796    >>> b = np.array([1, 2, 3, 4])
797
798    Calling ``solve_circulant(c, b)`` will raise a `LinAlgError`.  For the
799    least square solution, use the option ``singular='lstsq'``:
800
801    >>> solve_circulant(c, b, singular='lstsq')
802    array([ 0.25,  1.25,  2.25,  1.25])
803
804    Compare to `scipy.linalg.lstsq`:
805
806    >>> x, resid, rnk, s = lstsq(circulant(c), b)
807    >>> x
808    array([ 0.25,  1.25,  2.25,  1.25])
809
810    A broadcasting example:
811
812    Suppose we have the vectors of two circulant matrices stored in an array
813    with shape (2, 5), and three `b` vectors stored in an array with shape
814    (3, 5).  For example,
815
816    >>> c = np.array([[1.5, 2, 3, 0, 0], [1, 1, 4, 3, 2]])
817    >>> b = np.arange(15).reshape(-1, 5)
818
819    We want to solve all combinations of circulant matrices and `b` vectors,
820    with the result stored in an array with shape (2, 3, 5). When we
821    disregard the axes of `c` and `b` that hold the vectors of coefficients,
822    the shapes of the collections are (2,) and (3,), respectively, which are
823    not compatible for broadcasting. To have a broadcast result with shape
824    (2, 3), we add a trivial dimension to `c`: ``c[:, np.newaxis, :]`` has
825    shape (2, 1, 5). The last dimension holds the coefficients of the
826    circulant matrices, so when we call `solve_circulant`, we can use the
827    default ``caxis=-1``. The coefficients of the `b` vectors are in the last
828    dimension of the array `b`, so we use ``baxis=-1``. If we use the
829    default `outaxis`, the result will have shape (5, 2, 3), so we'll use
830    ``outaxis=-1`` to put the solution vectors in the last dimension.
831
832    >>> x = solve_circulant(c[:, np.newaxis, :], b, baxis=-1, outaxis=-1)
833    >>> x.shape
834    (2, 3, 5)
835    >>> np.set_printoptions(precision=3)  # For compact output of numbers.
836    >>> x
837    array([[[-0.118,  0.22 ,  1.277, -0.142,  0.302],
838            [ 0.651,  0.989,  2.046,  0.627,  1.072],
839            [ 1.42 ,  1.758,  2.816,  1.396,  1.841]],
840           [[ 0.401,  0.304,  0.694, -0.867,  0.377],
841            [ 0.856,  0.758,  1.149, -0.412,  0.831],
842            [ 1.31 ,  1.213,  1.603,  0.042,  1.286]]])
843
844    Check by solving one pair of `c` and `b` vectors (cf. ``x[1, 1, :]``):
845
846    >>> solve_circulant(c[1], b[1, :])
847    array([ 0.856,  0.758,  1.149, -0.412,  0.831])
848
849    """
850    c = np.atleast_1d(c)
851    nc = _get_axis_len("c", c, caxis)
852    b = np.atleast_1d(b)
853    nb = _get_axis_len("b", b, baxis)
854    if nc != nb:
855        raise ValueError('Shapes of c {} and b {} are incompatible'
856                         .format(c.shape, b.shape))
857
858    fc = np.fft.fft(np.rollaxis(c, caxis, c.ndim), axis=-1)
859    abs_fc = np.abs(fc)
860    if tol is None:
861        # This is the same tolerance as used in np.linalg.matrix_rank.
862        tol = abs_fc.max(axis=-1) * nc * np.finfo(np.float64).eps
863        if tol.shape != ():
864            tol.shape = tol.shape + (1,)
865        else:
866            tol = np.atleast_1d(tol)
867
868    near_zeros = abs_fc <= tol
869    is_near_singular = np.any(near_zeros)
870    if is_near_singular:
871        if singular == 'raise':
872            raise LinAlgError("near singular circulant matrix.")
873        else:
874            # Replace the small values with 1 to avoid errors in the
875            # division fb/fc below.
876            fc[near_zeros] = 1
877
878    fb = np.fft.fft(np.rollaxis(b, baxis, b.ndim), axis=-1)
879
880    q = fb / fc
881
882    if is_near_singular:
883        # `near_zeros` is a boolean array, same shape as `c`, that is
884        # True where `fc` is (near) zero. `q` is the broadcasted result
885        # of fb / fc, so to set the values of `q` to 0 where `fc` is near
886        # zero, we use a mask that is the broadcast result of an array
887        # of True values shaped like `b` with `near_zeros`.
888        mask = np.ones_like(b, dtype=bool) & near_zeros
889        q[mask] = 0
890
891    x = np.fft.ifft(q, axis=-1)
892    if not (np.iscomplexobj(c) or np.iscomplexobj(b)):
893        x = x.real
894    if outaxis != -1:
895        x = np.rollaxis(x, -1, outaxis)
896    return x
897
898
899# matrix inversion
900def inv(a, overwrite_a=False, check_finite=True):
901    """
902    Compute the inverse of a matrix.
903
904    Parameters
905    ----------
906    a : array_like
907        Square matrix to be inverted.
908    overwrite_a : bool, optional
909        Discard data in `a` (may improve performance). Default is False.
910    check_finite : bool, optional
911        Whether to check that the input matrix contains only finite numbers.
912        Disabling may give a performance gain, but may result in problems
913        (crashes, non-termination) if the inputs do contain infinities or NaNs.
914
915    Returns
916    -------
917    ainv : ndarray
918        Inverse of the matrix `a`.
919
920    Raises
921    ------
922    LinAlgError
923        If `a` is singular.
924    ValueError
925        If `a` is not square, or not 2D.
926
927    Examples
928    --------
929    >>> from scipy import linalg
930    >>> a = np.array([[1., 2.], [3., 4.]])
931    >>> linalg.inv(a)
932    array([[-2. ,  1. ],
933           [ 1.5, -0.5]])
934    >>> np.dot(a, linalg.inv(a))
935    array([[ 1.,  0.],
936           [ 0.,  1.]])
937
938    """
939    a1 = _asarray_validated(a, check_finite=check_finite)
940    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
941        raise ValueError('expected square matrix')
942    overwrite_a = overwrite_a or _datacopied(a1, a)
943    # XXX: I found no advantage or disadvantage of using finv.
944#     finv, = get_flinalg_funcs(('inv',),(a1,))
945#     if finv is not None:
946#         a_inv,info = finv(a1,overwrite_a=overwrite_a)
947#         if info==0:
948#             return a_inv
949#         if info>0: raise LinAlgError, "singular matrix"
950#         if info<0: raise ValueError('illegal value in %d-th argument of '
951#                                     'internal inv.getrf|getri'%(-info))
952    getrf, getri, getri_lwork = get_lapack_funcs(('getrf', 'getri',
953                                                  'getri_lwork'),
954                                                 (a1,))
955    lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
956    if info == 0:
957        lwork = _compute_lwork(getri_lwork, a1.shape[0])
958
959        # XXX: the following line fixes curious SEGFAULT when
960        # benchmarking 500x500 matrix inverse. This seems to
961        # be a bug in LAPACK ?getri routine because if lwork is
962        # minimal (when using lwork[0] instead of lwork[1]) then
963        # all tests pass. Further investigation is required if
964        # more such SEGFAULTs occur.
965        lwork = int(1.01 * lwork)
966        inv_a, info = getri(lu, piv, lwork=lwork, overwrite_lu=1)
967    if info > 0:
968        raise LinAlgError("singular matrix")
969    if info < 0:
970        raise ValueError('illegal value in %d-th argument of internal '
971                         'getrf|getri' % -info)
972    return inv_a
973
974
975# Determinant
976
977def det(a, overwrite_a=False, check_finite=True):
978    """
979    Compute the determinant of a matrix
980
981    The determinant of a square matrix is a value derived arithmetically
982    from the coefficients of the matrix.
983
984    The determinant for a 3x3 matrix, for example, is computed as follows::
985
986        a    b    c
987        d    e    f = A
988        g    h    i
989
990        det(A) = a*e*i + b*f*g + c*d*h - c*e*g - b*d*i - a*f*h
991
992    Parameters
993    ----------
994    a : (M, M) array_like
995        A square matrix.
996    overwrite_a : bool, optional
997        Allow overwriting data in a (may enhance performance).
998    check_finite : bool, optional
999        Whether to check that the input matrix contains only finite numbers.
1000        Disabling may give a performance gain, but may result in problems
1001        (crashes, non-termination) if the inputs do contain infinities or NaNs.
1002
1003    Returns
1004    -------
1005    det : float or complex
1006        Determinant of `a`.
1007
1008    Notes
1009    -----
1010    The determinant is computed via LU factorization, LAPACK routine z/dgetrf.
1011
1012    Examples
1013    --------
1014    >>> from scipy import linalg
1015    >>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
1016    >>> linalg.det(a)
1017    0.0
1018    >>> a = np.array([[0,2,3], [4,5,6], [7,8,9]])
1019    >>> linalg.det(a)
1020    3.0
1021
1022    """
1023    a1 = _asarray_validated(a, check_finite=check_finite)
1024    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
1025        raise ValueError('expected square matrix')
1026    overwrite_a = overwrite_a or _datacopied(a1, a)
1027    fdet, = get_flinalg_funcs(('det',), (a1,))
1028    a_det, info = fdet(a1, overwrite_a=overwrite_a)
1029    if info < 0:
1030        raise ValueError('illegal value in %d-th argument of internal '
1031                         'det.getrf' % -info)
1032    return a_det
1033
1034
1035# Linear Least Squares
1036def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
1037          check_finite=True, lapack_driver=None):
1038    """
1039    Compute least-squares solution to equation Ax = b.
1040
1041    Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
1042
1043    Parameters
1044    ----------
1045    a : (M, N) array_like
1046        Left-hand side array
1047    b : (M,) or (M, K) array_like
1048        Right hand side array
1049    cond : float, optional
1050        Cutoff for 'small' singular values; used to determine effective
1051        rank of a. Singular values smaller than
1052        ``rcond * largest_singular_value`` are considered zero.
1053    overwrite_a : bool, optional
1054        Discard data in `a` (may enhance performance). Default is False.
1055    overwrite_b : bool, optional
1056        Discard data in `b` (may enhance performance). Default is False.
1057    check_finite : bool, optional
1058        Whether to check that the input matrices contain only finite numbers.
1059        Disabling may give a performance gain, but may result in problems
1060        (crashes, non-termination) if the inputs do contain infinities or NaNs.
1061    lapack_driver : str, optional
1062        Which LAPACK driver is used to solve the least-squares problem.
1063        Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default
1064        (``'gelsd'``) is a good choice.  However, ``'gelsy'`` can be slightly
1065        faster on many problems.  ``'gelss'`` was used historically.  It is
1066        generally slow but uses less memory.
1067
1068        .. versionadded:: 0.17.0
1069
1070    Returns
1071    -------
1072    x : (N,) or (N, K) ndarray
1073        Least-squares solution.  Return shape matches shape of `b`.
1074    residues : (K,) ndarray or float
1075        Square of the 2-norm for each column in ``b - a x``, if ``M > N`` and
1076        ``ndim(A) == n`` (returns a scalar if b is 1-D). Otherwise a
1077        (0,)-shaped array is returned.
1078    rank : int
1079        Effective rank of `a`.
1080    s : (min(M, N),) ndarray or None
1081        Singular values of `a`. The condition number of a is
1082        ``abs(s[0] / s[-1])``.
1083
1084    Raises
1085    ------
1086    LinAlgError
1087        If computation does not converge.
1088
1089    ValueError
1090        When parameters are not compatible.
1091
1092    See Also
1093    --------
1094    scipy.optimize.nnls : linear least squares with non-negativity constraint
1095
1096    Notes
1097    -----
1098    When ``'gelsy'`` is used as a driver, `residues` is set to a (0,)-shaped
1099    array and `s` is always ``None``.
1100
1101    Examples
1102    --------
1103    >>> from scipy.linalg import lstsq
1104    >>> import matplotlib.pyplot as plt
1105
1106    Suppose we have the following data:
1107
1108    >>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5])
1109    >>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
1110
1111    We want to fit a quadratic polynomial of the form ``y = a + b*x**2``
1112    to this data.  We first form the "design matrix" M, with a constant
1113    column of 1s and a column containing ``x**2``:
1114
1115    >>> M = x[:, np.newaxis]**[0, 2]
1116    >>> M
1117    array([[  1.  ,   1.  ],
1118           [  1.  ,   6.25],
1119           [  1.  ,  12.25],
1120           [  1.  ,  16.  ],
1121           [  1.  ,  25.  ],
1122           [  1.  ,  49.  ],
1123           [  1.  ,  72.25]])
1124
1125    We want to find the least-squares solution to ``M.dot(p) = y``,
1126    where ``p`` is a vector with length 2 that holds the parameters
1127    ``a`` and ``b``.
1128
1129    >>> p, res, rnk, s = lstsq(M, y)
1130    >>> p
1131    array([ 0.20925829,  0.12013861])
1132
1133    Plot the data and the fitted curve.
1134
1135    >>> plt.plot(x, y, 'o', label='data')
1136    >>> xx = np.linspace(0, 9, 101)
1137    >>> yy = p[0] + p[1]*xx**2
1138    >>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$')
1139    >>> plt.xlabel('x')
1140    >>> plt.ylabel('y')
1141    >>> plt.legend(framealpha=1, shadow=True)
1142    >>> plt.grid(alpha=0.25)
1143    >>> plt.show()
1144
1145    """
1146    a1 = _asarray_validated(a, check_finite=check_finite)
1147    b1 = _asarray_validated(b, check_finite=check_finite)
1148    if len(a1.shape) != 2:
1149        raise ValueError('Input array a should be 2D')
1150    m, n = a1.shape
1151    if len(b1.shape) == 2:
1152        nrhs = b1.shape[1]
1153    else:
1154        nrhs = 1
1155    if m != b1.shape[0]:
1156        raise ValueError('Shape mismatch: a and b should have the same number'
1157                         ' of rows ({} != {}).'.format(m, b1.shape[0]))
1158    if m == 0 or n == 0:  # Zero-sized problem, confuses LAPACK
1159        x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
1160        if n == 0:
1161            residues = np.linalg.norm(b1, axis=0)**2
1162        else:
1163            residues = np.empty((0,))
1164        return x, residues, 0, np.empty((0,))
1165
1166    driver = lapack_driver
1167    if driver is None:
1168        driver = lstsq.default_lapack_driver
1169    if driver not in ('gelsd', 'gelsy', 'gelss'):
1170        raise ValueError('LAPACK driver "%s" is not found' % driver)
1171
1172    lapack_func, lapack_lwork = get_lapack_funcs((driver,
1173                                                 '%s_lwork' % driver),
1174                                                 (a1, b1))
1175    real_data = True if (lapack_func.dtype.kind == 'f') else False
1176
1177    if m < n:
1178        # need to extend b matrix as it will be filled with
1179        # a larger solution matrix
1180        if len(b1.shape) == 2:
1181            b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
1182            b2[:m, :] = b1
1183        else:
1184            b2 = np.zeros(n, dtype=lapack_func.dtype)
1185            b2[:m] = b1
1186        b1 = b2
1187
1188    overwrite_a = overwrite_a or _datacopied(a1, a)
1189    overwrite_b = overwrite_b or _datacopied(b1, b)
1190
1191    if cond is None:
1192        cond = np.finfo(lapack_func.dtype).eps
1193
1194    if driver in ('gelss', 'gelsd'):
1195        if driver == 'gelss':
1196            lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
1197            v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
1198                                                    overwrite_a=overwrite_a,
1199                                                    overwrite_b=overwrite_b)
1200
1201        elif driver == 'gelsd':
1202            if real_data:
1203                lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
1204                x, s, rank, info = lapack_func(a1, b1, lwork,
1205                                               iwork, cond, False, False)
1206            else:  # complex data
1207                lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
1208                                                     nrhs, cond)
1209                x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
1210                                               cond, False, False)
1211        if info > 0:
1212            raise LinAlgError("SVD did not converge in Linear Least Squares")
1213        if info < 0:
1214            raise ValueError('illegal value in %d-th argument of internal %s'
1215                             % (-info, lapack_driver))
1216        resids = np.asarray([], dtype=x.dtype)
1217        if m > n:
1218            x1 = x[:n]
1219            if rank == n:
1220                resids = np.sum(np.abs(x[n:])**2, axis=0)
1221            x = x1
1222        return x, resids, rank, s
1223
1224    elif driver == 'gelsy':
1225        lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
1226        jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
1227        v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
1228                                          lwork, False, False)
1229        if info < 0:
1230            raise ValueError("illegal value in %d-th argument of internal "
1231                             "gelsy" % -info)
1232        if m > n:
1233            x1 = x[:n]
1234            x = x1
1235        return x, np.array([], x.dtype), rank, None
1236
1237
1238lstsq.default_lapack_driver = 'gelsd'
1239
1240
1241def pinv(a, atol=None, rtol=None, return_rank=False, check_finite=True,
1242         cond=None, rcond=None):
1243    """
1244    Compute the (Moore-Penrose) pseudo-inverse of a matrix.
1245
1246    Calculate a generalized inverse of a matrix using its
1247    singular-value decomposition ``U @ S @ V`` in the economy mode and picking
1248    up only the columns/rows that are associated with significant singular
1249    values.
1250
1251    If ``s`` is the maximum singular value of ``a``, then the
1252    significance cut-off value is determined by ``atol + rtol * s``. Any
1253    singular value below this value is assumed insignificant.
1254
1255    Parameters
1256    ----------
1257    a : (M, N) array_like
1258        Matrix to be pseudo-inverted.
1259    atol: float, optional
1260        Absolute threshold term, default value is 0.
1261
1262        .. versionadded:: 1.7.0
1263
1264    rtol: float, optional
1265        Relative threshold term, default value is ``max(M, N) * eps`` where
1266        ``eps`` is the machine precision value of the datatype of ``a``.
1267
1268        .. versionadded:: 1.7.0
1269
1270    return_rank : bool, optional
1271        If True, return the effective rank of the matrix.
1272    check_finite : bool, optional
1273        Whether to check that the input matrix contains only finite numbers.
1274        Disabling may give a performance gain, but may result in problems
1275        (crashes, non-termination) if the inputs do contain infinities or NaNs.
1276    cond, rcond : float, optional
1277        In older versions, these values were meant to be used as ``atol`` with
1278        ``rtol=0``. If both were given ``rcond`` overwrote ``cond`` and hence
1279        the code was not correct. Thus using these are strongly discouraged and
1280        the tolerances above are recommended instead. In fact, if provided,
1281        atol, rtol takes precedence over these keywords.
1282
1283        .. versionchanged:: 1.7.0
1284            Deprecated in favor of ``rtol`` and ``atol`` parameters above and
1285            will be removed in future versions of SciPy.
1286
1287        .. versionchanged:: 1.3.0
1288            Previously the default cutoff value was just ``eps*f`` where ``f``
1289            was ``1e3`` for single precision and ``1e6`` for double precision.
1290
1291    Returns
1292    -------
1293    B : (N, M) ndarray
1294        The pseudo-inverse of matrix `a`.
1295    rank : int
1296        The effective rank of the matrix. Returned if `return_rank` is True.
1297
1298    Raises
1299    ------
1300    LinAlgError
1301        If SVD computation does not converge.
1302
1303    Examples
1304    --------
1305    >>> from scipy import linalg
1306    >>> rng = np.random.default_rng()
1307    >>> a = rng.standard_normal((9, 6))
1308    >>> B = linalg.pinv(a)
1309    >>> np.allclose(a, a @ B @ a)
1310    True
1311    >>> np.allclose(B, B @ a @ B)
1312    True
1313
1314    """
1315    a = _asarray_validated(a, check_finite=check_finite)
1316    u, s, vh = decomp_svd.svd(a, full_matrices=False, check_finite=False)
1317    t = u.dtype.char.lower()
1318    maxS = np.max(s)
1319
1320    if rcond or cond:
1321        warn('Use of the "cond" and "rcond" keywords are deprecated and '
1322             'will be removed in future versions of SciPy. Use "atol" and '
1323             '"rtol" keywords instead', DeprecationWarning, stacklevel=2)
1324
1325    # backwards compatible only atol and rtol are both missing
1326    if (rcond or cond) and (atol is None) and (rtol is None):
1327        atol = rcond or cond
1328        rtol = 0.
1329
1330    atol = 0. if atol is None else atol
1331    rtol = max(a.shape) * np.finfo(t).eps if (rtol is None) else rtol
1332
1333    if (atol < 0.) or (rtol < 0.):
1334        raise ValueError("atol and rtol values must be positive.")
1335
1336    val = atol + maxS * rtol
1337    rank = np.sum(s > val)
1338
1339    u = u[:, :rank]
1340    u /= s[:rank]
1341    B = (u @ vh[:rank]).conj().T
1342
1343    if return_rank:
1344        return B, rank
1345    else:
1346        return B
1347
1348
1349def pinv2(a, cond=None, rcond=None, return_rank=False, check_finite=True):
1350    """
1351    Compute the (Moore-Penrose) pseudo-inverse of a matrix.
1352
1353    `scipy.linalg.pinv2` is deprecated since SciPy 1.7.0, use
1354    `scipy.linalg.pinv` instead for better tolerance control.
1355
1356    Calculate a generalized inverse of a matrix using its
1357    singular-value decomposition and including all 'large' singular
1358    values.
1359
1360    Parameters
1361    ----------
1362    a : (M, N) array_like
1363        Matrix to be pseudo-inverted.
1364    cond, rcond : float or None
1365        Cutoff for 'small' singular values; singular values smaller than this
1366        value are considered as zero. If both are omitted, the default value
1367        ``max(M,N)*largest_singular_value*eps`` is used where ``eps`` is the
1368        machine precision value of the datatype of ``a``.
1369
1370        .. versionchanged:: 1.3.0
1371            Previously the default cutoff value was just ``eps*f`` where ``f``
1372            was ``1e3`` for single precision and ``1e6`` for double precision.
1373
1374    return_rank : bool, optional
1375        If True, return the effective rank of the matrix.
1376    check_finite : bool, optional
1377        Whether to check that the input matrix contains only finite numbers.
1378        Disabling may give a performance gain, but may result in problems
1379        (crashes, non-termination) if the inputs do contain infinities or NaNs.
1380
1381    Returns
1382    -------
1383    B : (N, M) ndarray
1384        The pseudo-inverse of matrix `a`.
1385    rank : int
1386        The effective rank of the matrix. Returned if `return_rank` is True.
1387
1388    Raises
1389    ------
1390    LinAlgError
1391        If SVD computation does not converge.
1392
1393    """
1394    # SciPy 1.7.0 2021-04-10
1395    warn('scipy.linalg.pinv2 is deprecated since SciPy 1.7.0, use '
1396         'scipy.linalg.pinv instead', DeprecationWarning, stacklevel=2)
1397    if rcond is not None:
1398        cond = rcond
1399
1400    return pinv(a=a, atol=cond, rtol=None, return_rank=return_rank,
1401                check_finite=check_finite)
1402
1403
1404def pinvh(a, atol=None, rtol=None, lower=True, return_rank=False,
1405          check_finite=True, cond=None, rcond=None):
1406    """
1407    Compute the (Moore-Penrose) pseudo-inverse of a Hermitian matrix.
1408
1409    Calculate a generalized inverse of a copmlex Hermitian/real symmetric
1410    matrix using its eigenvalue decomposition and including all eigenvalues
1411    with 'large' absolute value.
1412
1413    Parameters
1414    ----------
1415    a : (N, N) array_like
1416        Real symmetric or complex hermetian matrix to be pseudo-inverted
1417    atol: float, optional
1418        Absolute threshold term, default value is 0.
1419
1420        .. versionadded:: 1.7.0
1421
1422    rtol: float, optional
1423        Relative threshold term, default value is ``N * eps`` where
1424        ``eps`` is the machine precision value of the datatype of ``a``.
1425
1426        .. versionadded:: 1.7.0
1427
1428    lower : bool, optional
1429        Whether the pertinent array data is taken from the lower or upper
1430        triangle of `a`. (Default: lower)
1431    return_rank : bool, optional
1432        If True, return the effective rank of the matrix.
1433    check_finite : bool, optional
1434        Whether to check that the input matrix contains only finite numbers.
1435        Disabling may give a performance gain, but may result in problems
1436        (crashes, non-termination) if the inputs do contain infinities or NaNs.
1437    cond, rcond : float, optional
1438        In older versions, these values were meant to be used as ``atol`` with
1439        ``rtol=0``. If both were given ``rcond`` overwrote ``cond`` and hence
1440        the code was not correct. Thus using these are strongly discouraged and
1441        the tolerances above are recommended instead.  In fact, if provided,
1442        atol, rtol takes precedence over these keywords.
1443
1444        .. versionchanged:: 1.7.0
1445            Deprecated in favor of ``rtol`` and ``atol`` parameters above and
1446            will be removed in future versions of SciPy.
1447
1448        .. versionchanged:: 1.3.0
1449            Previously the default cutoff value was just ``eps*f`` where ``f``
1450            was ``1e3`` for single precision and ``1e6`` for double precision.
1451
1452    Returns
1453    -------
1454    B : (N, N) ndarray
1455        The pseudo-inverse of matrix `a`.
1456    rank : int
1457        The effective rank of the matrix.  Returned if `return_rank` is True.
1458
1459    Raises
1460    ------
1461    LinAlgError
1462        If eigenvalue algorithm does not converge.
1463
1464    Examples
1465    --------
1466    >>> from scipy.linalg import pinvh
1467    >>> rng = np.random.default_rng()
1468    >>> a = rng.standard_normal((9, 6))
1469    >>> a = np.dot(a, a.T)
1470    >>> B = pinvh(a)
1471    >>> np.allclose(a, a @ B @ a)
1472    True
1473    >>> np.allclose(B, B @ a @ B)
1474    True
1475
1476    """
1477    a = _asarray_validated(a, check_finite=check_finite)
1478    s, u = decomp.eigh(a, lower=lower, check_finite=False)
1479    t = u.dtype.char.lower()
1480    maxS = np.max(np.abs(s))
1481
1482    if rcond or cond:
1483        warn('Use of the "cond" and "rcond" keywords are deprecated and '
1484             'will be removed in future versions of SciPy. Use "atol" and '
1485             '"rtol" keywords instead', DeprecationWarning, stacklevel=2)
1486
1487    # backwards compatible only atol and rtol are both missing
1488    if (rcond or cond) and (atol is None) and (rtol is None):
1489        atol = rcond or cond
1490        rtol = 0.
1491
1492    atol = 0. if atol is None else atol
1493    rtol = max(a.shape) * np.finfo(t).eps if (rtol is None) else rtol
1494
1495    if (atol < 0.) or (rtol < 0.):
1496        raise ValueError("atol and rtol values must be positive.")
1497
1498    val = atol + maxS * rtol
1499    above_cutoff = (abs(s) > val)
1500
1501    psigma_diag = 1.0 / s[above_cutoff]
1502    u = u[:, above_cutoff]
1503
1504    B = (u * psigma_diag) @ u.conj().T
1505
1506    if return_rank:
1507        return B, len(psigma_diag)
1508    else:
1509        return B
1510
1511
1512def matrix_balance(A, permute=True, scale=True, separate=False,
1513                   overwrite_a=False):
1514    """
1515    Compute a diagonal similarity transformation for row/column balancing.
1516
1517    The balancing tries to equalize the row and column 1-norms by applying
1518    a similarity transformation such that the magnitude variation of the
1519    matrix entries is reflected to the scaling matrices.
1520
1521    Moreover, if enabled, the matrix is first permuted to isolate the upper
1522    triangular parts of the matrix and, again if scaling is also enabled,
1523    only the remaining subblocks are subjected to scaling.
1524
1525    The balanced matrix satisfies the following equality
1526
1527    .. math::
1528
1529                        B = T^{-1} A T
1530
1531    The scaling coefficients are approximated to the nearest power of 2
1532    to avoid round-off errors.
1533
1534    Parameters
1535    ----------
1536    A : (n, n) array_like
1537        Square data matrix for the balancing.
1538    permute : bool, optional
1539        The selector to define whether permutation of A is also performed
1540        prior to scaling.
1541    scale : bool, optional
1542        The selector to turn on and off the scaling. If False, the matrix
1543        will not be scaled.
1544    separate : bool, optional
1545        This switches from returning a full matrix of the transformation
1546        to a tuple of two separate 1-D permutation and scaling arrays.
1547    overwrite_a : bool, optional
1548        This is passed to xGEBAL directly. Essentially, overwrites the result
1549        to the data. It might increase the space efficiency. See LAPACK manual
1550        for details. This is False by default.
1551
1552    Returns
1553    -------
1554    B : (n, n) ndarray
1555        Balanced matrix
1556    T : (n, n) ndarray
1557        A possibly permuted diagonal matrix whose nonzero entries are
1558        integer powers of 2 to avoid numerical truncation errors.
1559    scale, perm : (n,) ndarray
1560        If ``separate`` keyword is set to True then instead of the array
1561        ``T`` above, the scaling and the permutation vectors are given
1562        separately as a tuple without allocating the full array ``T``.
1563
1564    Notes
1565    -----
1566
1567    This algorithm is particularly useful for eigenvalue and matrix
1568    decompositions and in many cases it is already called by various
1569    LAPACK routines.
1570
1571    The algorithm is based on the well-known technique of [1]_ and has
1572    been modified to account for special cases. See [2]_ for details
1573    which have been implemented since LAPACK v3.5.0. Before this version
1574    there are corner cases where balancing can actually worsen the
1575    conditioning. See [3]_ for such examples.
1576
1577    The code is a wrapper around LAPACK's xGEBAL routine family for matrix
1578    balancing.
1579
1580    .. versionadded:: 0.19.0
1581
1582    Examples
1583    --------
1584    >>> from scipy import linalg
1585    >>> x = np.array([[1,2,0], [9,1,0.01], [1,2,10*np.pi]])
1586
1587    >>> y, permscale = linalg.matrix_balance(x)
1588    >>> np.abs(x).sum(axis=0) / np.abs(x).sum(axis=1)
1589    array([ 3.66666667,  0.4995005 ,  0.91312162])
1590
1591    >>> np.abs(y).sum(axis=0) / np.abs(y).sum(axis=1)
1592    array([ 1.2       ,  1.27041742,  0.92658316])  # may vary
1593
1594    >>> permscale  # only powers of 2 (0.5 == 2^(-1))
1595    array([[  0.5,   0. ,  0. ],  # may vary
1596           [  0. ,   1. ,  0. ],
1597           [  0. ,   0. ,  1. ]])
1598
1599    References
1600    ----------
1601    .. [1] : B.N. Parlett and C. Reinsch, "Balancing a Matrix for
1602       Calculation of Eigenvalues and Eigenvectors", Numerische Mathematik,
1603       Vol.13(4), 1969, :doi:`10.1007/BF02165404`
1604
1605    .. [2] : R. James, J. Langou, B.R. Lowery, "On matrix balancing and
1606       eigenvector computation", 2014, :arxiv:`1401.5766`
1607
1608    .. [3] :  D.S. Watkins. A case where balancing is harmful.
1609       Electron. Trans. Numer. Anal, Vol.23, 2006.
1610
1611    """
1612
1613    A = np.atleast_2d(_asarray_validated(A, check_finite=True))
1614
1615    if not np.equal(*A.shape):
1616        raise ValueError('The data matrix for balancing should be square.')
1617
1618    gebal = get_lapack_funcs(('gebal'), (A,))
1619    B, lo, hi, ps, info = gebal(A, scale=scale, permute=permute,
1620                                overwrite_a=overwrite_a)
1621
1622    if info < 0:
1623        raise ValueError('xGEBAL exited with the internal error '
1624                         '"illegal value in argument number {}.". See '
1625                         'LAPACK documentation for the xGEBAL error codes.'
1626                         ''.format(-info))
1627
1628    # Separate the permutations from the scalings and then convert to int
1629    scaling = np.ones_like(ps, dtype=float)
1630    scaling[lo:hi+1] = ps[lo:hi+1]
1631
1632    # gebal uses 1-indexing
1633    ps = ps.astype(int, copy=False) - 1
1634    n = A.shape[0]
1635    perm = np.arange(n)
1636
1637    # LAPACK permutes with the ordering n --> hi, then 0--> lo
1638    if hi < n:
1639        for ind, x in enumerate(ps[hi+1:][::-1], 1):
1640            if n-ind == x:
1641                continue
1642            perm[[x, n-ind]] = perm[[n-ind, x]]
1643
1644    if lo > 0:
1645        for ind, x in enumerate(ps[:lo]):
1646            if ind == x:
1647                continue
1648            perm[[x, ind]] = perm[[ind, x]]
1649
1650    if separate:
1651        return B, (scaling, perm)
1652
1653    # get the inverse permutation
1654    iperm = np.empty_like(perm)
1655    iperm[perm] = np.arange(n)
1656
1657    return B, np.diag(scaling)[iperm, :]
1658
1659
1660def _validate_args_for_toeplitz_ops(c_or_cr, b, check_finite, keep_b_shape,
1661                                    enforce_square=True):
1662    """Validate arguments and format inputs for toeplitz functions
1663
1664    Parameters
1665    ----------
1666    c_or_cr : array_like or tuple of (array_like, array_like)
1667        The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
1668        actual shape of ``c``, it will be converted to a 1-D array. If not
1669        supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
1670        real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
1671        of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
1672        of ``r``, it will be converted to a 1-D array.
1673    b : (M,) or (M, K) array_like
1674        Right-hand side in ``T x = b``.
1675    check_finite : bool
1676        Whether to check that the input matrices contain only finite numbers.
1677        Disabling may give a performance gain, but may result in problems
1678        (result entirely NaNs) if the inputs do contain infinities or NaNs.
1679    keep_b_shape: bool
1680        Whether to convert a (M,) dimensional b into a (M, 1) dimensional
1681        matrix.
1682    enforce_square: bool, optional
1683        If True (default), this verifies that the Toeplitz matrix is square.
1684
1685    Returns
1686    -------
1687    r : array
1688        1d array corresponding to the first row of the Toeplitz matrix.
1689    c: array
1690        1d array corresponding to the first column of the Toeplitz matrix.
1691    b: array
1692        (M,), (M, 1) or (M, K) dimensional array, post validation,
1693        corresponding to ``b``.
1694    dtype: numpy datatype
1695        ``dtype`` stores the datatype of ``r``, ``c`` and ``b``. If any of
1696        ``r``, ``c`` or ``b`` are complex, ``dtype`` is ``np.complex128``,
1697        otherwise, it is ``np.float``.
1698    b_shape: tuple
1699        Shape of ``b`` after passing it through ``_asarray_validated``.
1700
1701    """
1702
1703    if isinstance(c_or_cr, tuple):
1704        c, r = c_or_cr
1705        c = _asarray_validated(c, check_finite=check_finite).ravel()
1706        r = _asarray_validated(r, check_finite=check_finite).ravel()
1707    else:
1708        c = _asarray_validated(c_or_cr, check_finite=check_finite).ravel()
1709        r = c.conjugate()
1710
1711    if b is None:
1712        raise ValueError('`b` must be an array, not None.')
1713
1714    b = _asarray_validated(b, check_finite=check_finite)
1715    b_shape = b.shape
1716
1717    is_not_square = r.shape[0] != c.shape[0]
1718    if (enforce_square and is_not_square) or b.shape[0] != r.shape[0]:
1719        raise ValueError('Incompatible dimensions.')
1720
1721    is_cmplx = np.iscomplexobj(r) or np.iscomplexobj(c) or np.iscomplexobj(b)
1722    dtype = np.complex128 if is_cmplx else np.double
1723    r, c, b = (np.asarray(i, dtype=dtype) for i in (r, c, b))
1724
1725    if b.ndim == 1 and not keep_b_shape:
1726        b = b.reshape(-1, 1)
1727    elif b.ndim != 1:
1728        b = b.reshape(b.shape[0], -1)
1729
1730    return r, c, b, dtype, b_shape
1731
1732
1733def matmul_toeplitz(c_or_cr, x, check_finite=False, workers=None):
1734    """Efficient Toeplitz Matrix-Matrix Multiplication using FFT
1735
1736    This function returns the matrix multiplication between a Toeplitz
1737    matrix and a dense matrix.
1738
1739    The Toeplitz matrix has constant diagonals, with c as its first column
1740    and r as its first row. If r is not given, ``r == conjugate(c)`` is
1741    assumed.
1742
1743    Parameters
1744    ----------
1745    c_or_cr : array_like or tuple of (array_like, array_like)
1746        The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
1747        actual shape of ``c``, it will be converted to a 1-D array. If not
1748        supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
1749        real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
1750        of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
1751        of ``r``, it will be converted to a 1-D array.
1752    x : (M,) or (M, K) array_like
1753        Matrix with which to multiply.
1754    check_finite : bool, optional
1755        Whether to check that the input matrices contain only finite numbers.
1756        Disabling may give a performance gain, but may result in problems
1757        (result entirely NaNs) if the inputs do contain infinities or NaNs.
1758    workers : int, optional
1759        To pass to scipy.fft.fft and ifft. Maximum number of workers to use
1760        for parallel computation. If negative, the value wraps around from
1761        ``os.cpu_count()``. See scipy.fft.fft for more details.
1762
1763    Returns
1764    -------
1765    T @ x : (M,) or (M, K) ndarray
1766        The result of the matrix multiplication ``T @ x``. Shape of return
1767        matches shape of `x`.
1768
1769    See Also
1770    --------
1771    toeplitz : Toeplitz matrix
1772    solve_toeplitz : Solve a Toeplitz system using Levinson Recursion
1773
1774    Notes
1775    -----
1776    The Toeplitz matrix is embedded in a circulant matrix and the FFT is used
1777    to efficiently calculate the matrix-matrix product.
1778
1779    Because the computation is based on the FFT, integer inputs will
1780    result in floating point outputs.  This is unlike NumPy's `matmul`,
1781    which preserves the data type of the input.
1782
1783    This is partly based on the implementation that can be found in [1]_,
1784    licensed under the MIT license. More information about the method can be
1785    found in reference [2]_. References [3]_ and [4]_ have more reference
1786    implementations in Python.
1787
1788    .. versionadded:: 1.6.0
1789
1790    References
1791    ----------
1792    .. [1] Jacob R Gardner, Geoff Pleiss, David Bindel, Kilian
1793       Q Weinberger, Andrew Gordon Wilson, "GPyTorch: Blackbox Matrix-Matrix
1794       Gaussian Process Inference with GPU Acceleration" with contributions
1795       from Max Balandat and Ruihan Wu. Available online:
1796       https://github.com/cornellius-gp/gpytorch
1797
1798    .. [2] J. Demmel, P. Koev, and X. Li, "A Brief Survey of Direct Linear
1799       Solvers". In Z. Bai, J. Demmel, J. Dongarra, A. Ruhe, and H. van der
1800       Vorst, editors. Templates for the Solution of Algebraic Eigenvalue
1801       Problems: A Practical Guide. SIAM, Philadelphia, 2000. Available at:
1802       http://www.netlib.org/utk/people/JackDongarra/etemplates/node384.html
1803
1804    .. [3] R. Scheibler, E. Bezzam, I. Dokmanic, Pyroomacoustics: A Python
1805       package for audio room simulations and array processing algorithms,
1806       Proc. IEEE ICASSP, Calgary, CA, 2018.
1807       https://github.com/LCAV/pyroomacoustics/blob/pypi-release/
1808       pyroomacoustics/adaptive/util.py
1809
1810    .. [4] Marano S, Edwards B, Ferrari G and Fah D (2017), "Fitting
1811       Earthquake Spectra: Colored Noise and Incomplete Data", Bulletin of
1812       the Seismological Society of America., January, 2017. Vol. 107(1),
1813       pp. 276-291.
1814
1815    Examples
1816    --------
1817    Multiply the Toeplitz matrix T with matrix x::
1818
1819            [ 1 -1 -2 -3]       [1 10]
1820        T = [ 3  1 -1 -2]   x = [2 11]
1821            [ 6  3  1 -1]       [2 11]
1822            [10  6  3  1]       [5 19]
1823
1824    To specify the Toeplitz matrix, only the first column and the first
1825    row are needed.
1826
1827    >>> c = np.array([1, 3, 6, 10])    # First column of T
1828    >>> r = np.array([1, -1, -2, -3])  # First row of T
1829    >>> x = np.array([[1, 10], [2, 11], [2, 11], [5, 19]])
1830
1831    >>> from scipy.linalg import toeplitz, matmul_toeplitz
1832    >>> matmul_toeplitz((c, r), x)
1833    array([[-20., -80.],
1834           [ -7.,  -8.],
1835           [  9.,  85.],
1836           [ 33., 218.]])
1837
1838    Check the result by creating the full Toeplitz matrix and
1839    multiplying it by ``x``.
1840
1841    >>> toeplitz(c, r) @ x
1842    array([[-20, -80],
1843           [ -7,  -8],
1844           [  9,  85],
1845           [ 33, 218]])
1846
1847    The full matrix is never formed explicitly, so this routine
1848    is suitable for very large Toeplitz matrices.
1849
1850    >>> n = 1000000
1851    >>> matmul_toeplitz([1] + [0]*(n-1), np.ones(n))
1852    array([1., 1., 1., ..., 1., 1., 1.])
1853
1854    """
1855
1856    from ..fft import fft, ifft, rfft, irfft
1857
1858    r, c, x, dtype, x_shape = _validate_args_for_toeplitz_ops(
1859        c_or_cr, x, check_finite, keep_b_shape=False, enforce_square=False)
1860    n, m = x.shape
1861
1862    T_nrows = len(c)
1863    T_ncols = len(r)
1864    p = T_nrows + T_ncols - 1  # equivalent to len(embedded_col)
1865
1866    embedded_col = np.concatenate((c, r[-1:0:-1]))
1867
1868    if np.iscomplexobj(embedded_col) or np.iscomplexobj(x):
1869        fft_mat = fft(embedded_col, axis=0, workers=workers).reshape(-1, 1)
1870        fft_x = fft(x, n=p, axis=0, workers=workers)
1871
1872        mat_times_x = ifft(fft_mat*fft_x, axis=0,
1873                           workers=workers)[:T_nrows, :]
1874    else:
1875        # Real inputs; using rfft is faster
1876        fft_mat = rfft(embedded_col, axis=0, workers=workers).reshape(-1, 1)
1877        fft_x = rfft(x, n=p, axis=0, workers=workers)
1878
1879        mat_times_x = irfft(fft_mat*fft_x, axis=0,
1880                            workers=workers, n=p)[:T_nrows, :]
1881
1882    return_shape = (T_nrows,) if len(x_shape) == 1 else (T_nrows, m)
1883    return mat_times_x.reshape(*return_shape)
1884