1 //===- GlobalCombinerEmitter.cpp - Generate a combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Generate a combiner implementation for GlobalISel from a declarative
10 /// syntax
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallSet.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/ScopedPrinter.h"
19 #include "llvm/Support/Timer.h"
20 #include "llvm/TableGen/Error.h"
21 #include "llvm/TableGen/StringMatcher.h"
22 #include "llvm/TableGen/TableGenBackend.h"
23 #include "CodeGenTarget.h"
24 #include "GlobalISel/CodeExpander.h"
25 #include "GlobalISel/CodeExpansions.h"
26 #include "GlobalISel/GIMatchDag.h"
27 #include "GlobalISel/GIMatchTree.h"
28 #include <cstdint>
29 
30 using namespace llvm;
31 
32 #define DEBUG_TYPE "gicombiner-emitter"
33 
34 // FIXME: Use ALWAYS_ENABLED_STATISTIC once it's available.
35 unsigned NumPatternTotal = 0;
36 STATISTIC(NumPatternTotalStatistic, "Total number of patterns");
37 
38 cl::OptionCategory
39     GICombinerEmitterCat("Options for -gen-global-isel-combiner");
40 static cl::list<std::string>
41     SelectedCombiners("combiners", cl::desc("Emit the specified combiners"),
42                       cl::cat(GICombinerEmitterCat), cl::CommaSeparated);
43 static cl::opt<bool> ShowExpansions(
44     "gicombiner-show-expansions",
45     cl::desc("Use C++ comments to indicate occurence of code expansion"),
46     cl::cat(GICombinerEmitterCat));
47 static cl::opt<bool> StopAfterParse(
48     "gicombiner-stop-after-parse",
49     cl::desc("Stop processing after parsing rules and dump state"),
50     cl::cat(GICombinerEmitterCat));
51 static cl::opt<bool> StopAfterBuild(
52     "gicombiner-stop-after-build",
53     cl::desc("Stop processing after building the match tree"),
54     cl::cat(GICombinerEmitterCat));
55 
56 namespace {
57 typedef uint64_t RuleID;
58 
59 // We're going to be referencing the same small strings quite a lot for operand
60 // names and the like. Make their lifetime management simple with a global
61 // string table.
62 StringSet<> StrTab;
63 
insertStrTab(StringRef S)64 StringRef insertStrTab(StringRef S) {
65   if (S.empty())
66     return S;
67   return StrTab.insert(S).first->first();
68 }
69 
70 class format_partition_name {
71   const GIMatchTree &Tree;
72   unsigned Idx;
73 
74 public:
format_partition_name(const GIMatchTree & Tree,unsigned Idx)75   format_partition_name(const GIMatchTree &Tree, unsigned Idx)
76       : Tree(Tree), Idx(Idx) {}
print(raw_ostream & OS) const77   void print(raw_ostream &OS) const {
78     Tree.getPartitioner()->emitPartitionName(OS, Idx);
79   }
80 };
operator <<(raw_ostream & OS,const format_partition_name & Fmt)81 raw_ostream &operator<<(raw_ostream &OS, const format_partition_name &Fmt) {
82   Fmt.print(OS);
83   return OS;
84 }
85 
86 /// Declares data that is passed from the match stage to the apply stage.
87 class MatchDataInfo {
88   /// The symbol used in the tablegen patterns
89   StringRef PatternSymbol;
90   /// The data type for the variable
91   StringRef Type;
92   /// The name of the variable as declared in the generated matcher.
93   std::string VariableName;
94 
95 public:
MatchDataInfo(StringRef PatternSymbol,StringRef Type,StringRef VariableName)96   MatchDataInfo(StringRef PatternSymbol, StringRef Type, StringRef VariableName)
97       : PatternSymbol(PatternSymbol), Type(Type), VariableName(VariableName) {}
98 
getPatternSymbol() const99   StringRef getPatternSymbol() const { return PatternSymbol; };
getType() const100   StringRef getType() const { return Type; };
getVariableName() const101   StringRef getVariableName() const { return VariableName; };
102 };
103 
104 class RootInfo {
105   StringRef PatternSymbol;
106 
107 public:
RootInfo(StringRef PatternSymbol)108   RootInfo(StringRef PatternSymbol) : PatternSymbol(PatternSymbol) {}
109 
getPatternSymbol() const110   StringRef getPatternSymbol() const { return PatternSymbol; }
111 };
112 
113 class CombineRule {
114 public:
115 
116   using const_matchdata_iterator = std::vector<MatchDataInfo>::const_iterator;
117 
118   struct VarInfo {
119     const GIMatchDagInstr *N;
120     const GIMatchDagOperand *Op;
121     const DagInit *Matcher;
122 
123   public:
VarInfo__anon1c8f329d0111::CombineRule::VarInfo124     VarInfo(const GIMatchDagInstr *N, const GIMatchDagOperand *Op,
125             const DagInit *Matcher)
126         : N(N), Op(Op), Matcher(Matcher) {}
127   };
128 
129 protected:
130   /// A unique ID for this rule
131   /// ID's are used for debugging and run-time disabling of rules among other
132   /// things.
133   RuleID ID;
134 
135   /// A unique ID that can be used for anonymous objects belonging to this rule.
136   /// Used to create unique names in makeNameForAnon*() without making tests
137   /// overly fragile.
138   unsigned UID = 0;
139 
140   /// The record defining this rule.
141   const Record &TheDef;
142 
143   /// The roots of a match. These are the leaves of the DAG that are closest to
144   /// the end of the function. I.e. the nodes that are encountered without
145   /// following any edges of the DAG described by the pattern as we work our way
146   /// from the bottom of the function to the top.
147   std::vector<RootInfo> Roots;
148 
149   GIMatchDag MatchDag;
150 
151   /// A block of arbitrary C++ to finish testing the match.
152   /// FIXME: This is a temporary measure until we have actual pattern matching
153   const CodeInit *MatchingFixupCode = nullptr;
154 
155   /// The MatchData defined by the match stage and required by the apply stage.
156   /// This allows the plumbing of arbitrary data from C++ predicates between the
157   /// stages.
158   ///
159   /// For example, suppose you have:
160   ///   %A = <some-constant-expr>
161   ///   %0 = G_ADD %1, %A
162   /// you could define a GIMatchPredicate that walks %A, constant folds as much
163   /// as possible and returns an APInt containing the discovered constant. You
164   /// could then declare:
165   ///   def apint : GIDefMatchData<"APInt">;
166   /// add it to the rule with:
167   ///   (defs root:$root, apint:$constant)
168   /// evaluate it in the pattern with a C++ function that takes a
169   /// MachineOperand& and an APInt& with:
170   ///   (match [{MIR %root = G_ADD %0, %A }],
171   ///             (constantfold operand:$A, apint:$constant))
172   /// and finally use it in the apply stage with:
173   ///   (apply (create_operand
174   ///                [{ MachineOperand::CreateImm(${constant}.getZExtValue());
175   ///                ]}, apint:$constant),
176   ///             [{MIR %root = FOO %0, %constant }])
177   std::vector<MatchDataInfo> MatchDataDecls;
178 
179   void declareMatchData(StringRef PatternSymbol, StringRef Type,
180                         StringRef VarName);
181 
182   bool parseInstructionMatcher(const CodeGenTarget &Target, StringInit *ArgName,
183                                const Init &Arg,
184                                StringMap<std::vector<VarInfo>> &NamedEdgeDefs,
185                                StringMap<std::vector<VarInfo>> &NamedEdgeUses);
186   bool parseWipMatchOpcodeMatcher(const CodeGenTarget &Target,
187                                   StringInit *ArgName, const Init &Arg);
188 
189 public:
CombineRule(const CodeGenTarget & Target,GIMatchDagContext & Ctx,RuleID ID,const Record & R)190   CombineRule(const CodeGenTarget &Target, GIMatchDagContext &Ctx, RuleID ID,
191               const Record &R)
192       : ID(ID), TheDef(R), MatchDag(Ctx) {}
193   CombineRule(const CombineRule &) = delete;
194 
195   bool parseDefs();
196   bool parseMatcher(const CodeGenTarget &Target);
197 
getID() const198   RuleID getID() const { return ID; }
allocUID()199   unsigned allocUID() { return UID++; }
getName() const200   StringRef getName() const { return TheDef.getName(); }
getDef() const201   const Record &getDef() const { return TheDef; }
getMatchingFixupCode() const202   const CodeInit *getMatchingFixupCode() const { return MatchingFixupCode; }
getNumRoots() const203   size_t getNumRoots() const { return Roots.size(); }
204 
getMatchDag()205   GIMatchDag &getMatchDag() { return MatchDag; }
getMatchDag() const206   const GIMatchDag &getMatchDag() const { return MatchDag; }
207 
208   using const_root_iterator = std::vector<RootInfo>::const_iterator;
roots_begin() const209   const_root_iterator roots_begin() const { return Roots.begin(); }
roots_end() const210   const_root_iterator roots_end() const { return Roots.end(); }
roots() const211   iterator_range<const_root_iterator> roots() const {
212     return llvm::make_range(Roots.begin(), Roots.end());
213   }
214 
matchdata_decls() const215   iterator_range<const_matchdata_iterator> matchdata_decls() const {
216     return make_range(MatchDataDecls.begin(), MatchDataDecls.end());
217   }
218 
219   /// Export expansions for this rule
declareExpansions(CodeExpansions & Expansions) const220   void declareExpansions(CodeExpansions &Expansions) const {
221     for (const auto &I : matchdata_decls())
222       Expansions.declare(I.getPatternSymbol(), I.getVariableName());
223   }
224 
225   /// The matcher will begin from the roots and will perform the match by
226   /// traversing the edges to cover the whole DAG. This function reverses DAG
227   /// edges such that everything is reachable from a root. This is part of the
228   /// preparation work for flattening the DAG into a tree.
reorientToRoots()229   void reorientToRoots() {
230     SmallSet<const GIMatchDagInstr *, 5> Roots;
231     SmallSet<const GIMatchDagInstr *, 5> Visited;
232     SmallSet<GIMatchDagEdge *, 20> EdgesRemaining;
233 
234     for (auto &I : MatchDag.roots()) {
235       Roots.insert(I);
236       Visited.insert(I);
237     }
238     for (auto &I : MatchDag.edges())
239       EdgesRemaining.insert(I);
240 
241     bool Progressed = false;
242     SmallSet<GIMatchDagEdge *, 20> EdgesToRemove;
243     while (!EdgesRemaining.empty()) {
244       for (auto EI = EdgesRemaining.begin(), EE = EdgesRemaining.end();
245            EI != EE; ++EI) {
246         if (Visited.count((*EI)->getFromMI())) {
247           if (Roots.count((*EI)->getToMI()))
248             PrintError(TheDef.getLoc(), "One or more roots are unnecessary");
249           Visited.insert((*EI)->getToMI());
250           EdgesToRemove.insert(*EI);
251           Progressed = true;
252         }
253       }
254       for (GIMatchDagEdge *ToRemove : EdgesToRemove)
255         EdgesRemaining.erase(ToRemove);
256       EdgesToRemove.clear();
257 
258       for (auto EI = EdgesRemaining.begin(), EE = EdgesRemaining.end();
259            EI != EE; ++EI) {
260         if (Visited.count((*EI)->getToMI())) {
261           (*EI)->reverse();
262           Visited.insert((*EI)->getToMI());
263           EdgesToRemove.insert(*EI);
264           Progressed = true;
265         }
266         for (GIMatchDagEdge *ToRemove : EdgesToRemove)
267           EdgesRemaining.erase(ToRemove);
268         EdgesToRemove.clear();
269       }
270 
271       if (!Progressed) {
272         LLVM_DEBUG(dbgs() << "No progress\n");
273         return;
274       }
275       Progressed = false;
276     }
277   }
278 };
279 
280 /// A convenience function to check that an Init refers to a specific def. This
281 /// is primarily useful for testing for defs and similar in DagInit's since
282 /// DagInit's support any type inside them.
isSpecificDef(const Init & N,StringRef Def)283 static bool isSpecificDef(const Init &N, StringRef Def) {
284   if (const DefInit *OpI = dyn_cast<DefInit>(&N))
285     if (OpI->getDef()->getName() == Def)
286       return true;
287   return false;
288 }
289 
290 /// A convenience function to check that an Init refers to a def that is a
291 /// subclass of the given class and coerce it to a def if it is. This is
292 /// primarily useful for testing for subclasses of GIMatchKind and similar in
293 /// DagInit's since DagInit's support any type inside them.
getDefOfSubClass(const Init & N,StringRef Cls)294 static Record *getDefOfSubClass(const Init &N, StringRef Cls) {
295   if (const DefInit *OpI = dyn_cast<DefInit>(&N))
296     if (OpI->getDef()->isSubClassOf(Cls))
297       return OpI->getDef();
298   return nullptr;
299 }
300 
301 /// A convenience function to check that an Init refers to a dag whose operator
302 /// is a specific def and coerce it to a dag if it is. This is primarily useful
303 /// for testing for subclasses of GIMatchKind and similar in DagInit's since
304 /// DagInit's support any type inside them.
getDagWithSpecificOperator(const Init & N,StringRef Name)305 static const DagInit *getDagWithSpecificOperator(const Init &N,
306                                                  StringRef Name) {
307   if (const DagInit *I = dyn_cast<DagInit>(&N))
308     if (I->getNumArgs() > 0)
309       if (const DefInit *OpI = dyn_cast<DefInit>(I->getOperator()))
310         if (OpI->getDef()->getName() == Name)
311           return I;
312   return nullptr;
313 }
314 
315 /// A convenience function to check that an Init refers to a dag whose operator
316 /// is a def that is a subclass of the given class and coerce it to a dag if it
317 /// is. This is primarily useful for testing for subclasses of GIMatchKind and
318 /// similar in DagInit's since DagInit's support any type inside them.
getDagWithOperatorOfSubClass(const Init & N,StringRef Cls)319 static const DagInit *getDagWithOperatorOfSubClass(const Init &N,
320                                                    StringRef Cls) {
321   if (const DagInit *I = dyn_cast<DagInit>(&N))
322     if (I->getNumArgs() > 0)
323       if (const DefInit *OpI = dyn_cast<DefInit>(I->getOperator()))
324         if (OpI->getDef()->isSubClassOf(Cls))
325           return I;
326   return nullptr;
327 }
328 
makeNameForAnonInstr(CombineRule & Rule)329 StringRef makeNameForAnonInstr(CombineRule &Rule) {
330   return insertStrTab(to_string(
331       format("__anon%" PRIu64 "_%u", Rule.getID(), Rule.allocUID())));
332 }
333 
makeDebugName(CombineRule & Rule,StringRef Name)334 StringRef makeDebugName(CombineRule &Rule, StringRef Name) {
335   return insertStrTab(Name.empty() ? makeNameForAnonInstr(Rule) : StringRef(Name));
336 }
337 
makeNameForAnonPredicate(CombineRule & Rule)338 StringRef makeNameForAnonPredicate(CombineRule &Rule) {
339   return insertStrTab(to_string(
340       format("__anonpred%" PRIu64 "_%u", Rule.getID(), Rule.allocUID())));
341 }
342 
declareMatchData(StringRef PatternSymbol,StringRef Type,StringRef VarName)343 void CombineRule::declareMatchData(StringRef PatternSymbol, StringRef Type,
344                                    StringRef VarName) {
345   MatchDataDecls.emplace_back(PatternSymbol, Type, VarName);
346 }
347 
parseDefs()348 bool CombineRule::parseDefs() {
349   NamedRegionTimer T("parseDefs", "Time spent parsing the defs", "Rule Parsing",
350                      "Time spent on rule parsing", TimeRegions);
351   DagInit *Defs = TheDef.getValueAsDag("Defs");
352 
353   if (Defs->getOperatorAsDef(TheDef.getLoc())->getName() != "defs") {
354     PrintError(TheDef.getLoc(), "Expected defs operator");
355     return false;
356   }
357 
358   for (unsigned I = 0, E = Defs->getNumArgs(); I < E; ++I) {
359     // Roots should be collected into Roots
360     if (isSpecificDef(*Defs->getArg(I), "root")) {
361       Roots.emplace_back(Defs->getArgNameStr(I));
362       continue;
363     }
364 
365     // Subclasses of GIDefMatchData should declare that this rule needs to pass
366     // data from the match stage to the apply stage, and ensure that the
367     // generated matcher has a suitable variable for it to do so.
368     if (Record *MatchDataRec =
369             getDefOfSubClass(*Defs->getArg(I), "GIDefMatchData")) {
370       declareMatchData(Defs->getArgNameStr(I),
371                        MatchDataRec->getValueAsString("Type"),
372                        llvm::to_string(llvm::format("MatchData%" PRIu64, ID)));
373       continue;
374     }
375 
376     // Otherwise emit an appropriate error message.
377     if (getDefOfSubClass(*Defs->getArg(I), "GIDefKind"))
378       PrintError(TheDef.getLoc(),
379                  "This GIDefKind not implemented in tablegen");
380     else if (getDefOfSubClass(*Defs->getArg(I), "GIDefKindWithArgs"))
381       PrintError(TheDef.getLoc(),
382                  "This GIDefKindWithArgs not implemented in tablegen");
383     else
384       PrintError(TheDef.getLoc(),
385                  "Expected a subclass of GIDefKind or a sub-dag whose "
386                  "operator is of type GIDefKindWithArgs");
387     return false;
388   }
389 
390   if (Roots.empty()) {
391     PrintError(TheDef.getLoc(), "Combine rules must have at least one root");
392     return false;
393   }
394   return true;
395 }
396 
397 // Parse an (Instruction $a:Arg1, $b:Arg2, ...) matcher. Edges are formed
398 // between matching operand names between different matchers.
parseInstructionMatcher(const CodeGenTarget & Target,StringInit * ArgName,const Init & Arg,StringMap<std::vector<VarInfo>> & NamedEdgeDefs,StringMap<std::vector<VarInfo>> & NamedEdgeUses)399 bool CombineRule::parseInstructionMatcher(
400     const CodeGenTarget &Target, StringInit *ArgName, const Init &Arg,
401     StringMap<std::vector<VarInfo>> &NamedEdgeDefs,
402     StringMap<std::vector<VarInfo>> &NamedEdgeUses) {
403   if (const DagInit *Matcher =
404           getDagWithOperatorOfSubClass(Arg, "Instruction")) {
405     auto &Instr =
406         Target.getInstruction(Matcher->getOperatorAsDef(TheDef.getLoc()));
407 
408     StringRef Name = ArgName ? ArgName->getValue() : "";
409 
410     GIMatchDagInstr *N =
411         MatchDag.addInstrNode(makeDebugName(*this, Name), insertStrTab(Name),
412                               MatchDag.getContext().makeOperandList(Instr));
413 
414     N->setOpcodeAnnotation(&Instr);
415     const auto &P = MatchDag.addPredicateNode<GIMatchDagOpcodePredicate>(
416         makeNameForAnonPredicate(*this), Instr);
417     MatchDag.addPredicateDependency(N, nullptr, P, &P->getOperandInfo()["mi"]);
418     unsigned OpIdx = 0;
419     for (const auto &NameInit : Matcher->getArgNames()) {
420       StringRef Name = insertStrTab(NameInit->getAsUnquotedString());
421       if (Name.empty())
422         continue;
423       N->assignNameToOperand(OpIdx, Name);
424 
425       // Record the endpoints of any named edges. We'll add the cartesian
426       // product of edges later.
427       const auto &InstrOperand = N->getOperandInfo()[OpIdx];
428       if (InstrOperand.isDef()) {
429         NamedEdgeDefs.try_emplace(Name);
430         NamedEdgeDefs[Name].emplace_back(N, &InstrOperand, Matcher);
431       } else {
432         NamedEdgeUses.try_emplace(Name);
433         NamedEdgeUses[Name].emplace_back(N, &InstrOperand, Matcher);
434       }
435 
436       if (InstrOperand.isDef()) {
437         if (find_if(Roots, [&](const RootInfo &X) {
438               return X.getPatternSymbol() == Name;
439             }) != Roots.end()) {
440           N->setMatchRoot();
441         }
442       }
443 
444       OpIdx++;
445     }
446 
447     return true;
448   }
449   return false;
450 }
451 
452 // Parse the wip_match_opcode placeholder that's temporarily present in lieu of
453 // implementing macros or choices between two matchers.
parseWipMatchOpcodeMatcher(const CodeGenTarget & Target,StringInit * ArgName,const Init & Arg)454 bool CombineRule::parseWipMatchOpcodeMatcher(const CodeGenTarget &Target,
455                                              StringInit *ArgName,
456                                              const Init &Arg) {
457   if (const DagInit *Matcher =
458           getDagWithSpecificOperator(Arg, "wip_match_opcode")) {
459     StringRef Name = ArgName ? ArgName->getValue() : "";
460 
461     GIMatchDagInstr *N =
462         MatchDag.addInstrNode(makeDebugName(*this, Name), insertStrTab(Name),
463                               MatchDag.getContext().makeEmptyOperandList());
464 
465     if (find_if(Roots, [&](const RootInfo &X) {
466           return ArgName && X.getPatternSymbol() == ArgName->getValue();
467         }) != Roots.end()) {
468       N->setMatchRoot();
469     }
470 
471     const auto &P = MatchDag.addPredicateNode<GIMatchDagOneOfOpcodesPredicate>(
472         makeNameForAnonPredicate(*this));
473     MatchDag.addPredicateDependency(N, nullptr, P, &P->getOperandInfo()["mi"]);
474     // Each argument is an opcode that will pass this predicate. Add them all to
475     // the predicate implementation
476     for (const auto &Arg : Matcher->getArgs()) {
477       Record *OpcodeDef = getDefOfSubClass(*Arg, "Instruction");
478       if (OpcodeDef) {
479         P->addOpcode(&Target.getInstruction(OpcodeDef));
480         continue;
481       }
482       PrintError(TheDef.getLoc(),
483                  "Arguments to wip_match_opcode must be instructions");
484       return false;
485     }
486     return true;
487   }
488   return false;
489 }
parseMatcher(const CodeGenTarget & Target)490 bool CombineRule::parseMatcher(const CodeGenTarget &Target) {
491   NamedRegionTimer T("parseMatcher", "Time spent parsing the matcher",
492                      "Rule Parsing", "Time spent on rule parsing", TimeRegions);
493   StringMap<std::vector<VarInfo>> NamedEdgeDefs;
494   StringMap<std::vector<VarInfo>> NamedEdgeUses;
495   DagInit *Matchers = TheDef.getValueAsDag("Match");
496 
497   if (Matchers->getOperatorAsDef(TheDef.getLoc())->getName() != "match") {
498     PrintError(TheDef.getLoc(), "Expected match operator");
499     return false;
500   }
501 
502   if (Matchers->getNumArgs() == 0) {
503     PrintError(TheDef.getLoc(), "Matcher is empty");
504     return false;
505   }
506 
507   // The match section consists of a list of matchers and predicates. Parse each
508   // one and add the equivalent GIMatchDag nodes, predicates, and edges.
509   for (unsigned I = 0; I < Matchers->getNumArgs(); ++I) {
510     if (parseInstructionMatcher(Target, Matchers->getArgName(I),
511                                 *Matchers->getArg(I), NamedEdgeDefs,
512                                 NamedEdgeUses))
513       continue;
514 
515     if (parseWipMatchOpcodeMatcher(Target, Matchers->getArgName(I),
516                                    *Matchers->getArg(I)))
517       continue;
518 
519 
520     // Parse arbitrary C++ code we have in lieu of supporting MIR matching
521     if (const CodeInit *CodeI = dyn_cast<CodeInit>(Matchers->getArg(I))) {
522       assert(!MatchingFixupCode &&
523              "Only one block of arbitrary code is currently permitted");
524       MatchingFixupCode = CodeI;
525       MatchDag.setHasPostMatchPredicate(true);
526       continue;
527     }
528 
529     PrintError(TheDef.getLoc(),
530                "Expected a subclass of GIMatchKind or a sub-dag whose "
531                "operator is either of a GIMatchKindWithArgs or Instruction");
532     PrintNote("Pattern was `" + Matchers->getArg(I)->getAsString() + "'");
533     return false;
534   }
535 
536   // Add the cartesian product of use -> def edges.
537   bool FailedToAddEdges = false;
538   for (const auto &NameAndDefs : NamedEdgeDefs) {
539     if (NameAndDefs.getValue().size() > 1) {
540       PrintError(TheDef.getLoc(),
541                  "Two different MachineInstrs cannot def the same vreg");
542       for (const auto &NameAndDefOp : NameAndDefs.getValue())
543         PrintNote("in " + to_string(*NameAndDefOp.N) + " created from " +
544                   to_string(*NameAndDefOp.Matcher) + "");
545       FailedToAddEdges = true;
546     }
547     const auto &Uses = NamedEdgeUses[NameAndDefs.getKey()];
548     for (const VarInfo &DefVar : NameAndDefs.getValue()) {
549       for (const VarInfo &UseVar : Uses) {
550         MatchDag.addEdge(insertStrTab(NameAndDefs.getKey()), UseVar.N, UseVar.Op,
551                          DefVar.N, DefVar.Op);
552       }
553     }
554   }
555   if (FailedToAddEdges)
556     return false;
557 
558   // If a variable is referenced in multiple use contexts then we need a
559   // predicate to confirm they are the same operand. We can elide this if it's
560   // also referenced in a def context and we're traversing the def-use chain
561   // from the def to the uses but we can't know which direction we're going
562   // until after reorientToRoots().
563   for (const auto &NameAndUses : NamedEdgeUses) {
564     const auto &Uses = NameAndUses.getValue();
565     if (Uses.size() > 1) {
566       const auto &LeadingVar = Uses.front();
567       for (const auto &Var : ArrayRef<VarInfo>(Uses).drop_front()) {
568         // Add a predicate for each pair until we've covered the whole
569         // equivalence set. We could test the whole set in a single predicate
570         // but that means we can't test any equivalence until all the MO's are
571         // available which can lead to wasted work matching the DAG when this
572         // predicate can already be seen to have failed.
573         //
574         // We have a similar problem due to the need to wait for a particular MO
575         // before being able to test any of them. However, that is mitigated by
576         // the order in which we build the DAG. We build from the roots outwards
577         // so by using the first recorded use in all the predicates, we are
578         // making the dependency on one of the earliest visited references in
579         // the DAG. It's not guaranteed once the generated matcher is optimized
580         // (because the factoring the common portions of rules might change the
581         // visit order) but this should mean that these predicates depend on the
582         // first MO to become available.
583         const auto &P = MatchDag.addPredicateNode<GIMatchDagSameMOPredicate>(
584             makeNameForAnonPredicate(*this));
585         MatchDag.addPredicateDependency(LeadingVar.N, LeadingVar.Op, P,
586                                         &P->getOperandInfo()["mi0"]);
587         MatchDag.addPredicateDependency(Var.N, Var.Op, P,
588                                         &P->getOperandInfo()["mi1"]);
589       }
590     }
591   }
592   return true;
593 }
594 
595 class GICombinerEmitter {
596   StringRef Name;
597   const CodeGenTarget &Target;
598   Record *Combiner;
599   std::vector<std::unique_ptr<CombineRule>> Rules;
600   GIMatchDagContext MatchDagCtx;
601 
602   std::unique_ptr<CombineRule> makeCombineRule(const Record &R);
603 
604   void gatherRules(std::vector<std::unique_ptr<CombineRule>> &ActiveRules,
605                    const std::vector<Record *> &&RulesAndGroups);
606 
607 public:
608   explicit GICombinerEmitter(RecordKeeper &RK, const CodeGenTarget &Target,
609                              StringRef Name, Record *Combiner);
~GICombinerEmitter()610   ~GICombinerEmitter() {}
611 
getClassName() const612   StringRef getClassName() const {
613     return Combiner->getValueAsString("Classname");
614   }
615   void run(raw_ostream &OS);
616 
617   /// Emit the name matcher (guarded by #ifndef NDEBUG) used to disable rules in
618   /// response to the generated cl::opt.
619   void emitNameMatcher(raw_ostream &OS) const;
620 
621   void generateDeclarationsCodeForTree(raw_ostream &OS, const GIMatchTree &Tree) const;
622   void generateCodeForTree(raw_ostream &OS, const GIMatchTree &Tree,
623                            StringRef Indent) const;
624 };
625 
GICombinerEmitter(RecordKeeper & RK,const CodeGenTarget & Target,StringRef Name,Record * Combiner)626 GICombinerEmitter::GICombinerEmitter(RecordKeeper &RK,
627                                      const CodeGenTarget &Target,
628                                      StringRef Name, Record *Combiner)
629     : Name(Name), Target(Target), Combiner(Combiner) {}
630 
emitNameMatcher(raw_ostream & OS) const631 void GICombinerEmitter::emitNameMatcher(raw_ostream &OS) const {
632   std::vector<std::pair<std::string, std::string>> Cases;
633   Cases.reserve(Rules.size());
634 
635   for (const CombineRule &EnumeratedRule : make_pointee_range(Rules)) {
636     std::string Code;
637     raw_string_ostream SS(Code);
638     SS << "return " << EnumeratedRule.getID() << ";\n";
639     Cases.push_back(std::make_pair(EnumeratedRule.getName(), SS.str()));
640   }
641 
642   OS << "static Optional<uint64_t> getRuleIdxForIdentifier(StringRef "
643         "RuleIdentifier) {\n"
644      << "  uint64_t I;\n"
645      << "  // getAtInteger(...) returns false on success\n"
646      << "  bool Parsed = !RuleIdentifier.getAsInteger(0, I);\n"
647      << "  if (Parsed)\n"
648      << "    return I;\n\n"
649      << "#ifndef NDEBUG\n";
650   StringMatcher Matcher("RuleIdentifier", Cases, OS);
651   Matcher.Emit();
652   OS << "#endif // ifndef NDEBUG\n\n"
653      << "  return None;\n"
654      << "}\n";
655 }
656 
657 std::unique_ptr<CombineRule>
makeCombineRule(const Record & TheDef)658 GICombinerEmitter::makeCombineRule(const Record &TheDef) {
659   std::unique_ptr<CombineRule> Rule =
660       std::make_unique<CombineRule>(Target, MatchDagCtx, NumPatternTotal, TheDef);
661 
662   if (!Rule->parseDefs())
663     return nullptr;
664   if (!Rule->parseMatcher(Target))
665     return nullptr;
666 
667   Rule->reorientToRoots();
668 
669   LLVM_DEBUG({
670     dbgs() << "Parsed rule defs/match for '" << Rule->getName() << "'\n";
671     Rule->getMatchDag().dump();
672     Rule->getMatchDag().writeDOTGraph(dbgs(), Rule->getName());
673   });
674   if (StopAfterParse)
675     return Rule;
676 
677   // For now, don't support traversing from def to use. We'll come back to
678   // this later once we have the algorithm changes to support it.
679   bool EmittedDefToUseError = false;
680   for (const auto &E : Rule->getMatchDag().edges()) {
681     if (E->isDefToUse()) {
682       if (!EmittedDefToUseError) {
683         PrintError(
684             TheDef.getLoc(),
685             "Generated state machine cannot lookup uses from a def (yet)");
686         EmittedDefToUseError = true;
687       }
688       PrintNote("Node " + to_string(*E->getFromMI()));
689       PrintNote("Node " + to_string(*E->getToMI()));
690       PrintNote("Edge " + to_string(*E));
691     }
692   }
693   if (EmittedDefToUseError)
694     return nullptr;
695 
696   // For now, don't support multi-root rules. We'll come back to this later
697   // once we have the algorithm changes to support it.
698   if (Rule->getNumRoots() > 1) {
699     PrintError(TheDef.getLoc(), "Multi-root matches are not supported (yet)");
700     return nullptr;
701   }
702   return Rule;
703 }
704 
705 /// Recurse into GICombineGroup's and flatten the ruleset into a simple list.
gatherRules(std::vector<std::unique_ptr<CombineRule>> & ActiveRules,const std::vector<Record * > && RulesAndGroups)706 void GICombinerEmitter::gatherRules(
707     std::vector<std::unique_ptr<CombineRule>> &ActiveRules,
708     const std::vector<Record *> &&RulesAndGroups) {
709   for (Record *R : RulesAndGroups) {
710     if (R->isValueUnset("Rules")) {
711       std::unique_ptr<CombineRule> Rule = makeCombineRule(*R);
712       if (Rule == nullptr) {
713         PrintError(R->getLoc(), "Failed to parse rule");
714         continue;
715       }
716       ActiveRules.emplace_back(std::move(Rule));
717       ++NumPatternTotal;
718     } else
719       gatherRules(ActiveRules, R->getValueAsListOfDefs("Rules"));
720   }
721 }
722 
generateCodeForTree(raw_ostream & OS,const GIMatchTree & Tree,StringRef Indent) const723 void GICombinerEmitter::generateCodeForTree(raw_ostream &OS,
724                                             const GIMatchTree &Tree,
725                                             StringRef Indent) const {
726   if (Tree.getPartitioner() != nullptr) {
727     Tree.getPartitioner()->generatePartitionSelectorCode(OS, Indent);
728     for (const auto &EnumChildren : enumerate(Tree.children())) {
729       OS << Indent << "if (Partition == " << EnumChildren.index() << " /* "
730          << format_partition_name(Tree, EnumChildren.index()) << " */) {\n";
731       generateCodeForTree(OS, EnumChildren.value(), (Indent + "  ").str());
732       OS << Indent << "}\n";
733     }
734     return;
735   }
736 
737   bool AnyFullyTested = false;
738   for (const auto &Leaf : Tree.possible_leaves()) {
739     OS << Indent << "// Leaf name: " << Leaf.getName() << "\n";
740 
741     const CombineRule *Rule = Leaf.getTargetData<CombineRule>();
742     const Record &RuleDef = Rule->getDef();
743 
744     OS << Indent << "// Rule: " << RuleDef.getName() << "\n"
745        << Indent << "if (!isRuleDisabled(" << Rule->getID() << ")) {\n";
746 
747     CodeExpansions Expansions;
748     for (const auto &VarBinding : Leaf.var_bindings()) {
749       if (VarBinding.isInstr())
750         Expansions.declare(VarBinding.getName(),
751                            "MIs[" + to_string(VarBinding.getInstrID()) + "]");
752       else
753         Expansions.declare(VarBinding.getName(),
754                            "MIs[" + to_string(VarBinding.getInstrID()) +
755                                "]->getOperand(" +
756                                to_string(VarBinding.getOpIdx()) + ")");
757     }
758     Rule->declareExpansions(Expansions);
759 
760     DagInit *Applyer = RuleDef.getValueAsDag("Apply");
761     if (Applyer->getOperatorAsDef(RuleDef.getLoc())->getName() !=
762         "apply") {
763       PrintError(RuleDef.getLoc(), "Expected apply operator");
764       return;
765     }
766 
767     OS << Indent << "  if (1\n";
768 
769     // Attempt to emit code for any untested predicates left over. Note that
770     // isFullyTested() will remain false even if we succeed here and therefore
771     // combine rule elision will not be performed. This is because we do not
772     // know if there's any connection between the predicates for each leaf and
773     // therefore can't tell if one makes another unreachable. Ideally, the
774     // partitioner(s) would be sufficiently complete to prevent us from having
775     // untested predicates left over.
776     for (const GIMatchDagPredicate *Predicate : Leaf.untested_predicates()) {
777       if (Predicate->generateCheckCode(OS, (Indent + "      ").str(),
778                                        Expansions))
779         continue;
780       PrintError(RuleDef.getLoc(),
781                  "Unable to test predicate used in rule");
782       PrintNote(SMLoc(),
783                 "This indicates an incomplete implementation in tablegen");
784       Predicate->print(errs());
785       errs() << "\n";
786       OS << Indent
787          << "llvm_unreachable(\"TableGen did not emit complete code for this "
788             "path\");\n";
789       break;
790     }
791 
792     if (Rule->getMatchingFixupCode() &&
793         !Rule->getMatchingFixupCode()->getValue().empty()) {
794       // FIXME: Single-use lambda's like this are a serious compile-time
795       // performance and memory issue. It's convenient for this early stage to
796       // defer some work to successive patches but we need to eliminate this
797       // before the ruleset grows to small-moderate size. Last time, it became
798       // a big problem for low-mem systems around the 500 rule mark but by the
799       // time we grow that large we should have merged the ISel match table
800       // mechanism with the Combiner.
801       OS << Indent << "      && [&]() {\n"
802          << Indent << "      "
803          << CodeExpander(Rule->getMatchingFixupCode()->getValue(), Expansions,
804                          Rule->getMatchingFixupCode()->getLoc(), ShowExpansions)
805          << "\n"
806          << Indent << "      return true;\n"
807          << Indent << "  }()";
808     }
809     OS << ") {\n" << Indent << "   ";
810 
811     if (const CodeInit *Code = dyn_cast<CodeInit>(Applyer->getArg(0))) {
812       OS << CodeExpander(Code->getAsUnquotedString(), Expansions,
813                          Code->getLoc(), ShowExpansions)
814          << "\n"
815          << Indent << "    return true;\n"
816          << Indent << "  }\n";
817     } else {
818       PrintError(RuleDef.getLoc(), "Expected apply code block");
819       return;
820     }
821 
822     OS << Indent << "}\n";
823 
824     assert(Leaf.isFullyTraversed());
825 
826     // If we didn't have any predicates left over and we're not using the
827     // trap-door we have to support arbitrary C++ code while we're migrating to
828     // the declarative style then we know that subsequent leaves are
829     // unreachable.
830     if (Leaf.isFullyTested() &&
831         (!Rule->getMatchingFixupCode() ||
832          Rule->getMatchingFixupCode()->getValue().empty())) {
833       AnyFullyTested = true;
834       OS << Indent
835          << "llvm_unreachable(\"Combine rule elision was incorrect\");\n"
836          << Indent << "return false;\n";
837     }
838   }
839   if (!AnyFullyTested)
840     OS << Indent << "return false;\n";
841 }
842 
run(raw_ostream & OS)843 void GICombinerEmitter::run(raw_ostream &OS) {
844   gatherRules(Rules, Combiner->getValueAsListOfDefs("Rules"));
845   if (StopAfterParse) {
846     MatchDagCtx.print(errs());
847     PrintNote(Combiner->getLoc(),
848               "Terminating due to -gicombiner-stop-after-parse");
849     return;
850   }
851   if (ErrorsPrinted)
852     PrintFatalError(Combiner->getLoc(), "Failed to parse one or more rules");
853   LLVM_DEBUG(dbgs() << "Optimizing tree for " << Rules.size() << " rules\n");
854   std::unique_ptr<GIMatchTree> Tree;
855   {
856     NamedRegionTimer T("Optimize", "Time spent optimizing the combiner",
857                        "Code Generation", "Time spent generating code",
858                        TimeRegions);
859 
860     GIMatchTreeBuilder TreeBuilder(0);
861     for (const auto &Rule : Rules) {
862       bool HadARoot = false;
863       for (const auto &Root : enumerate(Rule->getMatchDag().roots())) {
864         TreeBuilder.addLeaf(Rule->getName(), Root.index(), Rule->getMatchDag(),
865                             Rule.get());
866         HadARoot = true;
867       }
868       if (!HadARoot)
869         PrintFatalError(Rule->getDef().getLoc(), "All rules must have a root");
870     }
871 
872     Tree = TreeBuilder.run();
873   }
874   if (StopAfterBuild) {
875     Tree->writeDOTGraph(outs());
876     PrintNote(Combiner->getLoc(),
877               "Terminating due to -gicombiner-stop-after-build");
878     return;
879   }
880 
881   NamedRegionTimer T("Emit", "Time spent emitting the combiner",
882                      "Code Generation", "Time spent generating code",
883                      TimeRegions);
884   OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_DEPS\n"
885      << "#include \"llvm/ADT/SparseBitVector.h\"\n"
886      << "namespace llvm {\n"
887      << "extern cl::OptionCategory GICombinerOptionCategory;\n"
888      << "} // end namespace llvm\n"
889      << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_DEPS\n\n";
890 
891   OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_H\n"
892      << "class " << getClassName() << " {\n"
893      << "  SparseBitVector<> DisabledRules;\n"
894      << "\n"
895      << "public:\n"
896      << "  bool parseCommandLineOption();\n"
897      << "  bool isRuleDisabled(unsigned ID) const;\n"
898      << "  bool setRuleDisabled(StringRef RuleIdentifier);\n"
899      << "\n"
900      << "  bool tryCombineAll(\n"
901      << "    GISelChangeObserver &Observer,\n"
902      << "    MachineInstr &MI,\n"
903      << "    MachineIRBuilder &B,\n"
904      << "    CombinerHelper &Helper) const;\n"
905      << "};\n\n";
906 
907   emitNameMatcher(OS);
908 
909   OS << "bool " << getClassName()
910      << "::setRuleDisabled(StringRef RuleIdentifier) {\n"
911      << "  std::pair<StringRef, StringRef> RangePair = "
912         "RuleIdentifier.split('-');\n"
913      << "  if (!RangePair.second.empty()) {\n"
914      << "    const auto First = getRuleIdxForIdentifier(RangePair.first);\n"
915      << "    const auto Last = getRuleIdxForIdentifier(RangePair.second);\n"
916      << "    if (!First.hasValue() || !Last.hasValue())\n"
917      << "      return false;\n"
918      << "    if (First >= Last)\n"
919      << "      report_fatal_error(\"Beginning of range should be before end of "
920         "range\");\n"
921      << "    for (auto I = First.getValue(); I < Last.getValue(); ++I)\n"
922      << "      DisabledRules.set(I);\n"
923      << "    return true;\n"
924      << "  } else {\n"
925      << "    const auto I = getRuleIdxForIdentifier(RangePair.first);\n"
926      << "    if (!I.hasValue())\n"
927      << "      return false;\n"
928      << "    DisabledRules.set(I.getValue());\n"
929      << "    return true;\n"
930      << "  }\n"
931      << "  return false;\n"
932      << "}\n";
933 
934   OS << "bool " << getClassName()
935      << "::isRuleDisabled(unsigned RuleID) const {\n"
936      << "  return DisabledRules.test(RuleID);\n"
937      << "}\n";
938   OS << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_H\n\n";
939 
940   OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_CPP\n"
941      << "\n"
942      << "cl::list<std::string> " << Name << "Option(\n"
943      << "    \"" << Name.lower() << "-disable-rule\",\n"
944      << "    cl::desc(\"Disable one or more combiner rules temporarily in "
945      << "the " << Name << " pass\"),\n"
946      << "    cl::CommaSeparated,\n"
947      << "    cl::Hidden,\n"
948      << "    cl::cat(GICombinerOptionCategory));\n"
949      << "\n"
950      << "bool " << getClassName() << "::parseCommandLineOption() {\n"
951      << "  for (const auto &Identifier : " << Name << "Option)\n"
952      << "    if (!setRuleDisabled(Identifier))\n"
953      << "      return false;\n"
954      << "  return true;\n"
955      << "}\n\n";
956 
957   OS << "bool " << getClassName() << "::tryCombineAll(\n"
958      << "    GISelChangeObserver &Observer,\n"
959      << "    MachineInstr &MI,\n"
960      << "    MachineIRBuilder &B,\n"
961      << "    CombinerHelper &Helper) const {\n"
962      << "  MachineBasicBlock *MBB = MI.getParent();\n"
963      << "  MachineFunction *MF = MBB->getParent();\n"
964      << "  MachineRegisterInfo &MRI = MF->getRegInfo();\n"
965      << "  SmallVector<MachineInstr *, 8> MIs = { &MI };\n\n"
966      << "  (void)MBB; (void)MF; (void)MRI;\n\n";
967 
968   OS << "  // Match data\n";
969   for (const auto &Rule : Rules)
970     for (const auto &I : Rule->matchdata_decls())
971       OS << "  " << I.getType() << " " << I.getVariableName() << ";\n";
972   OS << "\n";
973 
974   OS << "  int Partition = -1;\n";
975   generateCodeForTree(OS, *Tree, "  ");
976   OS << "\n  return false;\n"
977      << "}\n"
978      << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_CPP\n";
979 }
980 
981 } // end anonymous namespace
982 
983 //===----------------------------------------------------------------------===//
984 
985 namespace llvm {
EmitGICombiner(RecordKeeper & RK,raw_ostream & OS)986 void EmitGICombiner(RecordKeeper &RK, raw_ostream &OS) {
987   CodeGenTarget Target(RK);
988   emitSourceFileHeader("Global Combiner", OS);
989 
990   if (SelectedCombiners.empty())
991     PrintFatalError("No combiners selected with -combiners");
992   for (const auto &Combiner : SelectedCombiners) {
993     Record *CombinerDef = RK.getDef(Combiner);
994     if (!CombinerDef)
995       PrintFatalError("Could not find " + Combiner);
996     GICombinerEmitter(RK, Target, Combiner, CombinerDef).run(OS);
997   }
998   NumPatternTotalStatistic = NumPatternTotal;
999 }
1000 
1001 } // namespace llvm
1002