1 use crate::cdsl::ast::{
2 Apply, BlockPool, ConstPool, DefIndex, DefPool, DummyDef, DummyExpr, Expr, PatternPosition,
3 VarIndex, VarPool,
4 };
5 use crate::cdsl::instructions::Instruction;
6 use crate::cdsl::type_inference::{infer_transform, TypeEnvironment};
7 use crate::cdsl::typevar::TypeVar;
8
9 use cranelift_entity::{entity_impl, PrimaryMap};
10
11 use std::collections::{HashMap, HashSet};
12 use std::iter::FromIterator;
13
14 /// An instruction transformation consists of a source and destination pattern.
15 ///
16 /// Patterns are expressed in *register transfer language* as tuples of Def or Expr nodes. A
17 /// pattern may optionally have a sequence of TypeConstraints, that additionally limit the set of
18 /// cases when it applies.
19 ///
20 /// The source pattern can contain only a single instruction.
21 pub(crate) struct Transform {
22 pub src: DefIndex,
23 pub dst: Vec<DefIndex>,
24 pub var_pool: VarPool,
25 pub def_pool: DefPool,
26 pub block_pool: BlockPool,
27 pub const_pool: ConstPool,
28 pub type_env: TypeEnvironment,
29 }
30
31 type SymbolTable = HashMap<String, VarIndex>;
32
33 impl Transform {
new(src: DummyDef, dst: Vec<DummyDef>) -> Self34 fn new(src: DummyDef, dst: Vec<DummyDef>) -> Self {
35 let mut var_pool = VarPool::new();
36 let mut def_pool = DefPool::new();
37 let mut block_pool = BlockPool::new();
38 let mut const_pool = ConstPool::new();
39
40 let mut input_vars: Vec<VarIndex> = Vec::new();
41 let mut defined_vars: Vec<VarIndex> = Vec::new();
42
43 // Maps variable names to our own Var copies.
44 let mut symbol_table: SymbolTable = SymbolTable::new();
45
46 // Rewrite variables in src and dst using our own copies.
47 let src = rewrite_def_list(
48 PatternPosition::Source,
49 vec![src],
50 &mut symbol_table,
51 &mut input_vars,
52 &mut defined_vars,
53 &mut var_pool,
54 &mut def_pool,
55 &mut block_pool,
56 &mut const_pool,
57 )[0];
58
59 let num_src_inputs = input_vars.len();
60
61 let dst = rewrite_def_list(
62 PatternPosition::Destination,
63 dst,
64 &mut symbol_table,
65 &mut input_vars,
66 &mut defined_vars,
67 &mut var_pool,
68 &mut def_pool,
69 &mut block_pool,
70 &mut const_pool,
71 );
72
73 // Sanity checks.
74 for &var_index in &input_vars {
75 assert!(
76 var_pool.get(var_index).is_input(),
77 "'{:?}' used as both input and def",
78 var_pool.get(var_index)
79 );
80 }
81 assert!(
82 input_vars.len() == num_src_inputs,
83 "extra input vars in dst pattern: {:?}",
84 input_vars
85 .iter()
86 .map(|&i| var_pool.get(i))
87 .skip(num_src_inputs)
88 .collect::<Vec<_>>()
89 );
90
91 // Perform type inference and cleanup.
92 let type_env = infer_transform(src, &dst, &def_pool, &mut var_pool).unwrap();
93
94 // Sanity check: the set of inferred free type variables should be a subset of the type
95 // variables corresponding to Vars appearing in the source pattern.
96 {
97 let free_typevars: HashSet<TypeVar> =
98 HashSet::from_iter(type_env.free_typevars(&mut var_pool));
99 let src_tvs = HashSet::from_iter(
100 input_vars
101 .clone()
102 .iter()
103 .chain(
104 defined_vars
105 .iter()
106 .filter(|&&var_index| !var_pool.get(var_index).is_temp()),
107 )
108 .map(|&var_index| var_pool.get(var_index).get_typevar())
109 .filter(|maybe_var| maybe_var.is_some())
110 .map(|var| var.unwrap()),
111 );
112 if !free_typevars.is_subset(&src_tvs) {
113 let missing_tvs = (&free_typevars - &src_tvs)
114 .iter()
115 .map(|tv| tv.name.clone())
116 .collect::<Vec<_>>()
117 .join(", ");
118 panic!("Some free vars don't appear in src: {}", missing_tvs);
119 }
120 }
121
122 for &var_index in input_vars.iter().chain(defined_vars.iter()) {
123 let var = var_pool.get_mut(var_index);
124 let canon_tv = type_env.get_equivalent(&var.get_or_create_typevar());
125 var.set_typevar(canon_tv);
126 }
127
128 Self {
129 src,
130 dst,
131 var_pool,
132 def_pool,
133 block_pool,
134 const_pool,
135 type_env,
136 }
137 }
138
verify_legalize(&self)139 fn verify_legalize(&self) {
140 let def = self.def_pool.get(self.src);
141 for &var_index in def.defined_vars.iter() {
142 let defined_var = self.var_pool.get(var_index);
143 assert!(
144 defined_var.is_output(),
145 "{:?} not defined in the destination pattern",
146 defined_var
147 );
148 }
149 }
150 }
151
152 /// Inserts, if not present, a name in the `symbol_table`. Then returns its index in the variable
153 /// pool `var_pool`. If the variable was not present in the symbol table, then add it to the list of
154 /// `defined_vars`.
var_index( name: &str, symbol_table: &mut SymbolTable, defined_vars: &mut Vec<VarIndex>, var_pool: &mut VarPool, ) -> VarIndex155 fn var_index(
156 name: &str,
157 symbol_table: &mut SymbolTable,
158 defined_vars: &mut Vec<VarIndex>,
159 var_pool: &mut VarPool,
160 ) -> VarIndex {
161 let name = name.to_string();
162 match symbol_table.get(&name) {
163 Some(&existing_var) => existing_var,
164 None => {
165 // Materialize the variable.
166 let new_var = var_pool.create(name.clone());
167 symbol_table.insert(name, new_var);
168 defined_vars.push(new_var);
169 new_var
170 }
171 }
172 }
173
174 /// Given a list of symbols defined in a Def, rewrite them to local symbols. Yield the new locals.
rewrite_defined_vars( position: PatternPosition, dummy_def: &DummyDef, def_index: DefIndex, symbol_table: &mut SymbolTable, defined_vars: &mut Vec<VarIndex>, var_pool: &mut VarPool, ) -> Vec<VarIndex>175 fn rewrite_defined_vars(
176 position: PatternPosition,
177 dummy_def: &DummyDef,
178 def_index: DefIndex,
179 symbol_table: &mut SymbolTable,
180 defined_vars: &mut Vec<VarIndex>,
181 var_pool: &mut VarPool,
182 ) -> Vec<VarIndex> {
183 let mut new_defined_vars = Vec::new();
184 for var in &dummy_def.defined_vars {
185 let own_var = var_index(&var.name, symbol_table, defined_vars, var_pool);
186 var_pool.get_mut(own_var).set_def(position, def_index);
187 new_defined_vars.push(own_var);
188 }
189 new_defined_vars
190 }
191
192 /// Find all uses of variables in `expr` and replace them with our own local symbols.
rewrite_expr( position: PatternPosition, dummy_expr: DummyExpr, symbol_table: &mut SymbolTable, input_vars: &mut Vec<VarIndex>, var_pool: &mut VarPool, const_pool: &mut ConstPool, ) -> Apply193 fn rewrite_expr(
194 position: PatternPosition,
195 dummy_expr: DummyExpr,
196 symbol_table: &mut SymbolTable,
197 input_vars: &mut Vec<VarIndex>,
198 var_pool: &mut VarPool,
199 const_pool: &mut ConstPool,
200 ) -> Apply {
201 let (apply_target, dummy_args) = if let DummyExpr::Apply(apply_target, dummy_args) = dummy_expr
202 {
203 (apply_target, dummy_args)
204 } else {
205 panic!("we only rewrite apply expressions");
206 };
207
208 assert_eq!(
209 apply_target.inst().operands_in.len(),
210 dummy_args.len(),
211 "number of arguments in instruction {} is incorrect\nexpected: {:?}",
212 apply_target.inst().name,
213 apply_target
214 .inst()
215 .operands_in
216 .iter()
217 .map(|operand| format!("{}: {}", operand.name, operand.kind.rust_type))
218 .collect::<Vec<_>>(),
219 );
220
221 let mut args = Vec::new();
222 for (i, arg) in dummy_args.into_iter().enumerate() {
223 match arg {
224 DummyExpr::Var(var) => {
225 let own_var = var_index(&var.name, symbol_table, input_vars, var_pool);
226 let var = var_pool.get(own_var);
227 assert!(
228 var.is_input() || var.get_def(position).is_some(),
229 "{:?} used as both input and def",
230 var
231 );
232 args.push(Expr::Var(own_var));
233 }
234 DummyExpr::Literal(literal) => {
235 assert!(!apply_target.inst().operands_in[i].is_value());
236 args.push(Expr::Literal(literal));
237 }
238 DummyExpr::Constant(constant) => {
239 let const_name = const_pool.insert(constant.0);
240 // Here we abuse var_index by passing an empty, immediately-dropped vector to
241 // `defined_vars`; the reason for this is that unlike the `Var` case above,
242 // constants will create a variable that is not an input variable (it is tracked
243 // instead by ConstPool).
244 let const_var = var_index(&const_name, symbol_table, &mut vec![], var_pool);
245 args.push(Expr::Var(const_var));
246 }
247 DummyExpr::Apply(..) => {
248 panic!("Recursive apply is not allowed.");
249 }
250 DummyExpr::Block(_block) => {
251 panic!("Blocks are not valid arguments.");
252 }
253 }
254 }
255
256 Apply::new(apply_target, args)
257 }
258
259 #[allow(clippy::too_many_arguments)]
rewrite_def_list( position: PatternPosition, dummy_defs: Vec<DummyDef>, symbol_table: &mut SymbolTable, input_vars: &mut Vec<VarIndex>, defined_vars: &mut Vec<VarIndex>, var_pool: &mut VarPool, def_pool: &mut DefPool, block_pool: &mut BlockPool, const_pool: &mut ConstPool, ) -> Vec<DefIndex>260 fn rewrite_def_list(
261 position: PatternPosition,
262 dummy_defs: Vec<DummyDef>,
263 symbol_table: &mut SymbolTable,
264 input_vars: &mut Vec<VarIndex>,
265 defined_vars: &mut Vec<VarIndex>,
266 var_pool: &mut VarPool,
267 def_pool: &mut DefPool,
268 block_pool: &mut BlockPool,
269 const_pool: &mut ConstPool,
270 ) -> Vec<DefIndex> {
271 let mut new_defs = Vec::new();
272 // Register variable names of new blocks first as a block name can be used to jump forward. Thus
273 // the name has to be registered first to avoid misinterpreting it as an input-var.
274 for dummy_def in dummy_defs.iter() {
275 if let DummyExpr::Block(ref var) = dummy_def.expr {
276 var_index(&var.name, symbol_table, defined_vars, var_pool);
277 }
278 }
279
280 // Iterate over the definitions and blocks, to map variables names to inputs or outputs.
281 for dummy_def in dummy_defs {
282 let def_index = def_pool.next_index();
283
284 let new_defined_vars = rewrite_defined_vars(
285 position,
286 &dummy_def,
287 def_index,
288 symbol_table,
289 defined_vars,
290 var_pool,
291 );
292 if let DummyExpr::Block(var) = dummy_def.expr {
293 let var_index = *symbol_table
294 .get(&var.name)
295 .or_else(|| {
296 panic!(
297 "Block {} was not registered during the first visit",
298 var.name
299 )
300 })
301 .unwrap();
302 var_pool.get_mut(var_index).set_def(position, def_index);
303 block_pool.create_block(var_index, def_index);
304 } else {
305 let new_apply = rewrite_expr(
306 position,
307 dummy_def.expr,
308 symbol_table,
309 input_vars,
310 var_pool,
311 const_pool,
312 );
313
314 assert!(
315 def_pool.next_index() == def_index,
316 "shouldn't have created new defs in the meanwhile"
317 );
318 assert_eq!(
319 new_apply.inst.value_results.len(),
320 new_defined_vars.len(),
321 "number of Var results in instruction is incorrect"
322 );
323
324 new_defs.push(def_pool.create_inst(new_apply, new_defined_vars));
325 }
326 }
327 new_defs
328 }
329
330 /// A group of related transformations.
331 pub(crate) struct TransformGroup {
332 pub name: &'static str,
333 pub doc: &'static str,
334 pub chain_with: Option<TransformGroupIndex>,
335 pub isa_name: Option<&'static str>,
336 pub id: TransformGroupIndex,
337
338 /// Maps Instruction camel_case names to custom legalization functions names.
339 pub custom_legalizes: HashMap<String, &'static str>,
340 pub transforms: Vec<Transform>,
341 }
342
343 impl TransformGroup {
rust_name(&self) -> String344 pub fn rust_name(&self) -> String {
345 match self.isa_name {
346 Some(_) => {
347 // This is a function in the same module as the LEGALIZE_ACTIONS table referring to
348 // it.
349 self.name.to_string()
350 }
351 None => format!("crate::legalizer::{}", self.name),
352 }
353 }
354 }
355
356 #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
357 pub(crate) struct TransformGroupIndex(u32);
358 entity_impl!(TransformGroupIndex);
359
360 pub(crate) struct TransformGroupBuilder {
361 name: &'static str,
362 doc: &'static str,
363 chain_with: Option<TransformGroupIndex>,
364 isa_name: Option<&'static str>,
365 pub custom_legalizes: HashMap<String, &'static str>,
366 pub transforms: Vec<Transform>,
367 }
368
369 impl TransformGroupBuilder {
new(name: &'static str, doc: &'static str) -> Self370 pub fn new(name: &'static str, doc: &'static str) -> Self {
371 Self {
372 name,
373 doc,
374 chain_with: None,
375 isa_name: None,
376 custom_legalizes: HashMap::new(),
377 transforms: Vec::new(),
378 }
379 }
380
chain_with(mut self, next_id: TransformGroupIndex) -> Self381 pub fn chain_with(mut self, next_id: TransformGroupIndex) -> Self {
382 assert!(self.chain_with.is_none());
383 self.chain_with = Some(next_id);
384 self
385 }
386
isa(mut self, isa_name: &'static str) -> Self387 pub fn isa(mut self, isa_name: &'static str) -> Self {
388 assert!(self.isa_name.is_none());
389 self.isa_name = Some(isa_name);
390 self
391 }
392
393 /// Add a custom legalization action for `inst`.
394 ///
395 /// The `func_name` parameter is the fully qualified name of a Rust function which takes the
396 /// same arguments as the `isa::Legalize` actions.
397 ///
398 /// The custom function will be called to legalize `inst` and any return value is ignored.
custom_legalize(&mut self, inst: &Instruction, func_name: &'static str)399 pub fn custom_legalize(&mut self, inst: &Instruction, func_name: &'static str) {
400 assert!(
401 self.custom_legalizes
402 .insert(inst.camel_name.clone(), func_name)
403 .is_none(),
404 "custom legalization action for {} inserted twice",
405 inst.name
406 );
407 }
408
409 /// Add a legalization pattern to this group.
legalize(&mut self, src: DummyDef, dst: Vec<DummyDef>)410 pub fn legalize(&mut self, src: DummyDef, dst: Vec<DummyDef>) {
411 let transform = Transform::new(src, dst);
412 transform.verify_legalize();
413 self.transforms.push(transform);
414 }
415
build_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex416 pub fn build_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex {
417 let next_id = owner.next_key();
418 owner.add(TransformGroup {
419 name: self.name,
420 doc: self.doc,
421 isa_name: self.isa_name,
422 id: next_id,
423 chain_with: self.chain_with,
424 custom_legalizes: self.custom_legalizes,
425 transforms: self.transforms,
426 })
427 }
428 }
429
430 pub(crate) struct TransformGroups {
431 groups: PrimaryMap<TransformGroupIndex, TransformGroup>,
432 }
433
434 impl TransformGroups {
new() -> Self435 pub fn new() -> Self {
436 Self {
437 groups: PrimaryMap::new(),
438 }
439 }
add(&mut self, new_group: TransformGroup) -> TransformGroupIndex440 pub fn add(&mut self, new_group: TransformGroup) -> TransformGroupIndex {
441 for group in self.groups.values() {
442 assert!(
443 group.name != new_group.name,
444 "trying to insert {} for the second time",
445 new_group.name
446 );
447 }
448 self.groups.push(new_group)
449 }
get(&self, id: TransformGroupIndex) -> &TransformGroup450 pub fn get(&self, id: TransformGroupIndex) -> &TransformGroup {
451 &self.groups[id]
452 }
next_key(&self) -> TransformGroupIndex453 fn next_key(&self) -> TransformGroupIndex {
454 self.groups.next_key()
455 }
by_name(&self, name: &'static str) -> &TransformGroup456 pub fn by_name(&self, name: &'static str) -> &TransformGroup {
457 for group in self.groups.values() {
458 if group.name == name {
459 return group;
460 }
461 }
462 panic!("transform group with name {} not found", name);
463 }
464 }
465
466 #[test]
467 #[should_panic]
test_double_custom_legalization()468 fn test_double_custom_legalization() {
469 use crate::cdsl::formats::InstructionFormatBuilder;
470 use crate::cdsl::instructions::{AllInstructions, InstructionBuilder, InstructionGroupBuilder};
471
472 let nullary = InstructionFormatBuilder::new("nullary").build();
473
474 let mut dummy_all = AllInstructions::new();
475 let mut inst_group = InstructionGroupBuilder::new(&mut dummy_all);
476 inst_group.push(InstructionBuilder::new("dummy", "doc", &nullary));
477
478 let inst_group = inst_group.build();
479 let dummy_inst = inst_group.by_name("dummy");
480
481 let mut transform_group = TransformGroupBuilder::new("test", "doc");
482 transform_group.custom_legalize(&dummy_inst, "custom 1");
483 transform_group.custom_legalize(&dummy_inst, "custom 2");
484 }
485