1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "AsmParserImpl.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "llvm/Support/SourceMgr.h"
19
20 using namespace mlir;
21 using namespace mlir::detail;
22 using llvm::MemoryBuffer;
23 using llvm::SMLoc;
24 using llvm::SourceMgr;
25
26 namespace {
27 /// This class provides the main implementation of the DialectAsmParser that
28 /// allows for dialects to parse attributes and types. This allows for dialect
29 /// hooking into the main MLIR parsing logic.
30 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
31 public:
CustomDialectAsmParser(StringRef fullSpec,Parser & parser)32 CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
33 : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
34 fullSpec(fullSpec) {}
~CustomDialectAsmParser()35 ~CustomDialectAsmParser() override {}
36
37 /// Returns the full specification of the symbol being parsed. This allows
38 /// for using a separate parser if necessary.
getFullSymbolSpec() const39 StringRef getFullSymbolSpec() const override { return fullSpec; }
40
41 private:
42 /// The full symbol specification.
43 StringRef fullSpec;
44 };
45 } // namespace
46
47 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
48 /// and may be recursive. Return with the 'prettyName' StringRef encompassing
49 /// the entire pretty name.
50 ///
51 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
52 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
53 /// | '(' pretty-dialect-sym-contents+ ')'
54 /// | '[' pretty-dialect-sym-contents+ ']'
55 /// | '{' pretty-dialect-sym-contents+ '}'
56 /// | '[^[<({>\])}\0]+'
57 ///
parsePrettyDialectSymbolName(StringRef & prettyName)58 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
59 // Pretty symbol names are a relatively unstructured format that contains a
60 // series of properly nested punctuation, with anything else in the middle.
61 // Scan ahead to find it and consume it if successful, otherwise emit an
62 // error.
63 auto *curPtr = getTokenSpelling().data();
64
65 SmallVector<char, 8> nestedPunctuation;
66
67 // Scan over the nested punctuation, bailing out on error and consuming until
68 // we find the end. We know that we're currently looking at the '<', so we
69 // can go until we find the matching '>' character.
70 assert(*curPtr == '<');
71 do {
72 char c = *curPtr++;
73 switch (c) {
74 case '\0':
75 // This also handles the EOF case.
76 return emitError("unexpected nul or EOF in pretty dialect name");
77 case '<':
78 case '[':
79 case '(':
80 case '{':
81 nestedPunctuation.push_back(c);
82 continue;
83
84 case '-':
85 // The sequence `->` is treated as special token.
86 if (*curPtr == '>')
87 ++curPtr;
88 continue;
89
90 case '>':
91 if (nestedPunctuation.pop_back_val() != '<')
92 return emitError("unbalanced '>' character in pretty dialect name");
93 break;
94 case ']':
95 if (nestedPunctuation.pop_back_val() != '[')
96 return emitError("unbalanced ']' character in pretty dialect name");
97 break;
98 case ')':
99 if (nestedPunctuation.pop_back_val() != '(')
100 return emitError("unbalanced ')' character in pretty dialect name");
101 break;
102 case '}':
103 if (nestedPunctuation.pop_back_val() != '{')
104 return emitError("unbalanced '}' character in pretty dialect name");
105 break;
106
107 default:
108 continue;
109 }
110 } while (!nestedPunctuation.empty());
111
112 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
113 // consuming all this stuff, and return.
114 state.lex.resetPointer(curPtr);
115
116 unsigned length = curPtr - prettyName.begin();
117 prettyName = StringRef(prettyName.begin(), length);
118 consumeToken();
119 return success();
120 }
121
122 /// Parse an extended dialect symbol.
123 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
parseExtendedSymbol(Parser & p,Token::Kind identifierTok,SymbolAliasMap & aliases,CreateFn && createSymbol)124 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
125 SymbolAliasMap &aliases,
126 CreateFn &&createSymbol) {
127 // Parse the dialect namespace.
128 StringRef identifier = p.getTokenSpelling().drop_front();
129 auto loc = p.getToken().getLoc();
130 p.consumeToken(identifierTok);
131
132 // If there is no '<' token following this, and if the typename contains no
133 // dot, then we are parsing a symbol alias.
134 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) {
135 // Check for an alias for this type.
136 auto aliasIt = aliases.find(identifier);
137 if (aliasIt == aliases.end())
138 return (p.emitError("undefined symbol alias id '" + identifier + "'"),
139 nullptr);
140 return aliasIt->second;
141 }
142
143 // Otherwise, we are parsing a dialect-specific symbol. If the name contains
144 // a dot, then this is the "pretty" form. If not, it is the verbose form that
145 // looks like <"...">.
146 std::string symbolData;
147 auto dialectName = identifier;
148
149 // Handle the verbose form, where "identifier" is a simple dialect name.
150 if (!identifier.contains('.')) {
151 // Consume the '<'.
152 if (p.parseToken(Token::less, "expected '<' in dialect type"))
153 return nullptr;
154
155 // Parse the symbol specific data.
156 if (p.getToken().isNot(Token::string))
157 return (p.emitError("expected string literal data in dialect symbol"),
158 nullptr);
159 symbolData = p.getToken().getStringValue();
160 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1);
161 p.consumeToken(Token::string);
162
163 // Consume the '>'.
164 if (p.parseToken(Token::greater, "expected '>' in dialect symbol"))
165 return nullptr;
166 } else {
167 // Ok, the dialect name is the part of the identifier before the dot, the
168 // part after the dot is the dialect's symbol, or the start thereof.
169 auto dotHalves = identifier.split('.');
170 dialectName = dotHalves.first;
171 auto prettyName = dotHalves.second;
172 loc = llvm::SMLoc::getFromPointer(prettyName.data());
173
174 // If the dialect's symbol is followed immediately by a <, then lex the body
175 // of it into prettyName.
176 if (p.getToken().is(Token::less) &&
177 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) {
178 if (p.parsePrettyDialectSymbolName(prettyName))
179 return nullptr;
180 }
181
182 symbolData = prettyName.str();
183 }
184
185 // Record the name location of the type remapped to the top level buffer.
186 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc);
187 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer);
188
189 // Call into the provided symbol construction function.
190 Symbol sym = createSymbol(dialectName, symbolData, loc);
191
192 // Pop the last parser location.
193 p.getState().symbols.nestedParserLocs.pop_back();
194 return sym;
195 }
196
197 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
198 /// parsing failed, nullptr is returned. The number of bytes read from the input
199 /// string is returned in 'numRead'.
200 template <typename T, typename ParserFn>
parseSymbol(StringRef inputStr,MLIRContext * context,SymbolState & symbolState,ParserFn && parserFn,size_t * numRead=nullptr)201 static T parseSymbol(StringRef inputStr, MLIRContext *context,
202 SymbolState &symbolState, ParserFn &&parserFn,
203 size_t *numRead = nullptr) {
204 SourceMgr sourceMgr;
205 auto memBuffer = MemoryBuffer::getMemBuffer(
206 inputStr, /*BufferName=*/"<mlir_parser_buffer>",
207 /*RequiresNullTerminator=*/false);
208 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
209 ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr);
210 Parser parser(state);
211
212 Token startTok = parser.getToken();
213 T symbol = parserFn(parser);
214 if (!symbol)
215 return T();
216
217 // If 'numRead' is valid, then provide the number of bytes that were read.
218 Token endTok = parser.getToken();
219 if (numRead) {
220 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
221 startTok.getLoc().getPointer());
222
223 // Otherwise, ensure that all of the tokens were parsed.
224 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
225 parser.emitError(endTok.getLoc(), "encountered unexpected token");
226 return T();
227 }
228 return symbol;
229 }
230
231 /// Parse an extended attribute.
232 ///
233 /// extended-attribute ::= (dialect-attribute | attribute-alias)
234 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
235 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
236 /// attribute-alias ::= `#` alias-name
237 ///
parseExtendedAttr(Type type)238 Attribute Parser::parseExtendedAttr(Type type) {
239 Attribute attr = parseExtendedSymbol<Attribute>(
240 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
241 [&](StringRef dialectName, StringRef symbolData,
242 llvm::SMLoc loc) -> Attribute {
243 // Parse an optional trailing colon type.
244 Type attrType = type;
245 if (consumeIf(Token::colon) && !(attrType = parseType()))
246 return Attribute();
247
248 // If we found a registered dialect, then ask it to parse the attribute.
249 if (Dialect *dialect =
250 builder.getContext()->getOrLoadDialect(dialectName)) {
251 return parseSymbol<Attribute>(
252 symbolData, state.context, state.symbols, [&](Parser &parser) {
253 CustomDialectAsmParser customParser(symbolData, parser);
254 return dialect->parseAttribute(customParser, attrType);
255 });
256 }
257
258 // Otherwise, form a new opaque attribute.
259 return OpaqueAttr::getChecked(
260 [&] { return emitError(loc); },
261 Identifier::get(dialectName, state.context), symbolData,
262 attrType ? attrType : NoneType::get(state.context));
263 });
264
265 // Ensure that the attribute has the same type as requested.
266 if (attr && type && attr.getType() != type) {
267 emitError("attribute type different than expected: expected ")
268 << type << ", but got " << attr.getType();
269 return nullptr;
270 }
271 return attr;
272 }
273
274 /// Parse an extended type.
275 ///
276 /// extended-type ::= (dialect-type | type-alias)
277 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
278 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
279 /// type-alias ::= `!` alias-name
280 ///
parseExtendedType()281 Type Parser::parseExtendedType() {
282 return parseExtendedSymbol<Type>(
283 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
284 [&](StringRef dialectName, StringRef symbolData,
285 llvm::SMLoc loc) -> Type {
286 // If we found a registered dialect, then ask it to parse the type.
287 auto *dialect = state.context->getOrLoadDialect(dialectName);
288
289 if (dialect) {
290 return parseSymbol<Type>(
291 symbolData, state.context, state.symbols, [&](Parser &parser) {
292 CustomDialectAsmParser customParser(symbolData, parser);
293 return dialect->parseType(customParser);
294 });
295 }
296
297 // Otherwise, form a new opaque type.
298 return OpaqueType::getChecked(
299 [&] { return emitError(loc); },
300 Identifier::get(dialectName, state.context), symbolData);
301 });
302 }
303
304 //===----------------------------------------------------------------------===//
305 // mlir::parseAttribute/parseType
306 //===----------------------------------------------------------------------===//
307
308 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
309 /// parsing failed, nullptr is returned. The number of bytes read from the input
310 /// string is returned in 'numRead'.
311 template <typename T, typename ParserFn>
parseSymbol(StringRef inputStr,MLIRContext * context,size_t & numRead,ParserFn && parserFn)312 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
313 ParserFn &&parserFn) {
314 SymbolState aliasState;
315 return parseSymbol<T>(
316 inputStr, context, aliasState,
317 [&](Parser &parser) {
318 SourceMgrDiagnosticHandler handler(
319 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
320 parser.getContext());
321 return parserFn(parser);
322 },
323 &numRead);
324 }
325
parseAttribute(StringRef attrStr,MLIRContext * context)326 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
327 size_t numRead = 0;
328 return parseAttribute(attrStr, context, numRead);
329 }
parseAttribute(StringRef attrStr,Type type)330 Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
331 size_t numRead = 0;
332 return parseAttribute(attrStr, type, numRead);
333 }
334
parseAttribute(StringRef attrStr,MLIRContext * context,size_t & numRead)335 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
336 size_t &numRead) {
337 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
338 return parser.parseAttribute();
339 });
340 }
parseAttribute(StringRef attrStr,Type type,size_t & numRead)341 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
342 return parseSymbol<Attribute>(
343 attrStr, type.getContext(), numRead,
344 [type](Parser &parser) { return parser.parseAttribute(type); });
345 }
346
parseType(StringRef typeStr,MLIRContext * context)347 Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
348 size_t numRead = 0;
349 return parseType(typeStr, context, numRead);
350 }
351
parseType(StringRef typeStr,MLIRContext * context,size_t & numRead)352 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
353 return parseSymbol<Type>(typeStr, context, numRead,
354 [](Parser &parser) { return parser.parseType(); });
355 }
356