1 use crate::cdsl::ast::{Def, DefPool, VarPool};
2 use crate::cdsl::formats::FormatRegistry;
3 use crate::cdsl::isa::TargetIsa;
4 use crate::cdsl::type_inference::Constraint;
5 use crate::cdsl::typevar::{TypeSet, TypeVar};
6 use crate::cdsl::xform::{Transform, TransformGroup, TransformGroups};
7 
8 use crate::error;
9 use crate::gen_inst::gen_typesets_table;
10 use crate::srcgen::Formatter;
11 use crate::unique_table::UniqueTable;
12 
13 use std::collections::{HashMap, HashSet};
14 use std::iter::FromIterator;
15 
16 /// Given a `Def` node, emit code that extracts all the instruction fields from
17 /// `pos.func.dfg[iref]`.
18 ///
19 /// Create local variables named after the `Var` instances in `node`.
20 ///
21 /// Also create a local variable named `predicate` with the value of the evaluated instruction
22 /// predicate, or `true` if the node has no predicate.
unwrap_inst( transform: &Transform, format_registry: &FormatRegistry, fmt: &mut Formatter, ) -> bool23 fn unwrap_inst(
24     transform: &Transform,
25     format_registry: &FormatRegistry,
26     fmt: &mut Formatter,
27 ) -> bool {
28     let var_pool = &transform.var_pool;
29     let def_pool = &transform.def_pool;
30 
31     let def = def_pool.get(transform.src);
32     let apply = &def.apply;
33     let inst = &apply.inst;
34     let iform = format_registry.get(inst.format);
35 
36     fmt.comment(format!(
37         "Unwrap {}",
38         def.to_comment_string(&transform.var_pool)
39     ));
40 
41     // Extract the Var arguments.
42     let arg_names = apply
43         .args
44         .iter()
45         .map(|arg| match arg.maybe_var() {
46             Some(var_index) => var_pool.get(var_index).name,
47             None => "_",
48         })
49         .collect::<Vec<_>>()
50         .join(", ");
51 
52     fmtln!(
53         fmt,
54         "let ({}, predicate) = if let crate::ir::InstructionData::{} {{",
55         arg_names,
56         iform.name
57     );
58     fmt.indent(|fmt| {
59         // Fields are encoded directly.
60         for field in &iform.imm_fields {
61             fmtln!(fmt, "{},", field.member);
62         }
63 
64         if iform.has_value_list || iform.num_value_operands > 1 {
65             fmt.line("ref args,");
66         } else if iform.num_value_operands == 1 {
67             fmt.line("arg,");
68         }
69 
70         fmt.line("..");
71         fmt.outdented_line("} = pos.func.dfg[inst] {");
72         fmt.line("let func = &pos.func;");
73 
74         if iform.has_value_list {
75             fmt.line("let args = args.as_slice(&func.dfg.value_lists);");
76         } else if iform.num_value_operands == 1 {
77             fmt.line("let args = [arg];")
78         }
79 
80         // Generate the values for the tuple.
81         fmt.line("(");
82         fmt.indent(|fmt| {
83             for (op_num, op) in inst.operands_in.iter().enumerate() {
84                 if op.is_immediate() {
85                     let n = inst.imm_opnums.iter().position(|&i| i == op_num).unwrap();
86                     fmtln!(fmt, "{},", iform.imm_fields[n].member);
87                 } else if op.is_value() {
88                     let n = inst.value_opnums.iter().position(|&i| i == op_num).unwrap();
89                     fmtln!(fmt, "func.dfg.resolve_aliases(args[{}]),", n);
90                 } else if op.is_varargs() {
91                     let n = inst.imm_opnums.iter().chain(inst.value_opnums.iter()).max().map(|n| n + 1).unwrap_or(0);
92                     // We need to create a `Vec` here, as using a slice would result in a borrowck
93                     // error later on.
94                     fmtln!(fmt, "\
95                         args.iter().skip({}).map(|&arg| func.dfg.resolve_aliases(arg)).collect::<Vec<_>>(),\
96                     ", n);
97                 }
98             }
99 
100             // Evaluate the instruction predicate if any.
101             fmt.multi_line(
102                 &apply
103                     .inst_predicate_with_ctrl_typevar(format_registry, var_pool)
104                     .rust_predicate(),
105             );
106         });
107         fmt.line(")");
108 
109         fmt.outdented_line("} else {");
110         fmt.line(r#"unreachable!("bad instruction format")"#);
111     });
112     fmtln!(fmt, "};");
113 
114     assert_eq!(inst.operands_in.len(), apply.args.len());
115     for (i, op) in inst.operands_in.iter().enumerate() {
116         if op.is_varargs() {
117             let name = var_pool
118                 .get(apply.args[i].maybe_var().expect("vararg without name"))
119                 .name;
120 
121             // Above name is set to an `Vec` representing the varargs. However it is expected to be
122             // `&[Value]` below, so we borrow it.
123             fmtln!(fmt, "let {} = &{};", name, name);
124         }
125     }
126 
127     for &op_num in &inst.value_opnums {
128         let arg = &apply.args[op_num];
129         if let Some(var_index) = arg.maybe_var() {
130             let var = var_pool.get(var_index);
131             if var.has_free_typevar() {
132                 fmtln!(
133                     fmt,
134                     "let typeof_{} = pos.func.dfg.value_type({});",
135                     var.name,
136                     var.name
137                 );
138             }
139         }
140     }
141 
142     // If the definition creates results, detach the values and place them in locals.
143     let mut replace_inst = false;
144     if def.defined_vars.len() > 0 {
145         if def.defined_vars
146             == def_pool
147                 .get(var_pool.get(def.defined_vars[0]).dst_def.unwrap())
148                 .defined_vars
149         {
150             // Special case: The instruction replacing node defines the exact same values.
151             fmt.comment(format!(
152                 "Results handled by {}.",
153                 def_pool
154                     .get(var_pool.get(def.defined_vars[0]).dst_def.unwrap())
155                     .to_comment_string(var_pool)
156             ));
157 
158             fmt.line("let r = pos.func.dfg.inst_results(inst);");
159             for (i, &var_index) in def.defined_vars.iter().enumerate() {
160                 let var = var_pool.get(var_index);
161                 fmtln!(fmt, "let {} = &r[{}];", var.name, i);
162                 fmtln!(
163                     fmt,
164                     "let typeof_{} = pos.func.dfg.value_type(*{});",
165                     var.name,
166                     var.name
167                 );
168             }
169 
170             replace_inst = true;
171         } else {
172             // Boring case: Detach the result values, capture them in locals.
173             for &var_index in &def.defined_vars {
174                 fmtln!(fmt, "let {};", var_pool.get(var_index).name);
175             }
176 
177             fmt.line("{");
178             fmt.indent(|fmt| {
179                 fmt.line("let r = pos.func.dfg.inst_results(inst);");
180                 for i in 0..def.defined_vars.len() {
181                     let var = var_pool.get(def.defined_vars[i]);
182                     fmtln!(fmt, "{} = r[{}];", var.name, i);
183                 }
184             });
185             fmt.line("}");
186 
187             for &var_index in &def.defined_vars {
188                 let var = var_pool.get(var_index);
189                 if var.has_free_typevar() {
190                     fmtln!(
191                         fmt,
192                         "let typeof_{} = pos.func.dfg.value_type({});",
193                         var.name,
194                         var.name
195                     );
196                 }
197             }
198         }
199     }
200     replace_inst
201 }
202 
build_derived_expr(tv: &TypeVar) -> String203 fn build_derived_expr(tv: &TypeVar) -> String {
204     let base = match &tv.base {
205         Some(base) => base,
206         None => {
207             assert!(tv.name.starts_with("typeof_"));
208             return format!("Some({})", tv.name);
209         }
210     };
211     let base_expr = build_derived_expr(&base.type_var);
212     format!(
213         "{}.map(|t: crate::ir::Type| t.{}())",
214         base_expr,
215         base.derived_func.name()
216     )
217 }
218 
219 /// Emit rust code for the given check.
220 ///
221 /// The emitted code is a statement redefining the `predicate` variable like this:
222 ///     let predicate = predicate && ...
emit_runtime_typecheck<'a, 'b>( constraint: &'a Constraint, type_sets: &mut UniqueTable<'a, TypeSet>, fmt: &mut Formatter, )223 fn emit_runtime_typecheck<'a, 'b>(
224     constraint: &'a Constraint,
225     type_sets: &mut UniqueTable<'a, TypeSet>,
226     fmt: &mut Formatter,
227 ) {
228     match constraint {
229         Constraint::InTypeset(tv, ts) => {
230             let ts_index = type_sets.add(&ts);
231             fmt.comment(format!(
232                 "{} must belong to {:?}",
233                 tv.name,
234                 type_sets.get(ts_index)
235             ));
236             fmtln!(
237                 fmt,
238                 "let predicate = predicate && TYPE_SETS[{}].contains({});",
239                 ts_index,
240                 tv.name
241             );
242         }
243         Constraint::Eq(tv1, tv2) => {
244             fmtln!(
245                 fmt,
246                 "let predicate = predicate && match ({}, {}) {{",
247                 build_derived_expr(tv1),
248                 build_derived_expr(tv2)
249             );
250             fmt.indent(|fmt| {
251                 fmt.line("(Some(a), Some(b)) => a == b,");
252                 fmt.comment("On overflow, constraint doesn\'t apply");
253                 fmt.line("_ => false,");
254             });
255             fmtln!(fmt, "};");
256         }
257         Constraint::WiderOrEq(tv1, tv2) => {
258             fmtln!(
259                 fmt,
260                 "let predicate = predicate && match ({}, {}) {{",
261                 build_derived_expr(tv1),
262                 build_derived_expr(tv2)
263             );
264             fmt.indent(|fmt| {
265                 fmt.line("(Some(a), Some(b)) => a.wider_or_equal(b),");
266                 fmt.comment("On overflow, constraint doesn\'t apply");
267                 fmt.line("_ => false,");
268             });
269             fmtln!(fmt, "};");
270         }
271     }
272 }
273 
274 /// Determine if `node` represents one of the value splitting instructions: `isplit` or `vsplit.
275 /// These instructions are lowered specially by the `legalize::split` module.
is_value_split(def: &Def) -> bool276 fn is_value_split(def: &Def) -> bool {
277     let name = &def.apply.inst.name;
278     name == "isplit" || name == "vsplit"
279 }
280 
emit_dst_inst(def: &Def, def_pool: &DefPool, var_pool: &VarPool, fmt: &mut Formatter)281 fn emit_dst_inst(def: &Def, def_pool: &DefPool, var_pool: &VarPool, fmt: &mut Formatter) {
282     let defined_vars = {
283         let vars = def
284             .defined_vars
285             .iter()
286             .map(|&var_index| var_pool.get(var_index).name)
287             .collect::<Vec<_>>();
288         if vars.len() == 1 {
289             vars[0].to_string()
290         } else {
291             format!("({})", vars.join(", "))
292         }
293     };
294 
295     if is_value_split(def) {
296         // Split instructions are not emitted with the builder, but by calling special functions in
297         // the `legalizer::split` module. These functions will eliminate concat-split patterns.
298         fmt.line("let curpos = pos.position();");
299         fmt.line("let srcloc = pos.srcloc();");
300         fmtln!(
301             fmt,
302             "let {} = split::{}(pos.func, cfg, curpos, srcloc, {});",
303             defined_vars,
304             def.apply.inst.snake_name(),
305             def.apply.args[0].to_rust_code(var_pool)
306         );
307         return;
308     }
309 
310     if def.defined_vars.is_empty() {
311         // This node doesn't define any values, so just insert the new instruction.
312         fmtln!(
313             fmt,
314             "pos.ins().{};",
315             def.apply.rust_builder(&def.defined_vars, var_pool)
316         );
317         return;
318     }
319 
320     if let Some(src_def0) = var_pool.get(def.defined_vars[0]).src_def {
321         if def.defined_vars == def_pool.get(src_def0).defined_vars {
322             // The replacement instruction defines the exact same values as the source pattern.
323             // Unwrapping would have left the results intact.  Replace the whole instruction.
324             fmtln!(
325                 fmt,
326                 "let {} = pos.func.dfg.replace(inst).{};",
327                 defined_vars,
328                 def.apply.rust_builder(&def.defined_vars, var_pool)
329             );
330 
331             // We need to bump the cursor so following instructions are inserted *after* the
332             // replaced instruction.
333             fmt.line("if pos.current_inst() == Some(inst) {");
334             fmt.indent(|fmt| {
335                 fmt.line("pos.next_inst();");
336             });
337             fmt.line("}");
338             return;
339         }
340     }
341 
342     // Insert a new instruction.
343     let mut builder = format!("let {} = pos.ins()", defined_vars);
344 
345     if def.defined_vars.len() == 1 && var_pool.get(def.defined_vars[0]).is_output() {
346         // Reuse the single source result value.
347         builder = format!(
348             "{}.with_result({})",
349             builder,
350             var_pool.get(def.defined_vars[0]).to_rust_code()
351         );
352     } else if def
353         .defined_vars
354         .iter()
355         .any(|&var_index| var_pool.get(var_index).is_output())
356     {
357         // There are more than one output values that can be reused.
358         let array = def
359             .defined_vars
360             .iter()
361             .map(|&var_index| {
362                 let var = var_pool.get(var_index);
363                 if var.is_output() {
364                     format!("Some({})", var.name)
365                 } else {
366                     "None".into()
367                 }
368             })
369             .collect::<Vec<_>>()
370             .join(", ");
371         builder = format!("{}.with_results([{}])", builder, array);
372     }
373 
374     fmtln!(
375         fmt,
376         "{}.{};",
377         builder,
378         def.apply.rust_builder(&def.defined_vars, var_pool)
379     );
380 }
381 
382 /// Emit code for `transform`, assuming that the opcode of transform's root instruction
383 /// has already been matched.
384 ///
385 /// `inst: Inst` is the variable to be replaced. It is pointed to by `pos: Cursor`.
386 /// `dfg: DataFlowGraph` is available and mutable.
gen_transform<'a>( transform: &'a Transform, format_registry: &FormatRegistry, type_sets: &mut UniqueTable<'a, TypeSet>, fmt: &mut Formatter, )387 fn gen_transform<'a>(
388     transform: &'a Transform,
389     format_registry: &FormatRegistry,
390     type_sets: &mut UniqueTable<'a, TypeSet>,
391     fmt: &mut Formatter,
392 ) {
393     // Unwrap the source instruction, create local variables for the input variables.
394     let replace_inst = unwrap_inst(&transform, format_registry, fmt);
395 
396     // Emit any runtime checks; these will rebind `predicate` emitted by unwrap_inst().
397     for constraint in &transform.type_env.constraints {
398         emit_runtime_typecheck(constraint, type_sets, fmt);
399     }
400 
401     // Guard the actual expansion by `predicate`.
402     fmt.line("if predicate {");
403     fmt.indent(|fmt| {
404         // If we are adding some blocks, we need to recall the original block, such that we can
405         // recompute it.
406         if !transform.block_pool.is_empty() {
407             fmt.line("let orig_ebb = pos.current_ebb().unwrap();");
408         }
409 
410         // If we're going to delete `inst`, we need to detach its results first so they can be
411         // reattached during pattern expansion.
412         if !replace_inst {
413             fmt.line("pos.func.dfg.clear_results(inst);");
414         }
415 
416         // Emit new block creation.
417         for block in &transform.block_pool {
418             let var = transform.var_pool.get(block.name);
419             fmtln!(fmt, "let {} = pos.func.dfg.make_ebb();", var.name);
420         }
421 
422         // Emit the destination pattern.
423         for &def_index in &transform.dst {
424             if let Some(block) = transform.block_pool.get(def_index) {
425                 let var = transform.var_pool.get(block.name);
426                 fmtln!(fmt, "pos.insert_ebb({});", var.name);
427             }
428             emit_dst_inst(
429                 transform.def_pool.get(def_index),
430                 &transform.def_pool,
431                 &transform.var_pool,
432                 fmt,
433             );
434         }
435 
436         // Insert a new block after the last instruction, if needed.
437         let def_next_index = transform.def_pool.next_index();
438         if let Some(block) = transform.block_pool.get(def_next_index) {
439             let var = transform.var_pool.get(block.name);
440             fmtln!(fmt, "pos.insert_ebb({});", var.name);
441         }
442 
443         // Delete the original instruction if we didn't have an opportunity to replace it.
444         if !replace_inst {
445             fmt.line("let removed = pos.remove_inst();");
446             fmt.line("debug_assert_eq!(removed, inst);");
447         }
448 
449         if transform.block_pool.is_empty() {
450             if transform.def_pool.get(transform.src).apply.inst.is_branch {
451                 // A branch might have been legalized into multiple branches, so we need to recompute
452                 // the cfg.
453                 fmt.line("cfg.recompute_ebb(pos.func, pos.current_ebb().unwrap());");
454             }
455         } else {
456             // Update CFG for the new blocks.
457             fmt.line("cfg.recompute_ebb(pos.func, orig_ebb);");
458             for block in &transform.block_pool {
459                 let var = transform.var_pool.get(block.name);
460                 fmtln!(fmt, "cfg.recompute_ebb(pos.func, {});", var.name);
461             }
462         }
463 
464         fmt.line("return true;");
465     });
466     fmt.line("}");
467 }
468 
gen_transform_group<'a>( group: &'a TransformGroup, format_registry: &FormatRegistry, transform_groups: &TransformGroups, type_sets: &mut UniqueTable<'a, TypeSet>, fmt: &mut Formatter, )469 fn gen_transform_group<'a>(
470     group: &'a TransformGroup,
471     format_registry: &FormatRegistry,
472     transform_groups: &TransformGroups,
473     type_sets: &mut UniqueTable<'a, TypeSet>,
474     fmt: &mut Formatter,
475 ) {
476     fmt.doc_comment(group.doc);
477     fmt.line("#[allow(unused_variables,unused_assignments,non_snake_case)]");
478 
479     // Function arguments.
480     fmtln!(fmt, "pub fn {}(", group.name);
481     fmt.indent(|fmt| {
482         fmt.line("inst: crate::ir::Inst,");
483         fmt.line("func: &mut crate::ir::Function,");
484         fmt.line("cfg: &mut crate::flowgraph::ControlFlowGraph,");
485         fmt.line("isa: &dyn crate::isa::TargetIsa,");
486     });
487     fmtln!(fmt, ") -> bool {");
488 
489     // Function body.
490     fmt.indent(|fmt| {
491         fmt.line("use crate::ir::InstBuilder;");
492         fmt.line("use crate::cursor::{Cursor, FuncCursor};");
493         fmt.line("let mut pos = FuncCursor::new(func).at_inst(inst);");
494         fmt.line("pos.use_srcloc(inst);");
495 
496         // Group the transforms by opcode so we can generate a big switch.
497         // Preserve ordering.
498         let mut inst_to_transforms = HashMap::new();
499         for transform in &group.transforms {
500             let def_index = transform.src;
501             let inst = &transform.def_pool.get(def_index).apply.inst;
502             inst_to_transforms
503                 .entry(inst.camel_name.clone())
504                 .or_insert(Vec::new())
505                 .push(transform);
506         }
507 
508         let mut sorted_inst_names = Vec::from_iter(inst_to_transforms.keys());
509         sorted_inst_names.sort();
510 
511         fmt.line("{");
512         fmt.indent(|fmt| {
513             fmt.line("match pos.func.dfg[inst].opcode() {");
514             fmt.indent(|fmt| {
515                 for camel_name in sorted_inst_names {
516                     fmtln!(fmt, "ir::Opcode::{} => {{", camel_name);
517                     fmt.indent(|fmt| {
518                         for transform in inst_to_transforms.get(camel_name).unwrap() {
519                             gen_transform(transform, format_registry, type_sets, fmt);
520                         }
521                     });
522                     fmtln!(fmt, "}");
523                     fmt.empty_line();
524                 }
525 
526                 // Emit the custom transforms. The Rust compiler will complain about any overlap with
527                 // the normal transforms.
528                 let mut sorted_custom_legalizes = Vec::from_iter(&group.custom_legalizes);
529                 sorted_custom_legalizes.sort();
530                 for (inst_camel_name, func_name) in sorted_custom_legalizes {
531                     fmtln!(fmt, "ir::Opcode::{} => {{", inst_camel_name);
532                     fmt.indent(|fmt| {
533                         fmtln!(fmt, "{}(inst, pos.func, cfg, isa);", func_name);
534                         fmt.line("return true;");
535                     });
536                     fmtln!(fmt, "}");
537                     fmt.empty_line();
538                 }
539 
540                 // We'll assume there are uncovered opcodes.
541                 fmt.line("_ => {},");
542             });
543             fmt.line("}");
544         });
545         fmt.line("}");
546 
547         // If we fall through, nothing was expanded; call the chain if any.
548         match &group.chain_with {
549             Some(group_id) => fmtln!(
550                 fmt,
551                 "{}(inst, pos.func, cfg, isa)",
552                 transform_groups.get(*group_id).rust_name()
553             ),
554             None => fmt.line("false"),
555         };
556     });
557     fmtln!(fmt, "}");
558     fmt.empty_line();
559 }
560 
561 /// Generate legalization functions for `isa` and add any shared `TransformGroup`s
562 /// encountered to `shared_groups`.
563 ///
564 /// Generate `TYPE_SETS` and `LEGALIZE_ACTIONS` tables.
gen_isa( isa: &TargetIsa, format_registry: &FormatRegistry, transform_groups: &TransformGroups, shared_group_names: &mut HashSet<&'static str>, fmt: &mut Formatter, )565 fn gen_isa(
566     isa: &TargetIsa,
567     format_registry: &FormatRegistry,
568     transform_groups: &TransformGroups,
569     shared_group_names: &mut HashSet<&'static str>,
570     fmt: &mut Formatter,
571 ) {
572     let mut type_sets = UniqueTable::new();
573     for group_index in isa.transitive_transform_groups(transform_groups) {
574         let group = transform_groups.get(group_index);
575         match group.isa_name {
576             Some(isa_name) => {
577                 assert!(
578                     isa_name == isa.name,
579                     "ISA-specific legalizations must be used by the same ISA"
580                 );
581                 gen_transform_group(
582                     group,
583                     format_registry,
584                     transform_groups,
585                     &mut type_sets,
586                     fmt,
587                 );
588             }
589             None => {
590                 shared_group_names.insert(group.name);
591             }
592         }
593     }
594 
595     gen_typesets_table(&type_sets, fmt);
596 
597     let direct_groups = isa.direct_transform_groups();
598     fmtln!(
599         fmt,
600         "pub static LEGALIZE_ACTIONS: [isa::Legalize; {}] = [",
601         direct_groups.len()
602     );
603     fmt.indent(|fmt| {
604         for &group_index in direct_groups {
605             fmtln!(fmt, "{},", transform_groups.get(group_index).rust_name());
606         }
607     });
608     fmtln!(fmt, "];");
609 }
610 
611 /// Generate the legalizer files.
generate( isas: &Vec<TargetIsa>, format_registry: &FormatRegistry, transform_groups: &TransformGroups, filename_prefix: &str, out_dir: &str, ) -> Result<(), error::Error>612 pub(crate) fn generate(
613     isas: &Vec<TargetIsa>,
614     format_registry: &FormatRegistry,
615     transform_groups: &TransformGroups,
616     filename_prefix: &str,
617     out_dir: &str,
618 ) -> Result<(), error::Error> {
619     let mut shared_group_names = HashSet::new();
620 
621     for isa in isas {
622         let mut fmt = Formatter::new();
623         gen_isa(
624             isa,
625             format_registry,
626             transform_groups,
627             &mut shared_group_names,
628             &mut fmt,
629         );
630         fmt.update_file(format!("{}-{}.rs", filename_prefix, isa.name), out_dir)?;
631     }
632 
633     // Generate shared legalize groups.
634     let mut fmt = Formatter::new();
635     let mut type_sets = UniqueTable::new();
636     let mut sorted_shared_group_names = Vec::from_iter(shared_group_names);
637     sorted_shared_group_names.sort();
638     for group_name in &sorted_shared_group_names {
639         let group = transform_groups.by_name(group_name);
640         gen_transform_group(
641             group,
642             format_registry,
643             transform_groups,
644             &mut type_sets,
645             &mut fmt,
646         );
647     }
648     gen_typesets_table(&type_sets, &mut fmt);
649     fmt.update_file(format!("{}r.rs", filename_prefix), out_dir)?;
650 
651     Ok(())
652 }
653