1 //==- SemaRISCVVectorLookup.cpp - Name Lookup for RISC-V Vector Intrinsic -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements name lookup for RISC-V vector intrinsic.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Decl.h"
15 #include "clang/Basic/Builtins.h"
16 #include "clang/Basic/TargetInfo.h"
17 #include "clang/Lex/Preprocessor.h"
18 #include "clang/Sema/Lookup.h"
19 #include "clang/Sema/RISCVIntrinsicManager.h"
20 #include "clang/Sema/Sema.h"
21 #include "clang/Support/RISCVVIntrinsicUtils.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include <optional>
24 #include <string>
25 #include <vector>
26
27 using namespace llvm;
28 using namespace clang;
29 using namespace clang::RISCV;
30
31 namespace {
32
33 // Function definition of a RVV intrinsic.
34 struct RVVIntrinsicDef {
35 /// Full function name with suffix, e.g. vadd_vv_i32m1.
36 std::string Name;
37
38 /// Overloaded function name, e.g. vadd.
39 std::string OverloadName;
40
41 /// Mapping to which clang built-in function, e.g. __builtin_rvv_vadd.
42 std::string BuiltinName;
43
44 /// Function signature, first element is return type.
45 RVVTypes Signature;
46 };
47
48 struct RVVOverloadIntrinsicDef {
49 // Indexes of RISCVIntrinsicManagerImpl::IntrinsicList.
50 SmallVector<size_t, 8> Indexes;
51 };
52
53 } // namespace
54
55 static const PrototypeDescriptor RVVSignatureTable[] = {
56 #define DECL_SIGNATURE_TABLE
57 #include "clang/Basic/riscv_vector_builtin_sema.inc"
58 #undef DECL_SIGNATURE_TABLE
59 };
60
61 static const RVVIntrinsicRecord RVVIntrinsicRecords[] = {
62 #define DECL_INTRINSIC_RECORDS
63 #include "clang/Basic/riscv_vector_builtin_sema.inc"
64 #undef DECL_INTRINSIC_RECORDS
65 };
66
67 // Get subsequence of signature table.
ProtoSeq2ArrayRef(uint16_t Index,uint8_t Length)68 static ArrayRef<PrototypeDescriptor> ProtoSeq2ArrayRef(uint16_t Index,
69 uint8_t Length) {
70 return ArrayRef(&RVVSignatureTable[Index], Length);
71 }
72
RVVType2Qual(ASTContext & Context,const RVVType * Type)73 static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
74 QualType QT;
75 switch (Type->getScalarType()) {
76 case ScalarTypeKind::Void:
77 QT = Context.VoidTy;
78 break;
79 case ScalarTypeKind::Size_t:
80 QT = Context.getSizeType();
81 break;
82 case ScalarTypeKind::Ptrdiff_t:
83 QT = Context.getPointerDiffType();
84 break;
85 case ScalarTypeKind::UnsignedLong:
86 QT = Context.UnsignedLongTy;
87 break;
88 case ScalarTypeKind::SignedLong:
89 QT = Context.LongTy;
90 break;
91 case ScalarTypeKind::Boolean:
92 QT = Context.BoolTy;
93 break;
94 case ScalarTypeKind::SignedInteger:
95 QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), true);
96 break;
97 case ScalarTypeKind::UnsignedInteger:
98 QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
99 break;
100 case ScalarTypeKind::Float:
101 switch (Type->getElementBitwidth()) {
102 case 64:
103 QT = Context.DoubleTy;
104 break;
105 case 32:
106 QT = Context.FloatTy;
107 break;
108 case 16:
109 QT = Context.Float16Ty;
110 break;
111 default:
112 llvm_unreachable("Unsupported floating point width.");
113 }
114 break;
115 case Invalid:
116 llvm_unreachable("Unhandled type.");
117 }
118 if (Type->isVector())
119 QT = Context.getScalableVectorType(QT, *Type->getScale());
120
121 if (Type->isConstant())
122 QT = Context.getConstType(QT);
123
124 // Transform the type to a pointer as the last step, if necessary.
125 if (Type->isPointer())
126 QT = Context.getPointerType(QT);
127
128 return QT;
129 }
130
131 namespace {
132 class RISCVIntrinsicManagerImpl : public sema::RISCVIntrinsicManager {
133 private:
134 Sema &S;
135 ASTContext &Context;
136 RVVTypeCache TypeCache;
137
138 // List of all RVV intrinsic.
139 std::vector<RVVIntrinsicDef> IntrinsicList;
140 // Mapping function name to index of IntrinsicList.
141 StringMap<size_t> Intrinsics;
142 // Mapping function name to RVVOverloadIntrinsicDef.
143 StringMap<RVVOverloadIntrinsicDef> OverloadIntrinsics;
144
145 // Create IntrinsicList
146 void InitIntrinsicList();
147
148 // Create RVVIntrinsicDef.
149 void InitRVVIntrinsic(const RVVIntrinsicRecord &Record, StringRef SuffixStr,
150 StringRef OverloadedSuffixStr, bool IsMask,
151 RVVTypes &Types, bool HasPolicy, Policy PolicyAttrs);
152
153 // Create FunctionDecl for a vector intrinsic.
154 void CreateRVVIntrinsicDecl(LookupResult &LR, IdentifierInfo *II,
155 Preprocessor &PP, unsigned Index,
156 bool IsOverload);
157
158 public:
RISCVIntrinsicManagerImpl(clang::Sema & S)159 RISCVIntrinsicManagerImpl(clang::Sema &S) : S(S), Context(S.Context) {
160 InitIntrinsicList();
161 }
162
163 // Create RISC-V vector intrinsic and insert into symbol table if found, and
164 // return true, otherwise return false.
165 bool CreateIntrinsicIfFound(LookupResult &LR, IdentifierInfo *II,
166 Preprocessor &PP) override;
167 };
168 } // namespace
169
InitIntrinsicList()170 void RISCVIntrinsicManagerImpl::InitIntrinsicList() {
171 const TargetInfo &TI = Context.getTargetInfo();
172 bool HasVectorFloat32 = TI.hasFeature("zve32f");
173 bool HasVectorFloat64 = TI.hasFeature("zve64d");
174 bool HasZvfh = TI.hasFeature("experimental-zvfh");
175 bool HasRV64 = TI.hasFeature("64bit");
176 bool HasFullMultiply = TI.hasFeature("v");
177
178 // Construction of RVVIntrinsicRecords need to sync with createRVVIntrinsics
179 // in RISCVVEmitter.cpp.
180 for (auto &Record : RVVIntrinsicRecords) {
181 // Create Intrinsics for each type and LMUL.
182 BasicType BaseType = BasicType::Unknown;
183 ArrayRef<PrototypeDescriptor> BasicProtoSeq =
184 ProtoSeq2ArrayRef(Record.PrototypeIndex, Record.PrototypeLength);
185 ArrayRef<PrototypeDescriptor> SuffixProto =
186 ProtoSeq2ArrayRef(Record.SuffixIndex, Record.SuffixLength);
187 ArrayRef<PrototypeDescriptor> OverloadedSuffixProto = ProtoSeq2ArrayRef(
188 Record.OverloadedSuffixIndex, Record.OverloadedSuffixSize);
189
190 PolicyScheme UnMaskedPolicyScheme =
191 static_cast<PolicyScheme>(Record.UnMaskedPolicyScheme);
192 PolicyScheme MaskedPolicyScheme =
193 static_cast<PolicyScheme>(Record.MaskedPolicyScheme);
194
195 const Policy DefaultPolicy;
196
197 llvm::SmallVector<PrototypeDescriptor> ProtoSeq =
198 RVVIntrinsic::computeBuiltinTypes(BasicProtoSeq, /*IsMasked=*/false,
199 /*HasMaskedOffOperand=*/false,
200 Record.HasVL, Record.NF,
201 UnMaskedPolicyScheme, DefaultPolicy);
202
203 llvm::SmallVector<PrototypeDescriptor> ProtoMaskSeq =
204 RVVIntrinsic::computeBuiltinTypes(
205 BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand,
206 Record.HasVL, Record.NF, MaskedPolicyScheme, DefaultPolicy);
207
208 bool UnMaskedHasPolicy = UnMaskedPolicyScheme != PolicyScheme::SchemeNone;
209 bool MaskedHasPolicy = MaskedPolicyScheme != PolicyScheme::SchemeNone;
210 SmallVector<Policy> SupportedUnMaskedPolicies =
211 RVVIntrinsic::getSupportedUnMaskedPolicies();
212 SmallVector<Policy> SupportedMaskedPolicies =
213 RVVIntrinsic::getSupportedMaskedPolicies(Record.HasTailPolicy,
214 Record.HasMaskPolicy);
215
216 for (unsigned int TypeRangeMaskShift = 0;
217 TypeRangeMaskShift <= static_cast<unsigned int>(BasicType::MaxOffset);
218 ++TypeRangeMaskShift) {
219 unsigned int BaseTypeI = 1 << TypeRangeMaskShift;
220 BaseType = static_cast<BasicType>(BaseTypeI);
221
222 if ((BaseTypeI & Record.TypeRangeMask) != BaseTypeI)
223 continue;
224
225 // Check requirement.
226 if (BaseType == BasicType::Float16 && !HasZvfh)
227 continue;
228
229 if (BaseType == BasicType::Float32 && !HasVectorFloat32)
230 continue;
231
232 if (BaseType == BasicType::Float64 && !HasVectorFloat64)
233 continue;
234
235 if (((Record.RequiredExtensions & RVV_REQ_RV64) == RVV_REQ_RV64) &&
236 !HasRV64)
237 continue;
238
239 if ((BaseType == BasicType::Int64) &&
240 ((Record.RequiredExtensions & RVV_REQ_FullMultiply) ==
241 RVV_REQ_FullMultiply) &&
242 !HasFullMultiply)
243 continue;
244
245 // Expanded with different LMUL.
246 for (int Log2LMUL = -3; Log2LMUL <= 3; Log2LMUL++) {
247 if (!(Record.Log2LMULMask & (1 << (Log2LMUL + 3))))
248 continue;
249
250 std::optional<RVVTypes> Types =
251 TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoSeq);
252
253 // Ignored to create new intrinsic if there are any illegal types.
254 if (!Types.has_value())
255 continue;
256
257 std::string SuffixStr = RVVIntrinsic::getSuffixStr(
258 TypeCache, BaseType, Log2LMUL, SuffixProto);
259 std::string OverloadedSuffixStr = RVVIntrinsic::getSuffixStr(
260 TypeCache, BaseType, Log2LMUL, OverloadedSuffixProto);
261
262 // Create non-masked intrinsic.
263 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, false, *Types,
264 UnMaskedHasPolicy, DefaultPolicy);
265
266 // Create non-masked policy intrinsic.
267 if (Record.UnMaskedPolicyScheme != PolicyScheme::SchemeNone) {
268 for (auto P : SupportedUnMaskedPolicies) {
269 llvm::SmallVector<PrototypeDescriptor> PolicyPrototype =
270 RVVIntrinsic::computeBuiltinTypes(
271 BasicProtoSeq, /*IsMasked=*/false,
272 /*HasMaskedOffOperand=*/false, Record.HasVL, Record.NF,
273 UnMaskedPolicyScheme, P);
274 std::optional<RVVTypes> PolicyTypes = TypeCache.computeTypes(
275 BaseType, Log2LMUL, Record.NF, PolicyPrototype);
276 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr,
277 /*IsMask=*/false, *PolicyTypes, UnMaskedHasPolicy,
278 P);
279 }
280 }
281 if (!Record.HasMasked)
282 continue;
283 // Create masked intrinsic.
284 std::optional<RVVTypes> MaskTypes =
285 TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoMaskSeq);
286 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, true,
287 *MaskTypes, MaskedHasPolicy, DefaultPolicy);
288 if (Record.MaskedPolicyScheme == PolicyScheme::SchemeNone)
289 continue;
290 // Create masked policy intrinsic.
291 for (auto P : SupportedMaskedPolicies) {
292 llvm::SmallVector<PrototypeDescriptor> PolicyPrototype =
293 RVVIntrinsic::computeBuiltinTypes(
294 BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand,
295 Record.HasVL, Record.NF, MaskedPolicyScheme, P);
296 std::optional<RVVTypes> PolicyTypes = TypeCache.computeTypes(
297 BaseType, Log2LMUL, Record.NF, PolicyPrototype);
298 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr,
299 /*IsMask=*/true, *PolicyTypes, MaskedHasPolicy, P);
300 }
301 } // End for different LMUL
302 } // End for different TypeRange
303 }
304 }
305
306 // Compute name and signatures for intrinsic with practical types.
InitRVVIntrinsic(const RVVIntrinsicRecord & Record,StringRef SuffixStr,StringRef OverloadedSuffixStr,bool IsMasked,RVVTypes & Signature,bool HasPolicy,Policy PolicyAttrs)307 void RISCVIntrinsicManagerImpl::InitRVVIntrinsic(
308 const RVVIntrinsicRecord &Record, StringRef SuffixStr,
309 StringRef OverloadedSuffixStr, bool IsMasked, RVVTypes &Signature,
310 bool HasPolicy, Policy PolicyAttrs) {
311 // Function name, e.g. vadd_vv_i32m1.
312 std::string Name = Record.Name;
313 if (!SuffixStr.empty())
314 Name += "_" + SuffixStr.str();
315
316 // Overloaded function name, e.g. vadd.
317 std::string OverloadedName;
318 if (!Record.OverloadedName)
319 OverloadedName = StringRef(Record.Name).split("_").first.str();
320 else
321 OverloadedName = Record.OverloadedName;
322 if (!OverloadedSuffixStr.empty())
323 OverloadedName += "_" + OverloadedSuffixStr.str();
324
325 // clang built-in function name, e.g. __builtin_rvv_vadd.
326 std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name);
327
328 RVVIntrinsic::updateNamesAndPolicy(IsMasked, HasPolicy, Name, BuiltinName,
329 OverloadedName, PolicyAttrs);
330
331 // Put into IntrinsicList.
332 size_t Index = IntrinsicList.size();
333 IntrinsicList.push_back({Name, OverloadedName, BuiltinName, Signature});
334
335 // Creating mapping to Intrinsics.
336 Intrinsics.insert({Name, Index});
337
338 // Get the RVVOverloadIntrinsicDef.
339 RVVOverloadIntrinsicDef &OverloadIntrinsicDef =
340 OverloadIntrinsics[OverloadedName];
341
342 // And added the index.
343 OverloadIntrinsicDef.Indexes.push_back(Index);
344 }
345
CreateRVVIntrinsicDecl(LookupResult & LR,IdentifierInfo * II,Preprocessor & PP,unsigned Index,bool IsOverload)346 void RISCVIntrinsicManagerImpl::CreateRVVIntrinsicDecl(LookupResult &LR,
347 IdentifierInfo *II,
348 Preprocessor &PP,
349 unsigned Index,
350 bool IsOverload) {
351 ASTContext &Context = S.Context;
352 RVVIntrinsicDef &IDef = IntrinsicList[Index];
353 RVVTypes Sigs = IDef.Signature;
354 size_t SigLength = Sigs.size();
355 RVVType *ReturnType = Sigs[0];
356 QualType RetType = RVVType2Qual(Context, ReturnType);
357 SmallVector<QualType, 8> ArgTypes;
358 QualType BuiltinFuncType;
359
360 // Skip return type, and convert RVVType to QualType for arguments.
361 for (size_t i = 1; i < SigLength; ++i)
362 ArgTypes.push_back(RVVType2Qual(Context, Sigs[i]));
363
364 FunctionProtoType::ExtProtoInfo PI(
365 Context.getDefaultCallingConvention(false, false, true));
366
367 PI.Variadic = false;
368
369 SourceLocation Loc = LR.getNameLoc();
370 BuiltinFuncType = Context.getFunctionType(RetType, ArgTypes, PI);
371 DeclContext *Parent = Context.getTranslationUnitDecl();
372
373 FunctionDecl *RVVIntrinsicDecl = FunctionDecl::Create(
374 Context, Parent, Loc, Loc, II, BuiltinFuncType, /*TInfo=*/nullptr,
375 SC_Extern, S.getCurFPFeatures().isFPConstrained(),
376 /*isInlineSpecified*/ false,
377 /*hasWrittenPrototype*/ true);
378
379 // Create Decl objects for each parameter, adding them to the
380 // FunctionDecl.
381 const auto *FP = cast<FunctionProtoType>(BuiltinFuncType);
382 SmallVector<ParmVarDecl *, 8> ParmList;
383 for (unsigned IParm = 0, E = FP->getNumParams(); IParm != E; ++IParm) {
384 ParmVarDecl *Parm =
385 ParmVarDecl::Create(Context, RVVIntrinsicDecl, Loc, Loc, nullptr,
386 FP->getParamType(IParm), nullptr, SC_None, nullptr);
387 Parm->setScopeInfo(0, IParm);
388 ParmList.push_back(Parm);
389 }
390 RVVIntrinsicDecl->setParams(ParmList);
391
392 // Add function attributes.
393 if (IsOverload)
394 RVVIntrinsicDecl->addAttr(OverloadableAttr::CreateImplicit(Context));
395
396 // Setup alias to __builtin_rvv_*
397 IdentifierInfo &IntrinsicII = PP.getIdentifierTable().get(IDef.BuiltinName);
398 RVVIntrinsicDecl->addAttr(
399 BuiltinAliasAttr::CreateImplicit(S.Context, &IntrinsicII));
400
401 // Add to symbol table.
402 LR.addDecl(RVVIntrinsicDecl);
403 }
404
CreateIntrinsicIfFound(LookupResult & LR,IdentifierInfo * II,Preprocessor & PP)405 bool RISCVIntrinsicManagerImpl::CreateIntrinsicIfFound(LookupResult &LR,
406 IdentifierInfo *II,
407 Preprocessor &PP) {
408 StringRef Name = II->getName();
409
410 // Lookup the function name from the overload intrinsics first.
411 auto OvIItr = OverloadIntrinsics.find(Name);
412 if (OvIItr != OverloadIntrinsics.end()) {
413 const RVVOverloadIntrinsicDef &OvIntrinsicDef = OvIItr->second;
414 for (auto Index : OvIntrinsicDef.Indexes)
415 CreateRVVIntrinsicDecl(LR, II, PP, Index,
416 /*IsOverload*/ true);
417
418 // If we added overloads, need to resolve the lookup result.
419 LR.resolveKind();
420 return true;
421 }
422
423 // Lookup the function name from the intrinsics.
424 auto Itr = Intrinsics.find(Name);
425 if (Itr != Intrinsics.end()) {
426 CreateRVVIntrinsicDecl(LR, II, PP, Itr->second,
427 /*IsOverload*/ false);
428 return true;
429 }
430
431 // It's not an RVV intrinsics.
432 return false;
433 }
434
435 namespace clang {
436 std::unique_ptr<clang::sema::RISCVIntrinsicManager>
CreateRISCVIntrinsicManager(Sema & S)437 CreateRISCVIntrinsicManager(Sema &S) {
438 return std::make_unique<RISCVIntrinsicManagerImpl>(S);
439 }
440 } // namespace clang
441