1import itertools
2from collections import defaultdict, Counter
3from typing import Tuple, Union, FrozenSet, Dict, List, Optional
4from functools import singledispatch
5from itertools import accumulate
7from sympy import Trace, MatrixExpr, Transpose, DiagMatrix, Mul, ZeroMatrix, hadamard_product, S
8from sympy.combinatorics.permutations import _af_invert, Permutation
9from sympy.matrices.common import MatrixCommon
10from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
11from sympy.tensor.array.expressions.array_expressions import PermuteDims, ArrayDiagonal, \
12    ArrayTensorProduct, OneArray, get_rank, _get_subrank, ZeroArray, ArrayContraction, \
13    ArrayAdd, _CodegenArrayAbstract, get_shape, ArrayElementwiseApplyFunc, _ArrayExpr, _EditArrayContraction, _ArgE
14from sympy.tensor.array.expressions.utils import _get_mapping_from_subranks
17def _get_candidate_for_matmul_from_contraction(scan_indices: List[Optional[int]], remaining_args: List[_ArgE]) -> Tuple[Optional[_ArgE], bool, int]:
19    scan_indices = [i for i in scan_indices if i is not None]
20    if len(scan_indices) == 0:
21        return None, False, -1
23    transpose: bool = False
24    candidate: Optional[_ArgE] = None
25    candidate_index: int = -1
26    for arg_with_ind2 in remaining_args:
27        if not isinstance(arg_with_ind2.element, MatrixExpr):
28            continue
29        for index in scan_indices:
30            if candidate_index != -1 and candidate_index != index:
31                # A candidate index has already been selected, check
32                # repetitions only for that index:
33                continue
34            if index in arg_with_ind2.indices:
35                if set(arg_with_ind2.indices) == {index}:
36                    # Index repeated twice in arg_with_ind2
37                    candidate = None
38                    break
39                if candidate is None:
40                    candidate = arg_with_ind2
41                    candidate_index = index
42                    transpose = (index == arg_with_ind2.indices[1])
43                else:
44                    # Index repeated more than twice, break
45                    candidate = None
46                    break
47    return candidate, transpose, candidate_index
50def _insert_candidate_into_editor(editor: _EditArrayContraction, arg_with_ind: _ArgE, candidate: _ArgE, transpose1: bool, transpose2: bool):
51    other = candidate.element
52    other_index: int
53    if transpose2:
54        other = Transpose(other)
55        other_index = candidate.indices[0]
56    else:
57        other_index = candidate.indices[1]
58    new_element = (Transpose(arg_with_ind.element) if transpose1 else arg_with_ind.element) * other
59    editor.args_with_ind.remove(candidate)
60    new_arge = _ArgE(new_element)
61    return new_arge, other_index
64def _support_function_tp1_recognize(contraction_indices, args):
65    if len(contraction_indices) == 0:
66        return _a2m_tensor_product(*args)
68    ac = ArrayContraction(ArrayTensorProduct(*args), *contraction_indices)
69    editor = _EditArrayContraction(ac)
70    editor.track_permutation_start()
72    while True:
73        flag_stop: bool = True
74        for i, arg_with_ind in enumerate(editor.args_with_ind):
75            if not isinstance(arg_with_ind.element, MatrixExpr):
76                continue
78            first_index = arg_with_ind.indices[0]
79            second_index = arg_with_ind.indices[1]
81            first_frequency = editor.count_args_with_index(first_index)
82            second_frequency = editor.count_args_with_index(second_index)
84            if first_index is not None and first_frequency == 1 and first_index == second_index:
85                flag_stop = False
86                arg_with_ind.element = Trace(arg_with_ind.element)._normalize()
87                arg_with_ind.indices = []
88                break
90            scan_indices = []
91            if first_frequency == 2:
92                scan_indices.append(first_index)
93            if second_frequency == 2:
94                scan_indices.append(second_index)
96            candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction(scan_indices, editor.args_with_ind[i+1:])
97            if candidate is not None:
98                flag_stop = False
99                editor.track_permutation_merge(arg_with_ind, candidate)
100                transpose1 = found_index == first_index
101                new_arge, other_index = _insert_candidate_into_editor(editor, arg_with_ind, candidate, transpose1, transpose)
102                if found_index == first_index:
103                    new_arge.indices = [second_index, other_index]
104                else:
105                    new_arge.indices = [first_index, other_index]
106                set_indices = set(new_arge.indices)
107                if len(set_indices) == 1 and set_indices != {None}:
108                    # This is a trace:
109                    new_arge.element = Trace(new_arge.element)._normalize()
110                    new_arge.indices = []
111                editor.args_with_ind[i] = new_arge
112                # TODO: is this break necessary?
113                break
115        if flag_stop:
116            break
118    editor.refresh_indices()
119    return editor.to_array_contraction()
123def _array2matrix(expr):
124    return expr
128def _(expr: ZeroArray):
129    if get_rank(expr) == 2:
130        return ZeroMatrix(*expr.shape)
131    else:
132        return expr
136def _(expr: ArrayTensorProduct):
137    return _a2m_tensor_product(*[_array2matrix(arg) for arg in expr.args])
141def _(expr: ArrayContraction):
142    expr = expr.flatten_contraction_of_diagonal()
143    expr = expr.split_multiple_contractions()
144    expr = identify_hadamard_products(expr)
145    if not isinstance(expr, ArrayContraction):
146        return _array2matrix(expr)
147    subexpr = expr.expr
148    contraction_indices: Tuple[Tuple[int]] = expr.contraction_indices
149    if isinstance(subexpr, ArrayTensorProduct):
150        newexpr = ArrayContraction(_array2matrix(subexpr), *contraction_indices)
151        contraction_indices = newexpr.contraction_indices
152        if any(i > 2 for i in newexpr.subranks):
153            addends = ArrayAdd(*[_a2m_tensor_product(*j) for j in itertools.product(*[i.args if isinstance(i,
154                                                                                                                             ArrayAdd) else [i] for i in expr.expr.args])])
155            newexpr = ArrayContraction(addends, *contraction_indices)
156        if isinstance(newexpr, ArrayAdd):
157            ret = _array2matrix(newexpr)
158            return ret
159        assert isinstance(newexpr, ArrayContraction)
160        ret = _support_function_tp1_recognize(contraction_indices, list(newexpr.expr.args))
161        return ret
162    elif not isinstance(subexpr, _CodegenArrayAbstract):
163        ret = _array2matrix(subexpr)
164        if isinstance(ret, MatrixExpr):
165            assert expr.contraction_indices == ((0, 1),)
166            return _a2m_trace(ret)
167        else:
168            return ArrayContraction(ret, *expr.contraction_indices)
172def _(expr: ArrayDiagonal):
173    pexpr = ArrayDiagonal(_array2matrix(expr.expr), *expr.diagonal_indices)
174    pexpr = identify_hadamard_products(pexpr)
175    if isinstance(pexpr, ArrayDiagonal):
176        pexpr = _array_diag2contr_diagmatrix(pexpr)
177    if expr == pexpr:
178        return expr
179    return _array2matrix(pexpr)
183def _(expr: PermuteDims):
184    if expr.permutation.array_form == [1, 0]:
185        return _a2m_transpose(_array2matrix(expr.expr))
186    elif isinstance(expr.expr, ArrayTensorProduct):
187        ranks = expr.expr.subranks
188        inv_permutation = expr.permutation**(-1)
189        newrange = [inv_permutation(i) for i in range(sum(ranks))]
190        newpos = []
191        counter = 0
192        for rank in ranks:
193            newpos.append(newrange[counter:counter+rank])
194            counter += rank
195        newargs = []
196        newperm = []
197        scalars = []
198        for pos, arg in zip(newpos, expr.expr.args):
199            if len(pos) == 0:
200                scalars.append(_array2matrix(arg))
201            elif pos == sorted(pos):
202                newargs.append((_array2matrix(arg), pos[0]))
203                newperm.extend(pos)
204            elif len(pos) == 2:
205                newargs.append((_a2m_transpose(_array2matrix(arg)), pos[0]))
206                newperm.extend(reversed(pos))
207            else:
208                raise NotImplementedError()
209        newargs = [i[0] for i in newargs]
210        return PermuteDims(_a2m_tensor_product(*scalars, *newargs), _af_invert(newperm))
211    elif isinstance(expr.expr, ArrayContraction):
212        mat_mul_lines = _array2matrix(expr.expr)
213        if not isinstance(mat_mul_lines, ArrayTensorProduct):
214            flat_cyclic_form = [j for i in expr.permutation.cyclic_form for j in i]
215            expr_shape = get_shape(expr)
216            if all(expr_shape[i] == 1 for i in flat_cyclic_form):
217                return mat_mul_lines
218            return mat_mul_lines
219        # TODO: this assumes that all arguments are matrices, it may not be the case:
220        permutation = Permutation(2*len(mat_mul_lines.args)-1)*expr.permutation
221        permuted = [permutation(i) for i in range(2*len(mat_mul_lines.args))]
222        args_array = [None for i in mat_mul_lines.args]
223        for i in range(len(mat_mul_lines.args)):
224            p1 = permuted[2*i]
225            p2 = permuted[2*i+1]
226            if p1 // 2 != p2 // 2:
227                return PermuteDims(mat_mul_lines, permutation)
228            pos = p1 // 2
229            if p1 > p2:
230                args_array[i] = _a2m_transpose(mat_mul_lines.args[pos])
231            else:
232                args_array[i] = mat_mul_lines.args[pos]
233        return _a2m_tensor_product(*args_array)
234    else:
235        return expr
239def _(expr: ArrayAdd):
240    addends = [_array2matrix(arg) for arg in expr.args]
241    return _a2m_add(*addends)
245def _(expr: ArrayElementwiseApplyFunc):
246    subexpr = _array2matrix(expr.expr)
247    if isinstance(subexpr, MatrixExpr):
248        return ElementwiseApplyFunction(expr.function, subexpr)
249    else:
250        return ArrayElementwiseApplyFunc(expr.function, subexpr)
254def _remove_trivial_dims(expr):
255    return expr, []
259def _(expr: ArrayTensorProduct):
260    # Recognize expressions like [x, y] with shape (k, 1, k, 1) as `x*y.T`.
261    # The matrix expression has to be equivalent to the tensor product of the
262    # matrices, with trivial dimensions (i.e. dim=1) dropped.
263    # That is, add contractions over trivial dimensions:
265    removed = []
266    newargs = []
267    cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args]))
268    pending = None
269    prev_i = None
270    for i, arg in enumerate(expr.args):
271        current_range = list(range(cumul[i], cumul[i+1]))
272        if isinstance(arg, OneArray):
273            removed.extend(current_range)
274            continue
275        if not isinstance(arg, (MatrixExpr, MatrixCommon)):
276            rarg, rem = _remove_trivial_dims(arg)
277            removed.extend(rem)
278            newargs.append(rarg)
279            continue
280        elif getattr(arg, "is_Identity", False):
281            if arg.shape == (1, 1):
282                # Ignore identity matrices of shape (1, 1) - they are equivalent to scalar 1.
283                removed.extend(current_range)
284                continue
285            k = arg.shape[0]
286            if pending == k:
287                # OK, there is already
288                removed.extend(current_range)
289                continue
290            elif pending is None:
291                newargs.append(arg)
292                pending = k
293                prev_i = i
294            else:
295                pending = k
296                prev_i = i
297                newargs.append(arg)
298        elif arg.shape == (1, 1):
299            arg, _ = _remove_trivial_dims(arg)
300            # Matrix is equivalent to scalar:
301            if len(newargs) == 0:
302                newargs.append(arg)
303            elif 1 in get_shape(newargs[-1]):
304                if newargs[-1].shape[1] == 1:
305                    newargs[-1] = newargs[-1]*arg
306                else:
307                    newargs[-1] = arg*newargs[-1]
308                removed.extend(current_range)
309            else:
310                newargs.append(arg)
311        elif 1 in arg.shape:
312            k = [i for i in arg.shape if i != 1][0]
313            if pending is None:
314                pending = k
315                prev_i = i
316                newargs.append(arg)
317            elif pending == k:
318                prev = newargs[-1]
319                if prev.is_Identity:
320                    removed.extend([cumul[prev_i], cumul[prev_i]+1])
321                    newargs[-1] = arg
322                    prev_i = i
323                    continue
324                if prev.shape[0] == 1:
325                    d1 = cumul[prev_i]
326                    prev = _a2m_transpose(prev)
327                else:
328                    d1 = cumul[prev_i] + 1
329                if arg.shape[1] == 1:
330                    d2 = cumul[i] + 1
331                    arg = _a2m_transpose(arg)
332                else:
333                    d2 = cumul[i]
334                newargs[-1] = prev*arg
335                pending = None
336                removed.extend([d1, d2])
337            else:
338                newargs.append(arg)
339                pending = k
340                prev_i = i
341        else:
342            newargs.append(arg)
343            pending = None
344    return _a2m_tensor_product(*newargs), sorted(removed)
348def _(expr: ArrayAdd):
349    rec = [_remove_trivial_dims(arg) for arg in expr.args]
350    newargs, removed = zip(*rec)
351    if len(set(map(tuple, removed))) != 1:
352        return expr, []
353    return _a2m_add(*newargs), removed[0]
357def _(expr: PermuteDims):
358    subexpr, subremoved = _remove_trivial_dims(expr.expr)
359    p = expr.permutation.array_form
360    pinv = _af_invert(expr.permutation.array_form)
361    shift = list(accumulate([1 if i in subremoved else 0 for i in range(len(p))]))
362    premoved = [pinv[i] for i in subremoved]
363    p2 = [e - shift[e] for i, e in enumerate(p) if e not in subremoved]
364    # TODO: check if subremoved should be permuted as well...
365    newexpr = PermuteDims(subexpr, p2)
366    if newexpr != expr:
367        newexpr = _array2matrix(newexpr)
368    return newexpr, sorted(premoved)
372def _(expr: ArrayContraction):
373    newexpr, removed = _remove_trivial_dims(expr.expr)
374    shifts = list(accumulate([1 if i in removed else 0 for i in range(get_rank(expr.expr))]))
375    new_contraction_indices = [tuple(j for j in i if j not in removed) for i in expr.contraction_indices]
376    # Remove possible empty tuples "()":
377    new_contraction_indices = [i for i in new_contraction_indices if len(i) > 0]
378    contraction_indices_flat = [j for i in expr.contraction_indices for j in i]
379    removed = [i for i in removed if i not in contraction_indices_flat]
380    new_contraction_indices = [tuple(j - shifts[j] for j in i) for i in new_contraction_indices]
381    # Shift removed:
382    removed = ArrayContraction._push_indices_up(expr.contraction_indices, removed)
383    return ArrayContraction(newexpr, *new_contraction_indices), list(removed)
387def _(expr: ArrayDiagonal):
388    newexpr, removed = _remove_trivial_dims(expr.expr)
389    shifts = list(accumulate([0] + [1 if i in removed else 0 for i in range(get_rank(expr.expr))]))
390    new_diag_indices = [tuple(j for j in i if j not in removed) for i in expr.diagonal_indices]
391    new_diag_indices = [tuple(j - shifts[j] for j in i) for i in new_diag_indices]
392    rank = get_rank(expr.expr)
393    removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, rank)
394    removed = sorted({i for i in removed})
395    # If there are single axes to diagonalize remaining, it means that their
396    # corresponding dimension has been removed, they no longer need diagonalization:
397    new_diag_indices = [i for i in new_diag_indices if len(i) > 1]
398    return ArrayDiagonal(newexpr, *new_diag_indices), removed
402def _(expr: ElementwiseApplyFunction):
403    subexpr, removed = _remove_trivial_dims(expr.expr)
404    if subexpr.shape == (1, 1):
405        # TODO: move this to ElementwiseApplyFunction
406        return expr.function(subexpr), removed + [0, 1]
407    return ElementwiseApplyFunction(expr.function, subexpr)
411def _(expr: ArrayElementwiseApplyFunc):
412    subexpr, removed = _remove_trivial_dims(expr.expr)
413    return ArrayElementwiseApplyFunc(expr.function, subexpr), removed
416def convert_array_to_matrix(expr):
417    r"""
418    Recognize matrix expressions in codegen objects.
420    If more than one matrix multiplication line have been detected, return a
421    list with the matrix expressions.
423    Examples
424    ========
426    >>> from sympy.tensor.array.expressions.conv_indexed_to_array import convert_indexed_to_array
427    >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
428    >>> from sympy import MatrixSymbol, Sum
429    >>> from sympy.abc import i, j, k, l, N
430    >>> from sympy.tensor.array.expressions.array_expressions import ArrayContraction
431    >>> from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
432    >>> from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
433    >>> A = MatrixSymbol("A", N, N)
434    >>> B = MatrixSymbol("B", N, N)
435    >>> C = MatrixSymbol("C", N, N)
436    >>> D = MatrixSymbol("D", N, N)
438    >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1))
439    >>> cg = convert_indexed_to_array(expr)
440    >>> convert_array_to_matrix(cg)
441    A*B
442    >>> cg = convert_indexed_to_array(expr, first_indices=[k])
443    >>> convert_array_to_matrix(cg)
444    B.T*A.T
446    Transposition is detected:
448    >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1))
449    >>> cg = convert_indexed_to_array(expr)
450    >>> convert_array_to_matrix(cg)
451    A.T*B
452    >>> cg = convert_indexed_to_array(expr, first_indices=[k])
453    >>> convert_array_to_matrix(cg)
454    B.T*A
456    Detect the trace:
458    >>> expr = Sum(A[i, i], (i, 0, N-1))
459    >>> cg = convert_indexed_to_array(expr)
460    >>> convert_array_to_matrix(cg)
461    Trace(A)
463    Recognize some more complex traces:
465    >>> expr = Sum(A[i, j]*B[j, i], (i, 0, N-1), (j, 0, N-1))
466    >>> cg = convert_indexed_to_array(expr)
467    >>> convert_array_to_matrix(cg)
468    Trace(A*B)
470    More complicated expressions:
472    >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1))
473    >>> cg = convert_indexed_to_array(expr)
474    >>> convert_array_to_matrix(cg)
475    A*B.T*A.T
477    Expressions constructed from matrix expressions do not contain literal
478    indices, the positions of free indices are returned instead:
480    >>> expr = A*B
481    >>> cg = convert_matrix_to_array(expr)
482    >>> convert_array_to_matrix(cg)
483    A*B
485    If more than one line of matrix multiplications is detected, return
486    separate matrix multiplication factors embedded in a tensor product object:
488    >>> cg = ArrayContraction(ArrayTensorProduct(A, B, C, D), (1, 2), (5, 6))
489    >>> convert_array_to_matrix(cg)
490    ArrayTensorProduct(A*B, C*D)
492    The two lines have free indices at axes 0, 3 and 4, 7, respectively.
493    """
494    rec = _array2matrix(expr)
495    rec, removed = _remove_trivial_dims(rec)
496    return rec
499def _array_diag2contr_diagmatrix(expr: ArrayDiagonal):
500    if isinstance(expr.expr, ArrayTensorProduct):
501        args = list(expr.expr.args)
502        diag_indices = list(expr.diagonal_indices)
503        mapping = _get_mapping_from_subranks([_get_subrank(arg) for arg in args])
504        tuple_links = [[mapping[j] for j in i] for i in diag_indices]
505        contr_indices = []
506        total_rank = get_rank(expr)
507        replaced = [False for arg in args]
508        for i, (abs_pos, rel_pos) in enumerate(zip(diag_indices, tuple_links)):
509            if len(abs_pos) != 2:
510                continue
511            (pos1_outer, pos1_inner), (pos2_outer, pos2_inner) = rel_pos
512            arg1 = args[pos1_outer]
513            arg2 = args[pos2_outer]
514            if get_rank(arg1) != 2 or get_rank(arg2) != 2:
515                if replaced[pos1_outer]:
516                    diag_indices[i] = None
517                if replaced[pos2_outer]:
518                    diag_indices[i] = None
519                continue
520            pos1_in2 = 1 - pos1_inner
521            pos2_in2 = 1 - pos2_inner
522            if arg1.shape[pos1_in2] == 1:
523                darg1 = DiagMatrix(arg1)
524                args.append(darg1)
525                contr_indices.append(((pos2_outer, pos2_inner), (len(args)-1, pos1_inner)))
526                total_rank += 1
527                diag_indices[i] = None
528                args[pos1_outer] = OneArray(arg1.shape[pos1_in2])
529                replaced[pos1_outer] = True
530            elif arg2.shape[pos2_in2] == 1:
531                darg2 = DiagMatrix(arg2)
532                args.append(darg2)
533                contr_indices.append(((pos1_outer, pos1_inner), (len(args)-1, pos2_inner)))
534                total_rank += 1
535                diag_indices[i] = None
536                args[pos2_outer] = OneArray(arg2.shape[pos2_in2])
537                replaced[pos2_outer] = True
538        diag_indices_new = [i for i in diag_indices if i is not None]
539        cumul = list(accumulate([0] + [get_rank(arg) for arg in args]))
540        contr_indices2 = [tuple(cumul[a] + b for a, b in i) for i in contr_indices]
541        tc = ArrayContraction(
542            ArrayTensorProduct(*args), *contr_indices2
543        )
544        td = ArrayDiagonal(tc, *diag_indices_new)
545        return td
546    return expr
549def _a2m_mul(*args):
550    if all(not isinstance(i, _CodegenArrayAbstract) for i in args):
551        from sympy import MatMul
552        return MatMul(*args).doit()
553    else:
554        return ArrayContraction(
555            ArrayTensorProduct(*args),
556            *[(2*i-1, 2*i) for i in range(1, len(args))]
557        )
560def _a2m_tensor_product(*args):
561    scalars = []
562    arrays = []
563    for arg in args:
564        if isinstance(arg, (MatrixExpr, _ArrayExpr, _CodegenArrayAbstract)):
565            arrays.append(arg)
566        else:
567            scalars.append(arg)
568    scalar = Mul.fromiter(scalars)
569    if len(arrays) == 0:
570        return scalar
571    if scalar != 1:
572        if isinstance(arrays[0], _CodegenArrayAbstract):
573            arrays = [scalar] + arrays
574        else:
575            arrays[0] *= scalar
576    return ArrayTensorProduct(*arrays)
579def _a2m_add(*args):
580    if all(not isinstance(i, _CodegenArrayAbstract) for i in args):
581        from sympy import MatAdd
582        return MatAdd(*args).doit()
583    else:
584        return ArrayAdd(*args)
587def _a2m_trace(arg):
588    if isinstance(arg, _CodegenArrayAbstract):
589        return ArrayContraction(arg, (0, 1))
590    else:
591        from sympy import Trace
592        return Trace(arg)
595def _a2m_transpose(arg):
596    if isinstance(arg, _CodegenArrayAbstract):
597        return PermuteDims(arg, [1, 0])
598    else:
599        from sympy import Transpose
600        return Transpose(arg).doit()
603def identify_hadamard_products(expr: Union[ArrayContraction, ArrayDiagonal]):
604    mapping = _get_mapping_from_subranks(expr.subranks)
606    editor: _EditArrayContraction
607    if isinstance(expr, ArrayContraction):
608        editor = _EditArrayContraction(expr)
609    elif isinstance(expr, ArrayDiagonal):
610        if isinstance(expr.expr, ArrayContraction):
611            editor = _EditArrayContraction(expr.expr)
612            diagonalized = ArrayContraction._push_indices_down(expr.expr.contraction_indices, expr.diagonal_indices)
613        elif isinstance(expr.expr, ArrayTensorProduct):
614            editor = _EditArrayContraction(None)
615            editor.args_with_ind = [_ArgE(arg) for i, arg in enumerate(expr.expr.args)]
616            diagonalized = expr.diagonal_indices
617        else:
618            return expr
620        # Trick: add diagonalized indices as negative indices into the editor object:
621        for i, e in enumerate(diagonalized):
622            for j in e:
623                arg_pos, rel_pos = mapping[j]
624                editor.args_with_ind[arg_pos].indices[rel_pos] = -1 - i
626    map_contr_to_args: Dict[FrozenSet, List[_ArgE]] = defaultdict(list)
627    map_ind_to_inds = defaultdict(int)
628    for arg_with_ind in editor.args_with_ind:
629        for ind in arg_with_ind.indices:
630            map_ind_to_inds[ind] += 1
631        if None in arg_with_ind.indices:
632            continue
633        map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind)
635    k: FrozenSet[int]
636    v: List[_ArgE]
637    for k, v in map_contr_to_args.items():
638        make_trace: bool = False
639        if len(k) == 1 and next(iter(k)) >= 0 and sum([next(iter(k)) in i for i in map_contr_to_args]) == 1:
640            # This is a trace: the arguments are fully contracted with only one
641            # index, and the index isn't used anywhere else:
642            make_trace = True
643            first_element = S.One
644        elif len(k) != 2:
645            # Hadamard product only defined for matrices:
646            continue
647        if len(v) == 1:
648            # Hadamard product with a single argument makes no sense:
649            continue
650        for ind in k:
651            if map_ind_to_inds[ind] <= 2:
652                # There is no other contraction, skip:
653                continue
655        def check_transpose(x):
656            x = [i if i >= 0 else -1-i for i in x]
657            return x == sorted(x)
659        # Check if expression is a trace:
660        if all([map_ind_to_inds[j] == len(v) and j >= 0 for j in k]) and all([j >= 0 for j in k]):
661            # This is a trace
662            make_trace = True
663            first_element = v[0].element
664            if not check_transpose(v[0].indices):
665                first_element = first_element.T
666            hadamard_factors = v[1:]
667        else:
668            hadamard_factors = v
670        # This is a Hadamard product:
672        hp = hadamard_product(*[i.element if check_transpose(i.indices) else Transpose(i.element) for i in hadamard_factors])
673        hp_indices = v[0].indices
674        if not check_transpose(hadamard_factors[0].indices):
675            hp_indices = list(reversed(hp_indices))
676        if make_trace:
677            hp = Trace(first_element*hp.T)._normalize()
678            hp_indices = []
679        editor.insert_after(v[0], _ArgE(hp, hp_indices))
680        for i in v:
681            editor.args_with_ind.remove(i)
683    # Count the ranks of the arguments:
684    counter = 0
685    # Create a collector for the new diagonal indices:
686    diag_indices = defaultdict(list)
688    count_index_freq = Counter()
689    for arg_with_ind in editor.args_with_ind:
690        count_index_freq.update(Counter(arg_with_ind.indices))
692    free_index_count = count_index_freq[None]
694    # Construct the inverse permutation:
695    inv_perm1 = []
696    inv_perm2 = []
697    # Keep track of which diagonal indices have already been processed:
698    done = set([])
700    # Counter for the diagonal indices:
701    counter4 = 0
703    for arg_with_ind in editor.args_with_ind:
704        # If some diagonalization axes have been removed, they should be
705        # permuted in order to keep the permutation.
706        # Add permutation here
707        counter2 = 0  # counter for the indices
708        for i in arg_with_ind.indices:
709            if i is None:
710                inv_perm1.append(counter4)
711                counter2 += 1
712                counter4 += 1
713                continue
714            if i >= 0:
715                continue
716            # Reconstruct the diagonal indices:
717            diag_indices[-1 - i].append(counter + counter2)
718            if count_index_freq[i] == 1 and i not in done:
719                inv_perm1.append(free_index_count - 1 - i)
720                done.add(i)
721            elif i not in done:
722                inv_perm2.append(free_index_count - 1 - i)
723                done.add(i)
724            counter2 += 1
725        # Remove negative indices to restore a proper editor object:
726        arg_with_ind.indices = [i if i is not None and i >= 0 else None for i in arg_with_ind.indices]
727        counter += len([i for i in arg_with_ind.indices if i is None or i < 0])
729    inverse_permutation = inv_perm1 + inv_perm2
730    permutation = _af_invert(inverse_permutation)
732    if isinstance(expr, ArrayContraction):
733        return editor.to_array_contraction()
734    else:
735        # Get the diagonal indices after the detection of HadamardProduct in the expression:
736        diag_indices_filtered = [tuple(v) for v in diag_indices.values() if len(v) > 1]
738        expr1 = editor.to_array_contraction()
739        expr2 = ArrayDiagonal(expr1, *diag_indices_filtered)
740        expr3 = PermuteDims(expr2, permutation)
741        return expr3