1"""
2Implementation of optimized einsum.
3
4"""
5import itertools
6import operator
7
8from numpy.core.multiarray import c_einsum
9from numpy.core.numeric import asanyarray, tensordot
10from numpy.core.overrides import array_function_dispatch
11
12__all__ = ['einsum', 'einsum_path']
13
14einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
15einsum_symbols_set = set(einsum_symbols)
16
17
18def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
19    """
20    Computes the number of FLOPS in the contraction.
21
22    Parameters
23    ----------
24    idx_contraction : iterable
25        The indices involved in the contraction
26    inner : bool
27        Does this contraction require an inner product?
28    num_terms : int
29        The number of terms in a contraction
30    size_dictionary : dict
31        The size of each of the indices in idx_contraction
32
33    Returns
34    -------
35    flop_count : int
36        The total number of FLOPS required for the contraction.
37
38    Examples
39    --------
40
41    >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
42    30
43
44    >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
45    60
46
47    """
48
49    overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
50    op_factor = max(1, num_terms - 1)
51    if inner:
52        op_factor += 1
53
54    return overall_size * op_factor
55
56def _compute_size_by_dict(indices, idx_dict):
57    """
58    Computes the product of the elements in indices based on the dictionary
59    idx_dict.
60
61    Parameters
62    ----------
63    indices : iterable
64        Indices to base the product on.
65    idx_dict : dictionary
66        Dictionary of index sizes
67
68    Returns
69    -------
70    ret : int
71        The resulting product.
72
73    Examples
74    --------
75    >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
76    90
77
78    """
79    ret = 1
80    for i in indices:
81        ret *= idx_dict[i]
82    return ret
83
84
85def _find_contraction(positions, input_sets, output_set):
86    """
87    Finds the contraction for a given set of input and output sets.
88
89    Parameters
90    ----------
91    positions : iterable
92        Integer positions of terms used in the contraction.
93    input_sets : list
94        List of sets that represent the lhs side of the einsum subscript
95    output_set : set
96        Set that represents the rhs side of the overall einsum subscript
97
98    Returns
99    -------
100    new_result : set
101        The indices of the resulting contraction
102    remaining : list
103        List of sets that have not been contracted, the new set is appended to
104        the end of this list
105    idx_removed : set
106        Indices removed from the entire contraction
107    idx_contraction : set
108        The indices used in the current contraction
109
110    Examples
111    --------
112
113    # A simple dot product test case
114    >>> pos = (0, 1)
115    >>> isets = [set('ab'), set('bc')]
116    >>> oset = set('ac')
117    >>> _find_contraction(pos, isets, oset)
118    ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
119
120    # A more complex case with additional terms in the contraction
121    >>> pos = (0, 2)
122    >>> isets = [set('abd'), set('ac'), set('bdc')]
123    >>> oset = set('ac')
124    >>> _find_contraction(pos, isets, oset)
125    ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
126    """
127
128    idx_contract = set()
129    idx_remain = output_set.copy()
130    remaining = []
131    for ind, value in enumerate(input_sets):
132        if ind in positions:
133            idx_contract |= value
134        else:
135            remaining.append(value)
136            idx_remain |= value
137
138    new_result = idx_remain & idx_contract
139    idx_removed = (idx_contract - new_result)
140    remaining.append(new_result)
141
142    return (new_result, remaining, idx_removed, idx_contract)
143
144
145def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
146    """
147    Computes all possible pair contractions, sieves the results based
148    on ``memory_limit`` and returns the lowest cost path. This algorithm
149    scales factorial with respect to the elements in the list ``input_sets``.
150
151    Parameters
152    ----------
153    input_sets : list
154        List of sets that represent the lhs side of the einsum subscript
155    output_set : set
156        Set that represents the rhs side of the overall einsum subscript
157    idx_dict : dictionary
158        Dictionary of index sizes
159    memory_limit : int
160        The maximum number of elements in a temporary array
161
162    Returns
163    -------
164    path : list
165        The optimal contraction order within the memory limit constraint.
166
167    Examples
168    --------
169    >>> isets = [set('abd'), set('ac'), set('bdc')]
170    >>> oset = set()
171    >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
172    >>> _optimal_path(isets, oset, idx_sizes, 5000)
173    [(0, 2), (0, 1)]
174    """
175
176    full_results = [(0, [], input_sets)]
177    for iteration in range(len(input_sets) - 1):
178        iter_results = []
179
180        # Compute all unique pairs
181        for curr in full_results:
182            cost, positions, remaining = curr
183            for con in itertools.combinations(range(len(input_sets) - iteration), 2):
184
185                # Find the contraction
186                cont = _find_contraction(con, remaining, output_set)
187                new_result, new_input_sets, idx_removed, idx_contract = cont
188
189                # Sieve the results based on memory_limit
190                new_size = _compute_size_by_dict(new_result, idx_dict)
191                if new_size > memory_limit:
192                    continue
193
194                # Build (total_cost, positions, indices_remaining)
195                total_cost =  cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
196                new_pos = positions + [con]
197                iter_results.append((total_cost, new_pos, new_input_sets))
198
199        # Update combinatorial list, if we did not find anything return best
200        # path + remaining contractions
201        if iter_results:
202            full_results = iter_results
203        else:
204            path = min(full_results, key=lambda x: x[0])[1]
205            path += [tuple(range(len(input_sets) - iteration))]
206            return path
207
208    # If we have not found anything return single einsum contraction
209    if len(full_results) == 0:
210        return [tuple(range(len(input_sets)))]
211
212    path = min(full_results, key=lambda x: x[0])[1]
213    return path
214
215def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
216    """Compute the cost (removed size + flops) and resultant indices for
217    performing the contraction specified by ``positions``.
218
219    Parameters
220    ----------
221    positions : tuple of int
222        The locations of the proposed tensors to contract.
223    input_sets : list of sets
224        The indices found on each tensors.
225    output_set : set
226        The output indices of the expression.
227    idx_dict : dict
228        Mapping of each index to its size.
229    memory_limit : int
230        The total allowed size for an intermediary tensor.
231    path_cost : int
232        The contraction cost so far.
233    naive_cost : int
234        The cost of the unoptimized expression.
235
236    Returns
237    -------
238    cost : (int, int)
239        A tuple containing the size of any indices removed, and the flop cost.
240    positions : tuple of int
241        The locations of the proposed tensors to contract.
242    new_input_sets : list of sets
243        The resulting new list of indices if this proposed contraction is performed.
244
245    """
246
247    # Find the contraction
248    contract = _find_contraction(positions, input_sets, output_set)
249    idx_result, new_input_sets, idx_removed, idx_contract = contract
250
251    # Sieve the results based on memory_limit
252    new_size = _compute_size_by_dict(idx_result, idx_dict)
253    if new_size > memory_limit:
254        return None
255
256    # Build sort tuple
257    old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
258    removed_size = sum(old_sizes) - new_size
259
260    # NB: removed_size used to be just the size of any removed indices i.e.:
261    #     helpers.compute_size_by_dict(idx_removed, idx_dict)
262    cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
263    sort = (-removed_size, cost)
264
265    # Sieve based on total cost as well
266    if (path_cost + cost) > naive_cost:
267        return None
268
269    # Add contraction to possible choices
270    return [sort, positions, new_input_sets]
271
272
273def _update_other_results(results, best):
274    """Update the positions and provisional input_sets of ``results`` based on
275    performing the contraction result ``best``. Remove any involving the tensors
276    contracted.
277
278    Parameters
279    ----------
280    results : list
281        List of contraction results produced by ``_parse_possible_contraction``.
282    best : list
283        The best contraction of ``results`` i.e. the one that will be performed.
284
285    Returns
286    -------
287    mod_results : list
288        The list of modified results, updated with outcome of ``best`` contraction.
289    """
290
291    best_con = best[1]
292    bx, by = best_con
293    mod_results = []
294
295    for cost, (x, y), con_sets in results:
296
297        # Ignore results involving tensors just contracted
298        if x in best_con or y in best_con:
299            continue
300
301        # Update the input_sets
302        del con_sets[by - int(by > x) - int(by > y)]
303        del con_sets[bx - int(bx > x) - int(bx > y)]
304        con_sets.insert(-1, best[2][-1])
305
306        # Update the position indices
307        mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
308        mod_results.append((cost, mod_con, con_sets))
309
310    return mod_results
311
312def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
313    """
314    Finds the path by contracting the best pair until the input list is
315    exhausted. The best pair is found by minimizing the tuple
316    ``(-prod(indices_removed), cost)``.  What this amounts to is prioritizing
317    matrix multiplication or inner product operations, then Hadamard like
318    operations, and finally outer operations. Outer products are limited by
319    ``memory_limit``. This algorithm scales cubically with respect to the
320    number of elements in the list ``input_sets``.
321
322    Parameters
323    ----------
324    input_sets : list
325        List of sets that represent the lhs side of the einsum subscript
326    output_set : set
327        Set that represents the rhs side of the overall einsum subscript
328    idx_dict : dictionary
329        Dictionary of index sizes
330    memory_limit_limit : int
331        The maximum number of elements in a temporary array
332
333    Returns
334    -------
335    path : list
336        The greedy contraction order within the memory limit constraint.
337
338    Examples
339    --------
340    >>> isets = [set('abd'), set('ac'), set('bdc')]
341    >>> oset = set()
342    >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
343    >>> _greedy_path(isets, oset, idx_sizes, 5000)
344    [(0, 2), (0, 1)]
345    """
346
347    # Handle trivial cases that leaked through
348    if len(input_sets) == 1:
349        return [(0,)]
350    elif len(input_sets) == 2:
351        return [(0, 1)]
352
353    # Build up a naive cost
354    contract = _find_contraction(range(len(input_sets)), input_sets, output_set)
355    idx_result, new_input_sets, idx_removed, idx_contract = contract
356    naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
357
358    # Initially iterate over all pairs
359    comb_iter = itertools.combinations(range(len(input_sets)), 2)
360    known_contractions = []
361
362    path_cost = 0
363    path = []
364
365    for iteration in range(len(input_sets) - 1):
366
367        # Iterate over all pairs on first step, only previously found pairs on subsequent steps
368        for positions in comb_iter:
369
370            # Always initially ignore outer products
371            if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
372                continue
373
374            result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
375                                                 naive_cost)
376            if result is not None:
377                known_contractions.append(result)
378
379        # If we do not have a inner contraction, rescan pairs including outer products
380        if len(known_contractions) == 0:
381
382            # Then check the outer products
383            for positions in itertools.combinations(range(len(input_sets)), 2):
384                result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
385                                                     path_cost, naive_cost)
386                if result is not None:
387                    known_contractions.append(result)
388
389            # If we still did not find any remaining contractions, default back to einsum like behavior
390            if len(known_contractions) == 0:
391                path.append(tuple(range(len(input_sets))))
392                break
393
394        # Sort based on first index
395        best = min(known_contractions, key=lambda x: x[0])
396
397        # Now propagate as many unused contractions as possible to next iteration
398        known_contractions = _update_other_results(known_contractions, best)
399
400        # Next iteration only compute contractions with the new tensor
401        # All other contractions have been accounted for
402        input_sets = best[2]
403        new_tensor_pos = len(input_sets) - 1
404        comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
405
406        # Update path and total cost
407        path.append(best[1])
408        path_cost += best[0][1]
409
410    return path
411
412
413def _can_dot(inputs, result, idx_removed):
414    """
415    Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
416
417    Parameters
418    ----------
419    inputs : list of str
420        Specifies the subscripts for summation.
421    result : str
422        Resulting summation.
423    idx_removed : set
424        Indices that are removed in the summation
425
426
427    Returns
428    -------
429    type : bool
430        Returns true if BLAS should and can be used, else False
431
432    Notes
433    -----
434    If the operations is BLAS level 1 or 2 and is not already aligned
435    we default back to einsum as the memory movement to copy is more
436    costly than the operation itself.
437
438
439    Examples
440    --------
441
442    # Standard GEMM operation
443    >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
444    True
445
446    # Can use the standard BLAS, but requires odd data movement
447    >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
448    False
449
450    # DDOT where the memory is not aligned
451    >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
452    False
453
454    """
455
456    # All `dot` calls remove indices
457    if len(idx_removed) == 0:
458        return False
459
460    # BLAS can only handle two operands
461    if len(inputs) != 2:
462        return False
463
464    input_left, input_right = inputs
465
466    for c in set(input_left + input_right):
467        # can't deal with repeated indices on same input or more than 2 total
468        nl, nr = input_left.count(c), input_right.count(c)
469        if (nl > 1) or (nr > 1) or (nl + nr > 2):
470            return False
471
472        # can't do implicit summation or dimension collapse e.g.
473        #     "ab,bc->c" (implicitly sum over 'a')
474        #     "ab,ca->ca" (take diagonal of 'a')
475        if nl + nr - 1 == int(c in result):
476            return False
477
478    # Build a few temporaries
479    set_left = set(input_left)
480    set_right = set(input_right)
481    keep_left = set_left - idx_removed
482    keep_right = set_right - idx_removed
483    rs = len(idx_removed)
484
485    # At this point we are a DOT, GEMV, or GEMM operation
486
487    # Handle inner products
488
489    # DDOT with aligned data
490    if input_left == input_right:
491        return True
492
493    # DDOT without aligned data (better to use einsum)
494    if set_left == set_right:
495        return False
496
497    # Handle the 4 possible (aligned) GEMV or GEMM cases
498
499    # GEMM or GEMV no transpose
500    if input_left[-rs:] == input_right[:rs]:
501        return True
502
503    # GEMM or GEMV transpose both
504    if input_left[:rs] == input_right[-rs:]:
505        return True
506
507    # GEMM or GEMV transpose right
508    if input_left[-rs:] == input_right[-rs:]:
509        return True
510
511    # GEMM or GEMV transpose left
512    if input_left[:rs] == input_right[:rs]:
513        return True
514
515    # Einsum is faster than GEMV if we have to copy data
516    if not keep_left or not keep_right:
517        return False
518
519    # We are a matrix-matrix product, but we need to copy data
520    return True
521
522
523def _parse_einsum_input(operands):
524    """
525    A reproduction of einsum c side einsum parsing in python.
526
527    Returns
528    -------
529    input_strings : str
530        Parsed input strings
531    output_string : str
532        Parsed output string
533    operands : list of array_like
534        The operands to use in the numpy contraction
535
536    Examples
537    --------
538    The operand list is simplified to reduce printing:
539
540    >>> np.random.seed(123)
541    >>> a = np.random.rand(4, 4)
542    >>> b = np.random.rand(4, 4, 4)
543    >>> _parse_einsum_input(('...a,...a->...', a, b))
544    ('za,xza', 'xz', [a, b]) # may vary
545
546    >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
547    ('za,xza', 'xz', [a, b]) # may vary
548    """
549
550    if len(operands) == 0:
551        raise ValueError("No input operands")
552
553    if isinstance(operands[0], str):
554        subscripts = operands[0].replace(" ", "")
555        operands = [asanyarray(v) for v in operands[1:]]
556
557        # Ensure all characters are valid
558        for s in subscripts:
559            if s in '.,->':
560                continue
561            if s not in einsum_symbols:
562                raise ValueError("Character %s is not a valid symbol." % s)
563
564    else:
565        tmp_operands = list(operands)
566        operand_list = []
567        subscript_list = []
568        for p in range(len(operands) // 2):
569            operand_list.append(tmp_operands.pop(0))
570            subscript_list.append(tmp_operands.pop(0))
571
572        output_list = tmp_operands[-1] if len(tmp_operands) else None
573        operands = [asanyarray(v) for v in operand_list]
574        subscripts = ""
575        last = len(subscript_list) - 1
576        for num, sub in enumerate(subscript_list):
577            for s in sub:
578                if s is Ellipsis:
579                    subscripts += "..."
580                else:
581                    try:
582                        s = operator.index(s)
583                    except TypeError as e:
584                        raise TypeError("For this input type lists must contain "
585                                        "either int or Ellipsis") from e
586                    subscripts += einsum_symbols[s]
587            if num != last:
588                subscripts += ","
589
590        if output_list is not None:
591            subscripts += "->"
592            for s in output_list:
593                if s is Ellipsis:
594                    subscripts += "..."
595                else:
596                    try:
597                        s = operator.index(s)
598                    except TypeError as e:
599                        raise TypeError("For this input type lists must contain "
600                                        "either int or Ellipsis") from e
601                    subscripts += einsum_symbols[s]
602    # Check for proper "->"
603    if ("-" in subscripts) or (">" in subscripts):
604        invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
605        if invalid or (subscripts.count("->") != 1):
606            raise ValueError("Subscripts can only contain one '->'.")
607
608    # Parse ellipses
609    if "." in subscripts:
610        used = subscripts.replace(".", "").replace(",", "").replace("->", "")
611        unused = list(einsum_symbols_set - set(used))
612        ellipse_inds = "".join(unused)
613        longest = 0
614
615        if "->" in subscripts:
616            input_tmp, output_sub = subscripts.split("->")
617            split_subscripts = input_tmp.split(",")
618            out_sub = True
619        else:
620            split_subscripts = subscripts.split(',')
621            out_sub = False
622
623        for num, sub in enumerate(split_subscripts):
624            if "." in sub:
625                if (sub.count(".") != 3) or (sub.count("...") != 1):
626                    raise ValueError("Invalid Ellipses.")
627
628                # Take into account numerical values
629                if operands[num].shape == ():
630                    ellipse_count = 0
631                else:
632                    ellipse_count = max(operands[num].ndim, 1)
633                    ellipse_count -= (len(sub) - 3)
634
635                if ellipse_count > longest:
636                    longest = ellipse_count
637
638                if ellipse_count < 0:
639                    raise ValueError("Ellipses lengths do not match.")
640                elif ellipse_count == 0:
641                    split_subscripts[num] = sub.replace('...', '')
642                else:
643                    rep_inds = ellipse_inds[-ellipse_count:]
644                    split_subscripts[num] = sub.replace('...', rep_inds)
645
646        subscripts = ",".join(split_subscripts)
647        if longest == 0:
648            out_ellipse = ""
649        else:
650            out_ellipse = ellipse_inds[-longest:]
651
652        if out_sub:
653            subscripts += "->" + output_sub.replace("...", out_ellipse)
654        else:
655            # Special care for outputless ellipses
656            output_subscript = ""
657            tmp_subscripts = subscripts.replace(",", "")
658            for s in sorted(set(tmp_subscripts)):
659                if s not in (einsum_symbols):
660                    raise ValueError("Character %s is not a valid symbol." % s)
661                if tmp_subscripts.count(s) == 1:
662                    output_subscript += s
663            normal_inds = ''.join(sorted(set(output_subscript) -
664                                         set(out_ellipse)))
665
666            subscripts += "->" + out_ellipse + normal_inds
667
668    # Build output string if does not exist
669    if "->" in subscripts:
670        input_subscripts, output_subscript = subscripts.split("->")
671    else:
672        input_subscripts = subscripts
673        # Build output subscripts
674        tmp_subscripts = subscripts.replace(",", "")
675        output_subscript = ""
676        for s in sorted(set(tmp_subscripts)):
677            if s not in einsum_symbols:
678                raise ValueError("Character %s is not a valid symbol." % s)
679            if tmp_subscripts.count(s) == 1:
680                output_subscript += s
681
682    # Make sure output subscripts are in the input
683    for char in output_subscript:
684        if char not in input_subscripts:
685            raise ValueError("Output character %s did not appear in the input"
686                             % char)
687
688    # Make sure number operands is equivalent to the number of terms
689    if len(input_subscripts.split(',')) != len(operands):
690        raise ValueError("Number of einsum subscripts must be equal to the "
691                         "number of operands.")
692
693    return (input_subscripts, output_subscript, operands)
694
695
696def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
697    # NOTE: technically, we should only dispatch on array-like arguments, not
698    # subscripts (given as strings). But separating operands into
699    # arrays/subscripts is a little tricky/slow (given einsum's two supported
700    # signatures), so as a practical shortcut we dispatch on everything.
701    # Strings will be ignored for dispatching since they don't define
702    # __array_function__.
703    return operands
704
705
706@array_function_dispatch(_einsum_path_dispatcher, module='numpy')
707def einsum_path(*operands, optimize='greedy', einsum_call=False):
708    """
709    einsum_path(subscripts, *operands, optimize='greedy')
710
711    Evaluates the lowest cost contraction order for an einsum expression by
712    considering the creation of intermediate arrays.
713
714    Parameters
715    ----------
716    subscripts : str
717        Specifies the subscripts for summation.
718    *operands : list of array_like
719        These are the arrays for the operation.
720    optimize : {bool, list, tuple, 'greedy', 'optimal'}
721        Choose the type of path. If a tuple is provided, the second argument is
722        assumed to be the maximum intermediate size created. If only a single
723        argument is provided the largest input or output array size is used
724        as a maximum intermediate size.
725
726        * if a list is given that starts with ``einsum_path``, uses this as the
727          contraction path
728        * if False no optimization is taken
729        * if True defaults to the 'greedy' algorithm
730        * 'optimal' An algorithm that combinatorially explores all possible
731          ways of contracting the listed tensors and choosest the least costly
732          path. Scales exponentially with the number of terms in the
733          contraction.
734        * 'greedy' An algorithm that chooses the best pair contraction
735          at each step. Effectively, this algorithm searches the largest inner,
736          Hadamard, and then outer products at each step. Scales cubically with
737          the number of terms in the contraction. Equivalent to the 'optimal'
738          path for most contractions.
739
740        Default is 'greedy'.
741
742    Returns
743    -------
744    path : list of tuples
745        A list representation of the einsum path.
746    string_repr : str
747        A printable representation of the einsum path.
748
749    Notes
750    -----
751    The resulting path indicates which terms of the input contraction should be
752    contracted first, the result of this contraction is then appended to the
753    end of the contraction list. This list can then be iterated over until all
754    intermediate contractions are complete.
755
756    See Also
757    --------
758    einsum, linalg.multi_dot
759
760    Examples
761    --------
762
763    We can begin with a chain dot example. In this case, it is optimal to
764    contract the ``b`` and ``c`` tensors first as represented by the first
765    element of the path ``(1, 2)``. The resulting tensor is added to the end
766    of the contraction and the remaining contraction ``(0, 1)`` is then
767    completed.
768
769    >>> np.random.seed(123)
770    >>> a = np.random.rand(2, 2)
771    >>> b = np.random.rand(2, 5)
772    >>> c = np.random.rand(5, 2)
773    >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
774    >>> print(path_info[0])
775    ['einsum_path', (1, 2), (0, 1)]
776    >>> print(path_info[1])
777      Complete contraction:  ij,jk,kl->il # may vary
778             Naive scaling:  4
779         Optimized scaling:  3
780          Naive FLOP count:  1.600e+02
781      Optimized FLOP count:  5.600e+01
782       Theoretical speedup:  2.857
783      Largest intermediate:  4.000e+00 elements
784    -------------------------------------------------------------------------
785    scaling                  current                                remaining
786    -------------------------------------------------------------------------
787       3                   kl,jk->jl                                ij,jl->il
788       3                   jl,ij->il                                   il->il
789
790
791    A more complex index transformation example.
792
793    >>> I = np.random.rand(10, 10, 10, 10)
794    >>> C = np.random.rand(10, 10)
795    >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
796    ...                            optimize='greedy')
797
798    >>> print(path_info[0])
799    ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
800    >>> print(path_info[1])
801      Complete contraction:  ea,fb,abcd,gc,hd->efgh # may vary
802             Naive scaling:  8
803         Optimized scaling:  5
804          Naive FLOP count:  8.000e+08
805      Optimized FLOP count:  8.000e+05
806       Theoretical speedup:  1000.000
807      Largest intermediate:  1.000e+04 elements
808    --------------------------------------------------------------------------
809    scaling                  current                                remaining
810    --------------------------------------------------------------------------
811       5               abcd,ea->bcde                      fb,gc,hd,bcde->efgh
812       5               bcde,fb->cdef                         gc,hd,cdef->efgh
813       5               cdef,gc->defg                            hd,defg->efgh
814       5               defg,hd->efgh                               efgh->efgh
815    """
816
817    # Figure out what the path really is
818    path_type = optimize
819    if path_type is True:
820        path_type = 'greedy'
821    if path_type is None:
822        path_type = False
823
824    memory_limit = None
825
826    # No optimization or a named path algorithm
827    if (path_type is False) or isinstance(path_type, str):
828        pass
829
830    # Given an explicit path
831    elif len(path_type) and (path_type[0] == 'einsum_path'):
832        pass
833
834    # Path tuple with memory limit
835    elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
836            isinstance(path_type[1], (int, float))):
837        memory_limit = int(path_type[1])
838        path_type = path_type[0]
839
840    else:
841        raise TypeError("Did not understand the path: %s" % str(path_type))
842
843    # Hidden option, only einsum should call this
844    einsum_call_arg = einsum_call
845
846    # Python side parsing
847    input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
848
849    # Build a few useful list and sets
850    input_list = input_subscripts.split(',')
851    input_sets = [set(x) for x in input_list]
852    output_set = set(output_subscript)
853    indices = set(input_subscripts.replace(',', ''))
854
855    # Get length of each unique dimension and ensure all dimensions are correct
856    dimension_dict = {}
857    broadcast_indices = [[] for x in range(len(input_list))]
858    for tnum, term in enumerate(input_list):
859        sh = operands[tnum].shape
860        if len(sh) != len(term):
861            raise ValueError("Einstein sum subscript %s does not contain the "
862                             "correct number of indices for operand %d."
863                             % (input_subscripts[tnum], tnum))
864        for cnum, char in enumerate(term):
865            dim = sh[cnum]
866
867            # Build out broadcast indices
868            if dim == 1:
869                broadcast_indices[tnum].append(char)
870
871            if char in dimension_dict.keys():
872                # For broadcasting cases we always want the largest dim size
873                if dimension_dict[char] == 1:
874                    dimension_dict[char] = dim
875                elif dim not in (1, dimension_dict[char]):
876                    raise ValueError("Size of label '%s' for operand %d (%d) "
877                                     "does not match previous terms (%d)."
878                                     % (char, tnum, dimension_dict[char], dim))
879            else:
880                dimension_dict[char] = dim
881
882    # Convert broadcast inds to sets
883    broadcast_indices = [set(x) for x in broadcast_indices]
884
885    # Compute size of each input array plus the output array
886    size_list = [_compute_size_by_dict(term, dimension_dict)
887                 for term in input_list + [output_subscript]]
888    max_size = max(size_list)
889
890    if memory_limit is None:
891        memory_arg = max_size
892    else:
893        memory_arg = memory_limit
894
895    # Compute naive cost
896    # This isn't quite right, need to look into exactly how einsum does this
897    inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
898    naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
899
900    # Compute the path
901    if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set):
902        # Nothing to be optimized, leave it to einsum
903        path = [tuple(range(len(input_list)))]
904    elif path_type == "greedy":
905        path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
906    elif path_type == "optimal":
907        path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
908    elif path_type[0] == 'einsum_path':
909        path = path_type[1:]
910    else:
911        raise KeyError("Path name %s not found", path_type)
912
913    cost_list, scale_list, size_list, contraction_list = [], [], [], []
914
915    # Build contraction tuple (positions, gemm, einsum_str, remaining)
916    for cnum, contract_inds in enumerate(path):
917        # Make sure we remove inds from right to left
918        contract_inds = tuple(sorted(list(contract_inds), reverse=True))
919
920        contract = _find_contraction(contract_inds, input_sets, output_set)
921        out_inds, input_sets, idx_removed, idx_contract = contract
922
923        cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
924        cost_list.append(cost)
925        scale_list.append(len(idx_contract))
926        size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
927
928        bcast = set()
929        tmp_inputs = []
930        for x in contract_inds:
931            tmp_inputs.append(input_list.pop(x))
932            bcast |= broadcast_indices.pop(x)
933
934        new_bcast_inds = bcast - idx_removed
935
936        # If we're broadcasting, nix blas
937        if not len(idx_removed & bcast):
938            do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
939        else:
940            do_blas = False
941
942        # Last contraction
943        if (cnum - len(path)) == -1:
944            idx_result = output_subscript
945        else:
946            sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
947            idx_result = "".join([x[1] for x in sorted(sort_result)])
948
949        input_list.append(idx_result)
950        broadcast_indices.append(new_bcast_inds)
951        einsum_str = ",".join(tmp_inputs) + "->" + idx_result
952
953        contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
954        contraction_list.append(contraction)
955
956    opt_cost = sum(cost_list) + 1
957
958    if einsum_call_arg:
959        return (operands, contraction_list)
960
961    # Return the path along with a nice string representation
962    overall_contraction = input_subscripts + "->" + output_subscript
963    header = ("scaling", "current", "remaining")
964
965    speedup = naive_cost / opt_cost
966    max_i = max(size_list)
967
968    path_print  = "  Complete contraction:  %s\n" % overall_contraction
969    path_print += "         Naive scaling:  %d\n" % len(indices)
970    path_print += "     Optimized scaling:  %d\n" % max(scale_list)
971    path_print += "      Naive FLOP count:  %.3e\n" % naive_cost
972    path_print += "  Optimized FLOP count:  %.3e\n" % opt_cost
973    path_print += "   Theoretical speedup:  %3.3f\n" % speedup
974    path_print += "  Largest intermediate:  %.3e elements\n" % max_i
975    path_print += "-" * 74 + "\n"
976    path_print += "%6s %24s %40s\n" % header
977    path_print += "-" * 74
978
979    for n, contraction in enumerate(contraction_list):
980        inds, idx_rm, einsum_str, remaining, blas = contraction
981        remaining_str = ",".join(remaining) + "->" + output_subscript
982        path_run = (scale_list[n], einsum_str, remaining_str)
983        path_print += "\n%4d    %24s %40s" % path_run
984
985    path = ['einsum_path'] + path
986    return (path, path_print)
987
988
989def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
990    # Arguably we dispatch on more arguments that we really should; see note in
991    # _einsum_path_dispatcher for why.
992    yield from operands
993    yield out
994
995
996# Rewrite einsum to handle different cases
997@array_function_dispatch(_einsum_dispatcher, module='numpy')
998def einsum(*operands, out=None, optimize=False, **kwargs):
999    """
1000    einsum(subscripts, *operands, out=None, dtype=None, order='K',
1001           casting='safe', optimize=False)
1002
1003    Evaluates the Einstein summation convention on the operands.
1004
1005    Using the Einstein summation convention, many common multi-dimensional,
1006    linear algebraic array operations can be represented in a simple fashion.
1007    In *implicit* mode `einsum` computes these values.
1008
1009    In *explicit* mode, `einsum` provides further flexibility to compute
1010    other array operations that might not be considered classical Einstein
1011    summation operations, by disabling, or forcing summation over specified
1012    subscript labels.
1013
1014    See the notes and examples for clarification.
1015
1016    Parameters
1017    ----------
1018    subscripts : str
1019        Specifies the subscripts for summation as comma separated list of
1020        subscript labels. An implicit (classical Einstein summation)
1021        calculation is performed unless the explicit indicator '->' is
1022        included as well as subscript labels of the precise output form.
1023    operands : list of array_like
1024        These are the arrays for the operation.
1025    out : ndarray, optional
1026        If provided, the calculation is done into this array.
1027    dtype : {data-type, None}, optional
1028        If provided, forces the calculation to use the data type specified.
1029        Note that you may have to also give a more liberal `casting`
1030        parameter to allow the conversions. Default is None.
1031    order : {'C', 'F', 'A', 'K'}, optional
1032        Controls the memory layout of the output. 'C' means it should
1033        be C contiguous. 'F' means it should be Fortran contiguous,
1034        'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
1035        'K' means it should be as close to the layout as the inputs as
1036        is possible, including arbitrarily permuted axes.
1037        Default is 'K'.
1038    casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1039        Controls what kind of data casting may occur.  Setting this to
1040        'unsafe' is not recommended, as it can adversely affect accumulations.
1041
1042          * 'no' means the data types should not be cast at all.
1043          * 'equiv' means only byte-order changes are allowed.
1044          * 'safe' means only casts which can preserve values are allowed.
1045          * 'same_kind' means only safe casts or casts within a kind,
1046            like float64 to float32, are allowed.
1047          * 'unsafe' means any data conversions may be done.
1048
1049        Default is 'safe'.
1050    optimize : {False, True, 'greedy', 'optimal'}, optional
1051        Controls if intermediate optimization should occur. No optimization
1052        will occur if False and True will default to the 'greedy' algorithm.
1053        Also accepts an explicit contraction list from the ``np.einsum_path``
1054        function. See ``np.einsum_path`` for more details. Defaults to False.
1055
1056    Returns
1057    -------
1058    output : ndarray
1059        The calculation based on the Einstein summation convention.
1060
1061    See Also
1062    --------
1063    einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
1064
1065    einops:
1066        similar verbose interface is provided by
1067        `einops <https://github.com/arogozhnikov/einops>`_ package to cover
1068        additional operations: transpose, reshape/flatten, repeat/tile,
1069        squeeze/unsqueeze and reductions.
1070
1071    opt_einsum:
1072        `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
1073        optimizes contraction order for einsum-like expressions
1074        in backend-agnostic manner.
1075
1076    Notes
1077    -----
1078    .. versionadded:: 1.6.0
1079
1080    The Einstein summation convention can be used to compute
1081    many multi-dimensional, linear algebraic array operations. `einsum`
1082    provides a succinct way of representing these.
1083
1084    A non-exhaustive list of these operations,
1085    which can be computed by `einsum`, is shown below along with examples:
1086
1087    * Trace of an array, :py:func:`numpy.trace`.
1088    * Return a diagonal, :py:func:`numpy.diag`.
1089    * Array axis summations, :py:func:`numpy.sum`.
1090    * Transpositions and permutations, :py:func:`numpy.transpose`.
1091    * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`.
1092    * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`.
1093    * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`.
1094    * Tensor contractions, :py:func:`numpy.tensordot`.
1095    * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`.
1096
1097    The subscripts string is a comma-separated list of subscript labels,
1098    where each label refers to a dimension of the corresponding operand.
1099    Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
1100    is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
1101    appears only once, it is not summed, so ``np.einsum('i', a)`` produces a
1102    view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)``
1103    describes traditional matrix multiplication and is equivalent to
1104    :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one
1105    operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent
1106    to :py:func:`np.trace(a) <numpy.trace>`.
1107
1108    In *implicit mode*, the chosen subscripts are important
1109    since the axes of the output are reordered alphabetically.  This
1110    means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
1111    ``np.einsum('ji', a)`` takes its transpose. Additionally,
1112    ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
1113    ``np.einsum('ij,jh', a, b)`` returns the transpose of the
1114    multiplication since subscript 'h' precedes subscript 'i'.
1115
1116    In *explicit mode* the output can be directly controlled by
1117    specifying output subscript labels.  This requires the
1118    identifier '->' as well as the list of output subscript labels.
1119    This feature increases the flexibility of the function since
1120    summing can be disabled or forced when required. The call
1121    ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`,
1122    and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`.
1123    The difference is that `einsum` does not allow broadcasting by default.
1124    Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
1125    order of the output subscript labels and therefore returns matrix
1126    multiplication, unlike the example above in implicit mode.
1127
1128    To enable and control broadcasting, use an ellipsis.  Default
1129    NumPy-style broadcasting is done by adding an ellipsis
1130    to the left of each term, like ``np.einsum('...ii->...i', a)``.
1131    To take the trace along the first and last axes,
1132    you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
1133    product with the left-most indices instead of rightmost, one can do
1134    ``np.einsum('ij...,jk...->ik...', a, b)``.
1135
1136    When there is only one operand, no axes are summed, and no output
1137    parameter is provided, a view into the operand is returned instead
1138    of a new array.  Thus, taking the diagonal as ``np.einsum('ii->i', a)``
1139    produces a view (changed in version 1.10.0).
1140
1141    `einsum` also provides an alternative way to provide the subscripts
1142    and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
1143    If the output shape is not provided in this format `einsum` will be
1144    calculated in implicit mode, otherwise it will be performed explicitly.
1145    The examples below have corresponding `einsum` calls with the two
1146    parameter methods.
1147
1148    .. versionadded:: 1.10.0
1149
1150    Views returned from einsum are now writeable whenever the input array
1151    is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
1152    have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
1153    and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
1154    of a 2D array.
1155
1156    .. versionadded:: 1.12.0
1157
1158    Added the ``optimize`` argument which will optimize the contraction order
1159    of an einsum expression. For a contraction with three or more operands this
1160    can greatly increase the computational efficiency at the cost of a larger
1161    memory footprint during computation.
1162
1163    Typically a 'greedy' algorithm is applied which empirical tests have shown
1164    returns the optimal path in the majority of cases. In some cases 'optimal'
1165    will return the superlative path through a more expensive, exhaustive search.
1166    For iterative calculations it may be advisable to calculate the optimal path
1167    once and reuse that path by supplying it as an argument. An example is given
1168    below.
1169
1170    See :py:func:`numpy.einsum_path` for more details.
1171
1172    Examples
1173    --------
1174    >>> a = np.arange(25).reshape(5,5)
1175    >>> b = np.arange(5)
1176    >>> c = np.arange(6).reshape(2,3)
1177
1178    Trace of a matrix:
1179
1180    >>> np.einsum('ii', a)
1181    60
1182    >>> np.einsum(a, [0,0])
1183    60
1184    >>> np.trace(a)
1185    60
1186
1187    Extract the diagonal (requires explicit form):
1188
1189    >>> np.einsum('ii->i', a)
1190    array([ 0,  6, 12, 18, 24])
1191    >>> np.einsum(a, [0,0], [0])
1192    array([ 0,  6, 12, 18, 24])
1193    >>> np.diag(a)
1194    array([ 0,  6, 12, 18, 24])
1195
1196    Sum over an axis (requires explicit form):
1197
1198    >>> np.einsum('ij->i', a)
1199    array([ 10,  35,  60,  85, 110])
1200    >>> np.einsum(a, [0,1], [0])
1201    array([ 10,  35,  60,  85, 110])
1202    >>> np.sum(a, axis=1)
1203    array([ 10,  35,  60,  85, 110])
1204
1205    For higher dimensional arrays summing a single axis can be done with ellipsis:
1206
1207    >>> np.einsum('...j->...', a)
1208    array([ 10,  35,  60,  85, 110])
1209    >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
1210    array([ 10,  35,  60,  85, 110])
1211
1212    Compute a matrix transpose, or reorder any number of axes:
1213
1214    >>> np.einsum('ji', c)
1215    array([[0, 3],
1216           [1, 4],
1217           [2, 5]])
1218    >>> np.einsum('ij->ji', c)
1219    array([[0, 3],
1220           [1, 4],
1221           [2, 5]])
1222    >>> np.einsum(c, [1,0])
1223    array([[0, 3],
1224           [1, 4],
1225           [2, 5]])
1226    >>> np.transpose(c)
1227    array([[0, 3],
1228           [1, 4],
1229           [2, 5]])
1230
1231    Vector inner products:
1232
1233    >>> np.einsum('i,i', b, b)
1234    30
1235    >>> np.einsum(b, [0], b, [0])
1236    30
1237    >>> np.inner(b,b)
1238    30
1239
1240    Matrix vector multiplication:
1241
1242    >>> np.einsum('ij,j', a, b)
1243    array([ 30,  80, 130, 180, 230])
1244    >>> np.einsum(a, [0,1], b, [1])
1245    array([ 30,  80, 130, 180, 230])
1246    >>> np.dot(a, b)
1247    array([ 30,  80, 130, 180, 230])
1248    >>> np.einsum('...j,j', a, b)
1249    array([ 30,  80, 130, 180, 230])
1250
1251    Broadcasting and scalar multiplication:
1252
1253    >>> np.einsum('..., ...', 3, c)
1254    array([[ 0,  3,  6],
1255           [ 9, 12, 15]])
1256    >>> np.einsum(',ij', 3, c)
1257    array([[ 0,  3,  6],
1258           [ 9, 12, 15]])
1259    >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
1260    array([[ 0,  3,  6],
1261           [ 9, 12, 15]])
1262    >>> np.multiply(3, c)
1263    array([[ 0,  3,  6],
1264           [ 9, 12, 15]])
1265
1266    Vector outer product:
1267
1268    >>> np.einsum('i,j', np.arange(2)+1, b)
1269    array([[0, 1, 2, 3, 4],
1270           [0, 2, 4, 6, 8]])
1271    >>> np.einsum(np.arange(2)+1, [0], b, [1])
1272    array([[0, 1, 2, 3, 4],
1273           [0, 2, 4, 6, 8]])
1274    >>> np.outer(np.arange(2)+1, b)
1275    array([[0, 1, 2, 3, 4],
1276           [0, 2, 4, 6, 8]])
1277
1278    Tensor contraction:
1279
1280    >>> a = np.arange(60.).reshape(3,4,5)
1281    >>> b = np.arange(24.).reshape(4,3,2)
1282    >>> np.einsum('ijk,jil->kl', a, b)
1283    array([[4400., 4730.],
1284           [4532., 4874.],
1285           [4664., 5018.],
1286           [4796., 5162.],
1287           [4928., 5306.]])
1288    >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
1289    array([[4400., 4730.],
1290           [4532., 4874.],
1291           [4664., 5018.],
1292           [4796., 5162.],
1293           [4928., 5306.]])
1294    >>> np.tensordot(a,b, axes=([1,0],[0,1]))
1295    array([[4400., 4730.],
1296           [4532., 4874.],
1297           [4664., 5018.],
1298           [4796., 5162.],
1299           [4928., 5306.]])
1300
1301    Writeable returned arrays (since version 1.10.0):
1302
1303    >>> a = np.zeros((3, 3))
1304    >>> np.einsum('ii->i', a)[:] = 1
1305    >>> a
1306    array([[1., 0., 0.],
1307           [0., 1., 0.],
1308           [0., 0., 1.]])
1309
1310    Example of ellipsis use:
1311
1312    >>> a = np.arange(6).reshape((3,2))
1313    >>> b = np.arange(12).reshape((4,3))
1314    >>> np.einsum('ki,jk->ij', a, b)
1315    array([[10, 28, 46, 64],
1316           [13, 40, 67, 94]])
1317    >>> np.einsum('ki,...k->i...', a, b)
1318    array([[10, 28, 46, 64],
1319           [13, 40, 67, 94]])
1320    >>> np.einsum('k...,jk', a, b)
1321    array([[10, 28, 46, 64],
1322           [13, 40, 67, 94]])
1323
1324    Chained array operations. For more complicated contractions, speed ups
1325    might be achieved by repeatedly computing a 'greedy' path or pre-computing the
1326    'optimal' path and repeatedly applying it, using an
1327    `einsum_path` insertion (since version 1.12.0). Performance improvements can be
1328    particularly significant with larger arrays:
1329
1330    >>> a = np.ones(64).reshape(2,4,8)
1331
1332    Basic `einsum`: ~1520ms  (benchmarked on 3.1GHz Intel i5.)
1333
1334    >>> for iteration in range(500):
1335    ...     _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
1336
1337    Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
1338
1339    >>> for iteration in range(500):
1340    ...     _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
1341
1342    Greedy `einsum` (faster optimal path approximation): ~160ms
1343
1344    >>> for iteration in range(500):
1345    ...     _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
1346
1347    Optimal `einsum` (best usage pattern in some use cases): ~110ms
1348
1349    >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
1350    >>> for iteration in range(500):
1351    ...     _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
1352
1353    """
1354    # Special handling if out is specified
1355    specified_out = out is not None
1356
1357    # If no optimization, run pure einsum
1358    if optimize is False:
1359        if specified_out:
1360            kwargs['out'] = out
1361        return c_einsum(*operands, **kwargs)
1362
1363    # Check the kwargs to avoid a more cryptic error later, without having to
1364    # repeat default values here
1365    valid_einsum_kwargs = ['dtype', 'order', 'casting']
1366    unknown_kwargs = [k for (k, v) in kwargs.items() if
1367                      k not in valid_einsum_kwargs]
1368    if len(unknown_kwargs):
1369        raise TypeError("Did not understand the following kwargs: %s"
1370                        % unknown_kwargs)
1371
1372    # Build the contraction list and operand
1373    operands, contraction_list = einsum_path(*operands, optimize=optimize,
1374                                             einsum_call=True)
1375
1376    # Handle order kwarg for output array, c_einsum allows mixed case
1377    output_order = kwargs.pop('order', 'K')
1378    if output_order.upper() == 'A':
1379        if all(arr.flags.f_contiguous for arr in operands):
1380            output_order = 'F'
1381        else:
1382            output_order = 'C'
1383
1384    # Start contraction loop
1385    for num, contraction in enumerate(contraction_list):
1386        inds, idx_rm, einsum_str, remaining, blas = contraction
1387        tmp_operands = [operands.pop(x) for x in inds]
1388
1389        # Do we need to deal with the output?
1390        handle_out = specified_out and ((num + 1) == len(contraction_list))
1391
1392        # Call tensordot if still possible
1393        if blas:
1394            # Checks have already been handled
1395            input_str, results_index = einsum_str.split('->')
1396            input_left, input_right = input_str.split(',')
1397
1398            tensor_result = input_left + input_right
1399            for s in idx_rm:
1400                tensor_result = tensor_result.replace(s, "")
1401
1402            # Find indices to contract over
1403            left_pos, right_pos = [], []
1404            for s in sorted(idx_rm):
1405                left_pos.append(input_left.find(s))
1406                right_pos.append(input_right.find(s))
1407
1408            # Contract!
1409            new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)))
1410
1411            # Build a new view if needed
1412            if (tensor_result != results_index) or handle_out:
1413                if handle_out:
1414                    kwargs["out"] = out
1415                new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs)
1416
1417        # Call einsum
1418        else:
1419            # If out was specified
1420            if handle_out:
1421                kwargs["out"] = out
1422
1423            # Do the contraction
1424            new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
1425
1426        # Append new items and dereference what we can
1427        operands.append(new_view)
1428        del tmp_operands, new_view
1429
1430    if specified_out:
1431        return out
1432    else:
1433        return asanyarray(operands[0], order=output_order)
1434