1 //===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This tablegen backend is responsible for emitting riscv_vector.h which
10 // includes a declaration and definition of each intrinsic functions specified
11 // in https://github.com/riscv/rvv-intrinsic-doc.
12 //
13 // See also the documentation in include/clang/Basic/riscv_vector.td.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/TableGen/Error.h"
24 #include "llvm/TableGen/Record.h"
25 #include <numeric>
26 
27 using namespace llvm;
28 using BasicType = char;
29 using VScaleVal = Optional<unsigned>;
30 
31 namespace {
32 
33 // Exponential LMUL
34 struct LMULType {
35   int Log2LMUL;
36   LMULType(int Log2LMUL);
37   // Return the C/C++ string representation of LMUL
38   std::string str() const;
39   Optional<unsigned> getScale(unsigned ElementBitwidth) const;
40   void MulLog2LMUL(int Log2LMUL);
41   LMULType &operator*=(uint32_t RHS);
42 };
43 
44 // This class is compact representation of a valid and invalid RVVType.
45 class RVVType {
46   enum ScalarTypeKind : uint32_t {
47     Void,
48     Size_t,
49     Ptrdiff_t,
50     UnsignedLong,
51     SignedLong,
52     Boolean,
53     SignedInteger,
54     UnsignedInteger,
55     Float,
56     Invalid,
57   };
58   BasicType BT;
59   ScalarTypeKind ScalarType = Invalid;
60   LMULType LMUL;
61   bool IsPointer = false;
62   // IsConstant indices are "int", but have the constant expression.
63   bool IsImmediate = false;
64   // Const qualifier for pointer to const object or object of const type.
65   bool IsConstant = false;
66   unsigned ElementBitwidth = 0;
67   VScaleVal Scale = 0;
68   bool Valid;
69 
70   std::string BuiltinStr;
71   std::string ClangBuiltinStr;
72   std::string Str;
73   std::string ShortStr;
74 
75 public:
RVVType()76   RVVType() : RVVType(BasicType(), 0, StringRef()) {}
77   RVVType(BasicType BT, int Log2LMUL, StringRef prototype);
78 
79   // Return the string representation of a type, which is an encoded string for
80   // passing to the BUILTIN() macro in Builtins.def.
getBuiltinStr() const81   const std::string &getBuiltinStr() const { return BuiltinStr; }
82 
83   // Return the clang buitlin type for RVV vector type which are used in the
84   // riscv_vector.h header file.
getClangBuiltinStr() const85   const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
86 
87   // Return the C/C++ string representation of a type for use in the
88   // riscv_vector.h header file.
getTypeStr() const89   const std::string &getTypeStr() const { return Str; }
90 
91   // Return the short name of a type for C/C++ name suffix.
getShortStr()92   const std::string &getShortStr() {
93     // Not all types are used in short name, so compute the short name by
94     // demanded.
95     if (ShortStr.empty())
96       initShortStr();
97     return ShortStr;
98   }
99 
isValid() const100   bool isValid() const { return Valid; }
isScalar() const101   bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
isVector() const102   bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
isFloat() const103   bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
isSignedInteger() const104   bool isSignedInteger() const {
105     return ScalarType == ScalarTypeKind::SignedInteger;
106   }
isFloatVector(unsigned Width) const107   bool isFloatVector(unsigned Width) const {
108     return isVector() && isFloat() && ElementBitwidth == Width;
109   }
isFloat(unsigned Width) const110   bool isFloat(unsigned Width) const {
111     return isFloat() && ElementBitwidth == Width;
112   }
113 
114 private:
115   // Verify RVV vector type and set Valid.
116   bool verifyType() const;
117 
118   // Creates a type based on basic types of TypeRange
119   void applyBasicType();
120 
121   // Applies a prototype modifier to the current type. The result maybe an
122   // invalid type.
123   void applyModifier(StringRef prototype);
124 
125   // Compute and record a string for legal type.
126   void initBuiltinStr();
127   // Compute and record a builtin RVV vector type string.
128   void initClangBuiltinStr();
129   // Compute and record a type string for used in the header.
130   void initTypeStr();
131   // Compute and record a short name of a type for C/C++ name suffix.
132   void initShortStr();
133 };
134 
135 using RVVTypePtr = RVVType *;
136 using RVVTypes = std::vector<RVVTypePtr>;
137 
138 enum RISCVExtension : uint8_t {
139   Basic = 0,
140   F = 1 << 1,
141   D = 1 << 2,
142   Zfh = 1 << 3,
143   Zvamo = 1 << 4,
144   Zvlsseg = 1 << 5,
145 };
146 
147 // TODO refactor RVVIntrinsic class design after support all intrinsic
148 // combination. This represents an instantiation of an intrinsic with a
149 // particular type and prototype
150 class RVVIntrinsic {
151 
152 private:
153   std::string Name; // Builtin name
154   std::string MangledName;
155   std::string IRName;
156   bool HasSideEffects;
157   bool IsMask;
158   bool HasMaskedOffOperand;
159   bool HasVL;
160   bool HasPolicy;
161   bool IsMaskPolicyIntrinsic;
162   bool HasNoMaskedOverloaded;
163   bool HasAutoDef; // There is automiatic definition in header
164   std::string ManualCodegen;
165   RVVTypePtr OutputType; // Builtin output type
166   RVVTypes InputTypes;   // Builtin input types
167   // The types we use to obtain the specific LLVM intrinsic. They are index of
168   // InputTypes. -1 means the return type.
169   std::vector<int64_t> IntrinsicTypes;
170   uint8_t RISCVExtensions = 0;
171   unsigned NF = 1;
172 
173 public:
174   RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
175                StringRef MangledSuffix, StringRef IRName, bool HasSideEffects,
176                bool IsMask, bool HasMaskedOffOperand, bool HasVL,
177                bool HasPolicy, bool IsMaskPolicyIntrinsic,
178                bool HasNoMaskedOverloaded, bool HasAutoDef,
179                StringRef ManualCodegen, const RVVTypes &Types,
180                const std::vector<int64_t> &IntrinsicTypes,
181                StringRef RequiredExtension, unsigned NF);
182   ~RVVIntrinsic() = default;
183 
getName() const184   StringRef getName() const { return Name; }
getMangledName() const185   StringRef getMangledName() const { return MangledName; }
hasSideEffects() const186   bool hasSideEffects() const { return HasSideEffects; }
hasMaskedOffOperand() const187   bool hasMaskedOffOperand() const { return HasMaskedOffOperand; }
hasVL() const188   bool hasVL() const { return HasVL; }
hasPolicy() const189   bool hasPolicy() const { return HasPolicy; }
isMaskPolicyIntrinsic() const190   bool isMaskPolicyIntrinsic() const { return IsMaskPolicyIntrinsic; }
hasNoMaskedOverloaded() const191   bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; }
hasManualCodegen() const192   bool hasManualCodegen() const { return !ManualCodegen.empty(); }
hasAutoDef() const193   bool hasAutoDef() const { return HasAutoDef; }
isMask() const194   bool isMask() const { return IsMask; }
getIRName() const195   StringRef getIRName() const { return IRName; }
getManualCodegen() const196   StringRef getManualCodegen() const { return ManualCodegen; }
getRISCVExtensions() const197   uint8_t getRISCVExtensions() const { return RISCVExtensions; }
getNF() const198   unsigned getNF() const { return NF; }
199 
200   // Return the type string for a BUILTIN() macro in Builtins.def.
201   std::string getBuiltinTypeStr() const;
202 
203   // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should
204   // init the RVVIntrinsic ID and IntrinsicTypes.
205   void emitCodeGenSwitchBody(raw_ostream &o) const;
206 
207   // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
208   void emitIntrinsicMacro(raw_ostream &o) const;
209 
210   // Emit the mangled function definition.
211   void emitMangledFuncDef(raw_ostream &o) const;
212 };
213 
214 class RVVEmitter {
215 private:
216   RecordKeeper &Records;
217   std::string HeaderCode;
218   // Concat BasicType, LMUL and Proto as key
219   StringMap<RVVType> LegalTypes;
220   StringSet<> IllegalTypes;
221 
222 public:
RVVEmitter(RecordKeeper & R)223   RVVEmitter(RecordKeeper &R) : Records(R) {}
224 
225   /// Emit riscv_vector.h
226   void createHeader(raw_ostream &o);
227 
228   /// Emit all the __builtin prototypes and code needed by Sema.
229   void createBuiltins(raw_ostream &o);
230 
231   /// Emit all the information needed to map builtin -> LLVM IR intrinsic.
232   void createCodeGen(raw_ostream &o);
233 
234   std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
235 
236 private:
237   /// Create all intrinsics and add them to \p Out
238   void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
239   /// Create Headers and add them to \p Out
240   void createRVVHeaders(raw_ostream &OS);
241   /// Compute output and input types by applying different config (basic type
242   /// and LMUL with type transformers). It also record result of type in legal
243   /// or illegal set to avoid compute the  same config again. The result maybe
244   /// have illegal RVVType.
245   Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
246                                   ArrayRef<std::string> PrototypeSeq);
247   Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
248 
249   /// Emit Acrh predecessor definitions and body, assume the element of Defs are
250   /// sorted by extension.
251   void emitArchMacroAndBody(
252       std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o,
253       std::function<void(raw_ostream &, const RVVIntrinsic &)>);
254 
255   // Emit the architecture preprocessor definitions. Return true when emits
256   // non-empty string.
257   bool emitExtDefStr(uint8_t Extensions, raw_ostream &o);
258   // Slice Prototypes string into sub prototype string and process each sub
259   // prototype string individually in the Handler.
260   void parsePrototypes(StringRef Prototypes,
261                        std::function<void(StringRef)> Handler);
262 };
263 
264 } // namespace
265 
266 //===----------------------------------------------------------------------===//
267 // Type implementation
268 //===----------------------------------------------------------------------===//
269 
LMULType(int NewLog2LMUL)270 LMULType::LMULType(int NewLog2LMUL) {
271   // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
272   assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
273   Log2LMUL = NewLog2LMUL;
274 }
275 
str() const276 std::string LMULType::str() const {
277   if (Log2LMUL < 0)
278     return "mf" + utostr(1ULL << (-Log2LMUL));
279   return "m" + utostr(1ULL << Log2LMUL);
280 }
281 
getScale(unsigned ElementBitwidth) const282 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
283   int Log2ScaleResult = 0;
284   switch (ElementBitwidth) {
285   default:
286     break;
287   case 8:
288     Log2ScaleResult = Log2LMUL + 3;
289     break;
290   case 16:
291     Log2ScaleResult = Log2LMUL + 2;
292     break;
293   case 32:
294     Log2ScaleResult = Log2LMUL + 1;
295     break;
296   case 64:
297     Log2ScaleResult = Log2LMUL;
298     break;
299   }
300   // Illegal vscale result would be less than 1
301   if (Log2ScaleResult < 0)
302     return None;
303   return 1 << Log2ScaleResult;
304 }
305 
MulLog2LMUL(int log2LMUL)306 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
307 
operator *=(uint32_t RHS)308 LMULType &LMULType::operator*=(uint32_t RHS) {
309   assert(isPowerOf2_32(RHS));
310   this->Log2LMUL = this->Log2LMUL + Log2_32(RHS);
311   return *this;
312 }
313 
RVVType(BasicType BT,int Log2LMUL,StringRef prototype)314 RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
315     : BT(BT), LMUL(LMULType(Log2LMUL)) {
316   applyBasicType();
317   applyModifier(prototype);
318   Valid = verifyType();
319   if (Valid) {
320     initBuiltinStr();
321     initTypeStr();
322     if (isVector()) {
323       initClangBuiltinStr();
324     }
325   }
326 }
327 
328 // clang-format off
329 // boolean type are encoded the ratio of n (SEW/LMUL)
330 // SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
331 // c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
332 // IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
333 
334 // type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
335 // --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
336 // i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
337 // i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
338 // i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
339 // i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
340 // double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
341 // float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
342 // half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
343 // clang-format on
344 
verifyType() const345 bool RVVType::verifyType() const {
346   if (ScalarType == Invalid)
347     return false;
348   if (isScalar())
349     return true;
350   if (!Scale.hasValue())
351     return false;
352   if (isFloat() && ElementBitwidth == 8)
353     return false;
354   unsigned V = Scale.getValue();
355   switch (ElementBitwidth) {
356   case 1:
357   case 8:
358     // Check Scale is 1,2,4,8,16,32,64
359     return (V <= 64 && isPowerOf2_32(V));
360   case 16:
361     // Check Scale is 1,2,4,8,16,32
362     return (V <= 32 && isPowerOf2_32(V));
363   case 32:
364     // Check Scale is 1,2,4,8,16
365     return (V <= 16 && isPowerOf2_32(V));
366   case 64:
367     // Check Scale is 1,2,4,8
368     return (V <= 8 && isPowerOf2_32(V));
369   }
370   return false;
371 }
372 
initBuiltinStr()373 void RVVType::initBuiltinStr() {
374   assert(isValid() && "RVVType is invalid");
375   switch (ScalarType) {
376   case ScalarTypeKind::Void:
377     BuiltinStr = "v";
378     return;
379   case ScalarTypeKind::Size_t:
380     BuiltinStr = "z";
381     if (IsImmediate)
382       BuiltinStr = "I" + BuiltinStr;
383     if (IsPointer)
384       BuiltinStr += "*";
385     return;
386   case ScalarTypeKind::Ptrdiff_t:
387     BuiltinStr = "Y";
388     return;
389   case ScalarTypeKind::UnsignedLong:
390     BuiltinStr = "ULi";
391     return;
392   case ScalarTypeKind::SignedLong:
393     BuiltinStr = "Li";
394     return;
395   case ScalarTypeKind::Boolean:
396     assert(ElementBitwidth == 1);
397     BuiltinStr += "b";
398     break;
399   case ScalarTypeKind::SignedInteger:
400   case ScalarTypeKind::UnsignedInteger:
401     switch (ElementBitwidth) {
402     case 8:
403       BuiltinStr += "c";
404       break;
405     case 16:
406       BuiltinStr += "s";
407       break;
408     case 32:
409       BuiltinStr += "i";
410       break;
411     case 64:
412       BuiltinStr += "Wi";
413       break;
414     default:
415       llvm_unreachable("Unhandled ElementBitwidth!");
416     }
417     if (isSignedInteger())
418       BuiltinStr = "S" + BuiltinStr;
419     else
420       BuiltinStr = "U" + BuiltinStr;
421     break;
422   case ScalarTypeKind::Float:
423     switch (ElementBitwidth) {
424     case 16:
425       BuiltinStr += "x";
426       break;
427     case 32:
428       BuiltinStr += "f";
429       break;
430     case 64:
431       BuiltinStr += "d";
432       break;
433     default:
434       llvm_unreachable("Unhandled ElementBitwidth!");
435     }
436     break;
437   default:
438     llvm_unreachable("ScalarType is invalid!");
439   }
440   if (IsImmediate)
441     BuiltinStr = "I" + BuiltinStr;
442   if (isScalar()) {
443     if (IsConstant)
444       BuiltinStr += "C";
445     if (IsPointer)
446       BuiltinStr += "*";
447     return;
448   }
449   BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
450   // Pointer to vector types. Defined for Zvlsseg load intrinsics.
451   // Zvlsseg load intrinsics have pointer type arguments to store the loaded
452   // vector values.
453   if (IsPointer)
454     BuiltinStr += "*";
455 }
456 
initClangBuiltinStr()457 void RVVType::initClangBuiltinStr() {
458   assert(isValid() && "RVVType is invalid");
459   assert(isVector() && "Handle Vector type only");
460 
461   ClangBuiltinStr = "__rvv_";
462   switch (ScalarType) {
463   case ScalarTypeKind::Boolean:
464     ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
465     return;
466   case ScalarTypeKind::Float:
467     ClangBuiltinStr += "float";
468     break;
469   case ScalarTypeKind::SignedInteger:
470     ClangBuiltinStr += "int";
471     break;
472   case ScalarTypeKind::UnsignedInteger:
473     ClangBuiltinStr += "uint";
474     break;
475   default:
476     llvm_unreachable("ScalarTypeKind is invalid");
477   }
478   ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
479 }
480 
initTypeStr()481 void RVVType::initTypeStr() {
482   assert(isValid() && "RVVType is invalid");
483 
484   if (IsConstant)
485     Str += "const ";
486 
487   auto getTypeString = [&](StringRef TypeStr) {
488     if (isScalar())
489       return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
490     return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
491         .str();
492   };
493 
494   switch (ScalarType) {
495   case ScalarTypeKind::Void:
496     Str = "void";
497     return;
498   case ScalarTypeKind::Size_t:
499     Str = "size_t";
500     if (IsPointer)
501       Str += " *";
502     return;
503   case ScalarTypeKind::Ptrdiff_t:
504     Str = "ptrdiff_t";
505     return;
506   case ScalarTypeKind::UnsignedLong:
507     Str = "unsigned long";
508     return;
509   case ScalarTypeKind::SignedLong:
510     Str = "long";
511     return;
512   case ScalarTypeKind::Boolean:
513     if (isScalar())
514       Str += "bool";
515     else
516       // Vector bool is special case, the formulate is
517       // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
518       Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
519     break;
520   case ScalarTypeKind::Float:
521     if (isScalar()) {
522       if (ElementBitwidth == 64)
523         Str += "double";
524       else if (ElementBitwidth == 32)
525         Str += "float";
526       else if (ElementBitwidth == 16)
527         Str += "_Float16";
528       else
529         llvm_unreachable("Unhandled floating type.");
530     } else
531       Str += getTypeString("float");
532     break;
533   case ScalarTypeKind::SignedInteger:
534     Str += getTypeString("int");
535     break;
536   case ScalarTypeKind::UnsignedInteger:
537     Str += getTypeString("uint");
538     break;
539   default:
540     llvm_unreachable("ScalarType is invalid!");
541   }
542   if (IsPointer)
543     Str += " *";
544 }
545 
initShortStr()546 void RVVType::initShortStr() {
547   switch (ScalarType) {
548   case ScalarTypeKind::Boolean:
549     assert(isVector());
550     ShortStr = "b" + utostr(64 / Scale.getValue());
551     return;
552   case ScalarTypeKind::Float:
553     ShortStr = "f" + utostr(ElementBitwidth);
554     break;
555   case ScalarTypeKind::SignedInteger:
556     ShortStr = "i" + utostr(ElementBitwidth);
557     break;
558   case ScalarTypeKind::UnsignedInteger:
559     ShortStr = "u" + utostr(ElementBitwidth);
560     break;
561   default:
562     PrintFatalError("Unhandled case!");
563   }
564   if (isVector())
565     ShortStr += LMUL.str();
566 }
567 
applyBasicType()568 void RVVType::applyBasicType() {
569   switch (BT) {
570   case 'c':
571     ElementBitwidth = 8;
572     ScalarType = ScalarTypeKind::SignedInteger;
573     break;
574   case 's':
575     ElementBitwidth = 16;
576     ScalarType = ScalarTypeKind::SignedInteger;
577     break;
578   case 'i':
579     ElementBitwidth = 32;
580     ScalarType = ScalarTypeKind::SignedInteger;
581     break;
582   case 'l':
583     ElementBitwidth = 64;
584     ScalarType = ScalarTypeKind::SignedInteger;
585     break;
586   case 'x':
587     ElementBitwidth = 16;
588     ScalarType = ScalarTypeKind::Float;
589     break;
590   case 'f':
591     ElementBitwidth = 32;
592     ScalarType = ScalarTypeKind::Float;
593     break;
594   case 'd':
595     ElementBitwidth = 64;
596     ScalarType = ScalarTypeKind::Float;
597     break;
598   default:
599     PrintFatalError("Unhandled type code!");
600   }
601   assert(ElementBitwidth != 0 && "Bad element bitwidth!");
602 }
603 
applyModifier(StringRef Transformer)604 void RVVType::applyModifier(StringRef Transformer) {
605   if (Transformer.empty())
606     return;
607   // Handle primitive type transformer
608   auto PType = Transformer.back();
609   switch (PType) {
610   case 'e':
611     Scale = 0;
612     break;
613   case 'v':
614     Scale = LMUL.getScale(ElementBitwidth);
615     break;
616   case 'w':
617     ElementBitwidth *= 2;
618     LMUL *= 2;
619     Scale = LMUL.getScale(ElementBitwidth);
620     break;
621   case 'q':
622     ElementBitwidth *= 4;
623     LMUL *= 4;
624     Scale = LMUL.getScale(ElementBitwidth);
625     break;
626   case 'o':
627     ElementBitwidth *= 8;
628     LMUL *= 8;
629     Scale = LMUL.getScale(ElementBitwidth);
630     break;
631   case 'm':
632     ScalarType = ScalarTypeKind::Boolean;
633     Scale = LMUL.getScale(ElementBitwidth);
634     ElementBitwidth = 1;
635     break;
636   case '0':
637     ScalarType = ScalarTypeKind::Void;
638     break;
639   case 'z':
640     ScalarType = ScalarTypeKind::Size_t;
641     break;
642   case 't':
643     ScalarType = ScalarTypeKind::Ptrdiff_t;
644     break;
645   case 'u':
646     ScalarType = ScalarTypeKind::UnsignedLong;
647     break;
648   case 'l':
649     ScalarType = ScalarTypeKind::SignedLong;
650     break;
651   default:
652     PrintFatalError("Illegal primitive type transformers!");
653   }
654   Transformer = Transformer.drop_back();
655 
656   // Extract and compute complex type transformer. It can only appear one time.
657   if (Transformer.startswith("(")) {
658     size_t Idx = Transformer.find(')');
659     assert(Idx != StringRef::npos);
660     StringRef ComplexType = Transformer.slice(1, Idx);
661     Transformer = Transformer.drop_front(Idx + 1);
662     assert(Transformer.find('(') == StringRef::npos &&
663            "Only allow one complex type transformer");
664 
665     auto UpdateAndCheckComplexProto = [&]() {
666       Scale = LMUL.getScale(ElementBitwidth);
667       const StringRef VectorPrototypes("vwqom");
668       if (!VectorPrototypes.contains(PType))
669         PrintFatalError("Complex type transformer only supports vector type!");
670       if (Transformer.find_first_of("PCKWS") != StringRef::npos)
671         PrintFatalError(
672             "Illegal type transformer for Complex type transformer");
673     };
674     auto ComputeFixedLog2LMUL =
675         [&](StringRef Value,
676             std::function<bool(const int32_t &, const int32_t &)> Compare) {
677           int32_t Log2LMUL;
678           Value.getAsInteger(10, Log2LMUL);
679           if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
680             ScalarType = Invalid;
681             return false;
682           }
683           // Update new LMUL
684           LMUL = LMULType(Log2LMUL);
685           UpdateAndCheckComplexProto();
686           return true;
687         };
688     auto ComplexTT = ComplexType.split(":");
689     if (ComplexTT.first == "Log2EEW") {
690       uint32_t Log2EEW;
691       ComplexTT.second.getAsInteger(10, Log2EEW);
692       // update new elmul = (eew/sew) * lmul
693       LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
694       // update new eew
695       ElementBitwidth = 1 << Log2EEW;
696       ScalarType = ScalarTypeKind::SignedInteger;
697       UpdateAndCheckComplexProto();
698     } else if (ComplexTT.first == "FixedSEW") {
699       uint32_t NewSEW;
700       ComplexTT.second.getAsInteger(10, NewSEW);
701       // Set invalid type if src and dst SEW are same.
702       if (ElementBitwidth == NewSEW) {
703         ScalarType = Invalid;
704         return;
705       }
706       // Update new SEW
707       ElementBitwidth = NewSEW;
708       UpdateAndCheckComplexProto();
709     } else if (ComplexTT.first == "LFixedLog2LMUL") {
710       // New LMUL should be larger than old
711       if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
712         return;
713     } else if (ComplexTT.first == "SFixedLog2LMUL") {
714       // New LMUL should be smaller than old
715       if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
716         return;
717     } else {
718       PrintFatalError("Illegal complex type transformers!");
719     }
720   }
721 
722   // Compute the remain type transformers
723   for (char I : Transformer) {
724     switch (I) {
725     case 'P':
726       if (IsConstant)
727         PrintFatalError("'P' transformer cannot be used after 'C'");
728       if (IsPointer)
729         PrintFatalError("'P' transformer cannot be used twice");
730       IsPointer = true;
731       break;
732     case 'C':
733       if (IsConstant)
734         PrintFatalError("'C' transformer cannot be used twice");
735       IsConstant = true;
736       break;
737     case 'K':
738       IsImmediate = true;
739       break;
740     case 'U':
741       ScalarType = ScalarTypeKind::UnsignedInteger;
742       break;
743     case 'I':
744       ScalarType = ScalarTypeKind::SignedInteger;
745       break;
746     case 'F':
747       ScalarType = ScalarTypeKind::Float;
748       break;
749     case 'S':
750       LMUL = LMULType(0);
751       // Update ElementBitwidth need to update Scale too.
752       Scale = LMUL.getScale(ElementBitwidth);
753       break;
754     default:
755       PrintFatalError("Illegal non-primitive type transformer!");
756     }
757   }
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // RVVIntrinsic implementation
762 //===----------------------------------------------------------------------===//
RVVIntrinsic(StringRef NewName,StringRef Suffix,StringRef NewMangledName,StringRef MangledSuffix,StringRef IRName,bool HasSideEffects,bool IsMask,bool HasMaskedOffOperand,bool HasVL,bool HasPolicy,bool IsMaskPolicyIntrinsic,bool HasNoMaskedOverloaded,bool HasAutoDef,StringRef ManualCodegen,const RVVTypes & OutInTypes,const std::vector<int64_t> & NewIntrinsicTypes,StringRef RequiredExtension,unsigned NF)763 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
764                            StringRef NewMangledName, StringRef MangledSuffix,
765                            StringRef IRName, bool HasSideEffects, bool IsMask,
766                            bool HasMaskedOffOperand, bool HasVL, bool HasPolicy,
767                            bool IsMaskPolicyIntrinsic,
768                            bool HasNoMaskedOverloaded, bool HasAutoDef,
769                            StringRef ManualCodegen, const RVVTypes &OutInTypes,
770                            const std::vector<int64_t> &NewIntrinsicTypes,
771                            StringRef RequiredExtension, unsigned NF)
772     : IRName(IRName), HasSideEffects(HasSideEffects), IsMask(IsMask),
773       HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL),
774       HasPolicy(HasPolicy), IsMaskPolicyIntrinsic(IsMaskPolicyIntrinsic),
775       HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef),
776       ManualCodegen(ManualCodegen.str()), NF(NF) {
777 
778   // Init Name and MangledName
779   Name = NewName.str();
780   if (NewMangledName.empty())
781     MangledName = NewName.split("_").first.str();
782   else
783     MangledName = NewMangledName.str();
784   if (!Suffix.empty())
785     Name += "_" + Suffix.str();
786   if (!MangledSuffix.empty())
787     MangledName += "_" + MangledSuffix.str();
788   if (IsMask) {
789     Name += "_m";
790     if (IsMaskPolicyIntrinsic)
791       Name += "t";
792   }
793   // Init RISC-V extensions
794   for (const auto &T : OutInTypes) {
795     if (T->isFloatVector(16) || T->isFloat(16))
796       RISCVExtensions |= RISCVExtension::Zfh;
797     else if (T->isFloatVector(32) || T->isFloat(32))
798       RISCVExtensions |= RISCVExtension::F;
799     else if (T->isFloatVector(64) || T->isFloat(64))
800       RISCVExtensions |= RISCVExtension::D;
801   }
802   if (RequiredExtension == "Zvamo")
803     RISCVExtensions |= RISCVExtension::Zvamo;
804   if (RequiredExtension == "Zvlsseg")
805     RISCVExtensions |= RISCVExtension::Zvlsseg;
806 
807   // Init OutputType and InputTypes
808   OutputType = OutInTypes[0];
809   InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
810 
811   // IntrinsicTypes is nonmasked version index. Need to update it
812   // if there is maskedoff operand (It is always in first operand).
813   IntrinsicTypes = NewIntrinsicTypes;
814   if (IsMask && HasMaskedOffOperand) {
815     for (auto &I : IntrinsicTypes) {
816       if (I >= 0)
817         I += NF;
818     }
819   }
820 }
821 
getBuiltinTypeStr() const822 std::string RVVIntrinsic::getBuiltinTypeStr() const {
823   std::string S;
824   S += OutputType->getBuiltinStr();
825   for (const auto &T : InputTypes) {
826     S += T->getBuiltinStr();
827   }
828   return S;
829 }
830 
emitCodeGenSwitchBody(raw_ostream & OS) const831 void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
832   if (!getIRName().empty())
833     OS << "  ID = Intrinsic::riscv_" + getIRName() + ";\n";
834   if (NF >= 2)
835     OS << "  NF = " + utostr(getNF()) + ";\n";
836   if (hasManualCodegen()) {
837     OS << ManualCodegen;
838     OS << "break;\n";
839     return;
840   }
841 
842   if (isMask()) {
843     if (hasVL()) {
844       if (hasPolicy()) {
845         if (isMaskPolicyIntrinsic())
846           OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 2);\n";
847         else {
848           OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
849           OS << "  Ops.push_back(ConstantInt::get(Ops.back()->getType(), "
850                 "TAIL_AGNOSTIC));\n";
851         }
852       } else
853         OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
854     } else {
855       OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n";
856     }
857   }
858 
859   OS << "  IntrinsicTypes = {";
860   ListSeparator LS;
861   for (const auto &Idx : IntrinsicTypes) {
862     if (Idx == -1)
863       OS << LS << "ResultType";
864     else
865       OS << LS << "Ops[" << Idx << "]->getType()";
866   }
867 
868   // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is
869   // always last operand.
870   if (hasVL())
871     OS << ", Ops.back()->getType()";
872   OS << "};\n";
873   OS << "  break;\n";
874 }
875 
emitIntrinsicMacro(raw_ostream & OS) const876 void RVVIntrinsic::emitIntrinsicMacro(raw_ostream &OS) const {
877   OS << "#define " << getName() << "(";
878   if (!InputTypes.empty()) {
879     ListSeparator LS;
880     for (unsigned i = 0, e = InputTypes.size(); i != e; ++i)
881       OS << LS << "op" << i;
882   }
883   OS << ") \\\n";
884   OS << "__builtin_rvv_" << getName() << "(";
885   if (!InputTypes.empty()) {
886     ListSeparator LS;
887     for (unsigned i = 0, e = InputTypes.size(); i != e; ++i)
888       OS << LS << "(" << InputTypes[i]->getTypeStr() << ")(op" << i << ")";
889   }
890   OS << ")\n";
891 }
892 
emitMangledFuncDef(raw_ostream & OS) const893 void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const {
894   OS << "__attribute__((clang_builtin_alias(";
895   OS << "__builtin_rvv_" << getName() << ")))\n";
896   OS << OutputType->getTypeStr() << " " << getMangledName() << "(";
897   // Emit function arguments
898   if (!InputTypes.empty()) {
899     ListSeparator LS;
900     for (unsigned i = 0; i < InputTypes.size(); ++i)
901       OS << LS << InputTypes[i]->getTypeStr() << " op" << i;
902   }
903   OS << ");\n\n";
904 }
905 
906 //===----------------------------------------------------------------------===//
907 // RVVEmitter implementation
908 //===----------------------------------------------------------------------===//
createHeader(raw_ostream & OS)909 void RVVEmitter::createHeader(raw_ostream &OS) {
910 
911   OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics "
912         "-------------------===\n"
913         " *\n"
914         " *\n"
915         " * Part of the LLVM Project, under the Apache License v2.0 with LLVM "
916         "Exceptions.\n"
917         " * See https://llvm.org/LICENSE.txt for license information.\n"
918         " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n"
919         " *\n"
920         " *===-----------------------------------------------------------------"
921         "------===\n"
922         " */\n\n";
923 
924   OS << "#ifndef __RISCV_VECTOR_H\n";
925   OS << "#define __RISCV_VECTOR_H\n\n";
926 
927   OS << "#include <stdint.h>\n";
928   OS << "#include <stddef.h>\n\n";
929 
930   OS << "#ifndef __riscv_vector\n";
931   OS << "#error \"Vector intrinsics require the vector extension.\"\n";
932   OS << "#endif\n\n";
933 
934   OS << "#ifdef __cplusplus\n";
935   OS << "extern \"C\" {\n";
936   OS << "#endif\n\n";
937 
938   createRVVHeaders(OS);
939 
940   std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
941   createRVVIntrinsics(Defs);
942 
943   // Print header code
944   if (!HeaderCode.empty()) {
945     OS << HeaderCode;
946   }
947 
948   auto printType = [&](auto T) {
949     OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr()
950        << ";\n";
951   };
952 
953   constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
954   // Print RVV boolean types.
955   for (int Log2LMUL : Log2LMULs) {
956     auto T = computeType('c', Log2LMUL, "m");
957     if (T.hasValue())
958       printType(T.getValue());
959   }
960   // Print RVV int/float types.
961   for (char I : StringRef("csil")) {
962     for (int Log2LMUL : Log2LMULs) {
963       auto T = computeType(I, Log2LMUL, "v");
964       if (T.hasValue()) {
965         printType(T.getValue());
966         auto UT = computeType(I, Log2LMUL, "Uv");
967         printType(UT.getValue());
968       }
969     }
970   }
971   OS << "#if defined(__riscv_zfh)\n";
972   for (int Log2LMUL : Log2LMULs) {
973     auto T = computeType('x', Log2LMUL, "v");
974     if (T.hasValue())
975       printType(T.getValue());
976   }
977   OS << "#endif\n";
978 
979   OS << "#if defined(__riscv_f)\n";
980   for (int Log2LMUL : Log2LMULs) {
981     auto T = computeType('f', Log2LMUL, "v");
982     if (T.hasValue())
983       printType(T.getValue());
984   }
985   OS << "#endif\n";
986 
987   OS << "#if defined(__riscv_d)\n";
988   for (int Log2LMUL : Log2LMULs) {
989     auto T = computeType('d', Log2LMUL, "v");
990     if (T.hasValue())
991       printType(T.getValue());
992   }
993   OS << "#endif\n\n";
994 
995   // The same extension include in the same arch guard marco.
996   std::stable_sort(Defs.begin(), Defs.end(),
997                    [](const std::unique_ptr<RVVIntrinsic> &A,
998                       const std::unique_ptr<RVVIntrinsic> &B) {
999                      return A->getRISCVExtensions() < B->getRISCVExtensions();
1000                    });
1001 
1002   // Print intrinsic functions with macro
1003   emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
1004     Inst.emitIntrinsicMacro(OS);
1005   });
1006 
1007   OS << "#define __riscv_v_intrinsic_overloading 1\n";
1008 
1009   // Print Overloaded APIs
1010   OS << "#define __rvv_overloaded static inline "
1011         "__attribute__((__always_inline__, __nodebug__, __overloadable__))\n";
1012 
1013   emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
1014     if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded())
1015       return;
1016     OS << "__rvv_overloaded ";
1017     Inst.emitMangledFuncDef(OS);
1018   });
1019 
1020   OS << "\n#ifdef __cplusplus\n";
1021   OS << "}\n";
1022   OS << "#endif // __cplusplus\n";
1023   OS << "#endif // __RISCV_VECTOR_H\n";
1024 }
1025 
createBuiltins(raw_ostream & OS)1026 void RVVEmitter::createBuiltins(raw_ostream &OS) {
1027   std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1028   createRVVIntrinsics(Defs);
1029 
1030   OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n";
1031   OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, "
1032         "ATTRS, \"experimental-v\")\n";
1033   OS << "#endif\n";
1034   for (auto &Def : Defs) {
1035     OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getName() << ",\""
1036        << Def->getBuiltinTypeStr() << "\", ";
1037     if (!Def->hasSideEffects())
1038       OS << "\"n\")\n";
1039     else
1040       OS << "\"\")\n";
1041   }
1042   OS << "#undef RISCVV_BUILTIN\n";
1043 }
1044 
createCodeGen(raw_ostream & OS)1045 void RVVEmitter::createCodeGen(raw_ostream &OS) {
1046   std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1047   createRVVIntrinsics(Defs);
1048   // IR name could be empty, use the stable sort preserves the relative order.
1049   std::stable_sort(Defs.begin(), Defs.end(),
1050                    [](const std::unique_ptr<RVVIntrinsic> &A,
1051                       const std::unique_ptr<RVVIntrinsic> &B) {
1052                      int Cmp = A->getIRName().compare(B->getIRName());
1053                      if (Cmp != 0)
1054                        return Cmp < 0;
1055                      // Some mask intrinsics use the same IRName as unmasked.
1056                      // Sort the unmasked intrinsics first.
1057                      if (A->isMask() != B->isMask())
1058                        return A->isMask() < B->isMask();
1059                      // _m and _mt intrinsics use the same IRName.
1060                      // Sort the _m intrinsics first.
1061                      return A->isMaskPolicyIntrinsic() <
1062                             B->isMaskPolicyIntrinsic();
1063                    });
1064   // Print switch body when the ir name or ManualCodegen changes from previous
1065   // iteration.
1066   RVVIntrinsic *PrevDef = Defs.begin()->get();
1067   for (auto &Def : Defs) {
1068     StringRef CurIRName = Def->getIRName();
1069     if (CurIRName != PrevDef->getIRName() ||
1070         (Def->getManualCodegen() != PrevDef->getManualCodegen()) ||
1071         (Def->isMaskPolicyIntrinsic() != PrevDef->isMaskPolicyIntrinsic())) {
1072       PrevDef->emitCodeGenSwitchBody(OS);
1073     }
1074     PrevDef = Def.get();
1075     OS << "case RISCV::BI__builtin_rvv_" << Def->getName() << ":\n";
1076   }
1077   Defs.back()->emitCodeGenSwitchBody(OS);
1078   OS << "\n";
1079 }
1080 
parsePrototypes(StringRef Prototypes,std::function<void (StringRef)> Handler)1081 void RVVEmitter::parsePrototypes(StringRef Prototypes,
1082                                  std::function<void(StringRef)> Handler) {
1083   const StringRef Primaries("evwqom0ztul");
1084   while (!Prototypes.empty()) {
1085     size_t Idx = 0;
1086     // Skip over complex prototype because it could contain primitive type
1087     // character.
1088     if (Prototypes[0] == '(')
1089       Idx = Prototypes.find_first_of(')');
1090     Idx = Prototypes.find_first_of(Primaries, Idx);
1091     assert(Idx != StringRef::npos);
1092     Handler(Prototypes.slice(0, Idx + 1));
1093     Prototypes = Prototypes.drop_front(Idx + 1);
1094   }
1095 }
1096 
getSuffixStr(char Type,int Log2LMUL,StringRef Prototypes)1097 std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
1098                                      StringRef Prototypes) {
1099   SmallVector<std::string> SuffixStrs;
1100   parsePrototypes(Prototypes, [&](StringRef Proto) {
1101     auto T = computeType(Type, Log2LMUL, Proto);
1102     SuffixStrs.push_back(T.getValue()->getShortStr());
1103   });
1104   return join(SuffixStrs, "_");
1105 }
1106 
createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> & Out)1107 void RVVEmitter::createRVVIntrinsics(
1108     std::vector<std::unique_ptr<RVVIntrinsic>> &Out) {
1109   std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin");
1110   for (auto *R : RV) {
1111     StringRef Name = R->getValueAsString("Name");
1112     StringRef SuffixProto = R->getValueAsString("Suffix");
1113     StringRef MangledName = R->getValueAsString("MangledName");
1114     StringRef MangledSuffixProto = R->getValueAsString("MangledSuffix");
1115     StringRef Prototypes = R->getValueAsString("Prototype");
1116     StringRef TypeRange = R->getValueAsString("TypeRange");
1117     bool HasMask = R->getValueAsBit("HasMask");
1118     bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand");
1119     bool HasVL = R->getValueAsBit("HasVL");
1120     bool HasPolicy = R->getValueAsBit("HasPolicy");
1121     bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded");
1122     bool HasSideEffects = R->getValueAsBit("HasSideEffects");
1123     std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL");
1124     StringRef ManualCodegen = R->getValueAsString("ManualCodegen");
1125     StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask");
1126     std::vector<int64_t> IntrinsicTypes =
1127         R->getValueAsListOfInts("IntrinsicTypes");
1128     StringRef RequiredExtension = R->getValueAsString("RequiredExtension");
1129     StringRef IRName = R->getValueAsString("IRName");
1130     StringRef IRNameMask = R->getValueAsString("IRNameMask");
1131     unsigned NF = R->getValueAsInt("NF");
1132 
1133     StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
1134     bool HasAutoDef = HeaderCodeStr.empty();
1135     if (!HeaderCodeStr.empty()) {
1136       HeaderCode += HeaderCodeStr.str();
1137     }
1138     // Parse prototype and create a list of primitive type with transformers
1139     // (operand) in ProtoSeq. ProtoSeq[0] is output operand.
1140     SmallVector<std::string> ProtoSeq;
1141     parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
1142       ProtoSeq.push_back(Proto.str());
1143     });
1144 
1145     // Compute Builtin types
1146     SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
1147     if (HasMask) {
1148       // If HasMaskedOffOperand, insert result type as first input operand.
1149       if (HasMaskedOffOperand) {
1150         if (NF == 1) {
1151           ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]);
1152         } else {
1153           // Convert
1154           // (void, op0 address, op1 address, ...)
1155           // to
1156           // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1157           for (unsigned I = 0; I < NF; ++I)
1158             ProtoMaskSeq.insert(
1159                 ProtoMaskSeq.begin() + NF + 1,
1160                 ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
1161         }
1162       }
1163       if (HasMaskedOffOperand && NF > 1) {
1164         // Convert
1165         // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1166         // to
1167         // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1168         // ...)
1169         ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
1170       } else {
1171         // If HasMask, insert 'm' as first input operand.
1172         ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
1173       }
1174     }
1175     // If HasVL, append 'z' to the operand list.
1176     if (HasVL) {
1177       ProtoSeq.push_back("z");
1178       ProtoMaskSeq.push_back("z");
1179     }
1180 
1181     SmallVector<std::string> ProtoMaskPolicySeq = ProtoMaskSeq;
1182     if (HasPolicy) {
1183       ProtoMaskPolicySeq.push_back("Kz");
1184     }
1185 
1186     // Create Intrinsics for each type and LMUL.
1187     for (char I : TypeRange) {
1188       for (int Log2LMUL : Log2LMULList) {
1189         Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
1190         // Ignored to create new intrinsic if there are any illegal types.
1191         if (!Types.hasValue())
1192           continue;
1193 
1194         auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
1195         auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
1196         // Create a non-mask intrinsic
1197         Out.push_back(std::make_unique<RVVIntrinsic>(
1198             Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
1199             HasSideEffects, /*IsMask=*/false, /*HasMaskedOffOperand=*/false,
1200             HasVL, HasPolicy, /*IsMaskPolicyIntrinsic*/ false,
1201             HasNoMaskedOverloaded, HasAutoDef, ManualCodegen, Types.getValue(),
1202             IntrinsicTypes, RequiredExtension, NF));
1203         if (HasMask) {
1204           // Create a mask intrinsic
1205           Optional<RVVTypes> MaskTypes =
1206               computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
1207           Out.push_back(std::make_unique<RVVIntrinsic>(
1208               Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask,
1209               HasSideEffects, /*IsMask=*/true, HasMaskedOffOperand, HasVL,
1210               HasPolicy, /*IsMaskPolicyIntrinsic=*/false, HasNoMaskedOverloaded,
1211               HasAutoDef, ManualCodegenMask, MaskTypes.getValue(),
1212               IntrinsicTypes, RequiredExtension, NF));
1213           if (HasPolicy) {
1214             MaskTypes = computeTypes(I, Log2LMUL, NF, ProtoMaskPolicySeq);
1215             Out.push_back(std::make_unique<RVVIntrinsic>(
1216                 Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask,
1217                 HasSideEffects, /*IsMask=*/true, HasMaskedOffOperand, HasVL,
1218                 HasPolicy, /*IsMaskPolicyIntrinsic=*/true,
1219                 HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask,
1220                 MaskTypes.getValue(), IntrinsicTypes, RequiredExtension, NF));
1221           }
1222         }
1223       } // end for Log2LMULList
1224     }   // end for TypeRange
1225   }
1226 }
1227 
createRVVHeaders(raw_ostream & OS)1228 void RVVEmitter::createRVVHeaders(raw_ostream &OS) {
1229   std::vector<Record *> RVVHeaders =
1230       Records.getAllDerivedDefinitions("RVVHeader");
1231   for (auto *R : RVVHeaders) {
1232     StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
1233     OS << HeaderCodeStr.str();
1234   }
1235 }
1236 
1237 Optional<RVVTypes>
computeTypes(BasicType BT,int Log2LMUL,unsigned NF,ArrayRef<std::string> PrototypeSeq)1238 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
1239                          ArrayRef<std::string> PrototypeSeq) {
1240   // LMUL x NF must be less than or equal to 8.
1241   if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
1242     return llvm::None;
1243 
1244   RVVTypes Types;
1245   for (const std::string &Proto : PrototypeSeq) {
1246     auto T = computeType(BT, Log2LMUL, Proto);
1247     if (!T.hasValue())
1248       return llvm::None;
1249     // Record legal type index
1250     Types.push_back(T.getValue());
1251   }
1252   return Types;
1253 }
1254 
computeType(BasicType BT,int Log2LMUL,StringRef Proto)1255 Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
1256                                              StringRef Proto) {
1257   std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
1258   // Search first
1259   auto It = LegalTypes.find(Idx);
1260   if (It != LegalTypes.end())
1261     return &(It->second);
1262   if (IllegalTypes.count(Idx))
1263     return llvm::None;
1264   // Compute type and record the result.
1265   RVVType T(BT, Log2LMUL, Proto);
1266   if (T.isValid()) {
1267     // Record legal type index and value.
1268     LegalTypes.insert({Idx, T});
1269     return &(LegalTypes[Idx]);
1270   }
1271   // Record illegal type index.
1272   IllegalTypes.insert(Idx);
1273   return llvm::None;
1274 }
1275 
emitArchMacroAndBody(std::vector<std::unique_ptr<RVVIntrinsic>> & Defs,raw_ostream & OS,std::function<void (raw_ostream &,const RVVIntrinsic &)> PrintBody)1276 void RVVEmitter::emitArchMacroAndBody(
1277     std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS,
1278     std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) {
1279   uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions();
1280   bool NeedEndif = emitExtDefStr(PrevExt, OS);
1281   for (auto &Def : Defs) {
1282     uint8_t CurExt = Def->getRISCVExtensions();
1283     if (CurExt != PrevExt) {
1284       if (NeedEndif)
1285         OS << "#endif\n\n";
1286       NeedEndif = emitExtDefStr(CurExt, OS);
1287       PrevExt = CurExt;
1288     }
1289     if (Def->hasAutoDef())
1290       PrintBody(OS, *Def);
1291   }
1292   if (NeedEndif)
1293     OS << "#endif\n\n";
1294 }
1295 
emitExtDefStr(uint8_t Extents,raw_ostream & OS)1296 bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) {
1297   if (Extents == RISCVExtension::Basic)
1298     return false;
1299   OS << "#if ";
1300   ListSeparator LS(" && ");
1301   if (Extents & RISCVExtension::F)
1302     OS << LS << "defined(__riscv_f)";
1303   if (Extents & RISCVExtension::D)
1304     OS << LS << "defined(__riscv_d)";
1305   if (Extents & RISCVExtension::Zfh)
1306     OS << LS << "defined(__riscv_zfh)";
1307   if (Extents & RISCVExtension::Zvamo)
1308     OS << LS << "defined(__riscv_zvamo)";
1309   if (Extents & RISCVExtension::Zvlsseg)
1310     OS << LS << "defined(__riscv_zvlsseg)";
1311   OS << "\n";
1312   return true;
1313 }
1314 
1315 namespace clang {
EmitRVVHeader(RecordKeeper & Records,raw_ostream & OS)1316 void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) {
1317   RVVEmitter(Records).createHeader(OS);
1318 }
1319 
EmitRVVBuiltins(RecordKeeper & Records,raw_ostream & OS)1320 void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) {
1321   RVVEmitter(Records).createBuiltins(OS);
1322 }
1323 
EmitRVVBuiltinCG(RecordKeeper & Records,raw_ostream & OS)1324 void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) {
1325   RVVEmitter(Records).createCodeGen(OS);
1326 }
1327 
1328 } // End namespace clang
1329