1 //===-- AffinePromotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/IntegerSet.h"
19 #include "mlir/IR/Visitors.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/Optional.h"
23 #include "llvm/Support/Debug.h"
24
25 #define DEBUG_TYPE "flang-affine-promotion"
26
27 using namespace fir;
28
29 namespace {
30 struct AffineLoopAnalysis;
31 struct AffineIfAnalysis;
32
33 /// Stores analysis objects for all loops and if operations inside a function
34 /// these analysis are used twice, first for marking operations for rewrite and
35 /// second when doing rewrite.
36 struct AffineFunctionAnalysis {
AffineFunctionAnalysis__anond53ad5d60111::AffineFunctionAnalysis37 explicit AffineFunctionAnalysis(mlir::FuncOp funcOp) {
38 for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
39 loopAnalysisMap.try_emplace(op, op, *this);
40 }
41
42 AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
43
44 AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
45
46 llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
47 llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
48 };
49 } // namespace
50
analyzeCoordinate(mlir::Value coordinate,mlir::Operation * op)51 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
52 if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
53 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
54 return true;
55 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
56 "loop induction variable (owner not loopOp)\n";
57 op->dump());
58 return false;
59 }
60 LLVM_DEBUG(
61 llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
62 "induction variable (not a block argument)\n";
63 op->dump(); coordinate.getDefiningOp()->dump());
64 return false;
65 }
66
67 namespace {
68 struct AffineLoopAnalysis {
69 AffineLoopAnalysis() = default;
70
AffineLoopAnalysis__anond53ad5d60211::AffineLoopAnalysis71 explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
72 : legality(analyzeLoop(op, afa)) {}
73
canPromoteToAffine__anond53ad5d60211::AffineLoopAnalysis74 bool canPromoteToAffine() { return legality; }
75
76 private:
analyzeBody__anond53ad5d60211::AffineLoopAnalysis77 bool analyzeBody(fir::DoLoopOp loopOperation,
78 AffineFunctionAnalysis &functionAnalysis) {
79 for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
80 auto analysis = functionAnalysis.loopAnalysisMap
81 .try_emplace(loopOp, loopOp, functionAnalysis)
82 .first->getSecond();
83 if (!analysis.canPromoteToAffine())
84 return false;
85 }
86 for (auto ifOp : loopOperation.getOps<fir::IfOp>())
87 functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
88 return true;
89 }
90
analyzeLoop__anond53ad5d60211::AffineLoopAnalysis91 bool analyzeLoop(fir::DoLoopOp loopOperation,
92 AffineFunctionAnalysis &functionAnalysis) {
93 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
94 return analyzeMemoryAccess(loopOperation) &&
95 analyzeBody(loopOperation, functionAnalysis);
96 }
97
analyzeReference__anond53ad5d60211::AffineLoopAnalysis98 bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
99 if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
100 if (acoOp.memref().getType().isa<fir::BoxType>()) {
101 // TODO: Look if and how fir.box can be promoted to affine.
102 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
103 "array memory operation uses fir.box\n";
104 op->dump(); acoOp.dump(););
105 return false;
106 }
107 bool canPromote = true;
108 for (auto coordinate : acoOp.indices())
109 canPromote = canPromote && analyzeCoordinate(coordinate, op);
110 return canPromote;
111 }
112 if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
113 LLVM_DEBUG(llvm::dbgs()
114 << "AffineLoopAnalysis: cannot promote loop, "
115 "array memory operation uses non ArrayCoorOp\n";
116 op->dump(); coOp.dump(););
117
118 return false;
119 }
120 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
121 "reference for array load\n";
122 op->dump(););
123 return false;
124 }
125
analyzeMemoryAccess__anond53ad5d60211::AffineLoopAnalysis126 bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
127 for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
128 if (!analyzeReference(loadOp.memref(), loadOp))
129 return false;
130 for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
131 if (!analyzeReference(storeOp.memref(), storeOp))
132 return false;
133 return true;
134 }
135
136 bool legality{};
137 };
138 } // namespace
139
140 AffineLoopAnalysis
getChildLoopAnalysis(fir::DoLoopOp op) const141 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
142 auto it = loopAnalysisMap.find_as(op);
143 if (it == loopAnalysisMap.end()) {
144 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
145 op.dump(););
146 op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
147 return {};
148 }
149 return it->getSecond();
150 }
151
152 namespace {
153 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
154 /// final number of symbols and dimensions of the affine map. Integer set if
155 /// possible is in Optional IntegerSet.
156 struct AffineIfCondition {
157 using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>;
158
AffineIfCondition__anond53ad5d60311::AffineIfCondition159 explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
160 if (auto condDef = firCondition.getDefiningOp<mlir::CmpIOp>())
161 fromCmpIOp(condDef);
162 }
163
hasIntegerSet__anond53ad5d60311::AffineIfCondition164 bool hasIntegerSet() const { return integerSet.hasValue(); }
165
getIntegerSet__anond53ad5d60311::AffineIfCondition166 mlir::IntegerSet getIntegerSet() const {
167 assert(hasIntegerSet() && "integer set is missing");
168 return integerSet.getValue();
169 }
170
getAffineArgs__anond53ad5d60311::AffineIfCondition171 mlir::ValueRange getAffineArgs() const { return affineArgs; }
172
173 private:
affineBinaryOp__anond53ad5d60311::AffineIfCondition174 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
175 mlir::Value rhs) {
176 return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
177 }
178
affineBinaryOp__anond53ad5d60311::AffineIfCondition179 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
180 MaybeAffineExpr rhs) {
181 if (lhs.hasValue() && rhs.hasValue())
182 return mlir::getAffineBinaryOpExpr(kind, lhs.getValue(), rhs.getValue());
183 return {};
184 }
185
toAffineExpr__anond53ad5d60311::AffineIfCondition186 MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
187
toAffineExpr__anond53ad5d60311::AffineIfCondition188 MaybeAffineExpr toAffineExpr(int64_t value) {
189 return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
190 }
191
192 /// Returns an AffineExpr if it is a result of operations that can be done
193 /// in an affine expression, this includes -, +, *, rem, constant.
194 /// block arguments of a loopOp or forOp are used as dimensions
toAffineExpr__anond53ad5d60311::AffineIfCondition195 MaybeAffineExpr toAffineExpr(mlir::Value value) {
196 if (auto op = value.getDefiningOp<mlir::SubIOp>())
197 return affineBinaryOp(mlir::AffineExprKind::Add, toAffineExpr(op.lhs()),
198 affineBinaryOp(mlir::AffineExprKind::Mul,
199 toAffineExpr(op.rhs()),
200 toAffineExpr(-1)));
201 if (auto op = value.getDefiningOp<mlir::AddIOp>())
202 return affineBinaryOp(mlir::AffineExprKind::Add, op.lhs(), op.rhs());
203 if (auto op = value.getDefiningOp<mlir::MulIOp>())
204 return affineBinaryOp(mlir::AffineExprKind::Mul, op.lhs(), op.rhs());
205 if (auto op = value.getDefiningOp<mlir::UnsignedRemIOp>())
206 return affineBinaryOp(mlir::AffineExprKind::Mod, op.lhs(), op.rhs());
207 if (auto op = value.getDefiningOp<mlir::ConstantOp>())
208 if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>())
209 return toAffineExpr(intConstant.getInt());
210 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
211 affineArgs.push_back(value);
212 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
213 isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
214 return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
215 return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
216 }
217 return {};
218 }
219
fromCmpIOp__anond53ad5d60311::AffineIfCondition220 void fromCmpIOp(mlir::CmpIOp cmpOp) {
221 auto lhsAffine = toAffineExpr(cmpOp.lhs());
222 auto rhsAffine = toAffineExpr(cmpOp.rhs());
223 if (!lhsAffine.hasValue() || !rhsAffine.hasValue())
224 return;
225 auto constraintPair = constraint(
226 cmpOp.predicate(), rhsAffine.getValue() - lhsAffine.getValue());
227 if (!constraintPair)
228 return;
229 integerSet = mlir::IntegerSet::get(dimCount, symCount,
230 {constraintPair.getValue().first},
231 {constraintPair.getValue().second});
232 return;
233 }
234
235 llvm::Optional<std::pair<AffineExpr, bool>>
constraint__anond53ad5d60311::AffineIfCondition236 constraint(mlir::CmpIPredicate predicate, mlir::AffineExpr basic) {
237 switch (predicate) {
238 case mlir::CmpIPredicate::slt:
239 return {std::make_pair(basic - 1, false)};
240 case mlir::CmpIPredicate::sle:
241 return {std::make_pair(basic, false)};
242 case mlir::CmpIPredicate::sgt:
243 return {std::make_pair(1 - basic, false)};
244 case mlir::CmpIPredicate::sge:
245 return {std::make_pair(0 - basic, false)};
246 case mlir::CmpIPredicate::eq:
247 return {std::make_pair(basic, true)};
248 default:
249 return {};
250 }
251 }
252
253 llvm::SmallVector<mlir::Value> affineArgs;
254 llvm::Optional<mlir::IntegerSet> integerSet;
255 mlir::Value firCondition;
256 unsigned symCount{0u};
257 unsigned dimCount{0u};
258 };
259 } // namespace
260
261 namespace {
262 /// Analysis for affine promotion of fir.if
263 struct AffineIfAnalysis {
264 AffineIfAnalysis() = default;
265
AffineIfAnalysis__anond53ad5d60411::AffineIfAnalysis266 explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
267 : legality(analyzeIf(op, afa)) {}
268
canPromoteToAffine__anond53ad5d60411::AffineIfAnalysis269 bool canPromoteToAffine() { return legality; }
270
271 private:
analyzeIf__anond53ad5d60411::AffineIfAnalysis272 bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
273 if (op.getNumResults() == 0)
274 return true;
275 LLVM_DEBUG(llvm::dbgs()
276 << "AffineIfAnalysis: not promoting as op has results\n";);
277 return false;
278 }
279
280 bool legality{};
281 };
282 } // namespace
283
284 AffineIfAnalysis
getChildIfAnalysis(fir::IfOp op) const285 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
286 auto it = ifAnalysisMap.find_as(op);
287 if (it == ifAnalysisMap.end()) {
288 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
289 op.dump(););
290 op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
291 return {};
292 }
293 return it->getSecond();
294 }
295
296 /// AffineMap rewriting fir.array_coor operation to affine apply,
297 /// %dim = fir.gendim %lowerBound, %upperBound, %stride
298 /// %a = fir.array_coor %arr(%dim) %i
299 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
createArrayIndexAffineMap(unsigned dimensions,MLIRContext * context)300 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
301 MLIRContext *context) {
302 auto index = mlir::getAffineConstantExpr(0, context);
303 auto accuExtent = mlir::getAffineConstantExpr(1, context);
304 for (unsigned i = 0; i < dimensions; ++i) {
305 mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
306 lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
307 currentExtent =
308 mlir::getAffineSymbolExpr(i * 3 + 1, context),
309 stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
310 currentPart = (idx * stride - lowerBound) * accuExtent;
311 index = currentPart + index;
312 accuExtent = accuExtent * currentExtent;
313 }
314 return mlir::AffineMap::get(dimensions, dimensions * 3, index);
315 }
316
constantIntegerLike(const mlir::Value value)317 static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
318 if (auto definition = value.getDefiningOp<ConstantOp>())
319 if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
320 return stepAttr.getInt();
321 return {};
322 }
323
coordinateArrayElement(fir::ArrayCoorOp op)324 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
325 if (auto refType = op.memref().getType().dyn_cast_or_null<ReferenceType>()) {
326 if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
327 return seqType.getEleTy();
328 }
329 }
330 op.emitError(
331 "AffineLoopConversion: array type in coordinate operation not valid\n");
332 return mlir::Type();
333 }
334
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::ShapeOp shape,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)335 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
336 SmallVectorImpl<mlir::Value> &indexArgs,
337 mlir::PatternRewriter &rewriter) {
338 auto one = rewriter.create<mlir::ConstantOp>(
339 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
340 auto extents = shape.extents();
341 for (auto i = extents.begin(); i < extents.end(); i++) {
342 indexArgs.push_back(one);
343 indexArgs.push_back(*i);
344 indexArgs.push_back(one);
345 }
346 }
347
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::ShapeShiftOp shape,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)348 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
349 SmallVectorImpl<mlir::Value> &indexArgs,
350 mlir::PatternRewriter &rewriter) {
351 auto one = rewriter.create<mlir::ConstantOp>(
352 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
353 auto extents = shape.pairs();
354 for (auto i = extents.begin(); i < extents.end();) {
355 indexArgs.push_back(*i++);
356 indexArgs.push_back(*i++);
357 indexArgs.push_back(one);
358 }
359 }
360
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::SliceOp slice,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)361 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
362 SmallVectorImpl<mlir::Value> &indexArgs,
363 mlir::PatternRewriter &rewriter) {
364 auto extents = slice.triples();
365 for (auto i = extents.begin(); i < extents.end();) {
366 indexArgs.push_back(*i++);
367 indexArgs.push_back(*i++);
368 indexArgs.push_back(*i++);
369 }
370 }
371
populateIndexArgs(fir::ArrayCoorOp acoOp,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)372 static void populateIndexArgs(fir::ArrayCoorOp acoOp,
373 SmallVectorImpl<mlir::Value> &indexArgs,
374 mlir::PatternRewriter &rewriter) {
375 if (auto shape = acoOp.shape().getDefiningOp<ShapeOp>())
376 return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
377 if (auto shapeShift = acoOp.shape().getDefiningOp<ShapeShiftOp>())
378 return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
379 if (auto slice = acoOp.shape().getDefiningOp<SliceOp>())
380 return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
381 return;
382 }
383
384 /// Returns affine.apply and fir.convert from array_coor and gendims
385 static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
createAffineOps(mlir::Value arrayRef,mlir::PatternRewriter & rewriter)386 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
387 auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
388 auto affineMap =
389 createArrayIndexAffineMap(acoOp.indices().size(), acoOp.getContext());
390 SmallVector<mlir::Value> indexArgs;
391 indexArgs.append(acoOp.indices().begin(), acoOp.indices().end());
392
393 populateIndexArgs(acoOp, indexArgs, rewriter);
394
395 auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
396 affineMap, indexArgs);
397 auto arrayElementType = coordinateArrayElement(acoOp);
398 auto newType = mlir::MemRefType::get({-1}, arrayElementType);
399 auto arrayConvert =
400 rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, acoOp.memref());
401 return std::make_pair(affineApply, arrayConvert);
402 }
403
rewriteLoad(fir::LoadOp loadOp,mlir::PatternRewriter & rewriter)404 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
405 rewriter.setInsertionPoint(loadOp);
406 auto affineOps = createAffineOps(loadOp.memref(), rewriter);
407 rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
408 loadOp, affineOps.second.getResult(), affineOps.first.getResult());
409 }
410
rewriteStore(fir::StoreOp storeOp,mlir::PatternRewriter & rewriter)411 static void rewriteStore(fir::StoreOp storeOp,
412 mlir::PatternRewriter &rewriter) {
413 rewriter.setInsertionPoint(storeOp);
414 auto affineOps = createAffineOps(storeOp.memref(), rewriter);
415 rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.value(),
416 affineOps.second.getResult(),
417 affineOps.first.getResult());
418 }
419
rewriteMemoryOps(Block * block,mlir::PatternRewriter & rewriter)420 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
421 for (auto &bodyOp : block->getOperations()) {
422 if (isa<fir::LoadOp>(bodyOp))
423 rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
424 if (isa<fir::StoreOp>(bodyOp))
425 rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
426 }
427 }
428
429 namespace {
430 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
431 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
432 /// loads and stores to affine.
433 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
434 public:
435 using OpRewritePattern::OpRewritePattern;
AffineLoopConversion(mlir::MLIRContext * context,AffineFunctionAnalysis & afa)436 AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
437 : OpRewritePattern(context), functionAnalysis(afa) {}
438
439 mlir::LogicalResult
matchAndRewrite(fir::DoLoopOp loop,mlir::PatternRewriter & rewriter) const440 matchAndRewrite(fir::DoLoopOp loop,
441 mlir::PatternRewriter &rewriter) const override {
442 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
443 loop.dump(););
444 LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
445 functionAnalysis.getChildLoopAnalysis(loop);
446 auto &loopOps = loop.getBody()->getOperations();
447 auto loopAndIndex = createAffineFor(loop, rewriter);
448 auto affineFor = loopAndIndex.first;
449 auto inductionVar = loopAndIndex.second;
450
451 rewriter.startRootUpdate(affineFor.getOperation());
452 affineFor.getBody()->getOperations().splice(
453 std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
454 std::prev(loopOps.end()));
455 rewriter.finalizeRootUpdate(affineFor.getOperation());
456
457 rewriter.startRootUpdate(loop.getOperation());
458 loop.getInductionVar().replaceAllUsesWith(inductionVar);
459 rewriter.finalizeRootUpdate(loop.getOperation());
460
461 rewriteMemoryOps(affineFor.getBody(), rewriter);
462
463 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
464 affineFor.dump(););
465 rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
466 return success();
467 }
468
469 private:
470 std::pair<mlir::AffineForOp, mlir::Value>
createAffineFor(fir::DoLoopOp op,mlir::PatternRewriter & rewriter) const471 createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
472 if (auto constantStep = constantIntegerLike(op.step()))
473 if (constantStep.getValue() > 0)
474 return positiveConstantStep(op, constantStep.getValue(), rewriter);
475 return genericBounds(op, rewriter);
476 }
477
478 // when step for the loop is positive compile time constant
479 std::pair<mlir::AffineForOp, mlir::Value>
positiveConstantStep(fir::DoLoopOp op,int64_t step,mlir::PatternRewriter & rewriter) const480 positiveConstantStep(fir::DoLoopOp op, int64_t step,
481 mlir::PatternRewriter &rewriter) const {
482 auto affineFor = rewriter.create<mlir::AffineForOp>(
483 op.getLoc(), ValueRange(op.lowerBound()),
484 mlir::AffineMap::get(0, 1,
485 mlir::getAffineSymbolExpr(0, op.getContext())),
486 ValueRange(op.upperBound()),
487 mlir::AffineMap::get(0, 1,
488 1 + mlir::getAffineSymbolExpr(0, op.getContext())),
489 step);
490 return std::make_pair(affineFor, affineFor.getInductionVar());
491 }
492
493 std::pair<mlir::AffineForOp, mlir::Value>
genericBounds(fir::DoLoopOp op,mlir::PatternRewriter & rewriter) const494 genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
495 auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
496 auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
497 auto step = mlir::getAffineSymbolExpr(2, op.getContext());
498 mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
499 0, 3, (upperBound - lowerBound + step).floorDiv(step));
500 auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
501 op.getLoc(), upperBoundMap,
502 ValueRange({op.lowerBound(), op.upperBound(), op.step()}));
503 auto actualIndexMap = mlir::AffineMap::get(
504 1, 2,
505 (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
506 mlir::getAffineSymbolExpr(1, op.getContext()));
507
508 auto affineFor = rewriter.create<mlir::AffineForOp>(
509 op.getLoc(), ValueRange(),
510 AffineMap::getConstantMap(0, op.getContext()),
511 genericUpperBound.getResult(),
512 mlir::AffineMap::get(0, 1,
513 1 + mlir::getAffineSymbolExpr(0, op.getContext())),
514 1);
515 rewriter.setInsertionPointToStart(affineFor.getBody());
516 auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
517 op.getLoc(), actualIndexMap,
518 ValueRange({affineFor.getInductionVar(), op.lowerBound(), op.step()}));
519 return std::make_pair(affineFor, actualIndex.getResult());
520 }
521
522 AffineFunctionAnalysis &functionAnalysis;
523 };
524
525 /// Convert `fir.if` to `affine.if`.
526 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
527 public:
528 using OpRewritePattern::OpRewritePattern;
AffineIfConversion(mlir::MLIRContext * context,AffineFunctionAnalysis & afa)529 AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
530 : OpRewritePattern(context) {}
531 mlir::LogicalResult
matchAndRewrite(fir::IfOp op,mlir::PatternRewriter & rewriter) const532 matchAndRewrite(fir::IfOp op,
533 mlir::PatternRewriter &rewriter) const override {
534 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
535 op.dump(););
536 auto &ifOps = op.thenRegion().front().getOperations();
537 auto affineCondition = AffineIfCondition(op.condition());
538 if (!affineCondition.hasIntegerSet()) {
539 LLVM_DEBUG(
540 llvm::dbgs()
541 << "AffineIfConversion: couldn't calculate affine condition\n";);
542 return failure();
543 }
544 auto affineIf = rewriter.create<mlir::AffineIfOp>(
545 op.getLoc(), affineCondition.getIntegerSet(),
546 affineCondition.getAffineArgs(), !op.elseRegion().empty());
547 rewriter.startRootUpdate(affineIf);
548 affineIf.getThenBlock()->getOperations().splice(
549 std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
550 std::prev(ifOps.end()));
551 if (!op.elseRegion().empty()) {
552 auto &otherOps = op.elseRegion().front().getOperations();
553 affineIf.getElseBlock()->getOperations().splice(
554 std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
555 std::prev(otherOps.end()));
556 }
557 rewriter.finalizeRootUpdate(affineIf);
558 rewriteMemoryOps(affineIf.getBody(), rewriter);
559
560 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
561 affineIf.dump(););
562 rewriter.replaceOp(op, affineIf.getOperation()->getResults());
563 return success();
564 }
565 };
566
567 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
568 /// where such a promotion is possible.
569 class AffineDialectPromotion
570 : public AffineDialectPromotionBase<AffineDialectPromotion> {
571 public:
runOnFunction()572 void runOnFunction() override {
573
574 auto *context = &getContext();
575 auto function = getFunction();
576 markAllAnalysesPreserved();
577 auto functionAnalysis = AffineFunctionAnalysis(function);
578 mlir::OwningRewritePatternList patterns(context);
579 patterns.insert<AffineIfConversion>(context, functionAnalysis);
580 patterns.insert<AffineLoopConversion>(context, functionAnalysis);
581 mlir::ConversionTarget target = *context;
582 target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
583 mlir::scf::SCFDialect, mlir::StandardOpsDialect>();
584 target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
585 return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
586 });
587 target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
588 fir::DoLoopOp op) {
589 return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
590 });
591
592 LLVM_DEBUG(llvm::dbgs()
593 << "AffineDialectPromotion: running promotion on: \n";
594 function.print(llvm::dbgs()););
595 // apply the patterns
596 if (mlir::failed(mlir::applyPartialConversion(function, target,
597 std::move(patterns)))) {
598 mlir::emitError(mlir::UnknownLoc::get(context),
599 "error in converting to affine dialect\n");
600 signalPassFailure();
601 }
602 }
603 };
604 } // namespace
605
606 /// Convert FIR loop constructs to the Affine dialect
createPromoteToAffinePass()607 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
608 return std::make_unique<AffineDialectPromotion>();
609 }
610