1import itertools 2import warnings 3 4import numpy as np 5from numpy import (arange, array, dot, zeros, identity, conjugate, transpose, 6 float32) 7import numpy.linalg as linalg 8from numpy.random import random 9 10from numpy.testing import (assert_equal, assert_almost_equal, assert_, 11 assert_array_almost_equal, assert_allclose, 12 assert_array_equal, suppress_warnings) 13import pytest 14from pytest import raises as assert_raises 15 16from scipy.linalg import (solve, inv, det, lstsq, pinv, pinv2, pinvh, norm, 17 solve_banded, solveh_banded, solve_triangular, 18 solve_circulant, circulant, LinAlgError, block_diag, 19 matrix_balance, qr, LinAlgWarning) 20 21from scipy.linalg._testutils import assert_no_overwrite 22from scipy._lib._testutils import check_free_memory 23from scipy.linalg.blas import HAS_ILP64 24 25REAL_DTYPES = (np.float32, np.float64, np.longdouble) 26COMPLEX_DTYPES = (np.complex64, np.complex128, np.clongdouble) 27DTYPES = REAL_DTYPES + COMPLEX_DTYPES 28 29 30def _eps_cast(dtyp): 31 """Get the epsilon for dtype, possibly downcast to BLAS types.""" 32 dt = dtyp 33 if dt == np.longdouble: 34 dt = np.float64 35 elif dt == np.clongdouble: 36 dt = np.complex128 37 return np.finfo(dt).eps 38 39 40class TestSolveBanded: 41 42 def test_real(self): 43 a = array([[1.0, 20, 0, 0], 44 [-30, 4, 6, 0], 45 [2, 1, 20, 2], 46 [0, -1, 7, 14]]) 47 ab = array([[0.0, 20, 6, 2], 48 [1, 4, 20, 14], 49 [-30, 1, 7, 0], 50 [2, -1, 0, 0]]) 51 l, u = 2, 1 52 b4 = array([10.0, 0.0, 2.0, 14.0]) 53 b4by1 = b4.reshape(-1, 1) 54 b4by2 = array([[2, 1], 55 [-30, 4], 56 [2, 3], 57 [1, 3]]) 58 b4by4 = array([[1, 0, 0, 0], 59 [0, 0, 0, 1], 60 [0, 1, 0, 0], 61 [0, 1, 0, 0]]) 62 for b in [b4, b4by1, b4by2, b4by4]: 63 x = solve_banded((l, u), ab, b) 64 assert_array_almost_equal(dot(a, x), b) 65 66 def test_complex(self): 67 a = array([[1.0, 20, 0, 0], 68 [-30, 4, 6, 0], 69 [2j, 1, 20, 2j], 70 [0, -1, 7, 14]]) 71 ab = array([[0.0, 20, 6, 2j], 72 [1, 4, 20, 14], 73 [-30, 1, 7, 0], 74 [2j, -1, 0, 0]]) 75 l, u = 2, 1 76 b4 = array([10.0, 0.0, 2.0, 14.0j]) 77 b4by1 = b4.reshape(-1, 1) 78 b4by2 = array([[2, 1], 79 [-30, 4], 80 [2, 3], 81 [1, 3]]) 82 b4by4 = array([[1, 0, 0, 0], 83 [0, 0, 0, 1j], 84 [0, 1, 0, 0], 85 [0, 1, 0, 0]]) 86 for b in [b4, b4by1, b4by2, b4by4]: 87 x = solve_banded((l, u), ab, b) 88 assert_array_almost_equal(dot(a, x), b) 89 90 def test_tridiag_real(self): 91 ab = array([[0.0, 20, 6, 2], 92 [1, 4, 20, 14], 93 [-30, 1, 7, 0]]) 94 a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag( 95 ab[2, :-1], -1) 96 b4 = array([10.0, 0.0, 2.0, 14.0]) 97 b4by1 = b4.reshape(-1, 1) 98 b4by2 = array([[2, 1], 99 [-30, 4], 100 [2, 3], 101 [1, 3]]) 102 b4by4 = array([[1, 0, 0, 0], 103 [0, 0, 0, 1], 104 [0, 1, 0, 0], 105 [0, 1, 0, 0]]) 106 for b in [b4, b4by1, b4by2, b4by4]: 107 x = solve_banded((1, 1), ab, b) 108 assert_array_almost_equal(dot(a, x), b) 109 110 def test_tridiag_complex(self): 111 ab = array([[0.0, 20, 6, 2j], 112 [1, 4, 20, 14], 113 [-30, 1, 7, 0]]) 114 a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag( 115 ab[2, :-1], -1) 116 b4 = array([10.0, 0.0, 2.0, 14.0j]) 117 b4by1 = b4.reshape(-1, 1) 118 b4by2 = array([[2, 1], 119 [-30, 4], 120 [2, 3], 121 [1, 3]]) 122 b4by4 = array([[1, 0, 0, 0], 123 [0, 0, 0, 1], 124 [0, 1, 0, 0], 125 [0, 1, 0, 0]]) 126 for b in [b4, b4by1, b4by2, b4by4]: 127 x = solve_banded((1, 1), ab, b) 128 assert_array_almost_equal(dot(a, x), b) 129 130 def test_check_finite(self): 131 a = array([[1.0, 20, 0, 0], 132 [-30, 4, 6, 0], 133 [2, 1, 20, 2], 134 [0, -1, 7, 14]]) 135 ab = array([[0.0, 20, 6, 2], 136 [1, 4, 20, 14], 137 [-30, 1, 7, 0], 138 [2, -1, 0, 0]]) 139 l, u = 2, 1 140 b4 = array([10.0, 0.0, 2.0, 14.0]) 141 x = solve_banded((l, u), ab, b4, check_finite=False) 142 assert_array_almost_equal(dot(a, x), b4) 143 144 def test_bad_shape(self): 145 ab = array([[0.0, 20, 6, 2], 146 [1, 4, 20, 14], 147 [-30, 1, 7, 0], 148 [2, -1, 0, 0]]) 149 l, u = 2, 1 150 bad = array([1.0, 2.0, 3.0, 4.0]).reshape(-1, 4) 151 assert_raises(ValueError, solve_banded, (l, u), ab, bad) 152 assert_raises(ValueError, solve_banded, (l, u), ab, [1.0, 2.0]) 153 154 # Values of (l,u) are not compatible with ab. 155 assert_raises(ValueError, solve_banded, (1, 1), ab, [1.0, 2.0]) 156 157 def test_1x1(self): 158 b = array([[1., 2., 3.]]) 159 x = solve_banded((1, 1), [[0], [2], [0]], b) 160 assert_array_equal(x, [[0.5, 1.0, 1.5]]) 161 assert_equal(x.dtype, np.dtype('f8')) 162 assert_array_equal(b, [[1.0, 2.0, 3.0]]) 163 164 def test_native_list_arguments(self): 165 a = [[1.0, 20, 0, 0], 166 [-30, 4, 6, 0], 167 [2, 1, 20, 2], 168 [0, -1, 7, 14]] 169 ab = [[0.0, 20, 6, 2], 170 [1, 4, 20, 14], 171 [-30, 1, 7, 0], 172 [2, -1, 0, 0]] 173 l, u = 2, 1 174 b = [10.0, 0.0, 2.0, 14.0] 175 x = solve_banded((l, u), ab, b) 176 assert_array_almost_equal(dot(a, x), b) 177 178 179class TestSolveHBanded: 180 181 def test_01_upper(self): 182 # Solve 183 # [ 4 1 2 0] [1] 184 # [ 1 4 1 2] X = [4] 185 # [ 2 1 4 1] [1] 186 # [ 0 2 1 4] [2] 187 # with the RHS as a 1D array. 188 ab = array([[0.0, 0.0, 2.0, 2.0], 189 [-99, 1.0, 1.0, 1.0], 190 [4.0, 4.0, 4.0, 4.0]]) 191 b = array([1.0, 4.0, 1.0, 2.0]) 192 x = solveh_banded(ab, b) 193 assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0]) 194 195 def test_02_upper(self): 196 # Solve 197 # [ 4 1 2 0] [1 6] 198 # [ 1 4 1 2] X = [4 2] 199 # [ 2 1 4 1] [1 6] 200 # [ 0 2 1 4] [2 1] 201 # 202 ab = array([[0.0, 0.0, 2.0, 2.0], 203 [-99, 1.0, 1.0, 1.0], 204 [4.0, 4.0, 4.0, 4.0]]) 205 b = array([[1.0, 6.0], 206 [4.0, 2.0], 207 [1.0, 6.0], 208 [2.0, 1.0]]) 209 x = solveh_banded(ab, b) 210 expected = array([[0.0, 1.0], 211 [1.0, 0.0], 212 [0.0, 1.0], 213 [0.0, 0.0]]) 214 assert_array_almost_equal(x, expected) 215 216 def test_03_upper(self): 217 # Solve 218 # [ 4 1 2 0] [1] 219 # [ 1 4 1 2] X = [4] 220 # [ 2 1 4 1] [1] 221 # [ 0 2 1 4] [2] 222 # with the RHS as a 2D array with shape (3,1). 223 ab = array([[0.0, 0.0, 2.0, 2.0], 224 [-99, 1.0, 1.0, 1.0], 225 [4.0, 4.0, 4.0, 4.0]]) 226 b = array([1.0, 4.0, 1.0, 2.0]).reshape(-1, 1) 227 x = solveh_banded(ab, b) 228 assert_array_almost_equal(x, array([0., 1., 0., 0.]).reshape(-1, 1)) 229 230 def test_01_lower(self): 231 # Solve 232 # [ 4 1 2 0] [1] 233 # [ 1 4 1 2] X = [4] 234 # [ 2 1 4 1] [1] 235 # [ 0 2 1 4] [2] 236 # 237 ab = array([[4.0, 4.0, 4.0, 4.0], 238 [1.0, 1.0, 1.0, -99], 239 [2.0, 2.0, 0.0, 0.0]]) 240 b = array([1.0, 4.0, 1.0, 2.0]) 241 x = solveh_banded(ab, b, lower=True) 242 assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0]) 243 244 def test_02_lower(self): 245 # Solve 246 # [ 4 1 2 0] [1 6] 247 # [ 1 4 1 2] X = [4 2] 248 # [ 2 1 4 1] [1 6] 249 # [ 0 2 1 4] [2 1] 250 # 251 ab = array([[4.0, 4.0, 4.0, 4.0], 252 [1.0, 1.0, 1.0, -99], 253 [2.0, 2.0, 0.0, 0.0]]) 254 b = array([[1.0, 6.0], 255 [4.0, 2.0], 256 [1.0, 6.0], 257 [2.0, 1.0]]) 258 x = solveh_banded(ab, b, lower=True) 259 expected = array([[0.0, 1.0], 260 [1.0, 0.0], 261 [0.0, 1.0], 262 [0.0, 0.0]]) 263 assert_array_almost_equal(x, expected) 264 265 def test_01_float32(self): 266 # Solve 267 # [ 4 1 2 0] [1] 268 # [ 1 4 1 2] X = [4] 269 # [ 2 1 4 1] [1] 270 # [ 0 2 1 4] [2] 271 # 272 ab = array([[0.0, 0.0, 2.0, 2.0], 273 [-99, 1.0, 1.0, 1.0], 274 [4.0, 4.0, 4.0, 4.0]], dtype=float32) 275 b = array([1.0, 4.0, 1.0, 2.0], dtype=float32) 276 x = solveh_banded(ab, b) 277 assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0]) 278 279 def test_02_float32(self): 280 # Solve 281 # [ 4 1 2 0] [1 6] 282 # [ 1 4 1 2] X = [4 2] 283 # [ 2 1 4 1] [1 6] 284 # [ 0 2 1 4] [2 1] 285 # 286 ab = array([[0.0, 0.0, 2.0, 2.0], 287 [-99, 1.0, 1.0, 1.0], 288 [4.0, 4.0, 4.0, 4.0]], dtype=float32) 289 b = array([[1.0, 6.0], 290 [4.0, 2.0], 291 [1.0, 6.0], 292 [2.0, 1.0]], dtype=float32) 293 x = solveh_banded(ab, b) 294 expected = array([[0.0, 1.0], 295 [1.0, 0.0], 296 [0.0, 1.0], 297 [0.0, 0.0]]) 298 assert_array_almost_equal(x, expected) 299 300 def test_01_complex(self): 301 # Solve 302 # [ 4 -j 2 0] [2-j] 303 # [ j 4 -j 2] X = [4-j] 304 # [ 2 j 4 -j] [4+j] 305 # [ 0 2 j 4] [2+j] 306 # 307 ab = array([[0.0, 0.0, 2.0, 2.0], 308 [-99, -1.0j, -1.0j, -1.0j], 309 [4.0, 4.0, 4.0, 4.0]]) 310 b = array([2-1.0j, 4.0-1j, 4+1j, 2+1j]) 311 x = solveh_banded(ab, b) 312 assert_array_almost_equal(x, [0.0, 1.0, 1.0, 0.0]) 313 314 def test_02_complex(self): 315 # Solve 316 # [ 4 -j 2 0] [2-j 2+4j] 317 # [ j 4 -j 2] X = [4-j -1-j] 318 # [ 2 j 4 -j] [4+j 4+2j] 319 # [ 0 2 j 4] [2+j j] 320 # 321 ab = array([[0.0, 0.0, 2.0, 2.0], 322 [-99, -1.0j, -1.0j, -1.0j], 323 [4.0, 4.0, 4.0, 4.0]]) 324 b = array([[2-1j, 2+4j], 325 [4.0-1j, -1-1j], 326 [4.0+1j, 4+2j], 327 [2+1j, 1j]]) 328 x = solveh_banded(ab, b) 329 expected = array([[0.0, 1.0j], 330 [1.0, 0.0], 331 [1.0, 1.0], 332 [0.0, 0.0]]) 333 assert_array_almost_equal(x, expected) 334 335 def test_tridiag_01_upper(self): 336 # Solve 337 # [ 4 1 0] [1] 338 # [ 1 4 1] X = [4] 339 # [ 0 1 4] [1] 340 # with the RHS as a 1D array. 341 ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]]) 342 b = array([1.0, 4.0, 1.0]) 343 x = solveh_banded(ab, b) 344 assert_array_almost_equal(x, [0.0, 1.0, 0.0]) 345 346 def test_tridiag_02_upper(self): 347 # Solve 348 # [ 4 1 0] [1 4] 349 # [ 1 4 1] X = [4 2] 350 # [ 0 1 4] [1 4] 351 # 352 ab = array([[-99, 1.0, 1.0], 353 [4.0, 4.0, 4.0]]) 354 b = array([[1.0, 4.0], 355 [4.0, 2.0], 356 [1.0, 4.0]]) 357 x = solveh_banded(ab, b) 358 expected = array([[0.0, 1.0], 359 [1.0, 0.0], 360 [0.0, 1.0]]) 361 assert_array_almost_equal(x, expected) 362 363 def test_tridiag_03_upper(self): 364 # Solve 365 # [ 4 1 0] [1] 366 # [ 1 4 1] X = [4] 367 # [ 0 1 4] [1] 368 # with the RHS as a 2D array with shape (3,1). 369 ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]]) 370 b = array([1.0, 4.0, 1.0]).reshape(-1, 1) 371 x = solveh_banded(ab, b) 372 assert_array_almost_equal(x, array([0.0, 1.0, 0.0]).reshape(-1, 1)) 373 374 def test_tridiag_01_lower(self): 375 # Solve 376 # [ 4 1 0] [1] 377 # [ 1 4 1] X = [4] 378 # [ 0 1 4] [1] 379 # 380 ab = array([[4.0, 4.0, 4.0], 381 [1.0, 1.0, -99]]) 382 b = array([1.0, 4.0, 1.0]) 383 x = solveh_banded(ab, b, lower=True) 384 assert_array_almost_equal(x, [0.0, 1.0, 0.0]) 385 386 def test_tridiag_02_lower(self): 387 # Solve 388 # [ 4 1 0] [1 4] 389 # [ 1 4 1] X = [4 2] 390 # [ 0 1 4] [1 4] 391 # 392 ab = array([[4.0, 4.0, 4.0], 393 [1.0, 1.0, -99]]) 394 b = array([[1.0, 4.0], 395 [4.0, 2.0], 396 [1.0, 4.0]]) 397 x = solveh_banded(ab, b, lower=True) 398 expected = array([[0.0, 1.0], 399 [1.0, 0.0], 400 [0.0, 1.0]]) 401 assert_array_almost_equal(x, expected) 402 403 def test_tridiag_01_float32(self): 404 # Solve 405 # [ 4 1 0] [1] 406 # [ 1 4 1] X = [4] 407 # [ 0 1 4] [1] 408 # 409 ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]], dtype=float32) 410 b = array([1.0, 4.0, 1.0], dtype=float32) 411 x = solveh_banded(ab, b) 412 assert_array_almost_equal(x, [0.0, 1.0, 0.0]) 413 414 def test_tridiag_02_float32(self): 415 # Solve 416 # [ 4 1 0] [1 4] 417 # [ 1 4 1] X = [4 2] 418 # [ 0 1 4] [1 4] 419 # 420 ab = array([[-99, 1.0, 1.0], 421 [4.0, 4.0, 4.0]], dtype=float32) 422 b = array([[1.0, 4.0], 423 [4.0, 2.0], 424 [1.0, 4.0]], dtype=float32) 425 x = solveh_banded(ab, b) 426 expected = array([[0.0, 1.0], 427 [1.0, 0.0], 428 [0.0, 1.0]]) 429 assert_array_almost_equal(x, expected) 430 431 def test_tridiag_01_complex(self): 432 # Solve 433 # [ 4 -j 0] [ -j] 434 # [ j 4 -j] X = [4-j] 435 # [ 0 j 4] [4+j] 436 # 437 ab = array([[-99, -1.0j, -1.0j], [4.0, 4.0, 4.0]]) 438 b = array([-1.0j, 4.0-1j, 4+1j]) 439 x = solveh_banded(ab, b) 440 assert_array_almost_equal(x, [0.0, 1.0, 1.0]) 441 442 def test_tridiag_02_complex(self): 443 # Solve 444 # [ 4 -j 0] [ -j 4j] 445 # [ j 4 -j] X = [4-j -1-j] 446 # [ 0 j 4] [4+j 4 ] 447 # 448 ab = array([[-99, -1.0j, -1.0j], 449 [4.0, 4.0, 4.0]]) 450 b = array([[-1j, 4.0j], 451 [4.0-1j, -1.0-1j], 452 [4.0+1j, 4.0]]) 453 x = solveh_banded(ab, b) 454 expected = array([[0.0, 1.0j], 455 [1.0, 0.0], 456 [1.0, 1.0]]) 457 assert_array_almost_equal(x, expected) 458 459 def test_check_finite(self): 460 # Solve 461 # [ 4 1 0] [1] 462 # [ 1 4 1] X = [4] 463 # [ 0 1 4] [1] 464 # with the RHS as a 1D array. 465 ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]]) 466 b = array([1.0, 4.0, 1.0]) 467 x = solveh_banded(ab, b, check_finite=False) 468 assert_array_almost_equal(x, [0.0, 1.0, 0.0]) 469 470 def test_bad_shapes(self): 471 ab = array([[-99, 1.0, 1.0], 472 [4.0, 4.0, 4.0]]) 473 b = array([[1.0, 4.0], 474 [4.0, 2.0]]) 475 assert_raises(ValueError, solveh_banded, ab, b) 476 assert_raises(ValueError, solveh_banded, ab, [1.0, 2.0]) 477 assert_raises(ValueError, solveh_banded, ab, [1.0]) 478 479 def test_1x1(self): 480 x = solveh_banded([[1]], [[1, 2, 3]]) 481 assert_array_equal(x, [[1.0, 2.0, 3.0]]) 482 assert_equal(x.dtype, np.dtype('f8')) 483 484 def test_native_list_arguments(self): 485 # Same as test_01_upper, using python's native list. 486 ab = [[0.0, 0.0, 2.0, 2.0], 487 [-99, 1.0, 1.0, 1.0], 488 [4.0, 4.0, 4.0, 4.0]] 489 b = [1.0, 4.0, 1.0, 2.0] 490 x = solveh_banded(ab, b) 491 assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0]) 492 493 494class TestSolve: 495 def setup_method(self): 496 np.random.seed(1234) 497 498 def test_20Feb04_bug(self): 499 a = [[1, 1], [1.0, 0]] # ok 500 x0 = solve(a, [1, 0j]) 501 assert_array_almost_equal(dot(a, x0), [1, 0]) 502 503 # gives failure with clapack.zgesv(..,rowmajor=0) 504 a = [[1, 1], [1.2, 0]] 505 b = [1, 0j] 506 x0 = solve(a, b) 507 assert_array_almost_equal(dot(a, x0), [1, 0]) 508 509 def test_simple(self): 510 a = [[1, 20], [-30, 4]] 511 for b in ([[1, 0], [0, 1]], 512 [1, 0], 513 [[2, 1], [-30, 4]] 514 ): 515 x = solve(a, b) 516 assert_array_almost_equal(dot(a, x), b) 517 518 def test_simple_complex(self): 519 a = array([[5, 2], [2j, 4]], 'D') 520 for b in ([1j, 0], 521 [[1j, 1j], [0, 2]], 522 [1, 0j], 523 array([1, 0], 'D'), 524 ): 525 x = solve(a, b) 526 assert_array_almost_equal(dot(a, x), b) 527 528 def test_simple_pos(self): 529 a = [[2, 3], [3, 5]] 530 for lower in [0, 1]: 531 for b in ([[1, 0], [0, 1]], 532 [1, 0] 533 ): 534 x = solve(a, b, assume_a='pos', lower=lower) 535 assert_array_almost_equal(dot(a, x), b) 536 537 def test_simple_pos_complexb(self): 538 a = [[5, 2], [2, 4]] 539 for b in ([1j, 0], 540 [[1j, 1j], [0, 2]], 541 ): 542 x = solve(a, b, assume_a='pos') 543 assert_array_almost_equal(dot(a, x), b) 544 545 def test_simple_sym(self): 546 a = [[2, 3], [3, -5]] 547 for lower in [0, 1]: 548 for b in ([[1, 0], [0, 1]], 549 [1, 0] 550 ): 551 x = solve(a, b, assume_a='sym', lower=lower) 552 assert_array_almost_equal(dot(a, x), b) 553 554 def test_simple_sym_complexb(self): 555 a = [[5, 2], [2, -4]] 556 for b in ([1j, 0], 557 [[1j, 1j],[0, 2]] 558 ): 559 x = solve(a, b, assume_a='sym') 560 assert_array_almost_equal(dot(a, x), b) 561 562 def test_simple_sym_complex(self): 563 a = [[5, 2+1j], [2+1j, -4]] 564 for b in ([1j, 0], 565 [1, 0], 566 [[1j, 1j], [0, 2]] 567 ): 568 x = solve(a, b, assume_a='sym') 569 assert_array_almost_equal(dot(a, x), b) 570 571 def test_simple_her_actuallysym(self): 572 a = [[2, 3], [3, -5]] 573 for lower in [0, 1]: 574 for b in ([[1, 0], [0, 1]], 575 [1, 0], 576 [1j, 0], 577 ): 578 x = solve(a, b, assume_a='her', lower=lower) 579 assert_array_almost_equal(dot(a, x), b) 580 581 582 def test_simple_her(self): 583 a = [[5, 2+1j], [2-1j, -4]] 584 for b in ([1j, 0], 585 [1, 0], 586 [[1j, 1j], [0, 2]] 587 ): 588 x = solve(a, b, assume_a='her') 589 assert_array_almost_equal(dot(a, x), b) 590 591 592 593 def test_nils_20Feb04(self): 594 n = 2 595 A = random([n, n])+random([n, n])*1j 596 X = zeros((n, n), 'D') 597 Ainv = inv(A) 598 R = identity(n)+identity(n)*0j 599 for i in arange(0, n): 600 r = R[:, i] 601 X[:, i] = solve(A, r) 602 assert_array_almost_equal(X, Ainv) 603 604 def test_random(self): 605 606 n = 20 607 a = random([n, n]) 608 for i in range(n): 609 a[i, i] = 20*(.1+a[i, i]) 610 for i in range(4): 611 b = random([n, 3]) 612 x = solve(a, b) 613 assert_array_almost_equal(dot(a, x), b) 614 615 def test_random_complex(self): 616 n = 20 617 a = random([n, n]) + 1j * random([n, n]) 618 for i in range(n): 619 a[i, i] = 20*(.1+a[i, i]) 620 for i in range(2): 621 b = random([n, 3]) 622 x = solve(a, b) 623 assert_array_almost_equal(dot(a, x), b) 624 625 def test_random_sym(self): 626 n = 20 627 a = random([n, n]) 628 for i in range(n): 629 a[i, i] = abs(20*(.1+a[i, i])) 630 for j in range(i): 631 a[i, j] = a[j, i] 632 for i in range(4): 633 b = random([n]) 634 x = solve(a, b, sym_pos=1) 635 assert_array_almost_equal(dot(a, x), b) 636 637 def test_random_sym_complex(self): 638 n = 20 639 a = random([n, n]) 640 a = a + 1j*random([n, n]) 641 for i in range(n): 642 a[i, i] = abs(20*(.1+a[i, i])) 643 for j in range(i): 644 a[i, j] = conjugate(a[j, i]) 645 b = random([n])+2j*random([n]) 646 for i in range(2): 647 x = solve(a, b, sym_pos=1) 648 assert_array_almost_equal(dot(a, x), b) 649 650 def test_check_finite(self): 651 a = [[1, 20], [-30, 4]] 652 for b in ([[1, 0], [0, 1]], [1, 0], 653 [[2, 1], [-30, 4]]): 654 x = solve(a, b, check_finite=False) 655 assert_array_almost_equal(dot(a, x), b) 656 657 def test_scalar_a_and_1D_b(self): 658 a = 1 659 b = [1, 2, 3] 660 x = solve(a, b) 661 assert_array_almost_equal(x.ravel(), b) 662 assert_(x.shape == (3,), 'Scalar_a_1D_b test returned wrong shape') 663 664 def test_simple2(self): 665 a = np.array([[1.80, 2.88, 2.05, -0.89], 666 [525.00, -295.00, -95.00, -380.00], 667 [1.58, -2.69, -2.90, -1.04], 668 [-1.11, -0.66, -0.59, 0.80]]) 669 670 b = np.array([[9.52, 18.47], 671 [2435.00, 225.00], 672 [0.77, -13.28], 673 [-6.22, -6.21]]) 674 675 x = solve(a, b) 676 assert_array_almost_equal(x, np.array([[1., -1, 3, -5], 677 [3, 2, 4, 1]]).T) 678 679 def test_simple_complex2(self): 680 a = np.array([[-1.34+2.55j, 0.28+3.17j, -6.39-2.20j, 0.72-0.92j], 681 [-1.70-14.10j, 33.10-1.50j, -1.50+13.40j, 12.90+13.80j], 682 [-3.29-2.39j, -1.91+4.42j, -0.14-1.35j, 1.72+1.35j], 683 [2.41+0.39j, -0.56+1.47j, -0.83-0.69j, -1.96+0.67j]]) 684 685 b = np.array([[26.26+51.78j, 31.32-6.70j], 686 [64.30-86.80j, 158.60-14.20j], 687 [-5.75+25.31j, -2.15+30.19j], 688 [1.16+2.57j, -2.56+7.55j]]) 689 690 x = solve(a, b) 691 assert_array_almost_equal(x, np. array([[1+1.j, -1-2.j], 692 [2-3.j, 5+1.j], 693 [-4-5.j, -3+4.j], 694 [6.j, 2-3.j]])) 695 696 def test_hermitian(self): 697 # An upper triangular matrix will be used for hermitian matrix a 698 a = np.array([[-1.84, 0.11-0.11j, -1.78-1.18j, 3.91-1.50j], 699 [0, -4.63, -1.84+0.03j, 2.21+0.21j], 700 [0, 0, -8.87, 1.58-0.90j], 701 [0, 0, 0, -1.36]]) 702 b = np.array([[2.98-10.18j, 28.68-39.89j], 703 [-9.58+3.88j, -24.79-8.40j], 704 [-0.77-16.05j, 4.23-70.02j], 705 [7.79+5.48j, -35.39+18.01j]]) 706 res = np.array([[2.+1j, -8+6j], 707 [3.-2j, 7-2j], 708 [-1+2j, -1+5j], 709 [1.-1j, 3-4j]]) 710 x = solve(a, b, assume_a='her') 711 assert_array_almost_equal(x, res) 712 # Also conjugate a and test for lower triangular data 713 x = solve(a.conj().T, b, assume_a='her', lower=True) 714 assert_array_almost_equal(x, res) 715 716 def test_pos_and_sym(self): 717 A = np.arange(1, 10).reshape(3, 3) 718 x = solve(np.tril(A)/9, np.ones(3), assume_a='pos') 719 assert_array_almost_equal(x, [9., 1.8, 1.]) 720 x = solve(np.tril(A)/9, np.ones(3), assume_a='sym') 721 assert_array_almost_equal(x, [9., 1.8, 1.]) 722 723 def test_singularity(self): 724 a = np.array([[1, 0, 0, 0, 0, 0, 1, 0, 1], 725 [1, 1, 1, 0, 0, 0, 1, 0, 1], 726 [0, 1, 1, 0, 0, 0, 1, 0, 1], 727 [1, 0, 1, 1, 1, 1, 0, 0, 0], 728 [1, 0, 1, 1, 1, 1, 0, 0, 0], 729 [1, 0, 1, 1, 1, 1, 0, 0, 0], 730 [1, 0, 1, 1, 1, 1, 0, 0, 0], 731 [1, 1, 1, 1, 1, 1, 1, 1, 1], 732 [1, 1, 1, 1, 1, 1, 1, 1, 1]]) 733 b = np.arange(9)[:, None] 734 assert_raises(LinAlgError, solve, a, b) 735 736 def test_ill_condition_warning(self): 737 a = np.array([[1, 1], [1+1e-16, 1-1e-16]]) 738 b = np.ones(2) 739 with warnings.catch_warnings(): 740 warnings.simplefilter('error') 741 assert_raises(LinAlgWarning, solve, a, b) 742 743 def test_empty_rhs(self): 744 a = np.eye(2) 745 b = [[], []] 746 x = solve(a, b) 747 assert_(x.size == 0, 'Returned array is not empty') 748 assert_(x.shape == (2, 0), 'Returned empty array shape is wrong') 749 750 def test_multiple_rhs(self): 751 a = np.eye(2) 752 b = np.random.rand(2, 3, 4) 753 x = solve(a, b) 754 assert_array_almost_equal(x, b) 755 756 def test_transposed_keyword(self): 757 A = np.arange(9).reshape(3, 3) + 1 758 x = solve(np.tril(A)/9, np.ones(3), transposed=True) 759 assert_array_almost_equal(x, [1.2, 0.2, 1]) 760 x = solve(np.tril(A)/9, np.ones(3), transposed=False) 761 assert_array_almost_equal(x, [9, -5.4, -1.2]) 762 763 def test_transposed_notimplemented(self): 764 a = np.eye(3).astype(complex) 765 with assert_raises(NotImplementedError): 766 solve(a, a, transposed=True) 767 768 def test_nonsquare_a(self): 769 assert_raises(ValueError, solve, [1, 2], 1) 770 771 def test_size_mismatch_with_1D_b(self): 772 assert_array_almost_equal(solve(np.eye(3), np.ones(3)), np.ones(3)) 773 assert_raises(ValueError, solve, np.eye(3), np.ones(4)) 774 775 def test_assume_a_keyword(self): 776 assert_raises(ValueError, solve, 1, 1, assume_a='zxcv') 777 778 @pytest.mark.skip(reason="Failure on OS X (gh-7500), " 779 "crash on Windows (gh-8064)") 780 def test_all_type_size_routine_combinations(self): 781 sizes = [10, 100] 782 assume_as = ['gen', 'sym', 'pos', 'her'] 783 dtypes = [np.float32, np.float64, np.complex64, np.complex128] 784 for size, assume_a, dtype in itertools.product(sizes, assume_as, 785 dtypes): 786 is_complex = dtype in (np.complex64, np.complex128) 787 if assume_a == 'her' and not is_complex: 788 continue 789 790 err_msg = ("Failed for size: {}, assume_a: {}," 791 "dtype: {}".format(size, assume_a, dtype)) 792 793 a = np.random.randn(size, size).astype(dtype) 794 b = np.random.randn(size).astype(dtype) 795 if is_complex: 796 a = a + (1j*np.random.randn(size, size)).astype(dtype) 797 798 if assume_a == 'sym': # Can still be complex but only symmetric 799 a = a + a.T 800 elif assume_a == 'her': # Handle hermitian matrices here instead 801 a = a + a.T.conj() 802 elif assume_a == 'pos': 803 a = a.conj().T.dot(a) + 0.1*np.eye(size) 804 805 tol = 1e-12 if dtype in (np.float64, np.complex128) else 1e-6 806 807 if assume_a in ['gen', 'sym', 'her']: 808 # We revert the tolerance from before 809 # 4b4a6e7c34fa4060533db38f9a819b98fa81476c 810 if dtype in (np.float32, np.complex64): 811 tol *= 10 812 813 x = solve(a, b, assume_a=assume_a) 814 assert_allclose(a.dot(x), b, 815 atol=tol * size, 816 rtol=tol * size, 817 err_msg=err_msg) 818 819 if assume_a == 'sym' and dtype not in (np.complex64, 820 np.complex128): 821 x = solve(a, b, assume_a=assume_a, transposed=True) 822 assert_allclose(a.dot(x), b, 823 atol=tol * size, 824 rtol=tol * size, 825 err_msg=err_msg) 826 827 828class TestSolveTriangular: 829 830 def test_simple(self): 831 """ 832 solve_triangular on a simple 2x2 matrix. 833 """ 834 A = array([[1, 0], [1, 2]]) 835 b = [1, 1] 836 sol = solve_triangular(A, b, lower=True) 837 assert_array_almost_equal(sol, [1, 0]) 838 839 # check that it works also for non-contiguous matrices 840 sol = solve_triangular(A.T, b, lower=False) 841 assert_array_almost_equal(sol, [.5, .5]) 842 843 # and that it gives the same result as trans=1 844 sol = solve_triangular(A, b, lower=True, trans=1) 845 assert_array_almost_equal(sol, [.5, .5]) 846 847 b = identity(2) 848 sol = solve_triangular(A, b, lower=True, trans=1) 849 assert_array_almost_equal(sol, [[1., -.5], [0, 0.5]]) 850 851 def test_simple_complex(self): 852 """ 853 solve_triangular on a simple 2x2 complex matrix 854 """ 855 A = array([[1+1j, 0], [1j, 2]]) 856 b = identity(2) 857 sol = solve_triangular(A, b, lower=True, trans=1) 858 assert_array_almost_equal(sol, [[.5-.5j, -.25-.25j], [0, 0.5]]) 859 860 # check other option combinations with complex rhs 861 b = np.diag([1+1j, 1+2j]) 862 sol = solve_triangular(A, b, lower=True, trans=0) 863 assert_array_almost_equal(sol, [[1, 0], [-0.5j, 0.5+1j]]) 864 865 sol = solve_triangular(A, b, lower=True, trans=1) 866 assert_array_almost_equal(sol, [[1, 0.25-0.75j], [0, 0.5+1j]]) 867 868 sol = solve_triangular(A, b, lower=True, trans=2) 869 assert_array_almost_equal(sol, [[1j, -0.75-0.25j], [0, 0.5+1j]]) 870 871 sol = solve_triangular(A.T, b, lower=False, trans=0) 872 assert_array_almost_equal(sol, [[1, 0.25-0.75j], [0, 0.5+1j]]) 873 874 sol = solve_triangular(A.T, b, lower=False, trans=1) 875 assert_array_almost_equal(sol, [[1, 0], [-0.5j, 0.5+1j]]) 876 877 sol = solve_triangular(A.T, b, lower=False, trans=2) 878 assert_array_almost_equal(sol, [[1j, 0], [-0.5, 0.5+1j]]) 879 880 def test_check_finite(self): 881 """ 882 solve_triangular on a simple 2x2 matrix. 883 """ 884 A = array([[1, 0], [1, 2]]) 885 b = [1, 1] 886 sol = solve_triangular(A, b, lower=True, check_finite=False) 887 assert_array_almost_equal(sol, [1, 0]) 888 889 890class TestInv: 891 def setup_method(self): 892 np.random.seed(1234) 893 894 def test_simple(self): 895 a = [[1, 2], [3, 4]] 896 a_inv = inv(a) 897 assert_array_almost_equal(dot(a, a_inv), np.eye(2)) 898 a = [[1, 2, 3], [4, 5, 6], [7, 8, 10]] 899 a_inv = inv(a) 900 assert_array_almost_equal(dot(a, a_inv), np.eye(3)) 901 902 def test_random(self): 903 n = 20 904 for i in range(4): 905 a = random([n, n]) 906 for i in range(n): 907 a[i, i] = 20*(.1+a[i, i]) 908 a_inv = inv(a) 909 assert_array_almost_equal(dot(a, a_inv), 910 identity(n)) 911 912 def test_simple_complex(self): 913 a = [[1, 2], [3, 4j]] 914 a_inv = inv(a) 915 assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]]) 916 917 def test_random_complex(self): 918 n = 20 919 for i in range(4): 920 a = random([n, n])+2j*random([n, n]) 921 for i in range(n): 922 a[i, i] = 20*(.1+a[i, i]) 923 a_inv = inv(a) 924 assert_array_almost_equal(dot(a, a_inv), 925 identity(n)) 926 927 def test_check_finite(self): 928 a = [[1, 2], [3, 4]] 929 a_inv = inv(a, check_finite=False) 930 assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]]) 931 932 933class TestDet: 934 def setup_method(self): 935 np.random.seed(1234) 936 937 def test_simple(self): 938 a = [[1, 2], [3, 4]] 939 a_det = det(a) 940 assert_almost_equal(a_det, -2.0) 941 942 def test_simple_complex(self): 943 a = [[1, 2], [3, 4j]] 944 a_det = det(a) 945 assert_almost_equal(a_det, -6+4j) 946 947 def test_random(self): 948 basic_det = linalg.det 949 n = 20 950 for i in range(4): 951 a = random([n, n]) 952 d1 = det(a) 953 d2 = basic_det(a) 954 assert_almost_equal(d1, d2) 955 956 def test_random_complex(self): 957 basic_det = linalg.det 958 n = 20 959 for i in range(4): 960 a = random([n, n]) + 2j*random([n, n]) 961 d1 = det(a) 962 d2 = basic_det(a) 963 assert_allclose(d1, d2, rtol=1e-13) 964 965 def test_check_finite(self): 966 a = [[1, 2], [3, 4]] 967 a_det = det(a, check_finite=False) 968 assert_almost_equal(a_det, -2.0) 969 970 971def direct_lstsq(a, b, cmplx=0): 972 at = transpose(a) 973 if cmplx: 974 at = conjugate(at) 975 a1 = dot(at, a) 976 b1 = dot(at, b) 977 return solve(a1, b1) 978 979 980class TestLstsq: 981 982 lapack_drivers = ('gelsd', 'gelss', 'gelsy', None) 983 984 def setup_method(self): 985 np.random.seed(1234) 986 987 def test_simple_exact(self): 988 for dtype in REAL_DTYPES: 989 a = np.array([[1, 20], [-30, 4]], dtype=dtype) 990 for lapack_driver in TestLstsq.lapack_drivers: 991 for overwrite in (True, False): 992 for bt in (((1, 0), (0, 1)), (1, 0), 993 ((2, 1), (-30, 4))): 994 # Store values in case they are overwritten 995 # later 996 a1 = a.copy() 997 b = np.array(bt, dtype=dtype) 998 b1 = b.copy() 999 out = lstsq(a1, b1, 1000 lapack_driver=lapack_driver, 1001 overwrite_a=overwrite, 1002 overwrite_b=overwrite) 1003 x = out[0] 1004 r = out[2] 1005 assert_(r == 2, 1006 'expected efficient rank 2, got %s' % r) 1007 assert_allclose(dot(a, x), b, 1008 atol=25 * _eps_cast(a1.dtype), 1009 rtol=25 * _eps_cast(a1.dtype), 1010 err_msg="driver: %s" % lapack_driver) 1011 1012 def test_simple_overdet(self): 1013 for dtype in REAL_DTYPES: 1014 a = np.array([[1, 2], [4, 5], [3, 4]], dtype=dtype) 1015 b = np.array([1, 2, 3], dtype=dtype) 1016 for lapack_driver in TestLstsq.lapack_drivers: 1017 for overwrite in (True, False): 1018 # Store values in case they are overwritten later 1019 a1 = a.copy() 1020 b1 = b.copy() 1021 out = lstsq(a1, b1, lapack_driver=lapack_driver, 1022 overwrite_a=overwrite, 1023 overwrite_b=overwrite) 1024 x = out[0] 1025 if lapack_driver == 'gelsy': 1026 residuals = np.sum((b - a.dot(x))**2) 1027 else: 1028 residuals = out[1] 1029 r = out[2] 1030 assert_(r == 2, 'expected efficient rank 2, got %s' % r) 1031 assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0), 1032 residuals, 1033 rtol=25 * _eps_cast(a1.dtype), 1034 atol=25 * _eps_cast(a1.dtype), 1035 err_msg="driver: %s" % lapack_driver) 1036 assert_allclose(x, (-0.428571428571429, 0.85714285714285), 1037 rtol=25 * _eps_cast(a1.dtype), 1038 atol=25 * _eps_cast(a1.dtype), 1039 err_msg="driver: %s" % lapack_driver) 1040 1041 def test_simple_overdet_complex(self): 1042 for dtype in COMPLEX_DTYPES: 1043 a = np.array([[1+2j, 2], [4, 5], [3, 4]], dtype=dtype) 1044 b = np.array([1, 2+4j, 3], dtype=dtype) 1045 for lapack_driver in TestLstsq.lapack_drivers: 1046 for overwrite in (True, False): 1047 # Store values in case they are overwritten later 1048 a1 = a.copy() 1049 b1 = b.copy() 1050 out = lstsq(a1, b1, lapack_driver=lapack_driver, 1051 overwrite_a=overwrite, 1052 overwrite_b=overwrite) 1053 1054 x = out[0] 1055 if lapack_driver == 'gelsy': 1056 res = b - a.dot(x) 1057 residuals = np.sum(res * res.conj()) 1058 else: 1059 residuals = out[1] 1060 r = out[2] 1061 assert_(r == 2, 'expected efficient rank 2, got %s' % r) 1062 assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0), 1063 residuals, 1064 rtol=25 * _eps_cast(a1.dtype), 1065 atol=25 * _eps_cast(a1.dtype), 1066 err_msg="driver: %s" % lapack_driver) 1067 assert_allclose( 1068 x, (-0.4831460674157303 + 0.258426966292135j, 1069 0.921348314606741 + 0.292134831460674j), 1070 rtol=25 * _eps_cast(a1.dtype), 1071 atol=25 * _eps_cast(a1.dtype), 1072 err_msg="driver: %s" % lapack_driver) 1073 1074 def test_simple_underdet(self): 1075 for dtype in REAL_DTYPES: 1076 a = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) 1077 b = np.array([1, 2], dtype=dtype) 1078 for lapack_driver in TestLstsq.lapack_drivers: 1079 for overwrite in (True, False): 1080 # Store values in case they are overwritten later 1081 a1 = a.copy() 1082 b1 = b.copy() 1083 out = lstsq(a1, b1, lapack_driver=lapack_driver, 1084 overwrite_a=overwrite, 1085 overwrite_b=overwrite) 1086 1087 x = out[0] 1088 r = out[2] 1089 assert_(r == 2, 'expected efficient rank 2, got %s' % r) 1090 assert_allclose(x, (-0.055555555555555, 0.111111111111111, 1091 0.277777777777777), 1092 rtol=25 * _eps_cast(a1.dtype), 1093 atol=25 * _eps_cast(a1.dtype), 1094 err_msg="driver: %s" % lapack_driver) 1095 1096 def test_random_exact(self): 1097 for dtype in REAL_DTYPES: 1098 for n in (20, 200): 1099 for lapack_driver in TestLstsq.lapack_drivers: 1100 for overwrite in (True, False): 1101 a = np.asarray(random([n, n]), dtype=dtype) 1102 for i in range(n): 1103 a[i, i] = 20 * (0.1 + a[i, i]) 1104 for i in range(4): 1105 b = np.asarray(random([n, 3]), dtype=dtype) 1106 # Store values in case they are overwritten later 1107 a1 = a.copy() 1108 b1 = b.copy() 1109 out = lstsq(a1, b1, 1110 lapack_driver=lapack_driver, 1111 overwrite_a=overwrite, 1112 overwrite_b=overwrite) 1113 x = out[0] 1114 r = out[2] 1115 assert_(r == n, 'expected efficient rank %s, ' 1116 'got %s' % (n, r)) 1117 if dtype is np.float32: 1118 assert_allclose( 1119 dot(a, x), b, 1120 rtol=500 * _eps_cast(a1.dtype), 1121 atol=500 * _eps_cast(a1.dtype), 1122 err_msg="driver: %s" % lapack_driver) 1123 else: 1124 assert_allclose( 1125 dot(a, x), b, 1126 rtol=1000 * _eps_cast(a1.dtype), 1127 atol=1000 * _eps_cast(a1.dtype), 1128 err_msg="driver: %s" % lapack_driver) 1129 1130 def test_random_complex_exact(self): 1131 for dtype in COMPLEX_DTYPES: 1132 for n in (20, 200): 1133 for lapack_driver in TestLstsq.lapack_drivers: 1134 for overwrite in (True, False): 1135 a = np.asarray(random([n, n]) + 1j*random([n, n]), 1136 dtype=dtype) 1137 for i in range(n): 1138 a[i, i] = 20 * (0.1 + a[i, i]) 1139 for i in range(2): 1140 b = np.asarray(random([n, 3]), dtype=dtype) 1141 # Store values in case they are overwritten later 1142 a1 = a.copy() 1143 b1 = b.copy() 1144 out = lstsq(a1, b1, lapack_driver=lapack_driver, 1145 overwrite_a=overwrite, 1146 overwrite_b=overwrite) 1147 x = out[0] 1148 r = out[2] 1149 assert_(r == n, 'expected efficient rank %s, ' 1150 'got %s' % (n, r)) 1151 if dtype is np.complex64: 1152 assert_allclose( 1153 dot(a, x), b, 1154 rtol=400 * _eps_cast(a1.dtype), 1155 atol=400 * _eps_cast(a1.dtype), 1156 err_msg="driver: %s" % lapack_driver) 1157 else: 1158 assert_allclose( 1159 dot(a, x), b, 1160 rtol=1000 * _eps_cast(a1.dtype), 1161 atol=1000 * _eps_cast(a1.dtype), 1162 err_msg="driver: %s" % lapack_driver) 1163 1164 def test_random_overdet(self): 1165 for dtype in REAL_DTYPES: 1166 for (n, m) in ((20, 15), (200, 2)): 1167 for lapack_driver in TestLstsq.lapack_drivers: 1168 for overwrite in (True, False): 1169 a = np.asarray(random([n, m]), dtype=dtype) 1170 for i in range(m): 1171 a[i, i] = 20 * (0.1 + a[i, i]) 1172 for i in range(4): 1173 b = np.asarray(random([n, 3]), dtype=dtype) 1174 # Store values in case they are overwritten later 1175 a1 = a.copy() 1176 b1 = b.copy() 1177 out = lstsq(a1, b1, 1178 lapack_driver=lapack_driver, 1179 overwrite_a=overwrite, 1180 overwrite_b=overwrite) 1181 x = out[0] 1182 r = out[2] 1183 assert_(r == m, 'expected efficient rank %s, ' 1184 'got %s' % (m, r)) 1185 assert_allclose( 1186 x, direct_lstsq(a, b, cmplx=0), 1187 rtol=25 * _eps_cast(a1.dtype), 1188 atol=25 * _eps_cast(a1.dtype), 1189 err_msg="driver: %s" % lapack_driver) 1190 1191 def test_random_complex_overdet(self): 1192 for dtype in COMPLEX_DTYPES: 1193 for (n, m) in ((20, 15), (200, 2)): 1194 for lapack_driver in TestLstsq.lapack_drivers: 1195 for overwrite in (True, False): 1196 a = np.asarray(random([n, m]) + 1j*random([n, m]), 1197 dtype=dtype) 1198 for i in range(m): 1199 a[i, i] = 20 * (0.1 + a[i, i]) 1200 for i in range(2): 1201 b = np.asarray(random([n, 3]), dtype=dtype) 1202 # Store values in case they are overwritten 1203 # later 1204 a1 = a.copy() 1205 b1 = b.copy() 1206 out = lstsq(a1, b1, 1207 lapack_driver=lapack_driver, 1208 overwrite_a=overwrite, 1209 overwrite_b=overwrite) 1210 x = out[0] 1211 r = out[2] 1212 assert_(r == m, 'expected efficient rank %s, ' 1213 'got %s' % (m, r)) 1214 assert_allclose( 1215 x, direct_lstsq(a, b, cmplx=1), 1216 rtol=25 * _eps_cast(a1.dtype), 1217 atol=25 * _eps_cast(a1.dtype), 1218 err_msg="driver: %s" % lapack_driver) 1219 1220 def test_check_finite(self): 1221 with suppress_warnings() as sup: 1222 # On (some) OSX this tests triggers a warning (gh-7538) 1223 sup.filter(RuntimeWarning, 1224 "internal gelsd driver lwork query error,.*" 1225 "Falling back to 'gelss' driver.") 1226 1227 at = np.array(((1, 20), (-30, 4))) 1228 for dtype, bt, lapack_driver, overwrite, check_finite in \ 1229 itertools.product(REAL_DTYPES, 1230 (((1, 0), (0, 1)), (1, 0), ((2, 1), (-30, 4))), 1231 TestLstsq.lapack_drivers, 1232 (True, False), 1233 (True, False)): 1234 1235 a = at.astype(dtype) 1236 b = np.array(bt, dtype=dtype) 1237 # Store values in case they are overwritten 1238 # later 1239 a1 = a.copy() 1240 b1 = b.copy() 1241 out = lstsq(a1, b1, lapack_driver=lapack_driver, 1242 check_finite=check_finite, overwrite_a=overwrite, 1243 overwrite_b=overwrite) 1244 x = out[0] 1245 r = out[2] 1246 assert_(r == 2, 'expected efficient rank 2, got %s' % r) 1247 assert_allclose(dot(a, x), b, 1248 rtol=25 * _eps_cast(a.dtype), 1249 atol=25 * _eps_cast(a.dtype), 1250 err_msg="driver: %s" % lapack_driver) 1251 1252 def test_zero_size(self): 1253 for a_shape, b_shape in (((0, 2), (0,)), 1254 ((0, 4), (0, 2)), 1255 ((4, 0), (4,)), 1256 ((4, 0), (4, 2))): 1257 b = np.ones(b_shape) 1258 x, residues, rank, s = lstsq(np.zeros(a_shape), b) 1259 assert_equal(x, np.zeros((a_shape[1],) + b_shape[1:])) 1260 residues_should_be = (np.empty((0,)) if a_shape[1] 1261 else np.linalg.norm(b, axis=0)**2) 1262 assert_equal(residues, residues_should_be) 1263 assert_(rank == 0, 'expected rank 0') 1264 assert_equal(s, np.empty((0,))) 1265 1266 1267@pytest.mark.filterwarnings('ignore::DeprecationWarning') 1268class TestPinv: 1269 def setup_method(self): 1270 np.random.seed(1234) 1271 1272 def test_simple_real(self): 1273 a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float) 1274 a_pinv = pinv(a) 1275 assert_array_almost_equal(dot(a, a_pinv), np.eye(3)) 1276 a_pinv = pinv2(a) 1277 assert_array_almost_equal(dot(a, a_pinv), np.eye(3)) 1278 1279 def test_simple_complex(self): 1280 a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], 1281 dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]], 1282 dtype=float)) 1283 a_pinv = pinv(a) 1284 assert_array_almost_equal(dot(a, a_pinv), np.eye(3)) 1285 a_pinv = pinv2(a) 1286 assert_array_almost_equal(dot(a, a_pinv), np.eye(3)) 1287 1288 def test_simple_singular(self): 1289 a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float) 1290 a_pinv = pinv(a) 1291 a_pinv2 = pinv2(a) 1292 assert_array_almost_equal(a_pinv, a_pinv2) 1293 1294 def test_simple_cols(self): 1295 a = array([[1, 2, 3], [4, 5, 6]], dtype=float) 1296 a_pinv = pinv(a) 1297 a_pinv2 = pinv2(a) 1298 assert_array_almost_equal(a_pinv, a_pinv2) 1299 1300 def test_simple_rows(self): 1301 a = array([[1, 2], [3, 4], [5, 6]], dtype=float) 1302 a_pinv = pinv(a) 1303 a_pinv2 = pinv2(a) 1304 assert_array_almost_equal(a_pinv, a_pinv2) 1305 1306 def test_check_finite(self): 1307 a = array([[1, 2, 3], [4, 5, 6.], [7, 8, 10]]) 1308 a_pinv = pinv(a, check_finite=False) 1309 assert_array_almost_equal(dot(a, a_pinv), np.eye(3)) 1310 a_pinv = pinv2(a, check_finite=False) 1311 assert_array_almost_equal(dot(a, a_pinv), np.eye(3)) 1312 1313 def test_native_list_argument(self): 1314 a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 1315 a_pinv = pinv(a) 1316 a_pinv2 = pinv2(a) 1317 assert_array_almost_equal(a_pinv, a_pinv2) 1318 1319 def test_atol_rtol(self): 1320 n = 12 1321 # get a random ortho matrix for shuffling 1322 q, _ = qr(np.random.rand(n, n)) 1323 a_m = np.arange(35.0).reshape(7,5) 1324 a = a_m.copy() 1325 a[0,0] = 0.001 1326 atol = 1e-5 1327 rtol = 0.05 1328 # svds of a_m is ~ [116.906, 4.234, tiny, tiny, tiny] 1329 # svds of a is ~ [116.906, 4.234, 4.62959e-04, tiny, tiny] 1330 # Just abs cutoff such that we arrive at a_modified 1331 a_p = pinv(a_m, atol=atol, rtol=0.) 1332 adiff1 = a @ a_p @ a - a 1333 adiff2 = a_m @ a_p @ a_m - a_m 1334 # Now adiff1 should be around atol value while adiff2 should be 1335 # relatively tiny 1336 assert_allclose(np.linalg.norm(adiff1), 5e-4, atol=5.e-4) 1337 assert_allclose(np.linalg.norm(adiff2), 5e-14, atol=5.e-14) 1338 1339 # Now do the same but remove another sv ~4.234 via rtol 1340 a_p = pinv(a_m, atol=atol, rtol=rtol) 1341 adiff1 = a @ a_p @ a - a 1342 adiff2 = a_m @ a_p @ a_m - a_m 1343 assert_allclose(np.linalg.norm(adiff1), 4.233, rtol=0.01) 1344 assert_allclose(np.linalg.norm(adiff2), 4.233, rtol=0.01) 1345 1346 1347@pytest.mark.filterwarnings('ignore::DeprecationWarning') 1348class TestPinvSymmetric: 1349 1350 def setup_method(self): 1351 np.random.seed(1234) 1352 1353 def test_simple_real(self): 1354 a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float) 1355 a = np.dot(a, a.T) 1356 a_pinv = pinvh(a) 1357 assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3)) 1358 1359 def test_nonpositive(self): 1360 a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float) 1361 a = np.dot(a, a.T) 1362 u, s, vt = np.linalg.svd(a) 1363 s[0] *= -1 1364 a = np.dot(u * s, vt) # a is now symmetric non-positive and singular 1365 a_pinv = pinv2(a) 1366 a_pinvh = pinvh(a) 1367 assert_array_almost_equal(a_pinv, a_pinvh) 1368 1369 def test_simple_complex(self): 1370 a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], 1371 dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]], 1372 dtype=float)) 1373 a = np.dot(a, a.conj().T) 1374 a_pinv = pinvh(a) 1375 assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3)) 1376 1377 def test_native_list_argument(self): 1378 a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float) 1379 a = np.dot(a, a.T) 1380 a_pinv = pinvh(a.tolist()) 1381 assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3)) 1382 1383 def test_atol_rtol(self): 1384 n = 12 1385 # get a random ortho matrix for shuffling 1386 q, _ = qr(np.random.rand(n, n)) 1387 a = np.diag([4, 3, 2, 1, 0.99e-4, 0.99e-5] + [0.99e-6]*(n-6)) 1388 a = q.T @ a @ q 1389 a_m = np.diag([4, 3, 2, 1, 0.99e-4, 0.] + [0.]*(n-6)) 1390 a_m = q.T @ a_m @ q 1391 atol = 1e-5 1392 rtol = (4.01e-4 - 4e-5)/4 1393 # Just abs cutoff such that we arrive at a_modified 1394 a_p = pinvh(a, atol=atol, rtol=0.) 1395 adiff1 = a @ a_p @ a - a 1396 adiff2 = a_m @ a_p @ a_m - a_m 1397 # Now adiff1 should dance around atol value since truncation 1398 # while adiff2 should be relatively tiny 1399 assert_allclose(norm(adiff1), atol, rtol=0.1) 1400 assert_allclose(norm(adiff2), 1e-12, atol=1e-11) 1401 1402 # Now do the same but through rtol cancelling atol value 1403 a_p = pinvh(a, atol=atol, rtol=rtol) 1404 adiff1 = a @ a_p @ a - a 1405 adiff2 = a_m @ a_p @ a_m - a_m 1406 # adiff1 and adiff2 should be elevated to ~1e-4 due to mismatch 1407 assert_allclose(norm(adiff1), 1e-4, rtol=0.1) 1408 assert_allclose(norm(adiff2), 1e-4, rtol=0.1) 1409 1410 1411@pytest.mark.filterwarnings('ignore::DeprecationWarning') 1412@pytest.mark.parametrize('scale', (1e-20, 1., 1e20)) 1413@pytest.mark.parametrize('pinv_', (pinv, pinvh, pinv2)) 1414def test_auto_rcond(scale, pinv_): 1415 x = np.array([[1, 0], [0, 1e-10]]) * scale 1416 expected = np.diag(1. / np.diag(x)) 1417 x_inv = pinv_(x) 1418 assert_allclose(x_inv, expected) 1419 1420 1421class TestVectorNorms: 1422 1423 def test_types(self): 1424 for dtype in np.typecodes['AllFloat']: 1425 x = np.array([1, 2, 3], dtype=dtype) 1426 tol = max(1e-15, np.finfo(dtype).eps.real * 20) 1427 assert_allclose(norm(x), np.sqrt(14), rtol=tol) 1428 assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol) 1429 1430 for dtype in np.typecodes['Complex']: 1431 x = np.array([1j, 2j, 3j], dtype=dtype) 1432 tol = max(1e-15, np.finfo(dtype).eps.real * 20) 1433 assert_allclose(norm(x), np.sqrt(14), rtol=tol) 1434 assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol) 1435 1436 def test_overflow(self): 1437 # unlike numpy's norm, this one is 1438 # safer on overflow 1439 a = array([1e20], dtype=float32) 1440 assert_almost_equal(norm(a), a) 1441 1442 def test_stable(self): 1443 # more stable than numpy's norm 1444 a = array([1e4] + [1]*10000, dtype=float32) 1445 try: 1446 # snrm in double precision; we obtain the same as for float64 1447 # -- large atol needed due to varying blas implementations 1448 assert_allclose(norm(a) - 1e4, 0.5, atol=1e-2) 1449 except AssertionError: 1450 # snrm implemented in single precision, == np.linalg.norm result 1451 msg = ": Result should equal either 0.0 or 0.5 (depending on " \ 1452 "implementation of snrm2)." 1453 assert_almost_equal(norm(a) - 1e4, 0.0, err_msg=msg) 1454 1455 def test_zero_norm(self): 1456 assert_equal(norm([1, 0, 3], 0), 2) 1457 assert_equal(norm([1, 2, 3], 0), 3) 1458 1459 def test_axis_kwd(self): 1460 a = np.array([[[2, 1], [3, 4]]] * 2, 'd') 1461 assert_allclose(norm(a, axis=1), [[3.60555128, 4.12310563]] * 2) 1462 assert_allclose(norm(a, 1, axis=1), [[5.] * 2] * 2) 1463 1464 def test_keepdims_kwd(self): 1465 a = np.array([[[2, 1], [3, 4]]] * 2, 'd') 1466 b = norm(a, axis=1, keepdims=True) 1467 assert_allclose(b, [[[3.60555128, 4.12310563]]] * 2) 1468 assert_(b.shape == (2, 1, 2)) 1469 assert_allclose(norm(a, 1, axis=2, keepdims=True), [[[3.], [7.]]] * 2) 1470 1471 @pytest.mark.skipif(not HAS_ILP64, reason="64-bit BLAS required") 1472 def test_large_vector(self): 1473 check_free_memory(free_mb=17000) 1474 x = np.zeros([2**31], dtype=np.float64) 1475 x[-1] = 1 1476 res = norm(x) 1477 del x 1478 assert_allclose(res, 1.0) 1479 1480 1481class TestMatrixNorms: 1482 1483 def test_matrix_norms(self): 1484 # Not all of these are matrix norms in the most technical sense. 1485 np.random.seed(1234) 1486 for n, m in (1, 1), (1, 3), (3, 1), (4, 4), (4, 5), (5, 4): 1487 for t in np.single, np.double, np.csingle, np.cdouble, np.int64: 1488 A = 10 * np.random.randn(n, m).astype(t) 1489 if np.issubdtype(A.dtype, np.complexfloating): 1490 A = (A + 10j * np.random.randn(n, m)).astype(t) 1491 t_high = np.cdouble 1492 else: 1493 t_high = np.double 1494 for order in (None, 'fro', 1, -1, 2, -2, np.inf, -np.inf): 1495 actual = norm(A, ord=order) 1496 desired = np.linalg.norm(A, ord=order) 1497 # SciPy may return higher precision matrix norms. 1498 # This is a consequence of using LAPACK. 1499 if not np.allclose(actual, desired): 1500 desired = np.linalg.norm(A.astype(t_high), ord=order) 1501 assert_allclose(actual, desired) 1502 1503 def test_axis_kwd(self): 1504 a = np.array([[[2, 1], [3, 4]]] * 2, 'd') 1505 b = norm(a, ord=np.inf, axis=(1, 0)) 1506 c = norm(np.swapaxes(a, 0, 1), ord=np.inf, axis=(0, 1)) 1507 d = norm(a, ord=1, axis=(0, 1)) 1508 assert_allclose(b, c) 1509 assert_allclose(c, d) 1510 assert_allclose(b, d) 1511 assert_(b.shape == c.shape == d.shape) 1512 b = norm(a, ord=1, axis=(1, 0)) 1513 c = norm(np.swapaxes(a, 0, 1), ord=1, axis=(0, 1)) 1514 d = norm(a, ord=np.inf, axis=(0, 1)) 1515 assert_allclose(b, c) 1516 assert_allclose(c, d) 1517 assert_allclose(b, d) 1518 assert_(b.shape == c.shape == d.shape) 1519 1520 def test_keepdims_kwd(self): 1521 a = np.arange(120, dtype='d').reshape(2, 3, 4, 5) 1522 b = norm(a, ord=np.inf, axis=(1, 0), keepdims=True) 1523 c = norm(a, ord=1, axis=(0, 1), keepdims=True) 1524 assert_allclose(b, c) 1525 assert_(b.shape == c.shape) 1526 1527 1528class TestOverwrite: 1529 def test_solve(self): 1530 assert_no_overwrite(solve, [(3, 3), (3,)]) 1531 1532 def test_solve_triangular(self): 1533 assert_no_overwrite(solve_triangular, [(3, 3), (3,)]) 1534 1535 def test_solve_banded(self): 1536 assert_no_overwrite(lambda ab, b: solve_banded((2, 1), ab, b), 1537 [(4, 6), (6,)]) 1538 1539 def test_solveh_banded(self): 1540 assert_no_overwrite(solveh_banded, [(2, 6), (6,)]) 1541 1542 def test_inv(self): 1543 assert_no_overwrite(inv, [(3, 3)]) 1544 1545 def test_det(self): 1546 assert_no_overwrite(det, [(3, 3)]) 1547 1548 def test_lstsq(self): 1549 assert_no_overwrite(lstsq, [(3, 2), (3,)]) 1550 1551 def test_pinv(self): 1552 assert_no_overwrite(pinv, [(3, 3)]) 1553 1554 @pytest.mark.filterwarnings('ignore::DeprecationWarning') 1555 def test_pinv2(self): 1556 assert_no_overwrite(pinv2, [(3, 3)]) 1557 1558 def test_pinvh(self): 1559 assert_no_overwrite(pinvh, [(3, 3)]) 1560 1561 1562class TestSolveCirculant: 1563 1564 def test_basic1(self): 1565 c = np.array([1, 2, 3, 5]) 1566 b = np.array([1, -1, 1, 0]) 1567 x = solve_circulant(c, b) 1568 y = solve(circulant(c), b) 1569 assert_allclose(x, y) 1570 1571 def test_basic2(self): 1572 # b is a 2-d matrix. 1573 c = np.array([1, 2, -3, -5]) 1574 b = np.arange(12).reshape(4, 3) 1575 x = solve_circulant(c, b) 1576 y = solve(circulant(c), b) 1577 assert_allclose(x, y) 1578 1579 def test_basic3(self): 1580 # b is a 3-d matrix. 1581 c = np.array([1, 2, -3, -5]) 1582 b = np.arange(24).reshape(4, 3, 2) 1583 x = solve_circulant(c, b) 1584 y = solve(circulant(c), b) 1585 assert_allclose(x, y) 1586 1587 def test_complex(self): 1588 # Complex b and c 1589 c = np.array([1+2j, -3, 4j, 5]) 1590 b = np.arange(8).reshape(4, 2) + 0.5j 1591 x = solve_circulant(c, b) 1592 y = solve(circulant(c), b) 1593 assert_allclose(x, y) 1594 1595 def test_random_b_and_c(self): 1596 # Random b and c 1597 np.random.seed(54321) 1598 c = np.random.randn(50) 1599 b = np.random.randn(50) 1600 x = solve_circulant(c, b) 1601 y = solve(circulant(c), b) 1602 assert_allclose(x, y) 1603 1604 def test_singular(self): 1605 # c gives a singular circulant matrix. 1606 c = np.array([1, 1, 0, 0]) 1607 b = np.array([1, 2, 3, 4]) 1608 x = solve_circulant(c, b, singular='lstsq') 1609 y, res, rnk, s = lstsq(circulant(c), b) 1610 assert_allclose(x, y) 1611 assert_raises(LinAlgError, solve_circulant, x, y) 1612 1613 def test_axis_args(self): 1614 # Test use of caxis, baxis and outaxis. 1615 1616 # c has shape (2, 1, 4) 1617 c = np.array([[[-1, 2.5, 3, 3.5]], [[1, 6, 6, 6.5]]]) 1618 1619 # b has shape (3, 4) 1620 b = np.array([[0, 0, 1, 1], [1, 1, 0, 0], [1, -1, 0, 0]]) 1621 1622 x = solve_circulant(c, b, baxis=1) 1623 assert_equal(x.shape, (4, 2, 3)) 1624 expected = np.empty_like(x) 1625 expected[:, 0, :] = solve(circulant(c[0]), b.T) 1626 expected[:, 1, :] = solve(circulant(c[1]), b.T) 1627 assert_allclose(x, expected) 1628 1629 x = solve_circulant(c, b, baxis=1, outaxis=-1) 1630 assert_equal(x.shape, (2, 3, 4)) 1631 assert_allclose(np.rollaxis(x, -1), expected) 1632 1633 # np.swapaxes(c, 1, 2) has shape (2, 4, 1); b.T has shape (4, 3). 1634 x = solve_circulant(np.swapaxes(c, 1, 2), b.T, caxis=1) 1635 assert_equal(x.shape, (4, 2, 3)) 1636 assert_allclose(x, expected) 1637 1638 def test_native_list_arguments(self): 1639 # Same as test_basic1 using python's native list. 1640 c = [1, 2, 3, 5] 1641 b = [1, -1, 1, 0] 1642 x = solve_circulant(c, b) 1643 y = solve(circulant(c), b) 1644 assert_allclose(x, y) 1645 1646 1647class TestMatrix_Balance: 1648 1649 def test_string_arg(self): 1650 assert_raises(ValueError, matrix_balance, 'Some string for fail') 1651 1652 def test_infnan_arg(self): 1653 assert_raises(ValueError, matrix_balance, 1654 np.array([[1, 2], [3, np.inf]])) 1655 assert_raises(ValueError, matrix_balance, 1656 np.array([[1, 2], [3, np.nan]])) 1657 1658 def test_scaling(self): 1659 _, y = matrix_balance(np.array([[1000, 1], [1000, 0]])) 1660 # Pre/post LAPACK 3.5.0 gives the same result up to an offset 1661 # since in each case col norm is x1000 greater and 1662 # 1000 / 32 ~= 1 * 32 hence balanced with 2 ** 5. 1663 assert_allclose(int(np.diff(np.log2(np.diag(y)))), 5) 1664 1665 def test_scaling_order(self): 1666 A = np.array([[1, 0, 1e-4], [1, 1, 1e-2], [1e4, 1e2, 1]]) 1667 x, y = matrix_balance(A) 1668 assert_allclose(solve(y, A).dot(y), x) 1669 1670 def test_separate(self): 1671 _, (y, z) = matrix_balance(np.array([[1000, 1], [1000, 0]]), 1672 separate=1) 1673 assert_equal(int(np.diff(np.log2(y))), 5) 1674 assert_allclose(z, np.arange(2)) 1675 1676 def test_permutation(self): 1677 A = block_diag(np.ones((2, 2)), np.tril(np.ones((2, 2))), 1678 np.ones((3, 3))) 1679 x, (y, z) = matrix_balance(A, separate=1) 1680 assert_allclose(y, np.ones_like(y)) 1681 assert_allclose(z, np.array([0, 1, 6, 5, 4, 3, 2])) 1682 1683 def test_perm_and_scaling(self): 1684 # Matrix with its diagonal removed 1685 cases = ( # Case 0 1686 np.array([[0., 0., 0., 0., 0.000002], 1687 [0., 0., 0., 0., 0.], 1688 [2., 2., 0., 0., 0.], 1689 [2., 2., 0., 0., 0.], 1690 [0., 0., 0.000002, 0., 0.]]), 1691 # Case 1 user reported GH-7258 1692 np.array([[-0.5, 0., 0., 0.], 1693 [0., -1., 0., 0.], 1694 [1., 0., -0.5, 0.], 1695 [0., 1., 0., -1.]]), 1696 # Case 2 user reported GH-7258 1697 np.array([[-3., 0., 1., 0.], 1698 [-1., -1., -0., 1.], 1699 [-3., -0., -0., 0.], 1700 [-1., -0., 1., -1.]]) 1701 ) 1702 1703 for A in cases: 1704 x, y = matrix_balance(A) 1705 x, (s, p) = matrix_balance(A, separate=1) 1706 ip = np.empty_like(p) 1707 ip[p] = np.arange(A.shape[0]) 1708 assert_allclose(y, np.diag(s)[ip, :]) 1709 assert_allclose(solve(y, A).dot(y), x) 1710