1 //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a Union-find algorithm to compute Minimum Spanning Tree
10 // for a given CFG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_TRANSFORMS_INSTRUMENTATION_CFGMST_H
15 #define LLVM_TRANSFORMS_INSTRUMENTATION_CFGMST_H
16 
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Analysis/BlockFrequencyInfo.h"
20 #include "llvm/Analysis/BranchProbabilityInfo.h"
21 #include "llvm/Analysis/CFG.h"
22 #include "llvm/Support/BranchProbability.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
26 #include <utility>
27 #include <vector>
28 
29 #define DEBUG_TYPE "cfgmst"
30 
31 using namespace llvm;
32 
33 namespace llvm {
34 
35 /// An union-find based Minimum Spanning Tree for CFG
36 ///
37 /// Implements a Union-find algorithm to compute Minimum Spanning Tree
38 /// for a given CFG.
39 template <class Edge, class BBInfo> class CFGMST {
40 public:
41   Function &F;
42 
43   // Store all the edges in CFG. It may contain some stale edges
44   // when Removed is set.
45   std::vector<std::unique_ptr<Edge>> AllEdges;
46 
47   // This map records the auxiliary information for each BB.
48   DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos;
49 
50   // Whehter the function has an exit block with no successors.
51   // (For function with an infinite loop, this block may be absent)
52   bool ExitBlockFound = false;
53 
54   // Find the root group of the G and compress the path from G to the root.
55   BBInfo *findAndCompressGroup(BBInfo *G) {
56     if (G->Group != G)
57       G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group));
58     return static_cast<BBInfo *>(G->Group);
59   }
60 
61   // Union BB1 and BB2 into the same group and return true.
62   // Returns false if BB1 and BB2 are already in the same group.
63   bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) {
64     BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1));
65     BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2));
66 
67     if (BB1G == BB2G)
68       return false;
69 
70     // Make the smaller rank tree a direct child or the root of high rank tree.
71     if (BB1G->Rank < BB2G->Rank)
72       BB1G->Group = BB2G;
73     else {
74       BB2G->Group = BB1G;
75       // If the ranks are the same, increment root of one tree by one.
76       if (BB1G->Rank == BB2G->Rank)
77         BB1G->Rank++;
78     }
79     return true;
80   }
81 
82   // Give BB, return the auxiliary information.
83   BBInfo &getBBInfo(const BasicBlock *BB) const {
84     auto It = BBInfos.find(BB);
85     assert(It->second.get() != nullptr);
86     return *It->second.get();
87   }
88 
89   // Give BB, return the auxiliary information if it's available.
90   BBInfo *findBBInfo(const BasicBlock *BB) const {
91     auto It = BBInfos.find(BB);
92     if (It == BBInfos.end())
93       return nullptr;
94     return It->second.get();
95   }
96 
97   // Traverse the CFG using a stack. Find all the edges and assign the weight.
98   // Edges with large weight will be put into MST first so they are less likely
99   // to be instrumented.
100   void buildEdges() {
101     LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
102 
103     BasicBlock *Entry = &(F.getEntryBlock());
104     uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
105     // If we want to instrument the entry count, lower the weight to 0.
106     if (InstrumentFuncEntry)
107       EntryWeight = 0;
108     Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr,
109          *ExitOutgoing = nullptr, *ExitIncoming = nullptr;
110     uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0;
111 
112     // Add a fake edge to the entry.
113     EntryIncoming = &addEdge(nullptr, Entry, EntryWeight);
114     LLVM_DEBUG(dbgs() << "  Edge: from fake node to " << Entry->getName()
115                       << " w = " << EntryWeight << "\n");
116 
117     // Special handling for single BB functions.
118     if (succ_empty(Entry)) {
119       addEdge(Entry, nullptr, EntryWeight);
120       return;
121     }
122 
123     static const uint32_t CriticalEdgeMultiplier = 1000;
124 
125     for (BasicBlock &BB : F) {
126       Instruction *TI = BB.getTerminator();
127       uint64_t BBWeight =
128           (BFI != nullptr ? BFI->getBlockFreq(&BB).getFrequency() : 2);
129       uint64_t Weight = 2;
130       if (int successors = TI->getNumSuccessors()) {
131         for (int i = 0; i != successors; ++i) {
132           BasicBlock *TargetBB = TI->getSuccessor(i);
133           bool Critical = isCriticalEdge(TI, i);
134           uint64_t scaleFactor = BBWeight;
135           if (Critical) {
136             if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier)
137               scaleFactor *= CriticalEdgeMultiplier;
138             else
139               scaleFactor = UINT64_MAX;
140           }
141           if (BPI != nullptr)
142             Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor);
143           if (Weight == 0)
144             Weight++;
145           auto *E = &addEdge(&BB, TargetBB, Weight);
146           E->IsCritical = Critical;
147           LLVM_DEBUG(dbgs() << "  Edge: from " << BB.getName() << " to "
148                             << TargetBB->getName() << "  w=" << Weight << "\n");
149 
150           // Keep track of entry/exit edges:
151           if (&BB == Entry) {
152             if (Weight > MaxEntryOutWeight) {
153               MaxEntryOutWeight = Weight;
154               EntryOutgoing = E;
155             }
156           }
157 
158           auto *TargetTI = TargetBB->getTerminator();
159           if (TargetTI && !TargetTI->getNumSuccessors()) {
160             if (Weight > MaxExitInWeight) {
161               MaxExitInWeight = Weight;
162               ExitIncoming = E;
163             }
164           }
165         }
166       } else {
167         ExitBlockFound = true;
168         Edge *ExitO = &addEdge(&BB, nullptr, BBWeight);
169         if (BBWeight > MaxExitOutWeight) {
170           MaxExitOutWeight = BBWeight;
171           ExitOutgoing = ExitO;
172         }
173         LLVM_DEBUG(dbgs() << "  Edge: from " << BB.getName() << " to fake exit"
174                           << " w = " << BBWeight << "\n");
175       }
176     }
177 
178     // Entry/exit edge adjustment heurisitic:
179     // prefer instrumenting entry edge over exit edge
180     // if possible. Those exit edges may never have a chance to be
181     // executed (for instance the program is an event handling loop)
182     // before the profile is asynchronously dumped.
183     //
184     // If EntryIncoming and ExitOutgoing has similar weight, make sure
185     // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing
186     // and ExitIncoming has similar weight, make sure ExitIncoming becomes
187     // the min-edge.
188     uint64_t EntryInWeight = EntryWeight;
189 
190     if (EntryInWeight >= MaxExitOutWeight &&
191         EntryInWeight * 2 < MaxExitOutWeight * 3) {
192       EntryIncoming->Weight = MaxExitOutWeight;
193       ExitOutgoing->Weight = EntryInWeight + 1;
194     }
195 
196     if (MaxEntryOutWeight >= MaxExitInWeight &&
197         MaxEntryOutWeight * 2 < MaxExitInWeight * 3) {
198       EntryOutgoing->Weight = MaxExitInWeight;
199       ExitIncoming->Weight = MaxEntryOutWeight + 1;
200     }
201   }
202 
203   // Sort CFG edges based on its weight.
204   void sortEdgesByWeight() {
205     llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1,
206                                    const std::unique_ptr<Edge> &Edge2) {
207       return Edge1->Weight > Edge2->Weight;
208     });
209   }
210 
211   // Traverse all the edges and compute the Minimum Weight Spanning Tree
212   // using union-find algorithm.
213   void computeMinimumSpanningTree() {
214     // First, put all the critical edge with landing-pad as the Dest to MST.
215     // This works around the insufficient support of critical edges split
216     // when destination BB is a landing pad.
217     for (auto &Ei : AllEdges) {
218       if (Ei->Removed)
219         continue;
220       if (Ei->IsCritical) {
221         if (Ei->DestBB && Ei->DestBB->isLandingPad()) {
222           if (unionGroups(Ei->SrcBB, Ei->DestBB))
223             Ei->InMST = true;
224         }
225       }
226     }
227 
228     for (auto &Ei : AllEdges) {
229       if (Ei->Removed)
230         continue;
231       // If we detect infinite loops, force
232       // instrumenting the entry edge:
233       if (!ExitBlockFound && Ei->SrcBB == nullptr)
234         continue;
235       if (unionGroups(Ei->SrcBB, Ei->DestBB))
236         Ei->InMST = true;
237     }
238   }
239 
240   // Dump the Debug information about the instrumentation.
241   void dumpEdges(raw_ostream &OS, const Twine &Message) const {
242     if (!Message.str().empty())
243       OS << Message << "\n";
244     OS << "  Number of Basic Blocks: " << BBInfos.size() << "\n";
245     for (auto &BI : BBInfos) {
246       const BasicBlock *BB = BI.first;
247       OS << "  BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << "  "
248          << BI.second->infoString() << "\n";
249     }
250 
251     OS << "  Number of Edges: " << AllEdges.size()
252        << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
253     uint32_t Count = 0;
254     for (auto &EI : AllEdges)
255       OS << "  Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->"
256          << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n";
257   }
258 
259   // Add an edge to AllEdges with weight W.
260   Edge &addEdge(BasicBlock *Src, BasicBlock *Dest, uint64_t W) {
261     uint32_t Index = BBInfos.size();
262     auto Iter = BBInfos.end();
263     bool Inserted;
264     std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
265     if (Inserted) {
266       // Newly inserted, update the real info.
267       Iter->second = std::move(std::make_unique<BBInfo>(Index));
268       Index++;
269     }
270     std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
271     if (Inserted)
272       // Newly inserted, update the real info.
273       Iter->second = std::move(std::make_unique<BBInfo>(Index));
274     AllEdges.emplace_back(new Edge(Src, Dest, W));
275     return *AllEdges.back();
276   }
277 
278   BranchProbabilityInfo *BPI;
279   BlockFrequencyInfo *BFI;
280 
281   // If function entry will be always instrumented.
282   bool InstrumentFuncEntry;
283 
284 public:
285   CFGMST(Function &Func, bool InstrumentFuncEntry_,
286          BranchProbabilityInfo *BPI_ = nullptr,
287          BlockFrequencyInfo *BFI_ = nullptr)
288       : F(Func), BPI(BPI_), BFI(BFI_),
289         InstrumentFuncEntry(InstrumentFuncEntry_) {
290     buildEdges();
291     sortEdgesByWeight();
292     computeMinimumSpanningTree();
293     if (AllEdges.size() > 1 && InstrumentFuncEntry)
294       std::iter_swap(std::move(AllEdges.begin()),
295                      std::move(AllEdges.begin() + AllEdges.size() - 1));
296   }
297 };
298 
299 } // end namespace llvm
300 
301 #undef DEBUG_TYPE // "cfgmst"
302 
303 #endif // LLVM_TRANSFORMS_INSTRUMENTATION_CFGMST_H
304