1 //===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This tablegen backend is responsible for emitting riscv_vector.h which
10 // includes a declaration and definition of each intrinsic functions specified
11 // in https://github.com/riscv/rvv-intrinsic-doc.
12 //
13 // See also the documentation in include/clang/Basic/riscv_vector.td.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/TableGen/Error.h"
24 #include "llvm/TableGen/Record.h"
25 #include <numeric>
26 
27 using namespace llvm;
28 using BasicType = char;
29 using VScaleVal = Optional<unsigned>;
30 
31 namespace {
32 
33 // Exponential LMUL
34 struct LMULType {
35   int Log2LMUL;
36   LMULType(int Log2LMUL);
37   // Return the C/C++ string representation of LMUL
38   std::string str() const;
39   Optional<unsigned> getScale(unsigned ElementBitwidth) const;
40   void MulLog2LMUL(int Log2LMUL);
41   LMULType &operator*=(uint32_t RHS);
42 };
43 
44 // This class is compact representation of a valid and invalid RVVType.
45 class RVVType {
46   enum ScalarTypeKind : uint32_t {
47     Void,
48     Size_t,
49     Ptrdiff_t,
50     UnsignedLong,
51     SignedLong,
52     Boolean,
53     SignedInteger,
54     UnsignedInteger,
55     Float,
56     Invalid,
57   };
58   BasicType BT;
59   ScalarTypeKind ScalarType = Invalid;
60   LMULType LMUL;
61   bool IsPointer = false;
62   // IsConstant indices are "int", but have the constant expression.
63   bool IsImmediate = false;
64   // Const qualifier for pointer to const object or object of const type.
65   bool IsConstant = false;
66   unsigned ElementBitwidth = 0;
67   VScaleVal Scale = 0;
68   bool Valid;
69 
70   std::string BuiltinStr;
71   std::string ClangBuiltinStr;
72   std::string Str;
73   std::string ShortStr;
74 
75 public:
RVVType()76   RVVType() : RVVType(BasicType(), 0, StringRef()) {}
77   RVVType(BasicType BT, int Log2LMUL, StringRef prototype);
78 
79   // Return the string representation of a type, which is an encoded string for
80   // passing to the BUILTIN() macro in Builtins.def.
getBuiltinStr() const81   const std::string &getBuiltinStr() const { return BuiltinStr; }
82 
83   // Return the clang buitlin type for RVV vector type which are used in the
84   // riscv_vector.h header file.
getClangBuiltinStr() const85   const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; }
86 
87   // Return the C/C++ string representation of a type for use in the
88   // riscv_vector.h header file.
getTypeStr() const89   const std::string &getTypeStr() const { return Str; }
90 
91   // Return the short name of a type for C/C++ name suffix.
getShortStr()92   const std::string &getShortStr() {
93     // Not all types are used in short name, so compute the short name by
94     // demanded.
95     if (ShortStr.empty())
96       initShortStr();
97     return ShortStr;
98   }
99 
isValid() const100   bool isValid() const { return Valid; }
isScalar() const101   bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; }
isVector() const102   bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; }
isFloat() const103   bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
isSignedInteger() const104   bool isSignedInteger() const {
105     return ScalarType == ScalarTypeKind::SignedInteger;
106   }
isFloatVector(unsigned Width) const107   bool isFloatVector(unsigned Width) const {
108     return isVector() && isFloat() && ElementBitwidth == Width;
109   }
isFloat(unsigned Width) const110   bool isFloat(unsigned Width) const {
111     return isFloat() && ElementBitwidth == Width;
112   }
113 
114 private:
115   // Verify RVV vector type and set Valid.
116   bool verifyType() const;
117 
118   // Creates a type based on basic types of TypeRange
119   void applyBasicType();
120 
121   // Applies a prototype modifier to the current type. The result maybe an
122   // invalid type.
123   void applyModifier(StringRef prototype);
124 
125   // Compute and record a string for legal type.
126   void initBuiltinStr();
127   // Compute and record a builtin RVV vector type string.
128   void initClangBuiltinStr();
129   // Compute and record a type string for used in the header.
130   void initTypeStr();
131   // Compute and record a short name of a type for C/C++ name suffix.
132   void initShortStr();
133 };
134 
135 using RVVTypePtr = RVVType *;
136 using RVVTypes = std::vector<RVVTypePtr>;
137 
138 enum RISCVExtension : uint8_t {
139   Basic = 0,
140   F = 1 << 1,
141   D = 1 << 2,
142   Zfh = 1 << 3,
143   Zvamo = 1 << 4,
144   Zvlsseg = 1 << 5,
145 };
146 
147 // TODO refactor RVVIntrinsic class design after support all intrinsic
148 // combination. This represents an instantiation of an intrinsic with a
149 // particular type and prototype
150 class RVVIntrinsic {
151 
152 private:
153   std::string Name; // Builtin name
154   std::string MangledName;
155   std::string IRName;
156   bool HasSideEffects;
157   bool IsMask;
158   bool HasMaskedOffOperand;
159   bool HasVL;
160   bool HasNoMaskedOverloaded;
161   bool HasAutoDef; // There is automiatic definition in header
162   std::string ManualCodegen;
163   RVVTypePtr OutputType; // Builtin output type
164   RVVTypes InputTypes;   // Builtin input types
165   // The types we use to obtain the specific LLVM intrinsic. They are index of
166   // InputTypes. -1 means the return type.
167   std::vector<int64_t> IntrinsicTypes;
168   uint8_t RISCVExtensions = 0;
169   unsigned NF = 1;
170 
171 public:
172   RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
173                StringRef MangledSuffix, StringRef IRName, bool HasSideEffects,
174                bool IsMask, bool HasMaskedOffOperand, bool HasVL,
175                bool HasNoMaskedOverloaded, bool HasAutoDef,
176                StringRef ManualCodegen, const RVVTypes &Types,
177                const std::vector<int64_t> &IntrinsicTypes,
178                StringRef RequiredExtension, unsigned NF);
179   ~RVVIntrinsic() = default;
180 
getName() const181   StringRef getName() const { return Name; }
getMangledName() const182   StringRef getMangledName() const { return MangledName; }
hasSideEffects() const183   bool hasSideEffects() const { return HasSideEffects; }
hasMaskedOffOperand() const184   bool hasMaskedOffOperand() const { return HasMaskedOffOperand; }
hasVL() const185   bool hasVL() const { return HasVL; }
hasNoMaskedOverloaded() const186   bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; }
hasManualCodegen() const187   bool hasManualCodegen() const { return !ManualCodegen.empty(); }
hasAutoDef() const188   bool hasAutoDef() const { return HasAutoDef; }
isMask() const189   bool isMask() const { return IsMask; }
getIRName() const190   StringRef getIRName() const { return IRName; }
getManualCodegen() const191   StringRef getManualCodegen() const { return ManualCodegen; }
getRISCVExtensions() const192   uint8_t getRISCVExtensions() const { return RISCVExtensions; }
getNF() const193   unsigned getNF() const { return NF; }
194 
195   // Return the type string for a BUILTIN() macro in Builtins.def.
196   std::string getBuiltinTypeStr() const;
197 
198   // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should
199   // init the RVVIntrinsic ID and IntrinsicTypes.
200   void emitCodeGenSwitchBody(raw_ostream &o) const;
201 
202   // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
203   void emitIntrinsicMacro(raw_ostream &o) const;
204 
205   // Emit the mangled function definition.
206   void emitMangledFuncDef(raw_ostream &o) const;
207 };
208 
209 class RVVEmitter {
210 private:
211   RecordKeeper &Records;
212   std::string HeaderCode;
213   // Concat BasicType, LMUL and Proto as key
214   StringMap<RVVType> LegalTypes;
215   StringSet<> IllegalTypes;
216 
217 public:
RVVEmitter(RecordKeeper & R)218   RVVEmitter(RecordKeeper &R) : Records(R) {}
219 
220   /// Emit riscv_vector.h
221   void createHeader(raw_ostream &o);
222 
223   /// Emit all the __builtin prototypes and code needed by Sema.
224   void createBuiltins(raw_ostream &o);
225 
226   /// Emit all the information needed to map builtin -> LLVM IR intrinsic.
227   void createCodeGen(raw_ostream &o);
228 
229   std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
230 
231 private:
232   /// Create all intrinsics and add them to \p Out
233   void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
234   /// Compute output and input types by applying different config (basic type
235   /// and LMUL with type transformers). It also record result of type in legal
236   /// or illegal set to avoid compute the  same config again. The result maybe
237   /// have illegal RVVType.
238   Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
239                                   ArrayRef<std::string> PrototypeSeq);
240   Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
241 
242   /// Emit Acrh predecessor definitions and body, assume the element of Defs are
243   /// sorted by extension.
244   void emitArchMacroAndBody(
245       std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o,
246       std::function<void(raw_ostream &, const RVVIntrinsic &)>);
247 
248   // Emit the architecture preprocessor definitions. Return true when emits
249   // non-empty string.
250   bool emitExtDefStr(uint8_t Extensions, raw_ostream &o);
251   // Slice Prototypes string into sub prototype string and process each sub
252   // prototype string individually in the Handler.
253   void parsePrototypes(StringRef Prototypes,
254                        std::function<void(StringRef)> Handler);
255 };
256 
257 } // namespace
258 
259 //===----------------------------------------------------------------------===//
260 // Type implementation
261 //===----------------------------------------------------------------------===//
262 
LMULType(int NewLog2LMUL)263 LMULType::LMULType(int NewLog2LMUL) {
264   // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
265   assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
266   Log2LMUL = NewLog2LMUL;
267 }
268 
str() const269 std::string LMULType::str() const {
270   if (Log2LMUL < 0)
271     return "mf" + utostr(1ULL << (-Log2LMUL));
272   return "m" + utostr(1ULL << Log2LMUL);
273 }
274 
getScale(unsigned ElementBitwidth) const275 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
276   int Log2ScaleResult = 0;
277   switch (ElementBitwidth) {
278   default:
279     break;
280   case 8:
281     Log2ScaleResult = Log2LMUL + 3;
282     break;
283   case 16:
284     Log2ScaleResult = Log2LMUL + 2;
285     break;
286   case 32:
287     Log2ScaleResult = Log2LMUL + 1;
288     break;
289   case 64:
290     Log2ScaleResult = Log2LMUL;
291     break;
292   }
293   // Illegal vscale result would be less than 1
294   if (Log2ScaleResult < 0)
295     return None;
296   return 1 << Log2ScaleResult;
297 }
298 
MulLog2LMUL(int log2LMUL)299 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
300 
operator *=(uint32_t RHS)301 LMULType &LMULType::operator*=(uint32_t RHS) {
302   assert(isPowerOf2_32(RHS));
303   this->Log2LMUL = this->Log2LMUL + Log2_32(RHS);
304   return *this;
305 }
306 
RVVType(BasicType BT,int Log2LMUL,StringRef prototype)307 RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
308     : BT(BT), LMUL(LMULType(Log2LMUL)) {
309   applyBasicType();
310   applyModifier(prototype);
311   Valid = verifyType();
312   if (Valid) {
313     initBuiltinStr();
314     initTypeStr();
315     if (isVector()) {
316       initClangBuiltinStr();
317     }
318   }
319 }
320 
321 // clang-format off
322 // boolean type are encoded the ratio of n (SEW/LMUL)
323 // SEW/LMUL | 1         | 2         | 4         | 8        | 16        | 32        | 64
324 // c type   | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t  | vbool2_t  | vbool1_t
325 // IR type  | nxv1i1    | nxv2i1    | nxv4i1    | nxv8i1   | nxv16i1   | nxv32i1   | nxv64i1
326 
327 // type\lmul | 1/8    | 1/4      | 1/2     | 1       | 2        | 4        | 8
328 // --------  |------  | -------- | ------- | ------- | -------- | -------- | --------
329 // i64       | N/A    | N/A      | N/A     | nxv1i64 | nxv2i64  | nxv4i64  | nxv8i64
330 // i32       | N/A    | N/A      | nxv1i32 | nxv2i32 | nxv4i32  | nxv8i32  | nxv16i32
331 // i16       | N/A    | nxv1i16  | nxv2i16 | nxv4i16 | nxv8i16  | nxv16i16 | nxv32i16
332 // i8        | nxv1i8 | nxv2i8   | nxv4i8  | nxv8i8  | nxv16i8  | nxv32i8  | nxv64i8
333 // double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
334 // float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
335 // half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
336 // clang-format on
337 
verifyType() const338 bool RVVType::verifyType() const {
339   if (ScalarType == Invalid)
340     return false;
341   if (isScalar())
342     return true;
343   if (!Scale.hasValue())
344     return false;
345   if (isFloat() && ElementBitwidth == 8)
346     return false;
347   unsigned V = Scale.getValue();
348   switch (ElementBitwidth) {
349   case 1:
350   case 8:
351     // Check Scale is 1,2,4,8,16,32,64
352     return (V <= 64 && isPowerOf2_32(V));
353   case 16:
354     // Check Scale is 1,2,4,8,16,32
355     return (V <= 32 && isPowerOf2_32(V));
356   case 32:
357     // Check Scale is 1,2,4,8,16
358     return (V <= 16 && isPowerOf2_32(V));
359   case 64:
360     // Check Scale is 1,2,4,8
361     return (V <= 8 && isPowerOf2_32(V));
362   }
363   return false;
364 }
365 
initBuiltinStr()366 void RVVType::initBuiltinStr() {
367   assert(isValid() && "RVVType is invalid");
368   switch (ScalarType) {
369   case ScalarTypeKind::Void:
370     BuiltinStr = "v";
371     return;
372   case ScalarTypeKind::Size_t:
373     BuiltinStr = "z";
374     if (IsImmediate)
375       BuiltinStr = "I" + BuiltinStr;
376     if (IsPointer)
377       BuiltinStr += "*";
378     return;
379   case ScalarTypeKind::Ptrdiff_t:
380     BuiltinStr = "Y";
381     return;
382   case ScalarTypeKind::UnsignedLong:
383     BuiltinStr = "ULi";
384     return;
385   case ScalarTypeKind::SignedLong:
386     BuiltinStr = "Li";
387     return;
388   case ScalarTypeKind::Boolean:
389     assert(ElementBitwidth == 1);
390     BuiltinStr += "b";
391     break;
392   case ScalarTypeKind::SignedInteger:
393   case ScalarTypeKind::UnsignedInteger:
394     switch (ElementBitwidth) {
395     case 8:
396       BuiltinStr += "c";
397       break;
398     case 16:
399       BuiltinStr += "s";
400       break;
401     case 32:
402       BuiltinStr += "i";
403       break;
404     case 64:
405       BuiltinStr += "Wi";
406       break;
407     default:
408       llvm_unreachable("Unhandled ElementBitwidth!");
409     }
410     if (isSignedInteger())
411       BuiltinStr = "S" + BuiltinStr;
412     else
413       BuiltinStr = "U" + BuiltinStr;
414     break;
415   case ScalarTypeKind::Float:
416     switch (ElementBitwidth) {
417     case 16:
418       BuiltinStr += "x";
419       break;
420     case 32:
421       BuiltinStr += "f";
422       break;
423     case 64:
424       BuiltinStr += "d";
425       break;
426     default:
427       llvm_unreachable("Unhandled ElementBitwidth!");
428     }
429     break;
430   default:
431     llvm_unreachable("ScalarType is invalid!");
432   }
433   if (IsImmediate)
434     BuiltinStr = "I" + BuiltinStr;
435   if (isScalar()) {
436     if (IsConstant)
437       BuiltinStr += "C";
438     if (IsPointer)
439       BuiltinStr += "*";
440     return;
441   }
442   BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr;
443   // Pointer to vector types. Defined for Zvlsseg load intrinsics.
444   // Zvlsseg load intrinsics have pointer type arguments to store the loaded
445   // vector values.
446   if (IsPointer)
447     BuiltinStr += "*";
448 }
449 
initClangBuiltinStr()450 void RVVType::initClangBuiltinStr() {
451   assert(isValid() && "RVVType is invalid");
452   assert(isVector() && "Handle Vector type only");
453 
454   ClangBuiltinStr = "__rvv_";
455   switch (ScalarType) {
456   case ScalarTypeKind::Boolean:
457     ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t";
458     return;
459   case ScalarTypeKind::Float:
460     ClangBuiltinStr += "float";
461     break;
462   case ScalarTypeKind::SignedInteger:
463     ClangBuiltinStr += "int";
464     break;
465   case ScalarTypeKind::UnsignedInteger:
466     ClangBuiltinStr += "uint";
467     break;
468   default:
469     llvm_unreachable("ScalarTypeKind is invalid");
470   }
471   ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
472 }
473 
initTypeStr()474 void RVVType::initTypeStr() {
475   assert(isValid() && "RVVType is invalid");
476 
477   if (IsConstant)
478     Str += "const ";
479 
480   auto getTypeString = [&](StringRef TypeStr) {
481     if (isScalar())
482       return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
483     return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
484         .str();
485   };
486 
487   switch (ScalarType) {
488   case ScalarTypeKind::Void:
489     Str = "void";
490     return;
491   case ScalarTypeKind::Size_t:
492     Str = "size_t";
493     if (IsPointer)
494       Str += " *";
495     return;
496   case ScalarTypeKind::Ptrdiff_t:
497     Str = "ptrdiff_t";
498     return;
499   case ScalarTypeKind::UnsignedLong:
500     Str = "unsigned long";
501     return;
502   case ScalarTypeKind::SignedLong:
503     Str = "long";
504     return;
505   case ScalarTypeKind::Boolean:
506     if (isScalar())
507       Str += "bool";
508     else
509       // Vector bool is special case, the formulate is
510       // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
511       Str += "vbool" + utostr(64 / Scale.getValue()) + "_t";
512     break;
513   case ScalarTypeKind::Float:
514     if (isScalar()) {
515       if (ElementBitwidth == 64)
516         Str += "double";
517       else if (ElementBitwidth == 32)
518         Str += "float";
519       else if (ElementBitwidth == 16)
520         Str += "_Float16";
521       else
522         llvm_unreachable("Unhandled floating type.");
523     } else
524       Str += getTypeString("float");
525     break;
526   case ScalarTypeKind::SignedInteger:
527     Str += getTypeString("int");
528     break;
529   case ScalarTypeKind::UnsignedInteger:
530     Str += getTypeString("uint");
531     break;
532   default:
533     llvm_unreachable("ScalarType is invalid!");
534   }
535   if (IsPointer)
536     Str += " *";
537 }
538 
initShortStr()539 void RVVType::initShortStr() {
540   switch (ScalarType) {
541   case ScalarTypeKind::Boolean:
542     assert(isVector());
543     ShortStr = "b" + utostr(64 / Scale.getValue());
544     return;
545   case ScalarTypeKind::Float:
546     ShortStr = "f" + utostr(ElementBitwidth);
547     break;
548   case ScalarTypeKind::SignedInteger:
549     ShortStr = "i" + utostr(ElementBitwidth);
550     break;
551   case ScalarTypeKind::UnsignedInteger:
552     ShortStr = "u" + utostr(ElementBitwidth);
553     break;
554   default:
555     PrintFatalError("Unhandled case!");
556   }
557   if (isVector())
558     ShortStr += LMUL.str();
559 }
560 
applyBasicType()561 void RVVType::applyBasicType() {
562   switch (BT) {
563   case 'c':
564     ElementBitwidth = 8;
565     ScalarType = ScalarTypeKind::SignedInteger;
566     break;
567   case 's':
568     ElementBitwidth = 16;
569     ScalarType = ScalarTypeKind::SignedInteger;
570     break;
571   case 'i':
572     ElementBitwidth = 32;
573     ScalarType = ScalarTypeKind::SignedInteger;
574     break;
575   case 'l':
576     ElementBitwidth = 64;
577     ScalarType = ScalarTypeKind::SignedInteger;
578     break;
579   case 'x':
580     ElementBitwidth = 16;
581     ScalarType = ScalarTypeKind::Float;
582     break;
583   case 'f':
584     ElementBitwidth = 32;
585     ScalarType = ScalarTypeKind::Float;
586     break;
587   case 'd':
588     ElementBitwidth = 64;
589     ScalarType = ScalarTypeKind::Float;
590     break;
591   default:
592     PrintFatalError("Unhandled type code!");
593   }
594   assert(ElementBitwidth != 0 && "Bad element bitwidth!");
595 }
596 
applyModifier(StringRef Transformer)597 void RVVType::applyModifier(StringRef Transformer) {
598   if (Transformer.empty())
599     return;
600   // Handle primitive type transformer
601   auto PType = Transformer.back();
602   switch (PType) {
603   case 'e':
604     Scale = 0;
605     break;
606   case 'v':
607     Scale = LMUL.getScale(ElementBitwidth);
608     break;
609   case 'w':
610     ElementBitwidth *= 2;
611     LMUL *= 2;
612     Scale = LMUL.getScale(ElementBitwidth);
613     break;
614   case 'q':
615     ElementBitwidth *= 4;
616     LMUL *= 4;
617     Scale = LMUL.getScale(ElementBitwidth);
618     break;
619   case 'o':
620     ElementBitwidth *= 8;
621     LMUL *= 8;
622     Scale = LMUL.getScale(ElementBitwidth);
623     break;
624   case 'm':
625     ScalarType = ScalarTypeKind::Boolean;
626     Scale = LMUL.getScale(ElementBitwidth);
627     ElementBitwidth = 1;
628     break;
629   case '0':
630     ScalarType = ScalarTypeKind::Void;
631     break;
632   case 'z':
633     ScalarType = ScalarTypeKind::Size_t;
634     break;
635   case 't':
636     ScalarType = ScalarTypeKind::Ptrdiff_t;
637     break;
638   case 'u':
639     ScalarType = ScalarTypeKind::UnsignedLong;
640     break;
641   case 'l':
642     ScalarType = ScalarTypeKind::SignedLong;
643     break;
644   default:
645     PrintFatalError("Illegal primitive type transformers!");
646   }
647   Transformer = Transformer.drop_back();
648 
649   // Extract and compute complex type transformer. It can only appear one time.
650   if (Transformer.startswith("(")) {
651     size_t Idx = Transformer.find(')');
652     assert(Idx != StringRef::npos);
653     StringRef ComplexType = Transformer.slice(1, Idx);
654     Transformer = Transformer.drop_front(Idx + 1);
655     assert(Transformer.find('(') == StringRef::npos &&
656            "Only allow one complex type transformer");
657 
658     auto UpdateAndCheckComplexProto = [&]() {
659       Scale = LMUL.getScale(ElementBitwidth);
660       const StringRef VectorPrototypes("vwqom");
661       if (!VectorPrototypes.contains(PType))
662         PrintFatalError("Complex type transformer only supports vector type!");
663       if (Transformer.find_first_of("PCKWS") != StringRef::npos)
664         PrintFatalError(
665             "Illegal type transformer for Complex type transformer");
666     };
667     auto ComputeFixedLog2LMUL =
668         [&](StringRef Value,
669             std::function<bool(const int32_t &, const int32_t &)> Compare) {
670           int32_t Log2LMUL;
671           Value.getAsInteger(10, Log2LMUL);
672           if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
673             ScalarType = Invalid;
674             return false;
675           }
676           // Update new LMUL
677           LMUL = LMULType(Log2LMUL);
678           UpdateAndCheckComplexProto();
679           return true;
680         };
681     auto ComplexTT = ComplexType.split(":");
682     if (ComplexTT.first == "Log2EEW") {
683       uint32_t Log2EEW;
684       ComplexTT.second.getAsInteger(10, Log2EEW);
685       // update new elmul = (eew/sew) * lmul
686       LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
687       // update new eew
688       ElementBitwidth = 1 << Log2EEW;
689       ScalarType = ScalarTypeKind::SignedInteger;
690       UpdateAndCheckComplexProto();
691     } else if (ComplexTT.first == "FixedSEW") {
692       uint32_t NewSEW;
693       ComplexTT.second.getAsInteger(10, NewSEW);
694       // Set invalid type if src and dst SEW are same.
695       if (ElementBitwidth == NewSEW) {
696         ScalarType = Invalid;
697         return;
698       }
699       // Update new SEW
700       ElementBitwidth = NewSEW;
701       UpdateAndCheckComplexProto();
702     } else if (ComplexTT.first == "LFixedLog2LMUL") {
703       // New LMUL should be larger than old
704       if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
705         return;
706     } else if (ComplexTT.first == "SFixedLog2LMUL") {
707       // New LMUL should be smaller than old
708       if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
709         return;
710     } else {
711       PrintFatalError("Illegal complex type transformers!");
712     }
713   }
714 
715   // Compute the remain type transformers
716   for (char I : Transformer) {
717     switch (I) {
718     case 'P':
719       if (IsConstant)
720         PrintFatalError("'P' transformer cannot be used after 'C'");
721       if (IsPointer)
722         PrintFatalError("'P' transformer cannot be used twice");
723       IsPointer = true;
724       break;
725     case 'C':
726       if (IsConstant)
727         PrintFatalError("'C' transformer cannot be used twice");
728       IsConstant = true;
729       break;
730     case 'K':
731       IsImmediate = true;
732       break;
733     case 'U':
734       ScalarType = ScalarTypeKind::UnsignedInteger;
735       break;
736     case 'I':
737       ScalarType = ScalarTypeKind::SignedInteger;
738       break;
739     case 'F':
740       ScalarType = ScalarTypeKind::Float;
741       break;
742     case 'S':
743       LMUL = LMULType(0);
744       // Update ElementBitwidth need to update Scale too.
745       Scale = LMUL.getScale(ElementBitwidth);
746       break;
747     default:
748       PrintFatalError("Illegal non-primitive type transformer!");
749     }
750   }
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // RVVIntrinsic implementation
755 //===----------------------------------------------------------------------===//
RVVIntrinsic(StringRef NewName,StringRef Suffix,StringRef NewMangledName,StringRef MangledSuffix,StringRef IRName,bool HasSideEffects,bool IsMask,bool HasMaskedOffOperand,bool HasVL,bool HasNoMaskedOverloaded,bool HasAutoDef,StringRef ManualCodegen,const RVVTypes & OutInTypes,const std::vector<int64_t> & NewIntrinsicTypes,StringRef RequiredExtension,unsigned NF)756 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
757                            StringRef NewMangledName, StringRef MangledSuffix,
758                            StringRef IRName, bool HasSideEffects, bool IsMask,
759                            bool HasMaskedOffOperand, bool HasVL,
760                            bool HasNoMaskedOverloaded, bool HasAutoDef,
761                            StringRef ManualCodegen, const RVVTypes &OutInTypes,
762                            const std::vector<int64_t> &NewIntrinsicTypes,
763                            StringRef RequiredExtension, unsigned NF)
764     : IRName(IRName), HasSideEffects(HasSideEffects), IsMask(IsMask),
765       HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL),
766       HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef),
767       ManualCodegen(ManualCodegen.str()), NF(NF) {
768 
769   // Init Name and MangledName
770   Name = NewName.str();
771   if (NewMangledName.empty())
772     MangledName = NewName.split("_").first.str();
773   else
774     MangledName = NewMangledName.str();
775   if (!Suffix.empty())
776     Name += "_" + Suffix.str();
777   if (!MangledSuffix.empty())
778     MangledName += "_" + MangledSuffix.str();
779   if (IsMask) {
780     Name += "_m";
781   }
782   // Init RISC-V extensions
783   for (const auto &T : OutInTypes) {
784     if (T->isFloatVector(16) || T->isFloat(16))
785       RISCVExtensions |= RISCVExtension::Zfh;
786     else if (T->isFloatVector(32) || T->isFloat(32))
787       RISCVExtensions |= RISCVExtension::F;
788     else if (T->isFloatVector(64) || T->isFloat(64))
789       RISCVExtensions |= RISCVExtension::D;
790   }
791   if (RequiredExtension == "Zvamo")
792     RISCVExtensions |= RISCVExtension::Zvamo;
793   if (RequiredExtension == "Zvlsseg")
794     RISCVExtensions |= RISCVExtension::Zvlsseg;
795 
796   // Init OutputType and InputTypes
797   OutputType = OutInTypes[0];
798   InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
799 
800   // IntrinsicTypes is nonmasked version index. Need to update it
801   // if there is maskedoff operand (It is always in first operand).
802   IntrinsicTypes = NewIntrinsicTypes;
803   if (IsMask && HasMaskedOffOperand) {
804     for (auto &I : IntrinsicTypes) {
805       if (I >= 0)
806         I += NF;
807     }
808   }
809 }
810 
getBuiltinTypeStr() const811 std::string RVVIntrinsic::getBuiltinTypeStr() const {
812   std::string S;
813   S += OutputType->getBuiltinStr();
814   for (const auto &T : InputTypes) {
815     S += T->getBuiltinStr();
816   }
817   return S;
818 }
819 
emitCodeGenSwitchBody(raw_ostream & OS) const820 void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const {
821   if (!getIRName().empty())
822     OS << "  ID = Intrinsic::riscv_" + getIRName() + ";\n";
823   if (NF >= 2)
824     OS << "  NF = " + utostr(getNF()) + ";\n";
825   if (hasManualCodegen()) {
826     OS << ManualCodegen;
827     OS << "break;\n";
828     return;
829   }
830 
831   if (isMask()) {
832     if (hasVL()) {
833       OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
834     } else {
835       OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n";
836     }
837   }
838 
839   OS << "  IntrinsicTypes = {";
840   ListSeparator LS;
841   for (const auto &Idx : IntrinsicTypes) {
842     if (Idx == -1)
843       OS << LS << "ResultType";
844     else
845       OS << LS << "Ops[" << Idx << "]->getType()";
846   }
847 
848   // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is
849   // always last operand.
850   if (hasVL())
851     OS << ", Ops.back()->getType()";
852   OS << "};\n";
853   OS << "  break;\n";
854 }
855 
emitIntrinsicMacro(raw_ostream & OS) const856 void RVVIntrinsic::emitIntrinsicMacro(raw_ostream &OS) const {
857   OS << "#define " << getName() << "(";
858   if (!InputTypes.empty()) {
859     ListSeparator LS;
860     for (unsigned i = 0, e = InputTypes.size(); i != e; ++i)
861       OS << LS << "op" << i;
862   }
863   OS << ") \\\n";
864   OS << "__builtin_rvv_" << getName() << "(";
865   if (!InputTypes.empty()) {
866     ListSeparator LS;
867     for (unsigned i = 0, e = InputTypes.size(); i != e; ++i)
868       OS << LS << "(" << InputTypes[i]->getTypeStr() << ")(op" << i << ")";
869   }
870   OS << ")\n";
871 }
872 
emitMangledFuncDef(raw_ostream & OS) const873 void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const {
874   OS << "__attribute__((clang_builtin_alias(";
875   OS << "__builtin_rvv_" << getName() << ")))\n";
876   OS << OutputType->getTypeStr() << " " << getMangledName() << "(";
877   // Emit function arguments
878   if (!InputTypes.empty()) {
879     ListSeparator LS;
880     for (unsigned i = 0; i < InputTypes.size(); ++i)
881       OS << LS << InputTypes[i]->getTypeStr() << " op" << i;
882   }
883   OS << ");\n\n";
884 }
885 
886 //===----------------------------------------------------------------------===//
887 // RVVEmitter implementation
888 //===----------------------------------------------------------------------===//
createHeader(raw_ostream & OS)889 void RVVEmitter::createHeader(raw_ostream &OS) {
890 
891   OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics "
892         "-------------------===\n"
893         " *\n"
894         " *\n"
895         " * Part of the LLVM Project, under the Apache License v2.0 with LLVM "
896         "Exceptions.\n"
897         " * See https://llvm.org/LICENSE.txt for license information.\n"
898         " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n"
899         " *\n"
900         " *===-----------------------------------------------------------------"
901         "------===\n"
902         " */\n\n";
903 
904   OS << "#ifndef __RISCV_VECTOR_H\n";
905   OS << "#define __RISCV_VECTOR_H\n\n";
906 
907   OS << "#include <stdint.h>\n";
908   OS << "#include <stddef.h>\n\n";
909 
910   OS << "#ifndef __riscv_vector\n";
911   OS << "#error \"Vector intrinsics require the vector extension.\"\n";
912   OS << "#endif\n\n";
913 
914   OS << "#ifdef __cplusplus\n";
915   OS << "extern \"C\" {\n";
916   OS << "#endif\n\n";
917 
918   std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
919   createRVVIntrinsics(Defs);
920 
921   // Print header code
922   if (!HeaderCode.empty()) {
923     OS << HeaderCode;
924   }
925 
926   auto printType = [&](auto T) {
927     OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr()
928        << ";\n";
929   };
930 
931   constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
932   // Print RVV boolean types.
933   for (int Log2LMUL : Log2LMULs) {
934     auto T = computeType('c', Log2LMUL, "m");
935     if (T.hasValue())
936       printType(T.getValue());
937   }
938   // Print RVV int/float types.
939   for (char I : StringRef("csil")) {
940     for (int Log2LMUL : Log2LMULs) {
941       auto T = computeType(I, Log2LMUL, "v");
942       if (T.hasValue()) {
943         printType(T.getValue());
944         auto UT = computeType(I, Log2LMUL, "Uv");
945         printType(UT.getValue());
946       }
947     }
948   }
949   OS << "#if defined(__riscv_zfh)\n";
950   for (int Log2LMUL : Log2LMULs) {
951     auto T = computeType('x', Log2LMUL, "v");
952     if (T.hasValue())
953       printType(T.getValue());
954   }
955   OS << "#endif\n";
956 
957   OS << "#if defined(__riscv_f)\n";
958   for (int Log2LMUL : Log2LMULs) {
959     auto T = computeType('f', Log2LMUL, "v");
960     if (T.hasValue())
961       printType(T.getValue());
962   }
963   OS << "#endif\n";
964 
965   OS << "#if defined(__riscv_d)\n";
966   for (int Log2LMUL : Log2LMULs) {
967     auto T = computeType('d', Log2LMUL, "v");
968     if (T.hasValue())
969       printType(T.getValue());
970   }
971   OS << "#endif\n\n";
972 
973   // The same extension include in the same arch guard marco.
974   std::stable_sort(Defs.begin(), Defs.end(),
975                    [](const std::unique_ptr<RVVIntrinsic> &A,
976                       const std::unique_ptr<RVVIntrinsic> &B) {
977                      return A->getRISCVExtensions() < B->getRISCVExtensions();
978                    });
979 
980   // Print intrinsic functions with macro
981   emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
982     Inst.emitIntrinsicMacro(OS);
983   });
984 
985   OS << "#define __riscv_v_intrinsic_overloading 1\n";
986 
987   // Print Overloaded APIs
988   OS << "#define __rvv_overloaded static inline "
989         "__attribute__((__always_inline__, __nodebug__, __overloadable__))\n";
990 
991   emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
992     if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded())
993       return;
994     OS << "__rvv_overloaded ";
995     Inst.emitMangledFuncDef(OS);
996   });
997 
998   OS << "\n#ifdef __cplusplus\n";
999   OS << "}\n";
1000   OS << "#endif // __riscv_vector\n";
1001   OS << "#endif // __RISCV_VECTOR_H\n";
1002 }
1003 
createBuiltins(raw_ostream & OS)1004 void RVVEmitter::createBuiltins(raw_ostream &OS) {
1005   std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1006   createRVVIntrinsics(Defs);
1007 
1008   OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n";
1009   OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, "
1010         "ATTRS, \"experimental-v\")\n";
1011   OS << "#endif\n";
1012   for (auto &Def : Defs) {
1013     OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getName() << ",\""
1014        << Def->getBuiltinTypeStr() << "\", ";
1015     if (!Def->hasSideEffects())
1016       OS << "\"n\")\n";
1017     else
1018       OS << "\"\")\n";
1019   }
1020   OS << "#undef RISCVV_BUILTIN\n";
1021 }
1022 
createCodeGen(raw_ostream & OS)1023 void RVVEmitter::createCodeGen(raw_ostream &OS) {
1024   std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
1025   createRVVIntrinsics(Defs);
1026   // IR name could be empty, use the stable sort preserves the relative order.
1027   std::stable_sort(Defs.begin(), Defs.end(),
1028                    [](const std::unique_ptr<RVVIntrinsic> &A,
1029                       const std::unique_ptr<RVVIntrinsic> &B) {
1030                      return A->getIRName() < B->getIRName();
1031                    });
1032   // Print switch body when the ir name or ManualCodegen changes from previous
1033   // iteration.
1034   RVVIntrinsic *PrevDef = Defs.begin()->get();
1035   for (auto &Def : Defs) {
1036     StringRef CurIRName = Def->getIRName();
1037     if (CurIRName != PrevDef->getIRName() ||
1038         (Def->getManualCodegen() != PrevDef->getManualCodegen())) {
1039       PrevDef->emitCodeGenSwitchBody(OS);
1040     }
1041     PrevDef = Def.get();
1042     OS << "case RISCV::BI__builtin_rvv_" << Def->getName() << ":\n";
1043   }
1044   Defs.back()->emitCodeGenSwitchBody(OS);
1045   OS << "\n";
1046 }
1047 
parsePrototypes(StringRef Prototypes,std::function<void (StringRef)> Handler)1048 void RVVEmitter::parsePrototypes(StringRef Prototypes,
1049                                  std::function<void(StringRef)> Handler) {
1050   const StringRef Primaries("evwqom0ztul");
1051   while (!Prototypes.empty()) {
1052     size_t Idx = 0;
1053     // Skip over complex prototype because it could contain primitive type
1054     // character.
1055     if (Prototypes[0] == '(')
1056       Idx = Prototypes.find_first_of(')');
1057     Idx = Prototypes.find_first_of(Primaries, Idx);
1058     assert(Idx != StringRef::npos);
1059     Handler(Prototypes.slice(0, Idx + 1));
1060     Prototypes = Prototypes.drop_front(Idx + 1);
1061   }
1062 }
1063 
getSuffixStr(char Type,int Log2LMUL,StringRef Prototypes)1064 std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
1065                                      StringRef Prototypes) {
1066   SmallVector<std::string> SuffixStrs;
1067   parsePrototypes(Prototypes, [&](StringRef Proto) {
1068     auto T = computeType(Type, Log2LMUL, Proto);
1069     SuffixStrs.push_back(T.getValue()->getShortStr());
1070   });
1071   return join(SuffixStrs, "_");
1072 }
1073 
createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> & Out)1074 void RVVEmitter::createRVVIntrinsics(
1075     std::vector<std::unique_ptr<RVVIntrinsic>> &Out) {
1076   std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin");
1077   for (auto *R : RV) {
1078     StringRef Name = R->getValueAsString("Name");
1079     StringRef SuffixProto = R->getValueAsString("Suffix");
1080     StringRef MangledName = R->getValueAsString("MangledName");
1081     StringRef MangledSuffixProto = R->getValueAsString("MangledSuffix");
1082     StringRef Prototypes = R->getValueAsString("Prototype");
1083     StringRef TypeRange = R->getValueAsString("TypeRange");
1084     bool HasMask = R->getValueAsBit("HasMask");
1085     bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand");
1086     bool HasVL = R->getValueAsBit("HasVL");
1087     bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded");
1088     bool HasSideEffects = R->getValueAsBit("HasSideEffects");
1089     std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL");
1090     StringRef ManualCodegen = R->getValueAsString("ManualCodegen");
1091     StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask");
1092     std::vector<int64_t> IntrinsicTypes =
1093         R->getValueAsListOfInts("IntrinsicTypes");
1094     StringRef RequiredExtension = R->getValueAsString("RequiredExtension");
1095     StringRef IRName = R->getValueAsString("IRName");
1096     StringRef IRNameMask = R->getValueAsString("IRNameMask");
1097     unsigned NF = R->getValueAsInt("NF");
1098 
1099     StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
1100     bool HasAutoDef = HeaderCodeStr.empty();
1101     if (!HeaderCodeStr.empty()) {
1102       HeaderCode += HeaderCodeStr.str();
1103     }
1104     // Parse prototype and create a list of primitive type with transformers
1105     // (operand) in ProtoSeq. ProtoSeq[0] is output operand.
1106     SmallVector<std::string> ProtoSeq;
1107     parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
1108       ProtoSeq.push_back(Proto.str());
1109     });
1110 
1111     // Compute Builtin types
1112     SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
1113     if (HasMask) {
1114       // If HasMaskedOffOperand, insert result type as first input operand.
1115       if (HasMaskedOffOperand) {
1116         if (NF == 1) {
1117           ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]);
1118         } else {
1119           // Convert
1120           // (void, op0 address, op1 address, ...)
1121           // to
1122           // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1123           for (unsigned I = 0; I < NF; ++I)
1124             ProtoMaskSeq.insert(
1125                 ProtoMaskSeq.begin() + NF + 1,
1126                 ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
1127         }
1128       }
1129       if (HasMaskedOffOperand && NF > 1) {
1130         // Convert
1131         // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
1132         // to
1133         // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
1134         // ...)
1135         ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
1136       } else {
1137         // If HasMask, insert 'm' as first input operand.
1138         ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
1139       }
1140     }
1141     // If HasVL, append 'z' to last operand
1142     if (HasVL) {
1143       ProtoSeq.push_back("z");
1144       ProtoMaskSeq.push_back("z");
1145     }
1146 
1147     // Create Intrinsics for each type and LMUL.
1148     for (char I : TypeRange) {
1149       for (int Log2LMUL : Log2LMULList) {
1150         Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
1151         // Ignored to create new intrinsic if there are any illegal types.
1152         if (!Types.hasValue())
1153           continue;
1154 
1155         auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
1156         auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
1157         // Create a non-mask intrinsic
1158         Out.push_back(std::make_unique<RVVIntrinsic>(
1159             Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
1160             HasSideEffects, /*IsMask=*/false, /*HasMaskedOffOperand=*/false,
1161             HasVL, HasNoMaskedOverloaded, HasAutoDef, ManualCodegen,
1162             Types.getValue(), IntrinsicTypes, RequiredExtension, NF));
1163         if (HasMask) {
1164           // Create a mask intrinsic
1165           Optional<RVVTypes> MaskTypes =
1166               computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
1167           Out.push_back(std::make_unique<RVVIntrinsic>(
1168               Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask,
1169               HasSideEffects, /*IsMask=*/true, HasMaskedOffOperand, HasVL,
1170               HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask,
1171               MaskTypes.getValue(), IntrinsicTypes, RequiredExtension, NF));
1172         }
1173       } // end for Log2LMULList
1174     }   // end for TypeRange
1175   }
1176 }
1177 
1178 Optional<RVVTypes>
computeTypes(BasicType BT,int Log2LMUL,unsigned NF,ArrayRef<std::string> PrototypeSeq)1179 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
1180                          ArrayRef<std::string> PrototypeSeq) {
1181   // LMUL x NF must be less than or equal to 8.
1182   if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
1183     return llvm::None;
1184 
1185   RVVTypes Types;
1186   for (const std::string &Proto : PrototypeSeq) {
1187     auto T = computeType(BT, Log2LMUL, Proto);
1188     if (!T.hasValue())
1189       return llvm::None;
1190     // Record legal type index
1191     Types.push_back(T.getValue());
1192   }
1193   return Types;
1194 }
1195 
computeType(BasicType BT,int Log2LMUL,StringRef Proto)1196 Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
1197                                              StringRef Proto) {
1198   std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
1199   // Search first
1200   auto It = LegalTypes.find(Idx);
1201   if (It != LegalTypes.end())
1202     return &(It->second);
1203   if (IllegalTypes.count(Idx))
1204     return llvm::None;
1205   // Compute type and record the result.
1206   RVVType T(BT, Log2LMUL, Proto);
1207   if (T.isValid()) {
1208     // Record legal type index and value.
1209     LegalTypes.insert({Idx, T});
1210     return &(LegalTypes[Idx]);
1211   }
1212   // Record illegal type index.
1213   IllegalTypes.insert(Idx);
1214   return llvm::None;
1215 }
1216 
emitArchMacroAndBody(std::vector<std::unique_ptr<RVVIntrinsic>> & Defs,raw_ostream & OS,std::function<void (raw_ostream &,const RVVIntrinsic &)> PrintBody)1217 void RVVEmitter::emitArchMacroAndBody(
1218     std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS,
1219     std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) {
1220   uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions();
1221   bool NeedEndif = emitExtDefStr(PrevExt, OS);
1222   for (auto &Def : Defs) {
1223     uint8_t CurExt = Def->getRISCVExtensions();
1224     if (CurExt != PrevExt) {
1225       if (NeedEndif)
1226         OS << "#endif\n\n";
1227       NeedEndif = emitExtDefStr(CurExt, OS);
1228       PrevExt = CurExt;
1229     }
1230     if (Def->hasAutoDef())
1231       PrintBody(OS, *Def);
1232   }
1233   if (NeedEndif)
1234     OS << "#endif\n\n";
1235 }
1236 
emitExtDefStr(uint8_t Extents,raw_ostream & OS)1237 bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) {
1238   if (Extents == RISCVExtension::Basic)
1239     return false;
1240   OS << "#if ";
1241   ListSeparator LS(" && ");
1242   if (Extents & RISCVExtension::F)
1243     OS << LS << "defined(__riscv_f)";
1244   if (Extents & RISCVExtension::D)
1245     OS << LS << "defined(__riscv_d)";
1246   if (Extents & RISCVExtension::Zfh)
1247     OS << LS << "defined(__riscv_zfh)";
1248   if (Extents & RISCVExtension::Zvamo)
1249     OS << LS << "defined(__riscv_zvamo)";
1250   if (Extents & RISCVExtension::Zvlsseg)
1251     OS << LS << "defined(__riscv_zvlsseg)";
1252   OS << "\n";
1253   return true;
1254 }
1255 
1256 namespace clang {
EmitRVVHeader(RecordKeeper & Records,raw_ostream & OS)1257 void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) {
1258   RVVEmitter(Records).createHeader(OS);
1259 }
1260 
EmitRVVBuiltins(RecordKeeper & Records,raw_ostream & OS)1261 void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) {
1262   RVVEmitter(Records).createBuiltins(OS);
1263 }
1264 
EmitRVVBuiltinCG(RecordKeeper & Records,raw_ostream & OS)1265 void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) {
1266   RVVEmitter(Records).createCodeGen(OS);
1267 }
1268 
1269 } // End namespace clang
1270