1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/TensorEncoding.h"
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
21 /// Optionally parse a type.
parseOptionalType(Type & type)22 OptionalParseResult Parser::parseOptionalType(Type &type) {
23   // There are many different starting tokens for a type, check them here.
24   switch (getToken().getKind()) {
25   case Token::l_paren:
26   case Token::kw_memref:
27   case Token::kw_tensor:
28   case Token::kw_complex:
29   case Token::kw_tuple:
30   case Token::kw_vector:
31   case Token::inttype:
32   case Token::kw_bf16:
33   case Token::kw_f16:
34   case Token::kw_f32:
35   case Token::kw_f64:
36   case Token::kw_index:
37   case Token::kw_none:
38   case Token::exclamation_identifier:
39     return failure(!(type = parseType()));
40 
41   default:
42     return llvm::None;
43   }
44 }
45 
46 /// Parse an arbitrary type.
47 ///
48 ///   type ::= function-type
49 ///          | non-function-type
50 ///
parseType()51 Type Parser::parseType() {
52   if (getToken().is(Token::l_paren))
53     return parseFunctionType();
54   return parseNonFunctionType();
55 }
56 
57 /// Parse a function result type.
58 ///
59 ///   function-result-type ::= type-list-parens
60 ///                          | non-function-type
61 ///
parseFunctionResultTypes(SmallVectorImpl<Type> & elements)62 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
63   if (getToken().is(Token::l_paren))
64     return parseTypeListParens(elements);
65 
66   Type t = parseNonFunctionType();
67   if (!t)
68     return failure();
69   elements.push_back(t);
70   return success();
71 }
72 
73 /// Parse a list of types without an enclosing parenthesis.  The list must have
74 /// at least one member.
75 ///
76 ///   type-list-no-parens ::=  type (`,` type)*
77 ///
parseTypeListNoParens(SmallVectorImpl<Type> & elements)78 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
79   auto parseElt = [&]() -> ParseResult {
80     auto elt = parseType();
81     elements.push_back(elt);
82     return elt ? success() : failure();
83   };
84 
85   return parseCommaSeparatedList(parseElt);
86 }
87 
88 /// Parse a parenthesized list of types.
89 ///
90 ///   type-list-parens ::= `(` `)`
91 ///                      | `(` type-list-no-parens `)`
92 ///
parseTypeListParens(SmallVectorImpl<Type> & elements)93 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
94   if (parseToken(Token::l_paren, "expected '('"))
95     return failure();
96 
97   // Handle empty lists.
98   if (getToken().is(Token::r_paren))
99     return consumeToken(), success();
100 
101   if (parseTypeListNoParens(elements) ||
102       parseToken(Token::r_paren, "expected ')'"))
103     return failure();
104   return success();
105 }
106 
107 /// Parse a complex type.
108 ///
109 ///   complex-type ::= `complex` `<` type `>`
110 ///
parseComplexType()111 Type Parser::parseComplexType() {
112   consumeToken(Token::kw_complex);
113 
114   // Parse the '<'.
115   if (parseToken(Token::less, "expected '<' in complex type"))
116     return nullptr;
117 
118   llvm::SMLoc elementTypeLoc = getToken().getLoc();
119   auto elementType = parseType();
120   if (!elementType ||
121       parseToken(Token::greater, "expected '>' in complex type"))
122     return nullptr;
123   if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
124     return emitError(elementTypeLoc, "invalid element type for complex"),
125            nullptr;
126 
127   return ComplexType::get(elementType);
128 }
129 
130 /// Parse a function type.
131 ///
132 ///   function-type ::= type-list-parens `->` function-result-type
133 ///
parseFunctionType()134 Type Parser::parseFunctionType() {
135   assert(getToken().is(Token::l_paren));
136 
137   SmallVector<Type, 4> arguments, results;
138   if (parseTypeListParens(arguments) ||
139       parseToken(Token::arrow, "expected '->' in function type") ||
140       parseFunctionResultTypes(results))
141     return nullptr;
142 
143   return builder.getFunctionType(arguments, results);
144 }
145 
146 /// Parse the offset and strides from a strided layout specification.
147 ///
148 ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
149 ///
parseStridedLayout(int64_t & offset,SmallVectorImpl<int64_t> & strides)150 ParseResult Parser::parseStridedLayout(int64_t &offset,
151                                        SmallVectorImpl<int64_t> &strides) {
152   // Parse offset.
153   consumeToken(Token::kw_offset);
154   if (!consumeIf(Token::colon))
155     return emitError("expected colon after `offset` keyword");
156   auto maybeOffset = getToken().getUnsignedIntegerValue();
157   bool question = getToken().is(Token::question);
158   if (!maybeOffset && !question)
159     return emitError("invalid offset");
160   offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
161                        : MemRefType::getDynamicStrideOrOffset();
162   consumeToken();
163 
164   if (!consumeIf(Token::comma))
165     return emitError("expected comma after offset value");
166 
167   // Parse stride list.
168   if (parseToken(Token::kw_strides,
169                  "expected `strides` keyword after offset specification") ||
170 
171       parseToken(Token::colon, "expected colon after `strides` keyword") ||
172       parseStrideList(strides))
173     return failure();
174   return success();
175 }
176 
177 /// Parse a memref type.
178 ///
179 ///   memref-type ::= ranked-memref-type | unranked-memref-type
180 ///
181 ///   ranked-memref-type ::= `memref` `<` dimension-list-ranked type
182 ///                          (`,` layout-specification)? (`,` memory-space)? `>`
183 ///
184 ///   unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
185 ///
186 ///   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
187 ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
188 ///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
189 ///   layout-specification ::= semi-affine-map-composition | strided-layout
190 ///   memory-space ::= integer-literal /* | TODO: address-space-id */
191 ///
parseMemRefType()192 Type Parser::parseMemRefType() {
193   llvm::SMLoc loc = getToken().getLoc();
194   consumeToken(Token::kw_memref);
195 
196   if (parseToken(Token::less, "expected '<' in memref type"))
197     return nullptr;
198 
199   bool isUnranked;
200   SmallVector<int64_t, 4> dimensions;
201 
202   if (consumeIf(Token::star)) {
203     // This is an unranked memref type.
204     isUnranked = true;
205     if (parseXInDimensionList())
206       return nullptr;
207 
208   } else {
209     isUnranked = false;
210     if (parseDimensionListRanked(dimensions))
211       return nullptr;
212   }
213 
214   // Parse the element type.
215   auto typeLoc = getToken().getLoc();
216   auto elementType = parseType();
217   if (!elementType)
218     return nullptr;
219 
220   // Check that memref is formed from allowed types.
221   if (!BaseMemRefType::isValidElementType(elementType))
222     return emitError(typeLoc, "invalid memref element type"), nullptr;
223 
224   // Parse semi-affine-map-composition.
225   SmallVector<AffineMap, 2> affineMapComposition;
226   Attribute memorySpace;
227   unsigned numDims = dimensions.size();
228 
229   auto parseElt = [&]() -> ParseResult {
230     AffineMap map;
231     llvm::SMLoc mapLoc = getToken().getLoc();
232 
233     // Check for AffineMap as offset/strides.
234     if (getToken().is(Token::kw_offset)) {
235       int64_t offset;
236       SmallVector<int64_t, 4> strides;
237       if (failed(parseStridedLayout(offset, strides)))
238         return failure();
239       // Construct strided affine map.
240       map = makeStridedLinearLayoutMap(strides, offset, state.context);
241     } else {
242       // Either it is AffineMapAttr or memory space attribute.
243       Attribute attr = parseAttribute();
244       if (!attr)
245         return failure();
246 
247       if (AffineMapAttr affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
248         map = affineMapAttr.getValue();
249       } else if (memorySpace) {
250         return emitError("multiple memory spaces specified in memref type");
251       } else {
252         memorySpace = attr;
253         return success();
254       }
255     }
256 
257     if (isUnranked)
258       return emitError("cannot have affine map for unranked memref type");
259     if (memorySpace)
260       return emitError("expected memory space to be last in memref type");
261 
262     if (map.getNumDims() != numDims) {
263       size_t i = affineMapComposition.size();
264       return emitError(mapLoc, "memref affine map dimension mismatch between ")
265              << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
266              << " and affine map" << i + 1 << ": " << numDims
267              << " != " << map.getNumDims();
268     }
269     numDims = map.getNumResults();
270     affineMapComposition.push_back(map);
271     return success();
272   };
273 
274   // Parse a list of mappings and address space if present.
275   if (!consumeIf(Token::greater)) {
276     // Parse comma separated list of affine maps, followed by memory space.
277     if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
278         parseCommaSeparatedListUntil(Token::greater, parseElt,
279                                      /*allowEmptyList=*/false)) {
280       return nullptr;
281     }
282   }
283 
284   if (isUnranked)
285     return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
286 
287   return getChecked<MemRefType>(loc, dimensions, elementType,
288                                 affineMapComposition, memorySpace);
289 }
290 
291 /// Parse any type except the function type.
292 ///
293 ///   non-function-type ::= integer-type
294 ///                       | index-type
295 ///                       | float-type
296 ///                       | extended-type
297 ///                       | vector-type
298 ///                       | tensor-type
299 ///                       | memref-type
300 ///                       | complex-type
301 ///                       | tuple-type
302 ///                       | none-type
303 ///
304 ///   index-type ::= `index`
305 ///   float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
306 ///   none-type ::= `none`
307 ///
parseNonFunctionType()308 Type Parser::parseNonFunctionType() {
309   switch (getToken().getKind()) {
310   default:
311     return (emitError("expected non-function type"), nullptr);
312   case Token::kw_memref:
313     return parseMemRefType();
314   case Token::kw_tensor:
315     return parseTensorType();
316   case Token::kw_complex:
317     return parseComplexType();
318   case Token::kw_tuple:
319     return parseTupleType();
320   case Token::kw_vector:
321     return parseVectorType();
322   // integer-type
323   case Token::inttype: {
324     auto width = getToken().getIntTypeBitwidth();
325     if (!width.hasValue())
326       return (emitError("invalid integer width"), nullptr);
327     if (width.getValue() > IntegerType::kMaxWidth) {
328       emitError(getToken().getLoc(), "integer bitwidth is limited to ")
329           << IntegerType::kMaxWidth << " bits";
330       return nullptr;
331     }
332 
333     IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
334     if (Optional<bool> signedness = getToken().getIntTypeSignedness())
335       signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
336 
337     consumeToken(Token::inttype);
338     return IntegerType::get(getContext(), width.getValue(), signSemantics);
339   }
340 
341   // float-type
342   case Token::kw_bf16:
343     consumeToken(Token::kw_bf16);
344     return builder.getBF16Type();
345   case Token::kw_f16:
346     consumeToken(Token::kw_f16);
347     return builder.getF16Type();
348   case Token::kw_f32:
349     consumeToken(Token::kw_f32);
350     return builder.getF32Type();
351   case Token::kw_f64:
352     consumeToken(Token::kw_f64);
353     return builder.getF64Type();
354   case Token::kw_f80:
355     consumeToken(Token::kw_f80);
356     return builder.getF80Type();
357   case Token::kw_f128:
358     consumeToken(Token::kw_f128);
359     return builder.getF128Type();
360 
361   // index-type
362   case Token::kw_index:
363     consumeToken(Token::kw_index);
364     return builder.getIndexType();
365 
366   // none-type
367   case Token::kw_none:
368     consumeToken(Token::kw_none);
369     return builder.getNoneType();
370 
371   // extended type
372   case Token::exclamation_identifier:
373     return parseExtendedType();
374   }
375 }
376 
377 /// Parse a tensor type.
378 ///
379 ///   tensor-type ::= `tensor` `<` dimension-list type `>`
380 ///   dimension-list ::= dimension-list-ranked | `*x`
381 ///
parseTensorType()382 Type Parser::parseTensorType() {
383   consumeToken(Token::kw_tensor);
384 
385   if (parseToken(Token::less, "expected '<' in tensor type"))
386     return nullptr;
387 
388   bool isUnranked;
389   SmallVector<int64_t, 4> dimensions;
390 
391   if (consumeIf(Token::star)) {
392     // This is an unranked tensor type.
393     isUnranked = true;
394 
395     if (parseXInDimensionList())
396       return nullptr;
397 
398   } else {
399     isUnranked = false;
400     if (parseDimensionListRanked(dimensions))
401       return nullptr;
402   }
403 
404   // Parse the element type.
405   auto elementTypeLoc = getToken().getLoc();
406   auto elementType = parseType();
407 
408   // Parse an optional encoding attribute.
409   Attribute encoding;
410   if (consumeIf(Token::comma)) {
411     encoding = parseAttribute();
412     if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
413       if (failed(v.verifyEncoding(dimensions, elementType,
414                                   [&] { return emitError(); })))
415         return nullptr;
416     }
417   }
418 
419   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
420     return nullptr;
421   if (!TensorType::isValidElementType(elementType))
422     return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
423 
424   if (isUnranked) {
425     if (encoding)
426       return emitError("cannot apply encoding to unranked tensor"), nullptr;
427     return UnrankedTensorType::get(elementType);
428   }
429   return RankedTensorType::get(dimensions, elementType, encoding);
430 }
431 
432 /// Parse a tuple type.
433 ///
434 ///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
435 ///
parseTupleType()436 Type Parser::parseTupleType() {
437   consumeToken(Token::kw_tuple);
438 
439   // Parse the '<'.
440   if (parseToken(Token::less, "expected '<' in tuple type"))
441     return nullptr;
442 
443   // Check for an empty tuple by directly parsing '>'.
444   if (consumeIf(Token::greater))
445     return TupleType::get(getContext());
446 
447   // Parse the element types and the '>'.
448   SmallVector<Type, 4> types;
449   if (parseTypeListNoParens(types) ||
450       parseToken(Token::greater, "expected '>' in tuple type"))
451     return nullptr;
452 
453   return TupleType::get(getContext(), types);
454 }
455 
456 /// Parse a vector type.
457 ///
458 ///   vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
459 ///   non-empty-static-dimension-list ::= decimal-literal `x`
460 ///                                       static-dimension-list
461 ///   static-dimension-list ::= (decimal-literal `x`)*
462 ///
parseVectorType()463 VectorType Parser::parseVectorType() {
464   consumeToken(Token::kw_vector);
465 
466   if (parseToken(Token::less, "expected '<' in vector type"))
467     return nullptr;
468 
469   SmallVector<int64_t, 4> dimensions;
470   if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
471     return nullptr;
472   if (dimensions.empty())
473     return (emitError("expected dimension size in vector type"), nullptr);
474   if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
475     return emitError(getToken().getLoc(),
476                      "vector types must have positive constant sizes"),
477            nullptr;
478 
479   // Parse the element type.
480   auto typeLoc = getToken().getLoc();
481   auto elementType = parseType();
482   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
483     return nullptr;
484   if (!VectorType::isValidElementType(elementType))
485     return emitError(typeLoc, "vector elements must be int/index/float type"),
486            nullptr;
487 
488   return VectorType::get(dimensions, elementType);
489 }
490 
491 /// Parse a dimension list of a tensor or memref type.  This populates the
492 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
493 /// errors out on `?` otherwise.
494 ///
495 ///   dimension-list-ranked ::= (dimension `x`)*
496 ///   dimension ::= `?` | decimal-literal
497 ///
498 /// When `allowDynamic` is not set, this is used to parse:
499 ///
500 ///   static-dimension-list ::= (decimal-literal `x`)*
501 ParseResult
parseDimensionListRanked(SmallVectorImpl<int64_t> & dimensions,bool allowDynamic)502 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
503                                  bool allowDynamic) {
504   while (getToken().isAny(Token::integer, Token::question)) {
505     if (consumeIf(Token::question)) {
506       if (!allowDynamic)
507         return emitError("expected static shape");
508       dimensions.push_back(-1);
509     } else {
510       // Hexadecimal integer literals (starting with `0x`) are not allowed in
511       // aggregate type declarations.  Therefore, `0xf32` should be processed as
512       // a sequence of separate elements `0`, `x`, `f32`.
513       if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
514         // We can get here only if the token is an integer literal.  Hexadecimal
515         // integer literals can only start with `0x` (`1x` wouldn't lex as a
516         // literal, just `1` would, at which point we don't get into this
517         // branch).
518         assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
519         dimensions.push_back(0);
520         state.lex.resetPointer(getTokenSpelling().data() + 1);
521         consumeToken();
522       } else {
523         // Make sure this integer value is in bound and valid.
524         auto dimension = getToken().getUnsignedIntegerValue();
525         if (!dimension.hasValue())
526           return emitError("invalid dimension");
527         dimensions.push_back((int64_t)dimension.getValue());
528         consumeToken(Token::integer);
529       }
530     }
531 
532     // Make sure we have an 'x' or something like 'xbf32'.
533     if (parseXInDimensionList())
534       return failure();
535   }
536 
537   return success();
538 }
539 
540 /// Parse an 'x' token in a dimension list, handling the case where the x is
541 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
542 /// token.
parseXInDimensionList()543 ParseResult Parser::parseXInDimensionList() {
544   if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
545     return emitError("expected 'x' in dimension list");
546 
547   // If we had a prefix of 'x', lex the next token immediately after the 'x'.
548   if (getTokenSpelling().size() != 1)
549     state.lex.resetPointer(getTokenSpelling().data() + 1);
550 
551   // Consume the 'x'.
552   consumeToken(Token::bare_identifier);
553 
554   return success();
555 }
556 
557 // Parse a comma-separated list of dimensions, possibly empty:
558 //   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
parseStrideList(SmallVectorImpl<int64_t> & dimensions)559 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
560   return parseCommaSeparatedList(
561       Delimiter::Square,
562       [&]() -> ParseResult {
563         if (consumeIf(Token::question)) {
564           dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
565         } else {
566           // This must be an integer value.
567           int64_t val;
568           if (getToken().getSpelling().getAsInteger(10, val))
569             return emitError("invalid integer value: ")
570                    << getToken().getSpelling();
571           // Make sure it is not the one value for `?`.
572           if (ShapedType::isDynamic(val))
573             return emitError("invalid integer value: ")
574                    << getToken().getSpelling()
575                    << ", use `?` to specify a dynamic dimension";
576 
577           if (val == 0)
578             return emitError("invalid memref stride");
579 
580           dimensions.push_back(val);
581           consumeToken(Token::integer);
582         }
583         return success();
584       },
585       " in stride list");
586 }
587