1 //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert GPU dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/StringSwitch.h"
22
23 using namespace mlir;
24
25 static constexpr const char kSPIRVModule[] = "__spv__";
26
27 namespace {
28 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
29 /// builtin variables.
30 template <typename SourceOp, spirv::BuiltIn builtin>
31 class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
32 public:
33 using OpConversionPattern<SourceOp>::OpConversionPattern;
34
35 LogicalResult
36 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
37 ConversionPatternRewriter &rewriter) const override;
38 };
39
40 /// Pattern lowering subgroup size/id to loading SPIR-V invocation
41 /// builtin variables.
42 template <typename SourceOp, spirv::BuiltIn builtin>
43 class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
44 public:
45 using OpConversionPattern<SourceOp>::OpConversionPattern;
46
47 LogicalResult
48 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
49 ConversionPatternRewriter &rewriter) const override;
50 };
51
52 /// This is separate because in Vulkan workgroup size is exposed to shaders via
53 /// a constant with WorkgroupSize decoration. So here we cannot generate a
54 /// builtin variable; instead the information in the `spv.entry_point_abi`
55 /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
56 class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
57 public:
58 using OpConversionPattern<gpu::BlockDimOp>::OpConversionPattern;
59
60 LogicalResult
61 matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
62 ConversionPatternRewriter &rewriter) const override;
63 };
64
65 /// Pattern to convert a kernel function in GPU dialect within a spv.module.
66 class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
67 public:
68 using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
69
70 LogicalResult
71 matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
72 ConversionPatternRewriter &rewriter) const override;
73
74 private:
75 SmallVector<int32_t, 3> workGroupSizeAsInt32;
76 };
77
78 /// Pattern to convert a gpu.module to a spv.module.
79 class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
80 public:
81 using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
82
83 LogicalResult
84 matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
85 ConversionPatternRewriter &rewriter) const override;
86 };
87
88 class GPUModuleEndConversion final
89 : public OpConversionPattern<gpu::ModuleEndOp> {
90 public:
91 using OpConversionPattern::OpConversionPattern;
92
93 LogicalResult
matchAndRewrite(gpu::ModuleEndOp endOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const94 matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef<Value> operands,
95 ConversionPatternRewriter &rewriter) const override {
96 rewriter.eraseOp(endOp);
97 return success();
98 }
99 };
100
101 /// Pattern to convert a gpu.return into a SPIR-V return.
102 // TODO: This can go to DRR when GPU return has operands.
103 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
104 public:
105 using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
106
107 LogicalResult
108 matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
109 ConversionPatternRewriter &rewriter) const override;
110 };
111
112 } // namespace
113
114 //===----------------------------------------------------------------------===//
115 // Builtins.
116 //===----------------------------------------------------------------------===//
117
getLaunchConfigIndex(Operation * op)118 static Optional<int32_t> getLaunchConfigIndex(Operation *op) {
119 auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
120 if (!dimAttr)
121 return llvm::None;
122
123 return llvm::StringSwitch<Optional<int32_t>>(dimAttr.getValue())
124 .Case("x", 0)
125 .Case("y", 1)
126 .Case("z", 2)
127 .Default(llvm::None);
128 }
129
130 template <typename SourceOp, spirv::BuiltIn builtin>
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const131 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
132 SourceOp op, ArrayRef<Value> operands,
133 ConversionPatternRewriter &rewriter) const {
134 auto index = getLaunchConfigIndex(op);
135 if (!index)
136 return failure();
137
138 // SPIR-V invocation builtin variables are a vector of type <3xi32>
139 auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
140 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141 op, rewriter.getIntegerType(32), spirvBuiltin,
142 rewriter.getI32ArrayAttr({index.getValue()}));
143 return success();
144 }
145
146 template <typename SourceOp, spirv::BuiltIn builtin>
147 LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const148 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
149 SourceOp op, ArrayRef<Value> operands,
150 ConversionPatternRewriter &rewriter) const {
151 auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
152 rewriter.replaceOp(op, spirvBuiltin);
153 return success();
154 }
155
matchAndRewrite(gpu::BlockDimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const156 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
157 gpu::BlockDimOp op, ArrayRef<Value> operands,
158 ConversionPatternRewriter &rewriter) const {
159 auto index = getLaunchConfigIndex(op);
160 if (!index)
161 return failure();
162
163 auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
164 auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
165 auto convertedType =
166 getTypeConverter()->convertType(op.getResult().getType());
167 if (!convertedType)
168 return failure();
169 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
170 op, convertedType, IntegerAttr::get(convertedType, val));
171 return success();
172 }
173
174 //===----------------------------------------------------------------------===//
175 // GPUFuncOp
176 //===----------------------------------------------------------------------===//
177
178 // Legalizes a GPU function as an entry SPIR-V function.
179 static spirv::FuncOp
lowerAsEntryFunction(gpu::GPUFuncOp funcOp,TypeConverter & typeConverter,ConversionPatternRewriter & rewriter,spirv::EntryPointABIAttr entryPointInfo,ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo)180 lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
181 ConversionPatternRewriter &rewriter,
182 spirv::EntryPointABIAttr entryPointInfo,
183 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
184 auto fnType = funcOp.getType();
185 if (fnType.getNumResults()) {
186 funcOp.emitError("SPIR-V lowering only supports entry functions"
187 "with no return values right now");
188 return nullptr;
189 }
190 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
191 funcOp.emitError(
192 "lowering as entry functions requires ABI info for all arguments "
193 "or none of them");
194 return nullptr;
195 }
196 // Update the signature to valid SPIR-V types and add the ABI
197 // attributes. These will be "materialized" by using the
198 // LowerABIAttributesPass.
199 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
200 {
201 for (auto argType : enumerate(funcOp.getType().getInputs())) {
202 auto convertedType = typeConverter.convertType(argType.value());
203 signatureConverter.addInputs(argType.index(), convertedType);
204 }
205 }
206 auto newFuncOp = rewriter.create<spirv::FuncOp>(
207 funcOp.getLoc(), funcOp.getName(),
208 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
209 llvm::None));
210 for (const auto &namedAttr : funcOp->getAttrs()) {
211 if (namedAttr.first == function_like_impl::getTypeAttrName() ||
212 namedAttr.first == SymbolTable::getSymbolAttrName())
213 continue;
214 newFuncOp->setAttr(namedAttr.first, namedAttr.second);
215 }
216
217 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
218 newFuncOp.end());
219 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
220 &signatureConverter)))
221 return nullptr;
222 rewriter.eraseOp(funcOp);
223
224 // Set the attributes for argument and the function.
225 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
226 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
227 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
228 }
229 newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
230
231 return newFuncOp;
232 }
233
234 /// Populates `argABI` with spv.interface_var_abi attributes for lowering
235 /// gpu.func to spv.func if no arguments have the attributes set
236 /// already. Returns failure if any argument has the ABI attribute set already.
237 static LogicalResult
getDefaultABIAttrs(MLIRContext * context,gpu::GPUFuncOp funcOp,SmallVectorImpl<spirv::InterfaceVarABIAttr> & argABI)238 getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp,
239 SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
240 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(funcOp);
241 if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
242 return success();
243
244 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
245 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
246 argIndex, spirv::getInterfaceVarABIAttrName()))
247 return failure();
248 // Vulkan's interface variable requirements needs scalars to be wrapped in a
249 // struct. The struct held in storage buffer.
250 Optional<spirv::StorageClass> sc;
251 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
252 sc = spirv::StorageClass::StorageBuffer;
253 argABI.push_back(spirv::getInterfaceVarABIAttr(0, argIndex, sc, context));
254 }
255 return success();
256 }
257
matchAndRewrite(gpu::GPUFuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const258 LogicalResult GPUFuncOpConversion::matchAndRewrite(
259 gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
260 ConversionPatternRewriter &rewriter) const {
261 if (!gpu::GPUDialect::isKernel(funcOp))
262 return failure();
263
264 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
265 if (failed(getDefaultABIAttrs(rewriter.getContext(), funcOp, argABI))) {
266 argABI.clear();
267 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
268 // If the ABI is already specified, use it.
269 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
270 argIndex, spirv::getInterfaceVarABIAttrName());
271 if (!abiAttr) {
272 funcOp.emitRemark(
273 "match failure: missing 'spv.interface_var_abi' attribute at "
274 "argument ")
275 << argIndex;
276 return failure();
277 }
278 argABI.push_back(abiAttr);
279 }
280 }
281
282 auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
283 if (!entryPointAttr) {
284 funcOp.emitRemark("match failure: missing 'spv.entry_point_abi' attribute");
285 return failure();
286 }
287 spirv::FuncOp newFuncOp = lowerAsEntryFunction(
288 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
289 if (!newFuncOp)
290 return failure();
291 newFuncOp->removeAttr(Identifier::get(
292 gpu::GPUDialect::getKernelFuncAttrName(), rewriter.getContext()));
293 return success();
294 }
295
296 //===----------------------------------------------------------------------===//
297 // ModuleOp with gpu.module.
298 //===----------------------------------------------------------------------===//
299
matchAndRewrite(gpu::GPUModuleOp moduleOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const300 LogicalResult GPUModuleConversion::matchAndRewrite(
301 gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
302 ConversionPatternRewriter &rewriter) const {
303 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp);
304 spirv::AddressingModel addressingModel = spirv::getAddressingModel(targetEnv);
305 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
306 if (failed(memoryModel))
307 return moduleOp.emitRemark("match failure: could not selected memory model "
308 "based on 'spv.target_env'");
309
310 // Add a keyword to the module name to avoid symbolic conflict.
311 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
312 auto spvModule = rewriter.create<spirv::ModuleOp>(
313 moduleOp.getLoc(), addressingModel, memoryModel.getValue(),
314 StringRef(spvModuleName));
315
316 // Move the region from the module op into the SPIR-V module.
317 Region &spvModuleRegion = spvModule.getRegion();
318 rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion,
319 spvModuleRegion.begin());
320 // The spv.module build method adds a block. Remove that.
321 rewriter.eraseBlock(&spvModuleRegion.back());
322 rewriter.eraseOp(moduleOp);
323 return success();
324 }
325
326 //===----------------------------------------------------------------------===//
327 // GPU return inside kernel functions to SPIR-V return.
328 //===----------------------------------------------------------------------===//
329
matchAndRewrite(gpu::ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const330 LogicalResult GPUReturnOpConversion::matchAndRewrite(
331 gpu::ReturnOp returnOp, ArrayRef<Value> operands,
332 ConversionPatternRewriter &rewriter) const {
333 if (!operands.empty())
334 return failure();
335
336 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
337 return success();
338 }
339
340 //===----------------------------------------------------------------------===//
341 // GPU To SPIRV Patterns.
342 //===----------------------------------------------------------------------===//
343
populateGPUToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)344 void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
345 RewritePatternSet &patterns) {
346 patterns.add<
347 GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion,
348 GPUReturnOpConversion,
349 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
350 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
351 LaunchConfigConversion<gpu::ThreadIdOp,
352 spirv::BuiltIn::LocalInvocationId>,
353 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
354 spirv::BuiltIn::SubgroupId>,
355 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
356 spirv::BuiltIn::NumSubgroups>,
357 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
358 spirv::BuiltIn::SubgroupSize>,
359 WorkGroupSizeConversion>(typeConverter, patterns.getContext());
360 }
361