1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
21 #include "mlir/Dialect/Linalg/Passes.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/MLIRContext.h"
29 #include "mlir/IR/Module.h"
30 #include "mlir/IR/Operation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/StandardTypes.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Support/LogicalResult.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/SetVector.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/Allocator.h"
42 #include "llvm/Support/ErrorHandling.h"
43
44 using namespace mlir;
45 using namespace mlir::edsc;
46 using namespace mlir::edsc::intrinsics;
47 using namespace mlir::LLVM;
48 using namespace mlir::linalg;
49
50 using llvm_add = ValueBuilder<LLVM::AddOp>;
51 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
52 using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
53 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
54 using llvm_gep = ValueBuilder<LLVM::GEPOp>;
55 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
56 using llvm_call = OperationBuilder<LLVM::CallOp>;
57 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
58 using llvm_load = ValueBuilder<LLVM::LoadOp>;
59 using llvm_store = OperationBuilder<LLVM::StoreOp>;
60 using llvm_select = ValueBuilder<LLVM::SelectOp>;
61 using llvm_mul = ValueBuilder<LLVM::MulOp>;
62 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
63 using llvm_sub = ValueBuilder<LLVM::SubOp>;
64 using llvm_undef = ValueBuilder<LLVM::UndefOp>;
65 using llvm_urem = ValueBuilder<LLVM::URemOp>;
66 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
67 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
68
69 template <typename T>
getPtrToElementType(T containerType,LLVMTypeConverter & lowering)70 static LLVMType getPtrToElementType(T containerType,
71 LLVMTypeConverter &lowering) {
72 return lowering.convertType(containerType.getElementType())
73 .template cast<LLVMType>()
74 .getPointerTo();
75 }
76
77 /// Convert the given range descriptor type to the LLVMIR dialect.
78 /// Range descriptor contains the range bounds and the step as 64-bit integers.
79 ///
80 /// struct {
81 /// int64_t min;
82 /// int64_t max;
83 /// int64_t step;
84 /// };
convertRangeType(RangeType t,LLVMTypeConverter & converter)85 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
86 auto *context = t.getContext();
87 auto int64Ty = converter.convertType(IntegerType::get(64, context))
88 .cast<LLVM::LLVMType>();
89 return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
90 }
91
92 namespace {
93 /// EDSC-compatible wrapper for MemRefDescriptor.
94 class BaseViewConversionHelper {
95 public:
BaseViewConversionHelper(Type type)96 BaseViewConversionHelper(Type type)
97 : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
98
BaseViewConversionHelper(Value v)99 BaseViewConversionHelper(Value v) : d(v) {}
100
101 /// Wrappers around MemRefDescriptor that use EDSC builder and location.
allocatedPtr()102 Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
setAllocatedPtr(Value v)103 void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
alignedPtr()104 Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
setAlignedPtr(Value v)105 void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
offset()106 Value offset() { return d.offset(rewriter(), loc()); }
setOffset(Value v)107 void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
size(unsigned i)108 Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
setSize(unsigned i,Value v)109 void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
setConstantSize(unsigned i,int64_t v)110 void setConstantSize(unsigned i, int64_t v) {
111 d.setConstantSize(rewriter(), loc(), i, v);
112 }
stride(unsigned i)113 Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
setStride(unsigned i,Value v)114 void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
setConstantStride(unsigned i,int64_t v)115 void setConstantStride(unsigned i, int64_t v) {
116 d.setConstantStride(rewriter(), loc(), i, v);
117 }
118
operator Value()119 operator Value() { return d; }
120
121 private:
rewriter()122 OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
loc()123 Location loc() { return ScopedContext::getLocation(); }
124
125 MemRefDescriptor d;
126 };
127
128 // RangeOp creates a new range descriptor.
129 class RangeOpConversion : public ConvertToLLVMPattern {
130 public:
RangeOpConversion(MLIRContext * context,LLVMTypeConverter & lowering_)131 explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
132 : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
133
134 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const135 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
136 ConversionPatternRewriter &rewriter) const override {
137 auto rangeOp = cast<RangeOp>(op);
138 auto rangeDescriptorTy =
139 convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
140
141 edsc::ScopedContext context(rewriter, op->getLoc());
142
143 // Fill in an aggregate value of the descriptor.
144 RangeOpAdaptor adaptor(operands);
145 Value desc = llvm_undef(rangeDescriptorTy);
146 desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
147 desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
148 desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
149 rewriter.replaceOp(op, desc);
150 return success();
151 }
152 };
153
154 // ReshapeOp creates a new view descriptor of the proper rank.
155 // For now, the only conversion supported is for target MemRef with static sizes
156 // and strides.
157 class ReshapeOpConversion : public ConvertToLLVMPattern {
158 public:
ReshapeOpConversion(MLIRContext * context,LLVMTypeConverter & lowering_)159 explicit ReshapeOpConversion(MLIRContext *context,
160 LLVMTypeConverter &lowering_)
161 : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
162 lowering_) {}
163
164 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const165 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
166 ConversionPatternRewriter &rewriter) const override {
167 auto reshapeOp = cast<ReshapeOp>(op);
168 MemRefType dstType = reshapeOp.getResultType();
169
170 if (!dstType.hasStaticShape())
171 return failure();
172
173 int64_t offset;
174 SmallVector<int64_t, 4> strides;
175 auto res = getStridesAndOffset(dstType, strides, offset);
176 if (failed(res) || llvm::any_of(strides, [](int64_t val) {
177 return ShapedType::isDynamicStrideOrOffset(val);
178 }))
179 return failure();
180
181 edsc::ScopedContext context(rewriter, op->getLoc());
182 ReshapeOpAdaptor adaptor(operands);
183 BaseViewConversionHelper baseDesc(adaptor.src());
184 BaseViewConversionHelper desc(typeConverter.convertType(dstType));
185 desc.setAllocatedPtr(baseDesc.allocatedPtr());
186 desc.setAlignedPtr(baseDesc.alignedPtr());
187 desc.setOffset(baseDesc.offset());
188 for (auto en : llvm::enumerate(dstType.getShape()))
189 desc.setConstantSize(en.index(), en.value());
190 for (auto en : llvm::enumerate(strides))
191 desc.setConstantStride(en.index(), en.value());
192 rewriter.replaceOp(op, {desc});
193 return success();
194 }
195 };
196
197 /// Conversion pattern that transforms a linalg.slice op into:
198 /// 1. An "undef" value for the ViewDescriptor.
199 /// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
200 /// and stride corresponding to the region of memory within the bounds of
201 /// the parent view.
202 /// The linalg.slice op is replaced by the alloca'ed pointer.
203 class SliceOpConversion : public ConvertToLLVMPattern {
204 public:
SliceOpConversion(MLIRContext * context,LLVMTypeConverter & lowering_)205 explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
206 : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
207
208 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const209 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
210 ConversionPatternRewriter &rewriter) const override {
211 edsc::ScopedContext context(rewriter, op->getLoc());
212 SliceOpAdaptor adaptor(operands);
213 BaseViewConversionHelper baseDesc(adaptor.view());
214
215 auto sliceOp = cast<SliceOp>(op);
216 auto memRefType = sliceOp.getBaseViewType();
217 auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64))
218 .cast<LLVM::LLVMType>();
219
220 BaseViewConversionHelper desc(
221 typeConverter.convertType(sliceOp.getShapedType()));
222
223 // TODO: extract sizes and emit asserts.
224 SmallVector<Value, 4> strides(memRefType.getRank());
225 for (int i = 0, e = memRefType.getRank(); i < e; ++i)
226 strides[i] = baseDesc.stride(i);
227
228 auto pos = [&rewriter](ArrayRef<int64_t> values) {
229 return rewriter.getI64ArrayAttr(values);
230 };
231
232 // Compute base offset.
233 Value baseOffset = baseDesc.offset();
234 for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
235 Value indexing = adaptor.indexings()[i];
236 Value min = indexing;
237 if (sliceOp.indexing(i).getType().isa<RangeType>())
238 min = llvm_extractvalue(int64Ty, indexing, pos(0));
239 baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
240 }
241
242 // Insert the base and aligned pointers.
243 desc.setAllocatedPtr(baseDesc.allocatedPtr());
244 desc.setAlignedPtr(baseDesc.alignedPtr());
245
246 // Insert base offset.
247 desc.setOffset(baseOffset);
248
249 // Corner case, no sizes or strides: early return the descriptor.
250 if (sliceOp.getShapedType().getRank() == 0)
251 return rewriter.replaceOp(op, {desc}), success();
252
253 Value zero = llvm_constant(
254 int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
255 // Compute and insert view sizes (max - min along the range) and strides.
256 // Skip the non-range operands as they will be projected away from the view.
257 int numNewDims = 0;
258 for (auto en : llvm::enumerate(sliceOp.indexings())) {
259 Value indexing = en.value();
260 if (indexing.getType().isa<RangeType>()) {
261 int rank = en.index();
262 Value rangeDescriptor = adaptor.indexings()[rank];
263 Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
264 Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
265 Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
266 Value baseSize = baseDesc.size(rank);
267
268 // Bound upper by base view upper bound.
269 max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
270 baseSize);
271 Value size = llvm_sub(max, min);
272 // Bound lower by zero.
273 size =
274 llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
275 Value stride = llvm_mul(strides[rank], step);
276 desc.setSize(numNewDims, size);
277 desc.setStride(numNewDims, stride);
278 ++numNewDims;
279 }
280 }
281
282 rewriter.replaceOp(op, {desc});
283 return success();
284 }
285 };
286
287 // YieldOp produces and LLVM::ReturnOp.
288 class YieldOpConversion : public ConvertToLLVMPattern {
289 public:
YieldOpConversion(MLIRContext * context,LLVMTypeConverter & lowering_)290 explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
291 : ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
292 lowering_) {}
293
294 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const295 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
296 ConversionPatternRewriter &rewriter) const override {
297 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
298 return success();
299 }
300 };
301 } // namespace
302
303 /// Populate the given list with patterns that convert from Linalg to LLVM.
populateLinalgToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns,MLIRContext * ctx)304 void mlir::populateLinalgToLLVMConversionPatterns(
305 LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
306 MLIRContext *ctx) {
307 patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
308 YieldOpConversion>(ctx, converter);
309
310 // Populate the type conversions for the linalg types.
311 converter.addConversion(
312 [&](RangeType type) { return convertRangeType(type, converter); });
313 }
314
315 namespace {
316 struct ConvertLinalgToLLVMPass
317 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
318 void runOnOperation() override;
319 };
320 } // namespace
321
runOnOperation()322 void ConvertLinalgToLLVMPass::runOnOperation() {
323 auto module = getOperation();
324
325 // Convert to the LLVM IR dialect using the converter defined above.
326 OwningRewritePatternList patterns;
327 LLVMTypeConverter converter(&getContext());
328 populateAffineToStdConversionPatterns(patterns, &getContext());
329 populateLoopToStdConversionPatterns(patterns, &getContext());
330 populateStdToLLVMConversionPatterns(converter, patterns);
331 populateVectorToSCFConversionPatterns(patterns, &getContext());
332 populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
333 populateVectorToLLVMConversionPatterns(converter, patterns);
334 populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
335
336 LLVMConversionTarget target(getContext());
337 target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
338 if (failed(applyFullConversion(module, target, std::move(patterns))))
339 signalPassFailure();
340 }
341
createConvertLinalgToLLVMPass()342 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
343 return std::make_unique<ConvertLinalgToLLVMPass>();
344 }
345