1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26 
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28     TargetOpcode::G_ADD,
29     TargetOpcode::G_FADD,
30     TargetOpcode::G_SUB,
31     TargetOpcode::G_FSUB,
32     TargetOpcode::G_MUL,
33     TargetOpcode::G_FMUL,
34     TargetOpcode::G_SDIV,
35     TargetOpcode::G_UDIV,
36     TargetOpcode::G_FDIV,
37     TargetOpcode::G_SREM,
38     TargetOpcode::G_UREM,
39     TargetOpcode::G_FREM,
40     TargetOpcode::G_FNEG,
41     TargetOpcode::G_CONSTANT,
42     TargetOpcode::G_FCONSTANT,
43     TargetOpcode::G_AND,
44     TargetOpcode::G_OR,
45     TargetOpcode::G_XOR,
46     TargetOpcode::G_SHL,
47     TargetOpcode::G_ASHR,
48     TargetOpcode::G_LSHR,
49     TargetOpcode::G_SELECT,
50     TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52 
53 bool isTypeFoldingSupported(unsigned Opcode) {
54   return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56 
57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58   using namespace TargetOpcode;
59 
60   this->ST = &ST;
61   GR = ST.getSPIRVGlobalRegistry();
62 
63   const LLT s1 = LLT::scalar(1);
64   const LLT s8 = LLT::scalar(8);
65   const LLT s16 = LLT::scalar(16);
66   const LLT s32 = LLT::scalar(32);
67   const LLT s64 = LLT::scalar(64);
68 
69   const LLT v16s64 = LLT::fixed_vector(16, 64);
70   const LLT v16s32 = LLT::fixed_vector(16, 32);
71   const LLT v16s16 = LLT::fixed_vector(16, 16);
72   const LLT v16s8 = LLT::fixed_vector(16, 8);
73   const LLT v16s1 = LLT::fixed_vector(16, 1);
74 
75   const LLT v8s64 = LLT::fixed_vector(8, 64);
76   const LLT v8s32 = LLT::fixed_vector(8, 32);
77   const LLT v8s16 = LLT::fixed_vector(8, 16);
78   const LLT v8s8 = LLT::fixed_vector(8, 8);
79   const LLT v8s1 = LLT::fixed_vector(8, 1);
80 
81   const LLT v4s64 = LLT::fixed_vector(4, 64);
82   const LLT v4s32 = LLT::fixed_vector(4, 32);
83   const LLT v4s16 = LLT::fixed_vector(4, 16);
84   const LLT v4s8 = LLT::fixed_vector(4, 8);
85   const LLT v4s1 = LLT::fixed_vector(4, 1);
86 
87   const LLT v3s64 = LLT::fixed_vector(3, 64);
88   const LLT v3s32 = LLT::fixed_vector(3, 32);
89   const LLT v3s16 = LLT::fixed_vector(3, 16);
90   const LLT v3s8 = LLT::fixed_vector(3, 8);
91   const LLT v3s1 = LLT::fixed_vector(3, 1);
92 
93   const LLT v2s64 = LLT::fixed_vector(2, 64);
94   const LLT v2s32 = LLT::fixed_vector(2, 32);
95   const LLT v2s16 = LLT::fixed_vector(2, 16);
96   const LLT v2s8 = LLT::fixed_vector(2, 8);
97   const LLT v2s1 = LLT::fixed_vector(2, 1);
98 
99   const unsigned PSize = ST.getPointerSize();
100   const LLT p0 = LLT::pointer(0, PSize); // Function
101   const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104   const LLT p4 = LLT::pointer(4, PSize); // Generic
105   const LLT p5 = LLT::pointer(5, PSize); // Input
106 
107   // TODO: remove copy-pasting here by using concatenation in some way.
108   auto allPtrsScalarsAndVectors = {
109       p0,    p1,    p2,    p3,    p4,    p5,    s1,     s8,     s16,
110       s32,   s64,   v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,
111       v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,
112       v8s8,  v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113 
114   auto allScalarsAndVectors = {
115       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
116       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
117       v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118 
119   auto allIntScalarsAndVectors = {s8,    s16,   s32,   s64,    v2s8,   v2s16,
120                                   v2s32, v2s64, v3s8,  v3s16,  v3s32,  v3s64,
121                                   v4s8,  v4s16, v4s32, v4s64,  v8s8,   v8s16,
122                                   v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123 
124   auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125 
126   auto allIntScalars = {s8, s16, s32, s64};
127 
128   auto allFloatScalarsAndVectors = {
129       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
130       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131 
132   auto allFloatAndIntScalars = allIntScalars;
133 
134   auto allPtrs = {p0, p1, p2, p3, p4, p5};
135   auto allWritablePtrs = {p0, p1, p3, p4};
136 
137   for (auto Opc : TypeFoldingSupportingOpcs)
138     getActionDefinitionsBuilder(Opc).custom();
139 
140   getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141 
142   // TODO: add proper rules for vectors legalization.
143   getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144 
145   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146       .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147 
148   getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
149       .legalForCartesianProduct(allPtrs, allPtrs);
150 
151   getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
152 
153   getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
154 
155   getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
156 
157   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
158       .legalForCartesianProduct(allIntScalarsAndVectors,
159                                 allFloatScalarsAndVectors);
160 
161   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
162       .legalForCartesianProduct(allFloatScalarsAndVectors,
163                                 allScalarsAndVectors);
164 
165   getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
166       .legalFor(allIntScalarsAndVectors);
167 
168   getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
169       allIntScalarsAndVectors, allIntScalarsAndVectors);
170 
171   getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
172 
173   getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
174       typeInSet(0, allPtrsScalarsAndVectors),
175       typeInSet(1, allPtrsScalarsAndVectors),
176       LegalityPredicate(([=](const LegalityQuery &Query) {
177         return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
178       }))));
179 
180   getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
181 
182   getActionDefinitionsBuilder(G_INTTOPTR)
183       .legalForCartesianProduct(allPtrs, allIntScalars);
184   getActionDefinitionsBuilder(G_PTRTOINT)
185       .legalForCartesianProduct(allIntScalars, allPtrs);
186   getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
187       allPtrs, allIntScalars);
188 
189   // ST.canDirectlyComparePointers() for pointer args is supported in
190   // legalizeCustom().
191   getActionDefinitionsBuilder(G_ICMP).customIf(
192       all(typeInSet(0, allBoolScalarsAndVectors),
193           typeInSet(1, allPtrsScalarsAndVectors)));
194 
195   getActionDefinitionsBuilder(G_FCMP).legalIf(
196       all(typeInSet(0, allBoolScalarsAndVectors),
197           typeInSet(1, allFloatScalarsAndVectors)));
198 
199   getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
200                                G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
201                                G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
202                                G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
203       .legalForCartesianProduct(allIntScalars, allWritablePtrs);
204 
205   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
206       .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
207 
208   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
209   // TODO: add proper legalization rules.
210   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
211 
212   getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
213       .alwaysLegal();
214 
215   // Extensions.
216   getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
217       .legalForCartesianProduct(allScalarsAndVectors);
218 
219   // FP conversions.
220   getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
221       .legalForCartesianProduct(allFloatScalarsAndVectors);
222 
223   // Pointer-handling.
224   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
225 
226   // Control-flow.
227   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1});
228 
229   getActionDefinitionsBuilder({G_FPOW,
230                                G_FEXP,
231                                G_FEXP2,
232                                G_FLOG,
233                                G_FLOG2,
234                                G_FABS,
235                                G_FMINNUM,
236                                G_FMAXNUM,
237                                G_FCEIL,
238                                G_FCOS,
239                                G_FSIN,
240                                G_FSQRT,
241                                G_FFLOOR,
242                                G_FRINT,
243                                G_FNEARBYINT,
244                                G_INTRINSIC_ROUND,
245                                G_INTRINSIC_TRUNC,
246                                G_FMINIMUM,
247                                G_FMAXIMUM,
248                                G_INTRINSIC_ROUNDEVEN})
249       .legalFor(allFloatScalarsAndVectors);
250 
251   getActionDefinitionsBuilder(G_FCOPYSIGN)
252       .legalForCartesianProduct(allFloatScalarsAndVectors,
253                                 allFloatScalarsAndVectors);
254 
255   getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
256       allFloatScalarsAndVectors, allIntScalarsAndVectors);
257 
258   getLegacyLegalizerInfo().computeTables();
259   verify(*ST.getInstrInfo());
260 }
261 
262 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
263                                 LegalizerHelper &Helper,
264                                 MachineRegisterInfo &MRI,
265                                 SPIRVGlobalRegistry *GR) {
266   Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
267   GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
268   Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
269       .addDef(ConvReg)
270       .addUse(Reg);
271   return ConvReg;
272 }
273 
274 bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper,
275                                         MachineInstr &MI) const {
276   auto Opc = MI.getOpcode();
277   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
278   if (!isTypeFoldingSupported(Opc)) {
279     assert(Opc == TargetOpcode::G_ICMP);
280     assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
281     auto &Op0 = MI.getOperand(2);
282     auto &Op1 = MI.getOperand(3);
283     Register Reg0 = Op0.getReg();
284     Register Reg1 = Op1.getReg();
285     CmpInst::Predicate Cond =
286         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
287     if ((!ST->canDirectlyComparePointers() ||
288          (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
289         MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
290       LLT ConvT = LLT::scalar(ST->getPointerSize());
291       Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
292                                       ST->getPointerSize());
293       SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
294       Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
295       Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
296     }
297     return true;
298   }
299   // TODO: implement legalization for other opcodes.
300   return true;
301 }
302