1 //===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Transforms/Utils.h"
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/AffineStructures.h"
17 #include "mlir/Analysis/Utils.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 using namespace mlir;
27
28 // Perform the replacement in `op`.
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,Operation * op,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,bool allowNonDereferencingOps)29 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
30 Operation *op,
31 ArrayRef<Value> extraIndices,
32 AffineMap indexRemap,
33 ArrayRef<Value> extraOperands,
34 ArrayRef<Value> symbolOperands,
35 bool allowNonDereferencingOps) {
36 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
37 (void)newMemRefRank; // unused in opt mode
38 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
39 (void)oldMemRefRank; // unused in opt mode
40 if (indexRemap) {
41 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
42 "symbolic operand count mismatch");
43 assert(indexRemap.getNumInputs() ==
44 extraOperands.size() + oldMemRefRank + symbolOperands.size());
45 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
46 } else {
47 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
48 }
49
50 // Assert same elemental type.
51 assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
52 newMemRef.getType().cast<MemRefType>().getElementType());
53
54 SmallVector<unsigned, 2> usePositions;
55 for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
56 if (opEntry.value() == oldMemRef)
57 usePositions.push_back(opEntry.index());
58 }
59
60 // If memref doesn't appear, nothing to do.
61 if (usePositions.empty())
62 return success();
63
64 if (usePositions.size() > 1) {
65 // TODO: extend it for this case when needed (rare).
66 assert(false && "multiple dereferencing uses in a single op not supported");
67 return failure();
68 }
69
70 unsigned memRefOperandPos = usePositions.front();
71
72 OpBuilder builder(op);
73 // The following checks if op is dereferencing memref and performs the access
74 // index rewrites.
75 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
76 if (!affMapAccInterface) {
77 if (!allowNonDereferencingOps) {
78 // Failure: memref used in a non-dereferencing context (potentially
79 // escapes); no replacement in these cases unless allowNonDereferencingOps
80 // is set.
81 return failure();
82 }
83 op->setOperand(memRefOperandPos, newMemRef);
84 return success();
85 }
86 // Perform index rewrites for the dereferencing op and then replace the op
87 NamedAttribute oldMapAttrPair =
88 affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
89 AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
90 unsigned oldMapNumInputs = oldMap.getNumInputs();
91 SmallVector<Value, 4> oldMapOperands(
92 op->operand_begin() + memRefOperandPos + 1,
93 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
94
95 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
96 SmallVector<Value, 4> oldMemRefOperands;
97 SmallVector<Value, 4> affineApplyOps;
98 oldMemRefOperands.reserve(oldMemRefRank);
99 if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
100 for (auto resultExpr : oldMap.getResults()) {
101 auto singleResMap = AffineMap::get(oldMap.getNumDims(),
102 oldMap.getNumSymbols(), resultExpr);
103 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
104 oldMapOperands);
105 oldMemRefOperands.push_back(afOp);
106 affineApplyOps.push_back(afOp);
107 }
108 } else {
109 oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
110 }
111
112 // Construct new indices as a remap of the old ones if a remapping has been
113 // provided. The indices of a memref come right after it, i.e.,
114 // at position memRefOperandPos + 1.
115 SmallVector<Value, 4> remapOperands;
116 remapOperands.reserve(extraOperands.size() + oldMemRefRank +
117 symbolOperands.size());
118 remapOperands.append(extraOperands.begin(), extraOperands.end());
119 remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
120 remapOperands.append(symbolOperands.begin(), symbolOperands.end());
121
122 SmallVector<Value, 4> remapOutputs;
123 remapOutputs.reserve(oldMemRefRank);
124
125 if (indexRemap &&
126 indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
127 // Remapped indices.
128 for (auto resultExpr : indexRemap.getResults()) {
129 auto singleResMap = AffineMap::get(
130 indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
131 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
132 remapOperands);
133 remapOutputs.push_back(afOp);
134 affineApplyOps.push_back(afOp);
135 }
136 } else {
137 // No remapping specified.
138 remapOutputs.assign(remapOperands.begin(), remapOperands.end());
139 }
140
141 SmallVector<Value, 4> newMapOperands;
142 newMapOperands.reserve(newMemRefRank);
143
144 // Prepend 'extraIndices' in 'newMapOperands'.
145 for (Value extraIndex : extraIndices) {
146 assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
147 "single result op's expected to generate these indices");
148 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
149 "invalid memory op index");
150 newMapOperands.push_back(extraIndex);
151 }
152
153 // Append 'remapOutputs' to 'newMapOperands'.
154 newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
155
156 // Create new fully composed AffineMap for new op to be created.
157 assert(newMapOperands.size() == newMemRefRank);
158 auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
159 // TODO: Avoid creating/deleting temporary AffineApplyOps here.
160 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
161 newMap = simplifyAffineMap(newMap);
162 canonicalizeMapAndOperands(&newMap, &newMapOperands);
163 // Remove any affine.apply's that became dead as a result of composition.
164 for (Value value : affineApplyOps)
165 if (value.use_empty())
166 value.getDefiningOp()->erase();
167
168 OperationState state(op->getLoc(), op->getName());
169 // Construct the new operation using this memref.
170 state.operands.reserve(op->getNumOperands() + extraIndices.size());
171 // Insert the non-memref operands.
172 state.operands.append(op->operand_begin(),
173 op->operand_begin() + memRefOperandPos);
174 // Insert the new memref value.
175 state.operands.push_back(newMemRef);
176
177 // Insert the new memref map operands.
178 state.operands.append(newMapOperands.begin(), newMapOperands.end());
179
180 // Insert the remaining operands unmodified.
181 state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
182 oldMapNumInputs,
183 op->operand_end());
184
185 // Result types don't change. Both memref's are of the same elemental type.
186 state.types.reserve(op->getNumResults());
187 for (auto result : op->getResults())
188 state.types.push_back(result.getType());
189
190 // Add attribute for 'newMap', other Attributes do not change.
191 auto newMapAttr = AffineMapAttr::get(newMap);
192 for (auto namedAttr : op->getAttrs()) {
193 if (namedAttr.first == oldMapAttrPair.first)
194 state.attributes.push_back({namedAttr.first, newMapAttr});
195 else
196 state.attributes.push_back(namedAttr);
197 }
198
199 // Create the new operation.
200 auto *repOp = builder.createOperation(state);
201 op->replaceAllUsesWith(repOp);
202 op->erase();
203
204 return success();
205 }
206
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,Operation * domInstFilter,Operation * postDomInstFilter,bool allowNonDereferencingOps,bool replaceInDeallocOp)207 LogicalResult mlir::replaceAllMemRefUsesWith(
208 Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
209 AffineMap indexRemap, ArrayRef<Value> extraOperands,
210 ArrayRef<Value> symbolOperands, Operation *domInstFilter,
211 Operation *postDomInstFilter, bool allowNonDereferencingOps,
212 bool replaceInDeallocOp) {
213 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
214 (void)newMemRefRank; // unused in opt mode
215 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
216 (void)oldMemRefRank;
217 if (indexRemap) {
218 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
219 "symbol operand count mismatch");
220 assert(indexRemap.getNumInputs() ==
221 extraOperands.size() + oldMemRefRank + symbolOperands.size());
222 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
223 } else {
224 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
225 }
226
227 // Assert same elemental type.
228 assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
229 newMemRef.getType().cast<MemRefType>().getElementType());
230
231 std::unique_ptr<DominanceInfo> domInfo;
232 std::unique_ptr<PostDominanceInfo> postDomInfo;
233 if (domInstFilter)
234 domInfo = std::make_unique<DominanceInfo>(
235 domInstFilter->getParentOfType<FuncOp>());
236
237 if (postDomInstFilter)
238 postDomInfo = std::make_unique<PostDominanceInfo>(
239 postDomInstFilter->getParentOfType<FuncOp>());
240
241 // Walk all uses of old memref; collect ops to perform replacement. We use a
242 // DenseSet since an operation could potentially have multiple uses of a
243 // memref (although rare), and the replacement later is going to erase ops.
244 DenseSet<Operation *> opsToReplace;
245 for (auto *op : oldMemRef.getUsers()) {
246 // Skip this use if it's not dominated by domInstFilter.
247 if (domInstFilter && !domInfo->dominates(domInstFilter, op))
248 continue;
249
250 // Skip this use if it's not post-dominated by postDomInstFilter.
251 if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
252 continue;
253
254 // Skip dealloc's - no replacement is necessary, and a memref replacement
255 // at other uses doesn't hurt these dealloc's.
256 if (isa<memref::DeallocOp>(op) && !replaceInDeallocOp)
257 continue;
258
259 // Check if the memref was used in a non-dereferencing context. It is fine
260 // for the memref to be used in a non-dereferencing way outside of the
261 // region where this replacement is happening.
262 if (!isa<AffineMapAccessInterface>(*op)) {
263 if (!allowNonDereferencingOps)
264 return failure();
265 // Currently we support the following non-dereferencing ops to be a
266 // candidate for replacement: Dealloc, CallOp and ReturnOp.
267 // TODO: Add support for other kinds of ops.
268 if (!op->hasTrait<OpTrait::MemRefsNormalizable>())
269 return failure();
270 }
271
272 // We'll first collect and then replace --- since replacement erases the op
273 // that has the use, and that op could be postDomFilter or domFilter itself!
274 opsToReplace.insert(op);
275 }
276
277 for (auto *op : opsToReplace) {
278 if (failed(replaceAllMemRefUsesWith(
279 oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
280 symbolOperands, allowNonDereferencingOps)))
281 llvm_unreachable("memref replacement guaranteed to succeed here");
282 }
283
284 return success();
285 }
286
287 /// Given an operation, inserts one or more single result affine
288 /// apply operations, results of which are exclusively used by this operation
289 /// operation. The operands of these newly created affine apply ops are
290 /// guaranteed to be loop iterators or terminal symbols of a function.
291 ///
292 /// Before
293 ///
294 /// affine.for %i = 0 to #map(%N)
295 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
296 /// "send"(%idx, %A, ...)
297 /// "compute"(%idx)
298 ///
299 /// After
300 ///
301 /// affine.for %i = 0 to #map(%N)
302 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
303 /// "send"(%idx, %A, ...)
304 /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
305 /// "compute"(%idx_)
306 ///
307 /// This allows applying different transformations on send and compute (for eg.
308 /// different shifts/delays).
309 ///
310 /// Returns nullptr either if none of opInst's operands were the result of an
311 /// affine.apply and thus there was no affine computation slice to create, or if
312 /// all the affine.apply op's supplying operands to this opInst did not have any
313 /// uses besides this opInst; otherwise returns the list of affine.apply
314 /// operations created in output argument `sliceOps`.
createAffineComputationSlice(Operation * opInst,SmallVectorImpl<AffineApplyOp> * sliceOps)315 void mlir::createAffineComputationSlice(
316 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
317 // Collect all operands that are results of affine apply ops.
318 SmallVector<Value, 4> subOperands;
319 subOperands.reserve(opInst->getNumOperands());
320 for (auto operand : opInst->getOperands())
321 if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
322 subOperands.push_back(operand);
323
324 // Gather sequence of AffineApplyOps reachable from 'subOperands'.
325 SmallVector<Operation *, 4> affineApplyOps;
326 getReachableAffineApplyOps(subOperands, affineApplyOps);
327 // Skip transforming if there are no affine maps to compose.
328 if (affineApplyOps.empty())
329 return;
330
331 // Check if all uses of the affine apply op's lie only in this op op, in
332 // which case there would be nothing to do.
333 bool localized = true;
334 for (auto *op : affineApplyOps) {
335 for (auto result : op->getResults()) {
336 for (auto *user : result.getUsers()) {
337 if (user != opInst) {
338 localized = false;
339 break;
340 }
341 }
342 }
343 }
344 if (localized)
345 return;
346
347 OpBuilder builder(opInst);
348 SmallVector<Value, 4> composedOpOperands(subOperands);
349 auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
350 fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
351
352 // Create an affine.apply for each of the map results.
353 sliceOps->reserve(composedMap.getNumResults());
354 for (auto resultExpr : composedMap.getResults()) {
355 auto singleResMap = AffineMap::get(composedMap.getNumDims(),
356 composedMap.getNumSymbols(), resultExpr);
357 sliceOps->push_back(builder.create<AffineApplyOp>(
358 opInst->getLoc(), singleResMap, composedOpOperands));
359 }
360
361 // Construct the new operands that include the results from the composed
362 // affine apply op above instead of existing ones (subOperands). So, they
363 // differ from opInst's operands only for those operands in 'subOperands', for
364 // which they will be replaced by the corresponding one from 'sliceOps'.
365 SmallVector<Value, 4> newOperands(opInst->getOperands());
366 for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
367 // Replace the subOperands from among the new operands.
368 unsigned j, f;
369 for (j = 0, f = subOperands.size(); j < f; j++) {
370 if (newOperands[i] == subOperands[j])
371 break;
372 }
373 if (j < subOperands.size()) {
374 newOperands[i] = (*sliceOps)[j];
375 }
376 }
377 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
378 opInst->setOperand(idx, newOperands[idx]);
379 }
380 }
381
382 /// Enum to set patterns of affine expr in tiled-layout map.
383 /// TileFloorDiv: <dim expr> div <tile size>
384 /// TileMod: <dim expr> mod <tile size>
385 /// TileNone: None of the above
386 /// Example:
387 /// #tiled_2d_128x256 = affine_map<(d0, d1)
388 /// -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
389 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv
390 /// "d0 mod 128" and "d1 mod 256" ==> TileMod
391 enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
392
393 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions
394 /// being floordiv'ed by respective tile sizes appeare in a mod with the same
395 /// tile sizes, and no other expression involves those k dimensions. This
396 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for
397 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a
398 /// tiled layout, an empty vector is returned.
getTileSizePos(AffineMap map,SmallVectorImpl<std::tuple<AffineExpr,unsigned,unsigned>> & tileSizePos)399 static LogicalResult getTileSizePos(
400 AffineMap map,
401 SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
402 // Create `floordivExprs` which is a vector of tuples including LHS and RHS of
403 // `floordiv` and its position in `map` output.
404 // Example: #tiled_2d_128x256 = affine_map<(d0, d1)
405 // -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
406 // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}.
407 SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
408 unsigned pos = 0;
409 for (AffineExpr expr : map.getResults()) {
410 if (expr.getKind() == AffineExprKind::FloorDiv) {
411 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
412 if (binaryExpr.getRHS().isa<AffineConstantExpr>())
413 floordivExprs.emplace_back(
414 std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos));
415 }
416 pos++;
417 }
418 // Not tiled layout if `floordivExprs` is empty.
419 if (floordivExprs.empty()) {
420 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
421 return success();
422 }
423
424 // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is
425 // not tiled layout.
426 for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
427 AffineExpr floordivExprLHS = std::get<0>(fexpr);
428 AffineExpr floordivExprRHS = std::get<1>(fexpr);
429 unsigned floordivPos = std::get<2>(fexpr);
430
431 // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS
432 // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used
433 // other expr, the map is not tiled layout. Example of non tiled layout:
434 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)>
435 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)>
436 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod
437 // 256)>
438 bool found = false;
439 pos = 0;
440 for (AffineExpr expr : map.getResults()) {
441 bool notTiled = false;
442 if (pos != floordivPos) {
443 expr.walk([&](AffineExpr e) {
444 if (e == floordivExprLHS) {
445 if (expr.getKind() == AffineExprKind::Mod) {
446 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
447 // If LHS and RHS of `mod` are the same with those of floordiv.
448 if (floordivExprLHS == binaryExpr.getLHS() &&
449 floordivExprRHS == binaryExpr.getRHS()) {
450 // Save tile size (RHS of `mod`), and position of `floordiv` and
451 // `mod` if same expr with `mod` is not found yet.
452 if (!found) {
453 tileSizePos.emplace_back(
454 std::make_tuple(binaryExpr.getRHS(), floordivPos, pos));
455 found = true;
456 } else {
457 // Non tiled layout: Have multilpe `mod` with the same LHS.
458 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
459 // mod 256, d2 mod 256)>
460 notTiled = true;
461 }
462 } else {
463 // Non tiled layout: RHS of `mod` is different from `floordiv`.
464 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
465 // mod 128)>
466 notTiled = true;
467 }
468 } else {
469 // Non tiled layout: LHS is the same, but not `mod`.
470 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
471 // floordiv 256)>
472 notTiled = true;
473 }
474 }
475 });
476 }
477 if (notTiled) {
478 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
479 return success();
480 }
481 pos++;
482 }
483 }
484 return success();
485 }
486
487 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic
488 /// after normalization. Dimensions that include dynamic dimensions in the map
489 /// output will become dynamic dimensions. Return true if `dim` is dynamic
490 /// dimension.
491 ///
492 /// Example:
493 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
494 ///
495 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic.
496 /// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
497 static bool
isNormalizedMemRefDynamicDim(unsigned dim,AffineMap layoutMap,SmallVectorImpl<unsigned> & inMemrefTypeDynDims,MLIRContext * context)498 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
499 SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
500 MLIRContext *context) {
501 bool isDynamicDim = false;
502 AffineExpr expr = layoutMap.getResults()[dim];
503 // Check if affine expr of the dimension includes dynamic dimension of input
504 // memrefType.
505 expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
506 if (e.isa<AffineDimExpr>()) {
507 for (unsigned dm : inMemrefTypeDynDims) {
508 if (e == getAffineDimExpr(dm, context)) {
509 isDynamicDim = true;
510 }
511 }
512 }
513 });
514 return isDynamicDim;
515 }
516
517 /// Create affine expr to calculate dimension size for a tiled-layout map.
createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,TileExprPattern pat)518 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
519 TileExprPattern pat) {
520 // Create map output for the patterns.
521 // "floordiv <tile size>" ==> "ceildiv <tile size>"
522 // "mod <tile size>" ==> "<tile size>"
523 AffineExpr newMapOutput;
524 AffineBinaryOpExpr binaryExpr = nullptr;
525 switch (pat) {
526 case TileExprPattern::TileMod:
527 binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
528 newMapOutput = binaryExpr.getRHS();
529 break;
530 case TileExprPattern::TileFloorDiv:
531 binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
532 newMapOutput = getAffineBinaryOpExpr(
533 AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS());
534 break;
535 default:
536 newMapOutput = oldMapOutput;
537 }
538 return newMapOutput;
539 }
540
541 /// Create new maps to calculate each dimension size of `newMemRefType`, and
542 /// create `newDynamicSizes` from them by using AffineApplyOp.
543 ///
544 /// Steps for normalizing dynamic memrefs for a tiled layout map
545 /// Example:
546 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
547 /// %0 = dim %arg0, %c1 :memref<4x?xf32>
548 /// %1 = alloc(%0) : memref<4x?xf32, #map0>
549 ///
550 /// (Before this function)
551 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only
552 /// single layout map is supported.
553 ///
554 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It
555 /// is memref<4x?x?xf32> in the above example.
556 ///
557 /// (In this function)
558 /// 3. Create new maps to calculate each dimension of the normalized memrefType
559 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the
560 /// dimension size can be calculated by replacing "floordiv <tile size>" with
561 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>".
562 /// - New map in the above example
563 /// #map0 = affine_map<(d0, d1) -> (d0)>
564 /// #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)>
565 /// #map2 = affine_map<(d0, d1) -> (32)>
566 ///
567 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
568 /// is used in dynamicSizes of new AllocOp.
569 /// %0 = dim %arg0, %c1 : memref<4x?xf32>
570 /// %c4 = constant 4 : index
571 /// %1 = affine.apply #map1(%c4, %0)
572 /// %2 = affine.apply #map2(%c4, %0)
createNewDynamicSizes(MemRefType oldMemRefType,MemRefType newMemRefType,AffineMap map,memref::AllocOp * allocOp,OpBuilder b,SmallVectorImpl<Value> & newDynamicSizes)573 static void createNewDynamicSizes(MemRefType oldMemRefType,
574 MemRefType newMemRefType, AffineMap map,
575 memref::AllocOp *allocOp, OpBuilder b,
576 SmallVectorImpl<Value> &newDynamicSizes) {
577 // Create new input for AffineApplyOp.
578 SmallVector<Value, 4> inAffineApply;
579 ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
580 unsigned dynIdx = 0;
581 for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
582 if (oldMemRefShape[d] < 0) {
583 // Use dynamicSizes of allocOp for dynamic dimension.
584 inAffineApply.emplace_back(allocOp->dynamicSizes()[dynIdx]);
585 dynIdx++;
586 } else {
587 // Create ConstantOp for static dimension.
588 Attribute constantAttr =
589 b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
590 inAffineApply.emplace_back(
591 b.create<ConstantOp>(allocOp->getLoc(), constantAttr));
592 }
593 }
594
595 // Create new map to calculate each dimension size of new memref for each
596 // original map output. Only for dynamic dimesion of `newMemRefType`.
597 unsigned newDimIdx = 0;
598 ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
599 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
600 (void)getTileSizePos(map, tileSizePos);
601 for (AffineExpr expr : map.getResults()) {
602 if (newMemRefShape[newDimIdx] < 0) {
603 // Create new maps to calculate each dimension size of new memref.
604 enum TileExprPattern pat = TileExprPattern::TileNone;
605 for (auto pos : tileSizePos) {
606 if (newDimIdx == std::get<1>(pos))
607 pat = TileExprPattern::TileFloorDiv;
608 else if (newDimIdx == std::get<2>(pos))
609 pat = TileExprPattern::TileMod;
610 }
611 AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat);
612 AffineMap newMap =
613 AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
614 Value affineApp =
615 b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply);
616 newDynamicSizes.emplace_back(affineApp);
617 }
618 newDimIdx++;
619 }
620 }
621
622 // TODO: Currently works for static memrefs with a single layout map.
normalizeMemRef(memref::AllocOp * allocOp)623 LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
624 MemRefType memrefType = allocOp->getType();
625 OpBuilder b(*allocOp);
626
627 // Fetch a new memref type after normalizing the old memref to have an
628 // identity map layout.
629 MemRefType newMemRefType =
630 normalizeMemRefType(memrefType, b, allocOp->symbolOperands().size());
631 if (newMemRefType == memrefType)
632 // Either memrefType already had an identity map or the map couldn't be
633 // transformed to an identity map.
634 return failure();
635
636 Value oldMemRef = allocOp->getResult();
637
638 SmallVector<Value, 4> symbolOperands(allocOp->symbolOperands());
639 AffineMap layoutMap = memrefType.getAffineMaps().front();
640 memref::AllocOp newAlloc;
641 // Check if `layoutMap` is a tiled layout. Only single layout map is
642 // supported for normalizing dynamic memrefs.
643 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
644 (void)getTileSizePos(layoutMap, tileSizePos);
645 if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
646 MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>();
647 SmallVector<Value, 4> newDynamicSizes;
648 createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
649 newDynamicSizes);
650 // Add the new dynamic sizes in new AllocOp.
651 newAlloc =
652 b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
653 newDynamicSizes, allocOp->alignmentAttr());
654 } else {
655 newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
656 allocOp->alignmentAttr());
657 }
658 // Replace all uses of the old memref.
659 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
660 /*extraIndices=*/{},
661 /*indexRemap=*/layoutMap,
662 /*extraOperands=*/{},
663 /*symbolOperands=*/symbolOperands,
664 /*domInstFilter=*/nullptr,
665 /*postDomInstFilter=*/nullptr,
666 /*allowDereferencingOps=*/true))) {
667 // If it failed (due to escapes for example), bail out.
668 newAlloc.erase();
669 return failure();
670 }
671 // Replace any uses of the original alloc op and erase it. All remaining uses
672 // have to be dealloc's; RAMUW above would've failed otherwise.
673 assert(llvm::all_of(oldMemRef.getUsers(), [](Operation *op) {
674 return isa<memref::DeallocOp>(op);
675 }));
676 oldMemRef.replaceAllUsesWith(newAlloc);
677 allocOp->erase();
678 return success();
679 }
680
normalizeMemRefType(MemRefType memrefType,OpBuilder b,unsigned numSymbolicOperands)681 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
682 unsigned numSymbolicOperands) {
683 unsigned rank = memrefType.getRank();
684 if (rank == 0)
685 return memrefType;
686
687 ArrayRef<AffineMap> layoutMaps = memrefType.getAffineMaps();
688 if (layoutMaps.empty() ||
689 layoutMaps.front() == b.getMultiDimIdentityMap(rank)) {
690 // Either no maps is associated with this memref or this memref has
691 // a trivial (identity) map.
692 return memrefType;
693 }
694
695 // We don't do any checks for one-to-one'ness; we assume that it is
696 // one-to-one.
697
698 // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
699 // for now.
700 // TODO: Normalize the other types of dynamic memrefs.
701 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
702 (void)getTileSizePos(layoutMaps.front(), tileSizePos);
703 if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
704 return memrefType;
705
706 // We have a single map that is not an identity map. Create a new memref
707 // with the right shape and an identity layout map.
708 ArrayRef<int64_t> shape = memrefType.getShape();
709 // FlatAffineConstraint may later on use symbolicOperands.
710 FlatAffineConstraints fac(rank, numSymbolicOperands);
711 SmallVector<unsigned, 4> memrefTypeDynDims;
712 for (unsigned d = 0; d < rank; ++d) {
713 // Use constraint system only in static dimensions.
714 if (shape[d] > 0) {
715 fac.addConstantLowerBound(d, 0);
716 fac.addConstantUpperBound(d, shape[d] - 1);
717 } else {
718 memrefTypeDynDims.emplace_back(d);
719 }
720 }
721 // We compose this map with the original index (logical) space to derive
722 // the upper bounds for the new index space.
723 AffineMap layoutMap = layoutMaps.front();
724 unsigned newRank = layoutMap.getNumResults();
725 if (failed(fac.composeMatchingMap(layoutMap)))
726 return memrefType;
727 // TODO: Handle semi-affine maps.
728 // Project out the old data dimensions.
729 fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
730 SmallVector<int64_t, 4> newShape(newRank);
731 for (unsigned d = 0; d < newRank; ++d) {
732 // Check if each dimension of normalized memrefType is dynamic.
733 bool isDynDim = isNormalizedMemRefDynamicDim(
734 d, layoutMap, memrefTypeDynDims, b.getContext());
735 if (isDynDim) {
736 newShape[d] = -1;
737 } else {
738 // The lower bound for the shape is always zero.
739 auto ubConst = fac.getConstantUpperBound(d);
740 // For a static memref and an affine map with no symbols, this is
741 // always bounded.
742 assert(ubConst.hasValue() && "should always have an upper bound");
743 if (ubConst.getValue() < 0)
744 // This is due to an invalid map that maps to a negative space.
745 return memrefType;
746 // If dimension of new memrefType is dynamic, the value is -1.
747 newShape[d] = ubConst.getValue() + 1;
748 }
749 }
750
751 // Create the new memref type after trivializing the old layout map.
752 MemRefType newMemRefType =
753 MemRefType::Builder(memrefType)
754 .setShape(newShape)
755 .setAffineMaps(b.getMultiDimIdentityMap(newRank));
756
757 return newMemRefType;
758 }
759