1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AsmParserImpl.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 using llvm::MemoryBuffer;
23 using llvm::SMLoc;
24 using llvm::SourceMgr;
25 
26 namespace {
27 /// This class provides the main implementation of the DialectAsmParser that
28 /// allows for dialects to parse attributes and types. This allows for dialect
29 /// hooking into the main MLIR parsing logic.
30 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
31 public:
CustomDialectAsmParser(StringRef fullSpec,Parser & parser)32   CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
33       : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
34         fullSpec(fullSpec) {}
~CustomDialectAsmParser()35   ~CustomDialectAsmParser() override {}
36 
37   /// Returns the full specification of the symbol being parsed. This allows
38   /// for using a separate parser if necessary.
getFullSymbolSpec() const39   StringRef getFullSymbolSpec() const override { return fullSpec; }
40 
41 private:
42   /// The full symbol specification.
43   StringRef fullSpec;
44 };
45 } // namespace
46 
47 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
48 /// and may be recursive.  Return with the 'prettyName' StringRef encompassing
49 /// the entire pretty name.
50 ///
51 ///   pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
52 ///   pretty-dialect-sym-contents ::= pretty-dialect-sym-body
53 ///                                  | '(' pretty-dialect-sym-contents+ ')'
54 ///                                  | '[' pretty-dialect-sym-contents+ ']'
55 ///                                  | '{' pretty-dialect-sym-contents+ '}'
56 ///                                  | '[^[<({>\])}\0]+'
57 ///
parsePrettyDialectSymbolName(StringRef & prettyName)58 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
59   // Pretty symbol names are a relatively unstructured format that contains a
60   // series of properly nested punctuation, with anything else in the middle.
61   // Scan ahead to find it and consume it if successful, otherwise emit an
62   // error.
63   auto *curPtr = getTokenSpelling().data();
64 
65   SmallVector<char, 8> nestedPunctuation;
66 
67   // Scan over the nested punctuation, bailing out on error and consuming until
68   // we find the end.  We know that we're currently looking at the '<', so we
69   // can go until we find the matching '>' character.
70   assert(*curPtr == '<');
71   do {
72     char c = *curPtr++;
73     switch (c) {
74     case '\0':
75       // This also handles the EOF case.
76       return emitError("unexpected nul or EOF in pretty dialect name");
77     case '<':
78     case '[':
79     case '(':
80     case '{':
81       nestedPunctuation.push_back(c);
82       continue;
83 
84     case '-':
85       // The sequence `->` is treated as special token.
86       if (*curPtr == '>')
87         ++curPtr;
88       continue;
89 
90     case '>':
91       if (nestedPunctuation.pop_back_val() != '<')
92         return emitError("unbalanced '>' character in pretty dialect name");
93       break;
94     case ']':
95       if (nestedPunctuation.pop_back_val() != '[')
96         return emitError("unbalanced ']' character in pretty dialect name");
97       break;
98     case ')':
99       if (nestedPunctuation.pop_back_val() != '(')
100         return emitError("unbalanced ')' character in pretty dialect name");
101       break;
102     case '}':
103       if (nestedPunctuation.pop_back_val() != '{')
104         return emitError("unbalanced '}' character in pretty dialect name");
105       break;
106 
107     default:
108       continue;
109     }
110   } while (!nestedPunctuation.empty());
111 
112   // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
113   // consuming all this stuff, and return.
114   state.lex.resetPointer(curPtr);
115 
116   unsigned length = curPtr - prettyName.begin();
117   prettyName = StringRef(prettyName.begin(), length);
118   consumeToken();
119   return success();
120 }
121 
122 /// Parse an extended dialect symbol.
123 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
parseExtendedSymbol(Parser & p,Token::Kind identifierTok,SymbolAliasMap & aliases,CreateFn && createSymbol)124 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
125                                   SymbolAliasMap &aliases,
126                                   CreateFn &&createSymbol) {
127   // Parse the dialect namespace.
128   StringRef identifier = p.getTokenSpelling().drop_front();
129   auto loc = p.getToken().getLoc();
130   p.consumeToken(identifierTok);
131 
132   // If there is no '<' token following this, and if the typename contains no
133   // dot, then we are parsing a symbol alias.
134   if (p.getToken().isNot(Token::less) && !identifier.contains('.')) {
135     // Check for an alias for this type.
136     auto aliasIt = aliases.find(identifier);
137     if (aliasIt == aliases.end())
138       return (p.emitError("undefined symbol alias id '" + identifier + "'"),
139               nullptr);
140     return aliasIt->second;
141   }
142 
143   // Otherwise, we are parsing a dialect-specific symbol.  If the name contains
144   // a dot, then this is the "pretty" form.  If not, it is the verbose form that
145   // looks like <"...">.
146   std::string symbolData;
147   auto dialectName = identifier;
148 
149   // Handle the verbose form, where "identifier" is a simple dialect name.
150   if (!identifier.contains('.')) {
151     // Consume the '<'.
152     if (p.parseToken(Token::less, "expected '<' in dialect type"))
153       return nullptr;
154 
155     // Parse the symbol specific data.
156     if (p.getToken().isNot(Token::string))
157       return (p.emitError("expected string literal data in dialect symbol"),
158               nullptr);
159     symbolData = p.getToken().getStringValue();
160     loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1);
161     p.consumeToken(Token::string);
162 
163     // Consume the '>'.
164     if (p.parseToken(Token::greater, "expected '>' in dialect symbol"))
165       return nullptr;
166   } else {
167     // Ok, the dialect name is the part of the identifier before the dot, the
168     // part after the dot is the dialect's symbol, or the start thereof.
169     auto dotHalves = identifier.split('.');
170     dialectName = dotHalves.first;
171     auto prettyName = dotHalves.second;
172     loc = llvm::SMLoc::getFromPointer(prettyName.data());
173 
174     // If the dialect's symbol is followed immediately by a <, then lex the body
175     // of it into prettyName.
176     if (p.getToken().is(Token::less) &&
177         prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) {
178       if (p.parsePrettyDialectSymbolName(prettyName))
179         return nullptr;
180     }
181 
182     symbolData = prettyName.str();
183   }
184 
185   // Record the name location of the type remapped to the top level buffer.
186   llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc);
187   p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer);
188 
189   // Call into the provided symbol construction function.
190   Symbol sym = createSymbol(dialectName, symbolData, loc);
191 
192   // Pop the last parser location.
193   p.getState().symbols.nestedParserLocs.pop_back();
194   return sym;
195 }
196 
197 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
198 /// parsing failed, nullptr is returned. The number of bytes read from the input
199 /// string is returned in 'numRead'.
200 template <typename T, typename ParserFn>
parseSymbol(StringRef inputStr,MLIRContext * context,SymbolState & symbolState,ParserFn && parserFn,size_t * numRead=nullptr)201 static T parseSymbol(StringRef inputStr, MLIRContext *context,
202                      SymbolState &symbolState, ParserFn &&parserFn,
203                      size_t *numRead = nullptr) {
204   SourceMgr sourceMgr;
205   auto memBuffer = MemoryBuffer::getMemBuffer(
206       inputStr, /*BufferName=*/"<mlir_parser_buffer>",
207       /*RequiresNullTerminator=*/false);
208   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
209   ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr);
210   Parser parser(state);
211 
212   Token startTok = parser.getToken();
213   T symbol = parserFn(parser);
214   if (!symbol)
215     return T();
216 
217   // If 'numRead' is valid, then provide the number of bytes that were read.
218   Token endTok = parser.getToken();
219   if (numRead) {
220     *numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
221                                    startTok.getLoc().getPointer());
222 
223     // Otherwise, ensure that all of the tokens were parsed.
224   } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
225     parser.emitError(endTok.getLoc(), "encountered unexpected token");
226     return T();
227   }
228   return symbol;
229 }
230 
231 /// Parse an extended attribute.
232 ///
233 ///   extended-attribute ::= (dialect-attribute | attribute-alias)
234 ///   dialect-attribute  ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
235 ///   dialect-attribute  ::= `#` alias-name pretty-dialect-sym-body?
236 ///   attribute-alias    ::= `#` alias-name
237 ///
parseExtendedAttr(Type type)238 Attribute Parser::parseExtendedAttr(Type type) {
239   Attribute attr = parseExtendedSymbol<Attribute>(
240       *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
241       [&](StringRef dialectName, StringRef symbolData,
242           llvm::SMLoc loc) -> Attribute {
243         // Parse an optional trailing colon type.
244         Type attrType = type;
245         if (consumeIf(Token::colon) && !(attrType = parseType()))
246           return Attribute();
247 
248         // If we found a registered dialect, then ask it to parse the attribute.
249         if (Dialect *dialect =
250                 builder.getContext()->getOrLoadDialect(dialectName)) {
251           return parseSymbol<Attribute>(
252               symbolData, state.context, state.symbols, [&](Parser &parser) {
253                 CustomDialectAsmParser customParser(symbolData, parser);
254                 return dialect->parseAttribute(customParser, attrType);
255               });
256         }
257 
258         // Otherwise, form a new opaque attribute.
259         return OpaqueAttr::getChecked(
260             [&] { return emitError(loc); },
261             Identifier::get(dialectName, state.context), symbolData,
262             attrType ? attrType : NoneType::get(state.context));
263       });
264 
265   // Ensure that the attribute has the same type as requested.
266   if (attr && type && attr.getType() != type) {
267     emitError("attribute type different than expected: expected ")
268         << type << ", but got " << attr.getType();
269     return nullptr;
270   }
271   return attr;
272 }
273 
274 /// Parse an extended type.
275 ///
276 ///   extended-type ::= (dialect-type | type-alias)
277 ///   dialect-type  ::= `!` dialect-namespace `<` `"` type-data `"` `>`
278 ///   dialect-type  ::= `!` alias-name pretty-dialect-attribute-body?
279 ///   type-alias    ::= `!` alias-name
280 ///
parseExtendedType()281 Type Parser::parseExtendedType() {
282   return parseExtendedSymbol<Type>(
283       *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
284       [&](StringRef dialectName, StringRef symbolData,
285           llvm::SMLoc loc) -> Type {
286         // If we found a registered dialect, then ask it to parse the type.
287         auto *dialect = state.context->getOrLoadDialect(dialectName);
288 
289         if (dialect) {
290           return parseSymbol<Type>(
291               symbolData, state.context, state.symbols, [&](Parser &parser) {
292                 CustomDialectAsmParser customParser(symbolData, parser);
293                 return dialect->parseType(customParser);
294               });
295         }
296 
297         // Otherwise, form a new opaque type.
298         return OpaqueType::getChecked(
299             [&] { return emitError(loc); },
300             Identifier::get(dialectName, state.context), symbolData);
301       });
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // mlir::parseAttribute/parseType
306 //===----------------------------------------------------------------------===//
307 
308 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
309 /// parsing failed, nullptr is returned. The number of bytes read from the input
310 /// string is returned in 'numRead'.
311 template <typename T, typename ParserFn>
parseSymbol(StringRef inputStr,MLIRContext * context,size_t & numRead,ParserFn && parserFn)312 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
313                      ParserFn &&parserFn) {
314   SymbolState aliasState;
315   return parseSymbol<T>(
316       inputStr, context, aliasState,
317       [&](Parser &parser) {
318         SourceMgrDiagnosticHandler handler(
319             const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
320             parser.getContext());
321         return parserFn(parser);
322       },
323       &numRead);
324 }
325 
parseAttribute(StringRef attrStr,MLIRContext * context)326 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
327   size_t numRead = 0;
328   return parseAttribute(attrStr, context, numRead);
329 }
parseAttribute(StringRef attrStr,Type type)330 Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
331   size_t numRead = 0;
332   return parseAttribute(attrStr, type, numRead);
333 }
334 
parseAttribute(StringRef attrStr,MLIRContext * context,size_t & numRead)335 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
336                                size_t &numRead) {
337   return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
338     return parser.parseAttribute();
339   });
340 }
parseAttribute(StringRef attrStr,Type type,size_t & numRead)341 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
342   return parseSymbol<Attribute>(
343       attrStr, type.getContext(), numRead,
344       [type](Parser &parser) { return parser.parseAttribute(type); });
345 }
346 
parseType(StringRef typeStr,MLIRContext * context)347 Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
348   size_t numRead = 0;
349   return parseType(typeStr, context, numRead);
350 }
351 
parseType(StringRef typeStr,MLIRContext * context,size_t & numRead)352 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
353   return parseSymbol<Type>(typeStr, context, numRead,
354                            [](Parser &parser) { return parser.parseType(); });
355 }
356