1from __future__ import division
2import math, sys
3from collections import OrderedDict
4import sympy as sp
5import numpy as np
6import numpy.matlib as ml
7import numpy.linalg as la
8import dk_templates
9from codelib import *
10import dk_simulator
11
12def solver_set_defaults(solver_dict=None):
13    if solver_dict is None:
14        solver_dict = {}
15    else:
16        solver_dict = solver_dict.copy()
17    solver_dict.setdefault("method", "hybr")
18    solver_dict.setdefault("factor", 1.e2)
19    solver_dict.setdefault("xtol", math.sqrt(sys.float_info.epsilon))
20    solver_dict.setdefault("maxfev", 2000)
21    solver_dict.setdefault("max_homotopy_iter", 100)
22    return solver_dict
23
24
25class Structure(object):
26
27    def __init__(self, name):
28        self.name = name
29        self.members = []
30        self.nonlin_info = False
31        self.minmax_info = None
32
33    def add(self, m):
34        self.members.append(m)
35
36    def add_nonlin_info(self):
37        self.nonlin_info = True;
38
39    def add_minmax_info(self, rows):
40        self.minmax_info = rows
41
42    def get_initializer(self, kw):
43        s = dict((m.name, i) for i, m in enumerate(v for v in self.members if v.pointer))
44        n = len(s)
45        if self.nonlin_info:
46            s['info'] = n
47            s['nfev'] = n+1
48            s['fnorm'] = n+2
49            n += 3
50        if self.minmax_info:
51            s['p_val'] = n
52            n += 1
53        l = [None]*n
54        s2 = set()
55        for k, v in kw.items():
56            try:
57                i = s[k]
58            except KeyError:
59                raise ValueError("unknown struct member: %s" % k)
60            if isinstance(v, VariableAccess):
61                l[i] = str(v.address_of())
62            else:
63                l[i] = v
64            s2.add(k)
65        missing = set(s) - s2
66        if missing:
67            raise ValueError("no structure initialization for: %s" % ", ".join(missing))
68        if not l:
69            return ""
70        return "("+", ".join(l)+")"
71
72    def generate_lines(self):
73        arglist = [m.generate("_") for m in self.members if m.pointer]
74        initlist = []
75        for m in self.members:
76            if m.pointer:
77                initlist.append("%s(%s_)" % (m.name, m.name))
78            else:
79                initlist.append("%s()" % m.name)
80        l = []
81        l.append("struct %s {\n" % self.name)
82        l += ["    %s;\n" % m for m in self.members]
83        if self.nonlin_info:
84            l.append("    int *info;\n")
85            l.append("    int *nfev;\n")
86            l.append("    creal *fnorm;\n")
87            initlist += ["info(info_)", "nfev(nfev_)", "fnorm(fnorm_)"]
88            arglist += ["int *info_", "int *nfev_", "creal *fnorm_"]
89        if self.minmax_info:
90            p_val = MatrixDeclaration('p_val', rows=self.minmax_info, cols=1, pointer=True, array=True)
91            l.append("    %s;\n" % p_val)
92            initlist.append("p_val(p_val_)")
93            arglist.append(p_val.generate("_"))
94        if initlist:
95            l.append("    inline %(name)s(%(args)s): %(init)s {}\n" % dict(
96                    name = self.name,
97                    args = ", ".join(arglist),
98                    init = ", ".join(initlist),
99                    ))
100        l.append("};")
101        return l
102
103    def __str__(self):
104        return "".join(self.generate_lines())
105
106
107class NonlinFunction(object):
108
109    def __init__(self, global_ns, neq, on):
110        self.global_ns = global_ns
111        self.neq = neq
112        self.have_constant_matrices = (neq.eq.np == 0)
113        self.locals = Namespace()
114        self.input_slice = on.input_slice
115        self.output_slice = on.output_slice
116
117    def expr_list(self, v):
118        par_v = make_symbol_vector('(*par.v)', self.neq.nn)
119        l = []
120        f = self.neq.eq.f
121        off = 0
122        if self.neq.v_slice:
123            off = self.neq.v_slice.start
124            f = f[self.neq.v_slice]
125        for expr, vl, base in f:
126            for var, idx in zip(vl, base):
127                if idx >= off:
128                    vv = v[idx-off]
129                else:
130                    vv = par_v[idx]
131                expr = expr.subs(var, vv)
132            l.append(expr)
133        return l
134
135    def generate(self, template):
136        if self.neq.v_slice:
137            iblockV = self.input_slice
138            oblockV = self.output_slice
139            blockM = (self.neq.v_slice, self.neq.v_slice)
140            start = self.output_slice.start
141        else:
142            iblockV = oblockV = None
143            blockM = None
144            start = 0
145        v = make_symbol_vector('v', self.neq.nn)
146        i = VectorAccess('i', param=True)
147        l = []
148        l += expr_list_to_ccode(str(i), self.expr_list(v), '(%d)', start)
149        mv = MatrixDefinition('mv', rows=self.neq.nn, cols=1)
150        Mfvec = MatrixDefinition('Mfvec', value='fvec', rows=self.neq.nn, cols=1)
151        self.locals.add(mv)
152        self.locals.add(Mfvec)
153        l += self.locals.generate_lines()
154        CZ = self.neq.eq.CZ[self.neq.v_slice]
155        K = self.neq.eq.K
156        if self.have_constant_matrices:
157            K = K[blockM]
158        l.append("mv << %s;\n" % ", ".join(
159            ['v[%d]' % i if z else '0' for i, z in enumerate(CZ)]))
160        if not matrix_is_identity(K):
161            if self.have_constant_matrices:
162                self.global_ns.add(MatrixDefinition('K', K))
163                K = MatrixAccess('K')
164            else:
165                K = MatrixAccess('K', param=True, block=blockM, pointer=False)
166            l.append("Mfvec = %(p)s + %(K)s * %(i)s - %(mv)s;\n" % dict(
167                p = VectorAccess('p', param=True, block=iblockV),
168                K = K,
169                i = VectorAccess('i', param=True, block=oblockV),
170                mv = VectorAccess('mv'),
171                ))
172        else:
173            l.append("Mfvec = %(p)s + %(i)s - %(mv)s;" % dict(
174                p = VectorAccess('p', param=True, block=iblockV),
175                i = VectorAccess('i', param=True, block=oblockV),
176                mv = VectorAccess('mv'),
177                ))
178        return template.render(dict(expression=join_with_indent(l)))
179
180
181class NonlinFunctionCC(object):
182
183    def __init__(self, global_ns, neq, extra_sources):
184        self.global_ns = global_ns
185        self.neq = neq
186        self.extra_sources = extra_sources
187
188    def generate(self, template):
189        neq = self.neq
190        have_constant_matrices = (neq.eq.np == 0)
191        end = neq.cc_slice.start
192        cc = neq.cc_slice.stop - neq.cc_slice.start
193        p0_slice = slice(neq.p_slice.start, neq.p_slice.stop-cc)
194        p1_slice = slice(neq.p_slice.stop-cc, neq.p_slice.stop)
195        i0_slice = slice(neq.i_slice.start, neq.i_slice.stop-cc)
196        i1_slice = slice(neq.i_slice.stop-cc, neq.i_slice.stop)
197        M0 = (slice(0, end), slice(end, neq.nn))
198        M1 = (slice(end, neq.nn), slice(0, neq.nn))
199        loc = Namespace()
200        l = []
201        loc.add(MatrixDefinition('Mv', value='v', rows=neq.nn-end, cols=1, aligned=False, const=True))
202        if self.extra_sources:
203            loc.add(MatrixDefinition('pt', rows=p0_slice.stop-p0_slice.start, cols=1, array=True))
204            pt = VectorAccess('pt')
205        else:
206            loc.add(MatrixDefinition('pt', rows=neq.nn, cols=1))
207            pt = VectorAccess('pt', block=p0_slice)
208        l += loc.generate_lines()
209        if have_constant_matrices:
210            #self.global_ns.add(MatrixDefinition('Ku', neq.eq.K[M0]))
211            self.global_ns.add(MatrixDefinition('Ku', neq.Ku))
212            Ku = MatrixAccess('Ku')
213        else:
214            Ku = MatrixAccess('K', param=True, block=M0, pointer=False)
215        l.append("%s = %s + %s * %s;\n" % (
216            pt,
217            VectorAccess('p', param=True, block=p0_slice),
218            Ku,
219            VectorAccess('Mv')))
220        if self.extra_sources:
221            shape_transform = self.extra_sources.get("shape_transform", ())
222            for mat in shape_transform:
223                self.global_ns.add(mat)
224            if shape_transform:
225                def nth(j):
226                    idx = []
227                    for bl in neq.subblocks:
228                        ln = bl.v_slice.stop - bl.v_slice.start
229                        idx.append(j)
230                        j += ln
231                    return idx
232                def spliced(n):
233                    for a in range(n):
234                        yield a
235                        yield a+n
236                nn2 = end//2
237                loc.add(MatrixDeclaration('PP1', rows=nn2, cols=1, array=True))
238                l += loc.generate_lines()
239                l.append("PP1 << %s;\n" % ", ".join(["pt(%d)" % i for i in nth(1)]))
240                loc.add(MatrixDeclaration('PP0', rows=nn2, cols=1, array=True))
241                l += loc.generate_lines()
242                l.append("PP0 << %s;\n" % ", ".join(["pt(%d)" % i for i in nth(0)]))
243                l += loc.generate_lines()
244                l.append("pt.head<%d>() = (Spm1 * PP1 + Ssm1) * PP1 + Sam1 + ((Spm2 * PP1 + Ssm2) * PP1 + Sam2) * PP0;\n" % nn2)
245                loc.add(MatrixDeclaration('res', rows=end, cols=1, array=True))
246                l += loc.generate_lines()
247                inp = ["&pt(%d)" % (i//2) for i in range(end)]
248                outp = ["&res(%d)" % i for i in spliced(nn2)]
249            else:
250                inp = []
251                outp = []
252                for bl in neq.subblocks:
253                    o = VectorAccess('i', param=True)
254                    for j in range(bl.i_slice.start, bl.i_slice.stop):
255                        inp.append("&pt(%d)" % bl.p_slice.start)
256                        outp.append("&%s(%d)" % (o, j))
257            jj = 0
258            tables = self.extra_sources["tables"]
259            for bl in neq.subblocks:
260                for j, kn in enumerate(tables[bl.namespace].knot_data):
261                    reorder = False
262                    unused = False
263                    ll = []
264                    for i, v in enumerate(kn):
265                        if not v.used():
266                            unused = True
267                        else:
268                            ll.append(i)
269                            if unused:
270                                reorder = True
271                    if reorder:
272                        l.append("{ Array<creal, %d, 1> pt2; pt2 << %s;\n" % (len(ll), ", ".join(["pt(%d)" % (bl.p_slice.start+i) for i in ll])))
273                        inpt = "&pt2(0)"
274                    else:
275                        inpt = inp[jj]
276                    fu = "splev"
277                    if kn[0].tp == 'pp':
278                        fu = "splev_pp"
279                    l.append("splinedata<AmpData::%s::maptype>::%s<%s>(&AmpData::%s::sd.sc[%d], %s, %s);\n"
280                             % (bl.namespace, fu, ",".join([str(v.get_order()) for v in kn if v.used()]), bl.namespace, j, inpt, outp[jj]))
281                    if reorder:
282                        l.append("}\n")
283                    jj += 1
284            if shape_transform:
285                l.append("pt.head<%d>() = ((Spm0 * PP1 + Ssm0) * PP1 + Sam0) * res.head<%d>();\n" % (nn2,nn2))
286                l.append("pt.tail<%d>() = ((Spm0 * PP1 + Ssm0) * PP1 + Sam0) * res.tail<%d>();\n" % (nn2,nn2))
287                l.append("%s << %s;\n" % (VectorAccess('i', param=True, block=i0_slice),
288                                          ", ".join(["pt(%d)" % v for v in spliced(nn2)])))
289        else:
290            loc.add(MatrixDefinition('pp', rows=neq.nn, cols=1, pointer=True))
291            l += loc.generate_lines()
292            l.append("%s = %s;\n" % (
293                MatrixAccess('pp', pointer=True).address_of(),
294                MatrixAccess('p', param=True).address_of()))
295            restore = "%s = %s;\n" % (
296                MatrixAccess('p', param=True).address_of(),
297                MatrixAccess('pp', pointer=True).address_of())
298            l.append("%s = %s;\n" % (
299                MatrixAccess('p', param=True).address_of(),
300                MatrixAccess('pt').address_of()))
301            l.append("int ret;\n")
302            for bl in neq.subblocks:
303                l.append("ret = %s::nonlin_solve(par);\n" % bl.namespace)
304                l.append("if (ret != 0) {\n")
305                l.append("    "+restore)
306                l.append("    return 1;\n")
307                l.append("};\n")
308            l.append(restore)
309        l.append("%s = %s;\n" % (VectorAccess('i', param=True, block=i1_slice), VectorAccess('Mv')))
310        if have_constant_matrices:
311            if M1:
312                self.global_ns.add(MatrixDefinition('Kl', neq.Kl))
313                Kl = MatrixAccess('Kl')
314            else:
315                Kl = MatrixAccess('K', block=M1)
316        else:
317            Kl = MatrixAccess('K', param=True, block=M1, pointer=False)
318        Mfvec = MatrixDefinition('Mfvec', value='fvec', rows=neq.nn-end, cols=1)
319        loc.add(Mfvec)
320        l += loc.generate_lines()
321        l.append("Mfvec = %(p)s + %(K)s * %(i)s;\n" % dict(
322            p = VectorAccess('p', param=True, block=p1_slice),
323            K = Kl,
324            i = VectorAccess('i', param=True, block=neq.i_slice),
325            ))
326        return template.render(dict(expression=join_with_indent(l)))
327
328
329class NonlinSolver(object):
330
331    def __init__(self, base, glob, neq, solver_dict):
332        self.base = base
333        self.neq = neq
334        self.solver_dict = solver_dict
335        self.have_constant_matrices = (neq.eq.np == 0)
336        self.global_ns = glob
337        self.local_ns = Namespace()
338        self.mp_is_ident = matrix_is_identity(self.neq.U)
339        self.mpc_is_zero = matrix_is_zero(self.neq.Hc)
340        if self.mp_is_ident and self.mpc_is_zero:
341            self.input_slice = self.neq.p_slice
342        else:
343            self.input_slice = slice(0, self.neq.nn)
344        self.mi_is_ident = matrix_is_identity(self.neq.Mi)
345        if self.mi_is_ident:
346            self.output_slice = self.neq.i_slice
347        else:
348            self.output_slice = slice(0, self.neq.nn)
349        self.base["nn"] = self.neq.nn
350        self.base["nni"] = self.neq.nni
351        self.base["nno"] = self.neq.nno
352
353    def p_transform(self):
354        l = []
355        if self.mp_is_ident and self.mpc_is_zero:
356            return l
357        par = not self.have_constant_matrices
358        par_p = VectorAccess('p', param=True, block=self.neq.p_slice)
359        if not self.mp_is_ident:
360            if not par:
361                self.global_ns.add(MatrixDefinition('Mp', value=self.neq.U))
362            s = "%s * %s" % (MatrixAccess('Mp', param=par), par_p)
363        else:
364            s = "%s" % par_p
365        if not self.mpc_is_zero:
366            if not par:
367                self.global_ns.add(MatrixDefinition('Mpc', value=self.neq.Hc))
368            s += " + %s" % VectorAccess('Mpc', param=par, pointer=False)
369        g_nn = self.neq.eq.nonlin.nn
370        p_old_def = MatrixDefinition("p_old", rows=g_nn, cols=1, pointer=True)
371        self.local_ns.add(p_old_def)
372        p_old = VectorAccess(p_old_def)
373        p2_def = MatrixDefinition("p2", rows=g_nn, cols=1)
374        self.local_ns.add(p2_def)
375        p2 = VectorAccess(p2_def, block=self.input_slice)
376        l += self.local_ns.generate_lines()
377        l.append("%s = %s;\n" % (p2, s))
378        l.append("p_old = %s;\n" % par_p.address_of())
379        l.append("%s = &p2;\n" % par_p.address_of())
380        self.cleanup.append("%s = p_old;\n" % par_p.address_of())
381        return l
382
383    def i_transform(self):
384        l = []
385        if self.mi_is_ident or self.neq.nno == 0:
386            return l
387        g_nn = self.neq.eq.nonlin.nn
388        mi_def = MatrixDefinition("mi", rows=g_nn, cols=1)
389        mi = VectorAccess(mi_def, block=self.output_slice)
390        i_old_def = MatrixDefinition("i_old", rows=g_nn, cols=1, pointer=True)
391        i_old = VectorAccess(i_old_def)
392        par_i = VectorAccess('i', param=True, block=self.neq.i_slice)
393        self.global_ns.add(MatrixDefinition("Mi", self.neq.Mi))
394        self.local_ns.add(mi_def)
395        self.local_ns.add(i_old_def)
396        self.setup += self.local_ns.generate_lines()
397        self.setup.append("i_old = %s;\n" % VectorAccess('i', param=True).address_of())
398        self.setup.append("%s = &mi;\n" % VectorAccess('i', param=True).address_of())
399        l.append("%s = i_old;\n" % VectorAccess('i', param=True).address_of())
400        par = not self.have_constant_matrices
401        l.append("%(io)s = %(Mi)s * %(ii)s;" % dict(
402                io = par_i,
403                Mi = MatrixAccess('Mi', param=par, pointer=False),
404                ii = mi,
405                ))
406        return l
407
408    def make_var_v_ref(self):
409        start = 0
410        if self.neq.v_slice:
411            start = self.neq.v_slice.start
412        return "&%s(%s)" % (VectorAccess("v", param=True), start)
413
414    def generate(self, template):
415        d = self.base.copy()
416        self.setup = []
417        self.cleanup = []
418        d["i_transform"] = join_with_indent(self.i_transform())
419        d["p_transform"] = join_with_indent(self.p_transform())
420        d["setup"] = join_with_indent(self.setup)
421        d["cleanup"] = join_with_indent(self.cleanup)
422        d["var_v_ref"] = self.make_var_v_ref()
423        d["store_p"] = "%s = %s;\n" % (
424            VectorAccess('p_val', param=True, block=self.neq.p_slice),
425            VectorAccess('p', param=True, block=self.neq.p_slice),
426            )
427        return template.render(d)
428
429class NonlinSolverCC(NonlinSolver):
430
431    def __init__(self, base, glob, neq, solver_dict):
432        NonlinSolver.__init__(self, base, glob, neq, solver_dict)
433        self.input_slice = self.neq.p_slice
434        self.mp_is_ident = True
435        self.mi_is_ident = True
436
437    def make_var_v_ref(self):
438        return "&%s(%s)" % (VectorAccess("v", param=True), self.neq.cc_slice.start)
439
440    def generate(self, template):
441        self.base = self.base.copy()
442        self.base["nn"] = self.base["nni"] = self.base["nno"] = self.neq.cc_slice.stop - self.neq.cc_slice.start
443        return NonlinSolver.generate(self, template)
444
445
446class NonlinCode(object):
447
448    method_templates = {
449        'hybr':   (dk_templates.c_template_nonlin_func_hybrCC,
450                   dk_templates.c_template_nonlin_solver_hybrCC),
451        'lm':     (dk_templates.c_template_nonlin_func_lm,
452                   dk_templates.c_template_nonlin_solver_lm),
453        'hybrCC': (dk_templates.c_template_nonlin_func_hybrCC,
454                   dk_templates.c_template_nonlin_solver_hybrCC),
455        }
456
457    def __init__(self, struct, neq, solver_dict, extra_sources):
458        self.struct = struct
459        self.neq = neq
460        self.solver_dict = solver_dict
461        self.extra_sources = extra_sources
462
463    def setup(self, d):
464        neq = self.neq
465        base = {}
466        base["v0_guess"] = d["v0_guess"]
467        base["dev_interface"] = d["dev_interface"]
468        g_nonlin = neq.eq.nonlin
469        base["extern_nonlin"] = not neq.subblocks or neq is g_nonlin
470        base["g_nn"] = g_nonlin.nn
471        base["g_nni"] = g_nonlin.nni
472        base["g_nno"] = g_nonlin.nno
473        base["npl"] = neq.eq.get_npl()
474        if neq.nn != g_nonlin.nn:
475            nn = base["nn"] = neq.nn
476            nni = base["nni"] = neq.nni
477            nno = base["nno"] = neq.nno
478            self.blockV = neq.v_slice
479            self.pblockV = neq.p_slice
480            base["pblockV"] = VectorAccess.block_expr(neq.p_slice)
481            base["iblockV"] = VectorAccess.block_expr(neq.i_slice)
482        else:
483            base["nn"] = nn = neq.nn
484            base["nni"] = nni = neq.nni
485            base["nno"] = nno = neq.nno
486            self.blockV = None
487            if nn != nni:
488                self.pblockV = slice(0, nni)
489                base["pblockV"] = VectorAccess.block_expr([0, nni])
490            else:
491                self.pblockV = None
492                base["pblockV"] = ""
493            if nn != nno:
494                base["iblockV"] = VectorAccess.block_expr([0, nno])
495            else:
496                base["iblockV"] = ""
497        if self.solver_dict:
498            for k, v in self.solver_dict.items():
499                base["solver_"+k] = v
500        ini = dict(
501            p = MatrixAccess('mp'),
502            i = MatrixAccess('mi'),
503            v = MatrixAccess('Mv'),
504            info = 'info',
505            fnorm = 'fnorm',
506            nfev = 'nfev',
507            p_val = MatrixAccess('p_val'),
508            )
509        base["nonlin_mat_list"] = self.struct.get_initializer(ini)
510        self.glob = Namespace()
511        base["namespace"] = neq.namespace
512        base["global_data_def"] = self.glob
513        return base
514
515    def generate(self, d):
516        base = self.setup(d)
517        d = base.copy()
518        neq = self.neq
519        func_t, solv_t = self.method_templates[self.solver_dict["method"]]
520        if isinstance(neq, dk_simulator.PartitionedNonlinEquations):
521            of = NonlinFunctionCC(self.glob, neq, self.extra_sources)
522            on = NonlinSolverCC(base, self.glob, neq, self.solver_dict)
523        else:
524            on = NonlinSolver(base, self.glob, neq, self.solver_dict)
525            of = NonlinFunction(self.glob, neq, on)
526        if func_t is not None:
527            d["fcn_def"] = of.generate(func_t)
528        d["nonlin_def"] = on.generate(solv_t)
529        d["par_p"] = VectorAccess("p", param=True, block=self.pblockV)
530        d["par_v"] = VectorAccess("v", param=True, block=self.blockV)
531        return dk_templates.c_template_nonlin_solver.render(d)
532
533
534class NonlinChained(NonlinCode):
535
536    def __init__(self, s, neq):
537        NonlinCode.__init__(self, s, neq, None, None)
538
539    def generate(self, d):
540        d = self.setup(d)
541        self.glob.add(MatrixDefinition('K', self.neq.eq.K))
542        template = dk_templates.c_template_nonlin_chained
543        l = []
544        l.append("%s = %s;\n" % (
545            VectorAccess('p_val', param=True, block=self.neq.p_slice),
546            VectorAccess('p', param=True, block=self.neq.p_slice),
547            ))
548        g_nn = self.neq.eq.nonlin.nn
549        p_old_def = MatrixDefinition("p_old", rows=g_nn, cols=1, pointer=True)
550        local_ns = Namespace()
551        local_ns.add(p_old_def)
552        p_old = VectorAccess(p_old_def)
553        p2_def = MatrixDefinition("p2", rows=g_nn, cols=1)
554        local_ns.add(p2_def)
555        p2 = VectorAccess(p2_def, block=self.neq.p_slice)
556        l += local_ns.generate_lines()
557        l.append("int ret = 0;\n")
558        l.append("p2 = *par.p;\n")
559        l.append("p_old = par.p;\n")
560        l.append("par.p = &p2;\n")
561        for nlin in self.neq.subblocks:
562            st = nlin.p_slice.start
563            if st:
564                l.append("p2.segment<%(ln)d>(%(st)d) += K.block<%(ln)d,%(st)d>(%(st)d,0) * (*par.i).head<%(st)d>();\n"
565                         % dict(ln=nlin.p_slice.stop-st, st=st))
566            l.append("ret = %s::nonlin_solve(par);\n" % nlin.namespace)
567            l.append("if (ret != 0) {\n")
568            l.append("    par.p = p_old;\n")
569            l.append("    return ret;\n")
570            l.append("}\n")
571        l.append("par.p = p_old;\n")
572        l.append("return 0;\n")
573        d["chained_code"] = join_with_indent(l)
574        return template.render(d)
575
576
577class TableCode(object):
578
579    def __init__(self, struct, neq, extra_sources):
580        self.struct = struct
581        self.neq = neq
582        self.extra_sources = extra_sources
583
584    def add(self, d):
585        neq = self.neq
586        base = {}
587        base["dev_interface"] = d["dev_interface"]
588        g_nonlin = neq.eq.nonlin
589        base["g_nn"] = g_nonlin.nn
590        base["g_nni"] = g_nonlin.nni
591        base["g_nno"] = g_nonlin.nno
592        base["npl"] = neq.eq.get_npl()
593        nn = base["nn"] = neq.nn
594        nni = base["nni"] = neq.nni
595        nno = base["nno"] = neq.nno
596        if neq.v_slice:
597            self.blockV = neq.v_slice
598            self.pblockV = neq.p_slice
599            self.iblockV = neq.i_slice
600            base["blockV"] = VectorAccess.block_expr(neq.v_slice)
601            base["pblockV"] = VectorAccess.block_expr(neq.p_slice)
602            base["iblockV"] = VectorAccess.block_expr(neq.i_slice)
603        else:
604            self.blockV = None
605            base["blockV"] = ""
606            if nn != nni:
607                self.pblockV = slice(0, nni)
608                base["pblockV"] = VectorAccess.block_expr([0, nni])
609            else:
610                self.pblockV = None
611                base["pblockV"] = ""
612            if nn != nno:
613                base["iblockV"] = VectorAccess.block_expr([0, nno])
614            else:
615                base["iblockV"] = ""
616        ini = dict(
617            p = MatrixAccess('mp'),
618            i = MatrixAccess('mi'),
619            v = MatrixAccess('Mv'),
620            info = '&g_info',
621            fnorm = 'fnorm',
622            nfev = '&g_nfev',
623            p_val = MatrixAccess('p_val'),
624            )
625        base["nonlin_mat_list"] = self.struct.get_initializer(ini)
626        self.glob = Namespace()
627        d = base.copy()
628        d["namespace"] = neq.namespace
629        d["global_data_def"] = self.glob
630        d["par_p"] = VectorAccess("p", param=True, block=self.pblockV)
631        d["par_v"] = VectorAccess("v", param=True, block=self.iblockV)
632        return d
633
634    def generate(self, d):
635        d = self.add(d)
636        neq = self.neq
637        tables = self.extra_sources["tables"]
638        l = []
639        l.append("real t[AmpData::%(namespace)s::sd.m];\n" % d)
640        l.append("real m[%(nni)d+%(npl)d];\n" % d)
641        l.append("Map<Matrix<real, %(nni)d+%(npl)d, 1> >mp(m);\n" % d)
642        l.append("mp << last_pot.cast<real>(), (*par.p)%(pblockV)s.cast<real>();\n" % d)
643        for j, kn in enumerate(tables[neq.namespace].knot_data):
644            reorder = False
645            unused = False
646            fu = "splev"
647            ll = []
648            for i, v in enumerate(kn):
649                if not v.used():
650                    unused = True
651                else:
652                    ll.append(i)
653                    if unused:
654                        reorder = True
655                    if kn[i].tp == 'pp':
656                        fu = "splev_pp"
657            if reorder:
658                l.append("{ Array<creal, %d, 1> pt2; pt2 << %s;\n" % (len(ll), ", ".join(["mp(%d)" % i for i in ll])))
659                inpt = "&pt2(0)"
660            else:
661                inpt = "m"
662            l.append("splinedata<AmpData::%s::maptype>::%s<%s>(&AmpData::%s::sd.sc[%d], %s, &t[%d]);\n"
663                     % (neq.namespace, fu, ",".join([str(v.get_order()) for v in kn if v.used()]), neq.namespace, j, inpt, j))
664            if reorder:
665                l.append("}\n")
666        l.append("(*par.i)%(iblockV)s = Map<Matrix<real, %(nno)d, 1> >(t).cast<creal>();\n" % d)
667        d["call"] = join_with_indent(l)
668        return dk_templates.c_template_table.render(d)
669
670
671class UpdateMatrix(object):
672
673    def __init__(self, glob, eq, pot, pot_list, pot_func, Pv):
674        self.glob = glob
675        self.eq = eq
676        self.pot = pot
677        self.pot_func = pot_func
678        self.pot_list = pot_list
679        self.Pv = Pv
680
681    def trans_line(self, res, var, var_data, t, u, u_data):
682        if not isinstance(res, VariableAccess):
683            res = MatrixAccess(res)
684        if not isinstance(var, VariableAccess):
685            var = MatrixAccess(var)
686        self.glob.add(MatrixDefinition(var, var_data))
687        self.glob.add(MatrixDefinition(res, rows=var_data.shape[0], cols=var_data.shape[1]))
688        s = "%s = %s" % (res, var)
689        if t is None:
690            return s + ";\n"
691        if u_data is not None:
692            self.glob.add(MatrixDefinition(u, u_data))
693        return s + " - %s * %s;\n" % (t, u)
694
695    def generate(self, struct):
696        eq = self.eq
697        d = {}
698        pot = make_symbol_vector('pot', eq.np)
699        l = []
700        for (a, f), p in zip(self.pot_func, self.Pv):
701            s = str(a)
702            try:
703                i = self.pot_list.index(s)
704            except ValueError:
705                self.pot_list.append(s)
706                i = len(self.pot_list)-1
707            expr = f.subs(a, pot[i]) * p
708            l.append(expr)
709        d["pot_vars"] = ",".join(['"%s"' % v for v in self.pot_list])
710        d["pot"] = ",".join([str(self.pot.get(v,0.5)) for v in self.pot_list])
711        nx = eq.nx
712        no = eq.no
713        np = eq.np
714        nn = eq.nn
715        loc = Namespace()
716        loc.add(MatrixDeclaration("Rv", rows=np, cols=1))
717        lines = []
718        lines += loc.generate_lines()
719        lines += expr_list_to_ccode('Rv', l, '(%d)')
720        self.glob.add(MatrixDefinition("Q", eq.Q))
721        loc.add(MatrixDeclaration("Qi", rows=np, cols=np))
722        lines += loc.generate_lines()
723        lines.append("Qi = (%s + Matrix<creal, %d, %d>(Rv.asDiagonal())).inverse();\n" % (MatrixAccess('Q'), np, np))
724        if eq.nx:
725            if matrix_is_identity(eq.Uxl):
726                t = "Qi"
727            elif matrix_is_zero(eq.Uxl):
728                t = None
729            else:
730                loc.add(MatrixDeclaration("Tx", rows=nx, cols=np))
731                lines += loc.generate_lines()
732                self.glob.add(MatrixDefinition("Uxl", eq.Uxl))
733                lines.append("Tx = %s * Qi;\n" % MatrixAccess('Uxl'))
734                t = "Tx"
735            Mx_mat = eq.get_Mx()
736            m = MatrixDeclaration('Mx', rows=Mx_mat.shape[0], cols=Mx_mat.shape[1])
737            struct.add(m)
738            Mx = MatrixAccess(m, param=True)
739            lines.append(self.trans_line(Mx, 'Mx', Mx_mat, t, 'UR', eq.get_UR()))
740            m = MatrixDeclaration('Mxc', rows=eq.Bc.shape[0], cols=1)
741            struct.add(m)
742            Mxc = VectorAccess(m, param=True)
743            lines.append(self.trans_line(Mxc, 'Mxc', eq.Bc, t, 'Ucv', eq.Ucv.T))
744        if matrix_is_identity(eq.Uo):
745            t  = "Qi"
746        elif matrix_is_zero(eq.Uo):
747            t = None
748        else:
749            loc.add(MatrixDeclaration("To", rows=no, cols=np))
750            lines += loc.generate_lines()
751            self.glob.add(MatrixDefinition('Uo', eq.Uo))
752            lines.append("To = %s * Qi;\n" % MatrixAccess('Uo'))
753            t = "To"
754        Mo_mat = eq.get_Mo()
755        m = MatrixDeclaration('Mo', rows=Mo_mat.shape[0], cols=Mo_mat.shape[1])
756        struct.add(m)
757        Mo = MatrixAccess(m, param=True)
758        lines.append(self.trans_line(Mo, 'Mo', Mo_mat, t, 'UR', eq.get_UR()))
759        m = MatrixDeclaration('Moc', rows=eq.Ec.shape[0], cols=1)
760        struct.add(m)
761        Moc = VectorAccess(m, param=True)
762        lines.append(self.trans_line(Moc, 'Moc', eq.Ec, t, 'Ucv', eq.Ucv.T))
763        if eq.nonlin and eq.nonlin.nni:
764            if matrix_is_identity(eq.Unl):
765                t = "Qi"
766            elif  matrix_is_zero(eq.Unl):
767                t = None
768            else:
769                loc.add(MatrixDeclaration("Tp", rows=nn, cols=np))
770                lines += loc.generate_lines()
771                self.glob.add(MatrixDefinition("Unl", eq.Unl))
772                lines.append("Tp = %s * Qi;\n" % MatrixAccess('Unl'))
773                t = "Tp"
774            Mp_mat = eq.get_Mp()
775            struct.add(MatrixDeclaration('Mp', rows=Mp_mat.shape[0], cols=Mp_mat.shape[1]))
776            lines.append(self.trans_line(MatrixAccess('Mp', param=True, pointer=False), 'Mp', Mp_mat, t, 'UR.block<%d, %d>(0, 0)' % (eq.np, eq.mp_cols), None))
777            struct.add(MatrixDeclaration('Mpc', rows=eq.nonlin.Hc.shape[0], cols=1))
778            lines.append(self.trans_line(VectorAccess('Mpc', param=True, pointer=False), 'Mpc', eq.nonlin.Hc, t, 'Ucv', eq.Ucv.T))
779            struct.add(MatrixDeclaration('K', rows=eq.K.shape[0], cols=eq.K.shape[1]))
780            mp_cols = eq.get_mp_cols()
781            lines.append(self.trans_line(MatrixAccess('K', param=True, pointer=False), 'K', eq.K, t, 'UR.block<%d, %d>(0, %d)' % (eq.np, eq.get_mx_cols()-mp_cols, mp_cols), None))
782        d["update_pot"] = join_with_indent(lines)
783        return d
784
785
786class LinearCode(object):
787
788    def __init__(self, glob, eq):
789        self.glob = glob
790        self.eq = eq
791
792    def generate(self):
793        eq = self.eq
794        param = bool(eq.np)
795        Mp = MatrixAccess('Mp', param=param, pointer=False)
796        Mo = MatrixAccess('Mo', param=param, pointer=False)
797        Moc = VectorAccess('Moc', param=param, pointer=False)
798        Mx = MatrixAccess('Mx', param=param, pointer=False)
799        Mxc = VectorAccess('Mxc', param=param, pointer=False)
800        d = {}
801        if eq.nonlin and eq.nn != eq.nonlin.nni:
802            pblock = slice(0, eq.nonlin.nni)
803        else:
804            pblock = None
805        d["m_cols"] = eq.get_mx_cols()
806        d["gen_mp"] = gen_linear_combination(
807            self.glob, VectorAccess('mp',block=pblock), 'dp', Mp, eq.get_Mp())
808        d["gen_xn"] = gen_linear_combination(
809            self.glob, 'xn', 'd', Mx, eq.get_Mx(), Mxc, eq.Bc)
810        d["gen_xo"] = gen_linear_combination(
811            self.glob, 'xo', 'd', Mo, eq.get_Mo(), Moc, eq.Ec)
812        d["gen_xo_float"] = gen_linear_combination(
813            self.glob, 'xo', 'd', Mo, eq.get_Mo(), cast="float")
814        return d
815
816
817class NonlinEq(object):
818
819    def __init__(self, nn, nni, nno, np, npl, K, CZ, f, U, Hc, Mi, Kn):
820        self.nn = nn
821        self.nni = nni
822        self.nno = nno
823        self.np =  np
824        self.npl =  npl
825        self.K = K
826        self.CZ = CZ
827        self.f = f
828        self.U = U
829        self.Hc = Hc
830        self.Mi = Mi
831        self.Kn = Kn
832
833
834class LV2_Port_List(object):
835
836    def __init__(self, pot_attr, pot, eq):
837        self.pot_attr = pot_attr
838        self.pot = pot
839        self.ni = eq.ni
840        self.no = eq.no
841
842    def port_count(self):
843        return self.ni + self.no + len(self.pot_attr)
844
845    def __len__(self):
846        return self.ni + self.no + len(self.pot_attr)
847
848    def __iter__(self):
849        idx = 0
850        max_idx = len(self) - 1
851        for i in range(self.ni):
852            yield dict(
853                type_list="lv2:AudioPort , lv2:InputPort",
854                index = idx,
855                symbol = "in%d" % i,
856                name = "In%d" % i,
857                control_index = -1,
858                separator = "," if idx == max_idx else "",
859                )
860            idx += 1
861        for i in range(self.no):
862            yield dict(
863                type_list="lv2:AudioPort , lv2:OutputPort",
864                index = idx,
865                symbol = "out%d" % i,
866                name = "Out%d" % i,
867                control_index = -1,
868                separator = "," if idx == max_idx else "",
869                )
870            idx += 1
871        for i, row in enumerate(self.pot_attr):
872            var = row[0]
873            name = row[1]
874            yield dict(
875                type_list="lv2:InputPort , lv2:ControlPort",
876                index = idx,
877                symbol = var,
878                name = name,
879                default = self.pot.get(var, 0.5),
880                minimum = 0.0,
881                maximum = 1.0,
882                control_index = i,
883                separator = "," if idx == max_idx else "",
884                )
885            idx += 1
886
887
888class CodeGenerator(object):
889
890    def __init__(self, eq, solver_dict, solver_params, pot, pot_list, pot_func, pot_attr, Pv, extra_sources):
891        self.eq = eq
892        self.solver_dict = solver_set_defaults(solver_dict)
893        self.solver_params = solver_params
894        self.pot = pot
895        self.pot_list = pot_list
896        self.pot_func = pot_func
897        self.pot_attr = pot_attr
898        self.Pv = Pv
899        self.extra_sources = extra_sources
900        self.have_constant_matrices = (len(pot_func) == 0)
901
902    def pot_code(self):
903        d = {}
904        if self.pot_attr:
905            d['have_master_slider'] = True
906            d['master_slider_id'] = self.pot_attr[0][0]
907        else:
908            d['have_master_slider'] = False
909        d['knob_ids'] = [t[0] for t in self.pot_attr]
910        d['timecst'] = 0.01
911        d['regs'] = [dict(id=vv[0],name=vv[1],desc="",varidx=i) for i, vv in enumerate(self.pot_attr)]
912        ll = []
913        for i, (var, name, loga, inv, expr) in enumerate(self.pot_attr):
914            if loga and inv:
915                ss = "t[%d] = (exp(%s * (1-self.pots[%d])) - 1) / (exp(%s) - 1);" % (i, loga, i, loga)
916            elif loga:
917                ss = "t[%d] = (exp(%s * self.pots[%d]) - 1) / (exp(%s) - 1);" % (i, loga, i, loga)
918            elif inv:
919                ss = "t[%d] = 1-self.pots[%d];" % (i, i)
920            else:
921                ss = "t[%d] = self.pots[%d];" % (i, i)
922            ll.append(ss)
923        s = 0.993;
924        d['calc_pots'] = "\n    ".join(ll)
925        return d
926
927    @staticmethod
928    def walk_eqs(nlin):
929        yield nlin
930        if nlin.__class__ is dk_simulator.NonlinEquations:
931            return
932        for bl in nlin.subblocks:
933            for neq in CodeGenerator.walk_eqs(bl):
934                yield neq
935
936    def gen_nonlin(self, nonlin, s, d, complist, code):
937        eq = self.eq
938        if isinstance(nonlin, dk_simulator.PartitionedNonlinEquations):
939            for i, nlin in enumerate(nonlin.subblocks):
940                #spar = solver_set_defaults({} if self.solver_params is None else self.solver_params[i])
941                self.gen_nonlin(nlin, s, d, complist, code)
942            if self.solver_dict["method"] == "table":
943                solver = self.solver_dict.copy()
944                solver["method"] = "hybrCC"
945            else:
946                solver = self.solver_dict
947            code.append(NonlinCode(s, nonlin, solver, self.extra_sources).generate(d))
948        elif isinstance(nonlin, dk_simulator.ChainedNonlinEquations):
949            for i, nlin in enumerate(nonlin.subblocks):
950                #spar = solver_set_defaults({} if self.solver_params is None else self.solver_params[i])
951                self.gen_nonlin(nlin, s, d, complist, code)
952            solver = self.solver_dict
953            code.append(NonlinChained(s, nonlin).generate(d))
954        else:
955            complist.append(dict(
956                name = nonlin.name,
957                namespace = nonlin.namespace,
958                pins = ",".join(dk_simulator.Parser.format_element(v) for v in nonlin.pins),
959                v_slice = nonlin.v_slice,
960                p_slice = nonlin.p_slice,
961                i_slice = nonlin.i_slice,
962                ))
963            generator = None
964            d.overwrite("v0_guess", "")
965            if self.solver_params:
966                try:
967                    sp = self.solver_params[nonlin.name]
968                except KeyError:
969                    pass
970                except TypeError:
971                    pass
972                else:
973                    generator = sp.get("generator")
974                    d.overwrite("v0_guess", sp.get("v0_guess"))
975            if generator:
976                code.append(generator(s, nonlin, self.extra_sources, d))
977            elif self.solver_dict["method"] == "table":
978                code.append(TableCode(s, nonlin, self.extra_sources).generate(d))
979            else:
980                code.append(NonlinCode(s, nonlin, self.solver_dict, None).generate(d))
981
982    def add_dict(self, d):
983        eq = self.eq
984        s = Structure("nonlin_param")
985        par_ini = {}
986        glob = Namespace()
987        if self.have_constant_matrices:
988            d["update_pot"] = ""
989        else:
990            d.update(UpdateMatrix(glob, eq, self.pot, self.pot_list, self.pot_func, self.Pv).generate(s))
991        if eq.nonlin and eq.nonlin.nno:
992            for k, v in self.solver_dict.items():
993                d["solver_"+k] = v
994            s.add(MatrixDeclaration('p', rows=eq.nn, cols=1, pointer=True))
995            s.add(MatrixDeclaration('i', rows=eq.nn, cols=1, pointer=True))
996            s.add(MatrixDeclaration('v', rows=eq.nn, cols=1, mapping=True, pointer=True, aligned=False))
997            s.add_nonlin_info()
998            s.add_minmax_info(eq.nonlin.nni)
999            par_ini.update(dict(
1000                v = '&g_v',
1001                info = '&g_info',
1002                fnorm = '0',
1003                nfev = '&g_nfev',
1004                p = '0',
1005                i = '0',
1006                p_val = '0',
1007                ))
1008            ini = dict(
1009                p = MatrixAccess('mp'),
1010                i = MatrixAccess('mi'),
1011                v = '&g_v',
1012                info = '&g_info',
1013                fnorm = '&fnorm',
1014                nfev = '&g_nfev',
1015                p_val = MatrixAccess('p_val'),
1016                )
1017            d["nonlin_mat_list_calc"] = s.get_initializer(ini)
1018            for i, neq in enumerate(self.walk_eqs(eq.nonlin)):
1019                if i == 0:
1020                    neq.namespace = "nonlin"
1021                else:
1022                    neq.namespace = "nonlin_%d" % (i-1)
1023            complist = []
1024            code = []
1025            self.gen_nonlin(eq.nonlin, s, d, complist, code)
1026            d["nc"] = len(complist)
1027            d["components"] = complist
1028            d["nonlin_code"] = "".join(code)
1029            if eq.nn != eq.nonlin.nno:
1030                d["iblock"] = VectorAccess.block_expr([0, eq.nonlin.nno])
1031            else:
1032                d["iblock"] = ""
1033        else:
1034            d["nonlin_code"] = ""
1035            d["nc"] = 0
1036            d["iblock"] = ""
1037        d["nonlin_mat_list"] = s.get_initializer(par_ini)
1038        d.update(self.pot_code())
1039        d["struct_def"] = s
1040        d.update(LinearCode(glob, eq).generate())
1041        d["global_matrices"] = glob
1042        d["add_npl"] = 0
1043        d["DKPlugin_fields"] = ""
1044        d["DKPlugin_init"] = ""
1045        d["process_add"] = ""
1046        return d
1047
1048    def generate(self, d):
1049        d = self.add_dict(d)
1050        d["lv2_ports"] = LV2_Port_List(self.pot_attr, self.pot, self.eq)
1051        out = dict(c_source = dk_templates.c_template_top.render(d))
1052        plugindef = d["plugindef"]
1053        if plugindef.lv2_plugin_type:
1054            out["manifest.ttl"] = dk_templates.lv2_manifest.render(d)
1055            out["%s.ttl" % plugindef.lv2_versioned_id] = dk_templates.lv2_ttl.render(d)
1056        return out
1057