1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "mlir/Parser/AsmParserState.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/Endian.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 /// Parse an arbitrary attribute.
26 ///
27 ///  attribute-value ::= `unit`
28 ///                    | bool-literal
29 ///                    | integer-literal (`:` (index-type | integer-type))?
30 ///                    | float-literal (`:` float-type)?
31 ///                    | string-literal (`:` type)?
32 ///                    | type
33 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
34 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
35 ///                    | symbol-ref-id (`::` symbol-ref-id)*
36 ///                    | `dense` `<` attribute-value `>` `:`
37 ///                      (tensor-type | vector-type)
38 ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
39 ///                      `:` (tensor-type | vector-type)
40 ///                    | `opaque` `<` dialect-namespace  `,` hex-string-literal
41 ///                      `>` `:` (tensor-type | vector-type)
42 ///                    | extended-attribute
43 ///
44 Attribute Parser::parseAttribute(Type type) {
45   switch (getToken().getKind()) {
46   // Parse an AffineMap or IntegerSet attribute.
47   case Token::kw_affine_map: {
48     consumeToken(Token::kw_affine_map);
49 
50     AffineMap map;
51     if (parseToken(Token::less, "expected '<' in affine map") ||
52         parseAffineMapReference(map) ||
53         parseToken(Token::greater, "expected '>' in affine map"))
54       return Attribute();
55     return AffineMapAttr::get(map);
56   }
57   case Token::kw_affine_set: {
58     consumeToken(Token::kw_affine_set);
59 
60     IntegerSet set;
61     if (parseToken(Token::less, "expected '<' in integer set") ||
62         parseIntegerSetReference(set) ||
63         parseToken(Token::greater, "expected '>' in integer set"))
64       return Attribute();
65     return IntegerSetAttr::get(set);
66   }
67 
68   // Parse an array attribute.
69   case Token::l_square: {
70     consumeToken(Token::l_square);
71 
72     SmallVector<Attribute, 4> elements;
73     auto parseElt = [&]() -> ParseResult {
74       elements.push_back(parseAttribute());
75       return elements.back() ? success() : failure();
76     };
77 
78     if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
79       return nullptr;
80     return builder.getArrayAttr(elements);
81   }
82 
83   // Parse a boolean attribute.
84   case Token::kw_false:
85     consumeToken(Token::kw_false);
86     return builder.getBoolAttr(false);
87   case Token::kw_true:
88     consumeToken(Token::kw_true);
89     return builder.getBoolAttr(true);
90 
91   // Parse a dense elements attribute.
92   case Token::kw_dense:
93     return parseDenseElementsAttr(type);
94 
95   // Parse a dictionary attribute.
96   case Token::l_brace: {
97     NamedAttrList elements;
98     if (parseAttributeDict(elements))
99       return nullptr;
100     return elements.getDictionary(getContext());
101   }
102 
103   // Parse an extended attribute, i.e. alias or dialect attribute.
104   case Token::hash_identifier:
105     return parseExtendedAttr(type);
106 
107   // Parse floating point and integer attributes.
108   case Token::floatliteral:
109     return parseFloatAttr(type, /*isNegative=*/false);
110   case Token::integer:
111     return parseDecOrHexAttr(type, /*isNegative=*/false);
112   case Token::minus: {
113     consumeToken(Token::minus);
114     if (getToken().is(Token::integer))
115       return parseDecOrHexAttr(type, /*isNegative=*/true);
116     if (getToken().is(Token::floatliteral))
117       return parseFloatAttr(type, /*isNegative=*/true);
118 
119     return (emitError("expected constant integer or floating point value"),
120             nullptr);
121   }
122 
123   // Parse a location attribute.
124   case Token::kw_loc: {
125     consumeToken(Token::kw_loc);
126 
127     LocationAttr locAttr;
128     if (parseToken(Token::l_paren, "expected '(' in inline location") ||
129         parseLocationInstance(locAttr) ||
130         parseToken(Token::r_paren, "expected ')' in inline location"))
131       return Attribute();
132     return locAttr;
133   }
134 
135   // Parse an opaque elements attribute.
136   case Token::kw_opaque:
137     return parseOpaqueElementsAttr(type);
138 
139   // Parse a sparse elements attribute.
140   case Token::kw_sparse:
141     return parseSparseElementsAttr(type);
142 
143   // Parse a string attribute.
144   case Token::string: {
145     auto val = getToken().getStringValue();
146     consumeToken(Token::string);
147     // Parse the optional trailing colon type if one wasn't explicitly provided.
148     if (!type && consumeIf(Token::colon) && !(type = parseType()))
149       return Attribute();
150 
151     return type ? StringAttr::get(val, type)
152                 : StringAttr::get(getContext(), val);
153   }
154 
155   // Parse a symbol reference attribute.
156   case Token::at_identifier: {
157     // When populating the parser state, this is a list of locations for all of
158     // the nested references.
159     SmallVector<llvm::SMRange> referenceLocations;
160     if (state.asmState)
161       referenceLocations.push_back(getToken().getLocRange());
162 
163     // Parse the top-level reference.
164     std::string nameStr = getToken().getSymbolReference();
165     consumeToken(Token::at_identifier);
166 
167     // Parse any nested references.
168     std::vector<FlatSymbolRefAttr> nestedRefs;
169     while (getToken().is(Token::colon)) {
170       // Check for the '::' prefix.
171       const char *curPointer = getToken().getLoc().getPointer();
172       consumeToken(Token::colon);
173       if (!consumeIf(Token::colon)) {
174         state.lex.resetPointer(curPointer);
175         consumeToken();
176         break;
177       }
178       // Parse the reference itself.
179       auto curLoc = getToken().getLoc();
180       if (getToken().isNot(Token::at_identifier)) {
181         emitError(curLoc, "expected nested symbol reference identifier");
182         return Attribute();
183       }
184 
185       // If we are populating the assembly state, add the location for this
186       // reference.
187       if (state.asmState)
188         referenceLocations.push_back(getToken().getLocRange());
189 
190       std::string nameStr = getToken().getSymbolReference();
191       consumeToken(Token::at_identifier);
192       nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
193     }
194     SymbolRefAttr symbolRefAttr = builder.getSymbolRefAttr(nameStr, nestedRefs);
195 
196     // If we are populating the assembly state, record this symbol reference.
197     if (state.asmState)
198       state.asmState->addUses(symbolRefAttr, referenceLocations);
199     return symbolRefAttr;
200   }
201 
202   // Parse a 'unit' attribute.
203   case Token::kw_unit:
204     consumeToken(Token::kw_unit);
205     return builder.getUnitAttr();
206 
207   default:
208     // Parse a type attribute.
209     if (Type type = parseType())
210       return TypeAttr::get(type);
211     return nullptr;
212   }
213 }
214 
215 /// Parse an optional attribute with the provided type.
216 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
217                                                    Type type) {
218   switch (getToken().getKind()) {
219   case Token::at_identifier:
220   case Token::floatliteral:
221   case Token::integer:
222   case Token::hash_identifier:
223   case Token::kw_affine_map:
224   case Token::kw_affine_set:
225   case Token::kw_dense:
226   case Token::kw_false:
227   case Token::kw_loc:
228   case Token::kw_opaque:
229   case Token::kw_sparse:
230   case Token::kw_true:
231   case Token::kw_unit:
232   case Token::l_brace:
233   case Token::l_square:
234   case Token::minus:
235   case Token::string:
236     attribute = parseAttribute(type);
237     return success(attribute != nullptr);
238 
239   default:
240     // Parse an optional type attribute.
241     Type type;
242     OptionalParseResult result = parseOptionalType(type);
243     if (result.hasValue() && succeeded(*result))
244       attribute = TypeAttr::get(type);
245     return result;
246   }
247 }
248 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
249                                                    Type type) {
250   return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
251 }
252 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
253                                                    Type type) {
254   return parseOptionalAttributeWithToken(Token::string, attribute, type);
255 }
256 
257 /// Attribute dictionary.
258 ///
259 ///   attribute-dict ::= `{` `}`
260 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
261 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
262 ///
263 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
264   if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
265     return failure();
266 
267   llvm::SmallDenseSet<Identifier> seenKeys;
268   auto parseElt = [&]() -> ParseResult {
269     // The name of an attribute can either be a bare identifier, or a string.
270     Optional<Identifier> nameId;
271     if (getToken().is(Token::string))
272       nameId = builder.getIdentifier(getToken().getStringValue());
273     else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
274              getToken().isKeyword())
275       nameId = builder.getIdentifier(getTokenSpelling());
276     else
277       return emitError("expected attribute name");
278     if (!seenKeys.insert(*nameId).second)
279       return emitError("duplicate key '")
280              << *nameId << "' in dictionary attribute";
281     consumeToken();
282 
283     // Lazy load a dialect in the context if there is a possible namespace.
284     auto splitName = nameId->strref().split('.');
285     if (!splitName.second.empty())
286       getContext()->getOrLoadDialect(splitName.first);
287 
288     // Try to parse the '=' for the attribute value.
289     if (!consumeIf(Token::equal)) {
290       // If there is no '=', we treat this as a unit attribute.
291       attributes.push_back({*nameId, builder.getUnitAttr()});
292       return success();
293     }
294 
295     auto attr = parseAttribute();
296     if (!attr)
297       return failure();
298     attributes.push_back({*nameId, attr});
299     return success();
300   };
301 
302   if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
303     return failure();
304 
305   return success();
306 }
307 
308 /// Parse a float attribute.
309 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
310   auto val = getToken().getFloatingPointValue();
311   if (!val.hasValue())
312     return (emitError("floating point value too large for attribute"), nullptr);
313   consumeToken(Token::floatliteral);
314   if (!type) {
315     // Default to F64 when no type is specified.
316     if (!consumeIf(Token::colon))
317       type = builder.getF64Type();
318     else if (!(type = parseType()))
319       return nullptr;
320   }
321   if (!type.isa<FloatType>())
322     return (emitError("floating point value not valid for specified type"),
323             nullptr);
324   return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
325 }
326 
327 /// Construct an APint from a parsed value, a known attribute type and
328 /// sign.
329 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
330                                            StringRef spelling) {
331   // Parse the integer value into an APInt that is big enough to hold the value.
332   APInt result;
333   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
334   if (spelling.getAsInteger(isHex ? 0 : 10, result))
335     return llvm::None;
336 
337   // Extend or truncate the bitwidth to the right size.
338   unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
339                                   : type.getIntOrFloatBitWidth();
340 
341   // APInt cannot hold a zero bit value.
342   if (width == 0)
343     return llvm::None;
344 
345   if (width > result.getBitWidth()) {
346     result = result.zext(width);
347   } else if (width < result.getBitWidth()) {
348     // The parser can return an unnecessarily wide result with leading zeros.
349     // This isn't a problem, but truncating off bits is bad.
350     if (result.countLeadingZeros() < result.getBitWidth() - width)
351       return llvm::None;
352 
353     result = result.trunc(width);
354   }
355 
356   if (isNegative) {
357     // The value is negative, we have an overflow if the sign bit is not set
358     // in the negated apInt.
359     result.negate();
360     if (!result.isSignBitSet())
361       return llvm::None;
362   } else if ((type.isSignedInteger() || type.isIndex()) &&
363              result.isSignBitSet()) {
364     // The value is a positive signed integer or index,
365     // we have an overflow if the sign bit is set.
366     return llvm::None;
367   }
368 
369   return result;
370 }
371 
372 /// Parse a decimal or a hexadecimal literal, which can be either an integer
373 /// or a float attribute.
374 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
375   Token tok = getToken();
376   StringRef spelling = tok.getSpelling();
377   llvm::SMLoc loc = tok.getLoc();
378 
379   consumeToken(Token::integer);
380   if (!type) {
381     // Default to i64 if not type is specified.
382     if (!consumeIf(Token::colon))
383       type = builder.getIntegerType(64);
384     else if (!(type = parseType()))
385       return nullptr;
386   }
387 
388   if (auto floatType = type.dyn_cast<FloatType>()) {
389     Optional<APFloat> result;
390     if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
391                                             floatType.getFloatSemantics(),
392                                             floatType.getWidth())))
393       return Attribute();
394     return FloatAttr::get(floatType, *result);
395   }
396 
397   if (!type.isa<IntegerType, IndexType>())
398     return emitError(loc, "integer literal not valid for specified type"),
399            nullptr;
400 
401   if (isNegative && type.isUnsignedInteger()) {
402     emitError(loc,
403               "negative integer literal not valid for unsigned integer type");
404     return nullptr;
405   }
406 
407   Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
408   if (!apInt)
409     return emitError(loc, "integer constant out of range for attribute"),
410            nullptr;
411   return builder.getIntegerAttr(type, *apInt);
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // TensorLiteralParser
416 //===----------------------------------------------------------------------===//
417 
418 /// Parse elements values stored within a hex string. On success, the values are
419 /// stored into 'result'.
420 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
421                                              std::string &result) {
422   if (Optional<std::string> value = tok.getHexStringValue()) {
423     result = std::move(*value);
424     return success();
425   }
426   return parser.emitError(
427       tok.getLoc(), "expected string containing hex digits starting with `0x`");
428 }
429 
430 namespace {
431 /// This class implements a parser for TensorLiterals. A tensor literal is
432 /// either a single element (e.g, 5) or a multi-dimensional list of elements
433 /// (e.g., [[5, 5]]).
434 class TensorLiteralParser {
435 public:
436   TensorLiteralParser(Parser &p) : p(p) {}
437 
438   /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
439   /// may also parse a tensor literal that is store as a hex string.
440   ParseResult parse(bool allowHex);
441 
442   /// Build a dense attribute instance with the parsed elements and the given
443   /// shaped type.
444   DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
445 
446   ArrayRef<int64_t> getShape() const { return shape; }
447 
448 private:
449   /// Get the parsed elements for an integer attribute.
450   ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
451                                  std::vector<APInt> &intValues);
452 
453   /// Get the parsed elements for a float attribute.
454   ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
455                                    std::vector<APFloat> &floatValues);
456 
457   /// Build a Dense String attribute for the given type.
458   DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
459 
460   /// Build a Dense attribute with hex data for the given type.
461   DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
462 
463   /// Parse a single element, returning failure if it isn't a valid element
464   /// literal. For example:
465   /// parseElement(1) -> Success, 1
466   /// parseElement([1]) -> Failure
467   ParseResult parseElement();
468 
469   /// Parse a list of either lists or elements, returning the dimensions of the
470   /// parsed sub-tensors in dims. For example:
471   ///   parseList([1, 2, 3]) -> Success, [3]
472   ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
473   ///   parseList([[1, 2], 3]) -> Failure
474   ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
475   ParseResult parseList(SmallVectorImpl<int64_t> &dims);
476 
477   /// Parse a literal that was printed as a hex string.
478   ParseResult parseHexElements();
479 
480   Parser &p;
481 
482   /// The shape inferred from the parsed elements.
483   SmallVector<int64_t, 4> shape;
484 
485   /// Storage used when parsing elements, this is a pair of <is_negated, token>.
486   std::vector<std::pair<bool, Token>> storage;
487 
488   /// Storage used when parsing elements that were stored as hex values.
489   Optional<Token> hexStorage;
490 };
491 } // end anonymous namespace
492 
493 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
494 /// may also parse a tensor literal that is store as a hex string.
495 ParseResult TensorLiteralParser::parse(bool allowHex) {
496   // If hex is allowed, check for a string literal.
497   if (allowHex && p.getToken().is(Token::string)) {
498     hexStorage = p.getToken();
499     p.consumeToken(Token::string);
500     return success();
501   }
502   // Otherwise, parse a list or an individual element.
503   if (p.getToken().is(Token::l_square))
504     return parseList(shape);
505   return parseElement();
506 }
507 
508 /// Build a dense attribute instance with the parsed elements and the given
509 /// shaped type.
510 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
511                                                ShapedType type) {
512   Type eltType = type.getElementType();
513 
514   // Check to see if we parse the literal from a hex string.
515   if (hexStorage.hasValue() &&
516       (eltType.isIntOrFloat() || eltType.isa<ComplexType>()))
517     return getHexAttr(loc, type);
518 
519   // Check that the parsed storage size has the same number of elements to the
520   // type, or is a known splat.
521   if (!shape.empty() && getShape() != type.getShape()) {
522     p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
523                      << "]) does not match type ([" << type.getShape() << "])";
524     return nullptr;
525   }
526 
527   // Handle the case where no elements were parsed.
528   if (!hexStorage.hasValue() && storage.empty() && type.getNumElements()) {
529     p.emitError(loc) << "parsed zero elements, but type (" << type
530                      << ") expected at least 1";
531     return nullptr;
532   }
533 
534   // Handle complex types in the specific element type cases below.
535   bool isComplex = false;
536   if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
537     eltType = complexTy.getElementType();
538     isComplex = true;
539   }
540 
541   // Handle integer and index types.
542   if (eltType.isIntOrIndex()) {
543     std::vector<APInt> intValues;
544     if (failed(getIntAttrElements(loc, eltType, intValues)))
545       return nullptr;
546     if (isComplex) {
547       // If this is a complex, treat the parsed values as complex values.
548       auto complexData = llvm::makeArrayRef(
549           reinterpret_cast<std::complex<APInt> *>(intValues.data()),
550           intValues.size() / 2);
551       return DenseElementsAttr::get(type, complexData);
552     }
553     return DenseElementsAttr::get(type, intValues);
554   }
555   // Handle floating point types.
556   if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
557     std::vector<APFloat> floatValues;
558     if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
559       return nullptr;
560     if (isComplex) {
561       // If this is a complex, treat the parsed values as complex values.
562       auto complexData = llvm::makeArrayRef(
563           reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
564           floatValues.size() / 2);
565       return DenseElementsAttr::get(type, complexData);
566     }
567     return DenseElementsAttr::get(type, floatValues);
568   }
569 
570   // Other types are assumed to be string representations.
571   return getStringAttr(loc, type, type.getElementType());
572 }
573 
574 /// Build a Dense Integer attribute for the given type.
575 ParseResult
576 TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
577                                         std::vector<APInt> &intValues) {
578   intValues.reserve(storage.size());
579   bool isUintType = eltTy.isUnsignedInteger();
580   for (const auto &signAndToken : storage) {
581     bool isNegative = signAndToken.first;
582     const Token &token = signAndToken.second;
583     auto tokenLoc = token.getLoc();
584 
585     if (isNegative && isUintType) {
586       return p.emitError(tokenLoc)
587              << "expected unsigned integer elements, but parsed negative value";
588     }
589 
590     // Check to see if floating point values were parsed.
591     if (token.is(Token::floatliteral)) {
592       return p.emitError(tokenLoc)
593              << "expected integer elements, but parsed floating-point";
594     }
595 
596     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
597            "unexpected token type");
598     if (token.isAny(Token::kw_true, Token::kw_false)) {
599       if (!eltTy.isInteger(1)) {
600         return p.emitError(tokenLoc)
601                << "expected i1 type for 'true' or 'false' values";
602       }
603       APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
604       intValues.push_back(apInt);
605       continue;
606     }
607 
608     // Create APInt values for each element with the correct bitwidth.
609     Optional<APInt> apInt =
610         buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
611     if (!apInt)
612       return p.emitError(tokenLoc, "integer constant out of range for type");
613     intValues.push_back(*apInt);
614   }
615   return success();
616 }
617 
618 /// Build a Dense Float attribute for the given type.
619 ParseResult
620 TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
621                                           std::vector<APFloat> &floatValues) {
622   floatValues.reserve(storage.size());
623   for (const auto &signAndToken : storage) {
624     bool isNegative = signAndToken.first;
625     const Token &token = signAndToken.second;
626 
627     // Handle hexadecimal float literals.
628     if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
629       Optional<APFloat> result;
630       if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
631                                                 eltTy.getFloatSemantics(),
632                                                 eltTy.getWidth())))
633         return failure();
634 
635       floatValues.push_back(*result);
636       continue;
637     }
638 
639     // Check to see if any decimal integers or booleans were parsed.
640     if (!token.is(Token::floatliteral))
641       return p.emitError()
642              << "expected floating-point elements, but parsed integer";
643 
644     // Build the float values from tokens.
645     auto val = token.getFloatingPointValue();
646     if (!val.hasValue())
647       return p.emitError("floating point value too large for attribute");
648 
649     APFloat apVal(isNegative ? -*val : *val);
650     if (!eltTy.isF64()) {
651       bool unused;
652       apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
653                     &unused);
654     }
655     floatValues.push_back(apVal);
656   }
657   return success();
658 }
659 
660 /// Build a Dense String attribute for the given type.
661 DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
662                                                      ShapedType type,
663                                                      Type eltTy) {
664   if (hexStorage.hasValue()) {
665     auto stringValue = hexStorage.getValue().getStringValue();
666     return DenseStringElementsAttr::get(type, {stringValue});
667   }
668 
669   std::vector<std::string> stringValues;
670   std::vector<StringRef> stringRefValues;
671   stringValues.reserve(storage.size());
672   stringRefValues.reserve(storage.size());
673 
674   for (auto val : storage) {
675     stringValues.push_back(val.second.getStringValue());
676     stringRefValues.push_back(stringValues.back());
677   }
678 
679   return DenseStringElementsAttr::get(type, stringRefValues);
680 }
681 
682 /// Build a Dense attribute with hex data for the given type.
683 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
684                                                   ShapedType type) {
685   Type elementType = type.getElementType();
686   if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
687     p.emitError(loc)
688         << "expected floating-point, integer, or complex element type, got "
689         << elementType;
690     return nullptr;
691   }
692 
693   std::string data;
694   if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
695     return nullptr;
696 
697   ArrayRef<char> rawData(data.data(), data.size());
698   bool detectedSplat = false;
699   if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
700     p.emitError(loc) << "elements hex data size is invalid for provided type: "
701                      << type;
702     return nullptr;
703   }
704 
705   if (llvm::support::endian::system_endianness() ==
706       llvm::support::endianness::big) {
707     // Convert endianess in big-endian(BE) machines. `rawData` is
708     // little-endian(LE) because HEX in raw data of dense element attribute
709     // is always LE format. It is converted into BE here to be used in BE
710     // machines.
711     SmallVector<char, 64> outDataVec(rawData.size());
712     MutableArrayRef<char> convRawData(outDataVec);
713     DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
714         rawData, convRawData, type);
715     return DenseElementsAttr::getFromRawBuffer(type, convRawData,
716                                                detectedSplat);
717   }
718 
719   return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
720 }
721 
722 ParseResult TensorLiteralParser::parseElement() {
723   switch (p.getToken().getKind()) {
724   // Parse a boolean element.
725   case Token::kw_true:
726   case Token::kw_false:
727   case Token::floatliteral:
728   case Token::integer:
729     storage.emplace_back(/*isNegative=*/false, p.getToken());
730     p.consumeToken();
731     break;
732 
733   // Parse a signed integer or a negative floating-point element.
734   case Token::minus:
735     p.consumeToken(Token::minus);
736     if (!p.getToken().isAny(Token::floatliteral, Token::integer))
737       return p.emitError("expected integer or floating point literal");
738     storage.emplace_back(/*isNegative=*/true, p.getToken());
739     p.consumeToken();
740     break;
741 
742   case Token::string:
743     storage.emplace_back(/*isNegative=*/false, p.getToken());
744     p.consumeToken();
745     break;
746 
747   // Parse a complex element of the form '(' element ',' element ')'.
748   case Token::l_paren:
749     p.consumeToken(Token::l_paren);
750     if (parseElement() ||
751         p.parseToken(Token::comma, "expected ',' between complex elements") ||
752         parseElement() ||
753         p.parseToken(Token::r_paren, "expected ')' after complex elements"))
754       return failure();
755     break;
756 
757   default:
758     return p.emitError("expected element literal of primitive type");
759   }
760 
761   return success();
762 }
763 
764 /// Parse a list of either lists or elements, returning the dimensions of the
765 /// parsed sub-tensors in dims. For example:
766 ///   parseList([1, 2, 3]) -> Success, [3]
767 ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
768 ///   parseList([[1, 2], 3]) -> Failure
769 ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
770 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
771   p.consumeToken(Token::l_square);
772 
773   auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
774                        const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
775     if (prevDims == newDims)
776       return success();
777     return p.emitError("tensor literal is invalid; ranks are not consistent "
778                        "between elements");
779   };
780 
781   bool first = true;
782   SmallVector<int64_t, 4> newDims;
783   unsigned size = 0;
784   auto parseCommaSeparatedList = [&]() -> ParseResult {
785     SmallVector<int64_t, 4> thisDims;
786     if (p.getToken().getKind() == Token::l_square) {
787       if (parseList(thisDims))
788         return failure();
789     } else if (parseElement()) {
790       return failure();
791     }
792     ++size;
793     if (!first)
794       return checkDims(newDims, thisDims);
795     newDims = thisDims;
796     first = false;
797     return success();
798   };
799   if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
800     return failure();
801 
802   // Return the sublists' dimensions with 'size' prepended.
803   dims.clear();
804   dims.push_back(size);
805   dims.append(newDims.begin(), newDims.end());
806   return success();
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // ElementsAttr Parser
811 //===----------------------------------------------------------------------===//
812 
813 /// Parse a dense elements attribute.
814 Attribute Parser::parseDenseElementsAttr(Type attrType) {
815   auto attribLoc = getToken().getLoc();
816   consumeToken(Token::kw_dense);
817   if (parseToken(Token::less, "expected '<' after 'dense'"))
818     return nullptr;
819 
820   // Parse the literal data if necessary.
821   TensorLiteralParser literalParser(*this);
822   if (!consumeIf(Token::greater)) {
823     if (literalParser.parse(/*allowHex=*/true) ||
824         parseToken(Token::greater, "expected '>'"))
825       return nullptr;
826   }
827 
828   // If the type is specified `parseElementsLiteralType` will not parse a type.
829   // Use the attribute location as the location for error reporting in that
830   // case.
831   auto loc = attrType ? attribLoc : getToken().getLoc();
832   auto type = parseElementsLiteralType(attrType);
833   if (!type)
834     return nullptr;
835   return literalParser.getAttr(loc, type);
836 }
837 
838 /// Parse an opaque elements attribute.
839 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
840   consumeToken(Token::kw_opaque);
841   if (parseToken(Token::less, "expected '<' after 'opaque'"))
842     return nullptr;
843 
844   if (getToken().isNot(Token::string))
845     return (emitError("expected dialect namespace"), nullptr);
846 
847   std::string name = getToken().getStringValue();
848   consumeToken(Token::string);
849 
850   if (parseToken(Token::comma, "expected ','"))
851     return nullptr;
852 
853   Token hexTok = getToken();
854   if (parseToken(Token::string, "elements hex string should start with '0x'") ||
855       parseToken(Token::greater, "expected '>'"))
856     return nullptr;
857   auto type = parseElementsLiteralType(attrType);
858   if (!type)
859     return nullptr;
860 
861   std::string data;
862   if (parseElementAttrHexValues(*this, hexTok, data))
863     return nullptr;
864   return OpaqueElementsAttr::get(builder.getIdentifier(name), type, data);
865 }
866 
867 /// Shaped type for elements attribute.
868 ///
869 ///   elements-literal-type ::= vector-type | ranked-tensor-type
870 ///
871 /// This method also checks the type has static shape.
872 ShapedType Parser::parseElementsLiteralType(Type type) {
873   // If the user didn't provide a type, parse the colon type for the literal.
874   if (!type) {
875     if (parseToken(Token::colon, "expected ':'"))
876       return nullptr;
877     if (!(type = parseType()))
878       return nullptr;
879   }
880 
881   if (!type.isa<RankedTensorType, VectorType>()) {
882     emitError("elements literal must be a ranked tensor or vector type");
883     return nullptr;
884   }
885 
886   auto sType = type.cast<ShapedType>();
887   if (!sType.hasStaticShape())
888     return (emitError("elements literal type must have static shape"), nullptr);
889 
890   return sType;
891 }
892 
893 /// Parse a sparse elements attribute.
894 Attribute Parser::parseSparseElementsAttr(Type attrType) {
895   consumeToken(Token::kw_sparse);
896   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
897     return nullptr;
898 
899   // Check for the case where all elements are sparse. The indices are
900   // represented by a 2-dimensional shape where the second dimension is the rank
901   // of the type.
902   Type indiceEltType = builder.getIntegerType(64);
903   if (consumeIf(Token::greater)) {
904     ShapedType type = parseElementsLiteralType(attrType);
905     if (!type)
906       return nullptr;
907 
908     // Construct the sparse elements attr using zero element indice/value
909     // attributes.
910     ShapedType indicesType =
911         RankedTensorType::get({0, type.getRank()}, indiceEltType);
912     ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
913     return SparseElementsAttr::get(
914         type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
915         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
916   }
917 
918   /// Parse the indices. We don't allow hex values here as we may need to use
919   /// the inferred shape.
920   auto indicesLoc = getToken().getLoc();
921   TensorLiteralParser indiceParser(*this);
922   if (indiceParser.parse(/*allowHex=*/false))
923     return nullptr;
924 
925   if (parseToken(Token::comma, "expected ','"))
926     return nullptr;
927 
928   /// Parse the values.
929   auto valuesLoc = getToken().getLoc();
930   TensorLiteralParser valuesParser(*this);
931   if (valuesParser.parse(/*allowHex=*/true))
932     return nullptr;
933 
934   if (parseToken(Token::greater, "expected '>'"))
935     return nullptr;
936 
937   auto type = parseElementsLiteralType(attrType);
938   if (!type)
939     return nullptr;
940 
941   // If the indices are a splat, i.e. the literal parser parsed an element and
942   // not a list, we set the shape explicitly. The indices are represented by a
943   // 2-dimensional shape where the second dimension is the rank of the type.
944   // Given that the parsed indices is a splat, we know that we only have one
945   // indice and thus one for the first dimension.
946   ShapedType indicesType;
947   if (indiceParser.getShape().empty()) {
948     indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
949   } else {
950     // Otherwise, set the shape to the one parsed by the literal parser.
951     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
952   }
953   auto indices = indiceParser.getAttr(indicesLoc, indicesType);
954 
955   // If the values are a splat, set the shape explicitly based on the number of
956   // indices. The number of indices is encoded in the first dimension of the
957   // indice shape type.
958   auto valuesEltType = type.getElementType();
959   ShapedType valuesType =
960       valuesParser.getShape().empty()
961           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
962           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
963   auto values = valuesParser.getAttr(valuesLoc, valuesType);
964 
965   /// Sanity check.
966   if (valuesType.getRank() != 1)
967     return (emitError("expected 1-d tensor for values"), nullptr);
968 
969   auto sameShape = (indicesType.getRank() == 1) ||
970                    (type.getRank() == indicesType.getDimSize(1));
971   auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
972   if (!sameShape || !sameElementNum) {
973     emitError() << "expected shape ([" << type.getShape()
974                 << "]); inferred shape of indices literal (["
975                 << indicesType.getShape()
976                 << "]); inferred shape of values literal (["
977                 << valuesType.getShape() << "])";
978     return nullptr;
979   }
980 
981   // Build the sparse elements attribute by the indices and values.
982   return SparseElementsAttr::get(type, indices, values);
983 }
984