1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Transforms/FoldUtils.h"
15
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/Operation.h"
20
21 using namespace mlir;
22
23 /// Given an operation, find the parent region that folded constants should be
24 /// inserted into.
25 static Region *
getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> & interfaces,Block * insertionBlock)26 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
27 Block *insertionBlock) {
28 while (Region *region = insertionBlock->getParent()) {
29 // Insert in this region for any of the following scenarios:
30 // * The parent is unregistered, or is known to be isolated from above.
31 // * The parent is a top-level operation.
32 auto *parentOp = region->getParentOp();
33 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
34 !parentOp->getBlock())
35 return region;
36
37 // Otherwise, check if this region is a desired insertion region.
38 auto *interface = interfaces.getInterfaceFor(parentOp);
39 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
40 return region;
41
42 // Traverse up the parent looking for an insertion region.
43 insertionBlock = parentOp->getBlock();
44 }
45 llvm_unreachable("expected valid insertion region");
46 }
47
48 /// A utility function used to materialize a constant for a given attribute and
49 /// type. On success, a valid constant value is returned. Otherwise, null is
50 /// returned
materializeConstant(Dialect * dialect,OpBuilder & builder,Attribute value,Type type,Location loc)51 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
52 Attribute value, Type type,
53 Location loc) {
54 auto insertPt = builder.getInsertionPoint();
55 (void)insertPt;
56
57 // Ask the dialect to materialize a constant operation for this value.
58 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
59 assert(insertPt == builder.getInsertionPoint());
60 assert(matchPattern(constOp, m_Constant()));
61 return constOp;
62 }
63
64 // TODO: To facilitate splitting the std dialect (PR48490), have a special
65 // case for falling back to std.constant. Eventually, we will have separate
66 // ops tensor.constant, int.constant, float.constant, etc. that live in their
67 // respective dialects, which will allow each dialect to implement the
68 // materializeConstant hook above.
69 //
70 // The special case is needed because in the interim state while we are
71 // splitting out those dialects from std, the std dialect depends on the
72 // tensor dialect, which makes it impossible for the tensor dialect to use
73 // std.constant (it would be a cyclic dependency) as part of its
74 // materializeConstant hook.
75 //
76 // If the dialect is unable to materialize a constant, check to see if the
77 // standard constant can be used.
78 if (ConstantOp::isBuildableWith(value, type))
79 return builder.create<ConstantOp>(loc, type, value);
80 return nullptr;
81 }
82
83 //===----------------------------------------------------------------------===//
84 // OperationFolder
85 //===----------------------------------------------------------------------===//
86
tryToFold(Operation * op,function_ref<void (Operation *)> processGeneratedConstants,function_ref<void (Operation *)> preReplaceAction,bool * inPlaceUpdate)87 LogicalResult OperationFolder::tryToFold(
88 Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
89 function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
90 if (inPlaceUpdate)
91 *inPlaceUpdate = false;
92
93 // If this is a unique'd constant, return failure as we know that it has
94 // already been folded.
95 if (referencedDialects.count(op))
96 return failure();
97
98 // Try to fold the operation.
99 SmallVector<Value, 8> results;
100 OpBuilder builder(op);
101 if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
102 return failure();
103
104 // Check to see if the operation was just updated in place.
105 if (results.empty()) {
106 if (inPlaceUpdate)
107 *inPlaceUpdate = true;
108 return success();
109 }
110
111 // Constant folding succeeded. We will start replacing this op's uses and
112 // erase this op. Invoke the callback provided by the caller to perform any
113 // pre-replacement action.
114 if (preReplaceAction)
115 preReplaceAction(op);
116
117 // Replace all of the result values and erase the operation.
118 for (unsigned i = 0, e = results.size(); i != e; ++i)
119 op->getResult(i).replaceAllUsesWith(results[i]);
120 op->erase();
121 return success();
122 }
123
124 /// Notifies that the given constant `op` should be remove from this
125 /// OperationFolder's internal bookkeeping.
notifyRemoval(Operation * op)126 void OperationFolder::notifyRemoval(Operation *op) {
127 // Check to see if this operation is uniqued within the folder.
128 auto it = referencedDialects.find(op);
129 if (it == referencedDialects.end())
130 return;
131
132 // Get the constant value for this operation, this is the value that was used
133 // to unique the operation internally.
134 Attribute constValue;
135 matchPattern(op, m_Constant(&constValue));
136 assert(constValue);
137
138 // Get the constant map that this operation was uniqued in.
139 auto &uniquedConstants =
140 foldScopes[getInsertionRegion(interfaces, op->getBlock())];
141
142 // Erase all of the references to this operation.
143 auto type = op->getResult(0).getType();
144 for (auto *dialect : it->second)
145 uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
146 referencedDialects.erase(it);
147 }
148
149 /// Clear out any constants cached inside of the folder.
clear()150 void OperationFolder::clear() {
151 foldScopes.clear();
152 referencedDialects.clear();
153 }
154
155 /// Get or create a constant using the given builder. On success this returns
156 /// the constant operation, nullptr otherwise.
getOrCreateConstant(OpBuilder & builder,Dialect * dialect,Attribute value,Type type,Location loc)157 Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
158 Attribute value, Type type,
159 Location loc) {
160 OpBuilder::InsertionGuard foldGuard(builder);
161
162 // Use the builder insertion block to find an insertion point for the
163 // constant.
164 auto *insertRegion =
165 getInsertionRegion(interfaces, builder.getInsertionBlock());
166 auto &entry = insertRegion->front();
167 builder.setInsertionPoint(&entry, entry.begin());
168
169 // Get the constant map for the insertion region of this operation.
170 auto &uniquedConstants = foldScopes[insertRegion];
171 Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
172 builder, value, type, loc);
173 return constOp ? constOp->getResult(0) : Value();
174 }
175
176 /// Tries to perform folding on the given `op`. If successful, populates
177 /// `results` with the results of the folding.
tryToFold(OpBuilder & builder,Operation * op,SmallVectorImpl<Value> & results,function_ref<void (Operation *)> processGeneratedConstants)178 LogicalResult OperationFolder::tryToFold(
179 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
180 function_ref<void(Operation *)> processGeneratedConstants) {
181 SmallVector<Attribute, 8> operandConstants;
182 SmallVector<OpFoldResult, 8> foldResults;
183
184 // If this is a commutative operation, move constants to be trailing operands.
185 if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
186 std::stable_partition(
187 op->getOpOperands().begin(), op->getOpOperands().end(),
188 [&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); });
189 }
190
191 // Check to see if any operands to the operation is constant and whether
192 // the operation knows how to constant fold itself.
193 operandConstants.assign(op->getNumOperands(), Attribute());
194 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
195 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
196
197 // Attempt to constant fold the operation.
198 if (failed(op->fold(operandConstants, foldResults)))
199 return failure();
200
201 // Check to see if the operation was just updated in place.
202 if (foldResults.empty())
203 return success();
204 assert(foldResults.size() == op->getNumResults());
205
206 // Create a builder to insert new operations into the entry block of the
207 // insertion region.
208 auto *insertRegion =
209 getInsertionRegion(interfaces, builder.getInsertionBlock());
210 auto &entry = insertRegion->front();
211 OpBuilder::InsertionGuard foldGuard(builder);
212 builder.setInsertionPoint(&entry, entry.begin());
213
214 // Get the constant map for the insertion region of this operation.
215 auto &uniquedConstants = foldScopes[insertRegion];
216
217 // Create the result constants and replace the results.
218 auto *dialect = op->getDialect();
219 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
220 assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
221
222 // Check if the result was an SSA value.
223 if (auto repl = foldResults[i].dyn_cast<Value>()) {
224 if (repl.getType() != op->getResult(i).getType())
225 return failure();
226 results.emplace_back(repl);
227 continue;
228 }
229
230 // Check to see if there is a canonicalized version of this constant.
231 auto res = op->getResult(i);
232 Attribute attrRepl = foldResults[i].get<Attribute>();
233 if (auto *constOp =
234 tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
235 res.getType(), op->getLoc())) {
236 results.push_back(constOp->getResult(0));
237 continue;
238 }
239 // If materialization fails, cleanup any operations generated for the
240 // previous results and return failure.
241 for (Operation &op : llvm::make_early_inc_range(
242 llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
243 notifyRemoval(&op);
244 op.erase();
245 }
246 return failure();
247 }
248
249 // Process any newly generated operations.
250 if (processGeneratedConstants) {
251 for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
252 processGeneratedConstants(&*i);
253 }
254
255 return success();
256 }
257
258 /// Try to get or create a new constant entry. On success this returns the
259 /// constant operation value, nullptr otherwise.
tryGetOrCreateConstant(ConstantMap & uniquedConstants,Dialect * dialect,OpBuilder & builder,Attribute value,Type type,Location loc)260 Operation *OperationFolder::tryGetOrCreateConstant(
261 ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
262 Attribute value, Type type, Location loc) {
263 // Check if an existing mapping already exists.
264 auto constKey = std::make_tuple(dialect, value, type);
265 Operation *&constOp = uniquedConstants[constKey];
266 if (constOp)
267 return constOp;
268
269 // If one doesn't exist, try to materialize one.
270 if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
271 return nullptr;
272
273 // Check to see if the generated constant is in the expected dialect.
274 auto *newDialect = constOp->getDialect();
275 if (newDialect == dialect) {
276 referencedDialects[constOp].push_back(dialect);
277 return constOp;
278 }
279
280 // If it isn't, then we also need to make sure that the mapping for the new
281 // dialect is valid.
282 auto newKey = std::make_tuple(newDialect, value, type);
283
284 // If an existing operation in the new dialect already exists, delete the
285 // materialized operation in favor of the existing one.
286 if (auto *existingOp = uniquedConstants.lookup(newKey)) {
287 constOp->erase();
288 referencedDialects[existingOp].push_back(dialect);
289 return constOp = existingOp;
290 }
291
292 // Otherwise, update the new dialect to the materialized operation.
293 referencedDialects[constOp].assign({dialect, newDialect});
294 auto newIt = uniquedConstants.insert({newKey, constOp});
295 return newIt.first->second;
296 }
297