1""" 2Implementation of optimized einsum. 3 4""" 5import itertools 6import operator 7 8from numpy.core.multiarray import c_einsum 9from numpy.core.numeric import asanyarray, tensordot 10from numpy.core.overrides import array_function_dispatch 11 12__all__ = ['einsum', 'einsum_path'] 13 14einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 15einsum_symbols_set = set(einsum_symbols) 16 17 18def _flop_count(idx_contraction, inner, num_terms, size_dictionary): 19 """ 20 Computes the number of FLOPS in the contraction. 21 22 Parameters 23 ---------- 24 idx_contraction : iterable 25 The indices involved in the contraction 26 inner : bool 27 Does this contraction require an inner product? 28 num_terms : int 29 The number of terms in a contraction 30 size_dictionary : dict 31 The size of each of the indices in idx_contraction 32 33 Returns 34 ------- 35 flop_count : int 36 The total number of FLOPS required for the contraction. 37 38 Examples 39 -------- 40 41 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 42 30 43 44 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 45 60 46 47 """ 48 49 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary) 50 op_factor = max(1, num_terms - 1) 51 if inner: 52 op_factor += 1 53 54 return overall_size * op_factor 55 56def _compute_size_by_dict(indices, idx_dict): 57 """ 58 Computes the product of the elements in indices based on the dictionary 59 idx_dict. 60 61 Parameters 62 ---------- 63 indices : iterable 64 Indices to base the product on. 65 idx_dict : dictionary 66 Dictionary of index sizes 67 68 Returns 69 ------- 70 ret : int 71 The resulting product. 72 73 Examples 74 -------- 75 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 76 90 77 78 """ 79 ret = 1 80 for i in indices: 81 ret *= idx_dict[i] 82 return ret 83 84 85def _find_contraction(positions, input_sets, output_set): 86 """ 87 Finds the contraction for a given set of input and output sets. 88 89 Parameters 90 ---------- 91 positions : iterable 92 Integer positions of terms used in the contraction. 93 input_sets : list 94 List of sets that represent the lhs side of the einsum subscript 95 output_set : set 96 Set that represents the rhs side of the overall einsum subscript 97 98 Returns 99 ------- 100 new_result : set 101 The indices of the resulting contraction 102 remaining : list 103 List of sets that have not been contracted, the new set is appended to 104 the end of this list 105 idx_removed : set 106 Indices removed from the entire contraction 107 idx_contraction : set 108 The indices used in the current contraction 109 110 Examples 111 -------- 112 113 # A simple dot product test case 114 >>> pos = (0, 1) 115 >>> isets = [set('ab'), set('bc')] 116 >>> oset = set('ac') 117 >>> _find_contraction(pos, isets, oset) 118 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 119 120 # A more complex case with additional terms in the contraction 121 >>> pos = (0, 2) 122 >>> isets = [set('abd'), set('ac'), set('bdc')] 123 >>> oset = set('ac') 124 >>> _find_contraction(pos, isets, oset) 125 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 126 """ 127 128 idx_contract = set() 129 idx_remain = output_set.copy() 130 remaining = [] 131 for ind, value in enumerate(input_sets): 132 if ind in positions: 133 idx_contract |= value 134 else: 135 remaining.append(value) 136 idx_remain |= value 137 138 new_result = idx_remain & idx_contract 139 idx_removed = (idx_contract - new_result) 140 remaining.append(new_result) 141 142 return (new_result, remaining, idx_removed, idx_contract) 143 144 145def _optimal_path(input_sets, output_set, idx_dict, memory_limit): 146 """ 147 Computes all possible pair contractions, sieves the results based 148 on ``memory_limit`` and returns the lowest cost path. This algorithm 149 scales factorial with respect to the elements in the list ``input_sets``. 150 151 Parameters 152 ---------- 153 input_sets : list 154 List of sets that represent the lhs side of the einsum subscript 155 output_set : set 156 Set that represents the rhs side of the overall einsum subscript 157 idx_dict : dictionary 158 Dictionary of index sizes 159 memory_limit : int 160 The maximum number of elements in a temporary array 161 162 Returns 163 ------- 164 path : list 165 The optimal contraction order within the memory limit constraint. 166 167 Examples 168 -------- 169 >>> isets = [set('abd'), set('ac'), set('bdc')] 170 >>> oset = set() 171 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 172 >>> _optimal_path(isets, oset, idx_sizes, 5000) 173 [(0, 2), (0, 1)] 174 """ 175 176 full_results = [(0, [], input_sets)] 177 for iteration in range(len(input_sets) - 1): 178 iter_results = [] 179 180 # Compute all unique pairs 181 for curr in full_results: 182 cost, positions, remaining = curr 183 for con in itertools.combinations(range(len(input_sets) - iteration), 2): 184 185 # Find the contraction 186 cont = _find_contraction(con, remaining, output_set) 187 new_result, new_input_sets, idx_removed, idx_contract = cont 188 189 # Sieve the results based on memory_limit 190 new_size = _compute_size_by_dict(new_result, idx_dict) 191 if new_size > memory_limit: 192 continue 193 194 # Build (total_cost, positions, indices_remaining) 195 total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict) 196 new_pos = positions + [con] 197 iter_results.append((total_cost, new_pos, new_input_sets)) 198 199 # Update combinatorial list, if we did not find anything return best 200 # path + remaining contractions 201 if iter_results: 202 full_results = iter_results 203 else: 204 path = min(full_results, key=lambda x: x[0])[1] 205 path += [tuple(range(len(input_sets) - iteration))] 206 return path 207 208 # If we have not found anything return single einsum contraction 209 if len(full_results) == 0: 210 return [tuple(range(len(input_sets)))] 211 212 path = min(full_results, key=lambda x: x[0])[1] 213 return path 214 215def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost): 216 """Compute the cost (removed size + flops) and resultant indices for 217 performing the contraction specified by ``positions``. 218 219 Parameters 220 ---------- 221 positions : tuple of int 222 The locations of the proposed tensors to contract. 223 input_sets : list of sets 224 The indices found on each tensors. 225 output_set : set 226 The output indices of the expression. 227 idx_dict : dict 228 Mapping of each index to its size. 229 memory_limit : int 230 The total allowed size for an intermediary tensor. 231 path_cost : int 232 The contraction cost so far. 233 naive_cost : int 234 The cost of the unoptimized expression. 235 236 Returns 237 ------- 238 cost : (int, int) 239 A tuple containing the size of any indices removed, and the flop cost. 240 positions : tuple of int 241 The locations of the proposed tensors to contract. 242 new_input_sets : list of sets 243 The resulting new list of indices if this proposed contraction is performed. 244 245 """ 246 247 # Find the contraction 248 contract = _find_contraction(positions, input_sets, output_set) 249 idx_result, new_input_sets, idx_removed, idx_contract = contract 250 251 # Sieve the results based on memory_limit 252 new_size = _compute_size_by_dict(idx_result, idx_dict) 253 if new_size > memory_limit: 254 return None 255 256 # Build sort tuple 257 old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions) 258 removed_size = sum(old_sizes) - new_size 259 260 # NB: removed_size used to be just the size of any removed indices i.e.: 261 # helpers.compute_size_by_dict(idx_removed, idx_dict) 262 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) 263 sort = (-removed_size, cost) 264 265 # Sieve based on total cost as well 266 if (path_cost + cost) > naive_cost: 267 return None 268 269 # Add contraction to possible choices 270 return [sort, positions, new_input_sets] 271 272 273def _update_other_results(results, best): 274 """Update the positions and provisional input_sets of ``results`` based on 275 performing the contraction result ``best``. Remove any involving the tensors 276 contracted. 277 278 Parameters 279 ---------- 280 results : list 281 List of contraction results produced by ``_parse_possible_contraction``. 282 best : list 283 The best contraction of ``results`` i.e. the one that will be performed. 284 285 Returns 286 ------- 287 mod_results : list 288 The list of modified results, updated with outcome of ``best`` contraction. 289 """ 290 291 best_con = best[1] 292 bx, by = best_con 293 mod_results = [] 294 295 for cost, (x, y), con_sets in results: 296 297 # Ignore results involving tensors just contracted 298 if x in best_con or y in best_con: 299 continue 300 301 # Update the input_sets 302 del con_sets[by - int(by > x) - int(by > y)] 303 del con_sets[bx - int(bx > x) - int(bx > y)] 304 con_sets.insert(-1, best[2][-1]) 305 306 # Update the position indices 307 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) 308 mod_results.append((cost, mod_con, con_sets)) 309 310 return mod_results 311 312def _greedy_path(input_sets, output_set, idx_dict, memory_limit): 313 """ 314 Finds the path by contracting the best pair until the input list is 315 exhausted. The best pair is found by minimizing the tuple 316 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing 317 matrix multiplication or inner product operations, then Hadamard like 318 operations, and finally outer operations. Outer products are limited by 319 ``memory_limit``. This algorithm scales cubically with respect to the 320 number of elements in the list ``input_sets``. 321 322 Parameters 323 ---------- 324 input_sets : list 325 List of sets that represent the lhs side of the einsum subscript 326 output_set : set 327 Set that represents the rhs side of the overall einsum subscript 328 idx_dict : dictionary 329 Dictionary of index sizes 330 memory_limit_limit : int 331 The maximum number of elements in a temporary array 332 333 Returns 334 ------- 335 path : list 336 The greedy contraction order within the memory limit constraint. 337 338 Examples 339 -------- 340 >>> isets = [set('abd'), set('ac'), set('bdc')] 341 >>> oset = set() 342 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 343 >>> _greedy_path(isets, oset, idx_sizes, 5000) 344 [(0, 2), (0, 1)] 345 """ 346 347 # Handle trivial cases that leaked through 348 if len(input_sets) == 1: 349 return [(0,)] 350 elif len(input_sets) == 2: 351 return [(0, 1)] 352 353 # Build up a naive cost 354 contract = _find_contraction(range(len(input_sets)), input_sets, output_set) 355 idx_result, new_input_sets, idx_removed, idx_contract = contract 356 naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict) 357 358 # Initially iterate over all pairs 359 comb_iter = itertools.combinations(range(len(input_sets)), 2) 360 known_contractions = [] 361 362 path_cost = 0 363 path = [] 364 365 for iteration in range(len(input_sets) - 1): 366 367 # Iterate over all pairs on first step, only previously found pairs on subsequent steps 368 for positions in comb_iter: 369 370 # Always initially ignore outer products 371 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): 372 continue 373 374 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, 375 naive_cost) 376 if result is not None: 377 known_contractions.append(result) 378 379 # If we do not have a inner contraction, rescan pairs including outer products 380 if len(known_contractions) == 0: 381 382 # Then check the outer products 383 for positions in itertools.combinations(range(len(input_sets)), 2): 384 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, 385 path_cost, naive_cost) 386 if result is not None: 387 known_contractions.append(result) 388 389 # If we still did not find any remaining contractions, default back to einsum like behavior 390 if len(known_contractions) == 0: 391 path.append(tuple(range(len(input_sets)))) 392 break 393 394 # Sort based on first index 395 best = min(known_contractions, key=lambda x: x[0]) 396 397 # Now propagate as many unused contractions as possible to next iteration 398 known_contractions = _update_other_results(known_contractions, best) 399 400 # Next iteration only compute contractions with the new tensor 401 # All other contractions have been accounted for 402 input_sets = best[2] 403 new_tensor_pos = len(input_sets) - 1 404 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) 405 406 # Update path and total cost 407 path.append(best[1]) 408 path_cost += best[0][1] 409 410 return path 411 412 413def _can_dot(inputs, result, idx_removed): 414 """ 415 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so. 416 417 Parameters 418 ---------- 419 inputs : list of str 420 Specifies the subscripts for summation. 421 result : str 422 Resulting summation. 423 idx_removed : set 424 Indices that are removed in the summation 425 426 427 Returns 428 ------- 429 type : bool 430 Returns true if BLAS should and can be used, else False 431 432 Notes 433 ----- 434 If the operations is BLAS level 1 or 2 and is not already aligned 435 we default back to einsum as the memory movement to copy is more 436 costly than the operation itself. 437 438 439 Examples 440 -------- 441 442 # Standard GEMM operation 443 >>> _can_dot(['ij', 'jk'], 'ik', set('j')) 444 True 445 446 # Can use the standard BLAS, but requires odd data movement 447 >>> _can_dot(['ijj', 'jk'], 'ik', set('j')) 448 False 449 450 # DDOT where the memory is not aligned 451 >>> _can_dot(['ijk', 'ikj'], '', set('ijk')) 452 False 453 454 """ 455 456 # All `dot` calls remove indices 457 if len(idx_removed) == 0: 458 return False 459 460 # BLAS can only handle two operands 461 if len(inputs) != 2: 462 return False 463 464 input_left, input_right = inputs 465 466 for c in set(input_left + input_right): 467 # can't deal with repeated indices on same input or more than 2 total 468 nl, nr = input_left.count(c), input_right.count(c) 469 if (nl > 1) or (nr > 1) or (nl + nr > 2): 470 return False 471 472 # can't do implicit summation or dimension collapse e.g. 473 # "ab,bc->c" (implicitly sum over 'a') 474 # "ab,ca->ca" (take diagonal of 'a') 475 if nl + nr - 1 == int(c in result): 476 return False 477 478 # Build a few temporaries 479 set_left = set(input_left) 480 set_right = set(input_right) 481 keep_left = set_left - idx_removed 482 keep_right = set_right - idx_removed 483 rs = len(idx_removed) 484 485 # At this point we are a DOT, GEMV, or GEMM operation 486 487 # Handle inner products 488 489 # DDOT with aligned data 490 if input_left == input_right: 491 return True 492 493 # DDOT without aligned data (better to use einsum) 494 if set_left == set_right: 495 return False 496 497 # Handle the 4 possible (aligned) GEMV or GEMM cases 498 499 # GEMM or GEMV no transpose 500 if input_left[-rs:] == input_right[:rs]: 501 return True 502 503 # GEMM or GEMV transpose both 504 if input_left[:rs] == input_right[-rs:]: 505 return True 506 507 # GEMM or GEMV transpose right 508 if input_left[-rs:] == input_right[-rs:]: 509 return True 510 511 # GEMM or GEMV transpose left 512 if input_left[:rs] == input_right[:rs]: 513 return True 514 515 # Einsum is faster than GEMV if we have to copy data 516 if not keep_left or not keep_right: 517 return False 518 519 # We are a matrix-matrix product, but we need to copy data 520 return True 521 522 523def _parse_einsum_input(operands): 524 """ 525 A reproduction of einsum c side einsum parsing in python. 526 527 Returns 528 ------- 529 input_strings : str 530 Parsed input strings 531 output_string : str 532 Parsed output string 533 operands : list of array_like 534 The operands to use in the numpy contraction 535 536 Examples 537 -------- 538 The operand list is simplified to reduce printing: 539 540 >>> np.random.seed(123) 541 >>> a = np.random.rand(4, 4) 542 >>> b = np.random.rand(4, 4, 4) 543 >>> _parse_einsum_input(('...a,...a->...', a, b)) 544 ('za,xza', 'xz', [a, b]) # may vary 545 546 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 547 ('za,xza', 'xz', [a, b]) # may vary 548 """ 549 550 if len(operands) == 0: 551 raise ValueError("No input operands") 552 553 if isinstance(operands[0], str): 554 subscripts = operands[0].replace(" ", "") 555 operands = [asanyarray(v) for v in operands[1:]] 556 557 # Ensure all characters are valid 558 for s in subscripts: 559 if s in '.,->': 560 continue 561 if s not in einsum_symbols: 562 raise ValueError("Character %s is not a valid symbol." % s) 563 564 else: 565 tmp_operands = list(operands) 566 operand_list = [] 567 subscript_list = [] 568 for p in range(len(operands) // 2): 569 operand_list.append(tmp_operands.pop(0)) 570 subscript_list.append(tmp_operands.pop(0)) 571 572 output_list = tmp_operands[-1] if len(tmp_operands) else None 573 operands = [asanyarray(v) for v in operand_list] 574 subscripts = "" 575 last = len(subscript_list) - 1 576 for num, sub in enumerate(subscript_list): 577 for s in sub: 578 if s is Ellipsis: 579 subscripts += "..." 580 else: 581 try: 582 s = operator.index(s) 583 except TypeError as e: 584 raise TypeError("For this input type lists must contain " 585 "either int or Ellipsis") from e 586 subscripts += einsum_symbols[s] 587 if num != last: 588 subscripts += "," 589 590 if output_list is not None: 591 subscripts += "->" 592 for s in output_list: 593 if s is Ellipsis: 594 subscripts += "..." 595 else: 596 try: 597 s = operator.index(s) 598 except TypeError as e: 599 raise TypeError("For this input type lists must contain " 600 "either int or Ellipsis") from e 601 subscripts += einsum_symbols[s] 602 # Check for proper "->" 603 if ("-" in subscripts) or (">" in subscripts): 604 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) 605 if invalid or (subscripts.count("->") != 1): 606 raise ValueError("Subscripts can only contain one '->'.") 607 608 # Parse ellipses 609 if "." in subscripts: 610 used = subscripts.replace(".", "").replace(",", "").replace("->", "") 611 unused = list(einsum_symbols_set - set(used)) 612 ellipse_inds = "".join(unused) 613 longest = 0 614 615 if "->" in subscripts: 616 input_tmp, output_sub = subscripts.split("->") 617 split_subscripts = input_tmp.split(",") 618 out_sub = True 619 else: 620 split_subscripts = subscripts.split(',') 621 out_sub = False 622 623 for num, sub in enumerate(split_subscripts): 624 if "." in sub: 625 if (sub.count(".") != 3) or (sub.count("...") != 1): 626 raise ValueError("Invalid Ellipses.") 627 628 # Take into account numerical values 629 if operands[num].shape == (): 630 ellipse_count = 0 631 else: 632 ellipse_count = max(operands[num].ndim, 1) 633 ellipse_count -= (len(sub) - 3) 634 635 if ellipse_count > longest: 636 longest = ellipse_count 637 638 if ellipse_count < 0: 639 raise ValueError("Ellipses lengths do not match.") 640 elif ellipse_count == 0: 641 split_subscripts[num] = sub.replace('...', '') 642 else: 643 rep_inds = ellipse_inds[-ellipse_count:] 644 split_subscripts[num] = sub.replace('...', rep_inds) 645 646 subscripts = ",".join(split_subscripts) 647 if longest == 0: 648 out_ellipse = "" 649 else: 650 out_ellipse = ellipse_inds[-longest:] 651 652 if out_sub: 653 subscripts += "->" + output_sub.replace("...", out_ellipse) 654 else: 655 # Special care for outputless ellipses 656 output_subscript = "" 657 tmp_subscripts = subscripts.replace(",", "") 658 for s in sorted(set(tmp_subscripts)): 659 if s not in (einsum_symbols): 660 raise ValueError("Character %s is not a valid symbol." % s) 661 if tmp_subscripts.count(s) == 1: 662 output_subscript += s 663 normal_inds = ''.join(sorted(set(output_subscript) - 664 set(out_ellipse))) 665 666 subscripts += "->" + out_ellipse + normal_inds 667 668 # Build output string if does not exist 669 if "->" in subscripts: 670 input_subscripts, output_subscript = subscripts.split("->") 671 else: 672 input_subscripts = subscripts 673 # Build output subscripts 674 tmp_subscripts = subscripts.replace(",", "") 675 output_subscript = "" 676 for s in sorted(set(tmp_subscripts)): 677 if s not in einsum_symbols: 678 raise ValueError("Character %s is not a valid symbol." % s) 679 if tmp_subscripts.count(s) == 1: 680 output_subscript += s 681 682 # Make sure output subscripts are in the input 683 for char in output_subscript: 684 if char not in input_subscripts: 685 raise ValueError("Output character %s did not appear in the input" 686 % char) 687 688 # Make sure number operands is equivalent to the number of terms 689 if len(input_subscripts.split(',')) != len(operands): 690 raise ValueError("Number of einsum subscripts must be equal to the " 691 "number of operands.") 692 693 return (input_subscripts, output_subscript, operands) 694 695 696def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None): 697 # NOTE: technically, we should only dispatch on array-like arguments, not 698 # subscripts (given as strings). But separating operands into 699 # arrays/subscripts is a little tricky/slow (given einsum's two supported 700 # signatures), so as a practical shortcut we dispatch on everything. 701 # Strings will be ignored for dispatching since they don't define 702 # __array_function__. 703 return operands 704 705 706@array_function_dispatch(_einsum_path_dispatcher, module='numpy') 707def einsum_path(*operands, optimize='greedy', einsum_call=False): 708 """ 709 einsum_path(subscripts, *operands, optimize='greedy') 710 711 Evaluates the lowest cost contraction order for an einsum expression by 712 considering the creation of intermediate arrays. 713 714 Parameters 715 ---------- 716 subscripts : str 717 Specifies the subscripts for summation. 718 *operands : list of array_like 719 These are the arrays for the operation. 720 optimize : {bool, list, tuple, 'greedy', 'optimal'} 721 Choose the type of path. If a tuple is provided, the second argument is 722 assumed to be the maximum intermediate size created. If only a single 723 argument is provided the largest input or output array size is used 724 as a maximum intermediate size. 725 726 * if a list is given that starts with ``einsum_path``, uses this as the 727 contraction path 728 * if False no optimization is taken 729 * if True defaults to the 'greedy' algorithm 730 * 'optimal' An algorithm that combinatorially explores all possible 731 ways of contracting the listed tensors and choosest the least costly 732 path. Scales exponentially with the number of terms in the 733 contraction. 734 * 'greedy' An algorithm that chooses the best pair contraction 735 at each step. Effectively, this algorithm searches the largest inner, 736 Hadamard, and then outer products at each step. Scales cubically with 737 the number of terms in the contraction. Equivalent to the 'optimal' 738 path for most contractions. 739 740 Default is 'greedy'. 741 742 Returns 743 ------- 744 path : list of tuples 745 A list representation of the einsum path. 746 string_repr : str 747 A printable representation of the einsum path. 748 749 Notes 750 ----- 751 The resulting path indicates which terms of the input contraction should be 752 contracted first, the result of this contraction is then appended to the 753 end of the contraction list. This list can then be iterated over until all 754 intermediate contractions are complete. 755 756 See Also 757 -------- 758 einsum, linalg.multi_dot 759 760 Examples 761 -------- 762 763 We can begin with a chain dot example. In this case, it is optimal to 764 contract the ``b`` and ``c`` tensors first as represented by the first 765 element of the path ``(1, 2)``. The resulting tensor is added to the end 766 of the contraction and the remaining contraction ``(0, 1)`` is then 767 completed. 768 769 >>> np.random.seed(123) 770 >>> a = np.random.rand(2, 2) 771 >>> b = np.random.rand(2, 5) 772 >>> c = np.random.rand(5, 2) 773 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') 774 >>> print(path_info[0]) 775 ['einsum_path', (1, 2), (0, 1)] 776 >>> print(path_info[1]) 777 Complete contraction: ij,jk,kl->il # may vary 778 Naive scaling: 4 779 Optimized scaling: 3 780 Naive FLOP count: 1.600e+02 781 Optimized FLOP count: 5.600e+01 782 Theoretical speedup: 2.857 783 Largest intermediate: 4.000e+00 elements 784 ------------------------------------------------------------------------- 785 scaling current remaining 786 ------------------------------------------------------------------------- 787 3 kl,jk->jl ij,jl->il 788 3 jl,ij->il il->il 789 790 791 A more complex index transformation example. 792 793 >>> I = np.random.rand(10, 10, 10, 10) 794 >>> C = np.random.rand(10, 10) 795 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, 796 ... optimize='greedy') 797 798 >>> print(path_info[0]) 799 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] 800 >>> print(path_info[1]) 801 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary 802 Naive scaling: 8 803 Optimized scaling: 5 804 Naive FLOP count: 8.000e+08 805 Optimized FLOP count: 8.000e+05 806 Theoretical speedup: 1000.000 807 Largest intermediate: 1.000e+04 elements 808 -------------------------------------------------------------------------- 809 scaling current remaining 810 -------------------------------------------------------------------------- 811 5 abcd,ea->bcde fb,gc,hd,bcde->efgh 812 5 bcde,fb->cdef gc,hd,cdef->efgh 813 5 cdef,gc->defg hd,defg->efgh 814 5 defg,hd->efgh efgh->efgh 815 """ 816 817 # Figure out what the path really is 818 path_type = optimize 819 if path_type is True: 820 path_type = 'greedy' 821 if path_type is None: 822 path_type = False 823 824 memory_limit = None 825 826 # No optimization or a named path algorithm 827 if (path_type is False) or isinstance(path_type, str): 828 pass 829 830 # Given an explicit path 831 elif len(path_type) and (path_type[0] == 'einsum_path'): 832 pass 833 834 # Path tuple with memory limit 835 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and 836 isinstance(path_type[1], (int, float))): 837 memory_limit = int(path_type[1]) 838 path_type = path_type[0] 839 840 else: 841 raise TypeError("Did not understand the path: %s" % str(path_type)) 842 843 # Hidden option, only einsum should call this 844 einsum_call_arg = einsum_call 845 846 # Python side parsing 847 input_subscripts, output_subscript, operands = _parse_einsum_input(operands) 848 849 # Build a few useful list and sets 850 input_list = input_subscripts.split(',') 851 input_sets = [set(x) for x in input_list] 852 output_set = set(output_subscript) 853 indices = set(input_subscripts.replace(',', '')) 854 855 # Get length of each unique dimension and ensure all dimensions are correct 856 dimension_dict = {} 857 broadcast_indices = [[] for x in range(len(input_list))] 858 for tnum, term in enumerate(input_list): 859 sh = operands[tnum].shape 860 if len(sh) != len(term): 861 raise ValueError("Einstein sum subscript %s does not contain the " 862 "correct number of indices for operand %d." 863 % (input_subscripts[tnum], tnum)) 864 for cnum, char in enumerate(term): 865 dim = sh[cnum] 866 867 # Build out broadcast indices 868 if dim == 1: 869 broadcast_indices[tnum].append(char) 870 871 if char in dimension_dict.keys(): 872 # For broadcasting cases we always want the largest dim size 873 if dimension_dict[char] == 1: 874 dimension_dict[char] = dim 875 elif dim not in (1, dimension_dict[char]): 876 raise ValueError("Size of label '%s' for operand %d (%d) " 877 "does not match previous terms (%d)." 878 % (char, tnum, dimension_dict[char], dim)) 879 else: 880 dimension_dict[char] = dim 881 882 # Convert broadcast inds to sets 883 broadcast_indices = [set(x) for x in broadcast_indices] 884 885 # Compute size of each input array plus the output array 886 size_list = [_compute_size_by_dict(term, dimension_dict) 887 for term in input_list + [output_subscript]] 888 max_size = max(size_list) 889 890 if memory_limit is None: 891 memory_arg = max_size 892 else: 893 memory_arg = memory_limit 894 895 # Compute naive cost 896 # This isn't quite right, need to look into exactly how einsum does this 897 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 898 naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict) 899 900 # Compute the path 901 if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set): 902 # Nothing to be optimized, leave it to einsum 903 path = [tuple(range(len(input_list)))] 904 elif path_type == "greedy": 905 path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg) 906 elif path_type == "optimal": 907 path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg) 908 elif path_type[0] == 'einsum_path': 909 path = path_type[1:] 910 else: 911 raise KeyError("Path name %s not found", path_type) 912 913 cost_list, scale_list, size_list, contraction_list = [], [], [], [] 914 915 # Build contraction tuple (positions, gemm, einsum_str, remaining) 916 for cnum, contract_inds in enumerate(path): 917 # Make sure we remove inds from right to left 918 contract_inds = tuple(sorted(list(contract_inds), reverse=True)) 919 920 contract = _find_contraction(contract_inds, input_sets, output_set) 921 out_inds, input_sets, idx_removed, idx_contract = contract 922 923 cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict) 924 cost_list.append(cost) 925 scale_list.append(len(idx_contract)) 926 size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) 927 928 bcast = set() 929 tmp_inputs = [] 930 for x in contract_inds: 931 tmp_inputs.append(input_list.pop(x)) 932 bcast |= broadcast_indices.pop(x) 933 934 new_bcast_inds = bcast - idx_removed 935 936 # If we're broadcasting, nix blas 937 if not len(idx_removed & bcast): 938 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed) 939 else: 940 do_blas = False 941 942 # Last contraction 943 if (cnum - len(path)) == -1: 944 idx_result = output_subscript 945 else: 946 sort_result = [(dimension_dict[ind], ind) for ind in out_inds] 947 idx_result = "".join([x[1] for x in sorted(sort_result)]) 948 949 input_list.append(idx_result) 950 broadcast_indices.append(new_bcast_inds) 951 einsum_str = ",".join(tmp_inputs) + "->" + idx_result 952 953 contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas) 954 contraction_list.append(contraction) 955 956 opt_cost = sum(cost_list) + 1 957 958 if einsum_call_arg: 959 return (operands, contraction_list) 960 961 # Return the path along with a nice string representation 962 overall_contraction = input_subscripts + "->" + output_subscript 963 header = ("scaling", "current", "remaining") 964 965 speedup = naive_cost / opt_cost 966 max_i = max(size_list) 967 968 path_print = " Complete contraction: %s\n" % overall_contraction 969 path_print += " Naive scaling: %d\n" % len(indices) 970 path_print += " Optimized scaling: %d\n" % max(scale_list) 971 path_print += " Naive FLOP count: %.3e\n" % naive_cost 972 path_print += " Optimized FLOP count: %.3e\n" % opt_cost 973 path_print += " Theoretical speedup: %3.3f\n" % speedup 974 path_print += " Largest intermediate: %.3e elements\n" % max_i 975 path_print += "-" * 74 + "\n" 976 path_print += "%6s %24s %40s\n" % header 977 path_print += "-" * 74 978 979 for n, contraction in enumerate(contraction_list): 980 inds, idx_rm, einsum_str, remaining, blas = contraction 981 remaining_str = ",".join(remaining) + "->" + output_subscript 982 path_run = (scale_list[n], einsum_str, remaining_str) 983 path_print += "\n%4d %24s %40s" % path_run 984 985 path = ['einsum_path'] + path 986 return (path, path_print) 987 988 989def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs): 990 # Arguably we dispatch on more arguments that we really should; see note in 991 # _einsum_path_dispatcher for why. 992 yield from operands 993 yield out 994 995 996# Rewrite einsum to handle different cases 997@array_function_dispatch(_einsum_dispatcher, module='numpy') 998def einsum(*operands, out=None, optimize=False, **kwargs): 999 """ 1000 einsum(subscripts, *operands, out=None, dtype=None, order='K', 1001 casting='safe', optimize=False) 1002 1003 Evaluates the Einstein summation convention on the operands. 1004 1005 Using the Einstein summation convention, many common multi-dimensional, 1006 linear algebraic array operations can be represented in a simple fashion. 1007 In *implicit* mode `einsum` computes these values. 1008 1009 In *explicit* mode, `einsum` provides further flexibility to compute 1010 other array operations that might not be considered classical Einstein 1011 summation operations, by disabling, or forcing summation over specified 1012 subscript labels. 1013 1014 See the notes and examples for clarification. 1015 1016 Parameters 1017 ---------- 1018 subscripts : str 1019 Specifies the subscripts for summation as comma separated list of 1020 subscript labels. An implicit (classical Einstein summation) 1021 calculation is performed unless the explicit indicator '->' is 1022 included as well as subscript labels of the precise output form. 1023 operands : list of array_like 1024 These are the arrays for the operation. 1025 out : ndarray, optional 1026 If provided, the calculation is done into this array. 1027 dtype : {data-type, None}, optional 1028 If provided, forces the calculation to use the data type specified. 1029 Note that you may have to also give a more liberal `casting` 1030 parameter to allow the conversions. Default is None. 1031 order : {'C', 'F', 'A', 'K'}, optional 1032 Controls the memory layout of the output. 'C' means it should 1033 be C contiguous. 'F' means it should be Fortran contiguous, 1034 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. 1035 'K' means it should be as close to the layout as the inputs as 1036 is possible, including arbitrarily permuted axes. 1037 Default is 'K'. 1038 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional 1039 Controls what kind of data casting may occur. Setting this to 1040 'unsafe' is not recommended, as it can adversely affect accumulations. 1041 1042 * 'no' means the data types should not be cast at all. 1043 * 'equiv' means only byte-order changes are allowed. 1044 * 'safe' means only casts which can preserve values are allowed. 1045 * 'same_kind' means only safe casts or casts within a kind, 1046 like float64 to float32, are allowed. 1047 * 'unsafe' means any data conversions may be done. 1048 1049 Default is 'safe'. 1050 optimize : {False, True, 'greedy', 'optimal'}, optional 1051 Controls if intermediate optimization should occur. No optimization 1052 will occur if False and True will default to the 'greedy' algorithm. 1053 Also accepts an explicit contraction list from the ``np.einsum_path`` 1054 function. See ``np.einsum_path`` for more details. Defaults to False. 1055 1056 Returns 1057 ------- 1058 output : ndarray 1059 The calculation based on the Einstein summation convention. 1060 1061 See Also 1062 -------- 1063 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot 1064 1065 einops: 1066 similar verbose interface is provided by 1067 `einops <https://github.com/arogozhnikov/einops>`_ package to cover 1068 additional operations: transpose, reshape/flatten, repeat/tile, 1069 squeeze/unsqueeze and reductions. 1070 1071 opt_einsum: 1072 `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ 1073 optimizes contraction order for einsum-like expressions 1074 in backend-agnostic manner. 1075 1076 Notes 1077 ----- 1078 .. versionadded:: 1.6.0 1079 1080 The Einstein summation convention can be used to compute 1081 many multi-dimensional, linear algebraic array operations. `einsum` 1082 provides a succinct way of representing these. 1083 1084 A non-exhaustive list of these operations, 1085 which can be computed by `einsum`, is shown below along with examples: 1086 1087 * Trace of an array, :py:func:`numpy.trace`. 1088 * Return a diagonal, :py:func:`numpy.diag`. 1089 * Array axis summations, :py:func:`numpy.sum`. 1090 * Transpositions and permutations, :py:func:`numpy.transpose`. 1091 * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`. 1092 * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`. 1093 * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`. 1094 * Tensor contractions, :py:func:`numpy.tensordot`. 1095 * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`. 1096 1097 The subscripts string is a comma-separated list of subscript labels, 1098 where each label refers to a dimension of the corresponding operand. 1099 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` 1100 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label 1101 appears only once, it is not summed, so ``np.einsum('i', a)`` produces a 1102 view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)`` 1103 describes traditional matrix multiplication and is equivalent to 1104 :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one 1105 operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent 1106 to :py:func:`np.trace(a) <numpy.trace>`. 1107 1108 In *implicit mode*, the chosen subscripts are important 1109 since the axes of the output are reordered alphabetically. This 1110 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while 1111 ``np.einsum('ji', a)`` takes its transpose. Additionally, 1112 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, 1113 ``np.einsum('ij,jh', a, b)`` returns the transpose of the 1114 multiplication since subscript 'h' precedes subscript 'i'. 1115 1116 In *explicit mode* the output can be directly controlled by 1117 specifying output subscript labels. This requires the 1118 identifier '->' as well as the list of output subscript labels. 1119 This feature increases the flexibility of the function since 1120 summing can be disabled or forced when required. The call 1121 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`, 1122 and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`. 1123 The difference is that `einsum` does not allow broadcasting by default. 1124 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the 1125 order of the output subscript labels and therefore returns matrix 1126 multiplication, unlike the example above in implicit mode. 1127 1128 To enable and control broadcasting, use an ellipsis. Default 1129 NumPy-style broadcasting is done by adding an ellipsis 1130 to the left of each term, like ``np.einsum('...ii->...i', a)``. 1131 To take the trace along the first and last axes, 1132 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix 1133 product with the left-most indices instead of rightmost, one can do 1134 ``np.einsum('ij...,jk...->ik...', a, b)``. 1135 1136 When there is only one operand, no axes are summed, and no output 1137 parameter is provided, a view into the operand is returned instead 1138 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` 1139 produces a view (changed in version 1.10.0). 1140 1141 `einsum` also provides an alternative way to provide the subscripts 1142 and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. 1143 If the output shape is not provided in this format `einsum` will be 1144 calculated in implicit mode, otherwise it will be performed explicitly. 1145 The examples below have corresponding `einsum` calls with the two 1146 parameter methods. 1147 1148 .. versionadded:: 1.10.0 1149 1150 Views returned from einsum are now writeable whenever the input array 1151 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now 1152 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>` 1153 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal 1154 of a 2D array. 1155 1156 .. versionadded:: 1.12.0 1157 1158 Added the ``optimize`` argument which will optimize the contraction order 1159 of an einsum expression. For a contraction with three or more operands this 1160 can greatly increase the computational efficiency at the cost of a larger 1161 memory footprint during computation. 1162 1163 Typically a 'greedy' algorithm is applied which empirical tests have shown 1164 returns the optimal path in the majority of cases. In some cases 'optimal' 1165 will return the superlative path through a more expensive, exhaustive search. 1166 For iterative calculations it may be advisable to calculate the optimal path 1167 once and reuse that path by supplying it as an argument. An example is given 1168 below. 1169 1170 See :py:func:`numpy.einsum_path` for more details. 1171 1172 Examples 1173 -------- 1174 >>> a = np.arange(25).reshape(5,5) 1175 >>> b = np.arange(5) 1176 >>> c = np.arange(6).reshape(2,3) 1177 1178 Trace of a matrix: 1179 1180 >>> np.einsum('ii', a) 1181 60 1182 >>> np.einsum(a, [0,0]) 1183 60 1184 >>> np.trace(a) 1185 60 1186 1187 Extract the diagonal (requires explicit form): 1188 1189 >>> np.einsum('ii->i', a) 1190 array([ 0, 6, 12, 18, 24]) 1191 >>> np.einsum(a, [0,0], [0]) 1192 array([ 0, 6, 12, 18, 24]) 1193 >>> np.diag(a) 1194 array([ 0, 6, 12, 18, 24]) 1195 1196 Sum over an axis (requires explicit form): 1197 1198 >>> np.einsum('ij->i', a) 1199 array([ 10, 35, 60, 85, 110]) 1200 >>> np.einsum(a, [0,1], [0]) 1201 array([ 10, 35, 60, 85, 110]) 1202 >>> np.sum(a, axis=1) 1203 array([ 10, 35, 60, 85, 110]) 1204 1205 For higher dimensional arrays summing a single axis can be done with ellipsis: 1206 1207 >>> np.einsum('...j->...', a) 1208 array([ 10, 35, 60, 85, 110]) 1209 >>> np.einsum(a, [Ellipsis,1], [Ellipsis]) 1210 array([ 10, 35, 60, 85, 110]) 1211 1212 Compute a matrix transpose, or reorder any number of axes: 1213 1214 >>> np.einsum('ji', c) 1215 array([[0, 3], 1216 [1, 4], 1217 [2, 5]]) 1218 >>> np.einsum('ij->ji', c) 1219 array([[0, 3], 1220 [1, 4], 1221 [2, 5]]) 1222 >>> np.einsum(c, [1,0]) 1223 array([[0, 3], 1224 [1, 4], 1225 [2, 5]]) 1226 >>> np.transpose(c) 1227 array([[0, 3], 1228 [1, 4], 1229 [2, 5]]) 1230 1231 Vector inner products: 1232 1233 >>> np.einsum('i,i', b, b) 1234 30 1235 >>> np.einsum(b, [0], b, [0]) 1236 30 1237 >>> np.inner(b,b) 1238 30 1239 1240 Matrix vector multiplication: 1241 1242 >>> np.einsum('ij,j', a, b) 1243 array([ 30, 80, 130, 180, 230]) 1244 >>> np.einsum(a, [0,1], b, [1]) 1245 array([ 30, 80, 130, 180, 230]) 1246 >>> np.dot(a, b) 1247 array([ 30, 80, 130, 180, 230]) 1248 >>> np.einsum('...j,j', a, b) 1249 array([ 30, 80, 130, 180, 230]) 1250 1251 Broadcasting and scalar multiplication: 1252 1253 >>> np.einsum('..., ...', 3, c) 1254 array([[ 0, 3, 6], 1255 [ 9, 12, 15]]) 1256 >>> np.einsum(',ij', 3, c) 1257 array([[ 0, 3, 6], 1258 [ 9, 12, 15]]) 1259 >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) 1260 array([[ 0, 3, 6], 1261 [ 9, 12, 15]]) 1262 >>> np.multiply(3, c) 1263 array([[ 0, 3, 6], 1264 [ 9, 12, 15]]) 1265 1266 Vector outer product: 1267 1268 >>> np.einsum('i,j', np.arange(2)+1, b) 1269 array([[0, 1, 2, 3, 4], 1270 [0, 2, 4, 6, 8]]) 1271 >>> np.einsum(np.arange(2)+1, [0], b, [1]) 1272 array([[0, 1, 2, 3, 4], 1273 [0, 2, 4, 6, 8]]) 1274 >>> np.outer(np.arange(2)+1, b) 1275 array([[0, 1, 2, 3, 4], 1276 [0, 2, 4, 6, 8]]) 1277 1278 Tensor contraction: 1279 1280 >>> a = np.arange(60.).reshape(3,4,5) 1281 >>> b = np.arange(24.).reshape(4,3,2) 1282 >>> np.einsum('ijk,jil->kl', a, b) 1283 array([[4400., 4730.], 1284 [4532., 4874.], 1285 [4664., 5018.], 1286 [4796., 5162.], 1287 [4928., 5306.]]) 1288 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) 1289 array([[4400., 4730.], 1290 [4532., 4874.], 1291 [4664., 5018.], 1292 [4796., 5162.], 1293 [4928., 5306.]]) 1294 >>> np.tensordot(a,b, axes=([1,0],[0,1])) 1295 array([[4400., 4730.], 1296 [4532., 4874.], 1297 [4664., 5018.], 1298 [4796., 5162.], 1299 [4928., 5306.]]) 1300 1301 Writeable returned arrays (since version 1.10.0): 1302 1303 >>> a = np.zeros((3, 3)) 1304 >>> np.einsum('ii->i', a)[:] = 1 1305 >>> a 1306 array([[1., 0., 0.], 1307 [0., 1., 0.], 1308 [0., 0., 1.]]) 1309 1310 Example of ellipsis use: 1311 1312 >>> a = np.arange(6).reshape((3,2)) 1313 >>> b = np.arange(12).reshape((4,3)) 1314 >>> np.einsum('ki,jk->ij', a, b) 1315 array([[10, 28, 46, 64], 1316 [13, 40, 67, 94]]) 1317 >>> np.einsum('ki,...k->i...', a, b) 1318 array([[10, 28, 46, 64], 1319 [13, 40, 67, 94]]) 1320 >>> np.einsum('k...,jk', a, b) 1321 array([[10, 28, 46, 64], 1322 [13, 40, 67, 94]]) 1323 1324 Chained array operations. For more complicated contractions, speed ups 1325 might be achieved by repeatedly computing a 'greedy' path or pre-computing the 1326 'optimal' path and repeatedly applying it, using an 1327 `einsum_path` insertion (since version 1.12.0). Performance improvements can be 1328 particularly significant with larger arrays: 1329 1330 >>> a = np.ones(64).reshape(2,4,8) 1331 1332 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) 1333 1334 >>> for iteration in range(500): 1335 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) 1336 1337 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms 1338 1339 >>> for iteration in range(500): 1340 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal') 1341 1342 Greedy `einsum` (faster optimal path approximation): ~160ms 1343 1344 >>> for iteration in range(500): 1345 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') 1346 1347 Optimal `einsum` (best usage pattern in some use cases): ~110ms 1348 1349 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0] 1350 >>> for iteration in range(500): 1351 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) 1352 1353 """ 1354 # Special handling if out is specified 1355 specified_out = out is not None 1356 1357 # If no optimization, run pure einsum 1358 if optimize is False: 1359 if specified_out: 1360 kwargs['out'] = out 1361 return c_einsum(*operands, **kwargs) 1362 1363 # Check the kwargs to avoid a more cryptic error later, without having to 1364 # repeat default values here 1365 valid_einsum_kwargs = ['dtype', 'order', 'casting'] 1366 unknown_kwargs = [k for (k, v) in kwargs.items() if 1367 k not in valid_einsum_kwargs] 1368 if len(unknown_kwargs): 1369 raise TypeError("Did not understand the following kwargs: %s" 1370 % unknown_kwargs) 1371 1372 # Build the contraction list and operand 1373 operands, contraction_list = einsum_path(*operands, optimize=optimize, 1374 einsum_call=True) 1375 1376 # Handle order kwarg for output array, c_einsum allows mixed case 1377 output_order = kwargs.pop('order', 'K') 1378 if output_order.upper() == 'A': 1379 if all(arr.flags.f_contiguous for arr in operands): 1380 output_order = 'F' 1381 else: 1382 output_order = 'C' 1383 1384 # Start contraction loop 1385 for num, contraction in enumerate(contraction_list): 1386 inds, idx_rm, einsum_str, remaining, blas = contraction 1387 tmp_operands = [operands.pop(x) for x in inds] 1388 1389 # Do we need to deal with the output? 1390 handle_out = specified_out and ((num + 1) == len(contraction_list)) 1391 1392 # Call tensordot if still possible 1393 if blas: 1394 # Checks have already been handled 1395 input_str, results_index = einsum_str.split('->') 1396 input_left, input_right = input_str.split(',') 1397 1398 tensor_result = input_left + input_right 1399 for s in idx_rm: 1400 tensor_result = tensor_result.replace(s, "") 1401 1402 # Find indices to contract over 1403 left_pos, right_pos = [], [] 1404 for s in sorted(idx_rm): 1405 left_pos.append(input_left.find(s)) 1406 right_pos.append(input_right.find(s)) 1407 1408 # Contract! 1409 new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos))) 1410 1411 # Build a new view if needed 1412 if (tensor_result != results_index) or handle_out: 1413 if handle_out: 1414 kwargs["out"] = out 1415 new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs) 1416 1417 # Call einsum 1418 else: 1419 # If out was specified 1420 if handle_out: 1421 kwargs["out"] = out 1422 1423 # Do the contraction 1424 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs) 1425 1426 # Append new items and dereference what we can 1427 operands.append(new_view) 1428 del tmp_operands, new_view 1429 1430 if specified_out: 1431 return out 1432 else: 1433 return asanyarray(operands[0], order=output_order) 1434