1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/TensorEncoding.h"
17
18 using namespace mlir;
19 using namespace mlir::detail;
20
21 /// Optionally parse a type.
parseOptionalType(Type & type)22 OptionalParseResult Parser::parseOptionalType(Type &type) {
23 // There are many different starting tokens for a type, check them here.
24 switch (getToken().getKind()) {
25 case Token::l_paren:
26 case Token::kw_memref:
27 case Token::kw_tensor:
28 case Token::kw_complex:
29 case Token::kw_tuple:
30 case Token::kw_vector:
31 case Token::inttype:
32 case Token::kw_bf16:
33 case Token::kw_f16:
34 case Token::kw_f32:
35 case Token::kw_f64:
36 case Token::kw_index:
37 case Token::kw_none:
38 case Token::exclamation_identifier:
39 return failure(!(type = parseType()));
40
41 default:
42 return llvm::None;
43 }
44 }
45
46 /// Parse an arbitrary type.
47 ///
48 /// type ::= function-type
49 /// | non-function-type
50 ///
parseType()51 Type Parser::parseType() {
52 if (getToken().is(Token::l_paren))
53 return parseFunctionType();
54 return parseNonFunctionType();
55 }
56
57 /// Parse a function result type.
58 ///
59 /// function-result-type ::= type-list-parens
60 /// | non-function-type
61 ///
parseFunctionResultTypes(SmallVectorImpl<Type> & elements)62 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
63 if (getToken().is(Token::l_paren))
64 return parseTypeListParens(elements);
65
66 Type t = parseNonFunctionType();
67 if (!t)
68 return failure();
69 elements.push_back(t);
70 return success();
71 }
72
73 /// Parse a list of types without an enclosing parenthesis. The list must have
74 /// at least one member.
75 ///
76 /// type-list-no-parens ::= type (`,` type)*
77 ///
parseTypeListNoParens(SmallVectorImpl<Type> & elements)78 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
79 auto parseElt = [&]() -> ParseResult {
80 auto elt = parseType();
81 elements.push_back(elt);
82 return elt ? success() : failure();
83 };
84
85 return parseCommaSeparatedList(parseElt);
86 }
87
88 /// Parse a parenthesized list of types.
89 ///
90 /// type-list-parens ::= `(` `)`
91 /// | `(` type-list-no-parens `)`
92 ///
parseTypeListParens(SmallVectorImpl<Type> & elements)93 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
94 if (parseToken(Token::l_paren, "expected '('"))
95 return failure();
96
97 // Handle empty lists.
98 if (getToken().is(Token::r_paren))
99 return consumeToken(), success();
100
101 if (parseTypeListNoParens(elements) ||
102 parseToken(Token::r_paren, "expected ')'"))
103 return failure();
104 return success();
105 }
106
107 /// Parse a complex type.
108 ///
109 /// complex-type ::= `complex` `<` type `>`
110 ///
parseComplexType()111 Type Parser::parseComplexType() {
112 consumeToken(Token::kw_complex);
113
114 // Parse the '<'.
115 if (parseToken(Token::less, "expected '<' in complex type"))
116 return nullptr;
117
118 llvm::SMLoc elementTypeLoc = getToken().getLoc();
119 auto elementType = parseType();
120 if (!elementType ||
121 parseToken(Token::greater, "expected '>' in complex type"))
122 return nullptr;
123 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
124 return emitError(elementTypeLoc, "invalid element type for complex"),
125 nullptr;
126
127 return ComplexType::get(elementType);
128 }
129
130 /// Parse a function type.
131 ///
132 /// function-type ::= type-list-parens `->` function-result-type
133 ///
parseFunctionType()134 Type Parser::parseFunctionType() {
135 assert(getToken().is(Token::l_paren));
136
137 SmallVector<Type, 4> arguments, results;
138 if (parseTypeListParens(arguments) ||
139 parseToken(Token::arrow, "expected '->' in function type") ||
140 parseFunctionResultTypes(results))
141 return nullptr;
142
143 return builder.getFunctionType(arguments, results);
144 }
145
146 /// Parse the offset and strides from a strided layout specification.
147 ///
148 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
149 ///
parseStridedLayout(int64_t & offset,SmallVectorImpl<int64_t> & strides)150 ParseResult Parser::parseStridedLayout(int64_t &offset,
151 SmallVectorImpl<int64_t> &strides) {
152 // Parse offset.
153 consumeToken(Token::kw_offset);
154 if (!consumeIf(Token::colon))
155 return emitError("expected colon after `offset` keyword");
156 auto maybeOffset = getToken().getUnsignedIntegerValue();
157 bool question = getToken().is(Token::question);
158 if (!maybeOffset && !question)
159 return emitError("invalid offset");
160 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
161 : MemRefType::getDynamicStrideOrOffset();
162 consumeToken();
163
164 if (!consumeIf(Token::comma))
165 return emitError("expected comma after offset value");
166
167 // Parse stride list.
168 if (parseToken(Token::kw_strides,
169 "expected `strides` keyword after offset specification") ||
170
171 parseToken(Token::colon, "expected colon after `strides` keyword") ||
172 parseStrideList(strides))
173 return failure();
174 return success();
175 }
176
177 /// Parse a memref type.
178 ///
179 /// memref-type ::= ranked-memref-type | unranked-memref-type
180 ///
181 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
182 /// (`,` layout-specification)? (`,` memory-space)? `>`
183 ///
184 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
185 ///
186 /// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
187 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
188 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
189 /// layout-specification ::= semi-affine-map-composition | strided-layout
190 /// memory-space ::= integer-literal /* | TODO: address-space-id */
191 ///
parseMemRefType()192 Type Parser::parseMemRefType() {
193 llvm::SMLoc loc = getToken().getLoc();
194 consumeToken(Token::kw_memref);
195
196 if (parseToken(Token::less, "expected '<' in memref type"))
197 return nullptr;
198
199 bool isUnranked;
200 SmallVector<int64_t, 4> dimensions;
201
202 if (consumeIf(Token::star)) {
203 // This is an unranked memref type.
204 isUnranked = true;
205 if (parseXInDimensionList())
206 return nullptr;
207
208 } else {
209 isUnranked = false;
210 if (parseDimensionListRanked(dimensions))
211 return nullptr;
212 }
213
214 // Parse the element type.
215 auto typeLoc = getToken().getLoc();
216 auto elementType = parseType();
217 if (!elementType)
218 return nullptr;
219
220 // Check that memref is formed from allowed types.
221 if (!BaseMemRefType::isValidElementType(elementType))
222 return emitError(typeLoc, "invalid memref element type"), nullptr;
223
224 // Parse semi-affine-map-composition.
225 SmallVector<AffineMap, 2> affineMapComposition;
226 Attribute memorySpace;
227 unsigned numDims = dimensions.size();
228
229 auto parseElt = [&]() -> ParseResult {
230 AffineMap map;
231 llvm::SMLoc mapLoc = getToken().getLoc();
232
233 // Check for AffineMap as offset/strides.
234 if (getToken().is(Token::kw_offset)) {
235 int64_t offset;
236 SmallVector<int64_t, 4> strides;
237 if (failed(parseStridedLayout(offset, strides)))
238 return failure();
239 // Construct strided affine map.
240 map = makeStridedLinearLayoutMap(strides, offset, state.context);
241 } else {
242 // Either it is AffineMapAttr or memory space attribute.
243 Attribute attr = parseAttribute();
244 if (!attr)
245 return failure();
246
247 if (AffineMapAttr affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
248 map = affineMapAttr.getValue();
249 } else if (memorySpace) {
250 return emitError("multiple memory spaces specified in memref type");
251 } else {
252 memorySpace = attr;
253 return success();
254 }
255 }
256
257 if (isUnranked)
258 return emitError("cannot have affine map for unranked memref type");
259 if (memorySpace)
260 return emitError("expected memory space to be last in memref type");
261
262 if (map.getNumDims() != numDims) {
263 size_t i = affineMapComposition.size();
264 return emitError(mapLoc, "memref affine map dimension mismatch between ")
265 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
266 << " and affine map" << i + 1 << ": " << numDims
267 << " != " << map.getNumDims();
268 }
269 numDims = map.getNumResults();
270 affineMapComposition.push_back(map);
271 return success();
272 };
273
274 // Parse a list of mappings and address space if present.
275 if (!consumeIf(Token::greater)) {
276 // Parse comma separated list of affine maps, followed by memory space.
277 if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
278 parseCommaSeparatedListUntil(Token::greater, parseElt,
279 /*allowEmptyList=*/false)) {
280 return nullptr;
281 }
282 }
283
284 if (isUnranked)
285 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
286
287 return getChecked<MemRefType>(loc, dimensions, elementType,
288 affineMapComposition, memorySpace);
289 }
290
291 /// Parse any type except the function type.
292 ///
293 /// non-function-type ::= integer-type
294 /// | index-type
295 /// | float-type
296 /// | extended-type
297 /// | vector-type
298 /// | tensor-type
299 /// | memref-type
300 /// | complex-type
301 /// | tuple-type
302 /// | none-type
303 ///
304 /// index-type ::= `index`
305 /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
306 /// none-type ::= `none`
307 ///
parseNonFunctionType()308 Type Parser::parseNonFunctionType() {
309 switch (getToken().getKind()) {
310 default:
311 return (emitError("expected non-function type"), nullptr);
312 case Token::kw_memref:
313 return parseMemRefType();
314 case Token::kw_tensor:
315 return parseTensorType();
316 case Token::kw_complex:
317 return parseComplexType();
318 case Token::kw_tuple:
319 return parseTupleType();
320 case Token::kw_vector:
321 return parseVectorType();
322 // integer-type
323 case Token::inttype: {
324 auto width = getToken().getIntTypeBitwidth();
325 if (!width.hasValue())
326 return (emitError("invalid integer width"), nullptr);
327 if (width.getValue() > IntegerType::kMaxWidth) {
328 emitError(getToken().getLoc(), "integer bitwidth is limited to ")
329 << IntegerType::kMaxWidth << " bits";
330 return nullptr;
331 }
332
333 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
334 if (Optional<bool> signedness = getToken().getIntTypeSignedness())
335 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
336
337 consumeToken(Token::inttype);
338 return IntegerType::get(getContext(), width.getValue(), signSemantics);
339 }
340
341 // float-type
342 case Token::kw_bf16:
343 consumeToken(Token::kw_bf16);
344 return builder.getBF16Type();
345 case Token::kw_f16:
346 consumeToken(Token::kw_f16);
347 return builder.getF16Type();
348 case Token::kw_f32:
349 consumeToken(Token::kw_f32);
350 return builder.getF32Type();
351 case Token::kw_f64:
352 consumeToken(Token::kw_f64);
353 return builder.getF64Type();
354 case Token::kw_f80:
355 consumeToken(Token::kw_f80);
356 return builder.getF80Type();
357 case Token::kw_f128:
358 consumeToken(Token::kw_f128);
359 return builder.getF128Type();
360
361 // index-type
362 case Token::kw_index:
363 consumeToken(Token::kw_index);
364 return builder.getIndexType();
365
366 // none-type
367 case Token::kw_none:
368 consumeToken(Token::kw_none);
369 return builder.getNoneType();
370
371 // extended type
372 case Token::exclamation_identifier:
373 return parseExtendedType();
374 }
375 }
376
377 /// Parse a tensor type.
378 ///
379 /// tensor-type ::= `tensor` `<` dimension-list type `>`
380 /// dimension-list ::= dimension-list-ranked | `*x`
381 ///
parseTensorType()382 Type Parser::parseTensorType() {
383 consumeToken(Token::kw_tensor);
384
385 if (parseToken(Token::less, "expected '<' in tensor type"))
386 return nullptr;
387
388 bool isUnranked;
389 SmallVector<int64_t, 4> dimensions;
390
391 if (consumeIf(Token::star)) {
392 // This is an unranked tensor type.
393 isUnranked = true;
394
395 if (parseXInDimensionList())
396 return nullptr;
397
398 } else {
399 isUnranked = false;
400 if (parseDimensionListRanked(dimensions))
401 return nullptr;
402 }
403
404 // Parse the element type.
405 auto elementTypeLoc = getToken().getLoc();
406 auto elementType = parseType();
407
408 // Parse an optional encoding attribute.
409 Attribute encoding;
410 if (consumeIf(Token::comma)) {
411 encoding = parseAttribute();
412 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
413 if (failed(v.verifyEncoding(dimensions, elementType,
414 [&] { return emitError(); })))
415 return nullptr;
416 }
417 }
418
419 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
420 return nullptr;
421 if (!TensorType::isValidElementType(elementType))
422 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
423
424 if (isUnranked) {
425 if (encoding)
426 return emitError("cannot apply encoding to unranked tensor"), nullptr;
427 return UnrankedTensorType::get(elementType);
428 }
429 return RankedTensorType::get(dimensions, elementType, encoding);
430 }
431
432 /// Parse a tuple type.
433 ///
434 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
435 ///
parseTupleType()436 Type Parser::parseTupleType() {
437 consumeToken(Token::kw_tuple);
438
439 // Parse the '<'.
440 if (parseToken(Token::less, "expected '<' in tuple type"))
441 return nullptr;
442
443 // Check for an empty tuple by directly parsing '>'.
444 if (consumeIf(Token::greater))
445 return TupleType::get(getContext());
446
447 // Parse the element types and the '>'.
448 SmallVector<Type, 4> types;
449 if (parseTypeListNoParens(types) ||
450 parseToken(Token::greater, "expected '>' in tuple type"))
451 return nullptr;
452
453 return TupleType::get(getContext(), types);
454 }
455
456 /// Parse a vector type.
457 ///
458 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
459 /// non-empty-static-dimension-list ::= decimal-literal `x`
460 /// static-dimension-list
461 /// static-dimension-list ::= (decimal-literal `x`)*
462 ///
parseVectorType()463 VectorType Parser::parseVectorType() {
464 consumeToken(Token::kw_vector);
465
466 if (parseToken(Token::less, "expected '<' in vector type"))
467 return nullptr;
468
469 SmallVector<int64_t, 4> dimensions;
470 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
471 return nullptr;
472 if (dimensions.empty())
473 return (emitError("expected dimension size in vector type"), nullptr);
474 if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
475 return emitError(getToken().getLoc(),
476 "vector types must have positive constant sizes"),
477 nullptr;
478
479 // Parse the element type.
480 auto typeLoc = getToken().getLoc();
481 auto elementType = parseType();
482 if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
483 return nullptr;
484 if (!VectorType::isValidElementType(elementType))
485 return emitError(typeLoc, "vector elements must be int/index/float type"),
486 nullptr;
487
488 return VectorType::get(dimensions, elementType);
489 }
490
491 /// Parse a dimension list of a tensor or memref type. This populates the
492 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
493 /// errors out on `?` otherwise.
494 ///
495 /// dimension-list-ranked ::= (dimension `x`)*
496 /// dimension ::= `?` | decimal-literal
497 ///
498 /// When `allowDynamic` is not set, this is used to parse:
499 ///
500 /// static-dimension-list ::= (decimal-literal `x`)*
501 ParseResult
parseDimensionListRanked(SmallVectorImpl<int64_t> & dimensions,bool allowDynamic)502 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
503 bool allowDynamic) {
504 while (getToken().isAny(Token::integer, Token::question)) {
505 if (consumeIf(Token::question)) {
506 if (!allowDynamic)
507 return emitError("expected static shape");
508 dimensions.push_back(-1);
509 } else {
510 // Hexadecimal integer literals (starting with `0x`) are not allowed in
511 // aggregate type declarations. Therefore, `0xf32` should be processed as
512 // a sequence of separate elements `0`, `x`, `f32`.
513 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
514 // We can get here only if the token is an integer literal. Hexadecimal
515 // integer literals can only start with `0x` (`1x` wouldn't lex as a
516 // literal, just `1` would, at which point we don't get into this
517 // branch).
518 assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
519 dimensions.push_back(0);
520 state.lex.resetPointer(getTokenSpelling().data() + 1);
521 consumeToken();
522 } else {
523 // Make sure this integer value is in bound and valid.
524 auto dimension = getToken().getUnsignedIntegerValue();
525 if (!dimension.hasValue())
526 return emitError("invalid dimension");
527 dimensions.push_back((int64_t)dimension.getValue());
528 consumeToken(Token::integer);
529 }
530 }
531
532 // Make sure we have an 'x' or something like 'xbf32'.
533 if (parseXInDimensionList())
534 return failure();
535 }
536
537 return success();
538 }
539
540 /// Parse an 'x' token in a dimension list, handling the case where the x is
541 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
542 /// token.
parseXInDimensionList()543 ParseResult Parser::parseXInDimensionList() {
544 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
545 return emitError("expected 'x' in dimension list");
546
547 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
548 if (getTokenSpelling().size() != 1)
549 state.lex.resetPointer(getTokenSpelling().data() + 1);
550
551 // Consume the 'x'.
552 consumeToken(Token::bare_identifier);
553
554 return success();
555 }
556
557 // Parse a comma-separated list of dimensions, possibly empty:
558 // stride-list ::= `[` (dimension (`,` dimension)*)? `]`
parseStrideList(SmallVectorImpl<int64_t> & dimensions)559 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
560 return parseCommaSeparatedList(
561 Delimiter::Square,
562 [&]() -> ParseResult {
563 if (consumeIf(Token::question)) {
564 dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
565 } else {
566 // This must be an integer value.
567 int64_t val;
568 if (getToken().getSpelling().getAsInteger(10, val))
569 return emitError("invalid integer value: ")
570 << getToken().getSpelling();
571 // Make sure it is not the one value for `?`.
572 if (ShapedType::isDynamic(val))
573 return emitError("invalid integer value: ")
574 << getToken().getSpelling()
575 << ", use `?` to specify a dynamic dimension";
576
577 if (val == 0)
578 return emitError("invalid memref stride");
579
580 dimensions.push_back(val);
581 consumeToken(Token::integer);
582 }
583 return success();
584 },
585 " in stride list");
586 }
587