1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
21 #include "mlir/Dialect/Linalg/Passes.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/Operation.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Support/LogicalResult.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/SetVector.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/Allocator.h"
42 #include "llvm/Support/ErrorHandling.h"
43
44 using namespace mlir;
45 using namespace mlir::edsc;
46 using namespace mlir::edsc::intrinsics;
47 using namespace mlir::LLVM;
48 using namespace mlir::linalg;
49
50 using llvm_add = ValueBuilder<LLVM::AddOp>;
51 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
52 using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
53 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
54 using llvm_gep = ValueBuilder<LLVM::GEPOp>;
55 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
56 using llvm_call = OperationBuilder<LLVM::CallOp>;
57 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
58 using llvm_load = ValueBuilder<LLVM::LoadOp>;
59 using llvm_store = OperationBuilder<LLVM::StoreOp>;
60 using llvm_select = ValueBuilder<LLVM::SelectOp>;
61 using llvm_mul = ValueBuilder<LLVM::MulOp>;
62 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
63 using llvm_sub = ValueBuilder<LLVM::SubOp>;
64 using llvm_undef = ValueBuilder<LLVM::UndefOp>;
65 using llvm_urem = ValueBuilder<LLVM::URemOp>;
66 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
67 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
68
69 template <typename T>
getPtrToElementType(T containerType,LLVMTypeConverter & lowering)70 static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
71 return LLVMPointerType::get(
72 lowering.convertType(containerType.getElementType()));
73 }
74
75 /// Convert the given range descriptor type to the LLVMIR dialect.
76 /// Range descriptor contains the range bounds and the step as 64-bit integers.
77 ///
78 /// struct {
79 /// int64_t min;
80 /// int64_t max;
81 /// int64_t step;
82 /// };
convertRangeType(RangeType t,LLVMTypeConverter & converter)83 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
84 auto *context = t.getContext();
85 auto int64Ty = converter.convertType(IntegerType::get(context, 64));
86 return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
87 }
88
89 namespace {
90 /// EDSC-compatible wrapper for MemRefDescriptor.
91 class BaseViewConversionHelper {
92 public:
BaseViewConversionHelper(Type type)93 BaseViewConversionHelper(Type type)
94 : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
95
BaseViewConversionHelper(Value v)96 BaseViewConversionHelper(Value v) : d(v) {}
97
98 /// Wrappers around MemRefDescriptor that use EDSC builder and location.
allocatedPtr()99 Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
setAllocatedPtr(Value v)100 void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
alignedPtr()101 Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
setAlignedPtr(Value v)102 void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
offset()103 Value offset() { return d.offset(rewriter(), loc()); }
setOffset(Value v)104 void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
size(unsigned i)105 Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
setSize(unsigned i,Value v)106 void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
setConstantSize(unsigned i,int64_t v)107 void setConstantSize(unsigned i, int64_t v) {
108 d.setConstantSize(rewriter(), loc(), i, v);
109 }
stride(unsigned i)110 Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
setStride(unsigned i,Value v)111 void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
setConstantStride(unsigned i,int64_t v)112 void setConstantStride(unsigned i, int64_t v) {
113 d.setConstantStride(rewriter(), loc(), i, v);
114 }
115
operator Value()116 operator Value() { return d; }
117
118 private:
rewriter()119 OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
loc()120 Location loc() { return ScopedContext::getLocation(); }
121
122 MemRefDescriptor d;
123 };
124
125 // RangeOp creates a new range descriptor.
126 class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
127 public:
128 using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
129
130 LogicalResult
matchAndRewrite(RangeOp rangeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const131 matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
132 ConversionPatternRewriter &rewriter) const override {
133 auto rangeDescriptorTy = convertRangeType(
134 rangeOp.getType().cast<RangeType>(), *getTypeConverter());
135
136 edsc::ScopedContext context(rewriter, rangeOp->getLoc());
137
138 // Fill in an aggregate value of the descriptor.
139 RangeOpAdaptor adaptor(operands);
140 Value desc = llvm_undef(rangeDescriptorTy);
141 desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
142 desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
143 desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
144 rewriter.replaceOp(rangeOp, desc);
145 return success();
146 }
147 };
148
149 // ReshapeOp creates a new view descriptor of the proper rank.
150 // For now, the only conversion supported is for target MemRef with static sizes
151 // and strides.
152 class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
153 public:
154 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
155
156 LogicalResult
matchAndRewrite(ReshapeOp reshapeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const157 matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
158 ConversionPatternRewriter &rewriter) const override {
159 MemRefType dstType = reshapeOp.getResultType();
160
161 if (!dstType.hasStaticShape())
162 return failure();
163
164 int64_t offset;
165 SmallVector<int64_t, 4> strides;
166 auto res = getStridesAndOffset(dstType, strides, offset);
167 if (failed(res) || llvm::any_of(strides, [](int64_t val) {
168 return ShapedType::isDynamicStrideOrOffset(val);
169 }))
170 return failure();
171
172 edsc::ScopedContext context(rewriter, reshapeOp->getLoc());
173 ReshapeOpAdaptor adaptor(operands);
174 BaseViewConversionHelper baseDesc(adaptor.src());
175 BaseViewConversionHelper desc(typeConverter->convertType(dstType));
176 desc.setAllocatedPtr(baseDesc.allocatedPtr());
177 desc.setAlignedPtr(baseDesc.alignedPtr());
178 desc.setOffset(baseDesc.offset());
179 for (auto en : llvm::enumerate(dstType.getShape()))
180 desc.setConstantSize(en.index(), en.value());
181 for (auto en : llvm::enumerate(strides))
182 desc.setConstantStride(en.index(), en.value());
183 rewriter.replaceOp(reshapeOp, {desc});
184 return success();
185 }
186 };
187
188 /// Conversion pattern that transforms a linalg.slice op into:
189 /// 1. An "undef" value for the ViewDescriptor.
190 /// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
191 /// and stride corresponding to the region of memory within the bounds of
192 /// the parent view.
193 /// The linalg.slice op is replaced by the alloca'ed pointer.
194 class SliceOpConversion : public ConvertOpToLLVMPattern<SliceOp> {
195 public:
196 using ConvertOpToLLVMPattern<SliceOp>::ConvertOpToLLVMPattern;
197
198 LogicalResult
matchAndRewrite(SliceOp sliceOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const199 matchAndRewrite(SliceOp sliceOp, ArrayRef<Value> operands,
200 ConversionPatternRewriter &rewriter) const override {
201 edsc::ScopedContext context(rewriter, sliceOp->getLoc());
202 SliceOpAdaptor adaptor(operands);
203 BaseViewConversionHelper baseDesc(adaptor.view());
204
205 auto memRefType = sliceOp.getBaseViewType();
206 auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64));
207
208 BaseViewConversionHelper desc(
209 typeConverter->convertType(sliceOp.getShapedType()));
210
211 // TODO: extract sizes and emit asserts.
212 SmallVector<Value, 4> strides(memRefType.getRank());
213 for (int i = 0, e = memRefType.getRank(); i < e; ++i)
214 strides[i] = baseDesc.stride(i);
215
216 auto pos = [&rewriter](ArrayRef<int64_t> values) {
217 return rewriter.getI64ArrayAttr(values);
218 };
219
220 // Compute base offset.
221 Value baseOffset = baseDesc.offset();
222 for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
223 Value indexing = adaptor.indexings()[i];
224 Value min = indexing;
225 if (sliceOp.indexing(i).getType().isa<RangeType>())
226 min = llvm_extractvalue(int64Ty, indexing, pos(0));
227 baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
228 }
229
230 // Insert the base and aligned pointers.
231 desc.setAllocatedPtr(baseDesc.allocatedPtr());
232 desc.setAlignedPtr(baseDesc.alignedPtr());
233
234 // Insert base offset.
235 desc.setOffset(baseOffset);
236
237 // Corner case, no sizes or strides: early return the descriptor.
238 if (sliceOp.getShapedType().getRank() == 0)
239 return rewriter.replaceOp(sliceOp, {desc}), success();
240
241 Value zero = llvm_constant(
242 int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
243 // Compute and insert view sizes (max - min along the range) and strides.
244 // Skip the non-range operands as they will be projected away from the view.
245 int numNewDims = 0;
246 for (auto en : llvm::enumerate(sliceOp.indexings())) {
247 Value indexing = en.value();
248 if (indexing.getType().isa<RangeType>()) {
249 int rank = en.index();
250 Value rangeDescriptor = adaptor.indexings()[rank];
251 Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
252 Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
253 Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
254 Value baseSize = baseDesc.size(rank);
255
256 // Bound upper by base view upper bound.
257 max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
258 baseSize);
259 Value size = llvm_sub(max, min);
260 // Bound lower by zero.
261 size =
262 llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
263 Value stride = llvm_mul(strides[rank], step);
264 desc.setSize(numNewDims, size);
265 desc.setStride(numNewDims, stride);
266 ++numNewDims;
267 }
268 }
269
270 rewriter.replaceOp(sliceOp, {desc});
271 return success();
272 }
273 };
274
275 // YieldOp produces and LLVM::ReturnOp.
276 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
277 public:
278 using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
279
280 LogicalResult
matchAndRewrite(linalg::YieldOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const281 matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
282 ConversionPatternRewriter &rewriter) const override {
283 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
284 return success();
285 }
286 };
287 } // namespace
288
289 /// Populate the given list with patterns that convert from Linalg to LLVM.
populateLinalgToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)290 void mlir::populateLinalgToLLVMConversionPatterns(
291 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
292 patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
293 YieldOpConversion>(converter);
294
295 // Populate the type conversions for the linalg types.
296 converter.addConversion(
297 [&](RangeType type) { return convertRangeType(type, converter); });
298 }
299
300 namespace {
301 struct ConvertLinalgToLLVMPass
302 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
303 void runOnOperation() override;
304 };
305 } // namespace
306
runOnOperation()307 void ConvertLinalgToLLVMPass::runOnOperation() {
308 auto module = getOperation();
309
310 // Convert to the LLVM IR dialect using the converter defined above.
311 OwningRewritePatternList patterns;
312 LLVMTypeConverter converter(&getContext());
313 populateAffineToStdConversionPatterns(patterns, &getContext());
314 populateLoopToStdConversionPatterns(patterns, &getContext());
315 populateStdToLLVMConversionPatterns(converter, patterns);
316 populateVectorToSCFConversionPatterns(patterns, &getContext());
317 populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
318 populateVectorToLLVMConversionPatterns(converter, patterns);
319 populateLinalgToLLVMConversionPatterns(converter, patterns);
320
321 LLVMConversionTarget target(getContext());
322 target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
323 if (failed(applyFullConversion(module, target, std::move(patterns))))
324 signalPassFailure();
325 }
326
createConvertLinalgToLLVMPass()327 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
328 return std::make_unique<ConvertLinalgToLLVMPass>();
329 }
330