1 /*========================== begin_copyright_notice ============================
2
3 Copyright (C) 2021 Intel Corporation
4
5 SPDX-License-Identifier: MIT
6
7 ============================= end_copyright_notice ===========================*/
8
9 #include "vc/GenXOpts/Utils/KernelInfo.h"
10 #include "llvmWrapper/IR/Function.h"
11
12 namespace llvm {
13 namespace genx {
14
findNode(const Function & F,StringRef KernelsMDName,unsigned KernelRefOp,unsigned MustExceed)15 static MDNode *findNode(const Function &F, StringRef KernelsMDName,
16 unsigned KernelRefOp, unsigned MustExceed) {
17 NamedMDNode *Named = F.getParent()->getNamedMetadata(KernelsMDName);
18 // It's expected that in any case internal and external metadata nodes have
19 // already been created by createInternalMD() or vc-intrinsics.
20 if (!Named)
21 return nullptr;
22 auto Res = std::find_if(
23 Named->op_begin(), Named->op_end(),
24 [&F, KernelRefOp, MustExceed](const MDNode *InternalMD) {
25 return InternalMD->getNumOperands() >= MustExceed &&
26 &F == getValueAsMetadata(InternalMD->getOperand(KernelRefOp));
27 });
28 return Res != Named->op_end() ? *Res : nullptr;
29 }
30
findInternalNode(const Function & F)31 static MDNode *findInternalNode(const Function &F) {
32 return findNode(F, FunctionMD::GenXKernelInternal,
33 internal::KernelMDOp::FunctionRef,
34 internal::KernelMDOp::Last);
35 }
36
findExternalNode(const Function & F)37 static MDNode *findExternalNode(const Function &F) {
38 return findNode(F, FunctionMD::GenXKernels, KernelMDOp::FunctionRef,
39 KernelMDOp::ArgTypeDescs);
40 }
41
42 namespace internal {
createInternalMD(Function & F)43 void createInternalMD(Function &F) {
44 IGC_ASSERT_MESSAGE(!findInternalNode(F),
45 "Internal node has already been created!");
46
47 auto &Ctx = F.getContext();
48
49 // Create nullptr values by default.
50 SmallVector<Metadata *, internal::KernelMDOp::Last> KernelInternalMD(
51 internal::KernelMDOp::Last, nullptr);
52 KernelInternalMD[internal::KernelMDOp::FunctionRef] =
53 ValueAsMetadata::get(&F);
54
55 MDNode *InternalNode = MDNode::get(Ctx, KernelInternalMD);
56 NamedMDNode *KernelMDs =
57 F.getParent()->getOrInsertNamedMetadata(FunctionMD::GenXKernelInternal);
58 KernelMDs->addOperand(InternalNode);
59 }
60
replaceInternalFunctionRef(const Function & From,Function & To)61 void replaceInternalFunctionRef(const Function &From, Function &To) {
62 MDNode *Node = findInternalNode(From);
63 IGC_ASSERT_MESSAGE(Node, "Replacement was called for non existing in kernel "
64 "internal metadata function");
65 Node->replaceOperandWith(internal::KernelMDOp::FunctionRef,
66 ValueAsMetadata::get(&To));
67 }
68 } // namespace internal
69
replaceFunctionRefMD(const Function & From,Function & To)70 void replaceFunctionRefMD(const Function &From, Function &To) {
71 Module *M = To.getParent();
72 NamedMDNode *Named = M->getNamedMetadata(FunctionMD::GenXKernels);
73 IGC_ASSERT(Named);
74
75 auto Res =
76 std::find_if(Named->op_begin(), Named->op_end(), [&From](MDNode *Node) {
77 auto *NodeVal =
78 cast<ValueAsMetadata>(Node->getOperand(KernelMDOp::FunctionRef))
79 ->getValue();
80 auto *F = cast<Function>(NodeVal);
81 return &From == F;
82 });
83 IGC_ASSERT_MESSAGE(Res != Named->op_end(),
84 "Cannot find MD for 'From' function");
85
86 MDNode *FromNode = *Res;
87 FromNode->replaceOperandWith(KernelMDOp::FunctionRef,
88 ValueAsMetadata::get(&To));
89
90 internal::replaceInternalFunctionRef(From, To);
91 }
92
93 template <typename RetTy = unsigned>
extractConstantIntMD(const MDOperand & Op)94 static RetTy extractConstantIntMD(const MDOperand &Op) {
95 const auto *V = getValueAsMetadata<ConstantInt>(Op);
96 IGC_ASSERT_MESSAGE(V, "Unexpected null value in metadata");
97 return static_cast<RetTy>(V->getZExtValue());
98 }
99
100 template <typename Cont>
extractConstantsFromMDNode(const MDNode * N,Cont & C)101 static void extractConstantsFromMDNode(const MDNode *N, Cont &C) {
102 if (!N)
103 return;
104 using ValTy = typename Cont::value_type;
105 std::transform(
106 N->op_begin(), N->op_end(), std::back_inserter(C),
107 [](const MDOperand &Op) { return extractConstantIntMD<ValTy>(Op); });
108 }
109
110 static ImplicitLinearizationInfo
extractImplicitLinearizationArg(const Function & F,const MDOperand & ImplicitArg)111 extractImplicitLinearizationArg(const Function &F,
112 const MDOperand &ImplicitArg) {
113 auto *MD = cast<MDNode>(ImplicitArg.get());
114 IGC_ASSERT(MD->getNumOperands() == internal::LinearizationMDOp::Last);
115 Constant *ArgNoValue =
116 cast<ConstantAsMetadata>(
117 MD->getOperand(internal::LinearizationMDOp::Argument).get())
118 ->getValue();
119 unsigned ArgNo = cast<ConstantInt>(ArgNoValue)->getZExtValue();
120 Argument *Arg = IGCLLVM::getArg(F, ArgNo);
121 auto *OffsetMD = cast<ConstantAsMetadata>(
122 MD->getOperand(internal::LinearizationMDOp::Offset).get());
123 return ImplicitLinearizationInfo{Arg,
124 cast<ConstantInt>(OffsetMD->getValue())};
125 }
126
127 static ArgToImplicitLinearization::value_type
extractArgLinearization(const Function & F,const MDOperand & MDOp)128 extractArgLinearization(const Function &F, const MDOperand &MDOp) {
129 auto *ArgLinearizationMD = cast<MDNode>(MDOp.get());
130 IGC_ASSERT(ArgLinearizationMD->getNumOperands() ==
131 internal::ArgLinearizationMDOp::Last);
132 Constant *ExplicitArgNo =
133 cast<ConstantAsMetadata>(
134 ArgLinearizationMD
135 ->getOperand(internal::ArgLinearizationMDOp::Explicit)
136 .get())
137 ->getValue();
138 Argument *ExplicitArg =
139 IGCLLVM::getArg(F, cast<ConstantInt>(ExplicitArgNo)->getZExtValue());
140 auto *LinMD = cast<MDNode>(
141 ArgLinearizationMD
142 ->getOperand(internal::ArgLinearizationMDOp::Linearization)
143 .get());
144 LinearizedArgInfo Info;
145 std::transform(LinMD->op_begin(), LinMD->op_end(), std::back_inserter(Info),
146 [&F](const MDOperand &ImplicitArg) {
147 return extractImplicitLinearizationArg(F, ImplicitArg);
148 });
149 return std::make_pair(ExplicitArg, std::move(Info));
150 }
151
152 static ArgToImplicitLinearization
extractLinearizationMD(const Function & F,const MDNode * LinearizationNode)153 extractLinearizationMD(const Function &F, const MDNode *LinearizationNode) {
154 IGC_ASSERT(LinearizationNode);
155 ArgToImplicitLinearization Linearization;
156 std::transform(
157 LinearizationNode->op_begin(), LinearizationNode->op_end(),
158 std::inserter(Linearization, Linearization.end()),
159 [&F](const MDOperand &MDOp) { return extractArgLinearization(F, MDOp); });
160 return Linearization;
161 }
162
KernelMetadata(const Function * F)163 KernelMetadata::KernelMetadata(const Function *F) {
164 if (!genx::isKernel(F))
165 return;
166
167 ExternalNode = findExternalNode(*F);
168 if (!ExternalNode)
169 return;
170
171 // ExternalNode is the metadata node for F, and it has the required number of
172 // operands.
173 this->F = F;
174 IsKernel = true;
175 if (MDString *MDS =
176 dyn_cast<MDString>(ExternalNode->getOperand(KernelMDOp::Name)))
177 Name = MDS->getString();
178 if (ConstantInt *Sz = getValueAsMetadata<ConstantInt>(
179 ExternalNode->getOperand(KernelMDOp::SLMSize)))
180 SLMSize = Sz->getZExtValue();
181 // Build the argument kinds and offsets arrays that should correspond to the
182 // function arguments (both explicit and implicit)
183 MDNode *KindsNode =
184 dyn_cast<MDNode>(ExternalNode->getOperand(KernelMDOp::ArgKinds));
185 MDNode *OffsetsNode =
186 dyn_cast<MDNode>(ExternalNode->getOperand(KernelMDOp::ArgOffsets));
187 MDNode *InputOutputKinds =
188 dyn_cast<MDNode>(ExternalNode->getOperand(KernelMDOp::ArgIOKinds));
189 MDNode *ArgDescNode =
190 dyn_cast<MDNode>(ExternalNode->getOperand(KernelMDOp::ArgTypeDescs));
191
192 MDNode *IndexesNode = nullptr;
193 MDNode *OffsetInArgsNode = nullptr;
194 MDNode *LinearizationNode = nullptr;
195 MDNode *BTIndicesNode = nullptr;
196 InternalNode = findInternalNode(*F);
197 IGC_ASSERT_MESSAGE(InternalNode,
198 "Internal node is expected to have already been created!");
199
200 IndexesNode = cast_or_null<MDNode>(
201 InternalNode->getOperand(internal::KernelMDOp::ArgIndexes));
202 OffsetInArgsNode = cast_or_null<MDNode>(
203 InternalNode->getOperand(internal::KernelMDOp::OffsetInArgs));
204 LinearizationNode = cast_or_null<MDNode>(
205 InternalNode->getOperand(internal::KernelMDOp::LinearizationArgs));
206 BTIndicesNode = cast_or_null<MDNode>(
207 InternalNode->getOperand(internal::KernelMDOp::BTIndices));
208
209 IGC_ASSERT(KindsNode);
210
211 // These should have the same number of operands if they exist.
212 IGC_ASSERT(!OffsetsNode ||
213 KindsNode->getNumOperands() == OffsetsNode->getNumOperands());
214 IGC_ASSERT(!OffsetInArgsNode ||
215 KindsNode->getNumOperands() == OffsetInArgsNode->getNumOperands());
216 IGC_ASSERT(!IndexesNode ||
217 KindsNode->getNumOperands() == IndexesNode->getNumOperands());
218 IGC_ASSERT(!BTIndicesNode ||
219 KindsNode->getNumOperands() == BTIndicesNode->getNumOperands());
220
221 extractConstantsFromMDNode(KindsNode, ArgKinds);
222 extractConstantsFromMDNode(OffsetsNode, ArgOffsets);
223 extractConstantsFromMDNode(OffsetInArgsNode, OffsetInArgs);
224 extractConstantsFromMDNode(IndexesNode, ArgIndexes);
225 extractConstantsFromMDNode(BTIndicesNode, BTIs);
226
227 IGC_ASSERT(InputOutputKinds);
228 IGC_ASSERT(KindsNode->getNumOperands() >= InputOutputKinds->getNumOperands());
229 extractConstantsFromMDNode(InputOutputKinds, ArgIOKinds);
230
231 IGC_ASSERT(ArgDescNode);
232 for (unsigned i = 0, e = ArgDescNode->getNumOperands(); i < e; ++i) {
233 MDString *MDS = dyn_cast<MDString>(ArgDescNode->getOperand(i));
234 IGC_ASSERT(MDS);
235 ArgTypeDescs.push_back(MDS->getString());
236 }
237 if (LinearizationNode)
238 Linearization = extractLinearizationMD(*F, LinearizationNode);
239 }
240
createArgLinearizationMD(const ImplicitLinearizationInfo & Info)241 static MDNode *createArgLinearizationMD(const ImplicitLinearizationInfo &Info) {
242 auto &Ctx = Info.Arg->getContext();
243 auto *I32Ty = Type::getInt32Ty(Ctx);
244 Metadata *ArgMD =
245 ConstantAsMetadata::get(ConstantInt::get(I32Ty, Info.Arg->getArgNo()));
246 Metadata *OffsetMD = ConstantAsMetadata::get(Info.Offset);
247 return MDNode::get(Ctx, {ArgMD, OffsetMD});
248 }
249
updateLinearizationMD(ArgToImplicitLinearization && Lin)250 void KernelMetadata::updateLinearizationMD(ArgToImplicitLinearization &&Lin) {
251 Linearization = std::move(Lin);
252
253 std::vector<Metadata *> LinMDs;
254 LinMDs.reserve(Linearization.size());
255 auto &Ctx = F->getContext();
256 for (const auto &ArgLin : Linearization) {
257 std::vector<Metadata *> ArgLinMDs;
258 ArgLinMDs.reserve(ArgLin.second.size());
259 std::transform(ArgLin.second.begin(), ArgLin.second.end(),
260 std::back_inserter(ArgLinMDs), createArgLinearizationMD);
261 auto *I32Ty = Type::getInt32Ty(Ctx);
262 Metadata *ExplicitArgMD = ConstantAsMetadata::get(
263 ConstantInt::get(I32Ty, ArgLin.first->getArgNo()));
264 Metadata *ExplicitArgLinMD = MDNode::get(Ctx, ArgLinMDs);
265 LinMDs.push_back(MDNode::get(Ctx, {ExplicitArgMD, ExplicitArgLinMD}));
266 }
267 InternalNode->replaceOperandWith(internal::KernelMDOp::LinearizationArgs,
268 MDNode::get(Ctx, LinMDs));
269 }
270
271 template <typename InputIt>
updateArgsMD(InputIt Begin,InputIt End,MDNode * Node,unsigned NodeOpNo) const272 void KernelMetadata::updateArgsMD(InputIt Begin, InputIt End, MDNode *Node,
273 unsigned NodeOpNo) const {
274 IGC_ASSERT(F);
275 IGC_ASSERT(Node);
276 IGC_ASSERT_MESSAGE(std::distance(Begin, End) == getNumArgs(),
277 "Mismatch between metadata for kernel and number of args");
278 IGC_ASSERT(Node->getNumOperands() > NodeOpNo);
279 auto &Ctx = F->getContext();
280 auto *I32Ty = Type::getInt32Ty(Ctx);
281 SmallVector<Metadata *, 8> NewMD;
282 std::transform(Begin, End, std::back_inserter(NewMD), [I32Ty](auto Value) {
283 return ValueAsMetadata::getConstant(ConstantInt::get(I32Ty, Value));
284 });
285 MDNode *NewNode = MDNode::get(Ctx, NewMD);
286 Node->replaceOperandWith(NodeOpNo, NewNode);
287 }
288
updateArgOffsetsMD(SmallVectorImpl<unsigned> && Offsets)289 void KernelMetadata::updateArgOffsetsMD(SmallVectorImpl<unsigned> &&Offsets) {
290 ArgOffsets = std::move(Offsets);
291 updateArgsMD(ArgOffsets.begin(), ArgOffsets.end(), ExternalNode,
292 KernelMDOp::ArgOffsets);
293 }
updateArgKindsMD(SmallVectorImpl<unsigned> && Kinds)294 void KernelMetadata::updateArgKindsMD(SmallVectorImpl<unsigned> &&Kinds) {
295 ArgKinds = std::move(Kinds);
296 updateArgsMD(ArgKinds.begin(), ArgKinds.end(), ExternalNode,
297 KernelMDOp::ArgKinds);
298 }
updateArgIndexesMD(SmallVectorImpl<unsigned> && Indexes)299 void KernelMetadata::updateArgIndexesMD(SmallVectorImpl<unsigned> &&Indexes) {
300 ArgIndexes = std::move(Indexes);
301 updateArgsMD(ArgIndexes.begin(), ArgIndexes.end(), InternalNode,
302 internal::KernelMDOp::ArgIndexes);
303 }
updateOffsetInArgsMD(SmallVectorImpl<unsigned> && Offsets)304 void KernelMetadata::updateOffsetInArgsMD(SmallVectorImpl<unsigned> &&Offsets) {
305 OffsetInArgs = std::move(Offsets);
306 updateArgsMD(OffsetInArgs.begin(), OffsetInArgs.end(), InternalNode,
307 internal::KernelMDOp::OffsetInArgs);
308 }
updateBTIndicesMD(std::vector<int> && BTIndices)309 void KernelMetadata::updateBTIndicesMD(std::vector<int> &&BTIndices) {
310 BTIs = std::move(BTIndices);
311 updateArgsMD(BTIs.begin(), BTIs.end(), InternalNode,
312 internal::KernelMDOp::BTIndices);
313 }
314
315 } // namespace genx
316 } // namespace llvm
317