1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "mlir/Parser/AsmParserState.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/Endian.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 /// Parse an arbitrary attribute.
26 ///
27 ///  attribute-value ::= `unit`
28 ///                    | bool-literal
29 ///                    | integer-literal (`:` (index-type | integer-type))?
30 ///                    | float-literal (`:` float-type)?
31 ///                    | string-literal (`:` type)?
32 ///                    | type
33 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
34 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
35 ///                    | symbol-ref-id (`::` symbol-ref-id)*
36 ///                    | `dense` `<` attribute-value `>` `:`
37 ///                      (tensor-type | vector-type)
38 ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
39 ///                      `:` (tensor-type | vector-type)
40 ///                    | `opaque` `<` dialect-namespace  `,` hex-string-literal
41 ///                      `>` `:` (tensor-type | vector-type)
42 ///                    | extended-attribute
43 ///
parseAttribute(Type type)44 Attribute Parser::parseAttribute(Type type) {
45   switch (getToken().getKind()) {
46   // Parse an AffineMap or IntegerSet attribute.
47   case Token::kw_affine_map: {
48     consumeToken(Token::kw_affine_map);
49 
50     AffineMap map;
51     if (parseToken(Token::less, "expected '<' in affine map") ||
52         parseAffineMapReference(map) ||
53         parseToken(Token::greater, "expected '>' in affine map"))
54       return Attribute();
55     return AffineMapAttr::get(map);
56   }
57   case Token::kw_affine_set: {
58     consumeToken(Token::kw_affine_set);
59 
60     IntegerSet set;
61     if (parseToken(Token::less, "expected '<' in integer set") ||
62         parseIntegerSetReference(set) ||
63         parseToken(Token::greater, "expected '>' in integer set"))
64       return Attribute();
65     return IntegerSetAttr::get(set);
66   }
67 
68   // Parse an array attribute.
69   case Token::l_square: {
70     SmallVector<Attribute, 4> elements;
71     auto parseElt = [&]() -> ParseResult {
72       elements.push_back(parseAttribute());
73       return elements.back() ? success() : failure();
74     };
75 
76     if (parseCommaSeparatedList(Delimiter::Square, parseElt))
77       return nullptr;
78     return builder.getArrayAttr(elements);
79   }
80 
81   // Parse a boolean attribute.
82   case Token::kw_false:
83     consumeToken(Token::kw_false);
84     return builder.getBoolAttr(false);
85   case Token::kw_true:
86     consumeToken(Token::kw_true);
87     return builder.getBoolAttr(true);
88 
89   // Parse a dense elements attribute.
90   case Token::kw_dense:
91     return parseDenseElementsAttr(type);
92 
93   // Parse a dictionary attribute.
94   case Token::l_brace: {
95     NamedAttrList elements;
96     if (parseAttributeDict(elements))
97       return nullptr;
98     return elements.getDictionary(getContext());
99   }
100 
101   // Parse an extended attribute, i.e. alias or dialect attribute.
102   case Token::hash_identifier:
103     return parseExtendedAttr(type);
104 
105   // Parse floating point and integer attributes.
106   case Token::floatliteral:
107     return parseFloatAttr(type, /*isNegative=*/false);
108   case Token::integer:
109     return parseDecOrHexAttr(type, /*isNegative=*/false);
110   case Token::minus: {
111     consumeToken(Token::minus);
112     if (getToken().is(Token::integer))
113       return parseDecOrHexAttr(type, /*isNegative=*/true);
114     if (getToken().is(Token::floatliteral))
115       return parseFloatAttr(type, /*isNegative=*/true);
116 
117     return (emitError("expected constant integer or floating point value"),
118             nullptr);
119   }
120 
121   // Parse a location attribute.
122   case Token::kw_loc: {
123     consumeToken(Token::kw_loc);
124 
125     LocationAttr locAttr;
126     if (parseToken(Token::l_paren, "expected '(' in inline location") ||
127         parseLocationInstance(locAttr) ||
128         parseToken(Token::r_paren, "expected ')' in inline location"))
129       return Attribute();
130     return locAttr;
131   }
132 
133   // Parse an opaque elements attribute.
134   case Token::kw_opaque:
135     return parseOpaqueElementsAttr(type);
136 
137   // Parse a sparse elements attribute.
138   case Token::kw_sparse:
139     return parseSparseElementsAttr(type);
140 
141   // Parse a string attribute.
142   case Token::string: {
143     auto val = getToken().getStringValue();
144     consumeToken(Token::string);
145     // Parse the optional trailing colon type if one wasn't explicitly provided.
146     if (!type && consumeIf(Token::colon) && !(type = parseType()))
147       return Attribute();
148 
149     return type ? StringAttr::get(val, type)
150                 : StringAttr::get(getContext(), val);
151   }
152 
153   // Parse a symbol reference attribute.
154   case Token::at_identifier: {
155     // When populating the parser state, this is a list of locations for all of
156     // the nested references.
157     SmallVector<llvm::SMRange> referenceLocations;
158     if (state.asmState)
159       referenceLocations.push_back(getToken().getLocRange());
160 
161     // Parse the top-level reference.
162     std::string nameStr = getToken().getSymbolReference();
163     consumeToken(Token::at_identifier);
164 
165     // Parse any nested references.
166     std::vector<FlatSymbolRefAttr> nestedRefs;
167     while (getToken().is(Token::colon)) {
168       // Check for the '::' prefix.
169       const char *curPointer = getToken().getLoc().getPointer();
170       consumeToken(Token::colon);
171       if (!consumeIf(Token::colon)) {
172         state.lex.resetPointer(curPointer);
173         consumeToken();
174         break;
175       }
176       // Parse the reference itself.
177       auto curLoc = getToken().getLoc();
178       if (getToken().isNot(Token::at_identifier)) {
179         emitError(curLoc, "expected nested symbol reference identifier");
180         return Attribute();
181       }
182 
183       // If we are populating the assembly state, add the location for this
184       // reference.
185       if (state.asmState)
186         referenceLocations.push_back(getToken().getLocRange());
187 
188       std::string nameStr = getToken().getSymbolReference();
189       consumeToken(Token::at_identifier);
190       nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
191     }
192     SymbolRefAttr symbolRefAttr =
193         SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
194 
195     // If we are populating the assembly state, record this symbol reference.
196     if (state.asmState)
197       state.asmState->addUses(symbolRefAttr, referenceLocations);
198     return symbolRefAttr;
199   }
200 
201   // Parse a 'unit' attribute.
202   case Token::kw_unit:
203     consumeToken(Token::kw_unit);
204     return builder.getUnitAttr();
205 
206   default:
207     // Parse a type attribute.
208     if (Type type = parseType())
209       return TypeAttr::get(type);
210     return nullptr;
211   }
212 }
213 
214 /// Parse an optional attribute with the provided type.
parseOptionalAttribute(Attribute & attribute,Type type)215 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
216                                                    Type type) {
217   switch (getToken().getKind()) {
218   case Token::at_identifier:
219   case Token::floatliteral:
220   case Token::integer:
221   case Token::hash_identifier:
222   case Token::kw_affine_map:
223   case Token::kw_affine_set:
224   case Token::kw_dense:
225   case Token::kw_false:
226   case Token::kw_loc:
227   case Token::kw_opaque:
228   case Token::kw_sparse:
229   case Token::kw_true:
230   case Token::kw_unit:
231   case Token::l_brace:
232   case Token::l_square:
233   case Token::minus:
234   case Token::string:
235     attribute = parseAttribute(type);
236     return success(attribute != nullptr);
237 
238   default:
239     // Parse an optional type attribute.
240     Type type;
241     OptionalParseResult result = parseOptionalType(type);
242     if (result.hasValue() && succeeded(*result))
243       attribute = TypeAttr::get(type);
244     return result;
245   }
246 }
parseOptionalAttribute(ArrayAttr & attribute,Type type)247 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
248                                                    Type type) {
249   return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
250 }
parseOptionalAttribute(StringAttr & attribute,Type type)251 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
252                                                    Type type) {
253   return parseOptionalAttributeWithToken(Token::string, attribute, type);
254 }
255 
256 /// Attribute dictionary.
257 ///
258 ///   attribute-dict ::= `{` `}`
259 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
260 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
261 ///
parseAttributeDict(NamedAttrList & attributes)262 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
263   llvm::SmallDenseSet<Identifier> seenKeys;
264   auto parseElt = [&]() -> ParseResult {
265     // The name of an attribute can either be a bare identifier, or a string.
266     Optional<Identifier> nameId;
267     if (getToken().is(Token::string))
268       nameId = builder.getIdentifier(getToken().getStringValue());
269     else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
270              getToken().isKeyword())
271       nameId = builder.getIdentifier(getTokenSpelling());
272     else
273       return emitError("expected attribute name");
274     if (!seenKeys.insert(*nameId).second)
275       return emitError("duplicate key '")
276              << *nameId << "' in dictionary attribute";
277     consumeToken();
278 
279     // Lazy load a dialect in the context if there is a possible namespace.
280     auto splitName = nameId->strref().split('.');
281     if (!splitName.second.empty())
282       getContext()->getOrLoadDialect(splitName.first);
283 
284     // Try to parse the '=' for the attribute value.
285     if (!consumeIf(Token::equal)) {
286       // If there is no '=', we treat this as a unit attribute.
287       attributes.push_back({*nameId, builder.getUnitAttr()});
288       return success();
289     }
290 
291     auto attr = parseAttribute();
292     if (!attr)
293       return failure();
294     attributes.push_back({*nameId, attr});
295     return success();
296   };
297 
298   if (parseCommaSeparatedList(Delimiter::Braces, parseElt,
299                               " in attribute dictionary"))
300     return failure();
301 
302   return success();
303 }
304 
305 /// Parse a float attribute.
parseFloatAttr(Type type,bool isNegative)306 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
307   auto val = getToken().getFloatingPointValue();
308   if (!val.hasValue())
309     return (emitError("floating point value too large for attribute"), nullptr);
310   consumeToken(Token::floatliteral);
311   if (!type) {
312     // Default to F64 when no type is specified.
313     if (!consumeIf(Token::colon))
314       type = builder.getF64Type();
315     else if (!(type = parseType()))
316       return nullptr;
317   }
318   if (!type.isa<FloatType>())
319     return (emitError("floating point value not valid for specified type"),
320             nullptr);
321   return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
322 }
323 
324 /// Construct an APint from a parsed value, a known attribute type and
325 /// sign.
buildAttributeAPInt(Type type,bool isNegative,StringRef spelling)326 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
327                                            StringRef spelling) {
328   // Parse the integer value into an APInt that is big enough to hold the value.
329   APInt result;
330   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
331   if (spelling.getAsInteger(isHex ? 0 : 10, result))
332     return llvm::None;
333 
334   // Extend or truncate the bitwidth to the right size.
335   unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
336                                   : type.getIntOrFloatBitWidth();
337 
338   // APInt cannot hold a zero bit value.
339   if (width == 0)
340     return llvm::None;
341 
342   if (width > result.getBitWidth()) {
343     result = result.zext(width);
344   } else if (width < result.getBitWidth()) {
345     // The parser can return an unnecessarily wide result with leading zeros.
346     // This isn't a problem, but truncating off bits is bad.
347     if (result.countLeadingZeros() < result.getBitWidth() - width)
348       return llvm::None;
349 
350     result = result.trunc(width);
351   }
352 
353   if (isNegative) {
354     // The value is negative, we have an overflow if the sign bit is not set
355     // in the negated apInt.
356     result.negate();
357     if (!result.isSignBitSet())
358       return llvm::None;
359   } else if ((type.isSignedInteger() || type.isIndex()) &&
360              result.isSignBitSet()) {
361     // The value is a positive signed integer or index,
362     // we have an overflow if the sign bit is set.
363     return llvm::None;
364   }
365 
366   return result;
367 }
368 
369 /// Parse a decimal or a hexadecimal literal, which can be either an integer
370 /// or a float attribute.
parseDecOrHexAttr(Type type,bool isNegative)371 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
372   Token tok = getToken();
373   StringRef spelling = tok.getSpelling();
374   llvm::SMLoc loc = tok.getLoc();
375 
376   consumeToken(Token::integer);
377   if (!type) {
378     // Default to i64 if not type is specified.
379     if (!consumeIf(Token::colon))
380       type = builder.getIntegerType(64);
381     else if (!(type = parseType()))
382       return nullptr;
383   }
384 
385   if (auto floatType = type.dyn_cast<FloatType>()) {
386     Optional<APFloat> result;
387     if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
388                                             floatType.getFloatSemantics(),
389                                             floatType.getWidth())))
390       return Attribute();
391     return FloatAttr::get(floatType, *result);
392   }
393 
394   if (!type.isa<IntegerType, IndexType>())
395     return emitError(loc, "integer literal not valid for specified type"),
396            nullptr;
397 
398   if (isNegative && type.isUnsignedInteger()) {
399     emitError(loc,
400               "negative integer literal not valid for unsigned integer type");
401     return nullptr;
402   }
403 
404   Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
405   if (!apInt)
406     return emitError(loc, "integer constant out of range for attribute"),
407            nullptr;
408   return builder.getIntegerAttr(type, *apInt);
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // TensorLiteralParser
413 //===----------------------------------------------------------------------===//
414 
415 /// Parse elements values stored within a hex string. On success, the values are
416 /// stored into 'result'.
parseElementAttrHexValues(Parser & parser,Token tok,std::string & result)417 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
418                                              std::string &result) {
419   if (Optional<std::string> value = tok.getHexStringValue()) {
420     result = std::move(*value);
421     return success();
422   }
423   return parser.emitError(
424       tok.getLoc(), "expected string containing hex digits starting with `0x`");
425 }
426 
427 namespace {
428 /// This class implements a parser for TensorLiterals. A tensor literal is
429 /// either a single element (e.g, 5) or a multi-dimensional list of elements
430 /// (e.g., [[5, 5]]).
431 class TensorLiteralParser {
432 public:
TensorLiteralParser(Parser & p)433   TensorLiteralParser(Parser &p) : p(p) {}
434 
435   /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
436   /// may also parse a tensor literal that is store as a hex string.
437   ParseResult parse(bool allowHex);
438 
439   /// Build a dense attribute instance with the parsed elements and the given
440   /// shaped type.
441   DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
442 
getShape() const443   ArrayRef<int64_t> getShape() const { return shape; }
444 
445 private:
446   /// Get the parsed elements for an integer attribute.
447   ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
448                                  std::vector<APInt> &intValues);
449 
450   /// Get the parsed elements for a float attribute.
451   ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
452                                    std::vector<APFloat> &floatValues);
453 
454   /// Build a Dense String attribute for the given type.
455   DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
456 
457   /// Build a Dense attribute with hex data for the given type.
458   DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
459 
460   /// Parse a single element, returning failure if it isn't a valid element
461   /// literal. For example:
462   /// parseElement(1) -> Success, 1
463   /// parseElement([1]) -> Failure
464   ParseResult parseElement();
465 
466   /// Parse a list of either lists or elements, returning the dimensions of the
467   /// parsed sub-tensors in dims. For example:
468   ///   parseList([1, 2, 3]) -> Success, [3]
469   ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
470   ///   parseList([[1, 2], 3]) -> Failure
471   ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
472   ParseResult parseList(SmallVectorImpl<int64_t> &dims);
473 
474   /// Parse a literal that was printed as a hex string.
475   ParseResult parseHexElements();
476 
477   Parser &p;
478 
479   /// The shape inferred from the parsed elements.
480   SmallVector<int64_t, 4> shape;
481 
482   /// Storage used when parsing elements, this is a pair of <is_negated, token>.
483   std::vector<std::pair<bool, Token>> storage;
484 
485   /// Storage used when parsing elements that were stored as hex values.
486   Optional<Token> hexStorage;
487 };
488 } // end anonymous namespace
489 
490 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
491 /// may also parse a tensor literal that is store as a hex string.
parse(bool allowHex)492 ParseResult TensorLiteralParser::parse(bool allowHex) {
493   // If hex is allowed, check for a string literal.
494   if (allowHex && p.getToken().is(Token::string)) {
495     hexStorage = p.getToken();
496     p.consumeToken(Token::string);
497     return success();
498   }
499   // Otherwise, parse a list or an individual element.
500   if (p.getToken().is(Token::l_square))
501     return parseList(shape);
502   return parseElement();
503 }
504 
505 /// Build a dense attribute instance with the parsed elements and the given
506 /// shaped type.
getAttr(llvm::SMLoc loc,ShapedType type)507 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
508                                                ShapedType type) {
509   Type eltType = type.getElementType();
510 
511   // Check to see if we parse the literal from a hex string.
512   if (hexStorage.hasValue() &&
513       (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
514     return getHexAttr(loc, type);
515 
516   // Check that the parsed storage size has the same number of elements to the
517   // type, or is a known splat.
518   if (!shape.empty() && getShape() != type.getShape()) {
519     p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
520                      << "]) does not match type ([" << type.getShape() << "])";
521     return nullptr;
522   }
523 
524   // Handle the case where no elements were parsed.
525   if (!hexStorage.hasValue() && storage.empty() && type.getNumElements()) {
526     p.emitError(loc) << "parsed zero elements, but type (" << type
527                      << ") expected at least 1";
528     return nullptr;
529   }
530 
531   // Handle complex types in the specific element type cases below.
532   bool isComplex = false;
533   if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
534     eltType = complexTy.getElementType();
535     isComplex = true;
536   }
537 
538   // Handle integer and index types.
539   if (eltType.isIntOrIndex()) {
540     std::vector<APInt> intValues;
541     if (failed(getIntAttrElements(loc, eltType, intValues)))
542       return nullptr;
543     if (isComplex) {
544       // If this is a complex, treat the parsed values as complex values.
545       auto complexData = llvm::makeArrayRef(
546           reinterpret_cast<std::complex<APInt> *>(intValues.data()),
547           intValues.size() / 2);
548       return DenseElementsAttr::get(type, complexData);
549     }
550     return DenseElementsAttr::get(type, intValues);
551   }
552   // Handle floating point types.
553   if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
554     std::vector<APFloat> floatValues;
555     if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
556       return nullptr;
557     if (isComplex) {
558       // If this is a complex, treat the parsed values as complex values.
559       auto complexData = llvm::makeArrayRef(
560           reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
561           floatValues.size() / 2);
562       return DenseElementsAttr::get(type, complexData);
563     }
564     return DenseElementsAttr::get(type, floatValues);
565   }
566 
567   // Other types are assumed to be string representations.
568   return getStringAttr(loc, type, type.getElementType());
569 }
570 
571 /// Build a Dense Integer attribute for the given type.
572 ParseResult
getIntAttrElements(llvm::SMLoc loc,Type eltTy,std::vector<APInt> & intValues)573 TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
574                                         std::vector<APInt> &intValues) {
575   intValues.reserve(storage.size());
576   bool isUintType = eltTy.isUnsignedInteger();
577   for (const auto &signAndToken : storage) {
578     bool isNegative = signAndToken.first;
579     const Token &token = signAndToken.second;
580     auto tokenLoc = token.getLoc();
581 
582     if (isNegative && isUintType) {
583       return p.emitError(tokenLoc)
584              << "expected unsigned integer elements, but parsed negative value";
585     }
586 
587     // Check to see if floating point values were parsed.
588     if (token.is(Token::floatliteral)) {
589       return p.emitError(tokenLoc)
590              << "expected integer elements, but parsed floating-point";
591     }
592 
593     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
594            "unexpected token type");
595     if (token.isAny(Token::kw_true, Token::kw_false)) {
596       if (!eltTy.isInteger(1)) {
597         return p.emitError(tokenLoc)
598                << "expected i1 type for 'true' or 'false' values";
599       }
600       APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
601       intValues.push_back(apInt);
602       continue;
603     }
604 
605     // Create APInt values for each element with the correct bitwidth.
606     Optional<APInt> apInt =
607         buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
608     if (!apInt)
609       return p.emitError(tokenLoc, "integer constant out of range for type");
610     intValues.push_back(*apInt);
611   }
612   return success();
613 }
614 
615 /// Build a Dense Float attribute for the given type.
616 ParseResult
getFloatAttrElements(llvm::SMLoc loc,FloatType eltTy,std::vector<APFloat> & floatValues)617 TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
618                                           std::vector<APFloat> &floatValues) {
619   floatValues.reserve(storage.size());
620   for (const auto &signAndToken : storage) {
621     bool isNegative = signAndToken.first;
622     const Token &token = signAndToken.second;
623 
624     // Handle hexadecimal float literals.
625     if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
626       Optional<APFloat> result;
627       if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
628                                                 eltTy.getFloatSemantics(),
629                                                 eltTy.getWidth())))
630         return failure();
631 
632       floatValues.push_back(*result);
633       continue;
634     }
635 
636     // Check to see if any decimal integers or booleans were parsed.
637     if (!token.is(Token::floatliteral))
638       return p.emitError()
639              << "expected floating-point elements, but parsed integer";
640 
641     // Build the float values from tokens.
642     auto val = token.getFloatingPointValue();
643     if (!val.hasValue())
644       return p.emitError("floating point value too large for attribute");
645 
646     APFloat apVal(isNegative ? -*val : *val);
647     if (!eltTy.isF64()) {
648       bool unused;
649       apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
650                     &unused);
651     }
652     floatValues.push_back(apVal);
653   }
654   return success();
655 }
656 
657 /// Build a Dense String attribute for the given type.
getStringAttr(llvm::SMLoc loc,ShapedType type,Type eltTy)658 DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
659                                                      ShapedType type,
660                                                      Type eltTy) {
661   if (hexStorage.hasValue()) {
662     auto stringValue = hexStorage.getValue().getStringValue();
663     return DenseStringElementsAttr::get(type, {stringValue});
664   }
665 
666   std::vector<std::string> stringValues;
667   std::vector<StringRef> stringRefValues;
668   stringValues.reserve(storage.size());
669   stringRefValues.reserve(storage.size());
670 
671   for (auto val : storage) {
672     stringValues.push_back(val.second.getStringValue());
673     stringRefValues.push_back(stringValues.back());
674   }
675 
676   return DenseStringElementsAttr::get(type, stringRefValues);
677 }
678 
679 /// Build a Dense attribute with hex data for the given type.
getHexAttr(llvm::SMLoc loc,ShapedType type)680 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
681                                                   ShapedType type) {
682   Type elementType = type.getElementType();
683   if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
684     p.emitError(loc)
685         << "expected floating-point, integer, or complex element type, got "
686         << elementType;
687     return nullptr;
688   }
689 
690   std::string data;
691   if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
692     return nullptr;
693 
694   ArrayRef<char> rawData(data.data(), data.size());
695   bool detectedSplat = false;
696   if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
697     p.emitError(loc) << "elements hex data size is invalid for provided type: "
698                      << type;
699     return nullptr;
700   }
701 
702   if (llvm::support::endian::system_endianness() ==
703       llvm::support::endianness::big) {
704     // Convert endianess in big-endian(BE) machines. `rawData` is
705     // little-endian(LE) because HEX in raw data of dense element attribute
706     // is always LE format. It is converted into BE here to be used in BE
707     // machines.
708     SmallVector<char, 64> outDataVec(rawData.size());
709     MutableArrayRef<char> convRawData(outDataVec);
710     DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
711         rawData, convRawData, type);
712     return DenseElementsAttr::getFromRawBuffer(type, convRawData,
713                                                detectedSplat);
714   }
715 
716   return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
717 }
718 
parseElement()719 ParseResult TensorLiteralParser::parseElement() {
720   switch (p.getToken().getKind()) {
721   // Parse a boolean element.
722   case Token::kw_true:
723   case Token::kw_false:
724   case Token::floatliteral:
725   case Token::integer:
726     storage.emplace_back(/*isNegative=*/false, p.getToken());
727     p.consumeToken();
728     break;
729 
730   // Parse a signed integer or a negative floating-point element.
731   case Token::minus:
732     p.consumeToken(Token::minus);
733     if (!p.getToken().isAny(Token::floatliteral, Token::integer))
734       return p.emitError("expected integer or floating point literal");
735     storage.emplace_back(/*isNegative=*/true, p.getToken());
736     p.consumeToken();
737     break;
738 
739   case Token::string:
740     storage.emplace_back(/*isNegative=*/false, p.getToken());
741     p.consumeToken();
742     break;
743 
744   // Parse a complex element of the form '(' element ',' element ')'.
745   case Token::l_paren:
746     p.consumeToken(Token::l_paren);
747     if (parseElement() ||
748         p.parseToken(Token::comma, "expected ',' between complex elements") ||
749         parseElement() ||
750         p.parseToken(Token::r_paren, "expected ')' after complex elements"))
751       return failure();
752     break;
753 
754   default:
755     return p.emitError("expected element literal of primitive type");
756   }
757 
758   return success();
759 }
760 
761 /// Parse a list of either lists or elements, returning the dimensions of the
762 /// parsed sub-tensors in dims. For example:
763 ///   parseList([1, 2, 3]) -> Success, [3]
764 ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
765 ///   parseList([[1, 2], 3]) -> Failure
766 ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
parseList(SmallVectorImpl<int64_t> & dims)767 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
768   auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
769                        const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
770     if (prevDims == newDims)
771       return success();
772     return p.emitError("tensor literal is invalid; ranks are not consistent "
773                        "between elements");
774   };
775 
776   bool first = true;
777   SmallVector<int64_t, 4> newDims;
778   unsigned size = 0;
779   auto parseOneElement = [&]() -> ParseResult {
780     SmallVector<int64_t, 4> thisDims;
781     if (p.getToken().getKind() == Token::l_square) {
782       if (parseList(thisDims))
783         return failure();
784     } else if (parseElement()) {
785       return failure();
786     }
787     ++size;
788     if (!first)
789       return checkDims(newDims, thisDims);
790     newDims = thisDims;
791     first = false;
792     return success();
793   };
794   if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
795     return failure();
796 
797   // Return the sublists' dimensions with 'size' prepended.
798   dims.clear();
799   dims.push_back(size);
800   dims.append(newDims.begin(), newDims.end());
801   return success();
802 }
803 
804 //===----------------------------------------------------------------------===//
805 // ElementsAttr Parser
806 //===----------------------------------------------------------------------===//
807 
808 /// Parse a dense elements attribute.
parseDenseElementsAttr(Type attrType)809 Attribute Parser::parseDenseElementsAttr(Type attrType) {
810   auto attribLoc = getToken().getLoc();
811   consumeToken(Token::kw_dense);
812   if (parseToken(Token::less, "expected '<' after 'dense'"))
813     return nullptr;
814 
815   // Parse the literal data if necessary.
816   TensorLiteralParser literalParser(*this);
817   if (!consumeIf(Token::greater)) {
818     if (literalParser.parse(/*allowHex=*/true) ||
819         parseToken(Token::greater, "expected '>'"))
820       return nullptr;
821   }
822 
823   // If the type is specified `parseElementsLiteralType` will not parse a type.
824   // Use the attribute location as the location for error reporting in that
825   // case.
826   auto loc = attrType ? attribLoc : getToken().getLoc();
827   auto type = parseElementsLiteralType(attrType);
828   if (!type)
829     return nullptr;
830   return literalParser.getAttr(loc, type);
831 }
832 
833 /// Parse an opaque elements attribute.
parseOpaqueElementsAttr(Type attrType)834 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
835   consumeToken(Token::kw_opaque);
836   if (parseToken(Token::less, "expected '<' after 'opaque'"))
837     return nullptr;
838 
839   if (getToken().isNot(Token::string))
840     return (emitError("expected dialect namespace"), nullptr);
841 
842   std::string name = getToken().getStringValue();
843   consumeToken(Token::string);
844 
845   if (parseToken(Token::comma, "expected ','"))
846     return nullptr;
847 
848   Token hexTok = getToken();
849   if (parseToken(Token::string, "elements hex string should start with '0x'") ||
850       parseToken(Token::greater, "expected '>'"))
851     return nullptr;
852   auto type = parseElementsLiteralType(attrType);
853   if (!type)
854     return nullptr;
855 
856   std::string data;
857   if (parseElementAttrHexValues(*this, hexTok, data))
858     return nullptr;
859   return OpaqueElementsAttr::get(builder.getIdentifier(name), type, data);
860 }
861 
862 /// Shaped type for elements attribute.
863 ///
864 ///   elements-literal-type ::= vector-type | ranked-tensor-type
865 ///
866 /// This method also checks the type has static shape.
parseElementsLiteralType(Type type)867 ShapedType Parser::parseElementsLiteralType(Type type) {
868   // If the user didn't provide a type, parse the colon type for the literal.
869   if (!type) {
870     if (parseToken(Token::colon, "expected ':'"))
871       return nullptr;
872     if (!(type = parseType()))
873       return nullptr;
874   }
875 
876   if (!type.isa<RankedTensorType, VectorType>()) {
877     emitError("elements literal must be a ranked tensor or vector type");
878     return nullptr;
879   }
880 
881   auto sType = type.cast<ShapedType>();
882   if (!sType.hasStaticShape())
883     return (emitError("elements literal type must have static shape"), nullptr);
884 
885   return sType;
886 }
887 
888 /// Parse a sparse elements attribute.
parseSparseElementsAttr(Type attrType)889 Attribute Parser::parseSparseElementsAttr(Type attrType) {
890   llvm::SMLoc loc = getToken().getLoc();
891   consumeToken(Token::kw_sparse);
892   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
893     return nullptr;
894 
895   // Check for the case where all elements are sparse. The indices are
896   // represented by a 2-dimensional shape where the second dimension is the rank
897   // of the type.
898   Type indiceEltType = builder.getIntegerType(64);
899   if (consumeIf(Token::greater)) {
900     ShapedType type = parseElementsLiteralType(attrType);
901     if (!type)
902       return nullptr;
903 
904     // Construct the sparse elements attr using zero element indice/value
905     // attributes.
906     ShapedType indicesType =
907         RankedTensorType::get({0, type.getRank()}, indiceEltType);
908     ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
909     return getChecked<SparseElementsAttr>(
910         loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
911         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
912   }
913 
914   /// Parse the indices. We don't allow hex values here as we may need to use
915   /// the inferred shape.
916   auto indicesLoc = getToken().getLoc();
917   TensorLiteralParser indiceParser(*this);
918   if (indiceParser.parse(/*allowHex=*/false))
919     return nullptr;
920 
921   if (parseToken(Token::comma, "expected ','"))
922     return nullptr;
923 
924   /// Parse the values.
925   auto valuesLoc = getToken().getLoc();
926   TensorLiteralParser valuesParser(*this);
927   if (valuesParser.parse(/*allowHex=*/true))
928     return nullptr;
929 
930   if (parseToken(Token::greater, "expected '>'"))
931     return nullptr;
932 
933   auto type = parseElementsLiteralType(attrType);
934   if (!type)
935     return nullptr;
936 
937   // If the indices are a splat, i.e. the literal parser parsed an element and
938   // not a list, we set the shape explicitly. The indices are represented by a
939   // 2-dimensional shape where the second dimension is the rank of the type.
940   // Given that the parsed indices is a splat, we know that we only have one
941   // indice and thus one for the first dimension.
942   ShapedType indicesType;
943   if (indiceParser.getShape().empty()) {
944     indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
945   } else {
946     // Otherwise, set the shape to the one parsed by the literal parser.
947     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
948   }
949   auto indices = indiceParser.getAttr(indicesLoc, indicesType);
950 
951   // If the values are a splat, set the shape explicitly based on the number of
952   // indices. The number of indices is encoded in the first dimension of the
953   // indice shape type.
954   auto valuesEltType = type.getElementType();
955   ShapedType valuesType =
956       valuesParser.getShape().empty()
957           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
958           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
959   auto values = valuesParser.getAttr(valuesLoc, valuesType);
960 
961   // Build the sparse elements attribute by the indices and values.
962   return getChecked<SparseElementsAttr>(loc, type, indices, values);
963 }
964