1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "mlir/IR/Operation.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // BufferizeTypeConverter
16 //===----------------------------------------------------------------------===//
17 
18 /// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter()19 BufferizeTypeConverter::BufferizeTypeConverter() {
20   // Keep all types unchanged.
21   addConversion([](Type type) { return type; });
22   // Convert RankedTensorType to MemRefType.
23   addConversion([](RankedTensorType type) -> Type {
24     return MemRefType::get(type.getShape(), type.getElementType());
25   });
26   // Convert UnrankedTensorType to UnrankedMemRefType.
27   addConversion([](UnrankedTensorType type) -> Type {
28     return UnrankedMemRefType::get(type.getElementType(), 0);
29   });
30   addSourceMaterialization([](OpBuilder &builder, TensorType type,
31                               ValueRange inputs, Location loc) -> Value {
32     assert(inputs.size() == 1);
33     assert(inputs[0].getType().isa<BaseMemRefType>());
34     return builder.create<TensorLoadOp>(loc, type, inputs[0]);
35   });
36   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
37                               ValueRange inputs, Location loc) -> Value {
38     assert(inputs.size() == 1);
39     assert(inputs[0].getType().isa<TensorType>());
40     return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
41   });
42 }
43 
44 /// This method tries to decompose a value of a certain type using provided
45 /// decompose callback functions. If it is unable to do so, the original value
46 /// is returned.
tryDecomposeValue(OpBuilder & builder,Location loc,Type type,Value value,SmallVectorImpl<Value> & results)47 void BufferizeTypeConverter::tryDecomposeValue(
48     OpBuilder &builder, Location loc, Type type, Value value,
49     SmallVectorImpl<Value> &results) {
50   for (auto &conversion : decomposeValueConversions)
51     if (conversion(builder, loc, type, value, results))
52       return;
53   results.push_back(value);
54 }
55 
56 /// This method tries to decompose a type using provided decompose callback
57 /// functions. If it is unable to do so, the original type is returned.
tryDecomposeType(Type type,SmallVectorImpl<Type> & types)58 void BufferizeTypeConverter::tryDecomposeType(Type type,
59                                               SmallVectorImpl<Type> &types) {
60   for (auto &conversion : decomposeTypeConversions)
61     if (conversion(type, types))
62       return;
63   types.push_back(type);
64 }
65 
populateBufferizeMaterializationLegality(ConversionTarget & target)66 void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
67   target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
68 }
69 
70 namespace {
71 // In a finalizing bufferize conversion, we know that all tensors have been
72 // converted to memrefs, thus, this op becomes an identity.
73 class BufferizeTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
74 public:
75   using OpConversionPattern::OpConversionPattern;
76   LogicalResult
matchAndRewrite(TensorLoadOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const77   matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
78                   ConversionPatternRewriter &rewriter) const override {
79     TensorLoadOp::Adaptor adaptor(operands);
80     rewriter.replaceOp(op, adaptor.memref());
81     return success();
82   }
83 };
84 } // namespace
85 
86 namespace {
87 // In a finalizing bufferize conversion, we know that all tensors have been
88 // converted to memrefs, thus, this op becomes an identity.
89 class BufferizeTensorToMemrefOp : public OpConversionPattern<TensorToMemrefOp> {
90 public:
91   using OpConversionPattern::OpConversionPattern;
92   LogicalResult
matchAndRewrite(TensorToMemrefOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const93   matchAndRewrite(TensorToMemrefOp op, ArrayRef<Value> operands,
94                   ConversionPatternRewriter &rewriter) const override {
95     TensorToMemrefOp::Adaptor adaptor(operands);
96     rewriter.replaceOp(op, adaptor.tensor());
97     return success();
98   }
99 };
100 } // namespace
101 
populateEliminateBufferizeMaterializationsPatterns(MLIRContext * context,BufferizeTypeConverter & typeConverter,OwningRewritePatternList & patterns)102 void mlir::populateEliminateBufferizeMaterializationsPatterns(
103     MLIRContext *context, BufferizeTypeConverter &typeConverter,
104     OwningRewritePatternList &patterns) {
105   patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
106       typeConverter, context);
107 }
108 
109 //===----------------------------------------------------------------------===//
110 // BufferizeFuncOpConverter
111 //===----------------------------------------------------------------------===//
112 
113 /// Performs the actual function signature rewriting step.
matchAndRewrite(mlir::FuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const114 LogicalResult BufferizeFuncOpConverter::matchAndRewrite(
115     mlir::FuncOp funcOp, ArrayRef<Value> operands,
116     ConversionPatternRewriter &rewriter) const {
117   auto funcType = funcOp.getType();
118 
119   // Convert function arguments using the provided TypeConverter.
120   TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
121   for (auto argType : llvm::enumerate(funcType.getInputs())) {
122     SmallVector<Type, 2> decomposedTypes, convertedTypes;
123     converter.tryDecomposeType(argType.value(), decomposedTypes);
124     converter.convertTypes(decomposedTypes, convertedTypes);
125     conversion.addInputs(argType.index(), convertedTypes);
126   }
127 
128   // Convert the result types of the function.
129   SmallVector<Type, 2> newResultTypes;
130   newResultTypes.reserve(funcOp.getNumResults());
131   for (Type resultType : funcType.getResults()) {
132     SmallVector<Type, 2> originTypes;
133     converter.tryDecomposeType(resultType, originTypes);
134     for (auto origin : originTypes)
135       newResultTypes.push_back(converter.convertType(origin));
136   }
137 
138   if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter,
139                                          &conversion)))
140     return failure();
141 
142   // Update the signature of the function.
143   rewriter.updateRootInPlace(funcOp, [&] {
144     funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
145                                             newResultTypes));
146   });
147   return success();
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // BufferizeCallOpConverter
152 //===----------------------------------------------------------------------===//
153 
154 /// Performs the actual rewriting step.
matchAndRewrite(CallOp callOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const155 LogicalResult BufferizeCallOpConverter::matchAndRewrite(
156     CallOp callOp, ArrayRef<Value> operands,
157     ConversionPatternRewriter &rewriter) const {
158 
159   Location loc = callOp.getLoc();
160   SmallVector<Value, 2> newOperands;
161 
162   // TODO: if the CallOp references a FuncOp that only has a declaration (e.g.
163   // to an externally defined symbol like an external library calls), only
164   // convert if some special attribute is set.
165   // This will allow more control of interop across ABI boundaries.
166 
167   // Create the operands list of the new `CallOp`. It unpacks the decomposable
168   // values if a decompose callback function has been provided by the user.
169   for (auto operand : operands)
170     converter.tryDecomposeValue(rewriter, loc, operand.getType(), operand,
171                                 newOperands);
172 
173   // Create the new result types for the new `CallOp` and track the indices in
174   // the new call op's results that correspond to the old call op's results.
175   SmallVector<Type, 2> newResultTypes;
176   SmallVector<SmallVector<int, 2>, 4> expandedResultIndices;
177   expandedResultIndices.resize(callOp.getNumResults());
178   for (auto result : llvm::enumerate(callOp.getResults())) {
179     SmallVector<Type, 2> originTypes;
180     converter.tryDecomposeType(result.value().getType(), originTypes);
181     auto &resultMapping = expandedResultIndices[result.index()];
182     for (Type origin : originTypes) {
183       Type converted = converter.convertType(origin);
184       newResultTypes.push_back(converted);
185       // The result value is not yet available. Its index is kept and it is
186       // replaced with the actual value of the new `CallOp` later.
187       resultMapping.push_back(newResultTypes.size() - 1);
188     }
189   }
190 
191   CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
192                                              newResultTypes, newOperands);
193 
194   // Build a replacing value for each result to replace its uses. If a result
195   // has multiple mapping values, it needs to be packed to a single value.
196   SmallVector<Value, 2> replacedValues;
197   replacedValues.reserve(callOp.getNumResults());
198   for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
199     auto valuesToPack = llvm::to_vector<6>(
200         llvm::map_range(expandedResultIndices[i],
201                         [&](int i) { return newCallOp.getResult(i); }));
202     if (valuesToPack.empty()) {
203       // No replacement is required.
204       replacedValues.push_back(nullptr);
205     } else if (valuesToPack.size() == 1) {
206       replacedValues.push_back(valuesToPack.front());
207     } else {
208       // Values need to be packed using callback function. The same callback
209       // that is used for materializeArgumentConversion is used for packing.
210       Value packed = converter.materializeArgumentConversion(
211           rewriter, loc, callOp.getType(i), valuesToPack);
212       replacedValues.push_back(packed);
213     }
214   }
215   rewriter.replaceOp(callOp, replacedValues);
216   return success();
217 }
218