1""" 2Implementation of linear algebra operations. 3""" 4 5 6import contextlib 7 8from llvmlite import ir 9 10import numpy as np 11import operator 12 13from numba.core.imputils import (lower_builtin, impl_ret_borrowed, 14 impl_ret_new_ref, impl_ret_untracked) 15from numba.core.typing import signature 16from numba.core.extending import overload, register_jitable 17from numba.core import types, cgutils 18from numba.core.errors import TypingError 19from .arrayobj import make_array, _empty_nd_impl, array_copy 20from numba.np import numpy_support as np_support 21 22ll_char = ir.IntType(8) 23ll_char_p = ll_char.as_pointer() 24ll_void_p = ll_char_p 25ll_intc = ir.IntType(32) 26ll_intc_p = ll_intc.as_pointer() 27intp_t = cgutils.intp_t 28ll_intp_p = intp_t.as_pointer() 29 30 31# fortran int type, this needs to match the F_INT C declaration in 32# _lapack.c and is present to accommodate potential future 64bit int 33# based LAPACK use. 34F_INT_nptype = np.int32 35F_INT_nbtype = types.int32 36 37# BLAS kinds as letters 38_blas_kinds = { 39 types.float32: 's', 40 types.float64: 'd', 41 types.complex64: 'c', 42 types.complex128: 'z', 43} 44 45 46def get_blas_kind(dtype, func_name="<BLAS function>"): 47 kind = _blas_kinds.get(dtype) 48 if kind is None: 49 raise TypeError("unsupported dtype for %s()" % (func_name,)) 50 return kind 51 52 53def ensure_blas(): 54 try: 55 import scipy.linalg.cython_blas 56 except ImportError: 57 raise ImportError("scipy 0.16+ is required for linear algebra") 58 59 60def ensure_lapack(): 61 try: 62 import scipy.linalg.cython_lapack 63 except ImportError: 64 raise ImportError("scipy 0.16+ is required for linear algebra") 65 66 67def make_constant_slot(context, builder, ty, val): 68 const = context.get_constant_generic(builder, ty, val) 69 return cgutils.alloca_once_value(builder, const) 70 71 72class _BLAS: 73 """ 74 Functions to return type signatures for wrapped 75 BLAS functions. 76 """ 77 78 def __init__(self): 79 ensure_blas() 80 81 @classmethod 82 def numba_xxnrm2(cls, dtype): 83 rtype = getattr(dtype, "underlying_float", dtype) 84 sig = types.intc(types.char, # kind 85 types.intp, # n 86 types.CPointer(dtype), # x 87 types.intp, # incx 88 types.CPointer(rtype)) # returned 89 90 return types.ExternalFunction("numba_xxnrm2", sig) 91 92 @classmethod 93 def numba_xxgemm(cls, dtype): 94 sig = types.intc( 95 types.char, # kind 96 types.char, # transa 97 types.char, # transb 98 types.intp, # m 99 types.intp, # n 100 types.intp, # k 101 types.CPointer(dtype), # alpha 102 types.CPointer(dtype), # a 103 types.intp, # lda 104 types.CPointer(dtype), # b 105 types.intp, # ldb 106 types.CPointer(dtype), # beta 107 types.CPointer(dtype), # c 108 types.intp # ldc 109 ) 110 return types.ExternalFunction("numba_xxgemm", sig) 111 112 113class _LAPACK: 114 """ 115 Functions to return type signatures for wrapped 116 LAPACK functions. 117 """ 118 119 def __init__(self): 120 ensure_lapack() 121 122 @classmethod 123 def numba_xxgetrf(cls, dtype): 124 sig = types.intc(types.char, # kind 125 types.intp, # m 126 types.intp, # n 127 types.CPointer(dtype), # a 128 types.intp, # lda 129 types.CPointer(F_INT_nbtype) # ipiv 130 ) 131 return types.ExternalFunction("numba_xxgetrf", sig) 132 133 @classmethod 134 def numba_ez_xxgetri(cls, dtype): 135 sig = types.intc(types.char, # kind 136 types.intp, # n 137 types.CPointer(dtype), # a 138 types.intp, # lda 139 types.CPointer(F_INT_nbtype) # ipiv 140 ) 141 return types.ExternalFunction("numba_ez_xxgetri", sig) 142 143 @classmethod 144 def numba_ez_rgeev(cls, dtype): 145 sig = types.intc(types.char, # kind 146 types.char, # jobvl 147 types.char, # jobvr 148 types.intp, # n 149 types.CPointer(dtype), # a 150 types.intp, # lda 151 types.CPointer(dtype), # wr 152 types.CPointer(dtype), # wi 153 types.CPointer(dtype), # vl 154 types.intp, # ldvl 155 types.CPointer(dtype), # vr 156 types.intp # ldvr 157 ) 158 return types.ExternalFunction("numba_ez_rgeev", sig) 159 160 @classmethod 161 def numba_ez_cgeev(cls, dtype): 162 sig = types.intc(types.char, # kind 163 types.char, # jobvl 164 types.char, # jobvr 165 types.intp, # n 166 types.CPointer(dtype), # a 167 types.intp, # lda 168 types.CPointer(dtype), # w 169 types.CPointer(dtype), # vl 170 types.intp, # ldvl 171 types.CPointer(dtype), # vr 172 types.intp # ldvr 173 ) 174 return types.ExternalFunction("numba_ez_cgeev", sig) 175 176 @classmethod 177 def numba_ez_xxxevd(cls, dtype): 178 wtype = getattr(dtype, "underlying_float", dtype) 179 sig = types.intc(types.char, # kind 180 types.char, # jobz 181 types.char, # uplo 182 types.intp, # n 183 types.CPointer(dtype), # a 184 types.intp, # lda 185 types.CPointer(wtype), # w 186 ) 187 return types.ExternalFunction("numba_ez_xxxevd", sig) 188 189 @classmethod 190 def numba_xxpotrf(cls, dtype): 191 sig = types.intc(types.char, # kind 192 types.char, # uplo 193 types.intp, # n 194 types.CPointer(dtype), # a 195 types.intp # lda 196 ) 197 return types.ExternalFunction("numba_xxpotrf", sig) 198 199 @classmethod 200 def numba_ez_gesdd(cls, dtype): 201 stype = getattr(dtype, "underlying_float", dtype) 202 sig = types.intc( 203 types.char, # kind 204 types.char, # jobz 205 types.intp, # m 206 types.intp, # n 207 types.CPointer(dtype), # a 208 types.intp, # lda 209 types.CPointer(stype), # s 210 types.CPointer(dtype), # u 211 types.intp, # ldu 212 types.CPointer(dtype), # vt 213 types.intp # ldvt 214 ) 215 216 return types.ExternalFunction("numba_ez_gesdd", sig) 217 218 @classmethod 219 def numba_ez_geqrf(cls, dtype): 220 sig = types.intc( 221 types.char, # kind 222 types.intp, # m 223 types.intp, # n 224 types.CPointer(dtype), # a 225 types.intp, # lda 226 types.CPointer(dtype), # tau 227 ) 228 return types.ExternalFunction("numba_ez_geqrf", sig) 229 230 @classmethod 231 def numba_ez_xxgqr(cls, dtype): 232 sig = types.intc( 233 types.char, # kind 234 types.intp, # m 235 types.intp, # n 236 types.intp, # k 237 types.CPointer(dtype), # a 238 types.intp, # lda 239 types.CPointer(dtype), # tau 240 ) 241 return types.ExternalFunction("numba_ez_xxgqr", sig) 242 243 @classmethod 244 def numba_ez_gelsd(cls, dtype): 245 rtype = getattr(dtype, "underlying_float", dtype) 246 sig = types.intc( 247 types.char, # kind 248 types.intp, # m 249 types.intp, # n 250 types.intp, # nrhs 251 types.CPointer(dtype), # a 252 types.intp, # lda 253 types.CPointer(dtype), # b 254 types.intp, # ldb 255 types.CPointer(rtype), # S 256 types.float64, # rcond 257 types.CPointer(types.intc) # rank 258 ) 259 return types.ExternalFunction("numba_ez_gelsd", sig) 260 261 @classmethod 262 def numba_xgesv(cls, dtype): 263 sig = types.intc( 264 types.char, # kind 265 types.intp, # n 266 types.intp, # nhrs 267 types.CPointer(dtype), # a 268 types.intp, # lda 269 types.CPointer(F_INT_nbtype), # ipiv 270 types.CPointer(dtype), # b 271 types.intp # ldb 272 ) 273 return types.ExternalFunction("numba_xgesv", sig) 274 275 276@contextlib.contextmanager 277def make_contiguous(context, builder, sig, args): 278 """ 279 Ensure that all array arguments are contiguous, if necessary by 280 copying them. 281 A new (sig, args) tuple is yielded. 282 """ 283 newtys = [] 284 newargs = [] 285 copies = [] 286 for ty, val in zip(sig.args, args): 287 if not isinstance(ty, types.Array) or ty.layout in 'CF': 288 newty, newval = ty, val 289 else: 290 newty = ty.copy(layout='C') 291 copysig = signature(newty, ty) 292 newval = array_copy(context, builder, copysig, (val,)) 293 copies.append((newty, newval)) 294 newtys.append(newty) 295 newargs.append(newval) 296 yield signature(sig.return_type, *newtys), tuple(newargs) 297 for ty, val in copies: 298 context.nrt.decref(builder, ty, val) 299 300 301def check_c_int(context, builder, n): 302 """ 303 Check whether *n* fits in a C `int`. 304 """ 305 _maxint = 2**31 - 1 306 307 def impl(n): 308 if n > _maxint: 309 raise OverflowError("array size too large to fit in C int") 310 311 context.compile_internal(builder, impl, 312 signature(types.none, types.intp), (n,)) 313 314 315def check_blas_return(context, builder, res): 316 """ 317 Check the integer error return from one of the BLAS wrappers in 318 _helperlib.c. 319 """ 320 with builder.if_then(cgutils.is_not_null(builder, res), likely=False): 321 # Those errors shouldn't happen, it's easier to just abort the process 322 pyapi = context.get_python_api(builder) 323 pyapi.gil_ensure() 324 pyapi.fatal_error("BLAS wrapper returned with an error") 325 326 327def check_lapack_return(context, builder, res): 328 """ 329 Check the integer error return from one of the LAPACK wrappers in 330 _helperlib.c. 331 """ 332 with builder.if_then(cgutils.is_not_null(builder, res), likely=False): 333 # Those errors shouldn't happen, it's easier to just abort the process 334 pyapi = context.get_python_api(builder) 335 pyapi.gil_ensure() 336 pyapi.fatal_error("LAPACK wrapper returned with an error") 337 338 339def call_xxdot(context, builder, conjugate, dtype, 340 n, a_data, b_data, out_data): 341 """ 342 Call the BLAS vector * vector product function for the given arguments. 343 """ 344 fnty = ir.FunctionType(ir.IntType(32), 345 [ll_char, ll_char, intp_t, # kind, conjugate, n 346 ll_void_p, ll_void_p, ll_void_p, # a, b, out 347 ]) 348 fn = builder.module.get_or_insert_function(fnty, name="numba_xxdot") 349 350 kind = get_blas_kind(dtype) 351 kind_val = ir.Constant(ll_char, ord(kind)) 352 conjugate = ir.Constant(ll_char, int(conjugate)) 353 354 res = builder.call(fn, (kind_val, conjugate, n, 355 builder.bitcast(a_data, ll_void_p), 356 builder.bitcast(b_data, ll_void_p), 357 builder.bitcast(out_data, ll_void_p))) 358 check_blas_return(context, builder, res) 359 360 361def call_xxgemv(context, builder, do_trans, 362 m_type, m_shapes, m_data, v_data, out_data): 363 """ 364 Call the BLAS matrix * vector product function for the given arguments. 365 """ 366 fnty = ir.FunctionType(ir.IntType(32), 367 [ll_char, ll_char, # kind, trans 368 intp_t, intp_t, # m, n 369 ll_void_p, ll_void_p, intp_t, # alpha, a, lda 370 ll_void_p, ll_void_p, ll_void_p, # x, beta, y 371 ]) 372 fn = builder.module.get_or_insert_function(fnty, name="numba_xxgemv") 373 374 dtype = m_type.dtype 375 alpha = make_constant_slot(context, builder, dtype, 1.0) 376 beta = make_constant_slot(context, builder, dtype, 0.0) 377 378 if m_type.layout == 'F': 379 m, n = m_shapes 380 lda = m_shapes[0] 381 else: 382 n, m = m_shapes 383 lda = m_shapes[1] 384 385 kind = get_blas_kind(dtype) 386 kind_val = ir.Constant(ll_char, ord(kind)) 387 trans = ir.Constant(ll_char, ord('t') if do_trans else ord('n')) 388 389 res = builder.call(fn, (kind_val, trans, m, n, 390 builder.bitcast(alpha, ll_void_p), 391 builder.bitcast(m_data, ll_void_p), lda, 392 builder.bitcast(v_data, ll_void_p), 393 builder.bitcast(beta, ll_void_p), 394 builder.bitcast(out_data, ll_void_p))) 395 check_blas_return(context, builder, res) 396 397 398def call_xxgemm(context, builder, 399 x_type, x_shapes, x_data, 400 y_type, y_shapes, y_data, 401 out_type, out_shapes, out_data): 402 """ 403 Call the BLAS matrix * matrix product function for the given arguments. 404 """ 405 fnty = ir.FunctionType(ir.IntType(32), 406 [ll_char, # kind 407 ll_char, ll_char, # transa, transb 408 intp_t, intp_t, intp_t, # m, n, k 409 ll_void_p, ll_void_p, intp_t, # alpha, a, lda 410 ll_void_p, intp_t, ll_void_p, # b, ldb, beta 411 ll_void_p, intp_t, # c, ldc 412 ]) 413 fn = builder.module.get_or_insert_function(fnty, name="numba_xxgemm") 414 415 m, k = x_shapes 416 _k, n = y_shapes 417 dtype = x_type.dtype 418 alpha = make_constant_slot(context, builder, dtype, 1.0) 419 beta = make_constant_slot(context, builder, dtype, 0.0) 420 421 trans = ir.Constant(ll_char, ord('t')) 422 notrans = ir.Constant(ll_char, ord('n')) 423 424 def get_array_param(ty, shapes, data): 425 return ( 426 # Transpose if layout different from result's 427 notrans if ty.layout == out_type.layout else trans, 428 # Size of the inner dimension in physical array order 429 shapes[1] if ty.layout == 'C' else shapes[0], 430 # The data pointer, unit-less 431 builder.bitcast(data, ll_void_p), 432 ) 433 434 transa, lda, data_a = get_array_param(y_type, y_shapes, y_data) 435 transb, ldb, data_b = get_array_param(x_type, x_shapes, x_data) 436 _, ldc, data_c = get_array_param(out_type, out_shapes, out_data) 437 438 kind = get_blas_kind(dtype) 439 kind_val = ir.Constant(ll_char, ord(kind)) 440 441 res = builder.call(fn, (kind_val, transa, transb, n, m, k, 442 builder.bitcast(alpha, ll_void_p), data_a, lda, 443 data_b, ldb, builder.bitcast(beta, ll_void_p), 444 data_c, ldc)) 445 check_blas_return(context, builder, res) 446 447 448def dot_2_mm(context, builder, sig, args): 449 """ 450 np.dot(matrix, matrix) 451 """ 452 def dot_impl(a, b): 453 m, k = a.shape 454 _k, n = b.shape 455 if k == 0: 456 return np.zeros((m, n), a.dtype) 457 out = np.empty((m, n), a.dtype) 458 return np.dot(a, b, out) 459 460 res = context.compile_internal(builder, dot_impl, sig, args) 461 return impl_ret_new_ref(context, builder, sig.return_type, res) 462 463 464def dot_2_vm(context, builder, sig, args): 465 """ 466 np.dot(vector, matrix) 467 """ 468 def dot_impl(a, b): 469 m, = a.shape 470 _m, n = b.shape 471 if m == 0: 472 return np.zeros((n, ), a.dtype) 473 out = np.empty((n, ), a.dtype) 474 return np.dot(a, b, out) 475 476 res = context.compile_internal(builder, dot_impl, sig, args) 477 return impl_ret_new_ref(context, builder, sig.return_type, res) 478 479 480def dot_2_mv(context, builder, sig, args): 481 """ 482 np.dot(matrix, vector) 483 """ 484 def dot_impl(a, b): 485 m, n = a.shape 486 _n, = b.shape 487 if n == 0: 488 return np.zeros((m, ), a.dtype) 489 out = np.empty((m, ), a.dtype) 490 return np.dot(a, b, out) 491 492 res = context.compile_internal(builder, dot_impl, sig, args) 493 return impl_ret_new_ref(context, builder, sig.return_type, res) 494 495 496def dot_2_vv(context, builder, sig, args, conjugate=False): 497 """ 498 np.dot(vector, vector) 499 np.vdot(vector, vector) 500 """ 501 aty, bty = sig.args 502 dtype = sig.return_type 503 a = make_array(aty)(context, builder, args[0]) 504 b = make_array(bty)(context, builder, args[1]) 505 n, = cgutils.unpack_tuple(builder, a.shape) 506 507 def check_args(a, b): 508 m, = a.shape 509 n, = b.shape 510 if m != n: 511 raise ValueError("incompatible array sizes for np.dot(a, b) " 512 "(vector * vector)") 513 514 context.compile_internal(builder, check_args, 515 signature(types.none, *sig.args), args) 516 check_c_int(context, builder, n) 517 518 out = cgutils.alloca_once(builder, context.get_value_type(dtype)) 519 call_xxdot(context, builder, conjugate, dtype, n, a.data, b.data, out) 520 return builder.load(out) 521 522 523@lower_builtin(np.dot, types.Array, types.Array) 524def dot_2(context, builder, sig, args): 525 """ 526 np.dot(a, b) 527 a @ b 528 """ 529 ensure_blas() 530 531 with make_contiguous(context, builder, sig, args) as (sig, args): 532 ndims = [x.ndim for x in sig.args[:2]] 533 if ndims == [2, 2]: 534 return dot_2_mm(context, builder, sig, args) 535 elif ndims == [2, 1]: 536 return dot_2_mv(context, builder, sig, args) 537 elif ndims == [1, 2]: 538 return dot_2_vm(context, builder, sig, args) 539 elif ndims == [1, 1]: 540 return dot_2_vv(context, builder, sig, args) 541 else: 542 assert 0 543 544 545lower_builtin(operator.matmul, types.Array, types.Array)(dot_2) 546 547 548@lower_builtin(np.vdot, types.Array, types.Array) 549def vdot(context, builder, sig, args): 550 """ 551 np.vdot(a, b) 552 """ 553 ensure_blas() 554 555 with make_contiguous(context, builder, sig, args) as (sig, args): 556 return dot_2_vv(context, builder, sig, args, conjugate=True) 557 558 559def dot_3_vm_check_args(a, b, out): 560 m, = a.shape 561 _m, n = b.shape 562 if m != _m: 563 raise ValueError("incompatible array sizes for " 564 "np.dot(a, b) (vector * matrix)") 565 if out.shape != (n,): 566 raise ValueError("incompatible output array size for " 567 "np.dot(a, b, out) (vector * matrix)") 568 569 570def dot_3_mv_check_args(a, b, out): 571 m, _n = a.shape 572 n, = b.shape 573 if n != _n: 574 raise ValueError("incompatible array sizes for np.dot(a, b) " 575 "(matrix * vector)") 576 if out.shape != (m,): 577 raise ValueError("incompatible output array size for " 578 "np.dot(a, b, out) (matrix * vector)") 579 580 581def dot_3_vm(context, builder, sig, args): 582 """ 583 np.dot(vector, matrix, out) 584 np.dot(matrix, vector, out) 585 """ 586 xty, yty, outty = sig.args 587 assert outty == sig.return_type 588 dtype = xty.dtype 589 590 x = make_array(xty)(context, builder, args[0]) 591 y = make_array(yty)(context, builder, args[1]) 592 out = make_array(outty)(context, builder, args[2]) 593 x_shapes = cgutils.unpack_tuple(builder, x.shape) 594 y_shapes = cgutils.unpack_tuple(builder, y.shape) 595 out_shapes = cgutils.unpack_tuple(builder, out.shape) 596 if xty.ndim < yty.ndim: 597 # Vector * matrix 598 # Asked for x * y, we will compute y.T * x 599 mty = yty 600 m_shapes = y_shapes 601 v_shape = x_shapes[0] 602 lda = m_shapes[1] 603 do_trans = yty.layout == 'F' 604 m_data, v_data = y.data, x.data 605 check_args = dot_3_vm_check_args 606 else: 607 # Matrix * vector 608 # We will compute x * y 609 mty = xty 610 m_shapes = x_shapes 611 v_shape = y_shapes[0] 612 lda = m_shapes[0] 613 do_trans = xty.layout == 'C' 614 m_data, v_data = x.data, y.data 615 check_args = dot_3_mv_check_args 616 617 context.compile_internal(builder, check_args, 618 signature(types.none, *sig.args), args) 619 for val in m_shapes: 620 check_c_int(context, builder, val) 621 622 zero = context.get_constant(types.intp, 0) 623 both_empty = builder.icmp_signed('==', v_shape, zero) 624 matrix_empty = builder.icmp_signed('==', lda, zero) 625 is_empty = builder.or_(both_empty, matrix_empty) 626 with builder.if_else(is_empty, likely=False) as (empty, nonempty): 627 with empty: 628 cgutils.memset(builder, out.data, 629 builder.mul(out.itemsize, out.nitems), 0) 630 with nonempty: 631 call_xxgemv(context, builder, do_trans, mty, m_shapes, m_data, 632 v_data, out.data) 633 634 return impl_ret_borrowed(context, builder, sig.return_type, 635 out._getvalue()) 636 637 638def dot_3_mm(context, builder, sig, args): 639 """ 640 np.dot(matrix, matrix, out) 641 """ 642 xty, yty, outty = sig.args 643 assert outty == sig.return_type 644 dtype = xty.dtype 645 646 x = make_array(xty)(context, builder, args[0]) 647 y = make_array(yty)(context, builder, args[1]) 648 out = make_array(outty)(context, builder, args[2]) 649 x_shapes = cgutils.unpack_tuple(builder, x.shape) 650 y_shapes = cgutils.unpack_tuple(builder, y.shape) 651 out_shapes = cgutils.unpack_tuple(builder, out.shape) 652 m, k = x_shapes 653 _k, n = y_shapes 654 655 # The only case Numpy supports 656 assert outty.layout == 'C' 657 658 def check_args(a, b, out): 659 m, k = a.shape 660 _k, n = b.shape 661 if k != _k: 662 raise ValueError("incompatible array sizes for np.dot(a, b) " 663 "(matrix * matrix)") 664 if out.shape != (m, n): 665 raise ValueError("incompatible output array size for " 666 "np.dot(a, b, out) (matrix * matrix)") 667 668 context.compile_internal(builder, check_args, 669 signature(types.none, *sig.args), args) 670 671 check_c_int(context, builder, m) 672 check_c_int(context, builder, k) 673 check_c_int(context, builder, n) 674 675 x_data = x.data 676 y_data = y.data 677 out_data = out.data 678 679 # If eliminated dimension is zero, set all entries to zero and return 680 zero = context.get_constant(types.intp, 0) 681 both_empty = builder.icmp_signed('==', k, zero) 682 x_empty = builder.icmp_signed('==', m, zero) 683 y_empty = builder.icmp_signed('==', n, zero) 684 is_empty = builder.or_(both_empty, builder.or_(x_empty, y_empty)) 685 with builder.if_else(is_empty, likely=False) as (empty, nonempty): 686 with empty: 687 cgutils.memset(builder, out.data, 688 builder.mul(out.itemsize, out.nitems), 0) 689 with nonempty: 690 # Check if any of the operands is really a 1-d vector represented 691 # as a (1, k) or (k, 1) 2-d array. In those cases, it is pessimal 692 # to call the generic matrix * matrix product BLAS function. 693 one = context.get_constant(types.intp, 1) 694 is_left_vec = builder.icmp_signed('==', m, one) 695 is_right_vec = builder.icmp_signed('==', n, one) 696 697 with builder.if_else(is_right_vec) as (r_vec, r_mat): 698 with r_vec: 699 with builder.if_else(is_left_vec) as (v_v, m_v): 700 with v_v: 701 # V * V 702 call_xxdot(context, builder, False, dtype, 703 k, x_data, y_data, out_data) 704 with m_v: 705 # M * V 706 do_trans = xty.layout == outty.layout 707 call_xxgemv(context, builder, do_trans, 708 xty, x_shapes, x_data, y_data, out_data) 709 with r_mat: 710 with builder.if_else(is_left_vec) as (v_m, m_m): 711 with v_m: 712 # V * M 713 do_trans = yty.layout != outty.layout 714 call_xxgemv(context, builder, do_trans, 715 yty, y_shapes, y_data, x_data, out_data) 716 with m_m: 717 # M * M 718 call_xxgemm(context, builder, 719 xty, x_shapes, x_data, 720 yty, y_shapes, y_data, 721 outty, out_shapes, out_data) 722 723 return impl_ret_borrowed(context, builder, sig.return_type, 724 out._getvalue()) 725 726 727@lower_builtin(np.dot, types.Array, types.Array, 728 types.Array) 729def dot_3(context, builder, sig, args): 730 """ 731 np.dot(a, b, out) 732 """ 733 ensure_blas() 734 735 with make_contiguous(context, builder, sig, args) as (sig, args): 736 ndims = set(x.ndim for x in sig.args[:2]) 737 if ndims == set([2]): 738 return dot_3_mm(context, builder, sig, args) 739 elif ndims == set([1, 2]): 740 return dot_3_vm(context, builder, sig, args) 741 else: 742 assert 0 743 744fatal_error_sig = types.intc() 745fatal_error_func = types.ExternalFunction("numba_fatal_error", fatal_error_sig) 746 747 748@register_jitable 749def _check_finite_matrix(a): 750 for v in np.nditer(a): 751 if not np.isfinite(v.item()): 752 raise np.linalg.LinAlgError( 753 "Array must not contain infs or NaNs.") 754 755 756def _check_linalg_matrix(a, func_name, la_prefix=True): 757 # la_prefix is present as some functions, e.g. np.trace() 758 # are documented under "linear algebra" but aren't in the 759 # module 760 prefix = "np.linalg" if la_prefix else "np" 761 interp = (prefix, func_name) 762 # Unpack optional type 763 if isinstance(a, types.Optional): 764 a = a.type 765 if not isinstance(a, types.Array): 766 msg = "%s.%s() only supported for array types" % interp 767 raise TypingError(msg, highlighting=False) 768 if not a.ndim == 2: 769 msg = "%s.%s() only supported on 2-D arrays." % interp 770 raise TypingError(msg, highlighting=False) 771 if not isinstance(a.dtype, (types.Float, types.Complex)): 772 msg = "%s.%s() only supported on "\ 773 "float and complex arrays." % interp 774 raise TypingError(msg, highlighting=False) 775 776 777def _check_homogeneous_types(func_name, *types): 778 t0 = types[0].dtype 779 for t in types[1:]: 780 if t.dtype != t0: 781 msg = "np.linalg.%s() only supports inputs that have homogeneous dtypes." % func_name 782 raise TypingError(msg, highlighting=False) 783 784 785def _copy_to_fortran_order(): 786 pass 787 788 789@overload(_copy_to_fortran_order) 790def ol_copy_to_fortran_order(a): 791 # This function copies the array 'a' into a new array with fortran order. 792 # This exists because the copy routines don't take order flags yet. 793 F_layout = a.layout == 'F' 794 A_layout = a.layout == 'A' 795 def impl(a): 796 if F_layout: 797 # it's F ordered at compile time, just copy 798 acpy = np.copy(a) 799 elif A_layout: 800 # decide based on runtime value 801 flag_f = a.flags.f_contiguous 802 if flag_f: 803 # it's already F ordered, so copy but in a round about way to 804 # ensure that the copy is also F ordered 805 acpy = np.copy(a.T).T 806 else: 807 # it's something else ordered, so let asfortranarray deal with 808 # copying and making it fortran ordered 809 acpy = np.asfortranarray(a) 810 else: 811 # it's C ordered at compile time, asfortranarray it. 812 acpy = np.asfortranarray(a) 813 return acpy 814 return impl 815 816 817@register_jitable 818def _inv_err_handler(r): 819 if r != 0: 820 if r < 0: 821 fatal_error_func() 822 assert 0 # unreachable 823 if r > 0: 824 raise np.linalg.LinAlgError( 825 "Matrix is singular to machine precision.") 826 827@register_jitable 828def _dummy_liveness_func(a): 829 """pass a list of variables to be preserved through dead code elimination""" 830 return a[0] 831 832 833@overload(np.linalg.inv) 834def inv_impl(a): 835 ensure_lapack() 836 837 _check_linalg_matrix(a, "inv") 838 839 numba_xxgetrf = _LAPACK().numba_xxgetrf(a.dtype) 840 841 numba_xxgetri = _LAPACK().numba_ez_xxgetri(a.dtype) 842 843 kind = ord(get_blas_kind(a.dtype, "inv")) 844 845 def inv_impl(a): 846 n = a.shape[-1] 847 if a.shape[-2] != n: 848 msg = "Last 2 dimensions of the array must be square." 849 raise np.linalg.LinAlgError(msg) 850 851 _check_finite_matrix(a) 852 853 acpy = _copy_to_fortran_order(a) 854 855 if n == 0: 856 return acpy 857 858 ipiv = np.empty(n, dtype=F_INT_nptype) 859 860 r = numba_xxgetrf(kind, n, n, acpy.ctypes, n, ipiv.ctypes) 861 _inv_err_handler(r) 862 863 r = numba_xxgetri(kind, n, acpy.ctypes, n, ipiv.ctypes) 864 _inv_err_handler(r) 865 866 # help liveness analysis 867 _dummy_liveness_func([acpy.size, ipiv.size]) 868 return acpy 869 870 return inv_impl 871 872 873@register_jitable 874def _handle_err_maybe_convergence_problem(r): 875 if r != 0: 876 if r < 0: 877 fatal_error_func() 878 assert 0 # unreachable 879 if r > 0: 880 raise ValueError("Internal algorithm failed to converge.") 881 882 883def _check_linalg_1_or_2d_matrix(a, func_name, la_prefix=True): 884 # la_prefix is present as some functions, e.g. np.trace() 885 # are documented under "linear algebra" but aren't in the 886 # module 887 prefix = "np.linalg" if la_prefix else "np" 888 interp = (prefix, func_name) 889 # checks that a matrix is 1 or 2D 890 if not isinstance(a, types.Array): 891 raise TypingError("%s.%s() only supported for array types " 892 % interp) 893 if not a.ndim <= 2: 894 raise TypingError("%s.%s() only supported on 1 and 2-D arrays " 895 % interp) 896 if not isinstance(a.dtype, (types.Float, types.Complex)): 897 raise TypingError("%s.%s() only supported on " 898 "float and complex arrays." % interp) 899 900 901@overload(np.linalg.cholesky) 902def cho_impl(a): 903 ensure_lapack() 904 905 _check_linalg_matrix(a, "cholesky") 906 907 numba_xxpotrf = _LAPACK().numba_xxpotrf(a.dtype) 908 909 kind = ord(get_blas_kind(a.dtype, "cholesky")) 910 UP = ord('U') 911 LO = ord('L') 912 913 def cho_impl(a): 914 n = a.shape[-1] 915 if a.shape[-2] != n: 916 msg = "Last 2 dimensions of the array must be square." 917 raise np.linalg.LinAlgError(msg) 918 919 # The output is allocated in C order 920 out = a.copy() 921 922 if n == 0: 923 return out 924 925 # Pass UP since xxpotrf() operates in F order 926 # The semantics ensure this works fine 927 # (out is really its Hermitian in F order, but UP instructs 928 # xxpotrf to compute the Hermitian of the upper triangle 929 # => they cancel each other) 930 r = numba_xxpotrf(kind, UP, n, out.ctypes, n) 931 if r != 0: 932 if r < 0: 933 fatal_error_func() 934 assert 0 # unreachable 935 if r > 0: 936 raise np.linalg.LinAlgError( 937 "Matrix is not positive definite.") 938 # Zero out upper triangle, in F order 939 for col in range(n): 940 out[:col, col] = 0 941 return out 942 943 return cho_impl 944 945@overload(np.linalg.eig) 946def eig_impl(a): 947 ensure_lapack() 948 949 _check_linalg_matrix(a, "eig") 950 951 numba_ez_rgeev = _LAPACK().numba_ez_rgeev(a.dtype) 952 numba_ez_cgeev = _LAPACK().numba_ez_cgeev(a.dtype) 953 954 kind = ord(get_blas_kind(a.dtype, "eig")) 955 956 JOBVL = ord('N') 957 JOBVR = ord('V') 958 959 def real_eig_impl(a): 960 """ 961 eig() implementation for real arrays. 962 """ 963 n = a.shape[-1] 964 if a.shape[-2] != n: 965 msg = "Last 2 dimensions of the array must be square." 966 raise np.linalg.LinAlgError(msg) 967 968 _check_finite_matrix(a) 969 970 acpy = _copy_to_fortran_order(a) 971 972 ldvl = 1 973 ldvr = n 974 wr = np.empty(n, dtype=a.dtype) 975 wi = np.empty(n, dtype=a.dtype) 976 vl = np.empty((n, ldvl), dtype=a.dtype) 977 vr = np.empty((n, ldvr), dtype=a.dtype) 978 979 if n == 0: 980 return (wr, vr.T) 981 982 r = numba_ez_rgeev(kind, 983 JOBVL, 984 JOBVR, 985 n, 986 acpy.ctypes, 987 n, 988 wr.ctypes, 989 wi.ctypes, 990 vl.ctypes, 991 ldvl, 992 vr.ctypes, 993 ldvr) 994 _handle_err_maybe_convergence_problem(r) 995 996 # By design numba does not support dynamic return types, however, 997 # Numpy does. Numpy uses this ability in the case of returning 998 # eigenvalues/vectors of a real matrix. The return type of 999 # np.linalg.eig(), when operating on a matrix in real space 1000 # depends on the values present in the matrix itself (recalling 1001 # that eigenvalues are the roots of the characteristic polynomial 1002 # of the system matrix, which will by construction depend on the 1003 # values present in the system matrix). As numba cannot handle 1004 # the case of a runtime decision based domain change relative to 1005 # the input type, if it is required numba raises as below. 1006 if np.any(wi): 1007 raise ValueError( 1008 "eig() argument must not cause a domain change.") 1009 1010 # put these in to help with liveness analysis, 1011 # `.ctypes` doesn't keep the vars alive 1012 _dummy_liveness_func([acpy.size, vl.size, vr.size, wr.size, wi.size]) 1013 return (wr, vr.T) 1014 1015 def cmplx_eig_impl(a): 1016 """ 1017 eig() implementation for complex arrays. 1018 """ 1019 n = a.shape[-1] 1020 if a.shape[-2] != n: 1021 msg = "Last 2 dimensions of the array must be square." 1022 raise np.linalg.LinAlgError(msg) 1023 1024 _check_finite_matrix(a) 1025 1026 acpy = _copy_to_fortran_order(a) 1027 1028 ldvl = 1 1029 ldvr = n 1030 w = np.empty(n, dtype=a.dtype) 1031 vl = np.empty((n, ldvl), dtype=a.dtype) 1032 vr = np.empty((n, ldvr), dtype=a.dtype) 1033 1034 if n == 0: 1035 return (w, vr.T) 1036 1037 r = numba_ez_cgeev(kind, 1038 JOBVL, 1039 JOBVR, 1040 n, 1041 acpy.ctypes, 1042 n, 1043 w.ctypes, 1044 vl.ctypes, 1045 ldvl, 1046 vr.ctypes, 1047 ldvr) 1048 _handle_err_maybe_convergence_problem(r) 1049 1050 # put these in to help with liveness analysis, 1051 # `.ctypes` doesn't keep the vars alive 1052 _dummy_liveness_func([acpy.size, vl.size, vr.size, w.size]) 1053 return (w, vr.T) 1054 1055 if isinstance(a.dtype, types.scalars.Complex): 1056 return cmplx_eig_impl 1057 else: 1058 return real_eig_impl 1059 1060@overload(np.linalg.eigvals) 1061def eigvals_impl(a): 1062 ensure_lapack() 1063 1064 _check_linalg_matrix(a, "eigvals") 1065 1066 numba_ez_rgeev = _LAPACK().numba_ez_rgeev(a.dtype) 1067 numba_ez_cgeev = _LAPACK().numba_ez_cgeev(a.dtype) 1068 1069 kind = ord(get_blas_kind(a.dtype, "eigvals")) 1070 1071 JOBVL = ord('N') 1072 JOBVR = ord('N') 1073 1074 def real_eigvals_impl(a): 1075 """ 1076 eigvals() implementation for real arrays. 1077 """ 1078 n = a.shape[-1] 1079 if a.shape[-2] != n: 1080 msg = "Last 2 dimensions of the array must be square." 1081 raise np.linalg.LinAlgError(msg) 1082 1083 _check_finite_matrix(a) 1084 1085 acpy = _copy_to_fortran_order(a) 1086 1087 ldvl = 1 1088 ldvr = 1 1089 wr = np.empty(n, dtype=a.dtype) 1090 1091 if n == 0: 1092 return wr 1093 1094 wi = np.empty(n, dtype=a.dtype) 1095 1096 # not referenced but need setting for MKL null check 1097 vl = np.empty((1), dtype=a.dtype) 1098 vr = np.empty((1), dtype=a.dtype) 1099 1100 r = numba_ez_rgeev(kind, 1101 JOBVL, 1102 JOBVR, 1103 n, 1104 acpy.ctypes, 1105 n, 1106 wr.ctypes, 1107 wi.ctypes, 1108 vl.ctypes, 1109 ldvl, 1110 vr.ctypes, 1111 ldvr) 1112 _handle_err_maybe_convergence_problem(r) 1113 1114 # By design numba does not support dynamic return types, however, 1115 # Numpy does. Numpy uses this ability in the case of returning 1116 # eigenvalues/vectors of a real matrix. The return type of 1117 # np.linalg.eigvals(), when operating on a matrix in real space 1118 # depends on the values present in the matrix itself (recalling 1119 # that eigenvalues are the roots of the characteristic polynomial 1120 # of the system matrix, which will by construction depend on the 1121 # values present in the system matrix). As numba cannot handle 1122 # the case of a runtime decision based domain change relative to 1123 # the input type, if it is required numba raises as below. 1124 if np.any(wi): 1125 raise ValueError( 1126 "eigvals() argument must not cause a domain change.") 1127 1128 # put these in to help with liveness analysis, 1129 # `.ctypes` doesn't keep the vars alive 1130 _dummy_liveness_func([acpy.size, vl.size, vr.size, wr.size, wi.size]) 1131 return wr 1132 1133 def cmplx_eigvals_impl(a): 1134 """ 1135 eigvals() implementation for complex arrays. 1136 """ 1137 n = a.shape[-1] 1138 if a.shape[-2] != n: 1139 msg = "Last 2 dimensions of the array must be square." 1140 raise np.linalg.LinAlgError(msg) 1141 1142 _check_finite_matrix(a) 1143 1144 acpy = _copy_to_fortran_order(a) 1145 1146 ldvl = 1 1147 ldvr = 1 1148 w = np.empty(n, dtype=a.dtype) 1149 1150 if n == 0: 1151 return w 1152 1153 vl = np.empty((1), dtype=a.dtype) 1154 vr = np.empty((1), dtype=a.dtype) 1155 1156 r = numba_ez_cgeev(kind, 1157 JOBVL, 1158 JOBVR, 1159 n, 1160 acpy.ctypes, 1161 n, 1162 w.ctypes, 1163 vl.ctypes, 1164 ldvl, 1165 vr.ctypes, 1166 ldvr) 1167 _handle_err_maybe_convergence_problem(r) 1168 1169 # put these in to help with liveness analysis, 1170 # `.ctypes` doesn't keep the vars alive 1171 _dummy_liveness_func([acpy.size, vl.size, vr.size, w.size]) 1172 return w 1173 1174 if isinstance(a.dtype, types.scalars.Complex): 1175 return cmplx_eigvals_impl 1176 else: 1177 return real_eigvals_impl 1178 1179@overload(np.linalg.eigh) 1180def eigh_impl(a): 1181 ensure_lapack() 1182 1183 _check_linalg_matrix(a, "eigh") 1184 1185 # convert typing floats to numpy floats for use in the impl 1186 w_type = getattr(a.dtype, "underlying_float", a.dtype) 1187 w_dtype = np_support.as_dtype(w_type) 1188 1189 numba_ez_xxxevd = _LAPACK().numba_ez_xxxevd(a.dtype) 1190 1191 kind = ord(get_blas_kind(a.dtype, "eigh")) 1192 1193 JOBZ = ord('V') 1194 UPLO = ord('L') 1195 1196 def eigh_impl(a): 1197 n = a.shape[-1] 1198 1199 if a.shape[-2] != n: 1200 msg = "Last 2 dimensions of the array must be square." 1201 raise np.linalg.LinAlgError(msg) 1202 1203 _check_finite_matrix(a) 1204 1205 acpy = _copy_to_fortran_order(a) 1206 1207 w = np.empty(n, dtype=w_dtype) 1208 1209 if n == 0: 1210 return (w, acpy) 1211 1212 r = numba_ez_xxxevd(kind, # kind 1213 JOBZ, # jobz 1214 UPLO, # uplo 1215 n, # n 1216 acpy.ctypes, # a 1217 n, # lda 1218 w.ctypes # w 1219 ) 1220 _handle_err_maybe_convergence_problem(r) 1221 1222 # help liveness analysis 1223 _dummy_liveness_func([acpy.size, w.size]) 1224 return (w, acpy) 1225 1226 return eigh_impl 1227 1228@overload(np.linalg.eigvalsh) 1229def eigvalsh_impl(a): 1230 ensure_lapack() 1231 1232 _check_linalg_matrix(a, "eigvalsh") 1233 1234 # convert typing floats to numpy floats for use in the impl 1235 w_type = getattr(a.dtype, "underlying_float", a.dtype) 1236 w_dtype = np_support.as_dtype(w_type) 1237 1238 numba_ez_xxxevd = _LAPACK().numba_ez_xxxevd(a.dtype) 1239 1240 kind = ord(get_blas_kind(a.dtype, "eigvalsh")) 1241 1242 JOBZ = ord('N') 1243 UPLO = ord('L') 1244 1245 def eigvalsh_impl(a): 1246 n = a.shape[-1] 1247 1248 if a.shape[-2] != n: 1249 msg = "Last 2 dimensions of the array must be square." 1250 raise np.linalg.LinAlgError(msg) 1251 1252 _check_finite_matrix(a) 1253 1254 acpy = _copy_to_fortran_order(a) 1255 1256 w = np.empty(n, dtype=w_dtype) 1257 1258 if n == 0: 1259 return w 1260 1261 r = numba_ez_xxxevd(kind, # kind 1262 JOBZ, # jobz 1263 UPLO, # uplo 1264 n, # n 1265 acpy.ctypes, # a 1266 n, # lda 1267 w.ctypes # w 1268 ) 1269 _handle_err_maybe_convergence_problem(r) 1270 1271 # help liveness analysis 1272 _dummy_liveness_func([acpy.size, w.size]) 1273 return w 1274 1275 return eigvalsh_impl 1276 1277@overload(np.linalg.svd) 1278def svd_impl(a, full_matrices=1): 1279 ensure_lapack() 1280 1281 _check_linalg_matrix(a, "svd") 1282 1283 # convert typing floats to numpy floats for use in the impl 1284 s_type = getattr(a.dtype, "underlying_float", a.dtype) 1285 s_dtype = np_support.as_dtype(s_type) 1286 1287 numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype) 1288 1289 kind = ord(get_blas_kind(a.dtype, "svd")) 1290 1291 JOBZ_A = ord('A') 1292 JOBZ_S = ord('S') 1293 1294 def svd_impl(a, full_matrices=1): 1295 n = a.shape[-1] 1296 m = a.shape[-2] 1297 1298 if n == 0 or m == 0: 1299 raise np.linalg.LinAlgError("Arrays cannot be empty") 1300 1301 _check_finite_matrix(a) 1302 1303 acpy = _copy_to_fortran_order(a) 1304 1305 ldu = m 1306 minmn = min(m, n) 1307 1308 if full_matrices: 1309 JOBZ = JOBZ_A 1310 ucol = m 1311 ldvt = n 1312 else: 1313 JOBZ = JOBZ_S 1314 ucol = minmn 1315 ldvt = minmn 1316 1317 u = np.empty((ucol, ldu), dtype=a.dtype) 1318 s = np.empty(minmn, dtype=s_dtype) 1319 vt = np.empty((n, ldvt), dtype=a.dtype) 1320 1321 r = numba_ez_gesdd( 1322 kind, # kind 1323 JOBZ, # jobz 1324 m, # m 1325 n, # n 1326 acpy.ctypes, # a 1327 m, # lda 1328 s.ctypes, # s 1329 u.ctypes, # u 1330 ldu, # ldu 1331 vt.ctypes, # vt 1332 ldvt # ldvt 1333 ) 1334 _handle_err_maybe_convergence_problem(r) 1335 1336 # help liveness analysis 1337 _dummy_liveness_func([acpy.size, vt.size, u.size, s.size]) 1338 return (u.T, s, vt.T) 1339 1340 return svd_impl 1341 1342 1343@overload(np.linalg.qr) 1344def qr_impl(a): 1345 ensure_lapack() 1346 1347 _check_linalg_matrix(a, "qr") 1348 1349 # Need two functions, the first computes R, storing it in the upper 1350 # triangle of A with the below diagonal part of A containing elementary 1351 # reflectors needed to construct Q. The second turns the below diagonal 1352 # entries of A into Q, storing Q in A (creates orthonormal columns from 1353 # the elementary reflectors). 1354 1355 numba_ez_geqrf = _LAPACK().numba_ez_geqrf(a.dtype) 1356 numba_ez_xxgqr = _LAPACK().numba_ez_xxgqr(a.dtype) 1357 1358 kind = ord(get_blas_kind(a.dtype, "qr")) 1359 1360 def qr_impl(a): 1361 n = a.shape[-1] 1362 m = a.shape[-2] 1363 1364 if n == 0 or m == 0: 1365 raise np.linalg.LinAlgError("Arrays cannot be empty") 1366 1367 _check_finite_matrix(a) 1368 1369 # copy A as it will be destroyed 1370 q = _copy_to_fortran_order(a) 1371 1372 lda = m 1373 1374 minmn = min(m, n) 1375 tau = np.empty((minmn), dtype=a.dtype) 1376 1377 ret = numba_ez_geqrf( 1378 kind, # kind 1379 m, # m 1380 n, # n 1381 q.ctypes, # a 1382 m, # lda 1383 tau.ctypes # tau 1384 ) 1385 if ret < 0: 1386 fatal_error_func() 1387 assert 0 # unreachable 1388 1389 # pull out R, this is transposed because of Fortran 1390 r = np.zeros((n, minmn), dtype=a.dtype).T 1391 1392 # the triangle in R 1393 for i in range(minmn): 1394 for j in range(i + 1): 1395 r[j, i] = q[j, i] 1396 1397 # and the possible square in R 1398 for i in range(minmn, n): 1399 for j in range(minmn): 1400 r[j, i] = q[j, i] 1401 1402 ret = numba_ez_xxgqr( 1403 kind, # kind 1404 m, # m 1405 minmn, # n 1406 minmn, # k 1407 q.ctypes, # a 1408 m, # lda 1409 tau.ctypes # tau 1410 ) 1411 _handle_err_maybe_convergence_problem(ret) 1412 1413 # help liveness analysis 1414 _dummy_liveness_func([tau.size, q.size]) 1415 return (q[:, :minmn], r) 1416 1417 return qr_impl 1418 1419 1420# helpers and jitted specialisations required for np.linalg.lstsq 1421# and np.linalg.solve. These functions have "system" in their name 1422# as a differentiator. 1423 1424def _system_copy_in_b(bcpy, b, nrhs): 1425 """ 1426 Correctly copy 'b' into the 'bcpy' scratch space. 1427 """ 1428 raise NotImplementedError 1429 1430 1431@overload(_system_copy_in_b) 1432def _system_copy_in_b_impl(bcpy, b, nrhs): 1433 if b.ndim == 1: 1434 def oneD_impl(bcpy, b, nrhs): 1435 bcpy[:b.shape[-1], 0] = b 1436 return oneD_impl 1437 else: 1438 def twoD_impl(bcpy, b, nrhs): 1439 bcpy[:b.shape[-2], :nrhs] = b 1440 return twoD_impl 1441 1442 1443def _system_compute_nrhs(b): 1444 """ 1445 Compute the number of right hand sides in the system of equations 1446 """ 1447 raise NotImplementedError 1448 1449 1450@overload(_system_compute_nrhs) 1451def _system_compute_nrhs_impl(b): 1452 if b.ndim == 1: 1453 def oneD_impl(b): 1454 return 1 1455 return oneD_impl 1456 else: 1457 def twoD_impl(b): 1458 return b.shape[-1] 1459 return twoD_impl 1460 1461 1462def _system_check_dimensionally_valid(a, b): 1463 """ 1464 Check that AX=B style system input is dimensionally valid. 1465 """ 1466 raise NotImplementedError 1467 1468 1469@overload(_system_check_dimensionally_valid) 1470def _system_check_dimensionally_valid_impl(a, b): 1471 ndim = b.ndim 1472 if ndim == 1: 1473 def oneD_impl(a, b): 1474 am = a.shape[-2] 1475 bm = b.shape[-1] 1476 if am != bm: 1477 raise np.linalg.LinAlgError( 1478 "Incompatible array sizes, system is not dimensionally valid.") 1479 return oneD_impl 1480 else: 1481 def twoD_impl(a, b): 1482 am = a.shape[-2] 1483 bm = b.shape[-2] 1484 if am != bm: 1485 raise np.linalg.LinAlgError( 1486 "Incompatible array sizes, system is not dimensionally valid.") 1487 return twoD_impl 1488 1489 1490def _system_check_non_empty(a, b): 1491 """ 1492 Check that AX=B style system input is not empty. 1493 """ 1494 raise NotImplementedError 1495 1496 1497@overload(_system_check_non_empty) 1498def _system_check_non_empty_impl(a, b): 1499 ndim = b.ndim 1500 if ndim == 1: 1501 def oneD_impl(a, b): 1502 am = a.shape[-2] 1503 an = a.shape[-1] 1504 bm = b.shape[-1] 1505 if am == 0 or bm == 0 or an == 0: 1506 raise np.linalg.LinAlgError('Arrays cannot be empty') 1507 return oneD_impl 1508 else: 1509 def twoD_impl(a, b): 1510 am = a.shape[-2] 1511 an = a.shape[-1] 1512 bm = b.shape[-2] 1513 bn = b.shape[-1] 1514 if am == 0 or bm == 0 or an == 0 or bn == 0: 1515 raise np.linalg.LinAlgError('Arrays cannot be empty') 1516 return twoD_impl 1517 1518 1519def _lstsq_residual(b, n, nrhs): 1520 """ 1521 Compute the residual from the 'b' scratch space. 1522 """ 1523 raise NotImplementedError 1524 1525 1526@overload(_lstsq_residual) 1527def _lstsq_residual_impl(b, n, nrhs): 1528 ndim = b.ndim 1529 dtype = b.dtype 1530 real_dtype = np_support.as_dtype(getattr(dtype, "underlying_float", dtype)) 1531 1532 if ndim == 1: 1533 if isinstance(dtype, (types.Complex)): 1534 def cmplx_impl(b, n, nrhs): 1535 res = np.empty((1,), dtype=real_dtype) 1536 res[0] = np.sum(np.abs(b[n:, 0])**2) 1537 return res 1538 return cmplx_impl 1539 else: 1540 def real_impl(b, n, nrhs): 1541 res = np.empty((1,), dtype=real_dtype) 1542 res[0] = np.sum(b[n:, 0]**2) 1543 return res 1544 return real_impl 1545 else: 1546 assert ndim == 2 1547 if isinstance(dtype, (types.Complex)): 1548 def cmplx_impl(b, n, nrhs): 1549 res = np.empty((nrhs), dtype=real_dtype) 1550 for k in range(nrhs): 1551 res[k] = np.sum(np.abs(b[n:, k])**2) 1552 return res 1553 return cmplx_impl 1554 else: 1555 def real_impl(b, n, nrhs): 1556 res = np.empty((nrhs), dtype=real_dtype) 1557 for k in range(nrhs): 1558 res[k] = np.sum(b[n:, k]**2) 1559 return res 1560 return real_impl 1561 1562 1563def _lstsq_solution(b, bcpy, n): 1564 """ 1565 Extract 'x' (the lstsq solution) from the 'bcpy' scratch space. 1566 Note 'b' is only used to check the system input dimension... 1567 """ 1568 raise NotImplementedError 1569 1570 1571@overload(_lstsq_solution) 1572def _lstsq_solution_impl(b, bcpy, n): 1573 if b.ndim == 1: 1574 def oneD_impl(b, bcpy, n): 1575 return bcpy.T.ravel()[:n] 1576 return oneD_impl 1577 else: 1578 def twoD_impl(b, bcpy, n): 1579 return bcpy[:n, :].copy() 1580 return twoD_impl 1581 1582 1583@overload(np.linalg.lstsq) 1584def lstsq_impl(a, b, rcond=-1.0): 1585 ensure_lapack() 1586 1587 _check_linalg_matrix(a, "lstsq") 1588 1589 # B can be 1D or 2D. 1590 _check_linalg_1_or_2d_matrix(b, "lstsq") 1591 1592 _check_homogeneous_types("lstsq", a, b) 1593 1594 np_dt = np_support.as_dtype(a.dtype) 1595 nb_dt = a.dtype 1596 1597 # convert typing floats to np floats for use in the impl 1598 r_type = getattr(nb_dt, "underlying_float", nb_dt) 1599 real_dtype = np_support.as_dtype(r_type) 1600 1601 # lapack solver 1602 numba_ez_gelsd = _LAPACK().numba_ez_gelsd(a.dtype) 1603 1604 kind = ord(get_blas_kind(nb_dt, "lstsq")) 1605 1606 # The following functions select specialisations based on 1607 # information around 'b', a lot of this effort is required 1608 # as 'b' can be either 1D or 2D, and then there are 1609 # some optimisations available depending on real or complex 1610 # space. 1611 1612 def lstsq_impl(a, b, rcond=-1.0): 1613 n = a.shape[-1] 1614 m = a.shape[-2] 1615 nrhs = _system_compute_nrhs(b) 1616 1617 # check the systems have no inf or NaN 1618 _check_finite_matrix(a) 1619 _check_finite_matrix(b) 1620 1621 # check the system is not empty 1622 _system_check_non_empty(a, b) 1623 1624 # check the systems are dimensionally valid 1625 _system_check_dimensionally_valid(a, b) 1626 1627 minmn = min(m, n) 1628 maxmn = max(m, n) 1629 1630 # a is destroyed on exit, copy it 1631 acpy = _copy_to_fortran_order(a) 1632 1633 # b is overwritten on exit with the solution, copy allocate 1634 bcpy = np.empty((nrhs, maxmn), dtype=np_dt).T 1635 # specialised copy in due to b being 1 or 2D 1636 _system_copy_in_b(bcpy, b, nrhs) 1637 1638 # Allocate returns 1639 s = np.empty(minmn, dtype=real_dtype) 1640 rank_ptr = np.empty(1, dtype=np.int32) 1641 1642 r = numba_ez_gelsd( 1643 kind, # kind 1644 m, # m 1645 n, # n 1646 nrhs, # nrhs 1647 acpy.ctypes, # a 1648 m, # lda 1649 bcpy.ctypes, # a 1650 maxmn, # ldb 1651 s.ctypes, # s 1652 rcond, # rcond 1653 rank_ptr.ctypes # rank 1654 ) 1655 _handle_err_maybe_convergence_problem(r) 1656 1657 # set rank to that which was computed 1658 rank = rank_ptr[0] 1659 1660 # compute residuals 1661 if rank < n or m <= n: 1662 res = np.empty((0), dtype=real_dtype) 1663 else: 1664 # this requires additional dispatch as there's a faster 1665 # impl if the result is in the real domain (no abs() required) 1666 res = _lstsq_residual(bcpy, n, nrhs) 1667 1668 # extract 'x', the solution 1669 x = _lstsq_solution(b, bcpy, n) 1670 1671 # help liveness analysis 1672 _dummy_liveness_func([acpy.size, bcpy.size, s.size, rank_ptr.size]) 1673 return (x, res, rank, s[:minmn]) 1674 1675 return lstsq_impl 1676 1677 1678def _solve_compute_return(b, bcpy): 1679 """ 1680 Extract 'x' (the solution) from the 'bcpy' scratch space. 1681 Note 'b' is only used to check the system input dimension... 1682 """ 1683 raise NotImplementedError 1684 1685 1686@overload(_solve_compute_return) 1687def _solve_compute_return_impl(b, bcpy): 1688 if b.ndim == 1: 1689 def oneD_impl(b, bcpy): 1690 return bcpy.T.ravel() 1691 return oneD_impl 1692 else: 1693 def twoD_impl(b, bcpy): 1694 return bcpy 1695 return twoD_impl 1696 1697 1698@overload(np.linalg.solve) 1699def solve_impl(a, b): 1700 ensure_lapack() 1701 1702 _check_linalg_matrix(a, "solve") 1703 _check_linalg_1_or_2d_matrix(b, "solve") 1704 1705 _check_homogeneous_types("solve", a, b) 1706 1707 np_dt = np_support.as_dtype(a.dtype) 1708 nb_dt = a.dtype 1709 1710 # the lapack solver 1711 numba_xgesv = _LAPACK().numba_xgesv(a.dtype) 1712 1713 kind = ord(get_blas_kind(nb_dt, "solve")) 1714 1715 def solve_impl(a, b): 1716 n = a.shape[-1] 1717 nrhs = _system_compute_nrhs(b) 1718 1719 # check the systems have no inf or NaN 1720 _check_finite_matrix(a) 1721 _check_finite_matrix(b) 1722 1723 # check the systems are dimensionally valid 1724 _system_check_dimensionally_valid(a, b) 1725 1726 # a is destroyed on exit, copy it 1727 acpy = _copy_to_fortran_order(a) 1728 1729 # b is overwritten on exit with the solution, copy allocate 1730 bcpy = np.empty((nrhs, n), dtype=np_dt).T 1731 if n == 0: 1732 return _solve_compute_return(b, bcpy) 1733 1734 # specialised copy in due to b being 1 or 2D 1735 _system_copy_in_b(bcpy, b, nrhs) 1736 1737 # allocate pivot array (needs to be fortran int size) 1738 ipiv = np.empty(n, dtype=F_INT_nptype) 1739 1740 r = numba_xgesv( 1741 kind, # kind 1742 n, # n 1743 nrhs, # nhrs 1744 acpy.ctypes, # a 1745 n, # lda 1746 ipiv.ctypes, # ipiv 1747 bcpy.ctypes, # b 1748 n # ldb 1749 ) 1750 _inv_err_handler(r) 1751 1752 # help liveness analysis 1753 _dummy_liveness_func([acpy.size, bcpy.size, ipiv.size]) 1754 return _solve_compute_return(b, bcpy) 1755 1756 return solve_impl 1757 1758 1759@overload(np.linalg.pinv) 1760def pinv_impl(a, rcond=1.e-15): 1761 ensure_lapack() 1762 1763 _check_linalg_matrix(a, "pinv") 1764 1765 # convert typing floats to numpy floats for use in the impl 1766 s_type = getattr(a.dtype, "underlying_float", a.dtype) 1767 s_dtype = np_support.as_dtype(s_type) 1768 1769 numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype) 1770 1771 numba_xxgemm = _BLAS().numba_xxgemm(a.dtype) 1772 1773 kind = ord(get_blas_kind(a.dtype, "pinv")) 1774 JOB = ord('S') 1775 1776 # need conjugate transposes 1777 TRANSA = ord('C') 1778 TRANSB = ord('C') 1779 1780 # scalar constants 1781 dt = np_support.as_dtype(a.dtype) 1782 zero = np.array([0.], dtype=dt) 1783 one = np.array([1.], dtype=dt) 1784 1785 def pinv_impl(a, rcond=1.e-15): 1786 1787 # The idea is to build the pseudo-inverse via inverting the singular 1788 # value decomposition of a matrix `A`. Mathematically, this is roughly 1789 # A = U*S*V^H [The SV decomposition of A] 1790 # A^+ = V*(S^+)*U^H [The inverted SV decomposition of A] 1791 # where ^+ is pseudo inversion and ^H is Hermitian transpose. 1792 # As V and U are unitary, their inverses are simply their Hermitian 1793 # transpose. S has singular values on its diagonal and zero elsewhere, 1794 # it is inverted trivially by reciprocal of the diagonal values with 1795 # the exception that zero singular values remain as zero. 1796 # 1797 # The practical implementation can take advantage of a few things to 1798 # gain a few % performance increase: 1799 # * A is destroyed by the SVD algorithm from LAPACK so a copy is 1800 # required, this memory is exactly the right size in which to return 1801 # the pseudo-inverse and so can be reused for this purpose. 1802 # * The pseudo-inverse of S can be applied to either V or U^H, this 1803 # then leaves a GEMM operation to compute the inverse via either: 1804 # A^+ = (V*(S^+))*U^H 1805 # or 1806 # A^+ = V*((S^+)*U^H) 1807 # however application of S^+ to V^H or U is more convenient as they 1808 # are the result of the SVD algorithm. The application of the 1809 # diagonal system is just a matrix multiplication which results in a 1810 # row/column scaling (direction depending). To save effort, this 1811 # "matrix multiplication" is applied to the smallest of U or V^H and 1812 # only up to the point of "cut-off" (see next note) just as a direct 1813 # scaling. 1814 # * The cut-off level for application of S^+ can be used to reduce 1815 # total effort, this cut-off can come via rcond or may just naturally 1816 # be present as a result of zeros in the singular values. Regardless 1817 # there's no need to multiply by zeros in the application of S^+ to 1818 # V^H or U as above. Further, the GEMM operation can be shrunk in 1819 # effort by noting that the possible zero block generated by the 1820 # presence of zeros in S^+ has no effect apart from wasting cycles as 1821 # it is all fmadd()s where one operand is zero. The inner dimension 1822 # of the GEMM operation can therefore be set as shrunk accordingly! 1823 1824 n = a.shape[-1] 1825 m = a.shape[-2] 1826 1827 _check_finite_matrix(a) 1828 1829 acpy = _copy_to_fortran_order(a) 1830 1831 if m == 0 or n == 0: 1832 return acpy.T.ravel().reshape(a.shape).T 1833 1834 minmn = min(m, n) 1835 1836 u = np.empty((minmn, m), dtype=a.dtype) 1837 s = np.empty(minmn, dtype=s_dtype) 1838 vt = np.empty((n, minmn), dtype=a.dtype) 1839 1840 r = numba_ez_gesdd( 1841 kind, # kind 1842 JOB, # job 1843 m, # m 1844 n, # n 1845 acpy.ctypes, # a 1846 m, # lda 1847 s.ctypes, # s 1848 u.ctypes, # u 1849 m, # ldu 1850 vt.ctypes, # vt 1851 minmn # ldvt 1852 ) 1853 _handle_err_maybe_convergence_problem(r) 1854 1855 # Invert singular values under threshold. Also find the index of 1856 # the threshold value as this is the upper limit for the application 1857 # of the inverted singular values. Finding this value saves 1858 # multiplication by a block of zeros that would be created by the 1859 # application of these values to either U or V^H ahead of multiplying 1860 # them together. This is done by simply in BLAS parlance via 1861 # restricting the `k` dimension to `cut_idx` in `xgemm` whilst keeping 1862 # the leading dimensions correct. 1863 1864 cut_at = s[0] * rcond 1865 cut_idx = 0 1866 for k in range(minmn): 1867 if s[k] > cut_at: 1868 s[k] = 1. / s[k] 1869 cut_idx = k 1870 cut_idx += 1 1871 1872 # Use cut_idx so there's no scaling by 0. 1873 if m >= n: 1874 # U is largest so apply S^+ to V^H. 1875 for i in range(n): 1876 for j in range(cut_idx): 1877 vt[i, j] = vt[i, j] * s[j] 1878 else: 1879 # V^H is largest so apply S^+ to U. 1880 for i in range(cut_idx): 1881 s_local = s[i] 1882 for j in range(minmn): 1883 u[i, j] = u[i, j] * s_local 1884 1885 # Do (v^H)^H*U^H (obviously one of the matrices includes the S^+ 1886 # scaling) and write back to acpy. Note the innner dimension of cut_idx 1887 # taking account of the possible zero block. 1888 # We can store the result in acpy, given we had to create it 1889 # for use in the SVD, and it is now redundant and the right size 1890 # but wrong shape. 1891 1892 r = numba_xxgemm( 1893 kind, 1894 TRANSA, # TRANSA 1895 TRANSB, # TRANSB 1896 n, # M 1897 m, # N 1898 cut_idx, # K 1899 one.ctypes, # ALPHA 1900 vt.ctypes, # A 1901 minmn, # LDA 1902 u.ctypes, # B 1903 m, # LDB 1904 zero.ctypes, # BETA 1905 acpy.ctypes, # C 1906 n # LDC 1907 ) 1908 1909 # help liveness analysis 1910 #acpy.size 1911 #vt.size 1912 #u.size 1913 #s.size 1914 #one.size 1915 #zero.size 1916 _dummy_liveness_func([acpy.size, vt.size, u.size, s.size, one.size, 1917 zero.size]) 1918 return acpy.T.ravel().reshape(a.shape).T 1919 1920 return pinv_impl 1921 1922 1923def _get_slogdet_diag_walker(a): 1924 """ 1925 Walks the diag of a LUP decomposed matrix 1926 uses that det(A) = prod(diag(lup(A))) 1927 and also that log(a)+log(b) = log(a*b) 1928 The return sign is adjusted based on the values found 1929 such that the log(value) stays in the real domain. 1930 """ 1931 if isinstance(a.dtype, types.Complex): 1932 @register_jitable 1933 def cmplx_diag_walker(n, a, sgn): 1934 # walk diagonal 1935 csgn = sgn + 0.j 1936 acc = 0. 1937 for k in range(n): 1938 absel = np.abs(a[k, k]) 1939 csgn = csgn * (a[k, k] / absel) 1940 acc = acc + np.log(absel) 1941 return (csgn, acc) 1942 return cmplx_diag_walker 1943 else: 1944 @register_jitable 1945 def real_diag_walker(n, a, sgn): 1946 # walk diagonal 1947 acc = 0. 1948 for k in range(n): 1949 v = a[k, k] 1950 if v < 0.: 1951 sgn = -sgn 1952 v = -v 1953 acc = acc + np.log(v) 1954 # sgn is a float dtype 1955 return (sgn + 0., acc) 1956 return real_diag_walker 1957 1958 1959@overload(np.linalg.slogdet) 1960def slogdet_impl(a): 1961 ensure_lapack() 1962 1963 _check_linalg_matrix(a, "slogdet") 1964 1965 numba_xxgetrf = _LAPACK().numba_xxgetrf(a.dtype) 1966 1967 kind = ord(get_blas_kind(a.dtype, "slogdet")) 1968 1969 diag_walker = _get_slogdet_diag_walker(a) 1970 1971 ONE = a.dtype(1) 1972 ZERO = getattr(a.dtype, "underlying_float", a.dtype)(0) 1973 1974 def slogdet_impl(a): 1975 n = a.shape[-1] 1976 if a.shape[-2] != n: 1977 msg = "Last 2 dimensions of the array must be square." 1978 raise np.linalg.LinAlgError(msg) 1979 1980 if n == 0: 1981 return (ONE, ZERO) 1982 1983 _check_finite_matrix(a) 1984 1985 acpy = _copy_to_fortran_order(a) 1986 1987 ipiv = np.empty(n, dtype=F_INT_nptype) 1988 1989 r = numba_xxgetrf(kind, n, n, acpy.ctypes, n, ipiv.ctypes) 1990 1991 if r > 0: 1992 # factorisation failed, return same defaults as np 1993 return (0., -np.inf) 1994 _inv_err_handler(r) # catch input-to-lapack problem 1995 1996 # The following, prior to the call to diag_walker, is present 1997 # to account for the effect of possible permutations to the 1998 # sign of the determinant. 1999 # This is the same idea as in numpy: 2000 # File name `umath_linalg.c.src` e.g. 2001 # https://github.com/numpy/numpy/blob/master/numpy/linalg/umath_linalg.c.src 2002 # in function `@TYPE@_slogdet_single_element`. 2003 sgn = 1 2004 for k in range(n): 2005 sgn = sgn + (ipiv[k] != (k + 1)) 2006 2007 sgn = sgn & 1 2008 if sgn == 0: 2009 sgn = -1 2010 2011 # help liveness analysis 2012 _dummy_liveness_func([ipiv.size]) 2013 return diag_walker(n, acpy, sgn) 2014 2015 return slogdet_impl 2016 2017 2018@overload(np.linalg.det) 2019def det_impl(a): 2020 2021 ensure_lapack() 2022 2023 _check_linalg_matrix(a, "det") 2024 2025 def det_impl(a): 2026 (sgn, slogdet) = np.linalg.slogdet(a) 2027 return sgn * np.exp(slogdet) 2028 2029 return det_impl 2030 2031 2032def _compute_singular_values(a): 2033 """ 2034 Compute singular values of *a*. 2035 """ 2036 raise NotImplementedError 2037 2038 2039@overload(_compute_singular_values) 2040def _compute_singular_values_impl(a): 2041 """ 2042 Returns a function to compute singular values of `a` 2043 """ 2044 numba_ez_gesdd = _LAPACK().numba_ez_gesdd(a.dtype) 2045 2046 kind = ord(get_blas_kind(a.dtype, "svd")) 2047 2048 # Flag for "only compute `S`" to give to xgesdd 2049 JOBZ_N = ord('N') 2050 2051 nb_ret_type = getattr(a.dtype, "underlying_float", a.dtype) 2052 np_ret_type = np_support.as_dtype(nb_ret_type) 2053 np_dtype = np_support.as_dtype(a.dtype) 2054 2055 # These are not referenced in the computation but must be set 2056 # for MKL. 2057 u = np.empty((1, 1), dtype=np_dtype) 2058 vt = np.empty((1, 1), dtype=np_dtype) 2059 2060 def sv_function(a): 2061 """ 2062 Computes singular values. 2063 """ 2064 # Don't use the np.linalg.svd impl instead 2065 # call LAPACK to shortcut doing the "reconstruct 2066 # singular vectors from reflectors" step and just 2067 # get back the singular values. 2068 n = a.shape[-1] 2069 m = a.shape[-2] 2070 if m == 0 or n == 0: 2071 raise np.linalg.LinAlgError('Arrays cannot be empty') 2072 _check_finite_matrix(a) 2073 2074 ldu = m 2075 minmn = min(m, n) 2076 2077 # need to be >=1 but aren't referenced 2078 ucol = 1 2079 ldvt = 1 2080 2081 acpy = _copy_to_fortran_order(a) 2082 2083 # u and vt are not referenced however need to be 2084 # allocated (as done above) for MKL as it 2085 # checks for ref is nullptr. 2086 s = np.empty(minmn, dtype=np_ret_type) 2087 2088 r = numba_ez_gesdd( 2089 kind, # kind 2090 JOBZ_N, # jobz 2091 m, # m 2092 n, # n 2093 acpy.ctypes, # a 2094 m, # lda 2095 s.ctypes, # s 2096 u.ctypes, # u 2097 ldu, # ldu 2098 vt.ctypes, # vt 2099 ldvt # ldvt 2100 ) 2101 _handle_err_maybe_convergence_problem(r) 2102 2103 # help liveness analysis 2104 _dummy_liveness_func([acpy.size, vt.size, u.size, s.size]) 2105 return s 2106 2107 return sv_function 2108 2109 2110def _oneD_norm_2(a): 2111 """ 2112 Compute the L2-norm of 1D-array *a*. 2113 """ 2114 raise NotImplementedError 2115 2116 2117@overload(_oneD_norm_2) 2118def _oneD_norm_2_impl(a): 2119 nb_ret_type = getattr(a.dtype, "underlying_float", a.dtype) 2120 np_ret_type = np_support.as_dtype(nb_ret_type) 2121 2122 xxnrm2 = _BLAS().numba_xxnrm2(a.dtype) 2123 2124 kind = ord(get_blas_kind(a.dtype, "norm")) 2125 2126 def impl(a): 2127 # Just ignore order, calls are guarded to only come 2128 # from cases where order=None or order=2. 2129 n = len(a) 2130 # Call L2-norm routine from BLAS 2131 ret = np.empty((1,), dtype=np_ret_type) 2132 jmp = int(a.strides[0] / a.itemsize) 2133 r = xxnrm2( 2134 kind, # kind 2135 n, # n 2136 a.ctypes, # x 2137 jmp, # incx 2138 ret.ctypes # result 2139 ) 2140 if r < 0: 2141 fatal_error_func() 2142 assert 0 # unreachable 2143 2144 # help liveness analysis 2145 #ret.size 2146 #a.size 2147 _dummy_liveness_func([ret.size, a.size]) 2148 return ret[0] 2149 2150 return impl 2151 2152 2153def _get_norm_impl(a, ord_flag): 2154 # This function is quite involved as norm supports a large 2155 # range of values to select different norm types via kwarg `ord`. 2156 # The implementation below branches on dimension of the input 2157 # (1D or 2D). The default for `ord` is `None` which requires 2158 # special handling in numba, this is dealt with first in each of 2159 # the dimension branches. Following this the various norms are 2160 # computed via code that is in most cases simply a loop version 2161 # of a ufunc based version as found in numpy. 2162 2163 # The following is common to both 1D and 2D cases. 2164 # Convert typing floats to numpy floats for use in the impl. 2165 # The return type is always a float, numba differs from numpy in 2166 # that it returns an input precision specific value whereas numpy 2167 # always returns np.float64. 2168 nb_ret_type = getattr(a.dtype, "underlying_float", a.dtype) 2169 np_ret_type = np_support.as_dtype(nb_ret_type) 2170 2171 np_dtype = np_support.as_dtype(a.dtype) 2172 2173 xxnrm2 = _BLAS().numba_xxnrm2(a.dtype) 2174 2175 kind = ord(get_blas_kind(a.dtype, "norm")) 2176 2177 if a.ndim == 1: 2178 # 1D cases 2179 2180 # handle "ord" being "None", must be done separately 2181 if ord_flag in (None, types.none): 2182 def oneD_impl(a, ord=None): 2183 return _oneD_norm_2(a) 2184 else: 2185 def oneD_impl(a, ord=None): 2186 n = len(a) 2187 2188 # Shortcut to handle zero length arrays 2189 # this differs slightly to numpy in that 2190 # numpy raises a ValueError for kwarg ord= 2191 # +/-np.inf as the reduction operations like 2192 # max() and min() don't accept zero length 2193 # arrays 2194 if n == 0: 2195 return 0.0 2196 2197 # Note: on order == 2 2198 # This is the same as for ord=="None" but because 2199 # we have to handle "None" specially this condition 2200 # is separated 2201 if ord == 2: 2202 return _oneD_norm_2(a) 2203 elif ord == np.inf: 2204 # max(abs(a)) 2205 ret = abs(a[0]) 2206 for k in range(1, n): 2207 val = abs(a[k]) 2208 if val > ret: 2209 ret = val 2210 return ret 2211 2212 elif ord == -np.inf: 2213 # min(abs(a)) 2214 ret = abs(a[0]) 2215 for k in range(1, n): 2216 val = abs(a[k]) 2217 if val < ret: 2218 ret = val 2219 return ret 2220 2221 elif ord == 0: 2222 # sum(a != 0) 2223 ret = 0.0 2224 for k in range(n): 2225 if a[k] != 0.: 2226 ret += 1. 2227 return ret 2228 2229 elif ord == 1: 2230 # sum(abs(a)) 2231 ret = 0.0 2232 for k in range(n): 2233 ret += abs(a[k]) 2234 return ret 2235 2236 else: 2237 # sum(abs(a)**ord)**(1./ord) 2238 ret = 0.0 2239 for k in range(n): 2240 ret += abs(a[k])**ord 2241 return ret**(1. / ord) 2242 return oneD_impl 2243 2244 elif a.ndim == 2: 2245 # 2D cases 2246 2247 # handle "ord" being "None" 2248 if ord_flag in (None, types.none): 2249 # Force `a` to be C-order, so that we can take a contiguous 2250 # 1D view. 2251 if a.layout == 'C': 2252 @register_jitable 2253 def array_prepare(a): 2254 return a 2255 elif a.layout == 'F': 2256 @register_jitable 2257 def array_prepare(a): 2258 # Legal since L2(a) == L2(a.T) 2259 return a.T 2260 else: 2261 @register_jitable 2262 def array_prepare(a): 2263 return a.copy() 2264 2265 # Compute the Frobenius norm, this is the L2,2 induced norm of `A` 2266 # which is the L2-norm of A.ravel() and so can be computed via BLAS 2267 def twoD_impl(a, ord=None): 2268 n = a.size 2269 if n == 0: 2270 # reshape() currently doesn't support zero-sized arrays 2271 return 0.0 2272 a_c = array_prepare(a) 2273 return _oneD_norm_2(a_c.reshape(n)) 2274 else: 2275 # max value for this dtype 2276 max_val = np.finfo(np_ret_type.type).max 2277 2278 def twoD_impl(a, ord=None): 2279 n = a.shape[-1] 2280 m = a.shape[-2] 2281 2282 # Shortcut to handle zero size arrays 2283 # this differs slightly to numpy in that 2284 # numpy raises errors for some ord values 2285 # and in other cases returns zero. 2286 if a.size == 0: 2287 return 0.0 2288 2289 if ord == np.inf: 2290 # max of sum of abs across rows 2291 # max(sum(abs(a)), axis=1) 2292 global_max = 0. 2293 for ii in range(m): 2294 tmp = 0. 2295 for jj in range(n): 2296 tmp += abs(a[ii, jj]) 2297 if tmp > global_max: 2298 global_max = tmp 2299 return global_max 2300 2301 elif ord == -np.inf: 2302 # min of sum of abs across rows 2303 # min(sum(abs(a)), axis=1) 2304 global_min = max_val 2305 for ii in range(m): 2306 tmp = 0. 2307 for jj in range(n): 2308 tmp += abs(a[ii, jj]) 2309 if tmp < global_min: 2310 global_min = tmp 2311 return global_min 2312 elif ord == 1: 2313 # max of sum of abs across cols 2314 # max(sum(abs(a)), axis=0) 2315 global_max = 0. 2316 for ii in range(n): 2317 tmp = 0. 2318 for jj in range(m): 2319 tmp += abs(a[jj, ii]) 2320 if tmp > global_max: 2321 global_max = tmp 2322 return global_max 2323 2324 elif ord == -1: 2325 # min of sum of abs across cols 2326 # min(sum(abs(a)), axis=0) 2327 global_min = max_val 2328 for ii in range(n): 2329 tmp = 0. 2330 for jj in range(m): 2331 tmp += abs(a[jj, ii]) 2332 if tmp < global_min: 2333 global_min = tmp 2334 return global_min 2335 2336 # Results via SVD, singular values are sorted on return 2337 # by definition. 2338 elif ord == 2: 2339 # max SV 2340 return _compute_singular_values(a)[0] 2341 elif ord == -2: 2342 # min SV 2343 return _compute_singular_values(a)[-1] 2344 else: 2345 # replicate numpy error 2346 raise ValueError("Invalid norm order for matrices.") 2347 return twoD_impl 2348 else: 2349 assert 0 # unreachable 2350 2351 2352@overload(np.linalg.norm) 2353def norm_impl(a, ord=None): 2354 ensure_lapack() 2355 2356 _check_linalg_1_or_2d_matrix(a, "norm") 2357 2358 return _get_norm_impl(a, ord) 2359 2360 2361@overload(np.linalg.cond) 2362def cond_impl(a, p=None): 2363 ensure_lapack() 2364 2365 _check_linalg_matrix(a, "cond") 2366 2367 def impl(a, p=None): 2368 # This is extracted for performance, numpy does approximately: 2369 # `condition = norm(a) * norm(inv(a))` 2370 # in the cases of `p == 2` or `p ==-2` singular values are used 2371 # for computing norms. This costs numpy an svd of `a` then an 2372 # inversion of `a` and another svd of `a`. 2373 # Below is a different approach, which also gives a more 2374 # accurate answer as there is no inversion involved. 2375 # Recall that the singular values of an inverted matrix are the 2376 # reciprocal of singular values of the original matrix. 2377 # Therefore calling `svd(a)` once yields all the information 2378 # needed about both `a` and `inv(a)` without the cost or 2379 # potential loss of accuracy incurred through inversion. 2380 # For the case of `p == 2`, the result is just the ratio of 2381 # `largest singular value/smallest singular value`, and for the 2382 # case of `p==-2` the result is simply the 2383 # `smallest singular value/largest singular value`. 2384 # As a result of this, numba accepts non-square matrices as 2385 # input when p==+/-2 as well as when p==None. 2386 if p == 2 or p == -2 or p is None: 2387 s = _compute_singular_values(a) 2388 if p == 2 or p is None: 2389 r = np.divide(s[0], s[-1]) 2390 else: 2391 r = np.divide(s[-1], s[0]) 2392 else: # cases np.inf, -np.inf, 1, -1 2393 norm_a = np.linalg.norm(a, p) 2394 norm_inv_a = np.linalg.norm(np.linalg.inv(a), p) 2395 r = norm_a * norm_inv_a 2396 # NumPy uses a NaN mask, if the input has a NaN, it will return NaN, 2397 # Numba calls ban NaN through the use of _check_finite_matrix but this 2398 # catches cases where NaN occurs through floating point use 2399 if np.isnan(r): 2400 return np.inf 2401 else: 2402 return r 2403 return impl 2404 2405 2406@register_jitable 2407def _get_rank_from_singular_values(sv, t): 2408 """ 2409 Gets rank from singular values with cut-off at a given tolerance 2410 """ 2411 rank = 0 2412 for k in range(len(sv)): 2413 if sv[k] > t: 2414 rank = rank + 1 2415 else: # sv is ordered big->small so break on condition not met 2416 break 2417 return rank 2418 2419 2420@overload(np.linalg.matrix_rank) 2421def matrix_rank_impl(a, tol=None): 2422 """ 2423 Computes rank for matrices and vectors. 2424 The only issue that may arise is that because numpy uses double 2425 precision lapack calls whereas numba uses type specific lapack 2426 calls, some singular values may differ and therefore counting the 2427 number of them above a tolerance may lead to different counts, 2428 and therefore rank, in some cases. 2429 """ 2430 ensure_lapack() 2431 2432 _check_linalg_1_or_2d_matrix(a, "matrix_rank") 2433 2434 def _2d_matrix_rank_impl(a, tol): 2435 2436 # handle the tol==None case separately for type inference to work 2437 if tol in (None, types.none): 2438 nb_type = getattr(a.dtype, "underlying_float", a.dtype) 2439 np_type = np_support.as_dtype(nb_type) 2440 eps_val = np.finfo(np_type).eps 2441 2442 def _2d_tol_none_impl(a, tol=None): 2443 s = _compute_singular_values(a) 2444 # replicate numpy default tolerance calculation 2445 r = a.shape[0] 2446 c = a.shape[1] 2447 l = max(r, c) 2448 t = s[0] * l * eps_val 2449 return _get_rank_from_singular_values(s, t) 2450 return _2d_tol_none_impl 2451 else: 2452 def _2d_tol_not_none_impl(a, tol=None): 2453 s = _compute_singular_values(a) 2454 return _get_rank_from_singular_values(s, tol) 2455 return _2d_tol_not_none_impl 2456 2457 def _get_matrix_rank_impl(a, tol): 2458 ndim = a.ndim 2459 if ndim == 1: 2460 # NOTE: Technically, the numpy implementation could be argued as 2461 # incorrect for the case of a vector (1D matrix). If a tolerance 2462 # is provided and a vector with a singular value below tolerance is 2463 # encountered this should report a rank of zero, the numpy 2464 # implementation does not do this and instead elects to report that 2465 # if any value in the vector is nonzero then the rank is 1. 2466 # An example would be [0, 1e-15, 0, 2e-15] which numpy reports as 2467 # rank 1 invariant of `tol`. The singular value for this vector is 2468 # obviously sqrt(5)*1e-15 and so a tol of e.g. sqrt(6)*1e-15 should 2469 # lead to a reported rank of 0 whereas a tol of 1e-15 should lead 2470 # to a reported rank of 1, numpy reports 1 regardless. 2471 # The code below replicates the numpy behaviour. 2472 def _1d_matrix_rank_impl(a, tol=None): 2473 for k in range(len(a)): 2474 if a[k] != 0.: 2475 return 1 2476 return 0 2477 return _1d_matrix_rank_impl 2478 elif ndim == 2: 2479 return _2d_matrix_rank_impl(a, tol) 2480 else: 2481 assert 0 # unreachable 2482 2483 return _get_matrix_rank_impl(a, tol) 2484 2485 2486@overload(np.linalg.matrix_power) 2487def matrix_power_impl(a, n): 2488 """ 2489 Computes matrix power. Only integer powers are supported in numpy. 2490 """ 2491 2492 _check_linalg_matrix(a, "matrix_power") 2493 np_dtype = np_support.as_dtype(a.dtype) 2494 2495 nt = getattr(n, 'dtype', n) 2496 if not isinstance(nt, types.Integer): 2497 raise TypeError("Exponent must be an integer.") 2498 2499 def matrix_power_impl(a, n): 2500 2501 if n == 0: 2502 # this should be eye() but it doesn't support 2503 # the dtype kwarg yet so do it manually to save 2504 # the copy required by eye(a.shape[0]).asdtype() 2505 A = np.zeros(a.shape, dtype=np_dtype) 2506 for k in range(a.shape[0]): 2507 A[k, k] = 1. 2508 return A 2509 2510 am, an = a.shape[-1], a.shape[-2] 2511 if am != an: 2512 raise ValueError('input must be a square array') 2513 2514 # empty, return a copy 2515 if am == 0: 2516 return a.copy() 2517 2518 # note: to be consistent over contiguousness, C order is 2519 # returned as that is what dot() produces and the most common 2520 # paths through matrix_power will involve that. Therefore 2521 # copies are made here to ensure the data ordering is 2522 # correct for paths not going via dot(). 2523 2524 if n < 0: 2525 A = np.linalg.inv(a).copy() 2526 if n == -1: # return now 2527 return A 2528 n = -n 2529 else: 2530 if n == 1: # return a copy now 2531 return a.copy() 2532 A = a # this is safe, `a` is only read 2533 2534 if n < 4: 2535 if n == 2: 2536 return np.dot(A, A) 2537 if n == 3: 2538 return np.dot(np.dot(A, A), A) 2539 else: 2540 2541 acc = A 2542 exp = n 2543 2544 # Initialise ret, SSA cannot see the loop will execute, without this 2545 # it appears as uninitialised. 2546 ret = acc 2547 # tried a loop split and branchless using identity matrix as 2548 # input but it seems like having a "first entry" flag is quicker 2549 flag = True 2550 while exp != 0: 2551 if exp & 1: 2552 if flag: 2553 ret = acc 2554 flag = False 2555 else: 2556 ret = np.dot(ret, acc) 2557 acc = np.dot(acc, acc) 2558 exp = exp >> 1 2559 2560 return ret 2561 2562 return matrix_power_impl 2563 2564# This is documented under linalg despite not being in the module 2565 2566 2567@overload(np.trace) 2568def matrix_trace_impl(a, offset=0): 2569 """ 2570 Computes the trace of an array. 2571 """ 2572 2573 _check_linalg_matrix(a, "trace", la_prefix=False) 2574 2575 if not isinstance(offset, (int, types.Integer)): 2576 raise TypeError("integer argument expected, got %s" % offset) 2577 2578 def matrix_trace_impl(a, offset=0): 2579 rows, cols = a.shape 2580 k = offset 2581 if k < 0: 2582 rows = rows + k 2583 if k > 0: 2584 cols = cols - k 2585 n = max(min(rows, cols), 0) 2586 ret = 0 2587 if k >= 0: 2588 for i in range(n): 2589 ret += a[i, k + i] 2590 else: 2591 for i in range(n): 2592 ret += a[i - k, i] 2593 return ret 2594 2595 return matrix_trace_impl 2596 2597 2598def _check_scalar_or_lt_2d_mat(a, func_name, la_prefix=True): 2599 prefix = "np.linalg" if la_prefix else "np" 2600 interp = (prefix, func_name) 2601 # checks that a matrix is 1 or 2D 2602 if isinstance(a, types.Array): 2603 if not a.ndim <= 2: 2604 raise TypingError("%s.%s() only supported on 1 and 2-D arrays " 2605 % interp, highlighting=False) 2606 2607 2608def _get_as_array(x): 2609 if not isinstance(x, types.Array): 2610 @register_jitable 2611 def asarray(x): 2612 return np.array((x,)) 2613 return asarray 2614 else: 2615 @register_jitable 2616 def asarray(x): 2617 return x 2618 return asarray 2619 2620 2621def _get_outer_impl(a, b, out): 2622 a_arr = _get_as_array(a) 2623 b_arr = _get_as_array(b) 2624 2625 if out in (None, types.none): 2626 @register_jitable 2627 def outer_impl(a, b, out): 2628 aa = a_arr(a) 2629 bb = b_arr(b) 2630 return np.multiply(aa.ravel().reshape((aa.size, 1)), 2631 bb.ravel().reshape((1, bb.size))) 2632 return outer_impl 2633 else: 2634 @register_jitable 2635 def outer_impl(a, b, out): 2636 aa = a_arr(a) 2637 bb = b_arr(b) 2638 np.multiply(aa.ravel().reshape((aa.size, 1)), 2639 bb.ravel().reshape((1, bb.size)), 2640 out) 2641 return out 2642 return outer_impl 2643 2644 2645@overload(np.outer) 2646def outer_impl(a, b, out=None): 2647 2648 _check_scalar_or_lt_2d_mat(a, "outer", la_prefix=False) 2649 _check_scalar_or_lt_2d_mat(b, "outer", la_prefix=False) 2650 2651 impl = _get_outer_impl(a, b, out) 2652 2653 def outer_impl(a, b, out=None): 2654 return impl(a, b, out) 2655 2656 return outer_impl 2657 2658 2659def _kron_normaliser_impl(x): 2660 # makes x into a 2d array 2661 if isinstance(x, types.Array): 2662 if x.layout not in ('C', 'F'): 2663 raise TypingError("np.linalg.kron only supports 'C' or 'F' layout " 2664 "input arrays. Receieved an input of " 2665 "layout '{}'.".format(x.layout)) 2666 elif x.ndim == 2: 2667 @register_jitable 2668 def nrm_shape(x): 2669 xn = x.shape[-1] 2670 xm = x.shape[-2] 2671 return x.reshape(xm, xn) 2672 return nrm_shape 2673 else: 2674 @register_jitable 2675 def nrm_shape(x): 2676 xn = x.shape[-1] 2677 return x.reshape(1, xn) 2678 return nrm_shape 2679 else: # assume its a scalar 2680 @register_jitable 2681 def nrm_shape(x): 2682 a = np.empty((1, 1), type(x)) 2683 a[0] = x 2684 return a 2685 return nrm_shape 2686 2687 2688def _kron_return(a, b): 2689 # transforms c into something that kron would return 2690 # based on the shapes of a and b 2691 a_is_arr = isinstance(a, types.Array) 2692 b_is_arr = isinstance(b, types.Array) 2693 if a_is_arr and b_is_arr: 2694 if a.ndim == 2 or b.ndim == 2: 2695 @register_jitable 2696 def ret(a, b, c): 2697 return c 2698 return ret 2699 else: 2700 @register_jitable 2701 def ret(a, b, c): 2702 return c.reshape(c.size) 2703 return ret 2704 else: # at least one of (a, b) is a scalar 2705 if a_is_arr: 2706 @register_jitable 2707 def ret(a, b, c): 2708 return c.reshape(a.shape) 2709 return ret 2710 elif b_is_arr: 2711 @register_jitable 2712 def ret(a, b, c): 2713 return c.reshape(b.shape) 2714 return ret 2715 else: # both scalars 2716 @register_jitable 2717 def ret(a, b, c): 2718 return c[0] 2719 return ret 2720 2721 2722@overload(np.kron) 2723def kron_impl(a, b): 2724 2725 _check_scalar_or_lt_2d_mat(a, "kron", la_prefix=False) 2726 _check_scalar_or_lt_2d_mat(b, "kron", la_prefix=False) 2727 2728 fix_a = _kron_normaliser_impl(a) 2729 fix_b = _kron_normaliser_impl(b) 2730 ret_c = _kron_return(a, b) 2731 2732 # this is fine because the ufunc for the Hadamard product 2733 # will reject differing dtypes in a and b. 2734 dt = getattr(a, 'dtype', a) 2735 2736 def kron_impl(a, b): 2737 2738 aa = fix_a(a) 2739 bb = fix_b(b) 2740 2741 am = aa.shape[-2] 2742 an = aa.shape[-1] 2743 bm = bb.shape[-2] 2744 bn = bb.shape[-1] 2745 2746 cm = am * bm 2747 cn = an * bn 2748 2749 # allocate c 2750 C = np.empty((cm, cn), dtype=dt) 2751 2752 # In practice this is runs quicker than the more obvious 2753 # `each element of A multiplied by B and assigned to 2754 # a block in C` like alg. 2755 2756 # loop over rows of A 2757 for i in range(am): 2758 # compute the column offset into C 2759 rjmp = i * bm 2760 # loop over rows of B 2761 for k in range(bm): 2762 # compute row the offset into C 2763 irjmp = rjmp + k 2764 # slice a given row of B 2765 slc = bb[k, :] 2766 # loop over columns of A 2767 for j in range(an): 2768 # vectorized assignment of an element of A 2769 # multiplied by the current row of B into 2770 # a slice of a row of C 2771 cjmp = j * bn 2772 C[irjmp, cjmp:cjmp + bn] = aa[i, j] * slc 2773 2774 return ret_c(a, b, C) 2775 2776 return kron_impl 2777