1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25
26 #define DEBUG_TYPE "spirv-prelegalizer"
27
28 using namespace llvm;
29
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33 static char ID;
SPIRVPreLegalizer()34 SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35 initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36 }
37 bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40
addConstantsToTrack(MachineFunction & MF,SPIRVGlobalRegistry * GR)41 static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
42 MachineRegisterInfo &MRI = MF.getRegInfo();
43 DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
44 SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
45 for (MachineBasicBlock &MBB : MF) {
46 for (MachineInstr &MI : MBB) {
47 if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
48 continue;
49 ToErase.push_back(&MI);
50 auto *Const =
51 cast<Constant>(cast<ConstantAsMetadata>(
52 MI.getOperand(3).getMetadata()->getOperand(0))
53 ->getValue());
54 if (auto *GV = dyn_cast<GlobalValue>(Const)) {
55 Register Reg = GR->find(GV, &MF);
56 if (!Reg.isValid())
57 GR->add(GV, &MF, MI.getOperand(2).getReg());
58 else
59 RegsAlreadyAddedToDT[&MI] = Reg;
60 } else {
61 Register Reg = GR->find(Const, &MF);
62 if (!Reg.isValid()) {
63 if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
64 auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
65 assert(BuildVec &&
66 BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
67 for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
68 GR->add(ConstVec->getElementAsConstant(i), &MF,
69 BuildVec->getOperand(1 + i).getReg());
70 }
71 GR->add(Const, &MF, MI.getOperand(2).getReg());
72 } else {
73 RegsAlreadyAddedToDT[&MI] = Reg;
74 // This MI is unused and will be removed. If the MI uses
75 // const_composite, it will be unused and should be removed too.
76 assert(MI.getOperand(2).isReg() && "Reg operand is expected");
77 MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
78 if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
79 ToEraseComposites.push_back(SrcMI);
80 }
81 }
82 }
83 }
84 for (MachineInstr *MI : ToErase) {
85 Register Reg = MI->getOperand(2).getReg();
86 if (RegsAlreadyAddedToDT.contains(MI))
87 Reg = RegsAlreadyAddedToDT[MI];
88 auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
89 if (!MRI.getRegClassOrNull(Reg) && RC)
90 MRI.setRegClass(Reg, RC);
91 MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
92 MI->eraseFromParent();
93 }
94 for (MachineInstr *MI : ToEraseComposites)
95 MI->eraseFromParent();
96 }
97
foldConstantsIntoIntrinsics(MachineFunction & MF)98 static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
99 SmallVector<MachineInstr *, 10> ToErase;
100 MachineRegisterInfo &MRI = MF.getRegInfo();
101 const unsigned AssignNameOperandShift = 2;
102 for (MachineBasicBlock &MBB : MF) {
103 for (MachineInstr &MI : MBB) {
104 if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
105 continue;
106 unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
107 while (MI.getOperand(NumOp).isReg()) {
108 MachineOperand &MOp = MI.getOperand(NumOp);
109 MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
110 assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
111 MI.removeOperand(NumOp);
112 MI.addOperand(MachineOperand::CreateImm(
113 ConstMI->getOperand(1).getCImm()->getZExtValue()));
114 if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
115 ToErase.push_back(ConstMI);
116 }
117 }
118 }
119 for (MachineInstr *MI : ToErase)
120 MI->eraseFromParent();
121 }
122
insertBitcasts(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)123 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
124 MachineIRBuilder MIB) {
125 SmallVector<MachineInstr *, 10> ToErase;
126 for (MachineBasicBlock &MBB : MF) {
127 for (MachineInstr &MI : MBB) {
128 if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
129 !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
130 continue;
131 assert(MI.getOperand(2).isReg());
132 MIB.setInsertPt(*MI.getParent(), MI);
133 ToErase.push_back(&MI);
134 if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
135 MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
136 continue;
137 }
138 Register Def = MI.getOperand(0).getReg();
139 Register Source = MI.getOperand(2).getReg();
140 SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
141 getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
142 SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
143 BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
144 addressSpaceToStorageClass(MI.getOperand(4).getImm()));
145
146 // If the bitcast would be redundant, replace all uses with the source
147 // register.
148 if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
149 MIB.getMRI()->replaceRegWith(Def, Source);
150 } else {
151 GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
152 MIB.buildBitcast(Def, Source);
153 }
154 }
155 }
156 for (MachineInstr *MI : ToErase)
157 MI->eraseFromParent();
158 }
159
160 // Translating GV, IRTranslator sometimes generates following IR:
161 // %1 = G_GLOBAL_VALUE
162 // %2 = COPY %1
163 // %3 = G_ADDRSPACE_CAST %2
164 // New registers have no SPIRVType and no register class info.
165 //
166 // Set SPIRVType for GV, propagate it from GV to other instructions,
167 // also set register classes.
propagateSPIRVType(MachineInstr * MI,SPIRVGlobalRegistry * GR,MachineRegisterInfo & MRI,MachineIRBuilder & MIB)168 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
169 MachineRegisterInfo &MRI,
170 MachineIRBuilder &MIB) {
171 SPIRVType *SpirvTy = nullptr;
172 assert(MI && "Machine instr is expected");
173 if (MI->getOperand(0).isReg()) {
174 Register Reg = MI->getOperand(0).getReg();
175 SpirvTy = GR->getSPIRVTypeForVReg(Reg);
176 if (!SpirvTy) {
177 switch (MI->getOpcode()) {
178 case TargetOpcode::G_CONSTANT: {
179 MIB.setInsertPt(*MI->getParent(), MI);
180 Type *Ty = MI->getOperand(1).getCImm()->getType();
181 SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
182 break;
183 }
184 case TargetOpcode::G_GLOBAL_VALUE: {
185 MIB.setInsertPt(*MI->getParent(), MI);
186 Type *Ty = MI->getOperand(1).getGlobal()->getType();
187 SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
188 break;
189 }
190 case TargetOpcode::G_TRUNC:
191 case TargetOpcode::G_ADDRSPACE_CAST:
192 case TargetOpcode::G_PTR_ADD:
193 case TargetOpcode::COPY: {
194 MachineOperand &Op = MI->getOperand(1);
195 MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
196 if (Def)
197 SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
198 break;
199 }
200 default:
201 break;
202 }
203 if (SpirvTy)
204 GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
205 if (!MRI.getRegClassOrNull(Reg))
206 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
207 }
208 }
209 return SpirvTy;
210 }
211
212 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
213 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
214 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
215 // It's used also in SPIRVBuiltins.cpp.
216 // TODO: maybe move to SPIRVUtils.
217 namespace llvm {
insertAssignInstr(Register Reg,Type * Ty,SPIRVType * SpirvTy,SPIRVGlobalRegistry * GR,MachineIRBuilder & MIB,MachineRegisterInfo & MRI)218 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
219 SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
220 MachineRegisterInfo &MRI) {
221 MachineInstr *Def = MRI.getVRegDef(Reg);
222 assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
223 MIB.setInsertPt(*Def->getParent(),
224 (Def->getNextNode() ? Def->getNextNode()->getIterator()
225 : Def->getParent()->end()));
226 Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
227 if (auto *RC = MRI.getRegClassOrNull(Reg)) {
228 MRI.setRegClass(NewReg, RC);
229 } else {
230 MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
231 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
232 }
233 SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
234 GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
235 // This is to make it convenient for Legalizer to get the SPIRVType
236 // when processing the actual MI (i.e. not pseudo one).
237 GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
238 // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
239 // the flags after instruction selection.
240 const uint32_t Flags = Def->getFlags();
241 MIB.buildInstr(SPIRV::ASSIGN_TYPE)
242 .addDef(Reg)
243 .addUse(NewReg)
244 .addUse(GR->getSPIRVTypeID(SpirvTy))
245 .setMIFlags(Flags);
246 Def->getOperand(0).setReg(NewReg);
247 return NewReg;
248 }
249 } // namespace llvm
250
generateAssignInstrs(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)251 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
252 MachineIRBuilder MIB) {
253 MachineRegisterInfo &MRI = MF.getRegInfo();
254 SmallVector<MachineInstr *, 10> ToErase;
255
256 for (MachineBasicBlock *MBB : post_order(&MF)) {
257 if (MBB->empty())
258 continue;
259
260 bool ReachedBegin = false;
261 for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
262 !ReachedBegin;) {
263 MachineInstr &MI = *MII;
264
265 if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
266 Register Reg = MI.getOperand(1).getReg();
267 MIB.setInsertPt(*MI.getParent(), MI.getIterator());
268 SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
269 getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
270 SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
271 BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
272 addressSpaceToStorageClass(MI.getOperand(3).getImm()));
273 MachineInstr *Def = MRI.getVRegDef(Reg);
274 assert(Def && "Expecting an instruction that defines the register");
275 insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
276 MF.getRegInfo());
277 ToErase.push_back(&MI);
278 } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
279 Register Reg = MI.getOperand(1).getReg();
280 Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
281 MachineInstr *Def = MRI.getVRegDef(Reg);
282 assert(Def && "Expecting an instruction that defines the register");
283 // G_GLOBAL_VALUE already has type info.
284 if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
285 insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
286 ToErase.push_back(&MI);
287 } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
288 MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
289 MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
290 // %rc = G_CONSTANT ty Val
291 // ===>
292 // %cty = OpType* ty
293 // %rctmp = G_CONSTANT ty Val
294 // %rc = ASSIGN_TYPE %rctmp, %cty
295 Register Reg = MI.getOperand(0).getReg();
296 if (MRI.hasOneUse(Reg)) {
297 MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
298 if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
299 isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
300 continue;
301 }
302 Type *Ty = nullptr;
303 if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
304 Ty = MI.getOperand(1).getCImm()->getType();
305 else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
306 Ty = MI.getOperand(1).getFPImm()->getType();
307 else {
308 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
309 Type *ElemTy = nullptr;
310 MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
311 assert(ElemMI);
312
313 if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
314 ElemTy = ElemMI->getOperand(1).getCImm()->getType();
315 else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
316 ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
317 else
318 llvm_unreachable("Unexpected opcode");
319 unsigned NumElts =
320 MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
321 Ty = VectorType::get(ElemTy, NumElts, false);
322 }
323 insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
324 } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
325 MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
326 MI.getOpcode() == TargetOpcode::COPY ||
327 MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
328 propagateSPIRVType(&MI, GR, MRI, MIB);
329 }
330
331 if (MII == Begin)
332 ReachedBegin = true;
333 else
334 --MII;
335 }
336 }
337 for (MachineInstr *MI : ToErase)
338 MI->eraseFromParent();
339 }
340
341 static std::pair<Register, unsigned>
createNewIdReg(Register ValReg,unsigned Opcode,MachineRegisterInfo & MRI,const SPIRVGlobalRegistry & GR)342 createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
343 const SPIRVGlobalRegistry &GR) {
344 LLT NewT = LLT::scalar(32);
345 SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
346 assert(SpvType && "VReg is expected to have SPIRV type");
347 bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
348 bool IsVectorFloat =
349 SpvType->getOpcode() == SPIRV::OpTypeVector &&
350 GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
351 SPIRV::OpTypeFloat;
352 IsFloat |= IsVectorFloat;
353 auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
354 auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
355 if (MRI.getType(ValReg).isPointer()) {
356 NewT = LLT::pointer(0, 32);
357 GetIdOp = SPIRV::GET_pID;
358 DstClass = &SPIRV::pIDRegClass;
359 } else if (MRI.getType(ValReg).isVector()) {
360 NewT = LLT::fixed_vector(2, NewT);
361 GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
362 DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
363 }
364 Register IdReg = MRI.createGenericVirtualRegister(NewT);
365 MRI.setRegClass(IdReg, DstClass);
366 return {IdReg, GetIdOp};
367 }
368
processInstr(MachineInstr & MI,MachineIRBuilder & MIB,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)369 static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
370 MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
371 unsigned Opc = MI.getOpcode();
372 assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
373 MachineInstr &AssignTypeInst =
374 *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
375 auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
376 AssignTypeInst.getOperand(1).setReg(NewReg);
377 MI.getOperand(0).setReg(NewReg);
378 MIB.setInsertPt(*MI.getParent(),
379 (MI.getNextNode() ? MI.getNextNode()->getIterator()
380 : MI.getParent()->end()));
381 for (auto &Op : MI.operands()) {
382 if (!Op.isReg() || Op.isDef())
383 continue;
384 auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
385 MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
386 Op.setReg(IdOpInfo.first);
387 }
388 }
389
390 // Defined in SPIRVLegalizerInfo.cpp.
391 extern bool isTypeFoldingSupported(unsigned Opcode);
392
processInstrsWithTypeFolding(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)393 static void processInstrsWithTypeFolding(MachineFunction &MF,
394 SPIRVGlobalRegistry *GR,
395 MachineIRBuilder MIB) {
396 MachineRegisterInfo &MRI = MF.getRegInfo();
397 for (MachineBasicBlock &MBB : MF) {
398 for (MachineInstr &MI : MBB) {
399 if (isTypeFoldingSupported(MI.getOpcode()))
400 processInstr(MI, MIB, MRI, GR);
401 }
402 }
403 for (MachineBasicBlock &MBB : MF) {
404 for (MachineInstr &MI : MBB) {
405 // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
406 // to perform tblgen'erated selection and we can't do that on Legalizer
407 // as it operates on gMIR only.
408 if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
409 continue;
410 Register SrcReg = MI.getOperand(1).getReg();
411 unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
412 if (!isTypeFoldingSupported(Opcode))
413 continue;
414 Register DstReg = MI.getOperand(0).getReg();
415 if (MRI.getType(DstReg).isVector())
416 MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
417 // Don't need to reset type of register holding constant and used in
418 // G_ADDRSPACE_CAST, since it braaks legalizer.
419 if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
420 MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
421 if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
422 continue;
423 }
424 MRI.setType(DstReg, LLT::scalar(32));
425 }
426 }
427 }
428
processSwitches(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)429 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
430 MachineIRBuilder MIB) {
431 // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
432 // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
433 // + G_BR triples. A switch with two cases may be transformed to this MIR
434 // sequence:
435 //
436 // intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
437 // %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
438 // G_BRCOND %Dst0, %bb.2
439 // G_BR %bb.5
440 // bb.5.entry:
441 // %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
442 // G_BRCOND %Dst1, %bb.3
443 // G_BR %bb.4
444 // bb.2.sw.bb:
445 // ...
446 // bb.3.sw.bb1:
447 // ...
448 // bb.4.sw.epilog:
449 // ...
450 //
451 // Sometimes (in case of range-compare switches), additional G_SUBs
452 // instructions are inserted before G_ICMPs. Those need to be additionally
453 // processed.
454 //
455 // This function modifies spv_switch call's operands to include destination
456 // MBBs (default and for each constant value).
457 //
458 // At the end, the function removes redundant [G_SUB] + G_ICMP + G_BRCOND +
459 // G_BR sequences.
460
461 MachineRegisterInfo &MRI = MF.getRegInfo();
462
463 // Collect spv_switches and G_ICMPs across all MBBs in MF.
464 std::vector<MachineInstr *> RelevantInsts;
465
466 // Collect redundant MIs from [G_SUB] + G_ICMP + G_BRCOND + G_BR sequences.
467 // After updating spv_switches, the instructions can be removed.
468 std::vector<MachineInstr *> PostUpdateArtifacts;
469
470 // Temporary set of compare registers. G_SUBs and G_ICMPs relating to
471 // spv_switch use these registers.
472 DenseSet<Register> CompareRegs;
473 for (MachineBasicBlock &MBB : MF) {
474 for (MachineInstr &MI : MBB) {
475 // Calls to spv_switch intrinsics representing IR switches.
476 if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
477 assert(MI.getOperand(1).isReg());
478 CompareRegs.insert(MI.getOperand(1).getReg());
479 RelevantInsts.push_back(&MI);
480 }
481
482 // G_SUBs coming from range-compare switch lowering. G_SUBs are found
483 // after spv_switch but before G_ICMP.
484 if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
485 CompareRegs.contains(MI.getOperand(1).getReg())) {
486 assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
487 Register Dst = MI.getOperand(0).getReg();
488 CompareRegs.insert(Dst);
489 PostUpdateArtifacts.push_back(&MI);
490 }
491
492 // G_ICMPs relating to switches.
493 if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
494 CompareRegs.contains(MI.getOperand(2).getReg())) {
495 Register Dst = MI.getOperand(0).getReg();
496 RelevantInsts.push_back(&MI);
497 PostUpdateArtifacts.push_back(&MI);
498 MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
499 assert(CBr->getOpcode() == SPIRV::G_BRCOND);
500 PostUpdateArtifacts.push_back(CBr);
501 MachineInstr *Br = CBr->getNextNode();
502 assert(Br->getOpcode() == SPIRV::G_BR);
503 PostUpdateArtifacts.push_back(Br);
504 }
505 }
506 }
507
508 // Update each spv_switch with destination MBBs.
509 for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
510 if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
511 continue;
512
513 // Currently considered spv_switch.
514 MachineInstr *Switch = *i;
515 // Set the first successor as default MBB to support empty switches.
516 MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
517 // Container for mapping values to MMBs.
518 SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
519
520 // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
521 // spv_switch (i) and break at any spv_switch with the same compare
522 // register (indicating we are back at the same scope).
523 Register CompareReg = Switch->getOperand(1).getReg();
524 for (auto j = i + 1; j != RelevantInsts.end(); j++) {
525 if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
526 (*j)->getOperand(1).getReg() == CompareReg)
527 break;
528
529 if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
530 (*j)->getOperand(2).getReg() == CompareReg))
531 continue;
532
533 MachineInstr *ICMP = *j;
534 Register Dst = ICMP->getOperand(0).getReg();
535 MachineOperand &PredOp = ICMP->getOperand(1);
536 const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
537 assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
538 MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
539 uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
540 MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
541 assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
542 MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
543
544 // Map switch case Value to target MBB.
545 ValuesToMBBs[Value] = MBB;
546
547 // Add target MBB as successor to the switch's MBB.
548 Switch->getParent()->addSuccessor(MBB);
549
550 // The next MI is always G_BR to either the next case or the default.
551 MachineInstr *NextMI = CBr->getNextNode();
552 assert(NextMI->getOpcode() == SPIRV::G_BR &&
553 NextMI->getOperand(0).isMBB());
554 MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
555 // Default MBB does not begin with G_ICMP using spv_switch compare
556 // register.
557 if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
558 (NextMBB->front().getOperand(2).isReg() &&
559 NextMBB->front().getOperand(2).getReg() != CompareReg)) {
560 // Set default MBB and add it as successor to the switch's MBB.
561 DefaultMBB = NextMBB;
562 Switch->getParent()->addSuccessor(DefaultMBB);
563 }
564 }
565
566 // Modify considered spv_switch operands using collected Values and
567 // MBBs.
568 SmallVector<const ConstantInt *, 3> Values;
569 SmallVector<MachineBasicBlock *, 3> MBBs;
570 for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
571 Register CReg = Switch->getOperand(k).getReg();
572 uint64_t Val = getIConstVal(CReg, &MRI);
573 MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
574 if (!ValuesToMBBs[Val])
575 continue;
576
577 Values.push_back(ConstInstr->getOperand(1).getCImm());
578 MBBs.push_back(ValuesToMBBs[Val]);
579 }
580
581 for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
582 Switch->removeOperand(k);
583
584 Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
585 for (unsigned k = 0; k < Values.size(); k++) {
586 Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
587 Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
588 }
589 }
590
591 for (MachineInstr *MI : PostUpdateArtifacts) {
592 MachineBasicBlock *ParentMBB = MI->getParent();
593 MI->eraseFromParent();
594 // If G_ICMP + G_BRCOND + G_BR were the only MIs in MBB, erase this MBB. It
595 // can be safely assumed, there are no breaks or phis directing into this
596 // MBB. However, we need to remove this MBB from the CFG graph. MBBs must be
597 // erased top-down.
598 if (ParentMBB->empty()) {
599 while (!ParentMBB->pred_empty())
600 (*ParentMBB->pred_begin())->removeSuccessor(ParentMBB);
601
602 while (!ParentMBB->succ_empty())
603 ParentMBB->removeSuccessor(ParentMBB->succ_begin());
604
605 ParentMBB->eraseFromParent();
606 }
607 }
608 }
609
isImplicitFallthrough(MachineBasicBlock & MBB)610 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
611 if (MBB.empty())
612 return true;
613
614 // Branching SPIR-V intrinsics are not detected by this generic method.
615 // Thus, we can only trust negative result.
616 if (!MBB.canFallThrough())
617 return false;
618
619 // Otherwise, we must manually check if we have a SPIR-V intrinsic which
620 // prevent an implicit fallthrough.
621 for (MachineBasicBlock::reverse_iterator It = MBB.rbegin(), E = MBB.rend();
622 It != E; ++It) {
623 if (isSpvIntrinsic(*It, Intrinsic::spv_switch))
624 return false;
625 }
626 return true;
627 }
628
removeImplicitFallthroughs(MachineFunction & MF,MachineIRBuilder MIB)629 static void removeImplicitFallthroughs(MachineFunction &MF,
630 MachineIRBuilder MIB) {
631 // It is valid for MachineBasicBlocks to not finish with a branch instruction.
632 // In such cases, they will simply fallthrough their immediate successor.
633 for (MachineBasicBlock &MBB : MF) {
634 if (!isImplicitFallthrough(MBB))
635 continue;
636
637 assert(std::distance(MBB.successors().begin(), MBB.successors().end()) ==
638 1);
639 MIB.setInsertPt(MBB, MBB.end());
640 MIB.buildBr(**MBB.successors().begin());
641 }
642 }
643
runOnMachineFunction(MachineFunction & MF)644 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
645 // Initialize the type registry.
646 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
647 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
648 GR->setCurrentFunc(MF);
649 MachineIRBuilder MIB(MF);
650 addConstantsToTrack(MF, GR);
651 foldConstantsIntoIntrinsics(MF);
652 insertBitcasts(MF, GR, MIB);
653 generateAssignInstrs(MF, GR, MIB);
654 processSwitches(MF, GR, MIB);
655 processInstrsWithTypeFolding(MF, GR, MIB);
656 removeImplicitFallthroughs(MF, MIB);
657
658 return true;
659 }
660
661 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
662 false)
663
664 char SPIRVPreLegalizer::ID = 0;
665
createSPIRVPreLegalizerPass()666 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
667 return new SPIRVPreLegalizer();
668 }
669