1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert GPU dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/StringSwitch.h"
22 
23 using namespace mlir;
24 
25 static constexpr const char kSPIRVModule[] = "__spv__";
26 
27 namespace {
28 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
29 /// builtin variables.
30 template <typename SourceOp, spirv::BuiltIn builtin>
31 class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
32 public:
33   using OpConversionPattern<SourceOp>::OpConversionPattern;
34 
35   LogicalResult
36   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
37                   ConversionPatternRewriter &rewriter) const override;
38 };
39 
40 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
41 /// builtin variables.
42 template <typename SourceOp, spirv::BuiltIn builtin>
43 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
44 public:
45   using OpConversionPattern<SourceOp>::OpConversionPattern;
46 
47   LogicalResult
48   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
49                   ConversionPatternRewriter &rewriter) const override;
50 };
51 
52 /// This is separate because in Vulkan workgroup size is exposed to shaders via
53 /// a constant with WorkgroupSize decoration. So here we cannot generate a
54 /// builtin variable; instead the information in the `spv.entry_point_abi`
55 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
56 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
57 public:
58   using OpConversionPattern<gpu::BlockDimOp>::OpConversionPattern;
59 
60   LogicalResult
61   matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
62                   ConversionPatternRewriter &rewriter) const override;
63 };
64 
65 /// Pattern to convert a kernel function in GPU dialect within a spv.module.
66 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
67 public:
68   using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
69 
70   LogicalResult
71   matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
72                   ConversionPatternRewriter &rewriter) const override;
73 
74 private:
75   SmallVector<int32_t, 3> workGroupSizeAsInt32;
76 };
77 
78 /// Pattern to convert a gpu.module to a spv.module.
79 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
80 public:
81   using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
82 
83   LogicalResult
84   matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
85                   ConversionPatternRewriter &rewriter) const override;
86 };
87 
88 class GPUModuleEndConversion final
89     : public OpConversionPattern<gpu::ModuleEndOp> {
90 public:
91   using OpConversionPattern::OpConversionPattern;
92 
93   LogicalResult
matchAndRewrite(gpu::ModuleEndOp endOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const94   matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef<Value> operands,
95                   ConversionPatternRewriter &rewriter) const override {
96     rewriter.eraseOp(endOp);
97     return success();
98   }
99 };
100 
101 /// Pattern to convert a gpu.return into a SPIR-V return.
102 // TODO: This can go to DRR when GPU return has operands.
103 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
104 public:
105   using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
109                   ConversionPatternRewriter &rewriter) const override;
110 };
111 
112 } // namespace
113 
114 //===----------------------------------------------------------------------===//
115 // Builtins.
116 //===----------------------------------------------------------------------===//
117 
getLaunchConfigIndex(Operation * op)118 static Optional<int32_t> getLaunchConfigIndex(Operation *op) {
119   auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
120   if (!dimAttr)
121     return llvm::None;
122 
123   return llvm::StringSwitch<Optional<int32_t>>(dimAttr.getValue())
124       .Case("x", 0)
125       .Case("y", 1)
126       .Case("z", 2)
127       .Default(llvm::None);
128 }
129 
130 template <typename SourceOp, spirv::BuiltIn builtin>
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const131 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
132     SourceOp op, ArrayRef<Value> operands,
133     ConversionPatternRewriter &rewriter) const {
134   auto index = getLaunchConfigIndex(op);
135   if (!index)
136     return failure();
137 
138   // SPIR-V invocation builtin variables are a vector of type <3xi32>
139   auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
140   rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141       op, rewriter.getIntegerType(32), spirvBuiltin,
142       rewriter.getI32ArrayAttr({index.getValue()}));
143   return success();
144 }
145 
146 template <typename SourceOp, spirv::BuiltIn builtin>
147 LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const148 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
149     SourceOp op, ArrayRef<Value> operands,
150     ConversionPatternRewriter &rewriter) const {
151   auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
152   rewriter.replaceOp(op, spirvBuiltin);
153   return success();
154 }
155 
matchAndRewrite(gpu::BlockDimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const156 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
157     gpu::BlockDimOp op, ArrayRef<Value> operands,
158     ConversionPatternRewriter &rewriter) const {
159   auto index = getLaunchConfigIndex(op);
160   if (!index)
161     return failure();
162 
163   auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
164   auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
165   auto convertedType =
166       getTypeConverter()->convertType(op.getResult().getType());
167   if (!convertedType)
168     return failure();
169   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
170       op, convertedType, IntegerAttr::get(convertedType, val));
171   return success();
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // GPUFuncOp
176 //===----------------------------------------------------------------------===//
177 
178 // Legalizes a GPU function as an entry SPIR-V function.
179 static spirv::FuncOp
lowerAsEntryFunction(gpu::GPUFuncOp funcOp,TypeConverter & typeConverter,ConversionPatternRewriter & rewriter,spirv::EntryPointABIAttr entryPointInfo,ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo)180 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
181                      ConversionPatternRewriter &rewriter,
182                      spirv::EntryPointABIAttr entryPointInfo,
183                      ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
184   auto fnType = funcOp.getType();
185   if (fnType.getNumResults()) {
186     funcOp.emitError("SPIR-V lowering only supports entry functions"
187                      "with no return values right now");
188     return nullptr;
189   }
190   if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
191     funcOp.emitError(
192         "lowering as entry functions requires ABI info for all arguments "
193         "or none of them");
194     return nullptr;
195   }
196   // Update the signature to valid SPIR-V types and add the ABI
197   // attributes. These will be "materialized" by using the
198   // LowerABIAttributesPass.
199   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
200   {
201     for (auto argType : enumerate(funcOp.getType().getInputs())) {
202       auto convertedType = typeConverter.convertType(argType.value());
203       signatureConverter.addInputs(argType.index(), convertedType);
204     }
205   }
206   auto newFuncOp = rewriter.create<spirv::FuncOp>(
207       funcOp.getLoc(), funcOp.getName(),
208       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
209                                llvm::None));
210   for (const auto &namedAttr : funcOp->getAttrs()) {
211     if (namedAttr.first == function_like_impl::getTypeAttrName() ||
212         namedAttr.first == SymbolTable::getSymbolAttrName())
213       continue;
214     newFuncOp->setAttr(namedAttr.first, namedAttr.second);
215   }
216 
217   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
218                               newFuncOp.end());
219   if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
220                                          &signatureConverter)))
221     return nullptr;
222   rewriter.eraseOp(funcOp);
223 
224   // Set the attributes for argument and the function.
225   StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
226   for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
227     newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
228   }
229   newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
230 
231   return newFuncOp;
232 }
233 
234 /// Populates `argABI` with spv.interface_var_abi attributes for lowering
235 /// gpu.func to spv.func if no arguments have the attributes set
236 /// already. Returns failure if any argument has the ABI attribute set already.
237 static LogicalResult
getDefaultABIAttrs(MLIRContext * context,gpu::GPUFuncOp funcOp,SmallVectorImpl<spirv::InterfaceVarABIAttr> & argABI)238 getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp,
239                    SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
240   spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(funcOp);
241   if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
242     return success();
243 
244   for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
245     if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
246             argIndex, spirv::getInterfaceVarABIAttrName()))
247       return failure();
248     // Vulkan's interface variable requirements needs scalars to be wrapped in a
249     // struct. The struct held in storage buffer.
250     Optional<spirv::StorageClass> sc;
251     if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
252       sc = spirv::StorageClass::StorageBuffer;
253     argABI.push_back(spirv::getInterfaceVarABIAttr(0, argIndex, sc, context));
254   }
255   return success();
256 }
257 
matchAndRewrite(gpu::GPUFuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const258 LogicalResult GPUFuncOpConversion::matchAndRewrite(
259     gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
260     ConversionPatternRewriter &rewriter) const {
261   if (!gpu::GPUDialect::isKernel(funcOp))
262     return failure();
263 
264   SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
265   if (failed(getDefaultABIAttrs(rewriter.getContext(), funcOp, argABI))) {
266     argABI.clear();
267     for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
268       // If the ABI is already specified, use it.
269       auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
270           argIndex, spirv::getInterfaceVarABIAttrName());
271       if (!abiAttr) {
272         funcOp.emitRemark(
273             "match failure: missing 'spv.interface_var_abi' attribute at "
274             "argument ")
275             << argIndex;
276         return failure();
277       }
278       argABI.push_back(abiAttr);
279     }
280   }
281 
282   auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
283   if (!entryPointAttr) {
284     funcOp.emitRemark("match failure: missing 'spv.entry_point_abi' attribute");
285     return failure();
286   }
287   spirv::FuncOp newFuncOp = lowerAsEntryFunction(
288       funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
289   if (!newFuncOp)
290     return failure();
291   newFuncOp->removeAttr(Identifier::get(
292       gpu::GPUDialect::getKernelFuncAttrName(), rewriter.getContext()));
293   return success();
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // ModuleOp with gpu.module.
298 //===----------------------------------------------------------------------===//
299 
matchAndRewrite(gpu::GPUModuleOp moduleOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const300 LogicalResult GPUModuleConversion::matchAndRewrite(
301     gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
302     ConversionPatternRewriter &rewriter) const {
303   spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp);
304   spirv::AddressingModel addressingModel = spirv::getAddressingModel(targetEnv);
305   FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
306   if (failed(memoryModel))
307     return moduleOp.emitRemark("match failure: could not selected memory model "
308                                "based on 'spv.target_env'");
309 
310   // Add a keyword to the module name to avoid symbolic conflict.
311   std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
312   auto spvModule = rewriter.create<spirv::ModuleOp>(
313       moduleOp.getLoc(), addressingModel, memoryModel.getValue(),
314       StringRef(spvModuleName));
315 
316   // Move the region from the module op into the SPIR-V module.
317   Region &spvModuleRegion = spvModule.getRegion();
318   rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion,
319                               spvModuleRegion.begin());
320   // The spv.module build method adds a block. Remove that.
321   rewriter.eraseBlock(&spvModuleRegion.back());
322   rewriter.eraseOp(moduleOp);
323   return success();
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // GPU return inside kernel functions to SPIR-V return.
328 //===----------------------------------------------------------------------===//
329 
matchAndRewrite(gpu::ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const330 LogicalResult GPUReturnOpConversion::matchAndRewrite(
331     gpu::ReturnOp returnOp, ArrayRef<Value> operands,
332     ConversionPatternRewriter &rewriter) const {
333   if (!operands.empty())
334     return failure();
335 
336   rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
337   return success();
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // GPU To SPIRV Patterns.
342 //===----------------------------------------------------------------------===//
343 
populateGPUToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)344 void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
345                                       RewritePatternSet &patterns) {
346   patterns.add<
347       GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion,
348       GPUReturnOpConversion,
349       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
350       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
351       LaunchConfigConversion<gpu::ThreadIdOp,
352                              spirv::BuiltIn::LocalInvocationId>,
353       SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
354                                       spirv::BuiltIn::SubgroupId>,
355       SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
356                                       spirv::BuiltIn::NumSubgroups>,
357       SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
358                                       spirv::BuiltIn::SubgroupSize>,
359       WorkGroupSizeConversion>(typeConverter, patterns.getContext());
360 }
361