1"""
2Implementation of linear algebra operations.
3"""
4
5
6import contextlib
7
8from llvmlite import ir
9
10import numpy as np
11import operator
12
13from numba.core.imputils import (lower_builtin, impl_ret_borrowed,
14                                    impl_ret_new_ref, impl_ret_untracked)
15from numba.core.typing import signature
16from numba.core.extending import overload, register_jitable
17from numba.core import types, cgutils
18from numba.core.errors import TypingError
19from .arrayobj import make_array, _empty_nd_impl, array_copy
20from numba.np import numpy_support as np_support
21
22ll_char = ir.IntType(8)
23ll_char_p = ll_char.as_pointer()
24ll_void_p = ll_char_p
25ll_intc = ir.IntType(32)
26ll_intc_p = ll_intc.as_pointer()
27intp_t = cgutils.intp_t
28ll_intp_p = intp_t.as_pointer()
29
30
31# fortran int type, this needs to match the F_INT C declaration in
32# _lapack.c and is present to accommodate potential future 64bit int
33# based LAPACK use.
34F_INT_nptype = np.int32
35F_INT_nbtype = types.int32
36
37# BLAS kinds as letters
38_blas_kinds = {
39    types.float32: 's',
40    types.float64: 'd',
41    types.complex64: 'c',
42    types.complex128: 'z',
43}
44
45
46def get_blas_kind(dtype, func_name="<BLAS function>"):
47    kind = _blas_kinds.get(dtype)
48    if kind is None:
49        raise TypeError("unsupported dtype for %s()" % (func_name,))
50    return kind
51
52
53def ensure_blas():
54    try:
55        import scipy.linalg.cython_blas
56    except ImportError:
57        raise ImportError("scipy 0.16+ is required for linear algebra")
58
59
60def ensure_lapack():
61    try:
62        import scipy.linalg.cython_lapack
63    except ImportError:
64        raise ImportError("scipy 0.16+ is required for linear algebra")
65
66
67def make_constant_slot(context, builder, ty, val):
68    const = context.get_constant_generic(builder, ty, val)
69    return cgutils.alloca_once_value(builder, const)
70
71
72class _BLAS:
73    """
74    Functions to return type signatures for wrapped
75    BLAS functions.
76    """
77
78    def __init__(self):
79        ensure_blas()
80
81    @classmethod
82    def numba_xxnrm2(cls, dtype):
83        rtype = getattr(dtype, "underlying_float", dtype)
84        sig = types.intc(types.char,             # kind
85                         types.intp,             # n
86                         types.CPointer(dtype),  # x
87                         types.intp,             # incx
88                         types.CPointer(rtype))  # returned
89
90        return types.ExternalFunction("numba_xxnrm2", sig)
91
92    @classmethod
93    def numba_xxgemm(cls, dtype):
94        sig = types.intc(
95            types.char,             # kind
96            types.char,             # transa
97            types.char,             # transb
98            types.intp,             # m
99            types.intp,             # n
100            types.intp,             # k
101            types.CPointer(dtype),  # alpha
102            types.CPointer(dtype),  # a
103            types.intp,             # lda
104            types.CPointer(dtype),  # b
105            types.intp,             # ldb
106            types.CPointer(dtype),  # beta
107            types.CPointer(dtype),  # c
108            types.intp              # ldc
109        )
110        return types.ExternalFunction("numba_xxgemm", sig)
111
112
113class _LAPACK:
114    """
115    Functions to return type signatures for wrapped
116    LAPACK functions.
117    """
118
119    def __init__(self):
120        ensure_lapack()
121
122    @classmethod
123    def numba_xxgetrf(cls, dtype):
124        sig = types.intc(types.char,                   # kind
125                         types.intp,                   # m
126                         types.intp,                   # n
127                         types.CPointer(dtype),        # a
128                         types.intp,                   # lda
129                         types.CPointer(F_INT_nbtype)  # ipiv
130                         )
131        return types.ExternalFunction("numba_xxgetrf", sig)
132
133    @classmethod
134    def numba_ez_xxgetri(cls, dtype):
135        sig = types.intc(types.char,                   # kind
136                         types.intp,                   # n
137                         types.CPointer(dtype),        # a
138                         types.intp,                   # lda
139                         types.CPointer(F_INT_nbtype)  # ipiv
140                         )
141        return types.ExternalFunction("numba_ez_xxgetri", sig)
142
143    @classmethod
144    def numba_ez_rgeev(cls, dtype):
145        sig = types.intc(types.char,             # kind
146                         types.char,             # jobvl
147                         types.char,             # jobvr
148                         types.intp,             # n
149                         types.CPointer(dtype),  # a
150                         types.intp,             # lda
151                         types.CPointer(dtype),  # wr
152                         types.CPointer(dtype),  # wi
153                         types.CPointer(dtype),  # vl
154                         types.intp,             # ldvl
155                         types.CPointer(dtype),  # vr
156                         types.intp              # ldvr
157                         )
158        return types.ExternalFunction("numba_ez_rgeev", sig)
159
160    @classmethod
161    def numba_ez_cgeev(cls, dtype):
162        sig = types.intc(types.char,             # kind
163                         types.char,             # jobvl
164                         types.char,             # jobvr
165                         types.intp,             # n
166                         types.CPointer(dtype),  # a
167                         types.intp,             # lda
168                         types.CPointer(dtype),  # w
169                         types.CPointer(dtype),  # vl
170                         types.intp,             # ldvl
171                         types.CPointer(dtype),  # vr
172                         types.intp              # ldvr
173                         )
174        return types.ExternalFunction("numba_ez_cgeev", sig)
175
176    @classmethod
177    def numba_ez_xxxevd(cls, dtype):
178        wtype = getattr(dtype, "underlying_float", dtype)
179        sig = types.intc(types.char,             # kind
180                         types.char,             # jobz
181                         types.char,             # uplo
182                         types.intp,             # n
183                         types.CPointer(dtype),  # a
184                         types.intp,             # lda
185                         types.CPointer(wtype),  # w
186                         )
187        return types.ExternalFunction("numba_ez_xxxevd", sig)
188
189    @classmethod
190    def numba_xxpotrf(cls, dtype):
191        sig = types.intc(types.char,             # kind
192                         types.char,             # uplo
193                         types.intp,             # n
194                         types.CPointer(dtype),  # a
195                         types.intp              # lda
196                         )
197        return types.ExternalFunction("numba_xxpotrf", sig)
198
199    @classmethod
200    def numba_ez_gesdd(cls, dtype):
201        stype = getattr(dtype, "underlying_float", dtype)
202        sig = types.intc(
203            types.char,             # kind
204            types.char,             # jobz
205            types.intp,             # m
206            types.intp,             # n
207            types.CPointer(dtype),  # a
208            types.intp,             # lda
209            types.CPointer(stype),  # s
210            types.CPointer(dtype),  # u
211            types.intp,             # ldu
212            types.CPointer(dtype),  # vt
213            types.intp              # ldvt
214        )
215
216        return types.ExternalFunction("numba_ez_gesdd", sig)
217
218    @classmethod
219    def numba_ez_geqrf(cls, dtype):
220        sig = types.intc(
221            types.char,             # kind
222            types.intp,             # m
223            types.intp,             # n
224            types.CPointer(dtype),  # a
225            types.intp,             # lda
226            types.CPointer(dtype),  # tau
227        )
228        return types.ExternalFunction("numba_ez_geqrf", sig)
229
230    @classmethod
231    def numba_ez_xxgqr(cls, dtype):
232        sig = types.intc(
233            types.char,             # kind
234            types.intp,             # m
235            types.intp,             # n
236            types.intp,             # k
237            types.CPointer(dtype),  # a
238            types.intp,             # lda
239            types.CPointer(dtype),  # tau
240        )
241        return types.ExternalFunction("numba_ez_xxgqr", sig)
242
243    @classmethod
244    def numba_ez_gelsd(cls, dtype):
245        rtype = getattr(dtype, "underlying_float", dtype)
246        sig = types.intc(
247            types.char,                 # kind
248            types.intp,                 # m
249            types.intp,                 # n
250            types.intp,                 # nrhs
251            types.CPointer(dtype),      # a
252            types.intp,                 # lda
253            types.CPointer(dtype),      # b
254            types.intp,                 # ldb
255            types.CPointer(rtype),      # S
256            types.float64,              # rcond
257            types.CPointer(types.intc)  # rank
258        )
259        return types.ExternalFunction("numba_ez_gelsd", sig)
260
261    @classmethod
262    def numba_xgesv(cls, dtype):
263        sig = types.intc(
264            types.char,                    # kind
265            types.intp,                    # n
266            types.intp,                    # nhrs
267            types.CPointer(dtype),         # a
268            types.intp,                    # lda
269            types.CPointer(F_INT_nbtype),  # ipiv
270            types.CPointer(dtype),         # b
271            types.intp                     # ldb
272        )
273        return types.ExternalFunction("numba_xgesv", sig)
274
275
276@contextlib.contextmanager
277def make_contiguous(context, builder, sig, args):
278    """
279    Ensure that all array arguments are contiguous, if necessary by
280    copying them.
281    A new (sig, args) tuple is yielded.
282    """
283    newtys = []
284    newargs = []
285    copies = []
286    for ty, val in zip(sig.args, args):
287        if not isinstance(ty, types.Array) or ty.layout in 'CF':
288            newty, newval = ty, val
289        else:
290            newty = ty.copy(layout='C')
291            copysig = signature(newty, ty)
292            newval = array_copy(context, builder, copysig, (val,))
293            copies.append((newty, newval))
294        newtys.append(newty)
295        newargs.append(newval)
296    yield signature(sig.return_type, *newtys), tuple(newargs)
297    for ty, val in copies:
298        context.nrt.decref(builder, ty, val)
299
300
301def check_c_int(context, builder, n):
302    """
303    Check whether *n* fits in a C `int`.
304    """
305    _maxint = 2**31 - 1
306
307    def impl(n):
308        if n > _maxint:
309            raise OverflowError("array size too large to fit in C int")
310
311    context.compile_internal(builder, impl,
312                             signature(types.none, types.intp), (n,))
313
314
315def check_blas_return(context, builder, res):
316    """
317    Check the integer error return from one of the BLAS wrappers in
318    _helperlib.c.
319    """
320    with builder.if_then(cgutils.is_not_null(builder, res), likely=False):
321        # Those errors shouldn't happen, it's easier to just abort the process
322        pyapi = context.get_python_api(builder)
323        pyapi.gil_ensure()
324        pyapi.fatal_error("BLAS wrapper returned with an error")
325
326
327def check_lapack_return(context, builder, res):
328    """
329    Check the integer error return from one of the LAPACK wrappers in
330    _helperlib.c.
331    """
332    with builder.if_then(cgutils.is_not_null(builder, res), likely=False):
333        # Those errors shouldn't happen, it's easier to just abort the process
334        pyapi = context.get_python_api(builder)
335        pyapi.gil_ensure()
336        pyapi.fatal_error("LAPACK wrapper returned with an error")
337
338
339def call_xxdot(context, builder, conjugate, dtype,
340               n, a_data, b_data, out_data):
341    """
342    Call the BLAS vector * vector product function for the given arguments.
343    """
344    fnty = ir.FunctionType(ir.IntType(32),
345                           [ll_char, ll_char, intp_t,    # kind, conjugate, n
346                            ll_void_p, ll_void_p, ll_void_p,  # a, b, out
347                            ])
348    fn = builder.module.get_or_insert_function(fnty, name="numba_xxdot")
349
350    kind = get_blas_kind(dtype)
351    kind_val = ir.Constant(ll_char, ord(kind))
352    conjugate = ir.Constant(ll_char, int(conjugate))
353
354    res = builder.call(fn, (kind_val, conjugate, n,
355                            builder.bitcast(a_data, ll_void_p),
356                            builder.bitcast(b_data, ll_void_p),
357                            builder.bitcast(out_data, ll_void_p)))
358    check_blas_return(context, builder, res)
359
360
361def call_xxgemv(context, builder, do_trans,
362                m_type, m_shapes, m_data, v_data, out_data):
363    """
364    Call the BLAS matrix * vector product function for the given arguments.
365    """
366    fnty = ir.FunctionType(ir.IntType(32),
367                           [ll_char, ll_char,                 # kind, trans
368                            intp_t, intp_t,                   # m, n
369                            ll_void_p, ll_void_p, intp_t,     # alpha, a, lda
370                            ll_void_p, ll_void_p, ll_void_p,  # x, beta, y
371                            ])
372    fn = builder.module.get_or_insert_function(fnty, name="numba_xxgemv")
373
374    dtype = m_type.dtype
375    alpha = make_constant_slot(context, builder, dtype, 1.0)
376    beta = make_constant_slot(context, builder, dtype, 0.0)
377
378    if m_type.layout == 'F':
379        m, n = m_shapes
380        lda = m_shapes[0]
381    else:
382        n, m = m_shapes
383        lda = m_shapes[1]
384
385    kind = get_blas_kind(dtype)
386    kind_val = ir.Constant(ll_char, ord(kind))
387    trans = ir.Constant(ll_char, ord('t') if do_trans else ord('n'))
388
389    res = builder.call(fn, (kind_val, trans, m, n,
390                            builder.bitcast(alpha, ll_void_p),
391                            builder.bitcast(m_data, ll_void_p), lda,
392                            builder.bitcast(v_data, ll_void_p),
393                            builder.bitcast(beta, ll_void_p),
394                            builder.bitcast(out_data, ll_void_p)))
395    check_blas_return(context, builder, res)
396
397
398def call_xxgemm(context, builder,
399                x_type, x_shapes, x_data,
400                y_type, y_shapes, y_data,
401                out_type, out_shapes, out_data):
402    """
403    Call the BLAS matrix * matrix product function for the given arguments.
404    """
405    fnty = ir.FunctionType(ir.IntType(32),
406                           [ll_char,                       # kind
407                            ll_char, ll_char,              # transa, transb
408                            intp_t, intp_t, intp_t,        # m, n, k
409                            ll_void_p, ll_void_p, intp_t,  # alpha, a, lda
410                            ll_void_p, intp_t, ll_void_p,  # b, ldb, beta
411                            ll_void_p, intp_t,             # c, ldc
412                            ])
413    fn = builder.module.get_or_insert_function(fnty, name="numba_xxgemm")
414
415    m, k = x_shapes
416    _k, n = y_shapes
417    dtype = x_type.dtype
418    alpha = make_constant_slot(context, builder, dtype, 1.0)
419    beta = make_constant_slot(context, builder, dtype, 0.0)
420
421    trans = ir.Constant(ll_char, ord('t'))
422    notrans = ir.Constant(ll_char, ord('n'))
423
424    def get_array_param(ty, shapes, data):
425        return (
426            # Transpose if layout different from result's
427            notrans if ty.layout == out_type.layout else trans,
428            # Size of the inner dimension in physical array order
429            shapes[1] if ty.layout == 'C' else shapes[0],
430            # The data pointer, unit-less
431            builder.bitcast(data, ll_void_p),
432        )
433
434    transa, lda, data_a = get_array_param(y_type, y_shapes, y_data)
435    transb, ldb, data_b = get_array_param(x_type, x_shapes, x_data)
436    _, ldc, data_c = get_array_param(out_type, out_shapes, out_data)
437
438    kind = get_blas_kind(dtype)
439    kind_val = ir.Constant(ll_char, ord(kind))
440
441    res = builder.call(fn, (kind_val, transa, transb, n, m, k,
442                            builder.bitcast(alpha, ll_void_p), data_a, lda,
443                            data_b, ldb, builder.bitcast(beta, ll_void_p),
444                            data_c, ldc))
445    check_blas_return(context, builder, res)
446
447
448def dot_2_mm(context, builder, sig, args):
449    """
450    np.dot(matrix, matrix)
451    """
452    def dot_impl(a, b):
453        m, k = a.shape
454        _k, n = b.shape
455        if k == 0:
456            return np.zeros((m, n), a.dtype)
457        out = np.empty((m, n), a.dtype)
458        return np.dot(a, b, out)
459
460    res = context.compile_internal(builder, dot_impl, sig, args)
461    return impl_ret_new_ref(context, builder, sig.return_type, res)
462
463
464def dot_2_vm(context, builder, sig, args):
465    """
466    np.dot(vector, matrix)
467    """
468    def dot_impl(a, b):
469        m, = a.shape
470        _m, n = b.shape
471        if m == 0:
472            return np.zeros((n, ), a.dtype)
473        out = np.empty((n, ), a.dtype)
474        return np.dot(a, b, out)
475
476    res = context.compile_internal(builder, dot_impl, sig, args)
477    return impl_ret_new_ref(context, builder, sig.return_type, res)
478
479
480def dot_2_mv(context, builder, sig, args):
481    """
482    np.dot(matrix, vector)
483    """
484    def dot_impl(a, b):
485        m, n = a.shape
486        _n, = b.shape
487        if n == 0:
488            return np.zeros((m, ), a.dtype)
489        out = np.empty((m, ), a.dtype)
490        return np.dot(a, b, out)
491
492    res = context.compile_internal(builder, dot_impl, sig, args)
493    return impl_ret_new_ref(context, builder, sig.return_type, res)
494
495
496def dot_2_vv(context, builder, sig, args, conjugate=False):
497    """
498    np.dot(vector, vector)
499    np.vdot(vector, vector)
500    """
501    aty, bty = sig.args
502    dtype = sig.return_type
503    a = make_array(aty)(context, builder, args[0])
504    b = make_array(bty)(context, builder, args[1])
505    n, = cgutils.unpack_tuple(builder, a.shape)
506
507    def check_args(a, b):
508        m, = a.shape
509        n, = b.shape
510        if m != n:
511            raise ValueError("incompatible array sizes for np.dot(a, b) "
512                             "(vector * vector)")
513
514    context.compile_internal(builder, check_args,
515                             signature(types.none, *sig.args), args)
516    check_c_int(context, builder, n)
517
518    out = cgutils.alloca_once(builder, context.get_value_type(dtype))
519    call_xxdot(context, builder, conjugate, dtype, n, a.data, b.data, out)
520    return builder.load(out)
521
522
523@lower_builtin(np.dot, types.Array, types.Array)
524def dot_2(context, builder, sig, args):
525    """
526    np.dot(a, b)
527    a @ b
528    """
529    ensure_blas()
530
531    with make_contiguous(context, builder, sig, args) as (sig, args):
532        ndims = [x.ndim for x in sig.args[:2]]
533        if ndims == [2, 2]:
534            return dot_2_mm(context, builder, sig, args)
535        elif ndims == [2, 1]:
536            return dot_2_mv(context, builder, sig, args)
537        elif ndims == [1, 2]:
538            return dot_2_vm(context, builder, sig, args)
539        elif ndims == [1, 1]:
540            return dot_2_vv(context, builder, sig, args)
541        else:
542            assert 0
543
544
545lower_builtin(operator.matmul, types.Array, types.Array)(dot_2)
546
547
548@lower_builtin(np.vdot, types.Array, types.Array)
549def vdot(context, builder, sig, args):
550    """
551    np.vdot(a, b)
552    """
553    ensure_blas()
554
555    with make_contiguous(context, builder, sig, args) as (sig, args):
556        return dot_2_vv(context, builder, sig, args, conjugate=True)
557
558
559def dot_3_vm_check_args(a, b, out):
560    m, = a.shape
561    _m, n = b.shape
562    if m != _m:
563        raise ValueError("incompatible array sizes for "
564                         "np.dot(a, b) (vector * matrix)")
565    if out.shape != (n,):
566        raise ValueError("incompatible output array size for "
567                         "np.dot(a, b, out) (vector * matrix)")
568
569
570def dot_3_mv_check_args(a, b, out):
571    m, _n = a.shape
572    n, = b.shape
573    if n != _n:
574        raise ValueError("incompatible array sizes for np.dot(a, b) "
575                         "(matrix * vector)")
576    if out.shape != (m,):
577        raise ValueError("incompatible output array size for "
578                         "np.dot(a, b, out) (matrix * vector)")
579
580
581def dot_3_vm(context, builder, sig, args):
582    """
583    np.dot(vector, matrix, out)
584    np.dot(matrix, vector, out)
585    """
586    xty, yty, outty = sig.args
587    assert outty == sig.return_type
588    dtype = xty.dtype
589
590    x = make_array(xty)(context, builder, args[0])
591    y = make_array(yty)(context, builder, args[1])
592    out = make_array(outty)(context, builder, args[2])
593    x_shapes = cgutils.unpack_tuple(builder, x.shape)
594    y_shapes = cgutils.unpack_tuple(builder, y.shape)
595    out_shapes = cgutils.unpack_tuple(builder, out.shape)
596    if xty.ndim < yty.ndim:
597        # Vector * matrix
598        # Asked for x * y, we will compute y.T * x
599        mty = yty
600        m_shapes = y_shapes
601        v_shape = x_shapes[0]
602        lda = m_shapes[1]
603        do_trans = yty.layout == 'F'
604        m_data, v_data = y.data, x.data
605        check_args = dot_3_vm_check_args
606    else:
607        # Matrix * vector
608        # We will compute x * y
609        mty = xty
610        m_shapes = x_shapes
611        v_shape = y_shapes[0]
612        lda = m_shapes[0]
613        do_trans = xty.layout == 'C'
614        m_data, v_data = x.data, y.data
615        check_args = dot_3_mv_check_args
616
617    context.compile_internal(builder, check_args,
618                             signature(types.none, *sig.args), args)
619    for val in m_shapes:
620        check_c_int(context, builder, val)
621
622    zero = context.get_constant(types.intp, 0)
623    both_empty = builder.icmp_signed('==', v_shape, zero)
624    matrix_empty = builder.icmp_signed('==', lda, zero)
625    is_empty = builder.or_(both_empty, matrix_empty)
626    with builder.if_else(is_empty, likely=False) as (empty, nonempty):
627        with empty:
628            cgutils.memset(builder, out.data,
629                           builder.mul(out.itemsize, out.nitems), 0)
630        with nonempty:
631            call_xxgemv(context, builder, do_trans, mty, m_shapes, m_data,
632                        v_data, out.data)
633
634    return impl_ret_borrowed(context, builder, sig.return_type,
635                             out._getvalue())
636
637
638def dot_3_mm(context, builder, sig, args):
639    """
640    np.dot(matrix, matrix, out)
641    """
642    xty, yty, outty = sig.args
643    assert outty == sig.return_type
644    dtype = xty.dtype
645
646    x = make_array(xty)(context, builder, args[0])
647    y = make_array(yty)(context, builder, args[1])
648    out = make_array(outty)(context, builder, args[2])
649    x_shapes = cgutils.unpack_tuple(builder, x.shape)
650    y_shapes = cgutils.unpack_tuple(builder, y.shape)
651    out_shapes = cgutils.unpack_tuple(builder, out.shape)
652    m, k = x_shapes
653    _k, n = y_shapes
654
655    # The only case Numpy supports
656    assert outty.layout == 'C'
657
658    def check_args(a, b, out):
659        m, k = a.shape
660        _k, n = b.shape
661        if k != _k:
662            raise ValueError("incompatible array sizes for np.dot(a, b) "
663                             "(matrix * matrix)")
664        if out.shape != (m, n):
665            raise ValueError("incompatible output array size for "
666                             "np.dot(a, b, out) (matrix * matrix)")
667
668    context.compile_internal(builder, check_args,
669                             signature(types.none, *sig.args), args)
670
671    check_c_int(context, builder, m)
672    check_c_int(context, builder, k)
673    check_c_int(context, builder, n)
674
675    x_data = x.data
676    y_data = y.data
677    out_data = out.data
678
679    # If eliminated dimension is zero, set all entries to zero and return
680    zero = context.get_constant(types.intp, 0)
681    both_empty = builder.icmp_signed('==', k, zero)
682    x_empty = builder.icmp_signed('==', m, zero)
683    y_empty = builder.icmp_signed('==', n, zero)
684    is_empty = builder.or_(both_empty, builder.or_(x_empty, y_empty))
685    with builder.if_else(is_empty, likely=False) as (empty, nonempty):
686        with empty:
687            cgutils.memset(builder, out.data,
688                           builder.mul(out.itemsize, out.nitems), 0)
689        with nonempty:
690            # Check if any of the operands is really a 1-d vector represented
691            # as a (1, k) or (k, 1) 2-d array.  In those cases, it is pessimal
692            # to call the generic matrix * matrix product BLAS function.
693            one = context.get_constant(types.intp, 1)
694            is_left_vec = builder.icmp_signed('==', m, one)
695            is_right_vec = builder.icmp_signed('==', n, one)
696
697            with builder.if_else(is_right_vec) as (r_vec, r_mat):
698                with r_vec:
699                    with builder.if_else(is_left_vec) as (v_v, m_v):
700                        with v_v:
701                            # V * V
702                            call_xxdot(context, builder, False, dtype,
703                                       k, x_data, y_data, out_data)
704                        with m_v:
705                            # M * V
706                            do_trans = xty.layout == outty.layout
707                            call_xxgemv(context, builder, do_trans,
708                                        xty, x_shapes, x_data, y_data, out_data)
709                with r_mat:
710                    with builder.if_else(is_left_vec) as (v_m, m_m):
711                        with v_m:
712                            # V * M
713                            do_trans = yty.layout != outty.layout
714                            call_xxgemv(context, builder, do_trans,
715                                        yty, y_shapes, y_data, x_data, out_data)
716                        with m_m:
717                            # M * M
718                            call_xxgemm(context, builder,
719                                        xty, x_shapes, x_data,
720                                        yty, y_shapes, y_data,
721                                        outty, out_shapes, out_data)
722
723    return impl_ret_borrowed(context, builder, sig.return_type,
724                             out._getvalue())
725
726
727@lower_builtin(np.dot, types.Array, types.Array,
728               types.Array)
729def dot_3(context, builder, sig, args):
730    """
731    np.dot(a, b, out)
732    """
733    ensure_blas()
734
735    with make_contiguous(context, builder, sig, args) as (sig, args):
736        ndims = set(x.ndim for x in sig.args[:2])
737        if ndims == set([2]):
738            return dot_3_mm(context, builder, sig, args)
739        elif ndims == set([1, 2]):
740            return dot_3_vm(context, builder, sig, args)
741        else:
742            assert 0
743
744fatal_error_sig = types.intc()
745fatal_error_func = types.ExternalFunction("numba_fatal_error", fatal_error_sig)
746
747
748@register_jitable
749def _check_finite_matrix(a):
750    for v in np.nditer(a):
751        if not np.isfinite(v.item()):
752            raise np.linalg.LinAlgError(
753                "Array must not contain infs or NaNs.")
754
755
756def _check_linalg_matrix(a, func_name, la_prefix=True):
757    # la_prefix is present as some functions, e.g. np.trace()
758    # are documented under "linear algebra" but aren't in the
759    # module
760    prefix = "np.linalg" if la_prefix else "np"
761    interp = (prefix, func_name)
762    # Unpack optional type
763    if isinstance(a, types.Optional):
764        a = a.type
765    if not isinstance(a, types.Array):
766        msg = "%s.%s() only supported for array types" % interp
767        raise TypingError(msg, highlighting=False)
768    if not a.ndim == 2:
769        msg = "%s.%s() only supported on 2-D arrays." % interp
770        raise TypingError(msg, highlighting=False)
771    if not isinstance(a.dtype, (types.Float, types.Complex)):
772        msg = "%s.%s() only supported on "\
773            "float and complex arrays." % interp
774        raise TypingError(msg, highlighting=False)
775
776
777def _check_homogeneous_types(func_name, *types):
778    t0 = types[0].dtype
779    for t in types[1:]:
780        if t.dtype != t0:
781            msg = "np.linalg.%s() only supports inputs that have homogeneous dtypes." % func_name
782            raise TypingError(msg, highlighting=False)
783
784
785def _copy_to_fortran_order():
786    pass
787
788
789@overload(_copy_to_fortran_order)
790def ol_copy_to_fortran_order(a):
791    # This function copies the array 'a' into a new array with fortran order.
792    # This exists because the copy routines don't take order flags yet.
793    F_layout = a.layout == 'F'
794    A_layout = a.layout == 'A'
795    def impl(a):
796        if F_layout:
797            # it's F ordered at compile time, just copy
798            acpy = np.copy(a)
799        elif A_layout:
800            # decide based on runtime value
801            flag_f = a.flags.f_contiguous
802            if flag_f:
803                # it's already F ordered, so copy but in a round about way to
804                # ensure that the copy is also F ordered
805                acpy = np.copy(a.T).T
806            else:
807                # it's something else ordered, so let asfortranarray deal with
808                # copying and making it fortran ordered
809                acpy = np.asfortranarray(a)
810        else:
811            # it's C ordered at compile time, asfortranarray it.
812            acpy = np.asfortranarray(a)
813        return acpy
814    return impl
815
816
817@register_jitable
818def _inv_err_handler(r):
819    if r != 0:
820        if r < 0:
821            fatal_error_func()
822            assert 0   # unreachable
823        if r > 0:
824            raise np.linalg.LinAlgError(
825                "Matrix is singular to machine precision.")
826
827@register_jitable
828def _dummy_liveness_func(a):
829    """pass a list of variables to be preserved through dead code elimination"""
830    return a[0]
831
832
833@overload(np.linalg.inv)
834def inv_impl(a):
835    ensure_lapack()
836
837    _check_linalg_matrix(a, "inv")
838
839    numba_xxgetrf = _LAPACK().numba_xxgetrf(a.dtype)
840
841    numba_xxgetri = _LAPACK().numba_ez_xxgetri(a.dtype)
842
843    kind = ord(get_blas_kind(a.dtype, "inv"))
844
845    def inv_impl(a):
846        n = a.shape[-1]
847        if a.shape[-2] != n:
848            msg = "Last 2 dimensions of the array must be square."
849            raise np.linalg.LinAlgError(msg)
850
851        _check_finite_matrix(a)
852
853        acpy = _copy_to_fortran_order(a)
854
855        if n == 0:
856            return acpy
857
858        ipiv = np.empty(n, dtype=F_INT_nptype)
859
860        r = numba_xxgetrf(kind, n, n, acpy.ctypes, n, ipiv.ctypes)
861        _inv_err_handler(r)
862
863        r = numba_xxgetri(kind, n, acpy.ctypes, n, ipiv.ctypes)
864        _inv_err_handler(r)
865
866        # help liveness analysis
867        _dummy_liveness_func([acpy.size, ipiv.size])
868        return acpy
869
870    return inv_impl
871
872
873@register_jitable
874def _handle_err_maybe_convergence_problem(r):
875    if r != 0:
876        if r < 0:
877            fatal_error_func()
878            assert 0   # unreachable
879        if r > 0:
880            raise ValueError("Internal algorithm failed to converge.")
881
882
883def _check_linalg_1_or_2d_matrix(a, func_name, la_prefix=True):
884    # la_prefix is present as some functions, e.g. np.trace()
885    # are documented under "linear algebra" but aren't in the
886    # module
887    prefix = "np.linalg" if la_prefix else "np"
888    interp = (prefix, func_name)
889    # checks that a matrix is 1 or 2D
890    if not isinstance(a, types.Array):
891        raise TypingError("%s.%s() only supported for array types "
892                          % interp)
893    if not a.ndim <= 2:
894        raise TypingError("%s.%s() only supported on 1 and 2-D arrays "
895                          % interp)
896    if not isinstance(a.dtype, (types.Float, types.Complex)):
897        raise TypingError("%s.%s() only supported on "
898                          "float and complex arrays." % interp)
899
900
901@overload(np.linalg.cholesky)
902def cho_impl(a):
903    ensure_lapack()
904
905    _check_linalg_matrix(a, "cholesky")
906
907    numba_xxpotrf = _LAPACK().numba_xxpotrf(a.dtype)
908
909    kind = ord(get_blas_kind(a.dtype, "cholesky"))
910    UP = ord('U')
911    LO = ord('L')
912
913    def cho_impl(a):
914        n = a.shape[-1]
915        if a.shape[-2] != n:
916            msg = "Last 2 dimensions of the array must be square."
917            raise np.linalg.LinAlgError(msg)
918
919        # The output is allocated in C order
920        out = a.copy()
921
922        if n == 0:
923            return out
924
925        # Pass UP since xxpotrf() operates in F order
926        # The semantics ensure this works fine
927        # (out is really its Hermitian in F order, but UP instructs
928        #  xxpotrf to compute the Hermitian of the upper triangle
929        #  => they cancel each other)
930        r = numba_xxpotrf(kind, UP, n, out.ctypes, n)
931        if r != 0:
932            if r < 0:
933                fatal_error_func()
934                assert 0   # unreachable
935            if r > 0:
936                raise np.linalg.LinAlgError(
937                    "Matrix is not positive definite.")
938        # Zero out upper triangle, in F order
939        for col in range(n):
940            out[:col, col] = 0
941        return out
942
943    return cho_impl
944
945@overload(np.linalg.eig)
946def eig_impl(a):
947    ensure_lapack()
948
949    _check_linalg_matrix(a, "eig")
950
951    numba_ez_rgeev = _LAPACK().numba_ez_rgeev(a.dtype)
952    numba_ez_cgeev = _LAPACK().numba_ez_cgeev(a.dtype)
953
954    kind = ord(get_blas_kind(a.dtype, "eig"))
955
956    JOBVL = ord('N')
957    JOBVR = ord('V')
958
959    def real_eig_impl(a):
960        """
961        eig() implementation for real arrays.
962        """
963        n = a.shape[-1]
964        if a.shape[-2] != n:
965            msg = "Last 2 dimensions of the array must be square."
966            raise np.linalg.LinAlgError(msg)
967
968        _check_finite_matrix(a)
969
970        acpy = _copy_to_fortran_order(a)
971
972        ldvl = 1
973        ldvr = n
974        wr = np.empty(n, dtype=a.dtype)
975        wi = np.empty(n, dtype=a.dtype)
976        vl = np.empty((n, ldvl), dtype=a.dtype)
977        vr = np.empty((n, ldvr), dtype=a.dtype)
978
979        if n == 0:
980            return (wr, vr.T)
981
982        r = numba_ez_rgeev(kind,
983                            JOBVL,
984                            JOBVR,
985                            n,
986                            acpy.ctypes,
987                            n,
988                            wr.ctypes,
989                            wi.ctypes,
990                            vl.ctypes,
991                            ldvl,
992                            vr.ctypes,
993                            ldvr)
994        _handle_err_maybe_convergence_problem(r)
995
996        # By design numba does not support dynamic return types, however,
997        # Numpy does. Numpy uses this ability in the case of returning
998        # eigenvalues/vectors of a real matrix. The return type of
999        # np.linalg.eig(), when operating on a matrix in real space
1000        # depends on the values present in the matrix itself (recalling
1001        # that eigenvalues are the roots of the characteristic polynomial
1002        # of the system matrix, which will by construction depend on the
1003        # values present in the system matrix). As numba cannot handle
1004        # the case of a runtime decision based domain change relative to
1005        # the input type, if it is required numba raises as below.
1006        if np.any(wi):
1007            raise ValueError(
1008                "eig() argument must not cause a domain change.")
1009
1010        # put these in to help with liveness analysis,
1011        # `.ctypes` doesn't keep the vars alive
1012        _dummy_liveness_func([acpy.size, vl.size, vr.size, wr.size, wi.size])
1013        return (wr, vr.T)
1014
1015    def cmplx_eig_impl(a):
1016        """
1017        eig() implementation for complex arrays.
1018        """
1019        n = a.shape[-1]
1020        if a.shape[-2] != n:
1021            msg = "Last 2 dimensions of the array must be square."
1022            raise np.linalg.LinAlgError(msg)
1023
1024        _check_finite_matrix(a)
1025
1026        acpy = _copy_to_fortran_order(a)
1027
1028        ldvl = 1
1029        ldvr = n
1030        w = np.empty(n, dtype=a.dtype)
1031        vl = np.empty((n, ldvl), dtype=a.dtype)
1032        vr = np.empty((n, ldvr), dtype=a.dtype)
1033
1034        if n == 0:
1035            return (w, vr.T)
1036
1037        r = numba_ez_cgeev(kind,
1038                            JOBVL,
1039                            JOBVR,
1040                            n,
1041                            acpy.ctypes,
1042                            n,
1043                            w.ctypes,
1044                            vl.ctypes,
1045                            ldvl,
1046                            vr.ctypes,
1047                            ldvr)
1048        _handle_err_maybe_convergence_problem(r)
1049
1050        # put these in to help with liveness analysis,
1051        # `.ctypes` doesn't keep the vars alive
1052        _dummy_liveness_func([acpy.size, vl.size, vr.size, w.size])
1053        return (w, vr.T)
1054
1055    if isinstance(a.dtype, types.scalars.Complex):
1056        return cmplx_eig_impl
1057    else:
1058        return real_eig_impl
1059
1060@overload(np.linalg.eigvals)
1061def eigvals_impl(a):
1062    ensure_lapack()
1063
1064    _check_linalg_matrix(a, "eigvals")
1065
1066    numba_ez_rgeev = _LAPACK().numba_ez_rgeev(a.dtype)
1067    numba_ez_cgeev = _LAPACK().numba_ez_cgeev(a.dtype)
1068
1069    kind = ord(get_blas_kind(a.dtype, "eigvals"))
1070
1071    JOBVL = ord('N')
1072    JOBVR = ord('N')
1073
1074    def real_eigvals_impl(a):
1075        """
1076        eigvals() implementation for real arrays.
1077        """
1078        n = a.shape[-1]
1079        if a.shape[-2] != n:
1080            msg = "Last 2 dimensions of the array must be square."
1081            raise np.linalg.LinAlgError(msg)
1082
1083        _check_finite_matrix(a)
1084
1085        acpy = _copy_to_fortran_order(a)
1086
1087        ldvl = 1
1088        ldvr = 1
1089        wr = np.empty(n, dtype=a.dtype)
1090
1091        if n == 0:
1092            return wr
1093
1094        wi = np.empty(n, dtype=a.dtype)
1095
1096        # not referenced but need setting for MKL null check
1097        vl = np.empty((1), dtype=a.dtype)
1098        vr = np.empty((1), dtype=a.dtype)
1099
1100        r = numba_ez_rgeev(kind,
1101                            JOBVL,
1102                            JOBVR,
1103                            n,
1104                            acpy.ctypes,
1105                            n,
1106                            wr.ctypes,
1107                            wi.ctypes,
1108                            vl.ctypes,
1109                            ldvl,
1110                            vr.ctypes,
1111                            ldvr)
1112        _handle_err_maybe_convergence_problem(r)
1113
1114        # By design numba does not support dynamic return types, however,
1115        # Numpy does. Numpy uses this ability in the case of returning
1116        # eigenvalues/vectors of a real matrix. The return type of
1117        # np.linalg.eigvals(), when operating on a matrix in real space
1118        # depends on the values present in the matrix itself (recalling
1119        # that eigenvalues are the roots of the characteristic polynomial
1120        # of the system matrix, which will by construction depend on the
1121        # values present in the system matrix). As numba cannot handle
1122        # the case of a runtime decision based domain change relative to
1123        # the input type, if it is required numba raises as below.
1124        if np.any(wi):
1125            raise ValueError(
1126                "eigvals() argument must not cause a domain change.")
1127
1128        # put these in to help with liveness analysis,
1129        # `.ctypes` doesn't keep the vars alive
1130        _dummy_liveness_func([acpy.size, vl.size, vr.size, wr.size, wi.size])
1131        return wr
1132
1133    def cmplx_eigvals_impl(a):
1134        """
1135        eigvals() implementation for complex arrays.
1136        """
1137        n = a.shape[-1]
1138        if a.shape[-2] != n:
1139            msg = "Last 2 dimensions of the array must be square."
1140            raise np.linalg.LinAlgError(msg)
1141
1142        _check_finite_matrix(a)
1143
1144        acpy = _copy_to_fortran_order(a)
1145
1146        ldvl = 1
1147        ldvr = 1
1148        w = np.empty(n, dtype=a.dtype)
1149
1150        if n == 0:
1151            return w
1152
1153        vl = np.empty((1), dtype=a.dtype)
1154        vr = np.empty((1), dtype=a.dtype)
1155
1156        r = numba_ez_cgeev(kind,
1157                            JOBVL,
1158                            JOBVR,
1159                            n,
1160                            acpy.ctypes,
1161                            n,
1162                            w.ctypes,
1163                            vl.ctypes,
1164                            ldvl,
1165                            vr.ctypes,
1166                            ldvr)
1167        _handle_err_maybe_convergence_problem(r)
1168
1169        # put these in to help with liveness analysis,
1170        # `.ctypes` doesn't keep the vars alive
1171        _dummy_liveness_func([acpy.size, vl.size, vr.size, w.size])
1172        return w
1173
1174    if isinstance(a.dtype, types.scalars.Complex):
1175        return cmplx_eigvals_impl
1176    else:
1177        return real_eigvals_impl
1178
1179@overload(np.linalg.eigh)
1180def eigh_impl(a):
1181    ensure_lapack()
1182
1183    _check_linalg_matrix(a, "eigh")
1184
1185    # convert typing floats to numpy floats for use in the impl
1186    w_type = getattr(a.dtype, "underlying_float", a.dtype)
1187    w_dtype = np_support.as_dtype(w_type)
1188
1189    numba_ez_xxxevd = _LAPACK().numba_ez_xxxevd(a.dtype)
1190
1191    kind = ord(get_blas_kind(a.dtype, "eigh"))
1192
1193    JOBZ = ord('V')
1194    UPLO = ord('L')
1195
1196    def eigh_impl(a):
1197        n = a.shape[-1]
1198
1199        if a.shape[-2] != n:
1200            msg = "Last 2 dimensions of the array must be square."
1201            raise np.linalg.LinAlgError(msg)
1202
1203        _check_finite_matrix(a)
1204
1205        acpy = _copy_to_fortran_order(a)
1206
1207        w = np.empty(n, dtype=w_dtype)
1208
1209        if n == 0:
1210            return (w, acpy)
1211
1212        r = numba_ez_xxxevd(kind,  # kind
1213                            JOBZ,  # jobz
1214                            UPLO,  # uplo
1215                            n,  # n
1216                            acpy.ctypes,  # a
1217                            n,  # lda
1218                            w.ctypes  # w
1219                            )
1220        _handle_err_maybe_convergence_problem(r)
1221
1222        # help liveness analysis
1223        _dummy_liveness_func([acpy.size, w.size])
1224        return (w, acpy)
1225
1226    return eigh_impl
1227
1228@overload(np.linalg.eigvalsh)
1229def eigvalsh_impl(a):
1230    ensure_lapack()
1231
1232    _check_linalg_matrix(a, "eigvalsh")
1233
1234    # convert typing floats to numpy floats for use in the impl
1235    w_type = getattr(a.dtype, "underlying_float", a.dtype)
1236    w_dtype = np_support.as_dtype(w_type)
1237
1238    numba_ez_xxxevd = _LAPACK().numba_ez_xxxevd(a.dtype)
1239
1240    kind = ord(get_blas_kind(a.dtype, "eigvalsh"))
1241
1242    JOBZ = ord('N')
1243    UPLO = ord('L')
1244
1245    def eigvalsh_impl(a):
1246        n = a.shape[-1]
1247
1248        if a.shape[-2] != n:
1249            msg = "Last 2 dimensions of the array must be square."
1250            raise np.linalg.LinAlgError(msg)
1251
1252        _check_finite_matrix(a)
1253
1254        acpy = _copy_to_fortran_order(a)
1255
1256        w = np.empty(n, dtype=w_dtype)
1257
1258        if n == 0:
1259            return w
1260
1261        r = numba_ez_xxxevd(kind,  # kind
1262                            JOBZ,  # jobz
1263                            UPLO,  # uplo
1264                            n,  # n
1265                            acpy.ctypes,  # a
1266                            n,  # lda
1267                            w.ctypes  # w
1268                            )
1269        _handle_err_maybe_convergence_problem(r)
1270
1271        # help liveness analysis
1272        _dummy_liveness_func([acpy.size, w.size])
1273        return w
1274
1275    return eigvalsh_impl
1276
1277@overload(np.linalg.svd)
1278def svd_impl(a, full_matrices=1):
1279    ensure_lapack()
1280
1281    _check_linalg_matrix(a, "svd")
1282
1283    # convert typing floats to numpy floats for use in the impl
1284    s_type = getattr(a.dtype, "underlying_float", a.dtype)
1285    s_dtype = np_support.as_dtype(s_type)
1286
1287    numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype)
1288
1289    kind = ord(get_blas_kind(a.dtype, "svd"))
1290
1291    JOBZ_A = ord('A')
1292    JOBZ_S = ord('S')
1293
1294    def svd_impl(a, full_matrices=1):
1295        n = a.shape[-1]
1296        m = a.shape[-2]
1297
1298        if n == 0 or m == 0:
1299            raise np.linalg.LinAlgError("Arrays cannot be empty")
1300
1301        _check_finite_matrix(a)
1302
1303        acpy = _copy_to_fortran_order(a)
1304
1305        ldu = m
1306        minmn = min(m, n)
1307
1308        if full_matrices:
1309            JOBZ = JOBZ_A
1310            ucol = m
1311            ldvt = n
1312        else:
1313            JOBZ = JOBZ_S
1314            ucol = minmn
1315            ldvt = minmn
1316
1317        u = np.empty((ucol, ldu), dtype=a.dtype)
1318        s = np.empty(minmn, dtype=s_dtype)
1319        vt = np.empty((n, ldvt), dtype=a.dtype)
1320
1321        r = numba_ez_gesdd(
1322            kind,  # kind
1323            JOBZ,  # jobz
1324            m,  # m
1325            n,  # n
1326            acpy.ctypes,  # a
1327            m,  # lda
1328            s.ctypes,  # s
1329            u.ctypes,  # u
1330            ldu,  # ldu
1331            vt.ctypes,  # vt
1332            ldvt          # ldvt
1333        )
1334        _handle_err_maybe_convergence_problem(r)
1335
1336        # help liveness analysis
1337        _dummy_liveness_func([acpy.size, vt.size, u.size, s.size])
1338        return (u.T, s, vt.T)
1339
1340    return svd_impl
1341
1342
1343@overload(np.linalg.qr)
1344def qr_impl(a):
1345    ensure_lapack()
1346
1347    _check_linalg_matrix(a, "qr")
1348
1349    # Need two functions, the first computes R, storing it in the upper
1350    # triangle of A with the below diagonal part of A containing elementary
1351    # reflectors needed to construct Q. The second turns the below diagonal
1352    # entries of A into Q, storing Q in A (creates orthonormal columns from
1353    # the elementary reflectors).
1354
1355    numba_ez_geqrf = _LAPACK().numba_ez_geqrf(a.dtype)
1356    numba_ez_xxgqr = _LAPACK().numba_ez_xxgqr(a.dtype)
1357
1358    kind = ord(get_blas_kind(a.dtype, "qr"))
1359
1360    def qr_impl(a):
1361        n = a.shape[-1]
1362        m = a.shape[-2]
1363
1364        if n == 0 or m == 0:
1365            raise np.linalg.LinAlgError("Arrays cannot be empty")
1366
1367        _check_finite_matrix(a)
1368
1369        # copy A as it will be destroyed
1370        q = _copy_to_fortran_order(a)
1371
1372        lda = m
1373
1374        minmn = min(m, n)
1375        tau = np.empty((minmn), dtype=a.dtype)
1376
1377        ret = numba_ez_geqrf(
1378            kind,  # kind
1379            m,  # m
1380            n,  # n
1381            q.ctypes,  # a
1382            m,  # lda
1383            tau.ctypes  # tau
1384        )
1385        if ret < 0:
1386            fatal_error_func()
1387            assert 0   # unreachable
1388
1389        # pull out R, this is transposed because of Fortran
1390        r = np.zeros((n, minmn), dtype=a.dtype).T
1391
1392        # the triangle in R
1393        for i in range(minmn):
1394            for j in range(i + 1):
1395                r[j, i] = q[j, i]
1396
1397        # and the possible square in R
1398        for i in range(minmn, n):
1399            for j in range(minmn):
1400                r[j, i] = q[j, i]
1401
1402        ret = numba_ez_xxgqr(
1403            kind,  # kind
1404            m,  # m
1405            minmn,  # n
1406            minmn,  # k
1407            q.ctypes,  # a
1408            m,  # lda
1409            tau.ctypes  # tau
1410        )
1411        _handle_err_maybe_convergence_problem(ret)
1412
1413        # help liveness analysis
1414        _dummy_liveness_func([tau.size, q.size])
1415        return (q[:, :minmn], r)
1416
1417    return qr_impl
1418
1419
1420# helpers and jitted specialisations required for np.linalg.lstsq
1421# and np.linalg.solve. These functions have "system" in their name
1422# as a differentiator.
1423
1424def _system_copy_in_b(bcpy, b, nrhs):
1425    """
1426    Correctly copy 'b' into the 'bcpy' scratch space.
1427    """
1428    raise NotImplementedError
1429
1430
1431@overload(_system_copy_in_b)
1432def _system_copy_in_b_impl(bcpy, b, nrhs):
1433    if b.ndim == 1:
1434        def oneD_impl(bcpy, b, nrhs):
1435            bcpy[:b.shape[-1], 0] = b
1436        return oneD_impl
1437    else:
1438        def twoD_impl(bcpy, b, nrhs):
1439            bcpy[:b.shape[-2], :nrhs] = b
1440        return twoD_impl
1441
1442
1443def _system_compute_nrhs(b):
1444    """
1445    Compute the number of right hand sides in the system of equations
1446    """
1447    raise NotImplementedError
1448
1449
1450@overload(_system_compute_nrhs)
1451def _system_compute_nrhs_impl(b):
1452    if b.ndim == 1:
1453        def oneD_impl(b):
1454            return 1
1455        return oneD_impl
1456    else:
1457        def twoD_impl(b):
1458            return b.shape[-1]
1459        return twoD_impl
1460
1461
1462def _system_check_dimensionally_valid(a, b):
1463    """
1464    Check that AX=B style system input is dimensionally valid.
1465    """
1466    raise NotImplementedError
1467
1468
1469@overload(_system_check_dimensionally_valid)
1470def _system_check_dimensionally_valid_impl(a, b):
1471    ndim = b.ndim
1472    if ndim == 1:
1473        def oneD_impl(a, b):
1474            am = a.shape[-2]
1475            bm = b.shape[-1]
1476            if am != bm:
1477                raise np.linalg.LinAlgError(
1478                    "Incompatible array sizes, system is not dimensionally valid.")
1479        return oneD_impl
1480    else:
1481        def twoD_impl(a, b):
1482            am = a.shape[-2]
1483            bm = b.shape[-2]
1484            if am != bm:
1485                raise np.linalg.LinAlgError(
1486                    "Incompatible array sizes, system is not dimensionally valid.")
1487        return twoD_impl
1488
1489
1490def _system_check_non_empty(a, b):
1491    """
1492    Check that AX=B style system input is not empty.
1493    """
1494    raise NotImplementedError
1495
1496
1497@overload(_system_check_non_empty)
1498def _system_check_non_empty_impl(a, b):
1499    ndim = b.ndim
1500    if ndim == 1:
1501        def oneD_impl(a, b):
1502            am = a.shape[-2]
1503            an = a.shape[-1]
1504            bm = b.shape[-1]
1505            if am == 0 or bm == 0 or an == 0:
1506                raise np.linalg.LinAlgError('Arrays cannot be empty')
1507        return oneD_impl
1508    else:
1509        def twoD_impl(a, b):
1510            am = a.shape[-2]
1511            an = a.shape[-1]
1512            bm = b.shape[-2]
1513            bn = b.shape[-1]
1514            if am == 0 or bm == 0 or an == 0 or bn == 0:
1515                raise np.linalg.LinAlgError('Arrays cannot be empty')
1516        return twoD_impl
1517
1518
1519def _lstsq_residual(b, n, nrhs):
1520    """
1521    Compute the residual from the 'b' scratch space.
1522    """
1523    raise NotImplementedError
1524
1525
1526@overload(_lstsq_residual)
1527def _lstsq_residual_impl(b, n, nrhs):
1528    ndim = b.ndim
1529    dtype = b.dtype
1530    real_dtype = np_support.as_dtype(getattr(dtype, "underlying_float", dtype))
1531
1532    if ndim == 1:
1533        if isinstance(dtype, (types.Complex)):
1534            def cmplx_impl(b, n, nrhs):
1535                res = np.empty((1,), dtype=real_dtype)
1536                res[0] = np.sum(np.abs(b[n:, 0])**2)
1537                return res
1538            return cmplx_impl
1539        else:
1540            def real_impl(b, n, nrhs):
1541                res = np.empty((1,), dtype=real_dtype)
1542                res[0] = np.sum(b[n:, 0]**2)
1543                return res
1544            return real_impl
1545    else:
1546        assert ndim == 2
1547        if isinstance(dtype, (types.Complex)):
1548            def cmplx_impl(b, n, nrhs):
1549                res = np.empty((nrhs), dtype=real_dtype)
1550                for k in range(nrhs):
1551                    res[k] = np.sum(np.abs(b[n:, k])**2)
1552                return res
1553            return cmplx_impl
1554        else:
1555            def real_impl(b, n, nrhs):
1556                res = np.empty((nrhs), dtype=real_dtype)
1557                for k in range(nrhs):
1558                    res[k] = np.sum(b[n:, k]**2)
1559                return res
1560            return real_impl
1561
1562
1563def _lstsq_solution(b, bcpy, n):
1564    """
1565    Extract 'x' (the lstsq solution) from the 'bcpy' scratch space.
1566    Note 'b' is only used to check the system input dimension...
1567    """
1568    raise NotImplementedError
1569
1570
1571@overload(_lstsq_solution)
1572def _lstsq_solution_impl(b, bcpy, n):
1573    if b.ndim == 1:
1574        def oneD_impl(b, bcpy, n):
1575            return bcpy.T.ravel()[:n]
1576        return oneD_impl
1577    else:
1578        def twoD_impl(b, bcpy, n):
1579            return bcpy[:n, :].copy()
1580        return twoD_impl
1581
1582
1583@overload(np.linalg.lstsq)
1584def lstsq_impl(a, b, rcond=-1.0):
1585    ensure_lapack()
1586
1587    _check_linalg_matrix(a, "lstsq")
1588
1589    # B can be 1D or 2D.
1590    _check_linalg_1_or_2d_matrix(b, "lstsq")
1591
1592    _check_homogeneous_types("lstsq", a, b)
1593
1594    np_dt = np_support.as_dtype(a.dtype)
1595    nb_dt = a.dtype
1596
1597    # convert typing floats to np floats for use in the impl
1598    r_type = getattr(nb_dt, "underlying_float", nb_dt)
1599    real_dtype = np_support.as_dtype(r_type)
1600
1601    # lapack solver
1602    numba_ez_gelsd = _LAPACK().numba_ez_gelsd(a.dtype)
1603
1604    kind = ord(get_blas_kind(nb_dt, "lstsq"))
1605
1606    # The following functions select specialisations based on
1607    # information around 'b', a lot of this effort is required
1608    # as 'b' can be either 1D or 2D, and then there are
1609    # some optimisations available depending on real or complex
1610    # space.
1611
1612    def lstsq_impl(a, b, rcond=-1.0):
1613        n = a.shape[-1]
1614        m = a.shape[-2]
1615        nrhs = _system_compute_nrhs(b)
1616
1617        # check the systems have no inf or NaN
1618        _check_finite_matrix(a)
1619        _check_finite_matrix(b)
1620
1621        # check the system is not empty
1622        _system_check_non_empty(a, b)
1623
1624        # check the systems are dimensionally valid
1625        _system_check_dimensionally_valid(a, b)
1626
1627        minmn = min(m, n)
1628        maxmn = max(m, n)
1629
1630        # a is destroyed on exit, copy it
1631        acpy = _copy_to_fortran_order(a)
1632
1633        # b is overwritten on exit with the solution, copy allocate
1634        bcpy = np.empty((nrhs, maxmn), dtype=np_dt).T
1635        # specialised copy in due to b being 1 or 2D
1636        _system_copy_in_b(bcpy, b, nrhs)
1637
1638        # Allocate returns
1639        s = np.empty(minmn, dtype=real_dtype)
1640        rank_ptr = np.empty(1, dtype=np.int32)
1641
1642        r = numba_ez_gelsd(
1643            kind,  # kind
1644            m,  # m
1645            n,  # n
1646            nrhs,  # nrhs
1647            acpy.ctypes,  # a
1648            m,  # lda
1649            bcpy.ctypes,  # a
1650            maxmn,  # ldb
1651            s.ctypes,  # s
1652            rcond,  # rcond
1653            rank_ptr.ctypes  # rank
1654        )
1655        _handle_err_maybe_convergence_problem(r)
1656
1657        # set rank to that which was computed
1658        rank = rank_ptr[0]
1659
1660        # compute residuals
1661        if rank < n or m <= n:
1662            res = np.empty((0), dtype=real_dtype)
1663        else:
1664            # this requires additional dispatch as there's a faster
1665            # impl if the result is in the real domain (no abs() required)
1666            res = _lstsq_residual(bcpy, n, nrhs)
1667
1668        # extract 'x', the solution
1669        x = _lstsq_solution(b, bcpy, n)
1670
1671        # help liveness analysis
1672        _dummy_liveness_func([acpy.size, bcpy.size, s.size, rank_ptr.size])
1673        return (x, res, rank, s[:minmn])
1674
1675    return lstsq_impl
1676
1677
1678def _solve_compute_return(b, bcpy):
1679    """
1680    Extract 'x' (the solution) from the 'bcpy' scratch space.
1681    Note 'b' is only used to check the system input dimension...
1682    """
1683    raise NotImplementedError
1684
1685
1686@overload(_solve_compute_return)
1687def _solve_compute_return_impl(b, bcpy):
1688    if b.ndim == 1:
1689        def oneD_impl(b, bcpy):
1690            return bcpy.T.ravel()
1691        return oneD_impl
1692    else:
1693        def twoD_impl(b, bcpy):
1694            return bcpy
1695        return twoD_impl
1696
1697
1698@overload(np.linalg.solve)
1699def solve_impl(a, b):
1700    ensure_lapack()
1701
1702    _check_linalg_matrix(a, "solve")
1703    _check_linalg_1_or_2d_matrix(b, "solve")
1704
1705    _check_homogeneous_types("solve", a, b)
1706
1707    np_dt = np_support.as_dtype(a.dtype)
1708    nb_dt = a.dtype
1709
1710    # the lapack solver
1711    numba_xgesv = _LAPACK().numba_xgesv(a.dtype)
1712
1713    kind = ord(get_blas_kind(nb_dt, "solve"))
1714
1715    def solve_impl(a, b):
1716        n = a.shape[-1]
1717        nrhs = _system_compute_nrhs(b)
1718
1719        # check the systems have no inf or NaN
1720        _check_finite_matrix(a)
1721        _check_finite_matrix(b)
1722
1723        # check the systems are dimensionally valid
1724        _system_check_dimensionally_valid(a, b)
1725
1726        # a is destroyed on exit, copy it
1727        acpy = _copy_to_fortran_order(a)
1728
1729        # b is overwritten on exit with the solution, copy allocate
1730        bcpy = np.empty((nrhs, n), dtype=np_dt).T
1731        if n == 0:
1732            return _solve_compute_return(b, bcpy)
1733
1734        # specialised copy in due to b being 1 or 2D
1735        _system_copy_in_b(bcpy, b, nrhs)
1736
1737        # allocate pivot array (needs to be fortran int size)
1738        ipiv = np.empty(n, dtype=F_INT_nptype)
1739
1740        r = numba_xgesv(
1741            kind,        # kind
1742            n,           # n
1743            nrhs,        # nhrs
1744            acpy.ctypes,  # a
1745            n,           # lda
1746            ipiv.ctypes,  # ipiv
1747            bcpy.ctypes,  # b
1748            n            # ldb
1749        )
1750        _inv_err_handler(r)
1751
1752        # help liveness analysis
1753        _dummy_liveness_func([acpy.size, bcpy.size, ipiv.size])
1754        return _solve_compute_return(b, bcpy)
1755
1756    return solve_impl
1757
1758
1759@overload(np.linalg.pinv)
1760def pinv_impl(a, rcond=1.e-15):
1761    ensure_lapack()
1762
1763    _check_linalg_matrix(a, "pinv")
1764
1765    # convert typing floats to numpy floats for use in the impl
1766    s_type = getattr(a.dtype, "underlying_float", a.dtype)
1767    s_dtype = np_support.as_dtype(s_type)
1768
1769    numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype)
1770
1771    numba_xxgemm = _BLAS().numba_xxgemm(a.dtype)
1772
1773    kind = ord(get_blas_kind(a.dtype, "pinv"))
1774    JOB = ord('S')
1775
1776    # need conjugate transposes
1777    TRANSA = ord('C')
1778    TRANSB = ord('C')
1779
1780    # scalar constants
1781    dt = np_support.as_dtype(a.dtype)
1782    zero = np.array([0.], dtype=dt)
1783    one = np.array([1.], dtype=dt)
1784
1785    def pinv_impl(a, rcond=1.e-15):
1786
1787        # The idea is to build the pseudo-inverse via inverting the singular
1788        # value decomposition of a matrix `A`. Mathematically, this is roughly
1789        # A = U*S*V^H        [The SV decomposition of A]
1790        # A^+ = V*(S^+)*U^H  [The inverted SV decomposition of A]
1791        # where ^+ is pseudo inversion and ^H is Hermitian transpose.
1792        # As V and U are unitary, their inverses are simply their Hermitian
1793        # transpose. S has singular values on its diagonal and zero elsewhere,
1794        # it is inverted trivially by reciprocal of the diagonal values with
1795        # the exception that zero singular values remain as zero.
1796        #
1797        # The practical implementation can take advantage of a few things to
1798        # gain a few % performance increase:
1799        # * A is destroyed by the SVD algorithm from LAPACK so a copy is
1800        #   required, this memory is exactly the right size in which to return
1801        #   the pseudo-inverse and so can be reused for this purpose.
1802        # * The pseudo-inverse of S can be applied to either V or U^H, this
1803        #   then leaves a GEMM operation to compute the inverse via either:
1804        #   A^+ = (V*(S^+))*U^H
1805        #   or
1806        #   A^+ = V*((S^+)*U^H)
1807        #   however application of S^+ to V^H or U is more convenient as they
1808        #   are the result of the SVD algorithm. The application of the
1809        #   diagonal system is just a matrix multiplication which results in a
1810        #   row/column scaling (direction depending). To save effort, this
1811        #   "matrix multiplication" is applied to the smallest of U or V^H and
1812        #   only up to the point of "cut-off" (see next note) just as a direct
1813        #   scaling.
1814        # * The cut-off level for application of S^+ can be used to reduce
1815        #   total effort, this cut-off can come via rcond or may just naturally
1816        #   be present as a result of zeros in the singular values. Regardless
1817        #   there's no need to multiply by zeros in the application of S^+ to
1818        #   V^H or U as above. Further, the GEMM operation can be shrunk in
1819        #   effort by noting that the possible zero block generated by the
1820        #   presence of zeros in S^+ has no effect apart from wasting cycles as
1821        #   it is all fmadd()s where one operand is zero. The inner dimension
1822        #   of the GEMM operation can therefore be set as shrunk accordingly!
1823
1824        n = a.shape[-1]
1825        m = a.shape[-2]
1826
1827        _check_finite_matrix(a)
1828
1829        acpy = _copy_to_fortran_order(a)
1830
1831        if m == 0 or n == 0:
1832            return acpy.T.ravel().reshape(a.shape).T
1833
1834        minmn = min(m, n)
1835
1836        u = np.empty((minmn, m), dtype=a.dtype)
1837        s = np.empty(minmn, dtype=s_dtype)
1838        vt = np.empty((n, minmn), dtype=a.dtype)
1839
1840        r = numba_ez_gesdd(
1841            kind,         # kind
1842            JOB,          # job
1843            m,            # m
1844            n,            # n
1845            acpy.ctypes,  # a
1846            m,            # lda
1847            s.ctypes,     # s
1848            u.ctypes,     # u
1849            m,            # ldu
1850            vt.ctypes,    # vt
1851            minmn         # ldvt
1852        )
1853        _handle_err_maybe_convergence_problem(r)
1854
1855        # Invert singular values under threshold. Also find the index of
1856        # the threshold value as this is the upper limit for the application
1857        # of the inverted singular values. Finding this value saves
1858        # multiplication by a block of zeros that would be created by the
1859        # application of these values to either U or V^H ahead of multiplying
1860        # them together. This is done by simply in BLAS parlance via
1861        # restricting the `k` dimension to `cut_idx` in `xgemm` whilst keeping
1862        # the leading dimensions correct.
1863
1864        cut_at = s[0] * rcond
1865        cut_idx = 0
1866        for k in range(minmn):
1867            if s[k] > cut_at:
1868                s[k] = 1. / s[k]
1869                cut_idx = k
1870        cut_idx += 1
1871
1872        # Use cut_idx so there's no scaling by 0.
1873        if m >= n:
1874            # U is largest so apply S^+ to V^H.
1875            for i in range(n):
1876                for j in range(cut_idx):
1877                    vt[i, j] = vt[i, j] * s[j]
1878        else:
1879            # V^H is largest so apply S^+ to U.
1880            for i in range(cut_idx):
1881                s_local = s[i]
1882                for j in range(minmn):
1883                    u[i, j] = u[i, j] * s_local
1884
1885        # Do (v^H)^H*U^H (obviously one of the matrices includes the S^+
1886        # scaling) and write back to acpy. Note the innner dimension of cut_idx
1887        # taking account of the possible zero block.
1888        # We can store the result in acpy, given we had to create it
1889        # for use in the SVD, and it is now redundant and the right size
1890        # but wrong shape.
1891
1892        r = numba_xxgemm(
1893            kind,
1894            TRANSA,       # TRANSA
1895            TRANSB,       # TRANSB
1896            n,            # M
1897            m,            # N
1898            cut_idx,      # K
1899            one.ctypes,   # ALPHA
1900            vt.ctypes,    # A
1901            minmn,        # LDA
1902            u.ctypes,     # B
1903            m,            # LDB
1904            zero.ctypes,  # BETA
1905            acpy.ctypes,  # C
1906            n             # LDC
1907        )
1908
1909        # help liveness analysis
1910        #acpy.size
1911        #vt.size
1912        #u.size
1913        #s.size
1914        #one.size
1915        #zero.size
1916        _dummy_liveness_func([acpy.size, vt.size, u.size, s.size, one.size,
1917            zero.size])
1918        return acpy.T.ravel().reshape(a.shape).T
1919
1920    return pinv_impl
1921
1922
1923def _get_slogdet_diag_walker(a):
1924    """
1925    Walks the diag of a LUP decomposed matrix
1926    uses that det(A) = prod(diag(lup(A)))
1927    and also that log(a)+log(b) = log(a*b)
1928    The return sign is adjusted based on the values found
1929    such that the log(value) stays in the real domain.
1930    """
1931    if isinstance(a.dtype, types.Complex):
1932        @register_jitable
1933        def cmplx_diag_walker(n, a, sgn):
1934            # walk diagonal
1935            csgn = sgn + 0.j
1936            acc = 0.
1937            for k in range(n):
1938                absel = np.abs(a[k, k])
1939                csgn = csgn * (a[k, k] / absel)
1940                acc = acc + np.log(absel)
1941            return (csgn, acc)
1942        return cmplx_diag_walker
1943    else:
1944        @register_jitable
1945        def real_diag_walker(n, a, sgn):
1946            # walk diagonal
1947            acc = 0.
1948            for k in range(n):
1949                v = a[k, k]
1950                if v < 0.:
1951                    sgn = -sgn
1952                    v = -v
1953                acc = acc + np.log(v)
1954            # sgn is a float dtype
1955            return (sgn + 0., acc)
1956        return real_diag_walker
1957
1958
1959@overload(np.linalg.slogdet)
1960def slogdet_impl(a):
1961    ensure_lapack()
1962
1963    _check_linalg_matrix(a, "slogdet")
1964
1965    numba_xxgetrf = _LAPACK().numba_xxgetrf(a.dtype)
1966
1967    kind = ord(get_blas_kind(a.dtype, "slogdet"))
1968
1969    diag_walker = _get_slogdet_diag_walker(a)
1970
1971    ONE = a.dtype(1)
1972    ZERO = getattr(a.dtype, "underlying_float", a.dtype)(0)
1973
1974    def slogdet_impl(a):
1975        n = a.shape[-1]
1976        if a.shape[-2] != n:
1977            msg = "Last 2 dimensions of the array must be square."
1978            raise np.linalg.LinAlgError(msg)
1979
1980        if n == 0:
1981            return (ONE, ZERO)
1982
1983        _check_finite_matrix(a)
1984
1985        acpy = _copy_to_fortran_order(a)
1986
1987        ipiv = np.empty(n, dtype=F_INT_nptype)
1988
1989        r = numba_xxgetrf(kind, n, n, acpy.ctypes, n, ipiv.ctypes)
1990
1991        if r > 0:
1992            # factorisation failed, return same defaults as np
1993            return (0., -np.inf)
1994        _inv_err_handler(r)  # catch input-to-lapack problem
1995
1996        # The following, prior to the call to diag_walker, is present
1997        # to account for the effect of possible permutations to the
1998        # sign of the determinant.
1999        # This is the same idea as in numpy:
2000        # File name `umath_linalg.c.src` e.g.
2001        # https://github.com/numpy/numpy/blob/master/numpy/linalg/umath_linalg.c.src
2002        # in function `@TYPE@_slogdet_single_element`.
2003        sgn = 1
2004        for k in range(n):
2005            sgn = sgn + (ipiv[k] != (k + 1))
2006
2007        sgn = sgn & 1
2008        if sgn == 0:
2009            sgn = -1
2010
2011        # help liveness analysis
2012        _dummy_liveness_func([ipiv.size])
2013        return diag_walker(n, acpy, sgn)
2014
2015    return slogdet_impl
2016
2017
2018@overload(np.linalg.det)
2019def det_impl(a):
2020
2021    ensure_lapack()
2022
2023    _check_linalg_matrix(a, "det")
2024
2025    def det_impl(a):
2026        (sgn, slogdet) = np.linalg.slogdet(a)
2027        return sgn * np.exp(slogdet)
2028
2029    return det_impl
2030
2031
2032def _compute_singular_values(a):
2033    """
2034    Compute singular values of *a*.
2035    """
2036    raise NotImplementedError
2037
2038
2039@overload(_compute_singular_values)
2040def _compute_singular_values_impl(a):
2041    """
2042    Returns a function to compute singular values of `a`
2043    """
2044    numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype)
2045
2046    kind = ord(get_blas_kind(a.dtype, "svd"))
2047
2048    # Flag for "only compute `S`" to give to xgesdd
2049    JOBZ_N = ord('N')
2050
2051    nb_ret_type = getattr(a.dtype, "underlying_float", a.dtype)
2052    np_ret_type = np_support.as_dtype(nb_ret_type)
2053    np_dtype = np_support.as_dtype(a.dtype)
2054
2055    # These are not referenced in the computation but must be set
2056    # for MKL.
2057    u = np.empty((1, 1), dtype=np_dtype)
2058    vt = np.empty((1, 1), dtype=np_dtype)
2059
2060    def sv_function(a):
2061        """
2062        Computes singular values.
2063        """
2064        # Don't use the np.linalg.svd impl instead
2065        # call LAPACK to shortcut doing the "reconstruct
2066        # singular vectors from reflectors" step and just
2067        # get back the singular values.
2068        n = a.shape[-1]
2069        m = a.shape[-2]
2070        if m == 0 or n == 0:
2071            raise np.linalg.LinAlgError('Arrays cannot be empty')
2072        _check_finite_matrix(a)
2073
2074        ldu = m
2075        minmn = min(m, n)
2076
2077        # need to be >=1 but aren't referenced
2078        ucol = 1
2079        ldvt = 1
2080
2081        acpy = _copy_to_fortran_order(a)
2082
2083        # u and vt are not referenced however need to be
2084        # allocated (as done above) for MKL as it
2085        # checks for ref is nullptr.
2086        s = np.empty(minmn, dtype=np_ret_type)
2087
2088        r = numba_ez_gesdd(
2089            kind,        # kind
2090            JOBZ_N,      # jobz
2091            m,           # m
2092            n,           # n
2093            acpy.ctypes,  # a
2094            m,           # lda
2095            s.ctypes,    # s
2096            u.ctypes,    # u
2097            ldu,         # ldu
2098            vt.ctypes,   # vt
2099            ldvt         # ldvt
2100        )
2101        _handle_err_maybe_convergence_problem(r)
2102
2103        # help liveness analysis
2104        _dummy_liveness_func([acpy.size, vt.size, u.size, s.size])
2105        return s
2106
2107    return sv_function
2108
2109
2110def _oneD_norm_2(a):
2111    """
2112    Compute the L2-norm of 1D-array *a*.
2113    """
2114    raise NotImplementedError
2115
2116
2117@overload(_oneD_norm_2)
2118def _oneD_norm_2_impl(a):
2119    nb_ret_type = getattr(a.dtype, "underlying_float", a.dtype)
2120    np_ret_type = np_support.as_dtype(nb_ret_type)
2121
2122    xxnrm2 = _BLAS().numba_xxnrm2(a.dtype)
2123
2124    kind = ord(get_blas_kind(a.dtype, "norm"))
2125
2126    def impl(a):
2127        # Just ignore order, calls are guarded to only come
2128        # from cases where order=None or order=2.
2129        n = len(a)
2130        # Call L2-norm routine from BLAS
2131        ret = np.empty((1,), dtype=np_ret_type)
2132        jmp = int(a.strides[0] / a.itemsize)
2133        r = xxnrm2(
2134            kind,      # kind
2135            n,         # n
2136            a.ctypes,  # x
2137            jmp,       # incx
2138            ret.ctypes  # result
2139        )
2140        if r < 0:
2141            fatal_error_func()
2142            assert 0   # unreachable
2143
2144        # help liveness analysis
2145        #ret.size
2146        #a.size
2147        _dummy_liveness_func([ret.size, a.size])
2148        return ret[0]
2149
2150    return impl
2151
2152
2153def _get_norm_impl(a, ord_flag):
2154    # This function is quite involved as norm supports a large
2155    # range of values to select different norm types via kwarg `ord`.
2156    # The implementation below branches on dimension of the input
2157    # (1D or 2D). The default for `ord` is `None` which requires
2158    # special handling in numba, this is dealt with first in each of
2159    # the dimension branches. Following this the various norms are
2160    # computed via code that is in most cases simply a loop version
2161    # of a ufunc based version as found in numpy.
2162
2163    # The following is common to both 1D and 2D cases.
2164    # Convert typing floats to numpy floats for use in the impl.
2165    # The return type is always a float, numba differs from numpy in
2166    # that it returns an input precision specific value whereas numpy
2167    # always returns np.float64.
2168    nb_ret_type = getattr(a.dtype, "underlying_float", a.dtype)
2169    np_ret_type = np_support.as_dtype(nb_ret_type)
2170
2171    np_dtype = np_support.as_dtype(a.dtype)
2172
2173    xxnrm2 = _BLAS().numba_xxnrm2(a.dtype)
2174
2175    kind = ord(get_blas_kind(a.dtype, "norm"))
2176
2177    if a.ndim == 1:
2178        # 1D cases
2179
2180        # handle "ord" being "None", must be done separately
2181        if ord_flag in (None, types.none):
2182            def oneD_impl(a, ord=None):
2183                return _oneD_norm_2(a)
2184        else:
2185            def oneD_impl(a, ord=None):
2186                n = len(a)
2187
2188                # Shortcut to handle zero length arrays
2189                # this differs slightly to numpy in that
2190                # numpy raises a ValueError for kwarg ord=
2191                # +/-np.inf as the reduction operations like
2192                # max() and min() don't accept zero length
2193                # arrays
2194                if n == 0:
2195                    return 0.0
2196
2197                # Note: on order == 2
2198                # This is the same as for ord=="None" but because
2199                # we have to handle "None" specially this condition
2200                # is separated
2201                if ord == 2:
2202                    return _oneD_norm_2(a)
2203                elif ord == np.inf:
2204                    # max(abs(a))
2205                    ret = abs(a[0])
2206                    for k in range(1, n):
2207                        val = abs(a[k])
2208                        if val > ret:
2209                            ret = val
2210                    return ret
2211
2212                elif ord == -np.inf:
2213                    # min(abs(a))
2214                    ret = abs(a[0])
2215                    for k in range(1, n):
2216                        val = abs(a[k])
2217                        if val < ret:
2218                            ret = val
2219                    return ret
2220
2221                elif ord == 0:
2222                    # sum(a != 0)
2223                    ret = 0.0
2224                    for k in range(n):
2225                        if a[k] != 0.:
2226                            ret += 1.
2227                    return ret
2228
2229                elif ord == 1:
2230                    # sum(abs(a))
2231                    ret = 0.0
2232                    for k in range(n):
2233                        ret += abs(a[k])
2234                    return ret
2235
2236                else:
2237                    # sum(abs(a)**ord)**(1./ord)
2238                    ret = 0.0
2239                    for k in range(n):
2240                        ret += abs(a[k])**ord
2241                    return ret**(1. / ord)
2242        return oneD_impl
2243
2244    elif a.ndim == 2:
2245        # 2D cases
2246
2247        # handle "ord" being "None"
2248        if ord_flag in (None, types.none):
2249            # Force `a` to be C-order, so that we can take a contiguous
2250            # 1D view.
2251            if a.layout == 'C':
2252                @register_jitable
2253                def array_prepare(a):
2254                    return a
2255            elif a.layout == 'F':
2256                @register_jitable
2257                def array_prepare(a):
2258                    # Legal since L2(a) == L2(a.T)
2259                    return a.T
2260            else:
2261                @register_jitable
2262                def array_prepare(a):
2263                    return a.copy()
2264
2265            # Compute the Frobenius norm, this is the L2,2 induced norm of `A`
2266            # which is the L2-norm of A.ravel() and so can be computed via BLAS
2267            def twoD_impl(a, ord=None):
2268                n = a.size
2269                if n == 0:
2270                    # reshape() currently doesn't support zero-sized arrays
2271                    return 0.0
2272                a_c = array_prepare(a)
2273                return _oneD_norm_2(a_c.reshape(n))
2274        else:
2275            # max value for this dtype
2276            max_val = np.finfo(np_ret_type.type).max
2277
2278            def twoD_impl(a, ord=None):
2279                n = a.shape[-1]
2280                m = a.shape[-2]
2281
2282                # Shortcut to handle zero size arrays
2283                # this differs slightly to numpy in that
2284                # numpy raises errors for some ord values
2285                # and in other cases returns zero.
2286                if a.size == 0:
2287                    return 0.0
2288
2289                if ord == np.inf:
2290                    # max of sum of abs across rows
2291                    # max(sum(abs(a)), axis=1)
2292                    global_max = 0.
2293                    for ii in range(m):
2294                        tmp = 0.
2295                        for jj in range(n):
2296                            tmp += abs(a[ii, jj])
2297                        if tmp > global_max:
2298                            global_max = tmp
2299                    return global_max
2300
2301                elif ord == -np.inf:
2302                    # min of sum of abs across rows
2303                    # min(sum(abs(a)), axis=1)
2304                    global_min = max_val
2305                    for ii in range(m):
2306                        tmp = 0.
2307                        for jj in range(n):
2308                            tmp += abs(a[ii, jj])
2309                        if tmp < global_min:
2310                            global_min = tmp
2311                    return global_min
2312                elif ord == 1:
2313                    # max of sum of abs across cols
2314                    # max(sum(abs(a)), axis=0)
2315                    global_max = 0.
2316                    for ii in range(n):
2317                        tmp = 0.
2318                        for jj in range(m):
2319                            tmp += abs(a[jj, ii])
2320                        if tmp > global_max:
2321                            global_max = tmp
2322                    return global_max
2323
2324                elif ord == -1:
2325                    # min of sum of abs across cols
2326                    # min(sum(abs(a)), axis=0)
2327                    global_min = max_val
2328                    for ii in range(n):
2329                        tmp = 0.
2330                        for jj in range(m):
2331                            tmp += abs(a[jj, ii])
2332                        if tmp < global_min:
2333                            global_min = tmp
2334                    return global_min
2335
2336                # Results via SVD, singular values are sorted on return
2337                # by definition.
2338                elif ord == 2:
2339                    # max SV
2340                    return _compute_singular_values(a)[0]
2341                elif ord == -2:
2342                    # min SV
2343                    return _compute_singular_values(a)[-1]
2344                else:
2345                    # replicate numpy error
2346                    raise ValueError("Invalid norm order for matrices.")
2347        return twoD_impl
2348    else:
2349        assert 0  # unreachable
2350
2351
2352@overload(np.linalg.norm)
2353def norm_impl(a, ord=None):
2354    ensure_lapack()
2355
2356    _check_linalg_1_or_2d_matrix(a, "norm")
2357
2358    return _get_norm_impl(a, ord)
2359
2360
2361@overload(np.linalg.cond)
2362def cond_impl(a, p=None):
2363    ensure_lapack()
2364
2365    _check_linalg_matrix(a, "cond")
2366
2367    def impl(a, p=None):
2368        # This is extracted for performance, numpy does approximately:
2369        # `condition = norm(a) * norm(inv(a))`
2370        # in the cases of `p == 2` or `p ==-2` singular values are used
2371        # for computing norms. This costs numpy an svd of `a` then an
2372        # inversion of `a` and another svd of `a`.
2373        # Below is a different approach, which also gives a more
2374        # accurate answer as there is no inversion involved.
2375        # Recall that the singular values of an inverted matrix are the
2376        # reciprocal of singular values of the original matrix.
2377        # Therefore calling `svd(a)` once yields all the information
2378        # needed about both `a` and `inv(a)` without the cost or
2379        # potential loss of accuracy incurred through inversion.
2380        # For the case of `p == 2`, the result is just the ratio of
2381        # `largest singular value/smallest singular value`, and for the
2382        # case of `p==-2` the result is simply the
2383        # `smallest singular value/largest singular value`.
2384        # As a result of this, numba accepts non-square matrices as
2385        # input when p==+/-2 as well as when p==None.
2386        if p == 2 or p == -2 or p is None:
2387            s = _compute_singular_values(a)
2388            if p == 2 or p is None:
2389                r = np.divide(s[0], s[-1])
2390            else:
2391                r = np.divide(s[-1], s[0])
2392        else:  # cases np.inf, -np.inf, 1, -1
2393            norm_a = np.linalg.norm(a, p)
2394            norm_inv_a = np.linalg.norm(np.linalg.inv(a), p)
2395            r = norm_a * norm_inv_a
2396        # NumPy uses a NaN mask, if the input has a NaN, it will return NaN,
2397        # Numba calls ban NaN through the use of _check_finite_matrix but this
2398        # catches cases where NaN occurs through floating point use
2399        if np.isnan(r):
2400            return np.inf
2401        else:
2402            return r
2403    return impl
2404
2405
2406@register_jitable
2407def _get_rank_from_singular_values(sv, t):
2408    """
2409    Gets rank from singular values with cut-off at a given tolerance
2410    """
2411    rank = 0
2412    for k in range(len(sv)):
2413        if sv[k] > t:
2414            rank = rank + 1
2415        else:  # sv is ordered big->small so break on condition not met
2416            break
2417    return rank
2418
2419
2420@overload(np.linalg.matrix_rank)
2421def matrix_rank_impl(a, tol=None):
2422    """
2423    Computes rank for matrices and vectors.
2424    The only issue that may arise is that because numpy uses double
2425    precision lapack calls whereas numba uses type specific lapack
2426    calls, some singular values may differ and therefore counting the
2427    number of them above a tolerance may lead to different counts,
2428    and therefore rank, in some cases.
2429    """
2430    ensure_lapack()
2431
2432    _check_linalg_1_or_2d_matrix(a, "matrix_rank")
2433
2434    def _2d_matrix_rank_impl(a, tol):
2435
2436        # handle the tol==None case separately for type inference to work
2437        if tol in (None, types.none):
2438            nb_type = getattr(a.dtype, "underlying_float", a.dtype)
2439            np_type = np_support.as_dtype(nb_type)
2440            eps_val = np.finfo(np_type).eps
2441
2442            def _2d_tol_none_impl(a, tol=None):
2443                s = _compute_singular_values(a)
2444                # replicate numpy default tolerance calculation
2445                r = a.shape[0]
2446                c = a.shape[1]
2447                l = max(r, c)
2448                t = s[0] * l * eps_val
2449                return _get_rank_from_singular_values(s, t)
2450            return _2d_tol_none_impl
2451        else:
2452            def _2d_tol_not_none_impl(a, tol=None):
2453                s = _compute_singular_values(a)
2454                return _get_rank_from_singular_values(s, tol)
2455            return _2d_tol_not_none_impl
2456
2457    def _get_matrix_rank_impl(a, tol):
2458        ndim = a.ndim
2459        if ndim == 1:
2460            # NOTE: Technically, the numpy implementation could be argued as
2461            # incorrect for the case of a vector (1D matrix). If a tolerance
2462            # is provided and a vector with a singular value below tolerance is
2463            # encountered this should report a rank of zero, the numpy
2464            # implementation does not do this and instead elects to report that
2465            # if any value in the vector is nonzero then the rank is 1.
2466            # An example would be [0, 1e-15, 0, 2e-15] which numpy reports as
2467            # rank 1 invariant of `tol`. The singular value for this vector is
2468            # obviously sqrt(5)*1e-15 and so a tol of e.g. sqrt(6)*1e-15 should
2469            # lead to a reported rank of 0 whereas a tol of 1e-15 should lead
2470            # to a reported rank of 1, numpy reports 1 regardless.
2471            # The code below replicates the numpy behaviour.
2472            def _1d_matrix_rank_impl(a, tol=None):
2473                for k in range(len(a)):
2474                    if a[k] != 0.:
2475                        return 1
2476                return 0
2477            return _1d_matrix_rank_impl
2478        elif ndim == 2:
2479            return _2d_matrix_rank_impl(a, tol)
2480        else:
2481            assert 0  # unreachable
2482
2483    return _get_matrix_rank_impl(a, tol)
2484
2485
2486@overload(np.linalg.matrix_power)
2487def matrix_power_impl(a, n):
2488    """
2489    Computes matrix power. Only integer powers are supported in numpy.
2490    """
2491
2492    _check_linalg_matrix(a, "matrix_power")
2493    np_dtype = np_support.as_dtype(a.dtype)
2494
2495    nt = getattr(n, 'dtype', n)
2496    if not isinstance(nt, types.Integer):
2497        raise TypeError("Exponent must be an integer.")
2498
2499    def matrix_power_impl(a, n):
2500
2501        if n == 0:
2502            # this should be eye() but it doesn't support
2503            # the dtype kwarg yet so do it manually to save
2504            # the copy required by eye(a.shape[0]).asdtype()
2505            A = np.zeros(a.shape, dtype=np_dtype)
2506            for k in range(a.shape[0]):
2507                A[k, k] = 1.
2508            return A
2509
2510        am, an = a.shape[-1], a.shape[-2]
2511        if am != an:
2512            raise ValueError('input must be a square array')
2513
2514        # empty, return a copy
2515        if am == 0:
2516            return a.copy()
2517
2518        # note: to be consistent over contiguousness, C order is
2519        # returned as that is what dot() produces and the most common
2520        # paths through matrix_power will involve that. Therefore
2521        # copies are made here to ensure the data ordering is
2522        # correct for paths not going via dot().
2523
2524        if n < 0:
2525            A = np.linalg.inv(a).copy()
2526            if n == -1:  # return now
2527                return A
2528            n = -n
2529        else:
2530            if n == 1:  # return a copy now
2531                return a.copy()
2532            A = a  # this is safe, `a` is only read
2533
2534        if n < 4:
2535            if n == 2:
2536                return np.dot(A, A)
2537            if n == 3:
2538                return np.dot(np.dot(A, A), A)
2539        else:
2540
2541            acc = A
2542            exp = n
2543
2544            # Initialise ret, SSA cannot see the loop will execute, without this
2545            # it appears as uninitialised.
2546            ret = acc
2547            # tried a loop split and branchless using identity matrix as
2548            # input but it seems like having a "first entry" flag is quicker
2549            flag = True
2550            while exp != 0:
2551                if exp & 1:
2552                    if flag:
2553                        ret = acc
2554                        flag = False
2555                    else:
2556                        ret = np.dot(ret, acc)
2557                acc = np.dot(acc, acc)
2558                exp = exp >> 1
2559
2560            return ret
2561
2562    return matrix_power_impl
2563
2564# This is documented under linalg despite not being in the module
2565
2566
2567@overload(np.trace)
2568def matrix_trace_impl(a, offset=0):
2569    """
2570    Computes the trace of an array.
2571    """
2572
2573    _check_linalg_matrix(a, "trace", la_prefix=False)
2574
2575    if not isinstance(offset, (int, types.Integer)):
2576        raise TypeError("integer argument expected, got %s" % offset)
2577
2578    def matrix_trace_impl(a, offset=0):
2579        rows, cols = a.shape
2580        k = offset
2581        if k < 0:
2582            rows = rows + k
2583        if k > 0:
2584            cols = cols - k
2585        n = max(min(rows, cols), 0)
2586        ret = 0
2587        if k >= 0:
2588            for i in range(n):
2589                ret += a[i, k + i]
2590        else:
2591            for i in range(n):
2592                ret += a[i - k, i]
2593        return ret
2594
2595    return matrix_trace_impl
2596
2597
2598def _check_scalar_or_lt_2d_mat(a, func_name, la_prefix=True):
2599    prefix = "np.linalg" if la_prefix else "np"
2600    interp = (prefix, func_name)
2601    # checks that a matrix is 1 or 2D
2602    if isinstance(a, types.Array):
2603        if not a.ndim <= 2:
2604            raise TypingError("%s.%s() only supported on 1 and 2-D arrays "
2605                              % interp, highlighting=False)
2606
2607
2608def _get_as_array(x):
2609    if not isinstance(x, types.Array):
2610        @register_jitable
2611        def asarray(x):
2612            return np.array((x,))
2613        return asarray
2614    else:
2615        @register_jitable
2616        def asarray(x):
2617            return x
2618        return asarray
2619
2620
2621def _get_outer_impl(a, b, out):
2622    a_arr = _get_as_array(a)
2623    b_arr = _get_as_array(b)
2624
2625    if out in (None, types.none):
2626        @register_jitable
2627        def outer_impl(a, b, out):
2628            aa = a_arr(a)
2629            bb = b_arr(b)
2630            return np.multiply(aa.ravel().reshape((aa.size, 1)),
2631                               bb.ravel().reshape((1, bb.size)))
2632        return outer_impl
2633    else:
2634        @register_jitable
2635        def outer_impl(a, b, out):
2636            aa = a_arr(a)
2637            bb = b_arr(b)
2638            np.multiply(aa.ravel().reshape((aa.size, 1)),
2639                        bb.ravel().reshape((1, bb.size)),
2640                        out)
2641            return out
2642        return outer_impl
2643
2644
2645@overload(np.outer)
2646def outer_impl(a, b, out=None):
2647
2648    _check_scalar_or_lt_2d_mat(a, "outer", la_prefix=False)
2649    _check_scalar_or_lt_2d_mat(b, "outer", la_prefix=False)
2650
2651    impl = _get_outer_impl(a, b, out)
2652
2653    def outer_impl(a, b, out=None):
2654        return impl(a, b, out)
2655
2656    return outer_impl
2657
2658
2659def _kron_normaliser_impl(x):
2660    # makes x into a 2d array
2661    if isinstance(x, types.Array):
2662        if x.layout not in ('C', 'F'):
2663            raise TypingError("np.linalg.kron only supports 'C' or 'F' layout "
2664                              "input arrays. Receieved an input of "
2665                              "layout '{}'.".format(x.layout))
2666        elif x.ndim == 2:
2667            @register_jitable
2668            def nrm_shape(x):
2669                xn = x.shape[-1]
2670                xm = x.shape[-2]
2671                return x.reshape(xm, xn)
2672            return nrm_shape
2673        else:
2674            @register_jitable
2675            def nrm_shape(x):
2676                xn = x.shape[-1]
2677                return x.reshape(1, xn)
2678            return nrm_shape
2679    else:  # assume its a scalar
2680        @register_jitable
2681        def nrm_shape(x):
2682            a = np.empty((1, 1), type(x))
2683            a[0] = x
2684            return a
2685        return nrm_shape
2686
2687
2688def _kron_return(a, b):
2689    # transforms c into something that kron would return
2690    # based on the shapes of a and b
2691    a_is_arr = isinstance(a, types.Array)
2692    b_is_arr = isinstance(b, types.Array)
2693    if a_is_arr and b_is_arr:
2694        if a.ndim == 2 or b.ndim == 2:
2695            @register_jitable
2696            def ret(a, b, c):
2697                return c
2698            return ret
2699        else:
2700            @register_jitable
2701            def ret(a, b, c):
2702                return c.reshape(c.size)
2703            return ret
2704    else:  # at least one of (a, b) is a scalar
2705        if a_is_arr:
2706            @register_jitable
2707            def ret(a, b, c):
2708                return c.reshape(a.shape)
2709            return ret
2710        elif b_is_arr:
2711            @register_jitable
2712            def ret(a, b, c):
2713                return c.reshape(b.shape)
2714            return ret
2715        else:  # both scalars
2716            @register_jitable
2717            def ret(a, b, c):
2718                return c[0]
2719            return ret
2720
2721
2722@overload(np.kron)
2723def kron_impl(a, b):
2724
2725    _check_scalar_or_lt_2d_mat(a, "kron", la_prefix=False)
2726    _check_scalar_or_lt_2d_mat(b, "kron", la_prefix=False)
2727
2728    fix_a = _kron_normaliser_impl(a)
2729    fix_b = _kron_normaliser_impl(b)
2730    ret_c = _kron_return(a, b)
2731
2732    # this is fine because the ufunc for the Hadamard product
2733    # will reject differing dtypes in a and b.
2734    dt = getattr(a, 'dtype', a)
2735
2736    def kron_impl(a, b):
2737
2738        aa = fix_a(a)
2739        bb = fix_b(b)
2740
2741        am = aa.shape[-2]
2742        an = aa.shape[-1]
2743        bm = bb.shape[-2]
2744        bn = bb.shape[-1]
2745
2746        cm = am * bm
2747        cn = an * bn
2748
2749        # allocate c
2750        C = np.empty((cm, cn), dtype=dt)
2751
2752        # In practice this is runs quicker than the more obvious
2753        # `each element of A multiplied by B and assigned to
2754        # a block in C` like alg.
2755
2756        # loop over rows of A
2757        for i in range(am):
2758            # compute the column offset into C
2759            rjmp = i * bm
2760            # loop over rows of B
2761            for k in range(bm):
2762                # compute row the offset into C
2763                irjmp = rjmp + k
2764                # slice a given row of B
2765                slc = bb[k, :]
2766                # loop over columns of A
2767                for j in range(an):
2768                    # vectorized assignment of an element of A
2769                    # multiplied by the current row of B into
2770                    # a slice of a row of C
2771                    cjmp = j * bn
2772                    C[irjmp, cjmp:cjmp + bn] = aa[i, j] * slc
2773
2774        return ret_c(a, b, C)
2775
2776    return kron_impl
2777