1 //! Fold operations on constants at compile time.
2 #![allow(clippy::float_arithmetic)]
3 
4 use cranelift_codegen::{
5     cursor::{Cursor, FuncCursor},
6     ir::{self, dfg::ValueDef, InstBuilder},
7 };
8 // use rustc_apfloat::{
9 //     ieee::{Double, Single},
10 //     Float,
11 // };
12 
13 enum ConstImm {
14     Bool(bool),
15     I64(i64),
16     Ieee32(f32), // Ieee32 and Ieee64 will be replaced with `Single` and `Double` from the rust_apfloat library eventually.
17     Ieee64(f64),
18 }
19 
20 impl ConstImm {
unwrap_i64(self) -> i6421     fn unwrap_i64(self) -> i64 {
22         if let Self::I64(imm) = self {
23             imm
24         } else {
25             panic!("self did not contain an `i64`.")
26         }
27     }
28 
evaluate_truthiness(self) -> bool29     fn evaluate_truthiness(self) -> bool {
30         match self {
31             Self::Bool(b) => b,
32             Self::I64(imm) => imm != 0,
33             _ => panic!(
34                 "Only a `ConstImm::Bool` and `ConstImm::I64` can be evaluated for \"truthiness\""
35             ),
36         }
37     }
38 }
39 
40 /// Fold operations on constants.
41 ///
42 /// It's important to note that this will not remove unused constants. It's
43 /// assumed that the DCE pass will take care of them.
fold_constants(func: &mut ir::Function)44 pub fn fold_constants(func: &mut ir::Function) {
45     let mut pos = FuncCursor::new(func);
46 
47     while let Some(_block) = pos.next_block() {
48         while let Some(inst) = pos.next_inst() {
49             use self::ir::InstructionData::*;
50             match pos.func.dfg[inst] {
51                 Binary { opcode, args } => {
52                     fold_binary(&mut pos.func.dfg, inst, opcode, args);
53                 }
54                 Unary { opcode, arg } => {
55                     fold_unary(&mut pos.func.dfg, inst, opcode, arg);
56                 }
57                 Branch { opcode, .. } => {
58                     fold_branch(&mut pos, inst, opcode);
59                 }
60                 _ => {}
61             }
62         }
63     }
64 }
65 
resolve_value_to_imm(dfg: &ir::DataFlowGraph, value: ir::Value) -> Option<ConstImm>66 fn resolve_value_to_imm(dfg: &ir::DataFlowGraph, value: ir::Value) -> Option<ConstImm> {
67     let original = dfg.resolve_aliases(value);
68 
69     let inst = match dfg.value_def(original) {
70         ValueDef::Result(inst, _) => inst,
71         ValueDef::Param(_, _) => return None,
72     };
73 
74     use self::ir::{InstructionData::*, Opcode::*};
75     match dfg[inst] {
76         UnaryImm {
77             opcode: Iconst,
78             imm,
79         } => Some(ConstImm::I64(imm.into())),
80         UnaryIeee32 {
81             opcode: F32const,
82             imm,
83         } => {
84             // See https://doc.rust-lang.org/std/primitive.f32.html#method.from_bits for caveats.
85             let ieee_f32 = f32::from_bits(imm.bits());
86             Some(ConstImm::Ieee32(ieee_f32))
87         }
88         UnaryIeee64 {
89             opcode: F64const,
90             imm,
91         } => {
92             // See https://doc.rust-lang.org/std/primitive.f32.html#method.from_bits for caveats.
93             let ieee_f64 = f64::from_bits(imm.bits());
94             Some(ConstImm::Ieee64(ieee_f64))
95         }
96         UnaryBool {
97             opcode: Bconst,
98             imm,
99         } => Some(ConstImm::Bool(imm)),
100         _ => None,
101     }
102 }
103 
evaluate_binary(opcode: ir::Opcode, imm0: ConstImm, imm1: ConstImm) -> Option<ConstImm>104 fn evaluate_binary(opcode: ir::Opcode, imm0: ConstImm, imm1: ConstImm) -> Option<ConstImm> {
105     use core::num::Wrapping;
106 
107     match opcode {
108         ir::Opcode::Iadd => {
109             let imm0 = Wrapping(imm0.unwrap_i64());
110             let imm1 = Wrapping(imm1.unwrap_i64());
111             Some(ConstImm::I64((imm0 + imm1).0))
112         }
113         ir::Opcode::Isub => {
114             let imm0 = Wrapping(imm0.unwrap_i64());
115             let imm1 = Wrapping(imm1.unwrap_i64());
116             Some(ConstImm::I64((imm0 - imm1).0))
117         }
118         ir::Opcode::Imul => {
119             let imm0 = Wrapping(imm0.unwrap_i64());
120             let imm1 = Wrapping(imm1.unwrap_i64());
121             Some(ConstImm::I64((imm0 * imm1).0))
122         }
123         ir::Opcode::Udiv => {
124             let imm0 = Wrapping(imm0.unwrap_i64());
125             let imm1 = Wrapping(imm1.unwrap_i64());
126             if imm1.0 == 0 {
127                 panic!("Cannot divide by a zero.")
128             }
129             Some(ConstImm::I64((imm0 / imm1).0))
130         }
131         ir::Opcode::Fadd => match (imm0, imm1) {
132             (ConstImm::Ieee32(imm0), ConstImm::Ieee32(imm1)) => Some(ConstImm::Ieee32(imm0 + imm1)),
133             (ConstImm::Ieee64(imm0), ConstImm::Ieee64(imm1)) => Some(ConstImm::Ieee64(imm0 + imm1)),
134             _ => unreachable!(),
135         },
136         ir::Opcode::Fsub => match (imm0, imm1) {
137             (ConstImm::Ieee32(imm0), ConstImm::Ieee32(imm1)) => Some(ConstImm::Ieee32(imm0 - imm1)),
138             (ConstImm::Ieee64(imm0), ConstImm::Ieee64(imm1)) => Some(ConstImm::Ieee64(imm0 - imm1)),
139             _ => unreachable!(),
140         },
141         ir::Opcode::Fmul => match (imm0, imm1) {
142             (ConstImm::Ieee32(imm0), ConstImm::Ieee32(imm1)) => Some(ConstImm::Ieee32(imm0 * imm1)),
143             (ConstImm::Ieee64(imm0), ConstImm::Ieee64(imm1)) => Some(ConstImm::Ieee64(imm0 * imm1)),
144             _ => unreachable!(),
145         },
146         ir::Opcode::Fdiv => match (imm0, imm1) {
147             (ConstImm::Ieee32(imm0), ConstImm::Ieee32(imm1)) => Some(ConstImm::Ieee32(imm0 / imm1)),
148             (ConstImm::Ieee64(imm0), ConstImm::Ieee64(imm1)) => Some(ConstImm::Ieee64(imm0 / imm1)),
149             _ => unreachable!(),
150         },
151         _ => None,
152     }
153 }
154 
evaluate_unary(opcode: ir::Opcode, imm: ConstImm) -> Option<ConstImm>155 fn evaluate_unary(opcode: ir::Opcode, imm: ConstImm) -> Option<ConstImm> {
156     match opcode {
157         ir::Opcode::Fneg => match imm {
158             ConstImm::Ieee32(imm) => Some(ConstImm::Ieee32(-imm)),
159             ConstImm::Ieee64(imm) => Some(ConstImm::Ieee64(-imm)),
160             _ => unreachable!(),
161         },
162         ir::Opcode::Fabs => match imm {
163             ConstImm::Ieee32(imm) => Some(ConstImm::Ieee32(imm.abs())),
164             ConstImm::Ieee64(imm) => Some(ConstImm::Ieee64(imm.abs())),
165             _ => unreachable!(),
166         },
167         _ => None,
168     }
169 }
170 
replace_inst(dfg: &mut ir::DataFlowGraph, inst: ir::Inst, const_imm: ConstImm)171 fn replace_inst(dfg: &mut ir::DataFlowGraph, inst: ir::Inst, const_imm: ConstImm) {
172     use self::ConstImm::*;
173     match const_imm {
174         I64(imm) => {
175             let typevar = dfg.ctrl_typevar(inst);
176             dfg.replace(inst).iconst(typevar, imm);
177         }
178         Ieee32(imm) => {
179             dfg.replace(inst)
180                 .f32const(ir::immediates::Ieee32::with_bits(imm.to_bits()));
181         }
182         Ieee64(imm) => {
183             dfg.replace(inst)
184                 .f64const(ir::immediates::Ieee64::with_bits(imm.to_bits()));
185         }
186         Bool(imm) => {
187             let typevar = dfg.ctrl_typevar(inst);
188             dfg.replace(inst).bconst(typevar, imm);
189         }
190     }
191 }
192 
193 /// Fold a binary instruction.
fold_binary( dfg: &mut ir::DataFlowGraph, inst: ir::Inst, opcode: ir::Opcode, args: [ir::Value; 2], )194 fn fold_binary(
195     dfg: &mut ir::DataFlowGraph,
196     inst: ir::Inst,
197     opcode: ir::Opcode,
198     args: [ir::Value; 2],
199 ) {
200     let (imm0, imm1) = if let (Some(imm0), Some(imm1)) = (
201         resolve_value_to_imm(dfg, args[0]),
202         resolve_value_to_imm(dfg, args[1]),
203     ) {
204         (imm0, imm1)
205     } else {
206         return;
207     };
208 
209     if let Some(const_imm) = evaluate_binary(opcode, imm0, imm1) {
210         replace_inst(dfg, inst, const_imm);
211     }
212 }
213 
214 /// Fold a unary instruction.
fold_unary(dfg: &mut ir::DataFlowGraph, inst: ir::Inst, opcode: ir::Opcode, arg: ir::Value)215 fn fold_unary(dfg: &mut ir::DataFlowGraph, inst: ir::Inst, opcode: ir::Opcode, arg: ir::Value) {
216     let imm = if let Some(imm) = resolve_value_to_imm(dfg, arg) {
217         imm
218     } else {
219         return;
220     };
221 
222     if let Some(const_imm) = evaluate_unary(opcode, imm) {
223         replace_inst(dfg, inst, const_imm);
224     }
225 }
226 
fold_branch(pos: &mut FuncCursor, inst: ir::Inst, opcode: ir::Opcode)227 fn fold_branch(pos: &mut FuncCursor, inst: ir::Inst, opcode: ir::Opcode) {
228     let (cond, block, args) = {
229         let values = pos.func.dfg.inst_args(inst);
230         let inst_data = &pos.func.dfg[inst];
231         (
232             match resolve_value_to_imm(&pos.func.dfg, values[0]) {
233                 Some(imm) => imm,
234                 None => return,
235             },
236             inst_data.branch_destination().unwrap(),
237             values[1..].to_vec(),
238         )
239     };
240 
241     let truthiness = cond.evaluate_truthiness();
242     let branch_if_zero = match opcode {
243         ir::Opcode::Brz => true,
244         ir::Opcode::Brnz => false,
245         _ => unreachable!(),
246     };
247 
248     if (branch_if_zero && !truthiness) || (!branch_if_zero && truthiness) {
249         pos.func.dfg.replace(inst).jump(block, &args);
250         // remove the rest of the block to avoid verifier errors
251         while let Some(next_inst) = pos.func.layout.next_inst(inst) {
252             pos.func.layout.remove_inst(next_inst);
253         }
254     } else {
255         pos.remove_inst_and_step_back();
256     }
257 }
258