1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPUIGroupLP.h"
19 #include "AMDGPUTargetMachine.h"
20 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
21 #include "SIInstrInfo.h"
22 #include "SIMachineFunctionInfo.h"
23 #include "llvm/ADT/BitmaskEnum.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/CodeGen/MachineScheduler.h"
26 #include "llvm/CodeGen/TargetOpcodes.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "igrouplp"
31 
32 namespace {
33 
34 static cl::opt<bool> EnableExactSolver(
35     "amdgpu-igrouplp-exact-solver", cl::Hidden,
36     cl::desc("Whether to use the exponential time solver to fit "
37              "the instructions to the pipeline as closely as "
38              "possible."),
39     cl::init(false));
40 
41 static cl::opt<unsigned> CutoffForExact(
42     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
43     cl::desc("The maximum number of scheduling group conflicts "
44              "which we attempt to solve with the exponential time "
45              "exact solver. Problem sizes greater than this will"
46              "be solved by the less accurate greedy algorithm. Selecting "
47              "solver by size is superseded by manually selecting "
48              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
49 
50 static cl::opt<uint64_t> MaxBranchesExplored(
51     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
52     cl::desc("The amount of branches that we are willing to explore with"
53              "the exact algorithm before giving up."));
54 
55 static cl::opt<bool> UseCostHeur(
56     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
57     cl::desc("Whether to use the cost heuristic to make choices as we "
58              "traverse the search space using the exact solver. Defaulted "
59              "to on, and if turned off, we will use the node order -- "
60              "attempting to put the later nodes in the later sched groups. "
61              "Experimentally, results are mixed, so this should be set on a "
62              "case-by-case basis."));
63 
64 // Components of the mask that determines which instruction types may be may be
65 // classified into a SchedGroup.
66 enum class SchedGroupMask {
67   NONE = 0u,
68   ALU = 1u << 0,
69   VALU = 1u << 1,
70   SALU = 1u << 2,
71   MFMA = 1u << 3,
72   VMEM = 1u << 4,
73   VMEM_READ = 1u << 5,
74   VMEM_WRITE = 1u << 6,
75   DS = 1u << 7,
76   DS_READ = 1u << 8,
77   DS_WRITE = 1u << 9,
78   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79         DS_READ | DS_WRITE,
80   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
81 };
82 
83 class SchedGroup;
84 
85 // InstructionRule class is used to enact a filter which determines whether or
86 // not an SU maps to a given SchedGroup. It contains complementary data
87 // structures (e.g Cache) to help those filters.
88 class InstructionRule {
89 protected:
90   const SIInstrInfo *TII;
91   unsigned SGID;
92   // A cache made available to the Filter to store SUnits for subsequent
93   // invocations of the Filter
94   std::optional<SmallVector<SUnit *, 4>> Cache;
95 
96 public:
97   virtual bool
98   apply(const SUnit *, const ArrayRef<SUnit *>,
99         SmallVectorImpl<SchedGroup> &) {
100     return true;
101   };
102 
103   InstructionRule(const SIInstrInfo *TII, unsigned SGID,
104                   bool NeedsCache = false)
105       : TII(TII), SGID(SGID) {
106     if (NeedsCache) {
107       Cache = SmallVector<SUnit *, 4>();
108     }
109   }
110 
111   virtual ~InstructionRule() = default;
112 };
113 
114 typedef DenseMap<SUnit *, SmallVector<int, 4>> SUnitsToCandidateSGsMap;
115 
116 // Classify instructions into groups to enable fine tuned control over the
117 // scheduler. These groups may be more specific than current SchedModel
118 // instruction classes.
119 class SchedGroup {
120 private:
121   // Mask that defines which instruction types can be classified into this
122   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
123   // and SCHED_GROUP_BARRIER.
124   SchedGroupMask SGMask;
125 
126   // Maximum number of SUnits that can be added to this group.
127   std::optional<unsigned> MaxSize;
128 
129   // SchedGroups will only synchronize with other SchedGroups that have the same
130   // SyncID.
131   int SyncID = 0;
132 
133   // SGID is used to map instructions to candidate SchedGroups
134   unsigned SGID;
135 
136   // The different rules each instruction in this SchedGroup must conform to
137   SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
138 
139   // Count of the number of created SchedGroups, used to initialize SGID.
140   static unsigned NumSchedGroups;
141 
142   const SIInstrInfo *TII;
143 
144   // Try to add and edge from SU A to SU B.
145   bool tryAddEdge(SUnit *A, SUnit *B);
146 
147   // Use SGMask to determine whether we can classify MI as a member of this
148   // SchedGroup object.
149   bool canAddMI(const MachineInstr &MI) const;
150 
151 public:
152   // Collection of SUnits that are classified as members of this group.
153   SmallVector<SUnit *, 32> Collection;
154 
155   ScheduleDAGInstrs *DAG;
156 
157   // Returns true if SU can be added to this SchedGroup.
158   bool canAddSU(SUnit &SU) const;
159 
160   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
161   // MakePred is true, SU will be a predecessor of the SUnits in this
162   // SchedGroup, otherwise SU will be a successor.
163   void link(SUnit &SU, bool MakePred = false);
164 
165   // Add DAG dependencies and track which edges are added, and the count of
166   // missed edges
167   int link(SUnit &SU, bool MakePred,
168            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
169 
170   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
171   // Use the predicate to determine whether SU should be a predecessor (P =
172   // true) or a successor (P = false) of this SchedGroup.
173   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
174 
175   // Add DAG dependencies such that SUnits in this group shall be ordered
176   // before SUnits in OtherGroup.
177   void link(SchedGroup &OtherGroup);
178 
179   // Returns true if no more instructions may be added to this group.
180   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
181 
182   // Append a constraint that SUs must meet in order to fit into this
183   // SchedGroup. Since many rules involve the relationship between a SchedGroup
184   // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
185   // time (rather than SchedGroup init time.)
186   void addRule(std::shared_ptr<InstructionRule> NewRule) {
187     Rules.push_back(NewRule);
188   }
189 
190   // Returns true if the SU matches all rules
191   bool allowedByRules(const SUnit *SU,
192                       SmallVectorImpl<SchedGroup> &SyncPipe) const {
193     if (Rules.empty())
194       return true;
195     for (size_t I = 0; I < Rules.size(); I++) {
196       auto TheRule = Rules[I].get();
197       if (!TheRule->apply(SU, Collection, SyncPipe)) {
198         return false;
199       }
200     }
201     return true;
202   }
203 
204   // Add SU to the SchedGroup.
205   void add(SUnit &SU) {
206     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
207                       << format_hex((int)SGMask, 10, true) << " adding "
208                       << *SU.getInstr());
209     Collection.push_back(&SU);
210   }
211 
212   // Remove last element in the SchedGroup
213   void pop() { Collection.pop_back(); }
214 
215   // Identify and add all relevant SUs from the DAG to this SchedGroup.
216   void initSchedGroup();
217 
218   // Add instructions to the SchedGroup bottom up starting from RIter.
219   // PipelineInstrs is a set of instructions that should not be added to the
220   // SchedGroup even when the other conditions for adding it are satisfied.
221   // RIter will be added to the SchedGroup as well, and dependencies will be
222   // added so that RIter will always be scheduled at the end of the group.
223   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
224                       SUnitsToCandidateSGsMap &SyncedInstrs);
225 
226   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
227 
228   int getSyncID() { return SyncID; }
229 
230   int getSGID() { return SGID; }
231 
232   SchedGroupMask getMask() { return SGMask; }
233 
234   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
235              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
236       : SGMask(SGMask), MaxSize(MaxSize), TII(TII), DAG(DAG) {
237     SGID = NumSchedGroups++;
238   }
239 
240   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
241              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
242       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), TII(TII), DAG(DAG) {
243     SGID = NumSchedGroups++;
244   }
245 };
246 
247 // Remove all existing edges from a SCHED_BARRIER or SCHED_GROUP_BARRIER.
248 static void resetEdges(SUnit &SU, ScheduleDAGInstrs *DAG) {
249   assert(SU.getInstr()->getOpcode() == AMDGPU::SCHED_BARRIER ||
250          SU.getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER ||
251          SU.getInstr()->getOpcode() == AMDGPU::IGLP_OPT);
252 
253   while (!SU.Preds.empty())
254     for (auto &P : SU.Preds)
255       SU.removePred(P);
256 
257   while (!SU.Succs.empty())
258     for (auto &S : SU.Succs)
259       for (auto &SP : S.getSUnit()->Preds)
260         if (SP.getSUnit() == &SU)
261           S.getSUnit()->removePred(SP);
262 }
263 
264 typedef std::pair<SUnit *, SmallVector<int, 4>> SUToCandSGsPair;
265 typedef SmallVector<SUToCandSGsPair, 4> SUsToCandSGsVec;
266 
267 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
268 // in non-trivial cases. For example, if the requested pipeline is
269 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
270 // in the DAG, then we will have an instruction that can not be trivially
271 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
272 // to find a good solution to the pipeline -- a greedy algorithm and an exact
273 // algorithm. The exact algorithm has an exponential time complexity and should
274 // only be used for small sized problems or medium sized problems where an exact
275 // solution is highly desired.
276 class PipelineSolver {
277   ScheduleDAGMI *DAG;
278 
279   // Instructions that can be assigned to multiple SchedGroups
280   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
281   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
282   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
283   // The current working pipeline
284   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
285   // The pipeline that has the best solution found so far
286   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
287 
288   // Whether or not we actually have any SyncedInstrs to try to solve.
289   bool NeedsSolver = false;
290 
291   // Compute an estimate of the size of search tree -- the true size is
292   // the product of each conflictedInst.Matches.size() across all SyncPipelines
293   unsigned computeProblemSize();
294 
295   // The cost penalty of not assigning a SU to a SchedGroup
296   int MissPenalty = 0;
297 
298   // Costs in terms of the number of edges we are unable to add
299   int BestCost = -1;
300   int CurrCost = 0;
301 
302   // Index pointing to the conflicting instruction that is currently being
303   // fitted
304   int CurrConflInstNo = 0;
305   // Index to the pipeline that is currently being fitted
306   int CurrSyncGroupIdx = 0;
307   // The first non trivial pipeline
308   int BeginSyncGroupIdx = 0;
309 
310   // How many branches we have explored
311   uint64_t BranchesExplored = 0;
312 
313   // The direction in which we process the candidate SchedGroups per SU
314   bool IsBottomUp = 1;
315 
316   // Update indices to fit next conflicting instruction
317   void advancePosition();
318   // Recede indices to attempt to find better fit for previous conflicting
319   // instruction
320   void retreatPosition();
321 
322   // The exponential time algorithm which finds the provably best fit
323   bool solveExact();
324   // The polynomial time algorithm which attempts to find a good fit
325   bool solveGreedy();
326   // Find the best SchedGroup for the current SU using the heuristic given all
327   // current information. One step in the greedy algorithm. Templated against
328   // the SchedGroup iterator (either reverse or forward).
329   template <typename T>
330   void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
331                   T E);
332   // Whether or not the current solution is optimal
333   bool checkOptimal();
334   // Populate the ready list, prioiritizing fewest missed edges first
335   // Templated against the SchedGroup iterator (either reverse or forward).
336   template <typename T>
337   void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
338                          T E);
339   // Add edges corresponding to the SchedGroups as assigned by solver
340   void makePipeline();
341   // Link the SchedGroups in the best found pipeline.
342   // Tmplated against the SchedGroup iterator (either reverse or forward).
343   template <typename T> void linkSchedGroups(T I, T E);
344   // Add the edges from the SU to the other SchedGroups in pipeline, and
345   // return the number of edges missed.
346   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
347                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
348   /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
349   /// returns the cost (in terms of missed pipeline edges), and tracks the edges
350   /// added in \p AddedEdges
351   template <typename T>
352   int linkSUnit(SUnit *SU, int SGID,
353                 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
354   /// Remove the edges passed via \p AddedEdges
355   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
356   // Convert the passed in maps to arrays for bidirectional iterators
357   void convertSyncMapsToArrays();
358 
359   void reset();
360 
361 public:
362   // Invoke the solver to map instructions to instruction groups. Heuristic &&
363   // command-line-option determines to use exact or greedy algorithm.
364   void solve();
365 
366   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
367                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
368                  ScheduleDAGMI *DAG, bool IsBottomUp = 1)
369       : DAG(DAG), SyncedInstrs(SyncedInstrs),
370         SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
371 
372     for (auto &PipelineInstrs : SyncedInstrs) {
373       if (PipelineInstrs.second.size() > 0) {
374         NeedsSolver = true;
375         break;
376       }
377     }
378 
379     if (!NeedsSolver)
380       return;
381 
382     convertSyncMapsToArrays();
383 
384     CurrPipeline = BestPipeline;
385 
386     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
387            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
388       ++BeginSyncGroupIdx;
389 
390     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
391       return;
392   }
393 };
394 
395 void PipelineSolver::reset() {
396 
397   for (auto &SyncPipeline : CurrPipeline) {
398     for (auto &SG : SyncPipeline) {
399       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
400       SG.Collection.clear();
401       auto SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
402         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
403       });
404       if (SchedBarr != TempCollection.end())
405         SG.Collection.push_back(*SchedBarr);
406     }
407   }
408 
409   CurrSyncGroupIdx = BeginSyncGroupIdx;
410   CurrConflInstNo = 0;
411   CurrCost = 0;
412 }
413 
414 void PipelineSolver::convertSyncMapsToArrays() {
415   for (auto &SyncPipe : SyncedSchedGroups) {
416     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
417   }
418 
419   int PipelineIDx = SyncedInstrs.size() - 1;
420   PipelineInstrs.resize(SyncedInstrs.size());
421   for (auto &SyncInstrMap : SyncedInstrs) {
422     for (auto &SUsToCandSGs : SyncInstrMap.second) {
423       if (PipelineInstrs[PipelineIDx].size() == 0) {
424         PipelineInstrs[PipelineIDx].push_back(
425             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
426         continue;
427       }
428       auto SortPosition = PipelineInstrs[PipelineIDx].begin();
429       // Insert them in sorted order -- this allows for good parsing order in
430       // the greedy algorithm
431       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
432              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
433         ++SortPosition;
434       PipelineInstrs[PipelineIDx].insert(
435           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
436     }
437     --PipelineIDx;
438   }
439 }
440 
441 template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
442   for (; I != E; ++I) {
443     auto &GroupA = *I;
444     for (auto J = std::next(I); J != E; ++J) {
445       auto &GroupB = *J;
446       GroupA.link(GroupB);
447     }
448   }
449 }
450 
451 void PipelineSolver::makePipeline() {
452   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
453   for (auto &SyncPipeline : BestPipeline) {
454     LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
455     for (auto &SG : SyncPipeline) {
456       LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
457                         << " has: \n");
458       SUnit *SGBarr = nullptr;
459       for (auto &SU : SG.Collection) {
460         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
461           SGBarr = SU;
462         LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
463       }
464       // Command line requested IGroupLP doesn't have SGBarr
465       if (!SGBarr)
466         continue;
467       resetEdges(*SGBarr, DAG);
468       SG.link(*SGBarr, false);
469     }
470   }
471 
472   for (auto &SyncPipeline : BestPipeline) {
473     IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
474                : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
475   }
476 }
477 
478 template <typename T>
479 int PipelineSolver::linkSUnit(
480     SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
481     T I, T E) {
482   bool MakePred = false;
483   int AddedCost = 0;
484   for (; I < E; ++I) {
485     if (I->getSGID() == SGID) {
486       MakePred = true;
487       continue;
488     }
489     auto Group = *I;
490     AddedCost += Group.link(*SU, MakePred, AddedEdges);
491     assert(AddedCost >= 0);
492   }
493   return AddedCost;
494 }
495 
496 int PipelineSolver::addEdges(
497     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
498     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
499 
500   // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
501   // instructions that are the ultimate successors in the resultant mutation.
502   // Therefore, in such a configuration, the SchedGroups occurring before the
503   // candidate SGID are successors of the candidate SchedGroup, thus the current
504   // SU should be linked as a predecessor to SUs in those SchedGroups. The
505   // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
506   // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
507   // IsBottomUp (in reverse).
508   return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
509                                 SyncPipeline.rend())
510                     : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
511                                 SyncPipeline.end());
512 }
513 
514 void PipelineSolver::removeEdges(
515     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
516   // Only remove the edges that we have added when testing
517   // the fit.
518   for (auto &PredSuccPair : EdgesToRemove) {
519     SUnit *Pred = PredSuccPair.first;
520     SUnit *Succ = PredSuccPair.second;
521 
522     auto Match = llvm::find_if(
523         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
524     if (Match != Succ->Preds.end()) {
525       assert(Match->isArtificial());
526       Succ->removePred(*Match);
527     }
528   }
529 }
530 
531 void PipelineSolver::advancePosition() {
532   ++CurrConflInstNo;
533 
534   if (static_cast<size_t>(CurrConflInstNo) >=
535       PipelineInstrs[CurrSyncGroupIdx].size()) {
536     CurrConflInstNo = 0;
537     ++CurrSyncGroupIdx;
538     // Advance to next non-trivial pipeline
539     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
540            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
541       ++CurrSyncGroupIdx;
542   }
543 }
544 
545 void PipelineSolver::retreatPosition() {
546   assert(CurrConflInstNo >= 0);
547   assert(CurrSyncGroupIdx >= 0);
548 
549   if (CurrConflInstNo > 0) {
550     --CurrConflInstNo;
551     return;
552   }
553 
554   if (CurrConflInstNo == 0) {
555     // If we return to the starting position, we have explored
556     // the entire tree
557     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
558       return;
559 
560     --CurrSyncGroupIdx;
561     // Go to previous non-trivial pipeline
562     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
563       --CurrSyncGroupIdx;
564 
565     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
566   }
567 }
568 
569 bool PipelineSolver::checkOptimal() {
570   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
571     if (BestCost == -1 || CurrCost < BestCost) {
572       BestPipeline = CurrPipeline;
573       BestCost = CurrCost;
574       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
575     }
576     assert(BestCost >= 0);
577   }
578 
579   bool DoneExploring = false;
580   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
581     DoneExploring = true;
582 
583   return (DoneExploring || BestCost == 0);
584 }
585 
586 template <typename T>
587 void PipelineSolver::populateReadyList(
588     SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
589   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
590   auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
591   assert(CurrSU.second.size() >= 1);
592 
593   for (; I != E; ++I) {
594     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
595     int CandSGID = *I;
596     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
597       return SG.getSGID() == CandSGID;
598     });
599     assert(Match);
600 
601     if (UseCostHeur) {
602       if (Match->isFull()) {
603         ReadyList.push_back(std::pair(*I, MissPenalty));
604         continue;
605       }
606 
607       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
608       ReadyList.push_back(std::pair(*I, TempCost));
609       removeEdges(AddedEdges);
610     } else
611       ReadyList.push_back(std::pair(*I, -1));
612   }
613 
614   if (UseCostHeur) {
615     std::sort(ReadyList.begin(), ReadyList.end(),
616               [](std::pair<int, int> A, std::pair<int, int> B) {
617                 return A.second < B.second;
618               });
619   }
620 
621   assert(ReadyList.size() == CurrSU.second.size());
622 }
623 
624 bool PipelineSolver::solveExact() {
625   if (checkOptimal())
626     return true;
627 
628   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
629     return false;
630 
631   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
632   assert(static_cast<size_t>(CurrConflInstNo) <
633          PipelineInstrs[CurrSyncGroupIdx].size());
634   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
635   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
636                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
637 
638   // SchedGroup -> Cost pairs
639   SmallVector<std::pair<int, int>, 4> ReadyList;
640   // Prioritize the candidate sched groups in terms of lowest cost first
641   IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
642                                  CurrSU.second.rend())
643              : populateReadyList(ReadyList, CurrSU.second.begin(),
644                                  CurrSU.second.end());
645 
646   auto I = ReadyList.begin();
647   auto E = ReadyList.end();
648   for (; I != E; ++I) {
649     // If we are trying SGs in least cost order, and the current SG is cost
650     // infeasible, then all subsequent SGs will also be cost infeasible, so we
651     // can prune.
652     if (BestCost != -1 && (CurrCost + I->second > BestCost))
653       return false;
654 
655     int CandSGID = I->first;
656     int AddedCost = 0;
657     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
658     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
659     SchedGroup *Match;
660     for (auto &SG : SyncPipeline) {
661       if (SG.getSGID() == CandSGID)
662         Match = &SG;
663     }
664 
665     if (Match->isFull())
666       continue;
667 
668     if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
669       continue;
670 
671     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
672                       << (int)Match->getMask() << "and ID " << CandSGID
673                       << "\n");
674     Match->add(*CurrSU.first);
675     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
676     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
677     CurrCost += AddedCost;
678     advancePosition();
679     ++BranchesExplored;
680     bool FinishedExploring = false;
681     // If the Cost after adding edges is greater than a known solution,
682     // backtrack
683     if (CurrCost < BestCost || BestCost == -1) {
684       if (solveExact()) {
685         FinishedExploring = BestCost != 0;
686         if (!FinishedExploring)
687           return true;
688       }
689     }
690 
691     retreatPosition();
692     CurrCost -= AddedCost;
693     removeEdges(AddedEdges);
694     Match->pop();
695     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
696     if (FinishedExploring)
697       return true;
698   }
699 
700   // Try the pipeline where the current instruction is omitted
701   // Potentially if we omit a problematic instruction from the pipeline,
702   // all the other instructions can nicely fit.
703   CurrCost += MissPenalty;
704   advancePosition();
705 
706   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
707 
708   bool FinishedExploring = false;
709   if (CurrCost < BestCost || BestCost == -1) {
710     if (solveExact()) {
711       bool FinishedExploring = BestCost != 0;
712       if (!FinishedExploring)
713         return true;
714     }
715   }
716 
717   retreatPosition();
718   CurrCost -= MissPenalty;
719   return FinishedExploring;
720 }
721 
722 template <typename T>
723 void PipelineSolver::greedyFind(
724     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
725   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
726   int BestNodeCost = -1;
727   int TempCost;
728   SchedGroup *BestGroup = nullptr;
729   int BestGroupID = -1;
730   auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
731   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
732                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
733 
734   // Since we have added the potential SchedGroups from bottom up, but
735   // traversed the DAG from top down, parse over the groups from last to
736   // first. If we fail to do this for the greedy algorithm, the solution will
737   // likely not be good in more complex cases.
738   for (; I != E; ++I) {
739     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
740     int CandSGID = *I;
741     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
742       return SG.getSGID() == CandSGID;
743     });
744     assert(Match);
745 
746     LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
747                       << (int)Match->getMask() << "\n");
748 
749     if (Match->isFull()) {
750       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
751       continue;
752     }
753     if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
754       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
755       continue;
756     }
757     TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
758     LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
759     if (TempCost < BestNodeCost || BestNodeCost == -1) {
760       BestGroup = Match;
761       BestNodeCost = TempCost;
762       BestGroupID = CandSGID;
763     }
764     removeEdges(AddedEdges);
765     if (BestNodeCost == 0)
766       break;
767   }
768 
769   if (BestGroupID != -1) {
770     BestGroup->add(*CurrSU.first);
771     addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
772     LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
773                       << (int)BestGroup->getMask() << "\n");
774     BestCost += TempCost;
775   } else
776     BestCost += MissPenalty;
777 
778   CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
779 }
780 
781 bool PipelineSolver::solveGreedy() {
782   BestCost = 0;
783   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
784 
785   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
786     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
787     IsBottomUp
788         ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
789         : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
790     advancePosition();
791   }
792   BestPipeline = CurrPipeline;
793   removeEdges(AddedEdges);
794   return false;
795 }
796 
797 unsigned PipelineSolver::computeProblemSize() {
798   unsigned ProblemSize = 0;
799   for (auto &PipeConflicts : PipelineInstrs) {
800     ProblemSize += PipeConflicts.size();
801   }
802 
803   return ProblemSize;
804 }
805 
806 void PipelineSolver::solve() {
807   if (!NeedsSolver)
808     return;
809 
810   unsigned ProblemSize = computeProblemSize();
811   assert(ProblemSize > 0);
812 
813   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
814   MissPenalty = (ProblemSize / 2) + 1;
815 
816   LLVM_DEBUG(DAG->dump());
817   if (EnableExactSolver || BelowCutoff) {
818     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
819     solveGreedy();
820     reset();
821     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
822     if (BestCost > 0) {
823       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
824       solveExact();
825       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
826     }
827   } else { // Use the Greedy Algorithm by default
828     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
829     solveGreedy();
830   }
831 
832   makePipeline();
833   LLVM_DEBUG(dbgs() << "After applying mutation\n");
834   LLVM_DEBUG(DAG->dump());
835 }
836 
837 enum IGLPStrategyID : int {
838   MFMASmallGemmOptID = 0,
839   MFMASmallGemmSingleWaveOptID = 1,
840 };
841 
842 // Implement a IGLP scheduling strategy.
843 class IGLPStrategy {
844 protected:
845   ScheduleDAGInstrs *DAG;
846 
847   const SIInstrInfo *TII;
848 
849 public:
850   /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
851   virtual void applyIGLPStrategy(
852       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
853       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
854       bool IsReentry) = 0;
855 
856   // Returns true if this strategy should be applied to a ScheduleDAG.
857   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) = 0;
858 
859   bool IsBottomUp = 1;
860 
861   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
862       : DAG(DAG), TII(TII) {}
863 
864   virtual ~IGLPStrategy() = default;
865 };
866 
867 class MFMASmallGemmOpt final : public IGLPStrategy {
868 private:
869 public:
870   void applyIGLPStrategy(
871       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
872       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
873       bool IsReentry) override;
874 
875   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
876 
877   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
878       : IGLPStrategy(DAG, TII) {
879     IsBottomUp = 1;
880   }
881 };
882 
883 void MFMASmallGemmOpt::applyIGLPStrategy(
884     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
885     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
886     bool IsReentry) {
887   // Count the number of MFMA instructions.
888   unsigned MFMACount = 0;
889   for (const MachineInstr &I : *DAG)
890     if (TII->isMFMAorWMMA(I))
891       ++MFMACount;
892 
893   const unsigned PipelineSyncID = 0;
894   SchedGroup *SG = nullptr;
895   for (unsigned I = 0; I < MFMACount * 3; ++I) {
896     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
897         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
898     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
899 
900     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
901         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
902     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
903   }
904 }
905 
906 class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
907 private:
908   // Whether the DS_READ is a predecessor of first four MFMA in region
909   class EnablesInitialMFMA final : public InstructionRule {
910   public:
911     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
912                SmallVectorImpl<SchedGroup> &SyncPipe) override {
913       if (!SyncPipe.size())
914         return false;
915       int MFMAsFound = 0;
916       if (!Cache->size()) {
917         for (auto &Elt : SyncPipe[0].DAG->SUnits) {
918           if (TII->isMFMAorWMMA(*Elt.getInstr())) {
919             ++MFMAsFound;
920             if (MFMAsFound > 4)
921               break;
922             Cache->push_back(&Elt);
923           }
924         }
925       }
926 
927       assert(Cache->size());
928       auto DAG = SyncPipe[0].DAG;
929       for (auto &Elt : *Cache) {
930         if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
931           return true;
932       }
933       return false;
934     }
935 
936     EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
937                        bool NeedsCache = false)
938         : InstructionRule(TII, SGID, NeedsCache) {}
939   };
940 
941   // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
942   class IsPermForDSW final : public InstructionRule {
943   public:
944     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
945                SmallVectorImpl<SchedGroup> &SyncPipe) override {
946       auto MI = SU->getInstr();
947       if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
948         return false;
949 
950       bool FitsInGroup = false;
951       // Does the VALU have a DS_WRITE successor
952       if (!Collection.size()) {
953         for (auto &Succ : SU->Succs) {
954           SUnit *SuccUnit = Succ.getSUnit();
955           if (TII->isDS(*SuccUnit->getInstr()) &&
956               SuccUnit->getInstr()->mayStore()) {
957             Cache->push_back(SuccUnit);
958             FitsInGroup = true;
959           }
960         }
961         return FitsInGroup;
962       }
963 
964       assert(Cache->size());
965 
966       // Does the VALU have a DS_WRITE successor that is the same as other
967       // VALU already in the group. The V_PERMs will all share 1 DS_W succ
968       return llvm::any_of(*Cache, [&SU](SUnit *Elt) {
969         return llvm::any_of(SU->Succs, [&Elt](const SDep &ThisSucc) {
970           return ThisSucc.getSUnit() == Elt;
971         });
972       });
973     }
974 
975     IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
976         : InstructionRule(TII, SGID, NeedsCache) {}
977   };
978 
979   // Whether the SU is a successor of any element in previous SchedGroup
980   class IsSuccOfPrevGroup final : public InstructionRule {
981   public:
982     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
983                SmallVectorImpl<SchedGroup> &SyncPipe) override {
984       SchedGroup *OtherGroup = nullptr;
985       for (auto &PipeSG : SyncPipe) {
986         if ((unsigned)PipeSG.getSGID() == SGID - 1) {
987           OtherGroup = &PipeSG;
988         }
989       }
990 
991       if (!OtherGroup)
992         return false;
993       if (!OtherGroup->Collection.size())
994         return true;
995 
996       // Does the previous VALU have this DS_Write as a successor
997       return (std::any_of(OtherGroup->Collection.begin(),
998                           OtherGroup->Collection.end(), [&SU](SUnit *Elt) {
999                             return std::any_of(Elt->Succs.begin(),
1000                                                Elt->Succs.end(),
1001                                                [&SU](SDep &Succ) {
1002                                                  return Succ.getSUnit() == SU;
1003                                                });
1004                           }));
1005     }
1006     IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1007                       bool NeedsCache = false)
1008         : InstructionRule(TII, SGID, NeedsCache) {}
1009   };
1010 
1011   // Whether the combined load width of group is 128 bits
1012   class VMEMSize final : public InstructionRule {
1013   public:
1014     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1015                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1016       auto MI = SU->getInstr();
1017       if (MI->getOpcode() == TargetOpcode::BUNDLE)
1018         return false;
1019       if (!Collection.size())
1020         return true;
1021 
1022       int NumBits = 0;
1023 
1024       auto TRI = TII->getRegisterInfo();
1025       auto &MRI = MI->getParent()->getParent()->getRegInfo();
1026       for (auto &Elt : Collection) {
1027         auto Op = Elt->getInstr()->getOperand(0);
1028         auto Size =
1029             TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
1030         NumBits += Size;
1031       }
1032 
1033       if (NumBits < 128) {
1034         assert(TII->isVMEM(*MI) && MI->mayLoad());
1035         if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
1036                           MRI, MI->getOperand(0))) <=
1037             128)
1038           return true;
1039       }
1040 
1041       return false;
1042     }
1043 
1044     VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1045         : InstructionRule(TII, SGID, NeedsCache) {}
1046   };
1047 
1048   /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
1049   /// that is \p Distance steps away
1050   class SharesPredWithPrevNthGroup final : public InstructionRule {
1051   private:
1052     unsigned Distance = 1;
1053 
1054   public:
1055     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1056                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1057       SchedGroup *OtherGroup = nullptr;
1058       if (!SyncPipe.size())
1059         return false;
1060 
1061       if (!Cache->size()) {
1062 
1063         for (auto &PipeSG : SyncPipe) {
1064           if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
1065             OtherGroup = &PipeSG;
1066           }
1067         }
1068 
1069         if (!OtherGroup)
1070           return false;
1071         if (!OtherGroup->Collection.size())
1072           return true;
1073 
1074         for (auto &OtherEle : OtherGroup->Collection) {
1075           for (auto &Pred : OtherEle->Preds) {
1076             if (Pred.getSUnit()->getInstr()->getOpcode() ==
1077                 AMDGPU::V_PERM_B32_e64)
1078               Cache->push_back(Pred.getSUnit());
1079           }
1080         }
1081 
1082         // If the other group has no PERM preds, then this group won't share any
1083         if (!Cache->size())
1084           return false;
1085       }
1086 
1087       auto DAG = SyncPipe[0].DAG;
1088       // Does the previous DS_WRITE share a V_PERM predecessor with this
1089       // VMEM_READ
1090       return llvm::any_of(*Cache, [&SU, &DAG](SUnit *Elt) {
1091         return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
1092       });
1093     }
1094     SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1095                                unsigned SGID, bool NeedsCache = false)
1096         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1097   };
1098 
1099 public:
1100   void applyIGLPStrategy(
1101       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1102       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1103       bool IsReentry) override;
1104 
1105   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
1106 
1107   MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1108       : IGLPStrategy(DAG, TII) {
1109     IsBottomUp = 0;
1110   }
1111 };
1112 
1113 static unsigned DSWCount = 0;
1114 static unsigned DSWWithPermCount = 0;
1115 static unsigned DSWWithSharedVMEMCount = 0;
1116 
1117 void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1118     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1119     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1120     bool IsReentry) {
1121   unsigned MFMACount = 0;
1122   unsigned DSRCount = 0;
1123 
1124   assert((IsReentry || (DSWCount == 0 && DSWWithPermCount == 0 &&
1125                         DSWWithSharedVMEMCount == 0)) &&
1126          "DSWCounters should be zero in pre-RA scheduling!");
1127   SmallVector<SUnit *, 6> DSWithPerms;
1128   for (auto &SU : DAG->SUnits) {
1129     auto I = SU.getInstr();
1130     if (TII->isMFMAorWMMA(*I))
1131       ++MFMACount;
1132     else if (TII->isDS(*I)) {
1133       if (I->mayLoad())
1134         ++DSRCount;
1135       else if (I->mayStore() && !IsReentry) {
1136         ++DSWCount;
1137         for (auto Pred : SU.Preds) {
1138           if (Pred.getSUnit()->getInstr()->getOpcode() ==
1139               AMDGPU::V_PERM_B32_e64) {
1140             DSWithPerms.push_back(&SU);
1141             break;
1142           }
1143         }
1144       }
1145     }
1146   }
1147 
1148   if (!IsReentry) {
1149     DSWWithPermCount = DSWithPerms.size();
1150     auto I = DSWithPerms.begin();
1151     auto E = DSWithPerms.end();
1152 
1153     // Get the count of DS_WRITES with V_PERM predecessors which
1154     // have loop carried dependencies (WAR) on the same VMEM_READs.
1155     // We consider partial overlap as a miss -- in other words,
1156     // for a given DS_W, we only consider another DS_W as matching
1157     // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1158     // for every V_PERM pred of this DS_W.
1159     DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1160     SmallVector<SUnit *, 6> Counted;
1161     for (; I != E; I++) {
1162       SUnit *Cand = nullptr;
1163       bool MissedAny = false;
1164       for (auto &Pred : (*I)->Preds) {
1165         if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
1166           continue;
1167 
1168         if (Cand && llvm::is_contained(Counted, Cand))
1169           break;
1170 
1171         for (auto &Succ : Pred.getSUnit()->Succs) {
1172           auto MI = Succ.getSUnit()->getInstr();
1173           if (!TII->isVMEM(*MI) || !MI->mayLoad())
1174             continue;
1175 
1176           if (MissedAny || !VMEMLookup.size()) {
1177             MissedAny = true;
1178             VMEMLookup[MI] = *I;
1179             continue;
1180           }
1181 
1182           if (!VMEMLookup.contains(MI)) {
1183             MissedAny = true;
1184             VMEMLookup[MI] = *I;
1185             continue;
1186           }
1187 
1188           Cand = VMEMLookup[MI];
1189           if (llvm::is_contained(Counted, Cand)) {
1190             MissedAny = true;
1191             break;
1192           }
1193         }
1194       }
1195       if (!MissedAny && Cand) {
1196         DSWWithSharedVMEMCount += 2;
1197         Counted.push_back(Cand);
1198         Counted.push_back(*I);
1199       }
1200     }
1201   }
1202 
1203   assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
1204   SchedGroup *SG;
1205   unsigned PipelineSyncID = 0;
1206   // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
1207   if (DSWWithPermCount) {
1208     for (unsigned I = 0; I < MFMACount; I++) {
1209       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1210           SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1211       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1212 
1213       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1214           SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
1215       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1216     }
1217   }
1218 
1219   PipelineSyncID = 1;
1220   // Phase 1: Break up DS_READ and MFMA clusters.
1221   // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
1222   // prefetch
1223 
1224   // Make ready initial MFMA
1225   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1226       SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
1227   SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
1228   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1229 
1230   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1231       SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1232   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1233 
1234   // Interleave MFMA with DS_READ prefetch
1235   for (unsigned I = 0; I < DSRCount - 4; ++I) {
1236     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1237         SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
1238     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1239 
1240     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1241         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1242     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1243   }
1244 
1245   // Phase 2a: Loop carried dependency with V_PERM
1246   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
1247   // depend on. Interleave MFMA to keep XDL unit busy throughout.
1248   for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
1249     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1250         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1251     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1252     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1253 
1254     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1255         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1256     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1257     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1258 
1259     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1260         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1261     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1262         1, TII, SG->getSGID(), true));
1263     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1264     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1265 
1266     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1267         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1268     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1269 
1270     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1271         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1272     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1273         3, TII, SG->getSGID(), true));
1274     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1275     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1276 
1277     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1278         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1279     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1280   }
1281 
1282   // Phase 2b: Loop carried dependency without V_PERM
1283   // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
1284   // Interleave MFMA to keep XDL unit busy throughout.
1285   for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
1286     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1287         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1288     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1289 
1290     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1291         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1292     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1293     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1294 
1295     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1296         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1297     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1298   }
1299 
1300   // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
1301   // ultimately used by two DS_WRITE
1302   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
1303   // depend on. Interleave MFMA to keep XDL unit busy throughout.
1304 
1305   for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
1306     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1307         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1308     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1309     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1310 
1311     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1312         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1313     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1314     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1315 
1316     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1317         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1318     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1319 
1320     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1321         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1322     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1323     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1324 
1325     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1326         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1327     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1328     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1329 
1330     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1331         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1332     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1333 
1334     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1335         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1336     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1337         2, TII, SG->getSGID(), true));
1338     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1339     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1340 
1341     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1342         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1343     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1344 
1345     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1346         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1347     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1348         4, TII, SG->getSGID(), true));
1349     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1350     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1351 
1352     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1353         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1354     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1355   }
1356 }
1357 
1358 static std::unique_ptr<IGLPStrategy>
1359 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
1360                    const SIInstrInfo *TII) {
1361   switch (ID) {
1362   case MFMASmallGemmOptID:
1363     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
1364   case MFMASmallGemmSingleWaveOptID:
1365     return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
1366   }
1367 
1368   llvm_unreachable("Unknown IGLPStrategyID");
1369 }
1370 
1371 class IGroupLPDAGMutation : public ScheduleDAGMutation {
1372 private:
1373   const SIInstrInfo *TII;
1374 
1375   ScheduleDAGMI *DAG;
1376 
1377   // Organize lists of SchedGroups by their SyncID. SchedGroups /
1378   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
1379   // between then.
1380   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
1381 
1382   // Used to track instructions that can be mapped to multiple sched groups
1383   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
1384 
1385   // Add DAG edges that enforce SCHED_BARRIER ordering.
1386   void addSchedBarrierEdges(SUnit &SU);
1387 
1388   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
1389   // not be reordered accross the SCHED_BARRIER. This is used for the base
1390   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
1391   // SCHED_BARRIER will always block all instructions that can be classified
1392   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
1393   // and may only synchronize with some SchedGroups. Returns the inverse of
1394   // Mask. SCHED_BARRIER's mask describes which instruction types should be
1395   // allowed to be scheduled across it. Invert the mask to get the
1396   // SchedGroupMask of instructions that should be barred.
1397   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
1398 
1399   // Create SchedGroups for a SCHED_GROUP_BARRIER.
1400   void initSchedGroupBarrierPipelineStage(
1401       std::vector<SUnit>::reverse_iterator RIter);
1402 
1403   void initIGLPOpt(SUnit &SU);
1404 
1405 public:
1406   void apply(ScheduleDAGInstrs *DAGInstrs) override;
1407 
1408   // The order in which the PipelineSolver should process the candidate
1409   // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
1410   // created SchedGroup first, and will consider that as the ultimate
1411   // predecessor group when linking. TOP_DOWN instead links and processes the
1412   // first created SchedGroup first.
1413   bool IsBottomUp = 1;
1414 
1415   // Whether or not this is a reentry into the IGroupLPDAGMutation.
1416   bool IsReentry = false;
1417 
1418   IGroupLPDAGMutation() = default;
1419   IGroupLPDAGMutation(bool IsReentry) : IsReentry(IsReentry) {}
1420 };
1421 
1422 unsigned SchedGroup::NumSchedGroups = 0;
1423 
1424 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
1425   if (A != B && DAG->canAddEdge(B, A)) {
1426     DAG->addEdge(B, SDep(A, SDep::Artificial));
1427     return true;
1428   }
1429   return false;
1430 }
1431 
1432 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
1433   bool Result = false;
1434   if (MI.isMetaInstruction())
1435     Result = false;
1436 
1437   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
1438            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI)))
1439     Result = true;
1440 
1441   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
1442            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI))
1443     Result = true;
1444 
1445   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
1446            TII->isSALU(MI))
1447     Result = true;
1448 
1449   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
1450            TII->isMFMAorWMMA(MI))
1451     Result = true;
1452 
1453   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
1454            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1455     Result = true;
1456 
1457   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
1458            MI.mayLoad() &&
1459            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1460     Result = true;
1461 
1462   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
1463            MI.mayStore() &&
1464            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1465     Result = true;
1466 
1467   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
1468            TII->isDS(MI))
1469     Result = true;
1470 
1471   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
1472            MI.mayLoad() && TII->isDS(MI))
1473     Result = true;
1474 
1475   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
1476            MI.mayStore() && TII->isDS(MI))
1477     Result = true;
1478 
1479   LLVM_DEBUG(
1480       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
1481              << (Result ? " could classify " : " unable to classify ") << MI);
1482 
1483   return Result;
1484 }
1485 
1486 int SchedGroup::link(SUnit &SU, bool MakePred,
1487                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
1488   int MissedEdges = 0;
1489   for (auto *A : Collection) {
1490     SUnit *B = &SU;
1491     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
1492       continue;
1493     if (MakePred)
1494       std::swap(A, B);
1495 
1496     if (DAG->IsReachable(B, A))
1497       continue;
1498 
1499     // tryAddEdge returns false if there is a dependency that makes adding
1500     // the A->B edge impossible, otherwise it returns true;
1501     bool Added = tryAddEdge(A, B);
1502     if (Added)
1503       AddedEdges.push_back(std::pair(A, B));
1504     else
1505       ++MissedEdges;
1506   }
1507 
1508   return MissedEdges;
1509 }
1510 
1511 void SchedGroup::link(SUnit &SU, bool MakePred) {
1512   for (auto *A : Collection) {
1513     SUnit *B = &SU;
1514     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
1515       continue;
1516     if (MakePred)
1517       std::swap(A, B);
1518 
1519     tryAddEdge(A, B);
1520   }
1521 }
1522 
1523 void SchedGroup::link(SUnit &SU,
1524                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
1525   for (auto *A : Collection) {
1526     SUnit *B = &SU;
1527     if (P(A, B))
1528       std::swap(A, B);
1529 
1530     tryAddEdge(A, B);
1531   }
1532 }
1533 
1534 void SchedGroup::link(SchedGroup &OtherGroup) {
1535   for (auto *B : OtherGroup.Collection)
1536     link(*B);
1537 }
1538 
1539 bool SchedGroup::canAddSU(SUnit &SU) const {
1540   MachineInstr &MI = *SU.getInstr();
1541   if (MI.getOpcode() != TargetOpcode::BUNDLE)
1542     return canAddMI(MI);
1543 
1544   // Special case for bundled MIs.
1545   const MachineBasicBlock *MBB = MI.getParent();
1546   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
1547   while (E != MBB->end() && E->isBundledWithPred())
1548     ++E;
1549 
1550   // Return true if all of the bundled MIs can be added to this group.
1551   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
1552 }
1553 
1554 void SchedGroup::initSchedGroup() {
1555   for (auto &SU : DAG->SUnits) {
1556     if (isFull())
1557       break;
1558 
1559     if (canAddSU(SU))
1560       add(SU);
1561   }
1562 }
1563 
1564 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
1565                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
1566   SUnit &InitSU = *RIter;
1567   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
1568     auto &SU = *RIter;
1569     if (isFull())
1570       break;
1571 
1572     if (canAddSU(SU))
1573       SyncedInstrs[&SU].push_back(SGID);
1574   }
1575 
1576   add(InitSU);
1577   assert(MaxSize);
1578   (*MaxSize)++;
1579 }
1580 
1581 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
1582   auto I = DAG->SUnits.rbegin();
1583   auto E = DAG->SUnits.rend();
1584   for (; I != E; ++I) {
1585     auto &SU = *I;
1586     if (isFull())
1587       break;
1588 
1589     if (canAddSU(SU))
1590       SyncedInstrs[&SU].push_back(SGID);
1591   }
1592 }
1593 
1594 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
1595   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
1596   if (!TSchedModel || DAGInstrs->SUnits.empty())
1597     return;
1598 
1599   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
1600   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
1601   TII = ST.getInstrInfo();
1602   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
1603   SyncedSchedGroups.clear();
1604   SyncedInstrs.clear();
1605   bool foundSB = false;
1606   bool foundIGLP = false;
1607   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
1608     unsigned Opc = R->getInstr()->getOpcode();
1609     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
1610     if (Opc == AMDGPU::SCHED_BARRIER) {
1611       addSchedBarrierEdges(*R);
1612       foundSB = true;
1613     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
1614       initSchedGroupBarrierPipelineStage(R);
1615       foundSB = true;
1616     } else if (Opc == AMDGPU::IGLP_OPT) {
1617       resetEdges(*R, DAG);
1618       if (!foundSB && !foundIGLP)
1619         initIGLPOpt(*R);
1620       foundIGLP = true;
1621     }
1622   }
1623 
1624   if (foundSB || foundIGLP) {
1625     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
1626     // PipelineSolver performs the mutation by adding the edges it
1627     // determined as the best
1628     PS.solve();
1629   }
1630 }
1631 
1632 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
1633   MachineInstr &MI = *SchedBarrier.getInstr();
1634   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
1635   // Remove all existing edges from the SCHED_BARRIER that were added due to the
1636   // instruction having side effects.
1637   resetEdges(SchedBarrier, DAG);
1638   auto InvertedMask =
1639       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
1640   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
1641   SG.initSchedGroup();
1642   // Preserve original instruction ordering relative to the SCHED_BARRIER.
1643   SG.link(
1644       SchedBarrier,
1645       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
1646           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
1647 }
1648 
1649 SchedGroupMask
1650 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1651   // Invert mask and erase bits for types of instructions that are implied to be
1652   // allowed past the SCHED_BARRIER.
1653   SchedGroupMask InvertedMask = ~Mask;
1654 
1655   // ALU implies VALU, SALU, MFMA.
1656   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
1657     InvertedMask &=
1658         ~SchedGroupMask::VALU & ~SchedGroupMask::SALU & ~SchedGroupMask::MFMA;
1659   // VALU, SALU, MFMA implies ALU.
1660   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
1661            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
1662            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE)
1663     InvertedMask &= ~SchedGroupMask::ALU;
1664 
1665   // VMEM implies VMEM_READ, VMEM_WRITE.
1666   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
1667     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
1668   // VMEM_READ, VMEM_WRITE implies VMEM.
1669   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
1670            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
1671     InvertedMask &= ~SchedGroupMask::VMEM;
1672 
1673   // DS implies DS_READ, DS_WRITE.
1674   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
1675     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
1676   // DS_READ, DS_WRITE implies DS.
1677   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
1678            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
1679     InvertedMask &= ~SchedGroupMask::DS;
1680 
1681   return InvertedMask;
1682 }
1683 
1684 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
1685     std::vector<SUnit>::reverse_iterator RIter) {
1686   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
1687   // to the instruction having side effects.
1688   resetEdges(*RIter, DAG);
1689   MachineInstr &SGB = *RIter->getInstr();
1690   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
1691   int32_t SGMask = SGB.getOperand(0).getImm();
1692   int32_t Size = SGB.getOperand(1).getImm();
1693   int32_t SyncID = SGB.getOperand(2).getImm();
1694 
1695   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
1696                                                     Size, SyncID, DAG, TII);
1697 
1698   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
1699 }
1700 
1701 void IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
1702   IGLPStrategyID StrategyID =
1703       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
1704   auto S = createIGLPStrategy(StrategyID, DAG, TII);
1705   if (S->shouldApplyStrategy(DAG)) {
1706     IsBottomUp = S->IsBottomUp;
1707     S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, IsReentry);
1708   }
1709 }
1710 
1711 } // namespace
1712 
1713 namespace llvm {
1714 
1715 /// \p IsReentry specifes whether or not this is a reentry into the
1716 /// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
1717 /// same scheduling region (e.g. pre and post-RA scheduling / multiple
1718 /// scheduling "phases"), we can reenter this mutation framework more than once
1719 /// for a given region.
1720 std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation(bool IsReentry) {
1721   return std::make_unique<IGroupLPDAGMutation>(IsReentry);
1722 }
1723 
1724 } // end namespace llvm
1725