1 /*
2  * This program is free software; you can redistribute it and/or
3  * modify it under the terms of the GNU General Public License
4  * as published by the Free Software Foundation; either version 2
5  * of the License, or (at your option) any later version.
6  *
7  * This program is distributed in the hope that it will be useful,
8  * but WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  * GNU General Public License for more details.
11  *
12  * You should have received a copy of the GNU General Public License
13  * along with this program; if not, write to the Free Software Foundation,
14  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15  */
16 
17 /** \file
18  * \ingroup fn
19  */
20 
21 /* Used to check if two multi-functions have the exact same type. */
22 #include <typeinfo>
23 
24 #include "FN_multi_function_builder.hh"
25 #include "FN_multi_function_network_evaluation.hh"
26 #include "FN_multi_function_network_optimization.hh"
27 
28 #include "BLI_disjoint_set.hh"
29 #include "BLI_ghash.h"
30 #include "BLI_map.hh"
31 #include "BLI_multi_value_map.hh"
32 #include "BLI_rand.h"
33 #include "BLI_stack.hh"
34 
35 namespace blender::fn::mf_network_optimization {
36 
37 /* -------------------------------------------------------------------- */
38 /** \name Utility functions to find nodes in a network.
39  *
40  * \{ */
41 
set_tag_and_check_if_modified(bool & tag,bool new_value)42 static bool set_tag_and_check_if_modified(bool &tag, bool new_value)
43 {
44   if (tag != new_value) {
45     tag = new_value;
46     return true;
47   }
48 
49   return false;
50 }
51 
mask_nodes_to_the_left(MFNetwork & network,Span<MFNode * > nodes)52 static Array<bool> mask_nodes_to_the_left(MFNetwork &network, Span<MFNode *> nodes)
53 {
54   Array<bool> is_to_the_left(network.node_id_amount(), false);
55   Stack<MFNode *> nodes_to_check;
56 
57   for (MFNode *node : nodes) {
58     is_to_the_left[node->id()] = true;
59     nodes_to_check.push(node);
60   }
61 
62   while (!nodes_to_check.is_empty()) {
63     MFNode &node = *nodes_to_check.pop();
64 
65     for (MFInputSocket *input_socket : node.inputs()) {
66       MFOutputSocket *origin = input_socket->origin();
67       if (origin != nullptr) {
68         MFNode &origin_node = origin->node();
69         if (set_tag_and_check_if_modified(is_to_the_left[origin_node.id()], true)) {
70           nodes_to_check.push(&origin_node);
71         }
72       }
73     }
74   }
75 
76   return is_to_the_left;
77 }
78 
mask_nodes_to_the_right(MFNetwork & network,Span<MFNode * > nodes)79 static Array<bool> mask_nodes_to_the_right(MFNetwork &network, Span<MFNode *> nodes)
80 {
81   Array<bool> is_to_the_right(network.node_id_amount(), false);
82   Stack<MFNode *> nodes_to_check;
83 
84   for (MFNode *node : nodes) {
85     is_to_the_right[node->id()] = true;
86     nodes_to_check.push(node);
87   }
88 
89   while (!nodes_to_check.is_empty()) {
90     MFNode &node = *nodes_to_check.pop();
91 
92     for (MFOutputSocket *output_socket : node.outputs()) {
93       for (MFInputSocket *target_socket : output_socket->targets()) {
94         MFNode &target_node = target_socket->node();
95         if (set_tag_and_check_if_modified(is_to_the_right[target_node.id()], true)) {
96           nodes_to_check.push(&target_node);
97         }
98       }
99     }
100   }
101 
102   return is_to_the_right;
103 }
104 
find_nodes_based_on_mask(MFNetwork & network,Span<bool> id_mask,bool mask_value)105 static Vector<MFNode *> find_nodes_based_on_mask(MFNetwork &network,
106                                                  Span<bool> id_mask,
107                                                  bool mask_value)
108 {
109   Vector<MFNode *> nodes;
110   for (int id : id_mask.index_range()) {
111     if (id_mask[id] == mask_value) {
112       MFNode *node = network.node_or_null_by_id(id);
113       if (node != nullptr) {
114         nodes.append(node);
115       }
116     }
117   }
118   return nodes;
119 }
120 
121 /** \} */
122 
123 /* -------------------------------------------------------------------- */
124 /** \name Dead Node Removal
125  *
126  * \{ */
127 
128 /**
129  * Unused nodes are all those nodes that no dummy node depends upon.
130  */
dead_node_removal(MFNetwork & network)131 void dead_node_removal(MFNetwork &network)
132 {
133   Array<bool> node_is_used_mask = mask_nodes_to_the_left(network,
134                                                          network.dummy_nodes().cast<MFNode *>());
135   Vector<MFNode *> nodes_to_remove = find_nodes_based_on_mask(network, node_is_used_mask, false);
136   network.remove(nodes_to_remove);
137 }
138 
139 /** \} */
140 
141 /* -------------------------------------------------------------------- */
142 /** \name Constant Folding
143  *
144  * \{ */
145 
function_node_can_be_constant(MFFunctionNode * node)146 static bool function_node_can_be_constant(MFFunctionNode *node)
147 {
148   if (node->has_unlinked_inputs()) {
149     return false;
150   }
151   if (node->function().depends_on_context()) {
152     return false;
153   }
154   return true;
155 }
156 
find_non_constant_nodes(MFNetwork & network)157 static Vector<MFNode *> find_non_constant_nodes(MFNetwork &network)
158 {
159   Vector<MFNode *> non_constant_nodes;
160   non_constant_nodes.extend(network.dummy_nodes().cast<MFNode *>());
161 
162   for (MFFunctionNode *node : network.function_nodes()) {
163     if (!function_node_can_be_constant(node)) {
164       non_constant_nodes.append(node);
165     }
166   }
167   return non_constant_nodes;
168 }
169 
output_has_non_constant_target_node(MFOutputSocket * output_socket,Span<bool> is_not_constant_mask)170 static bool output_has_non_constant_target_node(MFOutputSocket *output_socket,
171                                                 Span<bool> is_not_constant_mask)
172 {
173   for (MFInputSocket *target_socket : output_socket->targets()) {
174     MFNode &target_node = target_socket->node();
175     bool target_is_not_constant = is_not_constant_mask[target_node.id()];
176     if (target_is_not_constant) {
177       return true;
178     }
179   }
180   return false;
181 }
182 
try_find_dummy_target_socket(MFOutputSocket * output_socket)183 static MFInputSocket *try_find_dummy_target_socket(MFOutputSocket *output_socket)
184 {
185   for (MFInputSocket *target_socket : output_socket->targets()) {
186     if (target_socket->node().is_dummy()) {
187       return target_socket;
188     }
189   }
190   return nullptr;
191 }
192 
find_constant_inputs_to_fold(MFNetwork & network,Vector<MFDummyNode * > & r_temporary_nodes)193 static Vector<MFInputSocket *> find_constant_inputs_to_fold(
194     MFNetwork &network, Vector<MFDummyNode *> &r_temporary_nodes)
195 {
196   Vector<MFNode *> non_constant_nodes = find_non_constant_nodes(network);
197   Array<bool> is_not_constant_mask = mask_nodes_to_the_right(network, non_constant_nodes);
198   Vector<MFNode *> constant_nodes = find_nodes_based_on_mask(network, is_not_constant_mask, false);
199 
200   Vector<MFInputSocket *> sockets_to_compute;
201   for (MFNode *node : constant_nodes) {
202     if (node->inputs().size() == 0) {
203       continue;
204     }
205 
206     for (MFOutputSocket *output_socket : node->outputs()) {
207       MFDataType data_type = output_socket->data_type();
208       if (output_has_non_constant_target_node(output_socket, is_not_constant_mask)) {
209         MFInputSocket *dummy_target = try_find_dummy_target_socket(output_socket);
210         if (dummy_target == nullptr) {
211           dummy_target = &network.add_output("Dummy", data_type);
212           network.add_link(*output_socket, *dummy_target);
213           r_temporary_nodes.append(&dummy_target->node().as_dummy());
214         }
215 
216         sockets_to_compute.append(dummy_target);
217       }
218     }
219   }
220   return sockets_to_compute;
221 }
222 
prepare_params_for_constant_folding(const MultiFunction & network_fn,MFParamsBuilder & params,ResourceCollector & resources)223 static void prepare_params_for_constant_folding(const MultiFunction &network_fn,
224                                                 MFParamsBuilder &params,
225                                                 ResourceCollector &resources)
226 {
227   for (int param_index : network_fn.param_indices()) {
228     MFParamType param_type = network_fn.param_type(param_index);
229     MFDataType data_type = param_type.data_type();
230 
231     switch (data_type.category()) {
232       case MFDataType::Single: {
233         /* Allocates memory for a single constant folded value. */
234         const CPPType &cpp_type = data_type.single_type();
235         void *buffer = resources.linear_allocator().allocate(cpp_type.size(),
236                                                              cpp_type.alignment());
237         GMutableSpan array{cpp_type, buffer, 1};
238         params.add_uninitialized_single_output(array);
239         break;
240       }
241       case MFDataType::Vector: {
242         /* Allocates memory for a constant folded vector. */
243         const CPPType &cpp_type = data_type.vector_base_type();
244         GVectorArray &vector_array = resources.construct<GVectorArray>(AT, cpp_type, 1);
245         params.add_vector_output(vector_array);
246         break;
247       }
248     }
249   }
250 }
251 
add_constant_folded_sockets(const MultiFunction & network_fn,MFParamsBuilder & params,ResourceCollector & resources,MFNetwork & network)252 static Array<MFOutputSocket *> add_constant_folded_sockets(const MultiFunction &network_fn,
253                                                            MFParamsBuilder &params,
254                                                            ResourceCollector &resources,
255                                                            MFNetwork &network)
256 {
257   Array<MFOutputSocket *> folded_sockets{network_fn.param_indices().size(), nullptr};
258 
259   for (int param_index : network_fn.param_indices()) {
260     MFParamType param_type = network_fn.param_type(param_index);
261     MFDataType data_type = param_type.data_type();
262 
263     const MultiFunction *constant_fn = nullptr;
264 
265     switch (data_type.category()) {
266       case MFDataType::Single: {
267         const CPPType &cpp_type = data_type.single_type();
268         GMutableSpan array = params.computed_array(param_index);
269         void *buffer = array.data();
270         resources.add(buffer, array.type().destruct_cb(), AT);
271 
272         constant_fn = &resources.construct<CustomMF_GenericConstant>(AT, cpp_type, buffer);
273         break;
274       }
275       case MFDataType::Vector: {
276         GVectorArray &vector_array = params.computed_vector_array(param_index);
277         GSpan array = vector_array[0];
278         constant_fn = &resources.construct<CustomMF_GenericConstantArray>(AT, array);
279         break;
280       }
281     }
282 
283     MFFunctionNode &folded_node = network.add_function(*constant_fn);
284     folded_sockets[param_index] = &folded_node.output(0);
285   }
286   return folded_sockets;
287 }
288 
compute_constant_sockets_and_add_folded_nodes(MFNetwork & network,Span<const MFInputSocket * > sockets_to_compute,ResourceCollector & resources)289 static Array<MFOutputSocket *> compute_constant_sockets_and_add_folded_nodes(
290     MFNetwork &network,
291     Span<const MFInputSocket *> sockets_to_compute,
292     ResourceCollector &resources)
293 {
294   MFNetworkEvaluator network_fn{{}, sockets_to_compute};
295 
296   MFContextBuilder context;
297   MFParamsBuilder params{network_fn, 1};
298   prepare_params_for_constant_folding(network_fn, params, resources);
299   network_fn.call({0}, params, context);
300   return add_constant_folded_sockets(network_fn, params, resources, network);
301 }
302 
303 class MyClass {
304   MFDummyNode node;
305 };
306 
307 /**
308  * Find function nodes that always output the same value and replace those with constant nodes.
309  */
constant_folding(MFNetwork & network,ResourceCollector & resources)310 void constant_folding(MFNetwork &network, ResourceCollector &resources)
311 {
312   Vector<MFDummyNode *> temporary_nodes;
313   Vector<MFInputSocket *> inputs_to_fold = find_constant_inputs_to_fold(network, temporary_nodes);
314   if (inputs_to_fold.size() == 0) {
315     return;
316   }
317 
318   Array<MFOutputSocket *> folded_sockets = compute_constant_sockets_and_add_folded_nodes(
319       network, inputs_to_fold, resources);
320 
321   for (int i : inputs_to_fold.index_range()) {
322     MFOutputSocket &original_socket = *inputs_to_fold[i]->origin();
323     network.relink(original_socket, *folded_sockets[i]);
324   }
325 
326   network.remove(temporary_nodes.as_span().cast<MFNode *>());
327 }
328 
329 /** \} */
330 
331 /* -------------------------------------------------------------------- */
332 /** \name Common Sub-network Elimination
333  *
334  * \{ */
335 
compute_node_hash(MFFunctionNode & node,RNG * rng,Span<uint64_t> node_hashes)336 static uint64_t compute_node_hash(MFFunctionNode &node, RNG *rng, Span<uint64_t> node_hashes)
337 {
338   if (node.function().depends_on_context()) {
339     return BLI_rng_get_uint(rng);
340   }
341   if (node.has_unlinked_inputs()) {
342     return BLI_rng_get_uint(rng);
343   }
344 
345   uint64_t combined_inputs_hash = 394659347u;
346   for (MFInputSocket *input_socket : node.inputs()) {
347     MFOutputSocket *origin_socket = input_socket->origin();
348     uint64_t input_hash = BLI_ghashutil_combine_hash(node_hashes[origin_socket->node().id()],
349                                                      origin_socket->index());
350     combined_inputs_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, input_hash);
351   }
352 
353   uint64_t function_hash = node.function().hash();
354   uint64_t node_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, function_hash);
355   return node_hash;
356 }
357 
358 /**
359  * Produces a hash for every node. Two nodes with the same hash should have a high probability of
360  * outputting the same values.
361  */
compute_node_hashes(MFNetwork & network)362 static Array<uint64_t> compute_node_hashes(MFNetwork &network)
363 {
364   RNG *rng = BLI_rng_new(0);
365   Array<uint64_t> node_hashes(network.node_id_amount());
366   Array<bool> node_is_hashed(network.node_id_amount(), false);
367 
368   /* No dummy nodes are not assumed to output the same values. */
369   for (MFDummyNode *node : network.dummy_nodes()) {
370     uint64_t node_hash = BLI_rng_get_uint(rng);
371     node_hashes[node->id()] = node_hash;
372     node_is_hashed[node->id()] = true;
373   }
374 
375   Stack<MFFunctionNode *> nodes_to_check;
376   nodes_to_check.push_multiple(network.function_nodes());
377 
378   while (!nodes_to_check.is_empty()) {
379     MFFunctionNode &node = *nodes_to_check.peek();
380     if (node_is_hashed[node.id()]) {
381       nodes_to_check.pop();
382       continue;
383     }
384 
385     /* Make sure that origin nodes are hashed first. */
386     bool all_dependencies_ready = true;
387     for (MFInputSocket *input_socket : node.inputs()) {
388       MFOutputSocket *origin_socket = input_socket->origin();
389       if (origin_socket != nullptr) {
390         MFNode &origin_node = origin_socket->node();
391         if (!node_is_hashed[origin_node.id()]) {
392           all_dependencies_ready = false;
393           nodes_to_check.push(&origin_node.as_function());
394         }
395       }
396     }
397     if (!all_dependencies_ready) {
398       continue;
399     }
400 
401     uint64_t node_hash = compute_node_hash(node, rng, node_hashes);
402     node_hashes[node.id()] = node_hash;
403     node_is_hashed[node.id()] = true;
404     nodes_to_check.pop();
405   }
406 
407   BLI_rng_free(rng);
408   return node_hashes;
409 }
410 
group_nodes_by_hash(MFNetwork & network,Span<uint64_t> node_hashes)411 static MultiValueMap<uint64_t, MFNode *> group_nodes_by_hash(MFNetwork &network,
412                                                              Span<uint64_t> node_hashes)
413 {
414   MultiValueMap<uint64_t, MFNode *> nodes_by_hash;
415   for (int id : IndexRange(network.node_id_amount())) {
416     MFNode *node = network.node_or_null_by_id(id);
417     if (node != nullptr) {
418       uint64_t node_hash = node_hashes[id];
419       nodes_by_hash.add(node_hash, node);
420     }
421   }
422   return nodes_by_hash;
423 }
424 
functions_are_equal(const MultiFunction & a,const MultiFunction & b)425 static bool functions_are_equal(const MultiFunction &a, const MultiFunction &b)
426 {
427   if (&a == &b) {
428     return true;
429   }
430   if (typeid(a) == typeid(b)) {
431     return a.equals(b);
432   }
433   return false;
434 }
435 
nodes_output_same_values(DisjointSet & cache,const MFNode & a,const MFNode & b)436 static bool nodes_output_same_values(DisjointSet &cache, const MFNode &a, const MFNode &b)
437 {
438   if (cache.in_same_set(a.id(), b.id())) {
439     return true;
440   }
441 
442   if (a.is_dummy() || b.is_dummy()) {
443     return false;
444   }
445   if (!functions_are_equal(a.as_function().function(), b.as_function().function())) {
446     return false;
447   }
448   for (int i : a.inputs().index_range()) {
449     const MFOutputSocket *origin_a = a.input(i).origin();
450     const MFOutputSocket *origin_b = b.input(i).origin();
451     if (origin_a == nullptr || origin_b == nullptr) {
452       return false;
453     }
454     if (!nodes_output_same_values(cache, origin_a->node(), origin_b->node())) {
455       return false;
456     }
457   }
458 
459   cache.join(a.id(), b.id());
460   return true;
461 }
462 
relink_duplicate_nodes(MFNetwork & network,MultiValueMap<uint64_t,MFNode * > & nodes_by_hash)463 static void relink_duplicate_nodes(MFNetwork &network,
464                                    MultiValueMap<uint64_t, MFNode *> &nodes_by_hash)
465 {
466   DisjointSet same_node_cache{network.node_id_amount()};
467 
468   for (Span<MFNode *> nodes_with_same_hash : nodes_by_hash.values()) {
469     if (nodes_with_same_hash.size() <= 1) {
470       continue;
471     }
472 
473     Vector<MFNode *, 16> nodes_to_check = nodes_with_same_hash;
474     while (nodes_to_check.size() >= 2) {
475       Vector<MFNode *, 16> remaining_nodes;
476 
477       MFNode &deduplicated_node = *nodes_to_check[0];
478       for (MFNode *node : nodes_to_check.as_span().drop_front(1)) {
479         /* This is true with fairly high probability, but hash collisions can happen. So we have to
480          * check if the node actually output the same values. */
481         if (nodes_output_same_values(same_node_cache, deduplicated_node, *node)) {
482           for (int i : deduplicated_node.outputs().index_range()) {
483             network.relink(node->output(i), deduplicated_node.output(i));
484           }
485         }
486         else {
487           remaining_nodes.append(node);
488         }
489       }
490       nodes_to_check = std::move(remaining_nodes);
491     }
492   }
493 }
494 
495 /**
496  * Tries to detect duplicate sub-networks and eliminates them. This can help quite a lot when node
497  * groups were used to create the network.
498  */
common_subnetwork_elimination(MFNetwork & network)499 void common_subnetwork_elimination(MFNetwork &network)
500 {
501   Array<uint64_t> node_hashes = compute_node_hashes(network);
502   MultiValueMap<uint64_t, MFNode *> nodes_by_hash = group_nodes_by_hash(network, node_hashes);
503   relink_duplicate_nodes(network, nodes_by_hash);
504 }
505 
506 /** \} */
507 
508 }  // namespace blender::fn::mf_network_optimization
509