1import operator 2from functools import reduce 3import itertools 4from itertools import accumulate 5from typing import Optional, List, Dict 6 7from sympy import Expr, ImmutableDenseNDimArray, S, Symbol, Integer, ZeroMatrix, Basic, tensorproduct, Add, permutedims, \ 8 Tuple, tensordiagonal, Lambda, Dummy, Function, MatrixExpr, NDimArray, Indexed, IndexedBase, default_sort_key, \ 9 tensorcontraction, diagonalize_vector, Mul 10from sympy.matrices.expressions.matexpr import MatrixElement 11from sympy.tensor.array.expressions.utils import _apply_recursively_over_nested_lists, _sort_contraction_indices, \ 12 _get_mapping_from_subranks, _build_push_indices_up_func_transformation, _get_contraction_links, \ 13 _build_push_indices_down_func_transformation 14from sympy.combinatorics import Permutation 15from sympy.combinatorics.permutations import _af_invert 16from sympy.core.sympify import _sympify 17 18 19class _ArrayExpr(Expr): 20 pass 21 22 23class ArraySymbol(_ArrayExpr): 24 """ 25 Symbol representing an array expression 26 """ 27 28 def __new__(cls, symbol, *shape): 29 if isinstance(symbol, str): 30 symbol = Symbol(symbol) 31 # symbol = _sympify(symbol) 32 shape = map(_sympify, shape) 33 obj = Expr.__new__(cls, symbol, *shape) 34 return obj 35 36 @property 37 def name(self): 38 return self._args[0] 39 40 @property 41 def shape(self): 42 return self._args[1:] 43 44 def __getitem__(self, item): 45 return ArrayElement(self, item) 46 47 def as_explicit(self): 48 if any(not isinstance(i, (int, Integer)) for i in self.shape): 49 raise ValueError("cannot express explicit array with symbolic shape") 50 data = [self[i] for i in itertools.product(*[range(j) for j in self.shape])] 51 return ImmutableDenseNDimArray(data).reshape(*self.shape) 52 53 54class ArrayElement(_ArrayExpr): 55 """ 56 An element of an array. 57 """ 58 def __new__(cls, name, indices): 59 if isinstance(name, str): 60 name = Symbol(name) 61 name = _sympify(name) 62 indices = _sympify(indices) 63 if hasattr(name, "shape"): 64 if any([(i >= s) == True for i, s in zip(indices, name.shape)]): 65 raise ValueError("shape is out of bounds") 66 if any([(i < 0) == True for i in indices]): 67 raise ValueError("shape contains negative values") 68 obj = Expr.__new__(cls, name, indices) 69 return obj 70 71 @property 72 def name(self): 73 return self._args[0] 74 75 @property 76 def indices(self): 77 return self._args[1] 78 79 80class ZeroArray(_ArrayExpr): 81 """ 82 Symbolic array of zeros. Equivalent to ``ZeroMatrix`` for matrices. 83 """ 84 85 def __new__(cls, *shape): 86 if len(shape) == 0: 87 return S.Zero 88 shape = map(_sympify, shape) 89 obj = Expr.__new__(cls, *shape) 90 return obj 91 92 @property 93 def shape(self): 94 return self._args 95 96 def as_explicit(self): 97 if any(not i.is_Integer for i in self.shape): 98 raise ValueError("Cannot return explicit form for symbolic shape.") 99 return ImmutableDenseNDimArray.zeros(*self.shape) 100 101 102class OneArray(_ArrayExpr): 103 """ 104 Symbolic array of ones. 105 """ 106 107 def __new__(cls, *shape): 108 if len(shape) == 0: 109 return S.One 110 shape = map(_sympify, shape) 111 obj = Expr.__new__(cls, *shape) 112 return obj 113 114 @property 115 def shape(self): 116 return self._args 117 118 def as_explicit(self): 119 if any(not i.is_Integer for i in self.shape): 120 raise ValueError("Cannot return explicit form for symbolic shape.") 121 return ImmutableDenseNDimArray([S.One for i in range(reduce(operator.mul, self.shape))]).reshape(*self.shape) 122 123 124class _CodegenArrayAbstract(Basic): 125 126 @property 127 def subranks(self): 128 """ 129 Returns the ranks of the objects in the uppermost tensor product inside 130 the current object. In case no tensor products are contained, return 131 the atomic ranks. 132 133 Examples 134 ======== 135 136 >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction 137 >>> from sympy import MatrixSymbol 138 >>> M = MatrixSymbol("M", 3, 3) 139 >>> N = MatrixSymbol("N", 3, 3) 140 >>> P = MatrixSymbol("P", 3, 3) 141 142 Important: do not confuse the rank of the matrix with the rank of an array. 143 144 >>> tp = ArrayTensorProduct(M, N, P) 145 >>> tp.subranks 146 [2, 2, 2] 147 148 >>> co = ArrayContraction(tp, (1, 2), (3, 4)) 149 >>> co.subranks 150 [2, 2, 2] 151 """ 152 return self._subranks[:] 153 154 def subrank(self): 155 """ 156 The sum of ``subranks``. 157 """ 158 return sum(self.subranks) 159 160 @property 161 def shape(self): 162 return self._shape 163 164 165class ArrayTensorProduct(_CodegenArrayAbstract): 166 r""" 167 Class to represent the tensor product of array-like objects. 168 """ 169 170 def __new__(cls, *args): 171 args = [_sympify(arg) for arg in args] 172 args = cls._flatten(args) 173 ranks = [get_rank(arg) for arg in args] 174 175 # Check if there are nested permutation and lift them up: 176 permutation_cycles = [] 177 for i, arg in enumerate(args): 178 if not isinstance(arg, PermuteDims): 179 continue 180 permutation_cycles.extend([[k + sum(ranks[:i]) for k in j] for j in arg.permutation.cyclic_form]) 181 args[i] = arg.expr 182 if permutation_cycles: 183 return PermuteDims(ArrayTensorProduct(*args), Permutation(sum(ranks)-1)*Permutation(permutation_cycles)) 184 185 if len(args) == 1: 186 return args[0] 187 188 # If any object is a ZeroArray, return a ZeroArray: 189 if any(isinstance(arg, (ZeroArray, ZeroMatrix)) for arg in args): 190 shapes = reduce(operator.add, [get_shape(i) for i in args], ()) 191 return ZeroArray(*shapes) 192 193 # If there are contraction objects inside, transform the whole 194 # expression into `ArrayContraction`: 195 contractions = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayContraction)} 196 if contractions: 197 ranks = [_get_subrank(arg) if isinstance(arg, ArrayContraction) else get_rank(arg) for arg in args] 198 cumulative_ranks = list(accumulate([0] + ranks))[:-1] 199 tp = cls(*[arg.expr if isinstance(arg, ArrayContraction) else arg for arg in args]) 200 contraction_indices = [tuple(cumulative_ranks[i] + k for k in j) for i, arg in contractions.items() for j in arg.contraction_indices] 201 return ArrayContraction(tp, *contraction_indices) 202 203 diagonals = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayDiagonal)} 204 if diagonals: 205 permutation = [] 206 last_perm = [] 207 ranks = [get_rank(arg) for arg in args] 208 cumulative_ranks = list(accumulate([0] + ranks))[:-1] 209 for i, arg in enumerate(args): 210 if isinstance(arg, ArrayDiagonal): 211 i1 = get_rank(arg) - len(arg.diagonal_indices) 212 i2 = len(arg.diagonal_indices) 213 permutation.extend([cumulative_ranks[i] + j for j in range(i1)]) 214 last_perm.extend([cumulative_ranks[i] + j for j in range(i1, i1 + i2)]) 215 else: 216 permutation.extend([cumulative_ranks[i] + j for j in range(get_rank(arg))]) 217 permutation.extend(last_perm) 218 tp = cls(*[arg.expr if isinstance(arg, ArrayDiagonal) else arg for arg in args]) 219 ranks2 = [_get_subrank(arg) if isinstance(arg, ArrayDiagonal) else get_rank(arg) for arg in args] 220 cumulative_ranks2 = list(accumulate([0] + ranks2))[:-1] 221 diagonal_indices = [tuple(cumulative_ranks2[i] + k for k in j) for i, arg in diagonals.items() for j in arg.diagonal_indices] 222 return PermuteDims(ArrayDiagonal(tp, *diagonal_indices), permutation) 223 224 obj = Basic.__new__(cls, *args) 225 obj._subranks = ranks 226 shapes = [get_shape(i) for i in args] 227 228 if any(i is None for i in shapes): 229 obj._shape = None 230 else: 231 obj._shape = tuple(j for i in shapes for j in i) 232 return obj 233 234 @classmethod 235 def _flatten(cls, args): 236 args = [i for arg in args for i in (arg.args if isinstance(arg, cls) else [arg])] 237 return args 238 239 def as_explicit(self): 240 return tensorproduct(*[arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args]) 241 242 243class ArrayAdd(_CodegenArrayAbstract): 244 r""" 245 Class for elementwise array additions. 246 """ 247 248 def __new__(cls, *args): 249 args = [_sympify(arg) for arg in args] 250 ranks = [get_rank(arg) for arg in args] 251 ranks = list(set(ranks)) 252 if len(ranks) != 1: 253 raise ValueError("summing arrays of different ranks") 254 shapes = [arg.shape for arg in args] 255 if len({i for i in shapes if i is not None}) > 1: 256 raise ValueError("mismatching shapes in addition") 257 258 # Flatten: 259 args = cls._flatten_args(args) 260 261 args = [arg for arg in args if not isinstance(arg, (ZeroArray, ZeroMatrix))] 262 if len(args) == 0: 263 if any(i for i in shapes if i is None): 264 raise NotImplementedError("cannot handle addition of ZeroMatrix/ZeroArray and undefined shape object") 265 return ZeroArray(*shapes[0]) 266 elif len(args) == 1: 267 return args[0] 268 269 obj = Basic.__new__(cls, *args) 270 obj._subranks = ranks 271 if any(i is None for i in shapes): 272 obj._shape = None 273 else: 274 obj._shape = shapes[0] 275 return obj 276 277 @classmethod 278 def _flatten_args(cls, args): 279 new_args = [] 280 for arg in args: 281 if isinstance(arg, ArrayAdd): 282 new_args.extend(arg.args) 283 else: 284 new_args.append(arg) 285 return new_args 286 287 def as_explicit(self): 288 return Add.fromiter([arg.as_explicit() for arg in self.args]) 289 290 291class PermuteDims(_CodegenArrayAbstract): 292 r""" 293 Class to represent permutation of axes of arrays. 294 295 Examples 296 ======== 297 298 >>> from sympy.tensor.array.expressions.array_expressions import PermuteDims 299 >>> from sympy import MatrixSymbol 300 >>> M = MatrixSymbol("M", 3, 3) 301 >>> cg = PermuteDims(M, [1, 0]) 302 303 The object ``cg`` represents the transposition of ``M``, as the permutation 304 ``[1, 0]`` will act on its indices by switching them: 305 306 `M_{ij} \Rightarrow M_{ji}` 307 308 This is evident when transforming back to matrix form: 309 310 >>> from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix 311 >>> convert_array_to_matrix(cg) 312 M.T 313 314 >>> N = MatrixSymbol("N", 3, 2) 315 >>> cg = PermuteDims(N, [1, 0]) 316 >>> cg.shape 317 (2, 3) 318 319 Permutations of tensor products are simplified in order to achieve a 320 standard form: 321 322 >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct 323 >>> M = MatrixSymbol("M", 4, 5) 324 >>> tp = ArrayTensorProduct(M, N) 325 >>> tp.shape 326 (4, 5, 3, 2) 327 >>> perm1 = PermuteDims(tp, [2, 3, 1, 0]) 328 329 The args ``(M, N)`` have been sorted and the permutation has been 330 simplified, the expression is equivalent: 331 332 >>> perm1.expr.args 333 (N, M) 334 >>> perm1.shape 335 (3, 2, 5, 4) 336 >>> perm1.permutation 337 (2 3) 338 339 The permutation in its array form has been simplified from 340 ``[2, 3, 1, 0]`` to ``[0, 1, 3, 2]``, as the arguments of the tensor 341 product `M` and `N` have been switched: 342 343 >>> perm1.permutation.array_form 344 [0, 1, 3, 2] 345 346 We can nest a second permutation: 347 348 >>> perm2 = PermuteDims(perm1, [1, 0, 2, 3]) 349 >>> perm2.shape 350 (2, 3, 5, 4) 351 >>> perm2.permutation.array_form 352 [1, 0, 3, 2] 353 """ 354 355 def __new__(cls, expr, permutation, nest_permutation=True): 356 from sympy.combinatorics import Permutation 357 expr = _sympify(expr) 358 permutation = Permutation(permutation) 359 permutation_size = permutation.size 360 expr_rank = get_rank(expr) 361 if permutation_size != expr_rank: 362 raise ValueError("Permutation size must be the length of the shape of expr") 363 if isinstance(expr, PermuteDims): 364 subexpr = expr.expr 365 subperm = expr.permutation 366 permutation = permutation * subperm 367 expr = subexpr 368 if isinstance(expr, ArrayContraction): 369 expr, permutation = cls._handle_nested_contraction(expr, permutation) 370 if isinstance(expr, ArrayTensorProduct): 371 expr, permutation = cls._sort_components(expr, permutation) 372 if isinstance(expr, (ZeroArray, ZeroMatrix)): 373 return ZeroArray(*[expr.shape[i] for i in permutation.array_form]) 374 plist = permutation.array_form 375 if plist == sorted(plist): 376 return expr 377 obj = Basic.__new__(cls, expr, permutation) 378 obj._subranks = [get_rank(expr)] 379 shape = expr.shape 380 if shape is None: 381 obj._shape = None 382 else: 383 obj._shape = tuple(shape[permutation(i)] for i in range(len(shape))) 384 return obj 385 386 @property 387 def expr(self): 388 return self.args[0] 389 390 @property 391 def permutation(self): 392 return self.args[1] 393 394 @classmethod 395 def _sort_components(cls, expr, permutation): 396 # Get the permutation in its image-form: 397 perm_image_form = _af_invert(permutation.array_form) 398 args = list(expr.args) 399 # Starting index global position for every arg: 400 cumul = list(accumulate([0] + expr.subranks)) 401 # Split `perm_image_form` into a list of list corresponding to the indices 402 # of every argument: 403 perm_image_form_in_components = [perm_image_form[cumul[i]:cumul[i+1]] for i in range(len(args))] 404 # Create an index, target-position-key array: 405 ps = [(i, sorted(comp)) for i, comp in enumerate(perm_image_form_in_components)] 406 # Sort the array according to the target-position-key: 407 # In this way, we define a canonical way to sort the arguments according 408 # to the permutation. 409 ps.sort(key=lambda x: x[1]) 410 # Read the inverse-permutation (i.e. image-form) of the args: 411 perm_args_image_form = [i[0] for i in ps] 412 # Apply the args-permutation to the `args`: 413 args_sorted = [args[i] for i in perm_args_image_form] 414 # Apply the args-permutation to the array-form of the permutation of the axes (of `expr`): 415 perm_image_form_sorted_args = [perm_image_form_in_components[i] for i in perm_args_image_form] 416 new_permutation = Permutation(_af_invert([j for i in perm_image_form_sorted_args for j in i])) 417 return ArrayTensorProduct(*args_sorted), new_permutation 418 419 @classmethod 420 def _handle_nested_contraction(cls, expr, permutation): 421 if not isinstance(expr, ArrayContraction): 422 return expr, permutation 423 if not isinstance(expr.expr, ArrayTensorProduct): 424 return expr, permutation 425 args = expr.expr.args 426 subranks = [get_rank(arg) for arg in expr.expr.args] 427 428 contraction_indices = expr.contraction_indices 429 contraction_indices_flat = [j for i in contraction_indices for j in i] 430 cumul = list(accumulate([0] + subranks)) 431 432 # Spread the permutation in its array form across the args in the corresponding 433 # tensor-product arguments with free indices: 434 permutation_array_blocks_up = [] 435 image_form = _af_invert(permutation.array_form) 436 counter = 0 437 for i, e in enumerate(subranks): 438 current = [] 439 for j in range(cumul[i], cumul[i+1]): 440 if j in contraction_indices_flat: 441 continue 442 current.append(image_form[counter]) 443 counter += 1 444 permutation_array_blocks_up.append(current) 445 446 # Get the map of axis repositioning for every argument of tensor-product: 447 index_blocks = [[j for j in range(cumul[i], cumul[i+1])] for i, e in enumerate(expr.subranks)] 448 index_blocks_up = expr._push_indices_up(expr.contraction_indices, index_blocks) 449 inverse_permutation = permutation**(-1) 450 index_blocks_up_permuted = [[inverse_permutation(j) for j in i if j is not None] for i in index_blocks_up] 451 452 # Sorting key is a list of tuple, first element is the index of `args`, second element of 453 # the tuple is the sorting key to sort `args` of the tensor product: 454 sorting_keys = list(enumerate(index_blocks_up_permuted)) 455 sorting_keys.sort(key=lambda x: x[1]) 456 457 # Now we can get the permutation acting on the args in its image-form: 458 new_perm_image_form = [i[0] for i in sorting_keys] 459 # Apply the args-level permutation to various elements: 460 new_index_blocks = [index_blocks[i] for i in new_perm_image_form] 461 new_index_perm_array_form = _af_invert([j for i in new_index_blocks for j in i]) 462 new_args = [args[i] for i in new_perm_image_form] 463 new_contraction_indices = [tuple(new_index_perm_array_form[j] for j in i) for i in contraction_indices] 464 new_expr = ArrayContraction(ArrayTensorProduct(*new_args), *new_contraction_indices) 465 new_permutation = Permutation(_af_invert([j for i in [permutation_array_blocks_up[k] for k in new_perm_image_form] for j in i])) 466 return new_expr, new_permutation 467 468 @classmethod 469 def _check_permutation_mapping(cls, expr, permutation): 470 subranks = expr.subranks 471 index2arg = [i for i, arg in enumerate(expr.args) for j in range(expr.subranks[i])] 472 permuted_indices = [permutation(i) for i in range(expr.subrank())] 473 new_args = list(expr.args) 474 arg_candidate_index = index2arg[permuted_indices[0]] 475 current_indices = [] 476 new_permutation = [] 477 inserted_arg_cand_indices = set([]) 478 for i, idx in enumerate(permuted_indices): 479 if index2arg[idx] != arg_candidate_index: 480 new_permutation.extend(current_indices) 481 current_indices = [] 482 arg_candidate_index = index2arg[idx] 483 current_indices.append(idx) 484 arg_candidate_rank = subranks[arg_candidate_index] 485 if len(current_indices) == arg_candidate_rank: 486 new_permutation.extend(sorted(current_indices)) 487 local_current_indices = [j - min(current_indices) for j in current_indices] 488 i1 = index2arg[i] 489 new_args[i1] = PermuteDims(new_args[i1], Permutation(local_current_indices)) 490 inserted_arg_cand_indices.add(arg_candidate_index) 491 current_indices = [] 492 new_permutation.extend(current_indices) 493 494 # TODO: swap args positions in order to simplify the expression: 495 # TODO: this should be in a function 496 args_positions = list(range(len(new_args))) 497 # Get possible shifts: 498 maps = {} 499 cumulative_subranks = [0] + list(accumulate(subranks)) 500 for i in range(0, len(subranks)): 501 s = set([index2arg[new_permutation[j]] for j in range(cumulative_subranks[i], cumulative_subranks[i+1])]) 502 if len(s) != 1: 503 continue 504 elem = next(iter(s)) 505 if i != elem: 506 maps[i] = elem 507 508 # Find cycles in the map: 509 lines = [] 510 current_line = [] 511 while maps: 512 if len(current_line) == 0: 513 k, v = maps.popitem() 514 current_line.append(k) 515 else: 516 k = current_line[-1] 517 if k not in maps: 518 current_line = [] 519 continue 520 v = maps.pop(k) 521 if v in current_line: 522 lines.append(current_line) 523 current_line = [] 524 continue 525 current_line.append(v) 526 for line in lines: 527 for i, e in enumerate(line): 528 args_positions[line[(i + 1) % len(line)]] = e 529 530 # TODO: function in order to permute the args: 531 permutation_blocks = [[new_permutation[cumulative_subranks[i] + j] for j in range(e)] for i, e in enumerate(subranks)] 532 new_args = [new_args[i] for i in args_positions] 533 new_permutation_blocks = [permutation_blocks[i] for i in args_positions] 534 new_permutation2 = [j for i in new_permutation_blocks for j in i] 535 return ArrayTensorProduct(*new_args), Permutation(new_permutation2) # **(-1) 536 537 @classmethod 538 def _check_if_there_are_closed_cycles(cls, expr, permutation): 539 args = list(expr.args) 540 subranks = expr.subranks 541 cyclic_form = permutation.cyclic_form 542 cumulative_subranks = [0] + list(accumulate(subranks)) 543 cyclic_min = [min(i) for i in cyclic_form] 544 cyclic_max = [max(i) for i in cyclic_form] 545 cyclic_keep = [] 546 for i, cycle in enumerate(cyclic_form): 547 flag = True 548 for j in range(0, len(cumulative_subranks) - 1): 549 if cyclic_min[i] >= cumulative_subranks[j] and cyclic_max[i] < cumulative_subranks[j+1]: 550 # Found a sinkable cycle. 551 args[j] = PermuteDims(args[j], Permutation([[k - cumulative_subranks[j] for k in cyclic_form[i]]])) 552 flag = False 553 break 554 if flag: 555 cyclic_keep.append(cyclic_form[i]) 556 return ArrayTensorProduct(*args), Permutation(cyclic_keep, size=permutation.size) 557 558 def nest_permutation(self): 559 r""" 560 DEPRECATED. 561 """ 562 ret = self._nest_permutation(self.expr, self.permutation) 563 if ret is None: 564 return self 565 return ret 566 567 @classmethod 568 def _nest_permutation(cls, expr, permutation): 569 if isinstance(expr, ArrayTensorProduct): 570 return PermuteDims(*cls._check_if_there_are_closed_cycles(expr, permutation)) 571 elif isinstance(expr, ArrayContraction): 572 # Invert tree hierarchy: put the contraction above. 573 cycles = permutation.cyclic_form 574 newcycles = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *cycles) 575 newpermutation = Permutation(newcycles) 576 new_contr_indices = [tuple(newpermutation(j) for j in i) for i in expr.contraction_indices] 577 return ArrayContraction(PermuteDims(expr.expr, newpermutation), *new_contr_indices) 578 elif isinstance(expr, ArrayAdd): 579 return ArrayAdd(*[PermuteDims(arg, permutation) for arg in expr.args]) 580 return None 581 582 def as_explicit(self): 583 return permutedims(self.expr.as_explicit(), self.permutation) 584 585 586class ArrayDiagonal(_CodegenArrayAbstract): 587 r""" 588 Class to represent the diagonal operator. 589 590 Explanation 591 =========== 592 593 In a 2-dimensional array it returns the diagonal, this looks like the 594 operation: 595 596 `A_{ij} \rightarrow A_{ii}` 597 598 The diagonal over axes 1 and 2 (the second and third) of the tensor product 599 of two 2-dimensional arrays `A \otimes B` is 600 601 `\Big[ A_{ab} B_{cd} \Big]_{abcd} \rightarrow \Big[ A_{ai} B_{id} \Big]_{adi}` 602 603 In this last example the array expression has been reduced from 604 4-dimensional to 3-dimensional. Notice that no contraction has occurred, 605 rather there is a new index `i` for the diagonal, contraction would have 606 reduced the array to 2 dimensions. 607 608 Notice that the diagonalized out dimensions are added as new dimensions at 609 the end of the indices. 610 """ 611 612 def __new__(cls, expr, *diagonal_indices): 613 expr = _sympify(expr) 614 diagonal_indices = [Tuple(*sorted(i)) for i in diagonal_indices] 615 if isinstance(expr, ArrayAdd): 616 return ArrayAdd(*[ArrayDiagonal(arg, *diagonal_indices) for arg in expr.args]) 617 if isinstance(expr, ArrayDiagonal): 618 return cls._flatten(expr, *diagonal_indices) 619 if isinstance(expr, PermuteDims): 620 return cls._handle_nested_permutedims_in_diag(expr, *diagonal_indices) 621 shape = get_shape(expr) 622 if shape is not None: 623 cls._validate(expr, *diagonal_indices) 624 # Get new shape: 625 positions, shape = cls._get_positions_shape(shape, diagonal_indices) 626 else: 627 positions = None 628 if len(diagonal_indices) == 0: 629 return expr 630 if isinstance(expr, (ZeroArray, ZeroMatrix)): 631 return ZeroArray(*shape) 632 obj = Basic.__new__(cls, expr, *diagonal_indices) 633 obj._positions = positions 634 obj._subranks = _get_subranks(expr) 635 obj._shape = shape 636 return obj 637 638 @staticmethod 639 def _validate(expr, *diagonal_indices): 640 # Check that no diagonalization happens on indices with mismatched 641 # dimensions: 642 shape = get_shape(expr) 643 for i in diagonal_indices: 644 if len({shape[j] for j in i}) != 1: 645 raise ValueError("diagonalizing indices of different dimensions") 646 if len(i) <= 1: 647 raise ValueError("need at least two axes to diagonalize") 648 649 @staticmethod 650 def _remove_trivial_dimensions(shape, *diagonal_indices): 651 return [tuple(j for j in i) for i in diagonal_indices if shape[i[0]] != 1] 652 653 @property 654 def expr(self): 655 return self.args[0] 656 657 @property 658 def diagonal_indices(self): 659 return self.args[1:] 660 661 @staticmethod 662 def _flatten(expr, *outer_diagonal_indices): 663 inner_diagonal_indices = expr.diagonal_indices 664 all_inner = [j for i in inner_diagonal_indices for j in i] 665 all_inner.sort() 666 # TODO: add API for total rank and cumulative rank: 667 total_rank = _get_subrank(expr) 668 inner_rank = len(all_inner) 669 outer_rank = total_rank - inner_rank 670 shifts = [0 for i in range(outer_rank)] 671 counter = 0 672 pointer = 0 673 for i in range(outer_rank): 674 while pointer < inner_rank and counter >= all_inner[pointer]: 675 counter += 1 676 pointer += 1 677 shifts[i] += pointer 678 counter += 1 679 outer_diagonal_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_diagonal_indices) 680 diagonal_indices = inner_diagonal_indices + outer_diagonal_indices 681 return ArrayDiagonal(expr.expr, *diagonal_indices) 682 683 @classmethod 684 def _handle_nested_permutedims_in_diag(cls, expr: PermuteDims, *diagonal_indices): 685 back_diagonal_indices = [[expr.permutation(j) for j in i] for i in diagonal_indices] 686 nondiag = [i for i in range(get_rank(expr)) if not any(i in j for j in diagonal_indices)] 687 back_nondiag = [expr.permutation(i) for i in nondiag] 688 remap = {e: i for i, e in enumerate(sorted(back_nondiag))} 689 new_permutation1 = [remap[i] for i in back_nondiag] 690 shift = len(new_permutation1) 691 diag_block_perm = [i + shift for i in range(len(back_diagonal_indices))] 692 new_permutation = new_permutation1 + diag_block_perm 693 return PermuteDims( 694 ArrayDiagonal( 695 expr.expr, 696 *back_diagonal_indices 697 ), 698 new_permutation 699 ) 700 701 def _push_indices_down_nonstatic(self, indices): 702 transform = lambda x: self._positions[x] if x < len(self._positions) else None 703 return _apply_recursively_over_nested_lists(transform, indices) 704 705 def _push_indices_up_nonstatic(self, indices): 706 707 def transform(x): 708 for i, e in enumerate(self._positions): 709 if (isinstance(e, int) and x == e) or (isinstance(e, tuple) and x in e): 710 return i 711 712 return _apply_recursively_over_nested_lists(transform, indices) 713 714 @classmethod 715 def _push_indices_down(cls, diagonal_indices, indices, rank): 716 positions, shape = cls._get_positions_shape(range(rank), diagonal_indices) 717 transform = lambda x: positions[x] if x < len(positions) else None 718 return _apply_recursively_over_nested_lists(transform, indices) 719 720 @classmethod 721 def _push_indices_up(cls, diagonal_indices, indices, rank): 722 positions, shape = cls._get_positions_shape(range(rank), diagonal_indices) 723 724 def transform(x): 725 for i, e in enumerate(positions): 726 if (isinstance(e, int) and x == e) or (isinstance(e, (tuple, Tuple)) and (x in e)): 727 return i 728 729 return _apply_recursively_over_nested_lists(transform, indices) 730 731 @classmethod 732 def _get_positions_shape(cls, shape, diagonal_indices): 733 data1 = tuple((i, shp) for i, shp in enumerate(shape) if not any(i in j for j in diagonal_indices)) 734 pos1, shp1 = zip(*data1) if data1 else ((), ()) 735 data2 = tuple((i, shape[i[0]]) for i in diagonal_indices) 736 pos2, shp2 = zip(*data2) if data2 else ((), ()) 737 positions = pos1 + pos2 738 shape = shp1 + shp2 739 return positions, shape 740 741 def as_explicit(self): 742 return tensordiagonal(self.expr.as_explicit(), *self.diagonal_indices) 743 744 745class ArrayElementwiseApplyFunc(_CodegenArrayAbstract): 746 747 def __new__(cls, function, element): 748 749 if not isinstance(function, Lambda): 750 d = Dummy('d') 751 function = Lambda(d, function(d)) 752 753 obj = _CodegenArrayAbstract.__new__(cls, function, element) 754 obj._subranks = _get_subranks(element) 755 return obj 756 757 @property 758 def function(self): 759 return self.args[0] 760 761 @property 762 def expr(self): 763 return self.args[1] 764 765 @property 766 def shape(self): 767 return self.expr.shape 768 769 def _get_function_fdiff(self): 770 d = Dummy("d") 771 function = self.function(d) 772 fdiff = function.diff(d) 773 if isinstance(fdiff, Function): 774 fdiff = type(fdiff) 775 else: 776 fdiff = Lambda(d, fdiff) 777 return fdiff 778 779 780class ArrayContraction(_CodegenArrayAbstract): 781 r""" 782 This class is meant to represent contractions of arrays in a form easily 783 processable by the code printers. 784 """ 785 786 def __new__(cls, expr, *contraction_indices, **kwargs): 787 contraction_indices = _sort_contraction_indices(contraction_indices) 788 expr = _sympify(expr) 789 790 if len(contraction_indices) == 0: 791 return expr 792 793 if isinstance(expr, ArrayContraction): 794 return cls._flatten(expr, *contraction_indices) 795 796 if isinstance(expr, (ZeroArray, ZeroMatrix)): 797 contraction_indices_flat = [j for i in contraction_indices for j in i] 798 shape = [e for i, e in enumerate(expr.shape) if i not in contraction_indices_flat] 799 return ZeroArray(*shape) 800 801 if isinstance(expr, PermuteDims): 802 return cls._handle_nested_permute_dims(expr, *contraction_indices) 803 804 if isinstance(expr, ArrayTensorProduct): 805 expr, contraction_indices = cls._sort_fully_contracted_args(expr, contraction_indices) 806 expr, contraction_indices = cls._lower_contraction_to_addends(expr, contraction_indices) 807 if len(contraction_indices) == 0: 808 return expr 809 810 if isinstance(expr, ArrayDiagonal): 811 return cls._handle_nested_diagonal(expr, *contraction_indices) 812 813 if isinstance(expr, ArrayAdd): 814 return ArrayAdd(*[ArrayContraction(i, *contraction_indices) for i in expr.args]) 815 816 obj = Basic.__new__(cls, expr, *contraction_indices) 817 obj._subranks = _get_subranks(expr) 818 obj._mapping = _get_mapping_from_subranks(obj._subranks) 819 820 free_indices_to_position = {i: i for i in range(sum(obj._subranks)) if all([i not in cind for cind in contraction_indices])} 821 obj._free_indices_to_position = free_indices_to_position 822 823 shape = expr.shape 824 cls._validate(expr, *contraction_indices) 825 if shape: 826 shape = tuple(shp for i, shp in enumerate(shape) if not any(i in j for j in contraction_indices)) 827 obj._shape = shape 828 return obj 829 830 def __mul__(self, other): 831 if other == 1: 832 return self 833 else: 834 raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.") 835 836 def __rmul__(self, other): 837 if other == 1: 838 return self 839 else: 840 raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.") 841 842 @staticmethod 843 def _validate(expr, *contraction_indices): 844 shape = expr.shape 845 if shape is None: 846 return 847 848 # Check that no contraction happens when the shape is mismatched: 849 for i in contraction_indices: 850 if len({shape[j] for j in i if shape[j] != -1}) != 1: 851 raise ValueError("contracting indices of different dimensions") 852 853 @classmethod 854 def _push_indices_down(cls, contraction_indices, indices): 855 flattened_contraction_indices = [j for i in contraction_indices for j in i] 856 flattened_contraction_indices.sort() 857 transform = _build_push_indices_down_func_transformation(flattened_contraction_indices) 858 return _apply_recursively_over_nested_lists(transform, indices) 859 860 @classmethod 861 def _push_indices_up(cls, contraction_indices, indices): 862 flattened_contraction_indices = [j for i in contraction_indices for j in i] 863 flattened_contraction_indices.sort() 864 transform = _build_push_indices_up_func_transformation(flattened_contraction_indices) 865 return _apply_recursively_over_nested_lists(transform, indices) 866 867 @classmethod 868 def _lower_contraction_to_addends(cls, expr, contraction_indices): 869 if isinstance(expr, ArrayAdd): 870 raise NotImplementedError() 871 if not isinstance(expr, ArrayTensorProduct): 872 return expr, contraction_indices 873 subranks = expr.subranks 874 cumranks = list(accumulate([0] + subranks)) 875 contraction_indices_remaining = [] 876 contraction_indices_args = [[] for i in expr.args] 877 backshift = set([]) 878 for i, contraction_group in enumerate(contraction_indices): 879 for j in range(len(expr.args)): 880 if not isinstance(expr.args[j], ArrayAdd): 881 continue 882 if all(cumranks[j] <= k < cumranks[j+1] for k in contraction_group): 883 contraction_indices_args[j].append([k - cumranks[j] for k in contraction_group]) 884 backshift.update(contraction_group) 885 break 886 else: 887 contraction_indices_remaining.append(contraction_group) 888 if len(contraction_indices_remaining) == len(contraction_indices): 889 return expr, contraction_indices 890 total_rank = get_rank(expr) 891 shifts = list(accumulate([1 if i in backshift else 0 for i in range(total_rank)])) 892 contraction_indices_remaining = [Tuple.fromiter(j - shifts[j] for j in i) for i in contraction_indices_remaining] 893 ret = ArrayTensorProduct(*[ 894 ArrayContraction(arg, *contr) for arg, contr in zip(expr.args, contraction_indices_args) 895 ]) 896 return ret, contraction_indices_remaining 897 898 def split_multiple_contractions(self): 899 """ 900 Recognize multiple contractions and attempt at rewriting them as paired-contractions. 901 902 This allows some contractions involving more than two indices to be 903 rewritten as multiple contractions involving two indices, thus allowing 904 the expression to be rewritten as a matrix multiplication line. 905 906 Examples: 907 908 * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C` 909 910 Care for: 911 - matrix being diagonalized (i.e. `A_ii`) 912 - vectors being diagonalized (i.e. `a_i0`) 913 914 Multiple contractions can be split into matrix multiplications if 915 not more than two arguments are non-diagonals or non-vectors. 916 Vectors get diagonalized while diagonal matrices remain diagonal. 917 The non-diagonal matrices can be at the beginning or at the end 918 of the final matrix multiplication line. 919 """ 920 from sympy import ask, Q 921 922 editor = _EditArrayContraction(self) 923 924 contraction_indices = self.contraction_indices 925 if isinstance(self.expr, ArrayTensorProduct): 926 args = list(self.expr.args) 927 else: 928 args = [self.expr] 929 # TODO: unify API, best location in ArrayTensorProduct 930 subranks = [get_rank(i) for i in args] 931 # TODO: unify API 932 mapping = _get_mapping_from_subranks(subranks) 933 reverse_mapping = {v: k for k, v in mapping.items()} 934 935 for indl, links in enumerate(contraction_indices): 936 if len(links) <= 2: 937 continue 938 939 positions = editor.get_mapping_for_index(indl) 940 941 # Also consider the case of diagonal matrices being contracted: 942 current_dimension = self.expr.shape[links[0]] 943 944 not_vectors: Tuple[_ArgE, int] = [] 945 vectors: Tuple[_ArgE, int] = [] 946 for arg_ind, rel_ind in positions: 947 mat = args[arg_ind] 948 other_arg_pos = 1-rel_ind 949 other_arg_abs = reverse_mapping[arg_ind, other_arg_pos] 950 arg = editor.args_with_ind[arg_ind] 951 if (((1 not in mat.shape) and (not ask(Q.diagonal(mat)))) or 952 ((current_dimension == 1) is True and mat.shape != (1, 1)) or 953 any([other_arg_abs in l for li, l in enumerate(contraction_indices) if li != indl]) 954 ): 955 not_vectors.append((arg, rel_ind)) 956 else: 957 vectors.append((arg, rel_ind)) 958 if len(not_vectors) > 2: 959 # If more than two arguments in the multiple contraction are 960 # non-vectors and non-diagonal matrices, we cannot find a way 961 # to split this contraction into a matrix multiplication line: 962 continue 963 # Three cases to handle: 964 # - zero non-vectors 965 # - one non-vector 966 # - two non-vectors 967 for v, rel_ind in vectors: 968 v.element = diagonalize_vector(v.element) 969 vectors_to_loop = not_vectors[:1] + vectors + not_vectors[1:] 970 first_not_vector, rel_ind = vectors_to_loop[0] 971 new_index = first_not_vector.indices[rel_ind] 972 973 for v, rel_ind in vectors_to_loop[1:-1]: 974 v.indices[rel_ind] = new_index 975 new_index = editor.get_new_contraction_index() 976 assert v.indices.index(None) == 1 - rel_ind 977 v.indices[v.indices.index(None)] = new_index 978 979 last_vec, rel_ind = vectors_to_loop[-1] 980 last_vec.indices[rel_ind] = new_index 981 return editor.to_array_contraction() 982 983 def flatten_contraction_of_diagonal(self): 984 if not isinstance(self.expr, ArrayDiagonal): 985 return self 986 contraction_down = self.expr._push_indices_down(self.expr.diagonal_indices, self.contraction_indices) 987 new_contraction_indices = [] 988 diagonal_indices = self.expr.diagonal_indices[:] 989 for i in contraction_down: 990 contraction_group = list(i) 991 for j in i: 992 diagonal_with = [k for k in diagonal_indices if j in k] 993 contraction_group.extend([l for k in diagonal_with for l in k]) 994 diagonal_indices = [k for k in diagonal_indices if k not in diagonal_with] 995 new_contraction_indices.append(sorted(set(contraction_group))) 996 997 new_contraction_indices = ArrayDiagonal._push_indices_up(diagonal_indices, new_contraction_indices) 998 return ArrayContraction( 999 ArrayDiagonal( 1000 self.expr.expr, 1001 *diagonal_indices 1002 ), 1003 *new_contraction_indices 1004 ) 1005 1006 @staticmethod 1007 def _get_free_indices_to_position_map(free_indices, contraction_indices): 1008 free_indices_to_position = {} 1009 flattened_contraction_indices = [j for i in contraction_indices for j in i] 1010 counter = 0 1011 for ind in free_indices: 1012 while counter in flattened_contraction_indices: 1013 counter += 1 1014 free_indices_to_position[ind] = counter 1015 counter += 1 1016 return free_indices_to_position 1017 1018 @staticmethod 1019 def _get_index_shifts(expr): 1020 """ 1021 Get the mapping of indices at the positions before the contraction 1022 occurs. 1023 1024 Examples 1025 ======== 1026 1027 >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct 1028 >>> from sympy.tensor.array.expressions.array_expressions import ArrayContraction 1029 >>> from sympy import MatrixSymbol 1030 >>> M = MatrixSymbol("M", 3, 3) 1031 >>> N = MatrixSymbol("N", 3, 3) 1032 >>> cg = ArrayContraction(ArrayTensorProduct(M, N), [1, 2]) 1033 >>> cg._get_index_shifts(cg) 1034 [0, 2] 1035 1036 Indeed, ``cg`` after the contraction has two dimensions, 0 and 1. They 1037 need to be shifted by 0 and 2 to get the corresponding positions before 1038 the contraction (that is, 0 and 3). 1039 """ 1040 inner_contraction_indices = expr.contraction_indices 1041 all_inner = [j for i in inner_contraction_indices for j in i] 1042 all_inner.sort() 1043 # TODO: add API for total rank and cumulative rank: 1044 total_rank = _get_subrank(expr) 1045 inner_rank = len(all_inner) 1046 outer_rank = total_rank - inner_rank 1047 shifts = [0 for i in range(outer_rank)] 1048 counter = 0 1049 pointer = 0 1050 for i in range(outer_rank): 1051 while pointer < inner_rank and counter >= all_inner[pointer]: 1052 counter += 1 1053 pointer += 1 1054 shifts[i] += pointer 1055 counter += 1 1056 return shifts 1057 1058 @staticmethod 1059 def _convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices): 1060 shifts = ArrayContraction._get_index_shifts(expr) 1061 outer_contraction_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_contraction_indices) 1062 return outer_contraction_indices 1063 1064 @staticmethod 1065 def _flatten(expr, *outer_contraction_indices): 1066 inner_contraction_indices = expr.contraction_indices 1067 outer_contraction_indices = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices) 1068 contraction_indices = inner_contraction_indices + outer_contraction_indices 1069 return ArrayContraction(expr.expr, *contraction_indices) 1070 1071 @classmethod 1072 def _handle_nested_permute_dims(cls, expr, *contraction_indices): 1073 permutation = expr.permutation 1074 plist = permutation.array_form 1075 new_contraction_indices = [tuple(permutation(j) for j in i) for i in contraction_indices] 1076 new_plist = [i for i in plist if all(i not in j for j in new_contraction_indices)] 1077 new_plist = cls._push_indices_up(new_contraction_indices, new_plist) 1078 return PermuteDims( 1079 ArrayContraction(expr.expr, *new_contraction_indices), 1080 Permutation(new_plist) 1081 ) 1082 1083 @classmethod 1084 def _handle_nested_diagonal(cls, expr: 'ArrayDiagonal', *contraction_indices): 1085 diagonal_indices = list(expr.diagonal_indices) 1086 down_contraction_indices = expr._push_indices_down(expr.diagonal_indices, contraction_indices, get_rank(expr.expr)) 1087 # Flatten diagonally contracted indices: 1088 down_contraction_indices = [[k for j in i for k in (j if isinstance(j, (tuple, Tuple)) else [j])] for i in down_contraction_indices] 1089 new_contraction_indices = [] 1090 for contr_indgrp in down_contraction_indices: 1091 ind = contr_indgrp[:] 1092 for j, diag_indgrp in enumerate(diagonal_indices): 1093 if diag_indgrp is None: 1094 continue 1095 if any(i in diag_indgrp for i in contr_indgrp): 1096 ind.extend(diag_indgrp) 1097 diagonal_indices[j] = None 1098 new_contraction_indices.append(sorted(set(ind))) 1099 1100 new_diagonal_indices_down = [i for i in diagonal_indices if i is not None] 1101 new_diagonal_indices = ArrayContraction._push_indices_up(new_contraction_indices, new_diagonal_indices_down) 1102 return ArrayDiagonal( 1103 ArrayContraction(expr.expr, *new_contraction_indices), 1104 *new_diagonal_indices 1105 ) 1106 1107 @classmethod 1108 def _sort_fully_contracted_args(cls, expr, contraction_indices): 1109 if expr.shape is None: 1110 return expr, contraction_indices 1111 cumul = list(accumulate([0] + expr.subranks)) 1112 index_blocks = [list(range(cumul[i], cumul[i+1])) for i in range(len(expr.args))] 1113 contraction_indices_flat = {j for i in contraction_indices for j in i} 1114 fully_contracted = [all(j in contraction_indices_flat for j in range(cumul[i], cumul[i+1])) for i, arg in enumerate(expr.args)] 1115 new_pos = sorted(range(len(expr.args)), key=lambda x: (0, default_sort_key(expr.args[x])) if fully_contracted[x] else (1,)) 1116 new_args = [expr.args[i] for i in new_pos] 1117 new_index_blocks_flat = [j for i in new_pos for j in index_blocks[i]] 1118 index_permutation_array_form = _af_invert(new_index_blocks_flat) 1119 new_contraction_indices = [tuple(index_permutation_array_form[j] for j in i) for i in contraction_indices] 1120 new_contraction_indices = _sort_contraction_indices(new_contraction_indices) 1121 return ArrayTensorProduct(*new_args), new_contraction_indices 1122 1123 def _get_contraction_tuples(self): 1124 r""" 1125 Return tuples containing the argument index and position within the 1126 argument of the index position. 1127 1128 Examples 1129 ======== 1130 1131 >>> from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct 1132 >>> from sympy import MatrixSymbol 1133 >>> from sympy.abc import N 1134 >>> from sympy.tensor.array.expressions.array_expressions import ArrayContraction 1135 >>> A = MatrixSymbol("A", N, N) 1136 >>> B = MatrixSymbol("B", N, N) 1137 1138 >>> cg = ArrayContraction(ArrayTensorProduct(A, B), (1, 2)) 1139 >>> cg._get_contraction_tuples() 1140 [[(0, 1), (1, 0)]] 1141 1142 Notes 1143 ===== 1144 1145 Here the contraction pair `(1, 2)` meaning that the 2nd and 3rd indices 1146 of the tensor product `A\otimes B` are contracted, has been transformed 1147 into `(0, 1)` and `(1, 0)`, identifying the same indices in a different 1148 notation. `(0, 1)` is the second index (1) of the first argument (i.e. 1149 0 or `A`). `(1, 0)` is the first index (i.e. 0) of the second 1150 argument (i.e. 1 or `B`). 1151 """ 1152 mapping = self._mapping 1153 return [[mapping[j] for j in i] for i in self.contraction_indices] 1154 1155 @staticmethod 1156 def _contraction_tuples_to_contraction_indices(expr, contraction_tuples): 1157 # TODO: check that `expr` has `.subranks`: 1158 ranks = expr.subranks 1159 cumulative_ranks = [0] + list(accumulate(ranks)) 1160 return [tuple(cumulative_ranks[j]+k for j, k in i) for i in contraction_tuples] 1161 1162 @property 1163 def free_indices(self): 1164 return self._free_indices[:] 1165 1166 @property 1167 def free_indices_to_position(self): 1168 return dict(self._free_indices_to_position) 1169 1170 @property 1171 def expr(self): 1172 return self.args[0] 1173 1174 @property 1175 def contraction_indices(self): 1176 return self.args[1:] 1177 1178 def _contraction_indices_to_components(self): 1179 expr = self.expr 1180 if not isinstance(expr, ArrayTensorProduct): 1181 raise NotImplementedError("only for contractions of tensor products") 1182 ranks = expr.subranks 1183 mapping = {} 1184 counter = 0 1185 for i, rank in enumerate(ranks): 1186 for j in range(rank): 1187 mapping[counter] = (i, j) 1188 counter += 1 1189 return mapping 1190 1191 def sort_args_by_name(self): 1192 """ 1193 Sort arguments in the tensor product so that their order is lexicographical. 1194 1195 Examples 1196 ======== 1197 1198 >>> from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array 1199 >>> from sympy import MatrixSymbol 1200 >>> from sympy.abc import N 1201 >>> A = MatrixSymbol("A", N, N) 1202 >>> B = MatrixSymbol("B", N, N) 1203 >>> C = MatrixSymbol("C", N, N) 1204 >>> D = MatrixSymbol("D", N, N) 1205 1206 >>> cg = convert_matrix_to_array(C*D*A*B) 1207 >>> cg 1208 ArrayContraction(ArrayTensorProduct(A, D, C, B), (0, 3), (1, 6), (2, 5)) 1209 >>> cg.sort_args_by_name() 1210 ArrayContraction(ArrayTensorProduct(A, D, B, C), (0, 3), (1, 4), (2, 7)) 1211 """ 1212 expr = self.expr 1213 if not isinstance(expr, ArrayTensorProduct): 1214 return self 1215 args = expr.args 1216 sorted_data = sorted(enumerate(args), key=lambda x: default_sort_key(x[1])) 1217 pos_sorted, args_sorted = zip(*sorted_data) 1218 reordering_map = {i: pos_sorted.index(i) for i, arg in enumerate(args)} 1219 contraction_tuples = self._get_contraction_tuples() 1220 contraction_tuples = [[(reordering_map[j], k) for j, k in i] for i in contraction_tuples] 1221 c_tp = ArrayTensorProduct(*args_sorted) 1222 new_contr_indices = self._contraction_tuples_to_contraction_indices( 1223 c_tp, 1224 contraction_tuples 1225 ) 1226 return ArrayContraction(c_tp, *new_contr_indices) 1227 1228 def _get_contraction_links(self): 1229 r""" 1230 Returns a dictionary of links between arguments in the tensor product 1231 being contracted. 1232 1233 See the example for an explanation of the values. 1234 1235 Examples 1236 ======== 1237 1238 >>> from sympy import MatrixSymbol 1239 >>> from sympy.abc import N 1240 >>> from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array 1241 >>> A = MatrixSymbol("A", N, N) 1242 >>> B = MatrixSymbol("B", N, N) 1243 >>> C = MatrixSymbol("C", N, N) 1244 >>> D = MatrixSymbol("D", N, N) 1245 1246 Matrix multiplications are pairwise contractions between neighboring 1247 matrices: 1248 1249 `A_{ij} B_{jk} C_{kl} D_{lm}` 1250 1251 >>> cg = convert_matrix_to_array(A*B*C*D) 1252 >>> cg 1253 ArrayContraction(ArrayTensorProduct(B, C, A, D), (0, 5), (1, 2), (3, 6)) 1254 1255 >>> cg._get_contraction_links() 1256 {0: {0: (2, 1), 1: (1, 0)}, 1: {0: (0, 1), 1: (3, 0)}, 2: {1: (0, 0)}, 3: {0: (1, 1)}} 1257 1258 This dictionary is interpreted as follows: argument in position 0 (i.e. 1259 matrix `A`) has its second index (i.e. 1) contracted to `(1, 0)`, that 1260 is argument in position 1 (matrix `B`) on the first index slot of `B`, 1261 this is the contraction provided by the index `j` from `A`. 1262 1263 The argument in position 1 (that is, matrix `B`) has two contractions, 1264 the ones provided by the indices `j` and `k`, respectively the first 1265 and second indices (0 and 1 in the sub-dict). The link `(0, 1)` and 1266 `(2, 0)` respectively. `(0, 1)` is the index slot 1 (the 2nd) of 1267 argument in position 0 (that is, `A_{\ldot j}`), and so on. 1268 """ 1269 args, dlinks = _get_contraction_links([self], self.subranks, *self.contraction_indices) 1270 return dlinks 1271 1272 def as_explicit(self): 1273 return tensorcontraction(self.expr.as_explicit(), *self.contraction_indices) 1274 1275 1276class _ArgE: 1277 """ 1278 The ``_ArgE`` object contains references to the array expression 1279 (``.element``) and a list containing the information about index 1280 contractions (``.indices``). 1281 1282 Index contractions are numbered and contracted indices show the number of 1283 the contraction. Uncontracted indices have ``None`` value. 1284 1285 For example: 1286 ``_ArgE(M, [None, 3])`` 1287 This object means that expression ``M`` is part of an array contraction 1288 and has two indices, the first is not contracted (value ``None``), 1289 the second index is contracted to the 4th (i.e. number ``3``) group of the 1290 array contraction object. 1291 """ 1292 def __init__(self, element, indices: Optional[List[Optional[int]]] = None): 1293 self.element = element 1294 if indices is None: 1295 self.indices: List[Optional[int]] = [None for i in range(get_rank(element))] 1296 else: 1297 self.indices: List[Optional[int]] = indices 1298 1299 def __str__(self): 1300 return "_ArgE(%s, %s)" % (self.element, self.indices) 1301 1302 __repr__ = __str__ 1303 1304 1305class _IndPos: 1306 """ 1307 Index position, requiring two integers in the constructor: 1308 1309 - arg: the position of the argument in the tensor product, 1310 - rel: the relative position of the index inside the argument. 1311 """ 1312 def __init__(self, arg: int, rel: int): 1313 self.arg = arg 1314 self.rel = rel 1315 1316 def __str__(self): 1317 return "_IndPos(%i, %i)" % (self.arg, self.rel) 1318 1319 __repr__ = __str__ 1320 1321 def __iter__(self): 1322 yield from [self.arg, self.rel] 1323 1324 1325class _EditArrayContraction: 1326 """ 1327 Utility class to help manipulate array contraction objects. 1328 1329 This class takes as input an ``ArrayContraction`` object and turns it into 1330 an editable object. 1331 1332 The field ``args_with_ind`` of this class is a list of ``_ArgE`` objects 1333 which can be used to easily edit the contraction structure of the 1334 expression. 1335 1336 Once editing is finished, the ``ArrayContraction`` object may be recreated 1337 by calling the ``.to_array_contraction()`` method. 1338 """ 1339 1340 def __init__(self, array_contraction: Optional[ArrayContraction]): 1341 if array_contraction is None: 1342 self.args_with_ind: List[_ArgE] = [] 1343 self.number_of_contraction_indices: int = 0 1344 self._track_permutation: Optional[List[int]] = None 1345 return 1346 expr = array_contraction.expr 1347 if isinstance(expr, ArrayTensorProduct): 1348 args = list(expr.args) 1349 else: 1350 args = [expr] 1351 args_with_ind: List[_ArgE] = [_ArgE(arg) for arg in args] 1352 mapping = _get_mapping_from_subranks(array_contraction.subranks) 1353 for i, contraction_tuple in enumerate(array_contraction.contraction_indices): 1354 for j in contraction_tuple: 1355 arg_pos, rel_pos = mapping[j] 1356 args_with_ind[arg_pos].indices[rel_pos] = i 1357 self.args_with_ind: List[_ArgE] = args_with_ind 1358 self.number_of_contraction_indices: int = len(array_contraction.contraction_indices) 1359 self._track_permutation: Optional[List[int]] = None 1360 1361 def insert_after(self, arg: _ArgE, new_arg: _ArgE): 1362 pos = self.args_with_ind.index(arg) 1363 self.args_with_ind.insert(pos + 1, new_arg) 1364 1365 def get_new_contraction_index(self): 1366 self.number_of_contraction_indices += 1 1367 return self.number_of_contraction_indices - 1 1368 1369 def refresh_indices(self): 1370 updates: Dict[int, int] = {} 1371 for arg_with_ind in self.args_with_ind: 1372 updates.update({i: -1 for i in arg_with_ind.indices if i is not None}) 1373 for i, e in enumerate(sorted(updates)): 1374 updates[e] = i 1375 self.number_of_contraction_indices: int = len(updates) 1376 for arg_with_ind in self.args_with_ind: 1377 arg_with_ind.indices = [updates.get(i, None) for i in arg_with_ind.indices] 1378 1379 def merge_scalars(self): 1380 scalars = [] 1381 for arg_with_ind in self.args_with_ind: 1382 if len(arg_with_ind.indices) == 0: 1383 scalars.append(arg_with_ind) 1384 for i in scalars: 1385 self.args_with_ind.remove(i) 1386 scalar = Mul.fromiter([i.element for i in scalars]) 1387 if len(self.args_with_ind) == 0: 1388 self.args_with_ind.append(_ArgE(scalar)) 1389 else: 1390 from sympy.tensor.array.expressions.conv_array_to_matrix import _a2m_tensor_product 1391 self.args_with_ind[0].element = _a2m_tensor_product(scalar, self.args_with_ind[0].element) 1392 1393 def to_array_contraction(self): 1394 self.merge_scalars() 1395 self.refresh_indices() 1396 args = [arg.element for arg in self.args_with_ind] 1397 contraction_indices = self.get_contraction_indices() 1398 expr = ArrayContraction(ArrayTensorProduct(*args), *contraction_indices) 1399 if self._track_permutation is not None: 1400 permutation = _af_invert([j for i in self._track_permutation for j in i]) 1401 expr = PermuteDims(expr, permutation) 1402 return expr 1403 1404 def get_contraction_indices(self) -> List[List[int]]: 1405 contraction_indices: List[List[int]] = [[] for i in range(self.number_of_contraction_indices)] 1406 current_position: int = 0 1407 for i, arg_with_ind in enumerate(self.args_with_ind): 1408 for j in arg_with_ind.indices: 1409 if j is not None: 1410 contraction_indices[j].append(current_position) 1411 current_position += 1 1412 return contraction_indices 1413 1414 def get_mapping_for_index(self, ind) -> List[_IndPos]: 1415 if ind >= self.number_of_contraction_indices: 1416 raise ValueError("index value exceeding the index range") 1417 positions: List[_IndPos] = [] 1418 for i, arg_with_ind in enumerate(self.args_with_ind): 1419 for j, arg_ind in enumerate(arg_with_ind.indices): 1420 if ind == arg_ind: 1421 positions.append(_IndPos(i, j)) 1422 return positions 1423 1424 def get_contraction_indices_to_ind_rel_pos(self) -> List[List[_IndPos]]: 1425 contraction_indices: List[List[_IndPos]] = [[] for i in range(self.number_of_contraction_indices)] 1426 for i, arg_with_ind in enumerate(self.args_with_ind): 1427 for j, ind in enumerate(arg_with_ind.indices): 1428 if ind is not None: 1429 contraction_indices[ind].append(_IndPos(i, j)) 1430 return contraction_indices 1431 1432 def count_args_with_index(self, index: int) -> int: 1433 """ 1434 Count the number of arguments that have the given index. 1435 """ 1436 counter: int = 0 1437 for arg_with_ind in self.args_with_ind: 1438 if index in arg_with_ind.indices: 1439 counter += 1 1440 return counter 1441 1442 def track_permutation_start(self): 1443 self._track_permutation = [] 1444 counter: int = 0 1445 for arg_with_ind in self.args_with_ind: 1446 perm = [] 1447 for i in arg_with_ind.indices: 1448 if i is not None: 1449 continue 1450 perm.append(counter) 1451 counter += 1 1452 self._track_permutation.append(perm) 1453 1454 def track_permutation_merge(self, destination: _ArgE, from_element: _ArgE): 1455 index_destination = self.args_with_ind.index(destination) 1456 index_element = self.args_with_ind.index(from_element) 1457 self._track_permutation[index_destination].extend(self._track_permutation[index_element]) 1458 self._track_permutation.pop(index_element) 1459 1460 1461def get_rank(expr): 1462 if isinstance(expr, (MatrixExpr, MatrixElement)): 1463 return 2 1464 if isinstance(expr, _CodegenArrayAbstract): 1465 return len(expr.shape) 1466 if isinstance(expr, NDimArray): 1467 return expr.rank() 1468 if isinstance(expr, Indexed): 1469 return expr.rank 1470 if isinstance(expr, IndexedBase): 1471 shape = expr.shape 1472 if shape is None: 1473 return -1 1474 else: 1475 return len(shape) 1476 if hasattr(expr, "shape"): 1477 return len(expr.shape) 1478 return 0 1479 1480 1481def _get_subrank(expr): 1482 if isinstance(expr, _CodegenArrayAbstract): 1483 return expr.subrank() 1484 return get_rank(expr) 1485 1486 1487def _get_subranks(expr): 1488 if isinstance(expr, _CodegenArrayAbstract): 1489 return expr.subranks 1490 else: 1491 return [get_rank(expr)] 1492 1493 1494def get_shape(expr): 1495 if hasattr(expr, "shape"): 1496 return expr.shape 1497 return () 1498 1499 1500def nest_permutation(expr): 1501 if isinstance(expr, PermuteDims): 1502 return expr.nest_permutation() 1503 else: 1504 return expr 1505