xref: /openbsd/gnu/llvm/llvm/lib/IR/ProfDataUtils.cpp (revision d415bd75)
1*d415bd75Srobert //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2*d415bd75Srobert //
3*d415bd75Srobert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*d415bd75Srobert // See https://llvm.org/LICENSE.txt for license information.
5*d415bd75Srobert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*d415bd75Srobert //
7*d415bd75Srobert //===----------------------------------------------------------------------===//
8*d415bd75Srobert //
9*d415bd75Srobert // This file implements utilities for working with Profiling Metadata.
10*d415bd75Srobert //
11*d415bd75Srobert //===----------------------------------------------------------------------===//
12*d415bd75Srobert 
13*d415bd75Srobert #include "llvm/IR/ProfDataUtils.h"
14*d415bd75Srobert #include "llvm/ADT/SmallVector.h"
15*d415bd75Srobert #include "llvm/ADT/Twine.h"
16*d415bd75Srobert #include "llvm/IR/Constants.h"
17*d415bd75Srobert #include "llvm/IR/Function.h"
18*d415bd75Srobert #include "llvm/IR/Instructions.h"
19*d415bd75Srobert #include "llvm/IR/LLVMContext.h"
20*d415bd75Srobert #include "llvm/IR/Metadata.h"
21*d415bd75Srobert #include "llvm/Support/BranchProbability.h"
22*d415bd75Srobert #include "llvm/Support/CommandLine.h"
23*d415bd75Srobert 
24*d415bd75Srobert using namespace llvm;
25*d415bd75Srobert 
26*d415bd75Srobert namespace {
27*d415bd75Srobert 
28*d415bd75Srobert // MD_prof nodes have the following layout
29*d415bd75Srobert //
30*d415bd75Srobert // In general:
31*d415bd75Srobert // { String name,         Array of i32   }
32*d415bd75Srobert //
33*d415bd75Srobert // In terms of Types:
34*d415bd75Srobert // { MDString,            [i32, i32, ...]}
35*d415bd75Srobert //
36*d415bd75Srobert // Concretely for Branch Weights
37*d415bd75Srobert // { "branch_weights",    [i32 1, i32 10000]}
38*d415bd75Srobert //
39*d415bd75Srobert // We maintain some constants here to ensure that we access the branch weights
40*d415bd75Srobert // correctly, and can change the behavior in the future if the layout changes
41*d415bd75Srobert 
42*d415bd75Srobert // The index at which the weights vector starts
43*d415bd75Srobert constexpr unsigned WeightsIdx = 1;
44*d415bd75Srobert 
45*d415bd75Srobert // the minimum number of operands for MD_prof nodes with branch weights
46*d415bd75Srobert constexpr unsigned MinBWOps = 3;
47*d415bd75Srobert 
extractWeights(const MDNode * ProfileData,SmallVectorImpl<uint32_t> & Weights)48*d415bd75Srobert bool extractWeights(const MDNode *ProfileData,
49*d415bd75Srobert                     SmallVectorImpl<uint32_t> &Weights) {
50*d415bd75Srobert   // Assume preconditions are already met (i.e. this is valid metadata)
51*d415bd75Srobert   assert(ProfileData && "ProfileData was nullptr in extractWeights");
52*d415bd75Srobert   unsigned NOps = ProfileData->getNumOperands();
53*d415bd75Srobert 
54*d415bd75Srobert   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
55*d415bd75Srobert   Weights.resize(NOps - WeightsIdx);
56*d415bd75Srobert 
57*d415bd75Srobert   for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
58*d415bd75Srobert     ConstantInt *Weight =
59*d415bd75Srobert         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
60*d415bd75Srobert     assert(Weight && "Malformed branch_weight in MD_prof node");
61*d415bd75Srobert     assert(Weight->getValue().getActiveBits() <= 32 &&
62*d415bd75Srobert            "Too many bits for uint32_t");
63*d415bd75Srobert     Weights[Idx - WeightsIdx] = Weight->getZExtValue();
64*d415bd75Srobert   }
65*d415bd75Srobert   return true;
66*d415bd75Srobert }
67*d415bd75Srobert 
68*d415bd75Srobert // We may want to add support for other MD_prof types, so provide an abstraction
69*d415bd75Srobert // for checking the metadata type.
isTargetMD(const MDNode * ProfData,const char * Name,unsigned MinOps)70*d415bd75Srobert bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
71*d415bd75Srobert   // TODO: This routine may be simplified if MD_prof used an enum instead of a
72*d415bd75Srobert   // string to differentiate the types of MD_prof nodes.
73*d415bd75Srobert   if (!ProfData || !Name || MinOps < 2)
74*d415bd75Srobert     return false;
75*d415bd75Srobert 
76*d415bd75Srobert   unsigned NOps = ProfData->getNumOperands();
77*d415bd75Srobert   if (NOps < MinOps)
78*d415bd75Srobert     return false;
79*d415bd75Srobert 
80*d415bd75Srobert   auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
81*d415bd75Srobert   if (!ProfDataName)
82*d415bd75Srobert     return false;
83*d415bd75Srobert 
84*d415bd75Srobert   return ProfDataName->getString().equals(Name);
85*d415bd75Srobert }
86*d415bd75Srobert 
87*d415bd75Srobert } // namespace
88*d415bd75Srobert 
89*d415bd75Srobert namespace llvm {
90*d415bd75Srobert 
hasProfMD(const Instruction & I)91*d415bd75Srobert bool hasProfMD(const Instruction &I) {
92*d415bd75Srobert   return nullptr != I.getMetadata(LLVMContext::MD_prof);
93*d415bd75Srobert }
94*d415bd75Srobert 
isBranchWeightMD(const MDNode * ProfileData)95*d415bd75Srobert bool isBranchWeightMD(const MDNode *ProfileData) {
96*d415bd75Srobert   return isTargetMD(ProfileData, "branch_weights", MinBWOps);
97*d415bd75Srobert }
98*d415bd75Srobert 
hasBranchWeightMD(const Instruction & I)99*d415bd75Srobert bool hasBranchWeightMD(const Instruction &I) {
100*d415bd75Srobert   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
101*d415bd75Srobert   return isBranchWeightMD(ProfileData);
102*d415bd75Srobert }
103*d415bd75Srobert 
hasValidBranchWeightMD(const Instruction & I)104*d415bd75Srobert bool hasValidBranchWeightMD(const Instruction &I) {
105*d415bd75Srobert   return getValidBranchWeightMDNode(I);
106*d415bd75Srobert }
107*d415bd75Srobert 
getBranchWeightMDNode(const Instruction & I)108*d415bd75Srobert MDNode *getBranchWeightMDNode(const Instruction &I) {
109*d415bd75Srobert   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
110*d415bd75Srobert   if (!isBranchWeightMD(ProfileData))
111*d415bd75Srobert     return nullptr;
112*d415bd75Srobert   return ProfileData;
113*d415bd75Srobert }
114*d415bd75Srobert 
getValidBranchWeightMDNode(const Instruction & I)115*d415bd75Srobert MDNode *getValidBranchWeightMDNode(const Instruction &I) {
116*d415bd75Srobert   auto *ProfileData = getBranchWeightMDNode(I);
117*d415bd75Srobert   if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
118*d415bd75Srobert     return ProfileData;
119*d415bd75Srobert   return nullptr;
120*d415bd75Srobert }
121*d415bd75Srobert 
extractBranchWeights(const MDNode * ProfileData,SmallVectorImpl<uint32_t> & Weights)122*d415bd75Srobert bool extractBranchWeights(const MDNode *ProfileData,
123*d415bd75Srobert                           SmallVectorImpl<uint32_t> &Weights) {
124*d415bd75Srobert   if (!isBranchWeightMD(ProfileData))
125*d415bd75Srobert     return false;
126*d415bd75Srobert   return extractWeights(ProfileData, Weights);
127*d415bd75Srobert }
128*d415bd75Srobert 
extractBranchWeights(const Instruction & I,SmallVectorImpl<uint32_t> & Weights)129*d415bd75Srobert bool extractBranchWeights(const Instruction &I,
130*d415bd75Srobert                           SmallVectorImpl<uint32_t> &Weights) {
131*d415bd75Srobert   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
132*d415bd75Srobert   return extractBranchWeights(ProfileData, Weights);
133*d415bd75Srobert }
134*d415bd75Srobert 
extractBranchWeights(const Instruction & I,uint64_t & TrueVal,uint64_t & FalseVal)135*d415bd75Srobert bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
136*d415bd75Srobert                           uint64_t &FalseVal) {
137*d415bd75Srobert   assert((I.getOpcode() == Instruction::Br ||
138*d415bd75Srobert           I.getOpcode() == Instruction::Select) &&
139*d415bd75Srobert          "Looking for branch weights on something besides branch, select, or "
140*d415bd75Srobert          "switch");
141*d415bd75Srobert 
142*d415bd75Srobert   SmallVector<uint32_t, 2> Weights;
143*d415bd75Srobert   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
144*d415bd75Srobert   if (!extractBranchWeights(ProfileData, Weights))
145*d415bd75Srobert     return false;
146*d415bd75Srobert 
147*d415bd75Srobert   if (Weights.size() > 2)
148*d415bd75Srobert     return false;
149*d415bd75Srobert 
150*d415bd75Srobert   TrueVal = Weights[0];
151*d415bd75Srobert   FalseVal = Weights[1];
152*d415bd75Srobert   return true;
153*d415bd75Srobert }
154*d415bd75Srobert 
extractProfTotalWeight(const MDNode * ProfileData,uint64_t & TotalVal)155*d415bd75Srobert bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
156*d415bd75Srobert   TotalVal = 0;
157*d415bd75Srobert   if (!ProfileData)
158*d415bd75Srobert     return false;
159*d415bd75Srobert 
160*d415bd75Srobert   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
161*d415bd75Srobert   if (!ProfDataName)
162*d415bd75Srobert     return false;
163*d415bd75Srobert 
164*d415bd75Srobert   if (ProfDataName->getString().equals("branch_weights")) {
165*d415bd75Srobert     for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
166*d415bd75Srobert       auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
167*d415bd75Srobert       assert(V && "Malformed branch_weight in MD_prof node");
168*d415bd75Srobert       TotalVal += V->getValue().getZExtValue();
169*d415bd75Srobert     }
170*d415bd75Srobert     return true;
171*d415bd75Srobert   }
172*d415bd75Srobert 
173*d415bd75Srobert   if (ProfDataName->getString().equals("VP") &&
174*d415bd75Srobert       ProfileData->getNumOperands() > 3) {
175*d415bd75Srobert     TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
176*d415bd75Srobert                    ->getValue()
177*d415bd75Srobert                    .getZExtValue();
178*d415bd75Srobert     return true;
179*d415bd75Srobert   }
180*d415bd75Srobert   return false;
181*d415bd75Srobert }
182*d415bd75Srobert 
extractProfTotalWeight(const Instruction & I,uint64_t & TotalVal)183*d415bd75Srobert bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
184*d415bd75Srobert   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
185*d415bd75Srobert }
186*d415bd75Srobert 
187*d415bd75Srobert } // namespace llvm
188