1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "mlir/Parser/AsmParserState.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/Endian.h"
21
22 using namespace mlir;
23 using namespace mlir::detail;
24
25 /// Parse an arbitrary attribute.
26 ///
27 /// attribute-value ::= `unit`
28 /// | bool-literal
29 /// | integer-literal (`:` (index-type | integer-type))?
30 /// | float-literal (`:` float-type)?
31 /// | string-literal (`:` type)?
32 /// | type
33 /// | `[` (attribute-value (`,` attribute-value)*)? `]`
34 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
35 /// | symbol-ref-id (`::` symbol-ref-id)*
36 /// | `dense` `<` attribute-value `>` `:`
37 /// (tensor-type | vector-type)
38 /// | `sparse` `<` attribute-value `,` attribute-value `>`
39 /// `:` (tensor-type | vector-type)
40 /// | `opaque` `<` dialect-namespace `,` hex-string-literal
41 /// `>` `:` (tensor-type | vector-type)
42 /// | extended-attribute
43 ///
parseAttribute(Type type)44 Attribute Parser::parseAttribute(Type type) {
45 switch (getToken().getKind()) {
46 // Parse an AffineMap or IntegerSet attribute.
47 case Token::kw_affine_map: {
48 consumeToken(Token::kw_affine_map);
49
50 AffineMap map;
51 if (parseToken(Token::less, "expected '<' in affine map") ||
52 parseAffineMapReference(map) ||
53 parseToken(Token::greater, "expected '>' in affine map"))
54 return Attribute();
55 return AffineMapAttr::get(map);
56 }
57 case Token::kw_affine_set: {
58 consumeToken(Token::kw_affine_set);
59
60 IntegerSet set;
61 if (parseToken(Token::less, "expected '<' in integer set") ||
62 parseIntegerSetReference(set) ||
63 parseToken(Token::greater, "expected '>' in integer set"))
64 return Attribute();
65 return IntegerSetAttr::get(set);
66 }
67
68 // Parse an array attribute.
69 case Token::l_square: {
70 SmallVector<Attribute, 4> elements;
71 auto parseElt = [&]() -> ParseResult {
72 elements.push_back(parseAttribute());
73 return elements.back() ? success() : failure();
74 };
75
76 if (parseCommaSeparatedList(Delimiter::Square, parseElt))
77 return nullptr;
78 return builder.getArrayAttr(elements);
79 }
80
81 // Parse a boolean attribute.
82 case Token::kw_false:
83 consumeToken(Token::kw_false);
84 return builder.getBoolAttr(false);
85 case Token::kw_true:
86 consumeToken(Token::kw_true);
87 return builder.getBoolAttr(true);
88
89 // Parse a dense elements attribute.
90 case Token::kw_dense:
91 return parseDenseElementsAttr(type);
92
93 // Parse a dictionary attribute.
94 case Token::l_brace: {
95 NamedAttrList elements;
96 if (parseAttributeDict(elements))
97 return nullptr;
98 return elements.getDictionary(getContext());
99 }
100
101 // Parse an extended attribute, i.e. alias or dialect attribute.
102 case Token::hash_identifier:
103 return parseExtendedAttr(type);
104
105 // Parse floating point and integer attributes.
106 case Token::floatliteral:
107 return parseFloatAttr(type, /*isNegative=*/false);
108 case Token::integer:
109 return parseDecOrHexAttr(type, /*isNegative=*/false);
110 case Token::minus: {
111 consumeToken(Token::minus);
112 if (getToken().is(Token::integer))
113 return parseDecOrHexAttr(type, /*isNegative=*/true);
114 if (getToken().is(Token::floatliteral))
115 return parseFloatAttr(type, /*isNegative=*/true);
116
117 return (emitError("expected constant integer or floating point value"),
118 nullptr);
119 }
120
121 // Parse a location attribute.
122 case Token::kw_loc: {
123 consumeToken(Token::kw_loc);
124
125 LocationAttr locAttr;
126 if (parseToken(Token::l_paren, "expected '(' in inline location") ||
127 parseLocationInstance(locAttr) ||
128 parseToken(Token::r_paren, "expected ')' in inline location"))
129 return Attribute();
130 return locAttr;
131 }
132
133 // Parse an opaque elements attribute.
134 case Token::kw_opaque:
135 return parseOpaqueElementsAttr(type);
136
137 // Parse a sparse elements attribute.
138 case Token::kw_sparse:
139 return parseSparseElementsAttr(type);
140
141 // Parse a string attribute.
142 case Token::string: {
143 auto val = getToken().getStringValue();
144 consumeToken(Token::string);
145 // Parse the optional trailing colon type if one wasn't explicitly provided.
146 if (!type && consumeIf(Token::colon) && !(type = parseType()))
147 return Attribute();
148
149 return type ? StringAttr::get(val, type)
150 : StringAttr::get(getContext(), val);
151 }
152
153 // Parse a symbol reference attribute.
154 case Token::at_identifier: {
155 // When populating the parser state, this is a list of locations for all of
156 // the nested references.
157 SmallVector<llvm::SMRange> referenceLocations;
158 if (state.asmState)
159 referenceLocations.push_back(getToken().getLocRange());
160
161 // Parse the top-level reference.
162 std::string nameStr = getToken().getSymbolReference();
163 consumeToken(Token::at_identifier);
164
165 // Parse any nested references.
166 std::vector<FlatSymbolRefAttr> nestedRefs;
167 while (getToken().is(Token::colon)) {
168 // Check for the '::' prefix.
169 const char *curPointer = getToken().getLoc().getPointer();
170 consumeToken(Token::colon);
171 if (!consumeIf(Token::colon)) {
172 state.lex.resetPointer(curPointer);
173 consumeToken();
174 break;
175 }
176 // Parse the reference itself.
177 auto curLoc = getToken().getLoc();
178 if (getToken().isNot(Token::at_identifier)) {
179 emitError(curLoc, "expected nested symbol reference identifier");
180 return Attribute();
181 }
182
183 // If we are populating the assembly state, add the location for this
184 // reference.
185 if (state.asmState)
186 referenceLocations.push_back(getToken().getLocRange());
187
188 std::string nameStr = getToken().getSymbolReference();
189 consumeToken(Token::at_identifier);
190 nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
191 }
192 SymbolRefAttr symbolRefAttr =
193 SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
194
195 // If we are populating the assembly state, record this symbol reference.
196 if (state.asmState)
197 state.asmState->addUses(symbolRefAttr, referenceLocations);
198 return symbolRefAttr;
199 }
200
201 // Parse a 'unit' attribute.
202 case Token::kw_unit:
203 consumeToken(Token::kw_unit);
204 return builder.getUnitAttr();
205
206 default:
207 // Parse a type attribute.
208 if (Type type = parseType())
209 return TypeAttr::get(type);
210 return nullptr;
211 }
212 }
213
214 /// Parse an optional attribute with the provided type.
parseOptionalAttribute(Attribute & attribute,Type type)215 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
216 Type type) {
217 switch (getToken().getKind()) {
218 case Token::at_identifier:
219 case Token::floatliteral:
220 case Token::integer:
221 case Token::hash_identifier:
222 case Token::kw_affine_map:
223 case Token::kw_affine_set:
224 case Token::kw_dense:
225 case Token::kw_false:
226 case Token::kw_loc:
227 case Token::kw_opaque:
228 case Token::kw_sparse:
229 case Token::kw_true:
230 case Token::kw_unit:
231 case Token::l_brace:
232 case Token::l_square:
233 case Token::minus:
234 case Token::string:
235 attribute = parseAttribute(type);
236 return success(attribute != nullptr);
237
238 default:
239 // Parse an optional type attribute.
240 Type type;
241 OptionalParseResult result = parseOptionalType(type);
242 if (result.hasValue() && succeeded(*result))
243 attribute = TypeAttr::get(type);
244 return result;
245 }
246 }
parseOptionalAttribute(ArrayAttr & attribute,Type type)247 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
248 Type type) {
249 return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
250 }
parseOptionalAttribute(StringAttr & attribute,Type type)251 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
252 Type type) {
253 return parseOptionalAttributeWithToken(Token::string, attribute, type);
254 }
255
256 /// Attribute dictionary.
257 ///
258 /// attribute-dict ::= `{` `}`
259 /// | `{` attribute-entry (`,` attribute-entry)* `}`
260 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
261 ///
parseAttributeDict(NamedAttrList & attributes)262 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
263 llvm::SmallDenseSet<Identifier> seenKeys;
264 auto parseElt = [&]() -> ParseResult {
265 // The name of an attribute can either be a bare identifier, or a string.
266 Optional<Identifier> nameId;
267 if (getToken().is(Token::string))
268 nameId = builder.getIdentifier(getToken().getStringValue());
269 else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
270 getToken().isKeyword())
271 nameId = builder.getIdentifier(getTokenSpelling());
272 else
273 return emitError("expected attribute name");
274 if (!seenKeys.insert(*nameId).second)
275 return emitError("duplicate key '")
276 << *nameId << "' in dictionary attribute";
277 consumeToken();
278
279 // Lazy load a dialect in the context if there is a possible namespace.
280 auto splitName = nameId->strref().split('.');
281 if (!splitName.second.empty())
282 getContext()->getOrLoadDialect(splitName.first);
283
284 // Try to parse the '=' for the attribute value.
285 if (!consumeIf(Token::equal)) {
286 // If there is no '=', we treat this as a unit attribute.
287 attributes.push_back({*nameId, builder.getUnitAttr()});
288 return success();
289 }
290
291 auto attr = parseAttribute();
292 if (!attr)
293 return failure();
294 attributes.push_back({*nameId, attr});
295 return success();
296 };
297
298 if (parseCommaSeparatedList(Delimiter::Braces, parseElt,
299 " in attribute dictionary"))
300 return failure();
301
302 return success();
303 }
304
305 /// Parse a float attribute.
parseFloatAttr(Type type,bool isNegative)306 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
307 auto val = getToken().getFloatingPointValue();
308 if (!val.hasValue())
309 return (emitError("floating point value too large for attribute"), nullptr);
310 consumeToken(Token::floatliteral);
311 if (!type) {
312 // Default to F64 when no type is specified.
313 if (!consumeIf(Token::colon))
314 type = builder.getF64Type();
315 else if (!(type = parseType()))
316 return nullptr;
317 }
318 if (!type.isa<FloatType>())
319 return (emitError("floating point value not valid for specified type"),
320 nullptr);
321 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
322 }
323
324 /// Construct an APint from a parsed value, a known attribute type and
325 /// sign.
buildAttributeAPInt(Type type,bool isNegative,StringRef spelling)326 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
327 StringRef spelling) {
328 // Parse the integer value into an APInt that is big enough to hold the value.
329 APInt result;
330 bool isHex = spelling.size() > 1 && spelling[1] == 'x';
331 if (spelling.getAsInteger(isHex ? 0 : 10, result))
332 return llvm::None;
333
334 // Extend or truncate the bitwidth to the right size.
335 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
336 : type.getIntOrFloatBitWidth();
337
338 // APInt cannot hold a zero bit value.
339 if (width == 0)
340 return llvm::None;
341
342 if (width > result.getBitWidth()) {
343 result = result.zext(width);
344 } else if (width < result.getBitWidth()) {
345 // The parser can return an unnecessarily wide result with leading zeros.
346 // This isn't a problem, but truncating off bits is bad.
347 if (result.countLeadingZeros() < result.getBitWidth() - width)
348 return llvm::None;
349
350 result = result.trunc(width);
351 }
352
353 if (isNegative) {
354 // The value is negative, we have an overflow if the sign bit is not set
355 // in the negated apInt.
356 result.negate();
357 if (!result.isSignBitSet())
358 return llvm::None;
359 } else if ((type.isSignedInteger() || type.isIndex()) &&
360 result.isSignBitSet()) {
361 // The value is a positive signed integer or index,
362 // we have an overflow if the sign bit is set.
363 return llvm::None;
364 }
365
366 return result;
367 }
368
369 /// Parse a decimal or a hexadecimal literal, which can be either an integer
370 /// or a float attribute.
parseDecOrHexAttr(Type type,bool isNegative)371 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
372 Token tok = getToken();
373 StringRef spelling = tok.getSpelling();
374 llvm::SMLoc loc = tok.getLoc();
375
376 consumeToken(Token::integer);
377 if (!type) {
378 // Default to i64 if not type is specified.
379 if (!consumeIf(Token::colon))
380 type = builder.getIntegerType(64);
381 else if (!(type = parseType()))
382 return nullptr;
383 }
384
385 if (auto floatType = type.dyn_cast<FloatType>()) {
386 Optional<APFloat> result;
387 if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
388 floatType.getFloatSemantics(),
389 floatType.getWidth())))
390 return Attribute();
391 return FloatAttr::get(floatType, *result);
392 }
393
394 if (!type.isa<IntegerType, IndexType>())
395 return emitError(loc, "integer literal not valid for specified type"),
396 nullptr;
397
398 if (isNegative && type.isUnsignedInteger()) {
399 emitError(loc,
400 "negative integer literal not valid for unsigned integer type");
401 return nullptr;
402 }
403
404 Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
405 if (!apInt)
406 return emitError(loc, "integer constant out of range for attribute"),
407 nullptr;
408 return builder.getIntegerAttr(type, *apInt);
409 }
410
411 //===----------------------------------------------------------------------===//
412 // TensorLiteralParser
413 //===----------------------------------------------------------------------===//
414
415 /// Parse elements values stored within a hex string. On success, the values are
416 /// stored into 'result'.
parseElementAttrHexValues(Parser & parser,Token tok,std::string & result)417 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
418 std::string &result) {
419 if (Optional<std::string> value = tok.getHexStringValue()) {
420 result = std::move(*value);
421 return success();
422 }
423 return parser.emitError(
424 tok.getLoc(), "expected string containing hex digits starting with `0x`");
425 }
426
427 namespace {
428 /// This class implements a parser for TensorLiterals. A tensor literal is
429 /// either a single element (e.g, 5) or a multi-dimensional list of elements
430 /// (e.g., [[5, 5]]).
431 class TensorLiteralParser {
432 public:
TensorLiteralParser(Parser & p)433 TensorLiteralParser(Parser &p) : p(p) {}
434
435 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
436 /// may also parse a tensor literal that is store as a hex string.
437 ParseResult parse(bool allowHex);
438
439 /// Build a dense attribute instance with the parsed elements and the given
440 /// shaped type.
441 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
442
getShape() const443 ArrayRef<int64_t> getShape() const { return shape; }
444
445 private:
446 /// Get the parsed elements for an integer attribute.
447 ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
448 std::vector<APInt> &intValues);
449
450 /// Get the parsed elements for a float attribute.
451 ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
452 std::vector<APFloat> &floatValues);
453
454 /// Build a Dense String attribute for the given type.
455 DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
456
457 /// Build a Dense attribute with hex data for the given type.
458 DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
459
460 /// Parse a single element, returning failure if it isn't a valid element
461 /// literal. For example:
462 /// parseElement(1) -> Success, 1
463 /// parseElement([1]) -> Failure
464 ParseResult parseElement();
465
466 /// Parse a list of either lists or elements, returning the dimensions of the
467 /// parsed sub-tensors in dims. For example:
468 /// parseList([1, 2, 3]) -> Success, [3]
469 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
470 /// parseList([[1, 2], 3]) -> Failure
471 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
472 ParseResult parseList(SmallVectorImpl<int64_t> &dims);
473
474 /// Parse a literal that was printed as a hex string.
475 ParseResult parseHexElements();
476
477 Parser &p;
478
479 /// The shape inferred from the parsed elements.
480 SmallVector<int64_t, 4> shape;
481
482 /// Storage used when parsing elements, this is a pair of <is_negated, token>.
483 std::vector<std::pair<bool, Token>> storage;
484
485 /// Storage used when parsing elements that were stored as hex values.
486 Optional<Token> hexStorage;
487 };
488 } // end anonymous namespace
489
490 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
491 /// may also parse a tensor literal that is store as a hex string.
parse(bool allowHex)492 ParseResult TensorLiteralParser::parse(bool allowHex) {
493 // If hex is allowed, check for a string literal.
494 if (allowHex && p.getToken().is(Token::string)) {
495 hexStorage = p.getToken();
496 p.consumeToken(Token::string);
497 return success();
498 }
499 // Otherwise, parse a list or an individual element.
500 if (p.getToken().is(Token::l_square))
501 return parseList(shape);
502 return parseElement();
503 }
504
505 /// Build a dense attribute instance with the parsed elements and the given
506 /// shaped type.
getAttr(llvm::SMLoc loc,ShapedType type)507 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
508 ShapedType type) {
509 Type eltType = type.getElementType();
510
511 // Check to see if we parse the literal from a hex string.
512 if (hexStorage.hasValue() &&
513 (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
514 return getHexAttr(loc, type);
515
516 // Check that the parsed storage size has the same number of elements to the
517 // type, or is a known splat.
518 if (!shape.empty() && getShape() != type.getShape()) {
519 p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
520 << "]) does not match type ([" << type.getShape() << "])";
521 return nullptr;
522 }
523
524 // Handle the case where no elements were parsed.
525 if (!hexStorage.hasValue() && storage.empty() && type.getNumElements()) {
526 p.emitError(loc) << "parsed zero elements, but type (" << type
527 << ") expected at least 1";
528 return nullptr;
529 }
530
531 // Handle complex types in the specific element type cases below.
532 bool isComplex = false;
533 if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
534 eltType = complexTy.getElementType();
535 isComplex = true;
536 }
537
538 // Handle integer and index types.
539 if (eltType.isIntOrIndex()) {
540 std::vector<APInt> intValues;
541 if (failed(getIntAttrElements(loc, eltType, intValues)))
542 return nullptr;
543 if (isComplex) {
544 // If this is a complex, treat the parsed values as complex values.
545 auto complexData = llvm::makeArrayRef(
546 reinterpret_cast<std::complex<APInt> *>(intValues.data()),
547 intValues.size() / 2);
548 return DenseElementsAttr::get(type, complexData);
549 }
550 return DenseElementsAttr::get(type, intValues);
551 }
552 // Handle floating point types.
553 if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
554 std::vector<APFloat> floatValues;
555 if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
556 return nullptr;
557 if (isComplex) {
558 // If this is a complex, treat the parsed values as complex values.
559 auto complexData = llvm::makeArrayRef(
560 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
561 floatValues.size() / 2);
562 return DenseElementsAttr::get(type, complexData);
563 }
564 return DenseElementsAttr::get(type, floatValues);
565 }
566
567 // Other types are assumed to be string representations.
568 return getStringAttr(loc, type, type.getElementType());
569 }
570
571 /// Build a Dense Integer attribute for the given type.
572 ParseResult
getIntAttrElements(llvm::SMLoc loc,Type eltTy,std::vector<APInt> & intValues)573 TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
574 std::vector<APInt> &intValues) {
575 intValues.reserve(storage.size());
576 bool isUintType = eltTy.isUnsignedInteger();
577 for (const auto &signAndToken : storage) {
578 bool isNegative = signAndToken.first;
579 const Token &token = signAndToken.second;
580 auto tokenLoc = token.getLoc();
581
582 if (isNegative && isUintType) {
583 return p.emitError(tokenLoc)
584 << "expected unsigned integer elements, but parsed negative value";
585 }
586
587 // Check to see if floating point values were parsed.
588 if (token.is(Token::floatliteral)) {
589 return p.emitError(tokenLoc)
590 << "expected integer elements, but parsed floating-point";
591 }
592
593 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
594 "unexpected token type");
595 if (token.isAny(Token::kw_true, Token::kw_false)) {
596 if (!eltTy.isInteger(1)) {
597 return p.emitError(tokenLoc)
598 << "expected i1 type for 'true' or 'false' values";
599 }
600 APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
601 intValues.push_back(apInt);
602 continue;
603 }
604
605 // Create APInt values for each element with the correct bitwidth.
606 Optional<APInt> apInt =
607 buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
608 if (!apInt)
609 return p.emitError(tokenLoc, "integer constant out of range for type");
610 intValues.push_back(*apInt);
611 }
612 return success();
613 }
614
615 /// Build a Dense Float attribute for the given type.
616 ParseResult
getFloatAttrElements(llvm::SMLoc loc,FloatType eltTy,std::vector<APFloat> & floatValues)617 TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
618 std::vector<APFloat> &floatValues) {
619 floatValues.reserve(storage.size());
620 for (const auto &signAndToken : storage) {
621 bool isNegative = signAndToken.first;
622 const Token &token = signAndToken.second;
623
624 // Handle hexadecimal float literals.
625 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
626 Optional<APFloat> result;
627 if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
628 eltTy.getFloatSemantics(),
629 eltTy.getWidth())))
630 return failure();
631
632 floatValues.push_back(*result);
633 continue;
634 }
635
636 // Check to see if any decimal integers or booleans were parsed.
637 if (!token.is(Token::floatliteral))
638 return p.emitError()
639 << "expected floating-point elements, but parsed integer";
640
641 // Build the float values from tokens.
642 auto val = token.getFloatingPointValue();
643 if (!val.hasValue())
644 return p.emitError("floating point value too large for attribute");
645
646 APFloat apVal(isNegative ? -*val : *val);
647 if (!eltTy.isF64()) {
648 bool unused;
649 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
650 &unused);
651 }
652 floatValues.push_back(apVal);
653 }
654 return success();
655 }
656
657 /// Build a Dense String attribute for the given type.
getStringAttr(llvm::SMLoc loc,ShapedType type,Type eltTy)658 DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
659 ShapedType type,
660 Type eltTy) {
661 if (hexStorage.hasValue()) {
662 auto stringValue = hexStorage.getValue().getStringValue();
663 return DenseStringElementsAttr::get(type, {stringValue});
664 }
665
666 std::vector<std::string> stringValues;
667 std::vector<StringRef> stringRefValues;
668 stringValues.reserve(storage.size());
669 stringRefValues.reserve(storage.size());
670
671 for (auto val : storage) {
672 stringValues.push_back(val.second.getStringValue());
673 stringRefValues.push_back(stringValues.back());
674 }
675
676 return DenseStringElementsAttr::get(type, stringRefValues);
677 }
678
679 /// Build a Dense attribute with hex data for the given type.
getHexAttr(llvm::SMLoc loc,ShapedType type)680 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
681 ShapedType type) {
682 Type elementType = type.getElementType();
683 if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
684 p.emitError(loc)
685 << "expected floating-point, integer, or complex element type, got "
686 << elementType;
687 return nullptr;
688 }
689
690 std::string data;
691 if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
692 return nullptr;
693
694 ArrayRef<char> rawData(data.data(), data.size());
695 bool detectedSplat = false;
696 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
697 p.emitError(loc) << "elements hex data size is invalid for provided type: "
698 << type;
699 return nullptr;
700 }
701
702 if (llvm::support::endian::system_endianness() ==
703 llvm::support::endianness::big) {
704 // Convert endianess in big-endian(BE) machines. `rawData` is
705 // little-endian(LE) because HEX in raw data of dense element attribute
706 // is always LE format. It is converted into BE here to be used in BE
707 // machines.
708 SmallVector<char, 64> outDataVec(rawData.size());
709 MutableArrayRef<char> convRawData(outDataVec);
710 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
711 rawData, convRawData, type);
712 return DenseElementsAttr::getFromRawBuffer(type, convRawData,
713 detectedSplat);
714 }
715
716 return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
717 }
718
parseElement()719 ParseResult TensorLiteralParser::parseElement() {
720 switch (p.getToken().getKind()) {
721 // Parse a boolean element.
722 case Token::kw_true:
723 case Token::kw_false:
724 case Token::floatliteral:
725 case Token::integer:
726 storage.emplace_back(/*isNegative=*/false, p.getToken());
727 p.consumeToken();
728 break;
729
730 // Parse a signed integer or a negative floating-point element.
731 case Token::minus:
732 p.consumeToken(Token::minus);
733 if (!p.getToken().isAny(Token::floatliteral, Token::integer))
734 return p.emitError("expected integer or floating point literal");
735 storage.emplace_back(/*isNegative=*/true, p.getToken());
736 p.consumeToken();
737 break;
738
739 case Token::string:
740 storage.emplace_back(/*isNegative=*/false, p.getToken());
741 p.consumeToken();
742 break;
743
744 // Parse a complex element of the form '(' element ',' element ')'.
745 case Token::l_paren:
746 p.consumeToken(Token::l_paren);
747 if (parseElement() ||
748 p.parseToken(Token::comma, "expected ',' between complex elements") ||
749 parseElement() ||
750 p.parseToken(Token::r_paren, "expected ')' after complex elements"))
751 return failure();
752 break;
753
754 default:
755 return p.emitError("expected element literal of primitive type");
756 }
757
758 return success();
759 }
760
761 /// Parse a list of either lists or elements, returning the dimensions of the
762 /// parsed sub-tensors in dims. For example:
763 /// parseList([1, 2, 3]) -> Success, [3]
764 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
765 /// parseList([[1, 2], 3]) -> Failure
766 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
parseList(SmallVectorImpl<int64_t> & dims)767 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
768 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
769 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
770 if (prevDims == newDims)
771 return success();
772 return p.emitError("tensor literal is invalid; ranks are not consistent "
773 "between elements");
774 };
775
776 bool first = true;
777 SmallVector<int64_t, 4> newDims;
778 unsigned size = 0;
779 auto parseOneElement = [&]() -> ParseResult {
780 SmallVector<int64_t, 4> thisDims;
781 if (p.getToken().getKind() == Token::l_square) {
782 if (parseList(thisDims))
783 return failure();
784 } else if (parseElement()) {
785 return failure();
786 }
787 ++size;
788 if (!first)
789 return checkDims(newDims, thisDims);
790 newDims = thisDims;
791 first = false;
792 return success();
793 };
794 if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
795 return failure();
796
797 // Return the sublists' dimensions with 'size' prepended.
798 dims.clear();
799 dims.push_back(size);
800 dims.append(newDims.begin(), newDims.end());
801 return success();
802 }
803
804 //===----------------------------------------------------------------------===//
805 // ElementsAttr Parser
806 //===----------------------------------------------------------------------===//
807
808 /// Parse a dense elements attribute.
parseDenseElementsAttr(Type attrType)809 Attribute Parser::parseDenseElementsAttr(Type attrType) {
810 auto attribLoc = getToken().getLoc();
811 consumeToken(Token::kw_dense);
812 if (parseToken(Token::less, "expected '<' after 'dense'"))
813 return nullptr;
814
815 // Parse the literal data if necessary.
816 TensorLiteralParser literalParser(*this);
817 if (!consumeIf(Token::greater)) {
818 if (literalParser.parse(/*allowHex=*/true) ||
819 parseToken(Token::greater, "expected '>'"))
820 return nullptr;
821 }
822
823 // If the type is specified `parseElementsLiteralType` will not parse a type.
824 // Use the attribute location as the location for error reporting in that
825 // case.
826 auto loc = attrType ? attribLoc : getToken().getLoc();
827 auto type = parseElementsLiteralType(attrType);
828 if (!type)
829 return nullptr;
830 return literalParser.getAttr(loc, type);
831 }
832
833 /// Parse an opaque elements attribute.
parseOpaqueElementsAttr(Type attrType)834 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
835 consumeToken(Token::kw_opaque);
836 if (parseToken(Token::less, "expected '<' after 'opaque'"))
837 return nullptr;
838
839 if (getToken().isNot(Token::string))
840 return (emitError("expected dialect namespace"), nullptr);
841
842 std::string name = getToken().getStringValue();
843 consumeToken(Token::string);
844
845 if (parseToken(Token::comma, "expected ','"))
846 return nullptr;
847
848 Token hexTok = getToken();
849 if (parseToken(Token::string, "elements hex string should start with '0x'") ||
850 parseToken(Token::greater, "expected '>'"))
851 return nullptr;
852 auto type = parseElementsLiteralType(attrType);
853 if (!type)
854 return nullptr;
855
856 std::string data;
857 if (parseElementAttrHexValues(*this, hexTok, data))
858 return nullptr;
859 return OpaqueElementsAttr::get(builder.getIdentifier(name), type, data);
860 }
861
862 /// Shaped type for elements attribute.
863 ///
864 /// elements-literal-type ::= vector-type | ranked-tensor-type
865 ///
866 /// This method also checks the type has static shape.
parseElementsLiteralType(Type type)867 ShapedType Parser::parseElementsLiteralType(Type type) {
868 // If the user didn't provide a type, parse the colon type for the literal.
869 if (!type) {
870 if (parseToken(Token::colon, "expected ':'"))
871 return nullptr;
872 if (!(type = parseType()))
873 return nullptr;
874 }
875
876 if (!type.isa<RankedTensorType, VectorType>()) {
877 emitError("elements literal must be a ranked tensor or vector type");
878 return nullptr;
879 }
880
881 auto sType = type.cast<ShapedType>();
882 if (!sType.hasStaticShape())
883 return (emitError("elements literal type must have static shape"), nullptr);
884
885 return sType;
886 }
887
888 /// Parse a sparse elements attribute.
parseSparseElementsAttr(Type attrType)889 Attribute Parser::parseSparseElementsAttr(Type attrType) {
890 llvm::SMLoc loc = getToken().getLoc();
891 consumeToken(Token::kw_sparse);
892 if (parseToken(Token::less, "Expected '<' after 'sparse'"))
893 return nullptr;
894
895 // Check for the case where all elements are sparse. The indices are
896 // represented by a 2-dimensional shape where the second dimension is the rank
897 // of the type.
898 Type indiceEltType = builder.getIntegerType(64);
899 if (consumeIf(Token::greater)) {
900 ShapedType type = parseElementsLiteralType(attrType);
901 if (!type)
902 return nullptr;
903
904 // Construct the sparse elements attr using zero element indice/value
905 // attributes.
906 ShapedType indicesType =
907 RankedTensorType::get({0, type.getRank()}, indiceEltType);
908 ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
909 return getChecked<SparseElementsAttr>(
910 loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
911 DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
912 }
913
914 /// Parse the indices. We don't allow hex values here as we may need to use
915 /// the inferred shape.
916 auto indicesLoc = getToken().getLoc();
917 TensorLiteralParser indiceParser(*this);
918 if (indiceParser.parse(/*allowHex=*/false))
919 return nullptr;
920
921 if (parseToken(Token::comma, "expected ','"))
922 return nullptr;
923
924 /// Parse the values.
925 auto valuesLoc = getToken().getLoc();
926 TensorLiteralParser valuesParser(*this);
927 if (valuesParser.parse(/*allowHex=*/true))
928 return nullptr;
929
930 if (parseToken(Token::greater, "expected '>'"))
931 return nullptr;
932
933 auto type = parseElementsLiteralType(attrType);
934 if (!type)
935 return nullptr;
936
937 // If the indices are a splat, i.e. the literal parser parsed an element and
938 // not a list, we set the shape explicitly. The indices are represented by a
939 // 2-dimensional shape where the second dimension is the rank of the type.
940 // Given that the parsed indices is a splat, we know that we only have one
941 // indice and thus one for the first dimension.
942 ShapedType indicesType;
943 if (indiceParser.getShape().empty()) {
944 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
945 } else {
946 // Otherwise, set the shape to the one parsed by the literal parser.
947 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
948 }
949 auto indices = indiceParser.getAttr(indicesLoc, indicesType);
950
951 // If the values are a splat, set the shape explicitly based on the number of
952 // indices. The number of indices is encoded in the first dimension of the
953 // indice shape type.
954 auto valuesEltType = type.getElementType();
955 ShapedType valuesType =
956 valuesParser.getShape().empty()
957 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
958 : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
959 auto values = valuesParser.getAttr(valuesLoc, valuesType);
960
961 // Build the sparse elements attribute by the indices and values.
962 return getChecked<SparseElementsAttr>(loc, type, indices, values);
963 }
964