1 //===- Target/DirectX/PointerTypeAnalisis.cpp - PointerType analysis ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Analysis pass to assign types to opaque pointers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PointerTypeAnalysis.h"
14 #include "llvm/IR/Constants.h"
15 #include "llvm/IR/Instructions.h"
16 
17 using namespace llvm;
18 using namespace llvm::dxil;
19 
20 namespace {
21 
22 // Classifies the type of the value passed in by walking the value's users to
23 // find a typed instruction to materialize a type from.
24 Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
25   assert(V->getType()->isPointerTy() &&
26          "classifyPointerType called with non-pointer");
27   auto It = Map.find(V);
28   if (It != Map.end())
29     return It->second;
30 
31   Type *PointeeTy = nullptr;
32   if (auto *Inst = dyn_cast<GetElementPtrInst>(V)) {
33     if (!Inst->getResultElementType()->isPointerTy())
34       PointeeTy = Inst->getResultElementType();
35   } else if (auto *Inst = dyn_cast<AllocaInst>(V)) {
36     PointeeTy = Inst->getAllocatedType();
37   } else if (auto *GV = dyn_cast<GlobalVariable>(V)) {
38     PointeeTy = GV->getValueType();
39   }
40 
41   for (const auto *User : V->users()) {
42     Type *NewPointeeTy = nullptr;
43     if (const auto *Inst = dyn_cast<LoadInst>(User)) {
44       NewPointeeTy = Inst->getType();
45     } else if (const auto *Inst = dyn_cast<StoreInst>(User)) {
46       NewPointeeTy = Inst->getValueOperand()->getType();
47       // When store value is ptr type, cannot get more type info.
48       if (NewPointeeTy->isPointerTy())
49         continue;
50     } else if (const auto *Inst = dyn_cast<GetElementPtrInst>(User)) {
51       NewPointeeTy = Inst->getSourceElementType();
52     }
53     if (NewPointeeTy) {
54       // HLSL doesn't support pointers, so it is unlikely to get more than one
55       // or two levels of indirection in the IR. Because of this, recursion is
56       // pretty safe.
57       if (NewPointeeTy->isPointerTy()) {
58         PointeeTy = classifyPointerType(User, Map);
59         break;
60       }
61       if (!PointeeTy)
62         PointeeTy = NewPointeeTy;
63       else if (PointeeTy != NewPointeeTy)
64         PointeeTy = Type::getInt8Ty(V->getContext());
65     }
66   }
67   // If we were unable to determine the pointee type, set to i8
68   if (!PointeeTy)
69     PointeeTy = Type::getInt8Ty(V->getContext());
70   auto *TypedPtrTy =
71       TypedPointerType::get(PointeeTy, V->getType()->getPointerAddressSpace());
72 
73   Map[V] = TypedPtrTy;
74   return TypedPtrTy;
75 }
76 
77 // This function constructs a function type accepting typed pointers. It only
78 // handles function arguments and return types, and assigns the function type to
79 // the function's value in the type map.
80 Type *classifyFunctionType(const Function &F, PointerTypeMap &Map) {
81   auto It = Map.find(&F);
82   if (It != Map.end())
83     return It->second;
84 
85   SmallVector<Type *, 8> NewArgs;
86   Type *RetTy = F.getReturnType();
87   LLVMContext &Ctx = F.getContext();
88   if (RetTy->isPointerTy()) {
89     RetTy = nullptr;
90     for (const auto &B : F) {
91       const auto *RetInst = dyn_cast_or_null<ReturnInst>(B.getTerminator());
92       if (!RetInst)
93         continue;
94 
95       Type *NewRetTy = classifyPointerType(RetInst->getReturnValue(), Map);
96       if (!RetTy)
97         RetTy = NewRetTy;
98       else if (RetTy != NewRetTy)
99         RetTy = TypedPointerType::get(
100             Type::getInt8Ty(Ctx), F.getReturnType()->getPointerAddressSpace());
101     }
102     // For function decl.
103     if (!RetTy)
104       RetTy = TypedPointerType::get(
105           Type::getInt8Ty(Ctx), F.getReturnType()->getPointerAddressSpace());
106   }
107   for (auto &A : F.args()) {
108     Type *ArgTy = A.getType();
109     if (ArgTy->isPointerTy())
110       ArgTy = classifyPointerType(&A, Map);
111     NewArgs.push_back(ArgTy);
112   }
113   auto *TypedPtrTy =
114       TypedPointerType::get(FunctionType::get(RetTy, NewArgs, false), 0);
115   Map[&F] = TypedPtrTy;
116   return TypedPtrTy;
117 }
118 } // anonymous namespace
119 
120 static Type *classifyConstantWithOpaquePtr(const Constant *C,
121                                            PointerTypeMap &Map) {
122   // FIXME: support ConstantPointerNull which could map to more than one
123   // TypedPointerType.
124   // See https://github.com/llvm/llvm-project/issues/57942.
125   if (isa<ConstantPointerNull>(C))
126     return TypedPointerType::get(Type::getInt8Ty(C->getContext()),
127                                  C->getType()->getPointerAddressSpace());
128 
129   // Skip ConstantData which cannot have opaque ptr.
130   if (isa<ConstantData>(C))
131     return C->getType();
132 
133   auto It = Map.find(C);
134   if (It != Map.end())
135     return It->second;
136 
137   if (const auto *F = dyn_cast<Function>(C))
138     return classifyFunctionType(*F, Map);
139 
140   Type *Ty = C->getType();
141   Type *TargetTy = nullptr;
142   if (auto *CS = dyn_cast<ConstantStruct>(C)) {
143     SmallVector<Type *> EltTys;
144     for (unsigned int I = 0; I < CS->getNumOperands(); ++I) {
145       const Constant *Elt = C->getAggregateElement(I);
146       Type *EltTy = classifyConstantWithOpaquePtr(Elt, Map);
147       EltTys.emplace_back(EltTy);
148     }
149     TargetTy = StructType::get(C->getContext(), EltTys);
150   } else if (auto *CA = dyn_cast<ConstantAggregate>(C)) {
151 
152     Type *TargetEltTy = nullptr;
153     for (auto &Elt : CA->operands()) {
154       Type *EltTy = classifyConstantWithOpaquePtr(cast<Constant>(&Elt), Map);
155       assert(TargetEltTy == EltTy || TargetEltTy == nullptr);
156       TargetEltTy = EltTy;
157     }
158 
159     if (auto *AT = dyn_cast<ArrayType>(Ty)) {
160       TargetTy = ArrayType::get(TargetEltTy, AT->getNumElements());
161     } else {
162       // Not struct, not array, must be vector here.
163       auto *VT = cast<VectorType>(Ty);
164       TargetTy = VectorType::get(TargetEltTy, VT);
165     }
166   }
167   // Must have a target ty when map.
168   assert(TargetTy && "PointerTypeAnalyisis failed to identify target type");
169 
170   // Same type, no need to map.
171   if (TargetTy == Ty)
172     return Ty;
173 
174   Map[C] = TargetTy;
175   return TargetTy;
176 }
177 
178 static void classifyGlobalCtorPointerType(const GlobalVariable &GV,
179                                           PointerTypeMap &Map) {
180   const auto *CA = cast<ConstantArray>(GV.getInitializer());
181   // Type for global ctor should be array of { i32, void ()*, i8* }.
182   Type *CtorArrayTy = classifyConstantWithOpaquePtr(CA, Map);
183 
184   // Map the global type.
185   Map[&GV] = TypedPointerType::get(CtorArrayTy,
186                                    GV.getType()->getPointerAddressSpace());
187 }
188 
189 PointerTypeMap PointerTypeAnalysis::run(const Module &M) {
190   PointerTypeMap Map;
191   for (auto &G : M.globals()) {
192     if (G.getType()->isPointerTy())
193       classifyPointerType(&G, Map);
194     if (G.getName() == "llvm.global_ctors")
195       classifyGlobalCtorPointerType(G, Map);
196   }
197 
198   for (auto &F : M) {
199     classifyFunctionType(F, Map);
200 
201     for (const auto &B : F) {
202       for (const auto &I : B) {
203         if (I.getType()->isPointerTy())
204           classifyPointerType(&I, Map);
205       }
206     }
207   }
208   return Map;
209 }
210