1"""Utilities to deal with sympy.Matrix, numpy and scipy.sparse.""" 2 3from sympy import MatrixBase, I, Expr, Integer 4from sympy.matrices import eye, zeros 5from sympy.external import import_module 6 7__all__ = [ 8 'numpy_ndarray', 9 'scipy_sparse_matrix', 10 'sympy_to_numpy', 11 'sympy_to_scipy_sparse', 12 'numpy_to_sympy', 13 'scipy_sparse_to_sympy', 14 'flatten_scalar', 15 'matrix_dagger', 16 'to_sympy', 17 'to_numpy', 18 'to_scipy_sparse', 19 'matrix_tensor_product', 20 'matrix_zeros' 21] 22 23# Conditionally define the base classes for numpy and scipy.sparse arrays 24# for use in isinstance tests. 25 26np = import_module('numpy') 27if not np: 28 class numpy_ndarray: 29 pass 30else: 31 numpy_ndarray = np.ndarray # type: ignore 32 33scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) 34if not scipy: 35 class scipy_sparse_matrix: 36 pass 37 sparse = None 38else: 39 sparse = scipy.sparse 40 # Try to find spmatrix. 41 if hasattr(sparse, 'base'): 42 # Newer versions have it under scipy.sparse.base. 43 scipy_sparse_matrix = sparse.base.spmatrix # type: ignore 44 elif hasattr(sparse, 'sparse'): 45 # Older versions have it under scipy.sparse.sparse. 46 scipy_sparse_matrix = sparse.sparse.spmatrix # type: ignore 47 48 49def sympy_to_numpy(m, **options): 50 """Convert a sympy Matrix/complex number to a numpy matrix or scalar.""" 51 if not np: 52 raise ImportError 53 dtype = options.get('dtype', 'complex') 54 if isinstance(m, MatrixBase): 55 return np.matrix(m.tolist(), dtype=dtype) 56 elif isinstance(m, Expr): 57 if m.is_Number or m.is_NumberSymbol or m == I: 58 return complex(m) 59 raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m) 60 61 62def sympy_to_scipy_sparse(m, **options): 63 """Convert a sympy Matrix/complex number to a numpy matrix or scalar.""" 64 if not np or not sparse: 65 raise ImportError 66 dtype = options.get('dtype', 'complex') 67 if isinstance(m, MatrixBase): 68 return sparse.csr_matrix(np.matrix(m.tolist(), dtype=dtype)) 69 elif isinstance(m, Expr): 70 if m.is_Number or m.is_NumberSymbol or m == I: 71 return complex(m) 72 raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m) 73 74 75def scipy_sparse_to_sympy(m, **options): 76 """Convert a scipy.sparse matrix to a sympy matrix.""" 77 return MatrixBase(m.todense()) 78 79 80def numpy_to_sympy(m, **options): 81 """Convert a numpy matrix to a sympy matrix.""" 82 return MatrixBase(m) 83 84 85def to_sympy(m, **options): 86 """Convert a numpy/scipy.sparse matrix to a sympy matrix.""" 87 if isinstance(m, MatrixBase): 88 return m 89 elif isinstance(m, numpy_ndarray): 90 return numpy_to_sympy(m) 91 elif isinstance(m, scipy_sparse_matrix): 92 return scipy_sparse_to_sympy(m) 93 elif isinstance(m, Expr): 94 return m 95 raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) 96 97 98def to_numpy(m, **options): 99 """Convert a sympy/scipy.sparse matrix to a numpy matrix.""" 100 dtype = options.get('dtype', 'complex') 101 if isinstance(m, (MatrixBase, Expr)): 102 return sympy_to_numpy(m, dtype=dtype) 103 elif isinstance(m, numpy_ndarray): 104 return m 105 elif isinstance(m, scipy_sparse_matrix): 106 return m.todense() 107 raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) 108 109 110def to_scipy_sparse(m, **options): 111 """Convert a sympy/numpy matrix to a scipy.sparse matrix.""" 112 dtype = options.get('dtype', 'complex') 113 if isinstance(m, (MatrixBase, Expr)): 114 return sympy_to_scipy_sparse(m, dtype=dtype) 115 elif isinstance(m, numpy_ndarray): 116 if not sparse: 117 raise ImportError 118 return sparse.csr_matrix(m) 119 elif isinstance(m, scipy_sparse_matrix): 120 return m 121 raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) 122 123 124def flatten_scalar(e): 125 """Flatten a 1x1 matrix to a scalar, return larger matrices unchanged.""" 126 if isinstance(e, MatrixBase): 127 if e.shape == (1, 1): 128 e = e[0] 129 if isinstance(e, (numpy_ndarray, scipy_sparse_matrix)): 130 if e.shape == (1, 1): 131 e = complex(e[0, 0]) 132 return e 133 134 135def matrix_dagger(e): 136 """Return the dagger of a sympy/numpy/scipy.sparse matrix.""" 137 if isinstance(e, MatrixBase): 138 return e.H 139 elif isinstance(e, (numpy_ndarray, scipy_sparse_matrix)): 140 return e.conjugate().transpose() 141 raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % e) 142 143 144# TODO: Move this into sympy.matricies. 145def _sympy_tensor_product(*matrices): 146 """Compute the kronecker product of a sequence of sympy Matrices. 147 """ 148 from sympy.matrices.expressions.kronecker import matrix_kronecker_product 149 150 return matrix_kronecker_product(*matrices) 151 152 153def _numpy_tensor_product(*product): 154 """numpy version of tensor product of multiple arguments.""" 155 if not np: 156 raise ImportError 157 answer = product[0] 158 for item in product[1:]: 159 answer = np.kron(answer, item) 160 return answer 161 162 163def _scipy_sparse_tensor_product(*product): 164 """scipy.sparse version of tensor product of multiple arguments.""" 165 if not sparse: 166 raise ImportError 167 answer = product[0] 168 for item in product[1:]: 169 answer = sparse.kron(answer, item) 170 # The final matrices will just be multiplied, so csr is a good final 171 # sparse format. 172 return sparse.csr_matrix(answer) 173 174 175def matrix_tensor_product(*product): 176 """Compute the matrix tensor product of sympy/numpy/scipy.sparse matrices.""" 177 if isinstance(product[0], MatrixBase): 178 return _sympy_tensor_product(*product) 179 elif isinstance(product[0], numpy_ndarray): 180 return _numpy_tensor_product(*product) 181 elif isinstance(product[0], scipy_sparse_matrix): 182 return _scipy_sparse_tensor_product(*product) 183 184 185def _numpy_eye(n): 186 """numpy version of complex eye.""" 187 if not np: 188 raise ImportError 189 return np.matrix(np.eye(n, dtype='complex')) 190 191 192def _scipy_sparse_eye(n): 193 """scipy.sparse version of complex eye.""" 194 if not sparse: 195 raise ImportError 196 return sparse.eye(n, n, dtype='complex') 197 198 199def matrix_eye(n, **options): 200 """Get the version of eye and tensor_product for a given format.""" 201 format = options.get('format', 'sympy') 202 if format == 'sympy': 203 return eye(n) 204 elif format == 'numpy': 205 return _numpy_eye(n) 206 elif format == 'scipy.sparse': 207 return _scipy_sparse_eye(n) 208 raise NotImplementedError('Invalid format: %r' % format) 209 210 211def _numpy_zeros(m, n, **options): 212 """numpy version of zeros.""" 213 dtype = options.get('dtype', 'float64') 214 if not np: 215 raise ImportError 216 return np.zeros((m, n), dtype=dtype) 217 218 219def _scipy_sparse_zeros(m, n, **options): 220 """scipy.sparse version of zeros.""" 221 spmatrix = options.get('spmatrix', 'csr') 222 dtype = options.get('dtype', 'float64') 223 if not sparse: 224 raise ImportError 225 if spmatrix == 'lil': 226 return sparse.lil_matrix((m, n), dtype=dtype) 227 elif spmatrix == 'csr': 228 return sparse.csr_matrix((m, n), dtype=dtype) 229 230 231def matrix_zeros(m, n, **options): 232 """"Get a zeros matrix for a given format.""" 233 format = options.get('format', 'sympy') 234 if format == 'sympy': 235 return zeros(m, n) 236 elif format == 'numpy': 237 return _numpy_zeros(m, n, **options) 238 elif format == 'scipy.sparse': 239 return _scipy_sparse_zeros(m, n, **options) 240 raise NotImplementedError('Invaild format: %r' % format) 241 242 243def _numpy_matrix_to_zero(e): 244 """Convert a numpy zero matrix to the zero scalar.""" 245 if not np: 246 raise ImportError 247 test = np.zeros_like(e) 248 if np.allclose(e, test): 249 return 0.0 250 else: 251 return e 252 253 254def _scipy_sparse_matrix_to_zero(e): 255 """Convert a scipy.sparse zero matrix to the zero scalar.""" 256 if not np: 257 raise ImportError 258 edense = e.todense() 259 test = np.zeros_like(edense) 260 if np.allclose(edense, test): 261 return 0.0 262 else: 263 return e 264 265 266def matrix_to_zero(e): 267 """Convert a zero matrix to the scalar zero.""" 268 if isinstance(e, MatrixBase): 269 if zeros(*e.shape) == e: 270 e = Integer(0) 271 elif isinstance(e, numpy_ndarray): 272 e = _numpy_matrix_to_zero(e) 273 elif isinstance(e, scipy_sparse_matrix): 274 e = _scipy_sparse_matrix_to_zero(e) 275 return e 276