1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17 
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"
22 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
23 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
24 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
25 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
26 #include "mlir/Dialect/Async/IR/Async.h"
27 #include "mlir/Dialect/GPU/GPUDialect.h"
28 #include "mlir/Dialect/GPU/Passes.h"
29 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/BuiltinOps.h"
33 #include "mlir/IR/BuiltinTypes.h"
34 
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/Error.h"
37 #include "llvm/Support/FormatVariadic.h"
38 
39 using namespace mlir;
40 
41 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
42 
43 namespace {
44 
45 class GpuToLLVMConversionPass
46     : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
47 public:
48   GpuToLLVMConversionPass() = default;
49 
GpuToLLVMConversionPass(const GpuToLLVMConversionPass & other)50   GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
51       : GpuToLLVMConversionPassBase(other) {}
52 
53   // Run the dialect converter on the module.
54   void runOnOperation() override;
55 
56 private:
57   Option<std::string> gpuBinaryAnnotation{
58       *this, "gpu-binary-annotation",
59       llvm::cl::desc("Annotation attribute string for GPU binary"),
60       llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
61 };
62 
63 struct FunctionCallBuilder {
FunctionCallBuilder__anon4001de4b0111::FunctionCallBuilder64   FunctionCallBuilder(StringRef functionName, Type returnType,
65                       ArrayRef<Type> argumentTypes)
66       : functionName(functionName),
67         functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
68   LLVM::CallOp create(Location loc, OpBuilder &builder,
69                       ArrayRef<Value> arguments) const;
70 
71   StringRef functionName;
72   LLVM::LLVMFunctionType functionType;
73 };
74 
75 template <typename OpTy>
76 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
77 public:
ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)78   explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
79       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
80 
81 protected:
getNumElements(ConversionPatternRewriter & rewriter,Location loc,MemRefType type,MemRefDescriptor desc) const82   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
83                        MemRefType type, MemRefDescriptor desc) const {
84     return type.hasStaticShape()
85                ? ConvertToLLVMPattern::createIndexConstant(
86                      rewriter, loc, type.getNumElements())
87                // For identity maps (verified by caller), the number of
88                // elements is stride[0] * size[0].
89                : rewriter.create<LLVM::MulOp>(loc,
90                                               desc.stride(rewriter, loc, 0),
91                                               desc.size(rewriter, loc, 0));
92   }
93 
94   MLIRContext *context = &this->getTypeConverter()->getContext();
95 
96   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
97   Type llvmPointerType =
98       LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
99   Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
100   Type llvmInt8Type = IntegerType::get(context, 8);
101   Type llvmInt32Type = IntegerType::get(context, 32);
102   Type llvmInt64Type = IntegerType::get(context, 64);
103   Type llvmIntPtrType = IntegerType::get(
104       context, this->getTypeConverter()->getPointerBitwidth(0));
105 
106   FunctionCallBuilder moduleLoadCallBuilder = {
107       "mgpuModuleLoad",
108       llvmPointerType /* void *module */,
109       {llvmPointerType /* void *cubin */}};
110   FunctionCallBuilder moduleUnloadCallBuilder = {
111       "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
112   FunctionCallBuilder moduleGetFunctionCallBuilder = {
113       "mgpuModuleGetFunction",
114       llvmPointerType /* void *function */,
115       {
116           llvmPointerType, /* void *module */
117           llvmPointerType  /* char *name   */
118       }};
119   FunctionCallBuilder launchKernelCallBuilder = {
120       "mgpuLaunchKernel",
121       llvmVoidType,
122       {
123           llvmPointerType,        /* void* f */
124           llvmIntPtrType,         /* intptr_t gridXDim */
125           llvmIntPtrType,         /* intptr_t gridyDim */
126           llvmIntPtrType,         /* intptr_t gridZDim */
127           llvmIntPtrType,         /* intptr_t blockXDim */
128           llvmIntPtrType,         /* intptr_t blockYDim */
129           llvmIntPtrType,         /* intptr_t blockZDim */
130           llvmInt32Type,          /* unsigned int sharedMemBytes */
131           llvmPointerType,        /* void *hstream */
132           llvmPointerPointerType, /* void **kernelParams */
133           llvmPointerPointerType  /* void **extra */
134       }};
135   FunctionCallBuilder streamCreateCallBuilder = {
136       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
137   FunctionCallBuilder streamDestroyCallBuilder = {
138       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
139   FunctionCallBuilder streamSynchronizeCallBuilder = {
140       "mgpuStreamSynchronize",
141       llvmVoidType,
142       {llvmPointerType /* void *stream */}};
143   FunctionCallBuilder streamWaitEventCallBuilder = {
144       "mgpuStreamWaitEvent",
145       llvmVoidType,
146       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
147   FunctionCallBuilder eventCreateCallBuilder = {
148       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
149   FunctionCallBuilder eventDestroyCallBuilder = {
150       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
151   FunctionCallBuilder eventSynchronizeCallBuilder = {
152       "mgpuEventSynchronize",
153       llvmVoidType,
154       {llvmPointerType /* void *event */}};
155   FunctionCallBuilder eventRecordCallBuilder = {
156       "mgpuEventRecord",
157       llvmVoidType,
158       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
159   FunctionCallBuilder hostRegisterCallBuilder = {
160       "mgpuMemHostRegisterMemRef",
161       llvmVoidType,
162       {llvmIntPtrType /* intptr_t rank */,
163        llvmPointerType /* void *memrefDesc */,
164        llvmIntPtrType /* intptr_t elementSizeBytes */}};
165   FunctionCallBuilder allocCallBuilder = {
166       "mgpuMemAlloc",
167       llvmPointerType /* void * */,
168       {llvmIntPtrType /* intptr_t sizeBytes */,
169        llvmPointerType /* void *stream */}};
170   FunctionCallBuilder deallocCallBuilder = {
171       "mgpuMemFree",
172       llvmVoidType,
173       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
174   FunctionCallBuilder memcpyCallBuilder = {
175       "mgpuMemcpy",
176       llvmVoidType,
177       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
178        llvmIntPtrType /* intptr_t sizeBytes */,
179        llvmPointerType /* void *stream */}};
180   FunctionCallBuilder memsetCallBuilder = {
181       "mgpuMemset32",
182       llvmVoidType,
183       {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
184        llvmIntPtrType /* intptr_t sizeBytes */,
185        llvmPointerType /* void *stream */}};
186 };
187 
188 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
189 /// call. Currently it supports CUDA and ROCm (HIP).
190 class ConvertHostRegisterOpToGpuRuntimeCallPattern
191     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
192 public:
ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)193   ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
194       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
195 
196 private:
197   LogicalResult
198   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
199                   ConversionPatternRewriter &rewriter) const override;
200 };
201 
202 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
203 /// call. Currently it supports CUDA and ROCm (HIP).
204 class ConvertAllocOpToGpuRuntimeCallPattern
205     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
206 public:
ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)207   ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
208       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
209 
210 private:
211   LogicalResult
212   matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
213                   ConversionPatternRewriter &rewriter) const override;
214 };
215 
216 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
217 /// call. Currently it supports CUDA and ROCm (HIP).
218 class ConvertDeallocOpToGpuRuntimeCallPattern
219     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
220 public:
ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)221   ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
222       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
223 
224 private:
225   LogicalResult
226   matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
227                   ConversionPatternRewriter &rewriter) const override;
228 };
229 
230 class ConvertAsyncYieldToGpuRuntimeCallPattern
231     : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
232 public:
ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)233   ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
234       : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
235 
236 private:
237   LogicalResult
238   matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
239                   ConversionPatternRewriter &rewriter) const override;
240 };
241 
242 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
243 /// call. Currently it supports CUDA and ROCm (HIP).
244 class ConvertWaitOpToGpuRuntimeCallPattern
245     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
246 public:
ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)247   ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
248       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
249 
250 private:
251   LogicalResult
252   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
253                   ConversionPatternRewriter &rewriter) const override;
254 };
255 
256 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
257 /// call. Currently it supports CUDA and ROCm (HIP).
258 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
259     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
260 public:
ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)261   ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
262       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
263 
264 private:
265   LogicalResult
266   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
267                   ConversionPatternRewriter &rewriter) const override;
268 };
269 
270 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
271 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
272 ///
273 /// In essence, a gpu.launch_func operations gets compiled into the following
274 /// sequence of runtime calls:
275 ///
276 /// * moduleLoad        -- loads the module given the cubin / hsaco data
277 /// * moduleGetFunction -- gets a handle to the actual kernel function
278 /// * getStreamHelper   -- initializes a new compute stream on GPU
279 /// * launchKernel      -- launches the kernel on a stream
280 /// * streamSynchronize -- waits for operations on the stream to finish
281 ///
282 /// Intermediate data structures are allocated on the stack.
283 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
284     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
285 public:
ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter,StringRef gpuBinaryAnnotation)286   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
287                                              StringRef gpuBinaryAnnotation)
288       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
289         gpuBinaryAnnotation(gpuBinaryAnnotation) {}
290 
291 private:
292   Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
293                             OpBuilder &builder) const;
294   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
295                                    Location loc, OpBuilder &builder) const;
296 
297   LogicalResult
298   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
299                   ConversionPatternRewriter &rewriter) const override;
300 
301   llvm::SmallString<32> gpuBinaryAnnotation;
302 };
303 
304 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
305   using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
306 
matchAndRewrite(gpu::GPUModuleOp op,PatternRewriter & rewriter) const307   LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
308                                 PatternRewriter &rewriter) const override {
309     // GPU kernel modules are no longer necessary since we have a global
310     // constant with the CUBIN, or HSACO data.
311     rewriter.eraseOp(op);
312     return success();
313   }
314 };
315 
316 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
317 /// call. Currently it supports CUDA and ROCm (HIP).
318 class ConvertMemcpyOpToGpuRuntimeCallPattern
319     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
320 public:
ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)321   ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
322       : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
323 
324 private:
325   LogicalResult
326   matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
327                   ConversionPatternRewriter &rewriter) const override;
328 };
329 
330 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
331 /// call. Currently it supports CUDA and ROCm (HIP).
332 class ConvertMemsetOpToGpuRuntimeCallPattern
333     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
334 public:
ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)335   ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
336       : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
337 
338 private:
339   LogicalResult
340   matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
341                   ConversionPatternRewriter &rewriter) const override;
342 };
343 } // namespace
344 
runOnOperation()345 void GpuToLLVMConversionPass::runOnOperation() {
346   LLVMTypeConverter converter(&getContext());
347   RewritePatternSet patterns(&getContext());
348   LLVMConversionTarget target(getContext());
349 
350   target.addIllegalDialect<gpu::GPUDialect>();
351 
352   populateVectorToLLVMConversionPatterns(converter, patterns);
353   populateMemRefToLLVMConversionPatterns(converter, patterns);
354   populateStdToLLVMConversionPatterns(converter, patterns);
355   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
356                                                     target);
357   populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
358 
359   if (failed(
360           applyPartialConversion(getOperation(), target, std::move(patterns))))
361     signalPassFailure();
362 }
363 
create(Location loc,OpBuilder & builder,ArrayRef<Value> arguments) const364 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
365                                          ArrayRef<Value> arguments) const {
366   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
367   auto function = [&] {
368     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
369       return function;
370     return OpBuilder::atBlockEnd(module.getBody())
371         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
372   }();
373   return builder.create<LLVM::CallOp>(loc, function, arguments);
374 }
375 
376 // Returns whether all operands are of LLVM type.
areAllLLVMTypes(Operation * op,ValueRange operands,ConversionPatternRewriter & rewriter)377 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
378                                      ConversionPatternRewriter &rewriter) {
379   if (!llvm::all_of(operands, [](Value value) {
380         return LLVM::isCompatibleType(value.getType());
381       }))
382     return rewriter.notifyMatchFailure(
383         op, "Cannot convert if operands aren't of LLVM type.");
384   return success();
385 }
386 
387 static LogicalResult
isAsyncWithOneDependency(ConversionPatternRewriter & rewriter,gpu::AsyncOpInterface op)388 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
389                          gpu::AsyncOpInterface op) {
390   if (op.getAsyncDependencies().size() != 1)
391     return rewriter.notifyMatchFailure(
392         op, "Can only convert with exactly one async dependency.");
393 
394   if (!op.getAsyncToken())
395     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
396 
397   return success();
398 }
399 
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const400 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
401     gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
402     ConversionPatternRewriter &rewriter) const {
403   auto *op = hostRegisterOp.getOperation();
404   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
405     return failure();
406 
407   Location loc = op->getLoc();
408 
409   auto memRefType = hostRegisterOp.value().getType();
410   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
411   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
412 
413   auto arguments = getTypeConverter()->promoteOperands(
414       loc, op->getOperands(), adaptor.getOperands(), rewriter);
415   arguments.push_back(elementSize);
416   hostRegisterCallBuilder.create(loc, rewriter, arguments);
417 
418   rewriter.eraseOp(op);
419   return success();
420 }
421 
matchAndRewrite(gpu::AllocOp allocOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const422 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
423     gpu::AllocOp allocOp, OpAdaptor adaptor,
424     ConversionPatternRewriter &rewriter) const {
425   MemRefType memRefType = allocOp.getType();
426 
427   if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
428       !isConvertibleAndHasIdentityMaps(memRefType) ||
429       failed(isAsyncWithOneDependency(rewriter, allocOp)))
430     return failure();
431 
432   auto loc = allocOp.getLoc();
433 
434   // Get shape of the memref as values: static sizes are constant
435   // values and dynamic sizes are passed to 'alloc' as operands.
436   SmallVector<Value, 4> shape;
437   SmallVector<Value, 4> strides;
438   Value sizeBytes;
439   getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter,
440                            shape, strides, sizeBytes);
441 
442   // Allocate the underlying buffer and store a pointer to it in the MemRef
443   // descriptor.
444   Type elementPtrType = this->getElementPtrType(memRefType);
445   auto stream = adaptor.asyncDependencies().front();
446   Value allocatedPtr =
447       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
448   allocatedPtr =
449       rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
450 
451   // No alignment.
452   Value alignedPtr = allocatedPtr;
453 
454   // Create the MemRef descriptor.
455   auto memRefDescriptor = this->createMemRefDescriptor(
456       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
457 
458   rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
459 
460   return success();
461 }
462 
matchAndRewrite(gpu::DeallocOp deallocOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const463 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
464     gpu::DeallocOp deallocOp, OpAdaptor adaptor,
465     ConversionPatternRewriter &rewriter) const {
466   if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
467       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
468     return failure();
469 
470   Location loc = deallocOp.getLoc();
471 
472   Value pointer =
473       MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
474   auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
475   Value stream = adaptor.asyncDependencies().front();
476   deallocCallBuilder.create(loc, rewriter, {casted, stream});
477 
478   rewriter.replaceOp(deallocOp, {stream});
479   return success();
480 }
481 
isGpuAsyncTokenType(Value value)482 static bool isGpuAsyncTokenType(Value value) {
483   return value.getType().isa<gpu::AsyncTokenType>();
484 }
485 
486 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
487 // !gpu.async.token are lowered to stream within the async.execute region, but
488 // are passed as events between them. For each !gpu.async.token operand, we
489 // create an event and record it on the stream.
matchAndRewrite(async::YieldOp yieldOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const490 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
491     async::YieldOp yieldOp, OpAdaptor adaptor,
492     ConversionPatternRewriter &rewriter) const {
493   if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
494     return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
495 
496   Location loc = yieldOp.getLoc();
497   SmallVector<Value, 4> newOperands(adaptor.getOperands());
498   llvm::SmallDenseSet<Value> streams;
499   for (auto &operand : yieldOp->getOpOperands()) {
500     if (!isGpuAsyncTokenType(operand.get()))
501       continue;
502     auto idx = operand.getOperandNumber();
503     auto stream = adaptor.getOperands()[idx];
504     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
505     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
506     newOperands[idx] = event;
507     streams.insert(stream);
508   }
509   for (auto stream : streams)
510     streamDestroyCallBuilder.create(loc, rewriter, {stream});
511 
512   rewriter.updateRootInPlace(yieldOp,
513                              [&] { yieldOp->setOperands(newOperands); });
514   return success();
515 }
516 
517 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
isDefinedByCallTo(Value value,StringRef functionName)518 static bool isDefinedByCallTo(Value value, StringRef functionName) {
519   assert(value.getType().isa<LLVM::LLVMPointerType>());
520   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
521     return defOp.callee()->equals(functionName);
522   return false;
523 }
524 
525 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
526 // with the stream/event operands. The operands are destroyed. That is, it
527 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
528 // runtime error. Eventually, we should guarantee this property.
matchAndRewrite(gpu::WaitOp waitOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const529 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
530     gpu::WaitOp waitOp, OpAdaptor adaptor,
531     ConversionPatternRewriter &rewriter) const {
532   if (waitOp.asyncToken())
533     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
534 
535   Location loc = waitOp.getLoc();
536 
537   for (auto operand : adaptor.getOperands()) {
538     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
539       // The converted operand's definition created a stream.
540       streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
541       streamDestroyCallBuilder.create(loc, rewriter, {operand});
542     } else {
543       // Otherwise the converted operand is an event. This assumes that we use
544       // events in control flow code as well.
545       eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
546       eventDestroyCallBuilder.create(loc, rewriter, {operand});
547     }
548   }
549 
550   rewriter.eraseOp(waitOp);
551   return success();
552 }
553 
554 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
555 // stream that is synchronized with stream/event operands. The operands are
556 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
557 // Otherwise we will get a runtime error. Eventually, we should guarantee this
558 // property.
matchAndRewrite(gpu::WaitOp waitOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const559 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
560     gpu::WaitOp waitOp, OpAdaptor adaptor,
561     ConversionPatternRewriter &rewriter) const {
562   if (!waitOp.asyncToken())
563     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
564 
565   Location loc = waitOp.getLoc();
566 
567   auto insertionPoint = rewriter.saveInsertionPoint();
568   SmallVector<Value, 1> events;
569   for (auto pair :
570        llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) {
571     auto operand = std::get<1>(pair);
572     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
573       // The converted operand's definition created a stream. Insert an event
574       // into the stream just after the last use of the original token operand.
575       auto *defOp = std::get<0>(pair).getDefiningOp();
576       rewriter.setInsertionPointAfter(defOp);
577       auto event =
578           eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
579       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
580       events.push_back(event);
581     } else {
582       // Otherwise the converted operand is an event. This assumes that we use
583       // events in control flow code as well.
584       events.push_back(operand);
585     }
586   }
587   rewriter.restoreInsertionPoint(insertionPoint);
588   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
589   for (auto event : events)
590     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
591   for (auto event : events)
592     eventDestroyCallBuilder.create(loc, rewriter, {event});
593   rewriter.replaceOp(waitOp, {stream});
594 
595   return success();
596 }
597 
598 // Creates a struct containing all kernel parameters on the stack and returns
599 // an array of type-erased pointers to the fields of the struct. The array can
600 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
601 // The generated code is essentially as follows:
602 //
603 // %struct = alloca(sizeof(struct { Parameters... }))
604 // %array = alloca(NumParameters * sizeof(void *))
605 // for (i : [0, NumParameters))
606 //   %fieldPtr = llvm.getelementptr %struct[0, i]
607 //   llvm.store parameters[i], %fieldPtr
608 //   %elementPtr = llvm.getelementptr %array[i]
609 //   llvm.store %fieldPtr, %elementPtr
610 // return %array
generateParamsArray(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,OpBuilder & builder) const611 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
612     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
613   auto loc = launchOp.getLoc();
614   auto numKernelOperands = launchOp.getNumKernelOperands();
615   auto arguments = getTypeConverter()->promoteOperands(
616       loc, launchOp.getOperands().take_back(numKernelOperands),
617       adaptor.getOperands().take_back(numKernelOperands), builder);
618   auto numArguments = arguments.size();
619   SmallVector<Type, 4> argumentTypes;
620   argumentTypes.reserve(numArguments);
621   for (auto argument : arguments)
622     argumentTypes.push_back(argument.getType());
623   auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
624                                                            argumentTypes);
625   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
626                                               builder.getI32IntegerAttr(1));
627   auto structPtr = builder.create<LLVM::AllocaOp>(
628       loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
629   auto arraySize = builder.create<LLVM::ConstantOp>(
630       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
631   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
632                                                  arraySize, /*alignment=*/0);
633   auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
634                                                builder.getI32IntegerAttr(0));
635   for (auto en : llvm::enumerate(arguments)) {
636     auto index = builder.create<LLVM::ConstantOp>(
637         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
638     auto fieldPtr = builder.create<LLVM::GEPOp>(
639         loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
640         ArrayRef<Value>{zero, index.getResult()});
641     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
642     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
643                                                   arrayPtr, index.getResult());
644     auto casted =
645         builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
646     builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
647   }
648   return arrayPtr;
649 }
650 
651 // Generates an LLVM IR dialect global that contains the name of the given
652 // kernel function as a C string, and returns a pointer to its beginning.
653 // The code is essentially:
654 //
655 // llvm.global constant @kernel_name("function_name\00")
656 // func(...) {
657 //   %0 = llvm.addressof @kernel_name
658 //   %1 = llvm.constant (0 : index)
659 //   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
660 // }
generateKernelNameConstant(StringRef moduleName,StringRef name,Location loc,OpBuilder & builder) const661 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
662     StringRef moduleName, StringRef name, Location loc,
663     OpBuilder &builder) const {
664   // Make sure the trailing zero is included in the constant.
665   std::vector<char> kernelName(name.begin(), name.end());
666   kernelName.push_back('\0');
667 
668   std::string globalName =
669       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
670   return LLVM::createGlobalString(
671       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
672       LLVM::Linkage::Internal);
673 }
674 
675 // Emits LLVM IR to launch a kernel function. Expects the module that contains
676 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
677 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
678 //
679 // %0 = call %binarygetter
680 // %1 = call %moduleLoad(%0)
681 // %2 = <see generateKernelNameConstant>
682 // %3 = call %moduleGetFunction(%1, %2)
683 // %4 = call %streamCreate()
684 // %5 = <see generateParamsArray>
685 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
686 // call %streamSynchronize(%4)
687 // call %streamDestroy(%4)
688 // call %moduleUnload(%1)
689 //
690 // If the op is async, the stream corresponds to the (single) async dependency
691 // as well as the async token the op produces.
matchAndRewrite(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const692 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
693     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
694     ConversionPatternRewriter &rewriter) const {
695   if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
696     return failure();
697 
698   if (launchOp.asyncDependencies().size() > 1)
699     return rewriter.notifyMatchFailure(
700         launchOp, "Cannot convert with more than one async dependency.");
701 
702   // Fail when the synchronous version of the op has async dependencies. The
703   // lowering destroys the stream, and we do not want to check that there is no
704   // use of the stream after this op.
705   if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
706     return rewriter.notifyMatchFailure(
707         launchOp, "Cannot convert non-async op with async dependencies.");
708 
709   Location loc = launchOp.getLoc();
710 
711   // Create an LLVM global with CUBIN extracted from the kernel annotation and
712   // obtain a pointer to the first byte in it.
713   auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
714       launchOp, launchOp.getKernelModuleName());
715   assert(kernelModule && "expected a kernel module");
716 
717   auto binaryAttr =
718       kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
719   if (!binaryAttr) {
720     kernelModule.emitOpError()
721         << "missing " << gpuBinaryAnnotation << " attribute";
722     return failure();
723   }
724 
725   SmallString<128> nameBuffer(kernelModule.getName());
726   nameBuffer.append(kGpuBinaryStorageSuffix);
727   Value data =
728       LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
729                                binaryAttr.getValue(), LLVM::Linkage::Internal);
730 
731   auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
732   // Get the function from the module. The name corresponds to the name of
733   // the kernel function.
734   auto kernelName = generateKernelNameConstant(
735       launchOp.getKernelModuleName().getValue(),
736       launchOp.getKernelName().getValue(), loc, rewriter);
737   auto function = moduleGetFunctionCallBuilder.create(
738       loc, rewriter, {module.getResult(0), kernelName});
739   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
740                                                 rewriter.getI32IntegerAttr(0));
741   Value stream =
742       adaptor.asyncDependencies().empty()
743           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
744           : adaptor.asyncDependencies().front();
745   // Create array of pointers to kernel arguments.
746   auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
747   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
748   Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize()
749                                       ? launchOp.dynamicSharedMemorySize()
750                                       : zero;
751   launchKernelCallBuilder.create(
752       loc, rewriter,
753       {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
754        adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
755        adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
756        /*extra=*/nullpointer});
757 
758   if (launchOp.asyncToken()) {
759     // Async launch: make dependent ops use the same stream.
760     rewriter.replaceOp(launchOp, {stream});
761   } else {
762     // Synchronize with host and destroy stream. This must be the stream created
763     // above (with no other uses) because we check that the synchronous version
764     // does not have any async dependencies.
765     streamSynchronizeCallBuilder.create(loc, rewriter, stream);
766     streamDestroyCallBuilder.create(loc, rewriter, stream);
767     rewriter.eraseOp(launchOp);
768   }
769   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
770 
771   return success();
772 }
773 
matchAndRewrite(gpu::MemcpyOp memcpyOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const774 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
775     gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
776     ConversionPatternRewriter &rewriter) const {
777   auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
778 
779   if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
780       !isConvertibleAndHasIdentityMaps(memRefType) ||
781       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
782     return failure();
783 
784   auto loc = memcpyOp.getLoc();
785 
786   MemRefDescriptor srcDesc(adaptor.src());
787   Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
788 
789   Type elementPtrType = getElementPtrType(memRefType);
790   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
791   Value gepPtr = rewriter.create<LLVM::GEPOp>(
792       loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
793   auto sizeBytes =
794       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
795 
796   auto src = rewriter.create<LLVM::BitcastOp>(
797       loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
798   auto dst = rewriter.create<LLVM::BitcastOp>(
799       loc, llvmPointerType,
800       MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
801 
802   auto stream = adaptor.asyncDependencies().front();
803   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
804 
805   rewriter.replaceOp(memcpyOp, {stream});
806 
807   return success();
808 }
809 
matchAndRewrite(gpu::MemsetOp memsetOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const810 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
811     gpu::MemsetOp memsetOp, OpAdaptor adaptor,
812     ConversionPatternRewriter &rewriter) const {
813   auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
814 
815   if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
816       !isConvertibleAndHasIdentityMaps(memRefType) ||
817       failed(isAsyncWithOneDependency(rewriter, memsetOp)))
818     return failure();
819 
820   auto loc = memsetOp.getLoc();
821 
822   Type valueType = adaptor.value().getType();
823   if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
824     return rewriter.notifyMatchFailure(memsetOp,
825                                        "value must be a 32 bit scalar");
826   }
827 
828   MemRefDescriptor dstDesc(adaptor.dst());
829   Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
830 
831   auto value =
832       rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
833   auto dst = rewriter.create<LLVM::BitcastOp>(
834       loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
835 
836   auto stream = adaptor.asyncDependencies().front();
837   memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
838 
839   rewriter.replaceOp(memsetOp, {stream});
840   return success();
841 }
842 
843 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createGpuToLLVMConversionPass()844 mlir::createGpuToLLVMConversionPass() {
845   return std::make_unique<GpuToLLVMConversionPass>();
846 }
847 
populateGpuToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns,StringRef gpuBinaryAnnotation)848 void mlir::populateGpuToLLVMConversionPatterns(
849     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
850     StringRef gpuBinaryAnnotation) {
851   converter.addConversion(
852       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
853         return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
854       });
855   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
856                ConvertDeallocOpToGpuRuntimeCallPattern,
857                ConvertHostRegisterOpToGpuRuntimeCallPattern,
858                ConvertMemcpyOpToGpuRuntimeCallPattern,
859                ConvertMemsetOpToGpuRuntimeCallPattern,
860                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
861                ConvertWaitOpToGpuRuntimeCallPattern,
862                ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
863   patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
864                                                            gpuBinaryAnnotation);
865   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
866 }
867