1 /*========================== begin_copyright_notice ============================
2 
3 Copyright (C) 2021 Intel Corporation
4 
5 SPDX-License-Identifier: MIT
6 
7 ============================= end_copyright_notice ===========================*/
8 
9 #include "vc/GenXOpts/Utils/KernelInfo.h"
10 #include "llvmWrapper/IR/Function.h"
11 
12 namespace llvm {
13 namespace genx {
14 
findNode(const Function & F,StringRef KernelsMDName,unsigned KernelRefOp,unsigned MustExceed)15 static MDNode *findNode(const Function &F, StringRef KernelsMDName,
16                         unsigned KernelRefOp, unsigned MustExceed) {
17   NamedMDNode *Named = F.getParent()->getNamedMetadata(KernelsMDName);
18   // It's expected that in any case internal and external metadata nodes have
19   // already been created by createInternalMD() or vc-intrinsics.
20   if (!Named)
21     return nullptr;
22   auto Res = std::find_if(
23       Named->op_begin(), Named->op_end(),
24       [&F, KernelRefOp, MustExceed](const MDNode *InternalMD) {
25         return InternalMD->getNumOperands() >= MustExceed &&
26                &F == getValueAsMetadata(InternalMD->getOperand(KernelRefOp));
27       });
28   return Res != Named->op_end() ? *Res : nullptr;
29 }
30 
findInternalNode(const Function & F)31 static MDNode *findInternalNode(const Function &F) {
32   return findNode(F, FunctionMD::GenXKernelInternal,
33                   internal::KernelMDOp::FunctionRef,
34                   internal::KernelMDOp::Last);
35 }
36 
findExternalNode(const Function & F)37 static MDNode *findExternalNode(const Function &F) {
38   return findNode(F, FunctionMD::GenXKernels, KernelMDOp::FunctionRef,
39                   KernelMDOp::ArgTypeDescs);
40 }
41 
42 namespace internal {
createInternalMD(Function & F)43 void createInternalMD(Function &F) {
44   IGC_ASSERT_MESSAGE(!findInternalNode(F),
45                      "Internal node has already been created!");
46 
47   auto &Ctx = F.getContext();
48 
49   // Create nullptr values by default.
50   SmallVector<Metadata *, internal::KernelMDOp::Last> KernelInternalMD(
51       internal::KernelMDOp::Last, nullptr);
52   KernelInternalMD[internal::KernelMDOp::FunctionRef] =
53       ValueAsMetadata::get(&F);
54 
55   MDNode *InternalNode = MDNode::get(Ctx, KernelInternalMD);
56   NamedMDNode *KernelMDs =
57       F.getParent()->getOrInsertNamedMetadata(FunctionMD::GenXKernelInternal);
58   KernelMDs->addOperand(InternalNode);
59 }
60 
replaceInternalFunctionRef(const Function & From,Function & To)61 void replaceInternalFunctionRef(const Function &From, Function &To) {
62   MDNode *Node = findInternalNode(From);
63   IGC_ASSERT_MESSAGE(Node, "Replacement was called for non existing in kernel "
64                            "internal metadata function");
65   Node->replaceOperandWith(internal::KernelMDOp::FunctionRef,
66                            ValueAsMetadata::get(&To));
67 }
68 } // namespace internal
69 
replaceFunctionRefMD(const Function & From,Function & To)70 void replaceFunctionRefMD(const Function &From, Function &To) {
71   Module *M = To.getParent();
72   NamedMDNode *Named = M->getNamedMetadata(FunctionMD::GenXKernels);
73   IGC_ASSERT(Named);
74 
75   auto Res =
76       std::find_if(Named->op_begin(), Named->op_end(), [&From](MDNode *Node) {
77         auto *NodeVal =
78             cast<ValueAsMetadata>(Node->getOperand(KernelMDOp::FunctionRef))
79                 ->getValue();
80         auto *F = cast<Function>(NodeVal);
81         return &From == F;
82       });
83   IGC_ASSERT_MESSAGE(Res != Named->op_end(),
84                      "Cannot find MD for 'From' function");
85 
86   MDNode *FromNode = *Res;
87   FromNode->replaceOperandWith(KernelMDOp::FunctionRef,
88                                ValueAsMetadata::get(&To));
89 
90   internal::replaceInternalFunctionRef(From, To);
91 }
92 
93 template <typename RetTy = unsigned>
extractConstantIntMD(const MDOperand & Op)94 static RetTy extractConstantIntMD(const MDOperand &Op) {
95   const auto *V = getValueAsMetadata<ConstantInt>(Op);
96   IGC_ASSERT_MESSAGE(V, "Unexpected null value in metadata");
97   return static_cast<RetTy>(V->getZExtValue());
98 }
99 
100 template <typename Cont>
extractConstantsFromMDNode(const MDNode * N,Cont & C)101 static void extractConstantsFromMDNode(const MDNode *N, Cont &C) {
102   if (!N)
103     return;
104   using ValTy = typename Cont::value_type;
105   std::transform(
106       N->op_begin(), N->op_end(), std::back_inserter(C),
107       [](const MDOperand &Op) { return extractConstantIntMD<ValTy>(Op); });
108 }
109 
110 static ImplicitLinearizationInfo
extractImplicitLinearizationArg(const Function & F,const MDOperand & ImplicitArg)111 extractImplicitLinearizationArg(const Function &F,
112                                 const MDOperand &ImplicitArg) {
113   auto *MD = cast<MDNode>(ImplicitArg.get());
114   IGC_ASSERT(MD->getNumOperands() == internal::LinearizationMDOp::Last);
115   Constant *ArgNoValue =
116       cast<ConstantAsMetadata>(
117           MD->getOperand(internal::LinearizationMDOp::Argument).get())
118           ->getValue();
119   unsigned ArgNo = cast<ConstantInt>(ArgNoValue)->getZExtValue();
120   Argument *Arg = IGCLLVM::getArg(F, ArgNo);
121   auto *OffsetMD = cast<ConstantAsMetadata>(
122       MD->getOperand(internal::LinearizationMDOp::Offset).get());
123   return ImplicitLinearizationInfo{Arg,
124                                    cast<ConstantInt>(OffsetMD->getValue())};
125 }
126 
127 static ArgToImplicitLinearization::value_type
extractArgLinearization(const Function & F,const MDOperand & MDOp)128 extractArgLinearization(const Function &F, const MDOperand &MDOp) {
129   auto *ArgLinearizationMD = cast<MDNode>(MDOp.get());
130   IGC_ASSERT(ArgLinearizationMD->getNumOperands() ==
131              internal::ArgLinearizationMDOp::Last);
132   Constant *ExplicitArgNo =
133       cast<ConstantAsMetadata>(
134           ArgLinearizationMD
135               ->getOperand(internal::ArgLinearizationMDOp::Explicit)
136               .get())
137           ->getValue();
138   Argument *ExplicitArg =
139       IGCLLVM::getArg(F, cast<ConstantInt>(ExplicitArgNo)->getZExtValue());
140   auto *LinMD = cast<MDNode>(
141       ArgLinearizationMD
142           ->getOperand(internal::ArgLinearizationMDOp::Linearization)
143           .get());
144   LinearizedArgInfo Info;
145   std::transform(LinMD->op_begin(), LinMD->op_end(), std::back_inserter(Info),
146                  [&F](const MDOperand &ImplicitArg) {
147                    return extractImplicitLinearizationArg(F, ImplicitArg);
148                  });
149   return std::make_pair(ExplicitArg, std::move(Info));
150 }
151 
152 static ArgToImplicitLinearization
extractLinearizationMD(const Function & F,const MDNode * LinearizationNode)153 extractLinearizationMD(const Function &F, const MDNode *LinearizationNode) {
154   IGC_ASSERT(LinearizationNode);
155   ArgToImplicitLinearization Linearization;
156   std::transform(
157       LinearizationNode->op_begin(), LinearizationNode->op_end(),
158       std::inserter(Linearization, Linearization.end()),
159       [&F](const MDOperand &MDOp) { return extractArgLinearization(F, MDOp); });
160   return Linearization;
161 }
162 
KernelMetadata(const Function * F)163 KernelMetadata::KernelMetadata(const Function *F) {
164   if (!genx::isKernel(F))
165     return;
166 
167   ExternalNode = findExternalNode(*F);
168   if (!ExternalNode)
169     return;
170 
171   // ExternalNode is the metadata node for F, and it has the required number of
172   // operands.
173   this->F = F;
174   IsKernel = true;
175   if (MDString *MDS =
176           dyn_cast<MDString>(ExternalNode->getOperand(KernelMDOp::Name)))
177     Name = MDS->getString();
178   if (ConstantInt *Sz = getValueAsMetadata<ConstantInt>(
179           ExternalNode->getOperand(KernelMDOp::SLMSize)))
180     SLMSize = Sz->getZExtValue();
181   // Build the argument kinds and offsets arrays that should correspond to the
182   // function arguments (both explicit and implicit)
183   MDNode *KindsNode =
184       dyn_cast<MDNode>(ExternalNode->getOperand(KernelMDOp::ArgKinds));
185   MDNode *OffsetsNode =
186       dyn_cast<MDNode>(ExternalNode->getOperand(KernelMDOp::ArgOffsets));
187   MDNode *InputOutputKinds =
188       dyn_cast<MDNode>(ExternalNode->getOperand(KernelMDOp::ArgIOKinds));
189   MDNode *ArgDescNode =
190       dyn_cast<MDNode>(ExternalNode->getOperand(KernelMDOp::ArgTypeDescs));
191 
192   MDNode *IndexesNode = nullptr;
193   MDNode *OffsetInArgsNode = nullptr;
194   MDNode *LinearizationNode = nullptr;
195   MDNode *BTIndicesNode = nullptr;
196   InternalNode = findInternalNode(*F);
197   IGC_ASSERT_MESSAGE(InternalNode,
198                      "Internal node is expected to have already been created!");
199 
200   IndexesNode = cast_or_null<MDNode>(
201       InternalNode->getOperand(internal::KernelMDOp::ArgIndexes));
202   OffsetInArgsNode = cast_or_null<MDNode>(
203       InternalNode->getOperand(internal::KernelMDOp::OffsetInArgs));
204   LinearizationNode = cast_or_null<MDNode>(
205       InternalNode->getOperand(internal::KernelMDOp::LinearizationArgs));
206   BTIndicesNode = cast_or_null<MDNode>(
207       InternalNode->getOperand(internal::KernelMDOp::BTIndices));
208 
209   IGC_ASSERT(KindsNode);
210 
211   // These should have the same number of operands if they exist.
212   IGC_ASSERT(!OffsetsNode ||
213              KindsNode->getNumOperands() == OffsetsNode->getNumOperands());
214   IGC_ASSERT(!OffsetInArgsNode ||
215              KindsNode->getNumOperands() == OffsetInArgsNode->getNumOperands());
216   IGC_ASSERT(!IndexesNode ||
217              KindsNode->getNumOperands() == IndexesNode->getNumOperands());
218   IGC_ASSERT(!BTIndicesNode ||
219              KindsNode->getNumOperands() == BTIndicesNode->getNumOperands());
220 
221   extractConstantsFromMDNode(KindsNode, ArgKinds);
222   extractConstantsFromMDNode(OffsetsNode, ArgOffsets);
223   extractConstantsFromMDNode(OffsetInArgsNode, OffsetInArgs);
224   extractConstantsFromMDNode(IndexesNode, ArgIndexes);
225   extractConstantsFromMDNode(BTIndicesNode, BTIs);
226 
227   IGC_ASSERT(InputOutputKinds);
228   IGC_ASSERT(KindsNode->getNumOperands() >= InputOutputKinds->getNumOperands());
229   extractConstantsFromMDNode(InputOutputKinds, ArgIOKinds);
230 
231   IGC_ASSERT(ArgDescNode);
232   for (unsigned i = 0, e = ArgDescNode->getNumOperands(); i < e; ++i) {
233     MDString *MDS = dyn_cast<MDString>(ArgDescNode->getOperand(i));
234     IGC_ASSERT(MDS);
235     ArgTypeDescs.push_back(MDS->getString());
236   }
237   if (LinearizationNode)
238     Linearization = extractLinearizationMD(*F, LinearizationNode);
239 }
240 
createArgLinearizationMD(const ImplicitLinearizationInfo & Info)241 static MDNode *createArgLinearizationMD(const ImplicitLinearizationInfo &Info) {
242   auto &Ctx = Info.Arg->getContext();
243   auto *I32Ty = Type::getInt32Ty(Ctx);
244   Metadata *ArgMD =
245       ConstantAsMetadata::get(ConstantInt::get(I32Ty, Info.Arg->getArgNo()));
246   Metadata *OffsetMD = ConstantAsMetadata::get(Info.Offset);
247   return MDNode::get(Ctx, {ArgMD, OffsetMD});
248 }
249 
updateLinearizationMD(ArgToImplicitLinearization && Lin)250 void KernelMetadata::updateLinearizationMD(ArgToImplicitLinearization &&Lin) {
251   Linearization = std::move(Lin);
252 
253   std::vector<Metadata *> LinMDs;
254   LinMDs.reserve(Linearization.size());
255   auto &Ctx = F->getContext();
256   for (const auto &ArgLin : Linearization) {
257     std::vector<Metadata *> ArgLinMDs;
258     ArgLinMDs.reserve(ArgLin.second.size());
259     std::transform(ArgLin.second.begin(), ArgLin.second.end(),
260                    std::back_inserter(ArgLinMDs), createArgLinearizationMD);
261     auto *I32Ty = Type::getInt32Ty(Ctx);
262     Metadata *ExplicitArgMD = ConstantAsMetadata::get(
263         ConstantInt::get(I32Ty, ArgLin.first->getArgNo()));
264     Metadata *ExplicitArgLinMD = MDNode::get(Ctx, ArgLinMDs);
265     LinMDs.push_back(MDNode::get(Ctx, {ExplicitArgMD, ExplicitArgLinMD}));
266   }
267   InternalNode->replaceOperandWith(internal::KernelMDOp::LinearizationArgs,
268                                    MDNode::get(Ctx, LinMDs));
269 }
270 
271 template <typename InputIt>
updateArgsMD(InputIt Begin,InputIt End,MDNode * Node,unsigned NodeOpNo) const272 void KernelMetadata::updateArgsMD(InputIt Begin, InputIt End, MDNode *Node,
273                                   unsigned NodeOpNo) const {
274   IGC_ASSERT(F);
275   IGC_ASSERT(Node);
276   IGC_ASSERT_MESSAGE(std::distance(Begin, End) == getNumArgs(),
277                      "Mismatch between metadata for kernel and number of args");
278   IGC_ASSERT(Node->getNumOperands() > NodeOpNo);
279   auto &Ctx = F->getContext();
280   auto *I32Ty = Type::getInt32Ty(Ctx);
281   SmallVector<Metadata *, 8> NewMD;
282   std::transform(Begin, End, std::back_inserter(NewMD), [I32Ty](auto Value) {
283     return ValueAsMetadata::getConstant(ConstantInt::get(I32Ty, Value));
284   });
285   MDNode *NewNode = MDNode::get(Ctx, NewMD);
286   Node->replaceOperandWith(NodeOpNo, NewNode);
287 }
288 
updateArgOffsetsMD(SmallVectorImpl<unsigned> && Offsets)289 void KernelMetadata::updateArgOffsetsMD(SmallVectorImpl<unsigned> &&Offsets) {
290   ArgOffsets = std::move(Offsets);
291   updateArgsMD(ArgOffsets.begin(), ArgOffsets.end(), ExternalNode,
292                KernelMDOp::ArgOffsets);
293 }
updateArgKindsMD(SmallVectorImpl<unsigned> && Kinds)294 void KernelMetadata::updateArgKindsMD(SmallVectorImpl<unsigned> &&Kinds) {
295   ArgKinds = std::move(Kinds);
296   updateArgsMD(ArgKinds.begin(), ArgKinds.end(), ExternalNode,
297                KernelMDOp::ArgKinds);
298 }
updateArgIndexesMD(SmallVectorImpl<unsigned> && Indexes)299 void KernelMetadata::updateArgIndexesMD(SmallVectorImpl<unsigned> &&Indexes) {
300   ArgIndexes = std::move(Indexes);
301   updateArgsMD(ArgIndexes.begin(), ArgIndexes.end(), InternalNode,
302                internal::KernelMDOp::ArgIndexes);
303 }
updateOffsetInArgsMD(SmallVectorImpl<unsigned> && Offsets)304 void KernelMetadata::updateOffsetInArgsMD(SmallVectorImpl<unsigned> &&Offsets) {
305   OffsetInArgs = std::move(Offsets);
306   updateArgsMD(OffsetInArgs.begin(), OffsetInArgs.end(), InternalNode,
307                internal::KernelMDOp::OffsetInArgs);
308 }
updateBTIndicesMD(std::vector<int> && BTIndices)309 void KernelMetadata::updateBTIndicesMD(std::vector<int> &&BTIndices) {
310   BTIs = std::move(BTIndices);
311   updateArgsMD(BTIs.begin(), BTIs.end(), InternalNode,
312                internal::KernelMDOp::BTIndices);
313 }
314 
315 } // namespace genx
316 } // namespace llvm
317