1 //===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Utils.h"
15 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/AffineStructures.h"
17 #include "mlir/Analysis/Utils.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 using namespace mlir;
27 
28 // Perform the replacement in `op`.
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,Operation * op,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,bool allowNonDereferencingOps)29 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
30                                              Operation *op,
31                                              ArrayRef<Value> extraIndices,
32                                              AffineMap indexRemap,
33                                              ArrayRef<Value> extraOperands,
34                                              ArrayRef<Value> symbolOperands,
35                                              bool allowNonDereferencingOps) {
36   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
37   (void)newMemRefRank; // unused in opt mode
38   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
39   (void)oldMemRefRank; // unused in opt mode
40   if (indexRemap) {
41     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
42            "symbolic operand count mismatch");
43     assert(indexRemap.getNumInputs() ==
44            extraOperands.size() + oldMemRefRank + symbolOperands.size());
45     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
46   } else {
47     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
48   }
49 
50   // Assert same elemental type.
51   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
52          newMemRef.getType().cast<MemRefType>().getElementType());
53 
54   SmallVector<unsigned, 2> usePositions;
55   for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
56     if (opEntry.value() == oldMemRef)
57       usePositions.push_back(opEntry.index());
58   }
59 
60   // If memref doesn't appear, nothing to do.
61   if (usePositions.empty())
62     return success();
63 
64   if (usePositions.size() > 1) {
65     // TODO: extend it for this case when needed (rare).
66     assert(false && "multiple dereferencing uses in a single op not supported");
67     return failure();
68   }
69 
70   unsigned memRefOperandPos = usePositions.front();
71 
72   OpBuilder builder(op);
73   // The following checks if op is dereferencing memref and performs the access
74   // index rewrites.
75   auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
76   if (!affMapAccInterface) {
77     if (!allowNonDereferencingOps) {
78       // Failure: memref used in a non-dereferencing context (potentially
79       // escapes); no replacement in these cases unless allowNonDereferencingOps
80       // is set.
81       return failure();
82     }
83     op->setOperand(memRefOperandPos, newMemRef);
84     return success();
85   }
86   // Perform index rewrites for the dereferencing op and then replace the op
87   NamedAttribute oldMapAttrPair =
88       affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
89   AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
90   unsigned oldMapNumInputs = oldMap.getNumInputs();
91   SmallVector<Value, 4> oldMapOperands(
92       op->operand_begin() + memRefOperandPos + 1,
93       op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
94 
95   // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
96   SmallVector<Value, 4> oldMemRefOperands;
97   SmallVector<Value, 4> affineApplyOps;
98   oldMemRefOperands.reserve(oldMemRefRank);
99   if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
100     for (auto resultExpr : oldMap.getResults()) {
101       auto singleResMap = AffineMap::get(oldMap.getNumDims(),
102                                          oldMap.getNumSymbols(), resultExpr);
103       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
104                                                 oldMapOperands);
105       oldMemRefOperands.push_back(afOp);
106       affineApplyOps.push_back(afOp);
107     }
108   } else {
109     oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
110   }
111 
112   // Construct new indices as a remap of the old ones if a remapping has been
113   // provided. The indices of a memref come right after it, i.e.,
114   // at position memRefOperandPos + 1.
115   SmallVector<Value, 4> remapOperands;
116   remapOperands.reserve(extraOperands.size() + oldMemRefRank +
117                         symbolOperands.size());
118   remapOperands.append(extraOperands.begin(), extraOperands.end());
119   remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
120   remapOperands.append(symbolOperands.begin(), symbolOperands.end());
121 
122   SmallVector<Value, 4> remapOutputs;
123   remapOutputs.reserve(oldMemRefRank);
124 
125   if (indexRemap &&
126       indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
127     // Remapped indices.
128     for (auto resultExpr : indexRemap.getResults()) {
129       auto singleResMap = AffineMap::get(
130           indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
131       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
132                                                 remapOperands);
133       remapOutputs.push_back(afOp);
134       affineApplyOps.push_back(afOp);
135     }
136   } else {
137     // No remapping specified.
138     remapOutputs.assign(remapOperands.begin(), remapOperands.end());
139   }
140 
141   SmallVector<Value, 4> newMapOperands;
142   newMapOperands.reserve(newMemRefRank);
143 
144   // Prepend 'extraIndices' in 'newMapOperands'.
145   for (Value extraIndex : extraIndices) {
146     assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
147            "single result op's expected to generate these indices");
148     assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
149            "invalid memory op index");
150     newMapOperands.push_back(extraIndex);
151   }
152 
153   // Append 'remapOutputs' to 'newMapOperands'.
154   newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
155 
156   // Create new fully composed AffineMap for new op to be created.
157   assert(newMapOperands.size() == newMemRefRank);
158   auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
159   // TODO: Avoid creating/deleting temporary AffineApplyOps here.
160   fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
161   newMap = simplifyAffineMap(newMap);
162   canonicalizeMapAndOperands(&newMap, &newMapOperands);
163   // Remove any affine.apply's that became dead as a result of composition.
164   for (Value value : affineApplyOps)
165     if (value.use_empty())
166       value.getDefiningOp()->erase();
167 
168   OperationState state(op->getLoc(), op->getName());
169   // Construct the new operation using this memref.
170   state.operands.reserve(op->getNumOperands() + extraIndices.size());
171   // Insert the non-memref operands.
172   state.operands.append(op->operand_begin(),
173                         op->operand_begin() + memRefOperandPos);
174   // Insert the new memref value.
175   state.operands.push_back(newMemRef);
176 
177   // Insert the new memref map operands.
178   state.operands.append(newMapOperands.begin(), newMapOperands.end());
179 
180   // Insert the remaining operands unmodified.
181   state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
182                             oldMapNumInputs,
183                         op->operand_end());
184 
185   // Result types don't change. Both memref's are of the same elemental type.
186   state.types.reserve(op->getNumResults());
187   for (auto result : op->getResults())
188     state.types.push_back(result.getType());
189 
190   // Add attribute for 'newMap', other Attributes do not change.
191   auto newMapAttr = AffineMapAttr::get(newMap);
192   for (auto namedAttr : op->getAttrs()) {
193     if (namedAttr.first == oldMapAttrPair.first)
194       state.attributes.push_back({namedAttr.first, newMapAttr});
195     else
196       state.attributes.push_back(namedAttr);
197   }
198 
199   // Create the new operation.
200   auto *repOp = builder.createOperation(state);
201   op->replaceAllUsesWith(repOp);
202   op->erase();
203 
204   return success();
205 }
206 
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,Operation * domInstFilter,Operation * postDomInstFilter,bool allowNonDereferencingOps,bool replaceInDeallocOp)207 LogicalResult mlir::replaceAllMemRefUsesWith(
208     Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
209     AffineMap indexRemap, ArrayRef<Value> extraOperands,
210     ArrayRef<Value> symbolOperands, Operation *domInstFilter,
211     Operation *postDomInstFilter, bool allowNonDereferencingOps,
212     bool replaceInDeallocOp) {
213   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
214   (void)newMemRefRank; // unused in opt mode
215   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
216   (void)oldMemRefRank;
217   if (indexRemap) {
218     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
219            "symbol operand count mismatch");
220     assert(indexRemap.getNumInputs() ==
221            extraOperands.size() + oldMemRefRank + symbolOperands.size());
222     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
223   } else {
224     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
225   }
226 
227   // Assert same elemental type.
228   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
229          newMemRef.getType().cast<MemRefType>().getElementType());
230 
231   std::unique_ptr<DominanceInfo> domInfo;
232   std::unique_ptr<PostDominanceInfo> postDomInfo;
233   if (domInstFilter)
234     domInfo = std::make_unique<DominanceInfo>(
235         domInstFilter->getParentOfType<FuncOp>());
236 
237   if (postDomInstFilter)
238     postDomInfo = std::make_unique<PostDominanceInfo>(
239         postDomInstFilter->getParentOfType<FuncOp>());
240 
241   // Walk all uses of old memref; collect ops to perform replacement. We use a
242   // DenseSet since an operation could potentially have multiple uses of a
243   // memref (although rare), and the replacement later is going to erase ops.
244   DenseSet<Operation *> opsToReplace;
245   for (auto *op : oldMemRef.getUsers()) {
246     // Skip this use if it's not dominated by domInstFilter.
247     if (domInstFilter && !domInfo->dominates(domInstFilter, op))
248       continue;
249 
250     // Skip this use if it's not post-dominated by postDomInstFilter.
251     if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
252       continue;
253 
254     // Skip dealloc's - no replacement is necessary, and a memref replacement
255     // at other uses doesn't hurt these dealloc's.
256     if (isa<memref::DeallocOp>(op) && !replaceInDeallocOp)
257       continue;
258 
259     // Check if the memref was used in a non-dereferencing context. It is fine
260     // for the memref to be used in a non-dereferencing way outside of the
261     // region where this replacement is happening.
262     if (!isa<AffineMapAccessInterface>(*op)) {
263       if (!allowNonDereferencingOps)
264         return failure();
265       // Currently we support the following non-dereferencing ops to be a
266       // candidate for replacement: Dealloc, CallOp and ReturnOp.
267       // TODO: Add support for other kinds of ops.
268       if (!op->hasTrait<OpTrait::MemRefsNormalizable>())
269         return failure();
270     }
271 
272     // We'll first collect and then replace --- since replacement erases the op
273     // that has the use, and that op could be postDomFilter or domFilter itself!
274     opsToReplace.insert(op);
275   }
276 
277   for (auto *op : opsToReplace) {
278     if (failed(replaceAllMemRefUsesWith(
279             oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
280             symbolOperands, allowNonDereferencingOps)))
281       llvm_unreachable("memref replacement guaranteed to succeed here");
282   }
283 
284   return success();
285 }
286 
287 /// Given an operation, inserts one or more single result affine
288 /// apply operations, results of which are exclusively used by this operation
289 /// operation. The operands of these newly created affine apply ops are
290 /// guaranteed to be loop iterators or terminal symbols of a function.
291 ///
292 /// Before
293 ///
294 /// affine.for %i = 0 to #map(%N)
295 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
296 ///   "send"(%idx, %A, ...)
297 ///   "compute"(%idx)
298 ///
299 /// After
300 ///
301 /// affine.for %i = 0 to #map(%N)
302 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
303 ///   "send"(%idx, %A, ...)
304 ///   %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
305 ///   "compute"(%idx_)
306 ///
307 /// This allows applying different transformations on send and compute (for eg.
308 /// different shifts/delays).
309 ///
310 /// Returns nullptr either if none of opInst's operands were the result of an
311 /// affine.apply and thus there was no affine computation slice to create, or if
312 /// all the affine.apply op's supplying operands to this opInst did not have any
313 /// uses besides this opInst; otherwise returns the list of affine.apply
314 /// operations created in output argument `sliceOps`.
createAffineComputationSlice(Operation * opInst,SmallVectorImpl<AffineApplyOp> * sliceOps)315 void mlir::createAffineComputationSlice(
316     Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
317   // Collect all operands that are results of affine apply ops.
318   SmallVector<Value, 4> subOperands;
319   subOperands.reserve(opInst->getNumOperands());
320   for (auto operand : opInst->getOperands())
321     if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
322       subOperands.push_back(operand);
323 
324   // Gather sequence of AffineApplyOps reachable from 'subOperands'.
325   SmallVector<Operation *, 4> affineApplyOps;
326   getReachableAffineApplyOps(subOperands, affineApplyOps);
327   // Skip transforming if there are no affine maps to compose.
328   if (affineApplyOps.empty())
329     return;
330 
331   // Check if all uses of the affine apply op's lie only in this op op, in
332   // which case there would be nothing to do.
333   bool localized = true;
334   for (auto *op : affineApplyOps) {
335     for (auto result : op->getResults()) {
336       for (auto *user : result.getUsers()) {
337         if (user != opInst) {
338           localized = false;
339           break;
340         }
341       }
342     }
343   }
344   if (localized)
345     return;
346 
347   OpBuilder builder(opInst);
348   SmallVector<Value, 4> composedOpOperands(subOperands);
349   auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
350   fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
351 
352   // Create an affine.apply for each of the map results.
353   sliceOps->reserve(composedMap.getNumResults());
354   for (auto resultExpr : composedMap.getResults()) {
355     auto singleResMap = AffineMap::get(composedMap.getNumDims(),
356                                        composedMap.getNumSymbols(), resultExpr);
357     sliceOps->push_back(builder.create<AffineApplyOp>(
358         opInst->getLoc(), singleResMap, composedOpOperands));
359   }
360 
361   // Construct the new operands that include the results from the composed
362   // affine apply op above instead of existing ones (subOperands). So, they
363   // differ from opInst's operands only for those operands in 'subOperands', for
364   // which they will be replaced by the corresponding one from 'sliceOps'.
365   SmallVector<Value, 4> newOperands(opInst->getOperands());
366   for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
367     // Replace the subOperands from among the new operands.
368     unsigned j, f;
369     for (j = 0, f = subOperands.size(); j < f; j++) {
370       if (newOperands[i] == subOperands[j])
371         break;
372     }
373     if (j < subOperands.size()) {
374       newOperands[i] = (*sliceOps)[j];
375     }
376   }
377   for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
378     opInst->setOperand(idx, newOperands[idx]);
379   }
380 }
381 
382 /// Enum to set patterns of affine expr in tiled-layout map.
383 /// TileFloorDiv: <dim expr> div <tile size>
384 /// TileMod: <dim expr> mod <tile size>
385 /// TileNone: None of the above
386 /// Example:
387 /// #tiled_2d_128x256 = affine_map<(d0, d1)
388 ///            -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
389 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv
390 /// "d0 mod 128" and "d1 mod 256" ==> TileMod
391 enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
392 
393 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions
394 /// being floordiv'ed by respective tile sizes appeare in a mod with the same
395 /// tile sizes, and no other expression involves those k dimensions. This
396 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for
397 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a
398 /// tiled layout, an empty vector is returned.
getTileSizePos(AffineMap map,SmallVectorImpl<std::tuple<AffineExpr,unsigned,unsigned>> & tileSizePos)399 static LogicalResult getTileSizePos(
400     AffineMap map,
401     SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
402   // Create `floordivExprs` which is a vector of tuples including LHS and RHS of
403   // `floordiv` and its position in `map` output.
404   // Example: #tiled_2d_128x256 = affine_map<(d0, d1)
405   //                -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
406   // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}.
407   SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
408   unsigned pos = 0;
409   for (AffineExpr expr : map.getResults()) {
410     if (expr.getKind() == AffineExprKind::FloorDiv) {
411       AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
412       if (binaryExpr.getRHS().isa<AffineConstantExpr>())
413         floordivExprs.emplace_back(
414             std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos));
415     }
416     pos++;
417   }
418   // Not tiled layout if `floordivExprs` is empty.
419   if (floordivExprs.empty()) {
420     tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
421     return success();
422   }
423 
424   // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is
425   // not tiled layout.
426   for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
427     AffineExpr floordivExprLHS = std::get<0>(fexpr);
428     AffineExpr floordivExprRHS = std::get<1>(fexpr);
429     unsigned floordivPos = std::get<2>(fexpr);
430 
431     // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS
432     // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used
433     // other expr, the map is not tiled layout. Example of non tiled layout:
434     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)>
435     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)>
436     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod
437     //   256)>
438     bool found = false;
439     pos = 0;
440     for (AffineExpr expr : map.getResults()) {
441       bool notTiled = false;
442       if (pos != floordivPos) {
443         expr.walk([&](AffineExpr e) {
444           if (e == floordivExprLHS) {
445             if (expr.getKind() == AffineExprKind::Mod) {
446               AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
447               // If LHS and RHS of `mod` are the same with those of floordiv.
448               if (floordivExprLHS == binaryExpr.getLHS() &&
449                   floordivExprRHS == binaryExpr.getRHS()) {
450                 // Save tile size (RHS of `mod`), and position of `floordiv` and
451                 // `mod` if same expr with `mod` is not found yet.
452                 if (!found) {
453                   tileSizePos.emplace_back(
454                       std::make_tuple(binaryExpr.getRHS(), floordivPos, pos));
455                   found = true;
456                 } else {
457                   // Non tiled layout: Have multilpe `mod` with the same LHS.
458                   // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
459                   // mod 256, d2 mod 256)>
460                   notTiled = true;
461                 }
462               } else {
463                 // Non tiled layout: RHS of `mod` is different from `floordiv`.
464                 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
465                 // mod 128)>
466                 notTiled = true;
467               }
468             } else {
469               // Non tiled layout: LHS is the same, but not `mod`.
470               // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
471               // floordiv 256)>
472               notTiled = true;
473             }
474           }
475         });
476       }
477       if (notTiled) {
478         tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
479         return success();
480       }
481       pos++;
482     }
483   }
484   return success();
485 }
486 
487 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic
488 /// after normalization. Dimensions that include dynamic dimensions in the map
489 /// output will become dynamic dimensions. Return true if `dim` is dynamic
490 /// dimension.
491 ///
492 /// Example:
493 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
494 ///
495 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic.
496 /// memref<4x?xf32, #map0>  ==>  memref<4x?x?xf32>
497 static bool
isNormalizedMemRefDynamicDim(unsigned dim,AffineMap layoutMap,SmallVectorImpl<unsigned> & inMemrefTypeDynDims,MLIRContext * context)498 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
499                              SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
500                              MLIRContext *context) {
501   bool isDynamicDim = false;
502   AffineExpr expr = layoutMap.getResults()[dim];
503   // Check if affine expr of the dimension includes dynamic dimension of input
504   // memrefType.
505   expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
506     if (e.isa<AffineDimExpr>()) {
507       for (unsigned dm : inMemrefTypeDynDims) {
508         if (e == getAffineDimExpr(dm, context)) {
509           isDynamicDim = true;
510         }
511       }
512     }
513   });
514   return isDynamicDim;
515 }
516 
517 /// Create affine expr to calculate dimension size for a tiled-layout map.
createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,TileExprPattern pat)518 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
519                                                   TileExprPattern pat) {
520   // Create map output for the patterns.
521   // "floordiv <tile size>" ==> "ceildiv <tile size>"
522   // "mod <tile size>" ==> "<tile size>"
523   AffineExpr newMapOutput;
524   AffineBinaryOpExpr binaryExpr = nullptr;
525   switch (pat) {
526   case TileExprPattern::TileMod:
527     binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
528     newMapOutput = binaryExpr.getRHS();
529     break;
530   case TileExprPattern::TileFloorDiv:
531     binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
532     newMapOutput = getAffineBinaryOpExpr(
533         AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS());
534     break;
535   default:
536     newMapOutput = oldMapOutput;
537   }
538   return newMapOutput;
539 }
540 
541 /// Create new maps to calculate each dimension size of `newMemRefType`, and
542 /// create `newDynamicSizes` from them by using AffineApplyOp.
543 ///
544 /// Steps for normalizing dynamic memrefs for a tiled layout map
545 /// Example:
546 ///    #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
547 ///    %0 = dim %arg0, %c1 :memref<4x?xf32>
548 ///    %1 = alloc(%0) : memref<4x?xf32, #map0>
549 ///
550 /// (Before this function)
551 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only
552 /// single layout map is supported.
553 ///
554 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It
555 /// is memref<4x?x?xf32> in the above example.
556 ///
557 /// (In this function)
558 /// 3. Create new maps to calculate each dimension of the normalized memrefType
559 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the
560 /// dimension size can be calculated by replacing "floordiv <tile size>" with
561 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>".
562 /// - New map in the above example
563 ///   #map0 = affine_map<(d0, d1) -> (d0)>
564 ///   #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)>
565 ///   #map2 = affine_map<(d0, d1) -> (32)>
566 ///
567 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
568 /// is used in dynamicSizes of new AllocOp.
569 ///   %0 = dim %arg0, %c1 : memref<4x?xf32>
570 ///   %c4 = constant 4 : index
571 ///   %1 = affine.apply #map1(%c4, %0)
572 ///   %2 = affine.apply #map2(%c4, %0)
createNewDynamicSizes(MemRefType oldMemRefType,MemRefType newMemRefType,AffineMap map,memref::AllocOp * allocOp,OpBuilder b,SmallVectorImpl<Value> & newDynamicSizes)573 static void createNewDynamicSizes(MemRefType oldMemRefType,
574                                   MemRefType newMemRefType, AffineMap map,
575                                   memref::AllocOp *allocOp, OpBuilder b,
576                                   SmallVectorImpl<Value> &newDynamicSizes) {
577   // Create new input for AffineApplyOp.
578   SmallVector<Value, 4> inAffineApply;
579   ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
580   unsigned dynIdx = 0;
581   for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
582     if (oldMemRefShape[d] < 0) {
583       // Use dynamicSizes of allocOp for dynamic dimension.
584       inAffineApply.emplace_back(allocOp->dynamicSizes()[dynIdx]);
585       dynIdx++;
586     } else {
587       // Create ConstantOp for static dimension.
588       Attribute constantAttr =
589           b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
590       inAffineApply.emplace_back(
591           b.create<ConstantOp>(allocOp->getLoc(), constantAttr));
592     }
593   }
594 
595   // Create new map to calculate each dimension size of new memref for each
596   // original map output. Only for dynamic dimesion of `newMemRefType`.
597   unsigned newDimIdx = 0;
598   ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
599   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
600   (void)getTileSizePos(map, tileSizePos);
601   for (AffineExpr expr : map.getResults()) {
602     if (newMemRefShape[newDimIdx] < 0) {
603       // Create new maps to calculate each dimension size of new memref.
604       enum TileExprPattern pat = TileExprPattern::TileNone;
605       for (auto pos : tileSizePos) {
606         if (newDimIdx == std::get<1>(pos))
607           pat = TileExprPattern::TileFloorDiv;
608         else if (newDimIdx == std::get<2>(pos))
609           pat = TileExprPattern::TileMod;
610       }
611       AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat);
612       AffineMap newMap =
613           AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
614       Value affineApp =
615           b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply);
616       newDynamicSizes.emplace_back(affineApp);
617     }
618     newDimIdx++;
619   }
620 }
621 
622 // TODO: Currently works for static memrefs with a single layout map.
normalizeMemRef(memref::AllocOp * allocOp)623 LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
624   MemRefType memrefType = allocOp->getType();
625   OpBuilder b(*allocOp);
626 
627   // Fetch a new memref type after normalizing the old memref to have an
628   // identity map layout.
629   MemRefType newMemRefType =
630       normalizeMemRefType(memrefType, b, allocOp->symbolOperands().size());
631   if (newMemRefType == memrefType)
632     // Either memrefType already had an identity map or the map couldn't be
633     // transformed to an identity map.
634     return failure();
635 
636   Value oldMemRef = allocOp->getResult();
637 
638   SmallVector<Value, 4> symbolOperands(allocOp->symbolOperands());
639   AffineMap layoutMap = memrefType.getAffineMaps().front();
640   memref::AllocOp newAlloc;
641   // Check if `layoutMap` is a tiled layout. Only single layout map is
642   // supported for normalizing dynamic memrefs.
643   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
644   (void)getTileSizePos(layoutMap, tileSizePos);
645   if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
646     MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>();
647     SmallVector<Value, 4> newDynamicSizes;
648     createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
649                           newDynamicSizes);
650     // Add the new dynamic sizes in new AllocOp.
651     newAlloc =
652         b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
653                                   newDynamicSizes, allocOp->alignmentAttr());
654   } else {
655     newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
656                                          allocOp->alignmentAttr());
657   }
658   // Replace all uses of the old memref.
659   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
660                                       /*extraIndices=*/{},
661                                       /*indexRemap=*/layoutMap,
662                                       /*extraOperands=*/{},
663                                       /*symbolOperands=*/symbolOperands,
664                                       /*domInstFilter=*/nullptr,
665                                       /*postDomInstFilter=*/nullptr,
666                                       /*allowDereferencingOps=*/true))) {
667     // If it failed (due to escapes for example), bail out.
668     newAlloc.erase();
669     return failure();
670   }
671   // Replace any uses of the original alloc op and erase it. All remaining uses
672   // have to be dealloc's; RAMUW above would've failed otherwise.
673   assert(llvm::all_of(oldMemRef.getUsers(), [](Operation *op) {
674     return isa<memref::DeallocOp>(op);
675   }));
676   oldMemRef.replaceAllUsesWith(newAlloc);
677   allocOp->erase();
678   return success();
679 }
680 
normalizeMemRefType(MemRefType memrefType,OpBuilder b,unsigned numSymbolicOperands)681 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
682                                      unsigned numSymbolicOperands) {
683   unsigned rank = memrefType.getRank();
684   if (rank == 0)
685     return memrefType;
686 
687   ArrayRef<AffineMap> layoutMaps = memrefType.getAffineMaps();
688   if (layoutMaps.empty() ||
689       layoutMaps.front() == b.getMultiDimIdentityMap(rank)) {
690     // Either no maps is associated with this memref or this memref has
691     // a trivial (identity) map.
692     return memrefType;
693   }
694 
695   // We don't do any checks for one-to-one'ness; we assume that it is
696   // one-to-one.
697 
698   // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
699   // for now.
700   // TODO: Normalize the other types of dynamic memrefs.
701   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
702   (void)getTileSizePos(layoutMaps.front(), tileSizePos);
703   if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
704     return memrefType;
705 
706   // We have a single map that is not an identity map. Create a new memref
707   // with the right shape and an identity layout map.
708   ArrayRef<int64_t> shape = memrefType.getShape();
709   // FlatAffineConstraint may later on use symbolicOperands.
710   FlatAffineConstraints fac(rank, numSymbolicOperands);
711   SmallVector<unsigned, 4> memrefTypeDynDims;
712   for (unsigned d = 0; d < rank; ++d) {
713     // Use constraint system only in static dimensions.
714     if (shape[d] > 0) {
715       fac.addConstantLowerBound(d, 0);
716       fac.addConstantUpperBound(d, shape[d] - 1);
717     } else {
718       memrefTypeDynDims.emplace_back(d);
719     }
720   }
721   // We compose this map with the original index (logical) space to derive
722   // the upper bounds for the new index space.
723   AffineMap layoutMap = layoutMaps.front();
724   unsigned newRank = layoutMap.getNumResults();
725   if (failed(fac.composeMatchingMap(layoutMap)))
726     return memrefType;
727   // TODO: Handle semi-affine maps.
728   // Project out the old data dimensions.
729   fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
730   SmallVector<int64_t, 4> newShape(newRank);
731   for (unsigned d = 0; d < newRank; ++d) {
732     // Check if each dimension of normalized memrefType is dynamic.
733     bool isDynDim = isNormalizedMemRefDynamicDim(
734         d, layoutMap, memrefTypeDynDims, b.getContext());
735     if (isDynDim) {
736       newShape[d] = -1;
737     } else {
738       // The lower bound for the shape is always zero.
739       auto ubConst = fac.getConstantUpperBound(d);
740       // For a static memref and an affine map with no symbols, this is
741       // always bounded.
742       assert(ubConst.hasValue() && "should always have an upper bound");
743       if (ubConst.getValue() < 0)
744         // This is due to an invalid map that maps to a negative space.
745         return memrefType;
746       // If dimension of new memrefType is dynamic, the value is -1.
747       newShape[d] = ubConst.getValue() + 1;
748     }
749   }
750 
751   // Create the new memref type after trivializing the old layout map.
752   MemRefType newMemRefType =
753       MemRefType::Builder(memrefType)
754           .setShape(newShape)
755           .setAffineMaps(b.getMultiDimIdentityMap(newRank));
756 
757   return newMemRefType;
758 }
759