1 //===-- PreCGRewrite.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "CGOps.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/CodeGen/CodeGen.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROps.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Optimizer/Support/FIRContext.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23
24 //===----------------------------------------------------------------------===//
25 // Codegen rewrite: rewriting of subgraphs of ops
26 //===----------------------------------------------------------------------===//
27
28 using namespace fir;
29
30 #define DEBUG_TYPE "flang-codegen-rewrite"
31
populateShape(llvm::SmallVectorImpl<mlir::Value> & vec,ShapeOp shape)32 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
33 ShapeOp shape) {
34 vec.append(shape.extents().begin(), shape.extents().end());
35 }
36
37 // Operands of fir.shape_shift split into two vectors.
populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> & shapeVec,llvm::SmallVectorImpl<mlir::Value> & shiftVec,ShapeShiftOp shift)38 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
39 llvm::SmallVectorImpl<mlir::Value> &shiftVec,
40 ShapeShiftOp shift) {
41 auto endIter = shift.pairs().end();
42 for (auto i = shift.pairs().begin(); i != endIter;) {
43 shiftVec.push_back(*i++);
44 shapeVec.push_back(*i++);
45 }
46 }
47
populateShift(llvm::SmallVectorImpl<mlir::Value> & vec,ShiftOp shift)48 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
49 ShiftOp shift) {
50 vec.append(shift.origins().begin(), shift.origins().end());
51 }
52
53 namespace {
54
55 /// Convert fir.embox to the extended form where necessary.
56 ///
57 /// The embox operation can take arguments that specify multidimensional array
58 /// properties at runtime. These properties may be shared between distinct
59 /// objects that have the same properties. Before we lower these small DAGs to
60 /// LLVM-IR, we gather all the information into a single extended operation. For
61 /// example,
62 /// ```
63 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
64 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
65 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
66 /// ```
67 /// can be rewritten as
68 /// ```
69 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> !fir.box<!fir.array<?xi32>>
70 /// ```
71 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> {
72 public:
73 using OpRewritePattern::OpRewritePattern;
74
75 mlir::LogicalResult
matchAndRewrite(EmboxOp embox,mlir::PatternRewriter & rewriter) const76 matchAndRewrite(EmboxOp embox,
77 mlir::PatternRewriter &rewriter) const override {
78 auto shapeVal = embox.getShape();
79 // If the embox does not include a shape, then do not convert it
80 if (shapeVal)
81 return rewriteDynamicShape(embox, rewriter, shapeVal);
82 if (auto boxTy = embox.getType().dyn_cast<BoxType>())
83 if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>())
84 if (seqTy.hasConstantShape())
85 return rewriteStaticShape(embox, rewriter, seqTy);
86 return mlir::failure();
87 }
88
rewriteStaticShape(EmboxOp embox,mlir::PatternRewriter & rewriter,SequenceType seqTy) const89 mlir::LogicalResult rewriteStaticShape(EmboxOp embox,
90 mlir::PatternRewriter &rewriter,
91 SequenceType seqTy) const {
92 auto loc = embox.getLoc();
93 llvm::SmallVector<mlir::Value> shapeOpers;
94 auto idxTy = rewriter.getIndexType();
95 for (auto ext : seqTy.getShape()) {
96 auto iAttr = rewriter.getIndexAttr(ext);
97 auto extVal = rewriter.create<mlir::ConstantOp>(loc, idxTy, iAttr);
98 shapeOpers.push_back(extVal);
99 }
100 auto xbox = rewriter.create<cg::XEmboxOp>(
101 loc, embox.getType(), embox.memref(), shapeOpers, llvm::None,
102 llvm::None, llvm::None, embox.lenParams());
103 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
104 rewriter.replaceOp(embox, xbox.getOperation()->getResults());
105 return mlir::success();
106 }
107
rewriteDynamicShape(EmboxOp embox,mlir::PatternRewriter & rewriter,mlir::Value shapeVal) const108 mlir::LogicalResult rewriteDynamicShape(EmboxOp embox,
109 mlir::PatternRewriter &rewriter,
110 mlir::Value shapeVal) const {
111 auto loc = embox.getLoc();
112 auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp());
113 llvm::SmallVector<mlir::Value> shapeOpers;
114 llvm::SmallVector<mlir::Value> shiftOpers;
115 if (shapeOp) {
116 populateShape(shapeOpers, shapeOp);
117 } else {
118 auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp());
119 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
120 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
121 }
122 llvm::SmallVector<mlir::Value> sliceOpers;
123 llvm::SmallVector<mlir::Value> subcompOpers;
124 if (auto s = embox.getSlice())
125 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
126 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
127 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
128 }
129 auto xbox = rewriter.create<cg::XEmboxOp>(
130 loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers,
131 sliceOpers, subcompOpers, embox.lenParams());
132 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
133 rewriter.replaceOp(embox, xbox.getOperation()->getResults());
134 return mlir::success();
135 }
136 };
137
138 /// Convert fir.rebox to the extended form where necessary.
139 ///
140 /// For example,
141 /// ```
142 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<?xi32>>
143 /// ```
144 /// converted to
145 /// ```
146 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, index, index) -> !fir.box<!fir.array<?xi32>>
147 /// ```
148 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
149 public:
150 using OpRewritePattern::OpRewritePattern;
151
152 mlir::LogicalResult
matchAndRewrite(ReboxOp rebox,mlir::PatternRewriter & rewriter) const153 matchAndRewrite(ReboxOp rebox,
154 mlir::PatternRewriter &rewriter) const override {
155 auto loc = rebox.getLoc();
156 llvm::SmallVector<mlir::Value> shapeOpers;
157 llvm::SmallVector<mlir::Value> shiftOpers;
158 if (auto shapeVal = rebox.shape()) {
159 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
160 populateShape(shapeOpers, shapeOp);
161 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
162 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
163 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
164 populateShift(shiftOpers, shiftOp);
165 else
166 return mlir::failure();
167 }
168 llvm::SmallVector<mlir::Value> sliceOpers;
169 llvm::SmallVector<mlir::Value> subcompOpers;
170 if (auto s = rebox.slice())
171 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
172 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
173 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
174 }
175
176 auto xRebox = rewriter.create<cg::XReboxOp>(
177 loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers,
178 subcompOpers);
179 LLVM_DEBUG(llvm::dbgs()
180 << "rewriting " << rebox << " to " << xRebox << '\n');
181 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
182 return mlir::success();
183 }
184 };
185
186 /// Convert all fir.array_coor to the extended form.
187 ///
188 /// For example,
189 /// ```
190 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
191 /// ```
192 /// converted to
193 /// ```
194 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> !fir.ref<i32>
195 /// ```
196 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> {
197 public:
198 using OpRewritePattern::OpRewritePattern;
199
200 mlir::LogicalResult
matchAndRewrite(ArrayCoorOp arrCoor,mlir::PatternRewriter & rewriter) const201 matchAndRewrite(ArrayCoorOp arrCoor,
202 mlir::PatternRewriter &rewriter) const override {
203 auto loc = arrCoor.getLoc();
204 llvm::SmallVector<mlir::Value> shapeOpers;
205 llvm::SmallVector<mlir::Value> shiftOpers;
206 if (auto shapeVal = arrCoor.shape()) {
207 if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
208 populateShape(shapeOpers, shapeOp);
209 else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
210 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
211 else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
212 populateShift(shiftOpers, shiftOp);
213 else
214 return mlir::failure();
215 }
216 llvm::SmallVector<mlir::Value> sliceOpers;
217 llvm::SmallVector<mlir::Value> subcompOpers;
218 if (auto s = arrCoor.slice())
219 if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
220 sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
221 subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
222 }
223 auto xArrCoor = rewriter.create<cg::XArrayCoorOp>(
224 loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers,
225 sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.lenParams());
226 LLVM_DEBUG(llvm::dbgs()
227 << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
228 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
229 return mlir::success();
230 }
231 };
232
233 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> {
234 public:
runOnOperation()235 void runOnOperation() override final {
236 auto op = getOperation();
237 auto &context = getContext();
238 mlir::OpBuilder rewriter(&context);
239 mlir::ConversionTarget target(context);
240 target.addLegalDialect<FIROpsDialect, FIRCodeGenDialect,
241 mlir::StandardOpsDialect>();
242 target.addIllegalOp<ArrayCoorOp>();
243 target.addIllegalOp<ReboxOp>();
244 target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) {
245 return !(embox.getShape() ||
246 embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>());
247 });
248 mlir::OwningRewritePatternList patterns(&context);
249 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
250 &context);
251 if (mlir::failed(
252 mlir::applyPartialConversion(op, target, std::move(patterns)))) {
253 mlir::emitError(mlir::UnknownLoc::get(&context),
254 "error in running the pre-codegen conversions");
255 signalPassFailure();
256 }
257 }
258 };
259
260 } // namespace
261
createFirCodeGenRewritePass()262 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
263 return std::make_unique<CodeGenRewrite>();
264 }
265