1 //===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // Vector operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
15 
16 #include "../PassDetail.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
19 #include "mlir/Dialect/GPU/GPUDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
replaceTransferOpWithMubuf(ConversionPatternRewriter & rewriter,ArrayRef<Value> operands,LLVMTypeConverter & typeConverter,Location loc,TransferReadOp xferOp,Type & vecTy,Value & dwordConfig,Value & vindex,Value & offsetSizeInBytes,Value & glc,Value & slc)30 static LogicalResult replaceTransferOpWithMubuf(
31     ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
32     LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
33     Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
34     Value &glc, Value &slc) {
35   rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
36       xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
37   return success();
38 }
39 
replaceTransferOpWithMubuf(ConversionPatternRewriter & rewriter,ArrayRef<Value> operands,LLVMTypeConverter & typeConverter,Location loc,TransferWriteOp xferOp,Type & vecTy,Value & dwordConfig,Value & vindex,Value & offsetSizeInBytes,Value & glc,Value & slc)40 static LogicalResult replaceTransferOpWithMubuf(
41     ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
42     LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
43     Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
44     Value &glc, Value &slc) {
45   auto adaptor = TransferWriteOpAdaptor(operands);
46   rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
47                                                    dwordConfig, vindex,
48                                                    offsetSizeInBytes, glc, slc);
49   return success();
50 }
51 
52 namespace {
53 /// Conversion pattern that converts a 1-D vector transfer read/write.
54 /// Note that this conversion pass only converts vector x2 or x4 f32
55 /// types. For unsupported cases, they will fall back to the vector to
56 /// llvm conversion pattern.
57 template <typename ConcreteOp>
58 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
59 public:
60   using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
61 
62   LogicalResult
matchAndRewrite(ConcreteOp xferOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const63   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
64                   ConversionPatternRewriter &rewriter) const override {
65     typename ConcreteOp::Adaptor adaptor(operands);
66 
67     if (xferOp.getVectorType().getRank() > 1 ||
68         llvm::size(xferOp.indices()) == 0)
69       return failure();
70 
71     if (!xferOp.permutation_map().isMinorIdentity())
72       return failure();
73 
74     // Have it handled in vector->llvm conversion pass.
75     if (!xferOp.isMaskedDim(0))
76       return failure();
77 
78     auto toLLVMTy = [&](Type t) {
79       return this->getTypeConverter()->convertType(t);
80     };
81     auto vecTy = toLLVMTy(xferOp.getVectorType());
82     unsigned vecWidth = LLVM::getVectorNumElements(vecTy).getFixedValue();
83     Location loc = xferOp->getLoc();
84 
85     // The backend result vector scalarization have trouble scalarize
86     // <1 x ty> result, exclude the x1 width from the lowering.
87     if (vecWidth != 2 && vecWidth != 4)
88       return failure();
89 
90     // Obtain dataPtr and elementType from the memref.
91     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
92     if (!memRefType)
93       return failure();
94     // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
95     // In case of 3(LDS): fall back to vector->llvm pass
96     // In case of 5(VGPR): wrong
97     if ((memRefType.getMemorySpace() != 0) &&
98         (memRefType.getMemorySpace() != 1))
99       return failure();
100 
101     // Note that the dataPtr starts at the offset address specified by
102     // indices, so no need to calculate offset size in bytes again in
103     // the MUBUF instruction.
104     Value dataPtr = this->getStridedElementPtr(
105         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
106 
107     // 1. Create and fill a <4 x i32> dwordConfig with:
108     //    1st two elements holding the address of dataPtr.
109     //    3rd element: -1.
110     //    4th element: 0x27000.
111     SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
112     Type i32Ty = rewriter.getIntegerType(32);
113     VectorType i32Vecx4 = VectorType::get(4, i32Ty);
114     Value constConfig = rewriter.create<LLVM::ConstantOp>(
115         loc, toLLVMTy(i32Vecx4),
116         DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
117 
118     // Treat first two element of <4 x i32> as i64, and save the dataPtr
119     // to it.
120     Type i64Ty = rewriter.getIntegerType(64);
121     Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
122         loc, LLVM::getFixedVectorType(toLLVMTy(i64Ty), 2), constConfig);
123     Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
124         loc, toLLVMTy(i64Ty).template cast<Type>(), dataPtr);
125     Value zero = this->createIndexConstant(rewriter, loc, 0);
126     Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
127         loc, LLVM::getFixedVectorType(toLLVMTy(i64Ty), 2), i64x2Ty,
128         dataPtrAsI64, zero);
129     dwordConfig =
130         rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
131 
132     // 2. Rewrite op as a buffer read or write.
133     Value int1False = rewriter.create<LLVM::ConstantOp>(
134         loc, toLLVMTy(rewriter.getIntegerType(1)),
135         rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
136     Value int32Zero = rewriter.create<LLVM::ConstantOp>(
137         loc, toLLVMTy(i32Ty),
138         rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
139     return replaceTransferOpWithMubuf(
140         rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy,
141         dwordConfig, int32Zero, int32Zero, int1False, int1False);
142   }
143 };
144 } // end anonymous namespace
145 
populateVectorToROCDLConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)146 void mlir::populateVectorToROCDLConversionPatterns(
147     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
148   patterns.insert<VectorTransferConversion<TransferReadOp>,
149                   VectorTransferConversion<TransferWriteOp>>(converter);
150 }
151 
152 namespace {
153 struct LowerVectorToROCDLPass
154     : public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
155   void runOnOperation() override;
156 };
157 } // namespace
158 
runOnOperation()159 void LowerVectorToROCDLPass::runOnOperation() {
160   LLVMTypeConverter converter(&getContext());
161   OwningRewritePatternList patterns;
162 
163   populateVectorToROCDLConversionPatterns(converter, patterns);
164   populateStdToLLVMConversionPatterns(converter, patterns);
165 
166   LLVMConversionTarget target(getContext());
167   target.addLegalDialect<ROCDL::ROCDLDialect>();
168 
169   if (failed(
170           applyPartialConversion(getOperation(), target, std::move(patterns))))
171     signalPassFailure();
172 }
173 
174 std::unique_ptr<OperationPass<ModuleOp>>
createConvertVectorToROCDLPass()175 mlir::createConvertVectorToROCDLPass() {
176   return std::make_unique<LowerVectorToROCDLPass>();
177 }
178