1 use crate::cdsl::ast::{
2     Apply, BlockPool, ConstPool, DefIndex, DefPool, DummyDef, DummyExpr, Expr, PatternPosition,
3     VarIndex, VarPool,
4 };
5 use crate::cdsl::instructions::Instruction;
6 use crate::cdsl::type_inference::{infer_transform, TypeEnvironment};
7 use crate::cdsl::typevar::TypeVar;
8 
9 use cranelift_entity::{entity_impl, PrimaryMap};
10 
11 use std::collections::{HashMap, HashSet};
12 use std::iter::FromIterator;
13 
14 /// An instruction transformation consists of a source and destination pattern.
15 ///
16 /// Patterns are expressed in *register transfer language* as tuples of Def or Expr nodes. A
17 /// pattern may optionally have a sequence of TypeConstraints, that additionally limit the set of
18 /// cases when it applies.
19 ///
20 /// The source pattern can contain only a single instruction.
21 pub(crate) struct Transform {
22     pub src: DefIndex,
23     pub dst: Vec<DefIndex>,
24     pub var_pool: VarPool,
25     pub def_pool: DefPool,
26     pub block_pool: BlockPool,
27     pub const_pool: ConstPool,
28     pub type_env: TypeEnvironment,
29 }
30 
31 type SymbolTable = HashMap<String, VarIndex>;
32 
33 impl Transform {
new(src: DummyDef, dst: Vec<DummyDef>) -> Self34     fn new(src: DummyDef, dst: Vec<DummyDef>) -> Self {
35         let mut var_pool = VarPool::new();
36         let mut def_pool = DefPool::new();
37         let mut block_pool = BlockPool::new();
38         let mut const_pool = ConstPool::new();
39 
40         let mut input_vars: Vec<VarIndex> = Vec::new();
41         let mut defined_vars: Vec<VarIndex> = Vec::new();
42 
43         // Maps variable names to our own Var copies.
44         let mut symbol_table: SymbolTable = SymbolTable::new();
45 
46         // Rewrite variables in src and dst using our own copies.
47         let src = rewrite_def_list(
48             PatternPosition::Source,
49             vec![src],
50             &mut symbol_table,
51             &mut input_vars,
52             &mut defined_vars,
53             &mut var_pool,
54             &mut def_pool,
55             &mut block_pool,
56             &mut const_pool,
57         )[0];
58 
59         let num_src_inputs = input_vars.len();
60 
61         let dst = rewrite_def_list(
62             PatternPosition::Destination,
63             dst,
64             &mut symbol_table,
65             &mut input_vars,
66             &mut defined_vars,
67             &mut var_pool,
68             &mut def_pool,
69             &mut block_pool,
70             &mut const_pool,
71         );
72 
73         // Sanity checks.
74         for &var_index in &input_vars {
75             assert!(
76                 var_pool.get(var_index).is_input(),
77                 "'{:?}' used as both input and def",
78                 var_pool.get(var_index)
79             );
80         }
81         assert!(
82             input_vars.len() == num_src_inputs,
83             "extra input vars in dst pattern: {:?}",
84             input_vars
85                 .iter()
86                 .map(|&i| var_pool.get(i))
87                 .skip(num_src_inputs)
88                 .collect::<Vec<_>>()
89         );
90 
91         // Perform type inference and cleanup.
92         let type_env = infer_transform(src, &dst, &def_pool, &mut var_pool).unwrap();
93 
94         // Sanity check: the set of inferred free type variables should be a subset of the type
95         // variables corresponding to Vars appearing in the source pattern.
96         {
97             let free_typevars: HashSet<TypeVar> =
98                 HashSet::from_iter(type_env.free_typevars(&mut var_pool));
99             let src_tvs = HashSet::from_iter(
100                 input_vars
101                     .clone()
102                     .iter()
103                     .chain(
104                         defined_vars
105                             .iter()
106                             .filter(|&&var_index| !var_pool.get(var_index).is_temp()),
107                     )
108                     .map(|&var_index| var_pool.get(var_index).get_typevar())
109                     .filter(|maybe_var| maybe_var.is_some())
110                     .map(|var| var.unwrap()),
111             );
112             if !free_typevars.is_subset(&src_tvs) {
113                 let missing_tvs = (&free_typevars - &src_tvs)
114                     .iter()
115                     .map(|tv| tv.name.clone())
116                     .collect::<Vec<_>>()
117                     .join(", ");
118                 panic!("Some free vars don't appear in src: {}", missing_tvs);
119             }
120         }
121 
122         for &var_index in input_vars.iter().chain(defined_vars.iter()) {
123             let var = var_pool.get_mut(var_index);
124             let canon_tv = type_env.get_equivalent(&var.get_or_create_typevar());
125             var.set_typevar(canon_tv);
126         }
127 
128         Self {
129             src,
130             dst,
131             var_pool,
132             def_pool,
133             block_pool,
134             const_pool,
135             type_env,
136         }
137     }
138 
verify_legalize(&self)139     fn verify_legalize(&self) {
140         let def = self.def_pool.get(self.src);
141         for &var_index in def.defined_vars.iter() {
142             let defined_var = self.var_pool.get(var_index);
143             assert!(
144                 defined_var.is_output(),
145                 "{:?} not defined in the destination pattern",
146                 defined_var
147             );
148         }
149     }
150 }
151 
152 /// Inserts, if not present, a name in the `symbol_table`. Then returns its index in the variable
153 /// pool `var_pool`. If the variable was not present in the symbol table, then add it to the list of
154 /// `defined_vars`.
var_index( name: &str, symbol_table: &mut SymbolTable, defined_vars: &mut Vec<VarIndex>, var_pool: &mut VarPool, ) -> VarIndex155 fn var_index(
156     name: &str,
157     symbol_table: &mut SymbolTable,
158     defined_vars: &mut Vec<VarIndex>,
159     var_pool: &mut VarPool,
160 ) -> VarIndex {
161     let name = name.to_string();
162     match symbol_table.get(&name) {
163         Some(&existing_var) => existing_var,
164         None => {
165             // Materialize the variable.
166             let new_var = var_pool.create(name.clone());
167             symbol_table.insert(name, new_var);
168             defined_vars.push(new_var);
169             new_var
170         }
171     }
172 }
173 
174 /// Given a list of symbols defined in a Def, rewrite them to local symbols. Yield the new locals.
rewrite_defined_vars( position: PatternPosition, dummy_def: &DummyDef, def_index: DefIndex, symbol_table: &mut SymbolTable, defined_vars: &mut Vec<VarIndex>, var_pool: &mut VarPool, ) -> Vec<VarIndex>175 fn rewrite_defined_vars(
176     position: PatternPosition,
177     dummy_def: &DummyDef,
178     def_index: DefIndex,
179     symbol_table: &mut SymbolTable,
180     defined_vars: &mut Vec<VarIndex>,
181     var_pool: &mut VarPool,
182 ) -> Vec<VarIndex> {
183     let mut new_defined_vars = Vec::new();
184     for var in &dummy_def.defined_vars {
185         let own_var = var_index(&var.name, symbol_table, defined_vars, var_pool);
186         var_pool.get_mut(own_var).set_def(position, def_index);
187         new_defined_vars.push(own_var);
188     }
189     new_defined_vars
190 }
191 
192 /// Find all uses of variables in `expr` and replace them with our own local symbols.
rewrite_expr( position: PatternPosition, dummy_expr: DummyExpr, symbol_table: &mut SymbolTable, input_vars: &mut Vec<VarIndex>, var_pool: &mut VarPool, const_pool: &mut ConstPool, ) -> Apply193 fn rewrite_expr(
194     position: PatternPosition,
195     dummy_expr: DummyExpr,
196     symbol_table: &mut SymbolTable,
197     input_vars: &mut Vec<VarIndex>,
198     var_pool: &mut VarPool,
199     const_pool: &mut ConstPool,
200 ) -> Apply {
201     let (apply_target, dummy_args) = if let DummyExpr::Apply(apply_target, dummy_args) = dummy_expr
202     {
203         (apply_target, dummy_args)
204     } else {
205         panic!("we only rewrite apply expressions");
206     };
207 
208     assert_eq!(
209         apply_target.inst().operands_in.len(),
210         dummy_args.len(),
211         "number of arguments in instruction {} is incorrect\nexpected: {:?}",
212         apply_target.inst().name,
213         apply_target
214             .inst()
215             .operands_in
216             .iter()
217             .map(|operand| format!("{}: {}", operand.name, operand.kind.rust_type))
218             .collect::<Vec<_>>(),
219     );
220 
221     let mut args = Vec::new();
222     for (i, arg) in dummy_args.into_iter().enumerate() {
223         match arg {
224             DummyExpr::Var(var) => {
225                 let own_var = var_index(&var.name, symbol_table, input_vars, var_pool);
226                 let var = var_pool.get(own_var);
227                 assert!(
228                     var.is_input() || var.get_def(position).is_some(),
229                     "{:?} used as both input and def",
230                     var
231                 );
232                 args.push(Expr::Var(own_var));
233             }
234             DummyExpr::Literal(literal) => {
235                 assert!(!apply_target.inst().operands_in[i].is_value());
236                 args.push(Expr::Literal(literal));
237             }
238             DummyExpr::Constant(constant) => {
239                 let const_name = const_pool.insert(constant.0);
240                 // Here we abuse var_index by passing an empty, immediately-dropped vector to
241                 // `defined_vars`; the reason for this is that unlike the `Var` case above,
242                 // constants will create a variable that is not an input variable (it is tracked
243                 // instead by ConstPool).
244                 let const_var = var_index(&const_name, symbol_table, &mut vec![], var_pool);
245                 args.push(Expr::Var(const_var));
246             }
247             DummyExpr::Apply(..) => {
248                 panic!("Recursive apply is not allowed.");
249             }
250             DummyExpr::Block(_block) => {
251                 panic!("Blocks are not valid arguments.");
252             }
253         }
254     }
255 
256     Apply::new(apply_target, args)
257 }
258 
259 #[allow(clippy::too_many_arguments)]
rewrite_def_list( position: PatternPosition, dummy_defs: Vec<DummyDef>, symbol_table: &mut SymbolTable, input_vars: &mut Vec<VarIndex>, defined_vars: &mut Vec<VarIndex>, var_pool: &mut VarPool, def_pool: &mut DefPool, block_pool: &mut BlockPool, const_pool: &mut ConstPool, ) -> Vec<DefIndex>260 fn rewrite_def_list(
261     position: PatternPosition,
262     dummy_defs: Vec<DummyDef>,
263     symbol_table: &mut SymbolTable,
264     input_vars: &mut Vec<VarIndex>,
265     defined_vars: &mut Vec<VarIndex>,
266     var_pool: &mut VarPool,
267     def_pool: &mut DefPool,
268     block_pool: &mut BlockPool,
269     const_pool: &mut ConstPool,
270 ) -> Vec<DefIndex> {
271     let mut new_defs = Vec::new();
272     // Register variable names of new blocks first as a block name can be used to jump forward. Thus
273     // the name has to be registered first to avoid misinterpreting it as an input-var.
274     for dummy_def in dummy_defs.iter() {
275         if let DummyExpr::Block(ref var) = dummy_def.expr {
276             var_index(&var.name, symbol_table, defined_vars, var_pool);
277         }
278     }
279 
280     // Iterate over the definitions and blocks, to map variables names to inputs or outputs.
281     for dummy_def in dummy_defs {
282         let def_index = def_pool.next_index();
283 
284         let new_defined_vars = rewrite_defined_vars(
285             position,
286             &dummy_def,
287             def_index,
288             symbol_table,
289             defined_vars,
290             var_pool,
291         );
292         if let DummyExpr::Block(var) = dummy_def.expr {
293             let var_index = *symbol_table
294                 .get(&var.name)
295                 .or_else(|| {
296                     panic!(
297                         "Block {} was not registered during the first visit",
298                         var.name
299                     )
300                 })
301                 .unwrap();
302             var_pool.get_mut(var_index).set_def(position, def_index);
303             block_pool.create_block(var_index, def_index);
304         } else {
305             let new_apply = rewrite_expr(
306                 position,
307                 dummy_def.expr,
308                 symbol_table,
309                 input_vars,
310                 var_pool,
311                 const_pool,
312             );
313 
314             assert!(
315                 def_pool.next_index() == def_index,
316                 "shouldn't have created new defs in the meanwhile"
317             );
318             assert_eq!(
319                 new_apply.inst.value_results.len(),
320                 new_defined_vars.len(),
321                 "number of Var results in instruction is incorrect"
322             );
323 
324             new_defs.push(def_pool.create_inst(new_apply, new_defined_vars));
325         }
326     }
327     new_defs
328 }
329 
330 /// A group of related transformations.
331 pub(crate) struct TransformGroup {
332     pub name: &'static str,
333     pub doc: &'static str,
334     pub chain_with: Option<TransformGroupIndex>,
335     pub isa_name: Option<&'static str>,
336     pub id: TransformGroupIndex,
337 
338     /// Maps Instruction camel_case names to custom legalization functions names.
339     pub custom_legalizes: HashMap<String, &'static str>,
340     pub transforms: Vec<Transform>,
341 }
342 
343 impl TransformGroup {
rust_name(&self) -> String344     pub fn rust_name(&self) -> String {
345         match self.isa_name {
346             Some(_) => {
347                 // This is a function in the same module as the LEGALIZE_ACTIONS table referring to
348                 // it.
349                 self.name.to_string()
350             }
351             None => format!("crate::legalizer::{}", self.name),
352         }
353     }
354 }
355 
356 #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
357 pub(crate) struct TransformGroupIndex(u32);
358 entity_impl!(TransformGroupIndex);
359 
360 pub(crate) struct TransformGroupBuilder {
361     name: &'static str,
362     doc: &'static str,
363     chain_with: Option<TransformGroupIndex>,
364     isa_name: Option<&'static str>,
365     pub custom_legalizes: HashMap<String, &'static str>,
366     pub transforms: Vec<Transform>,
367 }
368 
369 impl TransformGroupBuilder {
new(name: &'static str, doc: &'static str) -> Self370     pub fn new(name: &'static str, doc: &'static str) -> Self {
371         Self {
372             name,
373             doc,
374             chain_with: None,
375             isa_name: None,
376             custom_legalizes: HashMap::new(),
377             transforms: Vec::new(),
378         }
379     }
380 
chain_with(mut self, next_id: TransformGroupIndex) -> Self381     pub fn chain_with(mut self, next_id: TransformGroupIndex) -> Self {
382         assert!(self.chain_with.is_none());
383         self.chain_with = Some(next_id);
384         self
385     }
386 
isa(mut self, isa_name: &'static str) -> Self387     pub fn isa(mut self, isa_name: &'static str) -> Self {
388         assert!(self.isa_name.is_none());
389         self.isa_name = Some(isa_name);
390         self
391     }
392 
393     /// Add a custom legalization action for `inst`.
394     ///
395     /// The `func_name` parameter is the fully qualified name of a Rust function which takes the
396     /// same arguments as the `isa::Legalize` actions.
397     ///
398     /// The custom function will be called to legalize `inst` and any return value is ignored.
custom_legalize(&mut self, inst: &Instruction, func_name: &'static str)399     pub fn custom_legalize(&mut self, inst: &Instruction, func_name: &'static str) {
400         assert!(
401             self.custom_legalizes
402                 .insert(inst.camel_name.clone(), func_name)
403                 .is_none(),
404             "custom legalization action for {} inserted twice",
405             inst.name
406         );
407     }
408 
409     /// Add a legalization pattern to this group.
legalize(&mut self, src: DummyDef, dst: Vec<DummyDef>)410     pub fn legalize(&mut self, src: DummyDef, dst: Vec<DummyDef>) {
411         let transform = Transform::new(src, dst);
412         transform.verify_legalize();
413         self.transforms.push(transform);
414     }
415 
build_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex416     pub fn build_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex {
417         let next_id = owner.next_key();
418         owner.add(TransformGroup {
419             name: self.name,
420             doc: self.doc,
421             isa_name: self.isa_name,
422             id: next_id,
423             chain_with: self.chain_with,
424             custom_legalizes: self.custom_legalizes,
425             transforms: self.transforms,
426         })
427     }
428 }
429 
430 pub(crate) struct TransformGroups {
431     groups: PrimaryMap<TransformGroupIndex, TransformGroup>,
432 }
433 
434 impl TransformGroups {
new() -> Self435     pub fn new() -> Self {
436         Self {
437             groups: PrimaryMap::new(),
438         }
439     }
add(&mut self, new_group: TransformGroup) -> TransformGroupIndex440     pub fn add(&mut self, new_group: TransformGroup) -> TransformGroupIndex {
441         for group in self.groups.values() {
442             assert!(
443                 group.name != new_group.name,
444                 "trying to insert {} for the second time",
445                 new_group.name
446             );
447         }
448         self.groups.push(new_group)
449     }
get(&self, id: TransformGroupIndex) -> &TransformGroup450     pub fn get(&self, id: TransformGroupIndex) -> &TransformGroup {
451         &self.groups[id]
452     }
next_key(&self) -> TransformGroupIndex453     fn next_key(&self) -> TransformGroupIndex {
454         self.groups.next_key()
455     }
by_name(&self, name: &'static str) -> &TransformGroup456     pub fn by_name(&self, name: &'static str) -> &TransformGroup {
457         for group in self.groups.values() {
458             if group.name == name {
459                 return group;
460             }
461         }
462         panic!("transform group with name {} not found", name);
463     }
464 }
465 
466 #[test]
467 #[should_panic]
test_double_custom_legalization()468 fn test_double_custom_legalization() {
469     use crate::cdsl::formats::InstructionFormatBuilder;
470     use crate::cdsl::instructions::{AllInstructions, InstructionBuilder, InstructionGroupBuilder};
471 
472     let nullary = InstructionFormatBuilder::new("nullary").build();
473 
474     let mut dummy_all = AllInstructions::new();
475     let mut inst_group = InstructionGroupBuilder::new(&mut dummy_all);
476     inst_group.push(InstructionBuilder::new("dummy", "doc", &nullary));
477 
478     let inst_group = inst_group.build();
479     let dummy_inst = inst_group.by_name("dummy");
480 
481     let mut transform_group = TransformGroupBuilder::new("test", "doc");
482     transform_group.custom_legalize(&dummy_inst, "custom 1");
483     transform_group.custom_legalize(&dummy_inst, "custom 2");
484 }
485