1import itertools 2from collections import defaultdict, Counter 3from typing import Tuple, Union, FrozenSet, Dict, List, Optional 4from functools import singledispatch 5from itertools import accumulate 6 7from sympy import Trace, MatrixExpr, Transpose, DiagMatrix, Mul, ZeroMatrix, hadamard_product, S 8from sympy.combinatorics.permutations import _af_invert, Permutation 9from sympy.matrices.common import MatrixCommon 10from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction 11from sympy.tensor.array.expressions.array_expressions import PermuteDims, ArrayDiagonal, \ 12 ArrayTensorProduct, OneArray, get_rank, _get_subrank, ZeroArray, ArrayContraction, \ 13 ArrayAdd, _CodegenArrayAbstract, get_shape, ArrayElementwiseApplyFunc, _ArrayExpr, _EditArrayContraction, _ArgE 14from sympy.tensor.array.expressions.utils import _get_mapping_from_subranks 15 16 17def _get_candidate_for_matmul_from_contraction(scan_indices: List[Optional[int]], remaining_args: List[_ArgE]) -> Tuple[Optional[_ArgE], bool, int]: 18 19 scan_indices = [i for i in scan_indices if i is not None] 20 if len(scan_indices) == 0: 21 return None, False, -1 22 23 transpose: bool = False 24 candidate: Optional[_ArgE] = None 25 candidate_index: int = -1 26 for arg_with_ind2 in remaining_args: 27 if not isinstance(arg_with_ind2.element, MatrixExpr): 28 continue 29 for index in scan_indices: 30 if candidate_index != -1 and candidate_index != index: 31 # A candidate index has already been selected, check 32 # repetitions only for that index: 33 continue 34 if index in arg_with_ind2.indices: 35 if set(arg_with_ind2.indices) == {index}: 36 # Index repeated twice in arg_with_ind2 37 candidate = None 38 break 39 if candidate is None: 40 candidate = arg_with_ind2 41 candidate_index = index 42 transpose = (index == arg_with_ind2.indices[1]) 43 else: 44 # Index repeated more than twice, break 45 candidate = None 46 break 47 return candidate, transpose, candidate_index 48 49 50def _insert_candidate_into_editor(editor: _EditArrayContraction, arg_with_ind: _ArgE, candidate: _ArgE, transpose1: bool, transpose2: bool): 51 other = candidate.element 52 other_index: int 53 if transpose2: 54 other = Transpose(other) 55 other_index = candidate.indices[0] 56 else: 57 other_index = candidate.indices[1] 58 new_element = (Transpose(arg_with_ind.element) if transpose1 else arg_with_ind.element) * other 59 editor.args_with_ind.remove(candidate) 60 new_arge = _ArgE(new_element) 61 return new_arge, other_index 62 63 64def _support_function_tp1_recognize(contraction_indices, args): 65 if len(contraction_indices) == 0: 66 return _a2m_tensor_product(*args) 67 68 ac = ArrayContraction(ArrayTensorProduct(*args), *contraction_indices) 69 editor = _EditArrayContraction(ac) 70 editor.track_permutation_start() 71 72 while True: 73 flag_stop: bool = True 74 for i, arg_with_ind in enumerate(editor.args_with_ind): 75 if not isinstance(arg_with_ind.element, MatrixExpr): 76 continue 77 78 first_index = arg_with_ind.indices[0] 79 second_index = arg_with_ind.indices[1] 80 81 first_frequency = editor.count_args_with_index(first_index) 82 second_frequency = editor.count_args_with_index(second_index) 83 84 if first_index is not None and first_frequency == 1 and first_index == second_index: 85 flag_stop = False 86 arg_with_ind.element = Trace(arg_with_ind.element)._normalize() 87 arg_with_ind.indices = [] 88 break 89 90 scan_indices = [] 91 if first_frequency == 2: 92 scan_indices.append(first_index) 93 if second_frequency == 2: 94 scan_indices.append(second_index) 95 96 candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction(scan_indices, editor.args_with_ind[i+1:]) 97 if candidate is not None: 98 flag_stop = False 99 editor.track_permutation_merge(arg_with_ind, candidate) 100 transpose1 = found_index == first_index 101 new_arge, other_index = _insert_candidate_into_editor(editor, arg_with_ind, candidate, transpose1, transpose) 102 if found_index == first_index: 103 new_arge.indices = [second_index, other_index] 104 else: 105 new_arge.indices = [first_index, other_index] 106 set_indices = set(new_arge.indices) 107 if len(set_indices) == 1 and set_indices != {None}: 108 # This is a trace: 109 new_arge.element = Trace(new_arge.element)._normalize() 110 new_arge.indices = [] 111 editor.args_with_ind[i] = new_arge 112 # TODO: is this break necessary? 113 break 114 115 if flag_stop: 116 break 117 118 editor.refresh_indices() 119 return editor.to_array_contraction() 120 121 122@singledispatch 123def _array2matrix(expr): 124 return expr 125 126 127@_array2matrix.register(ZeroArray) 128def _(expr: ZeroArray): 129 if get_rank(expr) == 2: 130 return ZeroMatrix(*expr.shape) 131 else: 132 return expr 133 134 135@_array2matrix.register(ArrayTensorProduct) 136def _(expr: ArrayTensorProduct): 137 return _a2m_tensor_product(*[_array2matrix(arg) for arg in expr.args]) 138 139 140@_array2matrix.register(ArrayContraction) 141def _(expr: ArrayContraction): 142 expr = expr.flatten_contraction_of_diagonal() 143 expr = expr.split_multiple_contractions() 144 expr = identify_hadamard_products(expr) 145 if not isinstance(expr, ArrayContraction): 146 return _array2matrix(expr) 147 subexpr = expr.expr 148 contraction_indices: Tuple[Tuple[int]] = expr.contraction_indices 149 if isinstance(subexpr, ArrayTensorProduct): 150 newexpr = ArrayContraction(_array2matrix(subexpr), *contraction_indices) 151 contraction_indices = newexpr.contraction_indices 152 if any(i > 2 for i in newexpr.subranks): 153 addends = ArrayAdd(*[_a2m_tensor_product(*j) for j in itertools.product(*[i.args if isinstance(i, 154 ArrayAdd) else [i] for i in expr.expr.args])]) 155 newexpr = ArrayContraction(addends, *contraction_indices) 156 if isinstance(newexpr, ArrayAdd): 157 ret = _array2matrix(newexpr) 158 return ret 159 assert isinstance(newexpr, ArrayContraction) 160 ret = _support_function_tp1_recognize(contraction_indices, list(newexpr.expr.args)) 161 return ret 162 elif not isinstance(subexpr, _CodegenArrayAbstract): 163 ret = _array2matrix(subexpr) 164 if isinstance(ret, MatrixExpr): 165 assert expr.contraction_indices == ((0, 1),) 166 return _a2m_trace(ret) 167 else: 168 return ArrayContraction(ret, *expr.contraction_indices) 169 170 171@_array2matrix.register(ArrayDiagonal) 172def _(expr: ArrayDiagonal): 173 pexpr = ArrayDiagonal(_array2matrix(expr.expr), *expr.diagonal_indices) 174 pexpr = identify_hadamard_products(pexpr) 175 if isinstance(pexpr, ArrayDiagonal): 176 pexpr = _array_diag2contr_diagmatrix(pexpr) 177 if expr == pexpr: 178 return expr 179 return _array2matrix(pexpr) 180 181 182@_array2matrix.register(PermuteDims) 183def _(expr: PermuteDims): 184 if expr.permutation.array_form == [1, 0]: 185 return _a2m_transpose(_array2matrix(expr.expr)) 186 elif isinstance(expr.expr, ArrayTensorProduct): 187 ranks = expr.expr.subranks 188 inv_permutation = expr.permutation**(-1) 189 newrange = [inv_permutation(i) for i in range(sum(ranks))] 190 newpos = [] 191 counter = 0 192 for rank in ranks: 193 newpos.append(newrange[counter:counter+rank]) 194 counter += rank 195 newargs = [] 196 newperm = [] 197 scalars = [] 198 for pos, arg in zip(newpos, expr.expr.args): 199 if len(pos) == 0: 200 scalars.append(_array2matrix(arg)) 201 elif pos == sorted(pos): 202 newargs.append((_array2matrix(arg), pos[0])) 203 newperm.extend(pos) 204 elif len(pos) == 2: 205 newargs.append((_a2m_transpose(_array2matrix(arg)), pos[0])) 206 newperm.extend(reversed(pos)) 207 else: 208 raise NotImplementedError() 209 newargs = [i[0] for i in newargs] 210 return PermuteDims(_a2m_tensor_product(*scalars, *newargs), _af_invert(newperm)) 211 elif isinstance(expr.expr, ArrayContraction): 212 mat_mul_lines = _array2matrix(expr.expr) 213 if not isinstance(mat_mul_lines, ArrayTensorProduct): 214 flat_cyclic_form = [j for i in expr.permutation.cyclic_form for j in i] 215 expr_shape = get_shape(expr) 216 if all(expr_shape[i] == 1 for i in flat_cyclic_form): 217 return mat_mul_lines 218 return mat_mul_lines 219 # TODO: this assumes that all arguments are matrices, it may not be the case: 220 permutation = Permutation(2*len(mat_mul_lines.args)-1)*expr.permutation 221 permuted = [permutation(i) for i in range(2*len(mat_mul_lines.args))] 222 args_array = [None for i in mat_mul_lines.args] 223 for i in range(len(mat_mul_lines.args)): 224 p1 = permuted[2*i] 225 p2 = permuted[2*i+1] 226 if p1 // 2 != p2 // 2: 227 return PermuteDims(mat_mul_lines, permutation) 228 pos = p1 // 2 229 if p1 > p2: 230 args_array[i] = _a2m_transpose(mat_mul_lines.args[pos]) 231 else: 232 args_array[i] = mat_mul_lines.args[pos] 233 return _a2m_tensor_product(*args_array) 234 else: 235 return expr 236 237 238@_array2matrix.register(ArrayAdd) 239def _(expr: ArrayAdd): 240 addends = [_array2matrix(arg) for arg in expr.args] 241 return _a2m_add(*addends) 242 243 244@_array2matrix.register(ArrayElementwiseApplyFunc) 245def _(expr: ArrayElementwiseApplyFunc): 246 subexpr = _array2matrix(expr.expr) 247 if isinstance(subexpr, MatrixExpr): 248 return ElementwiseApplyFunction(expr.function, subexpr) 249 else: 250 return ArrayElementwiseApplyFunc(expr.function, subexpr) 251 252 253@singledispatch 254def _remove_trivial_dims(expr): 255 return expr, [] 256 257 258@_remove_trivial_dims.register(ArrayTensorProduct) 259def _(expr: ArrayTensorProduct): 260 # Recognize expressions like [x, y] with shape (k, 1, k, 1) as `x*y.T`. 261 # The matrix expression has to be equivalent to the tensor product of the 262 # matrices, with trivial dimensions (i.e. dim=1) dropped. 263 # That is, add contractions over trivial dimensions: 264 265 removed = [] 266 newargs = [] 267 cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args])) 268 pending = None 269 prev_i = None 270 for i, arg in enumerate(expr.args): 271 current_range = list(range(cumul[i], cumul[i+1])) 272 if isinstance(arg, OneArray): 273 removed.extend(current_range) 274 continue 275 if not isinstance(arg, (MatrixExpr, MatrixCommon)): 276 rarg, rem = _remove_trivial_dims(arg) 277 removed.extend(rem) 278 newargs.append(rarg) 279 continue 280 elif getattr(arg, "is_Identity", False): 281 if arg.shape == (1, 1): 282 # Ignore identity matrices of shape (1, 1) - they are equivalent to scalar 1. 283 removed.extend(current_range) 284 continue 285 k = arg.shape[0] 286 if pending == k: 287 # OK, there is already 288 removed.extend(current_range) 289 continue 290 elif pending is None: 291 newargs.append(arg) 292 pending = k 293 prev_i = i 294 else: 295 pending = k 296 prev_i = i 297 newargs.append(arg) 298 elif arg.shape == (1, 1): 299 arg, _ = _remove_trivial_dims(arg) 300 # Matrix is equivalent to scalar: 301 if len(newargs) == 0: 302 newargs.append(arg) 303 elif 1 in get_shape(newargs[-1]): 304 if newargs[-1].shape[1] == 1: 305 newargs[-1] = newargs[-1]*arg 306 else: 307 newargs[-1] = arg*newargs[-1] 308 removed.extend(current_range) 309 else: 310 newargs.append(arg) 311 elif 1 in arg.shape: 312 k = [i for i in arg.shape if i != 1][0] 313 if pending is None: 314 pending = k 315 prev_i = i 316 newargs.append(arg) 317 elif pending == k: 318 prev = newargs[-1] 319 if prev.is_Identity: 320 removed.extend([cumul[prev_i], cumul[prev_i]+1]) 321 newargs[-1] = arg 322 prev_i = i 323 continue 324 if prev.shape[0] == 1: 325 d1 = cumul[prev_i] 326 prev = _a2m_transpose(prev) 327 else: 328 d1 = cumul[prev_i] + 1 329 if arg.shape[1] == 1: 330 d2 = cumul[i] + 1 331 arg = _a2m_transpose(arg) 332 else: 333 d2 = cumul[i] 334 newargs[-1] = prev*arg 335 pending = None 336 removed.extend([d1, d2]) 337 else: 338 newargs.append(arg) 339 pending = k 340 prev_i = i 341 else: 342 newargs.append(arg) 343 pending = None 344 return _a2m_tensor_product(*newargs), sorted(removed) 345 346 347@_remove_trivial_dims.register(ArrayAdd) 348def _(expr: ArrayAdd): 349 rec = [_remove_trivial_dims(arg) for arg in expr.args] 350 newargs, removed = zip(*rec) 351 if len(set(map(tuple, removed))) != 1: 352 return expr, [] 353 return _a2m_add(*newargs), removed[0] 354 355 356@_remove_trivial_dims.register(PermuteDims) 357def _(expr: PermuteDims): 358 subexpr, subremoved = _remove_trivial_dims(expr.expr) 359 p = expr.permutation.array_form 360 pinv = _af_invert(expr.permutation.array_form) 361 shift = list(accumulate([1 if i in subremoved else 0 for i in range(len(p))])) 362 premoved = [pinv[i] for i in subremoved] 363 p2 = [e - shift[e] for i, e in enumerate(p) if e not in subremoved] 364 # TODO: check if subremoved should be permuted as well... 365 newexpr = PermuteDims(subexpr, p2) 366 if newexpr != expr: 367 newexpr = _array2matrix(newexpr) 368 return newexpr, sorted(premoved) 369 370 371@_remove_trivial_dims.register(ArrayContraction) 372def _(expr: ArrayContraction): 373 newexpr, removed = _remove_trivial_dims(expr.expr) 374 shifts = list(accumulate([1 if i in removed else 0 for i in range(get_rank(expr.expr))])) 375 new_contraction_indices = [tuple(j for j in i if j not in removed) for i in expr.contraction_indices] 376 # Remove possible empty tuples "()": 377 new_contraction_indices = [i for i in new_contraction_indices if len(i) > 0] 378 contraction_indices_flat = [j for i in expr.contraction_indices for j in i] 379 removed = [i for i in removed if i not in contraction_indices_flat] 380 new_contraction_indices = [tuple(j - shifts[j] for j in i) for i in new_contraction_indices] 381 # Shift removed: 382 removed = ArrayContraction._push_indices_up(expr.contraction_indices, removed) 383 return ArrayContraction(newexpr, *new_contraction_indices), list(removed) 384 385 386@_remove_trivial_dims.register(ArrayDiagonal) 387def _(expr: ArrayDiagonal): 388 newexpr, removed = _remove_trivial_dims(expr.expr) 389 shifts = list(accumulate([0] + [1 if i in removed else 0 for i in range(get_rank(expr.expr))])) 390 new_diag_indices = [tuple(j for j in i if j not in removed) for i in expr.diagonal_indices] 391 new_diag_indices = [tuple(j - shifts[j] for j in i) for i in new_diag_indices] 392 rank = get_rank(expr.expr) 393 removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, rank) 394 removed = sorted({i for i in removed}) 395 # If there are single axes to diagonalize remaining, it means that their 396 # corresponding dimension has been removed, they no longer need diagonalization: 397 new_diag_indices = [i for i in new_diag_indices if len(i) > 1] 398 return ArrayDiagonal(newexpr, *new_diag_indices), removed 399 400 401@_remove_trivial_dims.register(ElementwiseApplyFunction) 402def _(expr: ElementwiseApplyFunction): 403 subexpr, removed = _remove_trivial_dims(expr.expr) 404 if subexpr.shape == (1, 1): 405 # TODO: move this to ElementwiseApplyFunction 406 return expr.function(subexpr), removed + [0, 1] 407 return ElementwiseApplyFunction(expr.function, subexpr) 408 409 410@_remove_trivial_dims.register(ArrayElementwiseApplyFunc) 411def _(expr: ArrayElementwiseApplyFunc): 412 subexpr, removed = _remove_trivial_dims(expr.expr) 413 return ArrayElementwiseApplyFunc(expr.function, subexpr), removed 414 415 416def convert_array_to_matrix(expr): 417 r""" 418 Recognize matrix expressions in codegen objects. 419 420 If more than one matrix multiplication line have been detected, return a 421 list with the matrix expressions. 422 423 Examples 424 ======== 425 426 >>> from sympy.tensor.array.expressions.conv_indexed_to_array import convert_indexed_to_array 427 >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct 428 >>> from sympy import MatrixSymbol, Sum 429 >>> from sympy.abc import i, j, k, l, N 430 >>> from sympy.tensor.array.expressions.array_expressions import ArrayContraction 431 >>> from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array 432 >>> from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix 433 >>> A = MatrixSymbol("A", N, N) 434 >>> B = MatrixSymbol("B", N, N) 435 >>> C = MatrixSymbol("C", N, N) 436 >>> D = MatrixSymbol("D", N, N) 437 438 >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1)) 439 >>> cg = convert_indexed_to_array(expr) 440 >>> convert_array_to_matrix(cg) 441 A*B 442 >>> cg = convert_indexed_to_array(expr, first_indices=[k]) 443 >>> convert_array_to_matrix(cg) 444 B.T*A.T 445 446 Transposition is detected: 447 448 >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1)) 449 >>> cg = convert_indexed_to_array(expr) 450 >>> convert_array_to_matrix(cg) 451 A.T*B 452 >>> cg = convert_indexed_to_array(expr, first_indices=[k]) 453 >>> convert_array_to_matrix(cg) 454 B.T*A 455 456 Detect the trace: 457 458 >>> expr = Sum(A[i, i], (i, 0, N-1)) 459 >>> cg = convert_indexed_to_array(expr) 460 >>> convert_array_to_matrix(cg) 461 Trace(A) 462 463 Recognize some more complex traces: 464 465 >>> expr = Sum(A[i, j]*B[j, i], (i, 0, N-1), (j, 0, N-1)) 466 >>> cg = convert_indexed_to_array(expr) 467 >>> convert_array_to_matrix(cg) 468 Trace(A*B) 469 470 More complicated expressions: 471 472 >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1)) 473 >>> cg = convert_indexed_to_array(expr) 474 >>> convert_array_to_matrix(cg) 475 A*B.T*A.T 476 477 Expressions constructed from matrix expressions do not contain literal 478 indices, the positions of free indices are returned instead: 479 480 >>> expr = A*B 481 >>> cg = convert_matrix_to_array(expr) 482 >>> convert_array_to_matrix(cg) 483 A*B 484 485 If more than one line of matrix multiplications is detected, return 486 separate matrix multiplication factors embedded in a tensor product object: 487 488 >>> cg = ArrayContraction(ArrayTensorProduct(A, B, C, D), (1, 2), (5, 6)) 489 >>> convert_array_to_matrix(cg) 490 ArrayTensorProduct(A*B, C*D) 491 492 The two lines have free indices at axes 0, 3 and 4, 7, respectively. 493 """ 494 rec = _array2matrix(expr) 495 rec, removed = _remove_trivial_dims(rec) 496 return rec 497 498 499def _array_diag2contr_diagmatrix(expr: ArrayDiagonal): 500 if isinstance(expr.expr, ArrayTensorProduct): 501 args = list(expr.expr.args) 502 diag_indices = list(expr.diagonal_indices) 503 mapping = _get_mapping_from_subranks([_get_subrank(arg) for arg in args]) 504 tuple_links = [[mapping[j] for j in i] for i in diag_indices] 505 contr_indices = [] 506 total_rank = get_rank(expr) 507 replaced = [False for arg in args] 508 for i, (abs_pos, rel_pos) in enumerate(zip(diag_indices, tuple_links)): 509 if len(abs_pos) != 2: 510 continue 511 (pos1_outer, pos1_inner), (pos2_outer, pos2_inner) = rel_pos 512 arg1 = args[pos1_outer] 513 arg2 = args[pos2_outer] 514 if get_rank(arg1) != 2 or get_rank(arg2) != 2: 515 if replaced[pos1_outer]: 516 diag_indices[i] = None 517 if replaced[pos2_outer]: 518 diag_indices[i] = None 519 continue 520 pos1_in2 = 1 - pos1_inner 521 pos2_in2 = 1 - pos2_inner 522 if arg1.shape[pos1_in2] == 1: 523 darg1 = DiagMatrix(arg1) 524 args.append(darg1) 525 contr_indices.append(((pos2_outer, pos2_inner), (len(args)-1, pos1_inner))) 526 total_rank += 1 527 diag_indices[i] = None 528 args[pos1_outer] = OneArray(arg1.shape[pos1_in2]) 529 replaced[pos1_outer] = True 530 elif arg2.shape[pos2_in2] == 1: 531 darg2 = DiagMatrix(arg2) 532 args.append(darg2) 533 contr_indices.append(((pos1_outer, pos1_inner), (len(args)-1, pos2_inner))) 534 total_rank += 1 535 diag_indices[i] = None 536 args[pos2_outer] = OneArray(arg2.shape[pos2_in2]) 537 replaced[pos2_outer] = True 538 diag_indices_new = [i for i in diag_indices if i is not None] 539 cumul = list(accumulate([0] + [get_rank(arg) for arg in args])) 540 contr_indices2 = [tuple(cumul[a] + b for a, b in i) for i in contr_indices] 541 tc = ArrayContraction( 542 ArrayTensorProduct(*args), *contr_indices2 543 ) 544 td = ArrayDiagonal(tc, *diag_indices_new) 545 return td 546 return expr 547 548 549def _a2m_mul(*args): 550 if all(not isinstance(i, _CodegenArrayAbstract) for i in args): 551 from sympy import MatMul 552 return MatMul(*args).doit() 553 else: 554 return ArrayContraction( 555 ArrayTensorProduct(*args), 556 *[(2*i-1, 2*i) for i in range(1, len(args))] 557 ) 558 559 560def _a2m_tensor_product(*args): 561 scalars = [] 562 arrays = [] 563 for arg in args: 564 if isinstance(arg, (MatrixExpr, _ArrayExpr, _CodegenArrayAbstract)): 565 arrays.append(arg) 566 else: 567 scalars.append(arg) 568 scalar = Mul.fromiter(scalars) 569 if len(arrays) == 0: 570 return scalar 571 if scalar != 1: 572 if isinstance(arrays[0], _CodegenArrayAbstract): 573 arrays = [scalar] + arrays 574 else: 575 arrays[0] *= scalar 576 return ArrayTensorProduct(*arrays) 577 578 579def _a2m_add(*args): 580 if all(not isinstance(i, _CodegenArrayAbstract) for i in args): 581 from sympy import MatAdd 582 return MatAdd(*args).doit() 583 else: 584 return ArrayAdd(*args) 585 586 587def _a2m_trace(arg): 588 if isinstance(arg, _CodegenArrayAbstract): 589 return ArrayContraction(arg, (0, 1)) 590 else: 591 from sympy import Trace 592 return Trace(arg) 593 594 595def _a2m_transpose(arg): 596 if isinstance(arg, _CodegenArrayAbstract): 597 return PermuteDims(arg, [1, 0]) 598 else: 599 from sympy import Transpose 600 return Transpose(arg).doit() 601 602 603def identify_hadamard_products(expr: Union[ArrayContraction, ArrayDiagonal]): 604 mapping = _get_mapping_from_subranks(expr.subranks) 605 606 editor: _EditArrayContraction 607 if isinstance(expr, ArrayContraction): 608 editor = _EditArrayContraction(expr) 609 elif isinstance(expr, ArrayDiagonal): 610 if isinstance(expr.expr, ArrayContraction): 611 editor = _EditArrayContraction(expr.expr) 612 diagonalized = ArrayContraction._push_indices_down(expr.expr.contraction_indices, expr.diagonal_indices) 613 elif isinstance(expr.expr, ArrayTensorProduct): 614 editor = _EditArrayContraction(None) 615 editor.args_with_ind = [_ArgE(arg) for i, arg in enumerate(expr.expr.args)] 616 diagonalized = expr.diagonal_indices 617 else: 618 return expr 619 620 # Trick: add diagonalized indices as negative indices into the editor object: 621 for i, e in enumerate(diagonalized): 622 for j in e: 623 arg_pos, rel_pos = mapping[j] 624 editor.args_with_ind[arg_pos].indices[rel_pos] = -1 - i 625 626 map_contr_to_args: Dict[FrozenSet, List[_ArgE]] = defaultdict(list) 627 map_ind_to_inds = defaultdict(int) 628 for arg_with_ind in editor.args_with_ind: 629 for ind in arg_with_ind.indices: 630 map_ind_to_inds[ind] += 1 631 if None in arg_with_ind.indices: 632 continue 633 map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind) 634 635 k: FrozenSet[int] 636 v: List[_ArgE] 637 for k, v in map_contr_to_args.items(): 638 make_trace: bool = False 639 if len(k) == 1 and next(iter(k)) >= 0 and sum([next(iter(k)) in i for i in map_contr_to_args]) == 1: 640 # This is a trace: the arguments are fully contracted with only one 641 # index, and the index isn't used anywhere else: 642 make_trace = True 643 first_element = S.One 644 elif len(k) != 2: 645 # Hadamard product only defined for matrices: 646 continue 647 if len(v) == 1: 648 # Hadamard product with a single argument makes no sense: 649 continue 650 for ind in k: 651 if map_ind_to_inds[ind] <= 2: 652 # There is no other contraction, skip: 653 continue 654 655 def check_transpose(x): 656 x = [i if i >= 0 else -1-i for i in x] 657 return x == sorted(x) 658 659 # Check if expression is a trace: 660 if all([map_ind_to_inds[j] == len(v) and j >= 0 for j in k]) and all([j >= 0 for j in k]): 661 # This is a trace 662 make_trace = True 663 first_element = v[0].element 664 if not check_transpose(v[0].indices): 665 first_element = first_element.T 666 hadamard_factors = v[1:] 667 else: 668 hadamard_factors = v 669 670 # This is a Hadamard product: 671 672 hp = hadamard_product(*[i.element if check_transpose(i.indices) else Transpose(i.element) for i in hadamard_factors]) 673 hp_indices = v[0].indices 674 if not check_transpose(hadamard_factors[0].indices): 675 hp_indices = list(reversed(hp_indices)) 676 if make_trace: 677 hp = Trace(first_element*hp.T)._normalize() 678 hp_indices = [] 679 editor.insert_after(v[0], _ArgE(hp, hp_indices)) 680 for i in v: 681 editor.args_with_ind.remove(i) 682 683 # Count the ranks of the arguments: 684 counter = 0 685 # Create a collector for the new diagonal indices: 686 diag_indices = defaultdict(list) 687 688 count_index_freq = Counter() 689 for arg_with_ind in editor.args_with_ind: 690 count_index_freq.update(Counter(arg_with_ind.indices)) 691 692 free_index_count = count_index_freq[None] 693 694 # Construct the inverse permutation: 695 inv_perm1 = [] 696 inv_perm2 = [] 697 # Keep track of which diagonal indices have already been processed: 698 done = set([]) 699 700 # Counter for the diagonal indices: 701 counter4 = 0 702 703 for arg_with_ind in editor.args_with_ind: 704 # If some diagonalization axes have been removed, they should be 705 # permuted in order to keep the permutation. 706 # Add permutation here 707 counter2 = 0 # counter for the indices 708 for i in arg_with_ind.indices: 709 if i is None: 710 inv_perm1.append(counter4) 711 counter2 += 1 712 counter4 += 1 713 continue 714 if i >= 0: 715 continue 716 # Reconstruct the diagonal indices: 717 diag_indices[-1 - i].append(counter + counter2) 718 if count_index_freq[i] == 1 and i not in done: 719 inv_perm1.append(free_index_count - 1 - i) 720 done.add(i) 721 elif i not in done: 722 inv_perm2.append(free_index_count - 1 - i) 723 done.add(i) 724 counter2 += 1 725 # Remove negative indices to restore a proper editor object: 726 arg_with_ind.indices = [i if i is not None and i >= 0 else None for i in arg_with_ind.indices] 727 counter += len([i for i in arg_with_ind.indices if i is None or i < 0]) 728 729 inverse_permutation = inv_perm1 + inv_perm2 730 permutation = _af_invert(inverse_permutation) 731 732 if isinstance(expr, ArrayContraction): 733 return editor.to_array_contraction() 734 else: 735 # Get the diagonal indices after the detection of HadamardProduct in the expression: 736 diag_indices_filtered = [tuple(v) for v in diag_indices.values() if len(v) > 1] 737 738 expr1 = editor.to_array_contraction() 739 expr2 = ArrayDiagonal(expr1, *diag_indices_filtered) 740 expr3 = PermuteDims(expr2, permutation) 741 return expr3 742