1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Support/FormatVariadic.h"
26
27 using namespace mlir;
28
29 static constexpr const char *kCInterfaceVulkanLaunch =
30 "_mlir_ciface_vulkanLaunch";
31 static constexpr const char *kDeinitVulkan = "deinitVulkan";
32 static constexpr const char *kRunOnVulkan = "runOnVulkan";
33 static constexpr const char *kInitVulkan = "initVulkan";
34 static constexpr const char *kSetBinaryShader = "setBinaryShader";
35 static constexpr const char *kSetEntryPoint = "setEntryPoint";
36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
38 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
39 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
40 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
41
42 namespace {
43
44 /// A pass to convert vulkan launch call op into a sequence of Vulkan
45 /// runtime calls in the following order:
46 ///
47 /// * initVulkan -- initializes vulkan runtime
48 /// * bindMemRef -- binds memref
49 /// * setBinaryShader -- sets the binary shader data
50 /// * setEntryPoint -- sets the entry point name
51 /// * setNumWorkGroups -- sets the number of a local workgroups
52 /// * runOnVulkan -- runs vulkan runtime
53 /// * deinitVulkan -- deinitializes vulkan runtime
54 ///
55 class VulkanLaunchFuncToVulkanCallsPass
56 : public ConvertVulkanLaunchFuncToVulkanCallsBase<
57 VulkanLaunchFuncToVulkanCallsPass> {
58 private:
initializeCachedTypes()59 void initializeCachedTypes() {
60 llvmFloatType = Float32Type::get(&getContext());
61 llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
62 llvmPointerType =
63 LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
64 llvmInt32Type = IntegerType::get(&getContext(), 32);
65 llvmInt64Type = IntegerType::get(&getContext(), 64);
66 }
67
getMemRefType(uint32_t rank,Type elemenType)68 Type getMemRefType(uint32_t rank, Type elemenType) {
69 // According to the MLIR doc memref argument is converted into a
70 // pointer-to-struct argument of type:
71 // template <typename Elem, size_t Rank>
72 // struct {
73 // Elem *allocated;
74 // Elem *aligned;
75 // int64_t offset;
76 // int64_t sizes[Rank]; // omitted when rank == 0
77 // int64_t strides[Rank]; // omitted when rank == 0
78 // };
79 auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
80 auto llvmArrayRankElementSizeType =
81 LLVM::LLVMArrayType::get(getInt64Type(), rank);
82
83 // Create a type
84 // `!llvm<"{ `element-type`*, `element-type`*, i64,
85 // [`rank` x i64], [`rank` x i64]}">`.
86 return LLVM::LLVMStructType::getLiteral(
87 &getContext(),
88 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
89 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
90 }
91
getVoidType()92 Type getVoidType() { return llvmVoidType; }
getPointerType()93 Type getPointerType() { return llvmPointerType; }
getInt32Type()94 Type getInt32Type() { return llvmInt32Type; }
getInt64Type()95 Type getInt64Type() { return llvmInt64Type; }
96
97 /// Creates an LLVM global for the given `name`.
98 Value createEntryPointNameConstant(StringRef name, Location loc,
99 OpBuilder &builder);
100
101 /// Declares all needed runtime functions.
102 void declareVulkanFunctions(Location loc);
103
104 /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
isVulkanLaunchCallOp(LLVM::CallOp callOp)105 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
106 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
107 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
108 }
109
110 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
111 /// op.
isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp)112 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
113 return (callOp.callee() &&
114 callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
115 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
116 }
117
118 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
119 /// runtime calls.
120 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
121
122 /// Creates call to `bindMemRef` for each memref operand.
123 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
124 Value vulkanRuntime);
125
126 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
127 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
128
129 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
130 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
131 uint32_t &rank, Type &type);
132
133 /// Returns a string representation from the given `type`.
stringifyType(Type type)134 StringRef stringifyType(Type type) {
135 if (type.isa<Float32Type>())
136 return "Float";
137 if (type.isa<Float16Type>())
138 return "Half";
139 if (auto intType = type.dyn_cast<IntegerType>()) {
140 if (intType.getWidth() == 32)
141 return "Int32";
142 if (intType.getWidth() == 16)
143 return "Int16";
144 if (intType.getWidth() == 8)
145 return "Int8";
146 }
147
148 llvm_unreachable("unsupported type");
149 }
150
151 public:
152 void runOnOperation() override;
153
154 private:
155 Type llvmFloatType;
156 Type llvmVoidType;
157 Type llvmPointerType;
158 Type llvmInt32Type;
159 Type llvmInt64Type;
160
161 // TODO: Use an associative array to support multiple vulkan launch calls.
162 std::pair<StringAttr, StringAttr> spirvAttributes;
163 /// The number of vulkan launch configuration operands, placed at the leading
164 /// positions of the operand list.
165 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
166 };
167
168 } // anonymous namespace
169
runOnOperation()170 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
171 initializeCachedTypes();
172
173 // Collect SPIR-V attributes such as `spirv_blob` and
174 // `spirv_entry_point_name`.
175 getOperation().walk([this](LLVM::CallOp op) {
176 if (isVulkanLaunchCallOp(op))
177 collectSPIRVAttributes(op);
178 });
179
180 // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
181 getOperation().walk([this](LLVM::CallOp op) {
182 if (isCInterfaceVulkanLaunchCallOp(op))
183 translateVulkanLaunchCall(op);
184 });
185 }
186
collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp)187 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
188 LLVM::CallOp vulkanLaunchCallOp) {
189 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
190 // for the given vulkan launch call.
191 auto spirvBlobAttr =
192 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
193 if (!spirvBlobAttr) {
194 vulkanLaunchCallOp.emitError()
195 << "missing " << kSPIRVBlobAttrName << " attribute";
196 return signalPassFailure();
197 }
198
199 auto spirvEntryPointNameAttr =
200 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
201 if (!spirvEntryPointNameAttr) {
202 vulkanLaunchCallOp.emitError()
203 << "missing " << kSPIRVEntryPointAttrName << " attribute";
204 return signalPassFailure();
205 }
206
207 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
208 }
209
createBindMemRefCalls(LLVM::CallOp cInterfaceVulkanLaunchCallOp,Value vulkanRuntime)210 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
211 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
212 if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
213 kVulkanLaunchNumConfigOperands)
214 return;
215 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
216 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
217
218 // Create LLVM constant for the descriptor set index.
219 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
220 // pass does.
221 Value descriptorSet = builder.create<LLVM::ConstantOp>(
222 loc, getInt32Type(), builder.getI32IntegerAttr(0));
223
224 for (auto en :
225 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
226 kVulkanLaunchNumConfigOperands))) {
227 // Create LLVM constant for the descriptor binding index.
228 Value descriptorBinding = builder.create<LLVM::ConstantOp>(
229 loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
230
231 auto ptrToMemRefDescriptor = en.value();
232 uint32_t rank = 0;
233 Type type;
234 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
235 cInterfaceVulkanLaunchCallOp.emitError()
236 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
237 return signalPassFailure();
238 }
239
240 auto symbolName =
241 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
242 // Special case for fp16 type. Since it is not a supported type in C we use
243 // int16_t and bitcast the descriptor.
244 if (type.isa<Float16Type>()) {
245 auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
246 ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
247 loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
248 }
249 // Create call to `bindMemRef`.
250 builder.create<LLVM::CallOp>(
251 loc, TypeRange{getVoidType()},
252 builder.getSymbolRefAttr(
253 StringRef(symbolName.data(), symbolName.size())),
254 ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
255 ptrToMemRefDescriptor});
256 }
257 }
258
deduceMemRefRankAndType(Value ptrToMemRefDescriptor,uint32_t & rank,Type & type)259 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
260 Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
261 auto llvmPtrDescriptorTy =
262 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
263 if (!llvmPtrDescriptorTy)
264 return failure();
265
266 auto llvmDescriptorTy =
267 llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
268 // template <typename Elem, size_t Rank>
269 // struct {
270 // Elem *allocated;
271 // Elem *aligned;
272 // int64_t offset;
273 // int64_t sizes[Rank]; // omitted when rank == 0
274 // int64_t strides[Rank]; // omitted when rank == 0
275 // };
276 if (!llvmDescriptorTy)
277 return failure();
278
279 type = llvmDescriptorTy.getBody()[0]
280 .cast<LLVM::LLVMPointerType>()
281 .getElementType();
282 if (llvmDescriptorTy.getBody().size() == 3) {
283 rank = 0;
284 return success();
285 }
286 rank = llvmDescriptorTy.getBody()[3]
287 .cast<LLVM::LLVMArrayType>()
288 .getNumElements();
289 return success();
290 }
291
declareVulkanFunctions(Location loc)292 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
293 ModuleOp module = getOperation();
294 auto builder = OpBuilder::atBlockEnd(module.getBody());
295
296 if (!module.lookupSymbol(kSetEntryPoint)) {
297 builder.create<LLVM::LLVMFuncOp>(
298 loc, kSetEntryPoint,
299 LLVM::LLVMFunctionType::get(getVoidType(),
300 {getPointerType(), getPointerType()}));
301 }
302
303 if (!module.lookupSymbol(kSetNumWorkGroups)) {
304 builder.create<LLVM::LLVMFuncOp>(
305 loc, kSetNumWorkGroups,
306 LLVM::LLVMFunctionType::get(getVoidType(),
307 {getPointerType(), getInt64Type(),
308 getInt64Type(), getInt64Type()}));
309 }
310
311 if (!module.lookupSymbol(kSetBinaryShader)) {
312 builder.create<LLVM::LLVMFuncOp>(
313 loc, kSetBinaryShader,
314 LLVM::LLVMFunctionType::get(
315 getVoidType(),
316 {getPointerType(), getPointerType(), getInt32Type()}));
317 }
318
319 if (!module.lookupSymbol(kRunOnVulkan)) {
320 builder.create<LLVM::LLVMFuncOp>(
321 loc, kRunOnVulkan,
322 LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
323 }
324
325 for (unsigned i = 1; i <= 3; i++) {
326 SmallVector<Type, 5> types{
327 Float32Type::get(&getContext()), IntegerType::get(&getContext(), 32),
328 IntegerType::get(&getContext(), 16), IntegerType::get(&getContext(), 8),
329 Float16Type::get(&getContext())};
330 for (auto type : types) {
331 std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
332 std::string(stringifyType(type));
333 if (type.isa<Float16Type>())
334 type = IntegerType::get(&getContext(), 16);
335 if (!module.lookupSymbol(fnName)) {
336 auto fnType = LLVM::LLVMFunctionType::get(
337 getVoidType(),
338 {getPointerType(), getInt32Type(), getInt32Type(),
339 LLVM::LLVMPointerType::get(getMemRefType(i, type))},
340 /*isVarArg=*/false);
341 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
342 }
343 }
344 }
345
346 if (!module.lookupSymbol(kInitVulkan)) {
347 builder.create<LLVM::LLVMFuncOp>(
348 loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
349 }
350
351 if (!module.lookupSymbol(kDeinitVulkan)) {
352 builder.create<LLVM::LLVMFuncOp>(
353 loc, kDeinitVulkan,
354 LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
355 }
356 }
357
createEntryPointNameConstant(StringRef name,Location loc,OpBuilder & builder)358 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
359 StringRef name, Location loc, OpBuilder &builder) {
360 SmallString<16> shaderName(name.begin(), name.end());
361 // Append `\0` to follow C style string given that LLVM::createGlobalString()
362 // won't handle this directly for us.
363 shaderName.push_back('\0');
364
365 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
366 return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
367 shaderName, LLVM::Linkage::Internal);
368 }
369
translateVulkanLaunchCall(LLVM::CallOp cInterfaceVulkanLaunchCallOp)370 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
371 LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
372 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
373 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
374 // Create call to `initVulkan`.
375 auto initVulkanCall = builder.create<LLVM::CallOp>(
376 loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan),
377 ValueRange{});
378 // The result of `initVulkan` function is a pointer to Vulkan runtime, we
379 // need to pass that pointer to each Vulkan runtime call.
380 auto vulkanRuntime = initVulkanCall.getResult(0);
381
382 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
383 // that data to runtime call.
384 Value ptrToSPIRVBinary = LLVM::createGlobalString(
385 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
386 LLVM::Linkage::Internal);
387
388 // Create LLVM constant for the size of SPIR-V binary shader.
389 Value binarySize = builder.create<LLVM::ConstantOp>(
390 loc, getInt32Type(),
391 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
392
393 // Create call to `bindMemRef` for each memref operand.
394 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
395
396 // Create call to `setBinaryShader` runtime function with the given pointer to
397 // SPIR-V binary and binary size.
398 builder.create<LLVM::CallOp>(
399 loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader),
400 ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
401 // Create LLVM global with entry point name.
402 Value entryPointName = createEntryPointNameConstant(
403 spirvAttributes.second.getValue(), loc, builder);
404 // Create call to `setEntryPoint` runtime function with the given pointer to
405 // entry point name.
406 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
407 builder.getSymbolRefAttr(kSetEntryPoint),
408 ValueRange{vulkanRuntime, entryPointName});
409
410 // Create number of local workgroup for each dimension.
411 builder.create<LLVM::CallOp>(
412 loc, TypeRange{getVoidType()},
413 builder.getSymbolRefAttr(kSetNumWorkGroups),
414 ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
415 cInterfaceVulkanLaunchCallOp.getOperand(1),
416 cInterfaceVulkanLaunchCallOp.getOperand(2)});
417
418 // Create call to `runOnVulkan` runtime function.
419 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
420 builder.getSymbolRefAttr(kRunOnVulkan),
421 ValueRange{vulkanRuntime});
422
423 // Create call to 'deinitVulkan' runtime function.
424 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
425 builder.getSymbolRefAttr(kDeinitVulkan),
426 ValueRange{vulkanRuntime});
427
428 // Declare runtime functions.
429 declareVulkanFunctions(loc);
430
431 cInterfaceVulkanLaunchCallOp.erase();
432 }
433
434 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertVulkanLaunchFuncToVulkanCallsPass()435 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
436 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
437 }
438