1import operator
2from functools import reduce
3import itertools
4from itertools import accumulate
5from typing import Optional, List, Dict
6
7from sympy import Expr, ImmutableDenseNDimArray, S, Symbol, Integer, ZeroMatrix, Basic, tensorproduct, Add, permutedims, \
8    Tuple, tensordiagonal, Lambda, Dummy, Function, MatrixExpr, NDimArray, Indexed, IndexedBase, default_sort_key, \
9    tensorcontraction, diagonalize_vector, Mul
10from sympy.matrices.expressions.matexpr import MatrixElement
11from sympy.tensor.array.expressions.utils import _apply_recursively_over_nested_lists, _sort_contraction_indices, \
12    _get_mapping_from_subranks, _build_push_indices_up_func_transformation, _get_contraction_links, \
13    _build_push_indices_down_func_transformation
14from sympy.combinatorics import Permutation
15from sympy.combinatorics.permutations import _af_invert
16from sympy.core.sympify import _sympify
17
18
19class _ArrayExpr(Expr):
20    pass
21
22
23class ArraySymbol(_ArrayExpr):
24    """
25    Symbol representing an array expression
26    """
27
28    def __new__(cls, symbol, *shape):
29        if isinstance(symbol, str):
30            symbol = Symbol(symbol)
31        # symbol = _sympify(symbol)
32        shape = map(_sympify, shape)
33        obj = Expr.__new__(cls, symbol, *shape)
34        return obj
35
36    @property
37    def name(self):
38        return self._args[0]
39
40    @property
41    def shape(self):
42        return self._args[1:]
43
44    def __getitem__(self, item):
45        return ArrayElement(self, item)
46
47    def as_explicit(self):
48        if any(not isinstance(i, (int, Integer)) for i in self.shape):
49            raise ValueError("cannot express explicit array with symbolic shape")
50        data = [self[i] for i in itertools.product(*[range(j) for j in self.shape])]
51        return ImmutableDenseNDimArray(data).reshape(*self.shape)
52
53
54class ArrayElement(_ArrayExpr):
55    """
56    An element of an array.
57    """
58    def __new__(cls, name, indices):
59        if isinstance(name, str):
60            name = Symbol(name)
61        name = _sympify(name)
62        indices = _sympify(indices)
63        if hasattr(name, "shape"):
64            if any([(i >= s) == True for i, s in zip(indices, name.shape)]):
65                raise ValueError("shape is out of bounds")
66        if any([(i < 0) == True for i in indices]):
67            raise ValueError("shape contains negative values")
68        obj = Expr.__new__(cls, name, indices)
69        return obj
70
71    @property
72    def name(self):
73        return self._args[0]
74
75    @property
76    def indices(self):
77        return self._args[1]
78
79
80class ZeroArray(_ArrayExpr):
81    """
82    Symbolic array of zeros. Equivalent to ``ZeroMatrix`` for matrices.
83    """
84
85    def __new__(cls, *shape):
86        if len(shape) == 0:
87            return S.Zero
88        shape = map(_sympify, shape)
89        obj = Expr.__new__(cls, *shape)
90        return obj
91
92    @property
93    def shape(self):
94        return self._args
95
96    def as_explicit(self):
97        if any(not i.is_Integer for i in self.shape):
98            raise ValueError("Cannot return explicit form for symbolic shape.")
99        return ImmutableDenseNDimArray.zeros(*self.shape)
100
101
102class OneArray(_ArrayExpr):
103    """
104    Symbolic array of ones.
105    """
106
107    def __new__(cls, *shape):
108        if len(shape) == 0:
109            return S.One
110        shape = map(_sympify, shape)
111        obj = Expr.__new__(cls, *shape)
112        return obj
113
114    @property
115    def shape(self):
116        return self._args
117
118    def as_explicit(self):
119        if any(not i.is_Integer for i in self.shape):
120            raise ValueError("Cannot return explicit form for symbolic shape.")
121        return ImmutableDenseNDimArray([S.One for i in range(reduce(operator.mul, self.shape))]).reshape(*self.shape)
122
123
124class _CodegenArrayAbstract(Basic):
125
126    @property
127    def subranks(self):
128        """
129        Returns the ranks of the objects in the uppermost tensor product inside
130        the current object.  In case no tensor products are contained, return
131        the atomic ranks.
132
133        Examples
134        ========
135
136        >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction
137        >>> from sympy import MatrixSymbol
138        >>> M = MatrixSymbol("M", 3, 3)
139        >>> N = MatrixSymbol("N", 3, 3)
140        >>> P = MatrixSymbol("P", 3, 3)
141
142        Important: do not confuse the rank of the matrix with the rank of an array.
143
144        >>> tp = ArrayTensorProduct(M, N, P)
145        >>> tp.subranks
146        [2, 2, 2]
147
148        >>> co = ArrayContraction(tp, (1, 2), (3, 4))
149        >>> co.subranks
150        [2, 2, 2]
151        """
152        return self._subranks[:]
153
154    def subrank(self):
155        """
156        The sum of ``subranks``.
157        """
158        return sum(self.subranks)
159
160    @property
161    def shape(self):
162        return self._shape
163
164
165class ArrayTensorProduct(_CodegenArrayAbstract):
166    r"""
167    Class to represent the tensor product of array-like objects.
168    """
169
170    def __new__(cls, *args):
171        args = [_sympify(arg) for arg in args]
172        args = cls._flatten(args)
173        ranks = [get_rank(arg) for arg in args]
174
175        # Check if there are nested permutation and lift them up:
176        permutation_cycles = []
177        for i, arg in enumerate(args):
178            if not isinstance(arg, PermuteDims):
179                continue
180            permutation_cycles.extend([[k + sum(ranks[:i]) for k in j] for j in arg.permutation.cyclic_form])
181            args[i] = arg.expr
182        if permutation_cycles:
183            return PermuteDims(ArrayTensorProduct(*args), Permutation(sum(ranks)-1)*Permutation(permutation_cycles))
184
185        if len(args) == 1:
186            return args[0]
187
188        # If any object is a ZeroArray, return a ZeroArray:
189        if any(isinstance(arg, (ZeroArray, ZeroMatrix)) for arg in args):
190            shapes = reduce(operator.add, [get_shape(i) for i in args], ())
191            return ZeroArray(*shapes)
192
193        # If there are contraction objects inside, transform the whole
194        # expression into `ArrayContraction`:
195        contractions = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayContraction)}
196        if contractions:
197            ranks = [_get_subrank(arg) if isinstance(arg, ArrayContraction) else get_rank(arg) for arg in args]
198            cumulative_ranks = list(accumulate([0] + ranks))[:-1]
199            tp = cls(*[arg.expr if isinstance(arg, ArrayContraction) else arg for arg in args])
200            contraction_indices = [tuple(cumulative_ranks[i] + k for k in j) for i, arg in contractions.items() for j in arg.contraction_indices]
201            return ArrayContraction(tp, *contraction_indices)
202
203        diagonals = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayDiagonal)}
204        if diagonals:
205            permutation = []
206            last_perm = []
207            ranks = [get_rank(arg) for arg in args]
208            cumulative_ranks = list(accumulate([0] + ranks))[:-1]
209            for i, arg in enumerate(args):
210                if isinstance(arg, ArrayDiagonal):
211                    i1 = get_rank(arg) - len(arg.diagonal_indices)
212                    i2 = len(arg.diagonal_indices)
213                    permutation.extend([cumulative_ranks[i] + j for j in range(i1)])
214                    last_perm.extend([cumulative_ranks[i] + j for j in range(i1, i1 + i2)])
215                else:
216                    permutation.extend([cumulative_ranks[i] + j for j in range(get_rank(arg))])
217            permutation.extend(last_perm)
218            tp = cls(*[arg.expr if isinstance(arg, ArrayDiagonal) else arg for arg in args])
219            ranks2 = [_get_subrank(arg) if isinstance(arg, ArrayDiagonal) else get_rank(arg) for arg in args]
220            cumulative_ranks2 = list(accumulate([0] + ranks2))[:-1]
221            diagonal_indices = [tuple(cumulative_ranks2[i] + k for k in j) for i, arg in diagonals.items() for j in arg.diagonal_indices]
222            return PermuteDims(ArrayDiagonal(tp, *diagonal_indices), permutation)
223
224        obj = Basic.__new__(cls, *args)
225        obj._subranks = ranks
226        shapes = [get_shape(i) for i in args]
227
228        if any(i is None for i in shapes):
229            obj._shape = None
230        else:
231            obj._shape = tuple(j for i in shapes for j in i)
232        return obj
233
234    @classmethod
235    def _flatten(cls, args):
236        args = [i for arg in args for i in (arg.args if isinstance(arg, cls) else [arg])]
237        return args
238
239    def as_explicit(self):
240        return tensorproduct(*[arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args])
241
242
243class ArrayAdd(_CodegenArrayAbstract):
244    r"""
245    Class for elementwise array additions.
246    """
247
248    def __new__(cls, *args):
249        args = [_sympify(arg) for arg in args]
250        ranks = [get_rank(arg) for arg in args]
251        ranks = list(set(ranks))
252        if len(ranks) != 1:
253            raise ValueError("summing arrays of different ranks")
254        shapes = [arg.shape for arg in args]
255        if len({i for i in shapes if i is not None}) > 1:
256            raise ValueError("mismatching shapes in addition")
257
258        # Flatten:
259        args = cls._flatten_args(args)
260
261        args = [arg for arg in args if not isinstance(arg, (ZeroArray, ZeroMatrix))]
262        if len(args) == 0:
263            if any(i for i in shapes if i is None):
264                raise NotImplementedError("cannot handle addition of ZeroMatrix/ZeroArray and undefined shape object")
265            return ZeroArray(*shapes[0])
266        elif len(args) == 1:
267            return args[0]
268
269        obj = Basic.__new__(cls, *args)
270        obj._subranks = ranks
271        if any(i is None for i in shapes):
272            obj._shape = None
273        else:
274            obj._shape = shapes[0]
275        return obj
276
277    @classmethod
278    def _flatten_args(cls, args):
279        new_args = []
280        for arg in args:
281            if isinstance(arg, ArrayAdd):
282                new_args.extend(arg.args)
283            else:
284                new_args.append(arg)
285        return new_args
286
287    def as_explicit(self):
288        return Add.fromiter([arg.as_explicit() for arg in self.args])
289
290
291class PermuteDims(_CodegenArrayAbstract):
292    r"""
293    Class to represent permutation of axes of arrays.
294
295    Examples
296    ========
297
298    >>> from sympy.tensor.array.expressions.array_expressions import PermuteDims
299    >>> from sympy import MatrixSymbol
300    >>> M = MatrixSymbol("M", 3, 3)
301    >>> cg = PermuteDims(M, [1, 0])
302
303    The object ``cg`` represents the transposition of ``M``, as the permutation
304    ``[1, 0]`` will act on its indices by switching them:
305
306    `M_{ij} \Rightarrow M_{ji}`
307
308    This is evident when transforming back to matrix form:
309
310    >>> from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
311    >>> convert_array_to_matrix(cg)
312    M.T
313
314    >>> N = MatrixSymbol("N", 3, 2)
315    >>> cg = PermuteDims(N, [1, 0])
316    >>> cg.shape
317    (2, 3)
318
319    Permutations of tensor products are simplified in order to achieve a
320    standard form:
321
322    >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
323    >>> M = MatrixSymbol("M", 4, 5)
324    >>> tp = ArrayTensorProduct(M, N)
325    >>> tp.shape
326    (4, 5, 3, 2)
327    >>> perm1 = PermuteDims(tp, [2, 3, 1, 0])
328
329    The args ``(M, N)`` have been sorted and the permutation has been
330    simplified, the expression is equivalent:
331
332    >>> perm1.expr.args
333    (N, M)
334    >>> perm1.shape
335    (3, 2, 5, 4)
336    >>> perm1.permutation
337    (2 3)
338
339    The permutation in its array form has been simplified from
340    ``[2, 3, 1, 0]`` to ``[0, 1, 3, 2]``, as the arguments of the tensor
341    product `M` and `N` have been switched:
342
343    >>> perm1.permutation.array_form
344    [0, 1, 3, 2]
345
346    We can nest a second permutation:
347
348    >>> perm2 = PermuteDims(perm1, [1, 0, 2, 3])
349    >>> perm2.shape
350    (2, 3, 5, 4)
351    >>> perm2.permutation.array_form
352    [1, 0, 3, 2]
353    """
354
355    def __new__(cls, expr, permutation, nest_permutation=True):
356        from sympy.combinatorics import Permutation
357        expr = _sympify(expr)
358        permutation = Permutation(permutation)
359        permutation_size = permutation.size
360        expr_rank = get_rank(expr)
361        if permutation_size != expr_rank:
362            raise ValueError("Permutation size must be the length of the shape of expr")
363        if isinstance(expr, PermuteDims):
364            subexpr = expr.expr
365            subperm = expr.permutation
366            permutation = permutation * subperm
367            expr = subexpr
368        if isinstance(expr, ArrayContraction):
369            expr, permutation = cls._handle_nested_contraction(expr, permutation)
370        if isinstance(expr, ArrayTensorProduct):
371            expr, permutation = cls._sort_components(expr, permutation)
372        if isinstance(expr, (ZeroArray, ZeroMatrix)):
373            return ZeroArray(*[expr.shape[i] for i in permutation.array_form])
374        plist = permutation.array_form
375        if plist == sorted(plist):
376            return expr
377        obj = Basic.__new__(cls, expr, permutation)
378        obj._subranks = [get_rank(expr)]
379        shape = expr.shape
380        if shape is None:
381            obj._shape = None
382        else:
383            obj._shape = tuple(shape[permutation(i)] for i in range(len(shape)))
384        return obj
385
386    @property
387    def expr(self):
388        return self.args[0]
389
390    @property
391    def permutation(self):
392        return self.args[1]
393
394    @classmethod
395    def _sort_components(cls, expr, permutation):
396        # Get the permutation in its image-form:
397        perm_image_form = _af_invert(permutation.array_form)
398        args = list(expr.args)
399        # Starting index global position for every arg:
400        cumul = list(accumulate([0] + expr.subranks))
401        # Split `perm_image_form` into a list of list corresponding to the indices
402        # of every argument:
403        perm_image_form_in_components = [perm_image_form[cumul[i]:cumul[i+1]] for i in range(len(args))]
404        # Create an index, target-position-key array:
405        ps = [(i, sorted(comp)) for i, comp in enumerate(perm_image_form_in_components)]
406        # Sort the array according to the target-position-key:
407        # In this way, we define a canonical way to sort the arguments according
408        # to the permutation.
409        ps.sort(key=lambda x: x[1])
410        # Read the inverse-permutation (i.e. image-form) of the args:
411        perm_args_image_form = [i[0] for i in ps]
412        # Apply the args-permutation to the `args`:
413        args_sorted = [args[i] for i in perm_args_image_form]
414        # Apply the args-permutation to the array-form of the permutation of the axes (of `expr`):
415        perm_image_form_sorted_args = [perm_image_form_in_components[i] for i in perm_args_image_form]
416        new_permutation = Permutation(_af_invert([j for i in perm_image_form_sorted_args for j in i]))
417        return ArrayTensorProduct(*args_sorted), new_permutation
418
419    @classmethod
420    def _handle_nested_contraction(cls, expr, permutation):
421        if not isinstance(expr, ArrayContraction):
422            return expr, permutation
423        if not isinstance(expr.expr, ArrayTensorProduct):
424            return expr, permutation
425        args = expr.expr.args
426        subranks = [get_rank(arg) for arg in expr.expr.args]
427
428        contraction_indices = expr.contraction_indices
429        contraction_indices_flat = [j for i in contraction_indices for j in i]
430        cumul = list(accumulate([0] + subranks))
431
432        # Spread the permutation in its array form across the args in the corresponding
433        # tensor-product arguments with free indices:
434        permutation_array_blocks_up = []
435        image_form = _af_invert(permutation.array_form)
436        counter = 0
437        for i, e in enumerate(subranks):
438            current = []
439            for j in range(cumul[i], cumul[i+1]):
440                if j in contraction_indices_flat:
441                    continue
442                current.append(image_form[counter])
443                counter += 1
444            permutation_array_blocks_up.append(current)
445
446        # Get the map of axis repositioning for every argument of tensor-product:
447        index_blocks = [[j for j in range(cumul[i], cumul[i+1])] for i, e in enumerate(expr.subranks)]
448        index_blocks_up = expr._push_indices_up(expr.contraction_indices, index_blocks)
449        inverse_permutation = permutation**(-1)
450        index_blocks_up_permuted = [[inverse_permutation(j) for j in i if j is not None] for i in index_blocks_up]
451
452        # Sorting key is a list of tuple, first element is the index of `args`, second element of
453        # the tuple is the sorting key to sort `args` of the tensor product:
454        sorting_keys = list(enumerate(index_blocks_up_permuted))
455        sorting_keys.sort(key=lambda x: x[1])
456
457        # Now we can get the permutation acting on the args in its image-form:
458        new_perm_image_form = [i[0] for i in sorting_keys]
459        # Apply the args-level permutation to various elements:
460        new_index_blocks = [index_blocks[i] for i in new_perm_image_form]
461        new_index_perm_array_form = _af_invert([j for i in new_index_blocks for j in i])
462        new_args = [args[i] for i in new_perm_image_form]
463        new_contraction_indices = [tuple(new_index_perm_array_form[j] for j in i) for i in contraction_indices]
464        new_expr = ArrayContraction(ArrayTensorProduct(*new_args), *new_contraction_indices)
465        new_permutation = Permutation(_af_invert([j for i in [permutation_array_blocks_up[k] for k in new_perm_image_form] for j in i]))
466        return new_expr, new_permutation
467
468    @classmethod
469    def _check_permutation_mapping(cls, expr, permutation):
470        subranks = expr.subranks
471        index2arg = [i for i, arg in enumerate(expr.args) for j in range(expr.subranks[i])]
472        permuted_indices = [permutation(i) for i in range(expr.subrank())]
473        new_args = list(expr.args)
474        arg_candidate_index = index2arg[permuted_indices[0]]
475        current_indices = []
476        new_permutation = []
477        inserted_arg_cand_indices = set([])
478        for i, idx in enumerate(permuted_indices):
479            if index2arg[idx] != arg_candidate_index:
480                new_permutation.extend(current_indices)
481                current_indices = []
482                arg_candidate_index = index2arg[idx]
483            current_indices.append(idx)
484            arg_candidate_rank = subranks[arg_candidate_index]
485            if len(current_indices) == arg_candidate_rank:
486                new_permutation.extend(sorted(current_indices))
487                local_current_indices = [j - min(current_indices) for j in current_indices]
488                i1 = index2arg[i]
489                new_args[i1] = PermuteDims(new_args[i1], Permutation(local_current_indices))
490                inserted_arg_cand_indices.add(arg_candidate_index)
491                current_indices = []
492        new_permutation.extend(current_indices)
493
494        # TODO: swap args positions in order to simplify the expression:
495        # TODO: this should be in a function
496        args_positions = list(range(len(new_args)))
497        # Get possible shifts:
498        maps = {}
499        cumulative_subranks = [0] + list(accumulate(subranks))
500        for i in range(0, len(subranks)):
501            s = set([index2arg[new_permutation[j]] for j in range(cumulative_subranks[i], cumulative_subranks[i+1])])
502            if len(s) != 1:
503                continue
504            elem = next(iter(s))
505            if i != elem:
506                maps[i] = elem
507
508        # Find cycles in the map:
509        lines = []
510        current_line = []
511        while maps:
512            if len(current_line) == 0:
513                k, v = maps.popitem()
514                current_line.append(k)
515            else:
516                k = current_line[-1]
517                if k not in maps:
518                    current_line = []
519                    continue
520                v = maps.pop(k)
521            if v in current_line:
522                lines.append(current_line)
523                current_line = []
524                continue
525            current_line.append(v)
526        for line in lines:
527            for i, e in enumerate(line):
528                args_positions[line[(i + 1) % len(line)]] = e
529
530        # TODO: function in order to permute the args:
531        permutation_blocks = [[new_permutation[cumulative_subranks[i] + j] for j in range(e)] for i, e in enumerate(subranks)]
532        new_args = [new_args[i] for i in args_positions]
533        new_permutation_blocks = [permutation_blocks[i] for i in args_positions]
534        new_permutation2 = [j for i in new_permutation_blocks for j in i]
535        return ArrayTensorProduct(*new_args), Permutation(new_permutation2)  # **(-1)
536
537    @classmethod
538    def _check_if_there_are_closed_cycles(cls, expr, permutation):
539        args = list(expr.args)
540        subranks = expr.subranks
541        cyclic_form = permutation.cyclic_form
542        cumulative_subranks = [0] + list(accumulate(subranks))
543        cyclic_min = [min(i) for i in cyclic_form]
544        cyclic_max = [max(i) for i in cyclic_form]
545        cyclic_keep = []
546        for i, cycle in enumerate(cyclic_form):
547            flag = True
548            for j in range(0, len(cumulative_subranks) - 1):
549                if cyclic_min[i] >= cumulative_subranks[j] and cyclic_max[i] < cumulative_subranks[j+1]:
550                    # Found a sinkable cycle.
551                    args[j] = PermuteDims(args[j], Permutation([[k - cumulative_subranks[j] for k in cyclic_form[i]]]))
552                    flag = False
553                    break
554            if flag:
555                cyclic_keep.append(cyclic_form[i])
556        return ArrayTensorProduct(*args), Permutation(cyclic_keep, size=permutation.size)
557
558    def nest_permutation(self):
559        r"""
560        DEPRECATED.
561        """
562        ret = self._nest_permutation(self.expr, self.permutation)
563        if ret is None:
564            return self
565        return ret
566
567    @classmethod
568    def _nest_permutation(cls, expr, permutation):
569        if isinstance(expr, ArrayTensorProduct):
570            return PermuteDims(*cls._check_if_there_are_closed_cycles(expr, permutation))
571        elif isinstance(expr, ArrayContraction):
572            # Invert tree hierarchy: put the contraction above.
573            cycles = permutation.cyclic_form
574            newcycles = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *cycles)
575            newpermutation = Permutation(newcycles)
576            new_contr_indices = [tuple(newpermutation(j) for j in i) for i in expr.contraction_indices]
577            return ArrayContraction(PermuteDims(expr.expr, newpermutation), *new_contr_indices)
578        elif isinstance(expr, ArrayAdd):
579            return ArrayAdd(*[PermuteDims(arg, permutation) for arg in expr.args])
580        return None
581
582    def as_explicit(self):
583        return permutedims(self.expr.as_explicit(), self.permutation)
584
585
586class ArrayDiagonal(_CodegenArrayAbstract):
587    r"""
588    Class to represent the diagonal operator.
589
590    Explanation
591    ===========
592
593    In a 2-dimensional array it returns the diagonal, this looks like the
594    operation:
595
596    `A_{ij} \rightarrow A_{ii}`
597
598    The diagonal over axes 1 and 2 (the second and third) of the tensor product
599    of two 2-dimensional arrays `A \otimes B` is
600
601    `\Big[ A_{ab} B_{cd} \Big]_{abcd} \rightarrow \Big[ A_{ai} B_{id} \Big]_{adi}`
602
603    In this last example the array expression has been reduced from
604    4-dimensional to 3-dimensional. Notice that no contraction has occurred,
605    rather there is a new index `i` for the diagonal, contraction would have
606    reduced the array to 2 dimensions.
607
608    Notice that the diagonalized out dimensions are added as new dimensions at
609    the end of the indices.
610    """
611
612    def __new__(cls, expr, *diagonal_indices):
613        expr = _sympify(expr)
614        diagonal_indices = [Tuple(*sorted(i)) for i in diagonal_indices]
615        if isinstance(expr, ArrayAdd):
616            return ArrayAdd(*[ArrayDiagonal(arg, *diagonal_indices) for arg in expr.args])
617        if isinstance(expr, ArrayDiagonal):
618            return cls._flatten(expr, *diagonal_indices)
619        if isinstance(expr, PermuteDims):
620            return cls._handle_nested_permutedims_in_diag(expr, *diagonal_indices)
621        shape = get_shape(expr)
622        if shape is not None:
623            cls._validate(expr, *diagonal_indices)
624            # Get new shape:
625            positions, shape = cls._get_positions_shape(shape, diagonal_indices)
626        else:
627            positions = None
628        if len(diagonal_indices) == 0:
629            return expr
630        if isinstance(expr, (ZeroArray, ZeroMatrix)):
631            return ZeroArray(*shape)
632        obj = Basic.__new__(cls, expr, *diagonal_indices)
633        obj._positions = positions
634        obj._subranks = _get_subranks(expr)
635        obj._shape = shape
636        return obj
637
638    @staticmethod
639    def _validate(expr, *diagonal_indices):
640        # Check that no diagonalization happens on indices with mismatched
641        # dimensions:
642        shape = get_shape(expr)
643        for i in diagonal_indices:
644            if len({shape[j] for j in i}) != 1:
645                raise ValueError("diagonalizing indices of different dimensions")
646            if len(i) <= 1:
647                raise ValueError("need at least two axes to diagonalize")
648
649    @staticmethod
650    def _remove_trivial_dimensions(shape, *diagonal_indices):
651        return [tuple(j for j in i) for i in diagonal_indices if shape[i[0]] != 1]
652
653    @property
654    def expr(self):
655        return self.args[0]
656
657    @property
658    def diagonal_indices(self):
659        return self.args[1:]
660
661    @staticmethod
662    def _flatten(expr, *outer_diagonal_indices):
663        inner_diagonal_indices = expr.diagonal_indices
664        all_inner = [j for i in inner_diagonal_indices for j in i]
665        all_inner.sort()
666        # TODO: add API for total rank and cumulative rank:
667        total_rank = _get_subrank(expr)
668        inner_rank = len(all_inner)
669        outer_rank = total_rank - inner_rank
670        shifts = [0 for i in range(outer_rank)]
671        counter = 0
672        pointer = 0
673        for i in range(outer_rank):
674            while pointer < inner_rank and counter >= all_inner[pointer]:
675                counter += 1
676                pointer += 1
677            shifts[i] += pointer
678            counter += 1
679        outer_diagonal_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_diagonal_indices)
680        diagonal_indices = inner_diagonal_indices + outer_diagonal_indices
681        return ArrayDiagonal(expr.expr, *diagonal_indices)
682
683    @classmethod
684    def _handle_nested_permutedims_in_diag(cls, expr: PermuteDims, *diagonal_indices):
685        back_diagonal_indices = [[expr.permutation(j) for j in i] for i in diagonal_indices]
686        nondiag = [i for i in range(get_rank(expr)) if not any(i in j for j in diagonal_indices)]
687        back_nondiag = [expr.permutation(i) for i in nondiag]
688        remap = {e: i for i, e in enumerate(sorted(back_nondiag))}
689        new_permutation1 = [remap[i] for i in back_nondiag]
690        shift = len(new_permutation1)
691        diag_block_perm = [i + shift for i in range(len(back_diagonal_indices))]
692        new_permutation = new_permutation1 + diag_block_perm
693        return PermuteDims(
694            ArrayDiagonal(
695                expr.expr,
696                *back_diagonal_indices
697            ),
698            new_permutation
699        )
700
701    def _push_indices_down_nonstatic(self, indices):
702        transform = lambda x: self._positions[x] if x < len(self._positions) else None
703        return _apply_recursively_over_nested_lists(transform, indices)
704
705    def _push_indices_up_nonstatic(self, indices):
706
707        def transform(x):
708            for i, e in enumerate(self._positions):
709                if (isinstance(e, int) and x == e) or (isinstance(e, tuple) and x in e):
710                    return i
711
712        return _apply_recursively_over_nested_lists(transform, indices)
713
714    @classmethod
715    def _push_indices_down(cls, diagonal_indices, indices, rank):
716        positions, shape = cls._get_positions_shape(range(rank), diagonal_indices)
717        transform = lambda x: positions[x] if x < len(positions) else None
718        return _apply_recursively_over_nested_lists(transform, indices)
719
720    @classmethod
721    def _push_indices_up(cls, diagonal_indices, indices, rank):
722        positions, shape = cls._get_positions_shape(range(rank), diagonal_indices)
723
724        def transform(x):
725            for i, e in enumerate(positions):
726                if (isinstance(e, int) and x == e) or (isinstance(e, (tuple, Tuple)) and (x in e)):
727                    return i
728
729        return _apply_recursively_over_nested_lists(transform, indices)
730
731    @classmethod
732    def _get_positions_shape(cls, shape, diagonal_indices):
733        data1 = tuple((i, shp) for i, shp in enumerate(shape) if not any(i in j for j in diagonal_indices))
734        pos1, shp1 = zip(*data1) if data1 else ((), ())
735        data2 = tuple((i, shape[i[0]]) for i in diagonal_indices)
736        pos2, shp2 = zip(*data2) if data2 else ((), ())
737        positions = pos1 + pos2
738        shape = shp1 + shp2
739        return positions, shape
740
741    def as_explicit(self):
742        return tensordiagonal(self.expr.as_explicit(), *self.diagonal_indices)
743
744
745class ArrayElementwiseApplyFunc(_CodegenArrayAbstract):
746
747    def __new__(cls, function, element):
748
749        if not isinstance(function, Lambda):
750            d = Dummy('d')
751            function = Lambda(d, function(d))
752
753        obj = _CodegenArrayAbstract.__new__(cls, function, element)
754        obj._subranks = _get_subranks(element)
755        return obj
756
757    @property
758    def function(self):
759        return self.args[0]
760
761    @property
762    def expr(self):
763        return self.args[1]
764
765    @property
766    def shape(self):
767        return self.expr.shape
768
769    def _get_function_fdiff(self):
770        d = Dummy("d")
771        function = self.function(d)
772        fdiff = function.diff(d)
773        if isinstance(fdiff, Function):
774            fdiff = type(fdiff)
775        else:
776            fdiff = Lambda(d, fdiff)
777        return fdiff
778
779
780class ArrayContraction(_CodegenArrayAbstract):
781    r"""
782    This class is meant to represent contractions of arrays in a form easily
783    processable by the code printers.
784    """
785
786    def __new__(cls, expr, *contraction_indices, **kwargs):
787        contraction_indices = _sort_contraction_indices(contraction_indices)
788        expr = _sympify(expr)
789
790        if len(contraction_indices) == 0:
791            return expr
792
793        if isinstance(expr, ArrayContraction):
794            return cls._flatten(expr, *contraction_indices)
795
796        if isinstance(expr, (ZeroArray, ZeroMatrix)):
797            contraction_indices_flat = [j for i in contraction_indices for j in i]
798            shape = [e for i, e in enumerate(expr.shape) if i not in contraction_indices_flat]
799            return ZeroArray(*shape)
800
801        if isinstance(expr, PermuteDims):
802            return cls._handle_nested_permute_dims(expr, *contraction_indices)
803
804        if isinstance(expr, ArrayTensorProduct):
805            expr, contraction_indices = cls._sort_fully_contracted_args(expr, contraction_indices)
806            expr, contraction_indices = cls._lower_contraction_to_addends(expr, contraction_indices)
807            if len(contraction_indices) == 0:
808                return expr
809
810        if isinstance(expr, ArrayDiagonal):
811            return cls._handle_nested_diagonal(expr, *contraction_indices)
812
813        if isinstance(expr, ArrayAdd):
814            return ArrayAdd(*[ArrayContraction(i, *contraction_indices) for i in expr.args])
815
816        obj = Basic.__new__(cls, expr, *contraction_indices)
817        obj._subranks = _get_subranks(expr)
818        obj._mapping = _get_mapping_from_subranks(obj._subranks)
819
820        free_indices_to_position = {i: i for i in range(sum(obj._subranks)) if all([i not in cind for cind in contraction_indices])}
821        obj._free_indices_to_position = free_indices_to_position
822
823        shape = expr.shape
824        cls._validate(expr, *contraction_indices)
825        if shape:
826            shape = tuple(shp for i, shp in enumerate(shape) if not any(i in j for j in contraction_indices))
827        obj._shape = shape
828        return obj
829
830    def __mul__(self, other):
831        if other == 1:
832            return self
833        else:
834            raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.")
835
836    def __rmul__(self, other):
837        if other == 1:
838            return self
839        else:
840            raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.")
841
842    @staticmethod
843    def _validate(expr, *contraction_indices):
844        shape = expr.shape
845        if shape is None:
846            return
847
848        # Check that no contraction happens when the shape is mismatched:
849        for i in contraction_indices:
850            if len({shape[j] for j in i if shape[j] != -1}) != 1:
851                raise ValueError("contracting indices of different dimensions")
852
853    @classmethod
854    def _push_indices_down(cls, contraction_indices, indices):
855        flattened_contraction_indices = [j for i in contraction_indices for j in i]
856        flattened_contraction_indices.sort()
857        transform = _build_push_indices_down_func_transformation(flattened_contraction_indices)
858        return _apply_recursively_over_nested_lists(transform, indices)
859
860    @classmethod
861    def _push_indices_up(cls, contraction_indices, indices):
862        flattened_contraction_indices = [j for i in contraction_indices for j in i]
863        flattened_contraction_indices.sort()
864        transform = _build_push_indices_up_func_transformation(flattened_contraction_indices)
865        return _apply_recursively_over_nested_lists(transform, indices)
866
867    @classmethod
868    def _lower_contraction_to_addends(cls, expr, contraction_indices):
869        if isinstance(expr, ArrayAdd):
870            raise NotImplementedError()
871        if not isinstance(expr, ArrayTensorProduct):
872            return expr, contraction_indices
873        subranks = expr.subranks
874        cumranks = list(accumulate([0] + subranks))
875        contraction_indices_remaining = []
876        contraction_indices_args = [[] for i in expr.args]
877        backshift = set([])
878        for i, contraction_group in enumerate(contraction_indices):
879            for j in range(len(expr.args)):
880                if not isinstance(expr.args[j], ArrayAdd):
881                    continue
882                if all(cumranks[j] <= k < cumranks[j+1] for k in contraction_group):
883                    contraction_indices_args[j].append([k - cumranks[j] for k in contraction_group])
884                    backshift.update(contraction_group)
885                    break
886            else:
887                contraction_indices_remaining.append(contraction_group)
888        if len(contraction_indices_remaining) == len(contraction_indices):
889            return expr, contraction_indices
890        total_rank = get_rank(expr)
891        shifts = list(accumulate([1 if i in backshift else 0 for i in range(total_rank)]))
892        contraction_indices_remaining = [Tuple.fromiter(j - shifts[j] for j in i) for i in contraction_indices_remaining]
893        ret = ArrayTensorProduct(*[
894            ArrayContraction(arg, *contr) for arg, contr in zip(expr.args, contraction_indices_args)
895        ])
896        return ret, contraction_indices_remaining
897
898    def split_multiple_contractions(self):
899        """
900        Recognize multiple contractions and attempt at rewriting them as paired-contractions.
901
902        This allows some contractions involving more than two indices to be
903        rewritten as multiple contractions involving two indices, thus allowing
904        the expression to be rewritten as a matrix multiplication line.
905
906        Examples:
907
908        * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C`
909
910        Care for:
911        - matrix being diagonalized (i.e. `A_ii`)
912        - vectors being diagonalized (i.e. `a_i0`)
913
914        Multiple contractions can be split into matrix multiplications if
915        not more than two arguments are non-diagonals or non-vectors.
916        Vectors get diagonalized while diagonal matrices remain diagonal.
917        The non-diagonal matrices can be at the beginning or at the end
918        of the final matrix multiplication line.
919        """
920        from sympy import ask, Q
921
922        editor = _EditArrayContraction(self)
923
924        contraction_indices = self.contraction_indices
925        if isinstance(self.expr, ArrayTensorProduct):
926            args = list(self.expr.args)
927        else:
928            args = [self.expr]
929        # TODO: unify API, best location in ArrayTensorProduct
930        subranks = [get_rank(i) for i in args]
931        # TODO: unify API
932        mapping = _get_mapping_from_subranks(subranks)
933        reverse_mapping = {v: k for k, v in mapping.items()}
934
935        for indl, links in enumerate(contraction_indices):
936            if len(links) <= 2:
937                continue
938
939            positions = editor.get_mapping_for_index(indl)
940
941            # Also consider the case of diagonal matrices being contracted:
942            current_dimension = self.expr.shape[links[0]]
943
944            not_vectors: Tuple[_ArgE, int] = []
945            vectors: Tuple[_ArgE, int] = []
946            for arg_ind, rel_ind in positions:
947                mat = args[arg_ind]
948                other_arg_pos = 1-rel_ind
949                other_arg_abs = reverse_mapping[arg_ind, other_arg_pos]
950                arg = editor.args_with_ind[arg_ind]
951                if (((1 not in mat.shape) and (not ask(Q.diagonal(mat)))) or
952                    ((current_dimension == 1) is True and mat.shape != (1, 1)) or
953                    any([other_arg_abs in l for li, l in enumerate(contraction_indices) if li != indl])
954                ):
955                    not_vectors.append((arg, rel_ind))
956                else:
957                    vectors.append((arg, rel_ind))
958            if len(not_vectors) > 2:
959                # If more than two arguments in the multiple contraction are
960                # non-vectors and non-diagonal matrices, we cannot find a way
961                # to split this contraction into a matrix multiplication line:
962                continue
963            # Three cases to handle:
964            # - zero non-vectors
965            # - one non-vector
966            # - two non-vectors
967            for v, rel_ind in vectors:
968                v.element = diagonalize_vector(v.element)
969            vectors_to_loop = not_vectors[:1] + vectors + not_vectors[1:]
970            first_not_vector, rel_ind = vectors_to_loop[0]
971            new_index = first_not_vector.indices[rel_ind]
972
973            for v, rel_ind in vectors_to_loop[1:-1]:
974                v.indices[rel_ind] = new_index
975                new_index = editor.get_new_contraction_index()
976                assert v.indices.index(None) == 1 - rel_ind
977                v.indices[v.indices.index(None)] = new_index
978
979            last_vec, rel_ind = vectors_to_loop[-1]
980            last_vec.indices[rel_ind] = new_index
981        return editor.to_array_contraction()
982
983    def flatten_contraction_of_diagonal(self):
984        if not isinstance(self.expr, ArrayDiagonal):
985            return self
986        contraction_down = self.expr._push_indices_down(self.expr.diagonal_indices, self.contraction_indices)
987        new_contraction_indices = []
988        diagonal_indices = self.expr.diagonal_indices[:]
989        for i in contraction_down:
990            contraction_group = list(i)
991            for j in i:
992                diagonal_with = [k for k in diagonal_indices if j in k]
993                contraction_group.extend([l for k in diagonal_with for l in k])
994                diagonal_indices = [k for k in diagonal_indices if k not in diagonal_with]
995            new_contraction_indices.append(sorted(set(contraction_group)))
996
997        new_contraction_indices = ArrayDiagonal._push_indices_up(diagonal_indices, new_contraction_indices)
998        return ArrayContraction(
999            ArrayDiagonal(
1000                self.expr.expr,
1001                *diagonal_indices
1002            ),
1003            *new_contraction_indices
1004        )
1005
1006    @staticmethod
1007    def _get_free_indices_to_position_map(free_indices, contraction_indices):
1008        free_indices_to_position = {}
1009        flattened_contraction_indices = [j for i in contraction_indices for j in i]
1010        counter = 0
1011        for ind in free_indices:
1012            while counter in flattened_contraction_indices:
1013                counter += 1
1014            free_indices_to_position[ind] = counter
1015            counter += 1
1016        return free_indices_to_position
1017
1018    @staticmethod
1019    def _get_index_shifts(expr):
1020        """
1021        Get the mapping of indices at the positions before the contraction
1022        occurs.
1023
1024        Examples
1025        ========
1026
1027        >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
1028        >>> from sympy.tensor.array.expressions.array_expressions import ArrayContraction
1029        >>> from sympy import MatrixSymbol
1030        >>> M = MatrixSymbol("M", 3, 3)
1031        >>> N = MatrixSymbol("N", 3, 3)
1032        >>> cg = ArrayContraction(ArrayTensorProduct(M, N), [1, 2])
1033        >>> cg._get_index_shifts(cg)
1034        [0, 2]
1035
1036        Indeed, ``cg`` after the contraction has two dimensions, 0 and 1. They
1037        need to be shifted by 0 and 2 to get the corresponding positions before
1038        the contraction (that is, 0 and 3).
1039        """
1040        inner_contraction_indices = expr.contraction_indices
1041        all_inner = [j for i in inner_contraction_indices for j in i]
1042        all_inner.sort()
1043        # TODO: add API for total rank and cumulative rank:
1044        total_rank = _get_subrank(expr)
1045        inner_rank = len(all_inner)
1046        outer_rank = total_rank - inner_rank
1047        shifts = [0 for i in range(outer_rank)]
1048        counter = 0
1049        pointer = 0
1050        for i in range(outer_rank):
1051            while pointer < inner_rank and counter >= all_inner[pointer]:
1052                counter += 1
1053                pointer += 1
1054            shifts[i] += pointer
1055            counter += 1
1056        return shifts
1057
1058    @staticmethod
1059    def _convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices):
1060        shifts = ArrayContraction._get_index_shifts(expr)
1061        outer_contraction_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_contraction_indices)
1062        return outer_contraction_indices
1063
1064    @staticmethod
1065    def _flatten(expr, *outer_contraction_indices):
1066        inner_contraction_indices = expr.contraction_indices
1067        outer_contraction_indices = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices)
1068        contraction_indices = inner_contraction_indices + outer_contraction_indices
1069        return ArrayContraction(expr.expr, *contraction_indices)
1070
1071    @classmethod
1072    def _handle_nested_permute_dims(cls, expr, *contraction_indices):
1073        permutation = expr.permutation
1074        plist = permutation.array_form
1075        new_contraction_indices = [tuple(permutation(j) for j in i) for i in contraction_indices]
1076        new_plist = [i for i in plist if all(i not in j for j in new_contraction_indices)]
1077        new_plist = cls._push_indices_up(new_contraction_indices, new_plist)
1078        return PermuteDims(
1079            ArrayContraction(expr.expr, *new_contraction_indices),
1080            Permutation(new_plist)
1081        )
1082
1083    @classmethod
1084    def _handle_nested_diagonal(cls, expr: 'ArrayDiagonal', *contraction_indices):
1085        diagonal_indices = list(expr.diagonal_indices)
1086        down_contraction_indices = expr._push_indices_down(expr.diagonal_indices, contraction_indices, get_rank(expr.expr))
1087        # Flatten diagonally contracted indices:
1088        down_contraction_indices = [[k for j in i for k in (j if isinstance(j, (tuple, Tuple)) else [j])] for i in down_contraction_indices]
1089        new_contraction_indices = []
1090        for contr_indgrp in down_contraction_indices:
1091            ind = contr_indgrp[:]
1092            for j, diag_indgrp in enumerate(diagonal_indices):
1093                if diag_indgrp is None:
1094                    continue
1095                if any(i in diag_indgrp for i in contr_indgrp):
1096                    ind.extend(diag_indgrp)
1097                    diagonal_indices[j] = None
1098            new_contraction_indices.append(sorted(set(ind)))
1099
1100        new_diagonal_indices_down = [i for i in diagonal_indices if i is not None]
1101        new_diagonal_indices = ArrayContraction._push_indices_up(new_contraction_indices, new_diagonal_indices_down)
1102        return ArrayDiagonal(
1103            ArrayContraction(expr.expr, *new_contraction_indices),
1104            *new_diagonal_indices
1105        )
1106
1107    @classmethod
1108    def _sort_fully_contracted_args(cls, expr, contraction_indices):
1109        if expr.shape is None:
1110            return expr, contraction_indices
1111        cumul = list(accumulate([0] + expr.subranks))
1112        index_blocks = [list(range(cumul[i], cumul[i+1])) for i in range(len(expr.args))]
1113        contraction_indices_flat = {j for i in contraction_indices for j in i}
1114        fully_contracted = [all(j in contraction_indices_flat for j in range(cumul[i], cumul[i+1])) for i, arg in enumerate(expr.args)]
1115        new_pos = sorted(range(len(expr.args)), key=lambda x: (0, default_sort_key(expr.args[x])) if fully_contracted[x] else (1,))
1116        new_args = [expr.args[i] for i in new_pos]
1117        new_index_blocks_flat = [j for i in new_pos for j in index_blocks[i]]
1118        index_permutation_array_form = _af_invert(new_index_blocks_flat)
1119        new_contraction_indices = [tuple(index_permutation_array_form[j] for j in i) for i in contraction_indices]
1120        new_contraction_indices = _sort_contraction_indices(new_contraction_indices)
1121        return ArrayTensorProduct(*new_args), new_contraction_indices
1122
1123    def _get_contraction_tuples(self):
1124        r"""
1125        Return tuples containing the argument index and position within the
1126        argument of the index position.
1127
1128        Examples
1129        ========
1130
1131        >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
1132        >>> from sympy import MatrixSymbol
1133        >>> from sympy.abc import N
1134        >>> from sympy.tensor.array.expressions.array_expressions import ArrayContraction
1135        >>> A = MatrixSymbol("A", N, N)
1136        >>> B = MatrixSymbol("B", N, N)
1137
1138        >>> cg = ArrayContraction(ArrayTensorProduct(A, B), (1, 2))
1139        >>> cg._get_contraction_tuples()
1140        [[(0, 1), (1, 0)]]
1141
1142        Notes
1143        =====
1144
1145        Here the contraction pair `(1, 2)` meaning that the 2nd and 3rd indices
1146        of the tensor product `A\otimes B` are contracted, has been transformed
1147        into `(0, 1)` and `(1, 0)`, identifying the same indices in a different
1148        notation. `(0, 1)` is the second index (1) of the first argument (i.e.
1149                0 or `A`). `(1, 0)` is the first index (i.e. 0) of the second
1150        argument (i.e. 1 or `B`).
1151        """
1152        mapping = self._mapping
1153        return [[mapping[j] for j in i] for i in self.contraction_indices]
1154
1155    @staticmethod
1156    def _contraction_tuples_to_contraction_indices(expr, contraction_tuples):
1157        # TODO: check that `expr` has `.subranks`:
1158        ranks = expr.subranks
1159        cumulative_ranks = [0] + list(accumulate(ranks))
1160        return [tuple(cumulative_ranks[j]+k for j, k in i) for i in contraction_tuples]
1161
1162    @property
1163    def free_indices(self):
1164        return self._free_indices[:]
1165
1166    @property
1167    def free_indices_to_position(self):
1168        return dict(self._free_indices_to_position)
1169
1170    @property
1171    def expr(self):
1172        return self.args[0]
1173
1174    @property
1175    def contraction_indices(self):
1176        return self.args[1:]
1177
1178    def _contraction_indices_to_components(self):
1179        expr = self.expr
1180        if not isinstance(expr, ArrayTensorProduct):
1181            raise NotImplementedError("only for contractions of tensor products")
1182        ranks = expr.subranks
1183        mapping = {}
1184        counter = 0
1185        for i, rank in enumerate(ranks):
1186            for j in range(rank):
1187                mapping[counter] = (i, j)
1188                counter += 1
1189        return mapping
1190
1191    def sort_args_by_name(self):
1192        """
1193        Sort arguments in the tensor product so that their order is lexicographical.
1194
1195        Examples
1196        ========
1197
1198        >>> from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
1199        >>> from sympy import MatrixSymbol
1200        >>> from sympy.abc import N
1201        >>> A = MatrixSymbol("A", N, N)
1202        >>> B = MatrixSymbol("B", N, N)
1203        >>> C = MatrixSymbol("C", N, N)
1204        >>> D = MatrixSymbol("D", N, N)
1205
1206        >>> cg = convert_matrix_to_array(C*D*A*B)
1207        >>> cg
1208        ArrayContraction(ArrayTensorProduct(A, D, C, B), (0, 3), (1, 6), (2, 5))
1209        >>> cg.sort_args_by_name()
1210        ArrayContraction(ArrayTensorProduct(A, D, B, C), (0, 3), (1, 4), (2, 7))
1211        """
1212        expr = self.expr
1213        if not isinstance(expr, ArrayTensorProduct):
1214            return self
1215        args = expr.args
1216        sorted_data = sorted(enumerate(args), key=lambda x: default_sort_key(x[1]))
1217        pos_sorted, args_sorted = zip(*sorted_data)
1218        reordering_map = {i: pos_sorted.index(i) for i, arg in enumerate(args)}
1219        contraction_tuples = self._get_contraction_tuples()
1220        contraction_tuples = [[(reordering_map[j], k) for j, k in i] for i in contraction_tuples]
1221        c_tp = ArrayTensorProduct(*args_sorted)
1222        new_contr_indices = self._contraction_tuples_to_contraction_indices(
1223                c_tp,
1224                contraction_tuples
1225        )
1226        return ArrayContraction(c_tp, *new_contr_indices)
1227
1228    def _get_contraction_links(self):
1229        r"""
1230        Returns a dictionary of links between arguments in the tensor product
1231        being contracted.
1232
1233        See the example for an explanation of the values.
1234
1235        Examples
1236        ========
1237
1238        >>> from sympy import MatrixSymbol
1239        >>> from sympy.abc import N
1240        >>> from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
1241        >>> A = MatrixSymbol("A", N, N)
1242        >>> B = MatrixSymbol("B", N, N)
1243        >>> C = MatrixSymbol("C", N, N)
1244        >>> D = MatrixSymbol("D", N, N)
1245
1246        Matrix multiplications are pairwise contractions between neighboring
1247        matrices:
1248
1249        `A_{ij} B_{jk} C_{kl} D_{lm}`
1250
1251        >>> cg = convert_matrix_to_array(A*B*C*D)
1252        >>> cg
1253        ArrayContraction(ArrayTensorProduct(B, C, A, D), (0, 5), (1, 2), (3, 6))
1254
1255        >>> cg._get_contraction_links()
1256        {0: {0: (2, 1), 1: (1, 0)}, 1: {0: (0, 1), 1: (3, 0)}, 2: {1: (0, 0)}, 3: {0: (1, 1)}}
1257
1258        This dictionary is interpreted as follows: argument in position 0 (i.e.
1259        matrix `A`) has its second index (i.e. 1) contracted to `(1, 0)`, that
1260        is argument in position 1 (matrix `B`) on the first index slot of `B`,
1261        this is the contraction provided by the index `j` from `A`.
1262
1263        The argument in position 1 (that is, matrix `B`) has two contractions,
1264        the ones provided by the indices `j` and `k`, respectively the first
1265        and second indices (0 and 1 in the sub-dict).  The link `(0, 1)` and
1266        `(2, 0)` respectively. `(0, 1)` is the index slot 1 (the 2nd) of
1267        argument in position 0 (that is, `A_{\ldot j}`), and so on.
1268        """
1269        args, dlinks = _get_contraction_links([self], self.subranks, *self.contraction_indices)
1270        return dlinks
1271
1272    def as_explicit(self):
1273        return tensorcontraction(self.expr.as_explicit(), *self.contraction_indices)
1274
1275
1276class _ArgE:
1277    """
1278    The ``_ArgE`` object contains references to the array expression
1279    (``.element``) and a list containing the information about index
1280    contractions (``.indices``).
1281
1282    Index contractions are numbered and contracted indices show the number of
1283    the contraction. Uncontracted indices have ``None`` value.
1284
1285    For example:
1286    ``_ArgE(M, [None, 3])``
1287    This object means that expression ``M`` is part of an array contraction
1288    and has two indices, the first is not contracted (value ``None``),
1289    the second index is contracted to the 4th (i.e. number ``3``) group of the
1290    array contraction object.
1291    """
1292    def __init__(self, element, indices: Optional[List[Optional[int]]] = None):
1293        self.element = element
1294        if indices is None:
1295            self.indices: List[Optional[int]] = [None for i in range(get_rank(element))]
1296        else:
1297            self.indices: List[Optional[int]] = indices
1298
1299    def __str__(self):
1300        return "_ArgE(%s, %s)" % (self.element, self.indices)
1301
1302    __repr__ = __str__
1303
1304
1305class _IndPos:
1306    """
1307    Index position, requiring two integers in the constructor:
1308
1309    - arg: the position of the argument in the tensor product,
1310    - rel: the relative position of the index inside the argument.
1311    """
1312    def __init__(self, arg: int, rel: int):
1313        self.arg = arg
1314        self.rel = rel
1315
1316    def __str__(self):
1317        return "_IndPos(%i, %i)" % (self.arg, self.rel)
1318
1319    __repr__ = __str__
1320
1321    def __iter__(self):
1322        yield from [self.arg, self.rel]
1323
1324
1325class _EditArrayContraction:
1326    """
1327    Utility class to help manipulate array contraction objects.
1328
1329    This class takes as input an ``ArrayContraction`` object and turns it into
1330    an editable object.
1331
1332    The field ``args_with_ind`` of this class is a list of ``_ArgE`` objects
1333    which can be used to easily edit the contraction structure of the
1334    expression.
1335
1336    Once editing is finished, the ``ArrayContraction`` object may be recreated
1337    by calling the ``.to_array_contraction()`` method.
1338    """
1339
1340    def __init__(self, array_contraction: Optional[ArrayContraction]):
1341        if array_contraction is None:
1342            self.args_with_ind: List[_ArgE] = []
1343            self.number_of_contraction_indices: int = 0
1344            self._track_permutation: Optional[List[int]] = None
1345            return
1346        expr = array_contraction.expr
1347        if isinstance(expr, ArrayTensorProduct):
1348            args = list(expr.args)
1349        else:
1350            args = [expr]
1351        args_with_ind: List[_ArgE] = [_ArgE(arg) for arg in args]
1352        mapping = _get_mapping_from_subranks(array_contraction.subranks)
1353        for i, contraction_tuple in enumerate(array_contraction.contraction_indices):
1354            for j in contraction_tuple:
1355                arg_pos, rel_pos = mapping[j]
1356                args_with_ind[arg_pos].indices[rel_pos] = i
1357        self.args_with_ind: List[_ArgE] = args_with_ind
1358        self.number_of_contraction_indices: int = len(array_contraction.contraction_indices)
1359        self._track_permutation: Optional[List[int]] = None
1360
1361    def insert_after(self, arg: _ArgE, new_arg: _ArgE):
1362        pos = self.args_with_ind.index(arg)
1363        self.args_with_ind.insert(pos + 1, new_arg)
1364
1365    def get_new_contraction_index(self):
1366        self.number_of_contraction_indices += 1
1367        return self.number_of_contraction_indices - 1
1368
1369    def refresh_indices(self):
1370        updates: Dict[int, int] = {}
1371        for arg_with_ind in self.args_with_ind:
1372            updates.update({i: -1 for i in arg_with_ind.indices if i is not None})
1373        for i, e in enumerate(sorted(updates)):
1374            updates[e] = i
1375        self.number_of_contraction_indices: int = len(updates)
1376        for arg_with_ind in self.args_with_ind:
1377            arg_with_ind.indices = [updates.get(i, None) for i in arg_with_ind.indices]
1378
1379    def merge_scalars(self):
1380        scalars = []
1381        for arg_with_ind in self.args_with_ind:
1382            if len(arg_with_ind.indices) == 0:
1383                scalars.append(arg_with_ind)
1384        for i in scalars:
1385            self.args_with_ind.remove(i)
1386        scalar = Mul.fromiter([i.element for i in scalars])
1387        if len(self.args_with_ind) == 0:
1388            self.args_with_ind.append(_ArgE(scalar))
1389        else:
1390            from sympy.tensor.array.expressions.conv_array_to_matrix import _a2m_tensor_product
1391            self.args_with_ind[0].element = _a2m_tensor_product(scalar, self.args_with_ind[0].element)
1392
1393    def to_array_contraction(self):
1394        self.merge_scalars()
1395        self.refresh_indices()
1396        args = [arg.element for arg in self.args_with_ind]
1397        contraction_indices = self.get_contraction_indices()
1398        expr = ArrayContraction(ArrayTensorProduct(*args), *contraction_indices)
1399        if self._track_permutation is not None:
1400            permutation = _af_invert([j for i in self._track_permutation for j in i])
1401            expr = PermuteDims(expr, permutation)
1402        return expr
1403
1404    def get_contraction_indices(self) -> List[List[int]]:
1405        contraction_indices: List[List[int]] = [[] for i in range(self.number_of_contraction_indices)]
1406        current_position: int = 0
1407        for i, arg_with_ind in enumerate(self.args_with_ind):
1408            for j in arg_with_ind.indices:
1409                if j is not None:
1410                    contraction_indices[j].append(current_position)
1411                current_position += 1
1412        return contraction_indices
1413
1414    def get_mapping_for_index(self, ind) -> List[_IndPos]:
1415        if ind >= self.number_of_contraction_indices:
1416            raise ValueError("index value exceeding the index range")
1417        positions: List[_IndPos] = []
1418        for i, arg_with_ind in enumerate(self.args_with_ind):
1419            for j, arg_ind in enumerate(arg_with_ind.indices):
1420                if ind == arg_ind:
1421                    positions.append(_IndPos(i, j))
1422        return positions
1423
1424    def get_contraction_indices_to_ind_rel_pos(self) -> List[List[_IndPos]]:
1425        contraction_indices: List[List[_IndPos]] = [[] for i in range(self.number_of_contraction_indices)]
1426        for i, arg_with_ind in enumerate(self.args_with_ind):
1427            for j, ind in enumerate(arg_with_ind.indices):
1428                if ind is not None:
1429                    contraction_indices[ind].append(_IndPos(i, j))
1430        return contraction_indices
1431
1432    def count_args_with_index(self, index: int) -> int:
1433        """
1434        Count the number of arguments that have the given index.
1435        """
1436        counter: int = 0
1437        for arg_with_ind in self.args_with_ind:
1438            if index in arg_with_ind.indices:
1439                counter += 1
1440        return counter
1441
1442    def track_permutation_start(self):
1443        self._track_permutation = []
1444        counter: int = 0
1445        for arg_with_ind in self.args_with_ind:
1446            perm = []
1447            for i in arg_with_ind.indices:
1448                if i is not None:
1449                    continue
1450                perm.append(counter)
1451                counter += 1
1452            self._track_permutation.append(perm)
1453
1454    def track_permutation_merge(self, destination: _ArgE, from_element: _ArgE):
1455        index_destination = self.args_with_ind.index(destination)
1456        index_element = self.args_with_ind.index(from_element)
1457        self._track_permutation[index_destination].extend(self._track_permutation[index_element])
1458        self._track_permutation.pop(index_element)
1459
1460
1461def get_rank(expr):
1462    if isinstance(expr, (MatrixExpr, MatrixElement)):
1463        return 2
1464    if isinstance(expr, _CodegenArrayAbstract):
1465        return len(expr.shape)
1466    if isinstance(expr, NDimArray):
1467        return expr.rank()
1468    if isinstance(expr, Indexed):
1469        return expr.rank
1470    if isinstance(expr, IndexedBase):
1471        shape = expr.shape
1472        if shape is None:
1473            return -1
1474        else:
1475            return len(shape)
1476    if hasattr(expr, "shape"):
1477        return len(expr.shape)
1478    return 0
1479
1480
1481def _get_subrank(expr):
1482    if isinstance(expr, _CodegenArrayAbstract):
1483        return expr.subrank()
1484    return get_rank(expr)
1485
1486
1487def _get_subranks(expr):
1488    if isinstance(expr, _CodegenArrayAbstract):
1489        return expr.subranks
1490    else:
1491        return [get_rank(expr)]
1492
1493
1494def get_shape(expr):
1495    if hasattr(expr, "shape"):
1496        return expr.shape
1497    return ()
1498
1499
1500def nest_permutation(expr):
1501    if isinstance(expr, PermuteDims):
1502        return expr.nest_permutation()
1503    else:
1504        return expr
1505