1 //==- SemaRISCVVectorLookup.cpp - Name Lookup for RISC-V Vector Intrinsic -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file implements name lookup for RISC-V vector intrinsic.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Decl.h"
15 #include "clang/Basic/Builtins.h"
16 #include "clang/Basic/TargetInfo.h"
17 #include "clang/Lex/Preprocessor.h"
18 #include "clang/Sema/Lookup.h"
19 #include "clang/Sema/RISCVIntrinsicManager.h"
20 #include "clang/Sema/Sema.h"
21 #include "clang/Support/RISCVVIntrinsicUtils.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include <string>
24 #include <vector>
25 
26 using namespace llvm;
27 using namespace clang;
28 using namespace clang::RISCV;
29 
30 namespace {
31 
32 // Function definition of a RVV intrinsic.
33 struct RVVIntrinsicDef {
34   /// Full function name with suffix, e.g. vadd_vv_i32m1.
35   std::string Name;
36 
37   /// Overloaded function name, e.g. vadd.
38   std::string OverloadName;
39 
40   /// Mapping to which clang built-in function, e.g. __builtin_rvv_vadd.
41   std::string BuiltinName;
42 
43   /// Function signature, first element is return type.
44   RVVTypes Signature;
45 };
46 
47 struct RVVOverloadIntrinsicDef {
48   // Indexes of RISCVIntrinsicManagerImpl::IntrinsicList.
49   SmallVector<size_t, 8> Indexes;
50 };
51 
52 } // namespace
53 
54 static const PrototypeDescriptor RVVSignatureTable[] = {
55 #define DECL_SIGNATURE_TABLE
56 #include "clang/Basic/riscv_vector_builtin_sema.inc"
57 #undef DECL_SIGNATURE_TABLE
58 };
59 
60 static const RVVIntrinsicRecord RVVIntrinsicRecords[] = {
61 #define DECL_INTRINSIC_RECORDS
62 #include "clang/Basic/riscv_vector_builtin_sema.inc"
63 #undef DECL_INTRINSIC_RECORDS
64 };
65 
66 // Get subsequence of signature table.
67 static ArrayRef<PrototypeDescriptor> ProtoSeq2ArrayRef(uint16_t Index,
68                                                        uint8_t Length) {
69   return makeArrayRef(&RVVSignatureTable[Index], Length);
70 }
71 
72 static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
73   QualType QT;
74   switch (Type->getScalarType()) {
75   case ScalarTypeKind::Void:
76     QT = Context.VoidTy;
77     break;
78   case ScalarTypeKind::Size_t:
79     QT = Context.getSizeType();
80     break;
81   case ScalarTypeKind::Ptrdiff_t:
82     QT = Context.getPointerDiffType();
83     break;
84   case ScalarTypeKind::UnsignedLong:
85     QT = Context.UnsignedLongTy;
86     break;
87   case ScalarTypeKind::SignedLong:
88     QT = Context.LongTy;
89     break;
90   case ScalarTypeKind::Boolean:
91     QT = Context.BoolTy;
92     break;
93   case ScalarTypeKind::SignedInteger:
94     QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), true);
95     break;
96   case ScalarTypeKind::UnsignedInteger:
97     QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
98     break;
99   case ScalarTypeKind::Float:
100     switch (Type->getElementBitwidth()) {
101     case 64:
102       QT = Context.DoubleTy;
103       break;
104     case 32:
105       QT = Context.FloatTy;
106       break;
107     case 16:
108       QT = Context.Float16Ty;
109       break;
110     default:
111       llvm_unreachable("Unsupported floating point width.");
112     }
113     break;
114   case Invalid:
115     llvm_unreachable("Unhandled type.");
116   }
117   if (Type->isVector())
118     QT = Context.getScalableVectorType(QT, Type->getScale().getValue());
119 
120   if (Type->isConstant())
121     QT = Context.getConstType(QT);
122 
123   // Transform the type to a pointer as the last step, if necessary.
124   if (Type->isPointer())
125     QT = Context.getPointerType(QT);
126 
127   return QT;
128 }
129 
130 namespace {
131 class RISCVIntrinsicManagerImpl : public sema::RISCVIntrinsicManager {
132 private:
133   Sema &S;
134   ASTContext &Context;
135 
136   // List of all RVV intrinsic.
137   std::vector<RVVIntrinsicDef> IntrinsicList;
138   // Mapping function name to index of IntrinsicList.
139   StringMap<size_t> Intrinsics;
140   // Mapping function name to RVVOverloadIntrinsicDef.
141   StringMap<RVVOverloadIntrinsicDef> OverloadIntrinsics;
142 
143   // Create IntrinsicList
144   void InitIntrinsicList();
145 
146   // Create RVVIntrinsicDef.
147   void InitRVVIntrinsic(const RVVIntrinsicRecord &Record, StringRef SuffixStr,
148                         StringRef OverloadedSuffixStr, bool IsMask,
149                         RVVTypes &Types);
150 
151   // Create FunctionDecl for a vector intrinsic.
152   void CreateRVVIntrinsicDecl(LookupResult &LR, IdentifierInfo *II,
153                               Preprocessor &PP, unsigned Index,
154                               bool IsOverload);
155 
156 public:
157   RISCVIntrinsicManagerImpl(clang::Sema &S) : S(S), Context(S.Context) {
158     InitIntrinsicList();
159   }
160 
161   // Create RISC-V vector intrinsic and insert into symbol table if found, and
162   // return true, otherwise return false.
163   bool CreateIntrinsicIfFound(LookupResult &LR, IdentifierInfo *II,
164                               Preprocessor &PP) override;
165 };
166 } // namespace
167 
168 void RISCVIntrinsicManagerImpl::InitIntrinsicList() {
169   const TargetInfo &TI = Context.getTargetInfo();
170   bool HasVectorFloat32 = TI.hasFeature("zve32f");
171   bool HasVectorFloat64 = TI.hasFeature("zve64d");
172   bool HasZvfh = TI.hasFeature("experimental-zvfh");
173   bool HasRV64 = TI.hasFeature("64bit");
174   bool HasFullMultiply = TI.hasFeature("v");
175 
176   // Construction of RVVIntrinsicRecords need to sync with createRVVIntrinsics
177   // in RISCVVEmitter.cpp.
178   for (auto &Record : RVVIntrinsicRecords) {
179     // Create Intrinsics for each type and LMUL.
180     BasicType BaseType = BasicType::Unknown;
181     ArrayRef<PrototypeDescriptor> BasicProtoSeq =
182         ProtoSeq2ArrayRef(Record.PrototypeIndex, Record.PrototypeLength);
183     ArrayRef<PrototypeDescriptor> SuffixProto =
184         ProtoSeq2ArrayRef(Record.SuffixIndex, Record.SuffixLength);
185     ArrayRef<PrototypeDescriptor> OverloadedSuffixProto = ProtoSeq2ArrayRef(
186         Record.OverloadedSuffixIndex, Record.OverloadedSuffixSize);
187 
188     llvm::SmallVector<PrototypeDescriptor> ProtoSeq =
189         RVVIntrinsic::computeBuiltinTypes(BasicProtoSeq, /*IsMasked=*/false,
190                                           /*HasMaskedOffOperand=*/false,
191                                           Record.HasVL, Record.NF);
192 
193     llvm::SmallVector<PrototypeDescriptor> ProtoMaskSeq =
194         RVVIntrinsic::computeBuiltinTypes(BasicProtoSeq, /*IsMasked=*/true,
195                                           Record.HasMaskedOffOperand,
196                                           Record.HasVL, Record.NF);
197 
198     for (unsigned int TypeRangeMaskShift = 0;
199          TypeRangeMaskShift <= static_cast<unsigned int>(BasicType::MaxOffset);
200          ++TypeRangeMaskShift) {
201       unsigned int BaseTypeI = 1 << TypeRangeMaskShift;
202       BaseType = static_cast<BasicType>(BaseTypeI);
203 
204       if ((BaseTypeI & Record.TypeRangeMask) != BaseTypeI)
205         continue;
206 
207       // Check requirement.
208       if (BaseType == BasicType::Float16 && !HasZvfh)
209         continue;
210 
211       if (BaseType == BasicType::Float32 && !HasVectorFloat32)
212         continue;
213 
214       if (BaseType == BasicType::Float64 && !HasVectorFloat64)
215         continue;
216 
217       if (((Record.RequiredExtensions & RVV_REQ_RV64) == RVV_REQ_RV64) &&
218           !HasRV64)
219         continue;
220 
221       if ((BaseType == BasicType::Int64) &&
222           ((Record.RequiredExtensions & RVV_REQ_FullMultiply) ==
223            RVV_REQ_FullMultiply) &&
224           !HasFullMultiply)
225         continue;
226 
227       // Expanded with different LMUL.
228       for (int Log2LMUL = -3; Log2LMUL <= 3; Log2LMUL++) {
229         if (!(Record.Log2LMULMask & (1 << (Log2LMUL + 3))))
230           continue;
231 
232         Optional<RVVTypes> Types =
233             RVVType::computeTypes(BaseType, Log2LMUL, Record.NF, ProtoSeq);
234 
235         // Ignored to create new intrinsic if there are any illegal types.
236         if (!Types.hasValue())
237           continue;
238 
239         std::string SuffixStr =
240             RVVIntrinsic::getSuffixStr(BaseType, Log2LMUL, SuffixProto);
241         std::string OverloadedSuffixStr = RVVIntrinsic::getSuffixStr(
242             BaseType, Log2LMUL, OverloadedSuffixProto);
243 
244         // Create non-masked intrinsic.
245         InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, false, *Types);
246 
247         if (Record.HasMasked) {
248           // Create masked intrinsic.
249           Optional<RVVTypes> MaskTypes = RVVType::computeTypes(
250               BaseType, Log2LMUL, Record.NF, ProtoMaskSeq);
251 
252           InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, true,
253                            *MaskTypes);
254         }
255       }
256     }
257   }
258 }
259 
260 // Compute name and signatures for intrinsic with practical types.
261 void RISCVIntrinsicManagerImpl::InitRVVIntrinsic(
262     const RVVIntrinsicRecord &Record, StringRef SuffixStr,
263     StringRef OverloadedSuffixStr, bool IsMask, RVVTypes &Signature) {
264   // Function name, e.g. vadd_vv_i32m1.
265   std::string Name = Record.Name;
266   if (!SuffixStr.empty())
267     Name += "_" + SuffixStr.str();
268 
269   if (IsMask)
270     Name += "_m";
271 
272   // Overloaded function name, e.g. vadd.
273   std::string OverloadedName;
274   if (!Record.OverloadedName)
275     OverloadedName = StringRef(Record.Name).split("_").first.str();
276   else
277     OverloadedName = Record.OverloadedName;
278   if (!OverloadedSuffixStr.empty())
279     OverloadedName += "_" + OverloadedSuffixStr.str();
280 
281   // clang built-in function name, e.g. __builtin_rvv_vadd.
282   std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name);
283   if (IsMask)
284     BuiltinName += "_m";
285 
286   // Put into IntrinsicList.
287   size_t Index = IntrinsicList.size();
288   IntrinsicList.push_back({Name, OverloadedName, BuiltinName, Signature});
289 
290   // Creating mapping to Intrinsics.
291   Intrinsics.insert({Name, Index});
292 
293   // Get the RVVOverloadIntrinsicDef.
294   RVVOverloadIntrinsicDef &OverloadIntrinsicDef =
295       OverloadIntrinsics[OverloadedName];
296 
297   // And added the index.
298   OverloadIntrinsicDef.Indexes.push_back(Index);
299 }
300 
301 void RISCVIntrinsicManagerImpl::CreateRVVIntrinsicDecl(LookupResult &LR,
302                                                        IdentifierInfo *II,
303                                                        Preprocessor &PP,
304                                                        unsigned Index,
305                                                        bool IsOverload) {
306   ASTContext &Context = S.Context;
307   RVVIntrinsicDef &IDef = IntrinsicList[Index];
308   RVVTypes Sigs = IDef.Signature;
309   size_t SigLength = Sigs.size();
310   RVVType *ReturnType = Sigs[0];
311   QualType RetType = RVVType2Qual(Context, ReturnType);
312   SmallVector<QualType, 8> ArgTypes;
313   QualType BuiltinFuncType;
314 
315   // Skip return type, and convert RVVType to QualType for arguments.
316   for (size_t i = 1; i < SigLength; ++i)
317     ArgTypes.push_back(RVVType2Qual(Context, Sigs[i]));
318 
319   FunctionProtoType::ExtProtoInfo PI(
320       Context.getDefaultCallingConvention(false, false, true));
321 
322   PI.Variadic = false;
323 
324   SourceLocation Loc = LR.getNameLoc();
325   BuiltinFuncType = Context.getFunctionType(RetType, ArgTypes, PI);
326   DeclContext *Parent = Context.getTranslationUnitDecl();
327 
328   FunctionDecl *RVVIntrinsicDecl = FunctionDecl::Create(
329       Context, Parent, Loc, Loc, II, BuiltinFuncType, /*TInfo=*/nullptr,
330       SC_Extern, S.getCurFPFeatures().isFPConstrained(),
331       /*isInlineSpecified*/ false,
332       /*hasWrittenPrototype*/ true);
333 
334   // Create Decl objects for each parameter, adding them to the
335   // FunctionDecl.
336   const auto *FP = cast<FunctionProtoType>(BuiltinFuncType);
337   SmallVector<ParmVarDecl *, 8> ParmList;
338   for (unsigned IParm = 0, E = FP->getNumParams(); IParm != E; ++IParm) {
339     ParmVarDecl *Parm =
340         ParmVarDecl::Create(Context, RVVIntrinsicDecl, Loc, Loc, nullptr,
341                             FP->getParamType(IParm), nullptr, SC_None, nullptr);
342     Parm->setScopeInfo(0, IParm);
343     ParmList.push_back(Parm);
344   }
345   RVVIntrinsicDecl->setParams(ParmList);
346 
347   // Add function attributes.
348   if (IsOverload)
349     RVVIntrinsicDecl->addAttr(OverloadableAttr::CreateImplicit(Context));
350 
351   // Setup alias to __builtin_rvv_*
352   IdentifierInfo &IntrinsicII = PP.getIdentifierTable().get(IDef.BuiltinName);
353   RVVIntrinsicDecl->addAttr(
354       BuiltinAliasAttr::CreateImplicit(S.Context, &IntrinsicII));
355 
356   // Add to symbol table.
357   LR.addDecl(RVVIntrinsicDecl);
358 }
359 
360 bool RISCVIntrinsicManagerImpl::CreateIntrinsicIfFound(LookupResult &LR,
361                                                        IdentifierInfo *II,
362                                                        Preprocessor &PP) {
363   StringRef Name = II->getName();
364 
365   // Lookup the function name from the overload intrinsics first.
366   auto OvIItr = OverloadIntrinsics.find(Name);
367   if (OvIItr != OverloadIntrinsics.end()) {
368     const RVVOverloadIntrinsicDef &OvIntrinsicDef = OvIItr->second;
369     for (auto Index : OvIntrinsicDef.Indexes)
370       CreateRVVIntrinsicDecl(LR, II, PP, Index,
371                              /*IsOverload*/ true);
372 
373     // If we added overloads, need to resolve the lookup result.
374     LR.resolveKind();
375     return true;
376   }
377 
378   // Lookup the function name from the intrinsics.
379   auto Itr = Intrinsics.find(Name);
380   if (Itr != Intrinsics.end()) {
381     CreateRVVIntrinsicDecl(LR, II, PP, Itr->second,
382                            /*IsOverload*/ false);
383     return true;
384   }
385 
386   // It's not an RVV intrinsics.
387   return false;
388 }
389 
390 namespace clang {
391 std::unique_ptr<clang::sema::RISCVIntrinsicManager>
392 CreateRISCVIntrinsicManager(Sema &S) {
393   return std::make_unique<RISCVIntrinsicManagerImpl>(S);
394 }
395 } // namespace clang
396