1 //==- SemaRISCVVectorLookup.cpp - Name Lookup for RISC-V Vector Intrinsic -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file implements name lookup for RISC-V vector intrinsic.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Decl.h"
15 #include "clang/Basic/Builtins.h"
16 #include "clang/Basic/TargetInfo.h"
17 #include "clang/Lex/Preprocessor.h"
18 #include "clang/Sema/Lookup.h"
19 #include "clang/Sema/RISCVIntrinsicManager.h"
20 #include "clang/Sema/Sema.h"
21 #include "clang/Support/RISCVVIntrinsicUtils.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include <optional>
24 #include <string>
25 #include <vector>
26 
27 using namespace llvm;
28 using namespace clang;
29 using namespace clang::RISCV;
30 
31 namespace {
32 
33 // Function definition of a RVV intrinsic.
34 struct RVVIntrinsicDef {
35   /// Full function name with suffix, e.g. vadd_vv_i32m1.
36   std::string Name;
37 
38   /// Overloaded function name, e.g. vadd.
39   std::string OverloadName;
40 
41   /// Mapping to which clang built-in function, e.g. __builtin_rvv_vadd.
42   std::string BuiltinName;
43 
44   /// Function signature, first element is return type.
45   RVVTypes Signature;
46 };
47 
48 struct RVVOverloadIntrinsicDef {
49   // Indexes of RISCVIntrinsicManagerImpl::IntrinsicList.
50   SmallVector<size_t, 8> Indexes;
51 };
52 
53 } // namespace
54 
55 static const PrototypeDescriptor RVVSignatureTable[] = {
56 #define DECL_SIGNATURE_TABLE
57 #include "clang/Basic/riscv_vector_builtin_sema.inc"
58 #undef DECL_SIGNATURE_TABLE
59 };
60 
61 static const RVVIntrinsicRecord RVVIntrinsicRecords[] = {
62 #define DECL_INTRINSIC_RECORDS
63 #include "clang/Basic/riscv_vector_builtin_sema.inc"
64 #undef DECL_INTRINSIC_RECORDS
65 };
66 
67 // Get subsequence of signature table.
ProtoSeq2ArrayRef(uint16_t Index,uint8_t Length)68 static ArrayRef<PrototypeDescriptor> ProtoSeq2ArrayRef(uint16_t Index,
69                                                        uint8_t Length) {
70   return ArrayRef(&RVVSignatureTable[Index], Length);
71 }
72 
RVVType2Qual(ASTContext & Context,const RVVType * Type)73 static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
74   QualType QT;
75   switch (Type->getScalarType()) {
76   case ScalarTypeKind::Void:
77     QT = Context.VoidTy;
78     break;
79   case ScalarTypeKind::Size_t:
80     QT = Context.getSizeType();
81     break;
82   case ScalarTypeKind::Ptrdiff_t:
83     QT = Context.getPointerDiffType();
84     break;
85   case ScalarTypeKind::UnsignedLong:
86     QT = Context.UnsignedLongTy;
87     break;
88   case ScalarTypeKind::SignedLong:
89     QT = Context.LongTy;
90     break;
91   case ScalarTypeKind::Boolean:
92     QT = Context.BoolTy;
93     break;
94   case ScalarTypeKind::SignedInteger:
95     QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), true);
96     break;
97   case ScalarTypeKind::UnsignedInteger:
98     QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
99     break;
100   case ScalarTypeKind::Float:
101     switch (Type->getElementBitwidth()) {
102     case 64:
103       QT = Context.DoubleTy;
104       break;
105     case 32:
106       QT = Context.FloatTy;
107       break;
108     case 16:
109       QT = Context.Float16Ty;
110       break;
111     default:
112       llvm_unreachable("Unsupported floating point width.");
113     }
114     break;
115   case Invalid:
116     llvm_unreachable("Unhandled type.");
117   }
118   if (Type->isVector())
119     QT = Context.getScalableVectorType(QT, *Type->getScale());
120 
121   if (Type->isConstant())
122     QT = Context.getConstType(QT);
123 
124   // Transform the type to a pointer as the last step, if necessary.
125   if (Type->isPointer())
126     QT = Context.getPointerType(QT);
127 
128   return QT;
129 }
130 
131 namespace {
132 class RISCVIntrinsicManagerImpl : public sema::RISCVIntrinsicManager {
133 private:
134   Sema &S;
135   ASTContext &Context;
136   RVVTypeCache TypeCache;
137 
138   // List of all RVV intrinsic.
139   std::vector<RVVIntrinsicDef> IntrinsicList;
140   // Mapping function name to index of IntrinsicList.
141   StringMap<size_t> Intrinsics;
142   // Mapping function name to RVVOverloadIntrinsicDef.
143   StringMap<RVVOverloadIntrinsicDef> OverloadIntrinsics;
144 
145   // Create IntrinsicList
146   void InitIntrinsicList();
147 
148   // Create RVVIntrinsicDef.
149   void InitRVVIntrinsic(const RVVIntrinsicRecord &Record, StringRef SuffixStr,
150                         StringRef OverloadedSuffixStr, bool IsMask,
151                         RVVTypes &Types, bool HasPolicy, Policy PolicyAttrs);
152 
153   // Create FunctionDecl for a vector intrinsic.
154   void CreateRVVIntrinsicDecl(LookupResult &LR, IdentifierInfo *II,
155                               Preprocessor &PP, unsigned Index,
156                               bool IsOverload);
157 
158 public:
RISCVIntrinsicManagerImpl(clang::Sema & S)159   RISCVIntrinsicManagerImpl(clang::Sema &S) : S(S), Context(S.Context) {
160     InitIntrinsicList();
161   }
162 
163   // Create RISC-V vector intrinsic and insert into symbol table if found, and
164   // return true, otherwise return false.
165   bool CreateIntrinsicIfFound(LookupResult &LR, IdentifierInfo *II,
166                               Preprocessor &PP) override;
167 };
168 } // namespace
169 
InitIntrinsicList()170 void RISCVIntrinsicManagerImpl::InitIntrinsicList() {
171   const TargetInfo &TI = Context.getTargetInfo();
172   bool HasVectorFloat32 = TI.hasFeature("zve32f");
173   bool HasVectorFloat64 = TI.hasFeature("zve64d");
174   bool HasZvfh = TI.hasFeature("experimental-zvfh");
175   bool HasRV64 = TI.hasFeature("64bit");
176   bool HasFullMultiply = TI.hasFeature("v");
177 
178   // Construction of RVVIntrinsicRecords need to sync with createRVVIntrinsics
179   // in RISCVVEmitter.cpp.
180   for (auto &Record : RVVIntrinsicRecords) {
181     // Create Intrinsics for each type and LMUL.
182     BasicType BaseType = BasicType::Unknown;
183     ArrayRef<PrototypeDescriptor> BasicProtoSeq =
184         ProtoSeq2ArrayRef(Record.PrototypeIndex, Record.PrototypeLength);
185     ArrayRef<PrototypeDescriptor> SuffixProto =
186         ProtoSeq2ArrayRef(Record.SuffixIndex, Record.SuffixLength);
187     ArrayRef<PrototypeDescriptor> OverloadedSuffixProto = ProtoSeq2ArrayRef(
188         Record.OverloadedSuffixIndex, Record.OverloadedSuffixSize);
189 
190     PolicyScheme UnMaskedPolicyScheme =
191         static_cast<PolicyScheme>(Record.UnMaskedPolicyScheme);
192     PolicyScheme MaskedPolicyScheme =
193         static_cast<PolicyScheme>(Record.MaskedPolicyScheme);
194 
195     const Policy DefaultPolicy;
196 
197     llvm::SmallVector<PrototypeDescriptor> ProtoSeq =
198         RVVIntrinsic::computeBuiltinTypes(BasicProtoSeq, /*IsMasked=*/false,
199                                           /*HasMaskedOffOperand=*/false,
200                                           Record.HasVL, Record.NF,
201                                           UnMaskedPolicyScheme, DefaultPolicy);
202 
203     llvm::SmallVector<PrototypeDescriptor> ProtoMaskSeq =
204         RVVIntrinsic::computeBuiltinTypes(
205             BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand,
206             Record.HasVL, Record.NF, MaskedPolicyScheme, DefaultPolicy);
207 
208     bool UnMaskedHasPolicy = UnMaskedPolicyScheme != PolicyScheme::SchemeNone;
209     bool MaskedHasPolicy = MaskedPolicyScheme != PolicyScheme::SchemeNone;
210     SmallVector<Policy> SupportedUnMaskedPolicies =
211         RVVIntrinsic::getSupportedUnMaskedPolicies();
212     SmallVector<Policy> SupportedMaskedPolicies =
213         RVVIntrinsic::getSupportedMaskedPolicies(Record.HasTailPolicy,
214                                                  Record.HasMaskPolicy);
215 
216     for (unsigned int TypeRangeMaskShift = 0;
217          TypeRangeMaskShift <= static_cast<unsigned int>(BasicType::MaxOffset);
218          ++TypeRangeMaskShift) {
219       unsigned int BaseTypeI = 1 << TypeRangeMaskShift;
220       BaseType = static_cast<BasicType>(BaseTypeI);
221 
222       if ((BaseTypeI & Record.TypeRangeMask) != BaseTypeI)
223         continue;
224 
225       // Check requirement.
226       if (BaseType == BasicType::Float16 && !HasZvfh)
227         continue;
228 
229       if (BaseType == BasicType::Float32 && !HasVectorFloat32)
230         continue;
231 
232       if (BaseType == BasicType::Float64 && !HasVectorFloat64)
233         continue;
234 
235       if (((Record.RequiredExtensions & RVV_REQ_RV64) == RVV_REQ_RV64) &&
236           !HasRV64)
237         continue;
238 
239       if ((BaseType == BasicType::Int64) &&
240           ((Record.RequiredExtensions & RVV_REQ_FullMultiply) ==
241            RVV_REQ_FullMultiply) &&
242           !HasFullMultiply)
243         continue;
244 
245       // Expanded with different LMUL.
246       for (int Log2LMUL = -3; Log2LMUL <= 3; Log2LMUL++) {
247         if (!(Record.Log2LMULMask & (1 << (Log2LMUL + 3))))
248           continue;
249 
250         std::optional<RVVTypes> Types =
251             TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoSeq);
252 
253         // Ignored to create new intrinsic if there are any illegal types.
254         if (!Types.has_value())
255           continue;
256 
257         std::string SuffixStr = RVVIntrinsic::getSuffixStr(
258             TypeCache, BaseType, Log2LMUL, SuffixProto);
259         std::string OverloadedSuffixStr = RVVIntrinsic::getSuffixStr(
260             TypeCache, BaseType, Log2LMUL, OverloadedSuffixProto);
261 
262         // Create non-masked intrinsic.
263         InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, false, *Types,
264                          UnMaskedHasPolicy, DefaultPolicy);
265 
266         // Create non-masked policy intrinsic.
267         if (Record.UnMaskedPolicyScheme != PolicyScheme::SchemeNone) {
268           for (auto P : SupportedUnMaskedPolicies) {
269             llvm::SmallVector<PrototypeDescriptor> PolicyPrototype =
270                 RVVIntrinsic::computeBuiltinTypes(
271                     BasicProtoSeq, /*IsMasked=*/false,
272                     /*HasMaskedOffOperand=*/false, Record.HasVL, Record.NF,
273                     UnMaskedPolicyScheme, P);
274             std::optional<RVVTypes> PolicyTypes = TypeCache.computeTypes(
275                 BaseType, Log2LMUL, Record.NF, PolicyPrototype);
276             InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr,
277                              /*IsMask=*/false, *PolicyTypes, UnMaskedHasPolicy,
278                              P);
279           }
280         }
281         if (!Record.HasMasked)
282           continue;
283         // Create masked intrinsic.
284         std::optional<RVVTypes> MaskTypes =
285             TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoMaskSeq);
286         InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, true,
287                          *MaskTypes, MaskedHasPolicy, DefaultPolicy);
288         if (Record.MaskedPolicyScheme == PolicyScheme::SchemeNone)
289           continue;
290         // Create masked policy intrinsic.
291         for (auto P : SupportedMaskedPolicies) {
292           llvm::SmallVector<PrototypeDescriptor> PolicyPrototype =
293               RVVIntrinsic::computeBuiltinTypes(
294                   BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand,
295                   Record.HasVL, Record.NF, MaskedPolicyScheme, P);
296           std::optional<RVVTypes> PolicyTypes = TypeCache.computeTypes(
297               BaseType, Log2LMUL, Record.NF, PolicyPrototype);
298           InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr,
299                            /*IsMask=*/true, *PolicyTypes, MaskedHasPolicy, P);
300         }
301       } // End for different LMUL
302     } // End for different TypeRange
303   }
304 }
305 
306 // Compute name and signatures for intrinsic with practical types.
InitRVVIntrinsic(const RVVIntrinsicRecord & Record,StringRef SuffixStr,StringRef OverloadedSuffixStr,bool IsMasked,RVVTypes & Signature,bool HasPolicy,Policy PolicyAttrs)307 void RISCVIntrinsicManagerImpl::InitRVVIntrinsic(
308     const RVVIntrinsicRecord &Record, StringRef SuffixStr,
309     StringRef OverloadedSuffixStr, bool IsMasked, RVVTypes &Signature,
310     bool HasPolicy, Policy PolicyAttrs) {
311   // Function name, e.g. vadd_vv_i32m1.
312   std::string Name = Record.Name;
313   if (!SuffixStr.empty())
314     Name += "_" + SuffixStr.str();
315 
316   // Overloaded function name, e.g. vadd.
317   std::string OverloadedName;
318   if (!Record.OverloadedName)
319     OverloadedName = StringRef(Record.Name).split("_").first.str();
320   else
321     OverloadedName = Record.OverloadedName;
322   if (!OverloadedSuffixStr.empty())
323     OverloadedName += "_" + OverloadedSuffixStr.str();
324 
325   // clang built-in function name, e.g. __builtin_rvv_vadd.
326   std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name);
327 
328   RVVIntrinsic::updateNamesAndPolicy(IsMasked, HasPolicy, Name, BuiltinName,
329                                      OverloadedName, PolicyAttrs);
330 
331   // Put into IntrinsicList.
332   size_t Index = IntrinsicList.size();
333   IntrinsicList.push_back({Name, OverloadedName, BuiltinName, Signature});
334 
335   // Creating mapping to Intrinsics.
336   Intrinsics.insert({Name, Index});
337 
338   // Get the RVVOverloadIntrinsicDef.
339   RVVOverloadIntrinsicDef &OverloadIntrinsicDef =
340       OverloadIntrinsics[OverloadedName];
341 
342   // And added the index.
343   OverloadIntrinsicDef.Indexes.push_back(Index);
344 }
345 
CreateRVVIntrinsicDecl(LookupResult & LR,IdentifierInfo * II,Preprocessor & PP,unsigned Index,bool IsOverload)346 void RISCVIntrinsicManagerImpl::CreateRVVIntrinsicDecl(LookupResult &LR,
347                                                        IdentifierInfo *II,
348                                                        Preprocessor &PP,
349                                                        unsigned Index,
350                                                        bool IsOverload) {
351   ASTContext &Context = S.Context;
352   RVVIntrinsicDef &IDef = IntrinsicList[Index];
353   RVVTypes Sigs = IDef.Signature;
354   size_t SigLength = Sigs.size();
355   RVVType *ReturnType = Sigs[0];
356   QualType RetType = RVVType2Qual(Context, ReturnType);
357   SmallVector<QualType, 8> ArgTypes;
358   QualType BuiltinFuncType;
359 
360   // Skip return type, and convert RVVType to QualType for arguments.
361   for (size_t i = 1; i < SigLength; ++i)
362     ArgTypes.push_back(RVVType2Qual(Context, Sigs[i]));
363 
364   FunctionProtoType::ExtProtoInfo PI(
365       Context.getDefaultCallingConvention(false, false, true));
366 
367   PI.Variadic = false;
368 
369   SourceLocation Loc = LR.getNameLoc();
370   BuiltinFuncType = Context.getFunctionType(RetType, ArgTypes, PI);
371   DeclContext *Parent = Context.getTranslationUnitDecl();
372 
373   FunctionDecl *RVVIntrinsicDecl = FunctionDecl::Create(
374       Context, Parent, Loc, Loc, II, BuiltinFuncType, /*TInfo=*/nullptr,
375       SC_Extern, S.getCurFPFeatures().isFPConstrained(),
376       /*isInlineSpecified*/ false,
377       /*hasWrittenPrototype*/ true);
378 
379   // Create Decl objects for each parameter, adding them to the
380   // FunctionDecl.
381   const auto *FP = cast<FunctionProtoType>(BuiltinFuncType);
382   SmallVector<ParmVarDecl *, 8> ParmList;
383   for (unsigned IParm = 0, E = FP->getNumParams(); IParm != E; ++IParm) {
384     ParmVarDecl *Parm =
385         ParmVarDecl::Create(Context, RVVIntrinsicDecl, Loc, Loc, nullptr,
386                             FP->getParamType(IParm), nullptr, SC_None, nullptr);
387     Parm->setScopeInfo(0, IParm);
388     ParmList.push_back(Parm);
389   }
390   RVVIntrinsicDecl->setParams(ParmList);
391 
392   // Add function attributes.
393   if (IsOverload)
394     RVVIntrinsicDecl->addAttr(OverloadableAttr::CreateImplicit(Context));
395 
396   // Setup alias to __builtin_rvv_*
397   IdentifierInfo &IntrinsicII = PP.getIdentifierTable().get(IDef.BuiltinName);
398   RVVIntrinsicDecl->addAttr(
399       BuiltinAliasAttr::CreateImplicit(S.Context, &IntrinsicII));
400 
401   // Add to symbol table.
402   LR.addDecl(RVVIntrinsicDecl);
403 }
404 
CreateIntrinsicIfFound(LookupResult & LR,IdentifierInfo * II,Preprocessor & PP)405 bool RISCVIntrinsicManagerImpl::CreateIntrinsicIfFound(LookupResult &LR,
406                                                        IdentifierInfo *II,
407                                                        Preprocessor &PP) {
408   StringRef Name = II->getName();
409 
410   // Lookup the function name from the overload intrinsics first.
411   auto OvIItr = OverloadIntrinsics.find(Name);
412   if (OvIItr != OverloadIntrinsics.end()) {
413     const RVVOverloadIntrinsicDef &OvIntrinsicDef = OvIItr->second;
414     for (auto Index : OvIntrinsicDef.Indexes)
415       CreateRVVIntrinsicDecl(LR, II, PP, Index,
416                              /*IsOverload*/ true);
417 
418     // If we added overloads, need to resolve the lookup result.
419     LR.resolveKind();
420     return true;
421   }
422 
423   // Lookup the function name from the intrinsics.
424   auto Itr = Intrinsics.find(Name);
425   if (Itr != Intrinsics.end()) {
426     CreateRVVIntrinsicDecl(LR, II, PP, Itr->second,
427                            /*IsOverload*/ false);
428     return true;
429   }
430 
431   // It's not an RVV intrinsics.
432   return false;
433 }
434 
435 namespace clang {
436 std::unique_ptr<clang::sema::RISCVIntrinsicManager>
CreateRISCVIntrinsicManager(Sema & S)437 CreateRISCVIntrinsicManager(Sema &S) {
438   return std::make_unique<RISCVIntrinsicManagerImpl>(S);
439 }
440 } // namespace clang
441