1 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
2 // Copyright (c) Lawrence Livermore National Security, LLC and other Ascent
3 // Project developers. See top-level LICENSE AND COPYRIGHT files for dates and
4 // other details. No copyright assignment is required to contribute to Ascent.
5 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
6
7 //-----------------------------------------------------------------------------
8 ///
9 /// file: ascent_jit_fusion.cpp
10 ///
11 //-----------------------------------------------------------------------------
12
13 #include "ascent_jit_fusion.hpp"
14 #include "ascent_blueprint_topologies.hpp"
15 #include "ascent_blueprint_architect.hpp"
16 #include <ascent_logging.hpp>
17
18 //-----------------------------------------------------------------------------
19 // -- begin ascent:: --
20 //-----------------------------------------------------------------------------
21 namespace ascent
22 {
23
24 //-----------------------------------------------------------------------------
25 // -- begin ascent::runtime --
26 //-----------------------------------------------------------------------------
27 namespace runtime
28 {
29
30 //-----------------------------------------------------------------------------
31 // -- begin ascent::runtime::expressions--
32 //-----------------------------------------------------------------------------
33 namespace expressions
34 {
35
36
37
JitableFusion(const conduit::Node & params,const std::vector<const Jitable * > & input_jitables,const std::vector<const Kernel * > & input_kernels,const std::string & filter_name,const conduit::Node & dataset,const int dom_idx,const bool not_fused,Jitable & out_jitable,Kernel & out_kernel)38 JitableFusion::JitableFusion(
39 const conduit::Node ¶ms,
40 const std::vector<const Jitable *> &input_jitables,
41 const std::vector<const Kernel *> &input_kernels,
42 const std::string &filter_name,
43 const conduit::Node &dataset,
44 const int dom_idx,
45 const bool not_fused,
46 Jitable &out_jitable,
47 Kernel &out_kernel)
48 : params(params), input_jitables(input_jitables),
49 input_kernels(input_kernels), filter_name(filter_name), dataset(dataset),
50 dom_idx(dom_idx), not_fused(not_fused), out_jitable(out_jitable),
51 out_kernel(out_kernel), inputs(params["inputs"]),
52 domain(dataset.child(dom_idx))
53 {
54 }
55
56 void
binary_op()57 JitableFusion::binary_op()
58 {
59 if(not_fused)
60 {
61 const int lhs_port = inputs["lhs/port"].to_int32();
62 const int rhs_port = inputs["rhs/port"].to_int32();
63 const Kernel &lhs_kernel = *input_kernels[lhs_port];
64 const Kernel &rhs_kernel = *input_kernels[rhs_port];
65 // union the field/mesh vars
66 out_kernel.fuse_kernel(lhs_kernel);
67 out_kernel.fuse_kernel(rhs_kernel);
68 const std::string lhs_expr = lhs_kernel.expr;
69 const std::string rhs_expr = rhs_kernel.expr;
70 const std::string &op_str = params["op_string"].as_string();
71 if(lhs_kernel.num_components == 1 && rhs_kernel.num_components == 1)
72 {
73 // scalar ops
74 if(op_str == "not")
75 {
76 out_kernel.expr = "!(" + rhs_expr + ")";
77 }
78 else
79 {
80 std::string occa_op_str;
81 if(op_str == "and")
82 {
83 occa_op_str = "&&";
84 }
85 else if(op_str == "or")
86 {
87 occa_op_str = "||";
88 }
89 else
90 {
91 occa_op_str = op_str;
92 }
93 out_kernel.expr = "(" + lhs_expr + " " + op_str + " " + rhs_expr + ")";
94 }
95 out_kernel.num_components = 1;
96 }
97 else
98 {
99 // vector ops
100 bool error = false;
101 if(lhs_kernel.num_components == rhs_kernel.num_components)
102 {
103 if(op_str == "+")
104 {
105 MathCode().vector_add(out_kernel.for_body,
106 lhs_expr,
107 rhs_expr,
108 filter_name,
109 lhs_kernel.num_components);
110
111 out_kernel.num_components = lhs_kernel.num_components;
112 }
113 else if(op_str == "-")
114 {
115 MathCode().vector_subtract(out_kernel.for_body,
116 lhs_expr,
117 rhs_expr,
118 filter_name,
119 lhs_kernel.num_components);
120
121 out_kernel.num_components = lhs_kernel.num_components;
122 }
123 else if(op_str == "*")
124 {
125 MathCode().dot_product(out_kernel.for_body,
126 lhs_expr,
127 rhs_expr,
128 filter_name,
129 lhs_kernel.num_components);
130
131 out_kernel.num_components = 1;
132 }
133 else
134 {
135 error = true;
136 }
137 out_kernel.expr = filter_name;
138 }
139 else
140 {
141 error = true;
142 }
143 if(error)
144 {
145 ASCENT_ERROR("Unsupported binary_op: (field with "
146 << lhs_kernel.num_components << " components) " << op_str
147 << " (field with " << rhs_kernel.num_components
148 << " components).");
149 }
150 }
151 }
152 else
153 {
154 // kernel of this type has already been fused, do nothing on a
155 // per-domain basis
156 }
157 }
158
159 void
builtin_functions(const std::string & function_name)160 JitableFusion::builtin_functions(const std::string &function_name)
161 {
162 if(not_fused)
163 {
164 out_kernel.expr = function_name + "(";
165 const int num_inputs = inputs.number_of_children();
166 for(int i = 0; i < num_inputs; ++i)
167 {
168 const int port_num = inputs.child(i)["port"].to_int32();
169 const Kernel &inp_kernel = *input_kernels[port_num];
170 if(inp_kernel.num_components > 1)
171 {
172 ASCENT_ERROR("Built-in function '"
173 << function_name
174 << "' does not support vector fields with "
175 << inp_kernel.num_components << " components.");
176 }
177 const std::string &inp_expr = inp_kernel.expr;
178 out_kernel.fuse_kernel(inp_kernel);
179 if(i != 0)
180 {
181 out_kernel.expr += ", ";
182 }
183 out_kernel.expr += inp_expr;
184 }
185 out_kernel.expr += ")";
186 out_kernel.num_components = 1;
187 }
188 }
189
190 bool
available_component(const std::string & axis,const int num_axes)191 available_component(const std::string &axis, const int num_axes)
192 {
193 // if a field has only 1 component it doesn't have .x .y .z
194 if((axis == "x" && num_axes >= 2) || (axis == "y" && num_axes >= 2) ||
195 (axis == "z" && num_axes >= 3))
196 {
197 return true;
198 }
199 ASCENT_ERROR("Derived field with "
200 << num_axes << " components does not have component '" << axis
201 << "'.");
202 return false;
203 }
204
205 bool
available_axis(const std::string & axis,const int num_axes,const std::string & topo_name)206 available_axis(const std::string &axis,
207 const int num_axes,
208 const std::string &topo_name)
209 {
210 if(((axis == "x" || axis == "dx") && num_axes >= 1) ||
211 ((axis == "y" || axis == "dy") && num_axes >= 2) ||
212 ((axis == "z" || axis == "dz") && num_axes >= 3))
213 {
214 return true;
215 }
216 ASCENT_ERROR("Topology '" << topo_name << "' with " << num_axes
217 << " dimensions does not have axis '" << axis
218 << "'.");
219 return false;
220 }
221
222 void
topo_attrs(const conduit::Node & obj,const std::string & name)223 JitableFusion::topo_attrs(const conduit::Node &obj, const std::string &name)
224 {
225 const std::string &topo_name = obj["value"].as_string();
226 std::unique_ptr<Topology> topo = topologyFactory(topo_name, domain);
227 if(obj.has_path("attr"))
228 {
229 if(not_fused)
230 {
231 const conduit::Node &assoc = obj["attr"].child(0);
232 TopologyCode topo_code =
233 TopologyCode(topo_name, domain, out_jitable.arrays[dom_idx]);
234 if(assoc.name() == "cell")
235 {
236 // x, y, z
237 if(is_xyz(name) && available_axis(name, topo->num_dims, topo_name))
238 {
239 topo_code.element_coord(
240 out_kernel.for_body, name, "", topo_name + "_cell_" + name);
241 out_kernel.expr = topo_name + "_cell_" + name;
242 }
243 // dx, dy, dz
244 else if(name[0] == 'd' && is_xyz(std::string(1, name[1])) &&
245 available_axis(name, topo->num_dims, topo_name))
246 {
247 if(topo->topo_type == "uniform")
248 {
249 out_kernel.expr = topo_name + "_spacing_" + name;
250 }
251 else if(topo->topo_type == "rectilinear")
252 {
253 topo_code.dxdydz(out_kernel.for_body);
254 out_kernel.expr = topo_name + "_" + name;
255 }
256 else
257 {
258 ASCENT_ERROR("Can only get dx, dy, dz for uniform or rectilinear "
259 "topologies, not topologies of type '"
260 << topo->topo_type << "'.");
261 }
262 }
263 else if(name == "volume")
264 {
265 if(topo->num_dims != 3)
266 {
267 ASCENT_ERROR("Cell volume is only defined for topologies with 3 "
268 "dimensions. The specified topology '"
269 << topo->topo_name << "' has " << topo->num_dims
270 << " dimensions.");
271 }
272 topo_code.volume(out_kernel.for_body);
273 out_kernel.expr = topo_name + "_volume";
274 }
275 else if(name == "area")
276 {
277 if(topo->num_dims < 2)
278 {
279 ASCENT_ERROR("Cell area is only defined for topologies at most 2 "
280 "dimensions. The specified topology '"
281 << topo->topo_name << "' has " << topo->num_dims
282 << " dimensions.");
283 }
284 topo_code.area(out_kernel.for_body);
285 out_kernel.expr = topo_name + "_area";
286 }
287 else if(name == "surface_area")
288 {
289 if(topo->num_dims != 3)
290 {
291 ASCENT_ERROR(
292 "Cell surface area is only defined for topologies with 3 "
293 "dimensions. The specified topology '"
294 << topo->topo_name << "' has " << topo->num_dims
295 << " dimensions.");
296 }
297 topo_code.surface_area(out_kernel.for_body);
298 out_kernel.expr = topo_name + "_surface_area";
299 }
300 else if(name == "id")
301 {
302 out_kernel.expr = "item";
303 }
304 else
305 {
306 ASCENT_ERROR("Could not find attribute '"
307 << name << "' of topo.cell at runtime.");
308 }
309 }
310 else if(assoc.name() == "vertex")
311 {
312 if(is_xyz(name) && available_axis(name, topo->num_dims, topo_name))
313 {
314 topo_code.vertex_coord(
315 out_kernel.for_body, name, "", topo_name + "_vertex_" + name);
316 out_kernel.expr = topo_name + "_vertex_" + name;
317 }
318 else if(name == "id")
319 {
320 out_kernel.expr = "item";
321 }
322 else
323 {
324 ASCENT_ERROR("Could not find attribute '"
325 << name << "' of topo.vertex at runtime.");
326 }
327 }
328 else
329 {
330 ASCENT_ERROR("Could not find attribute '" << assoc.name()
331 << "' of topo at runtime.");
332 }
333 }
334 }
335 else
336 {
337 if(name == "cell")
338 {
339 if(topo->topo_type == "points")
340 {
341 ASCENT_ERROR("Point topology '" << topo_name
342 << "' has no cell attributes.");
343 }
344 out_jitable.dom_info.child(dom_idx)["entries"] = topo->get_num_cells();
345 out_jitable.association = "element";
346 }
347 else
348 {
349 out_jitable.dom_info.child(dom_idx)["entries"] = topo->get_num_points();
350 out_jitable.association = "vertex";
351 }
352 out_jitable.obj = obj;
353 out_jitable.obj["attr/" + name];
354 }
355 }
356
357 void
expr_dot()358 JitableFusion::expr_dot()
359 {
360 const int obj_port = inputs["obj/port"].as_int32();
361 const Kernel &obj_kernel = *input_kernels[obj_port];
362 const conduit::Node &obj = input_jitables[obj_port]->obj;
363 const std::string &name = params["name"].as_string();
364 // derived fields or trivial fields
365 if(!obj.has_path("type") || obj["type"].as_string() == "field")
366 {
367 if(is_xyz(name) && available_component(name, obj_kernel.num_components))
368 {
369 out_kernel.expr =
370 obj_kernel.expr + "[" + std::to_string(name[0] - 'x') + "]";
371 out_kernel.num_components = 1;
372 }
373 else
374 {
375
376 ASCENT_ERROR("Could not find attribute '" << name
377 << "' of field at runtime.");
378 }
379 }
380 else if(obj["type"].as_string() == "topo")
381 {
382 // needs to run for every domain not just every kernel type to
383 // populate entries
384 topo_attrs(obj, name);
385 // for now all topology attributes have one component :)
386 out_kernel.num_components = 1;
387 }
388 else
389 {
390 ASCENT_ERROR("JIT: Unknown obj:\n" << obj.to_yaml());
391 }
392 if(not_fused)
393 {
394 out_kernel.fuse_kernel(obj_kernel);
395 }
396 }
397
398 void
expr_if()399 JitableFusion::expr_if()
400 {
401 if(not_fused)
402 {
403 const int condition_port = inputs["condition/port"].as_int32();
404 const int if_port = inputs["if/port"].as_int32();
405 const int else_port = inputs["else/port"].as_int32();
406 const Kernel &condition_kernel = *input_kernels[condition_port];
407 const Kernel &if_kernel = *input_kernels[if_port];
408 const Kernel &else_kernel = *input_kernels[else_port];
409 out_kernel.functions.insert(condition_kernel.functions);
410 out_kernel.functions.insert(if_kernel.functions);
411 out_kernel.functions.insert(else_kernel.functions);
412 out_kernel.kernel_body.insert(condition_kernel.kernel_body);
413 out_kernel.kernel_body.insert(if_kernel.kernel_body);
414 out_kernel.kernel_body.insert(else_kernel.kernel_body);
415 const std::string cond_name = filter_name + "_cond";
416 const std::string res_name = filter_name + "_res";
417
418 out_kernel.for_body.insert(condition_kernel.for_body);
419 out_kernel.for_body.insert(
420 condition_kernel.generate_output(cond_name, true));
421
422 InsertionOrderedSet<std::string> if_else;
423 if_else.insert("double " + res_name + ";\n");
424 if_else.insert("if(" + cond_name + ")\n{\n");
425 if_else.insert(if_kernel.for_body.accumulate() +
426 if_kernel.generate_output(res_name, false));
427 if_else.insert("}\nelse\n{\n");
428 if_else.insert(else_kernel.for_body.accumulate() +
429 else_kernel.generate_output(res_name, false));
430 if_else.insert("}\n");
431
432 out_kernel.for_body.insert(if_else.accumulate());
433 out_kernel.expr = res_name;
434 if(if_kernel.num_components != else_kernel.num_components)
435 {
436 ASCENT_ERROR("Jitable if-else: The if-branch results in "
437 << if_kernel.num_components
438 << " components and the else-branch results in "
439 << else_kernel.num_components
440 << " but they must have the same number of components.");
441 }
442 out_kernel.num_components = if_kernel.num_components;
443 }
444 }
445
446 void
derived_field()447 JitableFusion::derived_field()
448 {
449 // setting association and topology should run once for Jitable not for
450 // each domain, but won't hurt
451 if(inputs.has_path("assoc"))
452 {
453 const conduit::Node &string_obj =
454 input_jitables[inputs["assoc/port"].as_int32()]->obj;
455 const std::string &new_association = string_obj["value"].as_string();
456 if(new_association != "vertex" && new_association != "element")
457 {
458 ASCENT_ERROR("derived_field: Unknown association '"
459 << new_association
460 << "'. Known associations are 'vertex' and 'element'.");
461 }
462 out_jitable.association = new_association;
463 }
464 if(inputs.has_path("topo"))
465 {
466 const conduit::Node &string_obj =
467 input_jitables[inputs["topo/port"].as_int32()]->obj;
468 const std::string &new_topology = string_obj["value"].as_string();
469 // We repeat this check because we pass the topology name as a
470 // string. If we pass a topo object it will get packed which is
471 // unnecessary if we only want to output on the topology.
472 if(!has_topology(dataset, new_topology))
473 {
474 std::set<std::string> topo_names = topology_names(dataset);
475 std::string res;
476 for(auto &name : topo_names)
477 {
478 res += name + " ";
479 }
480 ASCENT_ERROR(": dataset does not contain topology '"
481 << new_topology << "'"
482 << " known = " <<res);
483 }
484 if(!out_jitable.association.empty() && out_jitable.association != "none")
485 {
486 // update entries
487 std::unique_ptr<Topology> topo = topologyFactory(new_topology, domain);
488 conduit::Node &cur_dom_info = out_jitable.dom_info.child(dom_idx);
489 int new_entries = 0;
490 if(out_jitable.association == "vertex")
491 {
492 new_entries = topo->get_num_points();
493 }
494 else if(out_jitable.association == "element")
495 {
496 new_entries = topo->get_num_cells();
497 }
498 // ensure entries doesn't change if it's already defined
499 if(cur_dom_info.has_child("entries"))
500 {
501 const int cur_entries = cur_dom_info["entries"].to_int32();
502 if(new_entries != cur_entries)
503 {
504 ASCENT_ERROR(
505 "derived_field: cannot put a derived field with "
506 << cur_entries << " entries as a " << out_jitable.association
507 << "-associated derived field on the topology '" << new_topology
508 << "' since the resulting field would need to have "
509 << new_entries << " entries.");
510 }
511 }
512 else
513 {
514 cur_dom_info["entries"] = new_entries;
515 }
516 }
517 out_jitable.topology = new_topology;
518 }
519 if(not_fused)
520 {
521 const int arg1_port = inputs["arg1/port"].as_int32();
522 const Kernel &arg1_kernel = *input_kernels[arg1_port];
523 out_kernel.fuse_kernel(arg1_kernel);
524 out_kernel.expr = arg1_kernel.expr;
525 out_kernel.num_components = arg1_kernel.num_components;
526 }
527 }
528
529 // generate a temporary field on the device, used by things like gradient
530 // which have data dependencies that require the entire field to be present.
531 // essentially wrap field_kernel in a for loop
532 void
temporary_field(const Kernel & field_kernel,const std::string & field_name)533 JitableFusion::temporary_field(const Kernel &field_kernel,
534 const std::string &field_name)
535 {
536 // pass the value of entries for the temporary field
537 const auto entries =
538 out_jitable.dom_info.child(dom_idx)["entries"].to_int64();
539 const std::string entries_name = filter_name + "_inp_entries";
540 out_jitable.dom_info.child(dom_idx)["args/" + entries_name] = entries;
541
542 // we will need to allocate a temporary array so make a schema for it and
543 // put it in the array map
544 // TODO for now temporary fields are interleaved
545 conduit::Schema s;
546 schemaFactory("interleaved",
547 conduit::DataType::FLOAT64_ID,
548 entries,
549 field_kernel.num_components,
550 s);
551 // The array will have to be allocated but doesn't point to any data so we
552 // won't put it in args but it will still be passed in
553 out_jitable.arrays[dom_idx].array_map.insert(
554 std::make_pair(field_name, SchemaBool(s, false)));
555 if(not_fused)
556 {
557 // not a regular kernel_fuse because we have to generate a for-loop and add
558 // it to kernel_body instead of fusing for_body
559 out_kernel.functions.insert(field_kernel.functions);
560 out_kernel.kernel_body.insert(field_kernel.kernel_body);
561 out_kernel.kernel_body.insert(field_kernel.generate_loop(
562 field_name, out_jitable.arrays[dom_idx], entries_name));
563 }
564 }
565
566 std::string
possible_temporary(const int field_port)567 JitableFusion::possible_temporary(const int field_port)
568 {
569 const Jitable &field_jitable = *input_jitables[field_port];
570 const Kernel &field_kernel = *input_kernels[field_port];
571 const conduit::Node &obj = field_jitable.obj;
572 std::string field_name;
573 if(obj.has_path("value"))
574 {
575 field_name = obj["value"].as_string();
576 out_kernel.fuse_kernel(field_kernel);
577 }
578 else
579 {
580 field_name = filter_name + "_inp";
581 temporary_field(field_kernel, field_name);
582 }
583 return field_name;
584 }
585
586 void
gradient(const int field_port,const int component)587 JitableFusion::gradient(const int field_port, const int component)
588 {
589 const Kernel &field_kernel = *input_kernels[field_port];
590
591 if(component == -1 && field_kernel.num_components > 1)
592 {
593 ASCENT_ERROR("gradient is only supported on scalar fields but a field with "
594 << field_kernel.num_components << " components was given.");
595 }
596
597 // association and topology should be the same for out_jitable and
598 // field_jitable because we have already fused jitables at this point
599 if(out_jitable.topology.empty() || out_jitable.topology == "none")
600 {
601 ASCENT_ERROR("Could not take the gradient of the derived field because the "
602 "associated topology could not be determined.");
603 }
604
605 if(out_jitable.association.empty() || out_jitable.association == "none")
606 {
607 ASCENT_ERROR("Could not take the gradient of the derived field "
608 "because the association could not be determined.");
609 }
610
611 std::unique_ptr<Topology> topo =
612 topologyFactory(out_jitable.topology, domain);
613 std::string field_name = possible_temporary(field_port);
614
615 if((topo->topo_type == "structured" || topo->topo_type == "unstructured") &&
616 out_jitable.association == "vertex")
617 {
618 // this does a vertex to cell gradient so update entries
619 conduit::Node &n_entries = out_jitable.dom_info.child(dom_idx)["entries"];
620 n_entries = topo->get_num_cells();
621 out_jitable.association = "element";
622 }
623
624 if(not_fused)
625 {
626 const auto topo_code = std::make_shared<const TopologyCode>(
627 topo->topo_name, domain, out_jitable.arrays[dom_idx]);
628 FieldCode field_code = FieldCode(field_name,
629 out_jitable.association,
630 topo_code,
631 out_jitable.arrays[dom_idx],
632 1,
633 component);
634 field_code.gradient(out_kernel.for_body);
635 out_kernel.expr = field_name +
636 (component == -1 ? "" : "_" + std::to_string(component)) +
637 "_gradient";
638 out_kernel.num_components = 3;
639 }
640 }
641
642 void
gradient()643 JitableFusion::gradient()
644 {
645 const int field_port = inputs["field/port"].as_int32();
646 gradient(field_port, -1);
647 }
648
649 void
curl()650 JitableFusion::curl()
651 {
652 const int field_port = inputs["field/port"].as_int32();
653 const Kernel &field_kernel = *input_kernels[field_port];
654 if(field_kernel.num_components < 2)
655 {
656 ASCENT_ERROR("Vorticity is only implemented for fields with at least 2 "
657 "components. The input field has "
658 << field_kernel.num_components << ".");
659 }
660 const std::string field_name = possible_temporary(field_port);
661 // calling gradient here reuses the logic to update entries and association
662 for(int i = 0; i < field_kernel.num_components; ++i)
663 {
664 gradient(field_port, i);
665 }
666 // TODO make it easier to construct FieldCode
667 if(not_fused)
668 {
669 const auto topo_code = std::make_shared<const TopologyCode>(
670 out_jitable.topology, domain, out_jitable.arrays[dom_idx]);
671 FieldCode field_code = FieldCode(field_name,
672 out_jitable.association,
673 topo_code,
674 out_jitable.arrays[dom_idx],
675 field_kernel.num_components,
676 -1);
677 field_code.curl(out_kernel.for_body);
678 out_kernel.expr = field_name + "_curl";
679 out_kernel.num_components = field_kernel.num_components;
680 }
681 }
682
683 void
recenter()684 JitableFusion::recenter()
685 {
686 if(not_fused)
687 {
688 const int field_port = inputs["field/port"].as_int32();
689 const Kernel &field_kernel = *input_kernels[field_port];
690
691 std::string mode;
692 if(inputs.has_path("mode"))
693 {
694 const int mode_port = inputs["mode/port"].as_int32();
695 const Jitable &mode_jitable = *input_jitables[mode_port];
696 mode = mode_jitable.obj["value"].as_string();
697 if(mode != "toggle" && mode != "vertex" && mode != "element")
698 {
699 ASCENT_ERROR("recenter: Unknown mode '"
700 << mode
701 << "'. Known modes are 'toggle', 'vertex', 'element'.");
702 }
703 if(out_jitable.association == mode)
704 {
705 ASCENT_ERROR("Recenter: The field is already "
706 << out_jitable.association
707 << " associated, redundant recenter.");
708 }
709 }
710 else
711 {
712 mode = "toggle";
713 }
714 std::string target_association;
715 if(mode == "toggle")
716 {
717 if(out_jitable.association == "vertex")
718 {
719 target_association = "element";
720 }
721 else
722 {
723 target_association = "vertex";
724 }
725 }
726 else
727 {
728 target_association = mode;
729 }
730
731 const std::string field_name = possible_temporary(field_port);
732
733 // update entries and association
734 conduit::Node &n_entries = out_jitable.dom_info.child(dom_idx)["entries"];
735 std::unique_ptr<Topology> topo =
736 topologyFactory(out_jitable.topology, domain);
737 if(target_association == "vertex")
738 {
739 n_entries = topo->get_num_points();
740 }
741 else
742 {
743 n_entries = topo->get_num_cells();
744 }
745 out_jitable.association = target_association;
746
747 const auto topo_code = std::make_shared<const TopologyCode>(
748 out_jitable.topology, domain, out_jitable.arrays[dom_idx]);
749 FieldCode field_code = FieldCode(field_name,
750 out_jitable.association,
751 topo_code,
752 out_jitable.arrays[dom_idx],
753 field_kernel.num_components,
754 -1);
755 const std::string res_name = field_name + "_recenter_" + target_association;
756 field_code.recenter(out_kernel.for_body, target_association, res_name);
757 out_kernel.expr = res_name;
758 out_kernel.num_components = field_kernel.num_components;
759 }
760 }
761
762 void
magnitude()763 JitableFusion::magnitude()
764 {
765 if(not_fused)
766 {
767 const int vector_port = inputs["vector/port"].as_int32();
768 const Kernel &vector_kernel = *input_kernels[vector_port];
769 if(vector_kernel.num_components <= 1)
770 {
771 ASCENT_ERROR("Cannot take the magnitude of a vector with "
772 << vector_kernel.num_components << " components.");
773 }
774 out_kernel.fuse_kernel(vector_kernel);
775 MathCode().magnitude(out_kernel.for_body,
776 vector_kernel.expr,
777 filter_name,
778 vector_kernel.num_components);
779 out_kernel.expr = filter_name;
780 out_kernel.num_components = 1;
781 }
782 }
783
784 void
vector()785 JitableFusion::vector()
786 {
787 const int arg1_port = inputs["arg1/port"].to_int32();
788 const int arg2_port = inputs["arg2/port"].to_int32();
789 const int arg3_port = inputs["arg3/port"].to_int32();
790 const Jitable &arg1_jitable = *input_jitables[arg1_port];
791 const Jitable &arg2_jitable = *input_jitables[arg2_port];
792 const Jitable &arg3_jitable = *input_jitables[arg3_port];
793 // if all the inputs to the vector are "trivial" fields then we don't need
794 // to regenerate the vector in a separate for-loop
795 if(arg1_jitable.obj.has_path("type") && arg2_jitable.obj.has_path("type") &&
796 arg3_jitable.obj.has_path("type"))
797 {
798 // We construct a fake schema with the input arrays as the components
799 const std::string &arg1_field = arg1_jitable.obj["value"].as_string();
800 const std::string &arg2_field = arg2_jitable.obj["value"].as_string();
801 const std::string &arg3_field = arg3_jitable.obj["value"].as_string();
802 std::unordered_map<std::string, SchemaBool> &array_map =
803 out_jitable.arrays[dom_idx].array_map;
804 conduit::Schema s;
805 s[arg1_field].set(array_map.at(arg1_field).schema);
806 s[arg2_field].set(array_map.at(arg2_field).schema);
807 s[arg3_field].set(array_map.at(arg3_field).schema);
808 array_map.insert(std::make_pair(filter_name, SchemaBool(s, true)));
809 out_jitable.obj["value"] = filter_name;
810 out_jitable.obj["type"] = "field";
811 }
812
813 if(not_fused)
814 {
815 const Kernel &arg1_kernel = *input_kernels[arg1_port];
816 const Kernel &arg2_kernel = *input_kernels[arg2_port];
817 const Kernel &arg3_kernel = *input_kernels[arg3_port];
818 if(arg1_kernel.num_components != 1 || arg2_kernel.num_components != 1 ||
819 arg2_kernel.num_components != 1)
820 {
821 ASCENT_ERROR("Vector arguments must all have exactly one component.");
822 }
823 out_kernel.fuse_kernel(arg1_kernel);
824 out_kernel.fuse_kernel(arg2_kernel);
825 out_kernel.fuse_kernel(arg3_kernel);
826 const std::string arg1_expr = arg1_kernel.expr;
827 const std::string arg2_expr = arg2_kernel.expr;
828 const std::string arg3_expr = arg3_kernel.expr;
829 out_kernel.for_body.insert({"double " + filter_name + "[3];\n",
830 filter_name + "[0] = " + arg1_expr + ";\n",
831 filter_name + "[1] = " + arg2_expr + ";\n",
832 filter_name + "[2] = " + arg3_expr + ";\n"});
833 out_kernel.expr = filter_name;
834 out_kernel.num_components = 3;
835 }
836 }
837
838 void
binning_value(const conduit::Node & binning)839 JitableFusion::binning_value(const conduit::Node &binning)
840 {
841 // bin lookup functions
842 // clang-format off
843 const std::string rectilinear_bin =
844 "int\n"
845 "rectilinear_bin(const double value,\n"
846 " const double *const bins_begin,\n"
847 " const int len,\n"
848 " const bool clamp)\n"
849 "{\n"
850 // implements std::upper_bound
851 "int mid;\n"
852 "int low = 0;\n"
853 "int high = len;\n"
854 "while(low < high)\n"
855 "{\n"
856 "mid = (low + high) / 2;\n"
857 "if(value >= bins_begin[mid])\n"
858 "{\n"
859 "low = mid + 1;\n"
860 "}\n"
861 "else\n"
862 "{\n"
863 "high = mid;\n"
864 "}\n"
865 "}\n"
866
867 "if(clamp)\n"
868 "{\n"
869 "if(low <= 0)\n"
870 "{\n"
871 "return 0;\n"
872 "}\n"
873 "else if(low >= len)\n"
874 "{\n"
875 "return len - 2;\n"
876 "}\n"
877 "}\n"
878 "else if(low <= 0 || low >= len)\n"
879 "{\n"
880 "return -1;\n"
881 "}\n"
882 "return low - 1;\n"
883 "}\n\n";
884 const std::string uniform_bin =
885 "int\n"
886 "uniform_bin(const double value,\n"
887 " const double min_val,\n"
888 " const double max_val,\n"
889 " const int num_bins,\n"
890 " const bool clamp)\n"
891 "{\n"
892 "const double inv_delta = num_bins / (max_val - min_val);\n"
893 "const int bin_index = (int)((value - min_val) * inv_delta);\n"
894 "if(clamp)\n"
895 "{\n"
896 "if(bin_index < 0)\n"
897 "{\n"
898 "return 0;\n"
899 "}\n"
900 "else if(bin_index >= num_bins)\n"
901 "{\n"
902 "return num_bins - 1;\n"
903 "}\n"
904 "}\n"
905 "else if(bin_index < 0 || bin_index >= num_bins)\n"
906 "{\n"
907 "return -1;\n"
908 "}\n"
909 "return bin_index;\n"
910 "}\n\n";
911 // clang-format on
912 //---------------------------------------------------------------------------
913
914 // assume the necessary fields have been packed and are present in all
915 // domains
916
917 // get the passed association
918 std::string assoc_str_;
919 if(inputs.has_path("assoc"))
920 {
921 const conduit::Node &assoc_obj =
922 input_jitables[inputs["assoc/port"].as_int32()]->obj;
923 assoc_str_ = assoc_obj["value"].as_string();
924 }
925
926 const conduit::Node &bin_axes = binning["attrs/bin_axes/value"];
927 std::vector<std::string> axis_names = bin_axes.child_names();
928
929 // set/verify out_jitable.topology and out_jitable.association
930 const conduit::Node &topo_and_assoc =
931 final_topo_and_assoc(dataset, bin_axes, out_jitable.topology, assoc_str_);
932 std::string assoc_str = topo_and_assoc["assoc_str"].as_string();
933 if(assoc_str.empty())
934 {
935 // use the association from the binning
936 assoc_str = binning["attrs/association/value"].as_string();
937 }
938 out_jitable.association = assoc_str;
939
940 // set entries based on out_jitable.topology and out_jitable.association
941 std::unique_ptr<Topology> topo =
942 topologyFactory(out_jitable.topology, domain);
943 conduit::Node &n_entries = out_jitable.dom_info.child(dom_idx)["entries"];
944 if(out_jitable.association == "vertex")
945 {
946 n_entries = topo->get_num_points();
947 }
948 else if(out_jitable.association == "element")
949 {
950 n_entries = topo->get_num_cells();
951 }
952
953 if(not_fused)
954 {
955 const TopologyCode topo_code =
956 TopologyCode(topo->topo_name, domain, out_jitable.arrays[dom_idx]);
957 const std::string &binning_name = inputs["binning/filter_name"].as_string();
958 InsertionOrderedSet<std::string> &code = out_kernel.for_body;
959 const int num_axes = bin_axes.number_of_children();
960 bool used_uniform = false;
961 bool used_rectilinear = false;
962
963 code.insert("int " + filter_name + "_home = 0;\n");
964 code.insert("int " + filter_name + "_stride = 1;\n");
965 for(int axis_index = 0; axis_index < num_axes; ++axis_index)
966 {
967 const conduit::Node &axis = bin_axes.child(axis_index);
968 const std::string axis_name = axis.name();
969 const std::string axis_prefix = binning_name + "_" + axis_name + "_";
970 // find the value associated with the axis for the current item
971 std::string axis_value;
972 if(domain.has_path("fields/" + axis_name))
973 {
974 axis_value = axis_name + "_item";
975 code.insert("const double " + axis_value + " = " +
976 out_jitable.arrays[dom_idx].index(axis_name, "item") +
977 ";\n");
978 }
979 else if(is_xyz(axis_name))
980 {
981 if(out_jitable.association == "vertex")
982 {
983 axis_value = topo->topo_name + "_vertex_" + axis_name;
984 topo_code.vertex_coord(code, axis_name, "", axis_value);
985 }
986 else if(out_jitable.association == "element")
987 {
988 axis_value = topo->topo_name + "_cell_" + axis_name;
989 topo_code.element_coord(code, axis_name, "", axis_value);
990 }
991 }
992
993 size_t stride_multiplier;
994 if(axis.has_path("num_bins"))
995 {
996 // uniform axis
997 stride_multiplier = bin_axes.child(axis_index)["num_bins"].as_int32();
998
999 // find the value's index in the axis
1000 if(!used_uniform)
1001 {
1002 used_uniform = true;
1003 out_kernel.functions.insert(uniform_bin);
1004 }
1005 code.insert("int " + axis_prefix + "bin_index = uniform_bin(" +
1006 axis_value + ", " + axis_prefix + "min_val, " +
1007 axis_prefix + "max_val, " + axis_prefix + "num_bins, " +
1008 axis_prefix + "clamp);\n");
1009 }
1010 else
1011 {
1012 // rectilinear axis
1013 stride_multiplier =
1014 bin_axes.child(axis_index)["bins"].dtype().number_of_elements() - 1;
1015
1016 // find the value's index in the axis
1017 if(!used_rectilinear)
1018 {
1019 used_rectilinear = true;
1020 out_kernel.functions.insert(rectilinear_bin);
1021 }
1022 code.insert("int " + axis_prefix + "bin_index = rectilinear_bin(" +
1023 axis_value + ", " + axis_prefix + "bins, " + axis_prefix +
1024 "bins_len, " + axis_prefix + "clamp);\n");
1025 }
1026
1027 // update the current item's home
1028 code.insert("if(" + axis_prefix + "bin_index != -1 && " + filter_name +
1029 "_home != -1)\n{\n" + filter_name + "_home += " +
1030 axis_prefix + "bin_index * " + filter_name + "_stride;\n}\n");
1031 // update stride
1032 code.insert(filter_name +
1033 "_stride *= " + std::to_string(stride_multiplier) + ";\n");
1034 }
1035
1036 // get the value at home
1037 std::string default_value;
1038 if(inputs.has_path("default_value"))
1039 {
1040 default_value = inputs["default_value/filter_name"].as_string();
1041 }
1042 else
1043 {
1044 default_value = "0";
1045 }
1046 code.insert({"double " + filter_name + ";\n",
1047 "if(" + filter_name + "_home != -1)\n{\n" + filter_name +
1048 " = " + binning_name + "_value[" + filter_name +
1049 "_home];\n}\nelse\n{\n" + filter_name + " = " +
1050 default_value + ";\n}\n"});
1051 out_kernel.expr = filter_name;
1052 out_kernel.num_components = 1;
1053 }
1054 }
1055
1056 void
rand()1057 JitableFusion::rand()
1058 {
1059 out_jitable.dom_info.child(dom_idx)["args/" + filter_name + "_seed"] =
1060 time(nullptr);
1061 if(not_fused)
1062 {
1063 // clang-format off
1064 const std::string halton =
1065 "double rand(int i)\n"
1066 "{\n"
1067 "const int b = 2;\n"
1068 "double f = 1;\n"
1069 "double r = 0;\n"
1070 "while(i > 0)\n"
1071 "{\n"
1072 "f = f / b;\n"
1073 "r = r + f * (i \% b);\n"
1074 "i = i / b;\n"
1075 "}\n"
1076 "return r;\n"
1077 "}\n\n";
1078 // clang-format on
1079 out_kernel.functions.insert(halton);
1080 out_kernel.expr = "rand(item + " + filter_name + "_seed)";
1081 out_kernel.num_components = 1;
1082 }
1083 }
1084
1085 //-----------------------------------------------------------------------------
1086 };
1087 //-----------------------------------------------------------------------------
1088 // -- end ascent::runtime::expressions--
1089 //-----------------------------------------------------------------------------
1090
1091 //-----------------------------------------------------------------------------
1092 };
1093 //-----------------------------------------------------------------------------
1094 // -- end ascent::runtime --
1095 //-----------------------------------------------------------------------------
1096
1097 //-----------------------------------------------------------------------------
1098 };
1099 //-----------------------------------------------------------------------------
1100 // -- end ascent:: --
1101 //-----------------------------------------------------------------------------
1102