1from sympy import Basic
2from sympy import S
3from sympy.core.expr import Expr
4from sympy.core.numbers import Integer
5from sympy.core.sympify import sympify
6from sympy.core.kind import Kind, NumberKind, UndefinedKind
7from sympy.core.compatibility import SYMPY_INTS
8from sympy.printing.defaults import Printable
9
10import itertools
11from collections.abc import Iterable
12
13
14class ArrayKind(Kind):
15    """
16    Kind for N-dimensional array in SymPy.
17
18    This kind represents the multidimensional array that algebraic
19    operations are defined. Basic class for this kind is ``NDimArray``,
20    but any expression representing the array can have this.
21
22    Parameters
23    ==========
24
25    element_kind : Kind
26        Kind of the element. Default is :obj:NumberKind `<sympy.core.kind.NumberKind>`,
27        which means that the array contains only numbers.
28
29    Examples
30    ========
31
32    Any instance of array class has ``ArrayKind``.
33
34    >>> from sympy import NDimArray
35    >>> NDimArray([1,2,3]).kind
36    ArrayKind(NumberKind)
37
38    Although expressions representing an array may be not instance of
39    array class, it will have ``ArrayKind`` as well.
40
41    >>> from sympy import Integral
42    >>> from sympy.tensor.array import NDimArray
43    >>> from sympy.abc import x
44    >>> intA = Integral(NDimArray([1,2,3]), x)
45    >>> isinstance(intA, NDimArray)
46    False
47    >>> intA.kind
48    ArrayKind(NumberKind)
49
50    Use ``isinstance()`` to check for ``ArrayKind` without specifying
51    the element kind. Use ``is`` with specifying the element kind.
52
53    >>> from sympy.tensor.array import ArrayKind
54    >>> from sympy.core.kind import NumberKind
55    >>> boolA = NDimArray([True, False])
56    >>> isinstance(boolA.kind, ArrayKind)
57    True
58    >>> boolA.kind is ArrayKind(NumberKind)
59    False
60
61    See Also
62    ========
63
64    shape : Function to return the shape of objects with ``MatrixKind``.
65
66    """
67    def __new__(cls, element_kind=NumberKind):
68        obj = super().__new__(cls, element_kind)
69        obj.element_kind = element_kind
70        return obj
71
72    def __repr__(self):
73        return "ArrayKind(%s)" % self.element_kind
74
75
76class NDimArray(Printable):
77    """
78
79    Examples
80    ========
81
82    Create an N-dim array of zeros:
83
84    >>> from sympy import MutableDenseNDimArray
85    >>> a = MutableDenseNDimArray.zeros(2, 3, 4)
86    >>> a
87    [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
88
89    Create an N-dim array from a list;
90
91    >>> a = MutableDenseNDimArray([[2, 3], [4, 5]])
92    >>> a
93    [[2, 3], [4, 5]]
94
95    >>> b = MutableDenseNDimArray([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
96    >>> b
97    [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]
98
99    Create an N-dim array from a flat list with dimension shape:
100
101    >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3))
102    >>> a
103    [[1, 2, 3], [4, 5, 6]]
104
105    Create an N-dim array from a matrix:
106
107    >>> from sympy import Matrix
108    >>> a = Matrix([[1,2],[3,4]])
109    >>> a
110    Matrix([
111    [1, 2],
112    [3, 4]])
113    >>> b = MutableDenseNDimArray(a)
114    >>> b
115    [[1, 2], [3, 4]]
116
117    Arithmetic operations on N-dim arrays
118
119    >>> a = MutableDenseNDimArray([1, 1, 1, 1], (2, 2))
120    >>> b = MutableDenseNDimArray([4, 4, 4, 4], (2, 2))
121    >>> c = a + b
122    >>> c
123    [[5, 5], [5, 5]]
124    >>> a - b
125    [[-3, -3], [-3, -3]]
126
127    """
128
129    _diff_wrt = True
130    is_scalar = False
131
132    def __new__(cls, iterable, shape=None, **kwargs):
133        from sympy.tensor.array import ImmutableDenseNDimArray
134        return ImmutableDenseNDimArray(iterable, shape, **kwargs)
135
136    @property
137    def kind(self):
138        elem_kinds = set(e.kind for e in self._array)
139        if len(elem_kinds) == 1:
140            elemkind, = elem_kinds
141        else:
142            elemkind = UndefinedKind
143        return ArrayKind(elemkind)
144
145    def _parse_index(self, index):
146        if isinstance(index, (SYMPY_INTS, Integer)):
147            raise ValueError("Only a tuple index is accepted")
148
149        if self._loop_size == 0:
150            raise ValueError("Index not valide with an empty array")
151
152        if len(index) != self._rank:
153            raise ValueError('Wrong number of array axes')
154
155        real_index = 0
156        # check if input index can exist in current indexing
157        for i in range(self._rank):
158            if (index[i] >= self.shape[i]) or (index[i] < -self.shape[i]):
159                raise ValueError('Index ' + str(index) + ' out of border')
160            if index[i] < 0:
161                real_index += 1
162            real_index = real_index*self.shape[i] + index[i]
163
164        return real_index
165
166    def _get_tuple_index(self, integer_index):
167        index = []
168        for i, sh in enumerate(reversed(self.shape)):
169            index.append(integer_index % sh)
170            integer_index //= sh
171        index.reverse()
172        return tuple(index)
173
174    def _check_symbolic_index(self, index):
175        # Check if any index is symbolic:
176        tuple_index = (index if isinstance(index, tuple) else (index,))
177        if any([(isinstance(i, Expr) and (not i.is_number)) for i in tuple_index]):
178            for i, nth_dim in zip(tuple_index, self.shape):
179                if ((i < 0) == True) or ((i >= nth_dim) == True):
180                    raise ValueError("index out of range")
181            from sympy.tensor import Indexed
182            return Indexed(self, *tuple_index)
183        return None
184
185    def _setter_iterable_check(self, value):
186        from sympy.matrices.matrices import MatrixBase
187        if isinstance(value, (Iterable, MatrixBase, NDimArray)):
188            raise NotImplementedError
189
190    @classmethod
191    def _scan_iterable_shape(cls, iterable):
192        def f(pointer):
193            if not isinstance(pointer, Iterable):
194                return [pointer], ()
195
196            result = []
197            elems, shapes = zip(*[f(i) for i in pointer])
198            if len(set(shapes)) != 1:
199                raise ValueError("could not determine shape unambiguously")
200            for i in elems:
201                result.extend(i)
202            return result, (len(shapes),)+shapes[0]
203
204        return f(iterable)
205
206    @classmethod
207    def _handle_ndarray_creation_inputs(cls, iterable=None, shape=None, **kwargs):
208        from sympy.matrices.matrices import MatrixBase
209        from sympy.tensor.array import SparseNDimArray
210        from sympy import Dict, Tuple
211
212        if shape is None:
213            if iterable is None:
214                shape = ()
215                iterable = ()
216            # Construction of a sparse array from a sparse array
217            elif isinstance(iterable, SparseNDimArray):
218                return iterable._shape, iterable._sparse_array
219
220            # Construct N-dim array from another N-dim array:
221            elif isinstance(iterable, NDimArray):
222                shape = iterable.shape
223
224            # Construct N-dim array from an iterable (numpy arrays included):
225            elif isinstance(iterable, Iterable):
226                iterable, shape = cls._scan_iterable_shape(iterable)
227
228            # Construct N-dim array from a Matrix:
229            elif isinstance(iterable, MatrixBase):
230                shape = iterable.shape
231
232            else:
233                shape = ()
234                iterable = (iterable,)
235
236        if isinstance(iterable, (Dict, dict)) and shape is not None:
237            new_dict = iterable.copy()
238            for k, v in new_dict.items():
239                if isinstance(k, (tuple, Tuple)):
240                    new_key = 0
241                    for i, idx in enumerate(k):
242                        new_key = new_key * shape[i] + idx
243                    iterable[new_key] = iterable[k]
244                    del iterable[k]
245
246        if isinstance(shape, (SYMPY_INTS, Integer)):
247            shape = (shape,)
248
249        if any([not isinstance(dim, (SYMPY_INTS, Integer)) for dim in shape]):
250            raise TypeError("Shape should contain integers only.")
251
252        return tuple(shape), iterable
253
254    def __len__(self):
255        """Overload common function len(). Returns number of elements in array.
256
257        Examples
258        ========
259
260        >>> from sympy import MutableDenseNDimArray
261        >>> a = MutableDenseNDimArray.zeros(3, 3)
262        >>> a
263        [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
264        >>> len(a)
265        9
266
267        """
268        return self._loop_size
269
270    @property
271    def shape(self):
272        """
273        Returns array shape (dimension).
274
275        Examples
276        ========
277
278        >>> from sympy import MutableDenseNDimArray
279        >>> a = MutableDenseNDimArray.zeros(3, 3)
280        >>> a.shape
281        (3, 3)
282
283        """
284        return self._shape
285
286    def rank(self):
287        """
288        Returns rank of array.
289
290        Examples
291        ========
292
293        >>> from sympy import MutableDenseNDimArray
294        >>> a = MutableDenseNDimArray.zeros(3,4,5,6,3)
295        >>> a.rank()
296        5
297
298        """
299        return self._rank
300
301    def diff(self, *args, **kwargs):
302        """
303        Calculate the derivative of each element in the array.
304
305        Examples
306        ========
307
308        >>> from sympy import ImmutableDenseNDimArray
309        >>> from sympy.abc import x, y
310        >>> M = ImmutableDenseNDimArray([[x, y], [1, x*y]])
311        >>> M.diff(x)
312        [[1, 0], [0, y]]
313
314        """
315        from sympy.tensor.array.array_derivatives import ArrayDerivative
316        kwargs.setdefault('evaluate', True)
317        return ArrayDerivative(self.as_immutable(), *args, **kwargs)
318
319    def _eval_derivative(self, base):
320        # Types are (base: scalar, self: array)
321        return self.applyfunc(lambda x: base.diff(x))
322
323    def _eval_derivative_n_times(self, s, n):
324        return Basic._eval_derivative_n_times(self, s, n)
325
326    def applyfunc(self, f):
327        """Apply a function to each element of the N-dim array.
328
329        Examples
330        ========
331
332        >>> from sympy import ImmutableDenseNDimArray
333        >>> m = ImmutableDenseNDimArray([i*2+j for i in range(2) for j in range(2)], (2, 2))
334        >>> m
335        [[0, 1], [2, 3]]
336        >>> m.applyfunc(lambda i: 2*i)
337        [[0, 2], [4, 6]]
338        """
339        from sympy.tensor.array import SparseNDimArray
340        from sympy.tensor.array.arrayop import Flatten
341
342        if isinstance(self, SparseNDimArray) and f(S.Zero) == 0:
343            return type(self)({k: f(v) for k, v in self._sparse_array.items() if f(v) != 0}, self.shape)
344
345        return type(self)(map(f, Flatten(self)), self.shape)
346
347    def _sympystr(self, printer):
348        def f(sh, shape_left, i, j):
349            if len(shape_left) == 1:
350                return "["+", ".join([printer._print(self[self._get_tuple_index(e)]) for e in range(i, j)])+"]"
351
352            sh //= shape_left[0]
353            return "[" + ", ".join([f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh) for e in range(shape_left[0])]) + "]" # + "\n"*len(shape_left)
354
355        if self.rank() == 0:
356            return printer._print(self[()])
357
358        return f(self._loop_size, self.shape, 0, self._loop_size)
359
360    def tolist(self):
361        """
362        Converting MutableDenseNDimArray to one-dim list
363
364        Examples
365        ========
366
367        >>> from sympy import MutableDenseNDimArray
368        >>> a = MutableDenseNDimArray([1, 2, 3, 4], (2, 2))
369        >>> a
370        [[1, 2], [3, 4]]
371        >>> b = a.tolist()
372        >>> b
373        [[1, 2], [3, 4]]
374        """
375
376        def f(sh, shape_left, i, j):
377            if len(shape_left) == 1:
378                return [self[self._get_tuple_index(e)] for e in range(i, j)]
379            result = []
380            sh //= shape_left[0]
381            for e in range(shape_left[0]):
382                result.append(f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh))
383            return result
384
385        return f(self._loop_size, self.shape, 0, self._loop_size)
386
387    def __add__(self, other):
388        from sympy.tensor.array.arrayop import Flatten
389
390        if not isinstance(other, NDimArray):
391            return NotImplemented
392
393        if self.shape != other.shape:
394            raise ValueError("array shape mismatch")
395        result_list = [i+j for i,j in zip(Flatten(self), Flatten(other))]
396
397        return type(self)(result_list, self.shape)
398
399    def __sub__(self, other):
400        from sympy.tensor.array.arrayop import Flatten
401
402        if not isinstance(other, NDimArray):
403            return NotImplemented
404
405        if self.shape != other.shape:
406            raise ValueError("array shape mismatch")
407        result_list = [i-j for i,j in zip(Flatten(self), Flatten(other))]
408
409        return type(self)(result_list, self.shape)
410
411    def __mul__(self, other):
412        from sympy.matrices.matrices import MatrixBase
413        from sympy.tensor.array import SparseNDimArray
414        from sympy.tensor.array.arrayop import Flatten
415
416        if isinstance(other, (Iterable, NDimArray, MatrixBase)):
417            raise ValueError("scalar expected, use tensorproduct(...) for tensorial product")
418
419        other = sympify(other)
420        if isinstance(self, SparseNDimArray):
421            if other.is_zero:
422                return type(self)({}, self.shape)
423            return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape)
424
425        result_list = [i*other for i in Flatten(self)]
426        return type(self)(result_list, self.shape)
427
428    def __rmul__(self, other):
429        from sympy.matrices.matrices import MatrixBase
430        from sympy.tensor.array import SparseNDimArray
431        from sympy.tensor.array.arrayop import Flatten
432
433        if isinstance(other, (Iterable, NDimArray, MatrixBase)):
434            raise ValueError("scalar expected, use tensorproduct(...) for tensorial product")
435
436        other = sympify(other)
437        if isinstance(self, SparseNDimArray):
438            if other.is_zero:
439                return type(self)({}, self.shape)
440            return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape)
441
442        result_list = [other*i for i in Flatten(self)]
443        return type(self)(result_list, self.shape)
444
445    def __truediv__(self, other):
446        from sympy.matrices.matrices import MatrixBase
447        from sympy.tensor.array import SparseNDimArray
448        from sympy.tensor.array.arrayop import Flatten
449
450        if isinstance(other, (Iterable, NDimArray, MatrixBase)):
451            raise ValueError("scalar expected")
452
453        other = sympify(other)
454        if isinstance(self, SparseNDimArray) and other != S.Zero:
455            return type(self)({k: v/other for (k, v) in self._sparse_array.items()}, self.shape)
456
457        result_list = [i/other for i in Flatten(self)]
458        return type(self)(result_list, self.shape)
459
460    def __rtruediv__(self, other):
461        raise NotImplementedError('unsupported operation on NDimArray')
462
463    def __neg__(self):
464        from sympy.tensor.array import SparseNDimArray
465        from sympy.tensor.array.arrayop import Flatten
466
467        if isinstance(self, SparseNDimArray):
468            return type(self)({k: -v for (k, v) in self._sparse_array.items()}, self.shape)
469
470        result_list = [-i for i in Flatten(self)]
471        return type(self)(result_list, self.shape)
472
473    def __iter__(self):
474        def iterator():
475            if self._shape:
476                for i in range(self._shape[0]):
477                    yield self[i]
478            else:
479                yield self[()]
480
481        return iterator()
482
483    def __eq__(self, other):
484        """
485        NDimArray instances can be compared to each other.
486        Instances equal if they have same shape and data.
487
488        Examples
489        ========
490
491        >>> from sympy import MutableDenseNDimArray
492        >>> a = MutableDenseNDimArray.zeros(2, 3)
493        >>> b = MutableDenseNDimArray.zeros(2, 3)
494        >>> a == b
495        True
496        >>> c = a.reshape(3, 2)
497        >>> c == b
498        False
499        >>> a[0,0] = 1
500        >>> b[0,0] = 2
501        >>> a == b
502        False
503        """
504        from sympy.tensor.array import SparseNDimArray
505        if not isinstance(other, NDimArray):
506            return False
507
508        if not self.shape == other.shape:
509            return False
510
511        if isinstance(self, SparseNDimArray) and isinstance(other, SparseNDimArray):
512            return dict(self._sparse_array) == dict(other._sparse_array)
513
514        return list(self) == list(other)
515
516    def __ne__(self, other):
517        return not self == other
518
519    def _eval_transpose(self):
520        if self.rank() != 2:
521            raise ValueError("array rank not 2")
522        from .arrayop import permutedims
523        return permutedims(self, (1, 0))
524
525    def transpose(self):
526        return self._eval_transpose()
527
528    def _eval_conjugate(self):
529        from sympy.tensor.array.arrayop import Flatten
530
531        return self.func([i.conjugate() for i in Flatten(self)], self.shape)
532
533    def conjugate(self):
534        return self._eval_conjugate()
535
536    def _eval_adjoint(self):
537        return self.transpose().conjugate()
538
539    def adjoint(self):
540        return self._eval_adjoint()
541
542    def _slice_expand(self, s, dim):
543        if not isinstance(s, slice):
544                return (s,)
545        start, stop, step = s.indices(dim)
546        return [start + i*step for i in range((stop-start)//step)]
547
548    def _get_slice_data_for_array_access(self, index):
549        sl_factors = [self._slice_expand(i, dim) for (i, dim) in zip(index, self.shape)]
550        eindices = itertools.product(*sl_factors)
551        return sl_factors, eindices
552
553    def _get_slice_data_for_array_assignment(self, index, value):
554        if not isinstance(value, NDimArray):
555            value = type(self)(value)
556        sl_factors, eindices = self._get_slice_data_for_array_access(index)
557        slice_offsets = [min(i) if isinstance(i, list) else None for i in sl_factors]
558        # TODO: add checks for dimensions for `value`?
559        return value, eindices, slice_offsets
560
561    @classmethod
562    def _check_special_bounds(cls, flat_list, shape):
563        if shape == () and len(flat_list) != 1:
564            raise ValueError("arrays without shape need one scalar value")
565        if shape == (0,) and len(flat_list) > 0:
566            raise ValueError("if array shape is (0,) there cannot be elements")
567
568    def _check_index_for_getitem(self, index):
569        if isinstance(index, (SYMPY_INTS, Integer, slice)):
570            index = (index, )
571
572        if len(index) < self.rank():
573            index = tuple([i for i in index] + \
574                          [slice(None) for i in range(len(index), self.rank())])
575
576        if len(index) > self.rank():
577            raise ValueError('Dimension of index greater than rank of array')
578
579        return index
580
581
582class ImmutableNDimArray(NDimArray, Basic):
583    _op_priority = 11.0
584
585    def __hash__(self):
586        return Basic.__hash__(self)
587
588    def as_immutable(self):
589        return self
590
591    def as_mutable(self):
592        raise NotImplementedError("abstract method")
593