1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "TypeDetail.h"
15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22
23 #include "llvm/ADT/StringSwitch.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/AsmParser/Parser.h"
26 #include "llvm/Bitcode/BitcodeReader.h"
27 #include "llvm/Bitcode/BitcodeWriter.h"
28 #include "llvm/IR/Attributes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Mutex.h"
32 #include "llvm/Support/SourceMgr.h"
33
34 #include <iostream>
35
36 using namespace mlir;
37 using namespace mlir::LLVM;
38
39 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
40
41 static constexpr const char kVolatileAttrName[] = "volatile_";
42 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
43
44 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
45 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
46 #define GET_ATTRDEF_CLASSES
47 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
48
processFMFAttr(ArrayRef<NamedAttribute> attrs)49 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
50 SmallVector<NamedAttribute, 8> filteredAttrs(
51 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
52 if (attr.first == "fastmathFlags") {
53 auto defAttr = FMFAttr::get(attr.second.getContext(), {});
54 return defAttr != attr.second;
55 }
56 return true;
57 }));
58 return filteredAttrs;
59 }
60
parseLLVMOpAttrs(OpAsmParser & parser,NamedAttrList & result)61 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
62 NamedAttrList &result) {
63 return parser.parseOptionalAttrDict(result);
64 }
65
printLLVMOpAttrs(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)66 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
67 DictionaryAttr attrs) {
68 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
69 }
70
71 //===----------------------------------------------------------------------===//
72 // Printing/parsing for LLVM::CmpOp.
73 //===----------------------------------------------------------------------===//
printICmpOp(OpAsmPrinter & p,ICmpOp & op)74 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
75 p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
76 << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
77 p.printOptionalAttrDict(op->getAttrs(), {"predicate"});
78 p << " : " << op.lhs().getType();
79 }
80
printFCmpOp(OpAsmPrinter & p,FCmpOp & op)81 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
82 p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
83 << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
84 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"});
85 p << " : " << op.lhs().getType();
86 }
87
88 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
89 // attribute-dict? `:` type
90 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
91 // attribute-dict? `:` type
92 template <typename CmpPredicateType>
parseCmpOp(OpAsmParser & parser,OperationState & result)93 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
94 Builder &builder = parser.getBuilder();
95
96 StringAttr predicateAttr;
97 OpAsmParser::OperandType lhs, rhs;
98 Type type;
99 llvm::SMLoc predicateLoc, trailingTypeLoc;
100 if (parser.getCurrentLocation(&predicateLoc) ||
101 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
102 parser.parseOperand(lhs) || parser.parseComma() ||
103 parser.parseOperand(rhs) ||
104 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
105 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
106 parser.resolveOperand(lhs, type, result.operands) ||
107 parser.resolveOperand(rhs, type, result.operands))
108 return failure();
109
110 // Replace the string attribute `predicate` with an integer attribute.
111 int64_t predicateValue = 0;
112 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
113 Optional<ICmpPredicate> predicate =
114 symbolizeICmpPredicate(predicateAttr.getValue());
115 if (!predicate)
116 return parser.emitError(predicateLoc)
117 << "'" << predicateAttr.getValue()
118 << "' is an incorrect value of the 'predicate' attribute";
119 predicateValue = static_cast<int64_t>(predicate.getValue());
120 } else {
121 Optional<FCmpPredicate> predicate =
122 symbolizeFCmpPredicate(predicateAttr.getValue());
123 if (!predicate)
124 return parser.emitError(predicateLoc)
125 << "'" << predicateAttr.getValue()
126 << "' is an incorrect value of the 'predicate' attribute";
127 predicateValue = static_cast<int64_t>(predicate.getValue());
128 }
129
130 result.attributes.set("predicate",
131 parser.getBuilder().getI64IntegerAttr(predicateValue));
132
133 // The result type is either i1 or a vector type <? x i1> if the inputs are
134 // vectors.
135 Type resultType = IntegerType::get(builder.getContext(), 1);
136 if (!isCompatibleType(type))
137 return parser.emitError(trailingTypeLoc,
138 "expected LLVM dialect-compatible type");
139 if (LLVM::isCompatibleVectorType(type)) {
140 if (type.isa<LLVM::LLVMScalableVectorType>()) {
141 resultType = LLVM::LLVMScalableVectorType::get(
142 resultType, LLVM::getVectorNumElements(type).getKnownMinValue());
143 } else {
144 resultType = LLVM::getFixedVectorType(
145 resultType, LLVM::getVectorNumElements(type).getFixedValue());
146 }
147 }
148
149 result.addTypes({resultType});
150 return success();
151 }
152
153 //===----------------------------------------------------------------------===//
154 // Printing/parsing for LLVM::AllocaOp.
155 //===----------------------------------------------------------------------===//
156
printAllocaOp(OpAsmPrinter & p,AllocaOp & op)157 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
158 auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType();
159
160 auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
161 {op.getType()});
162
163 p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
164 if (op.alignment().hasValue() && *op.alignment() != 0)
165 p.printOptionalAttrDict(op->getAttrs());
166 else
167 p.printOptionalAttrDict(op->getAttrs(), {"alignment"});
168 p << " : " << funcTy;
169 }
170
171 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
172 // `:` type `,` type
parseAllocaOp(OpAsmParser & parser,OperationState & result)173 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
174 OpAsmParser::OperandType arraySize;
175 Type type, elemType;
176 llvm::SMLoc trailingTypeLoc;
177 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
178 parser.parseType(elemType) ||
179 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
180 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
181 return failure();
182
183 Optional<NamedAttribute> alignmentAttr =
184 result.attributes.getNamed("alignment");
185 if (alignmentAttr.hasValue()) {
186 auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>();
187 if (!alignmentInt)
188 return parser.emitError(parser.getNameLoc(),
189 "expected integer alignment");
190 if (alignmentInt.getValue().isNullValue())
191 result.attributes.erase("alignment");
192 }
193
194 // Extract the result type from the trailing function type.
195 auto funcType = type.dyn_cast<FunctionType>();
196 if (!funcType || funcType.getNumInputs() != 1 ||
197 funcType.getNumResults() != 1)
198 return parser.emitError(
199 trailingTypeLoc,
200 "expected trailing function type with one argument and one result");
201
202 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
203 return failure();
204
205 result.addTypes({funcType.getResult(0)});
206 return success();
207 }
208
209 //===----------------------------------------------------------------------===//
210 // LLVM::BrOp
211 //===----------------------------------------------------------------------===//
212
213 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)214 BrOp::getMutableSuccessorOperands(unsigned index) {
215 assert(index == 0 && "invalid successor index");
216 return destOperandsMutable();
217 }
218
219 //===----------------------------------------------------------------------===//
220 // LLVM::CondBrOp
221 //===----------------------------------------------------------------------===//
222
223 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)224 CondBrOp::getMutableSuccessorOperands(unsigned index) {
225 assert(index < getNumSuccessors() && "invalid successor index");
226 return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
227 }
228
229 //===----------------------------------------------------------------------===//
230 // LLVM::SwitchOp
231 //===----------------------------------------------------------------------===//
232
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<int32_t> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands,ArrayRef<int32_t> branchWeights)233 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
234 Block *defaultDestination, ValueRange defaultOperands,
235 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
236 ArrayRef<ValueRange> caseOperands,
237 ArrayRef<int32_t> branchWeights) {
238 SmallVector<Value> flattenedCaseOperands;
239 SmallVector<int32_t> caseOperandOffsets;
240 int32_t offset = 0;
241 for (ValueRange operands : caseOperands) {
242 flattenedCaseOperands.append(operands.begin(), operands.end());
243 caseOperandOffsets.push_back(offset);
244 offset += operands.size();
245 }
246 ElementsAttr caseValuesAttr;
247 if (!caseValues.empty())
248 caseValuesAttr = builder.getI32VectorAttr(caseValues);
249 ElementsAttr caseOperandOffsetsAttr;
250 if (!caseOperandOffsets.empty())
251 caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
252
253 ElementsAttr weightsAttr;
254 if (!branchWeights.empty())
255 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
256
257 build(builder, result, value, defaultOperands, flattenedCaseOperands,
258 caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination,
259 caseDestinations);
260 }
261
262 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
263 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
264 static ParseResult
parseSwitchOpCases(OpAsmParser & parser,ElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<OpAsmParser::OperandType> & caseOperands,SmallVectorImpl<Type> & caseOperandTypes,ElementsAttr & caseOperandOffsets)265 parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
266 SmallVectorImpl<Block *> &caseDestinations,
267 SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
268 SmallVectorImpl<Type> &caseOperandTypes,
269 ElementsAttr &caseOperandOffsets) {
270 SmallVector<int32_t> values;
271 SmallVector<int32_t> offsets;
272 int32_t value, offset = 0;
273 do {
274 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
275 if (values.empty() && !integerParseResult.hasValue())
276 return success();
277
278 if (!integerParseResult.hasValue() || integerParseResult.getValue())
279 return failure();
280 values.push_back(value);
281
282 Block *destination;
283 SmallVector<OpAsmParser::OperandType> operands;
284 if (parser.parseColon() || parser.parseSuccessor(destination))
285 return failure();
286 if (!parser.parseOptionalLParen()) {
287 if (parser.parseRegionArgumentList(operands) ||
288 parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen())
289 return failure();
290 }
291 caseDestinations.push_back(destination);
292 caseOperands.append(operands.begin(), operands.end());
293 offsets.push_back(offset);
294 offset += operands.size();
295 } while (!parser.parseOptionalComma());
296
297 Builder &builder = parser.getBuilder();
298 caseValues = builder.getI32VectorAttr(values);
299 caseOperandOffsets = builder.getI32VectorAttr(offsets);
300
301 return success();
302 }
303
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,ElementsAttr caseValues,SuccessorRange caseDestinations,OperandRange caseOperands,TypeRange caseOperandTypes,ElementsAttr caseOperandOffsets)304 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
305 ElementsAttr caseValues,
306 SuccessorRange caseDestinations,
307 OperandRange caseOperands,
308 TypeRange caseOperandTypes,
309 ElementsAttr caseOperandOffsets) {
310 if (!caseValues)
311 return;
312
313 size_t index = 0;
314 llvm::interleave(
315 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
316 [&](auto i) {
317 p << " ";
318 p << std::get<0>(i).getLimitedValue();
319 p << ": ";
320 p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++));
321 },
322 [&] {
323 p << ',';
324 p.printNewline();
325 });
326 p.printNewline();
327 }
328
verify(SwitchOp op)329 static LogicalResult verify(SwitchOp op) {
330 if ((!op.case_values() && !op.caseDestinations().empty()) ||
331 (op.case_values() &&
332 op.case_values()->size() !=
333 static_cast<int64_t>(op.caseDestinations().size())))
334 return op.emitOpError("expects number of case values to match number of "
335 "case destinations");
336 if (op.branch_weights() &&
337 op.branch_weights()->size() != op.getNumSuccessors())
338 return op.emitError("expects number of branch weights to match number of "
339 "successors: ")
340 << op.branch_weights()->size() << " vs " << op.getNumSuccessors();
341 return success();
342 }
343
getCaseOperands(unsigned index)344 OperandRange SwitchOp::getCaseOperands(unsigned index) {
345 return getCaseOperandsMutable(index);
346 }
347
getCaseOperandsMutable(unsigned index)348 MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
349 MutableOperandRange caseOperands = caseOperandsMutable();
350 if (!case_operand_offsets()) {
351 assert(caseOperands.size() == 0 &&
352 "non-empty case operands must have offsets");
353 return caseOperands;
354 }
355
356 ElementsAttr offsets = case_operand_offsets().getValue();
357 assert(index < offsets.size() && "invalid case operand offset index");
358
359 int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
360 int64_t end = index + 1 == offsets.size()
361 ? caseOperands.size()
362 : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
363 return caseOperandsMutable().slice(begin, end - begin);
364 }
365
366 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)367 SwitchOp::getMutableSuccessorOperands(unsigned index) {
368 assert(index < getNumSuccessors() && "invalid successor index");
369 return index == 0 ? defaultOperandsMutable()
370 : getCaseOperandsMutable(index - 1);
371 }
372
373 //===----------------------------------------------------------------------===//
374 // Builder, printer and parser for for LLVM::LoadOp.
375 //===----------------------------------------------------------------------===//
376
verifyAccessGroups(Operation * op)377 static LogicalResult verifyAccessGroups(Operation *op) {
378 if (Attribute attribute =
379 op->getAttr(LLVMDialect::getAccessGroupsAttrName())) {
380 // The attribute is already verified to be a symbol ref array attribute via
381 // a constraint in the operation definition.
382 for (SymbolRefAttr accessGroupRef :
383 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
384 StringRef metadataName = accessGroupRef.getRootReference();
385 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
386 op->getParentOp(), metadataName);
387 if (!metadataOp)
388 return op->emitOpError() << "expected '" << accessGroupRef
389 << "' to reference a metadata op";
390 StringRef accessGroupName = accessGroupRef.getLeafReference();
391 Operation *accessGroupOp =
392 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
393 if (!accessGroupOp)
394 return op->emitOpError() << "expected '" << accessGroupRef
395 << "' to reference an access_group op";
396 }
397 }
398 return success();
399 }
400
verify(LoadOp op)401 static LogicalResult verify(LoadOp op) {
402 return verifyAccessGroups(op.getOperation());
403 }
404
build(OpBuilder & builder,OperationState & result,Type t,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)405 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
406 Value addr, unsigned alignment, bool isVolatile,
407 bool isNonTemporal) {
408 result.addOperands(addr);
409 result.addTypes(t);
410 if (isVolatile)
411 result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
412 if (isNonTemporal)
413 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
414 if (alignment != 0)
415 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
416 }
417
printLoadOp(OpAsmPrinter & p,LoadOp & op)418 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
419 p << op.getOperationName() << ' ';
420 if (op.volatile_())
421 p << "volatile ";
422 p << op.addr();
423 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName});
424 p << " : " << op.addr().getType();
425 }
426
427 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
428 // the resulting type wrapped in MLIR, or nullptr on error.
getLoadStoreElementType(OpAsmParser & parser,Type type,llvm::SMLoc trailingTypeLoc)429 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
430 llvm::SMLoc trailingTypeLoc) {
431 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
432 if (!llvmTy)
433 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
434 nullptr;
435 return llvmTy.getElementType();
436 }
437
438 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
parseLoadOp(OpAsmParser & parser,OperationState & result)439 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
440 OpAsmParser::OperandType addr;
441 Type type;
442 llvm::SMLoc trailingTypeLoc;
443
444 if (succeeded(parser.parseOptionalKeyword("volatile")))
445 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
446
447 if (parser.parseOperand(addr) ||
448 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
449 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
450 parser.resolveOperand(addr, type, result.operands))
451 return failure();
452
453 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
454
455 result.addTypes(elemTy);
456 return success();
457 }
458
459 //===----------------------------------------------------------------------===//
460 // Builder, printer and parser for LLVM::StoreOp.
461 //===----------------------------------------------------------------------===//
462
verify(StoreOp op)463 static LogicalResult verify(StoreOp op) {
464 return verifyAccessGroups(op.getOperation());
465 }
466
build(OpBuilder & builder,OperationState & result,Value value,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)467 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
468 Value addr, unsigned alignment, bool isVolatile,
469 bool isNonTemporal) {
470 result.addOperands({value, addr});
471 result.addTypes({});
472 if (isVolatile)
473 result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
474 if (isNonTemporal)
475 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
476 if (alignment != 0)
477 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
478 }
479
printStoreOp(OpAsmPrinter & p,StoreOp & op)480 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
481 p << op.getOperationName() << ' ';
482 if (op.volatile_())
483 p << "volatile ";
484 p << op.value() << ", " << op.addr();
485 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName});
486 p << " : " << op.addr().getType();
487 }
488
489 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
490 // attribute-dict? `:` type
parseStoreOp(OpAsmParser & parser,OperationState & result)491 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
492 OpAsmParser::OperandType addr, value;
493 Type type;
494 llvm::SMLoc trailingTypeLoc;
495
496 if (succeeded(parser.parseOptionalKeyword("volatile")))
497 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
498
499 if (parser.parseOperand(value) || parser.parseComma() ||
500 parser.parseOperand(addr) ||
501 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
502 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
503 return failure();
504
505 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
506 if (!elemTy)
507 return failure();
508
509 if (parser.resolveOperand(value, elemTy, result.operands) ||
510 parser.resolveOperand(addr, type, result.operands))
511 return failure();
512
513 return success();
514 }
515
516 ///===---------------------------------------------------------------------===//
517 /// LLVM::InvokeOp
518 ///===---------------------------------------------------------------------===//
519
520 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)521 InvokeOp::getMutableSuccessorOperands(unsigned index) {
522 assert(index < getNumSuccessors() && "invalid successor index");
523 return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
524 }
525
verify(InvokeOp op)526 static LogicalResult verify(InvokeOp op) {
527 if (op.getNumResults() > 1)
528 return op.emitOpError("must have 0 or 1 result");
529
530 Block *unwindDest = op.unwindDest();
531 if (unwindDest->empty())
532 return op.emitError(
533 "must have at least one operation in unwind destination");
534
535 // In unwind destination, first operation must be LandingpadOp
536 if (!isa<LandingpadOp>(unwindDest->front()))
537 return op.emitError("first operation in unwind destination should be a "
538 "llvm.landingpad operation");
539
540 return success();
541 }
542
printInvokeOp(OpAsmPrinter & p,InvokeOp op)543 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) {
544 auto callee = op.callee();
545 bool isDirect = callee.hasValue();
546
547 p << op.getOperationName() << ' ';
548
549 // Either function name or pointer
550 if (isDirect)
551 p.printSymbolName(callee.getValue());
552 else
553 p << op.getOperand(0);
554
555 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
556 p << " to ";
557 p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands());
558 p << " unwind ";
559 p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands());
560
561 p.printOptionalAttrDict(op->getAttrs(),
562 {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
563 p << " : ";
564 p.printFunctionalType(
565 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1),
566 op.getResultTypes());
567 }
568
569 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
570 /// `to` bb-id (`[` ssa-use-and-type-list `]`)?
571 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
572 /// attribute-dict? `:` function-type
parseInvokeOp(OpAsmParser & parser,OperationState & result)573 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
574 SmallVector<OpAsmParser::OperandType, 8> operands;
575 FunctionType funcType;
576 SymbolRefAttr funcAttr;
577 llvm::SMLoc trailingTypeLoc;
578 Block *normalDest, *unwindDest;
579 SmallVector<Value, 4> normalOperands, unwindOperands;
580 Builder &builder = parser.getBuilder();
581
582 // Parse an operand list that will, in practice, contain 0 or 1 operand. In
583 // case of an indirect call, there will be 1 operand before `(`. In case of a
584 // direct call, there will be no operands and the parser will stop at the
585 // function identifier without complaining.
586 if (parser.parseOperandList(operands))
587 return failure();
588 bool isDirect = operands.empty();
589
590 // Optionally parse a function identifier.
591 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
592 return failure();
593
594 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
595 parser.parseKeyword("to") ||
596 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
597 parser.parseKeyword("unwind") ||
598 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
599 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
600 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
601 return failure();
602
603 if (isDirect) {
604 // Make sure types match.
605 if (parser.resolveOperands(operands, funcType.getInputs(),
606 parser.getNameLoc(), result.operands))
607 return failure();
608 result.addTypes(funcType.getResults());
609 } else {
610 // Construct the LLVM IR Dialect function type that the first operand
611 // should match.
612 if (funcType.getNumResults() > 1)
613 return parser.emitError(trailingTypeLoc,
614 "expected function with 0 or 1 result");
615
616 Type llvmResultType;
617 if (funcType.getNumResults() == 0) {
618 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
619 } else {
620 llvmResultType = funcType.getResult(0);
621 if (!isCompatibleType(llvmResultType))
622 return parser.emitError(trailingTypeLoc,
623 "expected result to have LLVM type");
624 }
625
626 SmallVector<Type, 8> argTypes;
627 argTypes.reserve(funcType.getNumInputs());
628 for (Type ty : funcType.getInputs()) {
629 if (isCompatibleType(ty))
630 argTypes.push_back(ty);
631 else
632 return parser.emitError(trailingTypeLoc,
633 "expected LLVM types as inputs");
634 }
635
636 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
637 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
638
639 auto funcArguments = llvm::makeArrayRef(operands).drop_front();
640
641 // Make sure that the first operand (indirect callee) matches the wrapped
642 // LLVM IR function type, and that the types of the other call operands
643 // match the types of the function arguments.
644 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
645 parser.resolveOperands(funcArguments, funcType.getInputs(),
646 parser.getNameLoc(), result.operands))
647 return failure();
648
649 result.addTypes(llvmResultType);
650 }
651 result.addSuccessors({normalDest, unwindDest});
652 result.addOperands(normalOperands);
653 result.addOperands(unwindOperands);
654
655 result.addAttribute(
656 InvokeOp::getOperandSegmentSizeAttr(),
657 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
658 static_cast<int32_t>(normalOperands.size()),
659 static_cast<int32_t>(unwindOperands.size())}));
660 return success();
661 }
662
663 ///===----------------------------------------------------------------------===//
664 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
665 ///===----------------------------------------------------------------------===//
666
verify(LandingpadOp op)667 static LogicalResult verify(LandingpadOp op) {
668 Value value;
669 if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) {
670 if (!func.personality().hasValue())
671 return op.emitError(
672 "llvm.landingpad needs to be in a function with a personality");
673 }
674
675 if (!op.cleanup() && op.getOperands().empty())
676 return op.emitError("landingpad instruction expects at least one clause or "
677 "cleanup attribute");
678
679 for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
680 value = op.getOperand(idx);
681 bool isFilter = value.getType().isa<LLVMArrayType>();
682 if (isFilter) {
683 // FIXME: Verify filter clauses when arrays are appropriately handled
684 } else {
685 // catch - global addresses only.
686 // Bitcast ops should have global addresses as their args.
687 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
688 if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>())
689 continue;
690 return op.emitError("constant clauses expected")
691 .attachNote(bcOp.getLoc())
692 << "global addresses expected as operand to "
693 "bitcast used in clauses for landingpad";
694 }
695 // NullOp and AddressOfOp allowed
696 if (value.getDefiningOp<NullOp>())
697 continue;
698 if (value.getDefiningOp<AddressOfOp>())
699 continue;
700 return op.emitError("clause #")
701 << idx << " is not a known constant - null, addressof, bitcast";
702 }
703 }
704 return success();
705 }
706
printLandingpadOp(OpAsmPrinter & p,LandingpadOp & op)707 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
708 p << op.getOperationName() << (op.cleanup() ? " cleanup " : " ");
709
710 // Clauses
711 for (auto value : op.getOperands()) {
712 // Similar to llvm - if clause is an array type then it is filter
713 // clause else catch clause
714 bool isArrayTy = value.getType().isa<LLVMArrayType>();
715 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
716 << value.getType() << ") ";
717 }
718
719 p.printOptionalAttrDict(op->getAttrs(), {"cleanup"});
720
721 p << ": " << op.getType();
722 }
723
724 /// <operation> ::= `llvm.landingpad` `cleanup`?
725 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
parseLandingpadOp(OpAsmParser & parser,OperationState & result)726 static ParseResult parseLandingpadOp(OpAsmParser &parser,
727 OperationState &result) {
728 // Check for cleanup
729 if (succeeded(parser.parseOptionalKeyword("cleanup")))
730 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
731
732 // Parse clauses with types
733 while (succeeded(parser.parseOptionalLParen()) &&
734 (succeeded(parser.parseOptionalKeyword("filter")) ||
735 succeeded(parser.parseOptionalKeyword("catch")))) {
736 OpAsmParser::OperandType operand;
737 Type ty;
738 if (parser.parseOperand(operand) || parser.parseColon() ||
739 parser.parseType(ty) ||
740 parser.resolveOperand(operand, ty, result.operands) ||
741 parser.parseRParen())
742 return failure();
743 }
744
745 Type type;
746 if (parser.parseColon() || parser.parseType(type))
747 return failure();
748
749 result.addTypes(type);
750 return success();
751 }
752
753 //===----------------------------------------------------------------------===//
754 // Verifying/Printing/parsing for LLVM::CallOp.
755 //===----------------------------------------------------------------------===//
756
verify(CallOp & op)757 static LogicalResult verify(CallOp &op) {
758 if (op.getNumResults() > 1)
759 return op.emitOpError("must have 0 or 1 result");
760
761 // Type for the callee, we'll get it differently depending if it is a direct
762 // or indirect call.
763 Type fnType;
764
765 bool isIndirect = false;
766
767 // If this is an indirect call, the callee attribute is missing.
768 Optional<StringRef> calleeName = op.callee();
769 if (!calleeName) {
770 isIndirect = true;
771 if (!op.getNumOperands())
772 return op.emitOpError(
773 "must have either a `callee` attribute or at least an operand");
774 auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>();
775 if (!ptrType)
776 return op.emitOpError("indirect call expects a pointer as callee: ")
777 << ptrType;
778 fnType = ptrType.getElementType();
779 } else {
780 Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
781 if (!callee)
782 return op.emitOpError()
783 << "'" << *calleeName
784 << "' does not reference a symbol in the current scope";
785 auto fn = dyn_cast<LLVMFuncOp>(callee);
786 if (!fn)
787 return op.emitOpError() << "'" << *calleeName
788 << "' does not reference a valid LLVM function";
789
790 fnType = fn.getType();
791 }
792
793 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
794 if (!funcType)
795 return op.emitOpError("callee does not have a functional type: ") << fnType;
796
797 // Verify that the operand and result types match the callee.
798
799 if (!funcType.isVarArg() &&
800 funcType.getNumParams() != (op.getNumOperands() - isIndirect))
801 return op.emitOpError()
802 << "incorrect number of operands ("
803 << (op.getNumOperands() - isIndirect)
804 << ") for callee (expecting: " << funcType.getNumParams() << ")";
805
806 if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
807 return op.emitOpError() << "incorrect number of operands ("
808 << (op.getNumOperands() - isIndirect)
809 << ") for varargs callee (expecting at least: "
810 << funcType.getNumParams() << ")";
811
812 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
813 if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
814 return op.emitOpError() << "operand type mismatch for operand " << i
815 << ": " << op.getOperand(i + isIndirect).getType()
816 << " != " << funcType.getParamType(i);
817
818 if (op.getNumResults() &&
819 op.getResult(0).getType() != funcType.getReturnType())
820 return op.emitOpError()
821 << "result type mismatch: " << op.getResult(0).getType()
822 << " != " << funcType.getReturnType();
823
824 return success();
825 }
826
printCallOp(OpAsmPrinter & p,CallOp & op)827 static void printCallOp(OpAsmPrinter &p, CallOp &op) {
828 auto callee = op.callee();
829 bool isDirect = callee.hasValue();
830
831 // Print the direct callee if present as a function attribute, or an indirect
832 // callee (first operand) otherwise.
833 p << op.getOperationName() << ' ';
834 if (isDirect)
835 p.printSymbolName(callee.getValue());
836 else
837 p << op.getOperand(0);
838
839 auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
840 p << '(' << args << ')';
841 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"callee"});
842
843 // Reconstruct the function MLIR function type from operand and result types.
844 p << " : "
845 << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes());
846 }
847
848 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
849 // attribute-dict? `:` function-type
parseCallOp(OpAsmParser & parser,OperationState & result)850 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
851 SmallVector<OpAsmParser::OperandType, 8> operands;
852 Type type;
853 SymbolRefAttr funcAttr;
854 llvm::SMLoc trailingTypeLoc;
855
856 // Parse an operand list that will, in practice, contain 0 or 1 operand. In
857 // case of an indirect call, there will be 1 operand before `(`. In case of a
858 // direct call, there will be no operands and the parser will stop at the
859 // function identifier without complaining.
860 if (parser.parseOperandList(operands))
861 return failure();
862 bool isDirect = operands.empty();
863
864 // Optionally parse a function identifier.
865 if (isDirect)
866 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
867 return failure();
868
869 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
870 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
871 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
872 return failure();
873
874 auto funcType = type.dyn_cast<FunctionType>();
875 if (!funcType)
876 return parser.emitError(trailingTypeLoc, "expected function type");
877 if (isDirect) {
878 // Make sure types match.
879 if (parser.resolveOperands(operands, funcType.getInputs(),
880 parser.getNameLoc(), result.operands))
881 return failure();
882 result.addTypes(funcType.getResults());
883 } else {
884 // Construct the LLVM IR Dialect function type that the first operand
885 // should match.
886 if (funcType.getNumResults() > 1)
887 return parser.emitError(trailingTypeLoc,
888 "expected function with 0 or 1 result");
889
890 Builder &builder = parser.getBuilder();
891 Type llvmResultType;
892 if (funcType.getNumResults() == 0) {
893 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
894 } else {
895 llvmResultType = funcType.getResult(0);
896 if (!isCompatibleType(llvmResultType))
897 return parser.emitError(trailingTypeLoc,
898 "expected result to have LLVM type");
899 }
900
901 SmallVector<Type, 8> argTypes;
902 argTypes.reserve(funcType.getNumInputs());
903 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
904 auto argType = funcType.getInput(i);
905 if (!isCompatibleType(argType))
906 return parser.emitError(trailingTypeLoc,
907 "expected LLVM types as inputs");
908 argTypes.push_back(argType);
909 }
910 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
911 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
912
913 auto funcArguments =
914 ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
915
916 // Make sure that the first operand (indirect callee) matches the wrapped
917 // LLVM IR function type, and that the types of the other call operands
918 // match the types of the function arguments.
919 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
920 parser.resolveOperands(funcArguments, funcType.getInputs(),
921 parser.getNameLoc(), result.operands))
922 return failure();
923
924 result.addTypes(llvmResultType);
925 }
926
927 return success();
928 }
929
930 //===----------------------------------------------------------------------===//
931 // Printing/parsing for LLVM::ExtractElementOp.
932 //===----------------------------------------------------------------------===//
933 // Expects vector to be of wrapped LLVM vector type and position to be of
934 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value vector,Value position,ArrayRef<NamedAttribute> attrs)935 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
936 Value vector, Value position,
937 ArrayRef<NamedAttribute> attrs) {
938 auto vectorType = vector.getType();
939 auto llvmType = LLVM::getVectorElementType(vectorType);
940 build(b, result, llvmType, vector, position);
941 result.addAttributes(attrs);
942 }
943
printExtractElementOp(OpAsmPrinter & p,ExtractElementOp & op)944 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
945 p << op.getOperationName() << ' ' << op.vector() << "[" << op.position()
946 << " : " << op.position().getType() << "]";
947 p.printOptionalAttrDict(op->getAttrs());
948 p << " : " << op.vector().getType();
949 }
950
951 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
952 // attribute-dict? `:` type
parseExtractElementOp(OpAsmParser & parser,OperationState & result)953 static ParseResult parseExtractElementOp(OpAsmParser &parser,
954 OperationState &result) {
955 llvm::SMLoc loc;
956 OpAsmParser::OperandType vector, position;
957 Type type, positionType;
958 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
959 parser.parseLSquare() || parser.parseOperand(position) ||
960 parser.parseColonType(positionType) || parser.parseRSquare() ||
961 parser.parseOptionalAttrDict(result.attributes) ||
962 parser.parseColonType(type) ||
963 parser.resolveOperand(vector, type, result.operands) ||
964 parser.resolveOperand(position, positionType, result.operands))
965 return failure();
966 if (!LLVM::isCompatibleVectorType(type))
967 return parser.emitError(
968 loc, "expected LLVM dialect-compatible vector type for operand #1");
969 result.addTypes(LLVM::getVectorElementType(type));
970 return success();
971 }
972
verify(ExtractElementOp op)973 static LogicalResult verify(ExtractElementOp op) {
974 Type vectorType = op.vector().getType();
975 if (!LLVM::isCompatibleVectorType(vectorType))
976 return op->emitOpError("expected LLVM dialect-compatible vector type for "
977 "operand #1, got")
978 << vectorType;
979 Type valueType = LLVM::getVectorElementType(vectorType);
980 if (valueType != op.res().getType())
981 return op.emitOpError() << "Type mismatch: extracting from " << vectorType
982 << " should produce " << valueType
983 << " but this op returns " << op.res().getType();
984 return success();
985 }
986
987 //===----------------------------------------------------------------------===//
988 // Printing/parsing for LLVM::ExtractValueOp.
989 //===----------------------------------------------------------------------===//
990
printExtractValueOp(OpAsmPrinter & p,ExtractValueOp & op)991 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
992 p << op.getOperationName() << ' ' << op.container() << op.position();
993 p.printOptionalAttrDict(op->getAttrs(), {"position"});
994 p << " : " << op.container().getType();
995 }
996
997 // Extract the type at `position` in the wrapped LLVM IR aggregate type
998 // `containerType`. Position is an integer array attribute where each value
999 // is a zero-based position of the element in the aggregate type. Return the
1000 // resulting type wrapped in MLIR, or nullptr on error.
getInsertExtractValueElementType(OpAsmParser & parser,Type containerType,ArrayAttr positionAttr,llvm::SMLoc attributeLoc,llvm::SMLoc typeLoc)1001 static Type getInsertExtractValueElementType(OpAsmParser &parser,
1002 Type containerType,
1003 ArrayAttr positionAttr,
1004 llvm::SMLoc attributeLoc,
1005 llvm::SMLoc typeLoc) {
1006 Type llvmType = containerType;
1007 if (!isCompatibleType(containerType))
1008 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
1009
1010 // Infer the element type from the structure type: iteratively step inside the
1011 // type by taking the element type, indexed by the position attribute for
1012 // structures. Check the position index before accessing, it is supposed to
1013 // be in bounds.
1014 for (Attribute subAttr : positionAttr) {
1015 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1016 if (!positionElementAttr)
1017 return parser.emitError(attributeLoc,
1018 "expected an array of integer literals"),
1019 nullptr;
1020 int position = positionElementAttr.getInt();
1021 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1022 if (position < 0 ||
1023 static_cast<unsigned>(position) >= arrayType.getNumElements())
1024 return parser.emitError(attributeLoc, "position out of bounds"),
1025 nullptr;
1026 llvmType = arrayType.getElementType();
1027 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1028 if (position < 0 ||
1029 static_cast<unsigned>(position) >= structType.getBody().size())
1030 return parser.emitError(attributeLoc, "position out of bounds"),
1031 nullptr;
1032 llvmType = structType.getBody()[position];
1033 } else {
1034 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
1035 nullptr;
1036 }
1037 }
1038 return llvmType;
1039 }
1040
1041 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1042 // `containerType`. Returns null on failure.
getInsertExtractValueElementType(Type containerType,ArrayAttr positionAttr,Operation * op)1043 static Type getInsertExtractValueElementType(Type containerType,
1044 ArrayAttr positionAttr,
1045 Operation *op) {
1046 Type llvmType = containerType;
1047 if (!isCompatibleType(containerType)) {
1048 op->emitError("expected LLVM IR Dialect type, got ") << containerType;
1049 return {};
1050 }
1051
1052 // Infer the element type from the structure type: iteratively step inside the
1053 // type by taking the element type, indexed by the position attribute for
1054 // structures. Check the position index before accessing, it is supposed to
1055 // be in bounds.
1056 for (Attribute subAttr : positionAttr) {
1057 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1058 if (!positionElementAttr) {
1059 op->emitOpError("expected an array of integer literals, got: ")
1060 << subAttr;
1061 return {};
1062 }
1063 int position = positionElementAttr.getInt();
1064 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1065 if (position < 0 ||
1066 static_cast<unsigned>(position) >= arrayType.getNumElements()) {
1067 op->emitOpError("position out of bounds: ") << position;
1068 return {};
1069 }
1070 llvmType = arrayType.getElementType();
1071 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1072 if (position < 0 ||
1073 static_cast<unsigned>(position) >= structType.getBody().size()) {
1074 op->emitOpError("position out of bounds") << position;
1075 return {};
1076 }
1077 llvmType = structType.getBody()[position];
1078 } else {
1079 op->emitOpError("expected LLVM IR structure/array type, got: ")
1080 << llvmType;
1081 return {};
1082 }
1083 }
1084 return llvmType;
1085 }
1086
1087 // <operation> ::= `llvm.extractvalue` ssa-use
1088 // `[` integer-literal (`,` integer-literal)* `]`
1089 // attribute-dict? `:` type
parseExtractValueOp(OpAsmParser & parser,OperationState & result)1090 static ParseResult parseExtractValueOp(OpAsmParser &parser,
1091 OperationState &result) {
1092 OpAsmParser::OperandType container;
1093 Type containerType;
1094 ArrayAttr positionAttr;
1095 llvm::SMLoc attributeLoc, trailingTypeLoc;
1096
1097 if (parser.parseOperand(container) ||
1098 parser.getCurrentLocation(&attributeLoc) ||
1099 parser.parseAttribute(positionAttr, "position", result.attributes) ||
1100 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1101 parser.getCurrentLocation(&trailingTypeLoc) ||
1102 parser.parseType(containerType) ||
1103 parser.resolveOperand(container, containerType, result.operands))
1104 return failure();
1105
1106 auto elementType = getInsertExtractValueElementType(
1107 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1108 if (!elementType)
1109 return failure();
1110
1111 result.addTypes(elementType);
1112 return success();
1113 }
1114
fold(ArrayRef<Attribute> operands)1115 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
1116 auto insertValueOp = container().getDefiningOp<InsertValueOp>();
1117 while (insertValueOp) {
1118 if (position() == insertValueOp.position())
1119 return insertValueOp.value();
1120 insertValueOp = insertValueOp.container().getDefiningOp<InsertValueOp>();
1121 }
1122 return {};
1123 }
1124
verify(ExtractValueOp op)1125 static LogicalResult verify(ExtractValueOp op) {
1126 Type valueType = getInsertExtractValueElementType(op.container().getType(),
1127 op.positionAttr(), op);
1128 if (!valueType)
1129 return failure();
1130
1131 if (op.res().getType() != valueType)
1132 return op.emitOpError()
1133 << "Type mismatch: extracting from " << op.container().getType()
1134 << " should produce " << valueType << " but this op returns "
1135 << op.res().getType();
1136 return success();
1137 }
1138
1139 //===----------------------------------------------------------------------===//
1140 // Printing/parsing for LLVM::InsertElementOp.
1141 //===----------------------------------------------------------------------===//
1142
printInsertElementOp(OpAsmPrinter & p,InsertElementOp & op)1143 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
1144 p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "["
1145 << op.position() << " : " << op.position().getType() << "]";
1146 p.printOptionalAttrDict(op->getAttrs());
1147 p << " : " << op.vector().getType();
1148 }
1149
1150 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
1151 // attribute-dict? `:` type
parseInsertElementOp(OpAsmParser & parser,OperationState & result)1152 static ParseResult parseInsertElementOp(OpAsmParser &parser,
1153 OperationState &result) {
1154 llvm::SMLoc loc;
1155 OpAsmParser::OperandType vector, value, position;
1156 Type vectorType, positionType;
1157 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
1158 parser.parseComma() || parser.parseOperand(vector) ||
1159 parser.parseLSquare() || parser.parseOperand(position) ||
1160 parser.parseColonType(positionType) || parser.parseRSquare() ||
1161 parser.parseOptionalAttrDict(result.attributes) ||
1162 parser.parseColonType(vectorType))
1163 return failure();
1164
1165 if (!LLVM::isCompatibleVectorType(vectorType))
1166 return parser.emitError(
1167 loc, "expected LLVM dialect-compatible vector type for operand #1");
1168 Type valueType = LLVM::getVectorElementType(vectorType);
1169 if (!valueType)
1170 return failure();
1171
1172 if (parser.resolveOperand(vector, vectorType, result.operands) ||
1173 parser.resolveOperand(value, valueType, result.operands) ||
1174 parser.resolveOperand(position, positionType, result.operands))
1175 return failure();
1176
1177 result.addTypes(vectorType);
1178 return success();
1179 }
1180
verify(InsertElementOp op)1181 static LogicalResult verify(InsertElementOp op) {
1182 Type valueType = LLVM::getVectorElementType(op.vector().getType());
1183 if (valueType != op.value().getType())
1184 return op.emitOpError()
1185 << "Type mismatch: cannot insert " << op.value().getType()
1186 << " into " << op.vector().getType();
1187 return success();
1188 }
1189 //===----------------------------------------------------------------------===//
1190 // Printing/parsing for LLVM::InsertValueOp.
1191 //===----------------------------------------------------------------------===//
1192
printInsertValueOp(OpAsmPrinter & p,InsertValueOp & op)1193 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
1194 p << op.getOperationName() << ' ' << op.value() << ", " << op.container()
1195 << op.position();
1196 p.printOptionalAttrDict(op->getAttrs(), {"position"});
1197 p << " : " << op.container().getType();
1198 }
1199
1200 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
1201 // `[` integer-literal (`,` integer-literal)* `]`
1202 // attribute-dict? `:` type
parseInsertValueOp(OpAsmParser & parser,OperationState & result)1203 static ParseResult parseInsertValueOp(OpAsmParser &parser,
1204 OperationState &result) {
1205 OpAsmParser::OperandType container, value;
1206 Type containerType;
1207 ArrayAttr positionAttr;
1208 llvm::SMLoc attributeLoc, trailingTypeLoc;
1209
1210 if (parser.parseOperand(value) || parser.parseComma() ||
1211 parser.parseOperand(container) ||
1212 parser.getCurrentLocation(&attributeLoc) ||
1213 parser.parseAttribute(positionAttr, "position", result.attributes) ||
1214 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1215 parser.getCurrentLocation(&trailingTypeLoc) ||
1216 parser.parseType(containerType))
1217 return failure();
1218
1219 auto valueType = getInsertExtractValueElementType(
1220 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1221 if (!valueType)
1222 return failure();
1223
1224 if (parser.resolveOperand(container, containerType, result.operands) ||
1225 parser.resolveOperand(value, valueType, result.operands))
1226 return failure();
1227
1228 result.addTypes(containerType);
1229 return success();
1230 }
1231
verify(InsertValueOp op)1232 static LogicalResult verify(InsertValueOp op) {
1233 Type valueType = getInsertExtractValueElementType(op.container().getType(),
1234 op.positionAttr(), op);
1235 if (!valueType)
1236 return failure();
1237
1238 if (op.value().getType() != valueType)
1239 return op.emitOpError()
1240 << "Type mismatch: cannot insert " << op.value().getType()
1241 << " into " << op.container().getType();
1242
1243 return success();
1244 }
1245
1246 //===----------------------------------------------------------------------===//
1247 // Printing, parsing and verification for LLVM::ReturnOp.
1248 //===----------------------------------------------------------------------===//
1249
printReturnOp(OpAsmPrinter & p,ReturnOp op)1250 static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
1251 p << op.getOperationName();
1252 p.printOptionalAttrDict(op->getAttrs());
1253 assert(op.getNumOperands() <= 1);
1254
1255 if (op.getNumOperands() == 0)
1256 return;
1257
1258 p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType();
1259 }
1260
1261 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
1262 // type-list-no-parens
parseReturnOp(OpAsmParser & parser,OperationState & result)1263 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
1264 SmallVector<OpAsmParser::OperandType, 1> operands;
1265 Type type;
1266
1267 if (parser.parseOperandList(operands) ||
1268 parser.parseOptionalAttrDict(result.attributes))
1269 return failure();
1270 if (operands.empty())
1271 return success();
1272
1273 if (parser.parseColonType(type) ||
1274 parser.resolveOperand(operands[0], type, result.operands))
1275 return failure();
1276 return success();
1277 }
1278
verify(ReturnOp op)1279 static LogicalResult verify(ReturnOp op) {
1280 if (op->getNumOperands() > 1)
1281 return op->emitOpError("expected at most 1 operand");
1282
1283 if (auto parent = op->getParentOfType<LLVMFuncOp>()) {
1284 Type expectedType = parent.getType().getReturnType();
1285 if (expectedType.isa<LLVMVoidType>()) {
1286 if (op->getNumOperands() == 0)
1287 return success();
1288 InFlightDiagnostic diag = op->emitOpError("expected no operands");
1289 diag.attachNote(parent->getLoc()) << "when returning from function";
1290 return diag;
1291 }
1292 if (op->getNumOperands() == 0) {
1293 if (expectedType.isa<LLVMVoidType>())
1294 return success();
1295 InFlightDiagnostic diag = op->emitOpError("expected 1 operand");
1296 diag.attachNote(parent->getLoc()) << "when returning from function";
1297 return diag;
1298 }
1299 if (expectedType != op->getOperand(0).getType()) {
1300 InFlightDiagnostic diag = op->emitOpError("mismatching result types");
1301 diag.attachNote(parent->getLoc()) << "when returning from function";
1302 return diag;
1303 }
1304 }
1305 return success();
1306 }
1307
1308 //===----------------------------------------------------------------------===//
1309 // Verifier for LLVM::AddressOfOp.
1310 //===----------------------------------------------------------------------===//
1311
1312 template <typename OpTy>
lookupSymbolInModule(Operation * parent,StringRef name)1313 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
1314 Operation *module = parent;
1315 while (module && !satisfiesLLVMModule(module))
1316 module = module->getParentOp();
1317 assert(module && "unexpected operation outside of a module");
1318 return dyn_cast_or_null<OpTy>(
1319 mlir::SymbolTable::lookupSymbolIn(module, name));
1320 }
1321
getGlobal()1322 GlobalOp AddressOfOp::getGlobal() {
1323 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
1324 global_name());
1325 }
1326
getFunction()1327 LLVMFuncOp AddressOfOp::getFunction() {
1328 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
1329 global_name());
1330 }
1331
verify(AddressOfOp op)1332 static LogicalResult verify(AddressOfOp op) {
1333 auto global = op.getGlobal();
1334 auto function = op.getFunction();
1335 if (!global && !function)
1336 return op.emitOpError(
1337 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1338
1339 if (global &&
1340 LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) !=
1341 op.getResult().getType())
1342 return op.emitOpError(
1343 "the type must be a pointer to the type of the referenced global");
1344
1345 if (function && LLVM::LLVMPointerType::get(function.getType()) !=
1346 op.getResult().getType())
1347 return op.emitOpError(
1348 "the type must be a pointer to the type of the referenced function");
1349
1350 return success();
1351 }
1352
1353 //===----------------------------------------------------------------------===//
1354 // Builder, printer and verifier for LLVM::GlobalOp.
1355 //===----------------------------------------------------------------------===//
1356
1357 /// Returns the name used for the linkage attribute. This *must* correspond to
1358 /// the name of the attribute in ODS.
getLinkageAttrName()1359 static StringRef getLinkageAttrName() { return "linkage"; }
1360
1361 /// Returns the name used for the unnamed_addr attribute. This *must* correspond
1362 /// to the name of the attribute in ODS.
getUnnamedAddrAttrName()1363 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; }
1364
build(OpBuilder & builder,OperationState & result,Type type,bool isConstant,Linkage linkage,StringRef name,Attribute value,uint64_t alignment,unsigned addrSpace,bool dsoLocal,ArrayRef<NamedAttribute> attrs)1365 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1366 bool isConstant, Linkage linkage, StringRef name,
1367 Attribute value, uint64_t alignment, unsigned addrSpace,
1368 bool dsoLocal, ArrayRef<NamedAttribute> attrs) {
1369 result.addAttribute(SymbolTable::getSymbolAttrName(),
1370 builder.getStringAttr(name));
1371 result.addAttribute("type", TypeAttr::get(type));
1372 if (isConstant)
1373 result.addAttribute("constant", builder.getUnitAttr());
1374 if (value)
1375 result.addAttribute("value", value);
1376 if (dsoLocal)
1377 result.addAttribute("dso_local", builder.getUnitAttr());
1378
1379 // Only add an alignment attribute if the "alignment" input
1380 // is different from 0. The value must also be a power of two, but
1381 // this is tested in GlobalOp::verify, not here.
1382 if (alignment != 0)
1383 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
1384
1385 result.addAttribute(getLinkageAttrName(),
1386 builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1387 if (addrSpace != 0)
1388 result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace));
1389 result.attributes.append(attrs.begin(), attrs.end());
1390 result.addRegion();
1391 }
1392
printGlobalOp(OpAsmPrinter & p,GlobalOp op)1393 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
1394 p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' ';
1395 if (op.unnamed_addr())
1396 p << stringifyUnnamedAddr(*op.unnamed_addr()) << ' ';
1397 if (op.constant())
1398 p << "constant ";
1399 p.printSymbolName(op.sym_name());
1400 p << '(';
1401 if (auto value = op.getValueOrNull())
1402 p.printAttribute(value);
1403 p << ')';
1404 // Note that the alignment attribute is printed using the
1405 // default syntax here, even though it is an inherent attribute
1406 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
1407 p.printOptionalAttrDict(op->getAttrs(),
1408 {SymbolTable::getSymbolAttrName(), "type", "constant",
1409 "value", getLinkageAttrName(),
1410 getUnnamedAddrAttrName()});
1411
1412 // Print the trailing type unless it's a string global.
1413 if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
1414 return;
1415 p << " : " << op.type();
1416
1417 Region &initializer = op.getInitializerRegion();
1418 if (!initializer.empty())
1419 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1420 }
1421
1422 // Parses one of the keywords provided in the list `keywords` and returns the
1423 // position of the parsed keyword in the list. If none of the keywords from the
1424 // list is parsed, returns -1.
parseOptionalKeywordAlternative(OpAsmParser & parser,ArrayRef<StringRef> keywords)1425 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1426 ArrayRef<StringRef> keywords) {
1427 for (auto en : llvm::enumerate(keywords)) {
1428 if (succeeded(parser.parseOptionalKeyword(en.value())))
1429 return en.index();
1430 }
1431 return -1;
1432 }
1433
1434 namespace {
1435 template <typename Ty>
1436 struct EnumTraits {};
1437
1438 #define REGISTER_ENUM_TYPE(Ty) \
1439 template <> \
1440 struct EnumTraits<Ty> { \
1441 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
1442 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
1443 }
1444
1445 REGISTER_ENUM_TYPE(Linkage);
1446 REGISTER_ENUM_TYPE(UnnamedAddr);
1447 } // end namespace
1448
1449 template <typename EnumTy>
parseOptionalLLVMKeyword(OpAsmParser & parser,OperationState & result,StringRef name)1450 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
1451 OperationState &result,
1452 StringRef name) {
1453 SmallVector<StringRef, 10> names;
1454 for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
1455 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1456
1457 int index = parseOptionalKeywordAlternative(parser, names);
1458 if (index == -1)
1459 return failure();
1460 result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
1461 return success();
1462 }
1463
1464 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1465 // `(` attribute? `)` align? attribute-list? (`:` type)? region?
1466 // align ::= `align` `=` UINT64
1467 //
1468 // The type can be omitted for string attributes, in which case it will be
1469 // inferred from the value of the string as [strlen(value) x i8].
parseGlobalOp(OpAsmParser & parser,OperationState & result)1470 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
1471 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1472 getLinkageAttrName())))
1473 result.addAttribute(getLinkageAttrName(),
1474 parser.getBuilder().getI64IntegerAttr(
1475 static_cast<int64_t>(LLVM::Linkage::External)));
1476
1477 if (failed(parseOptionalLLVMKeyword<UnnamedAddr>(parser, result,
1478 getUnnamedAddrAttrName())))
1479 result.addAttribute(getUnnamedAddrAttrName(),
1480 parser.getBuilder().getI64IntegerAttr(
1481 static_cast<int64_t>(LLVM::UnnamedAddr::None)));
1482
1483 if (succeeded(parser.parseOptionalKeyword("constant")))
1484 result.addAttribute("constant", parser.getBuilder().getUnitAttr());
1485
1486 StringAttr name;
1487 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(),
1488 result.attributes) ||
1489 parser.parseLParen())
1490 return failure();
1491
1492 Attribute value;
1493 if (parser.parseOptionalRParen()) {
1494 if (parser.parseAttribute(value, "value", result.attributes) ||
1495 parser.parseRParen())
1496 return failure();
1497 }
1498
1499 SmallVector<Type, 1> types;
1500 if (parser.parseOptionalAttrDict(result.attributes) ||
1501 parser.parseOptionalColonTypeList(types))
1502 return failure();
1503
1504 if (types.size() > 1)
1505 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1506
1507 Region &initRegion = *result.addRegion();
1508 if (types.empty()) {
1509 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1510 MLIRContext *context = parser.getBuilder().getContext();
1511 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
1512 strAttr.getValue().size());
1513 types.push_back(arrayType);
1514 } else {
1515 return parser.emitError(parser.getNameLoc(),
1516 "type can only be omitted for string globals");
1517 }
1518 } else {
1519 OptionalParseResult parseResult =
1520 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1521 /*argTypes=*/{});
1522 if (parseResult.hasValue() && failed(*parseResult))
1523 return failure();
1524 }
1525
1526 result.addAttribute("type", TypeAttr::get(types[0]));
1527 return success();
1528 }
1529
isZeroAttribute(Attribute value)1530 static bool isZeroAttribute(Attribute value) {
1531 if (auto intValue = value.dyn_cast<IntegerAttr>())
1532 return intValue.getValue().isNullValue();
1533 if (auto fpValue = value.dyn_cast<FloatAttr>())
1534 return fpValue.getValue().isZero();
1535 if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
1536 return isZeroAttribute(splatValue.getSplatValue());
1537 if (auto elementsValue = value.dyn_cast<ElementsAttr>())
1538 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
1539 if (auto arrayValue = value.dyn_cast<ArrayAttr>())
1540 return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
1541 return false;
1542 }
1543
verify(GlobalOp op)1544 static LogicalResult verify(GlobalOp op) {
1545 if (!LLVMPointerType::isValidElementType(op.getType()))
1546 return op.emitOpError(
1547 "expects type to be a valid element type for an LLVM pointer");
1548 if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp()))
1549 return op.emitOpError("must appear at the module level");
1550
1551 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1552 auto type = op.getType().dyn_cast<LLVMArrayType>();
1553 IntegerType elementType =
1554 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
1555 if (!elementType || elementType.getWidth() != 8 ||
1556 type.getNumElements() != strAttr.getValue().size())
1557 return op.emitOpError(
1558 "requires an i8 array type of the length equal to that of the string "
1559 "attribute");
1560 }
1561
1562 if (Block *b = op.getInitializerBlock()) {
1563 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
1564 if (ret.operand_type_begin() == ret.operand_type_end())
1565 return op.emitOpError("initializer region cannot return void");
1566 if (*ret.operand_type_begin() != op.getType())
1567 return op.emitOpError("initializer region type ")
1568 << *ret.operand_type_begin() << " does not match global type "
1569 << op.getType();
1570
1571 if (op.getValueOrNull())
1572 return op.emitOpError("cannot have both initializer value and region");
1573 }
1574
1575 if (op.linkage() == Linkage::Common) {
1576 if (Attribute value = op.getValueOrNull()) {
1577 if (!isZeroAttribute(value)) {
1578 return op.emitOpError()
1579 << "expected zero value for '"
1580 << stringifyLinkage(Linkage::Common) << "' linkage";
1581 }
1582 }
1583 }
1584
1585 if (op.linkage() == Linkage::Appending) {
1586 if (!op.getType().isa<LLVMArrayType>()) {
1587 return op.emitOpError()
1588 << "expected array type for '"
1589 << stringifyLinkage(Linkage::Appending) << "' linkage";
1590 }
1591 }
1592
1593 Optional<uint64_t> alignAttr = op.alignment();
1594 if (alignAttr.hasValue()) {
1595 uint64_t value = alignAttr.getValue();
1596 if (!llvm::isPowerOf2_64(value))
1597 return op->emitError() << "alignment attribute is not a power of 2";
1598 }
1599
1600 return success();
1601 }
1602
1603 //===----------------------------------------------------------------------===//
1604 // Printing/parsing for LLVM::ShuffleVectorOp.
1605 //===----------------------------------------------------------------------===//
1606 // Expects vector to be of wrapped LLVM vector type and position to be of
1607 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value v1,Value v2,ArrayAttr mask,ArrayRef<NamedAttribute> attrs)1608 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
1609 Value v1, Value v2, ArrayAttr mask,
1610 ArrayRef<NamedAttribute> attrs) {
1611 auto containerType = v1.getType();
1612 auto vType = LLVM::getFixedVectorType(
1613 LLVM::getVectorElementType(containerType), mask.size());
1614 build(b, result, vType, v1, v2, mask);
1615 result.addAttributes(attrs);
1616 }
1617
printShuffleVectorOp(OpAsmPrinter & p,ShuffleVectorOp & op)1618 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
1619 p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " "
1620 << op.mask();
1621 p.printOptionalAttrDict(op->getAttrs(), {"mask"});
1622 p << " : " << op.v1().getType() << ", " << op.v2().getType();
1623 }
1624
1625 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
1626 // `[` integer-literal (`,` integer-literal)* `]`
1627 // attribute-dict? `:` type
parseShuffleVectorOp(OpAsmParser & parser,OperationState & result)1628 static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
1629 OperationState &result) {
1630 llvm::SMLoc loc;
1631 OpAsmParser::OperandType v1, v2;
1632 ArrayAttr maskAttr;
1633 Type typeV1, typeV2;
1634 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
1635 parser.parseComma() || parser.parseOperand(v2) ||
1636 parser.parseAttribute(maskAttr, "mask", result.attributes) ||
1637 parser.parseOptionalAttrDict(result.attributes) ||
1638 parser.parseColonType(typeV1) || parser.parseComma() ||
1639 parser.parseType(typeV2) ||
1640 parser.resolveOperand(v1, typeV1, result.operands) ||
1641 parser.resolveOperand(v2, typeV2, result.operands))
1642 return failure();
1643 if (!LLVM::isCompatibleVectorType(typeV1))
1644 return parser.emitError(
1645 loc, "expected LLVM IR dialect vector type for operand #1");
1646 auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1),
1647 maskAttr.size());
1648 result.addTypes(vType);
1649 return success();
1650 }
1651
1652 //===----------------------------------------------------------------------===//
1653 // Implementations for LLVM::LLVMFuncOp.
1654 //===----------------------------------------------------------------------===//
1655
1656 // Add the entry block to the function.
addEntryBlock()1657 Block *LLVMFuncOp::addEntryBlock() {
1658 assert(empty() && "function already has an entry block");
1659 assert(!isVarArg() && "unimplemented: non-external variadic functions");
1660
1661 auto *entry = new Block;
1662 push_back(entry);
1663
1664 LLVMFunctionType type = getType();
1665 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
1666 entry->addArgument(type.getParamType(i));
1667 return entry;
1668 }
1669
build(OpBuilder & builder,OperationState & result,StringRef name,Type type,LLVM::Linkage linkage,bool dsoLocal,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)1670 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
1671 StringRef name, Type type, LLVM::Linkage linkage,
1672 bool dsoLocal, ArrayRef<NamedAttribute> attrs,
1673 ArrayRef<DictionaryAttr> argAttrs) {
1674 result.addRegion();
1675 result.addAttribute(SymbolTable::getSymbolAttrName(),
1676 builder.getStringAttr(name));
1677 result.addAttribute("type", TypeAttr::get(type));
1678 result.addAttribute(getLinkageAttrName(),
1679 builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1680 result.attributes.append(attrs.begin(), attrs.end());
1681 if (dsoLocal)
1682 result.addAttribute("dso_local", builder.getUnitAttr());
1683 if (argAttrs.empty())
1684 return;
1685
1686 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
1687 "expected as many argument attribute lists as arguments");
1688 function_like_impl::addArgAndResultAttrs(builder, result, argAttrs,
1689 /*resultAttrs=*/llvm::None);
1690 }
1691
1692 // Builds an LLVM function type from the given lists of input and output types.
1693 // Returns a null type if any of the types provided are non-LLVM types, or if
1694 // there is more than one output type.
1695 static Type
buildLLVMFunctionType(OpAsmParser & parser,llvm::SMLoc loc,ArrayRef<Type> inputs,ArrayRef<Type> outputs,function_like_impl::VariadicFlag variadicFlag)1696 buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
1697 ArrayRef<Type> inputs, ArrayRef<Type> outputs,
1698 function_like_impl::VariadicFlag variadicFlag) {
1699 Builder &b = parser.getBuilder();
1700 if (outputs.size() > 1) {
1701 parser.emitError(loc, "failed to construct function type: expected zero or "
1702 "one function result");
1703 return {};
1704 }
1705
1706 // Convert inputs to LLVM types, exit early on error.
1707 SmallVector<Type, 4> llvmInputs;
1708 for (auto t : inputs) {
1709 if (!isCompatibleType(t)) {
1710 parser.emitError(loc, "failed to construct function type: expected LLVM "
1711 "type for function arguments");
1712 return {};
1713 }
1714 llvmInputs.push_back(t);
1715 }
1716
1717 // No output is denoted as "void" in LLVM type system.
1718 Type llvmOutput =
1719 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
1720 if (!isCompatibleType(llvmOutput)) {
1721 parser.emitError(loc, "failed to construct function type: expected LLVM "
1722 "type for function results")
1723 << llvmOutput;
1724 return {};
1725 }
1726 return LLVMFunctionType::get(llvmOutput, llvmInputs,
1727 variadicFlag.isVariadic());
1728 }
1729
1730 // Parses an LLVM function.
1731 //
1732 // operation ::= `llvm.func` linkage? function-signature function-attributes?
1733 // function-body
1734 //
parseLLVMFuncOp(OpAsmParser & parser,OperationState & result)1735 static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
1736 OperationState &result) {
1737 // Default to external linkage if no keyword is provided.
1738 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1739 getLinkageAttrName())))
1740 result.addAttribute(getLinkageAttrName(),
1741 parser.getBuilder().getI64IntegerAttr(
1742 static_cast<int64_t>(LLVM::Linkage::External)));
1743
1744 StringAttr nameAttr;
1745 SmallVector<OpAsmParser::OperandType, 8> entryArgs;
1746 SmallVector<NamedAttrList, 1> argAttrs;
1747 SmallVector<NamedAttrList, 1> resultAttrs;
1748 SmallVector<Type, 8> argTypes;
1749 SmallVector<Type, 4> resultTypes;
1750 bool isVariadic;
1751
1752 auto signatureLocation = parser.getCurrentLocation();
1753 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1754 result.attributes) ||
1755 function_like_impl::parseFunctionSignature(
1756 parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
1757 isVariadic, resultTypes, resultAttrs))
1758 return failure();
1759
1760 auto type =
1761 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
1762 function_like_impl::VariadicFlag(isVariadic));
1763 if (!type)
1764 return failure();
1765 result.addAttribute(function_like_impl::getTypeAttrName(),
1766 TypeAttr::get(type));
1767
1768 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1769 return failure();
1770 function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result,
1771 argAttrs, resultAttrs);
1772
1773 auto *body = result.addRegion();
1774 OptionalParseResult parseResult = parser.parseOptionalRegion(
1775 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1776 return failure(parseResult.hasValue() && failed(*parseResult));
1777 }
1778
1779 // Print the LLVMFuncOp. Collects argument and result types and passes them to
1780 // helper functions. Drops "void" result since it cannot be parsed back. Skips
1781 // the external linkage since it is the default value.
printLLVMFuncOp(OpAsmPrinter & p,LLVMFuncOp op)1782 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
1783 p << op.getOperationName() << ' ';
1784 if (op.linkage() != LLVM::Linkage::External)
1785 p << stringifyLinkage(op.linkage()) << ' ';
1786 p.printSymbolName(op.getName());
1787
1788 LLVMFunctionType fnType = op.getType();
1789 SmallVector<Type, 8> argTypes;
1790 SmallVector<Type, 1> resTypes;
1791 argTypes.reserve(fnType.getNumParams());
1792 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
1793 argTypes.push_back(fnType.getParamType(i));
1794
1795 Type returnType = fnType.getReturnType();
1796 if (!returnType.isa<LLVMVoidType>())
1797 resTypes.push_back(returnType);
1798
1799 function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(),
1800 resTypes);
1801 function_like_impl::printFunctionAttributes(
1802 p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()});
1803
1804 // Print the body if this is not an external function.
1805 Region &body = op.body();
1806 if (!body.empty())
1807 p.printRegion(body, /*printEntryBlockArgs=*/false,
1808 /*printBlockTerminators=*/true);
1809 }
1810
1811 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
1812 // attribute is present. This can check for preconditions of the
1813 // getNumArguments hook not failing.
verifyType()1814 LogicalResult LLVMFuncOp::verifyType() {
1815 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>();
1816 if (!llvmType)
1817 return emitOpError("requires '" + getTypeAttrName() +
1818 "' attribute of wrapped LLVM function type");
1819
1820 return success();
1821 }
1822
1823 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
1824 // Depends on the type attribute being correct as checked by verifyType
getNumFuncArguments()1825 unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); }
1826
1827 // Hook for OpTrait::FunctionLike, returns the number of function results.
1828 // Depends on the type attribute being correct as checked by verifyType
getNumFuncResults()1829 unsigned LLVMFuncOp::getNumFuncResults() {
1830 // We model LLVM functions that return void as having zero results,
1831 // and all others as having one result.
1832 // If we modeled a void return as one result, then it would be possible to
1833 // attach an MLIR result attribute to it, and it isn't clear what semantics we
1834 // would assign to that.
1835 if (getType().getReturnType().isa<LLVMVoidType>())
1836 return 0;
1837 return 1;
1838 }
1839
1840 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
1841 // - functions don't have 'common' linkage
1842 // - external functions have 'external' or 'extern_weak' linkage;
1843 // - vararg is (currently) only supported for external functions;
1844 // - entry block arguments are of LLVM types and match the function signature.
verify(LLVMFuncOp op)1845 static LogicalResult verify(LLVMFuncOp op) {
1846 if (op.linkage() == LLVM::Linkage::Common)
1847 return op.emitOpError()
1848 << "functions cannot have '"
1849 << stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
1850
1851 if (op.isExternal()) {
1852 if (op.linkage() != LLVM::Linkage::External &&
1853 op.linkage() != LLVM::Linkage::ExternWeak)
1854 return op.emitOpError()
1855 << "external functions must have '"
1856 << stringifyLinkage(LLVM::Linkage::External) << "' or '"
1857 << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
1858 return success();
1859 }
1860
1861 if (op.isVarArg())
1862 return op.emitOpError("only external functions can be variadic");
1863
1864 unsigned numArguments = op.getType().getNumParams();
1865 Block &entryBlock = op.front();
1866 for (unsigned i = 0; i < numArguments; ++i) {
1867 Type argType = entryBlock.getArgument(i).getType();
1868 if (!isCompatibleType(argType))
1869 return op.emitOpError("entry block argument #")
1870 << i << " is not of LLVM type";
1871 if (op.getType().getParamType(i) != argType)
1872 return op.emitOpError("the type of entry block argument #")
1873 << i << " does not match the function signature";
1874 }
1875
1876 return success();
1877 }
1878
1879 //===----------------------------------------------------------------------===//
1880 // Verification for LLVM::ConstantOp.
1881 //===----------------------------------------------------------------------===//
1882
verify(LLVM::ConstantOp op)1883 static LogicalResult verify(LLVM::ConstantOp op) {
1884 if (StringAttr sAttr = op.value().dyn_cast<StringAttr>()) {
1885 auto arrayType = op.getType().dyn_cast<LLVMArrayType>();
1886 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
1887 !arrayType.getElementType().isInteger(8)) {
1888 return op->emitOpError()
1889 << "expected array type of " << sAttr.getValue().size()
1890 << " i8 elements for the string constant";
1891 }
1892 return success();
1893 }
1894 if (auto structType = op.getType().dyn_cast<LLVMStructType>()) {
1895 if (structType.getBody().size() != 2 ||
1896 structType.getBody()[0] != structType.getBody()[1]) {
1897 return op.emitError() << "expected struct type with two elements of the "
1898 "same type, the type of a complex constant";
1899 }
1900
1901 auto arrayAttr = op.value().dyn_cast<ArrayAttr>();
1902 if (!arrayAttr || arrayAttr.size() != 2 ||
1903 arrayAttr[0].getType() != arrayAttr[1].getType()) {
1904 return op.emitOpError() << "expected array attribute with two elements, "
1905 "representing a complex constant";
1906 }
1907
1908 Type elementType = structType.getBody()[0];
1909 if (!elementType
1910 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
1911 return op.emitError()
1912 << "expected struct element types to be floating point type or "
1913 "integer type";
1914 }
1915 return success();
1916 }
1917 if (!op.value().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
1918 return op.emitOpError()
1919 << "only supports integer, float, string or elements attributes";
1920 return success();
1921 }
1922
1923 //===----------------------------------------------------------------------===//
1924 // Utility functions for parsing atomic ops
1925 //===----------------------------------------------------------------------===//
1926
1927 // Helper function to parse a keyword into the specified attribute named by
1928 // `attrName`. The keyword must match one of the string values defined by the
1929 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
1930 // state.
parseAtomicBinOp(OpAsmParser & parser,OperationState & result,StringRef attrName)1931 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
1932 StringRef attrName) {
1933 llvm::SMLoc loc;
1934 StringRef keyword;
1935 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
1936 return failure();
1937
1938 // Replace the keyword `keyword` with an integer attribute.
1939 auto kind = symbolizeAtomicBinOp(keyword);
1940 if (!kind) {
1941 return parser.emitError(loc)
1942 << "'" << keyword << "' is an incorrect value of the '" << attrName
1943 << "' attribute";
1944 }
1945
1946 auto value = static_cast<int64_t>(kind.getValue());
1947 auto attr = parser.getBuilder().getI64IntegerAttr(value);
1948 result.addAttribute(attrName, attr);
1949
1950 return success();
1951 }
1952
1953 // Helper function to parse a keyword into the specified attribute named by
1954 // `attrName`. The keyword must match one of the string values defined by the
1955 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
1956 // state.
parseAtomicOrdering(OpAsmParser & parser,OperationState & result,StringRef attrName)1957 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
1958 OperationState &result,
1959 StringRef attrName) {
1960 llvm::SMLoc loc;
1961 StringRef ordering;
1962 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
1963 return failure();
1964
1965 // Replace the keyword `ordering` with an integer attribute.
1966 auto kind = symbolizeAtomicOrdering(ordering);
1967 if (!kind) {
1968 return parser.emitError(loc)
1969 << "'" << ordering << "' is an incorrect value of the '" << attrName
1970 << "' attribute";
1971 }
1972
1973 auto value = static_cast<int64_t>(kind.getValue());
1974 auto attr = parser.getBuilder().getI64IntegerAttr(value);
1975 result.addAttribute(attrName, attr);
1976
1977 return success();
1978 }
1979
1980 //===----------------------------------------------------------------------===//
1981 // Printer, parser and verifier for LLVM::AtomicRMWOp.
1982 //===----------------------------------------------------------------------===//
1983
printAtomicRMWOp(OpAsmPrinter & p,AtomicRMWOp & op)1984 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
1985 p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' '
1986 << op.ptr() << ", " << op.val() << ' '
1987 << stringifyAtomicOrdering(op.ordering()) << ' ';
1988 p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"});
1989 p << " : " << op.res().getType();
1990 }
1991
1992 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
1993 // attribute-dict? `:` type
parseAtomicRMWOp(OpAsmParser & parser,OperationState & result)1994 static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
1995 OperationState &result) {
1996 Type type;
1997 OpAsmParser::OperandType ptr, val;
1998 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
1999 parser.parseComma() || parser.parseOperand(val) ||
2000 parseAtomicOrdering(parser, result, "ordering") ||
2001 parser.parseOptionalAttrDict(result.attributes) ||
2002 parser.parseColonType(type) ||
2003 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2004 result.operands) ||
2005 parser.resolveOperand(val, type, result.operands))
2006 return failure();
2007
2008 result.addTypes(type);
2009 return success();
2010 }
2011
verify(AtomicRMWOp op)2012 static LogicalResult verify(AtomicRMWOp op) {
2013 auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
2014 auto valType = op.val().getType();
2015 if (valType != ptrType.getElementType())
2016 return op.emitOpError("expected LLVM IR element type for operand #0 to "
2017 "match type for operand #1");
2018 auto resType = op.res().getType();
2019 if (resType != valType)
2020 return op.emitOpError(
2021 "expected LLVM IR result type to match type for operand #1");
2022 if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
2023 if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2024 return op.emitOpError("expected LLVM IR floating point type");
2025 } else if (op.bin_op() == AtomicBinOp::xchg) {
2026 auto intType = valType.dyn_cast<IntegerType>();
2027 unsigned intBitWidth = intType ? intType.getWidth() : 0;
2028 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2029 intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
2030 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
2031 !valType.isa<Float64Type>())
2032 return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2033 } else {
2034 auto intType = valType.dyn_cast<IntegerType>();
2035 unsigned intBitWidth = intType ? intType.getWidth() : 0;
2036 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2037 intBitWidth != 64)
2038 return op.emitOpError("expected LLVM IR integer type");
2039 }
2040
2041 if (static_cast<unsigned>(op.ordering()) <
2042 static_cast<unsigned>(AtomicOrdering::monotonic))
2043 return op.emitOpError()
2044 << "expected at least '"
2045 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2046 << "' ordering";
2047
2048 return success();
2049 }
2050
2051 //===----------------------------------------------------------------------===//
2052 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
2053 //===----------------------------------------------------------------------===//
2054
printAtomicCmpXchgOp(OpAsmPrinter & p,AtomicCmpXchgOp & op)2055 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) {
2056 p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", "
2057 << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' '
2058 << stringifyAtomicOrdering(op.failure_ordering());
2059 p.printOptionalAttrDict(op->getAttrs(),
2060 {"success_ordering", "failure_ordering"});
2061 p << " : " << op.val().getType();
2062 }
2063
2064 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
2065 // keyword keyword attribute-dict? `:` type
parseAtomicCmpXchgOp(OpAsmParser & parser,OperationState & result)2066 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
2067 OperationState &result) {
2068 auto &builder = parser.getBuilder();
2069 Type type;
2070 OpAsmParser::OperandType ptr, cmp, val;
2071 if (parser.parseOperand(ptr) || parser.parseComma() ||
2072 parser.parseOperand(cmp) || parser.parseComma() ||
2073 parser.parseOperand(val) ||
2074 parseAtomicOrdering(parser, result, "success_ordering") ||
2075 parseAtomicOrdering(parser, result, "failure_ordering") ||
2076 parser.parseOptionalAttrDict(result.attributes) ||
2077 parser.parseColonType(type) ||
2078 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2079 result.operands) ||
2080 parser.resolveOperand(cmp, type, result.operands) ||
2081 parser.resolveOperand(val, type, result.operands))
2082 return failure();
2083
2084 auto boolType = IntegerType::get(builder.getContext(), 1);
2085 auto resultType =
2086 LLVMStructType::getLiteral(builder.getContext(), {type, boolType});
2087 result.addTypes(resultType);
2088
2089 return success();
2090 }
2091
verify(AtomicCmpXchgOp op)2092 static LogicalResult verify(AtomicCmpXchgOp op) {
2093 auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
2094 if (!ptrType)
2095 return op.emitOpError("expected LLVM IR pointer type for operand #0");
2096 auto cmpType = op.cmp().getType();
2097 auto valType = op.val().getType();
2098 if (cmpType != ptrType.getElementType() || cmpType != valType)
2099 return op.emitOpError("expected LLVM IR element type for operand #0 to "
2100 "match type for all other operands");
2101 auto intType = valType.dyn_cast<IntegerType>();
2102 unsigned intBitWidth = intType ? intType.getWidth() : 0;
2103 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
2104 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
2105 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
2106 !valType.isa<Float32Type>() && !valType.isa<Float64Type>())
2107 return op.emitOpError("unexpected LLVM IR type");
2108 if (op.success_ordering() < AtomicOrdering::monotonic ||
2109 op.failure_ordering() < AtomicOrdering::monotonic)
2110 return op.emitOpError("ordering must be at least 'monotonic'");
2111 if (op.failure_ordering() == AtomicOrdering::release ||
2112 op.failure_ordering() == AtomicOrdering::acq_rel)
2113 return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2114 return success();
2115 }
2116
2117 //===----------------------------------------------------------------------===//
2118 // Printer, parser and verifier for LLVM::FenceOp.
2119 //===----------------------------------------------------------------------===//
2120
2121 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
2122 // attribute-dict?
parseFenceOp(OpAsmParser & parser,OperationState & result)2123 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) {
2124 StringAttr sScope;
2125 StringRef syncscopeKeyword = "syncscope";
2126 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
2127 if (parser.parseLParen() ||
2128 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
2129 parser.parseRParen())
2130 return failure();
2131 } else {
2132 result.addAttribute(syncscopeKeyword,
2133 parser.getBuilder().getStringAttr(""));
2134 }
2135 if (parseAtomicOrdering(parser, result, "ordering") ||
2136 parser.parseOptionalAttrDict(result.attributes))
2137 return failure();
2138 return success();
2139 }
2140
printFenceOp(OpAsmPrinter & p,FenceOp & op)2141 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
2142 StringRef syncscopeKeyword = "syncscope";
2143 p << op.getOperationName() << ' ';
2144 if (!op->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
2145 p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") ";
2146 p << stringifyAtomicOrdering(op.ordering());
2147 }
2148
verify(FenceOp & op)2149 static LogicalResult verify(FenceOp &op) {
2150 if (op.ordering() == AtomicOrdering::not_atomic ||
2151 op.ordering() == AtomicOrdering::unordered ||
2152 op.ordering() == AtomicOrdering::monotonic)
2153 return op.emitOpError("can be given only acquire, release, acq_rel, "
2154 "and seq_cst orderings");
2155 return success();
2156 }
2157
2158 //===----------------------------------------------------------------------===//
2159 // LLVMDialect initialization, type parsing, and registration.
2160 //===----------------------------------------------------------------------===//
2161
initialize()2162 void LLVMDialect::initialize() {
2163 addAttributes<FMFAttr, LoopOptionsAttr>();
2164
2165 // clang-format off
2166 addTypes<LLVMVoidType,
2167 LLVMPPCFP128Type,
2168 LLVMX86MMXType,
2169 LLVMTokenType,
2170 LLVMLabelType,
2171 LLVMMetadataType,
2172 LLVMFunctionType,
2173 LLVMPointerType,
2174 LLVMFixedVectorType,
2175 LLVMScalableVectorType,
2176 LLVMArrayType,
2177 LLVMStructType>();
2178 // clang-format on
2179 addOperations<
2180 #define GET_OP_LIST
2181 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2182 >();
2183
2184 // Support unknown operations because not all LLVM operations are registered.
2185 allowUnknownOperations();
2186 }
2187
2188 #define GET_OP_CLASSES
2189 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2190
2191 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const2192 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
2193 return detail::parseType(parser);
2194 }
2195
2196 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const2197 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
2198 return detail::printType(type, os);
2199 }
2200
verifyDataLayoutString(StringRef descr,llvm::function_ref<void (const Twine &)> reportError)2201 LogicalResult LLVMDialect::verifyDataLayoutString(
2202 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
2203 llvm::Expected<llvm::DataLayout> maybeDataLayout =
2204 llvm::DataLayout::parse(descr);
2205 if (maybeDataLayout)
2206 return success();
2207
2208 std::string message;
2209 llvm::raw_string_ostream messageStream(message);
2210 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
2211 reportError("invalid data layout descriptor: " + messageStream.str());
2212 return failure();
2213 }
2214
2215 /// Verify LLVM dialect attributes.
verifyOperationAttribute(Operation * op,NamedAttribute attr)2216 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
2217 NamedAttribute attr) {
2218 // If the `llvm.loop` attribute is present, enforce the following structure,
2219 // which the module translation can assume.
2220 if (attr.first.strref() == LLVMDialect::getLoopAttrName()) {
2221 auto loopAttr = attr.second.dyn_cast<DictionaryAttr>();
2222 if (!loopAttr)
2223 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
2224 << "' to be a dictionary attribute";
2225 Optional<NamedAttribute> parallelAccessGroup =
2226 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
2227 if (parallelAccessGroup.hasValue()) {
2228 auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>();
2229 if (!accessGroups)
2230 return op->emitOpError()
2231 << "expected '" << LLVMDialect::getParallelAccessAttrName()
2232 << "' to be an array attribute";
2233 for (Attribute attr : accessGroups) {
2234 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
2235 if (!accessGroupRef)
2236 return op->emitOpError()
2237 << "expected '" << attr << "' to be a symbol reference";
2238 StringRef metadataName = accessGroupRef.getRootReference();
2239 auto metadataOp =
2240 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
2241 op->getParentOp(), metadataName);
2242 if (!metadataOp)
2243 return op->emitOpError()
2244 << "expected '" << attr << "' to reference a metadata op";
2245 StringRef accessGroupName = accessGroupRef.getLeafReference();
2246 Operation *accessGroupOp =
2247 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
2248 if (!accessGroupOp)
2249 return op->emitOpError()
2250 << "expected '" << attr << "' to reference an access_group op";
2251 }
2252 }
2253
2254 Optional<NamedAttribute> loopOptions =
2255 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
2256 if (loopOptions.hasValue() && !loopOptions->second.isa<LoopOptionsAttr>())
2257 return op->emitOpError()
2258 << "expected '" << LLVMDialect::getLoopOptionsAttrName()
2259 << "' to be a `loopopts` attribute";
2260 }
2261
2262 // If the data layout attribute is present, it must use the LLVM data layout
2263 // syntax. Try parsing it and report errors in case of failure. Users of this
2264 // attribute may assume it is well-formed and can pass it to the (asserting)
2265 // llvm::DataLayout constructor.
2266 if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName())
2267 return success();
2268 if (auto stringAttr = attr.second.dyn_cast<StringAttr>())
2269 return verifyDataLayoutString(
2270 stringAttr.getValue(),
2271 [op](const Twine &message) { op->emitOpError() << message.str(); });
2272
2273 return op->emitOpError() << "expected '"
2274 << LLVM::LLVMDialect::getDataLayoutAttrName()
2275 << "' to be a string attribute";
2276 }
2277
2278 /// Verify LLVMIR function argument attributes.
verifyRegionArgAttribute(Operation * op,unsigned regionIdx,unsigned argIdx,NamedAttribute argAttr)2279 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
2280 unsigned regionIdx,
2281 unsigned argIdx,
2282 NamedAttribute argAttr) {
2283 // Check that llvm.noalias is a unit attribute.
2284 if (argAttr.first == LLVMDialect::getNoAliasAttrName() &&
2285 !argAttr.second.isa<UnitAttr>())
2286 return op->emitError()
2287 << "expected llvm.noalias argument attribute to be a unit attribute";
2288 // Check that llvm.align is an integer attribute.
2289 if (argAttr.first == LLVMDialect::getAlignAttrName() &&
2290 !argAttr.second.isa<IntegerAttr>())
2291 return op->emitError()
2292 << "llvm.align argument attribute of non integer type";
2293 return success();
2294 }
2295
2296 //===----------------------------------------------------------------------===//
2297 // Utility functions.
2298 //===----------------------------------------------------------------------===//
2299
createGlobalString(Location loc,OpBuilder & builder,StringRef name,StringRef value,LLVM::Linkage linkage)2300 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
2301 StringRef name, StringRef value,
2302 LLVM::Linkage linkage) {
2303 assert(builder.getInsertionBlock() &&
2304 builder.getInsertionBlock()->getParentOp() &&
2305 "expected builder to point to a block constrained in an op");
2306 auto module =
2307 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
2308 assert(module && "builder points to an op outside of a module");
2309
2310 // Create the global at the entry of the module.
2311 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
2312 MLIRContext *ctx = builder.getContext();
2313 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
2314 auto global = moduleBuilder.create<LLVM::GlobalOp>(
2315 loc, type, /*isConstant=*/true, linkage, name,
2316 builder.getStringAttr(value), /*alignment=*/0);
2317
2318 // Get the pointer to the first character in the global string.
2319 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
2320 Value cst0 = builder.create<LLVM::ConstantOp>(
2321 loc, IntegerType::get(ctx, 64),
2322 builder.getIntegerAttr(builder.getIndexType(), 0));
2323 return builder.create<LLVM::GEPOp>(
2324 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
2325 ValueRange{cst0, cst0});
2326 }
2327
satisfiesLLVMModule(Operation * op)2328 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
2329 return op->hasTrait<OpTrait::SymbolTable>() &&
2330 op->hasTrait<OpTrait::IsIsolatedFromAbove>();
2331 }
2332
2333 static constexpr const FastmathFlags FastmathFlagsList[] = {
2334 // clang-format off
2335 FastmathFlags::nnan,
2336 FastmathFlags::ninf,
2337 FastmathFlags::nsz,
2338 FastmathFlags::arcp,
2339 FastmathFlags::contract,
2340 FastmathFlags::afn,
2341 FastmathFlags::reassoc,
2342 FastmathFlags::fast,
2343 // clang-format on
2344 };
2345
print(DialectAsmPrinter & printer) const2346 void FMFAttr::print(DialectAsmPrinter &printer) const {
2347 printer << "fastmath<";
2348 auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) {
2349 return bitEnumContains(this->getFlags(), flag);
2350 });
2351 llvm::interleaveComma(flags, printer,
2352 [&](auto flag) { printer << stringifyEnum(flag); });
2353 printer << ">";
2354 }
2355
parse(MLIRContext * context,DialectAsmParser & parser,Type type)2356 Attribute FMFAttr::parse(MLIRContext *context, DialectAsmParser &parser,
2357 Type type) {
2358 if (failed(parser.parseLess()))
2359 return {};
2360
2361 FastmathFlags flags = {};
2362 if (failed(parser.parseOptionalGreater())) {
2363 do {
2364 StringRef elemName;
2365 if (failed(parser.parseKeyword(&elemName)))
2366 return {};
2367
2368 auto elem = symbolizeFastmathFlags(elemName);
2369 if (!elem) {
2370 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
2371 << elemName;
2372 return {};
2373 }
2374
2375 flags = flags | *elem;
2376 } while (succeeded(parser.parseOptionalComma()));
2377
2378 if (failed(parser.parseGreater()))
2379 return {};
2380 }
2381
2382 return FMFAttr::get(parser.getBuilder().getContext(), flags);
2383 }
2384
LoopOptionsAttrBuilder(LoopOptionsAttr attr)2385 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
2386 : options(attr.getOptions().begin(), attr.getOptions().end()) {}
2387
2388 template <typename T>
setOption(LoopOptionCase tag,Optional<T> value)2389 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag,
2390 Optional<T> value) {
2391 auto option = llvm::find_if(
2392 options, [tag](auto option) { return option.first == tag; });
2393 if (option != options.end()) {
2394 if (value.hasValue())
2395 option->second = *value;
2396 else
2397 options.erase(option);
2398 } else {
2399 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value));
2400 }
2401 return *this;
2402 }
2403
2404 LoopOptionsAttrBuilder &
setDisableLICM(Optional<bool> value)2405 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) {
2406 return setOption(LoopOptionCase::disable_licm, value);
2407 }
2408
2409 /// Set the `interleave_count` option to the provided value. If no value
2410 /// is provided the option is deleted.
2411 LoopOptionsAttrBuilder &
setInterleaveCount(Optional<uint64_t> count)2412 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) {
2413 return setOption(LoopOptionCase::interleave_count, count);
2414 }
2415
2416 /// Set the `disable_unroll` option to the provided value. If no value
2417 /// is provided the option is deleted.
2418 LoopOptionsAttrBuilder &
setDisableUnroll(Optional<bool> value)2419 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) {
2420 return setOption(LoopOptionCase::disable_unroll, value);
2421 }
2422
2423 /// Set the `disable_pipeline` option to the provided value. If no value
2424 /// is provided the option is deleted.
2425 LoopOptionsAttrBuilder &
setDisablePipeline(Optional<bool> value)2426 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) {
2427 return setOption(LoopOptionCase::disable_pipeline, value);
2428 }
2429
2430 /// Set the `pipeline_initiation_interval` option to the provided value.
2431 /// If no value is provided the option is deleted.
setPipelineInitiationInterval(Optional<uint64_t> count)2432 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval(
2433 Optional<uint64_t> count) {
2434 return setOption(LoopOptionCase::pipeline_initiation_interval, count);
2435 }
2436
2437 template <typename T>
2438 static Optional<T>
getOption(ArrayRef<std::pair<LoopOptionCase,int64_t>> options,LoopOptionCase option)2439 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options,
2440 LoopOptionCase option) {
2441 auto it =
2442 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) {
2443 return optionPair.first < option;
2444 });
2445 if (it == options.end())
2446 return {};
2447 return static_cast<T>(it->second);
2448 }
2449
disableUnroll()2450 Optional<bool> LoopOptionsAttr::disableUnroll() {
2451 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll);
2452 }
2453
disableLICM()2454 Optional<bool> LoopOptionsAttr::disableLICM() {
2455 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm);
2456 }
2457
interleaveCount()2458 Optional<int64_t> LoopOptionsAttr::interleaveCount() {
2459 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count);
2460 }
2461
2462 /// Build the LoopOptions Attribute from a sorted array of individual options.
get(MLIRContext * context,ArrayRef<std::pair<LoopOptionCase,int64_t>> sortedOptions)2463 LoopOptionsAttr LoopOptionsAttr::get(
2464 MLIRContext *context,
2465 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) {
2466 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) &&
2467 "LoopOptionsAttr ctor expects a sorted options array");
2468 return Base::get(context, sortedOptions);
2469 }
2470
2471 /// Build the LoopOptions Attribute from a sorted array of individual options.
get(MLIRContext * context,LoopOptionsAttrBuilder & optionBuilders)2472 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context,
2473 LoopOptionsAttrBuilder &optionBuilders) {
2474 llvm::sort(optionBuilders.options, llvm::less_first());
2475 return Base::get(context, optionBuilders.options);
2476 }
2477
print(DialectAsmPrinter & printer) const2478 void LoopOptionsAttr::print(DialectAsmPrinter &printer) const {
2479 printer << getMnemonic() << "<";
2480 llvm::interleaveComma(getOptions(), printer, [&](auto option) {
2481 printer << stringifyEnum(option.first) << " = ";
2482 switch (option.first) {
2483 case LoopOptionCase::disable_licm:
2484 case LoopOptionCase::disable_unroll:
2485 case LoopOptionCase::disable_pipeline:
2486 printer << (option.second ? "true" : "false");
2487 break;
2488 case LoopOptionCase::interleave_count:
2489 case LoopOptionCase::pipeline_initiation_interval:
2490 printer << option.second;
2491 break;
2492 }
2493 });
2494 printer << ">";
2495 }
2496
parse(MLIRContext * context,DialectAsmParser & parser,Type type)2497 Attribute LoopOptionsAttr::parse(MLIRContext *context, DialectAsmParser &parser,
2498 Type type) {
2499 if (failed(parser.parseLess()))
2500 return {};
2501
2502 SmallVector<std::pair<LoopOptionCase, int64_t>> options;
2503 llvm::SmallDenseSet<LoopOptionCase> seenOptions;
2504 do {
2505 StringRef optionName;
2506 if (parser.parseKeyword(&optionName))
2507 return {};
2508
2509 auto option = symbolizeLoopOptionCase(optionName);
2510 if (!option) {
2511 parser.emitError(parser.getNameLoc(), "unknown loop option: ")
2512 << optionName;
2513 return {};
2514 }
2515 if (!seenOptions.insert(*option).second) {
2516 parser.emitError(parser.getNameLoc(), "loop option present twice");
2517 return {};
2518 }
2519 if (failed(parser.parseEqual()))
2520 return {};
2521
2522 int64_t value;
2523 switch (*option) {
2524 case LoopOptionCase::disable_licm:
2525 case LoopOptionCase::disable_unroll:
2526 case LoopOptionCase::disable_pipeline:
2527 if (succeeded(parser.parseOptionalKeyword("true")))
2528 value = 1;
2529 else if (succeeded(parser.parseOptionalKeyword("false")))
2530 value = 0;
2531 else {
2532 parser.emitError(parser.getNameLoc(),
2533 "expected boolean value 'true' or 'false'");
2534 return {};
2535 }
2536 break;
2537 case LoopOptionCase::interleave_count:
2538 case LoopOptionCase::pipeline_initiation_interval:
2539 if (failed(parser.parseInteger(value))) {
2540 parser.emitError(parser.getNameLoc(), "expected integer value");
2541 return {};
2542 }
2543 break;
2544 }
2545 options.push_back(std::make_pair(*option, value));
2546 } while (succeeded(parser.parseOptionalComma()));
2547 if (failed(parser.parseGreater()))
2548 return {};
2549
2550 llvm::sort(options, llvm::less_first());
2551 return get(parser.getBuilder().getContext(), options);
2552 }
2553
parseAttribute(DialectAsmParser & parser,Type type) const2554 Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
2555 Type type) const {
2556 if (type) {
2557 parser.emitError(parser.getNameLoc(), "unexpected type");
2558 return {};
2559 }
2560 StringRef attrKind;
2561 if (parser.parseKeyword(&attrKind))
2562 return {};
2563 {
2564 Attribute attr;
2565 auto parseResult =
2566 generatedAttributeParser(getContext(), parser, attrKind, type, attr);
2567 if (parseResult.hasValue())
2568 return attr;
2569 }
2570 parser.emitError(parser.getNameLoc(), "unknown attribute type: ") << attrKind;
2571 return {};
2572 }
2573
printAttribute(Attribute attr,DialectAsmPrinter & os) const2574 void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
2575 if (succeeded(generatedAttributePrinter(attr, os)))
2576 return;
2577 llvm_unreachable("Unknown attribute type");
2578 }
2579