1from sympy import Symbol, Number, sympify
2from sympy import MutableDenseNDimArray, S
3from sympy.tensor.tensor import (Tensor, TensExpr, TensAdd, TensMul,
4                                 TensorIndex)
5
6
7class PartialDerivative(TensExpr):
8    """
9    Partial derivative for tensor expressions.
10
11    Examples
12    ========
13
14    >>> from sympy.tensor.tensor import TensorIndexType, TensorHead
15    >>> from sympy.tensor.toperators import PartialDerivative
16    >>> from sympy import symbols
17    >>> L = TensorIndexType("L")
18    >>> A = TensorHead("A", [L])
19    >>> i, j = symbols("i j")
20
21    >>> expr = PartialDerivative(A(i), A(j))
22    >>> expr
23    PartialDerivative(A(i), A(j))
24
25    The ``PartialDerivative`` object behaves like a tensorial expression:
26
27    >>> expr.get_indices()
28    [i, -j]
29
30    Indices can be contracted:
31
32    >>> expr = PartialDerivative(A(i), A(i))
33    >>> expr
34    PartialDerivative(A(L_0), A(L_0))
35    >>> expr.get_indices()
36    [L_0, -L_0]
37    """
38
39    def __new__(cls, expr, *variables):
40
41        # Flatten:
42        if isinstance(expr, PartialDerivative):
43            variables = expr.variables + variables
44            expr = expr.expr
45
46        args, indices, free, dum = cls._contract_indices_for_derivative(
47            S(expr), variables)
48
49        obj = TensExpr.__new__(cls, *args)
50
51        obj._indices = indices
52        obj._free = free
53        obj._dum = dum
54        return obj
55
56    @property
57    def coeff(self):
58        return S.One
59
60    @property
61    def nocoeff(self):
62        return self
63
64    @classmethod
65    def _contract_indices_for_derivative(cls, expr, variables):
66        variables_opposite_valence = []
67
68        for i in variables:
69            if isinstance(i, Tensor):
70                i_free_indices = i.get_free_indices()
71                variables_opposite_valence.append(
72                        i.xreplace({k: -k for k in i_free_indices}))
73            elif isinstance(i, Symbol):
74                variables_opposite_valence.append(i)
75
76        args, indices, free, dum = TensMul._tensMul_contract_indices(
77            [expr] + variables_opposite_valence, replace_indices=True)
78
79        for i in range(1, len(args)):
80            args_i = args[i]
81            if isinstance(args_i, Tensor):
82                i_indices = args[i].get_free_indices()
83                args[i] = args[i].xreplace({k: -k for k in i_indices})
84
85        return args, indices, free, dum
86
87    def doit(self):
88        args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables)
89
90        obj = self.func(*args)
91        obj._indices = indices
92        obj._free = free
93        obj._dum = dum
94
95        return obj
96
97    def _expand_partial_derivative(self):
98        args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables)
99
100        obj = self.func(*args)
101        obj._indices = indices
102        obj._free = free
103        obj._dum = dum
104
105        result = obj
106
107        if not args[0].free_symbols:
108            return S.Zero
109        elif isinstance(obj.expr, TensAdd):
110            # take care of sums of multi PDs
111            result = obj.expr.func(*[
112                    self.func(a, *obj.variables)._expand_partial_derivative()
113                    for a in result.expr.args])
114        elif isinstance(obj.expr, TensMul):
115            # take care of products of multi PDs
116            if len(obj.variables) == 1:
117                # derivative with respect to single variable
118                terms = []
119                mulargs = list(obj.expr.args)
120                for ind in range(len(mulargs)):
121                    if not isinstance(sympify(mulargs[ind]), Number):
122                        # a number coefficient is not considered for
123                        # expansion of PartialDerivative
124                        d = self.func(mulargs[ind], *obj.variables)._expand_partial_derivative()
125                        terms.append(TensMul(*(mulargs[:ind]
126                                               + [d]
127                                               + mulargs[(ind + 1):])))
128                result = TensAdd.fromiter(terms)
129            else:
130                # derivative with respect to multiple variables
131                # decompose:
132                # partial(expr, (u, v))
133                # = partial(partial(expr, u).doit(), v).doit()
134                result = obj.expr  # init with expr
135                for v in obj.variables:
136                    result = self.func(result, v)._expand_partial_derivative()
137                    # then throw PD on it
138
139        return result
140
141    def _perform_derivative(self):
142        result = self.expr
143        for v in self.variables:
144            if isinstance(result, TensExpr):
145                result = result._eval_partial_derivative(v)
146            else:
147                if v._diff_wrt:
148                    result = result._eval_derivative(v)
149                else:
150                    result = S.Zero
151        return result
152
153    def get_indices(self):
154        return self._indices
155
156    def get_free_indices(self):
157        free = sorted(self._free, key=lambda x: x[1])
158        return [i[0] for i in free]
159
160    def _replace_indices(self, repl):
161        expr = self.expr.xreplace(repl)
162        mirrored = {-k: -v for k, v in repl.items()}
163        variables = [i.xreplace(mirrored) for i in self.variables]
164        return self.func(expr, *variables)
165
166    @property
167    def expr(self):
168        return self.args[0]
169
170    @property
171    def variables(self):
172        return self.args[1:]
173
174    def _extract_data(self, replacement_dict):
175        from .array import derive_by_array, tensorcontraction
176        indices, array = self.expr._extract_data(replacement_dict)
177        for variable in self.variables:
178            var_indices, var_array = variable._extract_data(replacement_dict)
179            var_indices = [-i for i in var_indices]
180            coeff_array, var_array = zip(*[i.as_coeff_Mul() for i in var_array])
181            array = derive_by_array(array, var_array)
182            array = array.as_mutable()  # type: MutableDenseNDimArray
183            varindex = var_indices[0]  # type: TensorIndex
184            # Remove coefficients of base vector:
185            coeff_index = [0] + [slice(None) for i in range(len(indices))]
186            for i, coeff in enumerate(coeff_array):
187                coeff_index[0] = i
188                array[tuple(coeff_index)] /= coeff
189            if -varindex in indices:
190                pos = indices.index(-varindex)
191                array = tensorcontraction(array, (0, pos+1))
192                indices.pop(pos)
193            else:
194                indices.append(varindex)
195        return indices, array
196