1"""test sparse matrix construction functions""" 2 3from numpy.testing import assert_equal 4from scipy.sparse import csr_matrix 5 6import numpy as np 7from scipy.sparse import extract 8 9 10class TestExtract: 11 def setup_method(self): 12 self.cases = [ 13 csr_matrix([[1,2]]), 14 csr_matrix([[1,0]]), 15 csr_matrix([[0,0]]), 16 csr_matrix([[1],[2]]), 17 csr_matrix([[1],[0]]), 18 csr_matrix([[0],[0]]), 19 csr_matrix([[1,2],[3,4]]), 20 csr_matrix([[0,1],[0,0]]), 21 csr_matrix([[0,0],[1,0]]), 22 csr_matrix([[0,0],[0,0]]), 23 csr_matrix([[1,2,0,0,3],[4,5,0,6,7],[0,0,8,9,0]]), 24 csr_matrix([[1,2,0,0,3],[4,5,0,6,7],[0,0,8,9,0]]).T, 25 ] 26 27 def find(self): 28 for A in self.cases: 29 I,J,V = extract.find(A) 30 assert_equal(A.toarray(), csr_matrix(((I,J),V), shape=A.shape)) 31 32 def test_tril(self): 33 for A in self.cases: 34 B = A.toarray() 35 for k in [-3,-2,-1,0,1,2,3]: 36 assert_equal(extract.tril(A,k=k).toarray(), np.tril(B,k=k)) 37 38 def test_triu(self): 39 for A in self.cases: 40 B = A.toarray() 41 for k in [-3,-2,-1,0,1,2,3]: 42 assert_equal(extract.triu(A,k=k).toarray(), np.triu(B,k=k)) 43