1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52
isTypeFoldingSupported(unsigned Opcode)53 bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56
SPIRVLegalizerInfo(const SPIRVSubtarget & ST)57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58 using namespace TargetOpcode;
59
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
62
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
68
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 = LLT::pointer(5, PSize); // Input
106
107 // TODO: remove copy-pasting here by using concatenation in some way.
108 auto allPtrsScalarsAndVectors = {
109 p0, p1, p2, p3, p4, p5, s1, s8, s16,
110 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
112 v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113
114 auto allScalarsAndVectors = {
115 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
117 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118
119 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
120 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
121 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
122 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123
124 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125
126 auto allIntScalars = {s8, s16, s32, s64};
127
128 auto allFloatScalarsAndVectors = {
129 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
130 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131
132 auto allFloatAndIntScalars = allIntScalars;
133
134 auto allPtrs = {p0, p1, p2, p3, p4, p5};
135 auto allWritablePtrs = {p0, p1, p3, p4};
136
137 for (auto Opc : TypeFoldingSupportingOpcs)
138 getActionDefinitionsBuilder(Opc).custom();
139
140 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141
142 // TODO: add proper rules for vectors legalization.
143 getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144
145 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147
148 getActionDefinitionsBuilder(G_MEMSET).legalIf(
149 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
150
151 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
152 .legalForCartesianProduct(allPtrs, allPtrs);
153
154 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
155
156 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
157
158 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
159
160 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
161 .legalForCartesianProduct(allIntScalarsAndVectors,
162 allFloatScalarsAndVectors);
163
164 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
165 .legalForCartesianProduct(allFloatScalarsAndVectors,
166 allScalarsAndVectors);
167
168 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
169 .legalFor(allIntScalarsAndVectors);
170
171 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
172 allIntScalarsAndVectors, allIntScalarsAndVectors);
173
174 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
175
176 getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
177 typeInSet(0, allPtrsScalarsAndVectors),
178 typeInSet(1, allPtrsScalarsAndVectors),
179 LegalityPredicate(([=](const LegalityQuery &Query) {
180 return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
181 }))));
182
183 getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
184
185 getActionDefinitionsBuilder(G_INTTOPTR)
186 .legalForCartesianProduct(allPtrs, allIntScalars);
187 getActionDefinitionsBuilder(G_PTRTOINT)
188 .legalForCartesianProduct(allIntScalars, allPtrs);
189 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
190 allPtrs, allIntScalars);
191
192 // ST.canDirectlyComparePointers() for pointer args is supported in
193 // legalizeCustom().
194 getActionDefinitionsBuilder(G_ICMP).customIf(
195 all(typeInSet(0, allBoolScalarsAndVectors),
196 typeInSet(1, allPtrsScalarsAndVectors)));
197
198 getActionDefinitionsBuilder(G_FCMP).legalIf(
199 all(typeInSet(0, allBoolScalarsAndVectors),
200 typeInSet(1, allFloatScalarsAndVectors)));
201
202 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
203 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
204 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
205 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
206 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
207
208 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
209 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
210
211 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
212 // TODO: add proper legalization rules.
213 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
214
215 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
216 .alwaysLegal();
217
218 // Extensions.
219 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220 .legalForCartesianProduct(allScalarsAndVectors);
221
222 // FP conversions.
223 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
224 .legalForCartesianProduct(allFloatScalarsAndVectors);
225
226 // Pointer-handling.
227 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
228
229 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
230 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
231
232 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
233 // tighten these requirements. Many of these math functions are only legal on
234 // specific bitwidths, so they are not selectable for
235 // allFloatScalarsAndVectors.
236 getActionDefinitionsBuilder({G_FPOW,
237 G_FEXP,
238 G_FEXP2,
239 G_FLOG,
240 G_FLOG2,
241 G_FLOG10,
242 G_FABS,
243 G_FMINNUM,
244 G_FMAXNUM,
245 G_FCEIL,
246 G_FCOS,
247 G_FSIN,
248 G_FSQRT,
249 G_FFLOOR,
250 G_FRINT,
251 G_FNEARBYINT,
252 G_INTRINSIC_ROUND,
253 G_INTRINSIC_TRUNC,
254 G_FMINIMUM,
255 G_FMAXIMUM,
256 G_INTRINSIC_ROUNDEVEN})
257 .legalFor(allFloatScalarsAndVectors);
258
259 getActionDefinitionsBuilder(G_FCOPYSIGN)
260 .legalForCartesianProduct(allFloatScalarsAndVectors,
261 allFloatScalarsAndVectors);
262
263 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
264 allFloatScalarsAndVectors, allIntScalarsAndVectors);
265
266 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
267 getActionDefinitionsBuilder(
268 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
269 .legalForCartesianProduct(allIntScalarsAndVectors,
270 allIntScalarsAndVectors);
271
272 // Struct return types become a single scalar, so cannot easily legalize.
273 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
274 }
275
276 getLegacyLegalizerInfo().computeTables();
277 verify(*ST.getInstrInfo());
278 }
279
convertPtrToInt(Register Reg,LLT ConvTy,SPIRVType * SpirvType,LegalizerHelper & Helper,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)280 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
281 LegalizerHelper &Helper,
282 MachineRegisterInfo &MRI,
283 SPIRVGlobalRegistry *GR) {
284 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
285 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
286 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
287 .addDef(ConvReg)
288 .addUse(Reg);
289 return ConvReg;
290 }
291
legalizeCustom(LegalizerHelper & Helper,MachineInstr & MI,LostDebugLocObserver & LocObserver) const292 bool SPIRVLegalizerInfo::legalizeCustom(
293 LegalizerHelper &Helper, MachineInstr &MI,
294 LostDebugLocObserver &LocObserver) const {
295 auto Opc = MI.getOpcode();
296 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
297 if (!isTypeFoldingSupported(Opc)) {
298 assert(Opc == TargetOpcode::G_ICMP);
299 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
300 auto &Op0 = MI.getOperand(2);
301 auto &Op1 = MI.getOperand(3);
302 Register Reg0 = Op0.getReg();
303 Register Reg1 = Op1.getReg();
304 CmpInst::Predicate Cond =
305 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
306 if ((!ST->canDirectlyComparePointers() ||
307 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
308 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
309 LLT ConvT = LLT::scalar(ST->getPointerSize());
310 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
311 ST->getPointerSize());
312 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
313 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
314 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
315 }
316 return true;
317 }
318 // TODO: implement legalization for other opcodes.
319 return true;
320 }
321