1 //===- GlobalCombinerEmitter.cpp - Generate a combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Generate a combiner implementation for GlobalISel from a declarative
10 /// syntax
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallSet.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ScopedPrinter.h"
20 #include "llvm/Support/Timer.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/StringMatcher.h"
23 #include "llvm/TableGen/TableGenBackend.h"
24 #include "CodeGenTarget.h"
25 #include "GlobalISel/CodeExpander.h"
26 #include "GlobalISel/CodeExpansions.h"
27 #include "GlobalISel/GIMatchDag.h"
28 #include "GlobalISel/GIMatchTree.h"
29 #include <cstdint>
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "gicombiner-emitter"
34 
35 // FIXME: Use ALWAYS_ENABLED_STATISTIC once it's available.
36 unsigned NumPatternTotal = 0;
37 STATISTIC(NumPatternTotalStatistic, "Total number of patterns");
38 
39 cl::OptionCategory
40     GICombinerEmitterCat("Options for -gen-global-isel-combiner");
41 static cl::list<std::string>
42     SelectedCombiners("combiners", cl::desc("Emit the specified combiners"),
43                       cl::cat(GICombinerEmitterCat), cl::CommaSeparated);
44 static cl::opt<bool> ShowExpansions(
45     "gicombiner-show-expansions",
46     cl::desc("Use C++ comments to indicate occurence of code expansion"),
47     cl::cat(GICombinerEmitterCat));
48 static cl::opt<bool> StopAfterParse(
49     "gicombiner-stop-after-parse",
50     cl::desc("Stop processing after parsing rules and dump state"),
51     cl::cat(GICombinerEmitterCat));
52 static cl::opt<bool> StopAfterBuild(
53     "gicombiner-stop-after-build",
54     cl::desc("Stop processing after building the match tree"),
55     cl::cat(GICombinerEmitterCat));
56 
57 namespace {
58 typedef uint64_t RuleID;
59 
60 // We're going to be referencing the same small strings quite a lot for operand
61 // names and the like. Make their lifetime management simple with a global
62 // string table.
63 StringSet<> StrTab;
64 
insertStrTab(StringRef S)65 StringRef insertStrTab(StringRef S) {
66   if (S.empty())
67     return S;
68   return StrTab.insert(S).first->first();
69 }
70 
71 class format_partition_name {
72   const GIMatchTree &Tree;
73   unsigned Idx;
74 
75 public:
format_partition_name(const GIMatchTree & Tree,unsigned Idx)76   format_partition_name(const GIMatchTree &Tree, unsigned Idx)
77       : Tree(Tree), Idx(Idx) {}
print(raw_ostream & OS) const78   void print(raw_ostream &OS) const {
79     Tree.getPartitioner()->emitPartitionName(OS, Idx);
80   }
81 };
operator <<(raw_ostream & OS,const format_partition_name & Fmt)82 raw_ostream &operator<<(raw_ostream &OS, const format_partition_name &Fmt) {
83   Fmt.print(OS);
84   return OS;
85 }
86 
87 /// Declares data that is passed from the match stage to the apply stage.
88 class MatchDataInfo {
89   /// The symbol used in the tablegen patterns
90   StringRef PatternSymbol;
91   /// The data type for the variable
92   StringRef Type;
93   /// The name of the variable as declared in the generated matcher.
94   std::string VariableName;
95 
96 public:
MatchDataInfo(StringRef PatternSymbol,StringRef Type,StringRef VariableName)97   MatchDataInfo(StringRef PatternSymbol, StringRef Type, StringRef VariableName)
98       : PatternSymbol(PatternSymbol), Type(Type), VariableName(VariableName) {}
99 
getPatternSymbol() const100   StringRef getPatternSymbol() const { return PatternSymbol; };
getType() const101   StringRef getType() const { return Type; };
getVariableName() const102   StringRef getVariableName() const { return VariableName; };
103 };
104 
105 class RootInfo {
106   StringRef PatternSymbol;
107 
108 public:
RootInfo(StringRef PatternSymbol)109   RootInfo(StringRef PatternSymbol) : PatternSymbol(PatternSymbol) {}
110 
getPatternSymbol() const111   StringRef getPatternSymbol() const { return PatternSymbol; }
112 };
113 
114 class CombineRule {
115 public:
116 
117   using const_matchdata_iterator = std::vector<MatchDataInfo>::const_iterator;
118 
119   struct VarInfo {
120     const GIMatchDagInstr *N;
121     const GIMatchDagOperand *Op;
122     const DagInit *Matcher;
123 
124   public:
VarInfo__anonc1d7ee6f0111::CombineRule::VarInfo125     VarInfo(const GIMatchDagInstr *N, const GIMatchDagOperand *Op,
126             const DagInit *Matcher)
127         : N(N), Op(Op), Matcher(Matcher) {}
128   };
129 
130 protected:
131   /// A unique ID for this rule
132   /// ID's are used for debugging and run-time disabling of rules among other
133   /// things.
134   RuleID ID;
135 
136   /// A unique ID that can be used for anonymous objects belonging to this rule.
137   /// Used to create unique names in makeNameForAnon*() without making tests
138   /// overly fragile.
139   unsigned UID = 0;
140 
141   /// The record defining this rule.
142   const Record &TheDef;
143 
144   /// The roots of a match. These are the leaves of the DAG that are closest to
145   /// the end of the function. I.e. the nodes that are encountered without
146   /// following any edges of the DAG described by the pattern as we work our way
147   /// from the bottom of the function to the top.
148   std::vector<RootInfo> Roots;
149 
150   GIMatchDag MatchDag;
151 
152   /// A block of arbitrary C++ to finish testing the match.
153   /// FIXME: This is a temporary measure until we have actual pattern matching
154   const StringInit *MatchingFixupCode = nullptr;
155 
156   /// The MatchData defined by the match stage and required by the apply stage.
157   /// This allows the plumbing of arbitrary data from C++ predicates between the
158   /// stages.
159   ///
160   /// For example, suppose you have:
161   ///   %A = <some-constant-expr>
162   ///   %0 = G_ADD %1, %A
163   /// you could define a GIMatchPredicate that walks %A, constant folds as much
164   /// as possible and returns an APInt containing the discovered constant. You
165   /// could then declare:
166   ///   def apint : GIDefMatchData<"APInt">;
167   /// add it to the rule with:
168   ///   (defs root:$root, apint:$constant)
169   /// evaluate it in the pattern with a C++ function that takes a
170   /// MachineOperand& and an APInt& with:
171   ///   (match [{MIR %root = G_ADD %0, %A }],
172   ///             (constantfold operand:$A, apint:$constant))
173   /// and finally use it in the apply stage with:
174   ///   (apply (create_operand
175   ///                [{ MachineOperand::CreateImm(${constant}.getZExtValue());
176   ///                ]}, apint:$constant),
177   ///             [{MIR %root = FOO %0, %constant }])
178   std::vector<MatchDataInfo> MatchDataDecls;
179 
180   void declareMatchData(StringRef PatternSymbol, StringRef Type,
181                         StringRef VarName);
182 
183   bool parseInstructionMatcher(const CodeGenTarget &Target, StringInit *ArgName,
184                                const Init &Arg,
185                                StringMap<std::vector<VarInfo>> &NamedEdgeDefs,
186                                StringMap<std::vector<VarInfo>> &NamedEdgeUses);
187   bool parseWipMatchOpcodeMatcher(const CodeGenTarget &Target,
188                                   StringInit *ArgName, const Init &Arg);
189 
190 public:
CombineRule(const CodeGenTarget & Target,GIMatchDagContext & Ctx,RuleID ID,const Record & R)191   CombineRule(const CodeGenTarget &Target, GIMatchDagContext &Ctx, RuleID ID,
192               const Record &R)
193       : ID(ID), TheDef(R), MatchDag(Ctx) {}
194   CombineRule(const CombineRule &) = delete;
195 
196   bool parseDefs();
197   bool parseMatcher(const CodeGenTarget &Target);
198 
getID() const199   RuleID getID() const { return ID; }
allocUID()200   unsigned allocUID() { return UID++; }
getName() const201   StringRef getName() const { return TheDef.getName(); }
getDef() const202   const Record &getDef() const { return TheDef; }
getMatchingFixupCode() const203   const StringInit *getMatchingFixupCode() const { return MatchingFixupCode; }
getNumRoots() const204   size_t getNumRoots() const { return Roots.size(); }
205 
getMatchDag()206   GIMatchDag &getMatchDag() { return MatchDag; }
getMatchDag() const207   const GIMatchDag &getMatchDag() const { return MatchDag; }
208 
209   using const_root_iterator = std::vector<RootInfo>::const_iterator;
roots_begin() const210   const_root_iterator roots_begin() const { return Roots.begin(); }
roots_end() const211   const_root_iterator roots_end() const { return Roots.end(); }
roots() const212   iterator_range<const_root_iterator> roots() const {
213     return llvm::make_range(Roots.begin(), Roots.end());
214   }
215 
matchdata_decls() const216   iterator_range<const_matchdata_iterator> matchdata_decls() const {
217     return make_range(MatchDataDecls.begin(), MatchDataDecls.end());
218   }
219 
220   /// Export expansions for this rule
declareExpansions(CodeExpansions & Expansions) const221   void declareExpansions(CodeExpansions &Expansions) const {
222     for (const auto &I : matchdata_decls())
223       Expansions.declare(I.getPatternSymbol(), I.getVariableName());
224   }
225 
226   /// The matcher will begin from the roots and will perform the match by
227   /// traversing the edges to cover the whole DAG. This function reverses DAG
228   /// edges such that everything is reachable from a root. This is part of the
229   /// preparation work for flattening the DAG into a tree.
reorientToRoots()230   void reorientToRoots() {
231     SmallSet<const GIMatchDagInstr *, 5> Roots;
232     SmallSet<const GIMatchDagInstr *, 5> Visited;
233     SmallSet<GIMatchDagEdge *, 20> EdgesRemaining;
234 
235     for (auto &I : MatchDag.roots()) {
236       Roots.insert(I);
237       Visited.insert(I);
238     }
239     for (auto &I : MatchDag.edges())
240       EdgesRemaining.insert(I);
241 
242     bool Progressed = false;
243     SmallSet<GIMatchDagEdge *, 20> EdgesToRemove;
244     while (!EdgesRemaining.empty()) {
245       for (auto *EI : EdgesRemaining) {
246         if (Visited.count(EI->getFromMI())) {
247           if (Roots.count(EI->getToMI()))
248             PrintError(TheDef.getLoc(), "One or more roots are unnecessary");
249           Visited.insert(EI->getToMI());
250           EdgesToRemove.insert(EI);
251           Progressed = true;
252         }
253       }
254       for (GIMatchDagEdge *ToRemove : EdgesToRemove)
255         EdgesRemaining.erase(ToRemove);
256       EdgesToRemove.clear();
257 
258       for (auto EI = EdgesRemaining.begin(), EE = EdgesRemaining.end();
259            EI != EE; ++EI) {
260         if (Visited.count((*EI)->getToMI())) {
261           (*EI)->reverse();
262           Visited.insert((*EI)->getToMI());
263           EdgesToRemove.insert(*EI);
264           Progressed = true;
265         }
266         for (GIMatchDagEdge *ToRemove : EdgesToRemove)
267           EdgesRemaining.erase(ToRemove);
268         EdgesToRemove.clear();
269       }
270 
271       if (!Progressed) {
272         LLVM_DEBUG(dbgs() << "No progress\n");
273         return;
274       }
275       Progressed = false;
276     }
277   }
278 };
279 
280 /// A convenience function to check that an Init refers to a specific def. This
281 /// is primarily useful for testing for defs and similar in DagInit's since
282 /// DagInit's support any type inside them.
isSpecificDef(const Init & N,StringRef Def)283 static bool isSpecificDef(const Init &N, StringRef Def) {
284   if (const DefInit *OpI = dyn_cast<DefInit>(&N))
285     if (OpI->getDef()->getName() == Def)
286       return true;
287   return false;
288 }
289 
290 /// A convenience function to check that an Init refers to a def that is a
291 /// subclass of the given class and coerce it to a def if it is. This is
292 /// primarily useful for testing for subclasses of GIMatchKind and similar in
293 /// DagInit's since DagInit's support any type inside them.
getDefOfSubClass(const Init & N,StringRef Cls)294 static Record *getDefOfSubClass(const Init &N, StringRef Cls) {
295   if (const DefInit *OpI = dyn_cast<DefInit>(&N))
296     if (OpI->getDef()->isSubClassOf(Cls))
297       return OpI->getDef();
298   return nullptr;
299 }
300 
301 /// A convenience function to check that an Init refers to a dag whose operator
302 /// is a specific def and coerce it to a dag if it is. This is primarily useful
303 /// for testing for subclasses of GIMatchKind and similar in DagInit's since
304 /// DagInit's support any type inside them.
getDagWithSpecificOperator(const Init & N,StringRef Name)305 static const DagInit *getDagWithSpecificOperator(const Init &N,
306                                                  StringRef Name) {
307   if (const DagInit *I = dyn_cast<DagInit>(&N))
308     if (I->getNumArgs() > 0)
309       if (const DefInit *OpI = dyn_cast<DefInit>(I->getOperator()))
310         if (OpI->getDef()->getName() == Name)
311           return I;
312   return nullptr;
313 }
314 
315 /// A convenience function to check that an Init refers to a dag whose operator
316 /// is a def that is a subclass of the given class and coerce it to a dag if it
317 /// is. This is primarily useful for testing for subclasses of GIMatchKind and
318 /// similar in DagInit's since DagInit's support any type inside them.
getDagWithOperatorOfSubClass(const Init & N,StringRef Cls)319 static const DagInit *getDagWithOperatorOfSubClass(const Init &N,
320                                                    StringRef Cls) {
321   if (const DagInit *I = dyn_cast<DagInit>(&N))
322     if (I->getNumArgs() > 0)
323       if (const DefInit *OpI = dyn_cast<DefInit>(I->getOperator()))
324         if (OpI->getDef()->isSubClassOf(Cls))
325           return I;
326   return nullptr;
327 }
328 
makeNameForAnonInstr(CombineRule & Rule)329 StringRef makeNameForAnonInstr(CombineRule &Rule) {
330   return insertStrTab(to_string(
331       format("__anon%" PRIu64 "_%u", Rule.getID(), Rule.allocUID())));
332 }
333 
makeDebugName(CombineRule & Rule,StringRef Name)334 StringRef makeDebugName(CombineRule &Rule, StringRef Name) {
335   return insertStrTab(Name.empty() ? makeNameForAnonInstr(Rule) : StringRef(Name));
336 }
337 
makeNameForAnonPredicate(CombineRule & Rule)338 StringRef makeNameForAnonPredicate(CombineRule &Rule) {
339   return insertStrTab(to_string(
340       format("__anonpred%" PRIu64 "_%u", Rule.getID(), Rule.allocUID())));
341 }
342 
declareMatchData(StringRef PatternSymbol,StringRef Type,StringRef VarName)343 void CombineRule::declareMatchData(StringRef PatternSymbol, StringRef Type,
344                                    StringRef VarName) {
345   MatchDataDecls.emplace_back(PatternSymbol, Type, VarName);
346 }
347 
parseDefs()348 bool CombineRule::parseDefs() {
349   DagInit *Defs = TheDef.getValueAsDag("Defs");
350 
351   if (Defs->getOperatorAsDef(TheDef.getLoc())->getName() != "defs") {
352     PrintError(TheDef.getLoc(), "Expected defs operator");
353     return false;
354   }
355 
356   for (unsigned I = 0, E = Defs->getNumArgs(); I < E; ++I) {
357     // Roots should be collected into Roots
358     if (isSpecificDef(*Defs->getArg(I), "root")) {
359       Roots.emplace_back(Defs->getArgNameStr(I));
360       continue;
361     }
362 
363     // Subclasses of GIDefMatchData should declare that this rule needs to pass
364     // data from the match stage to the apply stage, and ensure that the
365     // generated matcher has a suitable variable for it to do so.
366     if (Record *MatchDataRec =
367             getDefOfSubClass(*Defs->getArg(I), "GIDefMatchData")) {
368       declareMatchData(Defs->getArgNameStr(I),
369                        MatchDataRec->getValueAsString("Type"),
370                        llvm::to_string(llvm::format("MatchData%" PRIu64, ID)));
371       continue;
372     }
373 
374     // Otherwise emit an appropriate error message.
375     if (getDefOfSubClass(*Defs->getArg(I), "GIDefKind"))
376       PrintError(TheDef.getLoc(),
377                  "This GIDefKind not implemented in tablegen");
378     else if (getDefOfSubClass(*Defs->getArg(I), "GIDefKindWithArgs"))
379       PrintError(TheDef.getLoc(),
380                  "This GIDefKindWithArgs not implemented in tablegen");
381     else
382       PrintError(TheDef.getLoc(),
383                  "Expected a subclass of GIDefKind or a sub-dag whose "
384                  "operator is of type GIDefKindWithArgs");
385     return false;
386   }
387 
388   if (Roots.empty()) {
389     PrintError(TheDef.getLoc(), "Combine rules must have at least one root");
390     return false;
391   }
392   return true;
393 }
394 
395 // Parse an (Instruction $a:Arg1, $b:Arg2, ...) matcher. Edges are formed
396 // between matching operand names between different matchers.
parseInstructionMatcher(const CodeGenTarget & Target,StringInit * ArgName,const Init & Arg,StringMap<std::vector<VarInfo>> & NamedEdgeDefs,StringMap<std::vector<VarInfo>> & NamedEdgeUses)397 bool CombineRule::parseInstructionMatcher(
398     const CodeGenTarget &Target, StringInit *ArgName, const Init &Arg,
399     StringMap<std::vector<VarInfo>> &NamedEdgeDefs,
400     StringMap<std::vector<VarInfo>> &NamedEdgeUses) {
401   if (const DagInit *Matcher =
402           getDagWithOperatorOfSubClass(Arg, "Instruction")) {
403     auto &Instr =
404         Target.getInstruction(Matcher->getOperatorAsDef(TheDef.getLoc()));
405 
406     StringRef Name = ArgName ? ArgName->getValue() : "";
407 
408     GIMatchDagInstr *N =
409         MatchDag.addInstrNode(makeDebugName(*this, Name), insertStrTab(Name),
410                               MatchDag.getContext().makeOperandList(Instr));
411 
412     N->setOpcodeAnnotation(&Instr);
413     const auto &P = MatchDag.addPredicateNode<GIMatchDagOpcodePredicate>(
414         makeNameForAnonPredicate(*this), Instr);
415     MatchDag.addPredicateDependency(N, nullptr, P, &P->getOperandInfo()["mi"]);
416     unsigned OpIdx = 0;
417     for (const auto &NameInit : Matcher->getArgNames()) {
418       StringRef Name = insertStrTab(NameInit->getAsUnquotedString());
419       if (Name.empty())
420         continue;
421       N->assignNameToOperand(OpIdx, Name);
422 
423       // Record the endpoints of any named edges. We'll add the cartesian
424       // product of edges later.
425       const auto &InstrOperand = N->getOperandInfo()[OpIdx];
426       if (InstrOperand.isDef()) {
427         NamedEdgeDefs.try_emplace(Name);
428         NamedEdgeDefs[Name].emplace_back(N, &InstrOperand, Matcher);
429       } else {
430         NamedEdgeUses.try_emplace(Name);
431         NamedEdgeUses[Name].emplace_back(N, &InstrOperand, Matcher);
432       }
433 
434       if (InstrOperand.isDef()) {
435         if (any_of(Roots, [&](const RootInfo &X) {
436               return X.getPatternSymbol() == Name;
437             })) {
438           N->setMatchRoot();
439         }
440       }
441 
442       OpIdx++;
443     }
444 
445     return true;
446   }
447   return false;
448 }
449 
450 // Parse the wip_match_opcode placeholder that's temporarily present in lieu of
451 // implementing macros or choices between two matchers.
parseWipMatchOpcodeMatcher(const CodeGenTarget & Target,StringInit * ArgName,const Init & Arg)452 bool CombineRule::parseWipMatchOpcodeMatcher(const CodeGenTarget &Target,
453                                              StringInit *ArgName,
454                                              const Init &Arg) {
455   if (const DagInit *Matcher =
456           getDagWithSpecificOperator(Arg, "wip_match_opcode")) {
457     StringRef Name = ArgName ? ArgName->getValue() : "";
458 
459     GIMatchDagInstr *N =
460         MatchDag.addInstrNode(makeDebugName(*this, Name), insertStrTab(Name),
461                               MatchDag.getContext().makeEmptyOperandList());
462 
463     if (any_of(Roots, [&](const RootInfo &X) {
464           return ArgName && X.getPatternSymbol() == ArgName->getValue();
465         })) {
466       N->setMatchRoot();
467     }
468 
469     const auto &P = MatchDag.addPredicateNode<GIMatchDagOneOfOpcodesPredicate>(
470         makeNameForAnonPredicate(*this));
471     MatchDag.addPredicateDependency(N, nullptr, P, &P->getOperandInfo()["mi"]);
472     // Each argument is an opcode that will pass this predicate. Add them all to
473     // the predicate implementation
474     for (const auto &Arg : Matcher->getArgs()) {
475       Record *OpcodeDef = getDefOfSubClass(*Arg, "Instruction");
476       if (OpcodeDef) {
477         P->addOpcode(&Target.getInstruction(OpcodeDef));
478         continue;
479       }
480       PrintError(TheDef.getLoc(),
481                  "Arguments to wip_match_opcode must be instructions");
482       return false;
483     }
484     return true;
485   }
486   return false;
487 }
parseMatcher(const CodeGenTarget & Target)488 bool CombineRule::parseMatcher(const CodeGenTarget &Target) {
489   StringMap<std::vector<VarInfo>> NamedEdgeDefs;
490   StringMap<std::vector<VarInfo>> NamedEdgeUses;
491   DagInit *Matchers = TheDef.getValueAsDag("Match");
492 
493   if (Matchers->getOperatorAsDef(TheDef.getLoc())->getName() != "match") {
494     PrintError(TheDef.getLoc(), "Expected match operator");
495     return false;
496   }
497 
498   if (Matchers->getNumArgs() == 0) {
499     PrintError(TheDef.getLoc(), "Matcher is empty");
500     return false;
501   }
502 
503   // The match section consists of a list of matchers and predicates. Parse each
504   // one and add the equivalent GIMatchDag nodes, predicates, and edges.
505   for (unsigned I = 0; I < Matchers->getNumArgs(); ++I) {
506     if (parseInstructionMatcher(Target, Matchers->getArgName(I),
507                                 *Matchers->getArg(I), NamedEdgeDefs,
508                                 NamedEdgeUses))
509       continue;
510 
511     if (parseWipMatchOpcodeMatcher(Target, Matchers->getArgName(I),
512                                    *Matchers->getArg(I)))
513       continue;
514 
515 
516     // Parse arbitrary C++ code we have in lieu of supporting MIR matching
517     if (const StringInit *StringI = dyn_cast<StringInit>(Matchers->getArg(I))) {
518       assert(!MatchingFixupCode &&
519              "Only one block of arbitrary code is currently permitted");
520       MatchingFixupCode = StringI;
521       MatchDag.setHasPostMatchPredicate(true);
522       continue;
523     }
524 
525     PrintError(TheDef.getLoc(),
526                "Expected a subclass of GIMatchKind or a sub-dag whose "
527                "operator is either of a GIMatchKindWithArgs or Instruction");
528     PrintNote("Pattern was `" + Matchers->getArg(I)->getAsString() + "'");
529     return false;
530   }
531 
532   // Add the cartesian product of use -> def edges.
533   bool FailedToAddEdges = false;
534   for (const auto &NameAndDefs : NamedEdgeDefs) {
535     if (NameAndDefs.getValue().size() > 1) {
536       PrintError(TheDef.getLoc(),
537                  "Two different MachineInstrs cannot def the same vreg");
538       for (const auto &NameAndDefOp : NameAndDefs.getValue())
539         PrintNote("in " + to_string(*NameAndDefOp.N) + " created from " +
540                   to_string(*NameAndDefOp.Matcher) + "");
541       FailedToAddEdges = true;
542     }
543     const auto &Uses = NamedEdgeUses[NameAndDefs.getKey()];
544     for (const VarInfo &DefVar : NameAndDefs.getValue()) {
545       for (const VarInfo &UseVar : Uses) {
546         MatchDag.addEdge(insertStrTab(NameAndDefs.getKey()), UseVar.N, UseVar.Op,
547                          DefVar.N, DefVar.Op);
548       }
549     }
550   }
551   if (FailedToAddEdges)
552     return false;
553 
554   // If a variable is referenced in multiple use contexts then we need a
555   // predicate to confirm they are the same operand. We can elide this if it's
556   // also referenced in a def context and we're traversing the def-use chain
557   // from the def to the uses but we can't know which direction we're going
558   // until after reorientToRoots().
559   for (const auto &NameAndUses : NamedEdgeUses) {
560     const auto &Uses = NameAndUses.getValue();
561     if (Uses.size() > 1) {
562       const auto &LeadingVar = Uses.front();
563       for (const auto &Var : ArrayRef<VarInfo>(Uses).drop_front()) {
564         // Add a predicate for each pair until we've covered the whole
565         // equivalence set. We could test the whole set in a single predicate
566         // but that means we can't test any equivalence until all the MO's are
567         // available which can lead to wasted work matching the DAG when this
568         // predicate can already be seen to have failed.
569         //
570         // We have a similar problem due to the need to wait for a particular MO
571         // before being able to test any of them. However, that is mitigated by
572         // the order in which we build the DAG. We build from the roots outwards
573         // so by using the first recorded use in all the predicates, we are
574         // making the dependency on one of the earliest visited references in
575         // the DAG. It's not guaranteed once the generated matcher is optimized
576         // (because the factoring the common portions of rules might change the
577         // visit order) but this should mean that these predicates depend on the
578         // first MO to become available.
579         const auto &P = MatchDag.addPredicateNode<GIMatchDagSameMOPredicate>(
580             makeNameForAnonPredicate(*this));
581         MatchDag.addPredicateDependency(LeadingVar.N, LeadingVar.Op, P,
582                                         &P->getOperandInfo()["mi0"]);
583         MatchDag.addPredicateDependency(Var.N, Var.Op, P,
584                                         &P->getOperandInfo()["mi1"]);
585       }
586     }
587   }
588   return true;
589 }
590 
591 class GICombinerEmitter {
592   RecordKeeper &Records;
593   StringRef Name;
594   const CodeGenTarget &Target;
595   Record *Combiner;
596   std::vector<std::unique_ptr<CombineRule>> Rules;
597   GIMatchDagContext MatchDagCtx;
598 
599   std::unique_ptr<CombineRule> makeCombineRule(const Record &R);
600 
601   void gatherRules(std::vector<std::unique_ptr<CombineRule>> &ActiveRules,
602                    const std::vector<Record *> &&RulesAndGroups);
603 
604 public:
605   explicit GICombinerEmitter(RecordKeeper &RK, const CodeGenTarget &Target,
606                              StringRef Name, Record *Combiner);
~GICombinerEmitter()607   ~GICombinerEmitter() {}
608 
getClassName() const609   StringRef getClassName() const {
610     return Combiner->getValueAsString("Classname");
611   }
612   void run(raw_ostream &OS);
613 
614   /// Emit the name matcher (guarded by #ifndef NDEBUG) used to disable rules in
615   /// response to the generated cl::opt.
616   void emitNameMatcher(raw_ostream &OS) const;
617 
618   void generateCodeForTree(raw_ostream &OS, const GIMatchTree &Tree,
619                            StringRef Indent) const;
620 };
621 
GICombinerEmitter(RecordKeeper & RK,const CodeGenTarget & Target,StringRef Name,Record * Combiner)622 GICombinerEmitter::GICombinerEmitter(RecordKeeper &RK,
623                                      const CodeGenTarget &Target,
624                                      StringRef Name, Record *Combiner)
625     : Records(RK), Name(Name), Target(Target), Combiner(Combiner) {}
626 
emitNameMatcher(raw_ostream & OS) const627 void GICombinerEmitter::emitNameMatcher(raw_ostream &OS) const {
628   std::vector<std::pair<std::string, std::string>> Cases;
629   Cases.reserve(Rules.size());
630 
631   for (const CombineRule &EnumeratedRule : make_pointee_range(Rules)) {
632     std::string Code;
633     raw_string_ostream SS(Code);
634     SS << "return " << EnumeratedRule.getID() << ";\n";
635     Cases.push_back(
636         std::make_pair(std::string(EnumeratedRule.getName()), SS.str()));
637   }
638 
639   OS << "static Optional<uint64_t> getRuleIdxForIdentifier(StringRef "
640         "RuleIdentifier) {\n"
641      << "  uint64_t I;\n"
642      << "  // getAtInteger(...) returns false on success\n"
643      << "  bool Parsed = !RuleIdentifier.getAsInteger(0, I);\n"
644      << "  if (Parsed)\n"
645      << "    return I;\n\n"
646      << "#ifndef NDEBUG\n";
647   StringMatcher Matcher("RuleIdentifier", Cases, OS);
648   Matcher.Emit();
649   OS << "#endif // ifndef NDEBUG\n\n"
650      << "  return None;\n"
651      << "}\n";
652 }
653 
654 std::unique_ptr<CombineRule>
makeCombineRule(const Record & TheDef)655 GICombinerEmitter::makeCombineRule(const Record &TheDef) {
656   std::unique_ptr<CombineRule> Rule =
657       std::make_unique<CombineRule>(Target, MatchDagCtx, NumPatternTotal, TheDef);
658 
659   if (!Rule->parseDefs())
660     return nullptr;
661   if (!Rule->parseMatcher(Target))
662     return nullptr;
663 
664   Rule->reorientToRoots();
665 
666   LLVM_DEBUG({
667     dbgs() << "Parsed rule defs/match for '" << Rule->getName() << "'\n";
668     Rule->getMatchDag().dump();
669     Rule->getMatchDag().writeDOTGraph(dbgs(), Rule->getName());
670   });
671   if (StopAfterParse)
672     return Rule;
673 
674   // For now, don't support traversing from def to use. We'll come back to
675   // this later once we have the algorithm changes to support it.
676   bool EmittedDefToUseError = false;
677   for (const auto &E : Rule->getMatchDag().edges()) {
678     if (E->isDefToUse()) {
679       if (!EmittedDefToUseError) {
680         PrintError(
681             TheDef.getLoc(),
682             "Generated state machine cannot lookup uses from a def (yet)");
683         EmittedDefToUseError = true;
684       }
685       PrintNote("Node " + to_string(*E->getFromMI()));
686       PrintNote("Node " + to_string(*E->getToMI()));
687       PrintNote("Edge " + to_string(*E));
688     }
689   }
690   if (EmittedDefToUseError)
691     return nullptr;
692 
693   // For now, don't support multi-root rules. We'll come back to this later
694   // once we have the algorithm changes to support it.
695   if (Rule->getNumRoots() > 1) {
696     PrintError(TheDef.getLoc(), "Multi-root matches are not supported (yet)");
697     return nullptr;
698   }
699   return Rule;
700 }
701 
702 /// Recurse into GICombineGroup's and flatten the ruleset into a simple list.
gatherRules(std::vector<std::unique_ptr<CombineRule>> & ActiveRules,const std::vector<Record * > && RulesAndGroups)703 void GICombinerEmitter::gatherRules(
704     std::vector<std::unique_ptr<CombineRule>> &ActiveRules,
705     const std::vector<Record *> &&RulesAndGroups) {
706   for (Record *R : RulesAndGroups) {
707     if (R->isValueUnset("Rules")) {
708       std::unique_ptr<CombineRule> Rule = makeCombineRule(*R);
709       if (Rule == nullptr) {
710         PrintError(R->getLoc(), "Failed to parse rule");
711         continue;
712       }
713       ActiveRules.emplace_back(std::move(Rule));
714       ++NumPatternTotal;
715     } else
716       gatherRules(ActiveRules, R->getValueAsListOfDefs("Rules"));
717   }
718 }
719 
generateCodeForTree(raw_ostream & OS,const GIMatchTree & Tree,StringRef Indent) const720 void GICombinerEmitter::generateCodeForTree(raw_ostream &OS,
721                                             const GIMatchTree &Tree,
722                                             StringRef Indent) const {
723   if (Tree.getPartitioner() != nullptr) {
724     Tree.getPartitioner()->generatePartitionSelectorCode(OS, Indent);
725     for (const auto &EnumChildren : enumerate(Tree.children())) {
726       OS << Indent << "if (Partition == " << EnumChildren.index() << " /* "
727          << format_partition_name(Tree, EnumChildren.index()) << " */) {\n";
728       generateCodeForTree(OS, EnumChildren.value(), (Indent + "  ").str());
729       OS << Indent << "}\n";
730     }
731     return;
732   }
733 
734   bool AnyFullyTested = false;
735   for (const auto &Leaf : Tree.possible_leaves()) {
736     OS << Indent << "// Leaf name: " << Leaf.getName() << "\n";
737 
738     const CombineRule *Rule = Leaf.getTargetData<CombineRule>();
739     const Record &RuleDef = Rule->getDef();
740 
741     OS << Indent << "// Rule: " << RuleDef.getName() << "\n"
742        << Indent << "if (!RuleConfig->isRuleDisabled(" << Rule->getID()
743        << ")) {\n";
744 
745     CodeExpansions Expansions;
746     for (const auto &VarBinding : Leaf.var_bindings()) {
747       if (VarBinding.isInstr())
748         Expansions.declare(VarBinding.getName(),
749                            "MIs[" + to_string(VarBinding.getInstrID()) + "]");
750       else
751         Expansions.declare(VarBinding.getName(),
752                            "MIs[" + to_string(VarBinding.getInstrID()) +
753                                "]->getOperand(" +
754                                to_string(VarBinding.getOpIdx()) + ")");
755     }
756     Rule->declareExpansions(Expansions);
757 
758     DagInit *Applyer = RuleDef.getValueAsDag("Apply");
759     if (Applyer->getOperatorAsDef(RuleDef.getLoc())->getName() !=
760         "apply") {
761       PrintError(RuleDef.getLoc(), "Expected 'apply' operator in Apply DAG");
762       return;
763     }
764 
765     OS << Indent << "  if (1\n";
766 
767     // Attempt to emit code for any untested predicates left over. Note that
768     // isFullyTested() will remain false even if we succeed here and therefore
769     // combine rule elision will not be performed. This is because we do not
770     // know if there's any connection between the predicates for each leaf and
771     // therefore can't tell if one makes another unreachable. Ideally, the
772     // partitioner(s) would be sufficiently complete to prevent us from having
773     // untested predicates left over.
774     for (const GIMatchDagPredicate *Predicate : Leaf.untested_predicates()) {
775       if (Predicate->generateCheckCode(OS, (Indent + "      ").str(),
776                                        Expansions))
777         continue;
778       PrintError(RuleDef.getLoc(),
779                  "Unable to test predicate used in rule");
780       PrintNote(SMLoc(),
781                 "This indicates an incomplete implementation in tablegen");
782       Predicate->print(errs());
783       errs() << "\n";
784       OS << Indent
785          << "llvm_unreachable(\"TableGen did not emit complete code for this "
786             "path\");\n";
787       break;
788     }
789 
790     if (Rule->getMatchingFixupCode() &&
791         !Rule->getMatchingFixupCode()->getValue().empty()) {
792       // FIXME: Single-use lambda's like this are a serious compile-time
793       // performance and memory issue. It's convenient for this early stage to
794       // defer some work to successive patches but we need to eliminate this
795       // before the ruleset grows to small-moderate size. Last time, it became
796       // a big problem for low-mem systems around the 500 rule mark but by the
797       // time we grow that large we should have merged the ISel match table
798       // mechanism with the Combiner.
799       OS << Indent << "      && [&]() {\n"
800          << Indent << "      "
801          << CodeExpander(Rule->getMatchingFixupCode()->getValue(), Expansions,
802                          RuleDef.getLoc(), ShowExpansions)
803          << "\n"
804          << Indent << "      return true;\n"
805          << Indent << "  }()";
806     }
807     OS << ") {\n" << Indent << "   ";
808 
809     if (const StringInit *Code = dyn_cast<StringInit>(Applyer->getArg(0))) {
810       OS << CodeExpander(Code->getAsUnquotedString(), Expansions,
811                          RuleDef.getLoc(), ShowExpansions)
812          << "\n"
813          << Indent << "    return true;\n"
814          << Indent << "  }\n";
815     } else {
816       PrintError(RuleDef.getLoc(), "Expected apply code block");
817       return;
818     }
819 
820     OS << Indent << "}\n";
821 
822     assert(Leaf.isFullyTraversed());
823 
824     // If we didn't have any predicates left over and we're not using the
825     // trap-door we have to support arbitrary C++ code while we're migrating to
826     // the declarative style then we know that subsequent leaves are
827     // unreachable.
828     if (Leaf.isFullyTested() &&
829         (!Rule->getMatchingFixupCode() ||
830          Rule->getMatchingFixupCode()->getValue().empty())) {
831       AnyFullyTested = true;
832       OS << Indent
833          << "llvm_unreachable(\"Combine rule elision was incorrect\");\n"
834          << Indent << "return false;\n";
835     }
836   }
837   if (!AnyFullyTested)
838     OS << Indent << "return false;\n";
839 }
840 
emitAdditionalHelperMethodArguments(raw_ostream & OS,Record * Combiner)841 static void emitAdditionalHelperMethodArguments(raw_ostream &OS,
842                                                 Record *Combiner) {
843   for (Record *Arg : Combiner->getValueAsListOfDefs("AdditionalArguments"))
844     OS << ",\n    " << Arg->getValueAsString("Type")
845        << Arg->getValueAsString("Name");
846 }
847 
run(raw_ostream & OS)848 void GICombinerEmitter::run(raw_ostream &OS) {
849   Records.startTimer("Gather rules");
850   gatherRules(Rules, Combiner->getValueAsListOfDefs("Rules"));
851   if (StopAfterParse) {
852     MatchDagCtx.print(errs());
853     PrintNote(Combiner->getLoc(),
854               "Terminating due to -gicombiner-stop-after-parse");
855     return;
856   }
857   if (ErrorsPrinted)
858     PrintFatalError(Combiner->getLoc(), "Failed to parse one or more rules");
859   LLVM_DEBUG(dbgs() << "Optimizing tree for " << Rules.size() << " rules\n");
860   std::unique_ptr<GIMatchTree> Tree;
861   Records.startTimer("Optimize combiner");
862   {
863     GIMatchTreeBuilder TreeBuilder(0);
864     for (const auto &Rule : Rules) {
865       bool HadARoot = false;
866       for (const auto &Root : enumerate(Rule->getMatchDag().roots())) {
867         TreeBuilder.addLeaf(Rule->getName(), Root.index(), Rule->getMatchDag(),
868                             Rule.get());
869         HadARoot = true;
870       }
871       if (!HadARoot)
872         PrintFatalError(Rule->getDef().getLoc(), "All rules must have a root");
873     }
874 
875     Tree = TreeBuilder.run();
876   }
877   if (StopAfterBuild) {
878     Tree->writeDOTGraph(outs());
879     PrintNote(Combiner->getLoc(),
880               "Terminating due to -gicombiner-stop-after-build");
881     return;
882   }
883 
884   Records.startTimer("Emit combiner");
885   OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_DEPS\n"
886      << "#include \"llvm/ADT/SparseBitVector.h\"\n"
887      << "namespace llvm {\n"
888      << "extern cl::OptionCategory GICombinerOptionCategory;\n"
889      << "} // end namespace llvm\n"
890      << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_DEPS\n\n";
891 
892   OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_H\n"
893      << "class " << getClassName() << "RuleConfig {\n"
894      << "  SparseBitVector<> DisabledRules;\n"
895      << "\n"
896      << "public:\n"
897      << "  bool parseCommandLineOption();\n"
898      << "  bool isRuleDisabled(unsigned ID) const;\n"
899      << "  bool setRuleEnabled(StringRef RuleIdentifier);\n"
900      << "  bool setRuleDisabled(StringRef RuleIdentifier);\n"
901      << "};\n"
902      << "\n"
903      << "class " << getClassName();
904   StringRef StateClass = Combiner->getValueAsString("StateClass");
905   if (!StateClass.empty())
906     OS << " : public " << StateClass;
907   OS << " {\n"
908      << "  const " << getClassName() << "RuleConfig *RuleConfig;\n"
909      << "\n"
910      << "public:\n"
911      << "  template <typename... Args>" << getClassName() << "(const "
912      << getClassName() << "RuleConfig &RuleConfig, Args &&... args) : ";
913   if (!StateClass.empty())
914     OS << StateClass << "(std::forward<Args>(args)...), ";
915   OS << "RuleConfig(&RuleConfig) {}\n"
916      << "\n"
917      << "  bool tryCombineAll(\n"
918      << "    GISelChangeObserver &Observer,\n"
919      << "    MachineInstr &MI,\n"
920      << "    MachineIRBuilder &B";
921   emitAdditionalHelperMethodArguments(OS, Combiner);
922   OS << ") const;\n";
923   OS << "};\n\n";
924 
925   emitNameMatcher(OS);
926 
927   OS << "static Optional<std::pair<uint64_t, uint64_t>> "
928         "getRuleRangeForIdentifier(StringRef RuleIdentifier) {\n"
929      << "  std::pair<StringRef, StringRef> RangePair = "
930         "RuleIdentifier.split('-');\n"
931      << "  if (!RangePair.second.empty()) {\n"
932      << "    const auto First = "
933         "getRuleIdxForIdentifier(RangePair.first);\n"
934      << "    const auto Last = "
935         "getRuleIdxForIdentifier(RangePair.second);\n"
936      << "    if (!First.hasValue() || !Last.hasValue())\n"
937      << "      return None;\n"
938      << "    if (First >= Last)\n"
939      << "      report_fatal_error(\"Beginning of range should be before "
940         "end of range\");\n"
941      << "    return {{*First, *Last + 1}};\n"
942      << "  } else if (RangePair.first == \"*\") {\n"
943      << "    return {{0, " << Rules.size() << "}};\n"
944      << "  } else {\n"
945      << "    const auto I = getRuleIdxForIdentifier(RangePair.first);\n"
946      << "    if (!I.hasValue())\n"
947      << "      return None;\n"
948      << "    return {{*I, *I + 1}};\n"
949      << "  }\n"
950      << "  return None;\n"
951      << "}\n\n";
952 
953   for (bool Enabled : {true, false}) {
954     OS << "bool " << getClassName() << "RuleConfig::setRule"
955        << (Enabled ? "Enabled" : "Disabled") << "(StringRef RuleIdentifier) {\n"
956        << "  auto MaybeRange = getRuleRangeForIdentifier(RuleIdentifier);\n"
957        << "  if (!MaybeRange.hasValue())\n"
958        << "    return false;\n"
959        << "  for (auto I = MaybeRange->first; I < MaybeRange->second; ++I)\n"
960        << "    DisabledRules." << (Enabled ? "reset" : "set") << "(I);\n"
961        << "  return true;\n"
962        << "}\n\n";
963   }
964 
965   OS << "bool " << getClassName()
966      << "RuleConfig::isRuleDisabled(unsigned RuleID) const {\n"
967      << "  return DisabledRules.test(RuleID);\n"
968      << "}\n";
969   OS << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_H\n\n";
970 
971   OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_CPP\n"
972      << "\n"
973      << "std::vector<std::string> " << Name << "Option;\n"
974      << "cl::list<std::string> " << Name << "DisableOption(\n"
975      << "    \"" << Name.lower() << "-disable-rule\",\n"
976      << "    cl::desc(\"Disable one or more combiner rules temporarily in "
977      << "the " << Name << " pass\"),\n"
978      << "    cl::CommaSeparated,\n"
979      << "    cl::Hidden,\n"
980      << "    cl::cat(GICombinerOptionCategory),\n"
981      << "    cl::callback([](const std::string &Str) {\n"
982      << "      " << Name << "Option.push_back(Str);\n"
983      << "    }));\n"
984      << "cl::list<std::string> " << Name << "OnlyEnableOption(\n"
985      << "    \"" << Name.lower() << "-only-enable-rule\",\n"
986      << "    cl::desc(\"Disable all rules in the " << Name
987      << " pass then re-enable the specified ones\"),\n"
988      << "    cl::Hidden,\n"
989      << "    cl::cat(GICombinerOptionCategory),\n"
990      << "    cl::callback([](const std::string &CommaSeparatedArg) {\n"
991      << "      StringRef Str = CommaSeparatedArg;\n"
992      << "      " << Name << "Option.push_back(\"*\");\n"
993      << "      do {\n"
994      << "        auto X = Str.split(\",\");\n"
995      << "        " << Name << "Option.push_back((\"!\" + X.first).str());\n"
996      << "        Str = X.second;\n"
997      << "      } while (!Str.empty());\n"
998      << "    }));\n"
999      << "\n"
1000      << "bool " << getClassName() << "RuleConfig::parseCommandLineOption() {\n"
1001      << "  for (StringRef Identifier : " << Name << "Option) {\n"
1002      << "    bool Enabled = Identifier.consume_front(\"!\");\n"
1003      << "    if (Enabled && !setRuleEnabled(Identifier))\n"
1004      << "      return false;\n"
1005      << "    if (!Enabled && !setRuleDisabled(Identifier))\n"
1006      << "      return false;\n"
1007      << "  }\n"
1008      << "  return true;\n"
1009      << "}\n\n";
1010 
1011   OS << "bool " << getClassName() << "::tryCombineAll(\n"
1012      << "    GISelChangeObserver &Observer,\n"
1013      << "    MachineInstr &MI,\n"
1014      << "    MachineIRBuilder &B";
1015   emitAdditionalHelperMethodArguments(OS, Combiner);
1016   OS << ") const {\n"
1017      << "  MachineBasicBlock *MBB = MI.getParent();\n"
1018      << "  MachineFunction *MF = MBB->getParent();\n"
1019      << "  MachineRegisterInfo &MRI = MF->getRegInfo();\n"
1020      << "  SmallVector<MachineInstr *, 8> MIs = {&MI};\n\n"
1021      << "  (void)MBB; (void)MF; (void)MRI; (void)RuleConfig;\n\n";
1022 
1023   OS << "  // Match data\n";
1024   for (const auto &Rule : Rules)
1025     for (const auto &I : Rule->matchdata_decls())
1026       OS << "  " << I.getType() << " " << I.getVariableName() << ";\n";
1027   OS << "\n";
1028 
1029   OS << "  int Partition = -1;\n";
1030   generateCodeForTree(OS, *Tree, "  ");
1031   OS << "\n  return false;\n"
1032      << "}\n"
1033      << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_CPP\n";
1034 }
1035 
1036 } // end anonymous namespace
1037 
1038 //===----------------------------------------------------------------------===//
1039 
1040 namespace llvm {
EmitGICombiner(RecordKeeper & RK,raw_ostream & OS)1041 void EmitGICombiner(RecordKeeper &RK, raw_ostream &OS) {
1042   CodeGenTarget Target(RK);
1043   emitSourceFileHeader("Global Combiner", OS);
1044 
1045   if (SelectedCombiners.empty())
1046     PrintFatalError("No combiners selected with -combiners");
1047   for (const auto &Combiner : SelectedCombiners) {
1048     Record *CombinerDef = RK.getDef(Combiner);
1049     if (!CombinerDef)
1050       PrintFatalError("Could not find " + Combiner);
1051     GICombinerEmitter(RK, Target, Combiner, CombinerDef).run(OS);
1052   }
1053   NumPatternTotalStatistic = NumPatternTotal;
1054 }
1055 
1056 } // namespace llvm
1057