1import itertools
2from collections import defaultdict, Counter
3from typing import Tuple, Union, FrozenSet, Dict, List, Optional
4from functools import singledispatch
5from itertools import accumulate
6
7from sympy import Trace, MatrixExpr, Transpose, DiagMatrix, Mul, ZeroMatrix, hadamard_product, S
8from sympy.combinatorics.permutations import _af_invert, Permutation
9from sympy.matrices.common import MatrixCommon
10from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
11from sympy.tensor.array.expressions.array_expressions import PermuteDims, ArrayDiagonal, \
12    ArrayTensorProduct, OneArray, get_rank, _get_subrank, ZeroArray, ArrayContraction, \
13    ArrayAdd, _CodegenArrayAbstract, get_shape, ArrayElementwiseApplyFunc, _ArrayExpr, _EditArrayContraction, _ArgE
14from sympy.tensor.array.expressions.utils import _get_mapping_from_subranks
15
16
17def _get_candidate_for_matmul_from_contraction(scan_indices: List[Optional[int]], remaining_args: List[_ArgE]) -> Tuple[Optional[_ArgE], bool, int]:
18
19    scan_indices = [i for i in scan_indices if i is not None]
20    if len(scan_indices) == 0:
21        return None, False, -1
22
23    transpose: bool = False
24    candidate: Optional[_ArgE] = None
25    candidate_index: int = -1
26    for arg_with_ind2 in remaining_args:
27        if not isinstance(arg_with_ind2.element, MatrixExpr):
28            continue
29        for index in scan_indices:
30            if candidate_index != -1 and candidate_index != index:
31                # A candidate index has already been selected, check
32                # repetitions only for that index:
33                continue
34            if index in arg_with_ind2.indices:
35                if set(arg_with_ind2.indices) == {index}:
36                    # Index repeated twice in arg_with_ind2
37                    candidate = None
38                    break
39                if candidate is None:
40                    candidate = arg_with_ind2
41                    candidate_index = index
42                    transpose = (index == arg_with_ind2.indices[1])
43                else:
44                    # Index repeated more than twice, break
45                    candidate = None
46                    break
47    return candidate, transpose, candidate_index
48
49
50def _insert_candidate_into_editor(editor: _EditArrayContraction, arg_with_ind: _ArgE, candidate: _ArgE, transpose1: bool, transpose2: bool):
51    other = candidate.element
52    other_index: int
53    if transpose2:
54        other = Transpose(other)
55        other_index = candidate.indices[0]
56    else:
57        other_index = candidate.indices[1]
58    new_element = (Transpose(arg_with_ind.element) if transpose1 else arg_with_ind.element) * other
59    editor.args_with_ind.remove(candidate)
60    new_arge = _ArgE(new_element)
61    return new_arge, other_index
62
63
64def _support_function_tp1_recognize(contraction_indices, args):
65    if len(contraction_indices) == 0:
66        return _a2m_tensor_product(*args)
67
68    ac = ArrayContraction(ArrayTensorProduct(*args), *contraction_indices)
69    editor = _EditArrayContraction(ac)
70    editor.track_permutation_start()
71
72    while True:
73        flag_stop: bool = True
74        for i, arg_with_ind in enumerate(editor.args_with_ind):
75            if not isinstance(arg_with_ind.element, MatrixExpr):
76                continue
77
78            first_index = arg_with_ind.indices[0]
79            second_index = arg_with_ind.indices[1]
80
81            first_frequency = editor.count_args_with_index(first_index)
82            second_frequency = editor.count_args_with_index(second_index)
83
84            if first_index is not None and first_frequency == 1 and first_index == second_index:
85                flag_stop = False
86                arg_with_ind.element = Trace(arg_with_ind.element)._normalize()
87                arg_with_ind.indices = []
88                break
89
90            scan_indices = []
91            if first_frequency == 2:
92                scan_indices.append(first_index)
93            if second_frequency == 2:
94                scan_indices.append(second_index)
95
96            candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction(scan_indices, editor.args_with_ind[i+1:])
97            if candidate is not None:
98                flag_stop = False
99                editor.track_permutation_merge(arg_with_ind, candidate)
100                transpose1 = found_index == first_index
101                new_arge, other_index = _insert_candidate_into_editor(editor, arg_with_ind, candidate, transpose1, transpose)
102                if found_index == first_index:
103                    new_arge.indices = [second_index, other_index]
104                else:
105                    new_arge.indices = [first_index, other_index]
106                set_indices = set(new_arge.indices)
107                if len(set_indices) == 1 and set_indices != {None}:
108                    # This is a trace:
109                    new_arge.element = Trace(new_arge.element)._normalize()
110                    new_arge.indices = []
111                editor.args_with_ind[i] = new_arge
112                # TODO: is this break necessary?
113                break
114
115        if flag_stop:
116            break
117
118    editor.refresh_indices()
119    return editor.to_array_contraction()
120
121
122@singledispatch
123def _array2matrix(expr):
124    return expr
125
126
127@_array2matrix.register(ZeroArray)
128def _(expr: ZeroArray):
129    if get_rank(expr) == 2:
130        return ZeroMatrix(*expr.shape)
131    else:
132        return expr
133
134
135@_array2matrix.register(ArrayTensorProduct)
136def _(expr: ArrayTensorProduct):
137    return _a2m_tensor_product(*[_array2matrix(arg) for arg in expr.args])
138
139
140@_array2matrix.register(ArrayContraction)
141def _(expr: ArrayContraction):
142    expr = expr.flatten_contraction_of_diagonal()
143    expr = expr.split_multiple_contractions()
144    expr = identify_hadamard_products(expr)
145    if not isinstance(expr, ArrayContraction):
146        return _array2matrix(expr)
147    subexpr = expr.expr
148    contraction_indices: Tuple[Tuple[int]] = expr.contraction_indices
149    if isinstance(subexpr, ArrayTensorProduct):
150        newexpr = ArrayContraction(_array2matrix(subexpr), *contraction_indices)
151        contraction_indices = newexpr.contraction_indices
152        if any(i > 2 for i in newexpr.subranks):
153            addends = ArrayAdd(*[_a2m_tensor_product(*j) for j in itertools.product(*[i.args if isinstance(i,
154                                                                                                                             ArrayAdd) else [i] for i in expr.expr.args])])
155            newexpr = ArrayContraction(addends, *contraction_indices)
156        if isinstance(newexpr, ArrayAdd):
157            ret = _array2matrix(newexpr)
158            return ret
159        assert isinstance(newexpr, ArrayContraction)
160        ret = _support_function_tp1_recognize(contraction_indices, list(newexpr.expr.args))
161        return ret
162    elif not isinstance(subexpr, _CodegenArrayAbstract):
163        ret = _array2matrix(subexpr)
164        if isinstance(ret, MatrixExpr):
165            assert expr.contraction_indices == ((0, 1),)
166            return _a2m_trace(ret)
167        else:
168            return ArrayContraction(ret, *expr.contraction_indices)
169
170
171@_array2matrix.register(ArrayDiagonal)
172def _(expr: ArrayDiagonal):
173    pexpr = ArrayDiagonal(_array2matrix(expr.expr), *expr.diagonal_indices)
174    pexpr = identify_hadamard_products(pexpr)
175    if isinstance(pexpr, ArrayDiagonal):
176        pexpr = _array_diag2contr_diagmatrix(pexpr)
177    if expr == pexpr:
178        return expr
179    return _array2matrix(pexpr)
180
181
182@_array2matrix.register(PermuteDims)
183def _(expr: PermuteDims):
184    if expr.permutation.array_form == [1, 0]:
185        return _a2m_transpose(_array2matrix(expr.expr))
186    elif isinstance(expr.expr, ArrayTensorProduct):
187        ranks = expr.expr.subranks
188        inv_permutation = expr.permutation**(-1)
189        newrange = [inv_permutation(i) for i in range(sum(ranks))]
190        newpos = []
191        counter = 0
192        for rank in ranks:
193            newpos.append(newrange[counter:counter+rank])
194            counter += rank
195        newargs = []
196        newperm = []
197        scalars = []
198        for pos, arg in zip(newpos, expr.expr.args):
199            if len(pos) == 0:
200                scalars.append(_array2matrix(arg))
201            elif pos == sorted(pos):
202                newargs.append((_array2matrix(arg), pos[0]))
203                newperm.extend(pos)
204            elif len(pos) == 2:
205                newargs.append((_a2m_transpose(_array2matrix(arg)), pos[0]))
206                newperm.extend(reversed(pos))
207            else:
208                raise NotImplementedError()
209        newargs = [i[0] for i in newargs]
210        return PermuteDims(_a2m_tensor_product(*scalars, *newargs), _af_invert(newperm))
211    elif isinstance(expr.expr, ArrayContraction):
212        mat_mul_lines = _array2matrix(expr.expr)
213        if not isinstance(mat_mul_lines, ArrayTensorProduct):
214            flat_cyclic_form = [j for i in expr.permutation.cyclic_form for j in i]
215            expr_shape = get_shape(expr)
216            if all(expr_shape[i] == 1 for i in flat_cyclic_form):
217                return mat_mul_lines
218            return mat_mul_lines
219        # TODO: this assumes that all arguments are matrices, it may not be the case:
220        permutation = Permutation(2*len(mat_mul_lines.args)-1)*expr.permutation
221        permuted = [permutation(i) for i in range(2*len(mat_mul_lines.args))]
222        args_array = [None for i in mat_mul_lines.args]
223        for i in range(len(mat_mul_lines.args)):
224            p1 = permuted[2*i]
225            p2 = permuted[2*i+1]
226            if p1 // 2 != p2 // 2:
227                return PermuteDims(mat_mul_lines, permutation)
228            pos = p1 // 2
229            if p1 > p2:
230                args_array[i] = _a2m_transpose(mat_mul_lines.args[pos])
231            else:
232                args_array[i] = mat_mul_lines.args[pos]
233        return _a2m_tensor_product(*args_array)
234    else:
235        return expr
236
237
238@_array2matrix.register(ArrayAdd)
239def _(expr: ArrayAdd):
240    addends = [_array2matrix(arg) for arg in expr.args]
241    return _a2m_add(*addends)
242
243
244@_array2matrix.register(ArrayElementwiseApplyFunc)
245def _(expr: ArrayElementwiseApplyFunc):
246    subexpr = _array2matrix(expr.expr)
247    if isinstance(subexpr, MatrixExpr):
248        return ElementwiseApplyFunction(expr.function, subexpr)
249    else:
250        return ArrayElementwiseApplyFunc(expr.function, subexpr)
251
252
253@singledispatch
254def _remove_trivial_dims(expr):
255    return expr, []
256
257
258@_remove_trivial_dims.register(ArrayTensorProduct)
259def _(expr: ArrayTensorProduct):
260    # Recognize expressions like [x, y] with shape (k, 1, k, 1) as `x*y.T`.
261    # The matrix expression has to be equivalent to the tensor product of the
262    # matrices, with trivial dimensions (i.e. dim=1) dropped.
263    # That is, add contractions over trivial dimensions:
264
265    removed = []
266    newargs = []
267    cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args]))
268    pending = None
269    prev_i = None
270    for i, arg in enumerate(expr.args):
271        current_range = list(range(cumul[i], cumul[i+1]))
272        if isinstance(arg, OneArray):
273            removed.extend(current_range)
274            continue
275        if not isinstance(arg, (MatrixExpr, MatrixCommon)):
276            rarg, rem = _remove_trivial_dims(arg)
277            removed.extend(rem)
278            newargs.append(rarg)
279            continue
280        elif getattr(arg, "is_Identity", False):
281            if arg.shape == (1, 1):
282                # Ignore identity matrices of shape (1, 1) - they are equivalent to scalar 1.
283                removed.extend(current_range)
284                continue
285            k = arg.shape[0]
286            if pending == k:
287                # OK, there is already
288                removed.extend(current_range)
289                continue
290            elif pending is None:
291                newargs.append(arg)
292                pending = k
293                prev_i = i
294            else:
295                pending = k
296                prev_i = i
297                newargs.append(arg)
298        elif arg.shape == (1, 1):
299            arg, _ = _remove_trivial_dims(arg)
300            # Matrix is equivalent to scalar:
301            if len(newargs) == 0:
302                newargs.append(arg)
303            elif 1 in get_shape(newargs[-1]):
304                if newargs[-1].shape[1] == 1:
305                    newargs[-1] = newargs[-1]*arg
306                else:
307                    newargs[-1] = arg*newargs[-1]
308                removed.extend(current_range)
309            else:
310                newargs.append(arg)
311        elif 1 in arg.shape:
312            k = [i for i in arg.shape if i != 1][0]
313            if pending is None:
314                pending = k
315                prev_i = i
316                newargs.append(arg)
317            elif pending == k:
318                prev = newargs[-1]
319                if prev.is_Identity:
320                    removed.extend([cumul[prev_i], cumul[prev_i]+1])
321                    newargs[-1] = arg
322                    prev_i = i
323                    continue
324                if prev.shape[0] == 1:
325                    d1 = cumul[prev_i]
326                    prev = _a2m_transpose(prev)
327                else:
328                    d1 = cumul[prev_i] + 1
329                if arg.shape[1] == 1:
330                    d2 = cumul[i] + 1
331                    arg = _a2m_transpose(arg)
332                else:
333                    d2 = cumul[i]
334                newargs[-1] = prev*arg
335                pending = None
336                removed.extend([d1, d2])
337            else:
338                newargs.append(arg)
339                pending = k
340                prev_i = i
341        else:
342            newargs.append(arg)
343            pending = None
344    return _a2m_tensor_product(*newargs), sorted(removed)
345
346
347@_remove_trivial_dims.register(ArrayAdd)
348def _(expr: ArrayAdd):
349    rec = [_remove_trivial_dims(arg) for arg in expr.args]
350    newargs, removed = zip(*rec)
351    if len(set(map(tuple, removed))) != 1:
352        return expr, []
353    return _a2m_add(*newargs), removed[0]
354
355
356@_remove_trivial_dims.register(PermuteDims)
357def _(expr: PermuteDims):
358    subexpr, subremoved = _remove_trivial_dims(expr.expr)
359    p = expr.permutation.array_form
360    pinv = _af_invert(expr.permutation.array_form)
361    shift = list(accumulate([1 if i in subremoved else 0 for i in range(len(p))]))
362    premoved = [pinv[i] for i in subremoved]
363    p2 = [e - shift[e] for i, e in enumerate(p) if e not in subremoved]
364    # TODO: check if subremoved should be permuted as well...
365    newexpr = PermuteDims(subexpr, p2)
366    if newexpr != expr:
367        newexpr = _array2matrix(newexpr)
368    return newexpr, sorted(premoved)
369
370
371@_remove_trivial_dims.register(ArrayContraction)
372def _(expr: ArrayContraction):
373    newexpr, removed = _remove_trivial_dims(expr.expr)
374    shifts = list(accumulate([1 if i in removed else 0 for i in range(get_rank(expr.expr))]))
375    new_contraction_indices = [tuple(j for j in i if j not in removed) for i in expr.contraction_indices]
376    # Remove possible empty tuples "()":
377    new_contraction_indices = [i for i in new_contraction_indices if len(i) > 0]
378    contraction_indices_flat = [j for i in expr.contraction_indices for j in i]
379    removed = [i for i in removed if i not in contraction_indices_flat]
380    new_contraction_indices = [tuple(j - shifts[j] for j in i) for i in new_contraction_indices]
381    # Shift removed:
382    removed = ArrayContraction._push_indices_up(expr.contraction_indices, removed)
383    return ArrayContraction(newexpr, *new_contraction_indices), list(removed)
384
385
386@_remove_trivial_dims.register(ArrayDiagonal)
387def _(expr: ArrayDiagonal):
388    newexpr, removed = _remove_trivial_dims(expr.expr)
389    shifts = list(accumulate([0] + [1 if i in removed else 0 for i in range(get_rank(expr.expr))]))
390    new_diag_indices = [tuple(j for j in i if j not in removed) for i in expr.diagonal_indices]
391    new_diag_indices = [tuple(j - shifts[j] for j in i) for i in new_diag_indices]
392    rank = get_rank(expr.expr)
393    removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, rank)
394    removed = sorted({i for i in removed})
395    # If there are single axes to diagonalize remaining, it means that their
396    # corresponding dimension has been removed, they no longer need diagonalization:
397    new_diag_indices = [i for i in new_diag_indices if len(i) > 1]
398    return ArrayDiagonal(newexpr, *new_diag_indices), removed
399
400
401@_remove_trivial_dims.register(ElementwiseApplyFunction)
402def _(expr: ElementwiseApplyFunction):
403    subexpr, removed = _remove_trivial_dims(expr.expr)
404    if subexpr.shape == (1, 1):
405        # TODO: move this to ElementwiseApplyFunction
406        return expr.function(subexpr), removed + [0, 1]
407    return ElementwiseApplyFunction(expr.function, subexpr)
408
409
410@_remove_trivial_dims.register(ArrayElementwiseApplyFunc)
411def _(expr: ArrayElementwiseApplyFunc):
412    subexpr, removed = _remove_trivial_dims(expr.expr)
413    return ArrayElementwiseApplyFunc(expr.function, subexpr), removed
414
415
416def convert_array_to_matrix(expr):
417    r"""
418    Recognize matrix expressions in codegen objects.
419
420    If more than one matrix multiplication line have been detected, return a
421    list with the matrix expressions.
422
423    Examples
424    ========
425
426    >>> from sympy.tensor.array.expressions.conv_indexed_to_array import convert_indexed_to_array
427    >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
428    >>> from sympy import MatrixSymbol, Sum
429    >>> from sympy.abc import i, j, k, l, N
430    >>> from sympy.tensor.array.expressions.array_expressions import ArrayContraction
431    >>> from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
432    >>> from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
433    >>> A = MatrixSymbol("A", N, N)
434    >>> B = MatrixSymbol("B", N, N)
435    >>> C = MatrixSymbol("C", N, N)
436    >>> D = MatrixSymbol("D", N, N)
437
438    >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1))
439    >>> cg = convert_indexed_to_array(expr)
440    >>> convert_array_to_matrix(cg)
441    A*B
442    >>> cg = convert_indexed_to_array(expr, first_indices=[k])
443    >>> convert_array_to_matrix(cg)
444    B.T*A.T
445
446    Transposition is detected:
447
448    >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1))
449    >>> cg = convert_indexed_to_array(expr)
450    >>> convert_array_to_matrix(cg)
451    A.T*B
452    >>> cg = convert_indexed_to_array(expr, first_indices=[k])
453    >>> convert_array_to_matrix(cg)
454    B.T*A
455
456    Detect the trace:
457
458    >>> expr = Sum(A[i, i], (i, 0, N-1))
459    >>> cg = convert_indexed_to_array(expr)
460    >>> convert_array_to_matrix(cg)
461    Trace(A)
462
463    Recognize some more complex traces:
464
465    >>> expr = Sum(A[i, j]*B[j, i], (i, 0, N-1), (j, 0, N-1))
466    >>> cg = convert_indexed_to_array(expr)
467    >>> convert_array_to_matrix(cg)
468    Trace(A*B)
469
470    More complicated expressions:
471
472    >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1))
473    >>> cg = convert_indexed_to_array(expr)
474    >>> convert_array_to_matrix(cg)
475    A*B.T*A.T
476
477    Expressions constructed from matrix expressions do not contain literal
478    indices, the positions of free indices are returned instead:
479
480    >>> expr = A*B
481    >>> cg = convert_matrix_to_array(expr)
482    >>> convert_array_to_matrix(cg)
483    A*B
484
485    If more than one line of matrix multiplications is detected, return
486    separate matrix multiplication factors embedded in a tensor product object:
487
488    >>> cg = ArrayContraction(ArrayTensorProduct(A, B, C, D), (1, 2), (5, 6))
489    >>> convert_array_to_matrix(cg)
490    ArrayTensorProduct(A*B, C*D)
491
492    The two lines have free indices at axes 0, 3 and 4, 7, respectively.
493    """
494    rec = _array2matrix(expr)
495    rec, removed = _remove_trivial_dims(rec)
496    return rec
497
498
499def _array_diag2contr_diagmatrix(expr: ArrayDiagonal):
500    if isinstance(expr.expr, ArrayTensorProduct):
501        args = list(expr.expr.args)
502        diag_indices = list(expr.diagonal_indices)
503        mapping = _get_mapping_from_subranks([_get_subrank(arg) for arg in args])
504        tuple_links = [[mapping[j] for j in i] for i in diag_indices]
505        contr_indices = []
506        total_rank = get_rank(expr)
507        replaced = [False for arg in args]
508        for i, (abs_pos, rel_pos) in enumerate(zip(diag_indices, tuple_links)):
509            if len(abs_pos) != 2:
510                continue
511            (pos1_outer, pos1_inner), (pos2_outer, pos2_inner) = rel_pos
512            arg1 = args[pos1_outer]
513            arg2 = args[pos2_outer]
514            if get_rank(arg1) != 2 or get_rank(arg2) != 2:
515                if replaced[pos1_outer]:
516                    diag_indices[i] = None
517                if replaced[pos2_outer]:
518                    diag_indices[i] = None
519                continue
520            pos1_in2 = 1 - pos1_inner
521            pos2_in2 = 1 - pos2_inner
522            if arg1.shape[pos1_in2] == 1:
523                darg1 = DiagMatrix(arg1)
524                args.append(darg1)
525                contr_indices.append(((pos2_outer, pos2_inner), (len(args)-1, pos1_inner)))
526                total_rank += 1
527                diag_indices[i] = None
528                args[pos1_outer] = OneArray(arg1.shape[pos1_in2])
529                replaced[pos1_outer] = True
530            elif arg2.shape[pos2_in2] == 1:
531                darg2 = DiagMatrix(arg2)
532                args.append(darg2)
533                contr_indices.append(((pos1_outer, pos1_inner), (len(args)-1, pos2_inner)))
534                total_rank += 1
535                diag_indices[i] = None
536                args[pos2_outer] = OneArray(arg2.shape[pos2_in2])
537                replaced[pos2_outer] = True
538        diag_indices_new = [i for i in diag_indices if i is not None]
539        cumul = list(accumulate([0] + [get_rank(arg) for arg in args]))
540        contr_indices2 = [tuple(cumul[a] + b for a, b in i) for i in contr_indices]
541        tc = ArrayContraction(
542            ArrayTensorProduct(*args), *contr_indices2
543        )
544        td = ArrayDiagonal(tc, *diag_indices_new)
545        return td
546    return expr
547
548
549def _a2m_mul(*args):
550    if all(not isinstance(i, _CodegenArrayAbstract) for i in args):
551        from sympy import MatMul
552        return MatMul(*args).doit()
553    else:
554        return ArrayContraction(
555            ArrayTensorProduct(*args),
556            *[(2*i-1, 2*i) for i in range(1, len(args))]
557        )
558
559
560def _a2m_tensor_product(*args):
561    scalars = []
562    arrays = []
563    for arg in args:
564        if isinstance(arg, (MatrixExpr, _ArrayExpr, _CodegenArrayAbstract)):
565            arrays.append(arg)
566        else:
567            scalars.append(arg)
568    scalar = Mul.fromiter(scalars)
569    if len(arrays) == 0:
570        return scalar
571    if scalar != 1:
572        if isinstance(arrays[0], _CodegenArrayAbstract):
573            arrays = [scalar] + arrays
574        else:
575            arrays[0] *= scalar
576    return ArrayTensorProduct(*arrays)
577
578
579def _a2m_add(*args):
580    if all(not isinstance(i, _CodegenArrayAbstract) for i in args):
581        from sympy import MatAdd
582        return MatAdd(*args).doit()
583    else:
584        return ArrayAdd(*args)
585
586
587def _a2m_trace(arg):
588    if isinstance(arg, _CodegenArrayAbstract):
589        return ArrayContraction(arg, (0, 1))
590    else:
591        from sympy import Trace
592        return Trace(arg)
593
594
595def _a2m_transpose(arg):
596    if isinstance(arg, _CodegenArrayAbstract):
597        return PermuteDims(arg, [1, 0])
598    else:
599        from sympy import Transpose
600        return Transpose(arg).doit()
601
602
603def identify_hadamard_products(expr: Union[ArrayContraction, ArrayDiagonal]):
604    mapping = _get_mapping_from_subranks(expr.subranks)
605
606    editor: _EditArrayContraction
607    if isinstance(expr, ArrayContraction):
608        editor = _EditArrayContraction(expr)
609    elif isinstance(expr, ArrayDiagonal):
610        if isinstance(expr.expr, ArrayContraction):
611            editor = _EditArrayContraction(expr.expr)
612            diagonalized = ArrayContraction._push_indices_down(expr.expr.contraction_indices, expr.diagonal_indices)
613        elif isinstance(expr.expr, ArrayTensorProduct):
614            editor = _EditArrayContraction(None)
615            editor.args_with_ind = [_ArgE(arg) for i, arg in enumerate(expr.expr.args)]
616            diagonalized = expr.diagonal_indices
617        else:
618            return expr
619
620        # Trick: add diagonalized indices as negative indices into the editor object:
621        for i, e in enumerate(diagonalized):
622            for j in e:
623                arg_pos, rel_pos = mapping[j]
624                editor.args_with_ind[arg_pos].indices[rel_pos] = -1 - i
625
626    map_contr_to_args: Dict[FrozenSet, List[_ArgE]] = defaultdict(list)
627    map_ind_to_inds = defaultdict(int)
628    for arg_with_ind in editor.args_with_ind:
629        for ind in arg_with_ind.indices:
630            map_ind_to_inds[ind] += 1
631        if None in arg_with_ind.indices:
632            continue
633        map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind)
634
635    k: FrozenSet[int]
636    v: List[_ArgE]
637    for k, v in map_contr_to_args.items():
638        make_trace: bool = False
639        if len(k) == 1 and next(iter(k)) >= 0 and sum([next(iter(k)) in i for i in map_contr_to_args]) == 1:
640            # This is a trace: the arguments are fully contracted with only one
641            # index, and the index isn't used anywhere else:
642            make_trace = True
643            first_element = S.One
644        elif len(k) != 2:
645            # Hadamard product only defined for matrices:
646            continue
647        if len(v) == 1:
648            # Hadamard product with a single argument makes no sense:
649            continue
650        for ind in k:
651            if map_ind_to_inds[ind] <= 2:
652                # There is no other contraction, skip:
653                continue
654
655        def check_transpose(x):
656            x = [i if i >= 0 else -1-i for i in x]
657            return x == sorted(x)
658
659        # Check if expression is a trace:
660        if all([map_ind_to_inds[j] == len(v) and j >= 0 for j in k]) and all([j >= 0 for j in k]):
661            # This is a trace
662            make_trace = True
663            first_element = v[0].element
664            if not check_transpose(v[0].indices):
665                first_element = first_element.T
666            hadamard_factors = v[1:]
667        else:
668            hadamard_factors = v
669
670        # This is a Hadamard product:
671
672        hp = hadamard_product(*[i.element if check_transpose(i.indices) else Transpose(i.element) for i in hadamard_factors])
673        hp_indices = v[0].indices
674        if not check_transpose(hadamard_factors[0].indices):
675            hp_indices = list(reversed(hp_indices))
676        if make_trace:
677            hp = Trace(first_element*hp.T)._normalize()
678            hp_indices = []
679        editor.insert_after(v[0], _ArgE(hp, hp_indices))
680        for i in v:
681            editor.args_with_ind.remove(i)
682
683    # Count the ranks of the arguments:
684    counter = 0
685    # Create a collector for the new diagonal indices:
686    diag_indices = defaultdict(list)
687
688    count_index_freq = Counter()
689    for arg_with_ind in editor.args_with_ind:
690        count_index_freq.update(Counter(arg_with_ind.indices))
691
692    free_index_count = count_index_freq[None]
693
694    # Construct the inverse permutation:
695    inv_perm1 = []
696    inv_perm2 = []
697    # Keep track of which diagonal indices have already been processed:
698    done = set([])
699
700    # Counter for the diagonal indices:
701    counter4 = 0
702
703    for arg_with_ind in editor.args_with_ind:
704        # If some diagonalization axes have been removed, they should be
705        # permuted in order to keep the permutation.
706        # Add permutation here
707        counter2 = 0  # counter for the indices
708        for i in arg_with_ind.indices:
709            if i is None:
710                inv_perm1.append(counter4)
711                counter2 += 1
712                counter4 += 1
713                continue
714            if i >= 0:
715                continue
716            # Reconstruct the diagonal indices:
717            diag_indices[-1 - i].append(counter + counter2)
718            if count_index_freq[i] == 1 and i not in done:
719                inv_perm1.append(free_index_count - 1 - i)
720                done.add(i)
721            elif i not in done:
722                inv_perm2.append(free_index_count - 1 - i)
723                done.add(i)
724            counter2 += 1
725        # Remove negative indices to restore a proper editor object:
726        arg_with_ind.indices = [i if i is not None and i >= 0 else None for i in arg_with_ind.indices]
727        counter += len([i for i in arg_with_ind.indices if i is None or i < 0])
728
729    inverse_permutation = inv_perm1 + inv_perm2
730    permutation = _af_invert(inverse_permutation)
731
732    if isinstance(expr, ArrayContraction):
733        return editor.to_array_contraction()
734    else:
735        # Get the diagonal indices after the detection of HadamardProduct in the expression:
736        diag_indices_filtered = [tuple(v) for v in diag_indices.values() if len(v) > 1]
737
738        expr1 = editor.to_array_contraction()
739        expr2 = ArrayDiagonal(expr1, *diag_indices_filtered)
740        expr3 = PermuteDims(expr2, permutation)
741        return expr3
742