1 //===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This tablegen backend is responsible for emitting riscv_vector.h which
10 // includes a declaration and definition of each intrinsic functions specified
11 // in https://github.com/riscv/rvv-intrinsic-doc.
12 //
13 // See also the documentation in include/clang/Basic/riscv_vector.td.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/TableGen/Error.h"
24 #include "llvm/TableGen/Record.h"
25 #include <numeric>
26
27 using namespace llvm;
28 using BasicType = char;
29 using VScaleVal = Optional<unsigned>;
30
31 namespace {
32
33 // Exponential LMUL
34 struct LMULType {
35 int Log2LMUL;
36 LMULType(int Log2LMUL);
37 // Return the C/C++ string representation of LMUL
38 std::string str() const;
39 Optional<unsigned> getScale(unsigned ElementBitwidth) const;
40 void MulLog2LMUL(int Log2LMUL);
41 LMULType &operator*=(uint32_t RHS);
42 };
43
44 // This class is compact representation of a valid and invalid RVVType.
45 class RVVType {
46 enum ScalarTypeKind : uint32_t {
47 Void,
48 Size_t,
49 Ptrdiff_t,
50 UnsignedLong,
51 SignedLong,
52 Boolean,
53 SignedInteger,
54 UnsignedInteger,
55 Float,
56 Invalid,
57 };
58 BasicType BT;
59 ScalarTypeKind ScalarType = Invalid;
60 LMULType LMUL;
61 bool IsPointer = false;
62 // IsConstant indices are "int", but have the constant expression.
63 bool IsImmediate = false;
64 // Const qualifier for pointer to const object or object of const type.
65 bool IsConstant = false;
66 unsigned ElementBitwidth = 0;
67 VScaleVal Scale = 0;
68 bool Valid;
69
70 std::string BuiltinStr;
71 std::string ClangBuiltinStr;
72 std::string Str;
73 std::string ShortStr;
74
75 public:
RVVType()76 RVVType() : RVVType(BasicType(), 0, StringRef()) {}
77 RVVType(BasicType BT, int Log2LMUL, StringRef prototype);
78
79 // Return the string representation of a type, which is an encoded string for
80 // passing to the BUILTIN() macro in Builtins.def.
getBuiltinStr() const81 const std::string &getBuiltinStr() const { return BuiltinStr; }
82
83 // Return the clang buitlin type for RVV vector type which are used in the
84 // riscv_vector.h header file.
getClangBuiltinStr() const85 const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
86
87 // Return the C/C++ string representation of a type for use in the
88 // riscv_vector.h header file.
getTypeStr() const89 const std::string &getTypeStr() const { return Str; }
90
91 // Return the short name of a type for C/C++ name suffix.
getShortStr()92 const std::string &getShortStr() {
93 // Not all types are used in short name, so compute the short name by
94 // demanded.
95 if (ShortStr.empty())
96 initShortStr();
97 return ShortStr;
98 }
99
isValid() const100 bool isValid() const { return Valid; }
isScalar() const101 bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
isVector() const102 bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
isFloat() const103 bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
isSignedInteger() const104 bool isSignedInteger() const {
105 return ScalarType == ScalarTypeKind::SignedInteger;
106 }
isFloatVector(unsigned Width) const107 bool isFloatVector(unsigned Width) const {
108 return isVector() && isFloat() && ElementBitwidth == Width;
109 }
isFloat(unsigned Width) const110 bool isFloat(unsigned Width) const {
111 return isFloat() && ElementBitwidth == Width;
112 }
113
114 private:
115 // Verify RVV vector type and set Valid.
116 bool verifyType() const;
117
118 // Creates a type based on basic types of TypeRange
119 void applyBasicType();
120
121 // Applies a prototype modifier to the current type. The result maybe an
122 // invalid type.
123 void applyModifier(StringRef prototype);
124
125 // Compute and record a string for legal type.
126 void initBuiltinStr();
127 // Compute and record a builtin RVV vector type string.
128 void initClangBuiltinStr();
129 // Compute and record a type string for used in the header.
130 void initTypeStr();
131 // Compute and record a short name of a type for C/C++ name suffix.
132 void initShortStr();
133 };
134
135 using RVVTypePtr = RVVType *;
136 using RVVTypes = std::vector<RVVTypePtr>;
137
138 enum RISCVExtension : uint8_t {
139 Basic = 0,
140 F = 1 << 1,
141 D = 1 << 2,
142 Zfh = 1 << 3,
143 Zvamo = 1 << 4,
144 Zvlsseg = 1 << 5,
145 };
146
147 // TODO refactor RVVIntrinsic class design after support all intrinsic
148 // combination. This represents an instantiation of an intrinsic with a
149 // particular type and prototype
150 class RVVIntrinsic {
151
152 private:
153 std::string Name; // Builtin name
154 std::string MangledName;
155 std::string IRName;
156 bool HasSideEffects;
157 bool IsMask;
158 bool HasMaskedOffOperand;
159 bool HasVL;
160 bool HasPolicy;
161 bool IsMaskPolicyIntrinsic;
162 bool HasNoMaskedOverloaded;
163 bool HasAutoDef; // There is automiatic definition in header
164 std::string ManualCodegen;
165 RVVTypePtr OutputType; // Builtin output type
166 RVVTypes InputTypes; // Builtin input types
167 // The types we use to obtain the specific LLVM intrinsic. They are index of
168 // InputTypes. -1 means the return type.
169 std::vector<int64_t> IntrinsicTypes;
170 uint8_t RISCVExtensions = 0;
171 unsigned NF = 1;
172
173 public:
174 RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
175 StringRef MangledSuffix, StringRef IRName, bool HasSideEffects,
176 bool IsMask, bool HasMaskedOffOperand, bool HasVL,
177 bool HasPolicy, bool IsMaskPolicyIntrinsic,
178 bool HasNoMaskedOverloaded, bool HasAutoDef,
179 StringRef ManualCodegen, const RVVTypes &Types,
180 const std::vector<int64_t> &IntrinsicTypes,
181 StringRef RequiredExtension, unsigned NF);
182 ~RVVIntrinsic() = default;
183
getName() const184 StringRef getName() const { return Name; }
getMangledName() const185 StringRef getMangledName() const { return MangledName; }
hasSideEffects() const186 bool hasSideEffects() const { return HasSideEffects; }
hasMaskedOffOperand() const187 bool hasMaskedOffOperand() const { return HasMaskedOffOperand; }
hasVL() const188 bool hasVL() const { return HasVL; }
hasPolicy() const189 bool hasPolicy() const { return HasPolicy; }
isMaskPolicyIntrinsic() const190 bool isMaskPolicyIntrinsic() const { return IsMaskPolicyIntrinsic; }
hasNoMaskedOverloaded() const191 bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; }
hasManualCodegen() const192 bool hasManualCodegen() const { return !ManualCodegen.empty(); }
hasAutoDef() const193 bool hasAutoDef() const { return HasAutoDef; }
isMask() const194 bool isMask() const { return IsMask; }
getIRName() const195 StringRef getIRName() const { return IRName; }
getManualCodegen() const196 StringRef getManualCodegen() const { return ManualCodegen; }
getRISCVExtensions() const197 uint8_t getRISCVExtensions() const { return RISCVExtensions; }
getNF() const198 unsigned getNF() const { return NF; }
199
200 // Return the type string for a BUILTIN() macro in Builtins.def.
201 std::string getBuiltinTypeStr() const;
202
203 // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should
204 // init the RVVIntrinsic ID and IntrinsicTypes.
205 void emitCodeGenSwitchBody(raw_ostream &o) const;
206
207 // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
208 void emitIntrinsicMacro(raw_ostream &o) const;
209
210 // Emit the mangled function definition.
211 void emitMangledFuncDef(raw_ostream &o) const;
212 };
213
214 class RVVEmitter {
215 private:
216 RecordKeeper &Records;
217 std::string HeaderCode;
218 // Concat BasicType, LMUL and Proto as key
219 StringMap<RVVType> LegalTypes;
220 StringSet<> IllegalTypes;
221
222 public:
RVVEmitter(RecordKeeper & R)223 RVVEmitter(RecordKeeper &R) : Records(R) {}
224
225 /// Emit riscv_vector.h
226 void createHeader(raw_ostream &o);
227
228 /// Emit all the __builtin prototypes and code needed by Sema.
229 void createBuiltins(raw_ostream &o);
230
231 /// Emit all the information needed to map builtin -> LLVM IR intrinsic.
232 void createCodeGen(raw_ostream &o);
233
234 std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
235
236 private:
237 /// Create all intrinsics and add them to \p Out
238 void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
239 /// Create Headers and add them to \p Out
240 void createRVVHeaders(raw_ostream &OS);
241 /// Compute output and input types by applying different config (basic type
242 /// and LMUL with type transformers). It also record result of type in legal
243 /// or illegal set to avoid compute the same config again. The result maybe
244 /// have illegal RVVType.
245 Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
246 ArrayRef<std::string> PrototypeSeq);
247 Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
248
249 /// Emit Acrh predecessor definitions and body, assume the element of Defs are
250 /// sorted by extension.
251 void emitArchMacroAndBody(
252 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o,
253 std::function<void(raw_ostream &, const RVVIntrinsic &)>);
254
255 // Emit the architecture preprocessor definitions. Return true when emits
256 // non-empty string.
257 bool emitExtDefStr(uint8_t Extensions, raw_ostream &o);
258 // Slice Prototypes string into sub prototype string and process each sub
259 // prototype string individually in the Handler.
260 void parsePrototypes(StringRef Prototypes,
261 std::function<void(StringRef)> Handler);
262 };
263
264 } // namespace
265
266 //===----------------------------------------------------------------------===//
267 // Type implementation
268 //===----------------------------------------------------------------------===//
269
LMULType(int NewLog2LMUL)270 LMULType::LMULType(int NewLog2LMUL) {
271 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
272 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
273 Log2LMUL = NewLog2LMUL;
274 }
275
str() const276 std::string LMULType::str() const {
277 if (Log2LMUL < 0)
278 return "mf" + utostr(1ULL << (-Log2LMUL));
279 return "m" + utostr(1ULL << Log2LMUL);
280 }
281
getScale(unsigned ElementBitwidth) const282 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
283 int Log2ScaleResult = 0;
284 switch (ElementBitwidth) {
285 default:
286 break;
287 case 8:
288 Log2ScaleResult = Log2LMUL + 3;
289 break;
290 case 16:
291 Log2ScaleResult = Log2LMUL + 2;
292 break;
293 case 32:
294 Log2ScaleResult = Log2LMUL + 1;
295 break;
296 case 64:
297 Log2ScaleResult = Log2LMUL;
298 break;
299 }
300 // Illegal vscale result would be less than 1
301 if (Log2ScaleResult < 0)
302 return None;
303 return 1 << Log2ScaleResult;
304 }
305
MulLog2LMUL(int log2LMUL)306 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
307
operator *=(uint32_t RHS)308 LMULType &LMULType::operator*=(uint32_t RHS) {
309 assert(isPowerOf2_32(RHS));
310 this->Log2LMUL = this->Log2LMUL + Log2_32(RHS);
311 return *this;
312 }
313
RVVType(BasicType BT,int Log2LMUL,StringRef prototype)314 RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
315 : BT(BT), LMUL(LMULType(Log2LMUL)) {
316 applyBasicType();
317 applyModifier(prototype);
318 Valid = verifyType();
319 if (Valid) {
320 initBuiltinStr();
321 initTypeStr();
322 if (isVector()) {
323 initClangBuiltinStr();
324 }
325 }
326 }
327
328 // clang-format off
329 // boolean type are encoded the ratio of n (SEW/LMUL)
330 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
331 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
332 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
333
334 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
335 // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
336 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
337 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
338 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
339 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
340 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
341 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
342 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
343 // clang-format on
344
verifyType() const345 bool RVVType::verifyType() const {
346 if (ScalarType == Invalid)
347 return false;
348 if (isScalar())
349 return true;
350 if (!Scale.hasValue())
351 return false;
352 if (isFloat() && ElementBitwidth == 8)
353 return false;
354 unsigned V = Scale.getValue();
355 switch (ElementBitwidth) {
356 case 1:
357 case 8:
358 // Check Scale is 1,2,4,8,16,32,64
359 return (V <= 64 && isPowerOf2_32(V));
360 case 16:
361 // Check Scale is 1,2,4,8,16,32
362 return (V <= 32 && isPowerOf2_32(V));
363 case 32:
364 // Check Scale is 1,2,4,8,16
365 return (V <= 16 && isPowerOf2_32(V));
366 case 64:
367 // Check Scale is 1,2,4,8
368 return (V <= 8 && isPowerOf2_32(V));
369 }
370 return false;
371 }
372
initBuiltinStr()373 void RVVType::initBuiltinStr() {
374 assert(isValid() && "RVVType is invalid");
375 switch (ScalarType) {
376 case ScalarTypeKind::Void:
377 BuiltinStr = "v";
378 return;
379 case ScalarTypeKind::Size_t:
380 BuiltinStr = "z";
381 if (IsImmediate)
382 BuiltinStr = "I" + BuiltinStr;
383 if (IsPointer)
384 BuiltinStr += "*";
385 return;
386 case ScalarTypeKind::Ptrdiff_t:
387 BuiltinStr = "Y";
388 return;
389 case ScalarTypeKind::UnsignedLong:
390 BuiltinStr = "ULi";
391 return;
392 case ScalarTypeKind::SignedLong:
393 BuiltinStr = "Li";
394 return;
395 case ScalarTypeKind::Boolean:
396 assert(ElementBitwidth == 1);
397 BuiltinStr += "b";
398 break;
399 case ScalarTypeKind::SignedInteger:
400 case ScalarTypeKind::UnsignedInteger:
401 switch (ElementBitwidth) {
402 case 8:
403 BuiltinStr += "c";
404 break;
405 case 16:
406 BuiltinStr += "s";
407 break;
408 case 32:
409 BuiltinStr += "i";
410 break;
411 case 64:
412 BuiltinStr += "Wi";
413 break;
414 default:
415 llvm_unreachable("Unhandled ElementBitwidth!");
416 }
417 if (isSignedInteger())
418 BuiltinStr = "S" + BuiltinStr;
419 else
420 BuiltinStr = "U" + BuiltinStr;
421 break;
422 case ScalarTypeKind::Float:
423 switch (ElementBitwidth) {
424 case 16:
425 BuiltinStr += "x";
426 break;
427 case 32:
428 BuiltinStr += "f";
429 break;
430 case 64:
431 BuiltinStr += "d";
432 break;
433 default:
434 llvm_unreachable("Unhandled ElementBitwidth!");
435 }
436 break;
437 default:
438 llvm_unreachable("ScalarType is invalid!");
439 }
440 if (IsImmediate)
441 BuiltinStr = "I" + BuiltinStr;
442 if (isScalar()) {
443 if (IsConstant)
444 BuiltinStr += "C";
445 if (IsPointer)
446 BuiltinStr += "*";
447 return;
448 }
449 BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
450 // Pointer to vector types. Defined for Zvlsseg load intrinsics.
451 // Zvlsseg load intrinsics have pointer type arguments to store the loaded
452 // vector values.
453 if (IsPointer)
454 BuiltinStr += "*";
455 }
456
initClangBuiltinStr()457 void RVVType::initClangBuiltinStr() {
458 assert(isValid() && "RVVType is invalid");
459 assert(isVector() && "Handle Vector type only");
460
461 ClangBuiltinStr = "__rvv_";
462 switch (ScalarType) {
463 case ScalarTypeKind::Boolean:
464 ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
465 return;
466 case ScalarTypeKind::Float:
467 ClangBuiltinStr += "float";
468 break;
469 case ScalarTypeKind::SignedInteger:
470 ClangBuiltinStr += "int";
471 break;
472 case ScalarTypeKind::UnsignedInteger:
473 ClangBuiltinStr += "uint";
474 break;
475 default:
476 llvm_unreachable("ScalarTypeKind is invalid");
477 }
478 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
479 }
480
initTypeStr()481 void RVVType::initTypeStr() {
482 assert(isValid() && "RVVType is invalid");
483
484 if (IsConstant)
485 Str += "const ";
486
487 auto getTypeString = [&](StringRef TypeStr) {
488 if (isScalar())
489 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
490 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
491 .str();
492 };
493
494 switch (ScalarType) {
495 case ScalarTypeKind::Void:
496 Str = "void";
497 return;
498 case ScalarTypeKind::Size_t:
499 Str = "size_t";
500 if (IsPointer)
501 Str += " *";
502 return;
503 case ScalarTypeKind::Ptrdiff_t:
504 Str = "ptrdiff_t";
505 return;
506 case ScalarTypeKind::UnsignedLong:
507 Str = "unsigned long";
508 return;
509 case ScalarTypeKind::SignedLong:
510 Str = "long";
511 return;
512 case ScalarTypeKind::Boolean:
513 if (isScalar())
514 Str += "bool";
515 else
516 // Vector bool is special case, the formulate is
517 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
518 Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
519 break;
520 case ScalarTypeKind::Float:
521 if (isScalar()) {
522 if (ElementBitwidth == 64)
523 Str += "double";
524 else if (ElementBitwidth == 32)
525 Str += "float";
526 else if (ElementBitwidth == 16)
527 Str += "_Float16";
528 else
529 llvm_unreachable("Unhandled floating type.");
530 } else
531 Str += getTypeString("float");
532 break;
533 case ScalarTypeKind::SignedInteger:
534 Str += getTypeString("int");
535 break;
536 case ScalarTypeKind::UnsignedInteger:
537 Str += getTypeString("uint");
538 break;
539 default:
540 llvm_unreachable("ScalarType is invalid!");
541 }
542 if (IsPointer)
543 Str += " *";
544 }
545
initShortStr()546 void RVVType::initShortStr() {
547 switch (ScalarType) {
548 case ScalarTypeKind::Boolean:
549 assert(isVector());
550 ShortStr = "b" + utostr(64 / Scale.getValue());
551 return;
552 case ScalarTypeKind::Float:
553 ShortStr = "f" + utostr(ElementBitwidth);
554 break;
555 case ScalarTypeKind::SignedInteger:
556 ShortStr = "i" + utostr(ElementBitwidth);
557 break;
558 case ScalarTypeKind::UnsignedInteger:
559 ShortStr = "u" + utostr(ElementBitwidth);
560 break;
561 default:
562 PrintFatalError("Unhandled case!");
563 }
564 if (isVector())
565 ShortStr += LMUL.str();
566 }
567
applyBasicType()568 void RVVType::applyBasicType() {
569 switch (BT) {
570 case 'c':
571 ElementBitwidth = 8;
572 ScalarType = ScalarTypeKind::SignedInteger;
573 break;
574 case 's':
575 ElementBitwidth = 16;
576 ScalarType = ScalarTypeKind::SignedInteger;
577 break;
578 case 'i':
579 ElementBitwidth = 32;
580 ScalarType = ScalarTypeKind::SignedInteger;
581 break;
582 case 'l':
583 ElementBitwidth = 64;
584 ScalarType = ScalarTypeKind::SignedInteger;
585 break;
586 case 'x':
587 ElementBitwidth = 16;
588 ScalarType = ScalarTypeKind::Float;
589 break;
590 case 'f':
591 ElementBitwidth = 32;
592 ScalarType = ScalarTypeKind::Float;
593 break;
594 case 'd':
595 ElementBitwidth = 64;
596 ScalarType = ScalarTypeKind::Float;
597 break;
598 default:
599 PrintFatalError("Unhandled type code!");
600 }
601 assert(ElementBitwidth != 0 && "Bad element bitwidth!");
602 }
603
applyModifier(StringRef Transformer)604 void RVVType::applyModifier(StringRef Transformer) {
605 if (Transformer.empty())
606 return;
607 // Handle primitive type transformer
608 auto PType = Transformer.back();
609 switch (PType) {
610 case 'e':
611 Scale = 0;
612 break;
613 case 'v':
614 Scale = LMUL.getScale(ElementBitwidth);
615 break;
616 case 'w':
617 ElementBitwidth *= 2;
618 LMUL *= 2;
619 Scale = LMUL.getScale(ElementBitwidth);
620 break;
621 case 'q':
622 ElementBitwidth *= 4;
623 LMUL *= 4;
624 Scale = LMUL.getScale(ElementBitwidth);
625 break;
626 case 'o':
627 ElementBitwidth *= 8;
628 LMUL *= 8;
629 Scale = LMUL.getScale(ElementBitwidth);
630 break;
631 case 'm':
632 ScalarType = ScalarTypeKind::Boolean;
633 Scale = LMUL.getScale(ElementBitwidth);
634 ElementBitwidth = 1;
635 break;
636 case '0':
637 ScalarType = ScalarTypeKind::Void;
638 break;
639 case 'z':
640 ScalarType = ScalarTypeKind::Size_t;
641 break;
642 case 't':
643 ScalarType = ScalarTypeKind::Ptrdiff_t;
644 break;
645 case 'u':
646 ScalarType = ScalarTypeKind::UnsignedLong;
647 break;
648 case 'l':
649 ScalarType = ScalarTypeKind::SignedLong;
650 break;
651 default:
652 PrintFatalError("Illegal primitive type transformers!");
653 }
654 Transformer = Transformer.drop_back();
655
656 // Extract and compute complex type transformer. It can only appear one time.
657 if (Transformer.startswith("(")) {
658 size_t Idx = Transformer.find(')');
659 assert(Idx != StringRef::npos);
660 StringRef ComplexType = Transformer.slice(1, Idx);
661 Transformer = Transformer.drop_front(Idx + 1);
662 assert(Transformer.find('(') == StringRef::npos &&
663 "Only allow one complex type transformer");
664
665 auto UpdateAndCheckComplexProto = [&]() {
666 Scale = LMUL.getScale(ElementBitwidth);
667 const StringRef VectorPrototypes("vwqom");
668 if (!VectorPrototypes.contains(PType))
669 PrintFatalError("Complex type transformer only supports vector type!");
670 if (Transformer.find_first_of("PCKWS") != StringRef::npos)
671 PrintFatalError(
672 "Illegal type transformer for Complex type transformer");
673 };
674 auto ComputeFixedLog2LMUL =
675 [&](StringRef Value,
676 std::function<bool(const int32_t &, const int32_t &)> Compare) {
677 int32_t Log2LMUL;
678 Value.getAsInteger(10, Log2LMUL);
679 if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
680 ScalarType = Invalid;
681 return false;
682 }
683 // Update new LMUL
684 LMUL = LMULType(Log2LMUL);
685 UpdateAndCheckComplexProto();
686 return true;
687 };
688 auto ComplexTT = ComplexType.split(":");
689 if (ComplexTT.first == "Log2EEW") {
690 uint32_t Log2EEW;
691 ComplexTT.second.getAsInteger(10, Log2EEW);
692 // update new elmul = (eew/sew) * lmul
693 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
694 // update new eew
695 ElementBitwidth = 1 << Log2EEW;
696 ScalarType = ScalarTypeKind::SignedInteger;
697 UpdateAndCheckComplexProto();
698 } else if (ComplexTT.first == "FixedSEW") {
699 uint32_t NewSEW;
700 ComplexTT.second.getAsInteger(10, NewSEW);
701 // Set invalid type if src and dst SEW are same.
702 if (ElementBitwidth == NewSEW) {
703 ScalarType = Invalid;
704 return;
705 }
706 // Update new SEW
707 ElementBitwidth = NewSEW;
708 UpdateAndCheckComplexProto();
709 } else if (ComplexTT.first == "LFixedLog2LMUL") {
710 // New LMUL should be larger than old
711 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
712 return;
713 } else if (ComplexTT.first == "SFixedLog2LMUL") {
714 // New LMUL should be smaller than old
715 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
716 return;
717 } else {
718 PrintFatalError("Illegal complex type transformers!");
719 }
720 }
721
722 // Compute the remain type transformers
723 for (char I : Transformer) {
724 switch (I) {
725 case 'P':
726 if (IsConstant)
727 PrintFatalError("'P' transformer cannot be used after 'C'");
728 if (IsPointer)
729 PrintFatalError("'P' transformer cannot be used twice");
730 IsPointer = true;
731 break;
732 case 'C':
733 if (IsConstant)
734 PrintFatalError("'C' transformer cannot be used twice");
735 IsConstant = true;
736 break;
737 case 'K':
738 IsImmediate = true;
739 break;
740 case 'U':
741 ScalarType = ScalarTypeKind::UnsignedInteger;
742 break;
743 case 'I':
744 ScalarType = ScalarTypeKind::SignedInteger;
745 break;
746 case 'F':
747 ScalarType = ScalarTypeKind::Float;
748 break;
749 case 'S':
750 LMUL = LMULType(0);
751 // Update ElementBitwidth need to update Scale too.
752 Scale = LMUL.getScale(ElementBitwidth);
753 break;
754 default:
755 PrintFatalError("Illegal non-primitive type transformer!");
756 }
757 }
758 }
759
760 //===----------------------------------------------------------------------===//
761 // RVVIntrinsic implementation
762 //===----------------------------------------------------------------------===//
RVVIntrinsic(StringRef NewName,StringRef Suffix,StringRef NewMangledName,StringRef MangledSuffix,StringRef IRName,bool HasSideEffects,bool IsMask,bool HasMaskedOffOperand,bool HasVL,bool HasPolicy,bool IsMaskPolicyIntrinsic,bool HasNoMaskedOverloaded,bool HasAutoDef,StringRef ManualCodegen,const RVVTypes & OutInTypes,const std::vector<int64_t> & NewIntrinsicTypes,StringRef RequiredExtension,unsigned NF)763 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
764 StringRef NewMangledName, StringRef MangledSuffix,
765 StringRef IRName, bool HasSideEffects, bool IsMask,
766 bool HasMaskedOffOperand, bool HasVL, bool HasPolicy,
767 bool IsMaskPolicyIntrinsic,
768 bool HasNoMaskedOverloaded, bool HasAutoDef,
769 StringRef ManualCodegen, const RVVTypes &OutInTypes,
770 const std::vector<int64_t> &NewIntrinsicTypes,
771 StringRef RequiredExtension, unsigned NF)
772 : IRName(IRName), HasSideEffects(HasSideEffects), IsMask(IsMask),
773 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL),
774 HasPolicy(HasPolicy), IsMaskPolicyIntrinsic(IsMaskPolicyIntrinsic),
775 HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef),
776 ManualCodegen(ManualCodegen.str()), NF(NF) {
777
778 // Init Name and MangledName
779 Name = NewName.str();
780 if (NewMangledName.empty())
781 MangledName = NewName.split("_").first.str();
782 else
783 MangledName = NewMangledName.str();
784 if (!Suffix.empty())
785 Name += "_" + Suffix.str();
786 if (!MangledSuffix.empty())
787 MangledName += "_" + MangledSuffix.str();
788 if (IsMask) {
789 Name += "_m";
790 if (IsMaskPolicyIntrinsic)
791 Name += "t";
792 }
793 // Init RISC-V extensions
794 for (const auto &T : OutInTypes) {
795 if (T->isFloatVector(16) || T->isFloat(16))
796 RISCVExtensions |= RISCVExtension::Zfh;
797 else if (T->isFloatVector(32) || T->isFloat(32))
798 RISCVExtensions |= RISCVExtension::F;
799 else if (T->isFloatVector(64) || T->isFloat(64))
800 RISCVExtensions |= RISCVExtension::D;
801 }
802 if (RequiredExtension == "Zvamo")
803 RISCVExtensions |= RISCVExtension::Zvamo;
804 if (RequiredExtension == "Zvlsseg")
805 RISCVExtensions |= RISCVExtension::Zvlsseg;
806
807 // Init OutputType and InputTypes
808 OutputType = OutInTypes[0];
809 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
810
811 // IntrinsicTypes is nonmasked version index. Need to update it
812 // if there is maskedoff operand (It is always in first operand).
813 IntrinsicTypes = NewIntrinsicTypes;
814 if (IsMask && HasMaskedOffOperand) {
815 for (auto &I : IntrinsicTypes) {
816 if (I >= 0)
817 I += NF;
818 }
819 }
820 }
821
getBuiltinTypeStr() const822 std::string RVVIntrinsic::getBuiltinTypeStr() const {
823 std::string S;
824 S += OutputType->getBuiltinStr();
825 for (const auto &T : InputTypes) {
826 S += T->getBuiltinStr();
827 }
828 return S;
829 }
830
emitCodeGenSwitchBody(raw_ostream & OS) const831 void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
832 if (!getIRName().empty())
833 OS << " ID = Intrinsic::riscv_" + getIRName() + ";\n";
834 if (NF >= 2)
835 OS << " NF = " + utostr(getNF()) + ";\n";
836 if (hasManualCodegen()) {
837 OS << ManualCodegen;
838 OS << "break;\n";
839 return;
840 }
841
842 if (isMask()) {
843 if (hasVL()) {
844 if (hasPolicy()) {
845 if (isMaskPolicyIntrinsic())
846 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 2);\n";
847 else {
848 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
849 OS << " Ops.push_back(ConstantInt::get(Ops.back()->getType(), "
850 "TAIL_AGNOSTIC));\n";
851 }
852 } else
853 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
854 } else {
855 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n";
856 }
857 }
858
859 OS << " IntrinsicTypes = {";
860 ListSeparator LS;
861 for (const auto &Idx : IntrinsicTypes) {
862 if (Idx == -1)
863 OS << LS << "ResultType";
864 else
865 OS << LS << "Ops[" << Idx << "]->getType()";
866 }
867
868 // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is
869 // always last operand.
870 if (hasVL())
871 OS << ", Ops.back()->getType()";
872 OS << "};\n";
873 OS << " break;\n";
874 }
875
emitIntrinsicMacro(raw_ostream & OS) const876 void RVVIntrinsic::emitIntrinsicMacro(raw_ostream &OS) const {
877 OS << "#define " << getName() << "(";
878 if (!InputTypes.empty()) {
879 ListSeparator LS;
880 for (unsigned i = 0, e = InputTypes.size(); i != e; ++i)
881 OS << LS << "op" << i;
882 }
883 OS << ") \\\n";
884 OS << "__builtin_rvv_" << getName() << "(";
885 if (!InputTypes.empty()) {
886 ListSeparator LS;
887 for (unsigned i = 0, e = InputTypes.size(); i != e; ++i)
888 OS << LS << "(" << InputTypes[i]->getTypeStr() << ")(op" << i << ")";
889 }
890 OS << ")\n";
891 }
892
emitMangledFuncDef(raw_ostream & OS) const893 void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const {
894 OS << "__attribute__((clang_builtin_alias(";
895 OS << "__builtin_rvv_" << getName() << ")))\n";
896 OS << OutputType->getTypeStr() << " " << getMangledName() << "(";
897 // Emit function arguments
898 if (!InputTypes.empty()) {
899 ListSeparator LS;
900 for (unsigned i = 0; i < InputTypes.size(); ++i)
901 OS << LS << InputTypes[i]->getTypeStr() << " op" << i;
902 }
903 OS << ");\n\n";
904 }
905
906 //===----------------------------------------------------------------------===//
907 // RVVEmitter implementation
908 //===----------------------------------------------------------------------===//
createHeader(raw_ostream & OS)909 void RVVEmitter::createHeader(raw_ostream &OS) {
910
911 OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics "
912 "-------------------===\n"
913 " *\n"
914 " *\n"
915 " * Part of the LLVM Project, under the Apache License v2.0 with LLVM "
916 "Exceptions.\n"
917 " * See https://llvm.org/LICENSE.txt for license information.\n"
918 " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n"
919 " *\n"
920 " *===-----------------------------------------------------------------"
921 "------===\n"
922 " */\n\n";
923
924 OS << "#ifndef __RISCV_VECTOR_H\n";
925 OS << "#define __RISCV_VECTOR_H\n\n";
926
927 OS << "#include <stdint.h>\n";
928 OS << "#include <stddef.h>\n\n";
929
930 OS << "#ifndef __riscv_vector\n";
931 OS << "#error \"Vector intrinsics require the vector extension.\"\n";
932 OS << "#endif\n\n";
933
934 OS << "#ifdef __cplusplus\n";
935 OS << "extern \"C\" {\n";
936 OS << "#endif\n\n";
937
938 createRVVHeaders(OS);
939
940 std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
941 createRVVIntrinsics(Defs);
942
943 // Print header code
944 if (!HeaderCode.empty()) {
945 OS << HeaderCode;
946 }
947
948 auto printType = [&](auto T) {
949 OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr()
950 << ";\n";
951 };
952
953 constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
954 // Print RVV boolean types.
955 for (int Log2LMUL : Log2LMULs) {
956 auto T = computeType('c', Log2LMUL, "m");
957 if (T.hasValue())
958 printType(T.getValue());
959 }
960 // Print RVV int/float types.
961 for (char I : StringRef("csil")) {
962 for (int Log2LMUL : Log2LMULs) {
963 auto T = computeType(I, Log2LMUL, "v");
964 if (T.hasValue()) {
965 printType(T.getValue());
966 auto UT = computeType(I, Log2LMUL, "Uv");
967 printType(UT.getValue());
968 }
969 }
970 }
971 OS << "#if defined(__riscv_zfh)\n";
972 for (int Log2LMUL : Log2LMULs) {
973 auto T = computeType('x', Log2LMUL, "v");
974 if (T.hasValue())
975 printType(T.getValue());
976 }
977 OS << "#endif\n";
978
979 OS << "#if defined(__riscv_f)\n";
980 for (int Log2LMUL : Log2LMULs) {
981 auto T = computeType('f', Log2LMUL, "v");
982 if (T.hasValue())
983 printType(T.getValue());
984 }
985 OS << "#endif\n";
986
987 OS << "#if defined(__riscv_d)\n";
988 for (int Log2LMUL : Log2LMULs) {
989 auto T = computeType('d', Log2LMUL, "v");
990 if (T.hasValue())
991 printType(T.getValue());
992 }
993 OS << "#endif\n\n";
994
995 // The same extension include in the same arch guard marco.
996 std::stable_sort(Defs.begin(), Defs.end(),
997 [](const std::unique_ptr<RVVIntrinsic> &A,
998 const std::unique_ptr<RVVIntrinsic> &B) {
999 return A->getRISCVExtensions() < B->getRISCVExtensions();
1000 });
1001
1002 // Print intrinsic functions with macro
1003 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
1004 Inst.emitIntrinsicMacro(OS);
1005 });
1006
1007 OS << "#define __riscv_v_intrinsic_overloading 1\n";
1008
1009 // Print Overloaded APIs
1010 OS << "#define __rvv_overloaded static inline "
1011 "__attribute__((__always_inline__, __nodebug__, __overloadable__))\n";
1012
1013 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
1014 if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded())
1015 return;
1016 OS << "__rvv_overloaded ";
1017 Inst.emitMangledFuncDef(OS);
1018 });
1019
1020 OS << "\n#ifdef __cplusplus\n";
1021 OS << "}\n";
1022 OS << "#endif // __cplusplus\n";
1023 OS << "#endif // __RISCV_VECTOR_H\n";
1024 }
1025
createBuiltins(raw_ostream & OS)1026 void RVVEmitter::createBuiltins(raw_ostream &OS) {
1027 std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1028 createRVVIntrinsics(Defs);
1029
1030 OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n";
1031 OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, "
1032 "ATTRS, \"experimental-v\")\n";
1033 OS << "#endif\n";
1034 for (auto &Def : Defs) {
1035 OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getName() << ",\""
1036 << Def->getBuiltinTypeStr() << "\", ";
1037 if (!Def->hasSideEffects())
1038 OS << "\"n\")\n";
1039 else
1040 OS << "\"\")\n";
1041 }
1042 OS << "#undef RISCVV_BUILTIN\n";
1043 }
1044
createCodeGen(raw_ostream & OS)1045 void RVVEmitter::createCodeGen(raw_ostream &OS) {
1046 std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1047 createRVVIntrinsics(Defs);
1048 // IR name could be empty, use the stable sort preserves the relative order.
1049 std::stable_sort(Defs.begin(), Defs.end(),
1050 [](const std::unique_ptr<RVVIntrinsic> &A,
1051 const std::unique_ptr<RVVIntrinsic> &B) {
1052 int Cmp = A->getIRName().compare(B->getIRName());
1053 if (Cmp != 0)
1054 return Cmp < 0;
1055 // Some mask intrinsics use the same IRName as unmasked.
1056 // Sort the unmasked intrinsics first.
1057 if (A->isMask() != B->isMask())
1058 return A->isMask() < B->isMask();
1059 // _m and _mt intrinsics use the same IRName.
1060 // Sort the _m intrinsics first.
1061 return A->isMaskPolicyIntrinsic() <
1062 B->isMaskPolicyIntrinsic();
1063 });
1064 // Print switch body when the ir name or ManualCodegen changes from previous
1065 // iteration.
1066 RVVIntrinsic *PrevDef = Defs.begin()->get();
1067 for (auto &Def : Defs) {
1068 StringRef CurIRName = Def->getIRName();
1069 if (CurIRName != PrevDef->getIRName() ||
1070 (Def->getManualCodegen() != PrevDef->getManualCodegen()) ||
1071 (Def->isMaskPolicyIntrinsic() != PrevDef->isMaskPolicyIntrinsic())) {
1072 PrevDef->emitCodeGenSwitchBody(OS);
1073 }
1074 PrevDef = Def.get();
1075 OS << "case RISCV::BI__builtin_rvv_" << Def->getName() << ":\n";
1076 }
1077 Defs.back()->emitCodeGenSwitchBody(OS);
1078 OS << "\n";
1079 }
1080
parsePrototypes(StringRef Prototypes,std::function<void (StringRef)> Handler)1081 void RVVEmitter::parsePrototypes(StringRef Prototypes,
1082 std::function<void(StringRef)> Handler) {
1083 const StringRef Primaries("evwqom0ztul");
1084 while (!Prototypes.empty()) {
1085 size_t Idx = 0;
1086 // Skip over complex prototype because it could contain primitive type
1087 // character.
1088 if (Prototypes[0] == '(')
1089 Idx = Prototypes.find_first_of(')');
1090 Idx = Prototypes.find_first_of(Primaries, Idx);
1091 assert(Idx != StringRef::npos);
1092 Handler(Prototypes.slice(0, Idx + 1));
1093 Prototypes = Prototypes.drop_front(Idx + 1);
1094 }
1095 }
1096
getSuffixStr(char Type,int Log2LMUL,StringRef Prototypes)1097 std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
1098 StringRef Prototypes) {
1099 SmallVector<std::string> SuffixStrs;
1100 parsePrototypes(Prototypes, [&](StringRef Proto) {
1101 auto T = computeType(Type, Log2LMUL, Proto);
1102 SuffixStrs.push_back(T.getValue()->getShortStr());
1103 });
1104 return join(SuffixStrs, "_");
1105 }
1106
createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> & Out)1107 void RVVEmitter::createRVVIntrinsics(
1108 std::vector<std::unique_ptr<RVVIntrinsic>> &Out) {
1109 std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin");
1110 for (auto *R : RV) {
1111 StringRef Name = R->getValueAsString("Name");
1112 StringRef SuffixProto = R->getValueAsString("Suffix");
1113 StringRef MangledName = R->getValueAsString("MangledName");
1114 StringRef MangledSuffixProto = R->getValueAsString("MangledSuffix");
1115 StringRef Prototypes = R->getValueAsString("Prototype");
1116 StringRef TypeRange = R->getValueAsString("TypeRange");
1117 bool HasMask = R->getValueAsBit("HasMask");
1118 bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand");
1119 bool HasVL = R->getValueAsBit("HasVL");
1120 bool HasPolicy = R->getValueAsBit("HasPolicy");
1121 bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded");
1122 bool HasSideEffects = R->getValueAsBit("HasSideEffects");
1123 std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL");
1124 StringRef ManualCodegen = R->getValueAsString("ManualCodegen");
1125 StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask");
1126 std::vector<int64_t> IntrinsicTypes =
1127 R->getValueAsListOfInts("IntrinsicTypes");
1128 StringRef RequiredExtension = R->getValueAsString("RequiredExtension");
1129 StringRef IRName = R->getValueAsString("IRName");
1130 StringRef IRNameMask = R->getValueAsString("IRNameMask");
1131 unsigned NF = R->getValueAsInt("NF");
1132
1133 StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
1134 bool HasAutoDef = HeaderCodeStr.empty();
1135 if (!HeaderCodeStr.empty()) {
1136 HeaderCode += HeaderCodeStr.str();
1137 }
1138 // Parse prototype and create a list of primitive type with transformers
1139 // (operand) in ProtoSeq. ProtoSeq[0] is output operand.
1140 SmallVector<std::string> ProtoSeq;
1141 parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
1142 ProtoSeq.push_back(Proto.str());
1143 });
1144
1145 // Compute Builtin types
1146 SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
1147 if (HasMask) {
1148 // If HasMaskedOffOperand, insert result type as first input operand.
1149 if (HasMaskedOffOperand) {
1150 if (NF == 1) {
1151 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]);
1152 } else {
1153 // Convert
1154 // (void, op0 address, op1 address, ...)
1155 // to
1156 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1157 for (unsigned I = 0; I < NF; ++I)
1158 ProtoMaskSeq.insert(
1159 ProtoMaskSeq.begin() + NF + 1,
1160 ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
1161 }
1162 }
1163 if (HasMaskedOffOperand && NF > 1) {
1164 // Convert
1165 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1166 // to
1167 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1168 // ...)
1169 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
1170 } else {
1171 // If HasMask, insert 'm' as first input operand.
1172 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
1173 }
1174 }
1175 // If HasVL, append 'z' to the operand list.
1176 if (HasVL) {
1177 ProtoSeq.push_back("z");
1178 ProtoMaskSeq.push_back("z");
1179 }
1180
1181 SmallVector<std::string> ProtoMaskPolicySeq = ProtoMaskSeq;
1182 if (HasPolicy) {
1183 ProtoMaskPolicySeq.push_back("Kz");
1184 }
1185
1186 // Create Intrinsics for each type and LMUL.
1187 for (char I : TypeRange) {
1188 for (int Log2LMUL : Log2LMULList) {
1189 Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
1190 // Ignored to create new intrinsic if there are any illegal types.
1191 if (!Types.hasValue())
1192 continue;
1193
1194 auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
1195 auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
1196 // Create a non-mask intrinsic
1197 Out.push_back(std::make_unique<RVVIntrinsic>(
1198 Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
1199 HasSideEffects, /*IsMask=*/false, /*HasMaskedOffOperand=*/false,
1200 HasVL, HasPolicy, /*IsMaskPolicyIntrinsic*/ false,
1201 HasNoMaskedOverloaded, HasAutoDef, ManualCodegen, Types.getValue(),
1202 IntrinsicTypes, RequiredExtension, NF));
1203 if (HasMask) {
1204 // Create a mask intrinsic
1205 Optional<RVVTypes> MaskTypes =
1206 computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
1207 Out.push_back(std::make_unique<RVVIntrinsic>(
1208 Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask,
1209 HasSideEffects, /*IsMask=*/true, HasMaskedOffOperand, HasVL,
1210 HasPolicy, /*IsMaskPolicyIntrinsic=*/false, HasNoMaskedOverloaded,
1211 HasAutoDef, ManualCodegenMask, MaskTypes.getValue(),
1212 IntrinsicTypes, RequiredExtension, NF));
1213 if (HasPolicy) {
1214 MaskTypes = computeTypes(I, Log2LMUL, NF, ProtoMaskPolicySeq);
1215 Out.push_back(std::make_unique<RVVIntrinsic>(
1216 Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask,
1217 HasSideEffects, /*IsMask=*/true, HasMaskedOffOperand, HasVL,
1218 HasPolicy, /*IsMaskPolicyIntrinsic=*/true,
1219 HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask,
1220 MaskTypes.getValue(), IntrinsicTypes, RequiredExtension, NF));
1221 }
1222 }
1223 } // end for Log2LMULList
1224 } // end for TypeRange
1225 }
1226 }
1227
createRVVHeaders(raw_ostream & OS)1228 void RVVEmitter::createRVVHeaders(raw_ostream &OS) {
1229 std::vector<Record *> RVVHeaders =
1230 Records.getAllDerivedDefinitions("RVVHeader");
1231 for (auto *R : RVVHeaders) {
1232 StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
1233 OS << HeaderCodeStr.str();
1234 }
1235 }
1236
1237 Optional<RVVTypes>
computeTypes(BasicType BT,int Log2LMUL,unsigned NF,ArrayRef<std::string> PrototypeSeq)1238 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
1239 ArrayRef<std::string> PrototypeSeq) {
1240 // LMUL x NF must be less than or equal to 8.
1241 if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
1242 return llvm::None;
1243
1244 RVVTypes Types;
1245 for (const std::string &Proto : PrototypeSeq) {
1246 auto T = computeType(BT, Log2LMUL, Proto);
1247 if (!T.hasValue())
1248 return llvm::None;
1249 // Record legal type index
1250 Types.push_back(T.getValue());
1251 }
1252 return Types;
1253 }
1254
computeType(BasicType BT,int Log2LMUL,StringRef Proto)1255 Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
1256 StringRef Proto) {
1257 std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
1258 // Search first
1259 auto It = LegalTypes.find(Idx);
1260 if (It != LegalTypes.end())
1261 return &(It->second);
1262 if (IllegalTypes.count(Idx))
1263 return llvm::None;
1264 // Compute type and record the result.
1265 RVVType T(BT, Log2LMUL, Proto);
1266 if (T.isValid()) {
1267 // Record legal type index and value.
1268 LegalTypes.insert({Idx, T});
1269 return &(LegalTypes[Idx]);
1270 }
1271 // Record illegal type index.
1272 IllegalTypes.insert(Idx);
1273 return llvm::None;
1274 }
1275
emitArchMacroAndBody(std::vector<std::unique_ptr<RVVIntrinsic>> & Defs,raw_ostream & OS,std::function<void (raw_ostream &,const RVVIntrinsic &)> PrintBody)1276 void RVVEmitter::emitArchMacroAndBody(
1277 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS,
1278 std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) {
1279 uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions();
1280 bool NeedEndif = emitExtDefStr(PrevExt, OS);
1281 for (auto &Def : Defs) {
1282 uint8_t CurExt = Def->getRISCVExtensions();
1283 if (CurExt != PrevExt) {
1284 if (NeedEndif)
1285 OS << "#endif\n\n";
1286 NeedEndif = emitExtDefStr(CurExt, OS);
1287 PrevExt = CurExt;
1288 }
1289 if (Def->hasAutoDef())
1290 PrintBody(OS, *Def);
1291 }
1292 if (NeedEndif)
1293 OS << "#endif\n\n";
1294 }
1295
emitExtDefStr(uint8_t Extents,raw_ostream & OS)1296 bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) {
1297 if (Extents == RISCVExtension::Basic)
1298 return false;
1299 OS << "#if ";
1300 ListSeparator LS(" && ");
1301 if (Extents & RISCVExtension::F)
1302 OS << LS << "defined(__riscv_f)";
1303 if (Extents & RISCVExtension::D)
1304 OS << LS << "defined(__riscv_d)";
1305 if (Extents & RISCVExtension::Zfh)
1306 OS << LS << "defined(__riscv_zfh)";
1307 if (Extents & RISCVExtension::Zvamo)
1308 OS << LS << "defined(__riscv_zvamo)";
1309 if (Extents & RISCVExtension::Zvlsseg)
1310 OS << LS << "defined(__riscv_zvlsseg)";
1311 OS << "\n";
1312 return true;
1313 }
1314
1315 namespace clang {
EmitRVVHeader(RecordKeeper & Records,raw_ostream & OS)1316 void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) {
1317 RVVEmitter(Records).createHeader(OS);
1318 }
1319
EmitRVVBuiltins(RecordKeeper & Records,raw_ostream & OS)1320 void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) {
1321 RVVEmitter(Records).createBuiltins(OS);
1322 }
1323
EmitRVVBuiltinCG(RecordKeeper & Records,raw_ostream & OS)1324 void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) {
1325 RVVEmitter(Records).createCodeGen(OS);
1326 }
1327
1328 } // End namespace clang
1329