1from __future__ import division 2import math, sys 3from collections import OrderedDict 4import sympy as sp 5import numpy as np 6import numpy.matlib as ml 7import numpy.linalg as la 8import dk_templates 9from codelib import * 10import dk_simulator 11 12def solver_set_defaults(solver_dict=None): 13 if solver_dict is None: 14 solver_dict = {} 15 else: 16 solver_dict = solver_dict.copy() 17 solver_dict.setdefault("method", "hybr") 18 solver_dict.setdefault("factor", 1.e2) 19 solver_dict.setdefault("xtol", math.sqrt(sys.float_info.epsilon)) 20 solver_dict.setdefault("maxfev", 2000) 21 solver_dict.setdefault("max_homotopy_iter", 100) 22 return solver_dict 23 24 25class Structure(object): 26 27 def __init__(self, name): 28 self.name = name 29 self.members = [] 30 self.nonlin_info = False 31 self.minmax_info = None 32 33 def add(self, m): 34 self.members.append(m) 35 36 def add_nonlin_info(self): 37 self.nonlin_info = True; 38 39 def add_minmax_info(self, rows): 40 self.minmax_info = rows 41 42 def get_initializer(self, kw): 43 s = dict((m.name, i) for i, m in enumerate(v for v in self.members if v.pointer)) 44 n = len(s) 45 if self.nonlin_info: 46 s['info'] = n 47 s['nfev'] = n+1 48 s['fnorm'] = n+2 49 n += 3 50 if self.minmax_info: 51 s['p_val'] = n 52 n += 1 53 l = [None]*n 54 s2 = set() 55 for k, v in kw.items(): 56 try: 57 i = s[k] 58 except KeyError: 59 raise ValueError("unknown struct member: %s" % k) 60 if isinstance(v, VariableAccess): 61 l[i] = str(v.address_of()) 62 else: 63 l[i] = v 64 s2.add(k) 65 missing = set(s) - s2 66 if missing: 67 raise ValueError("no structure initialization for: %s" % ", ".join(missing)) 68 if not l: 69 return "" 70 return "("+", ".join(l)+")" 71 72 def generate_lines(self): 73 arglist = [m.generate("_") for m in self.members if m.pointer] 74 initlist = [] 75 for m in self.members: 76 if m.pointer: 77 initlist.append("%s(%s_)" % (m.name, m.name)) 78 else: 79 initlist.append("%s()" % m.name) 80 l = [] 81 l.append("struct %s {\n" % self.name) 82 l += [" %s;\n" % m for m in self.members] 83 if self.nonlin_info: 84 l.append(" int *info;\n") 85 l.append(" int *nfev;\n") 86 l.append(" creal *fnorm;\n") 87 initlist += ["info(info_)", "nfev(nfev_)", "fnorm(fnorm_)"] 88 arglist += ["int *info_", "int *nfev_", "creal *fnorm_"] 89 if self.minmax_info: 90 p_val = MatrixDeclaration('p_val', rows=self.minmax_info, cols=1, pointer=True, array=True) 91 l.append(" %s;\n" % p_val) 92 initlist.append("p_val(p_val_)") 93 arglist.append(p_val.generate("_")) 94 if initlist: 95 l.append(" inline %(name)s(%(args)s): %(init)s {}\n" % dict( 96 name = self.name, 97 args = ", ".join(arglist), 98 init = ", ".join(initlist), 99 )) 100 l.append("};") 101 return l 102 103 def __str__(self): 104 return "".join(self.generate_lines()) 105 106 107class NonlinFunction(object): 108 109 def __init__(self, global_ns, neq, on): 110 self.global_ns = global_ns 111 self.neq = neq 112 self.have_constant_matrices = (neq.eq.np == 0) 113 self.locals = Namespace() 114 self.input_slice = on.input_slice 115 self.output_slice = on.output_slice 116 117 def expr_list(self, v): 118 par_v = make_symbol_vector('(*par.v)', self.neq.nn) 119 l = [] 120 f = self.neq.eq.f 121 off = 0 122 if self.neq.v_slice: 123 off = self.neq.v_slice.start 124 f = f[self.neq.v_slice] 125 for expr, vl, base in f: 126 for var, idx in zip(vl, base): 127 if idx >= off: 128 vv = v[idx-off] 129 else: 130 vv = par_v[idx] 131 expr = expr.subs(var, vv) 132 l.append(expr) 133 return l 134 135 def generate(self, template): 136 if self.neq.v_slice: 137 iblockV = self.input_slice 138 oblockV = self.output_slice 139 blockM = (self.neq.v_slice, self.neq.v_slice) 140 start = self.output_slice.start 141 else: 142 iblockV = oblockV = None 143 blockM = None 144 start = 0 145 v = make_symbol_vector('v', self.neq.nn) 146 i = VectorAccess('i', param=True) 147 l = [] 148 l += expr_list_to_ccode(str(i), self.expr_list(v), '(%d)', start) 149 mv = MatrixDefinition('mv', rows=self.neq.nn, cols=1) 150 Mfvec = MatrixDefinition('Mfvec', value='fvec', rows=self.neq.nn, cols=1) 151 self.locals.add(mv) 152 self.locals.add(Mfvec) 153 l += self.locals.generate_lines() 154 CZ = self.neq.eq.CZ[self.neq.v_slice] 155 K = self.neq.eq.K 156 if self.have_constant_matrices: 157 K = K[blockM] 158 l.append("mv << %s;\n" % ", ".join( 159 ['v[%d]' % i if z else '0' for i, z in enumerate(CZ)])) 160 if not matrix_is_identity(K): 161 if self.have_constant_matrices: 162 self.global_ns.add(MatrixDefinition('K', K)) 163 K = MatrixAccess('K') 164 else: 165 K = MatrixAccess('K', param=True, block=blockM, pointer=False) 166 l.append("Mfvec = %(p)s + %(K)s * %(i)s - %(mv)s;\n" % dict( 167 p = VectorAccess('p', param=True, block=iblockV), 168 K = K, 169 i = VectorAccess('i', param=True, block=oblockV), 170 mv = VectorAccess('mv'), 171 )) 172 else: 173 l.append("Mfvec = %(p)s + %(i)s - %(mv)s;" % dict( 174 p = VectorAccess('p', param=True, block=iblockV), 175 i = VectorAccess('i', param=True, block=oblockV), 176 mv = VectorAccess('mv'), 177 )) 178 return template.render(dict(expression=join_with_indent(l))) 179 180 181class NonlinFunctionCC(object): 182 183 def __init__(self, global_ns, neq, extra_sources): 184 self.global_ns = global_ns 185 self.neq = neq 186 self.extra_sources = extra_sources 187 188 def generate(self, template): 189 neq = self.neq 190 have_constant_matrices = (neq.eq.np == 0) 191 end = neq.cc_slice.start 192 cc = neq.cc_slice.stop - neq.cc_slice.start 193 p0_slice = slice(neq.p_slice.start, neq.p_slice.stop-cc) 194 p1_slice = slice(neq.p_slice.stop-cc, neq.p_slice.stop) 195 i0_slice = slice(neq.i_slice.start, neq.i_slice.stop-cc) 196 i1_slice = slice(neq.i_slice.stop-cc, neq.i_slice.stop) 197 M0 = (slice(0, end), slice(end, neq.nn)) 198 M1 = (slice(end, neq.nn), slice(0, neq.nn)) 199 loc = Namespace() 200 l = [] 201 loc.add(MatrixDefinition('Mv', value='v', rows=neq.nn-end, cols=1, aligned=False, const=True)) 202 if self.extra_sources: 203 loc.add(MatrixDefinition('pt', rows=p0_slice.stop-p0_slice.start, cols=1, array=True)) 204 pt = VectorAccess('pt') 205 else: 206 loc.add(MatrixDefinition('pt', rows=neq.nn, cols=1)) 207 pt = VectorAccess('pt', block=p0_slice) 208 l += loc.generate_lines() 209 if have_constant_matrices: 210 #self.global_ns.add(MatrixDefinition('Ku', neq.eq.K[M0])) 211 self.global_ns.add(MatrixDefinition('Ku', neq.Ku)) 212 Ku = MatrixAccess('Ku') 213 else: 214 Ku = MatrixAccess('K', param=True, block=M0, pointer=False) 215 l.append("%s = %s + %s * %s;\n" % ( 216 pt, 217 VectorAccess('p', param=True, block=p0_slice), 218 Ku, 219 VectorAccess('Mv'))) 220 if self.extra_sources: 221 shape_transform = self.extra_sources.get("shape_transform", ()) 222 for mat in shape_transform: 223 self.global_ns.add(mat) 224 if shape_transform: 225 def nth(j): 226 idx = [] 227 for bl in neq.subblocks: 228 ln = bl.v_slice.stop - bl.v_slice.start 229 idx.append(j) 230 j += ln 231 return idx 232 def spliced(n): 233 for a in range(n): 234 yield a 235 yield a+n 236 nn2 = end//2 237 loc.add(MatrixDeclaration('PP1', rows=nn2, cols=1, array=True)) 238 l += loc.generate_lines() 239 l.append("PP1 << %s;\n" % ", ".join(["pt(%d)" % i for i in nth(1)])) 240 loc.add(MatrixDeclaration('PP0', rows=nn2, cols=1, array=True)) 241 l += loc.generate_lines() 242 l.append("PP0 << %s;\n" % ", ".join(["pt(%d)" % i for i in nth(0)])) 243 l += loc.generate_lines() 244 l.append("pt.head<%d>() = (Spm1 * PP1 + Ssm1) * PP1 + Sam1 + ((Spm2 * PP1 + Ssm2) * PP1 + Sam2) * PP0;\n" % nn2) 245 loc.add(MatrixDeclaration('res', rows=end, cols=1, array=True)) 246 l += loc.generate_lines() 247 inp = ["&pt(%d)" % (i//2) for i in range(end)] 248 outp = ["&res(%d)" % i for i in spliced(nn2)] 249 else: 250 inp = [] 251 outp = [] 252 for bl in neq.subblocks: 253 o = VectorAccess('i', param=True) 254 for j in range(bl.i_slice.start, bl.i_slice.stop): 255 inp.append("&pt(%d)" % bl.p_slice.start) 256 outp.append("&%s(%d)" % (o, j)) 257 jj = 0 258 tables = self.extra_sources["tables"] 259 for bl in neq.subblocks: 260 for j, kn in enumerate(tables[bl.namespace].knot_data): 261 reorder = False 262 unused = False 263 ll = [] 264 for i, v in enumerate(kn): 265 if not v.used(): 266 unused = True 267 else: 268 ll.append(i) 269 if unused: 270 reorder = True 271 if reorder: 272 l.append("{ Array<creal, %d, 1> pt2; pt2 << %s;\n" % (len(ll), ", ".join(["pt(%d)" % (bl.p_slice.start+i) for i in ll]))) 273 inpt = "&pt2(0)" 274 else: 275 inpt = inp[jj] 276 fu = "splev" 277 if kn[0].tp == 'pp': 278 fu = "splev_pp" 279 l.append("splinedata<AmpData::%s::maptype>::%s<%s>(&AmpData::%s::sd.sc[%d], %s, %s);\n" 280 % (bl.namespace, fu, ",".join([str(v.get_order()) for v in kn if v.used()]), bl.namespace, j, inpt, outp[jj])) 281 if reorder: 282 l.append("}\n") 283 jj += 1 284 if shape_transform: 285 l.append("pt.head<%d>() = ((Spm0 * PP1 + Ssm0) * PP1 + Sam0) * res.head<%d>();\n" % (nn2,nn2)) 286 l.append("pt.tail<%d>() = ((Spm0 * PP1 + Ssm0) * PP1 + Sam0) * res.tail<%d>();\n" % (nn2,nn2)) 287 l.append("%s << %s;\n" % (VectorAccess('i', param=True, block=i0_slice), 288 ", ".join(["pt(%d)" % v for v in spliced(nn2)]))) 289 else: 290 loc.add(MatrixDefinition('pp', rows=neq.nn, cols=1, pointer=True)) 291 l += loc.generate_lines() 292 l.append("%s = %s;\n" % ( 293 MatrixAccess('pp', pointer=True).address_of(), 294 MatrixAccess('p', param=True).address_of())) 295 restore = "%s = %s;\n" % ( 296 MatrixAccess('p', param=True).address_of(), 297 MatrixAccess('pp', pointer=True).address_of()) 298 l.append("%s = %s;\n" % ( 299 MatrixAccess('p', param=True).address_of(), 300 MatrixAccess('pt').address_of())) 301 l.append("int ret;\n") 302 for bl in neq.subblocks: 303 l.append("ret = %s::nonlin_solve(par);\n" % bl.namespace) 304 l.append("if (ret != 0) {\n") 305 l.append(" "+restore) 306 l.append(" return 1;\n") 307 l.append("};\n") 308 l.append(restore) 309 l.append("%s = %s;\n" % (VectorAccess('i', param=True, block=i1_slice), VectorAccess('Mv'))) 310 if have_constant_matrices: 311 if M1: 312 self.global_ns.add(MatrixDefinition('Kl', neq.Kl)) 313 Kl = MatrixAccess('Kl') 314 else: 315 Kl = MatrixAccess('K', block=M1) 316 else: 317 Kl = MatrixAccess('K', param=True, block=M1, pointer=False) 318 Mfvec = MatrixDefinition('Mfvec', value='fvec', rows=neq.nn-end, cols=1) 319 loc.add(Mfvec) 320 l += loc.generate_lines() 321 l.append("Mfvec = %(p)s + %(K)s * %(i)s;\n" % dict( 322 p = VectorAccess('p', param=True, block=p1_slice), 323 K = Kl, 324 i = VectorAccess('i', param=True, block=neq.i_slice), 325 )) 326 return template.render(dict(expression=join_with_indent(l))) 327 328 329class NonlinSolver(object): 330 331 def __init__(self, base, glob, neq, solver_dict): 332 self.base = base 333 self.neq = neq 334 self.solver_dict = solver_dict 335 self.have_constant_matrices = (neq.eq.np == 0) 336 self.global_ns = glob 337 self.local_ns = Namespace() 338 self.mp_is_ident = matrix_is_identity(self.neq.U) 339 self.mpc_is_zero = matrix_is_zero(self.neq.Hc) 340 if self.mp_is_ident and self.mpc_is_zero: 341 self.input_slice = self.neq.p_slice 342 else: 343 self.input_slice = slice(0, self.neq.nn) 344 self.mi_is_ident = matrix_is_identity(self.neq.Mi) 345 if self.mi_is_ident: 346 self.output_slice = self.neq.i_slice 347 else: 348 self.output_slice = slice(0, self.neq.nn) 349 self.base["nn"] = self.neq.nn 350 self.base["nni"] = self.neq.nni 351 self.base["nno"] = self.neq.nno 352 353 def p_transform(self): 354 l = [] 355 if self.mp_is_ident and self.mpc_is_zero: 356 return l 357 par = not self.have_constant_matrices 358 par_p = VectorAccess('p', param=True, block=self.neq.p_slice) 359 if not self.mp_is_ident: 360 if not par: 361 self.global_ns.add(MatrixDefinition('Mp', value=self.neq.U)) 362 s = "%s * %s" % (MatrixAccess('Mp', param=par), par_p) 363 else: 364 s = "%s" % par_p 365 if not self.mpc_is_zero: 366 if not par: 367 self.global_ns.add(MatrixDefinition('Mpc', value=self.neq.Hc)) 368 s += " + %s" % VectorAccess('Mpc', param=par, pointer=False) 369 g_nn = self.neq.eq.nonlin.nn 370 p_old_def = MatrixDefinition("p_old", rows=g_nn, cols=1, pointer=True) 371 self.local_ns.add(p_old_def) 372 p_old = VectorAccess(p_old_def) 373 p2_def = MatrixDefinition("p2", rows=g_nn, cols=1) 374 self.local_ns.add(p2_def) 375 p2 = VectorAccess(p2_def, block=self.input_slice) 376 l += self.local_ns.generate_lines() 377 l.append("%s = %s;\n" % (p2, s)) 378 l.append("p_old = %s;\n" % par_p.address_of()) 379 l.append("%s = &p2;\n" % par_p.address_of()) 380 self.cleanup.append("%s = p_old;\n" % par_p.address_of()) 381 return l 382 383 def i_transform(self): 384 l = [] 385 if self.mi_is_ident or self.neq.nno == 0: 386 return l 387 g_nn = self.neq.eq.nonlin.nn 388 mi_def = MatrixDefinition("mi", rows=g_nn, cols=1) 389 mi = VectorAccess(mi_def, block=self.output_slice) 390 i_old_def = MatrixDefinition("i_old", rows=g_nn, cols=1, pointer=True) 391 i_old = VectorAccess(i_old_def) 392 par_i = VectorAccess('i', param=True, block=self.neq.i_slice) 393 self.global_ns.add(MatrixDefinition("Mi", self.neq.Mi)) 394 self.local_ns.add(mi_def) 395 self.local_ns.add(i_old_def) 396 self.setup += self.local_ns.generate_lines() 397 self.setup.append("i_old = %s;\n" % VectorAccess('i', param=True).address_of()) 398 self.setup.append("%s = &mi;\n" % VectorAccess('i', param=True).address_of()) 399 l.append("%s = i_old;\n" % VectorAccess('i', param=True).address_of()) 400 par = not self.have_constant_matrices 401 l.append("%(io)s = %(Mi)s * %(ii)s;" % dict( 402 io = par_i, 403 Mi = MatrixAccess('Mi', param=par, pointer=False), 404 ii = mi, 405 )) 406 return l 407 408 def make_var_v_ref(self): 409 start = 0 410 if self.neq.v_slice: 411 start = self.neq.v_slice.start 412 return "&%s(%s)" % (VectorAccess("v", param=True), start) 413 414 def generate(self, template): 415 d = self.base.copy() 416 self.setup = [] 417 self.cleanup = [] 418 d["i_transform"] = join_with_indent(self.i_transform()) 419 d["p_transform"] = join_with_indent(self.p_transform()) 420 d["setup"] = join_with_indent(self.setup) 421 d["cleanup"] = join_with_indent(self.cleanup) 422 d["var_v_ref"] = self.make_var_v_ref() 423 d["store_p"] = "%s = %s;\n" % ( 424 VectorAccess('p_val', param=True, block=self.neq.p_slice), 425 VectorAccess('p', param=True, block=self.neq.p_slice), 426 ) 427 return template.render(d) 428 429class NonlinSolverCC(NonlinSolver): 430 431 def __init__(self, base, glob, neq, solver_dict): 432 NonlinSolver.__init__(self, base, glob, neq, solver_dict) 433 self.input_slice = self.neq.p_slice 434 self.mp_is_ident = True 435 self.mi_is_ident = True 436 437 def make_var_v_ref(self): 438 return "&%s(%s)" % (VectorAccess("v", param=True), self.neq.cc_slice.start) 439 440 def generate(self, template): 441 self.base = self.base.copy() 442 self.base["nn"] = self.base["nni"] = self.base["nno"] = self.neq.cc_slice.stop - self.neq.cc_slice.start 443 return NonlinSolver.generate(self, template) 444 445 446class NonlinCode(object): 447 448 method_templates = { 449 'hybr': (dk_templates.c_template_nonlin_func_hybrCC, 450 dk_templates.c_template_nonlin_solver_hybrCC), 451 'lm': (dk_templates.c_template_nonlin_func_lm, 452 dk_templates.c_template_nonlin_solver_lm), 453 'hybrCC': (dk_templates.c_template_nonlin_func_hybrCC, 454 dk_templates.c_template_nonlin_solver_hybrCC), 455 } 456 457 def __init__(self, struct, neq, solver_dict, extra_sources): 458 self.struct = struct 459 self.neq = neq 460 self.solver_dict = solver_dict 461 self.extra_sources = extra_sources 462 463 def setup(self, d): 464 neq = self.neq 465 base = {} 466 base["v0_guess"] = d["v0_guess"] 467 base["dev_interface"] = d["dev_interface"] 468 g_nonlin = neq.eq.nonlin 469 base["extern_nonlin"] = not neq.subblocks or neq is g_nonlin 470 base["g_nn"] = g_nonlin.nn 471 base["g_nni"] = g_nonlin.nni 472 base["g_nno"] = g_nonlin.nno 473 base["npl"] = neq.eq.get_npl() 474 if neq.nn != g_nonlin.nn: 475 nn = base["nn"] = neq.nn 476 nni = base["nni"] = neq.nni 477 nno = base["nno"] = neq.nno 478 self.blockV = neq.v_slice 479 self.pblockV = neq.p_slice 480 base["pblockV"] = VectorAccess.block_expr(neq.p_slice) 481 base["iblockV"] = VectorAccess.block_expr(neq.i_slice) 482 else: 483 base["nn"] = nn = neq.nn 484 base["nni"] = nni = neq.nni 485 base["nno"] = nno = neq.nno 486 self.blockV = None 487 if nn != nni: 488 self.pblockV = slice(0, nni) 489 base["pblockV"] = VectorAccess.block_expr([0, nni]) 490 else: 491 self.pblockV = None 492 base["pblockV"] = "" 493 if nn != nno: 494 base["iblockV"] = VectorAccess.block_expr([0, nno]) 495 else: 496 base["iblockV"] = "" 497 if self.solver_dict: 498 for k, v in self.solver_dict.items(): 499 base["solver_"+k] = v 500 ini = dict( 501 p = MatrixAccess('mp'), 502 i = MatrixAccess('mi'), 503 v = MatrixAccess('Mv'), 504 info = 'info', 505 fnorm = 'fnorm', 506 nfev = 'nfev', 507 p_val = MatrixAccess('p_val'), 508 ) 509 base["nonlin_mat_list"] = self.struct.get_initializer(ini) 510 self.glob = Namespace() 511 base["namespace"] = neq.namespace 512 base["global_data_def"] = self.glob 513 return base 514 515 def generate(self, d): 516 base = self.setup(d) 517 d = base.copy() 518 neq = self.neq 519 func_t, solv_t = self.method_templates[self.solver_dict["method"]] 520 if isinstance(neq, dk_simulator.PartitionedNonlinEquations): 521 of = NonlinFunctionCC(self.glob, neq, self.extra_sources) 522 on = NonlinSolverCC(base, self.glob, neq, self.solver_dict) 523 else: 524 on = NonlinSolver(base, self.glob, neq, self.solver_dict) 525 of = NonlinFunction(self.glob, neq, on) 526 if func_t is not None: 527 d["fcn_def"] = of.generate(func_t) 528 d["nonlin_def"] = on.generate(solv_t) 529 d["par_p"] = VectorAccess("p", param=True, block=self.pblockV) 530 d["par_v"] = VectorAccess("v", param=True, block=self.blockV) 531 return dk_templates.c_template_nonlin_solver.render(d) 532 533 534class NonlinChained(NonlinCode): 535 536 def __init__(self, s, neq): 537 NonlinCode.__init__(self, s, neq, None, None) 538 539 def generate(self, d): 540 d = self.setup(d) 541 self.glob.add(MatrixDefinition('K', self.neq.eq.K)) 542 template = dk_templates.c_template_nonlin_chained 543 l = [] 544 l.append("%s = %s;\n" % ( 545 VectorAccess('p_val', param=True, block=self.neq.p_slice), 546 VectorAccess('p', param=True, block=self.neq.p_slice), 547 )) 548 g_nn = self.neq.eq.nonlin.nn 549 p_old_def = MatrixDefinition("p_old", rows=g_nn, cols=1, pointer=True) 550 local_ns = Namespace() 551 local_ns.add(p_old_def) 552 p_old = VectorAccess(p_old_def) 553 p2_def = MatrixDefinition("p2", rows=g_nn, cols=1) 554 local_ns.add(p2_def) 555 p2 = VectorAccess(p2_def, block=self.neq.p_slice) 556 l += local_ns.generate_lines() 557 l.append("int ret = 0;\n") 558 l.append("p2 = *par.p;\n") 559 l.append("p_old = par.p;\n") 560 l.append("par.p = &p2;\n") 561 for nlin in self.neq.subblocks: 562 st = nlin.p_slice.start 563 if st: 564 l.append("p2.segment<%(ln)d>(%(st)d) += K.block<%(ln)d,%(st)d>(%(st)d,0) * (*par.i).head<%(st)d>();\n" 565 % dict(ln=nlin.p_slice.stop-st, st=st)) 566 l.append("ret = %s::nonlin_solve(par);\n" % nlin.namespace) 567 l.append("if (ret != 0) {\n") 568 l.append(" par.p = p_old;\n") 569 l.append(" return ret;\n") 570 l.append("}\n") 571 l.append("par.p = p_old;\n") 572 l.append("return 0;\n") 573 d["chained_code"] = join_with_indent(l) 574 return template.render(d) 575 576 577class TableCode(object): 578 579 def __init__(self, struct, neq, extra_sources): 580 self.struct = struct 581 self.neq = neq 582 self.extra_sources = extra_sources 583 584 def add(self, d): 585 neq = self.neq 586 base = {} 587 base["dev_interface"] = d["dev_interface"] 588 g_nonlin = neq.eq.nonlin 589 base["g_nn"] = g_nonlin.nn 590 base["g_nni"] = g_nonlin.nni 591 base["g_nno"] = g_nonlin.nno 592 base["npl"] = neq.eq.get_npl() 593 nn = base["nn"] = neq.nn 594 nni = base["nni"] = neq.nni 595 nno = base["nno"] = neq.nno 596 if neq.v_slice: 597 self.blockV = neq.v_slice 598 self.pblockV = neq.p_slice 599 self.iblockV = neq.i_slice 600 base["blockV"] = VectorAccess.block_expr(neq.v_slice) 601 base["pblockV"] = VectorAccess.block_expr(neq.p_slice) 602 base["iblockV"] = VectorAccess.block_expr(neq.i_slice) 603 else: 604 self.blockV = None 605 base["blockV"] = "" 606 if nn != nni: 607 self.pblockV = slice(0, nni) 608 base["pblockV"] = VectorAccess.block_expr([0, nni]) 609 else: 610 self.pblockV = None 611 base["pblockV"] = "" 612 if nn != nno: 613 base["iblockV"] = VectorAccess.block_expr([0, nno]) 614 else: 615 base["iblockV"] = "" 616 ini = dict( 617 p = MatrixAccess('mp'), 618 i = MatrixAccess('mi'), 619 v = MatrixAccess('Mv'), 620 info = '&g_info', 621 fnorm = 'fnorm', 622 nfev = '&g_nfev', 623 p_val = MatrixAccess('p_val'), 624 ) 625 base["nonlin_mat_list"] = self.struct.get_initializer(ini) 626 self.glob = Namespace() 627 d = base.copy() 628 d["namespace"] = neq.namespace 629 d["global_data_def"] = self.glob 630 d["par_p"] = VectorAccess("p", param=True, block=self.pblockV) 631 d["par_v"] = VectorAccess("v", param=True, block=self.iblockV) 632 return d 633 634 def generate(self, d): 635 d = self.add(d) 636 neq = self.neq 637 tables = self.extra_sources["tables"] 638 l = [] 639 l.append("real t[AmpData::%(namespace)s::sd.m];\n" % d) 640 l.append("real m[%(nni)d+%(npl)d];\n" % d) 641 l.append("Map<Matrix<real, %(nni)d+%(npl)d, 1> >mp(m);\n" % d) 642 l.append("mp << last_pot.cast<real>(), (*par.p)%(pblockV)s.cast<real>();\n" % d) 643 for j, kn in enumerate(tables[neq.namespace].knot_data): 644 reorder = False 645 unused = False 646 fu = "splev" 647 ll = [] 648 for i, v in enumerate(kn): 649 if not v.used(): 650 unused = True 651 else: 652 ll.append(i) 653 if unused: 654 reorder = True 655 if kn[i].tp == 'pp': 656 fu = "splev_pp" 657 if reorder: 658 l.append("{ Array<creal, %d, 1> pt2; pt2 << %s;\n" % (len(ll), ", ".join(["mp(%d)" % i for i in ll]))) 659 inpt = "&pt2(0)" 660 else: 661 inpt = "m" 662 l.append("splinedata<AmpData::%s::maptype>::%s<%s>(&AmpData::%s::sd.sc[%d], %s, &t[%d]);\n" 663 % (neq.namespace, fu, ",".join([str(v.get_order()) for v in kn if v.used()]), neq.namespace, j, inpt, j)) 664 if reorder: 665 l.append("}\n") 666 l.append("(*par.i)%(iblockV)s = Map<Matrix<real, %(nno)d, 1> >(t).cast<creal>();\n" % d) 667 d["call"] = join_with_indent(l) 668 return dk_templates.c_template_table.render(d) 669 670 671class UpdateMatrix(object): 672 673 def __init__(self, glob, eq, pot, pot_list, pot_func, Pv): 674 self.glob = glob 675 self.eq = eq 676 self.pot = pot 677 self.pot_func = pot_func 678 self.pot_list = pot_list 679 self.Pv = Pv 680 681 def trans_line(self, res, var, var_data, t, u, u_data): 682 if not isinstance(res, VariableAccess): 683 res = MatrixAccess(res) 684 if not isinstance(var, VariableAccess): 685 var = MatrixAccess(var) 686 self.glob.add(MatrixDefinition(var, var_data)) 687 self.glob.add(MatrixDefinition(res, rows=var_data.shape[0], cols=var_data.shape[1])) 688 s = "%s = %s" % (res, var) 689 if t is None: 690 return s + ";\n" 691 if u_data is not None: 692 self.glob.add(MatrixDefinition(u, u_data)) 693 return s + " - %s * %s;\n" % (t, u) 694 695 def generate(self, struct): 696 eq = self.eq 697 d = {} 698 pot = make_symbol_vector('pot', eq.np) 699 l = [] 700 for (a, f), p in zip(self.pot_func, self.Pv): 701 s = str(a) 702 try: 703 i = self.pot_list.index(s) 704 except ValueError: 705 self.pot_list.append(s) 706 i = len(self.pot_list)-1 707 expr = f.subs(a, pot[i]) * p 708 l.append(expr) 709 d["pot_vars"] = ",".join(['"%s"' % v for v in self.pot_list]) 710 d["pot"] = ",".join([str(self.pot.get(v,0.5)) for v in self.pot_list]) 711 nx = eq.nx 712 no = eq.no 713 np = eq.np 714 nn = eq.nn 715 loc = Namespace() 716 loc.add(MatrixDeclaration("Rv", rows=np, cols=1)) 717 lines = [] 718 lines += loc.generate_lines() 719 lines += expr_list_to_ccode('Rv', l, '(%d)') 720 self.glob.add(MatrixDefinition("Q", eq.Q)) 721 loc.add(MatrixDeclaration("Qi", rows=np, cols=np)) 722 lines += loc.generate_lines() 723 lines.append("Qi = (%s + Matrix<creal, %d, %d>(Rv.asDiagonal())).inverse();\n" % (MatrixAccess('Q'), np, np)) 724 if eq.nx: 725 if matrix_is_identity(eq.Uxl): 726 t = "Qi" 727 elif matrix_is_zero(eq.Uxl): 728 t = None 729 else: 730 loc.add(MatrixDeclaration("Tx", rows=nx, cols=np)) 731 lines += loc.generate_lines() 732 self.glob.add(MatrixDefinition("Uxl", eq.Uxl)) 733 lines.append("Tx = %s * Qi;\n" % MatrixAccess('Uxl')) 734 t = "Tx" 735 Mx_mat = eq.get_Mx() 736 m = MatrixDeclaration('Mx', rows=Mx_mat.shape[0], cols=Mx_mat.shape[1]) 737 struct.add(m) 738 Mx = MatrixAccess(m, param=True) 739 lines.append(self.trans_line(Mx, 'Mx', Mx_mat, t, 'UR', eq.get_UR())) 740 m = MatrixDeclaration('Mxc', rows=eq.Bc.shape[0], cols=1) 741 struct.add(m) 742 Mxc = VectorAccess(m, param=True) 743 lines.append(self.trans_line(Mxc, 'Mxc', eq.Bc, t, 'Ucv', eq.Ucv.T)) 744 if matrix_is_identity(eq.Uo): 745 t = "Qi" 746 elif matrix_is_zero(eq.Uo): 747 t = None 748 else: 749 loc.add(MatrixDeclaration("To", rows=no, cols=np)) 750 lines += loc.generate_lines() 751 self.glob.add(MatrixDefinition('Uo', eq.Uo)) 752 lines.append("To = %s * Qi;\n" % MatrixAccess('Uo')) 753 t = "To" 754 Mo_mat = eq.get_Mo() 755 m = MatrixDeclaration('Mo', rows=Mo_mat.shape[0], cols=Mo_mat.shape[1]) 756 struct.add(m) 757 Mo = MatrixAccess(m, param=True) 758 lines.append(self.trans_line(Mo, 'Mo', Mo_mat, t, 'UR', eq.get_UR())) 759 m = MatrixDeclaration('Moc', rows=eq.Ec.shape[0], cols=1) 760 struct.add(m) 761 Moc = VectorAccess(m, param=True) 762 lines.append(self.trans_line(Moc, 'Moc', eq.Ec, t, 'Ucv', eq.Ucv.T)) 763 if eq.nonlin and eq.nonlin.nni: 764 if matrix_is_identity(eq.Unl): 765 t = "Qi" 766 elif matrix_is_zero(eq.Unl): 767 t = None 768 else: 769 loc.add(MatrixDeclaration("Tp", rows=nn, cols=np)) 770 lines += loc.generate_lines() 771 self.glob.add(MatrixDefinition("Unl", eq.Unl)) 772 lines.append("Tp = %s * Qi;\n" % MatrixAccess('Unl')) 773 t = "Tp" 774 Mp_mat = eq.get_Mp() 775 struct.add(MatrixDeclaration('Mp', rows=Mp_mat.shape[0], cols=Mp_mat.shape[1])) 776 lines.append(self.trans_line(MatrixAccess('Mp', param=True, pointer=False), 'Mp', Mp_mat, t, 'UR.block<%d, %d>(0, 0)' % (eq.np, eq.mp_cols), None)) 777 struct.add(MatrixDeclaration('Mpc', rows=eq.nonlin.Hc.shape[0], cols=1)) 778 lines.append(self.trans_line(VectorAccess('Mpc', param=True, pointer=False), 'Mpc', eq.nonlin.Hc, t, 'Ucv', eq.Ucv.T)) 779 struct.add(MatrixDeclaration('K', rows=eq.K.shape[0], cols=eq.K.shape[1])) 780 mp_cols = eq.get_mp_cols() 781 lines.append(self.trans_line(MatrixAccess('K', param=True, pointer=False), 'K', eq.K, t, 'UR.block<%d, %d>(0, %d)' % (eq.np, eq.get_mx_cols()-mp_cols, mp_cols), None)) 782 d["update_pot"] = join_with_indent(lines) 783 return d 784 785 786class LinearCode(object): 787 788 def __init__(self, glob, eq): 789 self.glob = glob 790 self.eq = eq 791 792 def generate(self): 793 eq = self.eq 794 param = bool(eq.np) 795 Mp = MatrixAccess('Mp', param=param, pointer=False) 796 Mo = MatrixAccess('Mo', param=param, pointer=False) 797 Moc = VectorAccess('Moc', param=param, pointer=False) 798 Mx = MatrixAccess('Mx', param=param, pointer=False) 799 Mxc = VectorAccess('Mxc', param=param, pointer=False) 800 d = {} 801 if eq.nonlin and eq.nn != eq.nonlin.nni: 802 pblock = slice(0, eq.nonlin.nni) 803 else: 804 pblock = None 805 d["m_cols"] = eq.get_mx_cols() 806 d["gen_mp"] = gen_linear_combination( 807 self.glob, VectorAccess('mp',block=pblock), 'dp', Mp, eq.get_Mp()) 808 d["gen_xn"] = gen_linear_combination( 809 self.glob, 'xn', 'd', Mx, eq.get_Mx(), Mxc, eq.Bc) 810 d["gen_xo"] = gen_linear_combination( 811 self.glob, 'xo', 'd', Mo, eq.get_Mo(), Moc, eq.Ec) 812 d["gen_xo_float"] = gen_linear_combination( 813 self.glob, 'xo', 'd', Mo, eq.get_Mo(), cast="float") 814 return d 815 816 817class NonlinEq(object): 818 819 def __init__(self, nn, nni, nno, np, npl, K, CZ, f, U, Hc, Mi, Kn): 820 self.nn = nn 821 self.nni = nni 822 self.nno = nno 823 self.np = np 824 self.npl = npl 825 self.K = K 826 self.CZ = CZ 827 self.f = f 828 self.U = U 829 self.Hc = Hc 830 self.Mi = Mi 831 self.Kn = Kn 832 833 834class LV2_Port_List(object): 835 836 def __init__(self, pot_attr, pot, eq): 837 self.pot_attr = pot_attr 838 self.pot = pot 839 self.ni = eq.ni 840 self.no = eq.no 841 842 def port_count(self): 843 return self.ni + self.no + len(self.pot_attr) 844 845 def __len__(self): 846 return self.ni + self.no + len(self.pot_attr) 847 848 def __iter__(self): 849 idx = 0 850 max_idx = len(self) - 1 851 for i in range(self.ni): 852 yield dict( 853 type_list="lv2:AudioPort , lv2:InputPort", 854 index = idx, 855 symbol = "in%d" % i, 856 name = "In%d" % i, 857 control_index = -1, 858 separator = "," if idx == max_idx else "", 859 ) 860 idx += 1 861 for i in range(self.no): 862 yield dict( 863 type_list="lv2:AudioPort , lv2:OutputPort", 864 index = idx, 865 symbol = "out%d" % i, 866 name = "Out%d" % i, 867 control_index = -1, 868 separator = "," if idx == max_idx else "", 869 ) 870 idx += 1 871 for i, row in enumerate(self.pot_attr): 872 var = row[0] 873 name = row[1] 874 yield dict( 875 type_list="lv2:InputPort , lv2:ControlPort", 876 index = idx, 877 symbol = var, 878 name = name, 879 default = self.pot.get(var, 0.5), 880 minimum = 0.0, 881 maximum = 1.0, 882 control_index = i, 883 separator = "," if idx == max_idx else "", 884 ) 885 idx += 1 886 887 888class CodeGenerator(object): 889 890 def __init__(self, eq, solver_dict, solver_params, pot, pot_list, pot_func, pot_attr, Pv, extra_sources): 891 self.eq = eq 892 self.solver_dict = solver_set_defaults(solver_dict) 893 self.solver_params = solver_params 894 self.pot = pot 895 self.pot_list = pot_list 896 self.pot_func = pot_func 897 self.pot_attr = pot_attr 898 self.Pv = Pv 899 self.extra_sources = extra_sources 900 self.have_constant_matrices = (len(pot_func) == 0) 901 902 def pot_code(self): 903 d = {} 904 if self.pot_attr: 905 d['have_master_slider'] = True 906 d['master_slider_id'] = self.pot_attr[0][0] 907 else: 908 d['have_master_slider'] = False 909 d['knob_ids'] = [t[0] for t in self.pot_attr] 910 d['timecst'] = 0.01 911 d['regs'] = [dict(id=vv[0],name=vv[1],desc="",varidx=i) for i, vv in enumerate(self.pot_attr)] 912 ll = [] 913 for i, (var, name, loga, inv, expr) in enumerate(self.pot_attr): 914 if loga and inv: 915 ss = "t[%d] = (exp(%s * (1-self.pots[%d])) - 1) / (exp(%s) - 1);" % (i, loga, i, loga) 916 elif loga: 917 ss = "t[%d] = (exp(%s * self.pots[%d]) - 1) / (exp(%s) - 1);" % (i, loga, i, loga) 918 elif inv: 919 ss = "t[%d] = 1-self.pots[%d];" % (i, i) 920 else: 921 ss = "t[%d] = self.pots[%d];" % (i, i) 922 ll.append(ss) 923 s = 0.993; 924 d['calc_pots'] = "\n ".join(ll) 925 return d 926 927 @staticmethod 928 def walk_eqs(nlin): 929 yield nlin 930 if nlin.__class__ is dk_simulator.NonlinEquations: 931 return 932 for bl in nlin.subblocks: 933 for neq in CodeGenerator.walk_eqs(bl): 934 yield neq 935 936 def gen_nonlin(self, nonlin, s, d, complist, code): 937 eq = self.eq 938 if isinstance(nonlin, dk_simulator.PartitionedNonlinEquations): 939 for i, nlin in enumerate(nonlin.subblocks): 940 #spar = solver_set_defaults({} if self.solver_params is None else self.solver_params[i]) 941 self.gen_nonlin(nlin, s, d, complist, code) 942 if self.solver_dict["method"] == "table": 943 solver = self.solver_dict.copy() 944 solver["method"] = "hybrCC" 945 else: 946 solver = self.solver_dict 947 code.append(NonlinCode(s, nonlin, solver, self.extra_sources).generate(d)) 948 elif isinstance(nonlin, dk_simulator.ChainedNonlinEquations): 949 for i, nlin in enumerate(nonlin.subblocks): 950 #spar = solver_set_defaults({} if self.solver_params is None else self.solver_params[i]) 951 self.gen_nonlin(nlin, s, d, complist, code) 952 solver = self.solver_dict 953 code.append(NonlinChained(s, nonlin).generate(d)) 954 else: 955 complist.append(dict( 956 name = nonlin.name, 957 namespace = nonlin.namespace, 958 pins = ",".join(dk_simulator.Parser.format_element(v) for v in nonlin.pins), 959 v_slice = nonlin.v_slice, 960 p_slice = nonlin.p_slice, 961 i_slice = nonlin.i_slice, 962 )) 963 generator = None 964 d.overwrite("v0_guess", "") 965 if self.solver_params: 966 try: 967 sp = self.solver_params[nonlin.name] 968 except KeyError: 969 pass 970 except TypeError: 971 pass 972 else: 973 generator = sp.get("generator") 974 d.overwrite("v0_guess", sp.get("v0_guess")) 975 if generator: 976 code.append(generator(s, nonlin, self.extra_sources, d)) 977 elif self.solver_dict["method"] == "table": 978 code.append(TableCode(s, nonlin, self.extra_sources).generate(d)) 979 else: 980 code.append(NonlinCode(s, nonlin, self.solver_dict, None).generate(d)) 981 982 def add_dict(self, d): 983 eq = self.eq 984 s = Structure("nonlin_param") 985 par_ini = {} 986 glob = Namespace() 987 if self.have_constant_matrices: 988 d["update_pot"] = "" 989 else: 990 d.update(UpdateMatrix(glob, eq, self.pot, self.pot_list, self.pot_func, self.Pv).generate(s)) 991 if eq.nonlin and eq.nonlin.nno: 992 for k, v in self.solver_dict.items(): 993 d["solver_"+k] = v 994 s.add(MatrixDeclaration('p', rows=eq.nn, cols=1, pointer=True)) 995 s.add(MatrixDeclaration('i', rows=eq.nn, cols=1, pointer=True)) 996 s.add(MatrixDeclaration('v', rows=eq.nn, cols=1, mapping=True, pointer=True, aligned=False)) 997 s.add_nonlin_info() 998 s.add_minmax_info(eq.nonlin.nni) 999 par_ini.update(dict( 1000 v = '&g_v', 1001 info = '&g_info', 1002 fnorm = '0', 1003 nfev = '&g_nfev', 1004 p = '0', 1005 i = '0', 1006 p_val = '0', 1007 )) 1008 ini = dict( 1009 p = MatrixAccess('mp'), 1010 i = MatrixAccess('mi'), 1011 v = '&g_v', 1012 info = '&g_info', 1013 fnorm = '&fnorm', 1014 nfev = '&g_nfev', 1015 p_val = MatrixAccess('p_val'), 1016 ) 1017 d["nonlin_mat_list_calc"] = s.get_initializer(ini) 1018 for i, neq in enumerate(self.walk_eqs(eq.nonlin)): 1019 if i == 0: 1020 neq.namespace = "nonlin" 1021 else: 1022 neq.namespace = "nonlin_%d" % (i-1) 1023 complist = [] 1024 code = [] 1025 self.gen_nonlin(eq.nonlin, s, d, complist, code) 1026 d["nc"] = len(complist) 1027 d["components"] = complist 1028 d["nonlin_code"] = "".join(code) 1029 if eq.nn != eq.nonlin.nno: 1030 d["iblock"] = VectorAccess.block_expr([0, eq.nonlin.nno]) 1031 else: 1032 d["iblock"] = "" 1033 else: 1034 d["nonlin_code"] = "" 1035 d["nc"] = 0 1036 d["iblock"] = "" 1037 d["nonlin_mat_list"] = s.get_initializer(par_ini) 1038 d.update(self.pot_code()) 1039 d["struct_def"] = s 1040 d.update(LinearCode(glob, eq).generate()) 1041 d["global_matrices"] = glob 1042 d["add_npl"] = 0 1043 d["DKPlugin_fields"] = "" 1044 d["DKPlugin_init"] = "" 1045 d["process_add"] = "" 1046 return d 1047 1048 def generate(self, d): 1049 d = self.add_dict(d) 1050 d["lv2_ports"] = LV2_Port_List(self.pot_attr, self.pot, self.eq) 1051 out = dict(c_source = dk_templates.c_template_top.render(d)) 1052 plugindef = d["plugindef"] 1053 if plugindef.lv2_plugin_type: 1054 out["manifest.ttl"] = dk_templates.lv2_manifest.render(d) 1055 out["%s.ttl" % plugindef.lv2_versioned_id] = dk_templates.lv2_ttl.render(d) 1056 return out 1057