1"""Utilities to deal with sympy.Matrix, numpy and scipy.sparse."""
2
3from sympy import MatrixBase, I, Expr, Integer
4from sympy.matrices import eye, zeros
5from sympy.external import import_module
6
7__all__ = [
8    'numpy_ndarray',
9    'scipy_sparse_matrix',
10    'sympy_to_numpy',
11    'sympy_to_scipy_sparse',
12    'numpy_to_sympy',
13    'scipy_sparse_to_sympy',
14    'flatten_scalar',
15    'matrix_dagger',
16    'to_sympy',
17    'to_numpy',
18    'to_scipy_sparse',
19    'matrix_tensor_product',
20    'matrix_zeros'
21]
22
23# Conditionally define the base classes for numpy and scipy.sparse arrays
24# for use in isinstance tests.
25
26np = import_module('numpy')
27if not np:
28    class numpy_ndarray:
29        pass
30else:
31    numpy_ndarray = np.ndarray  # type: ignore
32
33scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']})
34if not scipy:
35    class scipy_sparse_matrix:
36        pass
37    sparse = None
38else:
39    sparse = scipy.sparse
40    # Try to find spmatrix.
41    if hasattr(sparse, 'base'):
42        # Newer versions have it under scipy.sparse.base.
43        scipy_sparse_matrix = sparse.base.spmatrix  # type: ignore
44    elif hasattr(sparse, 'sparse'):
45        # Older versions have it under scipy.sparse.sparse.
46        scipy_sparse_matrix = sparse.sparse.spmatrix  # type: ignore
47
48
49def sympy_to_numpy(m, **options):
50    """Convert a sympy Matrix/complex number to a numpy matrix or scalar."""
51    if not np:
52        raise ImportError
53    dtype = options.get('dtype', 'complex')
54    if isinstance(m, MatrixBase):
55        return np.matrix(m.tolist(), dtype=dtype)
56    elif isinstance(m, Expr):
57        if m.is_Number or m.is_NumberSymbol or m == I:
58            return complex(m)
59    raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m)
60
61
62def sympy_to_scipy_sparse(m, **options):
63    """Convert a sympy Matrix/complex number to a numpy matrix or scalar."""
64    if not np or not sparse:
65        raise ImportError
66    dtype = options.get('dtype', 'complex')
67    if isinstance(m, MatrixBase):
68        return sparse.csr_matrix(np.matrix(m.tolist(), dtype=dtype))
69    elif isinstance(m, Expr):
70        if m.is_Number or m.is_NumberSymbol or m == I:
71            return complex(m)
72    raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m)
73
74
75def scipy_sparse_to_sympy(m, **options):
76    """Convert a scipy.sparse matrix to a sympy matrix."""
77    return MatrixBase(m.todense())
78
79
80def numpy_to_sympy(m, **options):
81    """Convert a numpy matrix to a sympy matrix."""
82    return MatrixBase(m)
83
84
85def to_sympy(m, **options):
86    """Convert a numpy/scipy.sparse matrix to a sympy matrix."""
87    if isinstance(m, MatrixBase):
88        return m
89    elif isinstance(m, numpy_ndarray):
90        return numpy_to_sympy(m)
91    elif isinstance(m, scipy_sparse_matrix):
92        return scipy_sparse_to_sympy(m)
93    elif isinstance(m, Expr):
94        return m
95    raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m)
96
97
98def to_numpy(m, **options):
99    """Convert a sympy/scipy.sparse matrix to a numpy matrix."""
100    dtype = options.get('dtype', 'complex')
101    if isinstance(m, (MatrixBase, Expr)):
102        return sympy_to_numpy(m, dtype=dtype)
103    elif isinstance(m, numpy_ndarray):
104        return m
105    elif isinstance(m, scipy_sparse_matrix):
106        return m.todense()
107    raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m)
108
109
110def to_scipy_sparse(m, **options):
111    """Convert a sympy/numpy matrix to a scipy.sparse matrix."""
112    dtype = options.get('dtype', 'complex')
113    if isinstance(m, (MatrixBase, Expr)):
114        return sympy_to_scipy_sparse(m, dtype=dtype)
115    elif isinstance(m, numpy_ndarray):
116        if not sparse:
117            raise ImportError
118        return sparse.csr_matrix(m)
119    elif isinstance(m, scipy_sparse_matrix):
120        return m
121    raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m)
122
123
124def flatten_scalar(e):
125    """Flatten a 1x1 matrix to a scalar, return larger matrices unchanged."""
126    if isinstance(e, MatrixBase):
127        if e.shape == (1, 1):
128            e = e[0]
129    if isinstance(e, (numpy_ndarray, scipy_sparse_matrix)):
130        if e.shape == (1, 1):
131            e = complex(e[0, 0])
132    return e
133
134
135def matrix_dagger(e):
136    """Return the dagger of a sympy/numpy/scipy.sparse matrix."""
137    if isinstance(e, MatrixBase):
138        return e.H
139    elif isinstance(e, (numpy_ndarray, scipy_sparse_matrix)):
140        return e.conjugate().transpose()
141    raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % e)
142
143
144# TODO: Move this into sympy.matricies.
145def _sympy_tensor_product(*matrices):
146    """Compute the kronecker product of a sequence of sympy Matrices.
147    """
148    from sympy.matrices.expressions.kronecker import matrix_kronecker_product
149
150    return matrix_kronecker_product(*matrices)
151
152
153def _numpy_tensor_product(*product):
154    """numpy version of tensor product of multiple arguments."""
155    if not np:
156        raise ImportError
157    answer = product[0]
158    for item in product[1:]:
159        answer = np.kron(answer, item)
160    return answer
161
162
163def _scipy_sparse_tensor_product(*product):
164    """scipy.sparse version of tensor product of multiple arguments."""
165    if not sparse:
166        raise ImportError
167    answer = product[0]
168    for item in product[1:]:
169        answer = sparse.kron(answer, item)
170    # The final matrices will just be multiplied, so csr is a good final
171    # sparse format.
172    return sparse.csr_matrix(answer)
173
174
175def matrix_tensor_product(*product):
176    """Compute the matrix tensor product of sympy/numpy/scipy.sparse matrices."""
177    if isinstance(product[0], MatrixBase):
178        return _sympy_tensor_product(*product)
179    elif isinstance(product[0], numpy_ndarray):
180        return _numpy_tensor_product(*product)
181    elif isinstance(product[0], scipy_sparse_matrix):
182        return _scipy_sparse_tensor_product(*product)
183
184
185def _numpy_eye(n):
186    """numpy version of complex eye."""
187    if not np:
188        raise ImportError
189    return np.matrix(np.eye(n, dtype='complex'))
190
191
192def _scipy_sparse_eye(n):
193    """scipy.sparse version of complex eye."""
194    if not sparse:
195        raise ImportError
196    return sparse.eye(n, n, dtype='complex')
197
198
199def matrix_eye(n, **options):
200    """Get the version of eye and tensor_product for a given format."""
201    format = options.get('format', 'sympy')
202    if format == 'sympy':
203        return eye(n)
204    elif format == 'numpy':
205        return _numpy_eye(n)
206    elif format == 'scipy.sparse':
207        return _scipy_sparse_eye(n)
208    raise NotImplementedError('Invalid format: %r' % format)
209
210
211def _numpy_zeros(m, n, **options):
212    """numpy version of zeros."""
213    dtype = options.get('dtype', 'float64')
214    if not np:
215        raise ImportError
216    return np.zeros((m, n), dtype=dtype)
217
218
219def _scipy_sparse_zeros(m, n, **options):
220    """scipy.sparse version of zeros."""
221    spmatrix = options.get('spmatrix', 'csr')
222    dtype = options.get('dtype', 'float64')
223    if not sparse:
224        raise ImportError
225    if spmatrix == 'lil':
226        return sparse.lil_matrix((m, n), dtype=dtype)
227    elif spmatrix == 'csr':
228        return sparse.csr_matrix((m, n), dtype=dtype)
229
230
231def matrix_zeros(m, n, **options):
232    """"Get a zeros matrix for a given format."""
233    format = options.get('format', 'sympy')
234    if format == 'sympy':
235        return zeros(m, n)
236    elif format == 'numpy':
237        return _numpy_zeros(m, n, **options)
238    elif format == 'scipy.sparse':
239        return _scipy_sparse_zeros(m, n, **options)
240    raise NotImplementedError('Invaild format: %r' % format)
241
242
243def _numpy_matrix_to_zero(e):
244    """Convert a numpy zero matrix to the zero scalar."""
245    if not np:
246        raise ImportError
247    test = np.zeros_like(e)
248    if np.allclose(e, test):
249        return 0.0
250    else:
251        return e
252
253
254def _scipy_sparse_matrix_to_zero(e):
255    """Convert a scipy.sparse zero matrix to the zero scalar."""
256    if not np:
257        raise ImportError
258    edense = e.todense()
259    test = np.zeros_like(edense)
260    if np.allclose(edense, test):
261        return 0.0
262    else:
263        return e
264
265
266def matrix_to_zero(e):
267    """Convert a zero matrix to the scalar zero."""
268    if isinstance(e, MatrixBase):
269        if zeros(*e.shape) == e:
270            e = Integer(0)
271    elif isinstance(e, numpy_ndarray):
272        e = _numpy_matrix_to_zero(e)
273    elif isinstance(e, scipy_sparse_matrix):
274        e = _scipy_sparse_matrix_to_zero(e)
275    return e
276