1 //===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // Vector operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
15
16 #include "../PassDetail.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
19 #include "mlir/Dialect/GPU/GPUDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26
27 using namespace mlir;
28 using namespace mlir::vector;
29
replaceTransferOpWithMubuf(ConversionPatternRewriter & rewriter,ArrayRef<Value> operands,LLVMTypeConverter & typeConverter,Location loc,TransferReadOp xferOp,Type & vecTy,Value & dwordConfig,Value & vindex,Value & offsetSizeInBytes,Value & glc,Value & slc)30 static LogicalResult replaceTransferOpWithMubuf(
31 ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
32 LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
33 Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
34 Value &glc, Value &slc) {
35 rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
36 xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
37 return success();
38 }
39
replaceTransferOpWithMubuf(ConversionPatternRewriter & rewriter,ArrayRef<Value> operands,LLVMTypeConverter & typeConverter,Location loc,TransferWriteOp xferOp,Type & vecTy,Value & dwordConfig,Value & vindex,Value & offsetSizeInBytes,Value & glc,Value & slc)40 static LogicalResult replaceTransferOpWithMubuf(
41 ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
42 LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
43 Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
44 Value &glc, Value &slc) {
45 auto adaptor = TransferWriteOpAdaptor(operands);
46 rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
47 dwordConfig, vindex,
48 offsetSizeInBytes, glc, slc);
49 return success();
50 }
51
52 namespace {
53 /// Conversion pattern that converts a 1-D vector transfer read/write.
54 /// Note that this conversion pass only converts vector x2 or x4 f32
55 /// types. For unsupported cases, they will fall back to the vector to
56 /// llvm conversion pattern.
57 template <typename ConcreteOp>
58 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
59 public:
60 using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
61
62 LogicalResult
matchAndRewrite(ConcreteOp xferOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const63 matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
64 ConversionPatternRewriter &rewriter) const override {
65 typename ConcreteOp::Adaptor adaptor(operands);
66
67 if (xferOp.getVectorType().getRank() > 1 ||
68 llvm::size(xferOp.indices()) == 0)
69 return failure();
70
71 if (!xferOp.permutation_map().isMinorIdentity())
72 return failure();
73
74 // Have it handled in vector->llvm conversion pass.
75 if (!xferOp.isMaskedDim(0))
76 return failure();
77
78 auto toLLVMTy = [&](Type t) {
79 return this->getTypeConverter()->convertType(t);
80 };
81 auto vecTy = toLLVMTy(xferOp.getVectorType());
82 unsigned vecWidth = LLVM::getVectorNumElements(vecTy).getFixedValue();
83 Location loc = xferOp->getLoc();
84
85 // The backend result vector scalarization have trouble scalarize
86 // <1 x ty> result, exclude the x1 width from the lowering.
87 if (vecWidth != 2 && vecWidth != 4)
88 return failure();
89
90 // Obtain dataPtr and elementType from the memref.
91 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
92 if (!memRefType)
93 return failure();
94 // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
95 // In case of 3(LDS): fall back to vector->llvm pass
96 // In case of 5(VGPR): wrong
97 if ((memRefType.getMemorySpace() != 0) &&
98 (memRefType.getMemorySpace() != 1))
99 return failure();
100
101 // Note that the dataPtr starts at the offset address specified by
102 // indices, so no need to calculate offset size in bytes again in
103 // the MUBUF instruction.
104 Value dataPtr = this->getStridedElementPtr(
105 loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
106
107 // 1. Create and fill a <4 x i32> dwordConfig with:
108 // 1st two elements holding the address of dataPtr.
109 // 3rd element: -1.
110 // 4th element: 0x27000.
111 SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
112 Type i32Ty = rewriter.getIntegerType(32);
113 VectorType i32Vecx4 = VectorType::get(4, i32Ty);
114 Value constConfig = rewriter.create<LLVM::ConstantOp>(
115 loc, toLLVMTy(i32Vecx4),
116 DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
117
118 // Treat first two element of <4 x i32> as i64, and save the dataPtr
119 // to it.
120 Type i64Ty = rewriter.getIntegerType(64);
121 Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
122 loc, LLVM::getFixedVectorType(toLLVMTy(i64Ty), 2), constConfig);
123 Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
124 loc, toLLVMTy(i64Ty).template cast<Type>(), dataPtr);
125 Value zero = this->createIndexConstant(rewriter, loc, 0);
126 Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
127 loc, LLVM::getFixedVectorType(toLLVMTy(i64Ty), 2), i64x2Ty,
128 dataPtrAsI64, zero);
129 dwordConfig =
130 rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
131
132 // 2. Rewrite op as a buffer read or write.
133 Value int1False = rewriter.create<LLVM::ConstantOp>(
134 loc, toLLVMTy(rewriter.getIntegerType(1)),
135 rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
136 Value int32Zero = rewriter.create<LLVM::ConstantOp>(
137 loc, toLLVMTy(i32Ty),
138 rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
139 return replaceTransferOpWithMubuf(
140 rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy,
141 dwordConfig, int32Zero, int32Zero, int1False, int1False);
142 }
143 };
144 } // end anonymous namespace
145
populateVectorToROCDLConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)146 void mlir::populateVectorToROCDLConversionPatterns(
147 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
148 patterns.insert<VectorTransferConversion<TransferReadOp>,
149 VectorTransferConversion<TransferWriteOp>>(converter);
150 }
151
152 namespace {
153 struct LowerVectorToROCDLPass
154 : public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
155 void runOnOperation() override;
156 };
157 } // namespace
158
runOnOperation()159 void LowerVectorToROCDLPass::runOnOperation() {
160 LLVMTypeConverter converter(&getContext());
161 OwningRewritePatternList patterns;
162
163 populateVectorToROCDLConversionPatterns(converter, patterns);
164 populateStdToLLVMConversionPatterns(converter, patterns);
165
166 LLVMConversionTarget target(getContext());
167 target.addLegalDialect<ROCDL::ROCDLDialect>();
168
169 if (failed(
170 applyPartialConversion(getOperation(), target, std::move(patterns))))
171 signalPassFailure();
172 }
173
174 std::unique_ptr<OperationPass<ModuleOp>>
createConvertVectorToROCDLPass()175 mlir::createConvertVectorToROCDLPass() {
176 return std::make_unique<LowerVectorToROCDLPass>();
177 }
178