1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/Bufferize.h"
10 #include "mlir/IR/Operation.h"
11
12 using namespace mlir;
13
14 //===----------------------------------------------------------------------===//
15 // BufferizeTypeConverter
16 //===----------------------------------------------------------------------===//
17
18 /// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter()19 BufferizeTypeConverter::BufferizeTypeConverter() {
20 // Keep all types unchanged.
21 addConversion([](Type type) { return type; });
22 // Convert RankedTensorType to MemRefType.
23 addConversion([](RankedTensorType type) -> Type {
24 return MemRefType::get(type.getShape(), type.getElementType());
25 });
26 // Convert UnrankedTensorType to UnrankedMemRefType.
27 addConversion([](UnrankedTensorType type) -> Type {
28 return UnrankedMemRefType::get(type.getElementType(), 0);
29 });
30 addSourceMaterialization([](OpBuilder &builder, TensorType type,
31 ValueRange inputs, Location loc) -> Value {
32 assert(inputs.size() == 1);
33 assert(inputs[0].getType().isa<BaseMemRefType>());
34 return builder.create<TensorLoadOp>(loc, type, inputs[0]);
35 });
36 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
37 ValueRange inputs, Location loc) -> Value {
38 assert(inputs.size() == 1);
39 assert(inputs[0].getType().isa<TensorType>());
40 return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
41 });
42 }
43
44 /// This method tries to decompose a value of a certain type using provided
45 /// decompose callback functions. If it is unable to do so, the original value
46 /// is returned.
tryDecomposeValue(OpBuilder & builder,Location loc,Type type,Value value,SmallVectorImpl<Value> & results)47 void BufferizeTypeConverter::tryDecomposeValue(
48 OpBuilder &builder, Location loc, Type type, Value value,
49 SmallVectorImpl<Value> &results) {
50 for (auto &conversion : decomposeValueConversions)
51 if (conversion(builder, loc, type, value, results))
52 return;
53 results.push_back(value);
54 }
55
56 /// This method tries to decompose a type using provided decompose callback
57 /// functions. If it is unable to do so, the original type is returned.
tryDecomposeType(Type type,SmallVectorImpl<Type> & types)58 void BufferizeTypeConverter::tryDecomposeType(Type type,
59 SmallVectorImpl<Type> &types) {
60 for (auto &conversion : decomposeTypeConversions)
61 if (conversion(type, types))
62 return;
63 types.push_back(type);
64 }
65
populateBufferizeMaterializationLegality(ConversionTarget & target)66 void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
67 target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
68 }
69
70 namespace {
71 // In a finalizing bufferize conversion, we know that all tensors have been
72 // converted to memrefs, thus, this op becomes an identity.
73 class BufferizeTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
74 public:
75 using OpConversionPattern::OpConversionPattern;
76 LogicalResult
matchAndRewrite(TensorLoadOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const77 matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
78 ConversionPatternRewriter &rewriter) const override {
79 TensorLoadOp::Adaptor adaptor(operands);
80 rewriter.replaceOp(op, adaptor.memref());
81 return success();
82 }
83 };
84 } // namespace
85
86 namespace {
87 // In a finalizing bufferize conversion, we know that all tensors have been
88 // converted to memrefs, thus, this op becomes an identity.
89 class BufferizeTensorToMemrefOp : public OpConversionPattern<TensorToMemrefOp> {
90 public:
91 using OpConversionPattern::OpConversionPattern;
92 LogicalResult
matchAndRewrite(TensorToMemrefOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const93 matchAndRewrite(TensorToMemrefOp op, ArrayRef<Value> operands,
94 ConversionPatternRewriter &rewriter) const override {
95 TensorToMemrefOp::Adaptor adaptor(operands);
96 rewriter.replaceOp(op, adaptor.tensor());
97 return success();
98 }
99 };
100 } // namespace
101
populateEliminateBufferizeMaterializationsPatterns(MLIRContext * context,BufferizeTypeConverter & typeConverter,OwningRewritePatternList & patterns)102 void mlir::populateEliminateBufferizeMaterializationsPatterns(
103 MLIRContext *context, BufferizeTypeConverter &typeConverter,
104 OwningRewritePatternList &patterns) {
105 patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
106 typeConverter, context);
107 }
108
109 //===----------------------------------------------------------------------===//
110 // BufferizeFuncOpConverter
111 //===----------------------------------------------------------------------===//
112
113 /// Performs the actual function signature rewriting step.
matchAndRewrite(mlir::FuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const114 LogicalResult BufferizeFuncOpConverter::matchAndRewrite(
115 mlir::FuncOp funcOp, ArrayRef<Value> operands,
116 ConversionPatternRewriter &rewriter) const {
117 auto funcType = funcOp.getType();
118
119 // Convert function arguments using the provided TypeConverter.
120 TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
121 for (auto argType : llvm::enumerate(funcType.getInputs())) {
122 SmallVector<Type, 2> decomposedTypes, convertedTypes;
123 converter.tryDecomposeType(argType.value(), decomposedTypes);
124 converter.convertTypes(decomposedTypes, convertedTypes);
125 conversion.addInputs(argType.index(), convertedTypes);
126 }
127
128 // Convert the result types of the function.
129 SmallVector<Type, 2> newResultTypes;
130 newResultTypes.reserve(funcOp.getNumResults());
131 for (Type resultType : funcType.getResults()) {
132 SmallVector<Type, 2> originTypes;
133 converter.tryDecomposeType(resultType, originTypes);
134 for (auto origin : originTypes)
135 newResultTypes.push_back(converter.convertType(origin));
136 }
137
138 if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter,
139 &conversion)))
140 return failure();
141
142 // Update the signature of the function.
143 rewriter.updateRootInPlace(funcOp, [&] {
144 funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
145 newResultTypes));
146 });
147 return success();
148 }
149
150 //===----------------------------------------------------------------------===//
151 // BufferizeCallOpConverter
152 //===----------------------------------------------------------------------===//
153
154 /// Performs the actual rewriting step.
matchAndRewrite(CallOp callOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const155 LogicalResult BufferizeCallOpConverter::matchAndRewrite(
156 CallOp callOp, ArrayRef<Value> operands,
157 ConversionPatternRewriter &rewriter) const {
158
159 Location loc = callOp.getLoc();
160 SmallVector<Value, 2> newOperands;
161
162 // TODO: if the CallOp references a FuncOp that only has a declaration (e.g.
163 // to an externally defined symbol like an external library calls), only
164 // convert if some special attribute is set.
165 // This will allow more control of interop across ABI boundaries.
166
167 // Create the operands list of the new `CallOp`. It unpacks the decomposable
168 // values if a decompose callback function has been provided by the user.
169 for (auto operand : operands)
170 converter.tryDecomposeValue(rewriter, loc, operand.getType(), operand,
171 newOperands);
172
173 // Create the new result types for the new `CallOp` and track the indices in
174 // the new call op's results that correspond to the old call op's results.
175 SmallVector<Type, 2> newResultTypes;
176 SmallVector<SmallVector<int, 2>, 4> expandedResultIndices;
177 expandedResultIndices.resize(callOp.getNumResults());
178 for (auto result : llvm::enumerate(callOp.getResults())) {
179 SmallVector<Type, 2> originTypes;
180 converter.tryDecomposeType(result.value().getType(), originTypes);
181 auto &resultMapping = expandedResultIndices[result.index()];
182 for (Type origin : originTypes) {
183 Type converted = converter.convertType(origin);
184 newResultTypes.push_back(converted);
185 // The result value is not yet available. Its index is kept and it is
186 // replaced with the actual value of the new `CallOp` later.
187 resultMapping.push_back(newResultTypes.size() - 1);
188 }
189 }
190
191 CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
192 newResultTypes, newOperands);
193
194 // Build a replacing value for each result to replace its uses. If a result
195 // has multiple mapping values, it needs to be packed to a single value.
196 SmallVector<Value, 2> replacedValues;
197 replacedValues.reserve(callOp.getNumResults());
198 for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
199 auto valuesToPack = llvm::to_vector<6>(
200 llvm::map_range(expandedResultIndices[i],
201 [&](int i) { return newCallOp.getResult(i); }));
202 if (valuesToPack.empty()) {
203 // No replacement is required.
204 replacedValues.push_back(nullptr);
205 } else if (valuesToPack.size() == 1) {
206 replacedValues.push_back(valuesToPack.front());
207 } else {
208 // Values need to be packed using callback function. The same callback
209 // that is used for materializeArgumentConversion is used for packing.
210 Value packed = converter.materializeArgumentConversion(
211 rewriter, loc, callOp.getType(i), valuesToPack);
212 replacedValues.push_back(packed);
213 }
214 }
215 rewriter.replaceOp(callOp, replacedValues);
216 return success();
217 }
218