1from __future__ import division, print_function, unicode_literals
2
3import re
4
5from dune.common.utility import isInteger
6
7from ufl import replace
8from ufl.log import UFLException
9from ufl.core.expr import Expr
10
11from dune.source.builtin import get, make_shared
12from dune.source.cplusplus import UnformattedExpression
13from dune.source.cplusplus import AccessModifier, Constructor, Declaration, Function, Method, NameSpace, Struct, TypeAlias, UnformattedBlock, Variable
14from dune.source.cplusplus import assign, construct, dereference, lambda_, nullptr, return_, this
15from dune.source.cplusplus import SourceWriter
16from dune.source.fem import declareFunctionSpace
17from dune.ufl.codegen import generateMethod
18
19from ufl.differentiation import Grad
20
21class EllipticModel:
22    version = "v1_1"
23    def __init__(self, dimDomain, dimRange, u, signature):
24        assert isInteger(dimRange)
25        self.dimDomain = dimDomain
26        self.dimRange = dimRange
27        self.trialFunction = u
28        self.init = []
29        self.vars = None
30        self.signature = signature + EllipticModel.version
31        self.field = "double"
32
33        self._constants = []
34        self._constantNames = {}
35        self._parameterNames = {}
36        self._coefficients = []
37
38        self.arg_r = Variable("RRangeType &", "result")
39        self.arg_dr = Variable("RJacobianRangeType &", "result")
40
41        self.arg_x = Variable("const Point &", "x")
42        self.arg_u = Variable("const DRangeType &", "u")
43        self.arg_du = Variable("const DJacobianRangeType &", "du")
44        self.arg_d2u = Variable("const DHessianRangeType &", "d2u")
45        self.arg_ubar = Variable("const DRangeType &", "ubar")
46        self.arg_dubar = Variable("const DJacobianRangeType &", "dubar")
47        self.arg_d2ubar = Variable("const DHessianRangeType &", "d2ubar")
48
49        self.arg_i = Variable("const IntersectionType &", "intersection")
50        self.arg_bndId = Variable("int", "bndId")
51
52        self.source = [assign(self.arg_r, construct("RRangeType", 0))]
53        self.linSource = [assign(self.arg_r, construct("RRangeType", 0))]
54        self.diffusiveFlux = [assign(self.arg_dr, construct("RJacobianRangeType", 0))]
55        self.linDiffusiveFlux = [assign(self.arg_dr, construct("RJacobianRangeType", 0))]
56        self.fluxDivergence = [assign(self.arg_r, construct("RRangeType", 0))]
57        self.alpha = [assign(self.arg_r, construct("RRangeType", 0))]
58        self.linAlpha = [assign(self.arg_r, construct("RRangeType", 0))]
59
60        self.hasDirichletBoundary = False
61        self.hasNeumanBoundary = False
62        self.isDirichletIntersection = [return_(False)]
63        self.dirichlet = [assign(self.arg_r, construct("RRangeType", 0))]
64        self.symmetric = False
65
66        self.baseName = "elliptic"
67        self.modelWrapper = "DiffusionModelWrapper< Model >"
68
69    def predefineCoefficients(self,predefined,x):
70        for coefficient, idx in self.coefficients.items():
71            for derivative in self.coefficient(idx, x):
72                predefined[coefficient] = derivative
73                coefficient = Grad(coefficient)
74        predefined.update({c: self.constant(i) for c, i in self.constants.items()})
75
76    def addCoefficient(self, dimRange, typeName, name=None, field="double"):
77        idx = len(self._coefficients)
78        self._coefficients.append({'typeName':typeName, 'dimRange': dimRange, 'name': name, 'field': field})
79        return idx
80
81    def addConstant(self, cppType, name=None, parameter=None):
82        idx = len(self._constants)
83        self._constants.append(cppType)
84        if name is not None:
85            self._constantNames[name] = idx
86        if parameter is not None:
87            self._parameterNames[parameter] = idx
88        return idx
89
90    def cppIdentifier(self,name,base,idx):
91        if re.match('^[a-zA-Z_][a-zA-Z0-9_]*$', name) is None:
92            return base+str(idx)
93        else:
94            return name
95    def cppTypeIdentifier(self,name,base,idx):
96        ret = self.cppIdentifier(name,base,idx)
97        return ret[0].upper() + ret[1:]
98    def cppVarIdentifier(self,name,base,idx):
99        ret = self.cppIdentifier(name,base,idx)
100        return ret[0].lower() + ret[1:]
101
102    def constant(self, idx):
103        return UnformattedExpression(self._constants[idx], 'constant< ' + str(idx) + ' >()')
104
105    def constant(self, idx):
106        return UnformattedExpression(self._constants[idx], 'constant< ' + str(idx) + ' >()')
107
108    def coefficient(self, idx, x):
109        coefficient = []
110        for t, n in (('RangeType', 'evaluate'), ('JacobianRangeType', 'jacobian'), ('HessianRangeType', 'hessian')):
111            result = Variable('typename std::tuple_element_t< ' + str(idx) + ', CoefficientFunctionSpaceTupleType >::' + t, 'result')
112            code = [Declaration(result),
113                    UnformattedExpression('void', 'std::get< ' + str(idx) + ' >( coefficients_ ).' + n + '( x, ' + result.name + ' )', uses=[result]),
114                    return_(result)]
115            coefficient += [lambda_(capture=[this], args=['auto x'], code=code)(x)]
116        return coefficient
117
118    @property
119    def hasCoefficients(self):
120        return bool(self._coefficients)
121
122    @property
123    def hasConstants(self):
124        return bool(self._constants)
125
126    def code(self, name=None, targs=None):
127        if targs is None:
128            targs = []
129        if name is None:
130            name = 'Model'
131        constants_ = Variable('std::tuple< ' + ', '.join('std::shared_ptr< ' + c  + ' >' for c in self._constants) + ' >', 'constants_')
132        # coefficients_ = Variable('std::tuple< ' + ', '.join(c['name'] if c['name'] is not None else 'Coefficient' + str(i) for i, c in enumerate(self._coefficients)) + ' >', 'coefficients_')
133        coefficients_ = Variable('std::tuple< ' + ', '.join(\
134                'Dune::Fem::ConstLocalFunction<' + self.cppTypeIdentifier(c['name'],"coefficient",i) + '> ' for i, c in enumerate(self._coefficients)) + ' >', 'coefficients_')
135        entity_ = Variable('const EntityType *', 'entity_')
136
137        # code = Struct(name, targs=(['class GridPart'] + ['class ' + c['name'] if c['name'] is not None else 'class Coefficient' + str(i) for i, c in enumerate(self._coefficients)] + targs))
138        code = Struct(name, targs=(['class GridPart'] + ['class ' + self.cppTypeIdentifier(c['name'],"coefficient",i) for i, c in enumerate(self._coefficients)] + targs))
139
140        code.append(TypeAlias("GridPartType", "GridPart"))
141        code.append(TypeAlias("EntityType", "typename GridPart::template Codim< 0 >::EntityType"))
142        code.append(TypeAlias("IntersectionType", "typename GridPart::IntersectionType"))
143
144        code.append(declareFunctionSpace("typename GridPartType::ctype", SourceWriter.cpp_fields(self.field), UnformattedExpression("int", "GridPartType::dimensionworld"), self.dimDomain,
145            name="DFunctionSpaceType",prefix="D",
146            dimDomainName="dimDomain", dimRangeName="dimD"
147            ))
148        code.append(declareFunctionSpace("typename GridPartType::ctype", SourceWriter.cpp_fields(self.field), UnformattedExpression("int", "GridPartType::dimensionworld"), self.dimRange,
149            name="RFunctionSpaceType",prefix="R",
150            dimDomainName=None, dimRangeName="dimR"
151            ))
152        code.append(Declaration(Variable("const int", "dimLocal"), initializer=UnformattedExpression("int", "GridPartType::dimension"), static=True))
153
154        if self.hasConstants:
155            code.append(TypeAlias("ConstantType", "typename std::tuple_element_t< i, " + constants_.cppType + " >::element_type", targs=["std::size_t i"]))
156            code.append(Declaration(Variable("const std::size_t", "numConstants"), initializer=len(self._constants), static=True))
157
158        if self.hasCoefficients:
159            code.append(TypeAlias('CoefficientType', 'std::tuple_element_t< i, ' + coefficients_.cppType + ' >', targs=['std::size_t i']))
160            # coefficientSpaces = ["Dune::Fem::FunctionSpace< DomainFieldType, " + SourceWriter.cpp_fields(c['field']) + ", dimDomain, " + str(c['dimRange']) + " >" for c in self._coefficients]
161            coefficientSpaces = ["typename CoefficientType<"+str(i)+">::FunctionSpaceType" for i,c in enumerate(self._coefficients)]
162            code.append(TypeAlias("CoefficientFunctionSpaceTupleType", "std::tuple< " + ", ".join(coefficientSpaces) + " >"))
163
164        arg_param = Variable("const Dune::Fem::ParameterReader &", "parameter")
165        args = [Declaration(arg_param, initializer=UnformattedExpression('const ParameterReader &', 'Dune::Fem::Parameter::container()'))]
166        init = None
167        if self.hasCoefficients:
168            # args = [Variable("const " + c['name'] if c['name'] is not None else "const Coefficient" + str(i) + " &", "coefficient" + str(i)) for i, c in enumerate(self._coefficients)] + args
169            args = [Variable("const " + self.cppTypeIdentifier(c['name'],"coefficient",i) + " &", "coefficient" + str(i)) for i, c in enumerate(self._coefficients)] + args
170            init = ["coefficients_(" + ",".\
171                  join("CoefficientType<"+str(i)+">"\
172                        +"(coefficient" + str(i)+")" for i, c in enumerate(self._coefficients)) + " )"]
173        constructor = Constructor(args=args, init=init)
174        constructor.append([assign(get(str(i))(constants_), make_shared(c)()) for i, c in enumerate(self._constants)])
175        for name, idx in self._parameterNames.items():
176            constructor.append(assign(dereference(get(idx)(constants_)), UnformattedExpression("auto", arg_param.name + '.getValue< ' + self._constants[idx] + ' >( "' + name + '" )', uses=[arg_param])))
177        code.append(constructor)
178
179        init = ['entity_ = &entity;']
180        init += ['std::get< ' + str(i) + ' >( ' + coefficients_.name + ').bind( entity );' for i, c in enumerate(self._coefficients)]
181        init = [UnformattedBlock(init)] + self.init + [return_(True)]
182        code.append(Method('bool', 'init', args=['const EntityType &entity'], code=init, const=True))
183
184        code.append(Method('const EntityType &', 'entity', code=return_(dereference(entity_)), const=True))
185        code.append(Method('std::string', 'name', const=True, code=return_(UnformattedExpression('const char *', '"' + name + '"'))))
186
187        code.append(TypeAlias("BoundaryIdProviderType", "Dune::Fem::BoundaryIdProvider< typename GridPartType::GridType >"))
188        code.append(Declaration(Variable("const bool", "symmetric"), initializer=self.symmetric, static=True))
189
190        code.append(Method('void', 'source', targs=['class Point'], args=[self.arg_x, self.arg_u, self.arg_du, self.arg_r], code=self.source, const=True))
191        code.append(Method('void', 'linSource', targs=['class Point'], args=[self.arg_ubar, self.arg_dubar, self.arg_x, self.arg_u, self.arg_du, self.arg_r], code=self.linSource, const=True))
192
193        code.append(Method('void', 'diffusiveFlux', targs=['class Point'], args=[self.arg_x, self.arg_u, self.arg_du, self.arg_dr], code=self.diffusiveFlux, const=True))
194        code.append(Method('void', 'linDiffusiveFlux', targs=['class Point'], args=[self.arg_ubar, self.arg_dubar, self.arg_x, self.arg_u, self.arg_du, self.arg_dr], code=self.linDiffusiveFlux, const=True))
195
196        code.append(Method('void', 'fluxDivergence', targs=['class Point'], args=[self.arg_x, self.arg_u, self.arg_du, self.arg_d2u, self.arg_r], code=self.fluxDivergence, const=True))
197
198        code.append(Method('void', 'alpha', targs=['class Point'], args=[self.arg_x, self.arg_u, self.arg_r], code=self.alpha, const=True))
199        code.append(Method('void', 'linAlpha', targs=['class Point'], args=[self.arg_ubar, self.arg_x, self.arg_u, self.arg_r], code=self.linAlpha, const=True))
200
201        code.append(Method('bool', 'hasNeumanBoundary', const=True, code=return_(self.hasNeumanBoundary)))
202
203        code.append(TypeAlias("DirichletComponentType","std::array<int,"+str(self.dimRange)+">"))
204        code.append(Method('bool', 'hasDirichletBoundary', const=True, code=return_(self.hasDirichletBoundary)))
205        code.append(Method('bool', 'isDirichletIntersection', args=[self.arg_i, 'DirichletComponentType &dirichletComponent'], code=self.isDirichletIntersection, const=True))
206        code.append(Method('void', 'dirichlet', targs=['class Point'], args=[self.arg_bndId, self.arg_x, self.arg_r], code=self.dirichlet, const=True))
207
208        if self.hasConstants:
209            code.append(Method("const ConstantType< i > &", "constant", targs=["std::size_t i"], code=return_(dereference(get("i")(constants_))), const=True))
210            code.append(Method("ConstantType< i > &", "constant", targs=["std::size_t i"], code=return_(dereference(get("i")(constants_)))))
211
212        if self.hasCoefficients:
213            code.append(Method("const CoefficientType< i > &", "coefficient", targs=["std::size_t i"], code=return_(get("i")(coefficients_)), const=True))
214            code.append(Method("CoefficientType< i > &", "coefficient", targs=["std::size_t i"], code=return_(get("i")(coefficients_))))
215
216        for n, i in self._constantNames.items():
217            t = self._constants[i]
218            code.append(Method('const ' + t + ' &', n, code=return_(dereference(get(i)(constants_))), const=True))
219            code.append(Method(t + ' &', n, code=return_(dereference(get(i)(constants_)))))
220
221        code.append(AccessModifier("private"))
222        code.append(Declaration(entity_, nullptr, mutable=True))
223        if self.hasConstants:
224            code.append(Declaration(constants_, mutable=True))
225        if self.hasCoefficients:
226            code.append(Declaration(coefficients_, mutable=True))
227        return code
228
229    #def write(self, sourceWriter, name='Model', targs=[]):
230    #    sourceWriter.emit(self.code(name=name, targs=targs))
231
232    def exportSetConstant(self, sourceWriter, modelClass='Model', wrapperClass='ModelWrapper'):
233        sourceWriter.emit(TypeAlias('ModelType', modelClass))
234        sourceWriter.openFunction('std::size_t renumberConstants', args=['pybind11::handle &obj'])
235        sourceWriter.emit('std::string id = pybind11::str( obj );')
236        sourceWriter.emit('if( pybind11::hasattr(obj,"name") ) id = pybind11::str(obj.attr("name"));')
237        for name, number in self._constantNames.items():
238            sourceWriter.emit('if (id == "' + name + '") return ' + str(number) + ';')
239        sourceWriter.emit('throw pybind11::value_error("coefficient \'" + id + "\' has not been registered");')
240        sourceWriter.closeFunction()
241
242        sourceWriter.openFunction('void setConstant', targs=['std::size_t i'], args=['ModelType &model', 'pybind11::handle value'])
243        sourceWriter.emit('model.template constant< i >() = value.template cast< typename ModelType::ConstantType< i > >();')
244        sourceWriter.closeFunction()
245
246        sourceWriter.openFunction('auto DUNE_PRIVATE defSetConstant', targs=['std::size_t... i'], args=['std::index_sequence< i... >'])
247        sourceWriter.emit(TypeAlias('Dispatch', 'std::function< void( ModelType &model, pybind11::handle ) >'))
248        sourceWriter.emit('std::array< Dispatch, sizeof...( i ) > dispatch = {{ Dispatch( setConstant< i > )... }};')
249        sourceWriter.emit('')
250        sourceWriter.emit('return [ dispatch ] ( ' + wrapperClass + ' &model, pybind11::handle coeff, pybind11::handle value ) {')
251        sourceWriter.emit('    std::size_t k = renumberConstants( coeff );')
252        sourceWriter.emit('    if( k >= dispatch.size() )')
253        sourceWriter.emit('      throw std::range_error( "No such coefficient: "+std::to_string(k)+" >= "+std::to_string(dispatch.size()) );' )
254        sourceWriter.emit('    dispatch[ k ]( model, value );')
255        sourceWriter.emit('    return k;')
256        sourceWriter.emit('  };')
257        sourceWriter.closeFunction()
258
259    def export(self, sourceWriter, modelClass='Model', wrapperClass='ModelWrapper',nameSpace=''):
260        if self.hasConstants:
261            sourceWriter.emit('cls.def( "setConstant",'+nameSpace+'::defSetConstant( std::make_index_sequence< ' + modelClass + '::numConstants >() ) );')
262        coefficients = [('Dune::FemPy::VirtualizedGridFunction< GridPart, Dune::FieldVector< ' + SourceWriter.cpp_fields(c['field']) + ', ' + str(c['dimRange']) + ' > >')
263                        if not c['typeName'].startswith("Dune::Python::SimpleGridFunction") \
264                        else c['typeName'] \
265                for c in self._coefficients]
266        sourceWriter.emit('')
267        # TODO
268        sourceWriter.emit('cls.def( pybind11::init( [] ( ' + ', '.join( [] + ['const ' + c + ' &coefficient' + str(i) for i, c in enumerate(coefficients)]) + ' ) {')
269        if self.hasCoefficients:
270            sourceWriter.emit('  return new  ' + wrapperClass + '( ' + ', '.join('coefficient' + str(i) for i, c in enumerate(coefficients)) + ' );')
271        else:
272            sourceWriter.emit('  return new  ' + wrapperClass + '();')
273        #if self.coefficients:
274        #    sourceWriter.emit('  const int size = std::tuple_size<Coefficients>::value;')
275        #    sourceWriter.emit('  auto dispatch = defSetCoefficient( std::make_index_sequence<size>() );' )
276        #    sourceWriter.emit('  std::vector<bool> coeffSet(size,false);')
277        #    sourceWriter.emit('  for (auto item : coeff) {')
278        #    sourceWriter.emit('    int k = dispatch(instance, item.first, item.second); ')
279        #    sourceWriter.emit('    coeffSet[k] = true;')
280        #    sourceWriter.emit('  }')
281        #    sourceWriter.emit('  if ( !std::all_of(coeffSet.begin(),coeffSet.end(),[](bool v){return v;}) )')
282        #    sourceWriter.emit('    throw pybind11::key_error("need to set all coefficients during construction");')
283        if self.hasCoefficients:
284            sourceWriter.emit('  }), ' + ', '.join('pybind11::keep_alive< 1, ' + str(i) + ' >()' for i, c in enumerate(coefficients, start=2)) + ' );')
285        else:
286            sourceWriter.emit('  } ) );')
287
288    def codeCoefficient(self, code, coefficients, constants):
289        """find coefficients/constants in code string and do replacements
290        """
291        for name, value in coefficients.items():
292            if not any(name == c['name'] for c in self._coefficients):
293                self.addCoefficient(value.dimRange,value._typeName, name)
294        for name, dimRange in constants.items():
295            if name not in self._constantNames:
296                self.addConstant('Dune::FieldVector< double, ' + str(dimRange) + ' >', name)
297
298        if '@const:' in code:
299            codeCst = code.split('@const:')
300            import itertools
301            for name in set([''.join(itertools.takewhile(str.isalpha, str(c))) for c in codeCst[1:]]):
302                if name not in self._constantNames:
303                    cname = '@const:' + name
304                    afterName = code.split(cname)[1:]
305                    if afterName[0][0] == '[':
306                        beforeText = [an.split(']')[0].split('[')[1] for an in afterName]
307                        dimRange = max( [int(bt) for bt in beforeText] ) + 1
308                    else:
309                        dimRange = 1
310                    self.addConstant('Dune::FieldVector< double, ' + str(dimRange) + ' >', name)
311
312        for i, c in enumerate(self._coefficients):
313            jacname = '@jac:' + c['name']
314            if jacname in code:
315                varname = 'dc' + str(i)
316                code = code.replace(jacname, varname)
317                decl = 'CoefficientJacobianRangeType< ' + str(i) + ' > ' + varname + ';'
318                if not decl in code:
319                    code = decl + '\ncoefficient< ' + str(i) + ' >().jacobian( x, ' + varname + ' );' + code
320            gfname = '@gf:' + c['name']
321            if gfname in code:
322                varname = 'c' + str(i)
323                code = code.replace(gfname, varname)
324                decl = 'CoefficientRangeType< ' + str(i) + ' > c' + str(i) + ';'
325                if not decl in code:
326                    code = decl + '\ncoefficient< ' + str(i) + ' >().evaluate( x, ' + varname + ' );' + code
327
328        for name, i in self._constantNames.items():
329            cname = '@const:' + name
330            if cname in code:
331                varname = 'cc' + str(i)
332                code = code.replace(cname, varname)
333                init = 'const ' + self._constants[i] + ' &' + varname + ' = constant< ' + str(i) + ' >();'
334                if not init in code:
335                    code = init + code
336
337    def appendCode(self, key, code, **kwargs):
338        function = getattr(self, key)
339        coef = kwargs.pop("coefficients", {})
340        const = kwargs.pop("constants", {})
341        function.append(UnformattedBlock(self.codeCoefficient(code, coef, const)))
342        setattr(self, key, function)
343
344    def generateMethod(self, code, expr, *args, **kwargs):
345        if isinstance(expr, Expr):
346            try:
347                expr = replace(expr, self._replaceCoeff)
348            except UFLException:
349                pass
350        return generateMethod(code,expr,*args,**kwargs)
351