1 //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This classes used by the implementation details of Op types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_OPIMPLEMENTATION_H
14 #define MLIR_IR_OPIMPLEMENTATION_H
15 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/SMLoc.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 namespace mlir {
24 
25 class Builder;
26 
27 //===----------------------------------------------------------------------===//
28 // OpAsmPrinter
29 //===----------------------------------------------------------------------===//
30 
31 /// This is a pure-virtual base class that exposes the asmprinter hooks
32 /// necessary to implement a custom print() method.
33 class OpAsmPrinter {
34 public:
OpAsmPrinter()35   OpAsmPrinter() {}
36   virtual ~OpAsmPrinter();
37   virtual raw_ostream &getStream() const = 0;
38 
39   /// Print a newline and indent the printer to the start of the current
40   /// operation.
41   virtual void printNewline() = 0;
42 
43   /// Print implementations for various things an operation contains.
44   virtual void printOperand(Value value) = 0;
45   virtual void printOperand(Value value, raw_ostream &os) = 0;
46 
47   /// Print a comma separated list of operands.
48   template <typename ContainerType>
printOperands(const ContainerType & container)49   void printOperands(const ContainerType &container) {
50     printOperands(container.begin(), container.end());
51   }
52 
53   /// Print a comma separated list of operands.
54   template <typename IteratorType>
printOperands(IteratorType it,IteratorType end)55   void printOperands(IteratorType it, IteratorType end) {
56     if (it == end)
57       return;
58     printOperand(*it);
59     for (++it; it != end; ++it) {
60       getStream() << ", ";
61       printOperand(*it);
62     }
63   }
64   virtual void printType(Type type) = 0;
65   virtual void printAttribute(Attribute attr) = 0;
66 
67   /// Print the given attribute without its type. The corresponding parser must
68   /// provide a valid type for the attribute.
69   virtual void printAttributeWithoutType(Attribute attr) = 0;
70 
71   /// Print the given successor.
72   virtual void printSuccessor(Block *successor) = 0;
73 
74   /// Print the successor and its operands.
75   virtual void printSuccessorAndUseList(Block *successor,
76                                         ValueRange succOperands) = 0;
77 
78   /// If the specified operation has attributes, print out an attribute
79   /// dictionary with their values.  elidedAttrs allows the client to ignore
80   /// specific well known attributes, commonly used if the attribute value is
81   /// printed some other way (like as a fixed operand).
82   virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
83                                      ArrayRef<StringRef> elidedAttrs = {}) = 0;
84 
85   /// If the specified operation has attributes, print out an attribute
86   /// dictionary prefixed with 'attributes'.
87   virtual void
88   printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
89                                    ArrayRef<StringRef> elidedAttrs = {}) = 0;
90 
91   /// Print the entire operation with the default generic assembly form.
92   virtual void printGenericOp(Operation *op) = 0;
93 
94   /// Prints a region.
95   virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
96                            bool printBlockTerminators = true) = 0;
97 
98   /// Renumber the arguments for the specified region to the same names as the
99   /// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
100   /// operations.  If any entry in namesToUse is null, the corresponding
101   /// argument name is left alone.
102   virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
103 
104   /// Prints an affine map of SSA ids, where SSA id names are used in place
105   /// of dims/symbols.
106   /// Operand values must come from single-result sources, and be valid
107   /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
108   virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
109                                       ValueRange operands) = 0;
110 
111   /// Print an optional arrow followed by a type list.
112   template <typename TypeRange>
printOptionalArrowTypeList(TypeRange && types)113   void printOptionalArrowTypeList(TypeRange &&types) {
114     if (types.begin() != types.end())
115       printArrowTypeList(types);
116   }
117   template <typename TypeRange>
printArrowTypeList(TypeRange && types)118   void printArrowTypeList(TypeRange &&types) {
119     auto &os = getStream() << " -> ";
120 
121     bool wrapped = !llvm::hasSingleElement(types) ||
122                    (*types.begin()).template isa<FunctionType>();
123     if (wrapped)
124       os << '(';
125     llvm::interleaveComma(types, *this);
126     if (wrapped)
127       os << ')';
128   }
129 
130   /// Print the complete type of an operation in functional form.
131   void printFunctionalType(Operation *op);
132 
133   /// Print the two given type ranges in a functional form.
134   template <typename InputRangeT, typename ResultRangeT>
printFunctionalType(InputRangeT && inputs,ResultRangeT && results)135   void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
136     auto &os = getStream();
137     os << '(';
138     llvm::interleaveComma(inputs, *this);
139     os << ')';
140     printArrowTypeList(results);
141   }
142 
143   /// Print the given string as a symbol reference, i.e. a form representable by
144   /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
145   /// with '@'. The reference is surrounded with ""'s and escaped if it has any
146   /// special or non-printable characters in it.
147   virtual void printSymbolName(StringRef symbolRef) = 0;
148 
149 private:
150   OpAsmPrinter(const OpAsmPrinter &) = delete;
151   void operator=(const OpAsmPrinter &) = delete;
152 };
153 
154 // Make the implementations convenient to use.
155 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
156   p.printOperand(value);
157   return p;
158 }
159 
160 template <typename T,
161           typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
162                                       !std::is_convertible<T &, Value &>::value,
163                                   T>::type * = nullptr>
164 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
165   p.printOperands(values);
166   return p;
167 }
168 
169 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
170   p.printType(type);
171   return p;
172 }
173 
174 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
175   p.printAttribute(attr);
176   return p;
177 }
178 
179 // Support printing anything that isn't convertible to one of the above types,
180 // even if it isn't exactly one of them.  For example, we want to print
181 // FunctionType with the Type version above, not have it match this.
182 template <typename T, typename std::enable_if<
183                           !std::is_convertible<T &, Value &>::value &&
184                               !std::is_convertible<T &, Type &>::value &&
185                               !std::is_convertible<T &, Attribute &>::value &&
186                               !std::is_convertible<T &, ValueRange>::value &&
187                               !llvm::is_one_of<T, bool>::value,
188                           T>::type * = nullptr>
189 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
190   p.getStream() << other;
191   return p;
192 }
193 
194 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
195   return p << (value ? StringRef("true") : "false");
196 }
197 
198 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
199   p.printSuccessor(value);
200   return p;
201 }
202 
203 template <typename ValueRangeT>
204 inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
205                                 const ValueTypeRange<ValueRangeT> &types) {
206   llvm::interleaveComma(types, p);
207   return p;
208 }
209 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) {
210   llvm::interleaveComma(types, p);
211   return p;
212 }
213 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
214   llvm::interleaveComma(types, p);
215   return p;
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // OpAsmParser
220 //===----------------------------------------------------------------------===//
221 
222 /// The OpAsmParser has methods for interacting with the asm parser: parsing
223 /// things from it, emitting errors etc.  It has an intentionally high-level API
224 /// that is designed to reduce/constrain syntax innovation in individual
225 /// operations.
226 ///
227 /// For example, consider an op like this:
228 ///
229 ///    %x = load %p[%1, %2] : memref<...>
230 ///
231 /// The "%x = load" tokens are already parsed and therefore invisible to the
232 /// custom op parser.  This can be supported by calling `parseOperandList` to
233 /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
234 /// parse the indices, then calling `parseColonTypeList` to parse the result
235 /// type.
236 ///
237 class OpAsmParser {
238 public:
239   virtual ~OpAsmParser();
240 
241   /// Emit a diagnostic at the specified location and return failure.
242   virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
243                                        const Twine &message = {}) = 0;
244 
245   /// Return a builder which provides useful access to MLIRContext, global
246   /// objects like types and attributes.
247   virtual Builder &getBuilder() const = 0;
248 
249   /// Get the location of the next token and store it into the argument.  This
250   /// always succeeds.
251   virtual llvm::SMLoc getCurrentLocation() = 0;
getCurrentLocation(llvm::SMLoc * loc)252   ParseResult getCurrentLocation(llvm::SMLoc *loc) {
253     *loc = getCurrentLocation();
254     return success();
255   }
256 
257   /// Return the name of the specified result in the specified syntax, as well
258   /// as the sub-element in the name.  It returns an empty string and ~0U for
259   /// invalid result numbers.  For example, in this operation:
260   ///
261   ///  %x, %y:2, %z = foo.op
262   ///
263   ///    getResultName(0) == {"x", 0 }
264   ///    getResultName(1) == {"y", 0 }
265   ///    getResultName(2) == {"y", 1 }
266   ///    getResultName(3) == {"z", 0 }
267   ///    getResultName(4) == {"", ~0U }
268   virtual std::pair<StringRef, unsigned>
269   getResultName(unsigned resultNo) const = 0;
270 
271   /// Return the number of declared SSA results.  This returns 4 for the foo.op
272   /// example in the comment for `getResultName`.
273   virtual size_t getNumResults() const = 0;
274 
275   /// Return the location of the original name token.
276   virtual llvm::SMLoc getNameLoc() const = 0;
277 
278   // These methods emit an error and return failure or success. This allows
279   // these to be chained together into a linear sequence of || expressions in
280   // many cases.
281 
282   /// Parse an operation in its generic form.
283   /// The parsed operation is parsed in the current context and inserted in the
284   /// provided block and insertion point. The results produced by this operation
285   /// aren't mapped to any named value in the parser. Returns nullptr on
286   /// failure.
287   virtual Operation *parseGenericOperation(Block *insertBlock,
288                                            Block::iterator insertPt) = 0;
289 
290   //===--------------------------------------------------------------------===//
291   // Token Parsing
292   //===--------------------------------------------------------------------===//
293 
294   /// Parse a '->' token.
295   virtual ParseResult parseArrow() = 0;
296 
297   /// Parse a '->' token if present
298   virtual ParseResult parseOptionalArrow() = 0;
299 
300   /// Parse a `{` token.
301   virtual ParseResult parseLBrace() = 0;
302 
303   /// Parse a `{` token if present.
304   virtual ParseResult parseOptionalLBrace() = 0;
305 
306   /// Parse a `}` token.
307   virtual ParseResult parseRBrace() = 0;
308 
309   /// Parse a `}` token if present.
310   virtual ParseResult parseOptionalRBrace() = 0;
311 
312   /// Parse a `:` token.
313   virtual ParseResult parseColon() = 0;
314 
315   /// Parse a `:` token if present.
316   virtual ParseResult parseOptionalColon() = 0;
317 
318   /// Parse a `,` token.
319   virtual ParseResult parseComma() = 0;
320 
321   /// Parse a `,` token if present.
322   virtual ParseResult parseOptionalComma() = 0;
323 
324   /// Parse a `=` token.
325   virtual ParseResult parseEqual() = 0;
326 
327   /// Parse a `=` token if present.
328   virtual ParseResult parseOptionalEqual() = 0;
329 
330   /// Parse a '<' token.
331   virtual ParseResult parseLess() = 0;
332 
333   /// Parse a '<' token if present.
334   virtual ParseResult parseOptionalLess() = 0;
335 
336   /// Parse a '>' token.
337   virtual ParseResult parseGreater() = 0;
338 
339   /// Parse a '>' token if present.
340   virtual ParseResult parseOptionalGreater() = 0;
341 
342   /// Parse a '?' token.
343   virtual ParseResult parseQuestion() = 0;
344 
345   /// Parse a '?' token if present.
346   virtual ParseResult parseOptionalQuestion() = 0;
347 
348   /// Parse a '+' token.
349   virtual ParseResult parsePlus() = 0;
350 
351   /// Parse a '+' token if present.
352   virtual ParseResult parseOptionalPlus() = 0;
353 
354   /// Parse a '*' token.
355   virtual ParseResult parseStar() = 0;
356 
357   /// Parse a '*' token if present.
358   virtual ParseResult parseOptionalStar() = 0;
359 
360   /// Parse a given keyword.
361   ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
362     auto loc = getCurrentLocation();
363     if (parseOptionalKeyword(keyword))
364       return emitError(loc, "expected '") << keyword << "'" << msg;
365     return success();
366   }
367 
368   /// Parse a keyword into 'keyword'.
parseKeyword(StringRef * keyword)369   ParseResult parseKeyword(StringRef *keyword) {
370     auto loc = getCurrentLocation();
371     if (parseOptionalKeyword(keyword))
372       return emitError(loc, "expected valid keyword");
373     return success();
374   }
375 
376   /// Parse the given keyword if present.
377   virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
378 
379   /// Parse a keyword, if present, into 'keyword'.
380   virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
381 
382   /// Parse a keyword, if present, and if one of the 'allowedValues',
383   /// into 'keyword'
384   virtual ParseResult
385   parseOptionalKeyword(StringRef *keyword,
386                        ArrayRef<StringRef> allowedValues) = 0;
387 
388   /// Parse a `(` token.
389   virtual ParseResult parseLParen() = 0;
390 
391   /// Parse a `(` token if present.
392   virtual ParseResult parseOptionalLParen() = 0;
393 
394   /// Parse a `)` token.
395   virtual ParseResult parseRParen() = 0;
396 
397   /// Parse a `)` token if present.
398   virtual ParseResult parseOptionalRParen() = 0;
399 
400   /// Parse a `[` token.
401   virtual ParseResult parseLSquare() = 0;
402 
403   /// Parse a `[` token if present.
404   virtual ParseResult parseOptionalLSquare() = 0;
405 
406   /// Parse a `]` token.
407   virtual ParseResult parseRSquare() = 0;
408 
409   /// Parse a `]` token if present.
410   virtual ParseResult parseOptionalRSquare() = 0;
411 
412   /// Parse a `...` token if present;
413   virtual ParseResult parseOptionalEllipsis() = 0;
414 
415   /// Parse an integer value from the stream.
416   template <typename IntT>
parseInteger(IntT & result)417   ParseResult parseInteger(IntT &result) {
418     auto loc = getCurrentLocation();
419     OptionalParseResult parseResult = parseOptionalInteger(result);
420     if (!parseResult.hasValue())
421       return emitError(loc, "expected integer value");
422     return *parseResult;
423   }
424 
425   /// Parse an optional integer value from the stream.
426   virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0;
427 
428   template <typename IntT>
parseOptionalInteger(IntT & result)429   OptionalParseResult parseOptionalInteger(IntT &result) {
430     auto loc = getCurrentLocation();
431 
432     // Parse the unsigned variant.
433     uint64_t uintResult;
434     OptionalParseResult parseResult = parseOptionalInteger(uintResult);
435     if (!parseResult.hasValue() || failed(*parseResult))
436       return parseResult;
437 
438     // Try to convert to the provided integer type.
439     result = IntT(uintResult);
440     if (uint64_t(result) != uintResult)
441       return emitError(loc, "integer value too large");
442     return success();
443   }
444 
445   //===--------------------------------------------------------------------===//
446   // Attribute Parsing
447   //===--------------------------------------------------------------------===//
448 
449   /// Parse an arbitrary attribute of a given type and return it in result.
450   virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
451 
452   /// Parse an attribute of a specific kind and type.
453   template <typename AttrType>
454   ParseResult parseAttribute(AttrType &result, Type type = {}) {
455     llvm::SMLoc loc = getCurrentLocation();
456 
457     // Parse any kind of attribute.
458     Attribute attr;
459     if (parseAttribute(attr, type))
460       return failure();
461 
462     // Check for the right kind of attribute.
463     if (!(result = attr.dyn_cast<AttrType>()))
464       return emitError(loc, "invalid kind of attribute specified");
465 
466     return success();
467   }
468 
469   /// Parse an arbitrary attribute and return it in result.  This also adds the
470   /// attribute to the specified attribute list with the specified name.
parseAttribute(Attribute & result,StringRef attrName,NamedAttrList & attrs)471   ParseResult parseAttribute(Attribute &result, StringRef attrName,
472                              NamedAttrList &attrs) {
473     return parseAttribute(result, Type(), attrName, attrs);
474   }
475 
476   /// Parse an attribute of a specific kind and type.
477   template <typename AttrType>
parseAttribute(AttrType & result,StringRef attrName,NamedAttrList & attrs)478   ParseResult parseAttribute(AttrType &result, StringRef attrName,
479                              NamedAttrList &attrs) {
480     return parseAttribute(result, Type(), attrName, attrs);
481   }
482 
483   /// Parse an optional attribute.
484   virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
485                                                      Type type,
486                                                      StringRef attrName,
487                                                      NamedAttrList &attrs) = 0;
488   template <typename AttrT>
parseOptionalAttribute(AttrT & result,StringRef attrName,NamedAttrList & attrs)489   OptionalParseResult parseOptionalAttribute(AttrT &result, StringRef attrName,
490                                              NamedAttrList &attrs) {
491     return parseOptionalAttribute(result, Type(), attrName, attrs);
492   }
493 
494   /// Specialized variants of `parseOptionalAttribute` that remove potential
495   /// ambiguities in syntax.
496   virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
497                                                      Type type,
498                                                      StringRef attrName,
499                                                      NamedAttrList &attrs) = 0;
500   virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
501                                                      Type type,
502                                                      StringRef attrName,
503                                                      NamedAttrList &attrs) = 0;
504 
505   /// Parse an arbitrary attribute of a given type and return it in result. This
506   /// also adds the attribute to the specified attribute list with the specified
507   /// name.
508   template <typename AttrType>
parseAttribute(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)509   ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
510                              NamedAttrList &attrs) {
511     llvm::SMLoc loc = getCurrentLocation();
512 
513     // Parse any kind of attribute.
514     Attribute attr;
515     if (parseAttribute(attr, type))
516       return failure();
517 
518     // Check for the right kind of attribute.
519     result = attr.dyn_cast<AttrType>();
520     if (!result)
521       return emitError(loc, "invalid kind of attribute specified");
522 
523     attrs.append(attrName, result);
524     return success();
525   }
526 
527   /// Parse a named dictionary into 'result' if it is present.
528   virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
529 
530   /// Parse a named dictionary into 'result' if the `attributes` keyword is
531   /// present.
532   virtual ParseResult
533   parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
534 
535   /// Parse an affine map instance into 'map'.
536   virtual ParseResult parseAffineMap(AffineMap &map) = 0;
537 
538   /// Parse an integer set instance into 'set'.
539   virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
540 
541   //===--------------------------------------------------------------------===//
542   // Identifier Parsing
543   //===--------------------------------------------------------------------===//
544 
545   /// Parse an @-identifier and store it (without the '@' symbol) in a string
546   /// attribute named 'attrName'.
parseSymbolName(StringAttr & result,StringRef attrName,NamedAttrList & attrs)547   ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
548                               NamedAttrList &attrs) {
549     if (failed(parseOptionalSymbolName(result, attrName, attrs)))
550       return emitError(getCurrentLocation())
551              << "expected valid '@'-identifier for symbol name";
552     return success();
553   }
554 
555   /// Parse an optional @-identifier and store it (without the '@' symbol) in a
556   /// string attribute named 'attrName'.
557   virtual ParseResult parseOptionalSymbolName(StringAttr &result,
558                                               StringRef attrName,
559                                               NamedAttrList &attrs) = 0;
560 
561   //===--------------------------------------------------------------------===//
562   // Operand Parsing
563   //===--------------------------------------------------------------------===//
564 
565   /// This is the representation of an operand reference.
566   struct OperandType {
567     llvm::SMLoc location; // Location of the token.
568     StringRef name;       // Value name, e.g. %42 or %abc
569     unsigned number;      // Number, e.g. 12 for an operand like %xyz#12
570   };
571 
572   /// Parse a single operand.
573   virtual ParseResult parseOperand(OperandType &result) = 0;
574 
575   /// Parse a single operand if present.
576   virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0;
577 
578   /// These are the supported delimiters around operand lists and region
579   /// argument lists, used by parseOperandList and parseRegionArgumentList.
580   enum class Delimiter {
581     /// Zero or more operands with no delimiters.
582     None,
583     /// Parens surrounding zero or more operands.
584     Paren,
585     /// Square brackets surrounding zero or more operands.
586     Square,
587     /// Parens supporting zero or more operands, or nothing.
588     OptionalParen,
589     /// Square brackets supporting zero or more ops, or nothing.
590     OptionalSquare,
591   };
592 
593   /// Parse zero or more SSA comma-separated operand references with a specified
594   /// surrounding delimiter, and an optional required operand count.
595   virtual ParseResult
596   parseOperandList(SmallVectorImpl<OperandType> &result,
597                    int requiredOperandCount = -1,
598                    Delimiter delimiter = Delimiter::None) = 0;
parseOperandList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)599   ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
600                                Delimiter delimiter) {
601     return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
602   }
603 
604   /// Parse zero or more trailing SSA comma-separated trailing operand
605   /// references with a specified surrounding delimiter, and an optional
606   /// required operand count. A leading comma is expected before the operands.
607   virtual ParseResult
608   parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
609                            int requiredOperandCount = -1,
610                            Delimiter delimiter = Delimiter::None) = 0;
parseTrailingOperandList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)611   ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
612                                        Delimiter delimiter) {
613     return parseTrailingOperandList(result, /*requiredOperandCount=*/-1,
614                                     delimiter);
615   }
616 
617   /// Resolve an operand to an SSA value, emitting an error on failure.
618   virtual ParseResult resolveOperand(const OperandType &operand, Type type,
619                                      SmallVectorImpl<Value> &result) = 0;
620 
621   /// Resolve a list of operands to SSA values, emitting an error on failure, or
622   /// appending the results to the list on success. This method should be used
623   /// when all operands have the same type.
resolveOperands(ArrayRef<OperandType> operands,Type type,SmallVectorImpl<Value> & result)624   ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
625                               SmallVectorImpl<Value> &result) {
626     for (auto elt : operands)
627       if (resolveOperand(elt, type, result))
628         return failure();
629     return success();
630   }
631 
632   /// Resolve a list of operands and a list of operand types to SSA values,
633   /// emitting an error and returning failure, or appending the results
634   /// to the list on success.
resolveOperands(ArrayRef<OperandType> operands,ArrayRef<Type> types,llvm::SMLoc loc,SmallVectorImpl<Value> & result)635   ParseResult resolveOperands(ArrayRef<OperandType> operands,
636                               ArrayRef<Type> types, llvm::SMLoc loc,
637                               SmallVectorImpl<Value> &result) {
638     if (operands.size() != types.size())
639       return emitError(loc)
640              << operands.size() << " operands present, but expected "
641              << types.size();
642 
643     for (unsigned i = 0, e = operands.size(); i != e; ++i)
644       if (resolveOperand(operands[i], types[i], result))
645         return failure();
646     return success();
647   }
648   template <typename Operands>
resolveOperands(Operands && operands,Type type,llvm::SMLoc loc,SmallVectorImpl<Value> & result)649   ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc,
650                               SmallVectorImpl<Value> &result) {
651     return resolveOperands(std::forward<Operands>(operands),
652                            ArrayRef<Type>(type), loc, result);
653   }
654   template <typename Operands, typename Types>
655   std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
resolveOperands(Operands && operands,Types && types,llvm::SMLoc loc,SmallVectorImpl<Value> & result)656   resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc,
657                   SmallVectorImpl<Value> &result) {
658     size_t operandSize = std::distance(operands.begin(), operands.end());
659     size_t typeSize = std::distance(types.begin(), types.end());
660     if (operandSize != typeSize)
661       return emitError(loc)
662              << operandSize << " operands present, but expected " << typeSize;
663 
664     for (auto it : llvm::zip(operands, types))
665       if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
666         return failure();
667     return success();
668   }
669 
670   /// Parses an affine map attribute where dims and symbols are SSA operands.
671   /// Operand values must come from single-result sources, and be valid
672   /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
673   virtual ParseResult
674   parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
675                          StringRef attrName, NamedAttrList &attrs,
676                          Delimiter delimiter = Delimiter::Square) = 0;
677 
678   //===--------------------------------------------------------------------===//
679   // Region Parsing
680   //===--------------------------------------------------------------------===//
681 
682   /// Parses a region. Any parsed blocks are appended to 'region' and must be
683   /// moved to the op regions after the op is created. The first block of the
684   /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is
685   /// set to true, the argument names are allowed to shadow the names of other
686   /// existing SSA values defined above the region scope. 'enableNameShadowing'
687   /// can only be set to true for regions attached to operations that are
688   /// 'IsolatedFromAbove.
689   virtual ParseResult parseRegion(Region &region,
690                                   ArrayRef<OperandType> arguments = {},
691                                   ArrayRef<Type> argTypes = {},
692                                   bool enableNameShadowing = false) = 0;
693 
694   /// Parses a region if present.
695   virtual OptionalParseResult
696   parseOptionalRegion(Region &region, ArrayRef<OperandType> arguments = {},
697                       ArrayRef<Type> argTypes = {},
698                       bool enableNameShadowing = false) = 0;
699 
700   /// Parses a region if present. If the region is present, a new region is
701   /// allocated and placed in `region`. If no region is present or on failure,
702   /// `region` remains untouched.
703   virtual OptionalParseResult parseOptionalRegion(
704       std::unique_ptr<Region> &region, ArrayRef<OperandType> arguments = {},
705       ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
706 
707   /// Parse a region argument, this argument is resolved when calling
708   /// 'parseRegion'.
709   virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
710 
711   /// Parse zero or more region arguments with a specified surrounding
712   /// delimiter, and an optional required argument count. Region arguments
713   /// define new values; so this also checks if values with the same names have
714   /// not been defined yet.
715   virtual ParseResult
716   parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
717                           int requiredOperandCount = -1,
718                           Delimiter delimiter = Delimiter::None) = 0;
719   virtual ParseResult
parseRegionArgumentList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)720   parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
721                           Delimiter delimiter) {
722     return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
723                                    delimiter);
724   }
725 
726   /// Parse a region argument if present.
727   virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
728 
729   //===--------------------------------------------------------------------===//
730   // Successor Parsing
731   //===--------------------------------------------------------------------===//
732 
733   /// Parse a single operation successor.
734   virtual ParseResult parseSuccessor(Block *&dest) = 0;
735 
736   /// Parse an optional operation successor.
737   virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
738 
739   /// Parse a single operation successor and its operand list.
740   virtual ParseResult
741   parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
742 
743   //===--------------------------------------------------------------------===//
744   // Type Parsing
745   //===--------------------------------------------------------------------===//
746 
747   /// Parse a type.
748   virtual ParseResult parseType(Type &result) = 0;
749 
750   /// Parse an optional type.
751   virtual OptionalParseResult parseOptionalType(Type &result) = 0;
752 
753   /// Parse a type of a specific type.
754   template <typename TypeT>
parseType(TypeT & result)755   ParseResult parseType(TypeT &result) {
756     llvm::SMLoc loc = getCurrentLocation();
757 
758     // Parse any kind of type.
759     Type type;
760     if (parseType(type))
761       return failure();
762 
763     // Check for the right kind of attribute.
764     result = type.dyn_cast<TypeT>();
765     if (!result)
766       return emitError(loc, "invalid kind of type specified");
767 
768     return success();
769   }
770 
771   /// Parse a type list.
parseTypeList(SmallVectorImpl<Type> & result)772   ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
773     do {
774       Type type;
775       if (parseType(type))
776         return failure();
777       result.push_back(type);
778     } while (succeeded(parseOptionalComma()));
779     return success();
780   }
781 
782   /// Parse an arrow followed by a type list.
783   virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
784 
785   /// Parse an optional arrow followed by a type list.
786   virtual ParseResult
787   parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
788 
789   /// Parse a colon followed by a type.
790   virtual ParseResult parseColonType(Type &result) = 0;
791 
792   /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
793   template <typename TypeType>
parseColonType(TypeType & result)794   ParseResult parseColonType(TypeType &result) {
795     llvm::SMLoc loc = getCurrentLocation();
796 
797     // Parse any kind of type.
798     Type type;
799     if (parseColonType(type))
800       return failure();
801 
802     // Check for the right kind of attribute.
803     result = type.dyn_cast<TypeType>();
804     if (!result)
805       return emitError(loc, "invalid kind of type specified");
806 
807     return success();
808   }
809 
810   /// Parse a colon followed by a type list, which must have at least one type.
811   virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
812 
813   /// Parse an optional colon followed by a type list, which if present must
814   /// have at least one type.
815   virtual ParseResult
816   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
817 
818   /// Parse a list of assignments of the form
819   ///   (%x1 = %y1, %x2 = %y2, ...)
parseAssignmentList(SmallVectorImpl<OperandType> & lhs,SmallVectorImpl<OperandType> & rhs)820   ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
821                                   SmallVectorImpl<OperandType> &rhs) {
822     OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
823     if (!result.hasValue())
824       return emitError(getCurrentLocation(), "expected '('");
825     return result.getValue();
826   }
827 
828   virtual OptionalParseResult
829   parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
830                               SmallVectorImpl<OperandType> &rhs) = 0;
831 
832   /// Parse a keyword followed by a type.
parseKeywordType(const char * keyword,Type & result)833   ParseResult parseKeywordType(const char *keyword, Type &result) {
834     return failure(parseKeyword(keyword) || parseType(result));
835   }
836 
837   /// Add the specified type to the end of the specified type list and return
838   /// success.  This is a helper designed to allow parse methods to be simple
839   /// and chain through || operators.
addTypeToList(Type type,SmallVectorImpl<Type> & result)840   ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
841     result.push_back(type);
842     return success();
843   }
844 
845   /// Add the specified types to the end of the specified type list and return
846   /// success.  This is a helper designed to allow parse methods to be simple
847   /// and chain through || operators.
addTypesToList(ArrayRef<Type> types,SmallVectorImpl<Type> & result)848   ParseResult addTypesToList(ArrayRef<Type> types,
849                              SmallVectorImpl<Type> &result) {
850     result.append(types.begin(), types.end());
851     return success();
852   }
853 
854 private:
855   /// Parse either an operand list or a region argument list depending on
856   /// whether isOperandList is true.
857   ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
858                                           bool isOperandList,
859                                           int requiredOperandCount,
860                                           Delimiter delimiter);
861 };
862 
863 //===--------------------------------------------------------------------===//
864 // Dialect OpAsm interface.
865 //===--------------------------------------------------------------------===//
866 
867 /// A functor used to set the name of the start of a result group of an
868 /// operation. See 'getAsmResultNames' below for more details.
869 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
870 
871 class OpAsmDialectInterface
872     : public DialectInterface::Base<OpAsmDialectInterface> {
873 public:
OpAsmDialectInterface(Dialect * dialect)874   OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
875 
876   /// Hooks for getting an alias identifier alias for a given symbol, that is
877   /// not necessarily a part of this dialect. The identifier is used in place of
878   /// the symbol when printing textual IR. These aliases must not contain `.` or
879   /// end with a numeric digit([0-9]+). Returns success if an alias was
880   /// provided, failure otherwise.
getAlias(Attribute attr,raw_ostream & os)881   virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const {
882     return failure();
883   }
getAlias(Type type,raw_ostream & os)884   virtual LogicalResult getAlias(Type type, raw_ostream &os) const {
885     return failure();
886   }
887 
888   /// Get a special name to use when printing the given operation. See
889   /// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
getAsmResultNames(Operation * op,OpAsmSetValueNameFn setNameFn)890   virtual void getAsmResultNames(Operation *op,
891                                  OpAsmSetValueNameFn setNameFn) const {}
892 
893   /// Get a special name to use when printing the entry block arguments of the
894   /// region contained by an operation in this dialect.
getAsmBlockArgumentNames(Block * block,OpAsmSetValueNameFn setNameFn)895   virtual void getAsmBlockArgumentNames(Block *block,
896                                         OpAsmSetValueNameFn setNameFn) const {}
897 };
898 } // end namespace mlir
899 
900 //===--------------------------------------------------------------------===//
901 // Operation OpAsm interface.
902 //===--------------------------------------------------------------------===//
903 
904 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
905 #include "mlir/IR/OpAsmInterface.h.inc"
906 
907 #endif
908