1 //! Implementations of various IR traversals.
2 
3 use crate::ir::*;
4 
5 /// Perform an intra-procedural, depth-first, in-order traversal of the IR.
6 ///
7 /// * *Intra-procedural*: Only traverses IR within a function. Does not cross
8 ///   function boundaries (although it will report edges to other functions via
9 ///   `visit_function_id` calls on the visitor, so you can use this as a
10 ///   building block for making global, inter-procedural analyses).
11 ///
12 /// * *Depth-first, in-order*: Visits instructions and instruction sequences in
13 ///   the order they are defined and nested. See [Wikipedia][in-order] for
14 ///   details.
15 ///
16 /// Calls `visitor` methods for every instruction, instruction sequence, and
17 /// resource that the traversal visits.
18 ///
19 /// The traversals begins at the `start` instruction sequence and goes from
20 /// there. To traverse everything in a function, pass `func.entry_block()` as
21 /// `start`.
22 ///
23 /// This implementation is iterative — not recursive — and so it
24 /// will not blow the call stack on deeply nested Wasm (although it may still
25 /// OOM).
26 ///
27 /// [in-order]: https://en.wikipedia.org/wiki/Tree_traversal#In-order_(LNR)
28 ///
29 /// # Example
30 ///
31 /// This example counts the number of instruction sequences in a function.
32 ///
33 /// ```no_run
34 /// use walrus::LocalFunction;
35 /// use walrus::ir::*;
36 ///
37 /// #[derive(Default)]
38 /// struct CountInstructionSequences {
39 ///     count: usize,
40 /// }
41 ///
42 /// impl<'instr> Visitor<'instr> for CountInstructionSequences {
43 ///     fn start_instr_seq(&mut self, _: &'instr InstrSeq) {
44 ///         self.count += 1;
45 ///     }
46 /// }
47 ///
48 /// // Get a function from somewhere.
49 /// # let get_my_function = || unimplemented!();
50 /// let my_func: &LocalFunction = get_my_function();
51 ///
52 /// // Create our visitor.
53 /// let mut visitor = CountInstructionSequences::default();
54 ///
55 /// // Traverse everything in the function with our visitor.
56 /// dfs_in_order(&mut visitor, my_func, my_func.entry_block());
57 ///
58 /// // Use the aggregate results that `visitor` built up.
59 /// println!("The number of instruction sequences in `my_func` is {}", visitor.count);
60 /// ```
dfs_in_order<'instr>( visitor: &mut impl Visitor<'instr>, func: &'instr LocalFunction, start: InstrSeqId, )61 pub fn dfs_in_order<'instr>(
62     visitor: &mut impl Visitor<'instr>,
63     func: &'instr LocalFunction,
64     start: InstrSeqId,
65 ) {
66     // The stack of instruction sequences we still need to visit, and how far
67     // along in the instruction sequence we are.
68     let mut stack: Vec<(InstrSeqId, usize)> = vec![(start, 0)];
69 
70     'traversing_blocks: while let Some((seq_id, index)) = stack.pop() {
71         let seq = func.block(seq_id);
72 
73         if index == 0 {
74             // If the `index` is zero, then we haven't processed any
75             // instructions in this sequence yet, and it is the first time we
76             // are entering it, so let the visitor know.
77             visitor.start_instr_seq(seq);
78             seq.visit(visitor);
79         }
80 
81         'traversing_instrs: for (index, (instr, loc)) in seq.instrs.iter().enumerate().skip(index) {
82             // Visit this instruction.
83             log::trace!("dfs_in_order: visit_instr({:?})", instr);
84             visitor.visit_instr(instr, loc);
85 
86             // Visit every other resource that this instruction references,
87             // e.g. `MemoryId`s, `FunctionId`s and all that.
88             log::trace!("dfs_in_order: ({:?}).visit(..)", instr);
89             instr.visit(visitor);
90 
91             match instr {
92                 // Pause iteration through this sequence's instructions and
93                 // enqueue `seq` to be traversed next before continuing with
94                 // this one where we left off.
95                 Instr::Block(Block { seq }) | Instr::Loop(Loop { seq }) => {
96                     stack.push((seq_id, index + 1));
97                     stack.push((*seq, 0));
98                     continue 'traversing_blocks;
99                 }
100 
101                 // Pause iteration through this sequence's instructions.
102                 // Traverse the consequent and then the alternative.
103                 Instr::IfElse(IfElse {
104                     consequent,
105                     alternative,
106                 }) => {
107                     stack.push((seq_id, index + 1));
108                     stack.push((*alternative, 0));
109                     stack.push((*consequent, 0));
110                     continue 'traversing_blocks;
111                 }
112 
113                 // No other instructions define new instruction sequences, so
114                 // continue to the next instruction.
115                 _ => continue 'traversing_instrs,
116             }
117         }
118 
119         // If we made it through the whole loop above, then we processed every
120         // instruction in the sequence, and its nested sequences, so we are
121         // finished with it!
122         visitor.end_instr_seq(seq);
123     }
124 }
125 
126 /// Perform an intra-procedural, depth-first, pre-order, mutable traversal of
127 /// the IR.
128 ///
129 /// * *Intra-procedural*: Only traverses IR within a function. Does not cross
130 ///   function boundaries (although it will report edges to other functions via
131 ///   `visit_function_id` calls on the visitor, so you can use this as a
132 ///   building block for making global, inter-procedural analyses).
133 ///
134 /// * *Depth-first, pre-order*: Visits instructions and instruction sequences in
135 ///   a top-down manner, where all instructions in a parent sequences are
136 ///   visited before child sequences. See [Wikipedia][pre-order] for details.
137 ///
138 /// Calls `visitor` methods for every instruction, instruction sequence, and
139 /// resource that the traversal visits.
140 ///
141 /// The traversals begins at the `start` instruction sequence and goes from
142 /// there. To traverse everything in a function, pass `func.entry_block()` as
143 /// `start`.
144 ///
145 /// This implementation is iterative &mdash; not recursive &mdash; and so it
146 /// will not blow the call stack on deeply nested Wasm (although it may still
147 /// OOM).
148 ///
149 /// [pre-order]: https://en.wikipedia.org/wiki/Tree_traversal#Pre-order_(NLR)
150 ///
151 /// # Example
152 ///
153 /// This example walks the IR and adds one to all `i32.const`'s values.
154 ///
155 /// ```no_run
156 /// use walrus::LocalFunction;
157 /// use walrus::ir::*;
158 ///
159 /// #[derive(Default)]
160 /// struct AddOneToI32Consts;
161 ///
162 /// impl VisitorMut for AddOneToI32Consts {
163 ///     fn visit_const_mut(&mut self, c: &mut Const) {
164 ///         match &mut c.value {
165 ///             Value::I32(x) => {
166 ///                 *x += 1;
167 ///             }
168 ///             _ => {},
169 ///         }
170 ///     }
171 /// }
172 ///
173 /// // Get a function from somewhere.
174 /// # let get_my_function = || unimplemented!();
175 /// let my_func: &mut LocalFunction = get_my_function();
176 ///
177 /// // Create our visitor.
178 /// let mut visitor = AddOneToI32Consts::default();
179 ///
180 /// // Traverse and mutate everything in the function with our visitor.
181 /// dfs_pre_order_mut(&mut visitor, my_func, my_func.entry_block());
182 /// ```
dfs_pre_order_mut( visitor: &mut impl VisitorMut, func: &mut LocalFunction, start: InstrSeqId, )183 pub fn dfs_pre_order_mut(
184     visitor: &mut impl VisitorMut,
185     func: &mut LocalFunction,
186     start: InstrSeqId,
187 ) {
188     let mut stack = vec![start];
189 
190     while let Some(seq_id) = stack.pop() {
191         let seq = func.block_mut(seq_id);
192         visitor.start_instr_seq_mut(seq);
193         seq.visit_mut(visitor);
194 
195         for (instr, loc) in &mut seq.instrs {
196             visitor.visit_instr_mut(instr, loc);
197             instr.visit_mut(visitor);
198 
199             match instr {
200                 Instr::Block(Block { seq }) | Instr::Loop(Loop { seq }) => {
201                     stack.push(*seq);
202                 }
203 
204                 Instr::IfElse(IfElse {
205                     consequent,
206                     alternative,
207                 }) => {
208                     stack.push(*alternative);
209                     stack.push(*consequent);
210                 }
211 
212                 _ => {}
213             }
214         }
215 
216         visitor.end_instr_seq_mut(seq);
217     }
218 }
219 
220 #[cfg(test)]
221 mod tests {
222     use super::*;
223 
224     #[derive(Default)]
225     struct TestVisitor {
226         visits: Vec<String>,
227     }
228 
229     impl TestVisitor {
push(&mut self, s: impl ToString)230         fn push(&mut self, s: impl ToString) {
231             self.visits.push(s.to_string());
232         }
233     }
234 
235     impl<'a> Visitor<'a> for TestVisitor {
start_instr_seq(&mut self, _: &'a InstrSeq)236         fn start_instr_seq(&mut self, _: &'a InstrSeq) {
237             self.push("start");
238         }
239 
end_instr_seq(&mut self, _: &'a InstrSeq)240         fn end_instr_seq(&mut self, _: &'a InstrSeq) {
241             self.push("end");
242         }
243 
visit_const(&mut self, c: &Const)244         fn visit_const(&mut self, c: &Const) {
245             match c.value {
246                 Value::I32(x) => self.push(x),
247                 _ => unreachable!(),
248             }
249         }
250 
visit_drop(&mut self, _: &Drop)251         fn visit_drop(&mut self, _: &Drop) {
252             self.push("drop");
253         }
254 
visit_block(&mut self, _: &Block)255         fn visit_block(&mut self, _: &Block) {
256             self.push("block");
257         }
258 
visit_if_else(&mut self, _: &IfElse)259         fn visit_if_else(&mut self, _: &IfElse) {
260             self.push("if-else");
261         }
262     }
263 
264     impl VisitorMut for TestVisitor {
start_instr_seq_mut(&mut self, _: &mut InstrSeq)265         fn start_instr_seq_mut(&mut self, _: &mut InstrSeq) {
266             self.push("start");
267         }
268 
end_instr_seq_mut(&mut self, _: &mut InstrSeq)269         fn end_instr_seq_mut(&mut self, _: &mut InstrSeq) {
270             self.push("end");
271         }
272 
visit_const_mut(&mut self, c: &mut Const)273         fn visit_const_mut(&mut self, c: &mut Const) {
274             match &mut c.value {
275                 Value::I32(x) => {
276                     self.push(*x);
277                     *x += 1;
278                 }
279                 _ => unreachable!(),
280             }
281         }
282 
visit_drop_mut(&mut self, _: &mut Drop)283         fn visit_drop_mut(&mut self, _: &mut Drop) {
284             self.push("drop");
285         }
286 
visit_block_mut(&mut self, _: &mut Block)287         fn visit_block_mut(&mut self, _: &mut Block) {
288             self.push("block");
289         }
290 
visit_if_else_mut(&mut self, _: &mut IfElse)291         fn visit_if_else_mut(&mut self, _: &mut IfElse) {
292             self.push("if-else");
293         }
294     }
295 
make_test_func(module: &mut crate::Module) -> &mut LocalFunction296     fn make_test_func(module: &mut crate::Module) -> &mut LocalFunction {
297         let block_ty = module.types.add(&[], &[]);
298         let mut builder = crate::FunctionBuilder::new(&mut module.types, &[], &[]);
299 
300         builder
301             .func_body()
302             .i32_const(1)
303             .drop()
304             .block(block_ty, |block| {
305                 block
306                     .i32_const(2)
307                     .drop()
308                     .if_else(
309                         block_ty,
310                         |then| {
311                             then.i32_const(3).drop();
312                         },
313                         |else_| {
314                             else_.i32_const(4).drop();
315                         },
316                     )
317                     .i32_const(5)
318                     .drop();
319             })
320             .i32_const(6)
321             .drop();
322 
323         let func_id = builder.finish(vec![], &mut module.funcs);
324         module.funcs.get_mut(func_id).kind.unwrap_local_mut()
325     }
326 
327     #[test]
dfs_in_order()328     fn dfs_in_order() {
329         let mut module = crate::Module::default();
330         let func = make_test_func(&mut module);
331 
332         let mut visitor = TestVisitor::default();
333         crate::ir::dfs_in_order(&mut visitor, func, func.entry_block());
334 
335         let expected = [
336             "start", "1", "drop", "block", "start", "2", "drop", "if-else", "start", "3", "drop",
337             "end", "start", "4", "drop", "end", "5", "drop", "end", "6", "drop", "end",
338         ];
339 
340         assert_eq!(
341             visitor.visits,
342             expected.iter().map(|s| s.to_string()).collect::<Vec<_>>()
343         );
344     }
345 
346     #[test]
dfs_pre_order_mut()347     fn dfs_pre_order_mut() {
348         let mut module = crate::Module::default();
349         let func = make_test_func(&mut module);
350 
351         let mut visitor = TestVisitor::default();
352         crate::ir::dfs_pre_order_mut(&mut visitor, func, func.entry_block());
353 
354         let mut expected = vec![];
355         // function entry
356         expected.extend(vec!["start", "1", "drop", "block", "6", "drop", "end"]);
357         // block
358         expected.extend(vec!["start", "2", "drop", "if-else", "5", "drop", "end"]);
359         // consequent
360         expected.extend(vec!["start", "3", "drop", "end"]);
361         // alternative
362         expected.extend(vec!["start", "4", "drop", "end"]);
363 
364         assert_eq!(
365             visitor.visits,
366             expected.iter().map(|s| s.to_string()).collect::<Vec<_>>()
367         );
368 
369         // And then check that the increments of the constant values did indeed
370         // take effect.
371 
372         visitor.visits.clear();
373         crate::ir::dfs_in_order(&mut visitor, func, func.entry_block());
374 
375         let expected = [
376             "start", "2", "drop", "block", "start", "3", "drop", "if-else", "start", "4", "drop",
377             "end", "start", "5", "drop", "end", "6", "drop", "end", "7", "drop", "end",
378         ];
379 
380         assert_eq!(
381             visitor.visits,
382             expected.iter().map(|s| s.to_string()).collect::<Vec<_>>()
383         );
384     }
385 }
386