1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"
22 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
23 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
24 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
25 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
26 #include "mlir/Dialect/Async/IR/Async.h"
27 #include "mlir/Dialect/GPU/GPUDialect.h"
28 #include "mlir/Dialect/GPU/Passes.h"
29 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
30 #include "mlir/IR/Attributes.h"
31 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/BuiltinOps.h"
33 #include "mlir/IR/BuiltinTypes.h"
34
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/Error.h"
37 #include "llvm/Support/FormatVariadic.h"
38
39 using namespace mlir;
40
41 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
42
43 namespace {
44
45 class GpuToLLVMConversionPass
46 : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
47 public:
48 GpuToLLVMConversionPass() = default;
49
GpuToLLVMConversionPass(const GpuToLLVMConversionPass & other)50 GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
51 : GpuToLLVMConversionPassBase(other) {}
52
53 // Run the dialect converter on the module.
54 void runOnOperation() override;
55
56 private:
57 Option<std::string> gpuBinaryAnnotation{
58 *this, "gpu-binary-annotation",
59 llvm::cl::desc("Annotation attribute string for GPU binary"),
60 llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
61 };
62
63 struct FunctionCallBuilder {
FunctionCallBuilder__anon4001de4b0111::FunctionCallBuilder64 FunctionCallBuilder(StringRef functionName, Type returnType,
65 ArrayRef<Type> argumentTypes)
66 : functionName(functionName),
67 functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
68 LLVM::CallOp create(Location loc, OpBuilder &builder,
69 ArrayRef<Value> arguments) const;
70
71 StringRef functionName;
72 LLVM::LLVMFunctionType functionType;
73 };
74
75 template <typename OpTy>
76 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
77 public:
ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)78 explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
79 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
80
81 protected:
getNumElements(ConversionPatternRewriter & rewriter,Location loc,MemRefType type,MemRefDescriptor desc) const82 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
83 MemRefType type, MemRefDescriptor desc) const {
84 return type.hasStaticShape()
85 ? ConvertToLLVMPattern::createIndexConstant(
86 rewriter, loc, type.getNumElements())
87 // For identity maps (verified by caller), the number of
88 // elements is stride[0] * size[0].
89 : rewriter.create<LLVM::MulOp>(loc,
90 desc.stride(rewriter, loc, 0),
91 desc.size(rewriter, loc, 0));
92 }
93
94 MLIRContext *context = &this->getTypeConverter()->getContext();
95
96 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
97 Type llvmPointerType =
98 LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
99 Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
100 Type llvmInt8Type = IntegerType::get(context, 8);
101 Type llvmInt32Type = IntegerType::get(context, 32);
102 Type llvmInt64Type = IntegerType::get(context, 64);
103 Type llvmIntPtrType = IntegerType::get(
104 context, this->getTypeConverter()->getPointerBitwidth(0));
105
106 FunctionCallBuilder moduleLoadCallBuilder = {
107 "mgpuModuleLoad",
108 llvmPointerType /* void *module */,
109 {llvmPointerType /* void *cubin */}};
110 FunctionCallBuilder moduleUnloadCallBuilder = {
111 "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
112 FunctionCallBuilder moduleGetFunctionCallBuilder = {
113 "mgpuModuleGetFunction",
114 llvmPointerType /* void *function */,
115 {
116 llvmPointerType, /* void *module */
117 llvmPointerType /* char *name */
118 }};
119 FunctionCallBuilder launchKernelCallBuilder = {
120 "mgpuLaunchKernel",
121 llvmVoidType,
122 {
123 llvmPointerType, /* void* f */
124 llvmIntPtrType, /* intptr_t gridXDim */
125 llvmIntPtrType, /* intptr_t gridyDim */
126 llvmIntPtrType, /* intptr_t gridZDim */
127 llvmIntPtrType, /* intptr_t blockXDim */
128 llvmIntPtrType, /* intptr_t blockYDim */
129 llvmIntPtrType, /* intptr_t blockZDim */
130 llvmInt32Type, /* unsigned int sharedMemBytes */
131 llvmPointerType, /* void *hstream */
132 llvmPointerPointerType, /* void **kernelParams */
133 llvmPointerPointerType /* void **extra */
134 }};
135 FunctionCallBuilder streamCreateCallBuilder = {
136 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
137 FunctionCallBuilder streamDestroyCallBuilder = {
138 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
139 FunctionCallBuilder streamSynchronizeCallBuilder = {
140 "mgpuStreamSynchronize",
141 llvmVoidType,
142 {llvmPointerType /* void *stream */}};
143 FunctionCallBuilder streamWaitEventCallBuilder = {
144 "mgpuStreamWaitEvent",
145 llvmVoidType,
146 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
147 FunctionCallBuilder eventCreateCallBuilder = {
148 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
149 FunctionCallBuilder eventDestroyCallBuilder = {
150 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
151 FunctionCallBuilder eventSynchronizeCallBuilder = {
152 "mgpuEventSynchronize",
153 llvmVoidType,
154 {llvmPointerType /* void *event */}};
155 FunctionCallBuilder eventRecordCallBuilder = {
156 "mgpuEventRecord",
157 llvmVoidType,
158 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
159 FunctionCallBuilder hostRegisterCallBuilder = {
160 "mgpuMemHostRegisterMemRef",
161 llvmVoidType,
162 {llvmIntPtrType /* intptr_t rank */,
163 llvmPointerType /* void *memrefDesc */,
164 llvmIntPtrType /* intptr_t elementSizeBytes */}};
165 FunctionCallBuilder allocCallBuilder = {
166 "mgpuMemAlloc",
167 llvmPointerType /* void * */,
168 {llvmIntPtrType /* intptr_t sizeBytes */,
169 llvmPointerType /* void *stream */}};
170 FunctionCallBuilder deallocCallBuilder = {
171 "mgpuMemFree",
172 llvmVoidType,
173 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
174 FunctionCallBuilder memcpyCallBuilder = {
175 "mgpuMemcpy",
176 llvmVoidType,
177 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
178 llvmIntPtrType /* intptr_t sizeBytes */,
179 llvmPointerType /* void *stream */}};
180 FunctionCallBuilder memsetCallBuilder = {
181 "mgpuMemset32",
182 llvmVoidType,
183 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
184 llvmIntPtrType /* intptr_t sizeBytes */,
185 llvmPointerType /* void *stream */}};
186 };
187
188 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
189 /// call. Currently it supports CUDA and ROCm (HIP).
190 class ConvertHostRegisterOpToGpuRuntimeCallPattern
191 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
192 public:
ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)193 ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
194 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
195
196 private:
197 LogicalResult
198 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
199 ConversionPatternRewriter &rewriter) const override;
200 };
201
202 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
203 /// call. Currently it supports CUDA and ROCm (HIP).
204 class ConvertAllocOpToGpuRuntimeCallPattern
205 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
206 public:
ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)207 ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
208 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
209
210 private:
211 LogicalResult
212 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter) const override;
214 };
215
216 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
217 /// call. Currently it supports CUDA and ROCm (HIP).
218 class ConvertDeallocOpToGpuRuntimeCallPattern
219 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
220 public:
ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)221 ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
222 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
223
224 private:
225 LogicalResult
226 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override;
228 };
229
230 class ConvertAsyncYieldToGpuRuntimeCallPattern
231 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
232 public:
ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)233 ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
234 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
235
236 private:
237 LogicalResult
238 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter) const override;
240 };
241
242 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
243 /// call. Currently it supports CUDA and ROCm (HIP).
244 class ConvertWaitOpToGpuRuntimeCallPattern
245 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
246 public:
ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)247 ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
248 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
249
250 private:
251 LogicalResult
252 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
253 ConversionPatternRewriter &rewriter) const override;
254 };
255
256 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
257 /// call. Currently it supports CUDA and ROCm (HIP).
258 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
259 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
260 public:
ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)261 ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
262 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
263
264 private:
265 LogicalResult
266 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter) const override;
268 };
269
270 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
271 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
272 ///
273 /// In essence, a gpu.launch_func operations gets compiled into the following
274 /// sequence of runtime calls:
275 ///
276 /// * moduleLoad -- loads the module given the cubin / hsaco data
277 /// * moduleGetFunction -- gets a handle to the actual kernel function
278 /// * getStreamHelper -- initializes a new compute stream on GPU
279 /// * launchKernel -- launches the kernel on a stream
280 /// * streamSynchronize -- waits for operations on the stream to finish
281 ///
282 /// Intermediate data structures are allocated on the stack.
283 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
284 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
285 public:
ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter,StringRef gpuBinaryAnnotation)286 ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
287 StringRef gpuBinaryAnnotation)
288 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
289 gpuBinaryAnnotation(gpuBinaryAnnotation) {}
290
291 private:
292 Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
293 OpBuilder &builder) const;
294 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
295 Location loc, OpBuilder &builder) const;
296
297 LogicalResult
298 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override;
300
301 llvm::SmallString<32> gpuBinaryAnnotation;
302 };
303
304 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
305 using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
306
matchAndRewrite(gpu::GPUModuleOp op,PatternRewriter & rewriter) const307 LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
308 PatternRewriter &rewriter) const override {
309 // GPU kernel modules are no longer necessary since we have a global
310 // constant with the CUBIN, or HSACO data.
311 rewriter.eraseOp(op);
312 return success();
313 }
314 };
315
316 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
317 /// call. Currently it supports CUDA and ROCm (HIP).
318 class ConvertMemcpyOpToGpuRuntimeCallPattern
319 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
320 public:
ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)321 ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
322 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
323
324 private:
325 LogicalResult
326 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
327 ConversionPatternRewriter &rewriter) const override;
328 };
329
330 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
331 /// call. Currently it supports CUDA and ROCm (HIP).
332 class ConvertMemsetOpToGpuRuntimeCallPattern
333 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
334 public:
ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)335 ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
336 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
337
338 private:
339 LogicalResult
340 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
341 ConversionPatternRewriter &rewriter) const override;
342 };
343 } // namespace
344
runOnOperation()345 void GpuToLLVMConversionPass::runOnOperation() {
346 LLVMTypeConverter converter(&getContext());
347 RewritePatternSet patterns(&getContext());
348 LLVMConversionTarget target(getContext());
349
350 target.addIllegalDialect<gpu::GPUDialect>();
351
352 populateVectorToLLVMConversionPatterns(converter, patterns);
353 populateMemRefToLLVMConversionPatterns(converter, patterns);
354 populateStdToLLVMConversionPatterns(converter, patterns);
355 populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
356 target);
357 populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
358
359 if (failed(
360 applyPartialConversion(getOperation(), target, std::move(patterns))))
361 signalPassFailure();
362 }
363
create(Location loc,OpBuilder & builder,ArrayRef<Value> arguments) const364 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
365 ArrayRef<Value> arguments) const {
366 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
367 auto function = [&] {
368 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
369 return function;
370 return OpBuilder::atBlockEnd(module.getBody())
371 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
372 }();
373 return builder.create<LLVM::CallOp>(loc, function, arguments);
374 }
375
376 // Returns whether all operands are of LLVM type.
areAllLLVMTypes(Operation * op,ValueRange operands,ConversionPatternRewriter & rewriter)377 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
378 ConversionPatternRewriter &rewriter) {
379 if (!llvm::all_of(operands, [](Value value) {
380 return LLVM::isCompatibleType(value.getType());
381 }))
382 return rewriter.notifyMatchFailure(
383 op, "Cannot convert if operands aren't of LLVM type.");
384 return success();
385 }
386
387 static LogicalResult
isAsyncWithOneDependency(ConversionPatternRewriter & rewriter,gpu::AsyncOpInterface op)388 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
389 gpu::AsyncOpInterface op) {
390 if (op.getAsyncDependencies().size() != 1)
391 return rewriter.notifyMatchFailure(
392 op, "Can only convert with exactly one async dependency.");
393
394 if (!op.getAsyncToken())
395 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
396
397 return success();
398 }
399
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const400 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
401 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
402 ConversionPatternRewriter &rewriter) const {
403 auto *op = hostRegisterOp.getOperation();
404 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
405 return failure();
406
407 Location loc = op->getLoc();
408
409 auto memRefType = hostRegisterOp.value().getType();
410 auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
411 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
412
413 auto arguments = getTypeConverter()->promoteOperands(
414 loc, op->getOperands(), adaptor.getOperands(), rewriter);
415 arguments.push_back(elementSize);
416 hostRegisterCallBuilder.create(loc, rewriter, arguments);
417
418 rewriter.eraseOp(op);
419 return success();
420 }
421
matchAndRewrite(gpu::AllocOp allocOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const422 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
423 gpu::AllocOp allocOp, OpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter) const {
425 MemRefType memRefType = allocOp.getType();
426
427 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
428 !isConvertibleAndHasIdentityMaps(memRefType) ||
429 failed(isAsyncWithOneDependency(rewriter, allocOp)))
430 return failure();
431
432 auto loc = allocOp.getLoc();
433
434 // Get shape of the memref as values: static sizes are constant
435 // values and dynamic sizes are passed to 'alloc' as operands.
436 SmallVector<Value, 4> shape;
437 SmallVector<Value, 4> strides;
438 Value sizeBytes;
439 getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter,
440 shape, strides, sizeBytes);
441
442 // Allocate the underlying buffer and store a pointer to it in the MemRef
443 // descriptor.
444 Type elementPtrType = this->getElementPtrType(memRefType);
445 auto stream = adaptor.asyncDependencies().front();
446 Value allocatedPtr =
447 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
448 allocatedPtr =
449 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
450
451 // No alignment.
452 Value alignedPtr = allocatedPtr;
453
454 // Create the MemRef descriptor.
455 auto memRefDescriptor = this->createMemRefDescriptor(
456 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
457
458 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
459
460 return success();
461 }
462
matchAndRewrite(gpu::DeallocOp deallocOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const463 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
464 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter) const {
466 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
467 failed(isAsyncWithOneDependency(rewriter, deallocOp)))
468 return failure();
469
470 Location loc = deallocOp.getLoc();
471
472 Value pointer =
473 MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
474 auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
475 Value stream = adaptor.asyncDependencies().front();
476 deallocCallBuilder.create(loc, rewriter, {casted, stream});
477
478 rewriter.replaceOp(deallocOp, {stream});
479 return success();
480 }
481
isGpuAsyncTokenType(Value value)482 static bool isGpuAsyncTokenType(Value value) {
483 return value.getType().isa<gpu::AsyncTokenType>();
484 }
485
486 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
487 // !gpu.async.token are lowered to stream within the async.execute region, but
488 // are passed as events between them. For each !gpu.async.token operand, we
489 // create an event and record it on the stream.
matchAndRewrite(async::YieldOp yieldOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const490 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
491 async::YieldOp yieldOp, OpAdaptor adaptor,
492 ConversionPatternRewriter &rewriter) const {
493 if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
494 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
495
496 Location loc = yieldOp.getLoc();
497 SmallVector<Value, 4> newOperands(adaptor.getOperands());
498 llvm::SmallDenseSet<Value> streams;
499 for (auto &operand : yieldOp->getOpOperands()) {
500 if (!isGpuAsyncTokenType(operand.get()))
501 continue;
502 auto idx = operand.getOperandNumber();
503 auto stream = adaptor.getOperands()[idx];
504 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
505 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
506 newOperands[idx] = event;
507 streams.insert(stream);
508 }
509 for (auto stream : streams)
510 streamDestroyCallBuilder.create(loc, rewriter, {stream});
511
512 rewriter.updateRootInPlace(yieldOp,
513 [&] { yieldOp->setOperands(newOperands); });
514 return success();
515 }
516
517 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
isDefinedByCallTo(Value value,StringRef functionName)518 static bool isDefinedByCallTo(Value value, StringRef functionName) {
519 assert(value.getType().isa<LLVM::LLVMPointerType>());
520 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
521 return defOp.callee()->equals(functionName);
522 return false;
523 }
524
525 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
526 // with the stream/event operands. The operands are destroyed. That is, it
527 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
528 // runtime error. Eventually, we should guarantee this property.
matchAndRewrite(gpu::WaitOp waitOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const529 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
530 gpu::WaitOp waitOp, OpAdaptor adaptor,
531 ConversionPatternRewriter &rewriter) const {
532 if (waitOp.asyncToken())
533 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
534
535 Location loc = waitOp.getLoc();
536
537 for (auto operand : adaptor.getOperands()) {
538 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
539 // The converted operand's definition created a stream.
540 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
541 streamDestroyCallBuilder.create(loc, rewriter, {operand});
542 } else {
543 // Otherwise the converted operand is an event. This assumes that we use
544 // events in control flow code as well.
545 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
546 eventDestroyCallBuilder.create(loc, rewriter, {operand});
547 }
548 }
549
550 rewriter.eraseOp(waitOp);
551 return success();
552 }
553
554 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
555 // stream that is synchronized with stream/event operands. The operands are
556 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
557 // Otherwise we will get a runtime error. Eventually, we should guarantee this
558 // property.
matchAndRewrite(gpu::WaitOp waitOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const559 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
560 gpu::WaitOp waitOp, OpAdaptor adaptor,
561 ConversionPatternRewriter &rewriter) const {
562 if (!waitOp.asyncToken())
563 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
564
565 Location loc = waitOp.getLoc();
566
567 auto insertionPoint = rewriter.saveInsertionPoint();
568 SmallVector<Value, 1> events;
569 for (auto pair :
570 llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) {
571 auto operand = std::get<1>(pair);
572 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
573 // The converted operand's definition created a stream. Insert an event
574 // into the stream just after the last use of the original token operand.
575 auto *defOp = std::get<0>(pair).getDefiningOp();
576 rewriter.setInsertionPointAfter(defOp);
577 auto event =
578 eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
579 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
580 events.push_back(event);
581 } else {
582 // Otherwise the converted operand is an event. This assumes that we use
583 // events in control flow code as well.
584 events.push_back(operand);
585 }
586 }
587 rewriter.restoreInsertionPoint(insertionPoint);
588 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
589 for (auto event : events)
590 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
591 for (auto event : events)
592 eventDestroyCallBuilder.create(loc, rewriter, {event});
593 rewriter.replaceOp(waitOp, {stream});
594
595 return success();
596 }
597
598 // Creates a struct containing all kernel parameters on the stack and returns
599 // an array of type-erased pointers to the fields of the struct. The array can
600 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
601 // The generated code is essentially as follows:
602 //
603 // %struct = alloca(sizeof(struct { Parameters... }))
604 // %array = alloca(NumParameters * sizeof(void *))
605 // for (i : [0, NumParameters))
606 // %fieldPtr = llvm.getelementptr %struct[0, i]
607 // llvm.store parameters[i], %fieldPtr
608 // %elementPtr = llvm.getelementptr %array[i]
609 // llvm.store %fieldPtr, %elementPtr
610 // return %array
generateParamsArray(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,OpBuilder & builder) const611 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
612 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
613 auto loc = launchOp.getLoc();
614 auto numKernelOperands = launchOp.getNumKernelOperands();
615 auto arguments = getTypeConverter()->promoteOperands(
616 loc, launchOp.getOperands().take_back(numKernelOperands),
617 adaptor.getOperands().take_back(numKernelOperands), builder);
618 auto numArguments = arguments.size();
619 SmallVector<Type, 4> argumentTypes;
620 argumentTypes.reserve(numArguments);
621 for (auto argument : arguments)
622 argumentTypes.push_back(argument.getType());
623 auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
624 argumentTypes);
625 auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
626 builder.getI32IntegerAttr(1));
627 auto structPtr = builder.create<LLVM::AllocaOp>(
628 loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
629 auto arraySize = builder.create<LLVM::ConstantOp>(
630 loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
631 auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
632 arraySize, /*alignment=*/0);
633 auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
634 builder.getI32IntegerAttr(0));
635 for (auto en : llvm::enumerate(arguments)) {
636 auto index = builder.create<LLVM::ConstantOp>(
637 loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
638 auto fieldPtr = builder.create<LLVM::GEPOp>(
639 loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
640 ArrayRef<Value>{zero, index.getResult()});
641 builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
642 auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
643 arrayPtr, index.getResult());
644 auto casted =
645 builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
646 builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
647 }
648 return arrayPtr;
649 }
650
651 // Generates an LLVM IR dialect global that contains the name of the given
652 // kernel function as a C string, and returns a pointer to its beginning.
653 // The code is essentially:
654 //
655 // llvm.global constant @kernel_name("function_name\00")
656 // func(...) {
657 // %0 = llvm.addressof @kernel_name
658 // %1 = llvm.constant (0 : index)
659 // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
660 // }
generateKernelNameConstant(StringRef moduleName,StringRef name,Location loc,OpBuilder & builder) const661 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
662 StringRef moduleName, StringRef name, Location loc,
663 OpBuilder &builder) const {
664 // Make sure the trailing zero is included in the constant.
665 std::vector<char> kernelName(name.begin(), name.end());
666 kernelName.push_back('\0');
667
668 std::string globalName =
669 std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
670 return LLVM::createGlobalString(
671 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
672 LLVM::Linkage::Internal);
673 }
674
675 // Emits LLVM IR to launch a kernel function. Expects the module that contains
676 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
677 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
678 //
679 // %0 = call %binarygetter
680 // %1 = call %moduleLoad(%0)
681 // %2 = <see generateKernelNameConstant>
682 // %3 = call %moduleGetFunction(%1, %2)
683 // %4 = call %streamCreate()
684 // %5 = <see generateParamsArray>
685 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
686 // call %streamSynchronize(%4)
687 // call %streamDestroy(%4)
688 // call %moduleUnload(%1)
689 //
690 // If the op is async, the stream corresponds to the (single) async dependency
691 // as well as the async token the op produces.
matchAndRewrite(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const692 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
693 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
694 ConversionPatternRewriter &rewriter) const {
695 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
696 return failure();
697
698 if (launchOp.asyncDependencies().size() > 1)
699 return rewriter.notifyMatchFailure(
700 launchOp, "Cannot convert with more than one async dependency.");
701
702 // Fail when the synchronous version of the op has async dependencies. The
703 // lowering destroys the stream, and we do not want to check that there is no
704 // use of the stream after this op.
705 if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
706 return rewriter.notifyMatchFailure(
707 launchOp, "Cannot convert non-async op with async dependencies.");
708
709 Location loc = launchOp.getLoc();
710
711 // Create an LLVM global with CUBIN extracted from the kernel annotation and
712 // obtain a pointer to the first byte in it.
713 auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
714 launchOp, launchOp.getKernelModuleName());
715 assert(kernelModule && "expected a kernel module");
716
717 auto binaryAttr =
718 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
719 if (!binaryAttr) {
720 kernelModule.emitOpError()
721 << "missing " << gpuBinaryAnnotation << " attribute";
722 return failure();
723 }
724
725 SmallString<128> nameBuffer(kernelModule.getName());
726 nameBuffer.append(kGpuBinaryStorageSuffix);
727 Value data =
728 LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
729 binaryAttr.getValue(), LLVM::Linkage::Internal);
730
731 auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
732 // Get the function from the module. The name corresponds to the name of
733 // the kernel function.
734 auto kernelName = generateKernelNameConstant(
735 launchOp.getKernelModuleName().getValue(),
736 launchOp.getKernelName().getValue(), loc, rewriter);
737 auto function = moduleGetFunctionCallBuilder.create(
738 loc, rewriter, {module.getResult(0), kernelName});
739 auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
740 rewriter.getI32IntegerAttr(0));
741 Value stream =
742 adaptor.asyncDependencies().empty()
743 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
744 : adaptor.asyncDependencies().front();
745 // Create array of pointers to kernel arguments.
746 auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
747 auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
748 Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize()
749 ? launchOp.dynamicSharedMemorySize()
750 : zero;
751 launchKernelCallBuilder.create(
752 loc, rewriter,
753 {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
754 adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
755 adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
756 /*extra=*/nullpointer});
757
758 if (launchOp.asyncToken()) {
759 // Async launch: make dependent ops use the same stream.
760 rewriter.replaceOp(launchOp, {stream});
761 } else {
762 // Synchronize with host and destroy stream. This must be the stream created
763 // above (with no other uses) because we check that the synchronous version
764 // does not have any async dependencies.
765 streamSynchronizeCallBuilder.create(loc, rewriter, stream);
766 streamDestroyCallBuilder.create(loc, rewriter, stream);
767 rewriter.eraseOp(launchOp);
768 }
769 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
770
771 return success();
772 }
773
matchAndRewrite(gpu::MemcpyOp memcpyOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const774 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
775 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
776 ConversionPatternRewriter &rewriter) const {
777 auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
778
779 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
780 !isConvertibleAndHasIdentityMaps(memRefType) ||
781 failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
782 return failure();
783
784 auto loc = memcpyOp.getLoc();
785
786 MemRefDescriptor srcDesc(adaptor.src());
787 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
788
789 Type elementPtrType = getElementPtrType(memRefType);
790 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
791 Value gepPtr = rewriter.create<LLVM::GEPOp>(
792 loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
793 auto sizeBytes =
794 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
795
796 auto src = rewriter.create<LLVM::BitcastOp>(
797 loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
798 auto dst = rewriter.create<LLVM::BitcastOp>(
799 loc, llvmPointerType,
800 MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
801
802 auto stream = adaptor.asyncDependencies().front();
803 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
804
805 rewriter.replaceOp(memcpyOp, {stream});
806
807 return success();
808 }
809
matchAndRewrite(gpu::MemsetOp memsetOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const810 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
811 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
812 ConversionPatternRewriter &rewriter) const {
813 auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
814
815 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
816 !isConvertibleAndHasIdentityMaps(memRefType) ||
817 failed(isAsyncWithOneDependency(rewriter, memsetOp)))
818 return failure();
819
820 auto loc = memsetOp.getLoc();
821
822 Type valueType = adaptor.value().getType();
823 if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
824 return rewriter.notifyMatchFailure(memsetOp,
825 "value must be a 32 bit scalar");
826 }
827
828 MemRefDescriptor dstDesc(adaptor.dst());
829 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
830
831 auto value =
832 rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
833 auto dst = rewriter.create<LLVM::BitcastOp>(
834 loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
835
836 auto stream = adaptor.asyncDependencies().front();
837 memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
838
839 rewriter.replaceOp(memsetOp, {stream});
840 return success();
841 }
842
843 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createGpuToLLVMConversionPass()844 mlir::createGpuToLLVMConversionPass() {
845 return std::make_unique<GpuToLLVMConversionPass>();
846 }
847
populateGpuToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns,StringRef gpuBinaryAnnotation)848 void mlir::populateGpuToLLVMConversionPatterns(
849 LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
850 StringRef gpuBinaryAnnotation) {
851 converter.addConversion(
852 [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
853 return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
854 });
855 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
856 ConvertDeallocOpToGpuRuntimeCallPattern,
857 ConvertHostRegisterOpToGpuRuntimeCallPattern,
858 ConvertMemcpyOpToGpuRuntimeCallPattern,
859 ConvertMemsetOpToGpuRuntimeCallPattern,
860 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
861 ConvertWaitOpToGpuRuntimeCallPattern,
862 ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
863 patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
864 gpuBinaryAnnotation);
865 patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
866 }
867