1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "TypeDetail.h"
15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 
23 #include "llvm/ADT/StringSwitch.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/AsmParser/Parser.h"
26 #include "llvm/Bitcode/BitcodeReader.h"
27 #include "llvm/Bitcode/BitcodeWriter.h"
28 #include "llvm/IR/Attributes.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Mutex.h"
32 #include "llvm/Support/SourceMgr.h"
33 
34 #include <iostream>
35 
36 using namespace mlir;
37 using namespace mlir::LLVM;
38 
39 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
40 
41 static constexpr const char kVolatileAttrName[] = "volatile_";
42 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
43 
44 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
45 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
46 #define GET_ATTRDEF_CLASSES
47 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
48 
processFMFAttr(ArrayRef<NamedAttribute> attrs)49 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
50   SmallVector<NamedAttribute, 8> filteredAttrs(
51       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
52         if (attr.first == "fastmathFlags") {
53           auto defAttr = FMFAttr::get(attr.second.getContext(), {});
54           return defAttr != attr.second;
55         }
56         return true;
57       }));
58   return filteredAttrs;
59 }
60 
parseLLVMOpAttrs(OpAsmParser & parser,NamedAttrList & result)61 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
62                                     NamedAttrList &result) {
63   return parser.parseOptionalAttrDict(result);
64 }
65 
printLLVMOpAttrs(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)66 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
67                              DictionaryAttr attrs) {
68   printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Printing/parsing for LLVM::CmpOp.
73 //===----------------------------------------------------------------------===//
printICmpOp(OpAsmPrinter & p,ICmpOp & op)74 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
75   p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
76     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
77   p.printOptionalAttrDict(op->getAttrs(), {"predicate"});
78   p << " : " << op.lhs().getType();
79 }
80 
printFCmpOp(OpAsmPrinter & p,FCmpOp & op)81 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
82   p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
83     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
84   p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"});
85   p << " : " << op.lhs().getType();
86 }
87 
88 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
89 //                 attribute-dict? `:` type
90 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
91 //                 attribute-dict? `:` type
92 template <typename CmpPredicateType>
parseCmpOp(OpAsmParser & parser,OperationState & result)93 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
94   Builder &builder = parser.getBuilder();
95 
96   StringAttr predicateAttr;
97   OpAsmParser::OperandType lhs, rhs;
98   Type type;
99   llvm::SMLoc predicateLoc, trailingTypeLoc;
100   if (parser.getCurrentLocation(&predicateLoc) ||
101       parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
102       parser.parseOperand(lhs) || parser.parseComma() ||
103       parser.parseOperand(rhs) ||
104       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
105       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
106       parser.resolveOperand(lhs, type, result.operands) ||
107       parser.resolveOperand(rhs, type, result.operands))
108     return failure();
109 
110   // Replace the string attribute `predicate` with an integer attribute.
111   int64_t predicateValue = 0;
112   if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
113     Optional<ICmpPredicate> predicate =
114         symbolizeICmpPredicate(predicateAttr.getValue());
115     if (!predicate)
116       return parser.emitError(predicateLoc)
117              << "'" << predicateAttr.getValue()
118              << "' is an incorrect value of the 'predicate' attribute";
119     predicateValue = static_cast<int64_t>(predicate.getValue());
120   } else {
121     Optional<FCmpPredicate> predicate =
122         symbolizeFCmpPredicate(predicateAttr.getValue());
123     if (!predicate)
124       return parser.emitError(predicateLoc)
125              << "'" << predicateAttr.getValue()
126              << "' is an incorrect value of the 'predicate' attribute";
127     predicateValue = static_cast<int64_t>(predicate.getValue());
128   }
129 
130   result.attributes.set("predicate",
131                         parser.getBuilder().getI64IntegerAttr(predicateValue));
132 
133   // The result type is either i1 or a vector type <? x i1> if the inputs are
134   // vectors.
135   Type resultType = IntegerType::get(builder.getContext(), 1);
136   if (!isCompatibleType(type))
137     return parser.emitError(trailingTypeLoc,
138                             "expected LLVM dialect-compatible type");
139   if (LLVM::isCompatibleVectorType(type)) {
140     if (type.isa<LLVM::LLVMScalableVectorType>()) {
141       resultType = LLVM::LLVMScalableVectorType::get(
142           resultType, LLVM::getVectorNumElements(type).getKnownMinValue());
143     } else {
144       resultType = LLVM::getFixedVectorType(
145           resultType, LLVM::getVectorNumElements(type).getFixedValue());
146     }
147   }
148 
149   result.addTypes({resultType});
150   return success();
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // Printing/parsing for LLVM::AllocaOp.
155 //===----------------------------------------------------------------------===//
156 
printAllocaOp(OpAsmPrinter & p,AllocaOp & op)157 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
158   auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType();
159 
160   auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
161                                   {op.getType()});
162 
163   p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
164   if (op.alignment().hasValue() && *op.alignment() != 0)
165     p.printOptionalAttrDict(op->getAttrs());
166   else
167     p.printOptionalAttrDict(op->getAttrs(), {"alignment"});
168   p << " : " << funcTy;
169 }
170 
171 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
172 //                 `:` type `,` type
parseAllocaOp(OpAsmParser & parser,OperationState & result)173 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
174   OpAsmParser::OperandType arraySize;
175   Type type, elemType;
176   llvm::SMLoc trailingTypeLoc;
177   if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
178       parser.parseType(elemType) ||
179       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
180       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
181     return failure();
182 
183   Optional<NamedAttribute> alignmentAttr =
184       result.attributes.getNamed("alignment");
185   if (alignmentAttr.hasValue()) {
186     auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>();
187     if (!alignmentInt)
188       return parser.emitError(parser.getNameLoc(),
189                               "expected integer alignment");
190     if (alignmentInt.getValue().isNullValue())
191       result.attributes.erase("alignment");
192   }
193 
194   // Extract the result type from the trailing function type.
195   auto funcType = type.dyn_cast<FunctionType>();
196   if (!funcType || funcType.getNumInputs() != 1 ||
197       funcType.getNumResults() != 1)
198     return parser.emitError(
199         trailingTypeLoc,
200         "expected trailing function type with one argument and one result");
201 
202   if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
203     return failure();
204 
205   result.addTypes({funcType.getResult(0)});
206   return success();
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // LLVM::BrOp
211 //===----------------------------------------------------------------------===//
212 
213 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)214 BrOp::getMutableSuccessorOperands(unsigned index) {
215   assert(index == 0 && "invalid successor index");
216   return destOperandsMutable();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // LLVM::CondBrOp
221 //===----------------------------------------------------------------------===//
222 
223 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)224 CondBrOp::getMutableSuccessorOperands(unsigned index) {
225   assert(index < getNumSuccessors() && "invalid successor index");
226   return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // LLVM::SwitchOp
231 //===----------------------------------------------------------------------===//
232 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<int32_t> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands,ArrayRef<int32_t> branchWeights)233 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
234                      Block *defaultDestination, ValueRange defaultOperands,
235                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
236                      ArrayRef<ValueRange> caseOperands,
237                      ArrayRef<int32_t> branchWeights) {
238   SmallVector<Value> flattenedCaseOperands;
239   SmallVector<int32_t> caseOperandOffsets;
240   int32_t offset = 0;
241   for (ValueRange operands : caseOperands) {
242     flattenedCaseOperands.append(operands.begin(), operands.end());
243     caseOperandOffsets.push_back(offset);
244     offset += operands.size();
245   }
246   ElementsAttr caseValuesAttr;
247   if (!caseValues.empty())
248     caseValuesAttr = builder.getI32VectorAttr(caseValues);
249   ElementsAttr caseOperandOffsetsAttr;
250   if (!caseOperandOffsets.empty())
251     caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
252 
253   ElementsAttr weightsAttr;
254   if (!branchWeights.empty())
255     weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
256 
257   build(builder, result, value, defaultOperands, flattenedCaseOperands,
258         caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination,
259         caseDestinations);
260 }
261 
262 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
263 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
264 static ParseResult
parseSwitchOpCases(OpAsmParser & parser,ElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<OpAsmParser::OperandType> & caseOperands,SmallVectorImpl<Type> & caseOperandTypes,ElementsAttr & caseOperandOffsets)265 parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
266                    SmallVectorImpl<Block *> &caseDestinations,
267                    SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
268                    SmallVectorImpl<Type> &caseOperandTypes,
269                    ElementsAttr &caseOperandOffsets) {
270   SmallVector<int32_t> values;
271   SmallVector<int32_t> offsets;
272   int32_t value, offset = 0;
273   do {
274     OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
275     if (values.empty() && !integerParseResult.hasValue())
276       return success();
277 
278     if (!integerParseResult.hasValue() || integerParseResult.getValue())
279       return failure();
280     values.push_back(value);
281 
282     Block *destination;
283     SmallVector<OpAsmParser::OperandType> operands;
284     if (parser.parseColon() || parser.parseSuccessor(destination))
285       return failure();
286     if (!parser.parseOptionalLParen()) {
287       if (parser.parseRegionArgumentList(operands) ||
288           parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen())
289         return failure();
290     }
291     caseDestinations.push_back(destination);
292     caseOperands.append(operands.begin(), operands.end());
293     offsets.push_back(offset);
294     offset += operands.size();
295   } while (!parser.parseOptionalComma());
296 
297   Builder &builder = parser.getBuilder();
298   caseValues = builder.getI32VectorAttr(values);
299   caseOperandOffsets = builder.getI32VectorAttr(offsets);
300 
301   return success();
302 }
303 
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,ElementsAttr caseValues,SuccessorRange caseDestinations,OperandRange caseOperands,TypeRange caseOperandTypes,ElementsAttr caseOperandOffsets)304 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
305                                ElementsAttr caseValues,
306                                SuccessorRange caseDestinations,
307                                OperandRange caseOperands,
308                                TypeRange caseOperandTypes,
309                                ElementsAttr caseOperandOffsets) {
310   if (!caseValues)
311     return;
312 
313   size_t index = 0;
314   llvm::interleave(
315       llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
316       [&](auto i) {
317         p << "  ";
318         p << std::get<0>(i).getLimitedValue();
319         p << ": ";
320         p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++));
321       },
322       [&] {
323         p << ',';
324         p.printNewline();
325       });
326   p.printNewline();
327 }
328 
verify(SwitchOp op)329 static LogicalResult verify(SwitchOp op) {
330   if ((!op.case_values() && !op.caseDestinations().empty()) ||
331       (op.case_values() &&
332        op.case_values()->size() !=
333            static_cast<int64_t>(op.caseDestinations().size())))
334     return op.emitOpError("expects number of case values to match number of "
335                           "case destinations");
336   if (op.branch_weights() &&
337       op.branch_weights()->size() != op.getNumSuccessors())
338     return op.emitError("expects number of branch weights to match number of "
339                         "successors: ")
340            << op.branch_weights()->size() << " vs " << op.getNumSuccessors();
341   return success();
342 }
343 
getCaseOperands(unsigned index)344 OperandRange SwitchOp::getCaseOperands(unsigned index) {
345   return getCaseOperandsMutable(index);
346 }
347 
getCaseOperandsMutable(unsigned index)348 MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
349   MutableOperandRange caseOperands = caseOperandsMutable();
350   if (!case_operand_offsets()) {
351     assert(caseOperands.size() == 0 &&
352            "non-empty case operands must have offsets");
353     return caseOperands;
354   }
355 
356   ElementsAttr offsets = case_operand_offsets().getValue();
357   assert(index < offsets.size() && "invalid case operand offset index");
358 
359   int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
360   int64_t end = index + 1 == offsets.size()
361                     ? caseOperands.size()
362                     : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
363   return caseOperandsMutable().slice(begin, end - begin);
364 }
365 
366 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)367 SwitchOp::getMutableSuccessorOperands(unsigned index) {
368   assert(index < getNumSuccessors() && "invalid successor index");
369   return index == 0 ? defaultOperandsMutable()
370                     : getCaseOperandsMutable(index - 1);
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // Builder, printer and parser for for LLVM::LoadOp.
375 //===----------------------------------------------------------------------===//
376 
verifyAccessGroups(Operation * op)377 static LogicalResult verifyAccessGroups(Operation *op) {
378   if (Attribute attribute =
379           op->getAttr(LLVMDialect::getAccessGroupsAttrName())) {
380     // The attribute is already verified to be a symbol ref array attribute via
381     // a constraint in the operation definition.
382     for (SymbolRefAttr accessGroupRef :
383          attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
384       StringRef metadataName = accessGroupRef.getRootReference();
385       auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
386           op->getParentOp(), metadataName);
387       if (!metadataOp)
388         return op->emitOpError() << "expected '" << accessGroupRef
389                                  << "' to reference a metadata op";
390       StringRef accessGroupName = accessGroupRef.getLeafReference();
391       Operation *accessGroupOp =
392           SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
393       if (!accessGroupOp)
394         return op->emitOpError() << "expected '" << accessGroupRef
395                                  << "' to reference an access_group op";
396     }
397   }
398   return success();
399 }
400 
verify(LoadOp op)401 static LogicalResult verify(LoadOp op) {
402   return verifyAccessGroups(op.getOperation());
403 }
404 
build(OpBuilder & builder,OperationState & result,Type t,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)405 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
406                    Value addr, unsigned alignment, bool isVolatile,
407                    bool isNonTemporal) {
408   result.addOperands(addr);
409   result.addTypes(t);
410   if (isVolatile)
411     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
412   if (isNonTemporal)
413     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
414   if (alignment != 0)
415     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
416 }
417 
printLoadOp(OpAsmPrinter & p,LoadOp & op)418 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
419   p << op.getOperationName() << ' ';
420   if (op.volatile_())
421     p << "volatile ";
422   p << op.addr();
423   p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName});
424   p << " : " << op.addr().getType();
425 }
426 
427 // Extract the pointee type from the LLVM pointer type wrapped in MLIR.  Return
428 // the resulting type wrapped in MLIR, or nullptr on error.
getLoadStoreElementType(OpAsmParser & parser,Type type,llvm::SMLoc trailingTypeLoc)429 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
430                                     llvm::SMLoc trailingTypeLoc) {
431   auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
432   if (!llvmTy)
433     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
434            nullptr;
435   return llvmTy.getElementType();
436 }
437 
438 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
parseLoadOp(OpAsmParser & parser,OperationState & result)439 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
440   OpAsmParser::OperandType addr;
441   Type type;
442   llvm::SMLoc trailingTypeLoc;
443 
444   if (succeeded(parser.parseOptionalKeyword("volatile")))
445     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
446 
447   if (parser.parseOperand(addr) ||
448       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
449       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
450       parser.resolveOperand(addr, type, result.operands))
451     return failure();
452 
453   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
454 
455   result.addTypes(elemTy);
456   return success();
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // Builder, printer and parser for LLVM::StoreOp.
461 //===----------------------------------------------------------------------===//
462 
verify(StoreOp op)463 static LogicalResult verify(StoreOp op) {
464   return verifyAccessGroups(op.getOperation());
465 }
466 
build(OpBuilder & builder,OperationState & result,Value value,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)467 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
468                     Value addr, unsigned alignment, bool isVolatile,
469                     bool isNonTemporal) {
470   result.addOperands({value, addr});
471   result.addTypes({});
472   if (isVolatile)
473     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
474   if (isNonTemporal)
475     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
476   if (alignment != 0)
477     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
478 }
479 
printStoreOp(OpAsmPrinter & p,StoreOp & op)480 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
481   p << op.getOperationName() << ' ';
482   if (op.volatile_())
483     p << "volatile ";
484   p << op.value() << ", " << op.addr();
485   p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName});
486   p << " : " << op.addr().getType();
487 }
488 
489 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
490 //                 attribute-dict? `:` type
parseStoreOp(OpAsmParser & parser,OperationState & result)491 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
492   OpAsmParser::OperandType addr, value;
493   Type type;
494   llvm::SMLoc trailingTypeLoc;
495 
496   if (succeeded(parser.parseOptionalKeyword("volatile")))
497     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
498 
499   if (parser.parseOperand(value) || parser.parseComma() ||
500       parser.parseOperand(addr) ||
501       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
502       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
503     return failure();
504 
505   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
506   if (!elemTy)
507     return failure();
508 
509   if (parser.resolveOperand(value, elemTy, result.operands) ||
510       parser.resolveOperand(addr, type, result.operands))
511     return failure();
512 
513   return success();
514 }
515 
516 ///===---------------------------------------------------------------------===//
517 /// LLVM::InvokeOp
518 ///===---------------------------------------------------------------------===//
519 
520 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)521 InvokeOp::getMutableSuccessorOperands(unsigned index) {
522   assert(index < getNumSuccessors() && "invalid successor index");
523   return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
524 }
525 
verify(InvokeOp op)526 static LogicalResult verify(InvokeOp op) {
527   if (op.getNumResults() > 1)
528     return op.emitOpError("must have 0 or 1 result");
529 
530   Block *unwindDest = op.unwindDest();
531   if (unwindDest->empty())
532     return op.emitError(
533         "must have at least one operation in unwind destination");
534 
535   // In unwind destination, first operation must be LandingpadOp
536   if (!isa<LandingpadOp>(unwindDest->front()))
537     return op.emitError("first operation in unwind destination should be a "
538                         "llvm.landingpad operation");
539 
540   return success();
541 }
542 
printInvokeOp(OpAsmPrinter & p,InvokeOp op)543 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) {
544   auto callee = op.callee();
545   bool isDirect = callee.hasValue();
546 
547   p << op.getOperationName() << ' ';
548 
549   // Either function name or pointer
550   if (isDirect)
551     p.printSymbolName(callee.getValue());
552   else
553     p << op.getOperand(0);
554 
555   p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
556   p << " to ";
557   p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands());
558   p << " unwind ";
559   p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands());
560 
561   p.printOptionalAttrDict(op->getAttrs(),
562                           {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
563   p << " : ";
564   p.printFunctionalType(
565       llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1),
566       op.getResultTypes());
567 }
568 
569 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
570 ///                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
571 ///                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
572 ///                  attribute-dict? `:` function-type
parseInvokeOp(OpAsmParser & parser,OperationState & result)573 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
574   SmallVector<OpAsmParser::OperandType, 8> operands;
575   FunctionType funcType;
576   SymbolRefAttr funcAttr;
577   llvm::SMLoc trailingTypeLoc;
578   Block *normalDest, *unwindDest;
579   SmallVector<Value, 4> normalOperands, unwindOperands;
580   Builder &builder = parser.getBuilder();
581 
582   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
583   // case of an indirect call, there will be 1 operand before `(`.  In case of a
584   // direct call, there will be no operands and the parser will stop at the
585   // function identifier without complaining.
586   if (parser.parseOperandList(operands))
587     return failure();
588   bool isDirect = operands.empty();
589 
590   // Optionally parse a function identifier.
591   if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
592     return failure();
593 
594   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
595       parser.parseKeyword("to") ||
596       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
597       parser.parseKeyword("unwind") ||
598       parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
599       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
600       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
601     return failure();
602 
603   if (isDirect) {
604     // Make sure types match.
605     if (parser.resolveOperands(operands, funcType.getInputs(),
606                                parser.getNameLoc(), result.operands))
607       return failure();
608     result.addTypes(funcType.getResults());
609   } else {
610     // Construct the LLVM IR Dialect function type that the first operand
611     // should match.
612     if (funcType.getNumResults() > 1)
613       return parser.emitError(trailingTypeLoc,
614                               "expected function with 0 or 1 result");
615 
616     Type llvmResultType;
617     if (funcType.getNumResults() == 0) {
618       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
619     } else {
620       llvmResultType = funcType.getResult(0);
621       if (!isCompatibleType(llvmResultType))
622         return parser.emitError(trailingTypeLoc,
623                                 "expected result to have LLVM type");
624     }
625 
626     SmallVector<Type, 8> argTypes;
627     argTypes.reserve(funcType.getNumInputs());
628     for (Type ty : funcType.getInputs()) {
629       if (isCompatibleType(ty))
630         argTypes.push_back(ty);
631       else
632         return parser.emitError(trailingTypeLoc,
633                                 "expected LLVM types as inputs");
634     }
635 
636     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
637     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
638 
639     auto funcArguments = llvm::makeArrayRef(operands).drop_front();
640 
641     // Make sure that the first operand (indirect callee) matches the wrapped
642     // LLVM IR function type, and that the types of the other call operands
643     // match the types of the function arguments.
644     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
645         parser.resolveOperands(funcArguments, funcType.getInputs(),
646                                parser.getNameLoc(), result.operands))
647       return failure();
648 
649     result.addTypes(llvmResultType);
650   }
651   result.addSuccessors({normalDest, unwindDest});
652   result.addOperands(normalOperands);
653   result.addOperands(unwindOperands);
654 
655   result.addAttribute(
656       InvokeOp::getOperandSegmentSizeAttr(),
657       builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
658                                 static_cast<int32_t>(normalOperands.size()),
659                                 static_cast<int32_t>(unwindOperands.size())}));
660   return success();
661 }
662 
663 ///===----------------------------------------------------------------------===//
664 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
665 ///===----------------------------------------------------------------------===//
666 
verify(LandingpadOp op)667 static LogicalResult verify(LandingpadOp op) {
668   Value value;
669   if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) {
670     if (!func.personality().hasValue())
671       return op.emitError(
672           "llvm.landingpad needs to be in a function with a personality");
673   }
674 
675   if (!op.cleanup() && op.getOperands().empty())
676     return op.emitError("landingpad instruction expects at least one clause or "
677                         "cleanup attribute");
678 
679   for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
680     value = op.getOperand(idx);
681     bool isFilter = value.getType().isa<LLVMArrayType>();
682     if (isFilter) {
683       // FIXME: Verify filter clauses when arrays are appropriately handled
684     } else {
685       // catch - global addresses only.
686       // Bitcast ops should have global addresses as their args.
687       if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
688         if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>())
689           continue;
690         return op.emitError("constant clauses expected")
691                    .attachNote(bcOp.getLoc())
692                << "global addresses expected as operand to "
693                   "bitcast used in clauses for landingpad";
694       }
695       // NullOp and AddressOfOp allowed
696       if (value.getDefiningOp<NullOp>())
697         continue;
698       if (value.getDefiningOp<AddressOfOp>())
699         continue;
700       return op.emitError("clause #")
701              << idx << " is not a known constant - null, addressof, bitcast";
702     }
703   }
704   return success();
705 }
706 
printLandingpadOp(OpAsmPrinter & p,LandingpadOp & op)707 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
708   p << op.getOperationName() << (op.cleanup() ? " cleanup " : " ");
709 
710   // Clauses
711   for (auto value : op.getOperands()) {
712     // Similar to llvm - if clause is an array type then it is filter
713     // clause else catch clause
714     bool isArrayTy = value.getType().isa<LLVMArrayType>();
715     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
716       << value.getType() << ") ";
717   }
718 
719   p.printOptionalAttrDict(op->getAttrs(), {"cleanup"});
720 
721   p << ": " << op.getType();
722 }
723 
724 /// <operation> ::= `llvm.landingpad` `cleanup`?
725 ///                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
parseLandingpadOp(OpAsmParser & parser,OperationState & result)726 static ParseResult parseLandingpadOp(OpAsmParser &parser,
727                                      OperationState &result) {
728   // Check for cleanup
729   if (succeeded(parser.parseOptionalKeyword("cleanup")))
730     result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
731 
732   // Parse clauses with types
733   while (succeeded(parser.parseOptionalLParen()) &&
734          (succeeded(parser.parseOptionalKeyword("filter")) ||
735           succeeded(parser.parseOptionalKeyword("catch")))) {
736     OpAsmParser::OperandType operand;
737     Type ty;
738     if (parser.parseOperand(operand) || parser.parseColon() ||
739         parser.parseType(ty) ||
740         parser.resolveOperand(operand, ty, result.operands) ||
741         parser.parseRParen())
742       return failure();
743   }
744 
745   Type type;
746   if (parser.parseColon() || parser.parseType(type))
747     return failure();
748 
749   result.addTypes(type);
750   return success();
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // Verifying/Printing/parsing for LLVM::CallOp.
755 //===----------------------------------------------------------------------===//
756 
verify(CallOp & op)757 static LogicalResult verify(CallOp &op) {
758   if (op.getNumResults() > 1)
759     return op.emitOpError("must have 0 or 1 result");
760 
761   // Type for the callee, we'll get it differently depending if it is a direct
762   // or indirect call.
763   Type fnType;
764 
765   bool isIndirect = false;
766 
767   // If this is an indirect call, the callee attribute is missing.
768   Optional<StringRef> calleeName = op.callee();
769   if (!calleeName) {
770     isIndirect = true;
771     if (!op.getNumOperands())
772       return op.emitOpError(
773           "must have either a `callee` attribute or at least an operand");
774     auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>();
775     if (!ptrType)
776       return op.emitOpError("indirect call expects a pointer as callee: ")
777              << ptrType;
778     fnType = ptrType.getElementType();
779   } else {
780     Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
781     if (!callee)
782       return op.emitOpError()
783              << "'" << *calleeName
784              << "' does not reference a symbol in the current scope";
785     auto fn = dyn_cast<LLVMFuncOp>(callee);
786     if (!fn)
787       return op.emitOpError() << "'" << *calleeName
788                               << "' does not reference a valid LLVM function";
789 
790     fnType = fn.getType();
791   }
792 
793   LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
794   if (!funcType)
795     return op.emitOpError("callee does not have a functional type: ") << fnType;
796 
797   // Verify that the operand and result types match the callee.
798 
799   if (!funcType.isVarArg() &&
800       funcType.getNumParams() != (op.getNumOperands() - isIndirect))
801     return op.emitOpError()
802            << "incorrect number of operands ("
803            << (op.getNumOperands() - isIndirect)
804            << ") for callee (expecting: " << funcType.getNumParams() << ")";
805 
806   if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
807     return op.emitOpError() << "incorrect number of operands ("
808                             << (op.getNumOperands() - isIndirect)
809                             << ") for varargs callee (expecting at least: "
810                             << funcType.getNumParams() << ")";
811 
812   for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
813     if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
814       return op.emitOpError() << "operand type mismatch for operand " << i
815                               << ": " << op.getOperand(i + isIndirect).getType()
816                               << " != " << funcType.getParamType(i);
817 
818   if (op.getNumResults() &&
819       op.getResult(0).getType() != funcType.getReturnType())
820     return op.emitOpError()
821            << "result type mismatch: " << op.getResult(0).getType()
822            << " != " << funcType.getReturnType();
823 
824   return success();
825 }
826 
printCallOp(OpAsmPrinter & p,CallOp & op)827 static void printCallOp(OpAsmPrinter &p, CallOp &op) {
828   auto callee = op.callee();
829   bool isDirect = callee.hasValue();
830 
831   // Print the direct callee if present as a function attribute, or an indirect
832   // callee (first operand) otherwise.
833   p << op.getOperationName() << ' ';
834   if (isDirect)
835     p.printSymbolName(callee.getValue());
836   else
837     p << op.getOperand(0);
838 
839   auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
840   p << '(' << args << ')';
841   p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"callee"});
842 
843   // Reconstruct the function MLIR function type from operand and result types.
844   p << " : "
845     << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes());
846 }
847 
848 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
849 //                 attribute-dict? `:` function-type
parseCallOp(OpAsmParser & parser,OperationState & result)850 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
851   SmallVector<OpAsmParser::OperandType, 8> operands;
852   Type type;
853   SymbolRefAttr funcAttr;
854   llvm::SMLoc trailingTypeLoc;
855 
856   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
857   // case of an indirect call, there will be 1 operand before `(`.  In case of a
858   // direct call, there will be no operands and the parser will stop at the
859   // function identifier without complaining.
860   if (parser.parseOperandList(operands))
861     return failure();
862   bool isDirect = operands.empty();
863 
864   // Optionally parse a function identifier.
865   if (isDirect)
866     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
867       return failure();
868 
869   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
870       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
871       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
872     return failure();
873 
874   auto funcType = type.dyn_cast<FunctionType>();
875   if (!funcType)
876     return parser.emitError(trailingTypeLoc, "expected function type");
877   if (isDirect) {
878     // Make sure types match.
879     if (parser.resolveOperands(operands, funcType.getInputs(),
880                                parser.getNameLoc(), result.operands))
881       return failure();
882     result.addTypes(funcType.getResults());
883   } else {
884     // Construct the LLVM IR Dialect function type that the first operand
885     // should match.
886     if (funcType.getNumResults() > 1)
887       return parser.emitError(trailingTypeLoc,
888                               "expected function with 0 or 1 result");
889 
890     Builder &builder = parser.getBuilder();
891     Type llvmResultType;
892     if (funcType.getNumResults() == 0) {
893       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
894     } else {
895       llvmResultType = funcType.getResult(0);
896       if (!isCompatibleType(llvmResultType))
897         return parser.emitError(trailingTypeLoc,
898                                 "expected result to have LLVM type");
899     }
900 
901     SmallVector<Type, 8> argTypes;
902     argTypes.reserve(funcType.getNumInputs());
903     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
904       auto argType = funcType.getInput(i);
905       if (!isCompatibleType(argType))
906         return parser.emitError(trailingTypeLoc,
907                                 "expected LLVM types as inputs");
908       argTypes.push_back(argType);
909     }
910     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
911     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
912 
913     auto funcArguments =
914         ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
915 
916     // Make sure that the first operand (indirect callee) matches the wrapped
917     // LLVM IR function type, and that the types of the other call operands
918     // match the types of the function arguments.
919     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
920         parser.resolveOperands(funcArguments, funcType.getInputs(),
921                                parser.getNameLoc(), result.operands))
922       return failure();
923 
924     result.addTypes(llvmResultType);
925   }
926 
927   return success();
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // Printing/parsing for LLVM::ExtractElementOp.
932 //===----------------------------------------------------------------------===//
933 // Expects vector to be of wrapped LLVM vector type and position to be of
934 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value vector,Value position,ArrayRef<NamedAttribute> attrs)935 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
936                                    Value vector, Value position,
937                                    ArrayRef<NamedAttribute> attrs) {
938   auto vectorType = vector.getType();
939   auto llvmType = LLVM::getVectorElementType(vectorType);
940   build(b, result, llvmType, vector, position);
941   result.addAttributes(attrs);
942 }
943 
printExtractElementOp(OpAsmPrinter & p,ExtractElementOp & op)944 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
945   p << op.getOperationName() << ' ' << op.vector() << "[" << op.position()
946     << " : " << op.position().getType() << "]";
947   p.printOptionalAttrDict(op->getAttrs());
948   p << " : " << op.vector().getType();
949 }
950 
951 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
952 //                 attribute-dict? `:` type
parseExtractElementOp(OpAsmParser & parser,OperationState & result)953 static ParseResult parseExtractElementOp(OpAsmParser &parser,
954                                          OperationState &result) {
955   llvm::SMLoc loc;
956   OpAsmParser::OperandType vector, position;
957   Type type, positionType;
958   if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
959       parser.parseLSquare() || parser.parseOperand(position) ||
960       parser.parseColonType(positionType) || parser.parseRSquare() ||
961       parser.parseOptionalAttrDict(result.attributes) ||
962       parser.parseColonType(type) ||
963       parser.resolveOperand(vector, type, result.operands) ||
964       parser.resolveOperand(position, positionType, result.operands))
965     return failure();
966   if (!LLVM::isCompatibleVectorType(type))
967     return parser.emitError(
968         loc, "expected LLVM dialect-compatible vector type for operand #1");
969   result.addTypes(LLVM::getVectorElementType(type));
970   return success();
971 }
972 
verify(ExtractElementOp op)973 static LogicalResult verify(ExtractElementOp op) {
974   Type vectorType = op.vector().getType();
975   if (!LLVM::isCompatibleVectorType(vectorType))
976     return op->emitOpError("expected LLVM dialect-compatible vector type for "
977                            "operand #1, got")
978            << vectorType;
979   Type valueType = LLVM::getVectorElementType(vectorType);
980   if (valueType != op.res().getType())
981     return op.emitOpError() << "Type mismatch: extracting from " << vectorType
982                             << " should produce " << valueType
983                             << " but this op returns " << op.res().getType();
984   return success();
985 }
986 
987 //===----------------------------------------------------------------------===//
988 // Printing/parsing for LLVM::ExtractValueOp.
989 //===----------------------------------------------------------------------===//
990 
printExtractValueOp(OpAsmPrinter & p,ExtractValueOp & op)991 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
992   p << op.getOperationName() << ' ' << op.container() << op.position();
993   p.printOptionalAttrDict(op->getAttrs(), {"position"});
994   p << " : " << op.container().getType();
995 }
996 
997 // Extract the type at `position` in the wrapped LLVM IR aggregate type
998 // `containerType`.  Position is an integer array attribute where each value
999 // is a zero-based position of the element in the aggregate type.  Return the
1000 // resulting type wrapped in MLIR, or nullptr on error.
getInsertExtractValueElementType(OpAsmParser & parser,Type containerType,ArrayAttr positionAttr,llvm::SMLoc attributeLoc,llvm::SMLoc typeLoc)1001 static Type getInsertExtractValueElementType(OpAsmParser &parser,
1002                                              Type containerType,
1003                                              ArrayAttr positionAttr,
1004                                              llvm::SMLoc attributeLoc,
1005                                              llvm::SMLoc typeLoc) {
1006   Type llvmType = containerType;
1007   if (!isCompatibleType(containerType))
1008     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
1009 
1010   // Infer the element type from the structure type: iteratively step inside the
1011   // type by taking the element type, indexed by the position attribute for
1012   // structures.  Check the position index before accessing, it is supposed to
1013   // be in bounds.
1014   for (Attribute subAttr : positionAttr) {
1015     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1016     if (!positionElementAttr)
1017       return parser.emitError(attributeLoc,
1018                               "expected an array of integer literals"),
1019              nullptr;
1020     int position = positionElementAttr.getInt();
1021     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1022       if (position < 0 ||
1023           static_cast<unsigned>(position) >= arrayType.getNumElements())
1024         return parser.emitError(attributeLoc, "position out of bounds"),
1025                nullptr;
1026       llvmType = arrayType.getElementType();
1027     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1028       if (position < 0 ||
1029           static_cast<unsigned>(position) >= structType.getBody().size())
1030         return parser.emitError(attributeLoc, "position out of bounds"),
1031                nullptr;
1032       llvmType = structType.getBody()[position];
1033     } else {
1034       return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
1035              nullptr;
1036     }
1037   }
1038   return llvmType;
1039 }
1040 
1041 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1042 // `containerType`. Returns null on failure.
getInsertExtractValueElementType(Type containerType,ArrayAttr positionAttr,Operation * op)1043 static Type getInsertExtractValueElementType(Type containerType,
1044                                              ArrayAttr positionAttr,
1045                                              Operation *op) {
1046   Type llvmType = containerType;
1047   if (!isCompatibleType(containerType)) {
1048     op->emitError("expected LLVM IR Dialect type, got ") << containerType;
1049     return {};
1050   }
1051 
1052   // Infer the element type from the structure type: iteratively step inside the
1053   // type by taking the element type, indexed by the position attribute for
1054   // structures.  Check the position index before accessing, it is supposed to
1055   // be in bounds.
1056   for (Attribute subAttr : positionAttr) {
1057     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1058     if (!positionElementAttr) {
1059       op->emitOpError("expected an array of integer literals, got: ")
1060           << subAttr;
1061       return {};
1062     }
1063     int position = positionElementAttr.getInt();
1064     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1065       if (position < 0 ||
1066           static_cast<unsigned>(position) >= arrayType.getNumElements()) {
1067         op->emitOpError("position out of bounds: ") << position;
1068         return {};
1069       }
1070       llvmType = arrayType.getElementType();
1071     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1072       if (position < 0 ||
1073           static_cast<unsigned>(position) >= structType.getBody().size()) {
1074         op->emitOpError("position out of bounds") << position;
1075         return {};
1076       }
1077       llvmType = structType.getBody()[position];
1078     } else {
1079       op->emitOpError("expected LLVM IR structure/array type, got: ")
1080           << llvmType;
1081       return {};
1082     }
1083   }
1084   return llvmType;
1085 }
1086 
1087 // <operation> ::= `llvm.extractvalue` ssa-use
1088 //                 `[` integer-literal (`,` integer-literal)* `]`
1089 //                 attribute-dict? `:` type
parseExtractValueOp(OpAsmParser & parser,OperationState & result)1090 static ParseResult parseExtractValueOp(OpAsmParser &parser,
1091                                        OperationState &result) {
1092   OpAsmParser::OperandType container;
1093   Type containerType;
1094   ArrayAttr positionAttr;
1095   llvm::SMLoc attributeLoc, trailingTypeLoc;
1096 
1097   if (parser.parseOperand(container) ||
1098       parser.getCurrentLocation(&attributeLoc) ||
1099       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1100       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1101       parser.getCurrentLocation(&trailingTypeLoc) ||
1102       parser.parseType(containerType) ||
1103       parser.resolveOperand(container, containerType, result.operands))
1104     return failure();
1105 
1106   auto elementType = getInsertExtractValueElementType(
1107       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1108   if (!elementType)
1109     return failure();
1110 
1111   result.addTypes(elementType);
1112   return success();
1113 }
1114 
fold(ArrayRef<Attribute> operands)1115 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
1116   auto insertValueOp = container().getDefiningOp<InsertValueOp>();
1117   while (insertValueOp) {
1118     if (position() == insertValueOp.position())
1119       return insertValueOp.value();
1120     insertValueOp = insertValueOp.container().getDefiningOp<InsertValueOp>();
1121   }
1122   return {};
1123 }
1124 
verify(ExtractValueOp op)1125 static LogicalResult verify(ExtractValueOp op) {
1126   Type valueType = getInsertExtractValueElementType(op.container().getType(),
1127                                                     op.positionAttr(), op);
1128   if (!valueType)
1129     return failure();
1130 
1131   if (op.res().getType() != valueType)
1132     return op.emitOpError()
1133            << "Type mismatch: extracting from " << op.container().getType()
1134            << " should produce " << valueType << " but this op returns "
1135            << op.res().getType();
1136   return success();
1137 }
1138 
1139 //===----------------------------------------------------------------------===//
1140 // Printing/parsing for LLVM::InsertElementOp.
1141 //===----------------------------------------------------------------------===//
1142 
printInsertElementOp(OpAsmPrinter & p,InsertElementOp & op)1143 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
1144   p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "["
1145     << op.position() << " : " << op.position().getType() << "]";
1146   p.printOptionalAttrDict(op->getAttrs());
1147   p << " : " << op.vector().getType();
1148 }
1149 
1150 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
1151 //                 attribute-dict? `:` type
parseInsertElementOp(OpAsmParser & parser,OperationState & result)1152 static ParseResult parseInsertElementOp(OpAsmParser &parser,
1153                                         OperationState &result) {
1154   llvm::SMLoc loc;
1155   OpAsmParser::OperandType vector, value, position;
1156   Type vectorType, positionType;
1157   if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
1158       parser.parseComma() || parser.parseOperand(vector) ||
1159       parser.parseLSquare() || parser.parseOperand(position) ||
1160       parser.parseColonType(positionType) || parser.parseRSquare() ||
1161       parser.parseOptionalAttrDict(result.attributes) ||
1162       parser.parseColonType(vectorType))
1163     return failure();
1164 
1165   if (!LLVM::isCompatibleVectorType(vectorType))
1166     return parser.emitError(
1167         loc, "expected LLVM dialect-compatible vector type for operand #1");
1168   Type valueType = LLVM::getVectorElementType(vectorType);
1169   if (!valueType)
1170     return failure();
1171 
1172   if (parser.resolveOperand(vector, vectorType, result.operands) ||
1173       parser.resolveOperand(value, valueType, result.operands) ||
1174       parser.resolveOperand(position, positionType, result.operands))
1175     return failure();
1176 
1177   result.addTypes(vectorType);
1178   return success();
1179 }
1180 
verify(InsertElementOp op)1181 static LogicalResult verify(InsertElementOp op) {
1182   Type valueType = LLVM::getVectorElementType(op.vector().getType());
1183   if (valueType != op.value().getType())
1184     return op.emitOpError()
1185            << "Type mismatch: cannot insert " << op.value().getType()
1186            << " into " << op.vector().getType();
1187   return success();
1188 }
1189 //===----------------------------------------------------------------------===//
1190 // Printing/parsing for LLVM::InsertValueOp.
1191 //===----------------------------------------------------------------------===//
1192 
printInsertValueOp(OpAsmPrinter & p,InsertValueOp & op)1193 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
1194   p << op.getOperationName() << ' ' << op.value() << ", " << op.container()
1195     << op.position();
1196   p.printOptionalAttrDict(op->getAttrs(), {"position"});
1197   p << " : " << op.container().getType();
1198 }
1199 
1200 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
1201 //                 `[` integer-literal (`,` integer-literal)* `]`
1202 //                 attribute-dict? `:` type
parseInsertValueOp(OpAsmParser & parser,OperationState & result)1203 static ParseResult parseInsertValueOp(OpAsmParser &parser,
1204                                       OperationState &result) {
1205   OpAsmParser::OperandType container, value;
1206   Type containerType;
1207   ArrayAttr positionAttr;
1208   llvm::SMLoc attributeLoc, trailingTypeLoc;
1209 
1210   if (parser.parseOperand(value) || parser.parseComma() ||
1211       parser.parseOperand(container) ||
1212       parser.getCurrentLocation(&attributeLoc) ||
1213       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1214       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1215       parser.getCurrentLocation(&trailingTypeLoc) ||
1216       parser.parseType(containerType))
1217     return failure();
1218 
1219   auto valueType = getInsertExtractValueElementType(
1220       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1221   if (!valueType)
1222     return failure();
1223 
1224   if (parser.resolveOperand(container, containerType, result.operands) ||
1225       parser.resolveOperand(value, valueType, result.operands))
1226     return failure();
1227 
1228   result.addTypes(containerType);
1229   return success();
1230 }
1231 
verify(InsertValueOp op)1232 static LogicalResult verify(InsertValueOp op) {
1233   Type valueType = getInsertExtractValueElementType(op.container().getType(),
1234                                                     op.positionAttr(), op);
1235   if (!valueType)
1236     return failure();
1237 
1238   if (op.value().getType() != valueType)
1239     return op.emitOpError()
1240            << "Type mismatch: cannot insert " << op.value().getType()
1241            << " into " << op.container().getType();
1242 
1243   return success();
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // Printing, parsing and verification for LLVM::ReturnOp.
1248 //===----------------------------------------------------------------------===//
1249 
printReturnOp(OpAsmPrinter & p,ReturnOp op)1250 static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
1251   p << op.getOperationName();
1252   p.printOptionalAttrDict(op->getAttrs());
1253   assert(op.getNumOperands() <= 1);
1254 
1255   if (op.getNumOperands() == 0)
1256     return;
1257 
1258   p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType();
1259 }
1260 
1261 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
1262 //                 type-list-no-parens
parseReturnOp(OpAsmParser & parser,OperationState & result)1263 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
1264   SmallVector<OpAsmParser::OperandType, 1> operands;
1265   Type type;
1266 
1267   if (parser.parseOperandList(operands) ||
1268       parser.parseOptionalAttrDict(result.attributes))
1269     return failure();
1270   if (operands.empty())
1271     return success();
1272 
1273   if (parser.parseColonType(type) ||
1274       parser.resolveOperand(operands[0], type, result.operands))
1275     return failure();
1276   return success();
1277 }
1278 
verify(ReturnOp op)1279 static LogicalResult verify(ReturnOp op) {
1280   if (op->getNumOperands() > 1)
1281     return op->emitOpError("expected at most 1 operand");
1282 
1283   if (auto parent = op->getParentOfType<LLVMFuncOp>()) {
1284     Type expectedType = parent.getType().getReturnType();
1285     if (expectedType.isa<LLVMVoidType>()) {
1286       if (op->getNumOperands() == 0)
1287         return success();
1288       InFlightDiagnostic diag = op->emitOpError("expected no operands");
1289       diag.attachNote(parent->getLoc()) << "when returning from function";
1290       return diag;
1291     }
1292     if (op->getNumOperands() == 0) {
1293       if (expectedType.isa<LLVMVoidType>())
1294         return success();
1295       InFlightDiagnostic diag = op->emitOpError("expected 1 operand");
1296       diag.attachNote(parent->getLoc()) << "when returning from function";
1297       return diag;
1298     }
1299     if (expectedType != op->getOperand(0).getType()) {
1300       InFlightDiagnostic diag = op->emitOpError("mismatching result types");
1301       diag.attachNote(parent->getLoc()) << "when returning from function";
1302       return diag;
1303     }
1304   }
1305   return success();
1306 }
1307 
1308 //===----------------------------------------------------------------------===//
1309 // Verifier for LLVM::AddressOfOp.
1310 //===----------------------------------------------------------------------===//
1311 
1312 template <typename OpTy>
lookupSymbolInModule(Operation * parent,StringRef name)1313 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
1314   Operation *module = parent;
1315   while (module && !satisfiesLLVMModule(module))
1316     module = module->getParentOp();
1317   assert(module && "unexpected operation outside of a module");
1318   return dyn_cast_or_null<OpTy>(
1319       mlir::SymbolTable::lookupSymbolIn(module, name));
1320 }
1321 
getGlobal()1322 GlobalOp AddressOfOp::getGlobal() {
1323   return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
1324                                               global_name());
1325 }
1326 
getFunction()1327 LLVMFuncOp AddressOfOp::getFunction() {
1328   return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
1329                                                 global_name());
1330 }
1331 
verify(AddressOfOp op)1332 static LogicalResult verify(AddressOfOp op) {
1333   auto global = op.getGlobal();
1334   auto function = op.getFunction();
1335   if (!global && !function)
1336     return op.emitOpError(
1337         "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1338 
1339   if (global &&
1340       LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) !=
1341           op.getResult().getType())
1342     return op.emitOpError(
1343         "the type must be a pointer to the type of the referenced global");
1344 
1345   if (function && LLVM::LLVMPointerType::get(function.getType()) !=
1346                       op.getResult().getType())
1347     return op.emitOpError(
1348         "the type must be a pointer to the type of the referenced function");
1349 
1350   return success();
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // Builder, printer and verifier for LLVM::GlobalOp.
1355 //===----------------------------------------------------------------------===//
1356 
1357 /// Returns the name used for the linkage attribute. This *must* correspond to
1358 /// the name of the attribute in ODS.
getLinkageAttrName()1359 static StringRef getLinkageAttrName() { return "linkage"; }
1360 
1361 /// Returns the name used for the unnamed_addr attribute. This *must* correspond
1362 /// to the name of the attribute in ODS.
getUnnamedAddrAttrName()1363 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; }
1364 
build(OpBuilder & builder,OperationState & result,Type type,bool isConstant,Linkage linkage,StringRef name,Attribute value,uint64_t alignment,unsigned addrSpace,bool dsoLocal,ArrayRef<NamedAttribute> attrs)1365 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1366                      bool isConstant, Linkage linkage, StringRef name,
1367                      Attribute value, uint64_t alignment, unsigned addrSpace,
1368                      bool dsoLocal, ArrayRef<NamedAttribute> attrs) {
1369   result.addAttribute(SymbolTable::getSymbolAttrName(),
1370                       builder.getStringAttr(name));
1371   result.addAttribute("type", TypeAttr::get(type));
1372   if (isConstant)
1373     result.addAttribute("constant", builder.getUnitAttr());
1374   if (value)
1375     result.addAttribute("value", value);
1376   if (dsoLocal)
1377     result.addAttribute("dso_local", builder.getUnitAttr());
1378 
1379   // Only add an alignment attribute if the "alignment" input
1380   // is different from 0. The value must also be a power of two, but
1381   // this is tested in GlobalOp::verify, not here.
1382   if (alignment != 0)
1383     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
1384 
1385   result.addAttribute(getLinkageAttrName(),
1386                       builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1387   if (addrSpace != 0)
1388     result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace));
1389   result.attributes.append(attrs.begin(), attrs.end());
1390   result.addRegion();
1391 }
1392 
printGlobalOp(OpAsmPrinter & p,GlobalOp op)1393 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
1394   p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' ';
1395   if (op.unnamed_addr())
1396     p << stringifyUnnamedAddr(*op.unnamed_addr()) << ' ';
1397   if (op.constant())
1398     p << "constant ";
1399   p.printSymbolName(op.sym_name());
1400   p << '(';
1401   if (auto value = op.getValueOrNull())
1402     p.printAttribute(value);
1403   p << ')';
1404   // Note that the alignment attribute is printed using the
1405   // default syntax here, even though it is an inherent attribute
1406   // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
1407   p.printOptionalAttrDict(op->getAttrs(),
1408                           {SymbolTable::getSymbolAttrName(), "type", "constant",
1409                            "value", getLinkageAttrName(),
1410                            getUnnamedAddrAttrName()});
1411 
1412   // Print the trailing type unless it's a string global.
1413   if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
1414     return;
1415   p << " : " << op.type();
1416 
1417   Region &initializer = op.getInitializerRegion();
1418   if (!initializer.empty())
1419     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1420 }
1421 
1422 // Parses one of the keywords provided in the list `keywords` and returns the
1423 // position of the parsed keyword in the list. If none of the keywords from the
1424 // list is parsed, returns -1.
parseOptionalKeywordAlternative(OpAsmParser & parser,ArrayRef<StringRef> keywords)1425 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1426                                            ArrayRef<StringRef> keywords) {
1427   for (auto en : llvm::enumerate(keywords)) {
1428     if (succeeded(parser.parseOptionalKeyword(en.value())))
1429       return en.index();
1430   }
1431   return -1;
1432 }
1433 
1434 namespace {
1435 template <typename Ty>
1436 struct EnumTraits {};
1437 
1438 #define REGISTER_ENUM_TYPE(Ty)                                                 \
1439   template <>                                                                  \
1440   struct EnumTraits<Ty> {                                                      \
1441     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
1442     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
1443   }
1444 
1445 REGISTER_ENUM_TYPE(Linkage);
1446 REGISTER_ENUM_TYPE(UnnamedAddr);
1447 } // end namespace
1448 
1449 template <typename EnumTy>
parseOptionalLLVMKeyword(OpAsmParser & parser,OperationState & result,StringRef name)1450 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
1451                                             OperationState &result,
1452                                             StringRef name) {
1453   SmallVector<StringRef, 10> names;
1454   for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
1455     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1456 
1457   int index = parseOptionalKeywordAlternative(parser, names);
1458   if (index == -1)
1459     return failure();
1460   result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
1461   return success();
1462 }
1463 
1464 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1465 //               `(` attribute? `)` align? attribute-list? (`:` type)? region?
1466 // align     ::= `align` `=` UINT64
1467 //
1468 // The type can be omitted for string attributes, in which case it will be
1469 // inferred from the value of the string as [strlen(value) x i8].
parseGlobalOp(OpAsmParser & parser,OperationState & result)1470 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
1471   if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1472                                                getLinkageAttrName())))
1473     result.addAttribute(getLinkageAttrName(),
1474                         parser.getBuilder().getI64IntegerAttr(
1475                             static_cast<int64_t>(LLVM::Linkage::External)));
1476 
1477   if (failed(parseOptionalLLVMKeyword<UnnamedAddr>(parser, result,
1478                                                    getUnnamedAddrAttrName())))
1479     result.addAttribute(getUnnamedAddrAttrName(),
1480                         parser.getBuilder().getI64IntegerAttr(
1481                             static_cast<int64_t>(LLVM::UnnamedAddr::None)));
1482 
1483   if (succeeded(parser.parseOptionalKeyword("constant")))
1484     result.addAttribute("constant", parser.getBuilder().getUnitAttr());
1485 
1486   StringAttr name;
1487   if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(),
1488                              result.attributes) ||
1489       parser.parseLParen())
1490     return failure();
1491 
1492   Attribute value;
1493   if (parser.parseOptionalRParen()) {
1494     if (parser.parseAttribute(value, "value", result.attributes) ||
1495         parser.parseRParen())
1496       return failure();
1497   }
1498 
1499   SmallVector<Type, 1> types;
1500   if (parser.parseOptionalAttrDict(result.attributes) ||
1501       parser.parseOptionalColonTypeList(types))
1502     return failure();
1503 
1504   if (types.size() > 1)
1505     return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1506 
1507   Region &initRegion = *result.addRegion();
1508   if (types.empty()) {
1509     if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1510       MLIRContext *context = parser.getBuilder().getContext();
1511       auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
1512                                                 strAttr.getValue().size());
1513       types.push_back(arrayType);
1514     } else {
1515       return parser.emitError(parser.getNameLoc(),
1516                               "type can only be omitted for string globals");
1517     }
1518   } else {
1519     OptionalParseResult parseResult =
1520         parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1521                                    /*argTypes=*/{});
1522     if (parseResult.hasValue() && failed(*parseResult))
1523       return failure();
1524   }
1525 
1526   result.addAttribute("type", TypeAttr::get(types[0]));
1527   return success();
1528 }
1529 
isZeroAttribute(Attribute value)1530 static bool isZeroAttribute(Attribute value) {
1531   if (auto intValue = value.dyn_cast<IntegerAttr>())
1532     return intValue.getValue().isNullValue();
1533   if (auto fpValue = value.dyn_cast<FloatAttr>())
1534     return fpValue.getValue().isZero();
1535   if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
1536     return isZeroAttribute(splatValue.getSplatValue());
1537   if (auto elementsValue = value.dyn_cast<ElementsAttr>())
1538     return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
1539   if (auto arrayValue = value.dyn_cast<ArrayAttr>())
1540     return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
1541   return false;
1542 }
1543 
verify(GlobalOp op)1544 static LogicalResult verify(GlobalOp op) {
1545   if (!LLVMPointerType::isValidElementType(op.getType()))
1546     return op.emitOpError(
1547         "expects type to be a valid element type for an LLVM pointer");
1548   if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp()))
1549     return op.emitOpError("must appear at the module level");
1550 
1551   if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1552     auto type = op.getType().dyn_cast<LLVMArrayType>();
1553     IntegerType elementType =
1554         type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
1555     if (!elementType || elementType.getWidth() != 8 ||
1556         type.getNumElements() != strAttr.getValue().size())
1557       return op.emitOpError(
1558           "requires an i8 array type of the length equal to that of the string "
1559           "attribute");
1560   }
1561 
1562   if (Block *b = op.getInitializerBlock()) {
1563     ReturnOp ret = cast<ReturnOp>(b->getTerminator());
1564     if (ret.operand_type_begin() == ret.operand_type_end())
1565       return op.emitOpError("initializer region cannot return void");
1566     if (*ret.operand_type_begin() != op.getType())
1567       return op.emitOpError("initializer region type ")
1568              << *ret.operand_type_begin() << " does not match global type "
1569              << op.getType();
1570 
1571     if (op.getValueOrNull())
1572       return op.emitOpError("cannot have both initializer value and region");
1573   }
1574 
1575   if (op.linkage() == Linkage::Common) {
1576     if (Attribute value = op.getValueOrNull()) {
1577       if (!isZeroAttribute(value)) {
1578         return op.emitOpError()
1579                << "expected zero value for '"
1580                << stringifyLinkage(Linkage::Common) << "' linkage";
1581       }
1582     }
1583   }
1584 
1585   if (op.linkage() == Linkage::Appending) {
1586     if (!op.getType().isa<LLVMArrayType>()) {
1587       return op.emitOpError()
1588              << "expected array type for '"
1589              << stringifyLinkage(Linkage::Appending) << "' linkage";
1590     }
1591   }
1592 
1593   Optional<uint64_t> alignAttr = op.alignment();
1594   if (alignAttr.hasValue()) {
1595     uint64_t value = alignAttr.getValue();
1596     if (!llvm::isPowerOf2_64(value))
1597       return op->emitError() << "alignment attribute is not a power of 2";
1598   }
1599 
1600   return success();
1601 }
1602 
1603 //===----------------------------------------------------------------------===//
1604 // Printing/parsing for LLVM::ShuffleVectorOp.
1605 //===----------------------------------------------------------------------===//
1606 // Expects vector to be of wrapped LLVM vector type and position to be of
1607 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value v1,Value v2,ArrayAttr mask,ArrayRef<NamedAttribute> attrs)1608 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
1609                                   Value v1, Value v2, ArrayAttr mask,
1610                                   ArrayRef<NamedAttribute> attrs) {
1611   auto containerType = v1.getType();
1612   auto vType = LLVM::getFixedVectorType(
1613       LLVM::getVectorElementType(containerType), mask.size());
1614   build(b, result, vType, v1, v2, mask);
1615   result.addAttributes(attrs);
1616 }
1617 
printShuffleVectorOp(OpAsmPrinter & p,ShuffleVectorOp & op)1618 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
1619   p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " "
1620     << op.mask();
1621   p.printOptionalAttrDict(op->getAttrs(), {"mask"});
1622   p << " : " << op.v1().getType() << ", " << op.v2().getType();
1623 }
1624 
1625 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
1626 //                 `[` integer-literal (`,` integer-literal)* `]`
1627 //                 attribute-dict? `:` type
parseShuffleVectorOp(OpAsmParser & parser,OperationState & result)1628 static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
1629                                         OperationState &result) {
1630   llvm::SMLoc loc;
1631   OpAsmParser::OperandType v1, v2;
1632   ArrayAttr maskAttr;
1633   Type typeV1, typeV2;
1634   if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
1635       parser.parseComma() || parser.parseOperand(v2) ||
1636       parser.parseAttribute(maskAttr, "mask", result.attributes) ||
1637       parser.parseOptionalAttrDict(result.attributes) ||
1638       parser.parseColonType(typeV1) || parser.parseComma() ||
1639       parser.parseType(typeV2) ||
1640       parser.resolveOperand(v1, typeV1, result.operands) ||
1641       parser.resolveOperand(v2, typeV2, result.operands))
1642     return failure();
1643   if (!LLVM::isCompatibleVectorType(typeV1))
1644     return parser.emitError(
1645         loc, "expected LLVM IR dialect vector type for operand #1");
1646   auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1),
1647                                         maskAttr.size());
1648   result.addTypes(vType);
1649   return success();
1650 }
1651 
1652 //===----------------------------------------------------------------------===//
1653 // Implementations for LLVM::LLVMFuncOp.
1654 //===----------------------------------------------------------------------===//
1655 
1656 // Add the entry block to the function.
addEntryBlock()1657 Block *LLVMFuncOp::addEntryBlock() {
1658   assert(empty() && "function already has an entry block");
1659   assert(!isVarArg() && "unimplemented: non-external variadic functions");
1660 
1661   auto *entry = new Block;
1662   push_back(entry);
1663 
1664   LLVMFunctionType type = getType();
1665   for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
1666     entry->addArgument(type.getParamType(i));
1667   return entry;
1668 }
1669 
build(OpBuilder & builder,OperationState & result,StringRef name,Type type,LLVM::Linkage linkage,bool dsoLocal,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)1670 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
1671                        StringRef name, Type type, LLVM::Linkage linkage,
1672                        bool dsoLocal, ArrayRef<NamedAttribute> attrs,
1673                        ArrayRef<DictionaryAttr> argAttrs) {
1674   result.addRegion();
1675   result.addAttribute(SymbolTable::getSymbolAttrName(),
1676                       builder.getStringAttr(name));
1677   result.addAttribute("type", TypeAttr::get(type));
1678   result.addAttribute(getLinkageAttrName(),
1679                       builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1680   result.attributes.append(attrs.begin(), attrs.end());
1681   if (dsoLocal)
1682     result.addAttribute("dso_local", builder.getUnitAttr());
1683   if (argAttrs.empty())
1684     return;
1685 
1686   assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
1687          "expected as many argument attribute lists as arguments");
1688   function_like_impl::addArgAndResultAttrs(builder, result, argAttrs,
1689                                            /*resultAttrs=*/llvm::None);
1690 }
1691 
1692 // Builds an LLVM function type from the given lists of input and output types.
1693 // Returns a null type if any of the types provided are non-LLVM types, or if
1694 // there is more than one output type.
1695 static Type
buildLLVMFunctionType(OpAsmParser & parser,llvm::SMLoc loc,ArrayRef<Type> inputs,ArrayRef<Type> outputs,function_like_impl::VariadicFlag variadicFlag)1696 buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
1697                       ArrayRef<Type> inputs, ArrayRef<Type> outputs,
1698                       function_like_impl::VariadicFlag variadicFlag) {
1699   Builder &b = parser.getBuilder();
1700   if (outputs.size() > 1) {
1701     parser.emitError(loc, "failed to construct function type: expected zero or "
1702                           "one function result");
1703     return {};
1704   }
1705 
1706   // Convert inputs to LLVM types, exit early on error.
1707   SmallVector<Type, 4> llvmInputs;
1708   for (auto t : inputs) {
1709     if (!isCompatibleType(t)) {
1710       parser.emitError(loc, "failed to construct function type: expected LLVM "
1711                             "type for function arguments");
1712       return {};
1713     }
1714     llvmInputs.push_back(t);
1715   }
1716 
1717   // No output is denoted as "void" in LLVM type system.
1718   Type llvmOutput =
1719       outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
1720   if (!isCompatibleType(llvmOutput)) {
1721     parser.emitError(loc, "failed to construct function type: expected LLVM "
1722                           "type for function results")
1723         << llvmOutput;
1724     return {};
1725   }
1726   return LLVMFunctionType::get(llvmOutput, llvmInputs,
1727                                variadicFlag.isVariadic());
1728 }
1729 
1730 // Parses an LLVM function.
1731 //
1732 // operation ::= `llvm.func` linkage? function-signature function-attributes?
1733 //               function-body
1734 //
parseLLVMFuncOp(OpAsmParser & parser,OperationState & result)1735 static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
1736                                    OperationState &result) {
1737   // Default to external linkage if no keyword is provided.
1738   if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1739                                                getLinkageAttrName())))
1740     result.addAttribute(getLinkageAttrName(),
1741                         parser.getBuilder().getI64IntegerAttr(
1742                             static_cast<int64_t>(LLVM::Linkage::External)));
1743 
1744   StringAttr nameAttr;
1745   SmallVector<OpAsmParser::OperandType, 8> entryArgs;
1746   SmallVector<NamedAttrList, 1> argAttrs;
1747   SmallVector<NamedAttrList, 1> resultAttrs;
1748   SmallVector<Type, 8> argTypes;
1749   SmallVector<Type, 4> resultTypes;
1750   bool isVariadic;
1751 
1752   auto signatureLocation = parser.getCurrentLocation();
1753   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1754                              result.attributes) ||
1755       function_like_impl::parseFunctionSignature(
1756           parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
1757           isVariadic, resultTypes, resultAttrs))
1758     return failure();
1759 
1760   auto type =
1761       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
1762                             function_like_impl::VariadicFlag(isVariadic));
1763   if (!type)
1764     return failure();
1765   result.addAttribute(function_like_impl::getTypeAttrName(),
1766                       TypeAttr::get(type));
1767 
1768   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1769     return failure();
1770   function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result,
1771                                            argAttrs, resultAttrs);
1772 
1773   auto *body = result.addRegion();
1774   OptionalParseResult parseResult = parser.parseOptionalRegion(
1775       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1776   return failure(parseResult.hasValue() && failed(*parseResult));
1777 }
1778 
1779 // Print the LLVMFuncOp. Collects argument and result types and passes them to
1780 // helper functions. Drops "void" result since it cannot be parsed back. Skips
1781 // the external linkage since it is the default value.
printLLVMFuncOp(OpAsmPrinter & p,LLVMFuncOp op)1782 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
1783   p << op.getOperationName() << ' ';
1784   if (op.linkage() != LLVM::Linkage::External)
1785     p << stringifyLinkage(op.linkage()) << ' ';
1786   p.printSymbolName(op.getName());
1787 
1788   LLVMFunctionType fnType = op.getType();
1789   SmallVector<Type, 8> argTypes;
1790   SmallVector<Type, 1> resTypes;
1791   argTypes.reserve(fnType.getNumParams());
1792   for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
1793     argTypes.push_back(fnType.getParamType(i));
1794 
1795   Type returnType = fnType.getReturnType();
1796   if (!returnType.isa<LLVMVoidType>())
1797     resTypes.push_back(returnType);
1798 
1799   function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(),
1800                                              resTypes);
1801   function_like_impl::printFunctionAttributes(
1802       p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()});
1803 
1804   // Print the body if this is not an external function.
1805   Region &body = op.body();
1806   if (!body.empty())
1807     p.printRegion(body, /*printEntryBlockArgs=*/false,
1808                   /*printBlockTerminators=*/true);
1809 }
1810 
1811 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
1812 // attribute is present.  This can check for preconditions of the
1813 // getNumArguments hook not failing.
verifyType()1814 LogicalResult LLVMFuncOp::verifyType() {
1815   auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>();
1816   if (!llvmType)
1817     return emitOpError("requires '" + getTypeAttrName() +
1818                        "' attribute of wrapped LLVM function type");
1819 
1820   return success();
1821 }
1822 
1823 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
1824 // Depends on the type attribute being correct as checked by verifyType
getNumFuncArguments()1825 unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); }
1826 
1827 // Hook for OpTrait::FunctionLike, returns the number of function results.
1828 // Depends on the type attribute being correct as checked by verifyType
getNumFuncResults()1829 unsigned LLVMFuncOp::getNumFuncResults() {
1830   // We model LLVM functions that return void as having zero results,
1831   // and all others as having one result.
1832   // If we modeled a void return as one result, then it would be possible to
1833   // attach an MLIR result attribute to it, and it isn't clear what semantics we
1834   // would assign to that.
1835   if (getType().getReturnType().isa<LLVMVoidType>())
1836     return 0;
1837   return 1;
1838 }
1839 
1840 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
1841 // - functions don't have 'common' linkage
1842 // - external functions have 'external' or 'extern_weak' linkage;
1843 // - vararg is (currently) only supported for external functions;
1844 // - entry block arguments are of LLVM types and match the function signature.
verify(LLVMFuncOp op)1845 static LogicalResult verify(LLVMFuncOp op) {
1846   if (op.linkage() == LLVM::Linkage::Common)
1847     return op.emitOpError()
1848            << "functions cannot have '"
1849            << stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
1850 
1851   if (op.isExternal()) {
1852     if (op.linkage() != LLVM::Linkage::External &&
1853         op.linkage() != LLVM::Linkage::ExternWeak)
1854       return op.emitOpError()
1855              << "external functions must have '"
1856              << stringifyLinkage(LLVM::Linkage::External) << "' or '"
1857              << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
1858     return success();
1859   }
1860 
1861   if (op.isVarArg())
1862     return op.emitOpError("only external functions can be variadic");
1863 
1864   unsigned numArguments = op.getType().getNumParams();
1865   Block &entryBlock = op.front();
1866   for (unsigned i = 0; i < numArguments; ++i) {
1867     Type argType = entryBlock.getArgument(i).getType();
1868     if (!isCompatibleType(argType))
1869       return op.emitOpError("entry block argument #")
1870              << i << " is not of LLVM type";
1871     if (op.getType().getParamType(i) != argType)
1872       return op.emitOpError("the type of entry block argument #")
1873              << i << " does not match the function signature";
1874   }
1875 
1876   return success();
1877 }
1878 
1879 //===----------------------------------------------------------------------===//
1880 // Verification for LLVM::ConstantOp.
1881 //===----------------------------------------------------------------------===//
1882 
verify(LLVM::ConstantOp op)1883 static LogicalResult verify(LLVM::ConstantOp op) {
1884   if (StringAttr sAttr = op.value().dyn_cast<StringAttr>()) {
1885     auto arrayType = op.getType().dyn_cast<LLVMArrayType>();
1886     if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
1887         !arrayType.getElementType().isInteger(8)) {
1888       return op->emitOpError()
1889              << "expected array type of " << sAttr.getValue().size()
1890              << " i8 elements for the string constant";
1891     }
1892     return success();
1893   }
1894   if (auto structType = op.getType().dyn_cast<LLVMStructType>()) {
1895     if (structType.getBody().size() != 2 ||
1896         structType.getBody()[0] != structType.getBody()[1]) {
1897       return op.emitError() << "expected struct type with two elements of the "
1898                                "same type, the type of a complex constant";
1899     }
1900 
1901     auto arrayAttr = op.value().dyn_cast<ArrayAttr>();
1902     if (!arrayAttr || arrayAttr.size() != 2 ||
1903         arrayAttr[0].getType() != arrayAttr[1].getType()) {
1904       return op.emitOpError() << "expected array attribute with two elements, "
1905                                  "representing a complex constant";
1906     }
1907 
1908     Type elementType = structType.getBody()[0];
1909     if (!elementType
1910              .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
1911       return op.emitError()
1912              << "expected struct element types to be floating point type or "
1913                 "integer type";
1914     }
1915     return success();
1916   }
1917   if (!op.value().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
1918     return op.emitOpError()
1919            << "only supports integer, float, string or elements attributes";
1920   return success();
1921 }
1922 
1923 //===----------------------------------------------------------------------===//
1924 // Utility functions for parsing atomic ops
1925 //===----------------------------------------------------------------------===//
1926 
1927 // Helper function to parse a keyword into the specified attribute named by
1928 // `attrName`. The keyword must match one of the string values defined by the
1929 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
1930 // state.
parseAtomicBinOp(OpAsmParser & parser,OperationState & result,StringRef attrName)1931 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
1932                                     StringRef attrName) {
1933   llvm::SMLoc loc;
1934   StringRef keyword;
1935   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
1936     return failure();
1937 
1938   // Replace the keyword `keyword` with an integer attribute.
1939   auto kind = symbolizeAtomicBinOp(keyword);
1940   if (!kind) {
1941     return parser.emitError(loc)
1942            << "'" << keyword << "' is an incorrect value of the '" << attrName
1943            << "' attribute";
1944   }
1945 
1946   auto value = static_cast<int64_t>(kind.getValue());
1947   auto attr = parser.getBuilder().getI64IntegerAttr(value);
1948   result.addAttribute(attrName, attr);
1949 
1950   return success();
1951 }
1952 
1953 // Helper function to parse a keyword into the specified attribute named by
1954 // `attrName`. The keyword must match one of the string values defined by the
1955 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
1956 // state.
parseAtomicOrdering(OpAsmParser & parser,OperationState & result,StringRef attrName)1957 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
1958                                        OperationState &result,
1959                                        StringRef attrName) {
1960   llvm::SMLoc loc;
1961   StringRef ordering;
1962   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
1963     return failure();
1964 
1965   // Replace the keyword `ordering` with an integer attribute.
1966   auto kind = symbolizeAtomicOrdering(ordering);
1967   if (!kind) {
1968     return parser.emitError(loc)
1969            << "'" << ordering << "' is an incorrect value of the '" << attrName
1970            << "' attribute";
1971   }
1972 
1973   auto value = static_cast<int64_t>(kind.getValue());
1974   auto attr = parser.getBuilder().getI64IntegerAttr(value);
1975   result.addAttribute(attrName, attr);
1976 
1977   return success();
1978 }
1979 
1980 //===----------------------------------------------------------------------===//
1981 // Printer, parser and verifier for LLVM::AtomicRMWOp.
1982 //===----------------------------------------------------------------------===//
1983 
printAtomicRMWOp(OpAsmPrinter & p,AtomicRMWOp & op)1984 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
1985   p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' '
1986     << op.ptr() << ", " << op.val() << ' '
1987     << stringifyAtomicOrdering(op.ordering()) << ' ';
1988   p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"});
1989   p << " : " << op.res().getType();
1990 }
1991 
1992 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
1993 //                 attribute-dict? `:` type
parseAtomicRMWOp(OpAsmParser & parser,OperationState & result)1994 static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
1995                                     OperationState &result) {
1996   Type type;
1997   OpAsmParser::OperandType ptr, val;
1998   if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
1999       parser.parseComma() || parser.parseOperand(val) ||
2000       parseAtomicOrdering(parser, result, "ordering") ||
2001       parser.parseOptionalAttrDict(result.attributes) ||
2002       parser.parseColonType(type) ||
2003       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2004                             result.operands) ||
2005       parser.resolveOperand(val, type, result.operands))
2006     return failure();
2007 
2008   result.addTypes(type);
2009   return success();
2010 }
2011 
verify(AtomicRMWOp op)2012 static LogicalResult verify(AtomicRMWOp op) {
2013   auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
2014   auto valType = op.val().getType();
2015   if (valType != ptrType.getElementType())
2016     return op.emitOpError("expected LLVM IR element type for operand #0 to "
2017                           "match type for operand #1");
2018   auto resType = op.res().getType();
2019   if (resType != valType)
2020     return op.emitOpError(
2021         "expected LLVM IR result type to match type for operand #1");
2022   if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
2023     if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2024       return op.emitOpError("expected LLVM IR floating point type");
2025   } else if (op.bin_op() == AtomicBinOp::xchg) {
2026     auto intType = valType.dyn_cast<IntegerType>();
2027     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2028     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2029         intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
2030         !valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
2031         !valType.isa<Float64Type>())
2032       return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2033   } else {
2034     auto intType = valType.dyn_cast<IntegerType>();
2035     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2036     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2037         intBitWidth != 64)
2038       return op.emitOpError("expected LLVM IR integer type");
2039   }
2040 
2041   if (static_cast<unsigned>(op.ordering()) <
2042       static_cast<unsigned>(AtomicOrdering::monotonic))
2043     return op.emitOpError()
2044            << "expected at least '"
2045            << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2046            << "' ordering";
2047 
2048   return success();
2049 }
2050 
2051 //===----------------------------------------------------------------------===//
2052 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
2053 //===----------------------------------------------------------------------===//
2054 
printAtomicCmpXchgOp(OpAsmPrinter & p,AtomicCmpXchgOp & op)2055 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) {
2056   p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", "
2057     << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' '
2058     << stringifyAtomicOrdering(op.failure_ordering());
2059   p.printOptionalAttrDict(op->getAttrs(),
2060                           {"success_ordering", "failure_ordering"});
2061   p << " : " << op.val().getType();
2062 }
2063 
2064 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
2065 //                 keyword keyword attribute-dict? `:` type
parseAtomicCmpXchgOp(OpAsmParser & parser,OperationState & result)2066 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
2067                                         OperationState &result) {
2068   auto &builder = parser.getBuilder();
2069   Type type;
2070   OpAsmParser::OperandType ptr, cmp, val;
2071   if (parser.parseOperand(ptr) || parser.parseComma() ||
2072       parser.parseOperand(cmp) || parser.parseComma() ||
2073       parser.parseOperand(val) ||
2074       parseAtomicOrdering(parser, result, "success_ordering") ||
2075       parseAtomicOrdering(parser, result, "failure_ordering") ||
2076       parser.parseOptionalAttrDict(result.attributes) ||
2077       parser.parseColonType(type) ||
2078       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2079                             result.operands) ||
2080       parser.resolveOperand(cmp, type, result.operands) ||
2081       parser.resolveOperand(val, type, result.operands))
2082     return failure();
2083 
2084   auto boolType = IntegerType::get(builder.getContext(), 1);
2085   auto resultType =
2086       LLVMStructType::getLiteral(builder.getContext(), {type, boolType});
2087   result.addTypes(resultType);
2088 
2089   return success();
2090 }
2091 
verify(AtomicCmpXchgOp op)2092 static LogicalResult verify(AtomicCmpXchgOp op) {
2093   auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
2094   if (!ptrType)
2095     return op.emitOpError("expected LLVM IR pointer type for operand #0");
2096   auto cmpType = op.cmp().getType();
2097   auto valType = op.val().getType();
2098   if (cmpType != ptrType.getElementType() || cmpType != valType)
2099     return op.emitOpError("expected LLVM IR element type for operand #0 to "
2100                           "match type for all other operands");
2101   auto intType = valType.dyn_cast<IntegerType>();
2102   unsigned intBitWidth = intType ? intType.getWidth() : 0;
2103   if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
2104       intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
2105       !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
2106       !valType.isa<Float32Type>() && !valType.isa<Float64Type>())
2107     return op.emitOpError("unexpected LLVM IR type");
2108   if (op.success_ordering() < AtomicOrdering::monotonic ||
2109       op.failure_ordering() < AtomicOrdering::monotonic)
2110     return op.emitOpError("ordering must be at least 'monotonic'");
2111   if (op.failure_ordering() == AtomicOrdering::release ||
2112       op.failure_ordering() == AtomicOrdering::acq_rel)
2113     return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2114   return success();
2115 }
2116 
2117 //===----------------------------------------------------------------------===//
2118 // Printer, parser and verifier for LLVM::FenceOp.
2119 //===----------------------------------------------------------------------===//
2120 
2121 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
2122 // attribute-dict?
parseFenceOp(OpAsmParser & parser,OperationState & result)2123 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) {
2124   StringAttr sScope;
2125   StringRef syncscopeKeyword = "syncscope";
2126   if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
2127     if (parser.parseLParen() ||
2128         parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
2129         parser.parseRParen())
2130       return failure();
2131   } else {
2132     result.addAttribute(syncscopeKeyword,
2133                         parser.getBuilder().getStringAttr(""));
2134   }
2135   if (parseAtomicOrdering(parser, result, "ordering") ||
2136       parser.parseOptionalAttrDict(result.attributes))
2137     return failure();
2138   return success();
2139 }
2140 
printFenceOp(OpAsmPrinter & p,FenceOp & op)2141 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
2142   StringRef syncscopeKeyword = "syncscope";
2143   p << op.getOperationName() << ' ';
2144   if (!op->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
2145     p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") ";
2146   p << stringifyAtomicOrdering(op.ordering());
2147 }
2148 
verify(FenceOp & op)2149 static LogicalResult verify(FenceOp &op) {
2150   if (op.ordering() == AtomicOrdering::not_atomic ||
2151       op.ordering() == AtomicOrdering::unordered ||
2152       op.ordering() == AtomicOrdering::monotonic)
2153     return op.emitOpError("can be given only acquire, release, acq_rel, "
2154                           "and seq_cst orderings");
2155   return success();
2156 }
2157 
2158 //===----------------------------------------------------------------------===//
2159 // LLVMDialect initialization, type parsing, and registration.
2160 //===----------------------------------------------------------------------===//
2161 
initialize()2162 void LLVMDialect::initialize() {
2163   addAttributes<FMFAttr, LoopOptionsAttr>();
2164 
2165   // clang-format off
2166   addTypes<LLVMVoidType,
2167            LLVMPPCFP128Type,
2168            LLVMX86MMXType,
2169            LLVMTokenType,
2170            LLVMLabelType,
2171            LLVMMetadataType,
2172            LLVMFunctionType,
2173            LLVMPointerType,
2174            LLVMFixedVectorType,
2175            LLVMScalableVectorType,
2176            LLVMArrayType,
2177            LLVMStructType>();
2178   // clang-format on
2179   addOperations<
2180 #define GET_OP_LIST
2181 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2182       >();
2183 
2184   // Support unknown operations because not all LLVM operations are registered.
2185   allowUnknownOperations();
2186 }
2187 
2188 #define GET_OP_CLASSES
2189 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2190 
2191 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const2192 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
2193   return detail::parseType(parser);
2194 }
2195 
2196 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const2197 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
2198   return detail::printType(type, os);
2199 }
2200 
verifyDataLayoutString(StringRef descr,llvm::function_ref<void (const Twine &)> reportError)2201 LogicalResult LLVMDialect::verifyDataLayoutString(
2202     StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
2203   llvm::Expected<llvm::DataLayout> maybeDataLayout =
2204       llvm::DataLayout::parse(descr);
2205   if (maybeDataLayout)
2206     return success();
2207 
2208   std::string message;
2209   llvm::raw_string_ostream messageStream(message);
2210   llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
2211   reportError("invalid data layout descriptor: " + messageStream.str());
2212   return failure();
2213 }
2214 
2215 /// Verify LLVM dialect attributes.
verifyOperationAttribute(Operation * op,NamedAttribute attr)2216 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
2217                                                     NamedAttribute attr) {
2218   // If the `llvm.loop` attribute is present, enforce the following structure,
2219   // which the module translation can assume.
2220   if (attr.first.strref() == LLVMDialect::getLoopAttrName()) {
2221     auto loopAttr = attr.second.dyn_cast<DictionaryAttr>();
2222     if (!loopAttr)
2223       return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
2224                                << "' to be a dictionary attribute";
2225     Optional<NamedAttribute> parallelAccessGroup =
2226         loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
2227     if (parallelAccessGroup.hasValue()) {
2228       auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>();
2229       if (!accessGroups)
2230         return op->emitOpError()
2231                << "expected '" << LLVMDialect::getParallelAccessAttrName()
2232                << "' to be an array attribute";
2233       for (Attribute attr : accessGroups) {
2234         auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
2235         if (!accessGroupRef)
2236           return op->emitOpError()
2237                  << "expected '" << attr << "' to be a symbol reference";
2238         StringRef metadataName = accessGroupRef.getRootReference();
2239         auto metadataOp =
2240             SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
2241                 op->getParentOp(), metadataName);
2242         if (!metadataOp)
2243           return op->emitOpError()
2244                  << "expected '" << attr << "' to reference a metadata op";
2245         StringRef accessGroupName = accessGroupRef.getLeafReference();
2246         Operation *accessGroupOp =
2247             SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
2248         if (!accessGroupOp)
2249           return op->emitOpError()
2250                  << "expected '" << attr << "' to reference an access_group op";
2251       }
2252     }
2253 
2254     Optional<NamedAttribute> loopOptions =
2255         loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
2256     if (loopOptions.hasValue() && !loopOptions->second.isa<LoopOptionsAttr>())
2257       return op->emitOpError()
2258              << "expected '" << LLVMDialect::getLoopOptionsAttrName()
2259              << "' to be a `loopopts` attribute";
2260   }
2261 
2262   // If the data layout attribute is present, it must use the LLVM data layout
2263   // syntax. Try parsing it and report errors in case of failure. Users of this
2264   // attribute may assume it is well-formed and can pass it to the (asserting)
2265   // llvm::DataLayout constructor.
2266   if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName())
2267     return success();
2268   if (auto stringAttr = attr.second.dyn_cast<StringAttr>())
2269     return verifyDataLayoutString(
2270         stringAttr.getValue(),
2271         [op](const Twine &message) { op->emitOpError() << message.str(); });
2272 
2273   return op->emitOpError() << "expected '"
2274                            << LLVM::LLVMDialect::getDataLayoutAttrName()
2275                            << "' to be a string attribute";
2276 }
2277 
2278 /// Verify LLVMIR function argument attributes.
verifyRegionArgAttribute(Operation * op,unsigned regionIdx,unsigned argIdx,NamedAttribute argAttr)2279 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
2280                                                     unsigned regionIdx,
2281                                                     unsigned argIdx,
2282                                                     NamedAttribute argAttr) {
2283   // Check that llvm.noalias is a unit attribute.
2284   if (argAttr.first == LLVMDialect::getNoAliasAttrName() &&
2285       !argAttr.second.isa<UnitAttr>())
2286     return op->emitError()
2287            << "expected llvm.noalias argument attribute to be a unit attribute";
2288   // Check that llvm.align is an integer attribute.
2289   if (argAttr.first == LLVMDialect::getAlignAttrName() &&
2290       !argAttr.second.isa<IntegerAttr>())
2291     return op->emitError()
2292            << "llvm.align argument attribute of non integer type";
2293   return success();
2294 }
2295 
2296 //===----------------------------------------------------------------------===//
2297 // Utility functions.
2298 //===----------------------------------------------------------------------===//
2299 
createGlobalString(Location loc,OpBuilder & builder,StringRef name,StringRef value,LLVM::Linkage linkage)2300 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
2301                                      StringRef name, StringRef value,
2302                                      LLVM::Linkage linkage) {
2303   assert(builder.getInsertionBlock() &&
2304          builder.getInsertionBlock()->getParentOp() &&
2305          "expected builder to point to a block constrained in an op");
2306   auto module =
2307       builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
2308   assert(module && "builder points to an op outside of a module");
2309 
2310   // Create the global at the entry of the module.
2311   OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
2312   MLIRContext *ctx = builder.getContext();
2313   auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
2314   auto global = moduleBuilder.create<LLVM::GlobalOp>(
2315       loc, type, /*isConstant=*/true, linkage, name,
2316       builder.getStringAttr(value), /*alignment=*/0);
2317 
2318   // Get the pointer to the first character in the global string.
2319   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
2320   Value cst0 = builder.create<LLVM::ConstantOp>(
2321       loc, IntegerType::get(ctx, 64),
2322       builder.getIntegerAttr(builder.getIndexType(), 0));
2323   return builder.create<LLVM::GEPOp>(
2324       loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
2325       ValueRange{cst0, cst0});
2326 }
2327 
satisfiesLLVMModule(Operation * op)2328 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
2329   return op->hasTrait<OpTrait::SymbolTable>() &&
2330          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
2331 }
2332 
2333 static constexpr const FastmathFlags FastmathFlagsList[] = {
2334     // clang-format off
2335     FastmathFlags::nnan,
2336     FastmathFlags::ninf,
2337     FastmathFlags::nsz,
2338     FastmathFlags::arcp,
2339     FastmathFlags::contract,
2340     FastmathFlags::afn,
2341     FastmathFlags::reassoc,
2342     FastmathFlags::fast,
2343     // clang-format on
2344 };
2345 
print(DialectAsmPrinter & printer) const2346 void FMFAttr::print(DialectAsmPrinter &printer) const {
2347   printer << "fastmath<";
2348   auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) {
2349     return bitEnumContains(this->getFlags(), flag);
2350   });
2351   llvm::interleaveComma(flags, printer,
2352                         [&](auto flag) { printer << stringifyEnum(flag); });
2353   printer << ">";
2354 }
2355 
parse(MLIRContext * context,DialectAsmParser & parser,Type type)2356 Attribute FMFAttr::parse(MLIRContext *context, DialectAsmParser &parser,
2357                          Type type) {
2358   if (failed(parser.parseLess()))
2359     return {};
2360 
2361   FastmathFlags flags = {};
2362   if (failed(parser.parseOptionalGreater())) {
2363     do {
2364       StringRef elemName;
2365       if (failed(parser.parseKeyword(&elemName)))
2366         return {};
2367 
2368       auto elem = symbolizeFastmathFlags(elemName);
2369       if (!elem) {
2370         parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
2371             << elemName;
2372         return {};
2373       }
2374 
2375       flags = flags | *elem;
2376     } while (succeeded(parser.parseOptionalComma()));
2377 
2378     if (failed(parser.parseGreater()))
2379       return {};
2380   }
2381 
2382   return FMFAttr::get(parser.getBuilder().getContext(), flags);
2383 }
2384 
LoopOptionsAttrBuilder(LoopOptionsAttr attr)2385 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
2386     : options(attr.getOptions().begin(), attr.getOptions().end()) {}
2387 
2388 template <typename T>
setOption(LoopOptionCase tag,Optional<T> value)2389 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag,
2390                                                           Optional<T> value) {
2391   auto option = llvm::find_if(
2392       options, [tag](auto option) { return option.first == tag; });
2393   if (option != options.end()) {
2394     if (value.hasValue())
2395       option->second = *value;
2396     else
2397       options.erase(option);
2398   } else {
2399     options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value));
2400   }
2401   return *this;
2402 }
2403 
2404 LoopOptionsAttrBuilder &
setDisableLICM(Optional<bool> value)2405 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) {
2406   return setOption(LoopOptionCase::disable_licm, value);
2407 }
2408 
2409 /// Set the `interleave_count` option to the provided value. If no value
2410 /// is provided the option is deleted.
2411 LoopOptionsAttrBuilder &
setInterleaveCount(Optional<uint64_t> count)2412 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) {
2413   return setOption(LoopOptionCase::interleave_count, count);
2414 }
2415 
2416 /// Set the `disable_unroll` option to the provided value. If no value
2417 /// is provided the option is deleted.
2418 LoopOptionsAttrBuilder &
setDisableUnroll(Optional<bool> value)2419 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) {
2420   return setOption(LoopOptionCase::disable_unroll, value);
2421 }
2422 
2423 /// Set the `disable_pipeline` option to the provided value. If no value
2424 /// is provided the option is deleted.
2425 LoopOptionsAttrBuilder &
setDisablePipeline(Optional<bool> value)2426 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) {
2427   return setOption(LoopOptionCase::disable_pipeline, value);
2428 }
2429 
2430 /// Set the `pipeline_initiation_interval` option to the provided value.
2431 /// If no value is provided the option is deleted.
setPipelineInitiationInterval(Optional<uint64_t> count)2432 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval(
2433     Optional<uint64_t> count) {
2434   return setOption(LoopOptionCase::pipeline_initiation_interval, count);
2435 }
2436 
2437 template <typename T>
2438 static Optional<T>
getOption(ArrayRef<std::pair<LoopOptionCase,int64_t>> options,LoopOptionCase option)2439 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options,
2440           LoopOptionCase option) {
2441   auto it =
2442       lower_bound(options, option, [](auto optionPair, LoopOptionCase option) {
2443         return optionPair.first < option;
2444       });
2445   if (it == options.end())
2446     return {};
2447   return static_cast<T>(it->second);
2448 }
2449 
disableUnroll()2450 Optional<bool> LoopOptionsAttr::disableUnroll() {
2451   return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll);
2452 }
2453 
disableLICM()2454 Optional<bool> LoopOptionsAttr::disableLICM() {
2455   return getOption<bool>(getOptions(), LoopOptionCase::disable_licm);
2456 }
2457 
interleaveCount()2458 Optional<int64_t> LoopOptionsAttr::interleaveCount() {
2459   return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count);
2460 }
2461 
2462 /// Build the LoopOptions Attribute from a sorted array of individual options.
get(MLIRContext * context,ArrayRef<std::pair<LoopOptionCase,int64_t>> sortedOptions)2463 LoopOptionsAttr LoopOptionsAttr::get(
2464     MLIRContext *context,
2465     ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) {
2466   assert(llvm::is_sorted(sortedOptions, llvm::less_first()) &&
2467          "LoopOptionsAttr ctor expects a sorted options array");
2468   return Base::get(context, sortedOptions);
2469 }
2470 
2471 /// Build the LoopOptions Attribute from a sorted array of individual options.
get(MLIRContext * context,LoopOptionsAttrBuilder & optionBuilders)2472 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context,
2473                                      LoopOptionsAttrBuilder &optionBuilders) {
2474   llvm::sort(optionBuilders.options, llvm::less_first());
2475   return Base::get(context, optionBuilders.options);
2476 }
2477 
print(DialectAsmPrinter & printer) const2478 void LoopOptionsAttr::print(DialectAsmPrinter &printer) const {
2479   printer << getMnemonic() << "<";
2480   llvm::interleaveComma(getOptions(), printer, [&](auto option) {
2481     printer << stringifyEnum(option.first) << " = ";
2482     switch (option.first) {
2483     case LoopOptionCase::disable_licm:
2484     case LoopOptionCase::disable_unroll:
2485     case LoopOptionCase::disable_pipeline:
2486       printer << (option.second ? "true" : "false");
2487       break;
2488     case LoopOptionCase::interleave_count:
2489     case LoopOptionCase::pipeline_initiation_interval:
2490       printer << option.second;
2491       break;
2492     }
2493   });
2494   printer << ">";
2495 }
2496 
parse(MLIRContext * context,DialectAsmParser & parser,Type type)2497 Attribute LoopOptionsAttr::parse(MLIRContext *context, DialectAsmParser &parser,
2498                                  Type type) {
2499   if (failed(parser.parseLess()))
2500     return {};
2501 
2502   SmallVector<std::pair<LoopOptionCase, int64_t>> options;
2503   llvm::SmallDenseSet<LoopOptionCase> seenOptions;
2504   do {
2505     StringRef optionName;
2506     if (parser.parseKeyword(&optionName))
2507       return {};
2508 
2509     auto option = symbolizeLoopOptionCase(optionName);
2510     if (!option) {
2511       parser.emitError(parser.getNameLoc(), "unknown loop option: ")
2512           << optionName;
2513       return {};
2514     }
2515     if (!seenOptions.insert(*option).second) {
2516       parser.emitError(parser.getNameLoc(), "loop option present twice");
2517       return {};
2518     }
2519     if (failed(parser.parseEqual()))
2520       return {};
2521 
2522     int64_t value;
2523     switch (*option) {
2524     case LoopOptionCase::disable_licm:
2525     case LoopOptionCase::disable_unroll:
2526     case LoopOptionCase::disable_pipeline:
2527       if (succeeded(parser.parseOptionalKeyword("true")))
2528         value = 1;
2529       else if (succeeded(parser.parseOptionalKeyword("false")))
2530         value = 0;
2531       else {
2532         parser.emitError(parser.getNameLoc(),
2533                          "expected boolean value 'true' or 'false'");
2534         return {};
2535       }
2536       break;
2537     case LoopOptionCase::interleave_count:
2538     case LoopOptionCase::pipeline_initiation_interval:
2539       if (failed(parser.parseInteger(value))) {
2540         parser.emitError(parser.getNameLoc(), "expected integer value");
2541         return {};
2542       }
2543       break;
2544     }
2545     options.push_back(std::make_pair(*option, value));
2546   } while (succeeded(parser.parseOptionalComma()));
2547   if (failed(parser.parseGreater()))
2548     return {};
2549 
2550   llvm::sort(options, llvm::less_first());
2551   return get(parser.getBuilder().getContext(), options);
2552 }
2553 
parseAttribute(DialectAsmParser & parser,Type type) const2554 Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
2555                                       Type type) const {
2556   if (type) {
2557     parser.emitError(parser.getNameLoc(), "unexpected type");
2558     return {};
2559   }
2560   StringRef attrKind;
2561   if (parser.parseKeyword(&attrKind))
2562     return {};
2563   {
2564     Attribute attr;
2565     auto parseResult =
2566         generatedAttributeParser(getContext(), parser, attrKind, type, attr);
2567     if (parseResult.hasValue())
2568       return attr;
2569   }
2570   parser.emitError(parser.getNameLoc(), "unknown attribute type: ") << attrKind;
2571   return {};
2572 }
2573 
printAttribute(Attribute attr,DialectAsmPrinter & os) const2574 void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
2575   if (succeeded(generatedAttributePrinter(attr, os)))
2576     return;
2577   llvm_unreachable("Unknown attribute type");
2578 }
2579