1 //===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This tablegen backend is responsible for emitting riscv_vector.h which 10 // includes a declaration and definition of each intrinsic functions specified 11 // in https://github.com/riscv/rvv-intrinsic-doc. 12 // 13 // See also the documentation in include/clang/Basic/riscv_vector.td. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/SmallSet.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/ADT/StringMap.h" 21 #include "llvm/ADT/StringSet.h" 22 #include "llvm/ADT/Twine.h" 23 #include "llvm/TableGen/Error.h" 24 #include "llvm/TableGen/Record.h" 25 #include <numeric> 26 27 using namespace llvm; 28 using BasicType = char; 29 using VScaleVal = Optional<unsigned>; 30 31 namespace { 32 33 // Exponential LMUL 34 struct LMULType { 35 int Log2LMUL; 36 LMULType(int Log2LMUL); 37 // Return the C/C++ string representation of LMUL 38 std::string str() const; 39 Optional<unsigned> getScale(unsigned ElementBitwidth) const; 40 void MulLog2LMUL(int Log2LMUL); 41 LMULType &operator*=(uint32_t RHS); 42 }; 43 44 // This class is compact representation of a valid and invalid RVVType. 45 class RVVType { 46 enum ScalarTypeKind : uint32_t { 47 Void, 48 Size_t, 49 Ptrdiff_t, 50 UnsignedLong, 51 SignedLong, 52 Boolean, 53 SignedInteger, 54 UnsignedInteger, 55 Float, 56 Invalid, 57 }; 58 BasicType BT; 59 ScalarTypeKind ScalarType = Invalid; 60 LMULType LMUL; 61 bool IsPointer = false; 62 // IsConstant indices are "int", but have the constant expression. 63 bool IsImmediate = false; 64 // Const qualifier for pointer to const object or object of const type. 65 bool IsConstant = false; 66 unsigned ElementBitwidth = 0; 67 VScaleVal Scale = 0; 68 bool Valid; 69 70 std::string BuiltinStr; 71 std::string ClangBuiltinStr; 72 std::string Str; 73 std::string ShortStr; 74 75 public: 76 RVVType() : RVVType(BasicType(), 0, StringRef()) {} 77 RVVType(BasicType BT, int Log2LMUL, StringRef prototype); 78 79 // Return the string representation of a type, which is an encoded string for 80 // passing to the BUILTIN() macro in Builtins.def. 81 const std::string &getBuiltinStr() const { return BuiltinStr; } 82 83 // Return the clang buitlin type for RVV vector type which are used in the 84 // riscv_vector.h header file. 85 const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; } 86 87 // Return the C/C++ string representation of a type for use in the 88 // riscv_vector.h header file. 89 const std::string &getTypeStr() const { return Str; } 90 91 // Return the short name of a type for C/C++ name suffix. 92 const std::string &getShortStr() { 93 // Not all types are used in short name, so compute the short name by 94 // demanded. 95 if (ShortStr.empty()) 96 initShortStr(); 97 return ShortStr; 98 } 99 100 bool isValid() const { return Valid; } 101 bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; } 102 bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; } 103 bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } 104 bool isSignedInteger() const { 105 return ScalarType == ScalarTypeKind::SignedInteger; 106 } 107 bool isFloatVector(unsigned Width) const { 108 return isVector() && isFloat() && ElementBitwidth == Width; 109 } 110 bool isFloat(unsigned Width) const { 111 return isFloat() && ElementBitwidth == Width; 112 } 113 114 private: 115 // Verify RVV vector type and set Valid. 116 bool verifyType() const; 117 118 // Creates a type based on basic types of TypeRange 119 void applyBasicType(); 120 121 // Applies a prototype modifier to the current type. The result maybe an 122 // invalid type. 123 void applyModifier(StringRef prototype); 124 125 // Compute and record a string for legal type. 126 void initBuiltinStr(); 127 // Compute and record a builtin RVV vector type string. 128 void initClangBuiltinStr(); 129 // Compute and record a type string for used in the header. 130 void initTypeStr(); 131 // Compute and record a short name of a type for C/C++ name suffix. 132 void initShortStr(); 133 }; 134 135 using RVVTypePtr = RVVType *; 136 using RVVTypes = std::vector<RVVTypePtr>; 137 138 enum RISCVExtension : uint8_t { 139 Basic = 0, 140 F = 1 << 1, 141 D = 1 << 2, 142 Zfh = 1 << 3, 143 Zvamo = 1 << 4, 144 Zvlsseg = 1 << 5, 145 }; 146 147 // TODO refactor RVVIntrinsic class design after support all intrinsic 148 // combination. This represents an instantiation of an intrinsic with a 149 // particular type and prototype 150 class RVVIntrinsic { 151 152 private: 153 std::string Name; // Builtin name 154 std::string MangledName; 155 std::string IRName; 156 bool HasSideEffects; 157 bool IsMask; 158 bool HasMaskedOffOperand; 159 bool HasVL; 160 bool HasNoMaskedOverloaded; 161 bool HasAutoDef; // There is automiatic definition in header 162 std::string ManualCodegen; 163 RVVTypePtr OutputType; // Builtin output type 164 RVVTypes InputTypes; // Builtin input types 165 // The types we use to obtain the specific LLVM intrinsic. They are index of 166 // InputTypes. -1 means the return type. 167 std::vector<int64_t> IntrinsicTypes; 168 uint8_t RISCVExtensions = 0; 169 unsigned NF = 1; 170 171 public: 172 RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName, 173 StringRef MangledSuffix, StringRef IRName, bool HasSideEffects, 174 bool IsMask, bool HasMaskedOffOperand, bool HasVL, 175 bool HasNoMaskedOverloaded, bool HasAutoDef, 176 StringRef ManualCodegen, const RVVTypes &Types, 177 const std::vector<int64_t> &IntrinsicTypes, 178 StringRef RequiredExtension, unsigned NF); 179 ~RVVIntrinsic() = default; 180 181 StringRef getName() const { return Name; } 182 StringRef getMangledName() const { return MangledName; } 183 bool hasSideEffects() const { return HasSideEffects; } 184 bool hasMaskedOffOperand() const { return HasMaskedOffOperand; } 185 bool hasVL() const { return HasVL; } 186 bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; } 187 bool hasManualCodegen() const { return !ManualCodegen.empty(); } 188 bool hasAutoDef() const { return HasAutoDef; } 189 bool isMask() const { return IsMask; } 190 StringRef getIRName() const { return IRName; } 191 StringRef getManualCodegen() const { return ManualCodegen; } 192 uint8_t getRISCVExtensions() const { return RISCVExtensions; } 193 unsigned getNF() const { return NF; } 194 195 // Return the type string for a BUILTIN() macro in Builtins.def. 196 std::string getBuiltinTypeStr() const; 197 198 // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should 199 // init the RVVIntrinsic ID and IntrinsicTypes. 200 void emitCodeGenSwitchBody(raw_ostream &o) const; 201 202 // Emit the macros for mapping C/C++ intrinsic function to builtin functions. 203 void emitIntrinsicMacro(raw_ostream &o) const; 204 205 // Emit the mangled function definition. 206 void emitMangledFuncDef(raw_ostream &o) const; 207 }; 208 209 class RVVEmitter { 210 private: 211 RecordKeeper &Records; 212 std::string HeaderCode; 213 // Concat BasicType, LMUL and Proto as key 214 StringMap<RVVType> LegalTypes; 215 StringSet<> IllegalTypes; 216 217 public: 218 RVVEmitter(RecordKeeper &R) : Records(R) {} 219 220 /// Emit riscv_vector.h 221 void createHeader(raw_ostream &o); 222 223 /// Emit all the __builtin prototypes and code needed by Sema. 224 void createBuiltins(raw_ostream &o); 225 226 /// Emit all the information needed to map builtin -> LLVM IR intrinsic. 227 void createCodeGen(raw_ostream &o); 228 229 std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes); 230 231 private: 232 /// Create all intrinsics and add them to \p Out 233 void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out); 234 /// Compute output and input types by applying different config (basic type 235 /// and LMUL with type transformers). It also record result of type in legal 236 /// or illegal set to avoid compute the same config again. The result maybe 237 /// have illegal RVVType. 238 Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 239 ArrayRef<std::string> PrototypeSeq); 240 Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto); 241 242 /// Emit Acrh predecessor definitions and body, assume the element of Defs are 243 /// sorted by extension. 244 void emitArchMacroAndBody( 245 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o, 246 std::function<void(raw_ostream &, const RVVIntrinsic &)>); 247 248 // Emit the architecture preprocessor definitions. Return true when emits 249 // non-empty string. 250 bool emitExtDefStr(uint8_t Extensions, raw_ostream &o); 251 // Slice Prototypes string into sub prototype string and process each sub 252 // prototype string individually in the Handler. 253 void parsePrototypes(StringRef Prototypes, 254 std::function<void(StringRef)> Handler); 255 }; 256 257 } // namespace 258 259 //===----------------------------------------------------------------------===// 260 // Type implementation 261 //===----------------------------------------------------------------------===// 262 263 LMULType::LMULType(int NewLog2LMUL) { 264 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 265 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); 266 Log2LMUL = NewLog2LMUL; 267 } 268 269 std::string LMULType::str() const { 270 if (Log2LMUL < 0) 271 return "mf" + utostr(1ULL << (-Log2LMUL)); 272 return "m" + utostr(1ULL << Log2LMUL); 273 } 274 275 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { 276 int Log2ScaleResult = 0; 277 switch (ElementBitwidth) { 278 default: 279 break; 280 case 8: 281 Log2ScaleResult = Log2LMUL + 3; 282 break; 283 case 16: 284 Log2ScaleResult = Log2LMUL + 2; 285 break; 286 case 32: 287 Log2ScaleResult = Log2LMUL + 1; 288 break; 289 case 64: 290 Log2ScaleResult = Log2LMUL; 291 break; 292 } 293 // Illegal vscale result would be less than 1 294 if (Log2ScaleResult < 0) 295 return None; 296 return 1 << Log2ScaleResult; 297 } 298 299 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } 300 301 LMULType &LMULType::operator*=(uint32_t RHS) { 302 assert(isPowerOf2_32(RHS)); 303 this->Log2LMUL = this->Log2LMUL + Log2_32(RHS); 304 return *this; 305 } 306 307 RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) 308 : BT(BT), LMUL(LMULType(Log2LMUL)) { 309 applyBasicType(); 310 applyModifier(prototype); 311 Valid = verifyType(); 312 if (Valid) { 313 initBuiltinStr(); 314 initTypeStr(); 315 if (isVector()) { 316 initClangBuiltinStr(); 317 } 318 } 319 } 320 321 // clang-format off 322 // boolean type are encoded the ratio of n (SEW/LMUL) 323 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 324 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t 325 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 326 327 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 328 // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- 329 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 330 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 331 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 332 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 333 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 334 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 335 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 336 // clang-format on 337 338 bool RVVType::verifyType() const { 339 if (ScalarType == Invalid) 340 return false; 341 if (isScalar()) 342 return true; 343 if (!Scale.hasValue()) 344 return false; 345 if (isFloat() && ElementBitwidth == 8) 346 return false; 347 unsigned V = Scale.getValue(); 348 switch (ElementBitwidth) { 349 case 1: 350 case 8: 351 // Check Scale is 1,2,4,8,16,32,64 352 return (V <= 64 && isPowerOf2_32(V)); 353 case 16: 354 // Check Scale is 1,2,4,8,16,32 355 return (V <= 32 && isPowerOf2_32(V)); 356 case 32: 357 // Check Scale is 1,2,4,8,16 358 return (V <= 16 && isPowerOf2_32(V)); 359 case 64: 360 // Check Scale is 1,2,4,8 361 return (V <= 8 && isPowerOf2_32(V)); 362 } 363 return false; 364 } 365 366 void RVVType::initBuiltinStr() { 367 assert(isValid() && "RVVType is invalid"); 368 switch (ScalarType) { 369 case ScalarTypeKind::Void: 370 BuiltinStr = "v"; 371 return; 372 case ScalarTypeKind::Size_t: 373 BuiltinStr = "z"; 374 if (IsImmediate) 375 BuiltinStr = "I" + BuiltinStr; 376 if (IsPointer) 377 BuiltinStr += "*"; 378 return; 379 case ScalarTypeKind::Ptrdiff_t: 380 BuiltinStr = "Y"; 381 return; 382 case ScalarTypeKind::UnsignedLong: 383 BuiltinStr = "ULi"; 384 return; 385 case ScalarTypeKind::SignedLong: 386 BuiltinStr = "Li"; 387 return; 388 case ScalarTypeKind::Boolean: 389 assert(ElementBitwidth == 1); 390 BuiltinStr += "b"; 391 break; 392 case ScalarTypeKind::SignedInteger: 393 case ScalarTypeKind::UnsignedInteger: 394 switch (ElementBitwidth) { 395 case 8: 396 BuiltinStr += "c"; 397 break; 398 case 16: 399 BuiltinStr += "s"; 400 break; 401 case 32: 402 BuiltinStr += "i"; 403 break; 404 case 64: 405 BuiltinStr += "Wi"; 406 break; 407 default: 408 llvm_unreachable("Unhandled ElementBitwidth!"); 409 } 410 if (isSignedInteger()) 411 BuiltinStr = "S" + BuiltinStr; 412 else 413 BuiltinStr = "U" + BuiltinStr; 414 break; 415 case ScalarTypeKind::Float: 416 switch (ElementBitwidth) { 417 case 16: 418 BuiltinStr += "x"; 419 break; 420 case 32: 421 BuiltinStr += "f"; 422 break; 423 case 64: 424 BuiltinStr += "d"; 425 break; 426 default: 427 llvm_unreachable("Unhandled ElementBitwidth!"); 428 } 429 break; 430 default: 431 llvm_unreachable("ScalarType is invalid!"); 432 } 433 if (IsImmediate) 434 BuiltinStr = "I" + BuiltinStr; 435 if (isScalar()) { 436 if (IsConstant) 437 BuiltinStr += "C"; 438 if (IsPointer) 439 BuiltinStr += "*"; 440 return; 441 } 442 BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; 443 // Pointer to vector types. Defined for Zvlsseg load intrinsics. 444 // Zvlsseg load intrinsics have pointer type arguments to store the loaded 445 // vector values. 446 if (IsPointer) 447 BuiltinStr += "*"; 448 } 449 450 void RVVType::initClangBuiltinStr() { 451 assert(isValid() && "RVVType is invalid"); 452 assert(isVector() && "Handle Vector type only"); 453 454 ClangBuiltinStr = "__rvv_"; 455 switch (ScalarType) { 456 case ScalarTypeKind::Boolean: 457 ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; 458 return; 459 case ScalarTypeKind::Float: 460 ClangBuiltinStr += "float"; 461 break; 462 case ScalarTypeKind::SignedInteger: 463 ClangBuiltinStr += "int"; 464 break; 465 case ScalarTypeKind::UnsignedInteger: 466 ClangBuiltinStr += "uint"; 467 break; 468 default: 469 llvm_unreachable("ScalarTypeKind is invalid"); 470 } 471 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; 472 } 473 474 void RVVType::initTypeStr() { 475 assert(isValid() && "RVVType is invalid"); 476 477 if (IsConstant) 478 Str += "const "; 479 480 auto getTypeString = [&](StringRef TypeStr) { 481 if (isScalar()) 482 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); 483 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") 484 .str(); 485 }; 486 487 switch (ScalarType) { 488 case ScalarTypeKind::Void: 489 Str = "void"; 490 return; 491 case ScalarTypeKind::Size_t: 492 Str = "size_t"; 493 if (IsPointer) 494 Str += " *"; 495 return; 496 case ScalarTypeKind::Ptrdiff_t: 497 Str = "ptrdiff_t"; 498 return; 499 case ScalarTypeKind::UnsignedLong: 500 Str = "unsigned long"; 501 return; 502 case ScalarTypeKind::SignedLong: 503 Str = "long"; 504 return; 505 case ScalarTypeKind::Boolean: 506 if (isScalar()) 507 Str += "bool"; 508 else 509 // Vector bool is special case, the formulate is 510 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 511 Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; 512 break; 513 case ScalarTypeKind::Float: 514 if (isScalar()) { 515 if (ElementBitwidth == 64) 516 Str += "double"; 517 else if (ElementBitwidth == 32) 518 Str += "float"; 519 else if (ElementBitwidth == 16) 520 Str += "_Float16"; 521 else 522 llvm_unreachable("Unhandled floating type."); 523 } else 524 Str += getTypeString("float"); 525 break; 526 case ScalarTypeKind::SignedInteger: 527 Str += getTypeString("int"); 528 break; 529 case ScalarTypeKind::UnsignedInteger: 530 Str += getTypeString("uint"); 531 break; 532 default: 533 llvm_unreachable("ScalarType is invalid!"); 534 } 535 if (IsPointer) 536 Str += " *"; 537 } 538 539 void RVVType::initShortStr() { 540 switch (ScalarType) { 541 case ScalarTypeKind::Boolean: 542 assert(isVector()); 543 ShortStr = "b" + utostr(64 / Scale.getValue()); 544 return; 545 case ScalarTypeKind::Float: 546 ShortStr = "f" + utostr(ElementBitwidth); 547 break; 548 case ScalarTypeKind::SignedInteger: 549 ShortStr = "i" + utostr(ElementBitwidth); 550 break; 551 case ScalarTypeKind::UnsignedInteger: 552 ShortStr = "u" + utostr(ElementBitwidth); 553 break; 554 default: 555 PrintFatalError("Unhandled case!"); 556 } 557 if (isVector()) 558 ShortStr += LMUL.str(); 559 } 560 561 void RVVType::applyBasicType() { 562 switch (BT) { 563 case 'c': 564 ElementBitwidth = 8; 565 ScalarType = ScalarTypeKind::SignedInteger; 566 break; 567 case 's': 568 ElementBitwidth = 16; 569 ScalarType = ScalarTypeKind::SignedInteger; 570 break; 571 case 'i': 572 ElementBitwidth = 32; 573 ScalarType = ScalarTypeKind::SignedInteger; 574 break; 575 case 'l': 576 ElementBitwidth = 64; 577 ScalarType = ScalarTypeKind::SignedInteger; 578 break; 579 case 'x': 580 ElementBitwidth = 16; 581 ScalarType = ScalarTypeKind::Float; 582 break; 583 case 'f': 584 ElementBitwidth = 32; 585 ScalarType = ScalarTypeKind::Float; 586 break; 587 case 'd': 588 ElementBitwidth = 64; 589 ScalarType = ScalarTypeKind::Float; 590 break; 591 default: 592 PrintFatalError("Unhandled type code!"); 593 } 594 assert(ElementBitwidth != 0 && "Bad element bitwidth!"); 595 } 596 597 void RVVType::applyModifier(StringRef Transformer) { 598 if (Transformer.empty()) 599 return; 600 // Handle primitive type transformer 601 auto PType = Transformer.back(); 602 switch (PType) { 603 case 'e': 604 Scale = 0; 605 break; 606 case 'v': 607 Scale = LMUL.getScale(ElementBitwidth); 608 break; 609 case 'w': 610 ElementBitwidth *= 2; 611 LMUL *= 2; 612 Scale = LMUL.getScale(ElementBitwidth); 613 break; 614 case 'q': 615 ElementBitwidth *= 4; 616 LMUL *= 4; 617 Scale = LMUL.getScale(ElementBitwidth); 618 break; 619 case 'o': 620 ElementBitwidth *= 8; 621 LMUL *= 8; 622 Scale = LMUL.getScale(ElementBitwidth); 623 break; 624 case 'm': 625 ScalarType = ScalarTypeKind::Boolean; 626 Scale = LMUL.getScale(ElementBitwidth); 627 ElementBitwidth = 1; 628 break; 629 case '0': 630 ScalarType = ScalarTypeKind::Void; 631 break; 632 case 'z': 633 ScalarType = ScalarTypeKind::Size_t; 634 break; 635 case 't': 636 ScalarType = ScalarTypeKind::Ptrdiff_t; 637 break; 638 case 'u': 639 ScalarType = ScalarTypeKind::UnsignedLong; 640 break; 641 case 'l': 642 ScalarType = ScalarTypeKind::SignedLong; 643 break; 644 default: 645 PrintFatalError("Illegal primitive type transformers!"); 646 } 647 Transformer = Transformer.drop_back(); 648 649 // Extract and compute complex type transformer. It can only appear one time. 650 if (Transformer.startswith("(")) { 651 size_t Idx = Transformer.find(')'); 652 assert(Idx != StringRef::npos); 653 StringRef ComplexType = Transformer.slice(1, Idx); 654 Transformer = Transformer.drop_front(Idx + 1); 655 assert(Transformer.find('(') == StringRef::npos && 656 "Only allow one complex type transformer"); 657 658 auto UpdateAndCheckComplexProto = [&]() { 659 Scale = LMUL.getScale(ElementBitwidth); 660 const StringRef VectorPrototypes("vwqom"); 661 if (!VectorPrototypes.contains(PType)) 662 PrintFatalError("Complex type transformer only supports vector type!"); 663 if (Transformer.find_first_of("PCKWS") != StringRef::npos) 664 PrintFatalError( 665 "Illegal type transformer for Complex type transformer"); 666 }; 667 auto ComputeFixedLog2LMUL = 668 [&](StringRef Value, 669 std::function<bool(const int32_t &, const int32_t &)> Compare) { 670 int32_t Log2LMUL; 671 Value.getAsInteger(10, Log2LMUL); 672 if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { 673 ScalarType = Invalid; 674 return false; 675 } 676 // Update new LMUL 677 LMUL = LMULType(Log2LMUL); 678 UpdateAndCheckComplexProto(); 679 return true; 680 }; 681 auto ComplexTT = ComplexType.split(":"); 682 if (ComplexTT.first == "Log2EEW") { 683 uint32_t Log2EEW; 684 ComplexTT.second.getAsInteger(10, Log2EEW); 685 // update new elmul = (eew/sew) * lmul 686 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); 687 // update new eew 688 ElementBitwidth = 1 << Log2EEW; 689 ScalarType = ScalarTypeKind::SignedInteger; 690 UpdateAndCheckComplexProto(); 691 } else if (ComplexTT.first == "FixedSEW") { 692 uint32_t NewSEW; 693 ComplexTT.second.getAsInteger(10, NewSEW); 694 // Set invalid type if src and dst SEW are same. 695 if (ElementBitwidth == NewSEW) { 696 ScalarType = Invalid; 697 return; 698 } 699 // Update new SEW 700 ElementBitwidth = NewSEW; 701 UpdateAndCheckComplexProto(); 702 } else if (ComplexTT.first == "LFixedLog2LMUL") { 703 // New LMUL should be larger than old 704 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>())) 705 return; 706 } else if (ComplexTT.first == "SFixedLog2LMUL") { 707 // New LMUL should be smaller than old 708 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>())) 709 return; 710 } else { 711 PrintFatalError("Illegal complex type transformers!"); 712 } 713 } 714 715 // Compute the remain type transformers 716 for (char I : Transformer) { 717 switch (I) { 718 case 'P': 719 if (IsConstant) 720 PrintFatalError("'P' transformer cannot be used after 'C'"); 721 if (IsPointer) 722 PrintFatalError("'P' transformer cannot be used twice"); 723 IsPointer = true; 724 break; 725 case 'C': 726 if (IsConstant) 727 PrintFatalError("'C' transformer cannot be used twice"); 728 IsConstant = true; 729 break; 730 case 'K': 731 IsImmediate = true; 732 break; 733 case 'U': 734 ScalarType = ScalarTypeKind::UnsignedInteger; 735 break; 736 case 'I': 737 ScalarType = ScalarTypeKind::SignedInteger; 738 break; 739 case 'F': 740 ScalarType = ScalarTypeKind::Float; 741 break; 742 case 'S': 743 LMUL = LMULType(0); 744 // Update ElementBitwidth need to update Scale too. 745 Scale = LMUL.getScale(ElementBitwidth); 746 break; 747 default: 748 PrintFatalError("Illegal non-primitive type transformer!"); 749 } 750 } 751 } 752 753 //===----------------------------------------------------------------------===// 754 // RVVIntrinsic implementation 755 //===----------------------------------------------------------------------===// 756 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix, 757 StringRef NewMangledName, StringRef MangledSuffix, 758 StringRef IRName, bool HasSideEffects, bool IsMask, 759 bool HasMaskedOffOperand, bool HasVL, 760 bool HasNoMaskedOverloaded, bool HasAutoDef, 761 StringRef ManualCodegen, const RVVTypes &OutInTypes, 762 const std::vector<int64_t> &NewIntrinsicTypes, 763 StringRef RequiredExtension, unsigned NF) 764 : IRName(IRName), HasSideEffects(HasSideEffects), IsMask(IsMask), 765 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), 766 HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef), 767 ManualCodegen(ManualCodegen.str()), NF(NF) { 768 769 // Init Name and MangledName 770 Name = NewName.str(); 771 if (NewMangledName.empty()) 772 MangledName = NewName.split("_").first.str(); 773 else 774 MangledName = NewMangledName.str(); 775 if (!Suffix.empty()) 776 Name += "_" + Suffix.str(); 777 if (!MangledSuffix.empty()) 778 MangledName += "_" + MangledSuffix.str(); 779 if (IsMask) { 780 Name += "_m"; 781 } 782 // Init RISC-V extensions 783 for (const auto &T : OutInTypes) { 784 if (T->isFloatVector(16) || T->isFloat(16)) 785 RISCVExtensions |= RISCVExtension::Zfh; 786 else if (T->isFloatVector(32) || T->isFloat(32)) 787 RISCVExtensions |= RISCVExtension::F; 788 else if (T->isFloatVector(64) || T->isFloat(64)) 789 RISCVExtensions |= RISCVExtension::D; 790 } 791 if (RequiredExtension == "Zvamo") 792 RISCVExtensions |= RISCVExtension::Zvamo; 793 if (RequiredExtension == "Zvlsseg") 794 RISCVExtensions |= RISCVExtension::Zvlsseg; 795 796 // Init OutputType and InputTypes 797 OutputType = OutInTypes[0]; 798 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); 799 800 // IntrinsicTypes is nonmasked version index. Need to update it 801 // if there is maskedoff operand (It is always in first operand). 802 IntrinsicTypes = NewIntrinsicTypes; 803 if (IsMask && HasMaskedOffOperand) { 804 for (auto &I : IntrinsicTypes) { 805 if (I >= 0) 806 I += NF; 807 } 808 } 809 } 810 811 std::string RVVIntrinsic::getBuiltinTypeStr() const { 812 std::string S; 813 S += OutputType->getBuiltinStr(); 814 for (const auto &T : InputTypes) { 815 S += T->getBuiltinStr(); 816 } 817 return S; 818 } 819 820 void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const { 821 if (!getIRName().empty()) 822 OS << " ID = Intrinsic::riscv_" + getIRName() + ";\n"; 823 if (NF >= 2) 824 OS << " NF = " + utostr(getNF()) + ";\n"; 825 if (hasManualCodegen()) { 826 OS << ManualCodegen; 827 OS << "break;\n"; 828 return; 829 } 830 831 if (isMask()) { 832 if (hasVL()) { 833 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n"; 834 } else { 835 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n"; 836 } 837 } 838 839 OS << " IntrinsicTypes = {"; 840 ListSeparator LS; 841 for (const auto &Idx : IntrinsicTypes) { 842 if (Idx == -1) 843 OS << LS << "ResultType"; 844 else 845 OS << LS << "Ops[" << Idx << "]->getType()"; 846 } 847 848 // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is 849 // always last operand. 850 if (hasVL()) 851 OS << ", Ops.back()->getType()"; 852 OS << "};\n"; 853 OS << " break;\n"; 854 } 855 856 void RVVIntrinsic::emitIntrinsicMacro(raw_ostream &OS) const { 857 OS << "#define " << getName() << "("; 858 if (!InputTypes.empty()) { 859 ListSeparator LS; 860 for (unsigned i = 0, e = InputTypes.size(); i != e; ++i) 861 OS << LS << "op" << i; 862 } 863 OS << ") \\\n"; 864 OS << "__builtin_rvv_" << getName() << "("; 865 if (!InputTypes.empty()) { 866 ListSeparator LS; 867 for (unsigned i = 0, e = InputTypes.size(); i != e; ++i) 868 OS << LS << "(" << InputTypes[i]->getTypeStr() << ")(op" << i << ")"; 869 } 870 OS << ")\n"; 871 } 872 873 void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const { 874 OS << "__attribute__((clang_builtin_alias("; 875 OS << "__builtin_rvv_" << getName() << ")))\n"; 876 OS << OutputType->getTypeStr() << " " << getMangledName() << "("; 877 // Emit function arguments 878 if (!InputTypes.empty()) { 879 ListSeparator LS; 880 for (unsigned i = 0; i < InputTypes.size(); ++i) 881 OS << LS << InputTypes[i]->getTypeStr() << " op" << i; 882 } 883 OS << ");\n\n"; 884 } 885 886 //===----------------------------------------------------------------------===// 887 // RVVEmitter implementation 888 //===----------------------------------------------------------------------===// 889 void RVVEmitter::createHeader(raw_ostream &OS) { 890 891 OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics " 892 "-------------------===\n" 893 " *\n" 894 " *\n" 895 " * Part of the LLVM Project, under the Apache License v2.0 with LLVM " 896 "Exceptions.\n" 897 " * See https://llvm.org/LICENSE.txt for license information.\n" 898 " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n" 899 " *\n" 900 " *===-----------------------------------------------------------------" 901 "------===\n" 902 " */\n\n"; 903 904 OS << "#ifndef __RISCV_VECTOR_H\n"; 905 OS << "#define __RISCV_VECTOR_H\n\n"; 906 907 OS << "#include <stdint.h>\n"; 908 OS << "#include <stddef.h>\n\n"; 909 910 OS << "#ifndef __riscv_vector\n"; 911 OS << "#error \"Vector intrinsics require the vector extension.\"\n"; 912 OS << "#endif\n\n"; 913 914 OS << "#ifdef __cplusplus\n"; 915 OS << "extern \"C\" {\n"; 916 OS << "#endif\n\n"; 917 918 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 919 createRVVIntrinsics(Defs); 920 921 // Print header code 922 if (!HeaderCode.empty()) { 923 OS << HeaderCode; 924 } 925 926 auto printType = [&](auto T) { 927 OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr() 928 << ";\n"; 929 }; 930 931 constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3}; 932 // Print RVV boolean types. 933 for (int Log2LMUL : Log2LMULs) { 934 auto T = computeType('c', Log2LMUL, "m"); 935 if (T.hasValue()) 936 printType(T.getValue()); 937 } 938 // Print RVV int/float types. 939 for (char I : StringRef("csil")) { 940 for (int Log2LMUL : Log2LMULs) { 941 auto T = computeType(I, Log2LMUL, "v"); 942 if (T.hasValue()) { 943 printType(T.getValue()); 944 auto UT = computeType(I, Log2LMUL, "Uv"); 945 printType(UT.getValue()); 946 } 947 } 948 } 949 OS << "#if defined(__riscv_zfh)\n"; 950 for (int Log2LMUL : Log2LMULs) { 951 auto T = computeType('x', Log2LMUL, "v"); 952 if (T.hasValue()) 953 printType(T.getValue()); 954 } 955 OS << "#endif\n"; 956 957 OS << "#if defined(__riscv_f)\n"; 958 for (int Log2LMUL : Log2LMULs) { 959 auto T = computeType('f', Log2LMUL, "v"); 960 if (T.hasValue()) 961 printType(T.getValue()); 962 } 963 OS << "#endif\n"; 964 965 OS << "#if defined(__riscv_d)\n"; 966 for (int Log2LMUL : Log2LMULs) { 967 auto T = computeType('d', Log2LMUL, "v"); 968 if (T.hasValue()) 969 printType(T.getValue()); 970 } 971 OS << "#endif\n\n"; 972 973 // The same extension include in the same arch guard marco. 974 std::stable_sort(Defs.begin(), Defs.end(), 975 [](const std::unique_ptr<RVVIntrinsic> &A, 976 const std::unique_ptr<RVVIntrinsic> &B) { 977 return A->getRISCVExtensions() < B->getRISCVExtensions(); 978 }); 979 980 // Print intrinsic functions with macro 981 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { 982 Inst.emitIntrinsicMacro(OS); 983 }); 984 985 OS << "#define __riscv_v_intrinsic_overloading 1\n"; 986 987 // Print Overloaded APIs 988 OS << "#define __rvv_overloaded static inline " 989 "__attribute__((__always_inline__, __nodebug__, __overloadable__))\n"; 990 991 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { 992 if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded()) 993 return; 994 OS << "__rvv_overloaded "; 995 Inst.emitMangledFuncDef(OS); 996 }); 997 998 OS << "\n#ifdef __cplusplus\n"; 999 OS << "}\n"; 1000 OS << "#endif // __riscv_vector\n"; 1001 OS << "#endif // __RISCV_VECTOR_H\n"; 1002 } 1003 1004 void RVVEmitter::createBuiltins(raw_ostream &OS) { 1005 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 1006 createRVVIntrinsics(Defs); 1007 1008 OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n"; 1009 OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, " 1010 "ATTRS, \"experimental-v\")\n"; 1011 OS << "#endif\n"; 1012 for (auto &Def : Defs) { 1013 OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getName() << ",\"" 1014 << Def->getBuiltinTypeStr() << "\", "; 1015 if (!Def->hasSideEffects()) 1016 OS << "\"n\")\n"; 1017 else 1018 OS << "\"\")\n"; 1019 } 1020 OS << "#undef RISCVV_BUILTIN\n"; 1021 } 1022 1023 void RVVEmitter::createCodeGen(raw_ostream &OS) { 1024 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 1025 createRVVIntrinsics(Defs); 1026 // IR name could be empty, use the stable sort preserves the relative order. 1027 std::stable_sort(Defs.begin(), Defs.end(), 1028 [](const std::unique_ptr<RVVIntrinsic> &A, 1029 const std::unique_ptr<RVVIntrinsic> &B) { 1030 return A->getIRName() < B->getIRName(); 1031 }); 1032 // Print switch body when the ir name or ManualCodegen changes from previous 1033 // iteration. 1034 RVVIntrinsic *PrevDef = Defs.begin()->get(); 1035 for (auto &Def : Defs) { 1036 StringRef CurIRName = Def->getIRName(); 1037 if (CurIRName != PrevDef->getIRName() || 1038 (Def->getManualCodegen() != PrevDef->getManualCodegen())) { 1039 PrevDef->emitCodeGenSwitchBody(OS); 1040 } 1041 PrevDef = Def.get(); 1042 OS << "case RISCV::BI__builtin_rvv_" << Def->getName() << ":\n"; 1043 } 1044 Defs.back()->emitCodeGenSwitchBody(OS); 1045 OS << "\n"; 1046 } 1047 1048 void RVVEmitter::parsePrototypes(StringRef Prototypes, 1049 std::function<void(StringRef)> Handler) { 1050 const StringRef Primaries("evwqom0ztul"); 1051 while (!Prototypes.empty()) { 1052 size_t Idx = 0; 1053 // Skip over complex prototype because it could contain primitive type 1054 // character. 1055 if (Prototypes[0] == '(') 1056 Idx = Prototypes.find_first_of(')'); 1057 Idx = Prototypes.find_first_of(Primaries, Idx); 1058 assert(Idx != StringRef::npos); 1059 Handler(Prototypes.slice(0, Idx + 1)); 1060 Prototypes = Prototypes.drop_front(Idx + 1); 1061 } 1062 } 1063 1064 std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL, 1065 StringRef Prototypes) { 1066 SmallVector<std::string> SuffixStrs; 1067 parsePrototypes(Prototypes, [&](StringRef Proto) { 1068 auto T = computeType(Type, Log2LMUL, Proto); 1069 SuffixStrs.push_back(T.getValue()->getShortStr()); 1070 }); 1071 return join(SuffixStrs, "_"); 1072 } 1073 1074 void RVVEmitter::createRVVIntrinsics( 1075 std::vector<std::unique_ptr<RVVIntrinsic>> &Out) { 1076 std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin"); 1077 for (auto *R : RV) { 1078 StringRef Name = R->getValueAsString("Name"); 1079 StringRef SuffixProto = R->getValueAsString("Suffix"); 1080 StringRef MangledName = R->getValueAsString("MangledName"); 1081 StringRef MangledSuffixProto = R->getValueAsString("MangledSuffix"); 1082 StringRef Prototypes = R->getValueAsString("Prototype"); 1083 StringRef TypeRange = R->getValueAsString("TypeRange"); 1084 bool HasMask = R->getValueAsBit("HasMask"); 1085 bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand"); 1086 bool HasVL = R->getValueAsBit("HasVL"); 1087 bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded"); 1088 bool HasSideEffects = R->getValueAsBit("HasSideEffects"); 1089 std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL"); 1090 StringRef ManualCodegen = R->getValueAsString("ManualCodegen"); 1091 StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask"); 1092 std::vector<int64_t> IntrinsicTypes = 1093 R->getValueAsListOfInts("IntrinsicTypes"); 1094 StringRef RequiredExtension = R->getValueAsString("RequiredExtension"); 1095 StringRef IRName = R->getValueAsString("IRName"); 1096 StringRef IRNameMask = R->getValueAsString("IRNameMask"); 1097 unsigned NF = R->getValueAsInt("NF"); 1098 1099 StringRef HeaderCodeStr = R->getValueAsString("HeaderCode"); 1100 bool HasAutoDef = HeaderCodeStr.empty(); 1101 if (!HeaderCodeStr.empty()) { 1102 HeaderCode += HeaderCodeStr.str(); 1103 } 1104 // Parse prototype and create a list of primitive type with transformers 1105 // (operand) in ProtoSeq. ProtoSeq[0] is output operand. 1106 SmallVector<std::string> ProtoSeq; 1107 parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { 1108 ProtoSeq.push_back(Proto.str()); 1109 }); 1110 1111 // Compute Builtin types 1112 SmallVector<std::string> ProtoMaskSeq = ProtoSeq; 1113 if (HasMask) { 1114 // If HasMaskedOffOperand, insert result type as first input operand. 1115 if (HasMaskedOffOperand) { 1116 if (NF == 1) { 1117 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]); 1118 } else { 1119 // Convert 1120 // (void, op0 address, op1 address, ...) 1121 // to 1122 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 1123 for (unsigned I = 0; I < NF; ++I) 1124 ProtoMaskSeq.insert( 1125 ProtoMaskSeq.begin() + NF + 1, 1126 ProtoSeq[1].substr(1)); // Use substr(1) to skip '*' 1127 } 1128 } 1129 if (HasMaskedOffOperand && NF > 1) { 1130 // Convert 1131 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 1132 // to 1133 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, 1134 // ...) 1135 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m"); 1136 } else { 1137 // If HasMask, insert 'm' as first input operand. 1138 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m"); 1139 } 1140 } 1141 // If HasVL, append 'z' to last operand 1142 if (HasVL) { 1143 ProtoSeq.push_back("z"); 1144 ProtoMaskSeq.push_back("z"); 1145 } 1146 1147 // Create Intrinsics for each type and LMUL. 1148 for (char I : TypeRange) { 1149 for (int Log2LMUL : Log2LMULList) { 1150 Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq); 1151 // Ignored to create new intrinsic if there are any illegal types. 1152 if (!Types.hasValue()) 1153 continue; 1154 1155 auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto); 1156 auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto); 1157 // Create a non-mask intrinsic 1158 Out.push_back(std::make_unique<RVVIntrinsic>( 1159 Name, SuffixStr, MangledName, MangledSuffixStr, IRName, 1160 HasSideEffects, /*IsMask=*/false, /*HasMaskedOffOperand=*/false, 1161 HasVL, HasNoMaskedOverloaded, HasAutoDef, ManualCodegen, 1162 Types.getValue(), IntrinsicTypes, RequiredExtension, NF)); 1163 if (HasMask) { 1164 // Create a mask intrinsic 1165 Optional<RVVTypes> MaskTypes = 1166 computeTypes(I, Log2LMUL, NF, ProtoMaskSeq); 1167 Out.push_back(std::make_unique<RVVIntrinsic>( 1168 Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask, 1169 HasSideEffects, /*IsMask=*/true, HasMaskedOffOperand, HasVL, 1170 HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask, 1171 MaskTypes.getValue(), IntrinsicTypes, RequiredExtension, NF)); 1172 } 1173 } // end for Log2LMULList 1174 } // end for TypeRange 1175 } 1176 } 1177 1178 Optional<RVVTypes> 1179 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 1180 ArrayRef<std::string> PrototypeSeq) { 1181 // LMUL x NF must be less than or equal to 8. 1182 if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) 1183 return llvm::None; 1184 1185 RVVTypes Types; 1186 for (const std::string &Proto : PrototypeSeq) { 1187 auto T = computeType(BT, Log2LMUL, Proto); 1188 if (!T.hasValue()) 1189 return llvm::None; 1190 // Record legal type index 1191 Types.push_back(T.getValue()); 1192 } 1193 return Types; 1194 } 1195 1196 Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL, 1197 StringRef Proto) { 1198 std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str(); 1199 // Search first 1200 auto It = LegalTypes.find(Idx); 1201 if (It != LegalTypes.end()) 1202 return &(It->second); 1203 if (IllegalTypes.count(Idx)) 1204 return llvm::None; 1205 // Compute type and record the result. 1206 RVVType T(BT, Log2LMUL, Proto); 1207 if (T.isValid()) { 1208 // Record legal type index and value. 1209 LegalTypes.insert({Idx, T}); 1210 return &(LegalTypes[Idx]); 1211 } 1212 // Record illegal type index. 1213 IllegalTypes.insert(Idx); 1214 return llvm::None; 1215 } 1216 1217 void RVVEmitter::emitArchMacroAndBody( 1218 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS, 1219 std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) { 1220 uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions(); 1221 bool NeedEndif = emitExtDefStr(PrevExt, OS); 1222 for (auto &Def : Defs) { 1223 uint8_t CurExt = Def->getRISCVExtensions(); 1224 if (CurExt != PrevExt) { 1225 if (NeedEndif) 1226 OS << "#endif\n\n"; 1227 NeedEndif = emitExtDefStr(CurExt, OS); 1228 PrevExt = CurExt; 1229 } 1230 if (Def->hasAutoDef()) 1231 PrintBody(OS, *Def); 1232 } 1233 if (NeedEndif) 1234 OS << "#endif\n\n"; 1235 } 1236 1237 bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) { 1238 if (Extents == RISCVExtension::Basic) 1239 return false; 1240 OS << "#if "; 1241 ListSeparator LS(" && "); 1242 if (Extents & RISCVExtension::F) 1243 OS << LS << "defined(__riscv_f)"; 1244 if (Extents & RISCVExtension::D) 1245 OS << LS << "defined(__riscv_d)"; 1246 if (Extents & RISCVExtension::Zfh) 1247 OS << LS << "defined(__riscv_zfh)"; 1248 if (Extents & RISCVExtension::Zvamo) 1249 OS << LS << "defined(__riscv_zvamo)"; 1250 if (Extents & RISCVExtension::Zvlsseg) 1251 OS << LS << "defined(__riscv_zvlsseg)"; 1252 OS << "\n"; 1253 return true; 1254 } 1255 1256 namespace clang { 1257 void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) { 1258 RVVEmitter(Records).createHeader(OS); 1259 } 1260 1261 void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) { 1262 RVVEmitter(Records).createBuiltins(OS); 1263 } 1264 1265 void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) { 1266 RVVEmitter(Records).createCodeGen(OS); 1267 } 1268 1269 } // End namespace clang 1270