1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 
29 static constexpr const char *kCInterfaceVulkanLaunch =
30     "_mlir_ciface_vulkanLaunch";
31 static constexpr const char *kDeinitVulkan = "deinitVulkan";
32 static constexpr const char *kRunOnVulkan = "runOnVulkan";
33 static constexpr const char *kInitVulkan = "initVulkan";
34 static constexpr const char *kSetBinaryShader = "setBinaryShader";
35 static constexpr const char *kSetEntryPoint = "setEntryPoint";
36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
38 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
39 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
40 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
41 
42 namespace {
43 
44 /// A pass to convert vulkan launch call op into a sequence of Vulkan
45 /// runtime calls in the following order:
46 ///
47 /// * initVulkan           -- initializes vulkan runtime
48 /// * bindMemRef           -- binds memref
49 /// * setBinaryShader      -- sets the binary shader data
50 /// * setEntryPoint        -- sets the entry point name
51 /// * setNumWorkGroups     -- sets the number of a local workgroups
52 /// * runOnVulkan          -- runs vulkan runtime
53 /// * deinitVulkan         -- deinitializes vulkan runtime
54 ///
55 class VulkanLaunchFuncToVulkanCallsPass
56     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
57           VulkanLaunchFuncToVulkanCallsPass> {
58 private:
initializeCachedTypes()59   void initializeCachedTypes() {
60     llvmFloatType = Float32Type::get(&getContext());
61     llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
62     llvmPointerType =
63         LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
64     llvmInt32Type = IntegerType::get(&getContext(), 32);
65     llvmInt64Type = IntegerType::get(&getContext(), 64);
66   }
67 
getMemRefType(uint32_t rank,Type elemenType)68   Type getMemRefType(uint32_t rank, Type elemenType) {
69     // According to the MLIR doc memref argument is converted into a
70     // pointer-to-struct argument of type:
71     // template <typename Elem, size_t Rank>
72     // struct {
73     //   Elem *allocated;
74     //   Elem *aligned;
75     //   int64_t offset;
76     //   int64_t sizes[Rank]; // omitted when rank == 0
77     //   int64_t strides[Rank]; // omitted when rank == 0
78     // };
79     auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
80     auto llvmArrayRankElementSizeType =
81         LLVM::LLVMArrayType::get(getInt64Type(), rank);
82 
83     // Create a type
84     // `!llvm<"{ `element-type`*, `element-type`*, i64,
85     // [`rank` x i64], [`rank` x i64]}">`.
86     return LLVM::LLVMStructType::getLiteral(
87         &getContext(),
88         {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
89          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
90   }
91 
getVoidType()92   Type getVoidType() { return llvmVoidType; }
getPointerType()93   Type getPointerType() { return llvmPointerType; }
getInt32Type()94   Type getInt32Type() { return llvmInt32Type; }
getInt64Type()95   Type getInt64Type() { return llvmInt64Type; }
96 
97   /// Creates an LLVM global for the given `name`.
98   Value createEntryPointNameConstant(StringRef name, Location loc,
99                                      OpBuilder &builder);
100 
101   /// Declares all needed runtime functions.
102   void declareVulkanFunctions(Location loc);
103 
104   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
isVulkanLaunchCallOp(LLVM::CallOp callOp)105   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
106     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
107             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
108   }
109 
110   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
111   /// op.
isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp)112   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
113     return (callOp.callee() &&
114             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
115             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
116   }
117 
118   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
119   /// runtime calls.
120   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
121 
122   /// Creates call to `bindMemRef` for each memref operand.
123   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
124                              Value vulkanRuntime);
125 
126   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
127   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
128 
129   /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
130   LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
131                                         uint32_t &rank, Type &type);
132 
133   /// Returns a string representation from the given `type`.
stringifyType(Type type)134   StringRef stringifyType(Type type) {
135     if (type.isa<Float32Type>())
136       return "Float";
137     if (type.isa<Float16Type>())
138       return "Half";
139     if (auto intType = type.dyn_cast<IntegerType>()) {
140       if (intType.getWidth() == 32)
141         return "Int32";
142       if (intType.getWidth() == 16)
143         return "Int16";
144       if (intType.getWidth() == 8)
145         return "Int8";
146     }
147 
148     llvm_unreachable("unsupported type");
149   }
150 
151 public:
152   void runOnOperation() override;
153 
154 private:
155   Type llvmFloatType;
156   Type llvmVoidType;
157   Type llvmPointerType;
158   Type llvmInt32Type;
159   Type llvmInt64Type;
160 
161   // TODO: Use an associative array to support multiple vulkan launch calls.
162   std::pair<StringAttr, StringAttr> spirvAttributes;
163   /// The number of vulkan launch configuration operands, placed at the leading
164   /// positions of the operand list.
165   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
166 };
167 
168 } // anonymous namespace
169 
runOnOperation()170 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
171   initializeCachedTypes();
172 
173   // Collect SPIR-V attributes such as `spirv_blob` and
174   // `spirv_entry_point_name`.
175   getOperation().walk([this](LLVM::CallOp op) {
176     if (isVulkanLaunchCallOp(op))
177       collectSPIRVAttributes(op);
178   });
179 
180   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
181   getOperation().walk([this](LLVM::CallOp op) {
182     if (isCInterfaceVulkanLaunchCallOp(op))
183       translateVulkanLaunchCall(op);
184   });
185 }
186 
collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp)187 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
188     LLVM::CallOp vulkanLaunchCallOp) {
189   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
190   // for the given vulkan launch call.
191   auto spirvBlobAttr =
192       vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
193   if (!spirvBlobAttr) {
194     vulkanLaunchCallOp.emitError()
195         << "missing " << kSPIRVBlobAttrName << " attribute";
196     return signalPassFailure();
197   }
198 
199   auto spirvEntryPointNameAttr =
200       vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
201   if (!spirvEntryPointNameAttr) {
202     vulkanLaunchCallOp.emitError()
203         << "missing " << kSPIRVEntryPointAttrName << " attribute";
204     return signalPassFailure();
205   }
206 
207   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
208 }
209 
createBindMemRefCalls(LLVM::CallOp cInterfaceVulkanLaunchCallOp,Value vulkanRuntime)210 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
211     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
212   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
213       kVulkanLaunchNumConfigOperands)
214     return;
215   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
216   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
217 
218   // Create LLVM constant for the descriptor set index.
219   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
220   // pass does.
221   Value descriptorSet = builder.create<LLVM::ConstantOp>(
222       loc, getInt32Type(), builder.getI32IntegerAttr(0));
223 
224   for (auto en :
225        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
226            kVulkanLaunchNumConfigOperands))) {
227     // Create LLVM constant for the descriptor binding index.
228     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
229         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
230 
231     auto ptrToMemRefDescriptor = en.value();
232     uint32_t rank = 0;
233     Type type;
234     if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
235       cInterfaceVulkanLaunchCallOp.emitError()
236           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
237       return signalPassFailure();
238     }
239 
240     auto symbolName =
241         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
242     // Special case for fp16 type. Since it is not a supported type in C we use
243     // int16_t and bitcast the descriptor.
244     if (type.isa<Float16Type>()) {
245       auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
246       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
247           loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
248     }
249     // Create call to `bindMemRef`.
250     builder.create<LLVM::CallOp>(
251         loc, TypeRange{getVoidType()},
252         builder.getSymbolRefAttr(
253             StringRef(symbolName.data(), symbolName.size())),
254         ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
255                    ptrToMemRefDescriptor});
256   }
257 }
258 
deduceMemRefRankAndType(Value ptrToMemRefDescriptor,uint32_t & rank,Type & type)259 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
260     Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
261   auto llvmPtrDescriptorTy =
262       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
263   if (!llvmPtrDescriptorTy)
264     return failure();
265 
266   auto llvmDescriptorTy =
267       llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
268   // template <typename Elem, size_t Rank>
269   // struct {
270   //   Elem *allocated;
271   //   Elem *aligned;
272   //   int64_t offset;
273   //   int64_t sizes[Rank]; // omitted when rank == 0
274   //   int64_t strides[Rank]; // omitted when rank == 0
275   // };
276   if (!llvmDescriptorTy)
277     return failure();
278 
279   type = llvmDescriptorTy.getBody()[0]
280              .cast<LLVM::LLVMPointerType>()
281              .getElementType();
282   if (llvmDescriptorTy.getBody().size() == 3) {
283     rank = 0;
284     return success();
285   }
286   rank = llvmDescriptorTy.getBody()[3]
287              .cast<LLVM::LLVMArrayType>()
288              .getNumElements();
289   return success();
290 }
291 
declareVulkanFunctions(Location loc)292 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
293   ModuleOp module = getOperation();
294   auto builder = OpBuilder::atBlockEnd(module.getBody());
295 
296   if (!module.lookupSymbol(kSetEntryPoint)) {
297     builder.create<LLVM::LLVMFuncOp>(
298         loc, kSetEntryPoint,
299         LLVM::LLVMFunctionType::get(getVoidType(),
300                                     {getPointerType(), getPointerType()}));
301   }
302 
303   if (!module.lookupSymbol(kSetNumWorkGroups)) {
304     builder.create<LLVM::LLVMFuncOp>(
305         loc, kSetNumWorkGroups,
306         LLVM::LLVMFunctionType::get(getVoidType(),
307                                     {getPointerType(), getInt64Type(),
308                                      getInt64Type(), getInt64Type()}));
309   }
310 
311   if (!module.lookupSymbol(kSetBinaryShader)) {
312     builder.create<LLVM::LLVMFuncOp>(
313         loc, kSetBinaryShader,
314         LLVM::LLVMFunctionType::get(
315             getVoidType(),
316             {getPointerType(), getPointerType(), getInt32Type()}));
317   }
318 
319   if (!module.lookupSymbol(kRunOnVulkan)) {
320     builder.create<LLVM::LLVMFuncOp>(
321         loc, kRunOnVulkan,
322         LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
323   }
324 
325   for (unsigned i = 1; i <= 3; i++) {
326     SmallVector<Type, 5> types{
327         Float32Type::get(&getContext()), IntegerType::get(&getContext(), 32),
328         IntegerType::get(&getContext(), 16), IntegerType::get(&getContext(), 8),
329         Float16Type::get(&getContext())};
330     for (auto type : types) {
331       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
332                            std::string(stringifyType(type));
333       if (type.isa<Float16Type>())
334         type = IntegerType::get(&getContext(), 16);
335       if (!module.lookupSymbol(fnName)) {
336         auto fnType = LLVM::LLVMFunctionType::get(
337             getVoidType(),
338             {getPointerType(), getInt32Type(), getInt32Type(),
339              LLVM::LLVMPointerType::get(getMemRefType(i, type))},
340             /*isVarArg=*/false);
341         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
342       }
343     }
344   }
345 
346   if (!module.lookupSymbol(kInitVulkan)) {
347     builder.create<LLVM::LLVMFuncOp>(
348         loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
349   }
350 
351   if (!module.lookupSymbol(kDeinitVulkan)) {
352     builder.create<LLVM::LLVMFuncOp>(
353         loc, kDeinitVulkan,
354         LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
355   }
356 }
357 
createEntryPointNameConstant(StringRef name,Location loc,OpBuilder & builder)358 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
359     StringRef name, Location loc, OpBuilder &builder) {
360   SmallString<16> shaderName(name.begin(), name.end());
361   // Append `\0` to follow C style string given that LLVM::createGlobalString()
362   // won't handle this directly for us.
363   shaderName.push_back('\0');
364 
365   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
366   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
367                                   shaderName, LLVM::Linkage::Internal);
368 }
369 
translateVulkanLaunchCall(LLVM::CallOp cInterfaceVulkanLaunchCallOp)370 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
371     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
372   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
373   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
374   // Create call to `initVulkan`.
375   auto initVulkanCall = builder.create<LLVM::CallOp>(
376       loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan),
377       ValueRange{});
378   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
379   // need to pass that pointer to each Vulkan runtime call.
380   auto vulkanRuntime = initVulkanCall.getResult(0);
381 
382   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
383   // that data to runtime call.
384   Value ptrToSPIRVBinary = LLVM::createGlobalString(
385       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
386       LLVM::Linkage::Internal);
387 
388   // Create LLVM constant for the size of SPIR-V binary shader.
389   Value binarySize = builder.create<LLVM::ConstantOp>(
390       loc, getInt32Type(),
391       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
392 
393   // Create call to `bindMemRef` for each memref operand.
394   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
395 
396   // Create call to `setBinaryShader` runtime function with the given pointer to
397   // SPIR-V binary and binary size.
398   builder.create<LLVM::CallOp>(
399       loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader),
400       ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
401   // Create LLVM global with entry point name.
402   Value entryPointName = createEntryPointNameConstant(
403       spirvAttributes.second.getValue(), loc, builder);
404   // Create call to `setEntryPoint` runtime function with the given pointer to
405   // entry point name.
406   builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
407                                builder.getSymbolRefAttr(kSetEntryPoint),
408                                ValueRange{vulkanRuntime, entryPointName});
409 
410   // Create number of local workgroup for each dimension.
411   builder.create<LLVM::CallOp>(
412       loc, TypeRange{getVoidType()},
413       builder.getSymbolRefAttr(kSetNumWorkGroups),
414       ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
415                  cInterfaceVulkanLaunchCallOp.getOperand(1),
416                  cInterfaceVulkanLaunchCallOp.getOperand(2)});
417 
418   // Create call to `runOnVulkan` runtime function.
419   builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
420                                builder.getSymbolRefAttr(kRunOnVulkan),
421                                ValueRange{vulkanRuntime});
422 
423   // Create call to 'deinitVulkan' runtime function.
424   builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
425                                builder.getSymbolRefAttr(kDeinitVulkan),
426                                ValueRange{vulkanRuntime});
427 
428   // Declare runtime functions.
429   declareVulkanFunctions(loc);
430 
431   cInterfaceVulkanLaunchCallOp.erase();
432 }
433 
434 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertVulkanLaunchFuncToVulkanCallsPass()435 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
436   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
437 }
438