1"""
2This module defines tensors with abstract index notation.
3
4The abstract index notation has been first formalized by Penrose.
5
6Tensor indices are formal objects, with a tensor type; there is no
7notion of index range, it is only possible to assign the dimension,
8used to trace the Kronecker delta; the dimension can be a Symbol.
9
10The Einstein summation convention is used.
11The covariant indices are indicated with a minus sign in front of the index.
12
13For instance the tensor ``t = p(a)*A(b,c)*q(-c)`` has the index ``c``
14contracted.
15
16A tensor expression ``t`` can be called; called with its
17indices in sorted order it is equal to itself:
18in the above example ``t(a, b) == t``;
19one can call ``t`` with different indices; ``t(c, d) == p(c)*A(d,a)*q(-a)``.
20
21The contracted indices are dummy indices, internally they have no name,
22the indices being represented by a graph-like structure.
23
24Tensors are put in canonical form using ``canon_bp``, which uses
25the Butler-Portugal algorithm for canonicalization using the monoterm
26symmetries of the tensors.
27
28If there is a (anti)symmetric metric, the indices can be raised and
29lowered when the tensor is put in canonical form.
30"""
31
32from __future__ import annotations
33
34import functools
35import typing
36from collections import defaultdict
37
38from ..combinatorics.tensor_can import (bsgs_direct_product, canonicalize,
39                                        get_symmetric_group_sgs, riemann_bsgs)
40from ..core import (Add, Basic, Integer, Rational, Symbol, Tuple, symbols,
41                    sympify)
42from ..core.sympify import CantSympify
43from ..external import import_module
44from ..matrices import Matrix, eye
45from ..utilities.decorator import doctest_depends_on
46
47
48class TIDS(CantSympify):
49    """
50    Tensor-index data structure. This contains internal data structures about
51    components of a tensor expression, its free and dummy indices.
52
53    To create a ``TIDS`` object via the standard constructor, the required
54    arguments are
55
56    WARNING: this class is meant as an internal representation of tensor data
57    structures and should not be directly accessed by end users.
58
59    Parameters
60    ==========
61
62    components : ``TensorHead`` objects representing the components of the tensor expression.
63
64    free : Free indices in their internal representation.
65
66    dum : Dummy indices in their internal representation.
67
68    Examples
69    ========
70
71    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
72    >>> m0, m1, m2, m3 = tensor_indices('m0 m1 m2 m3', Lorentz)
73    >>> T = tensorhead('T', [Lorentz]*4, [[1]*4])
74    >>> TIDS([T], [(m0, 0, 0), (m3, 3, 0)], [(1, 2, 0, 0)])
75    TIDS([T(Lorentz,Lorentz,Lorentz,Lorentz)], [(m0, 0, 0), (m3, 3, 0)], [(1, 2, 0, 0)])
76
77    Notes
78    =====
79
80    In short, this has created the components, free and dummy indices for
81    the internal representation of a tensor T(m0, m1, -m1, m3).
82
83    Free indices are represented as a list of triplets. The elements of
84    each triplet identify a single free index and are
85
86    1. TensorIndex object
87    2. position inside the component
88    3. component number
89
90    Dummy indices are represented as a list of 4-plets. Each 4-plet stands
91    for couple for contracted indices, their original TensorIndex is not
92    stored as it is no longer required. The four elements of the 4-plet
93    are
94
95    1. position inside the component of the first index.
96    2. position inside the component of the second index.
97    3. component number of the first index.
98    4. component number of the second index.
99
100    """
101
102    def __init__(self, components, free, dum):
103        self.components = components
104        self.free = free
105        self.dum = dum
106        self._ext_rank = len(self.free) + 2*len(self.dum)
107        self.dum.sort(key=lambda x: (x[2], x[0]))
108
109    def get_tensors(self):
110        """
111        Get a list of ``Tensor`` objects having the same ``TIDS`` if multiplied
112        by one another.
113
114        """
115        indices = self.get_indices()
116        components = self.components
117        tensors = [None for i in components]  # pre-allocate list
118        ind_pos = 0
119        for i, component in enumerate(components):
120            prev_pos = ind_pos
121            ind_pos += component.rank
122            tensors[i] = Tensor(component, indices[prev_pos:ind_pos])
123        return tensors
124
125    def get_components_with_free_indices(self):
126        """
127        Get a list of components with their associated indices.
128
129        Examples
130        ========
131
132        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
133        >>> m0, m1, m2, m3 = tensor_indices('m0 m1 m2 m3', Lorentz)
134        >>> T = tensorhead('T', [Lorentz]*4, [[1]*4])
135        >>> A = tensorhead('A', [Lorentz], [[1]])
136        >>> t = TIDS.from_components_and_indices([T], [m0, m1, -m1, m3])
137        >>> t.get_components_with_free_indices()
138        [(T(Lorentz,Lorentz,Lorentz,Lorentz), [(m0, 0, 0), (m3, 3, 0)])]
139        >>> t2 = (A(m0)*A(-m0))._tids
140        >>> t2.get_components_with_free_indices()
141        [(A(Lorentz), []), (A(Lorentz), [])]
142        >>> t3 = (A(m0)*A(-m1)*A(-m0)*A(m1))._tids
143        >>> t3.get_components_with_free_indices()
144        [(A(Lorentz), []), (A(Lorentz), []), (A(Lorentz), []), (A(Lorentz), [])]
145        >>> t4 = (A(m0)*A(m1)*A(-m0))._tids
146        >>> t4.get_components_with_free_indices()
147        [(A(Lorentz), []), (A(Lorentz), [(m1, 0, 1)]), (A(Lorentz), [])]
148        >>> t5 = (A(m0)*A(m1)*A(m2))._tids
149        >>> t5.get_components_with_free_indices()
150        [(A(Lorentz), [(m0, 0, 0)]), (A(Lorentz), [(m1, 0, 1)]), (A(Lorentz), [(m2, 0, 2)])]
151
152        """
153        components = self.components
154        ret_comp = []
155
156        free_counter = 0
157        if len(self.free) == 0:
158            return [(comp, []) for comp in components]
159
160        for i, comp in enumerate(components):
161            c_free = []
162            while free_counter < len(self.free):
163                if not self.free[free_counter][2] == i:
164                    break
165
166                c_free.append(self.free[free_counter])
167                free_counter += 1
168
169                if free_counter >= len(self.free):
170                    break
171            ret_comp.append((comp, c_free))
172
173        return ret_comp
174
175    @staticmethod
176    def from_components_and_indices(components, indices):
177        """
178        Create a new ``TIDS`` object from ``components`` and ``indices``
179
180        ``components``  ``TensorHead`` objects representing the components
181                        of the tensor expression.
182
183        ``indices``     ``TensorIndex`` objects, the indices. Contractions are
184                        detected upon construction.
185
186        Examples
187        ========
188
189        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
190        >>> m0, m1, m2, m3 = tensor_indices('m0 m1 m2 m3', Lorentz)
191        >>> T = tensorhead('T', [Lorentz]*4, [[1]*4])
192        >>> TIDS.from_components_and_indices([T], [m0, m1, -m1, m3])
193        TIDS([T(Lorentz,Lorentz,Lorentz,Lorentz)], [(m0, 0, 0), (m3, 3, 0)], [(1, 2, 0, 0)])
194
195        In case of many components the same indices have slightly different
196        indexes:
197
198        >>> A = tensorhead('A', [Lorentz], [[1]])
199        >>> TIDS.from_components_and_indices([A]*4, [m0, m1, -m1, m3])
200        TIDS([A(Lorentz), A(Lorentz), A(Lorentz), A(Lorentz)], [(m0, 0, 0), (m3, 0, 3)], [(0, 0, 1, 2)])
201
202        """
203        tids = None
204        cur_pos = 0
205        for i in components:
206            tids_sing = TIDS([i], *TIDS.free_dum_from_indices(*indices[cur_pos:cur_pos+i.rank]))
207            if tids is None:
208                tids = tids_sing
209            else:
210                tids *= tids_sing
211            cur_pos += i.rank
212
213        if tids is None:
214            tids = TIDS([], [], [])
215
216        tids.free.sort(key=lambda x: x[0].name)
217        tids.dum.sort()
218
219        return tids
220
221    @staticmethod
222    def free_dum_from_indices(*indices):
223        """
224        Convert ``indices`` into ``free``, ``dum`` for single component tensor
225
226        ``free``     list of tuples ``(index, pos, 0)``,
227                     where ``pos`` is the position of index in
228                     the list of indices formed by the component tensors
229
230        ``dum``      list of tuples ``(pos_contr, pos_cov, 0, 0)``
231
232        Examples
233        ========
234
235        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
236        >>> m0, m1, m2, m3 = tensor_indices('m0 m1 m2 m3', Lorentz)
237        >>> TIDS.free_dum_from_indices(m0, m1, -m1, m3)
238        ([(m0, 0, 0), (m3, 3, 0)], [(1, 2, 0, 0)])
239
240        """
241        n = len(indices)
242        if n == 1:
243            return [(indices[0], 0, 0)], []
244
245        # find the positions of the free indices and of the dummy indices
246        free = [True]*len(indices)
247        index_dict = {}
248        dum = []
249        for i, index in enumerate(indices):
250            name = index._name
251            typ = index._tensortype
252            contr = index._is_up
253            if (name, typ) in index_dict:
254                # found a pair of dummy indices
255                is_contr, pos = index_dict[(name, typ)]
256                # check consistency and update free
257                if is_contr:
258                    if contr:
259                        raise ValueError(f'two equal contravariant indices in slots {pos:d} and {i:d}')
260                    else:
261                        free[pos] = False
262                        free[i] = False
263                else:
264                    if contr:
265                        free[pos] = False
266                        free[i] = False
267                    else:
268                        raise ValueError(f'two equal covariant indices in slots {pos:d} and {i:d}')
269                if contr:
270                    dum.append((i, pos, 0, 0))
271                else:
272                    dum.append((pos, i, 0, 0))
273            else:
274                index_dict[(name, typ)] = index._is_up, i
275
276        free = [(index, i, 0) for i, index in enumerate(indices) if free[i]]
277        free.sort()
278        return free, dum
279
280    @staticmethod
281    def _check_matrix_indices(f_free, g_free, nc1):
282        # This "private" method checks matrix indices.
283        # Matrix indices are special as there are only two, and observe
284        # anomalous substitution rules to determine contractions.
285
286        dum = []
287        # make sure that free indices appear in the same order as in their component:
288        f_free.sort(key=lambda x: (x[2], x[1]))
289        g_free.sort(key=lambda x: (x[2], x[1]))
290        matrix_indices_storage = {}
291        transform_right_to_left = {}
292        f_pop_pos = []
293        g_pop_pos = []
294        for free_pos, (ind, i, c) in enumerate(f_free):
295            index_type = ind._tensortype
296            if ind not in (index_type.auto_left, -index_type.auto_right):
297                continue
298            matrix_indices_storage[ind] = (free_pos, i, c)
299
300        for free_pos, (ind, i, c) in enumerate(g_free):
301            index_type = ind._tensortype
302            if ind not in (index_type.auto_left, -index_type.auto_right):
303                continue
304
305            if ind == index_type.auto_left:
306                if -index_type.auto_right in matrix_indices_storage:
307                    other_pos, other_i, other_c = matrix_indices_storage.pop(-index_type.auto_right)
308                    dum.append((other_i, i, other_c, c + nc1))
309                    # mark to remove other_pos and free_pos from free:
310                    g_pop_pos.append(free_pos)
311                    f_pop_pos.append(other_pos)
312                    continue
313                if ind in matrix_indices_storage:
314                    other_pos, other_i, other_c = matrix_indices_storage.pop(ind)
315                    dum.append((other_i, i, other_c, c + nc1))
316                    # mark to remove other_pos and free_pos from free:
317                    g_pop_pos.append(free_pos)
318                    f_pop_pos.append(other_pos)
319                    transform_right_to_left[-index_type.auto_right] = c
320                    continue
321
322            if ind in transform_right_to_left:
323                other_c = transform_right_to_left.pop(ind)
324                if c == other_c:
325                    g_free[free_pos] = (index_type.auto_left, i, c)
326
327        for i in sorted(f_pop_pos, reverse=True):
328            f_free.pop(i)
329        for i in sorted(g_pop_pos, reverse=True):
330            g_free.pop(i)
331        return dum
332
333    @staticmethod
334    def mul(f, g):
335        """
336        The algorithms performing the multiplication of two ``TIDS`` instances.
337
338        In short, it forms a new ``TIDS`` object, joining components and indices,
339        checking that abstract indices are compatible, and possibly contracting
340        them.
341
342        Examples
343        ========
344
345        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
346        >>> m0, m1, m2, m3 = tensor_indices('m0 m1 m2 m3', Lorentz)
347        >>> T = tensorhead('T', [Lorentz]*4, [[1]*4])
348        >>> A = tensorhead('A', [Lorentz], [[1]])
349        >>> tids_1 = TIDS.from_components_and_indices([T], [m0, m1, -m1, m3])
350        >>> tids_2 = TIDS.from_components_and_indices([A], [m2])
351        >>> tids_1 * tids_2
352        TIDS([T(Lorentz,Lorentz,Lorentz,Lorentz), A(Lorentz)],
353             [(m0, 0, 0), (m3, 3, 0), (m2, 0, 1)], [(1, 2, 0, 0)])
354
355        In this case no contraction has been performed.
356
357        >>> tids_3 = TIDS.from_components_and_indices([A], [-m3])
358        >>> tids_1 * tids_3
359        TIDS([T(Lorentz,Lorentz,Lorentz,Lorentz), A(Lorentz)],
360             [(m0, 0, 0)], [(1, 2, 0, 0), (3, 0, 0, 1)])
361
362        Free indices ``m3`` and ``-m3`` are identified as a contracted couple, and are
363        therefore transformed into dummy indices.
364
365        A wrong index construction (for example, trying to contract two
366        contravariant indices or using indices multiple times) would result in
367        an exception:
368
369        >>> tids_4 = TIDS.from_components_and_indices([A], [m3])
370        >>> # This raises an exception:
371        >>> # tids_1 * tids_4
372
373        """
374        def index_up(u):
375            return u if u.is_up else -u
376
377        f_free = f.free[:]
378        g_free = g.free[:]
379        nc1 = len(f.components)
380        dum = TIDS._check_matrix_indices(f_free, g_free, nc1)
381
382        # find out which free indices of f and g are contracted
383        free_dict1 = {i if i.is_up else -i: (pos, cpos, i) for i, pos, cpos in f_free}
384        free_dict2 = {i if i.is_up else -i: (pos, cpos, i) for i, pos, cpos in g_free}
385        free_names = set(free_dict1) & set(free_dict2)
386        # find the new `free` and `dum`
387
388        dum2 = [(i1, i2, c1 + nc1, c2 + nc1) for i1, i2, c1, c2 in g.dum]
389        free1 = [(ind, i, c) for ind, i, c in f_free if index_up(ind) not in free_names]
390        free2 = [(ind, i, c + nc1) for ind, i, c in g_free if index_up(ind) not in free_names]
391        free = free1 + free2
392        dum.extend(f.dum + dum2)
393        for name in free_names:
394            ipos1, cpos1, ind1 = free_dict1[name]
395            ipos2, cpos2, ind2 = free_dict2[name]
396            cpos2 += nc1
397            if ind1._is_up == ind2._is_up:
398                raise ValueError(f'wrong index construction {ind1}')
399            if ind1._is_up:
400                new_dummy = (ipos1, ipos2, cpos1, cpos2)
401            else:
402                new_dummy = (ipos2, ipos1, cpos2, cpos1)
403            dum.append(new_dummy)
404        return f.components + g.components, free, dum
405
406    def __mul__(self, other):
407        return TIDS(*self.mul(self, other))
408
409    def __str__(self):
410        from ..printing import sstr
411        return f'TIDS({sstr(self.components)}, {sstr(self.free)}, {sstr(self.dum)})'
412
413    def sorted_components(self):
414        """
415        Returns a ``TIDS`` with sorted components
416
417        The sorting is done taking into account the commutation group
418        of the component tensors.
419
420        """
421        from ..combinatorics.permutations import _af_invert
422        cv = list(zip(self.components, range(len(self.components))))
423        sign = 1
424        n = len(cv) - 1
425        for i in range(n):
426            for j in range(n, i, -1):
427                c = cv[j-1][0].commutes_with(cv[j][0])
428                if c not in [0, 1]:
429                    continue
430                if (cv[j-1][0]._types, cv[j-1][0]._name) > \
431                        (cv[j][0]._types, cv[j][0]._name):
432                    cv[j-1], cv[j] = cv[j], cv[j-1]
433                    if c:
434                        sign = -sign
435
436        # perm_inv[new_pos] = old_pos
437        components = [x[0] for x in cv]
438        perm_inv = [x[1] for x in cv]
439        perm = _af_invert(perm_inv)
440        free = [(ind, i, perm[c]) for ind, i, c in self.free]
441        free.sort()
442        dum = [(i1, i2, perm[c1], perm[c2]) for i1, i2, c1, c2 in self.dum]
443        dum.sort(key=lambda x: components[x[2]].index_types[x[0]])
444
445        return TIDS(components, free, dum), sign
446
447    def _get_sorted_free_indices_for_canon(self):
448        sorted_free = self.free[:]
449        sorted_free.sort(key=lambda x: x[0])
450        return sorted_free
451
452    def _get_sorted_dum_indices_for_canon(self):
453        return sorted(self.dum, key=lambda x: (x[2], x[0]))
454
455    def canon_args(self):
456        """
457        Returns ``(g, dummies, msym, v)``, the entries of ``canonicalize``
458
459        see ``canonicalize`` in ``tensor_can.py``
460
461        """
462        # to be called after sorted_components
463        from ..combinatorics.permutations import _af_new
464#       types = list(set(self._types))
465#       types.sort(key = lambda x: x._name)
466        n = self._ext_rank
467        g = [None]*n + [n, n+1]
468        pos = 0
469        vpos = []
470        components = self.components
471        for t in components:
472            vpos.append(pos)
473            pos += t._rank
474        # ordered indices: first the free indices, ordered by types
475        # then the dummy indices, ordered by types and contravariant before
476        # covariant
477        # g[position in tensor] = position in ordered indices
478        for i, (_, ipos, cpos) in enumerate(self._get_sorted_free_indices_for_canon()):
479            pos = vpos[cpos] + ipos
480            g[pos] = i
481        pos = len(self.free)
482        j = len(self.free)
483        dummies = []
484        prev = None
485        a = []
486        msym = []
487        for ipos1, ipos2, cpos1, cpos2 in self._get_sorted_dum_indices_for_canon():
488            pos1 = vpos[cpos1] + ipos1
489            pos2 = vpos[cpos2] + ipos2
490            g[pos1] = j
491            g[pos2] = j + 1
492            j += 2
493            typ = components[cpos1].index_types[ipos1]
494            if typ != prev:
495                if a:
496                    dummies.append(a)
497                a = [pos, pos + 1]
498                prev = typ
499                msym.append(typ.metric_antisym)
500            else:
501                a.extend([pos, pos + 1])
502            pos += 2
503        if a:
504            dummies.append(a)
505        numtyp = []
506        prev = None
507        for t in components:
508            if t == prev:
509                numtyp[-1][1] += 1
510            else:
511                prev = t
512                numtyp.append([prev, 1])
513        v = []
514        for h, n in numtyp:
515            if h._comm in (0, 1):
516                comm = h._comm
517            else:
518                comm = TensorManager.get_comm(h._comm, h._comm)
519            v.append((h._symmetry.base, h._symmetry.generators, n, comm))
520        return _af_new(g), dummies, msym, v
521
522    def perm2tensor(self, g, canon_bp=False):
523        """
524        Returns a ``TIDS`` instance corresponding to the permutation ``g``
525
526        ``g``  permutation corresponding to the tensor in the representation
527        used in canonicalization
528
529        ``canon_bp``   if True, then ``g`` is the permutation
530        corresponding to the canonical form of the tensor
531
532        """
533        vpos = []
534        components = self.components
535        pos = 0
536        for t in components:
537            vpos.append(pos)
538            pos += t._rank
539        sorted_free = [i[0] for i in self._get_sorted_free_indices_for_canon()]
540        nfree = len(sorted_free)
541        rank = self._ext_rank
542        dum = [[None]*4 for i in range((rank - nfree)//2)]
543        free = []
544        icomp = -1
545        for i in range(rank):
546            if i in vpos:
547                icomp += vpos.count(i)
548                pos0 = i
549            ipos = i - pos0
550            gi = g[i]
551            if gi < nfree:
552                ind = sorted_free[gi]
553                free.append((ind, ipos, icomp))
554            else:
555                j = gi - nfree
556                idum, cov = divmod(j, 2)
557                if cov:
558                    dum[idum][1] = ipos
559                    dum[idum][3] = icomp
560                else:
561                    dum[idum][0] = ipos
562                    dum[idum][2] = icomp
563        dum = [tuple(x) for x in dum]
564
565        return TIDS(components, free, dum)
566
567    def get_indices(self):
568        """
569        Get a list of indices, creating new tensor indices to complete dummy indices.
570
571        """
572        components = self.components
573        free = self.free
574        dum = self.dum
575        indices = [None]*self._ext_rank
576        start = 0
577        pos = 0
578        vpos = []
579        for t in components:
580            vpos.append(pos)
581            pos += t.rank
582        cdt = defaultdict(int)
583        # if the free indices have names with dummy_fmt, start with an
584        # index higher than those for the dummy indices
585        # to avoid name collisions
586        for indx, ipos, cpos in free:
587            if indx._name.split('_')[0] == indx._tensortype._dummy_fmt[:-3]:
588                cdt[indx._tensortype] = max(cdt[indx._tensortype], int(indx._name.split('_')[1]) + 1)
589            start = vpos[cpos]
590            indices[start + ipos] = indx
591        for ipos1, ipos2, cpos1, cpos2 in dum:
592            start1 = vpos[cpos1]
593            start2 = vpos[cpos2]
594            typ1 = components[cpos1].index_types[ipos1]
595            assert typ1 == components[cpos2].index_types[ipos2]
596            fmt = typ1._dummy_fmt
597            nd = cdt[typ1]
598            indices[start1 + ipos1] = TensorIndex(fmt % nd, typ1)
599            indices[start2 + ipos2] = TensorIndex(fmt % nd, typ1, False)
600            cdt[typ1] += 1
601        return indices
602
603    def contract_metric(self, g):
604        """
605        Returns new TIDS and sign.
606
607        Sign is either 1 or -1, to correct the sign after metric contraction
608        (for spinor indices).
609
610        """
611        components = self.components
612        antisym = g.index_types[0].metric_antisym
613        # if not any(x == g for x in components):
614        #    return self
615        # list of positions of the metric ``g``
616        gpos = [i for i, x in enumerate(components) if x == g]
617        if not gpos:
618            return self, 1
619        sign = 1
620        dum = self.dum[:]
621        free = self.free[:]
622        elim = set()
623        for gposx in gpos:
624            if gposx in elim:
625                continue
626            free1 = [x for x in free if x[-1] == gposx]
627            dum1 = [x for x in dum if gposx in (x[-2], x[-1])]
628            if not dum1:
629                continue
630            elim.add(gposx)
631            if len(dum1) == 2:
632                if not antisym:
633                    dum10, dum11 = dum1
634                    if dum10[3] == gposx:
635                        # the index with pos p0 and component c0 is contravariant
636                        c0 = dum10[2]
637                        p0 = dum10[0]
638                    else:
639                        # the index with pos p0 and component c0 is covariant
640                        c0 = dum10[3]
641                        p0 = dum10[1]
642                    if dum11[3] == gposx:
643                        # the index with pos p1 and component c1 is contravariant
644                        c1 = dum11[2]
645                        p1 = dum11[0]
646                    else:
647                        # the index with pos p1 and component c1 is covariant
648                        c1 = dum11[3]
649                        p1 = dum11[1]
650                    dum.append((p0, p1, c0, c1))
651                else:
652                    dum10, dum11 = dum1
653                    # change the sign to bring the indices of the metric to contravariant
654                    # form; change the sign if dum10 has the metric index in position 0
655                    if dum10[3] == gposx:
656                        # the index with pos p0 and component c0 is contravariant
657                        c0 = dum10[2]
658                        p0 = dum10[0]
659                        if dum10[1] == 1:
660                            sign = -sign
661                    else:
662                        # the index with pos p0 and component c0 is covariant
663                        c0 = dum10[3]
664                        p0 = dum10[1]
665                        if dum10[0] == 0:
666                            sign = -sign
667                    if dum11[3] == gposx:
668                        # the index with pos p1 and component c1 is contravariant
669                        c1 = dum11[2]
670                        p1 = dum11[0]
671                        sign = -sign
672                    else:
673                        # the index with pos p1 and component c1 is covariant
674                        c1 = dum11[3]
675                        p1 = dum11[1]
676                    dum.append((p0, p1, c0, c1))
677
678            elif len(dum1) == 1:
679                if not antisym:
680                    dp0, dp1, dc0, dc1 = dum1[0]
681                    if dc0 == dc1:
682                        # g(i, -i)
683                        typ = g.index_types[0]
684                        if typ._dim is None:
685                            raise ValueError('dimension not assigned')
686                        sign = sign*typ._dim
687
688                    else:
689                        # g(i0, i1)*p(-i1)
690                        if dc0 == gposx:
691                            p1 = dp1
692                            c1 = dc1
693                        else:
694                            p1 = dp0
695                            c1 = dc0
696                        ind, _, c = free1[0]
697                        free.append((ind, p1, c1))
698                else:
699                    dp0, dp1, dc0, dc1 = dum1[0]
700                    if dc0 == dc1:
701                        # g(i, -i)
702                        typ = g.index_types[0]
703                        if typ._dim is None:
704                            raise ValueError('dimension not assigned')
705                        sign = sign*typ._dim
706
707                        if dp0 < dp1:
708                            # g(i, -i) = -D with antisymmetric metric
709                            sign = -sign
710                    else:
711                        # g(i0, i1)*p(-i1)
712                        if dc0 == gposx:
713                            p1 = dp1
714                            c1 = dc1
715                            if dp0 == 0:
716                                sign = -sign
717                        else:
718                            p1 = dp0
719                            c1 = dc0
720                        ind, _, c = free1[0]
721                        free.append((ind, p1, c1))
722            dum = [x for x in dum if x not in dum1]
723            free = [x for x in free if x not in free1]
724
725        shift = 0
726        shifts = [0]*len(components)
727        for i in range(len(components)):
728            if i in elim:
729                shift += 1
730                continue
731            shifts[i] = shift
732        free = [(ind, p, c - shifts[c]) for (ind, p, c) in free if c not in elim]
733        dum = [(p0, p1, c0 - shifts[c0], c1 - shifts[c1]) for i, (p0, p1, c0, c1) in enumerate(dum) if c0 not in elim and c1 not in elim]
734        components = [c for i, c in enumerate(components) if i not in elim]
735        tids = TIDS(components, free, dum)
736        return tids, sign
737
738
739class _TensorDataLazyEvaluator(CantSympify):
740    """
741    EXPERIMENTAL: do not rely on this class, it may change without deprecation
742    warnings in future versions of Diofant.
743
744    This object contains the logic to associate components data to a tensor
745    expression. Components data are set via the ``.data`` property of tensor
746    expressions, is stored inside this class as a mapping between the tensor
747    expression and the ``ndarray``.
748
749    Computations are executed lazily: whereas the tensor expressions can have
750    contractions, tensor products, and additions, components data are not
751    computed until they are accessed by reading the ``.data`` property
752    associated to the tensor expression.
753
754    """
755
756    _substitutions_dict: dict[typing.Any, typing.Any] = {}
757    _substitutions_dict_tensmul: dict[typing.Any, typing.Any] = {}
758
759    def __getitem__(self, key):
760        dat = self._get(key)
761        if dat is None:
762            return
763
764        numpy = import_module('numpy')
765        if not isinstance(dat, numpy.ndarray):
766            return dat
767
768        if dat.ndim == 0:
769            return dat[()]
770        elif dat.ndim == 1 and dat.size == 1:
771            return dat[0]
772        return dat
773
774    def _get(self, key):
775        """
776        Retrieve ``data`` associated with ``key``.
777
778        This algorithm looks into ``self._substitutions_dict`` for all
779        ``TensorHead`` in the ``TensExpr`` (or just ``TensorHead`` if key is a
780        TensorHead instance). It reconstructs the components data that the
781        tensor expression should have by performing on components data the
782        operations that correspond to the abstract tensor operations applied.
783
784        Metric tensor is handled in a different manner: it is pre-computed in
785        ``self._substitutions_dict_tensmul``.
786
787        """
788        if key in self._substitutions_dict:
789            return self._substitutions_dict[key]
790
791        if isinstance(key, TensorHead):
792            return
793
794        if isinstance(key, Tensor):
795            # special case to handle metrics. Metric tensors cannot be
796            # constructed through contraction by the metric, their
797            # components show if they are a matrix or its inverse.
798            signature = tuple(i.is_up for i in key.get_indices())
799            srch = (key.component,) + signature
800            if srch in self._substitutions_dict_tensmul:
801                return self._substitutions_dict_tensmul[srch]
802            return self.data_tensmul_from_tensorhead(key, key.component)
803
804        if isinstance(key, TensMul):
805            tensmul_list = key.split()
806            if len(tensmul_list) == 1 and len(tensmul_list[0].components) == 1:
807                # special case to handle metrics. Metric tensors cannot be
808                # constructed through contraction by the metric, their
809                # components show if they are a matrix or its inverse.
810                signature = tuple(i.is_up for i in tensmul_list[0].get_indices())
811                srch = (tensmul_list[0].components[0],) + signature
812                if srch in self._substitutions_dict_tensmul:
813                    return self._substitutions_dict_tensmul[srch]
814            data_list = [self.data_tensmul_from_tensorhead(i, i.components[0]) for i in tensmul_list]
815            if all(i is None for i in data_list):
816                return
817            if any(i is None for i in data_list):
818                raise ValueError('Mixing tensors with associated components '
819                                 'data with tensors without components data')
820            data_result, _ = self.data_product_tensors(data_list, tensmul_list)
821            return data_result
822
823        if isinstance(key, TensAdd):
824            sumvar = Integer(0)
825            data_list = []
826            free_args_list = []
827            for arg in key.args:
828                if isinstance(arg, TensExpr):
829                    data_list.append(arg.data)
830                    free_args_list.append([x[0] for x in arg.free])
831                else:
832                    data_list.append(arg)
833                    free_args_list.append([])
834            if all(i is None for i in data_list):
835                return
836            if any(i is None for i in data_list):
837                raise ValueError('Mixing tensors with associated components '
838                                 'data with tensors without components data')
839
840            numpy = import_module('numpy')
841            for data, free_args in zip(data_list, free_args_list):
842                if len(free_args) < 2:
843                    sumvar += data
844                else:
845                    free_args_pos = {y: x for x, y in enumerate(free_args)}
846                    axes = [free_args_pos[arg] for arg in key.free_args]
847                    sumvar += numpy.transpose(data, axes)
848            return sumvar
849
850    def data_tensorhead_from_tensmul(self, data, tensmul, tensorhead):
851        """
852        This method is used when assigning components data to a ``TensMul``
853        object, it converts components data to a fully contravariant ndarray,
854        which is then stored according to the ``TensorHead`` key.
855
856        """
857        if data is not None:
858            return self._correct_signature_from_indices(
859                data,
860                tensmul.get_indices(),
861                tensmul.free,
862                tensmul.dum,
863                True)
864
865    def data_tensmul_from_tensorhead(self, tensmul, tensorhead):
866        """
867        This method corrects the components data to the right signature
868        (covariant/contravariant) using the metric associated with each
869        ``TensorIndexType``.
870
871        """
872        if tensorhead.data is not None:
873            return self._correct_signature_from_indices(
874                tensorhead.data,
875                tensmul.get_indices(),
876                tensmul.free,
877                tensmul.dum)
878
879    def data_product_tensors(self, data_list, tensmul_list):
880        """
881        Given a ``data_list``, list of ``ndarray``'s and a ``tensmul_list``,
882        list of ``TensMul`` instances, compute the resulting ``ndarray``,
883        after tensor products and contractions.
884
885        """
886        def data_mul(f, g):
887            """
888            Multiplies two ``ndarray`` objects, it first calls ``TIDS.mul``,
889            then checks which indices have been contracted, and finally
890            contraction operation on data, according to the contracted indices.
891
892            """
893            data1, tensmul1 = f
894            data2, tensmul2 = g
895            components, free, dum = TIDS.mul(tensmul1, tensmul2)
896            data = _TensorDataLazyEvaluator._contract_ndarray(tensmul1.free, tensmul2.free, data1, data2)
897            # TODO: do this more efficiently... maybe by just passing an index list
898            # to .data_product_tensor(...)
899            return data, TensMul.from_TIDS(Integer(1), TIDS(components, free, dum))
900
901        return functools.reduce(data_mul, zip(data_list, tensmul_list))
902
903    def _assign_data_to_tensor_expr(self, key, data):
904        if isinstance(key, TensAdd):
905            raise ValueError('cannot assign data to TensAdd')
906        # here it is assumed that `key` is a `TensMul` instance.
907        if len(key.components) != 1:
908            raise ValueError('cannot assign data to TensMul with multiple components')
909        tensorhead = key.components[0]
910        newdata = self.data_tensorhead_from_tensmul(data, key, tensorhead)
911        return tensorhead, newdata
912
913    def _check_permutations_on_data(self, tens, data):
914        import numpy
915
916        if isinstance(tens, TensorHead):
917            rank = tens.rank
918            generators = tens.symmetry.generators
919        elif isinstance(tens, Tensor):
920            rank = tens.rank
921            generators = tens.components[0].symmetry.generators
922        elif isinstance(tens, TensorIndexType):
923            rank = tens.metric.rank
924            generators = tens.metric.symmetry.generators
925
926        # Every generator is a permutation, check that by permuting the array
927        # by that permutation, the array will be the same, except for a
928        # possible sign change if the permutation admits it.
929        for gener in generators:
930            sign_change = +1 if (gener(rank) == rank) else -1
931            data_swapped = data
932            last_data = data
933            permute_axes = list(map(gener, range(rank)))
934            # the order of a permutation is the number of times to get the
935            # identity by applying that permutation.
936            for _ in range(gener.order()-1):
937                data_swapped = numpy.transpose(data_swapped, permute_axes)
938                # if any value in the difference array is non-zero, raise an error:
939                if (last_data - sign_change*data_swapped).any():
940                    raise ValueError('Component data symmetry structure error')
941                last_data = data_swapped
942
943    def __setitem__(self, key, value):
944        """
945        Set the components data of a tensor object/expression.
946
947        Components data are transformed to the all-contravariant form and stored
948        with the corresponding ``TensorHead`` object. If a ``TensorHead`` object
949        cannot be uniquely identified, it will raise an error.
950
951        """
952        data = _TensorDataLazyEvaluator.parse_data(value)
953        self._check_permutations_on_data(key, data)
954
955        # TensorHead and TensorIndexType can be assigned data directly, while
956        # TensMul must first convert data to a fully contravariant form, and
957        # assign it to its corresponding TensorHead single component.
958        if not isinstance(key, (TensorHead, TensorIndexType)):
959            key, data = self._assign_data_to_tensor_expr(key, data)
960
961        if isinstance(key, TensorHead):
962            for dim, indextype in zip(data.shape, key.index_types):
963                if indextype.data is None:
964                    raise ValueError(f'index type {indextype} has no components data'
965                                     ' associated (needed to raise/lower index')
966                if indextype.dim is None:
967                    continue
968                if dim != indextype.dim:
969                    raise ValueError('wrong dimension of ndarray')
970        self._substitutions_dict[key] = data
971
972    def __delitem__(self, key):
973        del self._substitutions_dict[key]
974
975    def __contains__(self, key):
976        return key in self._substitutions_dict
977
978    @staticmethod
979    def _contract_ndarray(free1, free2, ndarray1, ndarray2):
980        numpy = import_module('numpy')
981
982        def ikey(x):
983            return x[2], x[1]
984
985        free1 = free1[:]
986        free2 = free2[:]
987        free1.sort(key=ikey)
988        free2.sort(key=ikey)
989        self_free = [_[0] for _ in free1]
990        axes1 = []
991        axes2 = []
992        for jpos, jindex in enumerate(free2):
993            if -jindex[0] in self_free:
994                nidx = self_free.index(-jindex[0])
995            else:
996                continue
997            axes1.append(nidx)
998            axes2.append(jpos)
999
1000        contracted_ndarray = numpy.tensordot(
1001            ndarray1,
1002            ndarray2,
1003            (axes1, axes2)
1004        )
1005        return contracted_ndarray
1006
1007    def add_metric_data(self, metric, data):
1008        """
1009        Assign data to the ``metric`` tensor. The metric tensor behaves in an
1010        anomalous way when raising and lowering indices.
1011
1012        A fully covariant metric is the inverse transpose of the fully
1013        contravariant metric (it is meant matrix inverse). If the metric is
1014        symmetric, the transpose is not necessary and mixed
1015        covariant/contravariant metrics are Kronecker deltas.
1016
1017        """
1018        # hard assignment, data should not be added to `TensorHead` for metric:
1019        # the problem with `TensorHead` is that the metric is anomalous, i.e.
1020        # raising and lowering the index means considering the metric or its
1021        # inverse, this is not the case for other tensors.
1022        self._substitutions_dict_tensmul[metric, True, True] = data
1023        inverse_transpose = self.inverse_transpose_matrix(data)
1024        # in symmetric spaces, the traspose is the same as the original matrix,
1025        # the full covariant metric tensor is the inverse transpose, so this
1026        # code will be able to handle non-symmetric metrics.
1027        self._substitutions_dict_tensmul[metric, False, False] = inverse_transpose
1028        # now mixed cases, these are identical to the unit matrix if the metric
1029        # is symmetric.
1030        m = Matrix(data)
1031        invt = Matrix(inverse_transpose)
1032        self._substitutions_dict_tensmul[metric, True, False] = m * invt
1033        self._substitutions_dict_tensmul[metric, False, True] = invt * m
1034
1035    @staticmethod
1036    def _flip_index_by_metric(data, metric, pos):
1037        numpy = import_module('numpy')
1038
1039        data = numpy.tensordot(
1040            metric,
1041            data,
1042            (1, pos))
1043        return numpy.rollaxis(data, 0, pos+1)
1044
1045    @staticmethod
1046    def inverse_matrix(ndarray):
1047        m = Matrix(ndarray).inv()
1048        return _TensorDataLazyEvaluator.parse_data(m)
1049
1050    @staticmethod
1051    def inverse_transpose_matrix(ndarray):
1052        m = Matrix(ndarray).inv().T
1053        return _TensorDataLazyEvaluator.parse_data(m)
1054
1055    @staticmethod
1056    def _correct_signature_from_indices(data, indices, free, dum, inverse=False):
1057        """
1058        Utility function to correct the values inside the components data
1059        ndarray according to whether indices are covariant or contravariant.
1060
1061        It uses the metric matrix to lower values of covariant indices.
1062
1063        """
1064        numpy = import_module('numpy')
1065        # change the ndarray values according covariantness/contravariantness of the indices
1066        # use the metric
1067        for i, indx in enumerate(indices):
1068            if not indx.is_up and not inverse:
1069                data = _TensorDataLazyEvaluator._flip_index_by_metric(data, indx._tensortype.data, i)
1070            elif not indx.is_up and inverse:
1071                data = _TensorDataLazyEvaluator._flip_index_by_metric(
1072                    data,
1073                    _TensorDataLazyEvaluator.inverse_matrix(indx._tensortype.data),
1074                    i
1075                )
1076
1077        if len(dum) > 0:
1078            # perform contractions
1079            axes1 = []
1080            axes2 = []
1081            for i, indx1 in enumerate(indices):
1082                try:
1083                    nd = indices[:i].index(-indx1)
1084                except ValueError:
1085                    continue
1086                axes1.append(nd)
1087                axes2.append(i)
1088
1089            for ax1, ax2 in zip(axes1, axes2):
1090                data = numpy.trace(data, axis1=ax1, axis2=ax2)
1091        return data
1092
1093    @staticmethod
1094    @doctest_depends_on(modules=('numpy',))
1095    def parse_data(data):
1096        """
1097        Transform ``data`` to a numpy ndarray. The parameter ``data`` may
1098        contain data in various formats, e.g. nested lists, diofant ``Matrix``,
1099        and so on.
1100
1101        Examples
1102        ========
1103
1104        >>> print(str(_TensorDataLazyEvaluator.parse_data([1, 3, -6, 12])))
1105        [1 3 -6 12]
1106
1107        >>> print(str(_TensorDataLazyEvaluator.parse_data([[1, 2], [4, 7]])))
1108        [[1 2]
1109         [4 7]]
1110
1111        """
1112        numpy = import_module('numpy')
1113
1114        if (numpy is not None) and (not isinstance(data, numpy.ndarray)):
1115            vsympify = numpy.vectorize(sympify)
1116            data = vsympify(numpy.array(data))
1117        return data
1118
1119
1120_tensor_data_substitution_dict = _TensorDataLazyEvaluator()
1121
1122
1123class _TensorManager:
1124    """
1125    Class to manage tensor properties.
1126
1127    Notes
1128    =====
1129
1130    Tensors belong to tensor commutation groups; each group has a label
1131    ``comm``; there are predefined labels:
1132
1133    ``0``   tensors commuting with any other tensor
1134
1135    ``1``   tensors anticommuting among themselves
1136
1137    ``2``   tensors not commuting, apart with those with ``comm=0``
1138
1139    Other groups can be defined using ``set_comm``; tensors in those
1140    groups commute with those with ``comm=0``; by default they
1141    do not commute with any other group.
1142
1143    """
1144
1145    def __init__(self):
1146        self._comm_init()
1147
1148    def _comm_init(self):
1149        self._comm = [{} for i in range(3)]
1150        for i in range(3):
1151            self._comm[0][i] = 0
1152            self._comm[i][0] = 0
1153        self._comm[1][1] = 1
1154        self._comm[2][1] = None
1155        self._comm[1][2] = None
1156        self._comm_symbols2i = {0: 0, 1: 1, 2: 2}
1157        self._comm_i2symbol = {0: 0, 1: 1, 2: 2}
1158
1159    @property
1160    def comm(self):
1161        return self._comm
1162
1163    def comm_symbols2i(self, i):
1164        """
1165        Get the commutation group number corresponding to ``i``
1166
1167        ``i`` can be a symbol or a number or a string
1168
1169        If ``i`` is not already defined its commutation group number
1170        is set.
1171
1172        """
1173        if i not in self._comm_symbols2i:
1174            n = len(self._comm)
1175            self._comm.append({})
1176            self._comm[n][0] = 0
1177            self._comm[0][n] = 0
1178            self._comm_symbols2i[i] = n
1179            self._comm_i2symbol[n] = i
1180            return n
1181        return self._comm_symbols2i[i]
1182
1183    def comm_i2symbol(self, i):
1184        """Returns the symbol corresponding to the commutation group number."""
1185        return self._comm_i2symbol[i]
1186
1187    def set_comm(self, i, j, c):
1188        """
1189        Set the commutation parameter ``c`` for commutation groups ``i, j``
1190
1191        Parameters
1192        ==========
1193
1194        i, j : symbols representing commutation groups
1195
1196        c  :  group commutation number
1197
1198        Notes
1199        =====
1200
1201        ``i, j`` can be symbols, strings or numbers,
1202        apart from ``0, 1`` and ``2`` which are reserved respectively
1203        for commuting, anticommuting tensors and tensors not commuting
1204        with any other group apart with the commuting tensors.
1205        For the remaining cases, use this method to set the commutation rules;
1206        by default ``c=None``.
1207
1208        The group commutation number ``c`` is assigned in correspondence
1209        to the group commutation symbols; it can be
1210
1211        0        commuting
1212
1213        1        anticommuting
1214
1215        None     no commutation property
1216
1217        Examples
1218        ========
1219
1220        ``G`` and ``GH`` do not commute with themselves and commute with
1221        each other; A is commuting.
1222
1223        >>> Lorentz = TensorIndexType('Lorentz')
1224        >>> i0, i1, i2, i3, i4 = tensor_indices('i0:5', Lorentz)
1225        >>> A = tensorhead('A', [Lorentz], [[1]])
1226        >>> G = tensorhead('G', [Lorentz], [[1]], 'Gcomm')
1227        >>> GH = tensorhead('GH', [Lorentz], [[1]], 'GHcomm')
1228        >>> TensorManager.set_comm('Gcomm', 'GHcomm', 0)
1229        >>> (GH(i1)*G(i0)).canon_bp()
1230        G(i0)*GH(i1)
1231        >>> (G(i1)*G(i0)).canon_bp()
1232        G(i1)*G(i0)
1233        >>> (G(i1)*A(i0)).canon_bp()
1234        A(i0)*G(i1)
1235
1236        """
1237        if c not in (0, 1, None):
1238            raise ValueError('`c` can assume only the values 0, 1 or None')
1239
1240        if i not in self._comm_symbols2i:
1241            n = len(self._comm)
1242            self._comm.append({})
1243            self._comm[n][0] = 0
1244            self._comm[0][n] = 0
1245            self._comm_symbols2i[i] = n
1246            self._comm_i2symbol[n] = i
1247        if j not in self._comm_symbols2i:
1248            n = len(self._comm)
1249            self._comm.append({})
1250            self._comm[0][n] = 0
1251            self._comm[n][0] = 0
1252            self._comm_symbols2i[j] = n
1253            self._comm_i2symbol[n] = j
1254        ni = self._comm_symbols2i[i]
1255        nj = self._comm_symbols2i[j]
1256        self._comm[ni][nj] = c
1257        self._comm[nj][ni] = c
1258
1259    def set_comms(self, *args):
1260        """
1261        Set the commutation group numbers ``c`` for symbols ``i, j``
1262
1263        Parameters
1264        ==========
1265
1266        args : sequence of ``(i, j, c)``
1267
1268        """
1269        for i, j, c in args:
1270            self.set_comm(i, j, c)
1271
1272    def get_comm(self, i, j):
1273        """
1274        Return the commutation parameter for commutation group numbers ``i, j``
1275
1276        see ``_TensorManager.set_comm``
1277
1278        """
1279        return self._comm[i].get(j, 0 if i == 0 or j == 0 else None)
1280
1281    def clear(self):
1282        """Clear the TensorManager."""
1283        self._comm_init()
1284
1285
1286TensorManager = _TensorManager()
1287
1288
1289@doctest_depends_on(modules=('numpy',))
1290class TensorIndexType(Basic):
1291    """
1292    A TensorIndexType is characterized by its name and its metric.
1293
1294    Parameters
1295    ==========
1296
1297    name : name of the tensor type
1298
1299    metric : metric symmetry or metric object or ``None``
1300
1301
1302    dim : dimension, it can be a symbol or an integer or ``None``
1303
1304    eps_dim : dimension of the epsilon tensor
1305
1306    dummy_fmt : name of the head of dummy indices
1307
1308    Attributes
1309    ==========
1310
1311    ``name``
1312    ``metric_name`` : str
1313        it is 'metric' or metric.name
1314    ``metric_antisym``
1315    ``metric`` : TensorType
1316        the metric tensor
1317    ``delta`` : ``Kronecker delta``
1318    ``epsilon`` : the ``Levi-Civita epsilon`` tensor
1319    ``dim``
1320    ``dim_eps``
1321    ``dummy_fmt``
1322    ``data`` : a property to add ``ndarray`` values, to work in a specified basis.
1323
1324    Notes
1325    =====
1326
1327    The ``metric`` parameter can be:
1328    ``metric = False`` symmetric metric (in Riemannian geometry)
1329
1330    ``metric = True`` antisymmetric metric (for spinor calculus)
1331
1332    ``metric = None``  there is no metric
1333
1334    ``metric`` can be an object having ``name`` and ``antisym`` attributes.
1335
1336
1337    If there is a metric the metric is used to raise and lower indices.
1338
1339    In the case of antisymmetric metric, the following raising and
1340    lowering conventions will be adopted:
1341
1342    ``psi(a) = g(a, b)*psi(-b); chi(-a) = chi(b)*g(-b, -a)``
1343
1344    ``g(-a, b) = delta(-a, b); g(b, -a) = -delta(a, -b)``
1345
1346    where ``delta(-a, b) = delta(b, -a)`` is the ``Kronecker delta``
1347    (see ``TensorIndex`` for the conventions on indices).
1348
1349    If there is no metric it is not possible to raise or lower indices;
1350    e.g. the index of the defining representation of ``SU(N)``
1351    is 'covariant' and the conjugate representation is
1352    'contravariant'; for ``N > 2`` they are linearly independent.
1353
1354    ``eps_dim`` is by default equal to ``dim``, if the latter is an integer;
1355    else it can be assigned (for use in naive dimensional regularization);
1356    if ``eps_dim`` is not an integer ``epsilon`` is ``None``.
1357
1358    Examples
1359    ========
1360
1361    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1362    >>> Lorentz.metric
1363    metric(Lorentz,Lorentz)
1364
1365    Examples with metric components data added, this means it is working on a
1366    fixed basis:
1367
1368    >>> Lorentz.data = [1, -1, -1, -1]
1369    >>> print(sstr(Lorentz))
1370    TensorIndexType(Lorentz, 0)
1371    >>> print(str(Lorentz.data))
1372    [[1 0 0 0]
1373    [0 -1 0 0]
1374    [0 0 -1 0]
1375    [0 0 0 -1]]
1376
1377    """
1378
1379    def __new__(cls, name, metric=False, dim=None, eps_dim=None,
1380                dummy_fmt=None):
1381
1382        if isinstance(name, str):
1383            name = Symbol(name)
1384        obj = Basic.__new__(cls, name, Integer(1) if metric else Integer(0))
1385        obj._name = str(name)
1386        if not dummy_fmt:
1387            obj._dummy_fmt = f'{obj.name}_%d'
1388        else:
1389            obj._dummy_fmt = f'{dummy_fmt}_%d'
1390        if metric is None:
1391            obj.metric_antisym = None
1392            obj.metric = None
1393        else:
1394            if metric in (True, False, 0, 1):
1395                metric_name = 'metric'
1396                obj.metric_antisym = metric
1397            else:
1398                metric_name = metric.name
1399                obj.metric_antisym = metric.antisym
1400            sym2 = TensorSymmetry(get_symmetric_group_sgs(2, obj.metric_antisym))
1401            S2 = TensorType([obj]*2, sym2)
1402            obj.metric = S2(metric_name)
1403            obj.metric._matrix_behavior = True
1404
1405        obj._dim = dim
1406        obj._delta = obj.get_kronecker_delta()
1407        obj._eps_dim = eps_dim if eps_dim else dim
1408        obj._epsilon = obj.get_epsilon()
1409        obj._autogenerated = []
1410        return obj
1411
1412    @property
1413    def auto_right(self):
1414        if not hasattr(self, '_auto_right'):
1415            self._auto_right = TensorIndex('auto_right', self)
1416        return self._auto_right
1417
1418    @property
1419    def auto_left(self):
1420        if not hasattr(self, '_auto_left'):
1421            self._auto_left = TensorIndex('auto_left', self)
1422        return self._auto_left
1423
1424    @property
1425    def data(self):
1426        return _tensor_data_substitution_dict[self]
1427
1428    @data.setter
1429    def data(self, data):
1430        # This assignment is a bit controversial, should metric components be assigned
1431        # to the metric only or also to the TensorIndexType object? The advantage here
1432        # is the ability to assign a 1D array and transform it to a 2D diagonal array.
1433        numpy = import_module('numpy')
1434        data = _TensorDataLazyEvaluator.parse_data(data)
1435        if data.ndim > 2:
1436            raise ValueError('data have to be of rank 1 (diagonal metric) or 2.')
1437        if data.ndim == 1:
1438            if self.dim is not None:
1439                nda_dim = data.shape[0]
1440                if nda_dim != self.dim:
1441                    raise ValueError('Dimension mismatch')
1442
1443            dim = data.shape[0]
1444            newndarray = numpy.zeros((dim, dim), dtype=object)
1445            for i, val in enumerate(data):
1446                newndarray[i, i] = val
1447            data = newndarray
1448        dim1, dim2 = data.shape
1449        if dim1 != dim2:
1450            raise ValueError('Non-square matrix tensor.')
1451        if self.dim is not None:
1452            if self.dim != dim1:
1453                raise ValueError('Dimension mismatch')
1454        _tensor_data_substitution_dict[self] = data
1455        _tensor_data_substitution_dict.add_metric_data(self.metric, data)
1456        delta = self.get_kronecker_delta()
1457        i1 = TensorIndex('i1', self)
1458        i2 = TensorIndex('i2', self)
1459        delta(i1, -i2).data = _TensorDataLazyEvaluator.parse_data(eye(dim1))
1460
1461    @data.deleter
1462    def data(self):
1463        if self in _tensor_data_substitution_dict:
1464            del _tensor_data_substitution_dict[self]
1465
1466    @property
1467    def name(self):
1468        return self._name
1469
1470    @property
1471    def dim(self):
1472        return self._dim
1473
1474    @property
1475    def delta(self):
1476        return self._delta
1477
1478    @property
1479    def eps_dim(self):
1480        return self._eps_dim
1481
1482    @property
1483    def epsilon(self):
1484        return self._epsilon
1485
1486    def get_kronecker_delta(self):
1487        sym2 = TensorSymmetry(get_symmetric_group_sgs(2))
1488        S2 = TensorType([self]*2, sym2)
1489        delta = S2('KD')
1490        delta._matrix_behavior = True
1491        return delta
1492
1493    def get_epsilon(self):
1494        if not isinstance(self._eps_dim, int):
1495            return
1496        sym = TensorSymmetry(get_symmetric_group_sgs(self._eps_dim, 1))
1497        Sdim = TensorType([self]*self._eps_dim, sym)
1498        epsilon = Sdim('Eps')
1499        return epsilon
1500
1501    def __lt__(self, other):
1502        return self.name < other.name
1503
1504    def __str__(self):
1505        return self.name
1506
1507    __repr__ = __str__
1508
1509
1510@doctest_depends_on(modules=('numpy',))
1511class TensorIndex(Basic):
1512    """
1513    Represents an abstract tensor index.
1514
1515    Parameters
1516    ==========
1517
1518    name : name of the index, or ``True`` if you want it to be automatically assigned
1519    tensortype : ``TensorIndexType`` of the index
1520    is_up :  flag for contravariant index
1521
1522    Attributes
1523    ==========
1524
1525    ``name``
1526    ``tensortype``
1527    ``is_up``
1528
1529    Notes
1530    =====
1531
1532    Tensor indices are contracted with the Einstein summation convention.
1533
1534    An index can be in contravariant or in covariant form; in the latter
1535    case it is represented prepending a ``-`` to the index name.
1536
1537    Dummy indices have a name with head given by ``tensortype._dummy_fmt``
1538
1539    Examples
1540    ========
1541
1542    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1543    >>> i = TensorIndex('i', Lorentz)
1544    >>> i
1545    i
1546    >>> sym1 = TensorSymmetry(*get_symmetric_group_sgs(1))
1547    >>> S1 = TensorType([Lorentz], sym1)
1548    >>> A, B = S1('A B')
1549    >>> A(i)*B(-i)
1550    A(L_0)*B(-L_0)
1551
1552    If you want the index name to be automatically assigned, just put ``True``
1553    in the ``name`` field, it will be generated using the reserved character
1554    ``_`` in front of its name, in order to avoid conflicts with possible
1555    existing indices:
1556
1557    >>> i0 = TensorIndex(True, Lorentz)
1558    >>> i0
1559    _i0
1560    >>> i1 = TensorIndex(True, Lorentz)
1561    >>> i1
1562    _i1
1563    >>> A(i0)*B(-i1)
1564    A(_i0)*B(-_i1)
1565    >>> A(i0)*B(-i0)
1566    A(L_0)*B(-L_0)
1567
1568    """
1569
1570    def __new__(cls, name, tensortype, is_up=True):
1571        if isinstance(name, str):
1572            name_symbol = Symbol(name)
1573        elif isinstance(name, Symbol):
1574            name_symbol = name
1575        elif name is True:
1576            name = f'_i{len(tensortype._autogenerated)}'
1577            name_symbol = Symbol(name)
1578            tensortype._autogenerated.append(name_symbol)
1579        else:
1580            raise ValueError('invalid name')
1581
1582        obj = Basic.__new__(cls, name_symbol, tensortype, Integer(1) if is_up else Integer(0))
1583        obj._name = str(name)
1584        obj._tensortype = tensortype
1585        obj._is_up = is_up
1586        return obj
1587
1588    @property
1589    def name(self):
1590        return self._name
1591
1592    @property
1593    def tensortype(self):
1594        return self._tensortype
1595
1596    @property
1597    def is_up(self):
1598        return self._is_up
1599
1600    def _print(self):
1601        s = self._name
1602        if not self._is_up:
1603            s = f'-{s}'
1604        return s
1605
1606    def __lt__(self, other):
1607        return (self._tensortype, self._name) < (other._tensortype, other._name)
1608
1609    def __neg__(self):
1610        t1 = TensorIndex(self._name, self._tensortype,
1611                         (not self._is_up))
1612        return t1
1613
1614
1615def tensor_indices(s, typ):
1616    """
1617    Returns list of tensor indices given their names and their types
1618
1619    Parameters
1620    ==========
1621
1622    s : string of comma separated names of indices
1623
1624    typ : list of ``TensorIndexType`` of the indices
1625
1626    Examples
1627    ========
1628
1629    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1630    >>> a, b, c, d = tensor_indices('a b c d', Lorentz)
1631
1632    """
1633    if isinstance(s, str):
1634        a = [x.name for x in symbols(s, seq=True)]
1635    else:
1636        raise ValueError('expecting a string')
1637
1638    tilist = [TensorIndex(i, typ) for i in a]
1639    if len(tilist) == 1:
1640        return tilist[0]
1641    return tilist
1642
1643
1644@doctest_depends_on(modules=('numpy',))
1645class TensorSymmetry(Basic):
1646    """
1647    Monoterm symmetry of a tensor
1648
1649    Parameters
1650    ==========
1651
1652    bsgs : tuple ``(base, sgs)`` BSGS of the symmetry of the tensor
1653
1654    Attributes
1655    ==========
1656
1657    ``base`` : Tuple
1658        base of the BSGS
1659    ``generators`` : Tuple
1660        generators of the BSGS
1661    ``rank`` : Tuple
1662        rank of the tensor
1663
1664    Notes
1665    =====
1666
1667    A tensor can have an arbitrary monoterm symmetry provided by its BSGS.
1668    Multiterm symmetries, like the cyclic symmetry of the Riemann tensor,
1669    are not covered.
1670
1671    See Also
1672    ========
1673
1674    diofant.combinatorics.tensor_can.get_symmetric_group_sgs
1675
1676    Examples
1677    ========
1678
1679    Define a symmetric tensor
1680
1681    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1682    >>> sym2 = TensorSymmetry(get_symmetric_group_sgs(2))
1683    >>> S2 = TensorType([Lorentz]*2, sym2)
1684    >>> V = S2('V')
1685
1686    """
1687
1688    def __new__(cls, *args, **kw_args):
1689        if len(args) == 1:
1690            base, generators = args[0]
1691        elif len(args) == 2:
1692            base, generators = args
1693        else:
1694            raise TypeError('bsgs required, either two separate parameters or one tuple')
1695
1696        if not isinstance(base, Tuple):
1697            base = Tuple(*base)
1698        if not isinstance(generators, Tuple):
1699            generators = Tuple(*generators)
1700        obj = Basic.__new__(cls, base, generators, **kw_args)
1701        return obj
1702
1703    @property
1704    def base(self):
1705        return self.args[0]
1706
1707    @property
1708    def generators(self):
1709        return self.args[1]
1710
1711    @property
1712    def rank(self):
1713        return self.args[1][0].size - 2
1714
1715
1716def tensorsymmetry(*args):
1717    """
1718    Return a ``TensorSymmetry`` object.
1719
1720    One can represent a tensor with any monoterm slot symmetry group
1721    using a BSGS.
1722
1723    ``args`` can be a BSGS
1724    ``args[0]``    base
1725    ``args[1]``    sgs
1726
1727    Usually tensors are in (direct products of) representations
1728    of the symmetric group;
1729    ``args`` can be a list of lists representing the shapes of Young tableaux
1730
1731    Notes
1732    =====
1733
1734    For instance:
1735    ``[[1]]``       vector
1736    ``[[1]*n]``     symmetric tensor of rank ``n``
1737    ``[[n]]``       antisymmetric tensor of rank ``n``
1738    ``[[2, 2]]``    monoterm slot symmetry of the Riemann tensor
1739    ``[[1],[1]]``   vector*vector
1740    ``[[2],[1],[1]`` (antisymmetric tensor)*vector*vector
1741
1742    Notice that with the shape ``[2, 2]`` we associate only the monoterm
1743    symmetries of the Riemann tensor; this is an abuse of notation,
1744    since the shape ``[2, 2]`` corresponds usually to the irreducible
1745    representation characterized by the monoterm symmetries and by the
1746    cyclic symmetry.
1747
1748    Examples
1749    ========
1750
1751    Symmetric tensor using a Young tableau
1752
1753    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1754    >>> sym2 = tensorsymmetry([1, 1])
1755    >>> S2 = TensorType([Lorentz]*2, sym2)
1756    >>> V = S2('V')
1757
1758    Symmetric tensor using a ``BSGS`` (base, strong generator set)
1759
1760    >>> sym2 = tensorsymmetry(*get_symmetric_group_sgs(2))
1761    >>> S2 = TensorType([Lorentz]*2, sym2)
1762    >>> V = S2('V')
1763
1764    """
1765    from ..combinatorics import Permutation
1766
1767    def tableau2bsgs(a):
1768        if len(a) == 1:
1769            # antisymmetric vector
1770            n = a[0]
1771            bsgs = get_symmetric_group_sgs(n, 1)
1772        else:
1773            if all(x == 1 for x in a):
1774                # symmetric vector
1775                n = len(a)
1776                bsgs = get_symmetric_group_sgs(n)
1777            elif a == [2, 2]:
1778                bsgs = riemann_bsgs
1779            else:
1780                raise NotImplementedError
1781        return bsgs
1782
1783    if not args:
1784        return TensorSymmetry(Tuple(), Tuple(Permutation(1)))
1785
1786    if len(args) == 2 and isinstance(args[1][0], Permutation):
1787        return TensorSymmetry(args)
1788    base, sgs = tableau2bsgs(args[0])
1789    for a in args[1:]:
1790        basex, sgsx = tableau2bsgs(a)
1791        base, sgs = bsgs_direct_product(base, sgs, basex, sgsx)
1792    return TensorSymmetry(Tuple(base, sgs))
1793
1794
1795@doctest_depends_on(modules=('numpy',))
1796class TensorType(Basic):
1797    """
1798    Class of tensor types.
1799
1800    Parameters
1801    ==========
1802
1803    index_types : list of ``TensorIndexType`` of the tensor indices
1804    symmetry : ``TensorSymmetry`` of the tensor
1805
1806    Attributes
1807    ==========
1808
1809    ``index_types``
1810    ``symmetry``
1811    ``types`` : list of ``TensorIndexType`` without repetitions
1812
1813    Examples
1814    ========
1815
1816    Define a symmetric tensor
1817
1818    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1819    >>> sym2 = tensorsymmetry([1, 1])
1820    >>> S2 = TensorType([Lorentz]*2, sym2)
1821    >>> V = S2('V')
1822
1823    """
1824
1825    is_commutative = False
1826
1827    def __new__(cls, index_types, symmetry, **kw_args):
1828        assert symmetry.rank == len(index_types)
1829        obj = Basic.__new__(cls, Tuple(*index_types), symmetry, **kw_args)
1830        return obj
1831
1832    @property
1833    def index_types(self):
1834        return self.args[0]
1835
1836    @property
1837    def symmetry(self):
1838        return self.args[1]
1839
1840    @property
1841    def types(self):
1842        return sorted(set(self.index_types), key=lambda x: x.name)
1843
1844    def __str__(self):
1845        return f'TensorType({[str(x) for x in self.index_types]})'
1846
1847    def __call__(self, s, comm=0, matrix_behavior=0):
1848        """
1849        Return a TensorHead object or a list of TensorHead objects.
1850
1851        ``s``  name or string of names
1852
1853        ``comm``: commutation group number
1854        see ``_TensorManager.set_comm``
1855
1856        Examples
1857        ========
1858
1859        Define symmetric tensors ``V``, ``W`` and ``G``, respectively
1860        commuting, anticommuting and with no commutation symmetry
1861
1862        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1863        >>> a, b = tensor_indices('a b', Lorentz)
1864        >>> sym2 = tensorsymmetry([1]*2)
1865        >>> S2 = TensorType([Lorentz]*2, sym2)
1866        >>> V = S2('V')
1867        >>> W = S2('W', 1)
1868        >>> G = S2('G', 2)
1869        >>> canon_bp(V(a, b)*V(-b, -a))
1870        V(L_0, L_1)*V(-L_0, -L_1)
1871        >>> canon_bp(W(a, b)*W(-b, -a))
1872        0
1873
1874        """
1875        if isinstance(s, str):
1876            names = [x.name for x in symbols(s, seq=True)]
1877        else:
1878            raise ValueError('expecting a string')
1879        if len(names) == 1:
1880            return TensorHead(names[0], self, comm, matrix_behavior=matrix_behavior)
1881        else:
1882            return [TensorHead(name, self, comm, matrix_behavior=matrix_behavior) for name in names]
1883
1884
1885def tensorhead(name, typ, sym, comm=0, matrix_behavior=0):
1886    """
1887    Function generating tensorhead(s).
1888
1889    Parameters
1890    ==========
1891
1892    name : name or sequence of names (as in ``symbol``)
1893
1894    typ :  index types
1895
1896    sym :  same as ``*args`` in ``tensorsymmetry``
1897
1898    comm : commutation group number
1899    see ``_TensorManager.set_comm``
1900
1901
1902    Examples
1903    ========
1904
1905    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1906    >>> a, b = tensor_indices('a b', Lorentz)
1907    >>> A = tensorhead('A', [Lorentz]*2, [[1]*2])
1908    >>> A(a, -b)
1909    A(a, -b)
1910
1911    """
1912    sym = tensorsymmetry(*sym)
1913    S = TensorType(typ, sym)
1914    th = S(name, comm, matrix_behavior=matrix_behavior)
1915    return th
1916
1917
1918@doctest_depends_on(modules=('numpy',))
1919class TensorHead(Basic):
1920    r"""
1921    Tensor head of the tensor
1922
1923    Parameters
1924    ==========
1925
1926    name : name of the tensor
1927
1928    typ : list of TensorIndexType
1929
1930    comm : commutation group number
1931
1932    Attributes
1933    ==========
1934
1935    ``name``
1936    ``index_types``
1937    ``rank``
1938    ``types``  :  equal to ``typ.types``
1939    ``symmetry`` : equal to ``typ.symmetry``
1940    ``comm`` : int
1941        commutation group
1942
1943    Notes
1944    =====
1945
1946    A ``TensorHead`` belongs to a commutation group, defined by a
1947    symbol on number ``comm`` (see ``_TensorManager.set_comm``);
1948    tensors in a commutation group have the same commutation properties;
1949    by default ``comm`` is ``0``, the group of the commuting tensors.
1950
1951    Examples
1952    ========
1953
1954    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
1955    >>> A = tensorhead('A', [Lorentz, Lorentz], [[1], [1]])
1956
1957    Examples with ndarray values, the components data assigned to the
1958    ``TensorHead`` object are assumed to be in a fully-contravariant
1959    representation. In case it is necessary to assign components data which
1960    represents the values of a non-fully covariant tensor, see the other
1961    examples.
1962
1963    >>> Lorentz.data = [1, -1, -1, -1]
1964    >>> i0, i1 = tensor_indices('i0:2', Lorentz)
1965    >>> A.data = [[j+2*i for j in range(4)] for i in range(4)]
1966
1967    in order to retrieve data, it is also necessary to specify abstract indices
1968    enclosed by round brackets, then numerical indices inside square brackets.
1969
1970    >>> A(i0, i1)[0, 0]
1971    0
1972    >>> A(i0, i1)[2, 3] == 3+2*2
1973    True
1974
1975    Notice that square brackets create a valued tensor expression instance:
1976
1977    >>> A(i0, i1)
1978    A(i0, i1)
1979
1980    To view the data, just type:
1981
1982    >>> print(str(A.data))
1983    [[0 1 2 3]
1984     [2 3 4 5]
1985     [4 5 6 7]
1986     [6 7 8 9]]
1987
1988    Turning to a tensor expression, covariant indices get the corresponding
1989    components data corrected by the metric:
1990
1991    >>> print(str(A(i0, -i1).data))
1992    [[0 -1 -2 -3]
1993     [2 -3 -4 -5]
1994     [4 -5 -6 -7]
1995     [6 -7 -8 -9]]
1996
1997    >>> print(str(A(-i0, -i1).data))
1998    [[0 -1 -2 -3]
1999     [-2 3 4 5]
2000     [-4 5 6 7]
2001     [-6 7 8 9]]
2002
2003    while if all indices are contravariant, the ``ndarray`` remains the same
2004
2005    >>> print(str(A(i0, i1).data))
2006     [[0 1 2 3]
2007     [2 3 4 5]
2008     [4 5 6 7]
2009     [6 7 8 9]]
2010
2011    When all indices are contracted and components data are added to the tensor,
2012    accessing the data will return a scalar, no numpy object. In fact, numpy
2013    ndarrays are dropped to scalars if they contain only one element.
2014
2015    >>> A(i0, -i0)
2016    A(L_0, -L_0)
2017    >>> A(i0, -i0).data
2018    -18
2019
2020    It is also possible to assign components data to an indexed tensor, i.e. a
2021    tensor with specified covariant and contravariant components. In this
2022    example, the covariant components data of the Electromagnetic tensor are
2023    injected into `A`:
2024
2025    >>> Ex, Ey, Ez, Bx, By, Bz = symbols('E_x E_y E_z B_x B_y B_z')
2026    >>> c = symbols('c', positive=True)
2027
2028    Let's define `F`, an antisymmetric tensor, we have to assign an
2029    antisymmetric matrix to it, because `[[2]]` stands for the Young tableau
2030    representation of an antisymmetric set of two elements:
2031
2032    >>> F = tensorhead('A', [Lorentz, Lorentz], [[2]])
2033    >>> F(-i0, -i1).data = [[0, Ex/c, Ey/c, Ez/c],
2034    ...                     [-Ex/c, 0, -Bz, By],
2035    ...                     [-Ey/c, Bz, 0, -Bx],
2036    ...                     [-Ez/c, -By, Bx, 0]]
2037
2038    Now it is possible to retrieve the contravariant form of the Electromagnetic
2039    tensor:
2040
2041    >>> print(str(F(i0, i1).data))
2042    [[0 -E_x/c -E_y/c -E_z/c]
2043     [E_x/c 0 -B_z B_y]
2044     [E_y/c B_z 0 -B_x]
2045     [E_z/c -B_y B_x 0]]
2046
2047    and the mixed contravariant-covariant form:
2048
2049    >>> print(str(F(i0, -i1).data))
2050    [[0 E_x/c E_y/c E_z/c]
2051     [E_x/c 0 B_z -B_y]
2052     [E_y/c -B_z 0 B_x]
2053     [E_z/c B_y -B_x 0]]
2054
2055    To convert the numpy's ndarray to a diofant matrix, just cast:
2056
2057    >>> Matrix(F.data)
2058    Matrix([
2059    [    0, -E_x/c, -E_y/c, -E_z/c],
2060    [E_x/c,      0,   -B_z,    B_y],
2061    [E_y/c,    B_z,      0,   -B_x],
2062    [E_z/c,   -B_y,    B_x,      0]])
2063
2064    Still notice, in this last example, that accessing components data from a
2065    tensor without specifying the indices is equivalent to assume that all
2066    indices are contravariant.
2067
2068    It is also possible to store symbolic components data inside a tensor, for
2069    example, define a four-momentum-like tensor:
2070
2071    >>> P = tensorhead('P', [Lorentz], [[1]])
2072    >>> E, px, py, pz = symbols('E p_x p_y p_z', positive=True)
2073    >>> P.data = [E, px, py, pz]
2074
2075    The contravariant and covariant components are, respectively:
2076
2077    >>> print(str(P(i0).data))
2078    [E p_x p_y p_z]
2079    >>> print(str(P(-i0).data))
2080    [E -p_x -p_y -p_z]
2081
2082    The contraction of a 1-index tensor by itself is usually indicated by a
2083    power by two:
2084
2085    >>> P(i0)**2
2086    E**2 - p_x**2 - p_y**2 - p_z**2
2087
2088    As the power by two is clearly identical to `P_\mu P^\mu`, it is possible to
2089    simply contract the ``TensorHead`` object, without specifying the indices
2090
2091    >>> P**2
2092    E**2 - p_x**2 - p_y**2 - p_z**2
2093
2094    """
2095
2096    is_commutative = False
2097
2098    def __new__(cls, name, typ, comm=0, matrix_behavior=0, **kw_args):
2099        if isinstance(name, str):
2100            name_symbol = Symbol(name)
2101        elif isinstance(name, Symbol):
2102            name_symbol = name
2103        else:
2104            raise ValueError('invalid name')
2105
2106        comm2i = TensorManager.comm_symbols2i(comm)
2107
2108        obj = Basic.__new__(cls, name_symbol, typ, **kw_args)
2109
2110        obj._matrix_behavior = matrix_behavior
2111
2112        obj._name = obj.args[0].name
2113        obj._rank = len(obj.index_types)
2114        obj._types = typ.types
2115        obj._symmetry = typ.symmetry
2116        obj._comm = comm2i
2117        return obj
2118
2119    @property
2120    def name(self):
2121        return self._name
2122
2123    @property
2124    def rank(self):
2125        return self._rank
2126
2127    @property
2128    def types(self):
2129        return self._types[:]
2130
2131    @property
2132    def symmetry(self):
2133        return self._symmetry
2134
2135    @property
2136    def typ(self):
2137        return self.args[1]
2138
2139    @property
2140    def comm(self):
2141        return self._comm
2142
2143    @property
2144    def index_types(self):
2145        return self.args[1].index_types[:]
2146
2147    def __lt__(self, other):
2148        return (self.name, self.index_types) < (other.name, other.index_types)
2149
2150    def commutes_with(self, other):
2151        """
2152        Returns ``0`` if ``self`` and ``other`` commute, ``1`` if they anticommute.
2153
2154        Returns ``None`` if ``self`` and ``other`` neither commute nor anticommute.
2155
2156        """
2157        r = TensorManager.get_comm(self._comm, other._comm)
2158        return r
2159
2160    def _print(self):
2161        return f"{self.name}({','.join([str(x) for x in self.index_types])})"
2162
2163    def _check_auto_matrix_indices_in_call(self, *indices):
2164        matrix_behavior_kinds = {}
2165
2166        if len(indices) != len(self.index_types):
2167            if not self._matrix_behavior:
2168                raise ValueError('wrong number of indices')
2169
2170            # Take the last one or two missing
2171            # indices as auto-matrix indices:
2172            ldiff = len(self.index_types) - len(indices)
2173            if ldiff > 2:
2174                raise ValueError('wrong number of indices')
2175            if ldiff == 2:
2176                mat_ind = [len(indices), len(indices) + 1]
2177            elif ldiff == 1:
2178                mat_ind = [len(indices)]
2179            not_equal = True
2180        else:
2181            not_equal = False
2182            mat_ind = [i for i, e in enumerate(indices) if e is True]
2183            if mat_ind:
2184                not_equal = True
2185            indices = tuple(_ for _ in indices if _ is not True)
2186
2187            for i, el in enumerate(indices):
2188                if not isinstance(el, TensorIndex):
2189                    not_equal = True
2190                    break
2191                if el._tensortype != self.index_types[i]:
2192                    not_equal = True
2193                    break
2194
2195        if not_equal:
2196            for el in mat_ind:
2197                eltyp = self.index_types[el]
2198                if eltyp in matrix_behavior_kinds:
2199                    elind = -self.index_types[el].auto_right
2200                    matrix_behavior_kinds[eltyp].append(elind)
2201                else:
2202                    elind = self.index_types[el].auto_left
2203                    matrix_behavior_kinds[eltyp] = [elind]
2204                indices = indices[:el] + (elind,) + indices[el:]
2205
2206        return indices, matrix_behavior_kinds
2207
2208    def __call__(self, *indices, **kw_args):
2209        """
2210        Returns a tensor with indices.
2211
2212        There is a special behavior in case of indices denoted by ``True``,
2213        they are considered auto-matrix indices, their slots are automatically
2214        filled, and confer to the tensor the behavior of a matrix or vector
2215        upon multiplication with another tensor containing auto-matrix indices
2216        of the same ``TensorIndexType``. This means indices get summed over the
2217        same way as in matrix multiplication. For matrix behavior, define two
2218        auto-matrix indices, for vector behavior define just one.
2219
2220        Examples
2221        ========
2222
2223        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
2224        >>> a, b = tensor_indices('a b', Lorentz)
2225        >>> A = tensorhead('A', [Lorentz]*2, [[1]*2])
2226        >>> t = A(a, -b)
2227        >>> t
2228        A(a, -b)
2229
2230        To use the auto-matrix index behavior, just put a ``True`` on the
2231        desired index position.
2232
2233        >>> r = A(True, True)
2234        >>> r
2235        A(auto_left, -auto_right)
2236
2237        Here ``auto_left`` and ``auto_right`` are automatically generated
2238        tensor indices, they are only two for every ``TensorIndexType`` and
2239        can be assigned to just one or two indices of a given type.
2240
2241        Auto-matrix indices can be assigned many times in a tensor, if indices
2242        are of different ``TensorIndexType``
2243
2244        >>> Spinor = TensorIndexType('Spinor', dummy_fmt='S')
2245        >>> B = tensorhead('B', [Lorentz, Lorentz, Spinor, Spinor], [[1]*4])
2246        >>> s = B(True, True, True, True)
2247        >>> s
2248        B(auto_left, -auto_right, auto_left, -auto_right)
2249
2250        Here, ``auto_left`` and ``auto_right`` are repeated twice, but they are
2251        not the same indices, as they refer to different ``TensorIndexType``s.
2252
2253        Auto-matrix indices are automatically contracted upon multiplication,
2254
2255        >>> r*s
2256        A(auto_left, L_0)*B(-L_0, -auto_right, auto_left, -auto_right)
2257
2258        The multiplication algorithm has found an ``auto_right`` index in ``A``
2259        and an ``auto_left`` index in ``B`` referring to the same
2260        ``TensorIndexType`` (``Lorentz``), so they have been contracted.
2261
2262        Auto-matrix indices can be accessed from the ``TensorIndexType``:
2263
2264        >>> Lorentz.auto_right
2265        auto_right
2266        >>> Lorentz.auto_left
2267        auto_left
2268
2269        There is a special case, in which the ``True`` parameter is not needed
2270        to declare an auto-matrix index, i.e. when the matrix behavior has been
2271        declared upon ``TensorHead`` construction, in that case the last one or
2272        two tensor indices may be omitted, so that they automatically become
2273        auto-matrix indices:
2274
2275        >>> C = tensorhead('C', [Lorentz, Lorentz], [[1]*2], matrix_behavior=True)
2276        >>> C()
2277        C(auto_left, -auto_right)
2278
2279        """
2280        indices, _ = self._check_auto_matrix_indices_in_call(*indices)
2281        tensor = Tensor._new_with_dummy_replacement(self, indices, **kw_args)
2282        return tensor
2283
2284    def __pow__(self, other):
2285        if self.data is None:
2286            raise ValueError('No power on abstract tensors.')
2287        numpy = import_module('numpy')
2288        metrics = [_.data for _ in self.args[1].args[0]]
2289
2290        marray = self.data
2291        for metric in metrics:
2292            marray = numpy.tensordot(marray, numpy.tensordot(metric, marray, (1, 0)), (0, 0))
2293        pow2 = marray[()]
2294        return pow2 ** (Rational(1, 2) * other)
2295
2296    @property
2297    def data(self):
2298        return _tensor_data_substitution_dict[self]
2299
2300    @data.setter
2301    def data(self, data):
2302        _tensor_data_substitution_dict[self] = data
2303
2304    @data.deleter
2305    def data(self):
2306        if self in _tensor_data_substitution_dict:
2307            del _tensor_data_substitution_dict[self]
2308
2309    def __iter__(self):
2310        return self.data.flatten().__iter__()
2311
2312
2313@doctest_depends_on(modules=('numpy',))
2314class TensExpr(Basic):
2315    """
2316    Abstract base class for tensor expressions
2317
2318    Notes
2319    =====
2320
2321    A tensor expression is an expression formed by tensors;
2322    currently the sums of tensors are distributed.
2323
2324    A ``TensExpr`` can be a ``TensAdd`` or a ``TensMul``.
2325
2326    ``TensAdd`` objects are put in canonical form using the Butler-Portugal
2327    algorithm for canonicalization under monoterm symmetries.
2328
2329    ``TensMul`` objects are formed by products of component tensors,
2330    and include a coefficient, which is a Diofant expression.
2331
2332
2333    In the internal representation contracted indices are represented
2334    by ``(ipos1, ipos2, icomp1, icomp2)``, where ``icomp1`` is the position
2335    of the component tensor with contravariant index, ``ipos1`` is the
2336    slot which the index occupies in that component tensor.
2337
2338    Contracted indices are therefore nameless in the internal representation.
2339
2340    """
2341
2342    _op_priority = 11.0
2343    is_commutative = False
2344
2345    def __neg__(self):
2346        return self*Integer(-1)
2347
2348    def __abs__(self):
2349        raise NotImplementedError
2350
2351    def __add__(self, other):
2352        raise NotImplementedError
2353
2354    def __radd__(self, other):
2355        raise NotImplementedError
2356
2357    def __sub__(self, other):
2358        raise NotImplementedError
2359
2360    def __rsub__(self, other):
2361        raise NotImplementedError
2362
2363    def __mul__(self, other):
2364        raise NotImplementedError
2365
2366    def __pow__(self, other):
2367        if self.data is None:
2368            raise ValueError('No power without ndarray data.')
2369        numpy = import_module('numpy')
2370        free = self.free
2371
2372        marray = self.data
2373        for metric in free:
2374            marray = numpy.tensordot(
2375                marray,
2376                numpy.tensordot(
2377                    metric[0]._tensortype.data,
2378                    marray,
2379                    (1, 0)
2380                ),
2381                (0, 0)
2382            )
2383        pow2 = marray[()]
2384        return pow2 ** (Rational(1, 2) * other)
2385
2386    def __rpow__(self, other):
2387        raise NotImplementedError
2388
2389    def __truediv__(self, other):
2390        raise NotImplementedError
2391
2392    def __rtruediv__(self, other):
2393        raise NotImplementedError()
2394
2395    @doctest_depends_on(modules=('numpy',))
2396    def get_matrix(self):
2397        """
2398        Returns ndarray components data as a matrix, if components data are
2399        available and ndarray dimension does not exceed 2.
2400
2401        Examples
2402        ========
2403
2404        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
2405        >>> sym2 = tensorsymmetry([1]*2)
2406        >>> S2 = TensorType([Lorentz]*2, sym2)
2407        >>> A = S2('A')
2408
2409        The tensor ``A`` is symmetric in its indices, as can be deduced by the
2410        ``[1, 1]`` Young tableau when constructing `sym2`. One has to be
2411        careful to assign symmetric component data to ``A``, as the symmetry
2412        properties of data are currently not checked to be compatible with the
2413        defined tensor symmetry.
2414
2415        >>> Lorentz.data = [1, -1, -1, -1]
2416        >>> i0, i1 = tensor_indices('i0:2', Lorentz)
2417        >>> A.data = [[j+i for j in range(4)] for i in range(4)]
2418        >>> A(i0, i1).get_matrix()
2419        Matrix([
2420        [0, 1, 2, 3],
2421        [1, 2, 3, 4],
2422        [2, 3, 4, 5],
2423        [3, 4, 5, 6]])
2424
2425        It is possible to perform usual operation on matrices, such as the
2426        matrix multiplication:
2427
2428        >>> A(i0, i1).get_matrix()*ones(4, 1)
2429        Matrix([
2430        [ 6],
2431        [10],
2432        [14],
2433        [18]])
2434
2435        >>> del A.data
2436
2437        """
2438        if 0 < self.rank <= 2:
2439            rows = self.data.shape[0]
2440            columns = self.data.shape[1] if self.rank == 2 else 1
2441            if self.rank == 2:
2442                mat_list = [] * rows
2443                for i in range(rows):
2444                    mat_list.append([])
2445                    for j in range(columns):
2446                        mat_list[i].append(self[i, j])
2447            else:
2448                mat_list = [None] * rows
2449                for i in range(rows):
2450                    mat_list[i] = self[i]
2451            return Matrix(mat_list)
2452        else:
2453            raise NotImplementedError(
2454                'missing multidimensional reduction to matrix.')
2455
2456
2457@doctest_depends_on(modules=('numpy',))
2458class TensAdd(TensExpr):
2459    """
2460    Sum of tensors
2461
2462    Parameters
2463    ==========
2464
2465    free_args : list of the free indices
2466
2467    Attributes
2468    ==========
2469
2470    ``args`` : tuple
2471        of addends
2472    ``rank`` : tuple
2473        rank of the tensor
2474    ``free_args`` : list
2475        of the free indices in sorted order
2476
2477    Notes
2478    =====
2479
2480    Sum of more than one tensor are put automatically in canonical form.
2481
2482    Examples
2483    ========
2484
2485    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
2486    >>> a, b = tensor_indices('a b', Lorentz)
2487    >>> p, q = tensorhead('p q', [Lorentz], [[1]])
2488    >>> t = p(a) + q(a)
2489    >>> t
2490    p(a) + q(a)
2491    >>> t(b)
2492    p(b) + q(b)
2493
2494    Examples with components data added to the tensor expression:
2495
2496    >>> Lorentz.data = [1, -1, -1, -1]
2497    >>> a, b = tensor_indices('a, b', Lorentz)
2498    >>> p.data = [2, 3, -2, 7]
2499    >>> q.data = [2, 3, -2, 7]
2500    >>> t = p(a) + q(a)
2501    >>> t
2502    p(a) + q(a)
2503    >>> t(b)
2504    p(b) + q(b)
2505
2506    The following are: 2**2 - 3**2 - 2**2 - 7**2 ==> -58
2507
2508    >>> (p(a)*p(-a)).data
2509    -58
2510    >>> p(a)**2
2511    -58
2512
2513    """
2514
2515    def __new__(cls, *args, **kw_args):
2516        args = [sympify(x) for x in args if x]
2517        args = TensAdd._tensAdd_flatten(args)
2518
2519        if not args:
2520            return Integer(0)
2521
2522        if len(args) == 1 and not isinstance(args[0], TensExpr):
2523            return args[0]
2524
2525        # replace auto-matrix indices so that they are the same in all addends
2526        args = TensAdd._tensAdd_check_automatrix(args)
2527
2528        # now check that all addends have the same indices:
2529        TensAdd._tensAdd_check(args)
2530
2531        # if TensAdd has only 1 TensMul element in its `args`:
2532        if len(args) == 1 and isinstance(args[0], TensMul):
2533            obj = Basic.__new__(cls, *args, **kw_args)
2534            return obj
2535
2536        # TODO: do not or do canonicalize by default?
2537        # Technically, one may wish to have additions of non-canonicalized
2538        # tensors. This feature should be removed in the future.
2539        # Unfortunately this would require to rewrite a lot of tests.
2540        # canonicalize all TensMul
2541        args = [canon_bp(x) for x in args if x]
2542        args = [x for x in args if x]
2543
2544        # if there are no more args (i.e. have cancelled out),
2545        # just return zero:
2546        if not args:
2547            return Integer(0)
2548
2549        if len(args) == 1:
2550            return args[0]
2551
2552        # collect canonicalized terms
2553        def sort_key(t):
2554            x = get_tids(t)
2555            return x.components, x.free, x.dum
2556        args.sort(key=sort_key)
2557        args = TensAdd._tensAdd_collect_terms(args)
2558        if not args:
2559            return Integer(0)
2560        # it there is only a component tensor return it
2561        if len(args) == 1:
2562            return args[0]
2563
2564        obj = Basic.__new__(cls, *args, **kw_args)
2565        return obj
2566
2567    @staticmethod
2568    def _tensAdd_flatten(args):
2569        # flatten TensAdd, coerce terms which are not tensors to tensors
2570
2571        if not all(isinstance(x, TensExpr) for x in args):
2572            args_expanded = []
2573            for x in args:
2574                if isinstance(x, TensAdd):
2575                    args_expanded.extend(list(x.args))
2576                else:
2577                    args_expanded.append(x)
2578            args_tensor = []
2579            args_scalar = []
2580            for x in args_expanded:
2581                if isinstance(x, TensExpr) and x.coeff:
2582                    args_tensor.append(x)
2583                if not isinstance(x, TensExpr):
2584                    args_scalar.append(x)
2585            t1 = TensMul.from_data(Add(*args_scalar), [], [], [])
2586            args = [t1] + args_tensor
2587        a = []
2588        for x in args:
2589            if isinstance(x, TensAdd):
2590                a.extend(list(x.args))
2591            else:
2592                a.append(x)
2593        args = [x for x in a if x.coeff]
2594        return args
2595
2596    @staticmethod
2597    def _tensAdd_check_automatrix(args):
2598        # check that all automatrix indices are the same.
2599
2600        # if there are no addends, just return.
2601        if not args:
2602            return args
2603
2604        # @type auto_left_types: set
2605        auto_left_types = set()
2606        auto_right_types = set()
2607        args_auto_left_types = []
2608        args_auto_right_types = []
2609        for i, arg in enumerate(args):
2610            arg_auto_left_types = set()
2611            arg_auto_right_types = set()
2612            for index in get_indices(arg):
2613                # @type index: TensorIndex
2614                if index in (index._tensortype.auto_left, -index._tensortype.auto_left):
2615                    auto_left_types.add(index._tensortype)
2616                    arg_auto_left_types.add(index._tensortype)
2617                if index in (index._tensortype.auto_right, -index._tensortype.auto_right):
2618                    auto_right_types.add(index._tensortype)
2619                    arg_auto_right_types.add(index._tensortype)
2620            args_auto_left_types.append(arg_auto_left_types)
2621            args_auto_right_types.append(arg_auto_right_types)
2622        for arg, aas_left, aas_right in zip(args, args_auto_left_types, args_auto_right_types):
2623            missing_left = auto_left_types - aas_left
2624            missing_right = auto_right_types - aas_right
2625            missing_intersection = missing_left & missing_right
2626            for j in missing_intersection:
2627                args[i] *= j.delta(j.auto_left, -j.auto_right)
2628            if missing_left != missing_right:
2629                raise ValueError('cannot determine how to add auto-matrix indices on some args')
2630
2631        return args
2632
2633    @staticmethod
2634    def _tensAdd_check(args):
2635        # check that all addends have the same free indices
2636        indices0 = {x[0] for x in get_tids(args[0]).free}
2637        list_indices = [{y[0] for y in get_tids(x).free} for x in args[1:]]
2638        if not all(x == indices0 for x in list_indices):
2639            raise ValueError('all tensors must have the same indices')
2640
2641    @staticmethod
2642    def _tensAdd_collect_terms(args):
2643        # collect TensMul terms differing at most by their coefficient
2644        a = []
2645        prev = args[0]
2646        prev_coeff = get_coeff(prev)
2647        changed = False
2648
2649        for x in args[1:]:
2650            # if x and prev have the same tensor, update the coeff of prev
2651            x_tids = get_tids(x)
2652            prev_tids = get_tids(prev)
2653            if x_tids.components == prev_tids.components \
2654                    and x_tids.free == prev_tids.free and x_tids.dum == prev_tids.dum:
2655                prev_coeff = prev_coeff + get_coeff(x)
2656                changed = True
2657                op = 0
2658            else:
2659                # x and prev are different; if not changed, prev has not
2660                # been updated; store it
2661                if not changed:
2662                    a.append(prev)
2663                else:
2664                    # get a tensor from prev with coeff=prev_coeff and store it
2665                    if prev_coeff:
2666                        t = TensMul.from_data(prev_coeff, prev_tids.components,
2667                                              prev_tids.free, prev_tids.dum)
2668                        a.append(t)
2669                # move x to prev
2670                op = 1
2671                prev = x
2672                prev_coeff = get_coeff(x)
2673                changed = False
2674        # if the case op=0 prev was not stored; store it now
2675        # in the case op=1 x was not stored; store it now (as prev)
2676        if op == 0 and prev_coeff:
2677            prev = TensMul.from_data(prev_coeff, prev_tids.components, prev_tids.free, prev_tids.dum)
2678            a.append(prev)
2679        elif op == 1:
2680            a.append(prev)
2681        return a
2682
2683    @property
2684    def rank(self):
2685        return self.args[0].rank
2686
2687    @property
2688    def free_args(self):
2689        return self.args[0].free_args
2690
2691    def __call__(self, *indices):
2692        """Returns tensor with ordered free indices replaced by ``indices``
2693
2694        Parameters
2695        ==========
2696
2697        indices
2698
2699        Examples
2700        ========
2701
2702        >>> D = Symbol('D')
2703        >>> Lorentz = TensorIndexType('Lorentz', dim=D, dummy_fmt='L')
2704        >>> i0, i1, i2, i3, i4 = tensor_indices('i0:5', Lorentz)
2705        >>> p, q = tensorhead('p q', [Lorentz], [[1]])
2706        >>> g = Lorentz.metric
2707        >>> t = p(i0)*p(i1) + g(i0, i1)*q(i2)*q(-i2)
2708        >>> t(i0, i2)
2709        metric(i0, i2)*q(L_0)*q(-L_0) + p(i0)*p(i2)
2710        >>> t(i0, i1) - t(i1, i0)
2711        0
2712
2713        """
2714        free_args = self.free_args
2715        indices = list(indices)
2716        if [x._tensortype for x in indices] != [x._tensortype for x in free_args]:
2717            raise ValueError('incompatible types')
2718        if indices == free_args:
2719            return self
2720        index_tuples = list(zip(free_args, indices))
2721        a = [x.func(*x.fun_eval(*index_tuples).args) for x in self.args]
2722        res = TensAdd(*a)
2723
2724        return res
2725
2726    def canon_bp(self):
2727        """
2728        Canonicalize using the Butler-Portugal algorithm for canonicalization
2729        under monoterm symmetries.
2730
2731        """
2732        args = [x.canon_bp() for x in self.args]
2733        res = TensAdd(*args)
2734        return res
2735
2736    def equals(self, other):
2737        other = sympify(other)
2738        if isinstance(other, TensMul) and other._coeff == 0:
2739            return all(x._coeff == 0 for x in self.args)
2740        if isinstance(other, TensExpr):
2741            if self.rank != other.rank:
2742                return False
2743        if isinstance(other, TensAdd):
2744            if set(self.args) != set(other.args):
2745                return False
2746            else:
2747                return True
2748        t = self - other
2749        if not isinstance(t, TensExpr):
2750            return t == 0
2751        else:
2752            if isinstance(t, TensMul):
2753                return t._coeff == 0
2754            else:
2755                return all(x._coeff == 0 for x in t.args)
2756
2757    def __add__(self, other):
2758        return TensAdd(self, other)
2759
2760    def __radd__(self, other):
2761        return TensAdd(other, self)
2762
2763    def __sub__(self, other):
2764        return TensAdd(self, -other)
2765
2766    def __rsub__(self, other):
2767        return TensAdd(other, -self)
2768
2769    def __mul__(self, other):
2770        return TensAdd(*(x*other for x in self.args))
2771
2772    def __rmul__(self, other):
2773        return self*other
2774
2775    def __truediv__(self, other):
2776        other = sympify(other)
2777        if isinstance(other, TensExpr):
2778            raise ValueError('cannot divide by a tensor')
2779        return TensAdd(*(x/other for x in self.args))
2780
2781    def __rtruediv__(self, other):
2782        raise ValueError('cannot divide by a tensor')
2783
2784    def __getitem__(self, item):
2785        return self.data[item]
2786
2787    def contract_delta(self, delta):
2788        args = [x.contract_delta(delta) for x in self.args]
2789        t = TensAdd(*args)
2790        return canon_bp(t)
2791
2792    def contract_metric(self, g):
2793        """
2794        Raise or lower indices with the metric ``g``
2795
2796        Parameters
2797        ==========
2798
2799        g :  metric
2800
2801        contract_all : if True, eliminate all ``g`` which are contracted
2802
2803        Notes
2804        =====
2805
2806        See Also
2807        ========
2808
2809        TensorIndexType
2810
2811        """
2812        args = [contract_metric(x, g) for x in self.args]
2813        t = TensAdd(*args)
2814        return canon_bp(t)
2815
2816    def fun_eval(self, *index_tuples):
2817        """
2818        Return a tensor with free indices substituted according to ``index_tuples``
2819
2820        Parameters
2821        ==========
2822
2823        index_types : list of tuples ``(old_index, new_index)``
2824
2825        Examples
2826        ========
2827
2828        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
2829        >>> i, j, k, l = tensor_indices('i j k l', Lorentz)
2830        >>> A, B = tensorhead('A B', [Lorentz]*2, [[1]*2])
2831        >>> t = A(i, k)*B(-k, -j) + A(i, -j)
2832        >>> t.fun_eval((i, k), (-j, l))
2833        A(k, L_0)*B(l, -L_0) + A(k, l)
2834
2835        """
2836        args = self.args
2837        args1 = []
2838        for x in args:
2839            y = x.fun_eval(*index_tuples)
2840            args1.append(y)
2841        return TensAdd(*args1)
2842
2843    def substitute_indices(self, *index_tuples):
2844        """
2845        Return a tensor with free indices substituted according to ``index_tuples``
2846
2847        Parameters
2848        ==========
2849
2850        index_types : list of tuples ``(old_index, new_index)``
2851
2852        Examples
2853        ========
2854
2855        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
2856        >>> i, j, k, l = tensor_indices('i j k l', Lorentz)
2857        >>> A, B = tensorhead('A B', [Lorentz]*2, [[1]*2])
2858        >>> t = A(i, k)*B(-k, -j)
2859        >>> t
2860        A(i, L_0)*B(-L_0, -j)
2861        >>> t.substitute_indices((i, j), (j, k))
2862        A(j, L_0)*B(-L_0, -k)
2863
2864        """
2865        args = self.args
2866        args1 = []
2867        for x in args:
2868            y = x.substitute_indices(*index_tuples)
2869            args1.append(y)
2870        return TensAdd(*args1)
2871
2872    def _print(self):
2873        a = []
2874        args = self.args
2875        for x in args:
2876            a.append(str(x))
2877        a.sort()
2878        s = ' + '.join(a)
2879        s = s.replace('+ -', '- ')
2880        return s
2881
2882    @staticmethod
2883    def from_TIDS_list(coeff, tids_list):
2884        """
2885        Given a list of coefficients and a list of ``TIDS`` objects, construct
2886        a ``TensAdd`` instance, equivalent to the one that would result from
2887        creating single instances of ``TensMul`` and then adding them.
2888
2889        Examples
2890        ========
2891
2892        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
2893        >>> i, j = tensor_indices('i j', Lorentz)
2894        >>> A, B = tensorhead('A B', [Lorentz]*2, [[1]*2])
2895        >>> ea = 3*A(i, j)
2896        >>> eb = 2*B(j, i)
2897        >>> t1 = ea._tids
2898        >>> t2 = eb._tids
2899        >>> c1 = ea.coeff
2900        >>> c2 = eb.coeff
2901        >>> TensAdd.from_TIDS_list([c1, c2], [t1, t2])
2902        2*B(i, j) + 3*A(i, j)
2903
2904        If the coefficient parameter is a scalar, then it will be applied
2905        as a coefficient on all ``TIDS`` objects.
2906
2907        >>> TensAdd.from_TIDS_list(4, [t1, t2])
2908        4*A(i, j) + 4*B(i, j)
2909
2910        """
2911        if not isinstance(coeff, (list, tuple, Tuple)):
2912            coeff = [coeff] * len(tids_list)
2913        tensmul_list = [TensMul.from_TIDS(c, t) for c, t in zip(coeff, tids_list)]
2914        return TensAdd(*tensmul_list)
2915
2916    @property
2917    def data(self):
2918        return _tensor_data_substitution_dict[self]
2919
2920
2921@doctest_depends_on(modules=('numpy',))
2922class Tensor(TensExpr):
2923    """
2924    Base tensor class, i.e. this represents a tensor, the single unit to be
2925    put into an expression.
2926
2927    This object is usually created from a ``TensorHead``, by attaching indices
2928    to it. Indices preceded by a minus sign are considered contravariant,
2929    otherwise covariant.
2930
2931    Examples
2932    ========
2933
2934    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
2935    >>> mu, nu = tensor_indices('mu nu', Lorentz)
2936    >>> A = tensorhead('A', [Lorentz, Lorentz], [[1], [1]])
2937    >>> A(mu, -nu)
2938    A(mu, -nu)
2939    >>> A(mu, -mu)
2940    A(L_0, -L_0)
2941
2942    """
2943
2944    is_commutative = False
2945
2946    def __new__(cls, tensor_head, indices, **kw_args):
2947        tids = TIDS.from_components_and_indices((tensor_head,), indices)
2948        obj = Basic.__new__(cls, tensor_head, Tuple(*indices), **kw_args)
2949        obj._tids = tids
2950        obj._indices = indices
2951        obj._is_canon_bp = kw_args.get('is_canon_bp', False)
2952        return obj
2953
2954    @staticmethod
2955    def _new_with_dummy_replacement(tensor_head, indices, **kw_args):
2956        tids = TIDS.from_components_and_indices((tensor_head,), indices)
2957        indices = tids.get_indices()
2958        return Tensor(tensor_head, indices, **kw_args)
2959
2960    @property
2961    def is_canon_bp(self):
2962        return self._is_canon_bp
2963
2964    @property
2965    def indices(self):
2966        return self._indices
2967
2968    @property
2969    def free(self):
2970        return self._tids.free
2971
2972    @property
2973    def dum(self):
2974        return self._tids.dum
2975
2976    @property
2977    def rank(self):
2978        return len(self.free)
2979
2980    @property
2981    def free_args(self):
2982        return sorted(x[0] for x in self.free)
2983
2984    def perm2tensor(self, g, canon_bp=False):
2985        """
2986        Returns the tensor corresponding to the permutation ``g``
2987
2988        For further details, see the method in ``TIDS`` with the same name.
2989
2990        """
2991        return perm2tensor(self, g, canon_bp)
2992
2993    def canon_bp(self):
2994        if self._is_canon_bp:
2995            return self
2996        g, dummies, msym, v = self._tids.canon_args()
2997        can = canonicalize(g, dummies, msym, *v)
2998        if can == 0:
2999            return Integer(0)
3000        tensor = self.perm2tensor(can, True)
3001        return tensor
3002
3003    @property
3004    def types(self):
3005        return get_tids(self).components[0]._types
3006
3007    @property
3008    def coeff(self):
3009        return Integer(1)
3010
3011    @property
3012    def component(self):
3013        return self.args[0]
3014
3015    @property
3016    def components(self):
3017        return [self.args[0]]
3018
3019    def split(self):
3020        return [self]
3021
3022    def get_indices(self):
3023        """Get a list of indices, corresponding to those of the tensor."""
3024        return self._tids.get_indices()
3025
3026    def substitute_indices(self, *index_tuples):
3027        return substitute_indices(self, *index_tuples)
3028
3029    def __call__(self, *indices):
3030        """Returns tensor with ordered free indices replaced by ``indices``
3031
3032        Examples
3033        ========
3034
3035        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3036        >>> i0, i1, i2, i3, i4 = tensor_indices('i0:5', Lorentz)
3037        >>> A = tensorhead('A', [Lorentz]*5, [[1]*5])
3038        >>> t = A(i2, i1, -i2, -i3, i4)
3039        >>> t
3040        A(L_0, i1, -L_0, -i3, i4)
3041        >>> t(i1, i2, i3)
3042        A(L_0, i1, -L_0, i2, i3)
3043
3044        """
3045        free_args = self.free_args
3046        indices = list(indices)
3047        if [x._tensortype for x in indices] != [x._tensortype for x in free_args]:
3048            raise ValueError('incompatible types')
3049        if indices == free_args:
3050            return self
3051        t = self.fun_eval(*list(zip(free_args, indices)))
3052
3053        # object is rebuilt in order to make sure that all contracted indices
3054        # get recognized as dummies, but only if there are contracted indices.
3055        if len({i if i.is_up else -i for i in indices}) != len(indices):
3056            return t.func(*t.args)
3057        return t
3058
3059    def fun_eval(self, *index_tuples):
3060        free = self.free
3061        free1 = []
3062        for j, ipos, cpos in free:
3063            # search j in index_tuples
3064            for i, v in index_tuples:
3065                if i == j:
3066                    free1.append((v, ipos, cpos))
3067                    break
3068            else:
3069                free1.append((j, ipos, cpos))
3070        return TensMul.from_data(self.coeff, self.components, free1, self.dum)
3071
3072    # TODO: put this into TensExpr?
3073    def __iter__(self):
3074        return self.data.flatten().__iter__()
3075
3076    # TODO: put this into TensExpr?
3077    def __getitem__(self, item):
3078        return self.data[item]
3079
3080    @property
3081    def data(self):
3082        return _tensor_data_substitution_dict[self]
3083
3084    @data.setter
3085    def data(self, data):
3086        # TODO: check data compatibility with properties of tensor.
3087        _tensor_data_substitution_dict[self] = data
3088
3089    def __mul__(self, other):
3090        if isinstance(other, TensAdd):
3091            return TensAdd(*[self*arg for arg in other.args])
3092        tmul = TensMul(self, other)
3093        return tmul
3094
3095    def __rmul__(self, other):
3096        return TensMul(other, self)
3097
3098    def __truediv__(self, other):
3099        if isinstance(other, TensExpr):
3100            raise ValueError('cannot divide by a tensor')
3101        return TensMul(self, Integer(1)/other, is_canon_bp=self.is_canon_bp)
3102
3103    def __rtruediv__(self, other):
3104        raise ValueError('cannot divide by a tensor')
3105
3106    def __add__(self, other):
3107        return TensAdd(self, other)
3108
3109    def __radd__(self, other):
3110        return TensAdd(other, self)
3111
3112    def __sub__(self, other):
3113        return TensAdd(self, -other)
3114
3115    def __neg__(self):
3116        return TensMul(Integer(-1), self)
3117
3118    def _print(self):
3119        indices = [str(ind) for ind in self.indices]
3120        component = self.component
3121        if component.rank > 0:
3122            return f"{component.name}({', '.join(indices)})"
3123        else:
3124            return f'{component.name}'
3125
3126    def equals(self, other):
3127        if other == 0:
3128            return self.coeff == 0
3129        other = sympify(other)
3130        if not isinstance(other, TensExpr):
3131            assert not self.components
3132            return Integer(1) == other
3133
3134        def _get_compar_comp(self):
3135            t = self.canon_bp()
3136            r = (t.coeff, tuple(t.components),
3137                 tuple(sorted(t.free)), tuple(sorted(t.dum)))
3138            return r
3139
3140        return _get_compar_comp(self) == _get_compar_comp(other)
3141
3142    def contract_metric(self, metric):
3143        tids, sign = get_tids(self).contract_metric(metric)
3144        return TensMul.from_TIDS(sign, tids)
3145
3146
3147@doctest_depends_on(modules=('numpy',))
3148class TensMul(TensExpr):
3149    """
3150    Product of tensors
3151
3152    Parameters
3153    ==========
3154
3155    coeff : Diofant coefficient of the tensor
3156    args
3157
3158    Attributes
3159    ==========
3160
3161    ``components`` : list of ``TensorHead`` of the component tensors
3162    ``types`` : list of nonrepeated ``TensorIndexType``
3163    ``free`` : list of ``(ind, ipos, icomp)``, see Notes
3164    ``dum`` : list of ``(ipos1, ipos2, icomp1, icomp2)``, see Notes
3165    ``ext_rank`` : tuple
3166        rank of the tensor counting the dummy indices
3167    ``rank`` : tuple
3168        rank of the tensor
3169    ``coeff`` : Expr
3170        Diofant coefficient of the tensor
3171    ``free_args`` : list
3172        list of the free indices in sorted order
3173    ``is_canon_bp`` : ``True`` if the tensor in in canonical form
3174
3175    Notes
3176    =====
3177
3178    ``args[0]``   list of ``TensorHead`` of the component tensors.
3179
3180    ``args[1]``   list of ``(ind, ipos, icomp)``
3181    where ``ind`` is a free index, ``ipos`` is the slot position
3182    of ``ind`` in the ``icomp``-th component tensor.
3183
3184    ``args[2]`` list of tuples representing dummy indices.
3185    ``(ipos1, ipos2, icomp1, icomp2)`` indicates that the contravariant
3186    dummy index is the ``ipos1``-th slot position in the ``icomp1``-th
3187    component tensor; the corresponding covariant index is
3188    in the ``ipos2`` slot position in the ``icomp2``-th component tensor.
3189
3190    """
3191
3192    def __new__(cls, *args, **kw_args):
3193        # make sure everything is sympified:
3194        args = [sympify(arg) for arg in args]
3195
3196        # flatten:
3197        args = TensMul._flatten(args)
3198
3199        is_canon_bp = kw_args.get('is_canon_bp', False)
3200        if not any(isinstance(arg, TensExpr) for arg in args):
3201            tids = TIDS([], [], [])
3202        else:
3203            tids_list = [arg._tids for arg in args if isinstance(arg, (Tensor, TensMul))]
3204            if len(tids_list) == 1:
3205                for arg in args:
3206                    if not isinstance(arg, Tensor):
3207                        continue
3208                    is_canon_bp = kw_args.get('is_canon_bp', arg._is_canon_bp)
3209            tids = functools.reduce(lambda a, b: a*b, tids_list)
3210
3211        coeff = functools.reduce(lambda a, b: a*b, [Integer(1)] + [arg for arg in args if not isinstance(arg, TensExpr)])
3212        args = tids.get_tensors()
3213        if coeff != 1:
3214            args = [coeff] + args
3215        if len(args) == 1:
3216            return args[0]
3217
3218        obj = Basic.__new__(cls, *args)
3219        obj._types = []
3220        for t in tids.components:
3221            obj._types.extend(t._types)
3222        obj._tids = tids
3223        obj._ext_rank = len(obj._tids.free) + 2*len(obj._tids.dum)
3224        obj._coeff = coeff
3225        obj._is_canon_bp = is_canon_bp
3226        return obj
3227
3228    @staticmethod
3229    def _flatten(args):
3230        a = []
3231        for arg in args:
3232            if isinstance(arg, TensMul):
3233                a.extend(arg.args)
3234            else:
3235                a.append(arg)
3236        return a
3237
3238    @staticmethod
3239    def from_data(coeff, components, free, dum, **kw_args):
3240        tids = TIDS(components, free, dum)
3241        return TensMul.from_TIDS(coeff, tids, **kw_args)
3242
3243    @staticmethod
3244    def from_TIDS(coeff, tids, **kw_args):
3245        return TensMul(coeff, *tids.get_tensors(), **kw_args)
3246
3247    @property
3248    def free_args(self):
3249        return sorted(x[0] for x in self.free)
3250
3251    @property
3252    def components(self):
3253        return self._tids.components[:]
3254
3255    @property
3256    def free(self):
3257        return self._tids.free[:]
3258
3259    @property
3260    def coeff(self):
3261        return self._coeff
3262
3263    @property
3264    def dum(self):
3265        return self._tids.dum[:]
3266
3267    @property
3268    def rank(self):
3269        return len(self.free)
3270
3271    @property
3272    def types(self):
3273        return self._types[:]
3274
3275    def equals(self, other):
3276        if other == 0:
3277            return self.coeff == 0
3278        other = sympify(other)
3279        if not isinstance(other, TensExpr):
3280            assert not self.components
3281            return self._coeff == other
3282
3283        def _get_compar_comp(self):
3284            t = self.canon_bp()
3285            r = (get_coeff(t), tuple(t.components),
3286                 tuple(sorted(t.free)), tuple(sorted(t.dum)))
3287            return r
3288
3289        return _get_compar_comp(self) == _get_compar_comp(other)
3290
3291    def get_indices(self):
3292        """
3293        Returns the list of indices of the tensor
3294
3295        The indices are listed in the order in which they appear in the
3296        component tensors.
3297        The dummy indices are given a name which does not collide with
3298        the names of the free indices.
3299
3300        Examples
3301        ========
3302
3303        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3304        >>> m0, m1, m2 = tensor_indices('m0 m1 m2', Lorentz)
3305        >>> g = Lorentz.metric
3306        >>> p, q = tensorhead('p q', [Lorentz], [[1]])
3307        >>> t = p(m1)*g(m0, m2)
3308        >>> t.get_indices()
3309        [m1, m0, m2]
3310
3311        """
3312        return self._tids.get_indices()
3313
3314    def split(self):
3315        """
3316        Returns a list of tensors, whose product is ``self``
3317
3318        Dummy indices contracted among different tensor components
3319        become free indices with the same name as the one used to
3320        represent the dummy indices.
3321
3322        Examples
3323        ========
3324
3325        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3326        >>> a, b, c, d = tensor_indices('a b c d', Lorentz)
3327        >>> A, B = tensorhead('A B', [Lorentz]*2, [[1]*2])
3328        >>> t = A(a, b)*B(-b, c)
3329        >>> t
3330        A(a, L_0)*B(-L_0, c)
3331        >>> t.split()
3332        [A(a, L_0), B(-L_0, c)]
3333
3334        """
3335        if self.args == ():
3336            return [self]
3337        splitp = []
3338        res = 1
3339        for arg in self.args:
3340            if isinstance(arg, Tensor):
3341                splitp.append(res*arg)
3342                res = 1
3343            else:
3344                res *= arg
3345        return splitp
3346
3347    def __add__(self, other):
3348        return TensAdd(self, other)
3349
3350    def __radd__(self, other):
3351        return TensAdd(other, self)
3352
3353    def __sub__(self, other):
3354        return TensAdd(self, -other)
3355
3356    def __rsub__(self, other):
3357        return TensAdd(other, -self)
3358
3359    def __mul__(self, other):
3360        """
3361        Multiply two tensors using Einstein summation convention.
3362
3363        If the two tensors have an index in common, one contravariant
3364        and the other covariant, in their product the indices are summed
3365
3366        Examples
3367        ========
3368
3369        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3370        >>> m0, m1, m2 = tensor_indices('m0 m1 m2', Lorentz)
3371        >>> g = Lorentz.metric
3372        >>> p, q = tensorhead('p q', [Lorentz], [[1]])
3373        >>> t1 = p(m0)
3374        >>> t2 = q(-m0)
3375        >>> t1*t2
3376        p(L_0)*q(-L_0)
3377
3378        """
3379        other = sympify(other)
3380        if not isinstance(other, TensExpr):
3381            coeff = self.coeff*other
3382            tmul = TensMul.from_TIDS(coeff, self._tids, is_canon_bp=self._is_canon_bp)
3383            return tmul
3384        if isinstance(other, TensAdd):
3385            return TensAdd(*[self*x for x in other.args])
3386
3387        new_tids = self._tids*other._tids
3388        coeff = self.coeff*other.coeff
3389        tmul = TensMul.from_TIDS(coeff, new_tids)
3390        return tmul
3391
3392    def __rmul__(self, other):
3393        other = sympify(other)
3394        coeff = other*self._coeff
3395        tmul = TensMul.from_TIDS(coeff, self._tids)
3396        return tmul
3397
3398    def __truediv__(self, other):
3399        other = sympify(other)
3400        if isinstance(other, TensExpr):
3401            raise ValueError('cannot divide by a tensor')
3402        coeff = self._coeff/other
3403        tmul = TensMul.from_TIDS(coeff, self._tids, is_canon_bp=self._is_canon_bp)
3404        return tmul
3405
3406    def __getitem__(self, item):
3407        return self.data[item]
3408
3409    def sorted_components(self):
3410        """
3411        Returns a tensor with sorted components
3412        calling the corresponding method in a ``TIDS`` object.
3413
3414        """
3415        new_tids, sign = self._tids.sorted_components()
3416        coeff = -self.coeff if sign == -1 else self.coeff
3417        t = TensMul.from_TIDS(coeff, new_tids)
3418        return t
3419
3420    def perm2tensor(self, g, canon_bp=False):
3421        """
3422        Returns the tensor corresponding to the permutation ``g``
3423
3424        For further details, see the method in ``TIDS`` with the same name.
3425
3426        """
3427        return perm2tensor(self, g, canon_bp)
3428
3429    def canon_bp(self):
3430        """
3431        Canonicalize using the Butler-Portugal algorithm for canonicalization
3432        under monoterm symmetries.
3433
3434        Examples
3435        ========
3436
3437        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3438        >>> m0, m1, m2 = tensor_indices('m0 m1 m2', Lorentz)
3439        >>> A = tensorhead('A', [Lorentz]*2, [[2]])
3440        >>> t = A(m0, -m1)*A(m1, -m0)
3441        >>> t.canon_bp()
3442        -A(L_0, L_1)*A(-L_0, -L_1)
3443        >>> t = A(m0, -m1)*A(m1, -m2)*A(m2, -m0)
3444        >>> t.canon_bp()
3445        0
3446
3447        """
3448        if self._is_canon_bp:
3449            return self
3450        if not self.components:
3451            return self
3452        t = self.sorted_components()
3453        g, dummies, msym, v = t._tids.canon_args()
3454        can = canonicalize(g, dummies, msym, *v)
3455        if can == 0:
3456            return Integer(0)
3457        tmul = t.perm2tensor(can, True)
3458        return tmul
3459
3460    def contract_delta(self, delta):
3461        t = self.contract_metric(delta)
3462        return t
3463
3464    def contract_metric(self, g):
3465        """
3466        Raise or lower indices with the metric ``g``
3467
3468        Parameters
3469        ==========
3470
3471        g : metric
3472
3473        Notes
3474        =====
3475
3476        See Also
3477        ========
3478
3479        TensorIndexType
3480
3481        Examples
3482        ========
3483
3484        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3485        >>> m0, m1, m2 = tensor_indices('m0 m1 m2', Lorentz)
3486        >>> g = Lorentz.metric
3487        >>> p, q = tensorhead('p q', [Lorentz], [[1]])
3488        >>> t = p(m0)*q(m1)*g(-m0, -m1)
3489        >>> t.canon_bp()
3490        metric(L_0, L_1)*p(-L_0)*q(-L_1)
3491        >>> t.contract_metric(g).canon_bp()
3492        p(L_0)*q(-L_0)
3493
3494        """
3495        tids, sign = get_tids(self).contract_metric(g)
3496        res = TensMul.from_TIDS(sign*self.coeff, tids)
3497        return res
3498
3499    def substitute_indices(self, *index_tuples):
3500        return substitute_indices(self, *index_tuples)
3501
3502    def fun_eval(self, *index_tuples):
3503        """
3504        Return a tensor with free indices substituted according to ``index_tuples``
3505
3506        ``index_types`` list of tuples ``(old_index, new_index)``
3507
3508        Examples
3509        ========
3510
3511        >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3512        >>> i, j, k, l = tensor_indices('i j k l', Lorentz)
3513        >>> A, B = tensorhead('A B', [Lorentz]*2, [[1]*2])
3514        >>> t = A(i, k)*B(-k, -j)
3515        >>> t
3516        A(i, L_0)*B(-L_0, -j)
3517        >>> t.fun_eval((i, k), (-j, l))
3518        A(k, L_0)*B(-L_0, l)
3519
3520        """
3521        free = self.free
3522        free1 = []
3523        for j, ipos, cpos in free:
3524            # search j in index_tuples
3525            for i, v in index_tuples:
3526                if i == j:
3527                    free1.append((v, ipos, cpos))
3528                    break
3529            else:
3530                free1.append((j, ipos, cpos))
3531        return TensMul.from_data(self.coeff, self.components, free1, self.dum)
3532
3533    def __call__(self, *indices):
3534        """Returns tensor product with ordered free indices replaced by ``indices``
3535
3536        Examples
3537        ========
3538
3539        >>> D = Symbol('D')
3540        >>> Lorentz = TensorIndexType('Lorentz', dim=D, dummy_fmt='L')
3541        >>> i0, i1, i2, i3, i4 = tensor_indices('i0:5', Lorentz)
3542        >>> g = Lorentz.metric
3543        >>> p, q = tensorhead('p q', [Lorentz], [[1]])
3544        >>> t = p(i0)*q(i1)*q(-i1)
3545        >>> t(i1)
3546        p(i1)*q(L_0)*q(-L_0)
3547
3548        """
3549        free_args = self.free_args
3550        indices = list(indices)
3551        if [x._tensortype for x in indices] != [x._tensortype for x in free_args]:
3552            raise ValueError('incompatible types')
3553        if indices == free_args:
3554            return self
3555        t = self.fun_eval(*list(zip(free_args, indices)))
3556
3557        # object is rebuilt in order to make sure that all contracted indices
3558        # get recognized as dummies, but only if there are contracted indices.
3559        if len({i if i.is_up else -i for i in indices}) != len(indices):
3560            return t.func(*t.args)
3561        return t
3562
3563    def _print(self):
3564        args = self.args
3565
3566        def get_str(arg):
3567            return str(arg) if arg.is_Atom or isinstance(arg, TensExpr) else f'({arg!s})'
3568
3569        if not args:
3570            # no arguments is equivalent to "1", i.e. TensMul().
3571            # If tensors are constructed correctly, this should never occur.
3572            return '1'
3573        if self.coeff == -1:
3574            # expressions like "-A(a)"
3575            return '-'+'*'.join([get_str(arg) for arg in args[1:]])
3576
3577        # prints expressions like "A(a)", "3*A(a)", "(1+x)*A(a)"
3578        return '*'.join([get_str(arg) for arg in self.args])
3579
3580    @property
3581    def data(self):
3582        dat = _tensor_data_substitution_dict[self]
3583        if dat is not None:
3584            return self.coeff * dat
3585
3586    def __iter__(self):
3587        if self.data is None:
3588            raise ValueError('No iteration on abstract tensors')
3589        return self.data.flatten().__iter__()
3590
3591
3592def canon_bp(p):
3593    """Butler-Portugal canonicalization."""
3594    if isinstance(p, TensExpr):
3595        return p.canon_bp()
3596    return p
3597
3598
3599def tensor_mul(*a):
3600    """Product of tensors."""
3601    if not a:
3602        return TensMul.from_data(Integer(1), [], [], [])
3603    t = a[0]
3604    for tx in a[1:]:
3605        t = t*tx
3606    return t
3607
3608
3609def riemann_cyclic_replace(t_r):
3610    """Replace Riemann tensor with an equivalent expression.
3611
3612    ``R(m,n,p,q) -> 2/3*R(m,n,p,q) - 1/3*R(m,q,n,p) + 1/3*R(m,p,n,q)``
3613
3614    """
3615    free = sorted(t_r.free, key=lambda x: x[1])
3616    m, n, p, q = [x[0] for x in free]
3617    t0 = Rational(2, 3)*t_r
3618    t1 = - Rational(1, 3)*t_r.substitute_indices((m, m), (n, q), (p, n), (q, p))
3619    t2 = Rational(1, 3)*t_r.substitute_indices((m, m), (n, p), (p, n), (q, q))
3620    t3 = t0 + t1 + t2
3621    return t3
3622
3623
3624def riemann_cyclic(t2):
3625    """
3626    Replace each Riemann tensor with an equivalent expression
3627    satisfying the cyclic identity.
3628
3629    This trick is discussed in the reference guide to Cadabra.
3630
3631    Examples
3632    ========
3633
3634    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3635    >>> i, j, k, l = tensor_indices('i j k l', Lorentz)
3636    >>> R = tensorhead('R', [Lorentz]*4, [[2, 2]])
3637    >>> t = R(i, j, k, l)*(R(-i, -j, -k, -l) - 2*R(-i, -k, -j, -l))
3638    >>> riemann_cyclic(t)
3639    0
3640
3641    """
3642    if isinstance(t2, (TensMul, Tensor)):
3643        args = [t2]
3644    else:
3645        args = t2.args
3646    a1 = [x.split() for x in args]
3647    a2 = [[riemann_cyclic_replace(tx) for tx in y] for y in a1]
3648    a3 = [tensor_mul(*v) for v in a2]
3649    t3 = TensAdd(*a3)
3650    if not t3:
3651        return t3
3652    else:
3653        return canon_bp(t3)
3654
3655
3656def get_indices(t):
3657    if not isinstance(t, TensExpr):
3658        return ()
3659    return t.get_indices()
3660
3661
3662def get_tids(t):
3663    if isinstance(t, TensExpr):
3664        return t._tids
3665    return TIDS([], [], [])
3666
3667
3668def get_coeff(t):
3669    if isinstance(t, Tensor):
3670        return Integer(1)
3671    if isinstance(t, TensMul):
3672        return t.coeff
3673    if isinstance(t, TensExpr):
3674        raise ValueError('no coefficient associated to this tensor expression')
3675    return t
3676
3677
3678def contract_metric(t, g):
3679    if isinstance(t, TensExpr):
3680        return t.contract_metric(g)
3681    return t
3682
3683
3684def perm2tensor(t, g, canon_bp=False):
3685    """
3686    Returns the tensor corresponding to the permutation ``g``
3687
3688    For further details, see the method in ``TIDS`` with the same name.
3689
3690    """
3691    if not isinstance(t, TensExpr):
3692        return t
3693    new_tids = get_tids(t).perm2tensor(g, canon_bp)
3694    coeff = get_coeff(t)
3695    if g[-1] != len(g) - 1:
3696        coeff = -coeff
3697    res = TensMul.from_TIDS(coeff, new_tids, is_canon_bp=canon_bp)
3698    return res
3699
3700
3701def substitute_indices(t, *index_tuples):
3702    """
3703    Return a tensor with free indices substituted according to ``index_tuples``
3704
3705    ``index_types`` list of tuples ``(old_index, new_index)``
3706
3707    Note: this method will neither raise or lower the indices, it will just replace their symbol.
3708
3709    Examples
3710    ========
3711
3712    >>> Lorentz = TensorIndexType('Lorentz', dummy_fmt='L')
3713    >>> i, j, k, l = tensor_indices('i j k l', Lorentz)
3714    >>> A, B = tensorhead('A B', [Lorentz]*2, [[1]*2])
3715    >>> t = A(i, k)*B(-k, -j)
3716    >>> t
3717    A(i, L_0)*B(-L_0, -j)
3718    >>> t.substitute_indices((i, j), (j, k))
3719    A(j, L_0)*B(-L_0, -k)
3720
3721    """
3722    if not isinstance(t, TensExpr):
3723        return t
3724    free = t.free
3725    free1 = []
3726    for j, ipos, cpos in free:
3727        for i, v in index_tuples:
3728            if i._name == j._name and i._tensortype == j._tensortype:
3729                if i._is_up == j._is_up:
3730                    free1.append((v, ipos, cpos))
3731                else:
3732                    free1.append((-v, ipos, cpos))
3733                break
3734        else:
3735            free1.append((j, ipos, cpos))
3736
3737    t = TensMul.from_data(t.coeff, t.components, free1, t.dum)
3738    return t
3739