1 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
2 // Copyright (c) Lawrence Livermore National Security, LLC and other Ascent
3 // Project developers. See top-level LICENSE AND COPYRIGHT files for dates and
4 // other details. No copyright assignment is required to contribute to Ascent.
5 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
6 
7 //-----------------------------------------------------------------------------
8 ///
9 /// file: ascent_jit_fusion.cpp
10 ///
11 //-----------------------------------------------------------------------------
12 
13 #include "ascent_jit_fusion.hpp"
14 #include "ascent_blueprint_topologies.hpp"
15 #include "ascent_blueprint_architect.hpp"
16 #include <ascent_logging.hpp>
17 
18 //-----------------------------------------------------------------------------
19 // -- begin ascent:: --
20 //-----------------------------------------------------------------------------
21 namespace ascent
22 {
23 
24 //-----------------------------------------------------------------------------
25 // -- begin ascent::runtime --
26 //-----------------------------------------------------------------------------
27 namespace runtime
28 {
29 
30 //-----------------------------------------------------------------------------
31 // -- begin ascent::runtime::expressions--
32 //-----------------------------------------------------------------------------
33 namespace expressions
34 {
35 
36 
37 
JitableFusion(const conduit::Node & params,const std::vector<const Jitable * > & input_jitables,const std::vector<const Kernel * > & input_kernels,const std::string & filter_name,const conduit::Node & dataset,const int dom_idx,const bool not_fused,Jitable & out_jitable,Kernel & out_kernel)38 JitableFusion::JitableFusion(
39     const conduit::Node &params,
40     const std::vector<const Jitable *> &input_jitables,
41     const std::vector<const Kernel *> &input_kernels,
42     const std::string &filter_name,
43     const conduit::Node &dataset,
44     const int dom_idx,
45     const bool not_fused,
46     Jitable &out_jitable,
47     Kernel &out_kernel)
48     : params(params), input_jitables(input_jitables),
49       input_kernels(input_kernels), filter_name(filter_name), dataset(dataset),
50       dom_idx(dom_idx), not_fused(not_fused), out_jitable(out_jitable),
51       out_kernel(out_kernel), inputs(params["inputs"]),
52       domain(dataset.child(dom_idx))
53 {
54 }
55 
56 void
binary_op()57 JitableFusion::binary_op()
58 {
59   if(not_fused)
60   {
61     const int lhs_port = inputs["lhs/port"].to_int32();
62     const int rhs_port = inputs["rhs/port"].to_int32();
63     const Kernel &lhs_kernel = *input_kernels[lhs_port];
64     const Kernel &rhs_kernel = *input_kernels[rhs_port];
65     // union the field/mesh vars
66     out_kernel.fuse_kernel(lhs_kernel);
67     out_kernel.fuse_kernel(rhs_kernel);
68     const std::string lhs_expr = lhs_kernel.expr;
69     const std::string rhs_expr = rhs_kernel.expr;
70     const std::string &op_str = params["op_string"].as_string();
71     if(lhs_kernel.num_components == 1 && rhs_kernel.num_components == 1)
72     {
73       // scalar ops
74       if(op_str == "not")
75       {
76         out_kernel.expr = "!(" + rhs_expr + ")";
77       }
78       else
79       {
80         std::string occa_op_str;
81         if(op_str == "and")
82         {
83           occa_op_str = "&&";
84         }
85         else if(op_str == "or")
86         {
87           occa_op_str = "||";
88         }
89         else
90         {
91           occa_op_str = op_str;
92         }
93         out_kernel.expr = "(" + lhs_expr + " " + op_str + " " + rhs_expr + ")";
94       }
95       out_kernel.num_components = 1;
96     }
97     else
98     {
99       // vector ops
100       bool error = false;
101       if(lhs_kernel.num_components == rhs_kernel.num_components)
102       {
103         if(op_str == "+")
104         {
105           MathCode().vector_add(out_kernel.for_body,
106                                 lhs_expr,
107                                 rhs_expr,
108                                 filter_name,
109                                 lhs_kernel.num_components);
110 
111           out_kernel.num_components = lhs_kernel.num_components;
112         }
113         else if(op_str == "-")
114         {
115           MathCode().vector_subtract(out_kernel.for_body,
116                                      lhs_expr,
117                                      rhs_expr,
118                                      filter_name,
119                                      lhs_kernel.num_components);
120 
121           out_kernel.num_components = lhs_kernel.num_components;
122         }
123         else if(op_str == "*")
124         {
125           MathCode().dot_product(out_kernel.for_body,
126                                  lhs_expr,
127                                  rhs_expr,
128                                  filter_name,
129                                  lhs_kernel.num_components);
130 
131           out_kernel.num_components = 1;
132         }
133         else
134         {
135           error = true;
136         }
137         out_kernel.expr = filter_name;
138       }
139       else
140       {
141         error = true;
142       }
143       if(error)
144       {
145         ASCENT_ERROR("Unsupported binary_op: (field with "
146                      << lhs_kernel.num_components << " components) " << op_str
147                      << " (field with " << rhs_kernel.num_components
148                      << " components).");
149       }
150     }
151   }
152   else
153   {
154     // kernel of this type has already been fused, do nothing on a
155     // per-domain basis
156   }
157 }
158 
159 void
builtin_functions(const std::string & function_name)160 JitableFusion::builtin_functions(const std::string &function_name)
161 {
162   if(not_fused)
163   {
164     out_kernel.expr = function_name + "(";
165     const int num_inputs = inputs.number_of_children();
166     for(int i = 0; i < num_inputs; ++i)
167     {
168       const int port_num = inputs.child(i)["port"].to_int32();
169       const Kernel &inp_kernel = *input_kernels[port_num];
170       if(inp_kernel.num_components > 1)
171       {
172         ASCENT_ERROR("Built-in function '"
173                      << function_name
174                      << "' does not support vector fields with "
175                      << inp_kernel.num_components << " components.");
176       }
177       const std::string &inp_expr = inp_kernel.expr;
178       out_kernel.fuse_kernel(inp_kernel);
179       if(i != 0)
180       {
181         out_kernel.expr += ", ";
182       }
183       out_kernel.expr += inp_expr;
184     }
185     out_kernel.expr += ")";
186     out_kernel.num_components = 1;
187   }
188 }
189 
190 bool
available_component(const std::string & axis,const int num_axes)191 available_component(const std::string &axis, const int num_axes)
192 {
193   // if a field has only 1 component it doesn't have .x .y .z
194   if((axis == "x" && num_axes >= 2) || (axis == "y" && num_axes >= 2) ||
195      (axis == "z" && num_axes >= 3))
196   {
197     return true;
198   }
199   ASCENT_ERROR("Derived field with "
200                << num_axes << " components does not have component '" << axis
201                << "'.");
202   return false;
203 }
204 
205 bool
available_axis(const std::string & axis,const int num_axes,const std::string & topo_name)206 available_axis(const std::string &axis,
207                const int num_axes,
208                const std::string &topo_name)
209 {
210   if(((axis == "x" || axis == "dx") && num_axes >= 1) ||
211      ((axis == "y" || axis == "dy") && num_axes >= 2) ||
212      ((axis == "z" || axis == "dz") && num_axes >= 3))
213   {
214     return true;
215   }
216   ASCENT_ERROR("Topology '" << topo_name << "' with " << num_axes
217                             << " dimensions does not have axis '" << axis
218                             << "'.");
219   return false;
220 }
221 
222 void
topo_attrs(const conduit::Node & obj,const std::string & name)223 JitableFusion::topo_attrs(const conduit::Node &obj, const std::string &name)
224 {
225   const std::string &topo_name = obj["value"].as_string();
226   std::unique_ptr<Topology> topo = topologyFactory(topo_name, domain);
227   if(obj.has_path("attr"))
228   {
229     if(not_fused)
230     {
231       const conduit::Node &assoc = obj["attr"].child(0);
232       TopologyCode topo_code =
233           TopologyCode(topo_name, domain, out_jitable.arrays[dom_idx]);
234       if(assoc.name() == "cell")
235       {
236         // x, y, z
237         if(is_xyz(name) && available_axis(name, topo->num_dims, topo_name))
238         {
239           topo_code.element_coord(
240               out_kernel.for_body, name, "", topo_name + "_cell_" + name);
241           out_kernel.expr = topo_name + "_cell_" + name;
242         }
243         // dx, dy, dz
244         else if(name[0] == 'd' && is_xyz(std::string(1, name[1])) &&
245                 available_axis(name, topo->num_dims, topo_name))
246         {
247           if(topo->topo_type == "uniform")
248           {
249             out_kernel.expr = topo_name + "_spacing_" + name;
250           }
251           else if(topo->topo_type == "rectilinear")
252           {
253             topo_code.dxdydz(out_kernel.for_body);
254             out_kernel.expr = topo_name + "_" + name;
255           }
256           else
257           {
258             ASCENT_ERROR("Can only get dx, dy, dz for uniform or rectilinear "
259                          "topologies, not topologies of type '"
260                          << topo->topo_type << "'.");
261           }
262         }
263         else if(name == "volume")
264         {
265           if(topo->num_dims != 3)
266           {
267             ASCENT_ERROR("Cell volume is only defined for topologies with 3 "
268                          "dimensions. The specified topology '"
269                          << topo->topo_name << "' has " << topo->num_dims
270                          << " dimensions.");
271           }
272           topo_code.volume(out_kernel.for_body);
273           out_kernel.expr = topo_name + "_volume";
274         }
275         else if(name == "area")
276         {
277           if(topo->num_dims < 2)
278           {
279             ASCENT_ERROR("Cell area is only defined for topologies at most 2 "
280                          "dimensions. The specified topology '"
281                          << topo->topo_name << "' has " << topo->num_dims
282                          << " dimensions.");
283           }
284           topo_code.area(out_kernel.for_body);
285           out_kernel.expr = topo_name + "_area";
286         }
287         else if(name == "surface_area")
288         {
289           if(topo->num_dims != 3)
290           {
291             ASCENT_ERROR(
292                 "Cell surface area is only defined for topologies with 3 "
293                 "dimensions. The specified topology '"
294                 << topo->topo_name << "' has " << topo->num_dims
295                 << " dimensions.");
296           }
297           topo_code.surface_area(out_kernel.for_body);
298           out_kernel.expr = topo_name + "_surface_area";
299         }
300         else if(name == "id")
301         {
302           out_kernel.expr = "item";
303         }
304         else
305         {
306           ASCENT_ERROR("Could not find attribute '"
307                        << name << "' of topo.cell at runtime.");
308         }
309       }
310       else if(assoc.name() == "vertex")
311       {
312         if(is_xyz(name) && available_axis(name, topo->num_dims, topo_name))
313         {
314           topo_code.vertex_coord(
315               out_kernel.for_body, name, "", topo_name + "_vertex_" + name);
316           out_kernel.expr = topo_name + "_vertex_" + name;
317         }
318         else if(name == "id")
319         {
320           out_kernel.expr = "item";
321         }
322         else
323         {
324           ASCENT_ERROR("Could not find attribute '"
325                        << name << "' of topo.vertex at runtime.");
326         }
327       }
328       else
329       {
330         ASCENT_ERROR("Could not find attribute '" << assoc.name()
331                                                   << "' of topo at runtime.");
332       }
333     }
334   }
335   else
336   {
337     if(name == "cell")
338     {
339       if(topo->topo_type == "points")
340       {
341         ASCENT_ERROR("Point topology '" << topo_name
342                                         << "' has no cell attributes.");
343       }
344       out_jitable.dom_info.child(dom_idx)["entries"] = topo->get_num_cells();
345       out_jitable.association = "element";
346     }
347     else
348     {
349       out_jitable.dom_info.child(dom_idx)["entries"] = topo->get_num_points();
350       out_jitable.association = "vertex";
351     }
352     out_jitable.obj = obj;
353     out_jitable.obj["attr/" + name];
354   }
355 }
356 
357 void
expr_dot()358 JitableFusion::expr_dot()
359 {
360   const int obj_port = inputs["obj/port"].as_int32();
361   const Kernel &obj_kernel = *input_kernels[obj_port];
362   const conduit::Node &obj = input_jitables[obj_port]->obj;
363   const std::string &name = params["name"].as_string();
364   // derived fields or trivial fields
365   if(!obj.has_path("type") || obj["type"].as_string() == "field")
366   {
367     if(is_xyz(name) && available_component(name, obj_kernel.num_components))
368     {
369       out_kernel.expr =
370           obj_kernel.expr + "[" + std::to_string(name[0] - 'x') + "]";
371       out_kernel.num_components = 1;
372     }
373     else
374     {
375 
376       ASCENT_ERROR("Could not find attribute '" << name
377                                                 << "' of field at runtime.");
378     }
379   }
380   else if(obj["type"].as_string() == "topo")
381   {
382     // needs to run for every domain not just every kernel type to
383     // populate entries
384     topo_attrs(obj, name);
385     // for now all topology attributes have one component :)
386     out_kernel.num_components = 1;
387   }
388   else
389   {
390     ASCENT_ERROR("JIT: Unknown obj:\n" << obj.to_yaml());
391   }
392   if(not_fused)
393   {
394     out_kernel.fuse_kernel(obj_kernel);
395   }
396 }
397 
398 void
expr_if()399 JitableFusion::expr_if()
400 {
401   if(not_fused)
402   {
403     const int condition_port = inputs["condition/port"].as_int32();
404     const int if_port = inputs["if/port"].as_int32();
405     const int else_port = inputs["else/port"].as_int32();
406     const Kernel &condition_kernel = *input_kernels[condition_port];
407     const Kernel &if_kernel = *input_kernels[if_port];
408     const Kernel &else_kernel = *input_kernels[else_port];
409     out_kernel.functions.insert(condition_kernel.functions);
410     out_kernel.functions.insert(if_kernel.functions);
411     out_kernel.functions.insert(else_kernel.functions);
412     out_kernel.kernel_body.insert(condition_kernel.kernel_body);
413     out_kernel.kernel_body.insert(if_kernel.kernel_body);
414     out_kernel.kernel_body.insert(else_kernel.kernel_body);
415     const std::string cond_name = filter_name + "_cond";
416     const std::string res_name = filter_name + "_res";
417 
418     out_kernel.for_body.insert(condition_kernel.for_body);
419     out_kernel.for_body.insert(
420         condition_kernel.generate_output(cond_name, true));
421 
422     InsertionOrderedSet<std::string> if_else;
423     if_else.insert("double " + res_name + ";\n");
424     if_else.insert("if(" + cond_name + ")\n{\n");
425     if_else.insert(if_kernel.for_body.accumulate() +
426                    if_kernel.generate_output(res_name, false));
427     if_else.insert("}\nelse\n{\n");
428     if_else.insert(else_kernel.for_body.accumulate() +
429                    else_kernel.generate_output(res_name, false));
430     if_else.insert("}\n");
431 
432     out_kernel.for_body.insert(if_else.accumulate());
433     out_kernel.expr = res_name;
434     if(if_kernel.num_components != else_kernel.num_components)
435     {
436       ASCENT_ERROR("Jitable if-else: The if-branch results in "
437                    << if_kernel.num_components
438                    << " components and the else-branch results in "
439                    << else_kernel.num_components
440                    << " but they must have the same number of components.");
441     }
442     out_kernel.num_components = if_kernel.num_components;
443   }
444 }
445 
446 void
derived_field()447 JitableFusion::derived_field()
448 {
449   // setting association and topology should run once for Jitable not for
450   // each domain, but won't hurt
451   if(inputs.has_path("assoc"))
452   {
453     const conduit::Node &string_obj =
454         input_jitables[inputs["assoc/port"].as_int32()]->obj;
455     const std::string &new_association = string_obj["value"].as_string();
456     if(new_association != "vertex" && new_association != "element")
457     {
458       ASCENT_ERROR("derived_field: Unknown association '"
459                    << new_association
460                    << "'. Known associations are 'vertex' and 'element'.");
461     }
462     out_jitable.association = new_association;
463   }
464   if(inputs.has_path("topo"))
465   {
466     const conduit::Node &string_obj =
467         input_jitables[inputs["topo/port"].as_int32()]->obj;
468     const std::string &new_topology = string_obj["value"].as_string();
469     // We repeat this check because we pass the topology name as a
470     // string. If we pass a topo object it will get packed which is
471     // unnecessary if we only want to output on the topology.
472     if(!has_topology(dataset, new_topology))
473     {
474       std::set<std::string> topo_names = topology_names(dataset);
475       std::string res;
476       for(auto &name : topo_names)
477       {
478         res += name + " ";
479       }
480       ASCENT_ERROR(": dataset does not contain topology '"
481                    << new_topology << "'"
482                    << " known = " <<res);
483     }
484     if(!out_jitable.association.empty() && out_jitable.association != "none")
485     {
486       // update entries
487       std::unique_ptr<Topology> topo = topologyFactory(new_topology, domain);
488       conduit::Node &cur_dom_info = out_jitable.dom_info.child(dom_idx);
489       int new_entries = 0;
490       if(out_jitable.association == "vertex")
491       {
492         new_entries = topo->get_num_points();
493       }
494       else if(out_jitable.association == "element")
495       {
496         new_entries = topo->get_num_cells();
497       }
498       // ensure entries doesn't change if it's already defined
499       if(cur_dom_info.has_child("entries"))
500       {
501         const int cur_entries = cur_dom_info["entries"].to_int32();
502         if(new_entries != cur_entries)
503         {
504           ASCENT_ERROR(
505               "derived_field: cannot put a derived field with "
506               << cur_entries << " entries as a " << out_jitable.association
507               << "-associated derived field on the topology '" << new_topology
508               << "' since the resulting field would need to have "
509               << new_entries << " entries.");
510         }
511       }
512       else
513       {
514         cur_dom_info["entries"] = new_entries;
515       }
516     }
517     out_jitable.topology = new_topology;
518   }
519   if(not_fused)
520   {
521     const int arg1_port = inputs["arg1/port"].as_int32();
522     const Kernel &arg1_kernel = *input_kernels[arg1_port];
523     out_kernel.fuse_kernel(arg1_kernel);
524     out_kernel.expr = arg1_kernel.expr;
525     out_kernel.num_components = arg1_kernel.num_components;
526   }
527 }
528 
529 // generate a temporary field on the device, used by things like gradient
530 // which have data dependencies that require the entire field to be present.
531 // essentially wrap field_kernel in a for loop
532 void
temporary_field(const Kernel & field_kernel,const std::string & field_name)533 JitableFusion::temporary_field(const Kernel &field_kernel,
534                                   const std::string &field_name)
535 {
536   // pass the value of entries for the temporary field
537   const auto entries =
538       out_jitable.dom_info.child(dom_idx)["entries"].to_int64();
539   const std::string entries_name = filter_name + "_inp_entries";
540   out_jitable.dom_info.child(dom_idx)["args/" + entries_name] = entries;
541 
542   // we will need to allocate a temporary array so make a schema for it and
543   // put it in the array map
544   // TODO for now temporary fields are interleaved
545   conduit::Schema s;
546   schemaFactory("interleaved",
547                 conduit::DataType::FLOAT64_ID,
548                 entries,
549                 field_kernel.num_components,
550                 s);
551   // The array will have to be allocated but doesn't point to any data so we
552   // won't put it in args but it will still be passed in
553   out_jitable.arrays[dom_idx].array_map.insert(
554       std::make_pair(field_name, SchemaBool(s, false)));
555   if(not_fused)
556   {
557     // not a regular kernel_fuse because we have to generate a for-loop and add
558     // it to kernel_body instead of fusing for_body
559     out_kernel.functions.insert(field_kernel.functions);
560     out_kernel.kernel_body.insert(field_kernel.kernel_body);
561     out_kernel.kernel_body.insert(field_kernel.generate_loop(
562         field_name, out_jitable.arrays[dom_idx], entries_name));
563   }
564 }
565 
566 std::string
possible_temporary(const int field_port)567 JitableFusion::possible_temporary(const int field_port)
568 {
569   const Jitable &field_jitable = *input_jitables[field_port];
570   const Kernel &field_kernel = *input_kernels[field_port];
571   const conduit::Node &obj = field_jitable.obj;
572   std::string field_name;
573   if(obj.has_path("value"))
574   {
575     field_name = obj["value"].as_string();
576     out_kernel.fuse_kernel(field_kernel);
577   }
578   else
579   {
580     field_name = filter_name + "_inp";
581     temporary_field(field_kernel, field_name);
582   }
583   return field_name;
584 }
585 
586 void
gradient(const int field_port,const int component)587 JitableFusion::gradient(const int field_port, const int component)
588 {
589   const Kernel &field_kernel = *input_kernels[field_port];
590 
591   if(component == -1 && field_kernel.num_components > 1)
592   {
593     ASCENT_ERROR("gradient is only supported on scalar fields but a field with "
594                  << field_kernel.num_components << " components was given.");
595   }
596 
597   // association and topology should be the same for out_jitable and
598   // field_jitable because we have already fused jitables at this point
599   if(out_jitable.topology.empty() || out_jitable.topology == "none")
600   {
601     ASCENT_ERROR("Could not take the gradient of the derived field because the "
602                  "associated topology could not be determined.");
603   }
604 
605   if(out_jitable.association.empty() || out_jitable.association == "none")
606   {
607     ASCENT_ERROR("Could not take the gradient of the derived field "
608                  "because the association could not be determined.");
609   }
610 
611   std::unique_ptr<Topology> topo =
612       topologyFactory(out_jitable.topology, domain);
613   std::string field_name = possible_temporary(field_port);
614 
615   if((topo->topo_type == "structured" || topo->topo_type == "unstructured") &&
616      out_jitable.association == "vertex")
617   {
618     // this does a vertex to cell gradient so update entries
619     conduit::Node &n_entries = out_jitable.dom_info.child(dom_idx)["entries"];
620     n_entries = topo->get_num_cells();
621     out_jitable.association = "element";
622   }
623 
624   if(not_fused)
625   {
626     const auto topo_code = std::make_shared<const TopologyCode>(
627         topo->topo_name, domain, out_jitable.arrays[dom_idx]);
628     FieldCode field_code = FieldCode(field_name,
629                                      out_jitable.association,
630                                      topo_code,
631                                      out_jitable.arrays[dom_idx],
632                                      1,
633                                      component);
634     field_code.gradient(out_kernel.for_body);
635     out_kernel.expr = field_name +
636                       (component == -1 ? "" : "_" + std::to_string(component)) +
637                       "_gradient";
638     out_kernel.num_components = 3;
639   }
640 }
641 
642 void
gradient()643 JitableFusion::gradient()
644 {
645   const int field_port = inputs["field/port"].as_int32();
646   gradient(field_port, -1);
647 }
648 
649 void
curl()650 JitableFusion::curl()
651 {
652   const int field_port = inputs["field/port"].as_int32();
653   const Kernel &field_kernel = *input_kernels[field_port];
654   if(field_kernel.num_components < 2)
655   {
656     ASCENT_ERROR("Vorticity is only implemented for fields with at least 2 "
657                  "components. The input field has "
658                  << field_kernel.num_components << ".");
659   }
660   const std::string field_name = possible_temporary(field_port);
661   // calling gradient here reuses the logic to update entries and association
662   for(int i = 0; i < field_kernel.num_components; ++i)
663   {
664     gradient(field_port, i);
665   }
666   // TODO make it easier to construct FieldCode
667   if(not_fused)
668   {
669     const auto topo_code = std::make_shared<const TopologyCode>(
670         out_jitable.topology, domain, out_jitable.arrays[dom_idx]);
671     FieldCode field_code = FieldCode(field_name,
672                                      out_jitable.association,
673                                      topo_code,
674                                      out_jitable.arrays[dom_idx],
675                                      field_kernel.num_components,
676                                      -1);
677     field_code.curl(out_kernel.for_body);
678     out_kernel.expr = field_name + "_curl";
679     out_kernel.num_components = field_kernel.num_components;
680   }
681 }
682 
683 void
recenter()684 JitableFusion::recenter()
685 {
686   if(not_fused)
687   {
688     const int field_port = inputs["field/port"].as_int32();
689     const Kernel &field_kernel = *input_kernels[field_port];
690 
691     std::string mode;
692     if(inputs.has_path("mode"))
693     {
694       const int mode_port = inputs["mode/port"].as_int32();
695       const Jitable &mode_jitable = *input_jitables[mode_port];
696       mode = mode_jitable.obj["value"].as_string();
697       if(mode != "toggle" && mode != "vertex" && mode != "element")
698       {
699         ASCENT_ERROR("recenter: Unknown mode '"
700                      << mode
701                      << "'. Known modes are 'toggle', 'vertex', 'element'.");
702       }
703       if(out_jitable.association == mode)
704       {
705         ASCENT_ERROR("Recenter: The field is already "
706                      << out_jitable.association
707                      << " associated, redundant recenter.");
708       }
709     }
710     else
711     {
712       mode = "toggle";
713     }
714     std::string target_association;
715     if(mode == "toggle")
716     {
717       if(out_jitable.association == "vertex")
718       {
719         target_association = "element";
720       }
721       else
722       {
723         target_association = "vertex";
724       }
725     }
726     else
727     {
728       target_association = mode;
729     }
730 
731     const std::string field_name = possible_temporary(field_port);
732 
733     // update entries and association
734     conduit::Node &n_entries = out_jitable.dom_info.child(dom_idx)["entries"];
735     std::unique_ptr<Topology> topo =
736         topologyFactory(out_jitable.topology, domain);
737     if(target_association == "vertex")
738     {
739       n_entries = topo->get_num_points();
740     }
741     else
742     {
743       n_entries = topo->get_num_cells();
744     }
745     out_jitable.association = target_association;
746 
747     const auto topo_code = std::make_shared<const TopologyCode>(
748         out_jitable.topology, domain, out_jitable.arrays[dom_idx]);
749     FieldCode field_code = FieldCode(field_name,
750                                      out_jitable.association,
751                                      topo_code,
752                                      out_jitable.arrays[dom_idx],
753                                      field_kernel.num_components,
754                                      -1);
755     const std::string res_name = field_name + "_recenter_" + target_association;
756     field_code.recenter(out_kernel.for_body, target_association, res_name);
757     out_kernel.expr = res_name;
758     out_kernel.num_components = field_kernel.num_components;
759   }
760 }
761 
762 void
magnitude()763 JitableFusion::magnitude()
764 {
765   if(not_fused)
766   {
767     const int vector_port = inputs["vector/port"].as_int32();
768     const Kernel &vector_kernel = *input_kernels[vector_port];
769     if(vector_kernel.num_components <= 1)
770     {
771       ASCENT_ERROR("Cannot take the magnitude of a vector with "
772                    << vector_kernel.num_components << " components.");
773     }
774     out_kernel.fuse_kernel(vector_kernel);
775     MathCode().magnitude(out_kernel.for_body,
776                          vector_kernel.expr,
777                          filter_name,
778                          vector_kernel.num_components);
779     out_kernel.expr = filter_name;
780     out_kernel.num_components = 1;
781   }
782 }
783 
784 void
vector()785 JitableFusion::vector()
786 {
787   const int arg1_port = inputs["arg1/port"].to_int32();
788   const int arg2_port = inputs["arg2/port"].to_int32();
789   const int arg3_port = inputs["arg3/port"].to_int32();
790   const Jitable &arg1_jitable = *input_jitables[arg1_port];
791   const Jitable &arg2_jitable = *input_jitables[arg2_port];
792   const Jitable &arg3_jitable = *input_jitables[arg3_port];
793   // if all the inputs to the vector are "trivial" fields then we don't need
794   // to regenerate the vector in a separate for-loop
795   if(arg1_jitable.obj.has_path("type") && arg2_jitable.obj.has_path("type") &&
796      arg3_jitable.obj.has_path("type"))
797   {
798     // We construct a fake schema with the input arrays as the components
799     const std::string &arg1_field = arg1_jitable.obj["value"].as_string();
800     const std::string &arg2_field = arg2_jitable.obj["value"].as_string();
801     const std::string &arg3_field = arg3_jitable.obj["value"].as_string();
802     std::unordered_map<std::string, SchemaBool> &array_map =
803         out_jitable.arrays[dom_idx].array_map;
804     conduit::Schema s;
805     s[arg1_field].set(array_map.at(arg1_field).schema);
806     s[arg2_field].set(array_map.at(arg2_field).schema);
807     s[arg3_field].set(array_map.at(arg3_field).schema);
808     array_map.insert(std::make_pair(filter_name, SchemaBool(s, true)));
809     out_jitable.obj["value"] = filter_name;
810     out_jitable.obj["type"] = "field";
811   }
812 
813   if(not_fused)
814   {
815     const Kernel &arg1_kernel = *input_kernels[arg1_port];
816     const Kernel &arg2_kernel = *input_kernels[arg2_port];
817     const Kernel &arg3_kernel = *input_kernels[arg3_port];
818     if(arg1_kernel.num_components != 1 || arg2_kernel.num_components != 1 ||
819        arg2_kernel.num_components != 1)
820     {
821       ASCENT_ERROR("Vector arguments must all have exactly one component.");
822     }
823     out_kernel.fuse_kernel(arg1_kernel);
824     out_kernel.fuse_kernel(arg2_kernel);
825     out_kernel.fuse_kernel(arg3_kernel);
826     const std::string arg1_expr = arg1_kernel.expr;
827     const std::string arg2_expr = arg2_kernel.expr;
828     const std::string arg3_expr = arg3_kernel.expr;
829     out_kernel.for_body.insert({"double " + filter_name + "[3];\n",
830                                 filter_name + "[0] = " + arg1_expr + ";\n",
831                                 filter_name + "[1] = " + arg2_expr + ";\n",
832                                 filter_name + "[2] = " + arg3_expr + ";\n"});
833     out_kernel.expr = filter_name;
834     out_kernel.num_components = 3;
835   }
836 }
837 
838 void
binning_value(const conduit::Node & binning)839 JitableFusion::binning_value(const conduit::Node &binning)
840 {
841   // bin lookup functions
842   // clang-format off
843   const std::string rectilinear_bin =
844     "int\n"
845     "rectilinear_bin(const double value,\n"
846     "                const double *const bins_begin,\n"
847     "                const int len,\n"
848     "                const bool clamp)\n"
849     "{\n"
850       // implements std::upper_bound
851       "int mid;\n"
852       "int low = 0;\n"
853       "int high = len;\n"
854       "while(low < high)\n"
855       "{\n"
856         "mid = (low + high) / 2;\n"
857         "if(value >= bins_begin[mid])\n"
858         "{\n"
859           "low = mid + 1;\n"
860         "}\n"
861         "else\n"
862         "{\n"
863           "high = mid;\n"
864         "}\n"
865       "}\n"
866 
867       "if(clamp)\n"
868       "{\n"
869         "if(low <= 0)\n"
870         "{\n"
871           "return 0;\n"
872         "}\n"
873         "else if(low >= len)\n"
874         "{\n"
875           "return len - 2;\n"
876         "}\n"
877       "}\n"
878       "else if(low <= 0 || low >= len)\n"
879       "{\n"
880         "return -1;\n"
881       "}\n"
882       "return low - 1;\n"
883     "}\n\n";
884   const std::string uniform_bin =
885     "int\n"
886     "uniform_bin(const double value,\n"
887     "            const double min_val,\n"
888     "            const double max_val,\n"
889     "            const int num_bins,\n"
890     "            const bool clamp)\n"
891     "{\n"
892       "const double inv_delta = num_bins / (max_val - min_val);\n"
893       "const int bin_index = (int)((value - min_val) * inv_delta);\n"
894       "if(clamp)\n"
895       "{\n"
896         "if(bin_index < 0)\n"
897         "{\n"
898           "return 0;\n"
899         "}\n"
900         "else if(bin_index >= num_bins)\n"
901         "{\n"
902           "return num_bins - 1;\n"
903         "}\n"
904       "}\n"
905       "else if(bin_index < 0 || bin_index >= num_bins)\n"
906       "{\n"
907         "return -1;\n"
908       "}\n"
909       "return bin_index;\n"
910     "}\n\n";
911   // clang-format on
912   //---------------------------------------------------------------------------
913 
914   // assume the necessary fields have been packed and are present in all
915   // domains
916 
917   // get the passed association
918   std::string assoc_str_;
919   if(inputs.has_path("assoc"))
920   {
921     const conduit::Node &assoc_obj =
922         input_jitables[inputs["assoc/port"].as_int32()]->obj;
923     assoc_str_ = assoc_obj["value"].as_string();
924   }
925 
926   const conduit::Node &bin_axes = binning["attrs/bin_axes/value"];
927   std::vector<std::string> axis_names = bin_axes.child_names();
928 
929   // set/verify out_jitable.topology and out_jitable.association
930   const conduit::Node &topo_and_assoc =
931       final_topo_and_assoc(dataset, bin_axes, out_jitable.topology, assoc_str_);
932   std::string assoc_str = topo_and_assoc["assoc_str"].as_string();
933   if(assoc_str.empty())
934   {
935     // use the association from the binning
936     assoc_str = binning["attrs/association/value"].as_string();
937   }
938   out_jitable.association = assoc_str;
939 
940   // set entries based on out_jitable.topology and out_jitable.association
941   std::unique_ptr<Topology> topo =
942       topologyFactory(out_jitable.topology, domain);
943   conduit::Node &n_entries = out_jitable.dom_info.child(dom_idx)["entries"];
944   if(out_jitable.association == "vertex")
945   {
946     n_entries = topo->get_num_points();
947   }
948   else if(out_jitable.association == "element")
949   {
950     n_entries = topo->get_num_cells();
951   }
952 
953   if(not_fused)
954   {
955     const TopologyCode topo_code =
956         TopologyCode(topo->topo_name, domain, out_jitable.arrays[dom_idx]);
957     const std::string &binning_name = inputs["binning/filter_name"].as_string();
958     InsertionOrderedSet<std::string> &code = out_kernel.for_body;
959     const int num_axes = bin_axes.number_of_children();
960     bool used_uniform = false;
961     bool used_rectilinear = false;
962 
963     code.insert("int " + filter_name + "_home = 0;\n");
964     code.insert("int " + filter_name + "_stride = 1;\n");
965     for(int axis_index = 0; axis_index < num_axes; ++axis_index)
966     {
967       const conduit::Node &axis = bin_axes.child(axis_index);
968       const std::string axis_name = axis.name();
969       const std::string axis_prefix = binning_name + "_" + axis_name + "_";
970       // find the value associated with the axis for the current item
971       std::string axis_value;
972       if(domain.has_path("fields/" + axis_name))
973       {
974         axis_value = axis_name + "_item";
975         code.insert("const double " + axis_value + " = " +
976                     out_jitable.arrays[dom_idx].index(axis_name, "item") +
977                     ";\n");
978       }
979       else if(is_xyz(axis_name))
980       {
981         if(out_jitable.association == "vertex")
982         {
983           axis_value = topo->topo_name + "_vertex_" + axis_name;
984           topo_code.vertex_coord(code, axis_name, "", axis_value);
985         }
986         else if(out_jitable.association == "element")
987         {
988           axis_value = topo->topo_name + "_cell_" + axis_name;
989           topo_code.element_coord(code, axis_name, "", axis_value);
990         }
991       }
992 
993       size_t stride_multiplier;
994       if(axis.has_path("num_bins"))
995       {
996         // uniform axis
997         stride_multiplier = bin_axes.child(axis_index)["num_bins"].as_int32();
998 
999         // find the value's index in the axis
1000         if(!used_uniform)
1001         {
1002           used_uniform = true;
1003           out_kernel.functions.insert(uniform_bin);
1004         }
1005         code.insert("int " + axis_prefix + "bin_index = uniform_bin(" +
1006                     axis_value + ", " + axis_prefix + "min_val, " +
1007                     axis_prefix + "max_val, " + axis_prefix + "num_bins, " +
1008                     axis_prefix + "clamp);\n");
1009       }
1010       else
1011       {
1012         // rectilinear axis
1013         stride_multiplier =
1014             bin_axes.child(axis_index)["bins"].dtype().number_of_elements() - 1;
1015 
1016         // find the value's index in the axis
1017         if(!used_rectilinear)
1018         {
1019           used_rectilinear = true;
1020           out_kernel.functions.insert(rectilinear_bin);
1021         }
1022         code.insert("int " + axis_prefix + "bin_index = rectilinear_bin(" +
1023                     axis_value + ", " + axis_prefix + "bins, " + axis_prefix +
1024                     "bins_len, " + axis_prefix + "clamp);\n");
1025       }
1026 
1027       // update the current item's home
1028       code.insert("if(" + axis_prefix + "bin_index != -1 && " + filter_name +
1029                   "_home != -1)\n{\n" + filter_name + "_home += " +
1030                   axis_prefix + "bin_index * " + filter_name + "_stride;\n}\n");
1031       // update stride
1032       code.insert(filter_name +
1033                   "_stride *= " + std::to_string(stride_multiplier) + ";\n");
1034     }
1035 
1036     // get the value at home
1037     std::string default_value;
1038     if(inputs.has_path("default_value"))
1039     {
1040       default_value = inputs["default_value/filter_name"].as_string();
1041     }
1042     else
1043     {
1044       default_value = "0";
1045     }
1046     code.insert({"double " + filter_name + ";\n",
1047                  "if(" + filter_name + "_home != -1)\n{\n" + filter_name +
1048                      " = " + binning_name + "_value[" + filter_name +
1049                      "_home];\n}\nelse\n{\n" + filter_name + " = " +
1050                      default_value + ";\n}\n"});
1051     out_kernel.expr = filter_name;
1052     out_kernel.num_components = 1;
1053   }
1054 }
1055 
1056 void
rand()1057 JitableFusion::rand()
1058 {
1059   out_jitable.dom_info.child(dom_idx)["args/" + filter_name + "_seed"] =
1060       time(nullptr);
1061   if(not_fused)
1062   {
1063     // clang-format off
1064     const std::string halton =
1065       "double rand(int i)\n"
1066 			"{\n"
1067 				"const int b = 2;\n"
1068 				"double f = 1;\n"
1069 				"double r = 0;\n"
1070 				"while(i > 0)\n"
1071 				"{\n"
1072 					"f = f / b;\n"
1073 					"r = r + f * (i \% b);\n"
1074 					"i = i / b;\n"
1075 				"}\n"
1076 				"return r;\n"
1077 			"}\n\n";
1078     // clang-format on
1079     out_kernel.functions.insert(halton);
1080     out_kernel.expr = "rand(item + " + filter_name + "_seed)";
1081     out_kernel.num_components = 1;
1082   }
1083 }
1084 
1085 //-----------------------------------------------------------------------------
1086 };
1087 //-----------------------------------------------------------------------------
1088 // -- end ascent::runtime::expressions--
1089 //-----------------------------------------------------------------------------
1090 
1091 //-----------------------------------------------------------------------------
1092 };
1093 //-----------------------------------------------------------------------------
1094 // -- end ascent::runtime --
1095 //-----------------------------------------------------------------------------
1096 
1097 //-----------------------------------------------------------------------------
1098 };
1099 //-----------------------------------------------------------------------------
1100 // -- end ascent:: --
1101 //-----------------------------------------------------------------------------
1102