1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
21 #include "mlir/Dialect/Linalg/Passes.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/MLIRContext.h"
29 #include "mlir/IR/Module.h"
30 #include "mlir/IR/Operation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/StandardTypes.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Support/LogicalResult.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/SetVector.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/Allocator.h"
42 #include "llvm/Support/ErrorHandling.h"
43 
44 using namespace mlir;
45 using namespace mlir::edsc;
46 using namespace mlir::edsc::intrinsics;
47 using namespace mlir::LLVM;
48 using namespace mlir::linalg;
49 
50 using llvm_add = ValueBuilder<LLVM::AddOp>;
51 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
52 using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
53 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
54 using llvm_gep = ValueBuilder<LLVM::GEPOp>;
55 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
56 using llvm_call = OperationBuilder<LLVM::CallOp>;
57 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
58 using llvm_load = ValueBuilder<LLVM::LoadOp>;
59 using llvm_store = OperationBuilder<LLVM::StoreOp>;
60 using llvm_select = ValueBuilder<LLVM::SelectOp>;
61 using llvm_mul = ValueBuilder<LLVM::MulOp>;
62 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
63 using llvm_sub = ValueBuilder<LLVM::SubOp>;
64 using llvm_undef = ValueBuilder<LLVM::UndefOp>;
65 using llvm_urem = ValueBuilder<LLVM::URemOp>;
66 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
67 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
68 
69 template <typename T>
getPtrToElementType(T containerType,LLVMTypeConverter & lowering)70 static LLVMType getPtrToElementType(T containerType,
71                                     LLVMTypeConverter &lowering) {
72   return lowering.convertType(containerType.getElementType())
73       .template cast<LLVMType>()
74       .getPointerTo();
75 }
76 
77 /// Convert the given range descriptor type to the LLVMIR dialect.
78 /// Range descriptor contains the range bounds and the step as 64-bit integers.
79 ///
80 /// struct {
81 ///   int64_t min;
82 ///   int64_t max;
83 ///   int64_t step;
84 /// };
convertRangeType(RangeType t,LLVMTypeConverter & converter)85 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
86   auto *context = t.getContext();
87   auto int64Ty = converter.convertType(IntegerType::get(64, context))
88                      .cast<LLVM::LLVMType>();
89   return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
90 }
91 
92 namespace {
93 /// EDSC-compatible wrapper for MemRefDescriptor.
94 class BaseViewConversionHelper {
95 public:
BaseViewConversionHelper(Type type)96   BaseViewConversionHelper(Type type)
97       : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
98 
BaseViewConversionHelper(Value v)99   BaseViewConversionHelper(Value v) : d(v) {}
100 
101   /// Wrappers around MemRefDescriptor that use EDSC builder and location.
allocatedPtr()102   Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
setAllocatedPtr(Value v)103   void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
alignedPtr()104   Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
setAlignedPtr(Value v)105   void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
offset()106   Value offset() { return d.offset(rewriter(), loc()); }
setOffset(Value v)107   void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
size(unsigned i)108   Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
setSize(unsigned i,Value v)109   void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
setConstantSize(unsigned i,int64_t v)110   void setConstantSize(unsigned i, int64_t v) {
111     d.setConstantSize(rewriter(), loc(), i, v);
112   }
stride(unsigned i)113   Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
setStride(unsigned i,Value v)114   void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
setConstantStride(unsigned i,int64_t v)115   void setConstantStride(unsigned i, int64_t v) {
116     d.setConstantStride(rewriter(), loc(), i, v);
117   }
118 
operator Value()119   operator Value() { return d; }
120 
121 private:
rewriter()122   OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
loc()123   Location loc() { return ScopedContext::getLocation(); }
124 
125   MemRefDescriptor d;
126 };
127 
128 // RangeOp creates a new range descriptor.
129 class RangeOpConversion : public ConvertToLLVMPattern {
130 public:
RangeOpConversion(MLIRContext * context,LLVMTypeConverter & lowering_)131   explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
132       : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
133 
134   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const135   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
136                   ConversionPatternRewriter &rewriter) const override {
137     auto rangeOp = cast<RangeOp>(op);
138     auto rangeDescriptorTy =
139         convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
140 
141     edsc::ScopedContext context(rewriter, op->getLoc());
142 
143     // Fill in an aggregate value of the descriptor.
144     RangeOpAdaptor adaptor(operands);
145     Value desc = llvm_undef(rangeDescriptorTy);
146     desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
147     desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
148     desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
149     rewriter.replaceOp(op, desc);
150     return success();
151   }
152 };
153 
154 // ReshapeOp creates a new view descriptor of the proper rank.
155 // For now, the only conversion supported is for target MemRef with static sizes
156 // and strides.
157 class ReshapeOpConversion : public ConvertToLLVMPattern {
158 public:
ReshapeOpConversion(MLIRContext * context,LLVMTypeConverter & lowering_)159   explicit ReshapeOpConversion(MLIRContext *context,
160                                LLVMTypeConverter &lowering_)
161       : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
162                              lowering_) {}
163 
164   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const165   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
166                   ConversionPatternRewriter &rewriter) const override {
167     auto reshapeOp = cast<ReshapeOp>(op);
168     MemRefType dstType = reshapeOp.getResultType();
169 
170     if (!dstType.hasStaticShape())
171       return failure();
172 
173     int64_t offset;
174     SmallVector<int64_t, 4> strides;
175     auto res = getStridesAndOffset(dstType, strides, offset);
176     if (failed(res) || llvm::any_of(strides, [](int64_t val) {
177           return ShapedType::isDynamicStrideOrOffset(val);
178         }))
179       return failure();
180 
181     edsc::ScopedContext context(rewriter, op->getLoc());
182     ReshapeOpAdaptor adaptor(operands);
183     BaseViewConversionHelper baseDesc(adaptor.src());
184     BaseViewConversionHelper desc(typeConverter.convertType(dstType));
185     desc.setAllocatedPtr(baseDesc.allocatedPtr());
186     desc.setAlignedPtr(baseDesc.alignedPtr());
187     desc.setOffset(baseDesc.offset());
188     for (auto en : llvm::enumerate(dstType.getShape()))
189       desc.setConstantSize(en.index(), en.value());
190     for (auto en : llvm::enumerate(strides))
191       desc.setConstantStride(en.index(), en.value());
192     rewriter.replaceOp(op, {desc});
193     return success();
194   }
195 };
196 
197 /// Conversion pattern that transforms a linalg.slice op into:
198 ///   1. An "undef" value for the ViewDescriptor.
199 ///   2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
200 ///      and stride corresponding to the region of memory within the bounds of
201 ///      the parent view.
202 /// The linalg.slice op is replaced by the alloca'ed pointer.
203 class SliceOpConversion : public ConvertToLLVMPattern {
204 public:
SliceOpConversion(MLIRContext * context,LLVMTypeConverter & lowering_)205   explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
206       : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
207 
208   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const209   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
210                   ConversionPatternRewriter &rewriter) const override {
211     edsc::ScopedContext context(rewriter, op->getLoc());
212     SliceOpAdaptor adaptor(operands);
213     BaseViewConversionHelper baseDesc(adaptor.view());
214 
215     auto sliceOp = cast<SliceOp>(op);
216     auto memRefType = sliceOp.getBaseViewType();
217     auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64))
218                        .cast<LLVM::LLVMType>();
219 
220     BaseViewConversionHelper desc(
221         typeConverter.convertType(sliceOp.getShapedType()));
222 
223     // TODO: extract sizes and emit asserts.
224     SmallVector<Value, 4> strides(memRefType.getRank());
225     for (int i = 0, e = memRefType.getRank(); i < e; ++i)
226       strides[i] = baseDesc.stride(i);
227 
228     auto pos = [&rewriter](ArrayRef<int64_t> values) {
229       return rewriter.getI64ArrayAttr(values);
230     };
231 
232     // Compute base offset.
233     Value baseOffset = baseDesc.offset();
234     for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
235       Value indexing = adaptor.indexings()[i];
236       Value min = indexing;
237       if (sliceOp.indexing(i).getType().isa<RangeType>())
238         min = llvm_extractvalue(int64Ty, indexing, pos(0));
239       baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
240     }
241 
242     // Insert the base and aligned pointers.
243     desc.setAllocatedPtr(baseDesc.allocatedPtr());
244     desc.setAlignedPtr(baseDesc.alignedPtr());
245 
246     // Insert base offset.
247     desc.setOffset(baseOffset);
248 
249     // Corner case, no sizes or strides: early return the descriptor.
250     if (sliceOp.getShapedType().getRank() == 0)
251       return rewriter.replaceOp(op, {desc}), success();
252 
253     Value zero = llvm_constant(
254         int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
255     // Compute and insert view sizes (max - min along the range) and strides.
256     // Skip the non-range operands as they will be projected away from the view.
257     int numNewDims = 0;
258     for (auto en : llvm::enumerate(sliceOp.indexings())) {
259       Value indexing = en.value();
260       if (indexing.getType().isa<RangeType>()) {
261         int rank = en.index();
262         Value rangeDescriptor = adaptor.indexings()[rank];
263         Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
264         Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
265         Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
266         Value baseSize = baseDesc.size(rank);
267 
268         // Bound upper by base view upper bound.
269         max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
270                           baseSize);
271         Value size = llvm_sub(max, min);
272         // Bound lower by zero.
273         size =
274             llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
275         Value stride = llvm_mul(strides[rank], step);
276         desc.setSize(numNewDims, size);
277         desc.setStride(numNewDims, stride);
278         ++numNewDims;
279       }
280     }
281 
282     rewriter.replaceOp(op, {desc});
283     return success();
284   }
285 };
286 
287 // YieldOp produces and LLVM::ReturnOp.
288 class YieldOpConversion : public ConvertToLLVMPattern {
289 public:
YieldOpConversion(MLIRContext * context,LLVMTypeConverter & lowering_)290   explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
291       : ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context,
292                              lowering_) {}
293 
294   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const295   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
296                   ConversionPatternRewriter &rewriter) const override {
297     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
298     return success();
299   }
300 };
301 } // namespace
302 
303 /// Populate the given list with patterns that convert from Linalg to LLVM.
populateLinalgToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns,MLIRContext * ctx)304 void mlir::populateLinalgToLLVMConversionPatterns(
305     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
306     MLIRContext *ctx) {
307   patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
308                   YieldOpConversion>(ctx, converter);
309 
310   // Populate the type conversions for the linalg types.
311   converter.addConversion(
312       [&](RangeType type) { return convertRangeType(type, converter); });
313 }
314 
315 namespace {
316 struct ConvertLinalgToLLVMPass
317     : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
318   void runOnOperation() override;
319 };
320 } // namespace
321 
runOnOperation()322 void ConvertLinalgToLLVMPass::runOnOperation() {
323   auto module = getOperation();
324 
325   // Convert to the LLVM IR dialect using the converter defined above.
326   OwningRewritePatternList patterns;
327   LLVMTypeConverter converter(&getContext());
328   populateAffineToStdConversionPatterns(patterns, &getContext());
329   populateLoopToStdConversionPatterns(patterns, &getContext());
330   populateStdToLLVMConversionPatterns(converter, patterns);
331   populateVectorToSCFConversionPatterns(patterns, &getContext());
332   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
333   populateVectorToLLVMConversionPatterns(converter, patterns);
334   populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
335 
336   LLVMConversionTarget target(getContext());
337   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
338   if (failed(applyFullConversion(module, target, std::move(patterns))))
339     signalPassFailure();
340 }
341 
createConvertLinalgToLLVMPass()342 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
343   return std::make_unique<ConvertLinalgToLLVMPass>();
344 }
345