1 //===-- FIROps.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRAttr.h"
15 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Dialect/CommonFolders.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include "llvm/ADT/TypeSwitch.h"
25
26 using namespace fir;
27
28 /// Return true if a sequence type is of some incomplete size or a record type
29 /// is malformed or contains an incomplete sequence type. An incomplete sequence
30 /// type is one with more unknown extents in the type than have been provided
31 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by
32 /// definition.
verifyInType(mlir::Type inType,llvm::SmallVectorImpl<llvm::StringRef> & visited,unsigned dynamicExtents=0)33 static bool verifyInType(mlir::Type inType,
34 llvm::SmallVectorImpl<llvm::StringRef> &visited,
35 unsigned dynamicExtents = 0) {
36 if (auto st = inType.dyn_cast<fir::SequenceType>()) {
37 auto shape = st.getShape();
38 if (shape.size() == 0)
39 return true;
40 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) {
41 if (shape[i] != fir::SequenceType::getUnknownExtent())
42 continue;
43 if (dynamicExtents-- == 0)
44 return true;
45 }
46 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) {
47 // don't recurse if we're already visiting this one
48 if (llvm::is_contained(visited, rt.getName()))
49 return false;
50 // keep track of record types currently being visited
51 visited.push_back(rt.getName());
52 for (auto &field : rt.getTypeList())
53 if (verifyInType(field.second, visited))
54 return true;
55 visited.pop_back();
56 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) {
57 return verifyInType(rt.getEleTy(), visited);
58 }
59 return false;
60 }
61
verifyRecordLenParams(mlir::Type inType,unsigned numLenParams)62 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) {
63 if (numLenParams > 0) {
64 if (auto rt = inType.dyn_cast<fir::RecordType>())
65 return numLenParams != rt.getNumLenParams();
66 return true;
67 }
68 return false;
69 }
70
71 //===----------------------------------------------------------------------===//
72 // AllocaOp
73 //===----------------------------------------------------------------------===//
74
getAllocatedType()75 mlir::Type fir::AllocaOp::getAllocatedType() {
76 return getType().cast<ReferenceType>().getEleTy();
77 }
78
79 /// Create a legal memory reference as return type
wrapResultType(mlir::Type intype)80 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) {
81 // FIR semantics: memory references to memory references are disallowed
82 if (intype.isa<ReferenceType>())
83 return {};
84 return ReferenceType::get(intype);
85 }
86
getRefTy(mlir::Type ty)87 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) {
88 return ReferenceType::get(ty);
89 }
90
91 //===----------------------------------------------------------------------===//
92 // AllocMemOp
93 //===----------------------------------------------------------------------===//
94
getAllocatedType()95 mlir::Type fir::AllocMemOp::getAllocatedType() {
96 return getType().cast<HeapType>().getEleTy();
97 }
98
getRefTy(mlir::Type ty)99 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) {
100 return HeapType::get(ty);
101 }
102
103 /// Create a legal heap reference as return type
wrapResultType(mlir::Type intype)104 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) {
105 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER
106 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well
107 // FIR semantics: one may not allocate a memory reference value
108 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() ||
109 intype.isa<PointerType>() || intype.isa<FunctionType>())
110 return {};
111 return HeapType::get(intype);
112 }
113
114 //===----------------------------------------------------------------------===//
115 // ArrayCoorOp
116 //===----------------------------------------------------------------------===//
117
verify(fir::ArrayCoorOp op)118 static mlir::LogicalResult verify(fir::ArrayCoorOp op) {
119 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
120 auto arrTy = eleTy.dyn_cast<fir::SequenceType>();
121 if (!arrTy)
122 return op.emitOpError("must be a reference to an array");
123 auto arrDim = arrTy.getDimension();
124
125 if (auto shapeOp = op.shape()) {
126 auto shapeTy = shapeOp.getType();
127 unsigned shapeTyRank = 0;
128 if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) {
129 shapeTyRank = s.getRank();
130 } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) {
131 shapeTyRank = ss.getRank();
132 } else {
133 auto s = shapeTy.cast<fir::ShiftType>();
134 shapeTyRank = s.getRank();
135 if (!op.memref().getType().isa<fir::BoxType>())
136 return op.emitOpError("shift can only be provided with fir.box memref");
137 }
138 if (arrDim && arrDim != shapeTyRank)
139 return op.emitOpError("rank of dimension mismatched");
140 if (shapeTyRank != op.indices().size())
141 return op.emitOpError("number of indices do not match dim rank");
142 }
143
144 if (auto sliceOp = op.slice())
145 if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>())
146 if (sliceTy.getRank() != arrDim)
147 return op.emitOpError("rank of dimension in slice mismatched");
148
149 return mlir::success();
150 }
151
152 //===----------------------------------------------------------------------===//
153 // ArrayLoadOp
154 //===----------------------------------------------------------------------===//
155
getExtents()156 std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() {
157 if (auto sh = shape())
158 if (auto *op = sh.getDefiningOp()) {
159 if (auto shOp = dyn_cast<fir::ShapeOp>(op))
160 return shOp.getExtents();
161 return cast<fir::ShapeShiftOp>(op).getExtents();
162 }
163 return {};
164 }
165
verify(fir::ArrayLoadOp op)166 static mlir::LogicalResult verify(fir::ArrayLoadOp op) {
167 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
168 auto arrTy = eleTy.dyn_cast<fir::SequenceType>();
169 if (!arrTy)
170 return op.emitOpError("must be a reference to an array");
171 auto arrDim = arrTy.getDimension();
172
173 if (auto shapeOp = op.shape()) {
174 auto shapeTy = shapeOp.getType();
175 unsigned shapeTyRank = 0;
176 if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) {
177 shapeTyRank = s.getRank();
178 } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) {
179 shapeTyRank = ss.getRank();
180 } else {
181 auto s = shapeTy.cast<fir::ShiftType>();
182 shapeTyRank = s.getRank();
183 if (!op.memref().getType().isa<fir::BoxType>())
184 return op.emitOpError("shift can only be provided with fir.box memref");
185 }
186 if (arrDim && arrDim != shapeTyRank)
187 return op.emitOpError("rank of dimension mismatched");
188 }
189
190 if (auto sliceOp = op.slice())
191 if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>())
192 if (sliceTy.getRank() != arrDim)
193 return op.emitOpError("rank of dimension in slice mismatched");
194
195 return mlir::success();
196 }
197
198 //===----------------------------------------------------------------------===//
199 // BoxAddrOp
200 //===----------------------------------------------------------------------===//
201
fold(llvm::ArrayRef<mlir::Attribute> opnds)202 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) {
203 if (auto v = val().getDefiningOp()) {
204 if (auto box = dyn_cast<fir::EmboxOp>(v))
205 return box.memref();
206 if (auto box = dyn_cast<fir::EmboxCharOp>(v))
207 return box.memref();
208 }
209 return {};
210 }
211
212 //===----------------------------------------------------------------------===//
213 // BoxCharLenOp
214 //===----------------------------------------------------------------------===//
215
216 mlir::OpFoldResult
fold(llvm::ArrayRef<mlir::Attribute> opnds)217 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) {
218 if (auto v = val().getDefiningOp()) {
219 if (auto box = dyn_cast<fir::EmboxCharOp>(v))
220 return box.len();
221 }
222 return {};
223 }
224
225 //===----------------------------------------------------------------------===//
226 // BoxDimsOp
227 //===----------------------------------------------------------------------===//
228
229 /// Get the result types packed in a tuple tuple
getTupleType()230 mlir::Type fir::BoxDimsOp::getTupleType() {
231 // note: triple, but 4 is nearest power of 2
232 llvm::SmallVector<mlir::Type, 4> triple{
233 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()};
234 return mlir::TupleType::get(getContext(), triple);
235 }
236
237 //===----------------------------------------------------------------------===//
238 // CallOp
239 //===----------------------------------------------------------------------===//
240
getFunctionType()241 mlir::FunctionType fir::CallOp::getFunctionType() {
242 return mlir::FunctionType::get(getContext(), getOperandTypes(),
243 getResultTypes());
244 }
245
printCallOp(mlir::OpAsmPrinter & p,fir::CallOp & op)246 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) {
247 auto callee = op.callee();
248 bool isDirect = callee.hasValue();
249 p << op.getOperationName() << ' ';
250 if (isDirect)
251 p << callee.getValue();
252 else
253 p << op.getOperand(0);
254 p << '(' << op->getOperands().drop_front(isDirect ? 0 : 1) << ')';
255 p.printOptionalAttrDict(op->getAttrs(), {"callee"});
256 auto resultTypes{op.getResultTypes()};
257 llvm::SmallVector<Type, 8> argTypes(
258 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
259 p << " : " << FunctionType::get(op.getContext(), argTypes, resultTypes);
260 }
261
parseCallOp(mlir::OpAsmParser & parser,mlir::OperationState & result)262 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
263 mlir::OperationState &result) {
264 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands;
265 if (parser.parseOperandList(operands))
266 return mlir::failure();
267
268 mlir::NamedAttrList attrs;
269 mlir::SymbolRefAttr funcAttr;
270 bool isDirect = operands.empty();
271 if (isDirect)
272 if (parser.parseAttribute(funcAttr, "callee", attrs))
273 return mlir::failure();
274
275 Type type;
276 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
277 parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
278 parser.parseType(type))
279 return mlir::failure();
280
281 auto funcType = type.dyn_cast<mlir::FunctionType>();
282 if (!funcType)
283 return parser.emitError(parser.getNameLoc(), "expected function type");
284 if (isDirect) {
285 if (parser.resolveOperands(operands, funcType.getInputs(),
286 parser.getNameLoc(), result.operands))
287 return mlir::failure();
288 } else {
289 auto funcArgs =
290 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front();
291 if (parser.resolveOperand(operands[0], funcType, result.operands) ||
292 parser.resolveOperands(funcArgs, funcType.getInputs(),
293 parser.getNameLoc(), result.operands))
294 return mlir::failure();
295 }
296 result.addTypes(funcType.getResults());
297 result.attributes = attrs;
298 return mlir::success();
299 }
300
301 //===----------------------------------------------------------------------===//
302 // CmpfOp
303 //===----------------------------------------------------------------------===//
304
305 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp
getPredicateByName(llvm::StringRef name)306 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) {
307 auto pred = mlir::symbolizeCmpFPredicate(name);
308 assert(pred.hasValue() && "invalid predicate name");
309 return pred.getValue();
310 }
311
buildCmpFOp(OpBuilder & builder,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)312 void fir::buildCmpFOp(OpBuilder &builder, OperationState &result,
313 CmpFPredicate predicate, Value lhs, Value rhs) {
314 result.addOperands({lhs, rhs});
315 result.types.push_back(builder.getI1Type());
316 result.addAttribute(
317 CmpfOp::getPredicateAttrName(),
318 builder.getI64IntegerAttr(static_cast<int64_t>(predicate)));
319 }
320
321 template <typename OPTY>
printCmpOp(OpAsmPrinter & p,OPTY op)322 static void printCmpOp(OpAsmPrinter &p, OPTY op) {
323 p << op.getOperationName() << ' ';
324 auto predSym = mlir::symbolizeCmpFPredicate(
325 op->template getAttrOfType<mlir::IntegerAttr>(
326 OPTY::getPredicateAttrName())
327 .getInt());
328 assert(predSym.hasValue() && "invalid symbol value for predicate");
329 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", ";
330 p.printOperand(op.lhs());
331 p << ", ";
332 p.printOperand(op.rhs());
333 p.printOptionalAttrDict(op->getAttrs(),
334 /*elidedAttrs=*/{OPTY::getPredicateAttrName()});
335 p << " : " << op.lhs().getType();
336 }
337
printCmpfOp(OpAsmPrinter & p,CmpfOp op)338 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); }
339
340 template <typename OPTY>
parseCmpOp(mlir::OpAsmParser & parser,mlir::OperationState & result)341 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser,
342 mlir::OperationState &result) {
343 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops;
344 mlir::NamedAttrList attrs;
345 mlir::Attribute predicateNameAttr;
346 mlir::Type type;
347 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(),
348 attrs) ||
349 parser.parseComma() || parser.parseOperandList(ops, 2) ||
350 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
351 parser.resolveOperands(ops, type, result.operands))
352 return failure();
353
354 if (!predicateNameAttr.isa<mlir::StringAttr>())
355 return parser.emitError(parser.getNameLoc(),
356 "expected string comparison predicate attribute");
357
358 // Rewrite string attribute to an enum value.
359 llvm::StringRef predicateName =
360 predicateNameAttr.cast<mlir::StringAttr>().getValue();
361 auto predicate = fir::CmpfOp::getPredicateByName(predicateName);
362 auto builder = parser.getBuilder();
363 mlir::Type i1Type = builder.getI1Type();
364 attrs.set(OPTY::getPredicateAttrName(),
365 builder.getI64IntegerAttr(static_cast<int64_t>(predicate)));
366 result.attributes = attrs;
367 result.addTypes({i1Type});
368 return success();
369 }
370
parseCmpfOp(mlir::OpAsmParser & parser,mlir::OperationState & result)371 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser,
372 mlir::OperationState &result) {
373 return parseCmpOp<fir::CmpfOp>(parser, result);
374 }
375
376 //===----------------------------------------------------------------------===//
377 // CmpcOp
378 //===----------------------------------------------------------------------===//
379
buildCmpCOp(OpBuilder & builder,OperationState & result,CmpFPredicate predicate,Value lhs,Value rhs)380 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result,
381 CmpFPredicate predicate, Value lhs, Value rhs) {
382 result.addOperands({lhs, rhs});
383 result.types.push_back(builder.getI1Type());
384 result.addAttribute(
385 fir::CmpcOp::getPredicateAttrName(),
386 builder.getI64IntegerAttr(static_cast<int64_t>(predicate)));
387 }
388
printCmpcOp(OpAsmPrinter & p,fir::CmpcOp op)389 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); }
390
parseCmpcOp(mlir::OpAsmParser & parser,mlir::OperationState & result)391 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser,
392 mlir::OperationState &result) {
393 return parseCmpOp<fir::CmpcOp>(parser, result);
394 }
395
396 //===----------------------------------------------------------------------===//
397 // ConvertOp
398 //===----------------------------------------------------------------------===//
399
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)400 void fir::ConvertOp::getCanonicalizationPatterns(
401 OwningRewritePatternList &results, MLIRContext *context) {
402 }
403
fold(llvm::ArrayRef<mlir::Attribute> opnds)404 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) {
405 if (value().getType() == getType())
406 return value();
407 if (matchPattern(value(), m_Op<fir::ConvertOp>())) {
408 auto inner = cast<fir::ConvertOp>(value().getDefiningOp());
409 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a
410 if (auto toTy = getType().dyn_cast<fir::LogicalType>())
411 if (auto fromTy = inner.value().getType().dyn_cast<fir::LogicalType>())
412 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy))
413 return inner.value();
414 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a
415 if (auto toTy = getType().dyn_cast<mlir::IntegerType>())
416 if (auto fromTy = inner.value().getType().dyn_cast<mlir::IntegerType>())
417 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) &&
418 (fromTy.getWidth() == 1))
419 return inner.value();
420 }
421 return {};
422 }
423
isIntegerCompatible(mlir::Type ty)424 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) {
425 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() ||
426 ty.isa<fir::IntegerType>() || ty.isa<fir::LogicalType>();
427 }
428
isFloatCompatible(mlir::Type ty)429 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) {
430 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>();
431 }
432
isPointerCompatible(mlir::Type ty)433 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) {
434 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() ||
435 ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() ||
436 ty.isa<mlir::FunctionType>() || ty.isa<fir::TypeDescType>();
437 }
438
439 //===----------------------------------------------------------------------===//
440 // CoordinateOp
441 //===----------------------------------------------------------------------===//
442
print(mlir::OpAsmPrinter & p,fir::CoordinateOp op)443 static void print(mlir::OpAsmPrinter &p, fir::CoordinateOp op) {
444 p << op.getOperationName() << ' ' << op.ref() << ", " << op.coor();
445 p.printOptionalAttrDict(op->getAttrs(), /*elideAttrs=*/{"baseType"});
446 p << " : ";
447 p.printFunctionalType(op.getOperandTypes(), op->getResultTypes());
448 }
449
parseCoordinateCustom(mlir::OpAsmParser & parser,mlir::OperationState & result)450 static mlir::ParseResult parseCoordinateCustom(mlir::OpAsmParser &parser,
451 mlir::OperationState &result) {
452 mlir::OpAsmParser::OperandType memref;
453 if (parser.parseOperand(memref) || parser.parseComma())
454 return mlir::failure();
455 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> coorOperands;
456 if (parser.parseOperandList(coorOperands))
457 return mlir::failure();
458 llvm::SmallVector<mlir::OpAsmParser::OperandType, 16> allOperands;
459 allOperands.push_back(memref);
460 allOperands.append(coorOperands.begin(), coorOperands.end());
461 mlir::FunctionType funcTy;
462 auto loc = parser.getCurrentLocation();
463 if (parser.parseOptionalAttrDict(result.attributes) ||
464 parser.parseColonType(funcTy) ||
465 parser.resolveOperands(allOperands, funcTy.getInputs(), loc,
466 result.operands))
467 return failure();
468 parser.addTypesToList(funcTy.getResults(), result.types);
469 result.addAttribute("baseType", mlir::TypeAttr::get(funcTy.getInput(0)));
470 return mlir::success();
471 }
472
verify(fir::CoordinateOp op)473 static mlir::LogicalResult verify(fir::CoordinateOp op) {
474 auto refTy = op.ref().getType();
475 if (fir::isa_ref_type(refTy)) {
476 auto eleTy = fir::dyn_cast_ptrEleTy(refTy);
477 if (auto arrTy = eleTy.dyn_cast<fir::SequenceType>()) {
478 if (arrTy.hasUnknownShape())
479 return op.emitOpError("cannot find coordinate in unknown shape");
480 if (arrTy.getConstantRows() < arrTy.getDimension() - 1)
481 return op.emitOpError("cannot find coordinate with unknown extents");
482 }
483 if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) ||
484 fir::isa_char_string(eleTy)))
485 return op.emitOpError("cannot apply coordinate_of to this type");
486 }
487 // Recovering a LEN type parameter only makes sense from a boxed value. For a
488 // bare reference, the LEN type parameters must be passed as additional
489 // arguments to `op`.
490 for (auto co : op.coor())
491 if (dyn_cast_or_null<fir::LenParamIndexOp>(co.getDefiningOp())) {
492 if (op.getNumOperands() != 2)
493 return op.emitOpError("len_param_index must be last argument");
494 if (!op.ref().getType().isa<BoxType>())
495 return op.emitOpError("len_param_index must be used on box type");
496 }
497 return mlir::success();
498 }
499
500 //===----------------------------------------------------------------------===//
501 // DispatchOp
502 //===----------------------------------------------------------------------===//
503
getFunctionType()504 mlir::FunctionType fir::DispatchOp::getFunctionType() {
505 return mlir::FunctionType::get(getContext(), getOperandTypes(),
506 getResultTypes());
507 }
508
509 //===----------------------------------------------------------------------===//
510 // DispatchTableOp
511 //===----------------------------------------------------------------------===//
512
appendTableEntry(mlir::Operation * op)513 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) {
514 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp");
515 auto &block = getBlock();
516 block.getOperations().insert(block.end(), op);
517 }
518
519 //===----------------------------------------------------------------------===//
520 // EmboxOp
521 //===----------------------------------------------------------------------===//
522
verify(fir::EmboxOp op)523 static mlir::LogicalResult verify(fir::EmboxOp op) {
524 auto eleTy = fir::dyn_cast_ptrEleTy(op.memref().getType());
525 bool isArray = false;
526 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) {
527 eleTy = seqTy.getEleTy();
528 isArray = true;
529 }
530 if (op.hasLenParams()) {
531 auto lenPs = op.numLenParams();
532 if (auto rt = eleTy.dyn_cast<fir::RecordType>()) {
533 if (lenPs != rt.getNumLenParams())
534 return op.emitOpError("number of LEN params does not correspond"
535 " to the !fir.type type");
536 } else if (auto strTy = eleTy.dyn_cast<fir::CharacterType>()) {
537 if (strTy.getLen() != fir::CharacterType::unknownLen())
538 return op.emitOpError("CHARACTER already has static LEN");
539 } else {
540 return op.emitOpError("LEN parameters require CHARACTER or derived type");
541 }
542 for (auto lp : op.lenParams())
543 if (!fir::isa_integer(lp.getType()))
544 return op.emitOpError("LEN parameters must be integral type");
545 }
546 if (op.getShape() && !isArray)
547 return op.emitOpError("shape must not be provided for a scalar");
548 if (op.getSlice() && !isArray)
549 return op.emitOpError("slice must not be provided for a scalar");
550 return mlir::success();
551 }
552
553 //===----------------------------------------------------------------------===//
554 // GenTypeDescOp
555 //===----------------------------------------------------------------------===//
556
build(OpBuilder &,OperationState & result,mlir::TypeAttr inty)557 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result,
558 mlir::TypeAttr inty) {
559 result.addAttribute("in_type", inty);
560 result.addTypes(TypeDescType::get(inty.getValue()));
561 }
562
563 //===----------------------------------------------------------------------===//
564 // GlobalOp
565 //===----------------------------------------------------------------------===//
566
parseGlobalOp(OpAsmParser & parser,OperationState & result)567 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
568 // Parse the optional linkage
569 llvm::StringRef linkage;
570 auto &builder = parser.getBuilder();
571 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) {
572 if (fir::GlobalOp::verifyValidLinkage(linkage))
573 return mlir::failure();
574 mlir::StringAttr linkAttr = builder.getStringAttr(linkage);
575 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr);
576 }
577
578 // Parse the name as a symbol reference attribute.
579 mlir::SymbolRefAttr nameAttr;
580 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(),
581 result.attributes))
582 return mlir::failure();
583 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
584 builder.getStringAttr(nameAttr.getRootReference()));
585
586 bool simpleInitializer = false;
587 if (mlir::succeeded(parser.parseOptionalLParen())) {
588 Attribute attr;
589 if (parser.parseAttribute(attr, "initVal", result.attributes) ||
590 parser.parseRParen())
591 return mlir::failure();
592 simpleInitializer = true;
593 }
594
595 if (succeeded(parser.parseOptionalKeyword("constant"))) {
596 // if "constant" keyword then mark this as a constant, not a variable
597 result.addAttribute("constant", builder.getUnitAttr());
598 }
599
600 mlir::Type globalType;
601 if (parser.parseColonType(globalType))
602 return mlir::failure();
603
604 result.addAttribute(fir::GlobalOp::typeAttrName(result.name),
605 mlir::TypeAttr::get(globalType));
606
607 if (simpleInitializer) {
608 result.addRegion();
609 } else {
610 // Parse the optional initializer body.
611 auto parseResult = parser.parseOptionalRegion(
612 *result.addRegion(), /*arguments=*/llvm::None, /*argTypes=*/llvm::None);
613 if (parseResult.hasValue() && mlir::failed(*parseResult))
614 return mlir::failure();
615 }
616
617 return mlir::success();
618 }
619
appendInitialValue(mlir::Operation * op)620 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) {
621 getBlock().getOperations().push_back(op);
622 }
623
build(mlir::OpBuilder & builder,OperationState & result,StringRef name,bool isConstant,Type type,Attribute initialVal,StringAttr linkage,ArrayRef<NamedAttribute> attrs)624 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
625 StringRef name, bool isConstant, Type type,
626 Attribute initialVal, StringAttr linkage,
627 ArrayRef<NamedAttribute> attrs) {
628 result.addRegion();
629 result.addAttribute(typeAttrName(result.name), mlir::TypeAttr::get(type));
630 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
631 builder.getStringAttr(name));
632 result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name));
633 if (isConstant)
634 result.addAttribute(constantAttrName(result.name), builder.getUnitAttr());
635 if (initialVal)
636 result.addAttribute(initValAttrName(result.name), initialVal);
637 if (linkage)
638 result.addAttribute(linkageAttrName(), linkage);
639 result.attributes.append(attrs.begin(), attrs.end());
640 }
641
build(mlir::OpBuilder & builder,OperationState & result,StringRef name,Type type,Attribute initialVal,StringAttr linkage,ArrayRef<NamedAttribute> attrs)642 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
643 StringRef name, Type type, Attribute initialVal,
644 StringAttr linkage, ArrayRef<NamedAttribute> attrs) {
645 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs);
646 }
647
build(mlir::OpBuilder & builder,OperationState & result,StringRef name,bool isConstant,Type type,StringAttr linkage,ArrayRef<NamedAttribute> attrs)648 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
649 StringRef name, bool isConstant, Type type,
650 StringAttr linkage, ArrayRef<NamedAttribute> attrs) {
651 build(builder, result, name, isConstant, type, {}, linkage, attrs);
652 }
653
build(mlir::OpBuilder & builder,OperationState & result,StringRef name,Type type,StringAttr linkage,ArrayRef<NamedAttribute> attrs)654 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
655 StringRef name, Type type, StringAttr linkage,
656 ArrayRef<NamedAttribute> attrs) {
657 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs);
658 }
659
build(mlir::OpBuilder & builder,OperationState & result,StringRef name,bool isConstant,Type type,ArrayRef<NamedAttribute> attrs)660 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
661 StringRef name, bool isConstant, Type type,
662 ArrayRef<NamedAttribute> attrs) {
663 build(builder, result, name, isConstant, type, StringAttr{}, attrs);
664 }
665
build(mlir::OpBuilder & builder,OperationState & result,StringRef name,Type type,ArrayRef<NamedAttribute> attrs)666 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
667 StringRef name, Type type,
668 ArrayRef<NamedAttribute> attrs) {
669 build(builder, result, name, /*isConstant=*/false, type, attrs);
670 }
671
verifyValidLinkage(StringRef linkage)672 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) {
673 // Supporting only a subset of the LLVM linkage types for now
674 static const char *validNames[] = {"common", "internal", "linkonce", "weak"};
675 return mlir::success(llvm::is_contained(validNames, linkage));
676 }
677
678 //===----------------------------------------------------------------------===//
679 // InsertValueOp
680 //===----------------------------------------------------------------------===//
681
checkIsIntegerConstant(mlir::Value v,int64_t conVal)682 static bool checkIsIntegerConstant(mlir::Value v, int64_t conVal) {
683 if (auto c = dyn_cast_or_null<mlir::ConstantOp>(v.getDefiningOp())) {
684 auto attr = c.getValue();
685 if (auto iattr = attr.dyn_cast<mlir::IntegerAttr>())
686 return iattr.getInt() == conVal;
687 }
688 return false;
689 }
isZero(mlir::Value v)690 static bool isZero(mlir::Value v) { return checkIsIntegerConstant(v, 0); }
isOne(mlir::Value v)691 static bool isOne(mlir::Value v) { return checkIsIntegerConstant(v, 1); }
692
693 // Undo some complex patterns created in the front-end and turn them back into
694 // complex ops.
695 template <typename FltOp, typename CpxOp>
696 struct UndoComplexPattern : public mlir::RewritePattern {
UndoComplexPatternUndoComplexPattern697 UndoComplexPattern(mlir::MLIRContext *ctx)
698 : mlir::RewritePattern("fir.insert_value", 2, ctx) {}
699
700 mlir::LogicalResult
matchAndRewriteUndoComplexPattern701 matchAndRewrite(mlir::Operation *op,
702 mlir::PatternRewriter &rewriter) const override {
703 auto insval = dyn_cast_or_null<fir::InsertValueOp>(op);
704 if (!insval || !insval.getType().isa<fir::ComplexType>())
705 return mlir::failure();
706 auto insval2 =
707 dyn_cast_or_null<fir::InsertValueOp>(insval.adt().getDefiningOp());
708 if (!insval2 || !isa<fir::UndefOp>(insval2.adt().getDefiningOp()))
709 return mlir::failure();
710 auto binf = dyn_cast_or_null<FltOp>(insval.val().getDefiningOp());
711 auto binf2 = dyn_cast_or_null<FltOp>(insval2.val().getDefiningOp());
712 if (!binf || !binf2 || insval.coor().size() != 1 ||
713 !isOne(insval.coor()[0]) || insval2.coor().size() != 1 ||
714 !isZero(insval2.coor()[0]))
715 return mlir::failure();
716 auto eai =
717 dyn_cast_or_null<fir::ExtractValueOp>(binf.lhs().getDefiningOp());
718 auto ebi =
719 dyn_cast_or_null<fir::ExtractValueOp>(binf.rhs().getDefiningOp());
720 auto ear =
721 dyn_cast_or_null<fir::ExtractValueOp>(binf2.lhs().getDefiningOp());
722 auto ebr =
723 dyn_cast_or_null<fir::ExtractValueOp>(binf2.rhs().getDefiningOp());
724 if (!eai || !ebi || !ear || !ebr || ear.adt() != eai.adt() ||
725 ebr.adt() != ebi.adt() || eai.coor().size() != 1 ||
726 !isOne(eai.coor()[0]) || ebi.coor().size() != 1 ||
727 !isOne(ebi.coor()[0]) || ear.coor().size() != 1 ||
728 !isZero(ear.coor()[0]) || ebr.coor().size() != 1 ||
729 !isZero(ebr.coor()[0]))
730 return mlir::failure();
731 rewriter.replaceOpWithNewOp<CpxOp>(op, ear.adt(), ebr.adt());
732 return mlir::success();
733 }
734 };
735
getCanonicalizationPatterns(mlir::OwningRewritePatternList & results,mlir::MLIRContext * context)736 void fir::InsertValueOp::getCanonicalizationPatterns(
737 mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
738 results.insert<UndoComplexPattern<mlir::AddFOp, fir::AddcOp>,
739 UndoComplexPattern<mlir::SubFOp, fir::SubcOp>>(context);
740 }
741
742 //===----------------------------------------------------------------------===//
743 // IterWhileOp
744 //===----------------------------------------------------------------------===//
745
build(mlir::OpBuilder & builder,mlir::OperationState & result,mlir::Value lb,mlir::Value ub,mlir::Value step,mlir::Value iterate,bool finalCountValue,mlir::ValueRange iterArgs,llvm::ArrayRef<mlir::NamedAttribute> attributes)746 void fir::IterWhileOp::build(mlir::OpBuilder &builder,
747 mlir::OperationState &result, mlir::Value lb,
748 mlir::Value ub, mlir::Value step,
749 mlir::Value iterate, bool finalCountValue,
750 mlir::ValueRange iterArgs,
751 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
752 result.addOperands({lb, ub, step, iterate});
753 if (finalCountValue) {
754 result.addTypes(builder.getIndexType());
755 result.addAttribute(finalValueAttrName(result.name), builder.getUnitAttr());
756 }
757 result.addTypes(iterate.getType());
758 result.addOperands(iterArgs);
759 for (auto v : iterArgs)
760 result.addTypes(v.getType());
761 mlir::Region *bodyRegion = result.addRegion();
762 bodyRegion->push_back(new Block{});
763 bodyRegion->front().addArgument(builder.getIndexType());
764 bodyRegion->front().addArgument(iterate.getType());
765 bodyRegion->front().addArguments(iterArgs.getTypes());
766 result.addAttributes(attributes);
767 }
768
parseIterWhileOp(mlir::OpAsmParser & parser,mlir::OperationState & result)769 static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
770 mlir::OperationState &result) {
771 auto &builder = parser.getBuilder();
772 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step;
773 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) ||
774 parser.parseEqual())
775 return mlir::failure();
776
777 // Parse loop bounds.
778 auto indexType = builder.getIndexType();
779 auto i1Type = builder.getIntegerType(1);
780 if (parser.parseOperand(lb) ||
781 parser.resolveOperand(lb, indexType, result.operands) ||
782 parser.parseKeyword("to") || parser.parseOperand(ub) ||
783 parser.resolveOperand(ub, indexType, result.operands) ||
784 parser.parseKeyword("step") || parser.parseOperand(step) ||
785 parser.parseRParen() ||
786 parser.resolveOperand(step, indexType, result.operands))
787 return mlir::failure();
788
789 mlir::OpAsmParser::OperandType iterateVar, iterateInput;
790 if (parser.parseKeyword("and") || parser.parseLParen() ||
791 parser.parseRegionArgument(iterateVar) || parser.parseEqual() ||
792 parser.parseOperand(iterateInput) || parser.parseRParen() ||
793 parser.resolveOperand(iterateInput, i1Type, result.operands))
794 return mlir::failure();
795
796 // Parse the initial iteration arguments.
797 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs;
798 auto prependCount = false;
799
800 // Induction variable.
801 regionArgs.push_back(inductionVariable);
802 regionArgs.push_back(iterateVar);
803
804 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
805 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
806 llvm::SmallVector<mlir::Type, 4> regionTypes;
807 // Parse assignment list and results type list.
808 if (parser.parseAssignmentList(regionArgs, operands) ||
809 parser.parseArrowTypeList(regionTypes))
810 return failure();
811 if (regionTypes.size() == operands.size() + 2)
812 prependCount = true;
813 llvm::ArrayRef<mlir::Type> resTypes = regionTypes;
814 resTypes = prependCount ? resTypes.drop_front(2) : resTypes;
815 // Resolve input operands.
816 for (auto operand_type : llvm::zip(operands, resTypes))
817 if (parser.resolveOperand(std::get<0>(operand_type),
818 std::get<1>(operand_type), result.operands))
819 return failure();
820 if (prependCount) {
821 result.addTypes(regionTypes);
822 } else {
823 result.addTypes(i1Type);
824 result.addTypes(resTypes);
825 }
826 } else if (succeeded(parser.parseOptionalArrow())) {
827 llvm::SmallVector<mlir::Type, 4> typeList;
828 if (parser.parseLParen() || parser.parseTypeList(typeList) ||
829 parser.parseRParen())
830 return failure();
831 // Type list must be "(index, i1)".
832 if (typeList.size() != 2 || !typeList[0].isa<mlir::IndexType>() ||
833 !typeList[1].isSignlessInteger(1))
834 return failure();
835 result.addTypes(typeList);
836 prependCount = true;
837 } else {
838 result.addTypes(i1Type);
839 }
840
841 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
842 return mlir::failure();
843
844 llvm::SmallVector<mlir::Type, 4> argTypes;
845 // Induction variable (hidden)
846 if (prependCount)
847 result.addAttribute(IterWhileOp::finalValueAttrName(result.name),
848 builder.getUnitAttr());
849 else
850 argTypes.push_back(indexType);
851 // Loop carried variables (including iterate)
852 argTypes.append(result.types.begin(), result.types.end());
853 // Parse the body region.
854 auto *body = result.addRegion();
855 if (regionArgs.size() != argTypes.size())
856 return parser.emitError(
857 parser.getNameLoc(),
858 "mismatch in number of loop-carried values and defined values");
859
860 if (parser.parseRegion(*body, regionArgs, argTypes))
861 return failure();
862
863 fir::IterWhileOp::ensureTerminator(*body, builder, result.location);
864
865 return mlir::success();
866 }
867
verify(fir::IterWhileOp op)868 static mlir::LogicalResult verify(fir::IterWhileOp op) {
869 // Check that the body defines as single block argument for the induction
870 // variable.
871 auto *body = op.getBody();
872 if (!body->getArgument(1).getType().isInteger(1))
873 return op.emitOpError(
874 "expected body second argument to be an index argument for "
875 "the induction variable");
876 if (!body->getArgument(0).getType().isIndex())
877 return op.emitOpError(
878 "expected body first argument to be an index argument for "
879 "the induction variable");
880
881 auto opNumResults = op.getNumResults();
882 if (op.finalValue()) {
883 // Result type must be "(index, i1, ...)".
884 if (!op.getResult(0).getType().isa<mlir::IndexType>())
885 return op.emitOpError("result #0 expected to be index");
886 if (!op.getResult(1).getType().isSignlessInteger(1))
887 return op.emitOpError("result #1 expected to be i1");
888 opNumResults--;
889 } else {
890 // iterate_while always returns the early exit induction value.
891 // Result type must be "(i1, ...)"
892 if (!op.getResult(0).getType().isSignlessInteger(1))
893 return op.emitOpError("result #0 expected to be i1");
894 }
895 if (opNumResults == 0)
896 return mlir::failure();
897 if (op.getNumIterOperands() != opNumResults)
898 return op.emitOpError(
899 "mismatch in number of loop-carried values and defined values");
900 if (op.getNumRegionIterArgs() != opNumResults)
901 return op.emitOpError(
902 "mismatch in number of basic block args and defined values");
903 auto iterOperands = op.getIterOperands();
904 auto iterArgs = op.getRegionIterArgs();
905 auto opResults =
906 op.finalValue() ? op.getResults().drop_front() : op.getResults();
907 unsigned i = 0;
908 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
909 if (std::get<0>(e).getType() != std::get<2>(e).getType())
910 return op.emitOpError() << "types mismatch between " << i
911 << "th iter operand and defined value";
912 if (std::get<1>(e).getType() != std::get<2>(e).getType())
913 return op.emitOpError() << "types mismatch between " << i
914 << "th iter region arg and defined value";
915
916 i++;
917 }
918 return mlir::success();
919 }
920
print(mlir::OpAsmPrinter & p,fir::IterWhileOp op)921 static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) {
922 p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar()
923 << " = " << op.lowerBound() << " to " << op.upperBound() << " step "
924 << op.step() << ") and (";
925 assert(op.hasIterOperands());
926 auto regionArgs = op.getRegionIterArgs();
927 auto operands = op.getIterOperands();
928 p << regionArgs.front() << " = " << *operands.begin() << ")";
929 if (regionArgs.size() > 1) {
930 p << " iter_args(";
931 llvm::interleaveComma(
932 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p,
933 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); });
934 p << ") -> (";
935 llvm::interleaveComma(
936 llvm::drop_begin(op.getResultTypes(), op.finalValue() ? 0 : 1), p);
937 p << ")";
938 } else if (op.finalValue()) {
939 p << " -> (" << op.getResultTypes() << ')';
940 }
941 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"finalValue"});
942 p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
943 /*printBlockTerminators=*/true);
944 }
945
getLoopBody()946 mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); }
947
isDefinedOutsideOfLoop(mlir::Value value)948 bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) {
949 return !region().isAncestor(value.getParentRegion());
950 }
951
952 mlir::LogicalResult
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)953 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
954 for (auto op : ops)
955 op->moveBefore(*this);
956 return success();
957 }
958
iterArgToBlockArg(mlir::Value iterArg)959 mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) {
960 for (auto i : llvm::enumerate(initArgs()))
961 if (iterArg == i.value())
962 return region().front().getArgument(i.index() + 1);
963 return {};
964 }
965
resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> & results,unsigned resultNum)966 void fir::IterWhileOp::resultToSourceOps(
967 llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) {
968 auto oper = finalValue() ? resultNum + 1 : resultNum;
969 auto *term = region().front().getTerminator();
970 if (oper < term->getNumOperands())
971 results.push_back(term->getOperand(oper));
972 }
973
blockArgToSourceOp(unsigned blockArgNum)974 mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) {
975 if (blockArgNum > 0 && blockArgNum <= initArgs().size())
976 return initArgs()[blockArgNum - 1];
977 return {};
978 }
979
980 //===----------------------------------------------------------------------===//
981 // LoadOp
982 //===----------------------------------------------------------------------===//
983
984 /// Get the element type of a reference like type; otherwise null
elementTypeOf(mlir::Type ref)985 static mlir::Type elementTypeOf(mlir::Type ref) {
986 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref)
987 .Case<ReferenceType, PointerType, HeapType>(
988 [](auto type) { return type.getEleTy(); })
989 .Default([](mlir::Type) { return mlir::Type{}; });
990 }
991
getElementOf(mlir::Type & ele,mlir::Type ref)992 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) {
993 if ((ele = elementTypeOf(ref)))
994 return mlir::success();
995 return mlir::failure();
996 }
997
998 //===----------------------------------------------------------------------===//
999 // DoLoopOp
1000 //===----------------------------------------------------------------------===//
1001
build(mlir::OpBuilder & builder,mlir::OperationState & result,mlir::Value lb,mlir::Value ub,mlir::Value step,bool unordered,bool finalCountValue,mlir::ValueRange iterArgs,llvm::ArrayRef<mlir::NamedAttribute> attributes)1002 void fir::DoLoopOp::build(mlir::OpBuilder &builder,
1003 mlir::OperationState &result, mlir::Value lb,
1004 mlir::Value ub, mlir::Value step, bool unordered,
1005 bool finalCountValue, mlir::ValueRange iterArgs,
1006 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
1007 result.addOperands({lb, ub, step});
1008 result.addOperands(iterArgs);
1009 if (finalCountValue) {
1010 result.addTypes(builder.getIndexType());
1011 result.addAttribute(finalValueAttrName(result.name), builder.getUnitAttr());
1012 }
1013 for (auto v : iterArgs)
1014 result.addTypes(v.getType());
1015 mlir::Region *bodyRegion = result.addRegion();
1016 bodyRegion->push_back(new Block{});
1017 if (iterArgs.empty() && !finalCountValue)
1018 DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
1019 bodyRegion->front().addArgument(builder.getIndexType());
1020 bodyRegion->front().addArguments(iterArgs.getTypes());
1021 if (unordered)
1022 result.addAttribute(unorderedAttrName(result.name), builder.getUnitAttr());
1023 result.addAttributes(attributes);
1024 }
1025
parseDoLoopOp(mlir::OpAsmParser & parser,mlir::OperationState & result)1026 static mlir::ParseResult parseDoLoopOp(mlir::OpAsmParser &parser,
1027 mlir::OperationState &result) {
1028 auto &builder = parser.getBuilder();
1029 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step;
1030 // Parse the induction variable followed by '='.
1031 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1032 return mlir::failure();
1033
1034 // Parse loop bounds.
1035 auto indexType = builder.getIndexType();
1036 if (parser.parseOperand(lb) ||
1037 parser.resolveOperand(lb, indexType, result.operands) ||
1038 parser.parseKeyword("to") || parser.parseOperand(ub) ||
1039 parser.resolveOperand(ub, indexType, result.operands) ||
1040 parser.parseKeyword("step") || parser.parseOperand(step) ||
1041 parser.resolveOperand(step, indexType, result.operands))
1042 return failure();
1043
1044 if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
1045 result.addAttribute("unordered", builder.getUnitAttr());
1046
1047 // Parse the optional initial iteration arguments.
1048 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs, operands;
1049 llvm::SmallVector<mlir::Type, 4> argTypes;
1050 auto prependCount = false;
1051 regionArgs.push_back(inductionVariable);
1052
1053 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1054 // Parse assignment list and results type list.
1055 if (parser.parseAssignmentList(regionArgs, operands) ||
1056 parser.parseArrowTypeList(result.types))
1057 return failure();
1058 if (result.types.size() == operands.size() + 1)
1059 prependCount = true;
1060 // Resolve input operands.
1061 llvm::ArrayRef<mlir::Type> resTypes = result.types;
1062 for (auto operand_type :
1063 llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
1064 if (parser.resolveOperand(std::get<0>(operand_type),
1065 std::get<1>(operand_type), result.operands))
1066 return failure();
1067 } else if (succeeded(parser.parseOptionalArrow())) {
1068 if (parser.parseKeyword("index"))
1069 return failure();
1070 result.types.push_back(indexType);
1071 prependCount = true;
1072 }
1073
1074 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1075 return mlir::failure();
1076
1077 // Induction variable.
1078 if (prependCount)
1079 result.addAttribute(DoLoopOp::finalValueAttrName(result.name),
1080 builder.getUnitAttr());
1081 else
1082 argTypes.push_back(indexType);
1083 // Loop carried variables
1084 argTypes.append(result.types.begin(), result.types.end());
1085 // Parse the body region.
1086 auto *body = result.addRegion();
1087 if (regionArgs.size() != argTypes.size())
1088 return parser.emitError(
1089 parser.getNameLoc(),
1090 "mismatch in number of loop-carried values and defined values");
1091
1092 if (parser.parseRegion(*body, regionArgs, argTypes))
1093 return failure();
1094
1095 DoLoopOp::ensureTerminator(*body, builder, result.location);
1096
1097 return mlir::success();
1098 }
1099
getForInductionVarOwner(mlir::Value val)1100 fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) {
1101 auto ivArg = val.dyn_cast<mlir::BlockArgument>();
1102 if (!ivArg)
1103 return {};
1104 assert(ivArg.getOwner() && "unlinked block argument");
1105 auto *containingInst = ivArg.getOwner()->getParentOp();
1106 return dyn_cast_or_null<fir::DoLoopOp>(containingInst);
1107 }
1108
1109 // Lifted from loop.loop
verify(fir::DoLoopOp op)1110 static mlir::LogicalResult verify(fir::DoLoopOp op) {
1111 // Check that the body defines as single block argument for the induction
1112 // variable.
1113 auto *body = op.getBody();
1114 if (!body->getArgument(0).getType().isIndex())
1115 return op.emitOpError(
1116 "expected body first argument to be an index argument for "
1117 "the induction variable");
1118
1119 auto opNumResults = op.getNumResults();
1120 if (opNumResults == 0)
1121 return success();
1122
1123 if (op.finalValue()) {
1124 if (op.unordered())
1125 return op.emitOpError("unordered loop has no final value");
1126 opNumResults--;
1127 }
1128 if (op.getNumIterOperands() != opNumResults)
1129 return op.emitOpError(
1130 "mismatch in number of loop-carried values and defined values");
1131 if (op.getNumRegionIterArgs() != opNumResults)
1132 return op.emitOpError(
1133 "mismatch in number of basic block args and defined values");
1134 auto iterOperands = op.getIterOperands();
1135 auto iterArgs = op.getRegionIterArgs();
1136 auto opResults =
1137 op.finalValue() ? op.getResults().drop_front() : op.getResults();
1138 unsigned i = 0;
1139 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
1140 if (std::get<0>(e).getType() != std::get<2>(e).getType())
1141 return op.emitOpError() << "types mismatch between " << i
1142 << "th iter operand and defined value";
1143 if (std::get<1>(e).getType() != std::get<2>(e).getType())
1144 return op.emitOpError() << "types mismatch between " << i
1145 << "th iter region arg and defined value";
1146
1147 i++;
1148 }
1149 return success();
1150 }
1151
print(mlir::OpAsmPrinter & p,fir::DoLoopOp op)1152 static void print(mlir::OpAsmPrinter &p, fir::DoLoopOp op) {
1153 bool printBlockTerminators = false;
1154 p << fir::DoLoopOp::getOperationName() << ' ' << op.getInductionVar() << " = "
1155 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
1156 if (op.unordered())
1157 p << " unordered";
1158 if (op.hasIterOperands()) {
1159 p << " iter_args(";
1160 auto regionArgs = op.getRegionIterArgs();
1161 auto operands = op.getIterOperands();
1162 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1163 p << std::get<0>(it) << " = " << std::get<1>(it);
1164 });
1165 p << ") -> (" << op.getResultTypes() << ')';
1166 printBlockTerminators = true;
1167 } else if (op.finalValue()) {
1168 p << " -> " << op.getResultTypes();
1169 printBlockTerminators = true;
1170 }
1171 p.printOptionalAttrDictWithKeyword(op->getAttrs(),
1172 {"unordered", "finalValue"});
1173 p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
1174 printBlockTerminators);
1175 }
1176
getLoopBody()1177 mlir::Region &fir::DoLoopOp::getLoopBody() { return region(); }
1178
isDefinedOutsideOfLoop(mlir::Value value)1179 bool fir::DoLoopOp::isDefinedOutsideOfLoop(mlir::Value value) {
1180 return !region().isAncestor(value.getParentRegion());
1181 }
1182
1183 mlir::LogicalResult
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)1184 fir::DoLoopOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
1185 for (auto op : ops)
1186 op->moveBefore(*this);
1187 return success();
1188 }
1189
1190 /// Translate a value passed as an iter_arg to the corresponding block
1191 /// argument in the body of the loop.
iterArgToBlockArg(mlir::Value iterArg)1192 mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) {
1193 for (auto i : llvm::enumerate(initArgs()))
1194 if (iterArg == i.value())
1195 return region().front().getArgument(i.index() + 1);
1196 return {};
1197 }
1198
1199 /// Translate the result vector (by index number) to the corresponding value
1200 /// to the `fir.result` Op.
resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> & results,unsigned resultNum)1201 void fir::DoLoopOp::resultToSourceOps(
1202 llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) {
1203 auto oper = finalValue() ? resultNum + 1 : resultNum;
1204 auto *term = region().front().getTerminator();
1205 if (oper < term->getNumOperands())
1206 results.push_back(term->getOperand(oper));
1207 }
1208
1209 /// Translate the block argument (by index number) to the corresponding value
1210 /// passed as an iter_arg to the parent DoLoopOp.
blockArgToSourceOp(unsigned blockArgNum)1211 mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) {
1212 if (blockArgNum > 0 && blockArgNum <= initArgs().size())
1213 return initArgs()[blockArgNum - 1];
1214 return {};
1215 }
1216
1217 //===----------------------------------------------------------------------===//
1218 // ReboxOp
1219 //===----------------------------------------------------------------------===//
1220
1221 /// Get the scalar type related to a fir.box type.
1222 /// Example: return f32 for !fir.box<!fir.heap<!fir.array<?x?xf32>>.
getBoxScalarEleTy(mlir::Type boxTy)1223 static mlir::Type getBoxScalarEleTy(mlir::Type boxTy) {
1224 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy);
1225 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
1226 return seqTy.getEleTy();
1227 return eleTy;
1228 }
1229
1230 /// Get the rank from a !fir.box type
getBoxRank(mlir::Type boxTy)1231 static unsigned getBoxRank(mlir::Type boxTy) {
1232 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy);
1233 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
1234 return seqTy.getDimension();
1235 return 0;
1236 }
1237
verify(fir::ReboxOp op)1238 static mlir::LogicalResult verify(fir::ReboxOp op) {
1239 auto inputBoxTy = op.box().getType();
1240 if (fir::isa_unknown_size_box(inputBoxTy))
1241 return op.emitOpError("box operand must not have unknown rank or type");
1242 auto outBoxTy = op.getType();
1243 if (fir::isa_unknown_size_box(outBoxTy))
1244 return op.emitOpError("result type must not have unknown rank or type");
1245 auto inputRank = getBoxRank(inputBoxTy);
1246 auto inputEleTy = getBoxScalarEleTy(inputBoxTy);
1247 auto outRank = getBoxRank(outBoxTy);
1248 auto outEleTy = getBoxScalarEleTy(outBoxTy);
1249
1250 if (auto slice = op.slice()) {
1251 // Slicing case
1252 if (slice.getType().cast<fir::SliceType>().getRank() != inputRank)
1253 return op.emitOpError("slice operand rank must match box operand rank");
1254 if (auto shape = op.shape()) {
1255 if (auto shiftTy = shape.getType().dyn_cast<fir::ShiftType>()) {
1256 if (shiftTy.getRank() != inputRank)
1257 return op.emitOpError("shape operand and input box ranks must match "
1258 "when there is a slice");
1259 } else {
1260 return op.emitOpError("shape operand must absent or be a fir.shift "
1261 "when there is a slice");
1262 }
1263 }
1264 if (auto sliceOp = slice.getDefiningOp()) {
1265 auto slicedRank = mlir::cast<fir::SliceOp>(sliceOp).getOutRank();
1266 if (slicedRank != outRank)
1267 return op.emitOpError("result type rank and rank after applying slice "
1268 "operand must match");
1269 }
1270 } else {
1271 // Reshaping case
1272 unsigned shapeRank = inputRank;
1273 if (auto shape = op.shape()) {
1274 auto ty = shape.getType();
1275 if (auto shapeTy = ty.dyn_cast<fir::ShapeType>()) {
1276 shapeRank = shapeTy.getRank();
1277 } else if (auto shapeShiftTy = ty.dyn_cast<fir::ShapeShiftType>()) {
1278 shapeRank = shapeShiftTy.getRank();
1279 } else {
1280 auto shiftTy = ty.cast<fir::ShiftType>();
1281 shapeRank = shiftTy.getRank();
1282 if (shapeRank != inputRank)
1283 return op.emitOpError("shape operand and input box ranks must match "
1284 "when the shape is a fir.shift");
1285 }
1286 }
1287 if (shapeRank != outRank)
1288 return op.emitOpError("result type and shape operand ranks must match");
1289 }
1290
1291 if (inputEleTy != outEleTy)
1292 // TODO: check that outBoxTy is a parent type of inputBoxTy for derived
1293 // types.
1294 if (!inputEleTy.isa<fir::RecordType>())
1295 return op.emitOpError(
1296 "op input and output element types must match for intrinsic types");
1297 return mlir::success();
1298 }
1299
1300 //===----------------------------------------------------------------------===//
1301 // ResultOp
1302 //===----------------------------------------------------------------------===//
1303
verify(fir::ResultOp op)1304 static mlir::LogicalResult verify(fir::ResultOp op) {
1305 auto *parentOp = op->getParentOp();
1306 auto results = parentOp->getResults();
1307 auto operands = op->getOperands();
1308
1309 if (parentOp->getNumResults() != op.getNumOperands())
1310 return op.emitOpError() << "parent of result must have same arity";
1311 for (auto e : llvm::zip(results, operands))
1312 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1313 return op.emitOpError()
1314 << "types mismatch between result op and its parent";
1315 return success();
1316 }
1317
1318 //===----------------------------------------------------------------------===//
1319 // SelectOp
1320 //===----------------------------------------------------------------------===//
1321
getCompareOffsetAttr()1322 static constexpr llvm::StringRef getCompareOffsetAttr() {
1323 return "compare_operand_offsets";
1324 }
1325
getTargetOffsetAttr()1326 static constexpr llvm::StringRef getTargetOffsetAttr() {
1327 return "target_operand_offsets";
1328 }
1329
1330 template <typename A, typename... AdditionalArgs>
getSubOperands(unsigned pos,A allArgs,mlir::DenseIntElementsAttr ranges,AdditionalArgs &&...additionalArgs)1331 static A getSubOperands(unsigned pos, A allArgs,
1332 mlir::DenseIntElementsAttr ranges,
1333 AdditionalArgs &&... additionalArgs) {
1334 unsigned start = 0;
1335 for (unsigned i = 0; i < pos; ++i)
1336 start += (*(ranges.begin() + i)).getZExtValue();
1337 return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(),
1338 std::forward<AdditionalArgs>(additionalArgs)...);
1339 }
1340
1341 static mlir::MutableOperandRange
getMutableSuccessorOperands(unsigned pos,mlir::MutableOperandRange operands,StringRef offsetAttr)1342 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands,
1343 StringRef offsetAttr) {
1344 Operation *owner = operands.getOwner();
1345 NamedAttribute targetOffsetAttr =
1346 *owner->getAttrDictionary().getNamed(offsetAttr);
1347 return getSubOperands(
1348 pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(),
1349 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr));
1350 }
1351
denseElementsSize(mlir::DenseIntElementsAttr attr)1352 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) {
1353 return attr.getNumElements();
1354 }
1355
getCompareOperands(unsigned)1356 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) {
1357 return {};
1358 }
1359
1360 llvm::Optional<llvm::ArrayRef<mlir::Value>>
getCompareOperands(llvm::ArrayRef<mlir::Value>,unsigned)1361 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
1362 return {};
1363 }
1364
1365 llvm::Optional<mlir::MutableOperandRange>
getMutableSuccessorOperands(unsigned oper)1366 fir::SelectOp::getMutableSuccessorOperands(unsigned oper) {
1367 return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
1368 getTargetOffsetAttr());
1369 }
1370
1371 llvm::Optional<llvm::ArrayRef<mlir::Value>>
getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,unsigned oper)1372 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
1373 unsigned oper) {
1374 auto a =
1375 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
1376 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1377 getOperandSegmentSizeAttr());
1378 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
1379 }
1380
targetOffsetSize()1381 unsigned fir::SelectOp::targetOffsetSize() {
1382 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1383 getTargetOffsetAttr()));
1384 }
1385
1386 //===----------------------------------------------------------------------===//
1387 // SelectCaseOp
1388 //===----------------------------------------------------------------------===//
1389
1390 llvm::Optional<mlir::OperandRange>
getCompareOperands(unsigned cond)1391 fir::SelectCaseOp::getCompareOperands(unsigned cond) {
1392 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1393 getCompareOffsetAttr());
1394 return {getSubOperands(cond, compareArgs(), a)};
1395 }
1396
1397 llvm::Optional<llvm::ArrayRef<mlir::Value>>
getCompareOperands(llvm::ArrayRef<mlir::Value> operands,unsigned cond)1398 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands,
1399 unsigned cond) {
1400 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1401 getCompareOffsetAttr());
1402 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1403 getOperandSegmentSizeAttr());
1404 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
1405 }
1406
1407 llvm::Optional<mlir::MutableOperandRange>
getMutableSuccessorOperands(unsigned oper)1408 fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
1409 return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
1410 getTargetOffsetAttr());
1411 }
1412
1413 llvm::Optional<llvm::ArrayRef<mlir::Value>>
getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,unsigned oper)1414 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
1415 unsigned oper) {
1416 auto a =
1417 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
1418 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1419 getOperandSegmentSizeAttr());
1420 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
1421 }
1422
1423 // parser for fir.select_case Op
parseSelectCase(mlir::OpAsmParser & parser,mlir::OperationState & result)1424 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser,
1425 mlir::OperationState &result) {
1426 mlir::OpAsmParser::OperandType selector;
1427 mlir::Type type;
1428 if (parseSelector(parser, result, selector, type))
1429 return mlir::failure();
1430
1431 llvm::SmallVector<mlir::Attribute, 8> attrs;
1432 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers;
1433 llvm::SmallVector<mlir::Block *, 8> dests;
1434 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs;
1435 llvm::SmallVector<int32_t, 8> argOffs;
1436 int32_t offSize = 0;
1437 while (true) {
1438 mlir::Attribute attr;
1439 mlir::Block *dest;
1440 llvm::SmallVector<mlir::Value, 8> destArg;
1441 mlir::NamedAttrList temp;
1442 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) ||
1443 parser.parseComma())
1444 return mlir::failure();
1445 attrs.push_back(attr);
1446 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) {
1447 argOffs.push_back(0);
1448 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
1449 mlir::OpAsmParser::OperandType oper1;
1450 mlir::OpAsmParser::OperandType oper2;
1451 if (parser.parseOperand(oper1) || parser.parseComma() ||
1452 parser.parseOperand(oper2) || parser.parseComma())
1453 return mlir::failure();
1454 opers.push_back(oper1);
1455 opers.push_back(oper2);
1456 argOffs.push_back(2);
1457 offSize += 2;
1458 } else {
1459 mlir::OpAsmParser::OperandType oper;
1460 if (parser.parseOperand(oper) || parser.parseComma())
1461 return mlir::failure();
1462 opers.push_back(oper);
1463 argOffs.push_back(1);
1464 ++offSize;
1465 }
1466 if (parser.parseSuccessorAndUseList(dest, destArg))
1467 return mlir::failure();
1468 dests.push_back(dest);
1469 destArgs.push_back(destArg);
1470 if (mlir::succeeded(parser.parseOptionalRSquare()))
1471 break;
1472 if (parser.parseComma())
1473 return mlir::failure();
1474 }
1475 result.addAttribute(fir::SelectCaseOp::getCasesAttr(),
1476 parser.getBuilder().getArrayAttr(attrs));
1477 if (parser.resolveOperands(opers, type, result.operands))
1478 return mlir::failure();
1479 llvm::SmallVector<int32_t, 8> targOffs;
1480 int32_t toffSize = 0;
1481 const auto count = dests.size();
1482 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) {
1483 result.addSuccessors(dests[i]);
1484 result.addOperands(destArgs[i]);
1485 auto argSize = destArgs[i].size();
1486 targOffs.push_back(argSize);
1487 toffSize += argSize;
1488 }
1489 auto &bld = parser.getBuilder();
1490 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(),
1491 bld.getI32VectorAttr({1, offSize, toffSize}));
1492 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs));
1493 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs));
1494 return mlir::success();
1495 }
1496
compareOffsetSize()1497 unsigned fir::SelectCaseOp::compareOffsetSize() {
1498 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1499 getCompareOffsetAttr()));
1500 }
1501
targetOffsetSize()1502 unsigned fir::SelectCaseOp::targetOffsetSize() {
1503 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1504 getTargetOffsetAttr()));
1505 }
1506
build(mlir::OpBuilder & builder,mlir::OperationState & result,mlir::Value selector,llvm::ArrayRef<mlir::Attribute> compareAttrs,llvm::ArrayRef<mlir::ValueRange> cmpOperands,llvm::ArrayRef<mlir::Block * > destinations,llvm::ArrayRef<mlir::ValueRange> destOperands,llvm::ArrayRef<mlir::NamedAttribute> attributes)1507 void fir::SelectCaseOp::build(mlir::OpBuilder &builder,
1508 mlir::OperationState &result,
1509 mlir::Value selector,
1510 llvm::ArrayRef<mlir::Attribute> compareAttrs,
1511 llvm::ArrayRef<mlir::ValueRange> cmpOperands,
1512 llvm::ArrayRef<mlir::Block *> destinations,
1513 llvm::ArrayRef<mlir::ValueRange> destOperands,
1514 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
1515 result.addOperands(selector);
1516 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs));
1517 llvm::SmallVector<int32_t, 8> operOffs;
1518 int32_t operSize = 0;
1519 for (auto attr : compareAttrs) {
1520 if (attr.isa<fir::ClosedIntervalAttr>()) {
1521 operOffs.push_back(2);
1522 operSize += 2;
1523 } else if (attr.isa<mlir::UnitAttr>()) {
1524 operOffs.push_back(0);
1525 } else {
1526 operOffs.push_back(1);
1527 ++operSize;
1528 }
1529 }
1530 for (auto ops : cmpOperands)
1531 result.addOperands(ops);
1532 result.addAttribute(getCompareOffsetAttr(),
1533 builder.getI32VectorAttr(operOffs));
1534 const auto count = destinations.size();
1535 for (auto d : destinations)
1536 result.addSuccessors(d);
1537 const auto opCount = destOperands.size();
1538 llvm::SmallVector<int32_t, 8> argOffs;
1539 int32_t sumArgs = 0;
1540 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) {
1541 if (i < opCount) {
1542 result.addOperands(destOperands[i]);
1543 const auto argSz = destOperands[i].size();
1544 argOffs.push_back(argSz);
1545 sumArgs += argSz;
1546 } else {
1547 argOffs.push_back(0);
1548 }
1549 }
1550 result.addAttribute(getOperandSegmentSizeAttr(),
1551 builder.getI32VectorAttr({1, operSize, sumArgs}));
1552 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs));
1553 result.addAttributes(attributes);
1554 }
1555
1556 /// This builder has a slightly simplified interface in that the list of
1557 /// operands need not be partitioned by the builder. Instead the operands are
1558 /// partitioned here, before being passed to the default builder. This
1559 /// partitioning is unchecked, so can go awry on bad input.
build(mlir::OpBuilder & builder,mlir::OperationState & result,mlir::Value selector,llvm::ArrayRef<mlir::Attribute> compareAttrs,llvm::ArrayRef<mlir::Value> cmpOpList,llvm::ArrayRef<mlir::Block * > destinations,llvm::ArrayRef<mlir::ValueRange> destOperands,llvm::ArrayRef<mlir::NamedAttribute> attributes)1560 void fir::SelectCaseOp::build(mlir::OpBuilder &builder,
1561 mlir::OperationState &result,
1562 mlir::Value selector,
1563 llvm::ArrayRef<mlir::Attribute> compareAttrs,
1564 llvm::ArrayRef<mlir::Value> cmpOpList,
1565 llvm::ArrayRef<mlir::Block *> destinations,
1566 llvm::ArrayRef<mlir::ValueRange> destOperands,
1567 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
1568 llvm::SmallVector<mlir::ValueRange, 16> cmpOpers;
1569 auto iter = cmpOpList.begin();
1570 for (auto &attr : compareAttrs) {
1571 if (attr.isa<fir::ClosedIntervalAttr>()) {
1572 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2}));
1573 iter += 2;
1574 } else if (attr.isa<UnitAttr>()) {
1575 cmpOpers.push_back(mlir::ValueRange{});
1576 } else {
1577 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1}));
1578 ++iter;
1579 }
1580 }
1581 build(builder, result, selector, compareAttrs, cmpOpers, destinations,
1582 destOperands, attributes);
1583 }
1584
1585 //===----------------------------------------------------------------------===//
1586 // SelectRankOp
1587 //===----------------------------------------------------------------------===//
1588
1589 llvm::Optional<mlir::OperandRange>
getCompareOperands(unsigned)1590 fir::SelectRankOp::getCompareOperands(unsigned) {
1591 return {};
1592 }
1593
1594 llvm::Optional<llvm::ArrayRef<mlir::Value>>
getCompareOperands(llvm::ArrayRef<mlir::Value>,unsigned)1595 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
1596 return {};
1597 }
1598
1599 llvm::Optional<mlir::MutableOperandRange>
getMutableSuccessorOperands(unsigned oper)1600 fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) {
1601 return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
1602 getTargetOffsetAttr());
1603 }
1604
1605 llvm::Optional<llvm::ArrayRef<mlir::Value>>
getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,unsigned oper)1606 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
1607 unsigned oper) {
1608 auto a =
1609 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
1610 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1611 getOperandSegmentSizeAttr());
1612 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
1613 }
1614
targetOffsetSize()1615 unsigned fir::SelectRankOp::targetOffsetSize() {
1616 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1617 getTargetOffsetAttr()));
1618 }
1619
1620 //===----------------------------------------------------------------------===//
1621 // SelectTypeOp
1622 //===----------------------------------------------------------------------===//
1623
1624 llvm::Optional<mlir::OperandRange>
getCompareOperands(unsigned)1625 fir::SelectTypeOp::getCompareOperands(unsigned) {
1626 return {};
1627 }
1628
1629 llvm::Optional<llvm::ArrayRef<mlir::Value>>
getCompareOperands(llvm::ArrayRef<mlir::Value>,unsigned)1630 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
1631 return {};
1632 }
1633
1634 llvm::Optional<mlir::MutableOperandRange>
getMutableSuccessorOperands(unsigned oper)1635 fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) {
1636 return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
1637 getTargetOffsetAttr());
1638 }
1639
1640 llvm::Optional<llvm::ArrayRef<mlir::Value>>
getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,unsigned oper)1641 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
1642 unsigned oper) {
1643 auto a =
1644 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
1645 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1646 getOperandSegmentSizeAttr());
1647 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
1648 }
1649
parseSelectType(OpAsmParser & parser,OperationState & result)1650 static ParseResult parseSelectType(OpAsmParser &parser,
1651 OperationState &result) {
1652 mlir::OpAsmParser::OperandType selector;
1653 mlir::Type type;
1654 if (parseSelector(parser, result, selector, type))
1655 return mlir::failure();
1656
1657 llvm::SmallVector<mlir::Attribute, 8> attrs;
1658 llvm::SmallVector<mlir::Block *, 8> dests;
1659 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs;
1660 while (true) {
1661 mlir::Attribute attr;
1662 mlir::Block *dest;
1663 llvm::SmallVector<mlir::Value, 8> destArg;
1664 mlir::NamedAttrList temp;
1665 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() ||
1666 parser.parseSuccessorAndUseList(dest, destArg))
1667 return mlir::failure();
1668 attrs.push_back(attr);
1669 dests.push_back(dest);
1670 destArgs.push_back(destArg);
1671 if (mlir::succeeded(parser.parseOptionalRSquare()))
1672 break;
1673 if (parser.parseComma())
1674 return mlir::failure();
1675 }
1676 auto &bld = parser.getBuilder();
1677 result.addAttribute(fir::SelectTypeOp::getCasesAttr(),
1678 bld.getArrayAttr(attrs));
1679 llvm::SmallVector<int32_t, 8> argOffs;
1680 int32_t offSize = 0;
1681 const auto count = dests.size();
1682 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) {
1683 result.addSuccessors(dests[i]);
1684 result.addOperands(destArgs[i]);
1685 auto argSize = destArgs[i].size();
1686 argOffs.push_back(argSize);
1687 offSize += argSize;
1688 }
1689 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(),
1690 bld.getI32VectorAttr({1, 0, offSize}));
1691 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs));
1692 return mlir::success();
1693 }
1694
targetOffsetSize()1695 unsigned fir::SelectTypeOp::targetOffsetSize() {
1696 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
1697 getTargetOffsetAttr()));
1698 }
1699
1700 //===----------------------------------------------------------------------===//
1701 // SliceOp
1702 //===----------------------------------------------------------------------===//
1703
1704 /// Return the output rank of a slice op. The output rank must be between 1 and
1705 /// the rank of the array being sliced (inclusive).
getOutputRank(mlir::ValueRange triples)1706 unsigned fir::SliceOp::getOutputRank(mlir::ValueRange triples) {
1707 unsigned rank = 0;
1708 if (!triples.empty()) {
1709 for (unsigned i = 1, end = triples.size(); i < end; i += 3) {
1710 auto op = triples[i].getDefiningOp();
1711 if (!mlir::isa_and_nonnull<fir::UndefOp>(op))
1712 ++rank;
1713 }
1714 assert(rank > 0);
1715 }
1716 return rank;
1717 }
1718
1719 //===----------------------------------------------------------------------===//
1720 // StoreOp
1721 //===----------------------------------------------------------------------===//
1722
elementType(mlir::Type refType)1723 mlir::Type fir::StoreOp::elementType(mlir::Type refType) {
1724 if (auto ref = refType.dyn_cast<ReferenceType>())
1725 return ref.getEleTy();
1726 if (auto ref = refType.dyn_cast<PointerType>())
1727 return ref.getEleTy();
1728 if (auto ref = refType.dyn_cast<HeapType>())
1729 return ref.getEleTy();
1730 return {};
1731 }
1732
1733 //===----------------------------------------------------------------------===//
1734 // StringLitOp
1735 //===----------------------------------------------------------------------===//
1736
isWideValue()1737 bool fir::StringLitOp::isWideValue() {
1738 auto eleTy = getType().cast<fir::SequenceType>().getEleTy();
1739 return eleTy.cast<fir::CharacterType>().getFKind() != 1;
1740 }
1741
1742 //===----------------------------------------------------------------------===//
1743 // IfOp
1744 //===----------------------------------------------------------------------===//
1745
build(mlir::OpBuilder & builder,OperationState & result,mlir::Value cond,bool withElseRegion)1746 void fir::IfOp::build(mlir::OpBuilder &builder, OperationState &result,
1747 mlir::Value cond, bool withElseRegion) {
1748 build(builder, result, llvm::None, cond, withElseRegion);
1749 }
1750
build(mlir::OpBuilder & builder,OperationState & result,mlir::TypeRange resultTypes,mlir::Value cond,bool withElseRegion)1751 void fir::IfOp::build(mlir::OpBuilder &builder, OperationState &result,
1752 mlir::TypeRange resultTypes, mlir::Value cond,
1753 bool withElseRegion) {
1754 result.addOperands(cond);
1755 result.addTypes(resultTypes);
1756
1757 mlir::Region *thenRegion = result.addRegion();
1758 thenRegion->push_back(new mlir::Block());
1759 if (resultTypes.empty())
1760 IfOp::ensureTerminator(*thenRegion, builder, result.location);
1761
1762 mlir::Region *elseRegion = result.addRegion();
1763 if (withElseRegion) {
1764 elseRegion->push_back(new mlir::Block());
1765 if (resultTypes.empty())
1766 IfOp::ensureTerminator(*elseRegion, builder, result.location);
1767 }
1768 }
1769
parseIfOp(OpAsmParser & parser,OperationState & result)1770 static mlir::ParseResult parseIfOp(OpAsmParser &parser,
1771 OperationState &result) {
1772 result.regions.reserve(2);
1773 mlir::Region *thenRegion = result.addRegion();
1774 mlir::Region *elseRegion = result.addRegion();
1775
1776 auto &builder = parser.getBuilder();
1777 OpAsmParser::OperandType cond;
1778 mlir::Type i1Type = builder.getIntegerType(1);
1779 if (parser.parseOperand(cond) ||
1780 parser.resolveOperand(cond, i1Type, result.operands))
1781 return mlir::failure();
1782
1783 if (parser.parseOptionalArrowTypeList(result.types))
1784 return mlir::failure();
1785
1786 if (parser.parseRegion(*thenRegion, {}, {}))
1787 return mlir::failure();
1788 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1789
1790 if (mlir::succeeded(parser.parseOptionalKeyword("else"))) {
1791 if (parser.parseRegion(*elseRegion, {}, {}))
1792 return mlir::failure();
1793 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1794 }
1795
1796 // Parse the optional attribute list.
1797 if (parser.parseOptionalAttrDict(result.attributes))
1798 return mlir::failure();
1799 return mlir::success();
1800 }
1801
verify(fir::IfOp op)1802 static LogicalResult verify(fir::IfOp op) {
1803 if (op.getNumResults() != 0 && op.elseRegion().empty())
1804 return op.emitOpError("must have an else block if defining values");
1805
1806 return mlir::success();
1807 }
1808
print(mlir::OpAsmPrinter & p,fir::IfOp op)1809 static void print(mlir::OpAsmPrinter &p, fir::IfOp op) {
1810 bool printBlockTerminators = false;
1811 p << fir::IfOp::getOperationName() << ' ' << op.condition();
1812 if (!op.results().empty()) {
1813 p << " -> (" << op.getResultTypes() << ')';
1814 printBlockTerminators = true;
1815 }
1816 p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false,
1817 printBlockTerminators);
1818
1819 // Print the 'else' regions if it exists and has a block.
1820 auto &otherReg = op.elseRegion();
1821 if (!otherReg.empty()) {
1822 p << " else";
1823 p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
1824 printBlockTerminators);
1825 }
1826 p.printOptionalAttrDict(op->getAttrs());
1827 }
1828
resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> & results,unsigned resultNum)1829 void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
1830 unsigned resultNum) {
1831 auto *term = thenRegion().front().getTerminator();
1832 if (resultNum < term->getNumOperands())
1833 results.push_back(term->getOperand(resultNum));
1834 term = elseRegion().front().getTerminator();
1835 if (resultNum < term->getNumOperands())
1836 results.push_back(term->getOperand(resultNum));
1837 }
1838
1839 //===----------------------------------------------------------------------===//
1840
isValidCaseAttr(mlir::Attribute attr)1841 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) {
1842 if (attr.dyn_cast_or_null<mlir::UnitAttr>() ||
1843 attr.dyn_cast_or_null<ClosedIntervalAttr>() ||
1844 attr.dyn_cast_or_null<PointIntervalAttr>() ||
1845 attr.dyn_cast_or_null<LowerBoundAttr>() ||
1846 attr.dyn_cast_or_null<UpperBoundAttr>())
1847 return mlir::success();
1848 return mlir::failure();
1849 }
1850
getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases,unsigned dest)1851 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases,
1852 unsigned dest) {
1853 unsigned o = 0;
1854 for (unsigned i = 0; i < dest; ++i) {
1855 auto &attr = cases[i];
1856 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) {
1857 ++o;
1858 if (attr.dyn_cast_or_null<ClosedIntervalAttr>())
1859 ++o;
1860 }
1861 }
1862 return o;
1863 }
1864
parseSelector(mlir::OpAsmParser & parser,mlir::OperationState & result,mlir::OpAsmParser::OperandType & selector,mlir::Type & type)1865 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser,
1866 mlir::OperationState &result,
1867 mlir::OpAsmParser::OperandType &selector,
1868 mlir::Type &type) {
1869 if (parser.parseOperand(selector) || parser.parseColonType(type) ||
1870 parser.resolveOperand(selector, type, result.operands) ||
1871 parser.parseLSquare())
1872 return mlir::failure();
1873 return mlir::success();
1874 }
1875
1876 /// Generic pretty-printer of a binary operation
printBinaryOp(Operation * op,OpAsmPrinter & p)1877 static void printBinaryOp(Operation *op, OpAsmPrinter &p) {
1878 assert(op->getNumOperands() == 2 && "binary op must have two operands");
1879 assert(op->getNumResults() == 1 && "binary op must have one result");
1880
1881 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
1882 p.printOptionalAttrDict(op->getAttrs());
1883 p << " : " << op->getResult(0).getType();
1884 }
1885
1886 /// Generic pretty-printer of an unary operation
printUnaryOp(Operation * op,OpAsmPrinter & p)1887 static void printUnaryOp(Operation *op, OpAsmPrinter &p) {
1888 assert(op->getNumOperands() == 1 && "unary op must have one operand");
1889 assert(op->getNumResults() == 1 && "unary op must have one result");
1890
1891 p << op->getName() << ' ' << op->getOperand(0);
1892 p.printOptionalAttrDict(op->getAttrs());
1893 p << " : " << op->getResult(0).getType();
1894 }
1895
isReferenceLike(mlir::Type type)1896 bool fir::isReferenceLike(mlir::Type type) {
1897 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() ||
1898 type.isa<fir::PointerType>();
1899 }
1900
createFuncOp(mlir::Location loc,mlir::ModuleOp module,StringRef name,mlir::FunctionType type,llvm::ArrayRef<mlir::NamedAttribute> attrs)1901 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
1902 StringRef name, mlir::FunctionType type,
1903 llvm::ArrayRef<mlir::NamedAttribute> attrs) {
1904 if (auto f = module.lookupSymbol<mlir::FuncOp>(name))
1905 return f;
1906 mlir::OpBuilder modBuilder(module.getBodyRegion());
1907 modBuilder.setInsertionPoint(module.getBody()->getTerminator());
1908 auto result = modBuilder.create<mlir::FuncOp>(loc, name, type, attrs);
1909 result.setVisibility(mlir::SymbolTable::Visibility::Private);
1910 return result;
1911 }
1912
createGlobalOp(mlir::Location loc,mlir::ModuleOp module,StringRef name,mlir::Type type,llvm::ArrayRef<mlir::NamedAttribute> attrs)1913 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
1914 StringRef name, mlir::Type type,
1915 llvm::ArrayRef<mlir::NamedAttribute> attrs) {
1916 if (auto g = module.lookupSymbol<fir::GlobalOp>(name))
1917 return g;
1918 mlir::OpBuilder modBuilder(module.getBodyRegion());
1919 auto result = modBuilder.create<fir::GlobalOp>(loc, name, type, attrs);
1920 result.setVisibility(mlir::SymbolTable::Visibility::Private);
1921 return result;
1922 }
1923
valueHasFirAttribute(mlir::Value value,llvm::StringRef attributeName)1924 bool fir::valueHasFirAttribute(mlir::Value value,
1925 llvm::StringRef attributeName) {
1926 // If this is a fir.box that was loaded, the fir attributes will be on the
1927 // related fir.ref<fir.box> creation.
1928 if (value.getType().isa<fir::BoxType>())
1929 if (auto definingOp = value.getDefiningOp())
1930 if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(definingOp))
1931 value = loadOp.memref();
1932 // If this is a function argument, look in the argument attributes.
1933 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
1934 if (blockArg.getOwner() && blockArg.getOwner()->isEntryBlock())
1935 if (auto funcOp =
1936 mlir::dyn_cast<mlir::FuncOp>(blockArg.getOwner()->getParentOp()))
1937 if (funcOp.getArgAttr(blockArg.getArgNumber(), attributeName))
1938 return true;
1939 return false;
1940 }
1941
1942 if (auto definingOp = value.getDefiningOp()) {
1943 // If this is an allocated value, look at the allocation attributes.
1944 if (mlir::isa<fir::AllocMemOp>(definingOp) ||
1945 mlir::isa<AllocaOp>(definingOp))
1946 return definingOp->hasAttr(attributeName);
1947 // If this is an imported global, look at AddrOfOp and GlobalOp attributes.
1948 // Both operations are looked at because use/host associated variable (the
1949 // AddrOfOp) can have ASYNCHRONOUS/VOLATILE attributes even if the ultimate
1950 // entity (the globalOp) does not have them.
1951 if (auto addressOfOp = mlir::dyn_cast<fir::AddrOfOp>(definingOp)) {
1952 if (addressOfOp->hasAttr(attributeName))
1953 return true;
1954 if (auto module = definingOp->getParentOfType<mlir::ModuleOp>())
1955 if (auto globalOp =
1956 module.lookupSymbol<fir::GlobalOp>(addressOfOp.symbol()))
1957 return globalOp->hasAttr(attributeName);
1958 }
1959 }
1960 // TODO: Construct associated entities attributes. Decide where the fir
1961 // attributes must be placed/looked for in this case.
1962 return false;
1963 }
1964
1965 // Tablegen operators
1966
1967 #define GET_OP_CLASSES
1968 #include "flang/Optimizer/Dialect/FIROps.cpp.inc"
1969