1 /*
2  * Copyright 2017 WebAssembly Community Group participants
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef wasm_ir_module_h
18 #define wasm_ir_module_h
19 
20 #include "ir/find_all.h"
21 #include "ir/manipulation.h"
22 #include "ir/properties.h"
23 #include "pass.h"
24 #include "support/unique_deferring_queue.h"
25 #include "wasm.h"
26 
27 namespace wasm {
28 
29 namespace ModuleUtils {
30 
copyFunction(Function * func,Module & out)31 inline Function* copyFunction(Function* func, Module& out) {
32   auto* ret = new Function();
33   ret->name = func->name;
34   ret->sig = func->sig;
35   ret->vars = func->vars;
36   ret->localNames = func->localNames;
37   ret->localIndices = func->localIndices;
38   ret->debugLocations = func->debugLocations;
39   ret->body = ExpressionManipulator::copy(func->body, out);
40   ret->module = func->module;
41   ret->base = func->base;
42   // TODO: copy Stack IR
43   assert(!func->stackIR);
44   out.addFunction(ret);
45   return ret;
46 }
47 
copyGlobal(Global * global,Module & out)48 inline Global* copyGlobal(Global* global, Module& out) {
49   auto* ret = new Global();
50   ret->name = global->name;
51   ret->type = global->type;
52   ret->mutable_ = global->mutable_;
53   ret->module = global->module;
54   ret->base = global->base;
55   if (global->imported()) {
56     ret->init = nullptr;
57   } else {
58     ret->init = ExpressionManipulator::copy(global->init, out);
59   }
60   out.addGlobal(ret);
61   return ret;
62 }
63 
copyEvent(Event * event,Module & out)64 inline Event* copyEvent(Event* event, Module& out) {
65   auto* ret = new Event();
66   ret->name = event->name;
67   ret->attribute = event->attribute;
68   ret->sig = event->sig;
69   out.addEvent(ret);
70   return ret;
71 }
72 
copyModule(const Module & in,Module & out)73 inline void copyModule(const Module& in, Module& out) {
74   // we use names throughout, not raw pointers, so simple copying is fine
75   // for everything *but* expressions
76   for (auto& curr : in.exports) {
77     out.addExport(new Export(*curr));
78   }
79   for (auto& curr : in.functions) {
80     copyFunction(curr.get(), out);
81   }
82   for (auto& curr : in.globals) {
83     copyGlobal(curr.get(), out);
84   }
85   for (auto& curr : in.events) {
86     copyEvent(curr.get(), out);
87   }
88   out.table = in.table;
89   for (auto& segment : out.table.segments) {
90     segment.offset = ExpressionManipulator::copy(segment.offset, out);
91   }
92   out.memory = in.memory;
93   for (auto& segment : out.memory.segments) {
94     segment.offset = ExpressionManipulator::copy(segment.offset, out);
95   }
96   out.start = in.start;
97   out.userSections = in.userSections;
98   out.debugInfoFileNames = in.debugInfoFileNames;
99 }
100 
clearModule(Module & wasm)101 inline void clearModule(Module& wasm) {
102   wasm.~Module();
103   new (&wasm) Module;
104 }
105 
106 // Renaming
107 
108 // Rename functions along with all their uses.
109 // Note that for this to work the functions themselves don't necessarily need
110 // to exist.  For example, it is possible to remove a given function and then
111 // call this redirect all of its uses.
renameFunctions(Module & wasm,T & map)112 template<typename T> inline void renameFunctions(Module& wasm, T& map) {
113   // Update the function itself.
114   for (auto& pair : map) {
115     if (Function* F = wasm.getFunctionOrNull(pair.first)) {
116       assert(!wasm.getFunctionOrNull(pair.second) || F->name == pair.second);
117       F->name = pair.second;
118     }
119   }
120   wasm.updateMaps();
121   // Update other global things.
122   auto maybeUpdate = [&](Name& name) {
123     auto iter = map.find(name);
124     if (iter != map.end()) {
125       name = iter->second;
126     }
127   };
128   maybeUpdate(wasm.start);
129   for (auto& segment : wasm.table.segments) {
130     for (auto& name : segment.data) {
131       maybeUpdate(name);
132     }
133   }
134   for (auto& exp : wasm.exports) {
135     if (exp->kind == ExternalKind::Function) {
136       maybeUpdate(exp->value);
137     }
138   }
139   // Update call instructions.
140   for (auto& func : wasm.functions) {
141     // TODO: parallelize
142     if (!func->imported()) {
143       FindAll<Call> calls(func->body);
144       for (auto* call : calls.list) {
145         maybeUpdate(call->target);
146       }
147     }
148   }
149 }
150 
renameFunction(Module & wasm,Name oldName,Name newName)151 inline void renameFunction(Module& wasm, Name oldName, Name newName) {
152   std::map<Name, Name> map;
153   map[oldName] = newName;
154   renameFunctions(wasm, map);
155 }
156 
157 // Convenient iteration over imported/non-imported module elements
158 
iterImportedMemories(Module & wasm,T visitor)159 template<typename T> inline void iterImportedMemories(Module& wasm, T visitor) {
160   if (wasm.memory.exists && wasm.memory.imported()) {
161     visitor(&wasm.memory);
162   }
163 }
164 
iterDefinedMemories(Module & wasm,T visitor)165 template<typename T> inline void iterDefinedMemories(Module& wasm, T visitor) {
166   if (wasm.memory.exists && !wasm.memory.imported()) {
167     visitor(&wasm.memory);
168   }
169 }
170 
iterImportedTables(Module & wasm,T visitor)171 template<typename T> inline void iterImportedTables(Module& wasm, T visitor) {
172   if (wasm.table.exists && wasm.table.imported()) {
173     visitor(&wasm.table);
174   }
175 }
176 
iterDefinedTables(Module & wasm,T visitor)177 template<typename T> inline void iterDefinedTables(Module& wasm, T visitor) {
178   if (wasm.table.exists && !wasm.table.imported()) {
179     visitor(&wasm.table);
180   }
181 }
182 
iterImportedGlobals(Module & wasm,T visitor)183 template<typename T> inline void iterImportedGlobals(Module& wasm, T visitor) {
184   for (auto& import : wasm.globals) {
185     if (import->imported()) {
186       visitor(import.get());
187     }
188   }
189 }
190 
iterDefinedGlobals(Module & wasm,T visitor)191 template<typename T> inline void iterDefinedGlobals(Module& wasm, T visitor) {
192   for (auto& import : wasm.globals) {
193     if (!import->imported()) {
194       visitor(import.get());
195     }
196   }
197 }
198 
199 template<typename T>
iterImportedFunctions(Module & wasm,T visitor)200 inline void iterImportedFunctions(Module& wasm, T visitor) {
201   for (auto& import : wasm.functions) {
202     if (import->imported()) {
203       visitor(import.get());
204     }
205   }
206 }
207 
iterDefinedFunctions(Module & wasm,T visitor)208 template<typename T> inline void iterDefinedFunctions(Module& wasm, T visitor) {
209   for (auto& import : wasm.functions) {
210     if (!import->imported()) {
211       visitor(import.get());
212     }
213   }
214 }
215 
iterImportedEvents(Module & wasm,T visitor)216 template<typename T> inline void iterImportedEvents(Module& wasm, T visitor) {
217   for (auto& import : wasm.events) {
218     if (import->imported()) {
219       visitor(import.get());
220     }
221   }
222 }
223 
iterDefinedEvents(Module & wasm,T visitor)224 template<typename T> inline void iterDefinedEvents(Module& wasm, T visitor) {
225   for (auto& import : wasm.events) {
226     if (!import->imported()) {
227       visitor(import.get());
228     }
229   }
230 }
231 
iterImports(Module & wasm,T visitor)232 template<typename T> inline void iterImports(Module& wasm, T visitor) {
233   iterImportedMemories(wasm, visitor);
234   iterImportedTables(wasm, visitor);
235   iterImportedGlobals(wasm, visitor);
236   iterImportedFunctions(wasm, visitor);
237   iterImportedEvents(wasm, visitor);
238 }
239 
240 // Helper class for performing an operation on all the functions in the module,
241 // in parallel, with an Info object for each one that can contain results of
242 // some computation that the operation performs.
243 // The operation performend should not modify the wasm module in any way.
244 // TODO: enforce this
245 template<typename T> struct ParallelFunctionAnalysis {
246   Module& wasm;
247 
248   typedef std::map<Function*, T> Map;
249   Map map;
250 
251   typedef std::function<void(Function*, T&)> Func;
252 
ParallelFunctionAnalysisParallelFunctionAnalysis253   ParallelFunctionAnalysis(Module& wasm, Func work) : wasm(wasm) {
254     // Fill in map, as we operate on it in parallel (each function to its own
255     // entry).
256     for (auto& func : wasm.functions) {
257       map[func.get()];
258     }
259 
260     // Run on the imports first. TODO: parallelize this too
261     for (auto& func : wasm.functions) {
262       if (func->imported()) {
263         work(func.get(), map[func.get()]);
264       }
265     }
266 
267     struct Mapper : public WalkerPass<PostWalker<Mapper>> {
268       bool isFunctionParallel() override { return true; }
269       bool modifiesBinaryenIR() override { return false; }
270 
271       Mapper(Module& module, Map& map, Func work)
272         : module(module), map(map), work(work) {}
273 
274       Mapper* create() override { return new Mapper(module, map, work); }
275 
276       void doWalkFunction(Function* curr) {
277         assert(map.count(curr));
278         work(curr, map[curr]);
279       }
280 
281     private:
282       Module& module;
283       Map& map;
284       Func work;
285     };
286 
287     PassRunner runner(&wasm);
288     Mapper(wasm, map, work).run(&runner, &wasm);
289   }
290 };
291 
292 // Helper class for analyzing the call graph.
293 //
294 // Provides hooks for running some initial calculation on each function (which
295 // is done in parallel), writing to a FunctionInfo structure for each function.
296 // Then you can call propagateBack() to propagate a property of interest to the
297 // calling functions, transitively.
298 //
299 // For example, if some functions are known to call an import "foo", then you
300 // can use this to find which functions call something that might eventually
301 // reach foo, by initially marking the direct callers as "calling foo" and
302 // propagating that backwards.
303 template<typename T> struct CallGraphPropertyAnalysis {
304   Module& wasm;
305 
306   // The basic information for each function about whom it calls and who is
307   // called by it.
308   struct FunctionInfo {
309     std::set<Function*> callsTo;
310     std::set<Function*> calledBy;
311     bool hasIndirectCall = false;
312   };
313 
314   typedef std::map<Function*, T> Map;
315   Map map;
316 
317   typedef std::function<void(Function*, T&)> Func;
318 
CallGraphPropertyAnalysisCallGraphPropertyAnalysis319   CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) {
320     ParallelFunctionAnalysis<T> analysis(wasm, [&](Function* func, T& info) {
321       work(func, info);
322       if (func->imported()) {
323         return;
324       }
325       struct Mapper : public PostWalker<Mapper> {
326         Mapper(Module* module, T& info, Func work)
327           : module(module), info(info), work(work) {}
328 
329         void visitCall(Call* curr) {
330           info.callsTo.insert(module->getFunction(curr->target));
331         }
332 
333         void visitCallIndirect(CallIndirect* curr) {
334           info.hasIndirectCall = true;
335         }
336 
337       private:
338         Module* module;
339         T& info;
340         Func work;
341       } mapper(&wasm, info, work);
342       mapper.walk(func->body);
343     });
344 
345     map.swap(analysis.map);
346 
347     // Find what is called by what.
348     for (auto& pair : map) {
349       auto* func = pair.first;
350       auto& info = pair.second;
351       for (auto* target : info.callsTo) {
352         map[target].calledBy.insert(func);
353       }
354     }
355   }
356 
357   enum IndirectCalls { IgnoreIndirectCalls, IndirectCallsHaveProperty };
358 
359   // Propagate a property from a function to those that call it.
360   //
361   // hasProperty() - Check if the property is present.
362   // canHaveProperty() - Check if the property could be present.
363   // addProperty() - Adds the property. This receives a second parameter which
364   //                 is the function due to which we are adding the property.
propagateBackCallGraphPropertyAnalysis365   void propagateBack(std::function<bool(const T&)> hasProperty,
366                      std::function<bool(const T&)> canHaveProperty,
367                      std::function<void(T&, Function*)> addProperty,
368                      IndirectCalls indirectCalls) {
369     // The work queue contains items we just learned can change the state.
370     UniqueDeferredQueue<Function*> work;
371     for (auto& func : wasm.functions) {
372       if (hasProperty(map[func.get()]) ||
373           (indirectCalls == IndirectCallsHaveProperty &&
374            map[func.get()].hasIndirectCall)) {
375         addProperty(map[func.get()], func.get());
376         work.push(func.get());
377       }
378     }
379     while (!work.empty()) {
380       auto* func = work.pop();
381       for (auto* caller : map[func].calledBy) {
382         // If we don't already have the property, and we are not forbidden
383         // from getting it, then it propagates back to us now.
384         if (!hasProperty(map[caller]) && canHaveProperty(map[caller])) {
385           addProperty(map[caller], func);
386           work.push(caller);
387         }
388       }
389     }
390   }
391 };
392 
393 // Helper function for collecting the type signatures used in a module
394 //
395 // Used when emitting or printing a module to give signatures canonical
396 // indices. Signatures are sorted in order of decreasing frequency to minize the
397 // size of their collective encoding. Both a vector mapping indices to
398 // signatures and a map mapping signatures to indices are produced.
399 inline void
collectSignatures(Module & wasm,std::vector<Signature> & signatures,std::unordered_map<Signature,Index> & sigIndices)400 collectSignatures(Module& wasm,
401                   std::vector<Signature>& signatures,
402                   std::unordered_map<Signature, Index>& sigIndices) {
403   using Counts = std::unordered_map<Signature, size_t>;
404 
405   // Collect the signature use counts for a single function
406   auto updateCounts = [&](Function* func, Counts& counts) {
407     if (func->imported()) {
408       return;
409     }
410     struct TypeCounter
411       : PostWalker<TypeCounter, UnifiedExpressionVisitor<TypeCounter>> {
412       Counts& counts;
413 
414       TypeCounter(Counts& counts) : counts(counts) {}
415       void visitExpression(Expression* curr) {
416         if (auto* call = curr->dynCast<CallIndirect>()) {
417           counts[call->sig]++;
418         } else if (Properties::isControlFlowStructure(curr)) {
419           // TODO: Allow control flow to have input types as well
420           if (curr->type.isTuple()) {
421             counts[Signature(Type::none, curr->type)]++;
422           }
423         }
424       }
425     };
426     TypeCounter(counts).walk(func->body);
427   };
428 
429   ModuleUtils::ParallelFunctionAnalysis<Counts> analysis(wasm, updateCounts);
430 
431   // Collect all the counts.
432   Counts counts;
433   for (auto& curr : wasm.functions) {
434     counts[curr->sig]++;
435   }
436   for (auto& curr : wasm.events) {
437     counts[curr->sig]++;
438   }
439   for (auto& pair : analysis.map) {
440     Counts& functionCounts = pair.second;
441     for (auto& innerPair : functionCounts) {
442       counts[innerPair.first] += innerPair.second;
443     }
444   }
445   std::vector<std::pair<Signature, size_t>> sorted(counts.begin(),
446                                                    counts.end());
447   std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
448     // order by frequency then simplicity
449     if (a.second != b.second) {
450       return a.second > b.second;
451     }
452     return a.first < b.first;
453   });
454   for (Index i = 0; i < sorted.size(); ++i) {
455     sigIndices[sorted[i].first] = i;
456     signatures.push_back(sorted[i].first);
457   }
458 }
459 
460 } // namespace ModuleUtils
461 
462 } // namespace wasm
463 
464 #endif // wasm_ir_module_h
465