1 //===-- AffinePromotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/IntegerSet.h"
19 #include "mlir/IR/Visitors.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/Optional.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "flang-affine-promotion"
26 
27 using namespace fir;
28 
29 namespace {
30 struct AffineLoopAnalysis;
31 struct AffineIfAnalysis;
32 
33 /// Stores analysis objects for all loops and if operations inside a function
34 ///  these analysis are used twice, first for marking operations for rewrite and
35 ///  second when doing rewrite.
36 struct AffineFunctionAnalysis {
AffineFunctionAnalysis__anond53ad5d60111::AffineFunctionAnalysis37   explicit AffineFunctionAnalysis(mlir::FuncOp funcOp) {
38     for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
39       loopAnalysisMap.try_emplace(op, op, *this);
40   }
41 
42   AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
43 
44   AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
45 
46   llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
47   llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
48 };
49 } // namespace
50 
analyzeCoordinate(mlir::Value coordinate,mlir::Operation * op)51 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
52   if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
53     if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
54       return true;
55     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
56                                "loop induction variable (owner not loopOp)\n";
57                op->dump());
58     return false;
59   }
60   LLVM_DEBUG(
61       llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
62                       "induction variable (not a block argument)\n";
63       op->dump(); coordinate.getDefiningOp()->dump());
64   return false;
65 }
66 
67 namespace {
68 struct AffineLoopAnalysis {
69   AffineLoopAnalysis() = default;
70 
AffineLoopAnalysis__anond53ad5d60211::AffineLoopAnalysis71   explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
72       : legality(analyzeLoop(op, afa)) {}
73 
canPromoteToAffine__anond53ad5d60211::AffineLoopAnalysis74   bool canPromoteToAffine() { return legality; }
75 
76 private:
analyzeBody__anond53ad5d60211::AffineLoopAnalysis77   bool analyzeBody(fir::DoLoopOp loopOperation,
78                    AffineFunctionAnalysis &functionAnalysis) {
79     for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
80       auto analysis = functionAnalysis.loopAnalysisMap
81                           .try_emplace(loopOp, loopOp, functionAnalysis)
82                           .first->getSecond();
83       if (!analysis.canPromoteToAffine())
84         return false;
85     }
86     for (auto ifOp : loopOperation.getOps<fir::IfOp>())
87       functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
88     return true;
89   }
90 
analyzeLoop__anond53ad5d60211::AffineLoopAnalysis91   bool analyzeLoop(fir::DoLoopOp loopOperation,
92                    AffineFunctionAnalysis &functionAnalysis) {
93     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
94     return analyzeMemoryAccess(loopOperation) &&
95            analyzeBody(loopOperation, functionAnalysis);
96   }
97 
analyzeReference__anond53ad5d60211::AffineLoopAnalysis98   bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
99     if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
100       if (acoOp.memref().getType().isa<fir::BoxType>()) {
101         // TODO: Look if and how fir.box can be promoted to affine.
102         LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
103                                    "array memory operation uses fir.box\n";
104                    op->dump(); acoOp.dump(););
105         return false;
106       }
107       bool canPromote = true;
108       for (auto coordinate : acoOp.indices())
109         canPromote = canPromote && analyzeCoordinate(coordinate, op);
110       return canPromote;
111     }
112     if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
113       LLVM_DEBUG(llvm::dbgs()
114                      << "AffineLoopAnalysis: cannot promote loop, "
115                         "array memory operation uses non ArrayCoorOp\n";
116                  op->dump(); coOp.dump(););
117 
118       return false;
119     }
120     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
121                                "reference for array load\n";
122                op->dump(););
123     return false;
124   }
125 
analyzeMemoryAccess__anond53ad5d60211::AffineLoopAnalysis126   bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
127     for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
128       if (!analyzeReference(loadOp.memref(), loadOp))
129         return false;
130     for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
131       if (!analyzeReference(storeOp.memref(), storeOp))
132         return false;
133     return true;
134   }
135 
136   bool legality{};
137 };
138 } // namespace
139 
140 AffineLoopAnalysis
getChildLoopAnalysis(fir::DoLoopOp op) const141 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
142   auto it = loopAnalysisMap.find_as(op);
143   if (it == loopAnalysisMap.end()) {
144     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
145                op.dump(););
146     op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
147     return {};
148   }
149   return it->getSecond();
150 }
151 
152 namespace {
153 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
154 /// final number of symbols and dimensions of the affine map. Integer set if
155 /// possible is in Optional IntegerSet.
156 struct AffineIfCondition {
157   using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>;
158 
AffineIfCondition__anond53ad5d60311::AffineIfCondition159   explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
160     if (auto condDef = firCondition.getDefiningOp<mlir::CmpIOp>())
161       fromCmpIOp(condDef);
162   }
163 
hasIntegerSet__anond53ad5d60311::AffineIfCondition164   bool hasIntegerSet() const { return integerSet.hasValue(); }
165 
getIntegerSet__anond53ad5d60311::AffineIfCondition166   mlir::IntegerSet getIntegerSet() const {
167     assert(hasIntegerSet() && "integer set is missing");
168     return integerSet.getValue();
169   }
170 
getAffineArgs__anond53ad5d60311::AffineIfCondition171   mlir::ValueRange getAffineArgs() const { return affineArgs; }
172 
173 private:
affineBinaryOp__anond53ad5d60311::AffineIfCondition174   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
175                                  mlir::Value rhs) {
176     return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
177   }
178 
affineBinaryOp__anond53ad5d60311::AffineIfCondition179   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
180                                  MaybeAffineExpr rhs) {
181     if (lhs.hasValue() && rhs.hasValue())
182       return mlir::getAffineBinaryOpExpr(kind, lhs.getValue(), rhs.getValue());
183     return {};
184   }
185 
toAffineExpr__anond53ad5d60311::AffineIfCondition186   MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
187 
toAffineExpr__anond53ad5d60311::AffineIfCondition188   MaybeAffineExpr toAffineExpr(int64_t value) {
189     return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
190   }
191 
192   /// Returns an AffineExpr if it is a result of operations that can be done
193   /// in an affine expression, this includes -, +, *, rem, constant.
194   /// block arguments of a loopOp or forOp are used as dimensions
toAffineExpr__anond53ad5d60311::AffineIfCondition195   MaybeAffineExpr toAffineExpr(mlir::Value value) {
196     if (auto op = value.getDefiningOp<mlir::SubIOp>())
197       return affineBinaryOp(mlir::AffineExprKind::Add, toAffineExpr(op.lhs()),
198                             affineBinaryOp(mlir::AffineExprKind::Mul,
199                                            toAffineExpr(op.rhs()),
200                                            toAffineExpr(-1)));
201     if (auto op = value.getDefiningOp<mlir::AddIOp>())
202       return affineBinaryOp(mlir::AffineExprKind::Add, op.lhs(), op.rhs());
203     if (auto op = value.getDefiningOp<mlir::MulIOp>())
204       return affineBinaryOp(mlir::AffineExprKind::Mul, op.lhs(), op.rhs());
205     if (auto op = value.getDefiningOp<mlir::UnsignedRemIOp>())
206       return affineBinaryOp(mlir::AffineExprKind::Mod, op.lhs(), op.rhs());
207     if (auto op = value.getDefiningOp<mlir::ConstantOp>())
208       if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>())
209         return toAffineExpr(intConstant.getInt());
210     if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
211       affineArgs.push_back(value);
212       if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
213           isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
214         return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
215       return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
216     }
217     return {};
218   }
219 
fromCmpIOp__anond53ad5d60311::AffineIfCondition220   void fromCmpIOp(mlir::CmpIOp cmpOp) {
221     auto lhsAffine = toAffineExpr(cmpOp.lhs());
222     auto rhsAffine = toAffineExpr(cmpOp.rhs());
223     if (!lhsAffine.hasValue() || !rhsAffine.hasValue())
224       return;
225     auto constraintPair = constraint(
226         cmpOp.predicate(), rhsAffine.getValue() - lhsAffine.getValue());
227     if (!constraintPair)
228       return;
229     integerSet = mlir::IntegerSet::get(dimCount, symCount,
230                                        {constraintPair.getValue().first},
231                                        {constraintPair.getValue().second});
232     return;
233   }
234 
235   llvm::Optional<std::pair<AffineExpr, bool>>
constraint__anond53ad5d60311::AffineIfCondition236   constraint(mlir::CmpIPredicate predicate, mlir::AffineExpr basic) {
237     switch (predicate) {
238     case mlir::CmpIPredicate::slt:
239       return {std::make_pair(basic - 1, false)};
240     case mlir::CmpIPredicate::sle:
241       return {std::make_pair(basic, false)};
242     case mlir::CmpIPredicate::sgt:
243       return {std::make_pair(1 - basic, false)};
244     case mlir::CmpIPredicate::sge:
245       return {std::make_pair(0 - basic, false)};
246     case mlir::CmpIPredicate::eq:
247       return {std::make_pair(basic, true)};
248     default:
249       return {};
250     }
251   }
252 
253   llvm::SmallVector<mlir::Value> affineArgs;
254   llvm::Optional<mlir::IntegerSet> integerSet;
255   mlir::Value firCondition;
256   unsigned symCount{0u};
257   unsigned dimCount{0u};
258 };
259 } // namespace
260 
261 namespace {
262 /// Analysis for affine promotion of fir.if
263 struct AffineIfAnalysis {
264   AffineIfAnalysis() = default;
265 
AffineIfAnalysis__anond53ad5d60411::AffineIfAnalysis266   explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
267       : legality(analyzeIf(op, afa)) {}
268 
canPromoteToAffine__anond53ad5d60411::AffineIfAnalysis269   bool canPromoteToAffine() { return legality; }
270 
271 private:
analyzeIf__anond53ad5d60411::AffineIfAnalysis272   bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
273     if (op.getNumResults() == 0)
274       return true;
275     LLVM_DEBUG(llvm::dbgs()
276                    << "AffineIfAnalysis: not promoting as op has results\n";);
277     return false;
278   }
279 
280   bool legality{};
281 };
282 } // namespace
283 
284 AffineIfAnalysis
getChildIfAnalysis(fir::IfOp op) const285 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
286   auto it = ifAnalysisMap.find_as(op);
287   if (it == ifAnalysisMap.end()) {
288     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
289                op.dump(););
290     op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
291     return {};
292   }
293   return it->getSecond();
294 }
295 
296 /// AffineMap rewriting fir.array_coor operation to affine apply,
297 /// %dim = fir.gendim %lowerBound, %upperBound, %stride
298 /// %a = fir.array_coor %arr(%dim) %i
299 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
createArrayIndexAffineMap(unsigned dimensions,MLIRContext * context)300 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
301                                                  MLIRContext *context) {
302   auto index = mlir::getAffineConstantExpr(0, context);
303   auto accuExtent = mlir::getAffineConstantExpr(1, context);
304   for (unsigned i = 0; i < dimensions; ++i) {
305     mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
306                      lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
307                      currentExtent =
308                          mlir::getAffineSymbolExpr(i * 3 + 1, context),
309                      stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
310                      currentPart = (idx * stride - lowerBound) * accuExtent;
311     index = currentPart + index;
312     accuExtent = accuExtent * currentExtent;
313   }
314   return mlir::AffineMap::get(dimensions, dimensions * 3, index);
315 }
316 
constantIntegerLike(const mlir::Value value)317 static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
318   if (auto definition = value.getDefiningOp<ConstantOp>())
319     if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
320       return stepAttr.getInt();
321   return {};
322 }
323 
coordinateArrayElement(fir::ArrayCoorOp op)324 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
325   if (auto refType = op.memref().getType().dyn_cast_or_null<ReferenceType>()) {
326     if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
327       return seqType.getEleTy();
328     }
329   }
330   op.emitError(
331       "AffineLoopConversion: array type in coordinate operation not valid\n");
332   return mlir::Type();
333 }
334 
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::ShapeOp shape,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)335 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
336                               SmallVectorImpl<mlir::Value> &indexArgs,
337                               mlir::PatternRewriter &rewriter) {
338   auto one = rewriter.create<mlir::ConstantOp>(
339       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
340   auto extents = shape.extents();
341   for (auto i = extents.begin(); i < extents.end(); i++) {
342     indexArgs.push_back(one);
343     indexArgs.push_back(*i);
344     indexArgs.push_back(one);
345   }
346 }
347 
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::ShapeShiftOp shape,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)348 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
349                               SmallVectorImpl<mlir::Value> &indexArgs,
350                               mlir::PatternRewriter &rewriter) {
351   auto one = rewriter.create<mlir::ConstantOp>(
352       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
353   auto extents = shape.pairs();
354   for (auto i = extents.begin(); i < extents.end();) {
355     indexArgs.push_back(*i++);
356     indexArgs.push_back(*i++);
357     indexArgs.push_back(one);
358   }
359 }
360 
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::SliceOp slice,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)361 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
362                               SmallVectorImpl<mlir::Value> &indexArgs,
363                               mlir::PatternRewriter &rewriter) {
364   auto extents = slice.triples();
365   for (auto i = extents.begin(); i < extents.end();) {
366     indexArgs.push_back(*i++);
367     indexArgs.push_back(*i++);
368     indexArgs.push_back(*i++);
369   }
370 }
371 
populateIndexArgs(fir::ArrayCoorOp acoOp,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)372 static void populateIndexArgs(fir::ArrayCoorOp acoOp,
373                               SmallVectorImpl<mlir::Value> &indexArgs,
374                               mlir::PatternRewriter &rewriter) {
375   if (auto shape = acoOp.shape().getDefiningOp<ShapeOp>())
376     return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
377   if (auto shapeShift = acoOp.shape().getDefiningOp<ShapeShiftOp>())
378     return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
379   if (auto slice = acoOp.shape().getDefiningOp<SliceOp>())
380     return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
381   return;
382 }
383 
384 /// Returns affine.apply and fir.convert from array_coor and gendims
385 static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
createAffineOps(mlir::Value arrayRef,mlir::PatternRewriter & rewriter)386 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
387   auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
388   auto affineMap =
389       createArrayIndexAffineMap(acoOp.indices().size(), acoOp.getContext());
390   SmallVector<mlir::Value> indexArgs;
391   indexArgs.append(acoOp.indices().begin(), acoOp.indices().end());
392 
393   populateIndexArgs(acoOp, indexArgs, rewriter);
394 
395   auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
396                                                           affineMap, indexArgs);
397   auto arrayElementType = coordinateArrayElement(acoOp);
398   auto newType = mlir::MemRefType::get({-1}, arrayElementType);
399   auto arrayConvert =
400       rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, acoOp.memref());
401   return std::make_pair(affineApply, arrayConvert);
402 }
403 
rewriteLoad(fir::LoadOp loadOp,mlir::PatternRewriter & rewriter)404 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
405   rewriter.setInsertionPoint(loadOp);
406   auto affineOps = createAffineOps(loadOp.memref(), rewriter);
407   rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
408       loadOp, affineOps.second.getResult(), affineOps.first.getResult());
409 }
410 
rewriteStore(fir::StoreOp storeOp,mlir::PatternRewriter & rewriter)411 static void rewriteStore(fir::StoreOp storeOp,
412                          mlir::PatternRewriter &rewriter) {
413   rewriter.setInsertionPoint(storeOp);
414   auto affineOps = createAffineOps(storeOp.memref(), rewriter);
415   rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.value(),
416                                                    affineOps.second.getResult(),
417                                                    affineOps.first.getResult());
418 }
419 
rewriteMemoryOps(Block * block,mlir::PatternRewriter & rewriter)420 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
421   for (auto &bodyOp : block->getOperations()) {
422     if (isa<fir::LoadOp>(bodyOp))
423       rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
424     if (isa<fir::StoreOp>(bodyOp))
425       rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
426   }
427 }
428 
429 namespace {
430 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
431 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
432 /// loads and stores to affine.
433 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
434 public:
435   using OpRewritePattern::OpRewritePattern;
AffineLoopConversion(mlir::MLIRContext * context,AffineFunctionAnalysis & afa)436   AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
437       : OpRewritePattern(context), functionAnalysis(afa) {}
438 
439   mlir::LogicalResult
matchAndRewrite(fir::DoLoopOp loop,mlir::PatternRewriter & rewriter) const440   matchAndRewrite(fir::DoLoopOp loop,
441                   mlir::PatternRewriter &rewriter) const override {
442     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
443                loop.dump(););
444     LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
445         functionAnalysis.getChildLoopAnalysis(loop);
446     auto &loopOps = loop.getBody()->getOperations();
447     auto loopAndIndex = createAffineFor(loop, rewriter);
448     auto affineFor = loopAndIndex.first;
449     auto inductionVar = loopAndIndex.second;
450 
451     rewriter.startRootUpdate(affineFor.getOperation());
452     affineFor.getBody()->getOperations().splice(
453         std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
454         std::prev(loopOps.end()));
455     rewriter.finalizeRootUpdate(affineFor.getOperation());
456 
457     rewriter.startRootUpdate(loop.getOperation());
458     loop.getInductionVar().replaceAllUsesWith(inductionVar);
459     rewriter.finalizeRootUpdate(loop.getOperation());
460 
461     rewriteMemoryOps(affineFor.getBody(), rewriter);
462 
463     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
464                affineFor.dump(););
465     rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
466     return success();
467   }
468 
469 private:
470   std::pair<mlir::AffineForOp, mlir::Value>
createAffineFor(fir::DoLoopOp op,mlir::PatternRewriter & rewriter) const471   createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
472     if (auto constantStep = constantIntegerLike(op.step()))
473       if (constantStep.getValue() > 0)
474         return positiveConstantStep(op, constantStep.getValue(), rewriter);
475     return genericBounds(op, rewriter);
476   }
477 
478   // when step for the loop is positive compile time constant
479   std::pair<mlir::AffineForOp, mlir::Value>
positiveConstantStep(fir::DoLoopOp op,int64_t step,mlir::PatternRewriter & rewriter) const480   positiveConstantStep(fir::DoLoopOp op, int64_t step,
481                        mlir::PatternRewriter &rewriter) const {
482     auto affineFor = rewriter.create<mlir::AffineForOp>(
483         op.getLoc(), ValueRange(op.lowerBound()),
484         mlir::AffineMap::get(0, 1,
485                              mlir::getAffineSymbolExpr(0, op.getContext())),
486         ValueRange(op.upperBound()),
487         mlir::AffineMap::get(0, 1,
488                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
489         step);
490     return std::make_pair(affineFor, affineFor.getInductionVar());
491   }
492 
493   std::pair<mlir::AffineForOp, mlir::Value>
genericBounds(fir::DoLoopOp op,mlir::PatternRewriter & rewriter) const494   genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
495     auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
496     auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
497     auto step = mlir::getAffineSymbolExpr(2, op.getContext());
498     mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
499         0, 3, (upperBound - lowerBound + step).floorDiv(step));
500     auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
501         op.getLoc(), upperBoundMap,
502         ValueRange({op.lowerBound(), op.upperBound(), op.step()}));
503     auto actualIndexMap = mlir::AffineMap::get(
504         1, 2,
505         (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
506             mlir::getAffineSymbolExpr(1, op.getContext()));
507 
508     auto affineFor = rewriter.create<mlir::AffineForOp>(
509         op.getLoc(), ValueRange(),
510         AffineMap::getConstantMap(0, op.getContext()),
511         genericUpperBound.getResult(),
512         mlir::AffineMap::get(0, 1,
513                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
514         1);
515     rewriter.setInsertionPointToStart(affineFor.getBody());
516     auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
517         op.getLoc(), actualIndexMap,
518         ValueRange({affineFor.getInductionVar(), op.lowerBound(), op.step()}));
519     return std::make_pair(affineFor, actualIndex.getResult());
520   }
521 
522   AffineFunctionAnalysis &functionAnalysis;
523 };
524 
525 /// Convert `fir.if` to `affine.if`.
526 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
527 public:
528   using OpRewritePattern::OpRewritePattern;
AffineIfConversion(mlir::MLIRContext * context,AffineFunctionAnalysis & afa)529   AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
530       : OpRewritePattern(context) {}
531   mlir::LogicalResult
matchAndRewrite(fir::IfOp op,mlir::PatternRewriter & rewriter) const532   matchAndRewrite(fir::IfOp op,
533                   mlir::PatternRewriter &rewriter) const override {
534     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
535                op.dump(););
536     auto &ifOps = op.thenRegion().front().getOperations();
537     auto affineCondition = AffineIfCondition(op.condition());
538     if (!affineCondition.hasIntegerSet()) {
539       LLVM_DEBUG(
540           llvm::dbgs()
541               << "AffineIfConversion: couldn't calculate affine condition\n";);
542       return failure();
543     }
544     auto affineIf = rewriter.create<mlir::AffineIfOp>(
545         op.getLoc(), affineCondition.getIntegerSet(),
546         affineCondition.getAffineArgs(), !op.elseRegion().empty());
547     rewriter.startRootUpdate(affineIf);
548     affineIf.getThenBlock()->getOperations().splice(
549         std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
550         std::prev(ifOps.end()));
551     if (!op.elseRegion().empty()) {
552       auto &otherOps = op.elseRegion().front().getOperations();
553       affineIf.getElseBlock()->getOperations().splice(
554           std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
555           std::prev(otherOps.end()));
556     }
557     rewriter.finalizeRootUpdate(affineIf);
558     rewriteMemoryOps(affineIf.getBody(), rewriter);
559 
560     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
561                affineIf.dump(););
562     rewriter.replaceOp(op, affineIf.getOperation()->getResults());
563     return success();
564   }
565 };
566 
567 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
568 /// where such a promotion is possible.
569 class AffineDialectPromotion
570     : public AffineDialectPromotionBase<AffineDialectPromotion> {
571 public:
runOnFunction()572   void runOnFunction() override {
573 
574     auto *context = &getContext();
575     auto function = getFunction();
576     markAllAnalysesPreserved();
577     auto functionAnalysis = AffineFunctionAnalysis(function);
578     mlir::OwningRewritePatternList patterns(context);
579     patterns.insert<AffineIfConversion>(context, functionAnalysis);
580     patterns.insert<AffineLoopConversion>(context, functionAnalysis);
581     mlir::ConversionTarget target = *context;
582     target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
583                            mlir::scf::SCFDialect, mlir::StandardOpsDialect>();
584     target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
585       return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
586     });
587     target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
588                                                fir::DoLoopOp op) {
589       return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
590     });
591 
592     LLVM_DEBUG(llvm::dbgs()
593                    << "AffineDialectPromotion: running promotion on: \n";
594                function.print(llvm::dbgs()););
595     // apply the patterns
596     if (mlir::failed(mlir::applyPartialConversion(function, target,
597                                                   std::move(patterns)))) {
598       mlir::emitError(mlir::UnknownLoc::get(context),
599                       "error in converting to affine dialect\n");
600       signalPassFailure();
601     }
602   }
603 };
604 } // namespace
605 
606 /// Convert FIR loop constructs to the Affine dialect
createPromoteToAffinePass()607 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
608   return std::make_unique<AffineDialectPromotion>();
609 }
610