1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR Types. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Parser.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/IntegerSet.h" 18 #include "mlir/Parser/AsmParserState.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/Support/Endian.h" 21 22 using namespace mlir; 23 using namespace mlir::detail; 24 25 /// Parse an arbitrary attribute. 26 /// 27 /// attribute-value ::= `unit` 28 /// | bool-literal 29 /// | integer-literal (`:` (index-type | integer-type))? 30 /// | float-literal (`:` float-type)? 31 /// | string-literal (`:` type)? 32 /// | type 33 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 34 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 35 /// | symbol-ref-id (`::` symbol-ref-id)* 36 /// | `dense` `<` attribute-value `>` `:` 37 /// (tensor-type | vector-type) 38 /// | `sparse` `<` attribute-value `,` attribute-value `>` 39 /// `:` (tensor-type | vector-type) 40 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 41 /// `>` `:` (tensor-type | vector-type) 42 /// | extended-attribute 43 /// 44 Attribute Parser::parseAttribute(Type type) { 45 switch (getToken().getKind()) { 46 // Parse an AffineMap or IntegerSet attribute. 47 case Token::kw_affine_map: { 48 consumeToken(Token::kw_affine_map); 49 50 AffineMap map; 51 if (parseToken(Token::less, "expected '<' in affine map") || 52 parseAffineMapReference(map) || 53 parseToken(Token::greater, "expected '>' in affine map")) 54 return Attribute(); 55 return AffineMapAttr::get(map); 56 } 57 case Token::kw_affine_set: { 58 consumeToken(Token::kw_affine_set); 59 60 IntegerSet set; 61 if (parseToken(Token::less, "expected '<' in integer set") || 62 parseIntegerSetReference(set) || 63 parseToken(Token::greater, "expected '>' in integer set")) 64 return Attribute(); 65 return IntegerSetAttr::get(set); 66 } 67 68 // Parse an array attribute. 69 case Token::l_square: { 70 consumeToken(Token::l_square); 71 72 SmallVector<Attribute, 4> elements; 73 auto parseElt = [&]() -> ParseResult { 74 elements.push_back(parseAttribute()); 75 return elements.back() ? success() : failure(); 76 }; 77 78 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 79 return nullptr; 80 return builder.getArrayAttr(elements); 81 } 82 83 // Parse a boolean attribute. 84 case Token::kw_false: 85 consumeToken(Token::kw_false); 86 return builder.getBoolAttr(false); 87 case Token::kw_true: 88 consumeToken(Token::kw_true); 89 return builder.getBoolAttr(true); 90 91 // Parse a dense elements attribute. 92 case Token::kw_dense: 93 return parseDenseElementsAttr(type); 94 95 // Parse a dictionary attribute. 96 case Token::l_brace: { 97 NamedAttrList elements; 98 if (parseAttributeDict(elements)) 99 return nullptr; 100 return elements.getDictionary(getContext()); 101 } 102 103 // Parse an extended attribute, i.e. alias or dialect attribute. 104 case Token::hash_identifier: 105 return parseExtendedAttr(type); 106 107 // Parse floating point and integer attributes. 108 case Token::floatliteral: 109 return parseFloatAttr(type, /*isNegative=*/false); 110 case Token::integer: 111 return parseDecOrHexAttr(type, /*isNegative=*/false); 112 case Token::minus: { 113 consumeToken(Token::minus); 114 if (getToken().is(Token::integer)) 115 return parseDecOrHexAttr(type, /*isNegative=*/true); 116 if (getToken().is(Token::floatliteral)) 117 return parseFloatAttr(type, /*isNegative=*/true); 118 119 return (emitError("expected constant integer or floating point value"), 120 nullptr); 121 } 122 123 // Parse a location attribute. 124 case Token::kw_loc: { 125 consumeToken(Token::kw_loc); 126 127 LocationAttr locAttr; 128 if (parseToken(Token::l_paren, "expected '(' in inline location") || 129 parseLocationInstance(locAttr) || 130 parseToken(Token::r_paren, "expected ')' in inline location")) 131 return Attribute(); 132 return locAttr; 133 } 134 135 // Parse an opaque elements attribute. 136 case Token::kw_opaque: 137 return parseOpaqueElementsAttr(type); 138 139 // Parse a sparse elements attribute. 140 case Token::kw_sparse: 141 return parseSparseElementsAttr(type); 142 143 // Parse a string attribute. 144 case Token::string: { 145 auto val = getToken().getStringValue(); 146 consumeToken(Token::string); 147 // Parse the optional trailing colon type if one wasn't explicitly provided. 148 if (!type && consumeIf(Token::colon) && !(type = parseType())) 149 return Attribute(); 150 151 return type ? StringAttr::get(val, type) 152 : StringAttr::get(getContext(), val); 153 } 154 155 // Parse a symbol reference attribute. 156 case Token::at_identifier: { 157 // When populating the parser state, this is a list of locations for all of 158 // the nested references. 159 SmallVector<llvm::SMRange> referenceLocations; 160 if (state.asmState) 161 referenceLocations.push_back(getToken().getLocRange()); 162 163 // Parse the top-level reference. 164 std::string nameStr = getToken().getSymbolReference(); 165 consumeToken(Token::at_identifier); 166 167 // Parse any nested references. 168 std::vector<FlatSymbolRefAttr> nestedRefs; 169 while (getToken().is(Token::colon)) { 170 // Check for the '::' prefix. 171 const char *curPointer = getToken().getLoc().getPointer(); 172 consumeToken(Token::colon); 173 if (!consumeIf(Token::colon)) { 174 state.lex.resetPointer(curPointer); 175 consumeToken(); 176 break; 177 } 178 // Parse the reference itself. 179 auto curLoc = getToken().getLoc(); 180 if (getToken().isNot(Token::at_identifier)) { 181 emitError(curLoc, "expected nested symbol reference identifier"); 182 return Attribute(); 183 } 184 185 // If we are populating the assembly state, add the location for this 186 // reference. 187 if (state.asmState) 188 referenceLocations.push_back(getToken().getLocRange()); 189 190 std::string nameStr = getToken().getSymbolReference(); 191 consumeToken(Token::at_identifier); 192 nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); 193 } 194 SymbolRefAttr symbolRefAttr = builder.getSymbolRefAttr(nameStr, nestedRefs); 195 196 // If we are populating the assembly state, record this symbol reference. 197 if (state.asmState) 198 state.asmState->addUses(symbolRefAttr, referenceLocations); 199 return symbolRefAttr; 200 } 201 202 // Parse a 'unit' attribute. 203 case Token::kw_unit: 204 consumeToken(Token::kw_unit); 205 return builder.getUnitAttr(); 206 207 default: 208 // Parse a type attribute. 209 if (Type type = parseType()) 210 return TypeAttr::get(type); 211 return nullptr; 212 } 213 } 214 215 /// Parse an optional attribute with the provided type. 216 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, 217 Type type) { 218 switch (getToken().getKind()) { 219 case Token::at_identifier: 220 case Token::floatliteral: 221 case Token::integer: 222 case Token::hash_identifier: 223 case Token::kw_affine_map: 224 case Token::kw_affine_set: 225 case Token::kw_dense: 226 case Token::kw_false: 227 case Token::kw_loc: 228 case Token::kw_opaque: 229 case Token::kw_sparse: 230 case Token::kw_true: 231 case Token::kw_unit: 232 case Token::l_brace: 233 case Token::l_square: 234 case Token::minus: 235 case Token::string: 236 attribute = parseAttribute(type); 237 return success(attribute != nullptr); 238 239 default: 240 // Parse an optional type attribute. 241 Type type; 242 OptionalParseResult result = parseOptionalType(type); 243 if (result.hasValue() && succeeded(*result)) 244 attribute = TypeAttr::get(type); 245 return result; 246 } 247 } 248 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute, 249 Type type) { 250 return parseOptionalAttributeWithToken(Token::l_square, attribute, type); 251 } 252 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute, 253 Type type) { 254 return parseOptionalAttributeWithToken(Token::string, attribute, type); 255 } 256 257 /// Attribute dictionary. 258 /// 259 /// attribute-dict ::= `{` `}` 260 /// | `{` attribute-entry (`,` attribute-entry)* `}` 261 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value 262 /// 263 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { 264 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 265 return failure(); 266 267 llvm::SmallDenseSet<Identifier> seenKeys; 268 auto parseElt = [&]() -> ParseResult { 269 // The name of an attribute can either be a bare identifier, or a string. 270 Optional<Identifier> nameId; 271 if (getToken().is(Token::string)) 272 nameId = builder.getIdentifier(getToken().getStringValue()); 273 else if (getToken().isAny(Token::bare_identifier, Token::inttype) || 274 getToken().isKeyword()) 275 nameId = builder.getIdentifier(getTokenSpelling()); 276 else 277 return emitError("expected attribute name"); 278 if (!seenKeys.insert(*nameId).second) 279 return emitError("duplicate key '") 280 << *nameId << "' in dictionary attribute"; 281 consumeToken(); 282 283 // Lazy load a dialect in the context if there is a possible namespace. 284 auto splitName = nameId->strref().split('.'); 285 if (!splitName.second.empty()) 286 getContext()->getOrLoadDialect(splitName.first); 287 288 // Try to parse the '=' for the attribute value. 289 if (!consumeIf(Token::equal)) { 290 // If there is no '=', we treat this as a unit attribute. 291 attributes.push_back({*nameId, builder.getUnitAttr()}); 292 return success(); 293 } 294 295 auto attr = parseAttribute(); 296 if (!attr) 297 return failure(); 298 attributes.push_back({*nameId, attr}); 299 return success(); 300 }; 301 302 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 303 return failure(); 304 305 return success(); 306 } 307 308 /// Parse a float attribute. 309 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 310 auto val = getToken().getFloatingPointValue(); 311 if (!val.hasValue()) 312 return (emitError("floating point value too large for attribute"), nullptr); 313 consumeToken(Token::floatliteral); 314 if (!type) { 315 // Default to F64 when no type is specified. 316 if (!consumeIf(Token::colon)) 317 type = builder.getF64Type(); 318 else if (!(type = parseType())) 319 return nullptr; 320 } 321 if (!type.isa<FloatType>()) 322 return (emitError("floating point value not valid for specified type"), 323 nullptr); 324 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 325 } 326 327 /// Construct an APint from a parsed value, a known attribute type and 328 /// sign. 329 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative, 330 StringRef spelling) { 331 // Parse the integer value into an APInt that is big enough to hold the value. 332 APInt result; 333 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 334 if (spelling.getAsInteger(isHex ? 0 : 10, result)) 335 return llvm::None; 336 337 // Extend or truncate the bitwidth to the right size. 338 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth 339 : type.getIntOrFloatBitWidth(); 340 341 // APInt cannot hold a zero bit value. 342 if (width == 0) 343 return llvm::None; 344 345 if (width > result.getBitWidth()) { 346 result = result.zext(width); 347 } else if (width < result.getBitWidth()) { 348 // The parser can return an unnecessarily wide result with leading zeros. 349 // This isn't a problem, but truncating off bits is bad. 350 if (result.countLeadingZeros() < result.getBitWidth() - width) 351 return llvm::None; 352 353 result = result.trunc(width); 354 } 355 356 if (isNegative) { 357 // The value is negative, we have an overflow if the sign bit is not set 358 // in the negated apInt. 359 result.negate(); 360 if (!result.isSignBitSet()) 361 return llvm::None; 362 } else if ((type.isSignedInteger() || type.isIndex()) && 363 result.isSignBitSet()) { 364 // The value is a positive signed integer or index, 365 // we have an overflow if the sign bit is set. 366 return llvm::None; 367 } 368 369 return result; 370 } 371 372 /// Parse a decimal or a hexadecimal literal, which can be either an integer 373 /// or a float attribute. 374 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 375 Token tok = getToken(); 376 StringRef spelling = tok.getSpelling(); 377 llvm::SMLoc loc = tok.getLoc(); 378 379 consumeToken(Token::integer); 380 if (!type) { 381 // Default to i64 if not type is specified. 382 if (!consumeIf(Token::colon)) 383 type = builder.getIntegerType(64); 384 else if (!(type = parseType())) 385 return nullptr; 386 } 387 388 if (auto floatType = type.dyn_cast<FloatType>()) { 389 Optional<APFloat> result; 390 if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, 391 floatType.getFloatSemantics(), 392 floatType.getWidth()))) 393 return Attribute(); 394 return FloatAttr::get(floatType, *result); 395 } 396 397 if (!type.isa<IntegerType, IndexType>()) 398 return emitError(loc, "integer literal not valid for specified type"), 399 nullptr; 400 401 if (isNegative && type.isUnsignedInteger()) { 402 emitError(loc, 403 "negative integer literal not valid for unsigned integer type"); 404 return nullptr; 405 } 406 407 Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); 408 if (!apInt) 409 return emitError(loc, "integer constant out of range for attribute"), 410 nullptr; 411 return builder.getIntegerAttr(type, *apInt); 412 } 413 414 //===----------------------------------------------------------------------===// 415 // TensorLiteralParser 416 //===----------------------------------------------------------------------===// 417 418 /// Parse elements values stored within a hex string. On success, the values are 419 /// stored into 'result'. 420 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, 421 std::string &result) { 422 if (Optional<std::string> value = tok.getHexStringValue()) { 423 result = std::move(*value); 424 return success(); 425 } 426 return parser.emitError( 427 tok.getLoc(), "expected string containing hex digits starting with `0x`"); 428 } 429 430 namespace { 431 /// This class implements a parser for TensorLiterals. A tensor literal is 432 /// either a single element (e.g, 5) or a multi-dimensional list of elements 433 /// (e.g., [[5, 5]]). 434 class TensorLiteralParser { 435 public: 436 TensorLiteralParser(Parser &p) : p(p) {} 437 438 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 439 /// may also parse a tensor literal that is store as a hex string. 440 ParseResult parse(bool allowHex); 441 442 /// Build a dense attribute instance with the parsed elements and the given 443 /// shaped type. 444 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 445 446 ArrayRef<int64_t> getShape() const { return shape; } 447 448 private: 449 /// Get the parsed elements for an integer attribute. 450 ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy, 451 std::vector<APInt> &intValues); 452 453 /// Get the parsed elements for a float attribute. 454 ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, 455 std::vector<APFloat> &floatValues); 456 457 /// Build a Dense String attribute for the given type. 458 DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); 459 460 /// Build a Dense attribute with hex data for the given type. 461 DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); 462 463 /// Parse a single element, returning failure if it isn't a valid element 464 /// literal. For example: 465 /// parseElement(1) -> Success, 1 466 /// parseElement([1]) -> Failure 467 ParseResult parseElement(); 468 469 /// Parse a list of either lists or elements, returning the dimensions of the 470 /// parsed sub-tensors in dims. For example: 471 /// parseList([1, 2, 3]) -> Success, [3] 472 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 473 /// parseList([[1, 2], 3]) -> Failure 474 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 475 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 476 477 /// Parse a literal that was printed as a hex string. 478 ParseResult parseHexElements(); 479 480 Parser &p; 481 482 /// The shape inferred from the parsed elements. 483 SmallVector<int64_t, 4> shape; 484 485 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 486 std::vector<std::pair<bool, Token>> storage; 487 488 /// Storage used when parsing elements that were stored as hex values. 489 Optional<Token> hexStorage; 490 }; 491 } // end anonymous namespace 492 493 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 494 /// may also parse a tensor literal that is store as a hex string. 495 ParseResult TensorLiteralParser::parse(bool allowHex) { 496 // If hex is allowed, check for a string literal. 497 if (allowHex && p.getToken().is(Token::string)) { 498 hexStorage = p.getToken(); 499 p.consumeToken(Token::string); 500 return success(); 501 } 502 // Otherwise, parse a list or an individual element. 503 if (p.getToken().is(Token::l_square)) 504 return parseList(shape); 505 return parseElement(); 506 } 507 508 /// Build a dense attribute instance with the parsed elements and the given 509 /// shaped type. 510 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 511 ShapedType type) { 512 Type eltType = type.getElementType(); 513 514 // Check to see if we parse the literal from a hex string. 515 if (hexStorage.hasValue() && 516 (eltType.isIntOrFloat() || eltType.isa<ComplexType>())) 517 return getHexAttr(loc, type); 518 519 // Check that the parsed storage size has the same number of elements to the 520 // type, or is a known splat. 521 if (!shape.empty() && getShape() != type.getShape()) { 522 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 523 << "]) does not match type ([" << type.getShape() << "])"; 524 return nullptr; 525 } 526 527 // Handle the case where no elements were parsed. 528 if (!hexStorage.hasValue() && storage.empty() && type.getNumElements()) { 529 p.emitError(loc) << "parsed zero elements, but type (" << type 530 << ") expected at least 1"; 531 return nullptr; 532 } 533 534 // Handle complex types in the specific element type cases below. 535 bool isComplex = false; 536 if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) { 537 eltType = complexTy.getElementType(); 538 isComplex = true; 539 } 540 541 // Handle integer and index types. 542 if (eltType.isIntOrIndex()) { 543 std::vector<APInt> intValues; 544 if (failed(getIntAttrElements(loc, eltType, intValues))) 545 return nullptr; 546 if (isComplex) { 547 // If this is a complex, treat the parsed values as complex values. 548 auto complexData = llvm::makeArrayRef( 549 reinterpret_cast<std::complex<APInt> *>(intValues.data()), 550 intValues.size() / 2); 551 return DenseElementsAttr::get(type, complexData); 552 } 553 return DenseElementsAttr::get(type, intValues); 554 } 555 // Handle floating point types. 556 if (FloatType floatTy = eltType.dyn_cast<FloatType>()) { 557 std::vector<APFloat> floatValues; 558 if (failed(getFloatAttrElements(loc, floatTy, floatValues))) 559 return nullptr; 560 if (isComplex) { 561 // If this is a complex, treat the parsed values as complex values. 562 auto complexData = llvm::makeArrayRef( 563 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()), 564 floatValues.size() / 2); 565 return DenseElementsAttr::get(type, complexData); 566 } 567 return DenseElementsAttr::get(type, floatValues); 568 } 569 570 // Other types are assumed to be string representations. 571 return getStringAttr(loc, type, type.getElementType()); 572 } 573 574 /// Build a Dense Integer attribute for the given type. 575 ParseResult 576 TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy, 577 std::vector<APInt> &intValues) { 578 intValues.reserve(storage.size()); 579 bool isUintType = eltTy.isUnsignedInteger(); 580 for (const auto &signAndToken : storage) { 581 bool isNegative = signAndToken.first; 582 const Token &token = signAndToken.second; 583 auto tokenLoc = token.getLoc(); 584 585 if (isNegative && isUintType) { 586 return p.emitError(tokenLoc) 587 << "expected unsigned integer elements, but parsed negative value"; 588 } 589 590 // Check to see if floating point values were parsed. 591 if (token.is(Token::floatliteral)) { 592 return p.emitError(tokenLoc) 593 << "expected integer elements, but parsed floating-point"; 594 } 595 596 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 597 "unexpected token type"); 598 if (token.isAny(Token::kw_true, Token::kw_false)) { 599 if (!eltTy.isInteger(1)) { 600 return p.emitError(tokenLoc) 601 << "expected i1 type for 'true' or 'false' values"; 602 } 603 APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); 604 intValues.push_back(apInt); 605 continue; 606 } 607 608 // Create APInt values for each element with the correct bitwidth. 609 Optional<APInt> apInt = 610 buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); 611 if (!apInt) 612 return p.emitError(tokenLoc, "integer constant out of range for type"); 613 intValues.push_back(*apInt); 614 } 615 return success(); 616 } 617 618 /// Build a Dense Float attribute for the given type. 619 ParseResult 620 TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, 621 std::vector<APFloat> &floatValues) { 622 floatValues.reserve(storage.size()); 623 for (const auto &signAndToken : storage) { 624 bool isNegative = signAndToken.first; 625 const Token &token = signAndToken.second; 626 627 // Handle hexadecimal float literals. 628 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 629 Optional<APFloat> result; 630 if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative, 631 eltTy.getFloatSemantics(), 632 eltTy.getWidth()))) 633 return failure(); 634 635 floatValues.push_back(*result); 636 continue; 637 } 638 639 // Check to see if any decimal integers or booleans were parsed. 640 if (!token.is(Token::floatliteral)) 641 return p.emitError() 642 << "expected floating-point elements, but parsed integer"; 643 644 // Build the float values from tokens. 645 auto val = token.getFloatingPointValue(); 646 if (!val.hasValue()) 647 return p.emitError("floating point value too large for attribute"); 648 649 APFloat apVal(isNegative ? -*val : *val); 650 if (!eltTy.isF64()) { 651 bool unused; 652 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, 653 &unused); 654 } 655 floatValues.push_back(apVal); 656 } 657 return success(); 658 } 659 660 /// Build a Dense String attribute for the given type. 661 DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc, 662 ShapedType type, 663 Type eltTy) { 664 if (hexStorage.hasValue()) { 665 auto stringValue = hexStorage.getValue().getStringValue(); 666 return DenseStringElementsAttr::get(type, {stringValue}); 667 } 668 669 std::vector<std::string> stringValues; 670 std::vector<StringRef> stringRefValues; 671 stringValues.reserve(storage.size()); 672 stringRefValues.reserve(storage.size()); 673 674 for (auto val : storage) { 675 stringValues.push_back(val.second.getStringValue()); 676 stringRefValues.push_back(stringValues.back()); 677 } 678 679 return DenseStringElementsAttr::get(type, stringRefValues); 680 } 681 682 /// Build a Dense attribute with hex data for the given type. 683 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, 684 ShapedType type) { 685 Type elementType = type.getElementType(); 686 if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) { 687 p.emitError(loc) 688 << "expected floating-point, integer, or complex element type, got " 689 << elementType; 690 return nullptr; 691 } 692 693 std::string data; 694 if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) 695 return nullptr; 696 697 ArrayRef<char> rawData(data.data(), data.size()); 698 bool detectedSplat = false; 699 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { 700 p.emitError(loc) << "elements hex data size is invalid for provided type: " 701 << type; 702 return nullptr; 703 } 704 705 if (llvm::support::endian::system_endianness() == 706 llvm::support::endianness::big) { 707 // Convert endianess in big-endian(BE) machines. `rawData` is 708 // little-endian(LE) because HEX in raw data of dense element attribute 709 // is always LE format. It is converted into BE here to be used in BE 710 // machines. 711 SmallVector<char, 64> outDataVec(rawData.size()); 712 MutableArrayRef<char> convRawData(outDataVec); 713 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 714 rawData, convRawData, type); 715 return DenseElementsAttr::getFromRawBuffer(type, convRawData, 716 detectedSplat); 717 } 718 719 return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat); 720 } 721 722 ParseResult TensorLiteralParser::parseElement() { 723 switch (p.getToken().getKind()) { 724 // Parse a boolean element. 725 case Token::kw_true: 726 case Token::kw_false: 727 case Token::floatliteral: 728 case Token::integer: 729 storage.emplace_back(/*isNegative=*/false, p.getToken()); 730 p.consumeToken(); 731 break; 732 733 // Parse a signed integer or a negative floating-point element. 734 case Token::minus: 735 p.consumeToken(Token::minus); 736 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 737 return p.emitError("expected integer or floating point literal"); 738 storage.emplace_back(/*isNegative=*/true, p.getToken()); 739 p.consumeToken(); 740 break; 741 742 case Token::string: 743 storage.emplace_back(/*isNegative=*/false, p.getToken()); 744 p.consumeToken(); 745 break; 746 747 // Parse a complex element of the form '(' element ',' element ')'. 748 case Token::l_paren: 749 p.consumeToken(Token::l_paren); 750 if (parseElement() || 751 p.parseToken(Token::comma, "expected ',' between complex elements") || 752 parseElement() || 753 p.parseToken(Token::r_paren, "expected ')' after complex elements")) 754 return failure(); 755 break; 756 757 default: 758 return p.emitError("expected element literal of primitive type"); 759 } 760 761 return success(); 762 } 763 764 /// Parse a list of either lists or elements, returning the dimensions of the 765 /// parsed sub-tensors in dims. For example: 766 /// parseList([1, 2, 3]) -> Success, [3] 767 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 768 /// parseList([[1, 2], 3]) -> Failure 769 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 770 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 771 p.consumeToken(Token::l_square); 772 773 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 774 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 775 if (prevDims == newDims) 776 return success(); 777 return p.emitError("tensor literal is invalid; ranks are not consistent " 778 "between elements"); 779 }; 780 781 bool first = true; 782 SmallVector<int64_t, 4> newDims; 783 unsigned size = 0; 784 auto parseCommaSeparatedList = [&]() -> ParseResult { 785 SmallVector<int64_t, 4> thisDims; 786 if (p.getToken().getKind() == Token::l_square) { 787 if (parseList(thisDims)) 788 return failure(); 789 } else if (parseElement()) { 790 return failure(); 791 } 792 ++size; 793 if (!first) 794 return checkDims(newDims, thisDims); 795 newDims = thisDims; 796 first = false; 797 return success(); 798 }; 799 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 800 return failure(); 801 802 // Return the sublists' dimensions with 'size' prepended. 803 dims.clear(); 804 dims.push_back(size); 805 dims.append(newDims.begin(), newDims.end()); 806 return success(); 807 } 808 809 //===----------------------------------------------------------------------===// 810 // ElementsAttr Parser 811 //===----------------------------------------------------------------------===// 812 813 /// Parse a dense elements attribute. 814 Attribute Parser::parseDenseElementsAttr(Type attrType) { 815 auto attribLoc = getToken().getLoc(); 816 consumeToken(Token::kw_dense); 817 if (parseToken(Token::less, "expected '<' after 'dense'")) 818 return nullptr; 819 820 // Parse the literal data if necessary. 821 TensorLiteralParser literalParser(*this); 822 if (!consumeIf(Token::greater)) { 823 if (literalParser.parse(/*allowHex=*/true) || 824 parseToken(Token::greater, "expected '>'")) 825 return nullptr; 826 } 827 828 // If the type is specified `parseElementsLiteralType` will not parse a type. 829 // Use the attribute location as the location for error reporting in that 830 // case. 831 auto loc = attrType ? attribLoc : getToken().getLoc(); 832 auto type = parseElementsLiteralType(attrType); 833 if (!type) 834 return nullptr; 835 return literalParser.getAttr(loc, type); 836 } 837 838 /// Parse an opaque elements attribute. 839 Attribute Parser::parseOpaqueElementsAttr(Type attrType) { 840 consumeToken(Token::kw_opaque); 841 if (parseToken(Token::less, "expected '<' after 'opaque'")) 842 return nullptr; 843 844 if (getToken().isNot(Token::string)) 845 return (emitError("expected dialect namespace"), nullptr); 846 847 std::string name = getToken().getStringValue(); 848 consumeToken(Token::string); 849 850 if (parseToken(Token::comma, "expected ','")) 851 return nullptr; 852 853 Token hexTok = getToken(); 854 if (parseToken(Token::string, "elements hex string should start with '0x'") || 855 parseToken(Token::greater, "expected '>'")) 856 return nullptr; 857 auto type = parseElementsLiteralType(attrType); 858 if (!type) 859 return nullptr; 860 861 std::string data; 862 if (parseElementAttrHexValues(*this, hexTok, data)) 863 return nullptr; 864 return OpaqueElementsAttr::get(builder.getIdentifier(name), type, data); 865 } 866 867 /// Shaped type for elements attribute. 868 /// 869 /// elements-literal-type ::= vector-type | ranked-tensor-type 870 /// 871 /// This method also checks the type has static shape. 872 ShapedType Parser::parseElementsLiteralType(Type type) { 873 // If the user didn't provide a type, parse the colon type for the literal. 874 if (!type) { 875 if (parseToken(Token::colon, "expected ':'")) 876 return nullptr; 877 if (!(type = parseType())) 878 return nullptr; 879 } 880 881 if (!type.isa<RankedTensorType, VectorType>()) { 882 emitError("elements literal must be a ranked tensor or vector type"); 883 return nullptr; 884 } 885 886 auto sType = type.cast<ShapedType>(); 887 if (!sType.hasStaticShape()) 888 return (emitError("elements literal type must have static shape"), nullptr); 889 890 return sType; 891 } 892 893 /// Parse a sparse elements attribute. 894 Attribute Parser::parseSparseElementsAttr(Type attrType) { 895 consumeToken(Token::kw_sparse); 896 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 897 return nullptr; 898 899 // Check for the case where all elements are sparse. The indices are 900 // represented by a 2-dimensional shape where the second dimension is the rank 901 // of the type. 902 Type indiceEltType = builder.getIntegerType(64); 903 if (consumeIf(Token::greater)) { 904 ShapedType type = parseElementsLiteralType(attrType); 905 if (!type) 906 return nullptr; 907 908 // Construct the sparse elements attr using zero element indice/value 909 // attributes. 910 ShapedType indicesType = 911 RankedTensorType::get({0, type.getRank()}, indiceEltType); 912 ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); 913 return SparseElementsAttr::get( 914 type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()), 915 DenseElementsAttr::get(valuesType, ArrayRef<Attribute>())); 916 } 917 918 /// Parse the indices. We don't allow hex values here as we may need to use 919 /// the inferred shape. 920 auto indicesLoc = getToken().getLoc(); 921 TensorLiteralParser indiceParser(*this); 922 if (indiceParser.parse(/*allowHex=*/false)) 923 return nullptr; 924 925 if (parseToken(Token::comma, "expected ','")) 926 return nullptr; 927 928 /// Parse the values. 929 auto valuesLoc = getToken().getLoc(); 930 TensorLiteralParser valuesParser(*this); 931 if (valuesParser.parse(/*allowHex=*/true)) 932 return nullptr; 933 934 if (parseToken(Token::greater, "expected '>'")) 935 return nullptr; 936 937 auto type = parseElementsLiteralType(attrType); 938 if (!type) 939 return nullptr; 940 941 // If the indices are a splat, i.e. the literal parser parsed an element and 942 // not a list, we set the shape explicitly. The indices are represented by a 943 // 2-dimensional shape where the second dimension is the rank of the type. 944 // Given that the parsed indices is a splat, we know that we only have one 945 // indice and thus one for the first dimension. 946 ShapedType indicesType; 947 if (indiceParser.getShape().empty()) { 948 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 949 } else { 950 // Otherwise, set the shape to the one parsed by the literal parser. 951 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 952 } 953 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 954 955 // If the values are a splat, set the shape explicitly based on the number of 956 // indices. The number of indices is encoded in the first dimension of the 957 // indice shape type. 958 auto valuesEltType = type.getElementType(); 959 ShapedType valuesType = 960 valuesParser.getShape().empty() 961 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 962 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 963 auto values = valuesParser.getAttr(valuesLoc, valuesType); 964 965 /// Sanity check. 966 if (valuesType.getRank() != 1) 967 return (emitError("expected 1-d tensor for values"), nullptr); 968 969 auto sameShape = (indicesType.getRank() == 1) || 970 (type.getRank() == indicesType.getDimSize(1)); 971 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 972 if (!sameShape || !sameElementNum) { 973 emitError() << "expected shape ([" << type.getShape() 974 << "]); inferred shape of indices literal ([" 975 << indicesType.getShape() 976 << "]); inferred shape of values literal ([" 977 << valuesType.getShape() << "])"; 978 return nullptr; 979 } 980 981 // Build the sparse elements attribute by the indices and values. 982 return SparseElementsAttr::get(type, indices, values); 983 } 984