1 //===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include "clang/Sema/HLSLExternalSemaSource.h"
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/Attr.h"
15 #include "clang/AST/DeclCXX.h"
16 #include "clang/Basic/AttrKinds.h"
17 #include "clang/Basic/HLSLRuntime.h"
18 #include "clang/Sema/Lookup.h"
19 #include "clang/Sema/Sema.h"
20 #include "llvm/Frontend/HLSL/HLSLResource.h"
21 
22 #include <functional>
23 
24 using namespace clang;
25 using namespace llvm::hlsl;
26 
27 namespace {
28 
29 struct TemplateParameterListBuilder;
30 
31 struct BuiltinTypeDeclBuilder {
32   CXXRecordDecl *Record = nullptr;
33   ClassTemplateDecl *Template = nullptr;
34   ClassTemplateDecl *PrevTemplate = nullptr;
35   NamespaceDecl *HLSLNamespace = nullptr;
36   llvm::StringMap<FieldDecl *> Fields;
37 
38   BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) {
39     Record->startDefinition();
40     Template = Record->getDescribedClassTemplate();
41   }
42 
43   BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name)
44       : HLSLNamespace(Namespace) {
45     ASTContext &AST = S.getASTContext();
46     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
47 
48     LookupResult Result(S, &II, SourceLocation(), Sema::LookupTagName);
49     CXXRecordDecl *PrevDecl = nullptr;
50     if (S.LookupQualifiedName(Result, HLSLNamespace)) {
51       NamedDecl *Found = Result.getFoundDecl();
52       if (auto *TD = dyn_cast<ClassTemplateDecl>(Found)) {
53         PrevDecl = TD->getTemplatedDecl();
54         PrevTemplate = TD;
55       } else
56         PrevDecl = dyn_cast<CXXRecordDecl>(Found);
57       assert(PrevDecl && "Unexpected lookup result type.");
58     }
59 
60     if (PrevDecl && PrevDecl->isCompleteDefinition()) {
61       Record = PrevDecl;
62       return;
63     }
64 
65     Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class,
66                                    HLSLNamespace, SourceLocation(),
67                                    SourceLocation(), &II, PrevDecl, true);
68     Record->setImplicit(true);
69     Record->setLexicalDeclContext(HLSLNamespace);
70     Record->setHasExternalLexicalStorage();
71 
72     // Don't let anyone derive from built-in types.
73     Record->addAttr(FinalAttr::CreateImplicit(AST, SourceRange(),
74                                               FinalAttr::Keyword_final));
75   }
76 
77   ~BuiltinTypeDeclBuilder() {
78     if (HLSLNamespace && !Template && Record->getDeclContext() == HLSLNamespace)
79       HLSLNamespace->addDecl(Record);
80   }
81 
82   BuiltinTypeDeclBuilder &
83   addMemberVariable(StringRef Name, QualType Type,
84                     AccessSpecifier Access = AccessSpecifier::AS_private) {
85     if (Record->isCompleteDefinition())
86       return *this;
87     assert(Record->isBeingDefined() &&
88            "Definition must be started before adding members!");
89     ASTContext &AST = Record->getASTContext();
90 
91     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
92     TypeSourceInfo *MemTySource =
93         AST.getTrivialTypeSourceInfo(Type, SourceLocation());
94     auto *Field = FieldDecl::Create(
95         AST, Record, SourceLocation(), SourceLocation(), &II, Type, MemTySource,
96         nullptr, false, InClassInitStyle::ICIS_NoInit);
97     Field->setAccess(Access);
98     Field->setImplicit(true);
99     Record->addDecl(Field);
100     Fields[Name] = Field;
101     return *this;
102   }
103 
104   BuiltinTypeDeclBuilder &
105   addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) {
106     if (Record->isCompleteDefinition())
107       return *this;
108     QualType Ty = Record->getASTContext().VoidPtrTy;
109     if (Template) {
110       if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
111               Template->getTemplateParameters()->getParam(0)))
112         Ty = Record->getASTContext().getPointerType(
113             QualType(TTD->getTypeForDecl(), 0));
114     }
115     return addMemberVariable("h", Ty, Access);
116   }
117 
118   BuiltinTypeDeclBuilder &
119   annotateResourceClass(HLSLResourceAttr::ResourceClass RC,
120                         HLSLResourceAttr::ResourceKind RK) {
121     if (Record->isCompleteDefinition())
122       return *this;
123     Record->addAttr(
124         HLSLResourceAttr::CreateImplicit(Record->getASTContext(), RC, RK));
125     return *this;
126   }
127 
128   static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
129                                             StringRef Name) {
130     CXXScopeSpec SS;
131     IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
132     DeclarationNameInfo NameInfo =
133         DeclarationNameInfo(DeclarationName(&II), SourceLocation());
134     LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
135     S.LookupParsedName(R, S.getCurScope(), &SS, false);
136     assert(R.isSingleResult() &&
137            "Since this is a builtin it should always resolve!");
138     auto *VD = cast<ValueDecl>(R.getFoundDecl());
139     QualType Ty = VD->getType();
140     return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
141                                VD, false, NameInfo, Ty, VK_PRValue);
142   }
143 
144   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
145     return IntegerLiteral::Create(
146         AST,
147         llvm::APInt(AST.getIntWidth(AST.UnsignedCharTy),
148                     static_cast<uint8_t>(RC)),
149         AST.UnsignedCharTy, SourceLocation());
150   }
151 
152   BuiltinTypeDeclBuilder &addDefaultHandleConstructor(Sema &S,
153                                                       ResourceClass RC) {
154     if (Record->isCompleteDefinition())
155       return *this;
156     ASTContext &AST = Record->getASTContext();
157 
158     QualType ConstructorType =
159         AST.getFunctionType(AST.VoidTy, {}, FunctionProtoType::ExtProtoInfo());
160 
161     CanQualType CanTy = Record->getTypeForDecl()->getCanonicalTypeUnqualified();
162     DeclarationName Name = AST.DeclarationNames.getCXXConstructorName(CanTy);
163     CXXConstructorDecl *Constructor = CXXConstructorDecl::Create(
164         AST, Record, SourceLocation(),
165         DeclarationNameInfo(Name, SourceLocation()), ConstructorType,
166         AST.getTrivialTypeSourceInfo(ConstructorType, SourceLocation()),
167         ExplicitSpecifier(), false, true, false,
168         ConstexprSpecKind::Unspecified);
169 
170     DeclRefExpr *Fn =
171         lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle");
172 
173     Expr *RCExpr = emitResourceClassExpr(AST, RC);
174     Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
175                                   SourceLocation(), FPOptionsOverride());
176 
177     CXXThisExpr *This = new (AST) CXXThisExpr(
178         SourceLocation(),
179         Constructor->getThisType().getTypePtr()->getPointeeType(), true);
180     This->setValueKind(ExprValueKind::VK_LValue);
181     Expr *Handle = MemberExpr::CreateImplicit(AST, This, false, Fields["h"],
182                                               Fields["h"]->getType(), VK_LValue,
183                                               OK_Ordinary);
184 
185     // If the handle isn't a void pointer, cast the builtin result to the
186     // correct type.
187     if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) {
188       Call = CXXStaticCastExpr::Create(
189           AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr,
190           AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()),
191           FPOptionsOverride(), SourceLocation(), SourceLocation(),
192           SourceRange());
193     }
194 
195     BinaryOperator *Assign = BinaryOperator::Create(
196         AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary,
197         SourceLocation(), FPOptionsOverride());
198 
199     Constructor->setBody(
200         CompoundStmt::Create(AST, {Assign}, FPOptionsOverride(),
201                              SourceLocation(), SourceLocation()));
202     Constructor->setAccess(AccessSpecifier::AS_public);
203     Record->addDecl(Constructor);
204     return *this;
205   }
206 
207   BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
208     if (Record->isCompleteDefinition())
209       return *this;
210     addArraySubscriptOperator(true);
211     addArraySubscriptOperator(false);
212     return *this;
213   }
214 
215   BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
216     if (Record->isCompleteDefinition())
217       return *this;
218     assert(Fields.count("h") > 0 &&
219            "Subscript operator must be added after the handle.");
220 
221     FieldDecl *Handle = Fields["h"];
222     ASTContext &AST = Record->getASTContext();
223 
224     assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy &&
225            "Not yet supported for void pointer handles.");
226 
227     QualType ElemTy =
228         QualType(Handle->getType()->getPointeeOrArrayElementType(), 0);
229     QualType ReturnTy = ElemTy;
230 
231     FunctionProtoType::ExtProtoInfo ExtInfo;
232 
233     // Subscript operators return references to elements, const makes the
234     // reference and method const so that the underlying data is not mutable.
235     ReturnTy = AST.getLValueReferenceType(ReturnTy);
236     if (IsConst) {
237       ExtInfo.TypeQuals.addConst();
238       ReturnTy.addConst();
239     }
240 
241     QualType MethodTy =
242         AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
243     auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
244     auto *MethodDecl = CXXMethodDecl::Create(
245         AST, Record, SourceLocation(),
246         DeclarationNameInfo(
247             AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
248             SourceLocation()),
249         MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
250         SourceLocation());
251 
252     IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
253     auto *IdxParam = ParmVarDecl::Create(
254         AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
255         &II, AST.UnsignedIntTy,
256         AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
257         SC_None, nullptr);
258     MethodDecl->setParams({IdxParam});
259 
260     // Also add the parameter to the function prototype.
261     auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
262     FnProtoLoc.setParam(0, IdxParam);
263 
264     auto *This = new (AST) CXXThisExpr(
265         SourceLocation(),
266         MethodDecl->getThisType().getTypePtr()->getPointeeType(), true);
267     This->setValueKind(ExprValueKind::VK_LValue);
268     auto *HandleAccess = MemberExpr::CreateImplicit(
269         AST, This, false, Handle, Handle->getType(), VK_LValue, OK_Ordinary);
270 
271     auto *IndexExpr = DeclRefExpr::Create(
272         AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false,
273         DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()),
274         AST.UnsignedIntTy, VK_PRValue);
275 
276     auto *Array =
277         new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue,
278                                      OK_Ordinary, SourceLocation());
279 
280     auto *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr);
281 
282     MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
283                                              SourceLocation(),
284                                              SourceLocation()));
285     MethodDecl->setLexicalDeclContext(Record);
286     MethodDecl->setAccess(AccessSpecifier::AS_public);
287     MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit(
288         AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
289     Record->addDecl(MethodDecl);
290 
291     return *this;
292   }
293 
294   BuiltinTypeDeclBuilder &startDefinition() {
295     if (Record->isCompleteDefinition())
296       return *this;
297     Record->startDefinition();
298     return *this;
299   }
300 
301   BuiltinTypeDeclBuilder &completeDefinition() {
302     if (Record->isCompleteDefinition())
303       return *this;
304     assert(Record->isBeingDefined() &&
305            "Definition must be started before completing it.");
306 
307     Record->completeDefinition();
308     return *this;
309   }
310 
311   TemplateParameterListBuilder addTemplateArgumentList();
312 };
313 
314 struct TemplateParameterListBuilder {
315   BuiltinTypeDeclBuilder &Builder;
316   ASTContext &AST;
317   llvm::SmallVector<NamedDecl *> Params;
318 
319   TemplateParameterListBuilder(BuiltinTypeDeclBuilder &RB)
320       : Builder(RB), AST(RB.Record->getASTContext()) {}
321 
322   ~TemplateParameterListBuilder() { finalizeTemplateArgs(); }
323 
324   TemplateParameterListBuilder &
325   addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) {
326     if (Builder.Record->isCompleteDefinition())
327       return *this;
328     unsigned Position = static_cast<unsigned>(Params.size());
329     auto *Decl = TemplateTypeParmDecl::Create(
330         AST, Builder.Record->getDeclContext(), SourceLocation(),
331         SourceLocation(), /* TemplateDepth */ 0, Position,
332         &AST.Idents.get(Name, tok::TokenKind::identifier), /* Typename */ false,
333         /* ParameterPack */ false);
334     if (!DefaultValue.isNull())
335       Decl->setDefaultArgument(AST.getTrivialTypeSourceInfo(DefaultValue));
336 
337     Params.emplace_back(Decl);
338     return *this;
339   }
340 
341   BuiltinTypeDeclBuilder &finalizeTemplateArgs() {
342     if (Params.empty())
343       return Builder;
344     auto *ParamList =
345         TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
346                                       Params, SourceLocation(), nullptr);
347     Builder.Template = ClassTemplateDecl::Create(
348         AST, Builder.Record->getDeclContext(), SourceLocation(),
349         DeclarationName(Builder.Record->getIdentifier()), ParamList,
350         Builder.Record);
351     Builder.Record->setDescribedClassTemplate(Builder.Template);
352     Builder.Template->setImplicit(true);
353     Builder.Template->setLexicalDeclContext(Builder.Record->getDeclContext());
354     // NOTE: setPreviousDecl before addDecl so new decl replace old decl when
355     // make visible.
356     Builder.Template->setPreviousDecl(Builder.PrevTemplate);
357     Builder.Record->getDeclContext()->addDecl(Builder.Template);
358     Params.clear();
359 
360     QualType T = Builder.Template->getInjectedClassNameSpecialization();
361     T = AST.getInjectedClassNameType(Builder.Record, T);
362 
363     return Builder;
364   }
365 };
366 
367 TemplateParameterListBuilder BuiltinTypeDeclBuilder::addTemplateArgumentList() {
368   return TemplateParameterListBuilder(*this);
369 }
370 } // namespace
371 
372 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
373 
374 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
375   SemaPtr = &S;
376   ASTContext &AST = SemaPtr->getASTContext();
377   // If the translation unit has external storage force external decls to load.
378   if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage())
379     (void)AST.getTranslationUnitDecl()->decls_begin();
380 
381   IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
382   LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName);
383   NamespaceDecl *PrevDecl = nullptr;
384   if (S.LookupQualifiedName(Result, AST.getTranslationUnitDecl()))
385     PrevDecl = Result.getAsSingle<NamespaceDecl>();
386   HLSLNamespace = NamespaceDecl::Create(
387       AST, AST.getTranslationUnitDecl(), /*Inline=*/false, SourceLocation(),
388       SourceLocation(), &HLSL, PrevDecl, /*Nested=*/false);
389   HLSLNamespace->setImplicit(true);
390   HLSLNamespace->setHasExternalLexicalStorage();
391   AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
392 
393   // Force external decls in the HLSL namespace to load from the PCH.
394   (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
395   defineTrivialHLSLTypes();
396   forwardDeclareHLSLTypes();
397 
398   // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
399   // built in types inside a namespace, but we are planning to change that in
400   // the near future. In order to be source compatible older versions of HLSL
401   // will need to implicitly use the hlsl namespace. For now in clang everything
402   // will get added to the namespace, and we can remove the using directive for
403   // future language versions to match HLSL's evolution.
404   auto *UsingDecl = UsingDirectiveDecl::Create(
405       AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
406       NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
407       AST.getTranslationUnitDecl());
408 
409   AST.getTranslationUnitDecl()->addDecl(UsingDecl);
410 }
411 
412 void HLSLExternalSemaSource::defineHLSLVectorAlias() {
413   ASTContext &AST = SemaPtr->getASTContext();
414 
415   llvm::SmallVector<NamedDecl *> TemplateParams;
416 
417   auto *TypeParam = TemplateTypeParmDecl::Create(
418       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
419       &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
420   TypeParam->setDefaultArgument(AST.getTrivialTypeSourceInfo(AST.FloatTy));
421 
422   TemplateParams.emplace_back(TypeParam);
423 
424   auto *SizeParam = NonTypeTemplateParmDecl::Create(
425       AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
426       &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
427       false, AST.getTrivialTypeSourceInfo(AST.IntTy));
428   Expr *LiteralExpr =
429       IntegerLiteral::Create(AST, llvm::APInt(AST.getIntWidth(AST.IntTy), 4),
430                              AST.IntTy, SourceLocation());
431   SizeParam->setDefaultArgument(LiteralExpr);
432   TemplateParams.emplace_back(SizeParam);
433 
434   auto *ParamList =
435       TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
436                                     TemplateParams, SourceLocation(), nullptr);
437 
438   IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
439 
440   QualType AliasType = AST.getDependentSizedExtVectorType(
441       AST.getTemplateTypeParmType(0, 0, false, TypeParam),
442       DeclRefExpr::Create(
443           AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
444           DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
445           AST.IntTy, VK_LValue),
446       SourceLocation());
447 
448   auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
449                                        SourceLocation(), &II,
450                                        AST.getTrivialTypeSourceInfo(AliasType));
451   Record->setImplicit(true);
452 
453   auto *Template =
454       TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
455                                     Record->getIdentifier(), ParamList, Record);
456 
457   Record->setDescribedAliasTemplate(Template);
458   Template->setImplicit(true);
459   Template->setLexicalDeclContext(Record->getDeclContext());
460   HLSLNamespace->addDecl(Template);
461 }
462 
463 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
464   defineHLSLVectorAlias();
465 
466   ResourceDecl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Resource")
467                      .startDefinition()
468                      .addHandleMember(AccessSpecifier::AS_public)
469                      .completeDefinition()
470                      .Record;
471 }
472 
473 void HLSLExternalSemaSource::forwardDeclareHLSLTypes() {
474   CXXRecordDecl *Decl;
475   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
476              .addTemplateArgumentList()
477              .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy)
478              .finalizeTemplateArgs()
479              .Record;
480   if (!Decl->isCompleteDefinition())
481     Completions.insert(
482         std::make_pair(Decl->getCanonicalDecl(),
483                        std::bind(&HLSLExternalSemaSource::completeBufferType,
484                                  this, std::placeholders::_1)));
485 }
486 
487 void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
488   if (!isa<CXXRecordDecl>(Tag))
489     return;
490   auto Record = cast<CXXRecordDecl>(Tag);
491 
492   // If this is a specialization, we need to get the underlying templated
493   // declaration and complete that.
494   if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record))
495     Record = TDecl->getSpecializedTemplate()->getTemplatedDecl();
496   Record = Record->getCanonicalDecl();
497   auto It = Completions.find(Record);
498   if (It == Completions.end())
499     return;
500   It->second(Record);
501 }
502 
503 void HLSLExternalSemaSource::completeBufferType(CXXRecordDecl *Record) {
504   BuiltinTypeDeclBuilder(Record)
505       .addHandleMember()
506       .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV)
507       .addArraySubscriptOperators()
508       .annotateResourceClass(HLSLResourceAttr::UAV,
509                              HLSLResourceAttr::TypedBuffer)
510       .completeDefinition();
511 }
512