1 //===- GlobalISelCombinerMatchTableEmitter.cpp - --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Generate a combiner implementation for GlobalISel from a declarative
10 /// syntax using GlobalISelMatchTable.
11 ///
12 /// Usually, TableGen backends use "assert is an error" as a means to report
13 /// invalid input. They try to diagnose common case but don't try very hard and
14 /// crashes can be common. This backend aims to behave closer to how a language
15 /// compiler frontend would behave: we try extra hard to diagnose invalid inputs
16 /// early, and any crash should be considered a bug (= a feature or diagnostic
17 /// is missing).
18 ///
19 /// While this can make the backend a bit more complex than it needs to be, it
20 /// pays off because MIR patterns can get complicated. Giving useful error
21 /// messages to combine writers can help boost their productivity.
22 ///
23 /// As with anything, a good balance has to be found. We also don't want to
24 /// write hundreds of lines of code to detect edge cases. In practice, crashing
25 /// very occasionally, or giving poor errors in some rare instances, is fine.
26 ///
27 //===----------------------------------------------------------------------===//
28 
29 #include "CodeGenInstruction.h"
30 #include "CodeGenTarget.h"
31 #include "GlobalISel/CXXPredicates.h"
32 #include "GlobalISel/CodeExpander.h"
33 #include "GlobalISel/CodeExpansions.h"
34 #include "GlobalISel/CombinerUtils.h"
35 #include "GlobalISel/MatchDataInfo.h"
36 #include "GlobalISel/Patterns.h"
37 #include "GlobalISelMatchTable.h"
38 #include "GlobalISelMatchTableExecutorEmitter.h"
39 #include "SubtargetFeatureInfo.h"
40 #include "llvm/ADT/APInt.h"
41 #include "llvm/ADT/EquivalenceClasses.h"
42 #include "llvm/ADT/Hashing.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/SetVector.h"
45 #include "llvm/ADT/Statistic.h"
46 #include "llvm/ADT/StringSet.h"
47 #include "llvm/Support/CommandLine.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/PrettyStackTrace.h"
50 #include "llvm/Support/ScopedPrinter.h"
51 #include "llvm/TableGen/Error.h"
52 #include "llvm/TableGen/Record.h"
53 #include "llvm/TableGen/StringMatcher.h"
54 #include "llvm/TableGen/TableGenBackend.h"
55 #include <cstdint>
56 
57 using namespace llvm;
58 using namespace llvm::gi;
59 
60 #define DEBUG_TYPE "gicombiner-emitter"
61 
62 namespace {
63 cl::OptionCategory
64     GICombinerEmitterCat("Options for -gen-global-isel-combiner");
65 cl::opt<bool> StopAfterParse(
66     "gicombiner-stop-after-parse",
67     cl::desc("Stop processing after parsing rules and dump state"),
68     cl::cat(GICombinerEmitterCat));
69 cl::list<std::string>
70     SelectedCombiners("combiners", cl::desc("Emit the specified combiners"),
71                       cl::cat(GICombinerEmitterCat), cl::CommaSeparated);
72 cl::opt<bool> DebugCXXPreds(
73     "gicombiner-debug-cxxpreds",
74     cl::desc("Add Contextual/Debug comments to all C++ predicates"),
75     cl::cat(GICombinerEmitterCat));
76 cl::opt<bool> DebugTypeInfer("gicombiner-debug-typeinfer",
77                              cl::desc("Print type inference debug logs"),
78                              cl::cat(GICombinerEmitterCat));
79 
80 constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply";
81 constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
82 constexpr StringLiteral MIFlagsEnumClassName = "MIFlagEnum";
83 
84 //===- CodeExpansions Helpers  --------------------------------------------===//
85 
declareInstExpansion(CodeExpansions & CE,const InstructionMatcher & IM,StringRef Name)86 void declareInstExpansion(CodeExpansions &CE, const InstructionMatcher &IM,
87                           StringRef Name) {
88   CE.declare(Name, "State.MIs[" + to_string(IM.getInsnVarID()) + "]");
89 }
90 
declareInstExpansion(CodeExpansions & CE,const BuildMIAction & A,StringRef Name)91 void declareInstExpansion(CodeExpansions &CE, const BuildMIAction &A,
92                           StringRef Name) {
93   // Note: we use redeclare here because this may overwrite a matcher inst
94   // expansion.
95   CE.redeclare(Name, "OutMIs[" + to_string(A.getInsnID()) + "]");
96 }
97 
declareOperandExpansion(CodeExpansions & CE,const OperandMatcher & OM,StringRef Name)98 void declareOperandExpansion(CodeExpansions &CE, const OperandMatcher &OM,
99                              StringRef Name) {
100   CE.declare(Name, "State.MIs[" + to_string(OM.getInsnVarID()) +
101                        "]->getOperand(" + to_string(OM.getOpIdx()) + ")");
102 }
103 
declareTempRegExpansion(CodeExpansions & CE,unsigned TempRegID,StringRef Name)104 void declareTempRegExpansion(CodeExpansions &CE, unsigned TempRegID,
105                              StringRef Name) {
106   CE.declare(Name, "State.TempRegisters[" + to_string(TempRegID) + "]");
107 }
108 
109 //===- Misc. Helpers  -----------------------------------------------------===//
110 
111 /// Copies a StringRef into a static pool to preserve it.
112 /// Most Pattern classes use StringRef so we need this.
insertStrRef(StringRef S)113 StringRef insertStrRef(StringRef S) {
114   if (S.empty())
115     return {};
116 
117   static StringSet<> Pool;
118   auto [It, Inserted] = Pool.insert(S);
119   return It->getKey();
120 }
121 
keys(Container && C)122 template <typename Container> auto keys(Container &&C) {
123   return map_range(C, [](auto &Entry) -> auto & { return Entry.first; });
124 }
125 
values(Container && C)126 template <typename Container> auto values(Container &&C) {
127   return map_range(C, [](auto &Entry) -> auto & { return Entry.second; });
128 }
129 
getIsEnabledPredicateEnumName(unsigned CombinerRuleID)130 std::string getIsEnabledPredicateEnumName(unsigned CombinerRuleID) {
131   return "GICXXPred_Simple_IsRule" + to_string(CombinerRuleID) + "Enabled";
132 }
133 
134 //===- MatchTable Helpers  ------------------------------------------------===//
135 
getLLTCodeGen(const PatternType & PT)136 LLTCodeGen getLLTCodeGen(const PatternType &PT) {
137   return *MVTToLLT(getValueType(PT.getLLTRecord()));
138 }
139 
getLLTCodeGenOrTempType(const PatternType & PT,RuleMatcher & RM)140 LLTCodeGenOrTempType getLLTCodeGenOrTempType(const PatternType &PT,
141                                              RuleMatcher &RM) {
142   assert(!PT.isNone());
143 
144   if (PT.isLLT())
145     return getLLTCodeGen(PT);
146 
147   assert(PT.isTypeOf());
148   auto &OM = RM.getOperandMatcher(PT.getTypeOfOpName());
149   return OM.getTempTypeIdx(RM);
150 }
151 
152 //===- PrettyStackTrace Helpers  ------------------------------------------===//
153 
154 class PrettyStackTraceParse : public PrettyStackTraceEntry {
155   const Record &Def;
156 
157 public:
PrettyStackTraceParse(const Record & Def)158   PrettyStackTraceParse(const Record &Def) : Def(Def) {}
159 
print(raw_ostream & OS) const160   void print(raw_ostream &OS) const override {
161     if (Def.isSubClassOf("GICombineRule"))
162       OS << "Parsing GICombineRule '" << Def.getName() << "'";
163     else if (Def.isSubClassOf(PatFrag::ClassName))
164       OS << "Parsing " << PatFrag::ClassName << " '" << Def.getName() << "'";
165     else
166       OS << "Parsing '" << Def.getName() << "'";
167     OS << '\n';
168   }
169 };
170 
171 class PrettyStackTraceEmit : public PrettyStackTraceEntry {
172   const Record &Def;
173   const Pattern *Pat = nullptr;
174 
175 public:
PrettyStackTraceEmit(const Record & Def,const Pattern * Pat=nullptr)176   PrettyStackTraceEmit(const Record &Def, const Pattern *Pat = nullptr)
177       : Def(Def), Pat(Pat) {}
178 
print(raw_ostream & OS) const179   void print(raw_ostream &OS) const override {
180     if (Def.isSubClassOf("GICombineRule"))
181       OS << "Emitting GICombineRule '" << Def.getName() << "'";
182     else if (Def.isSubClassOf(PatFrag::ClassName))
183       OS << "Emitting " << PatFrag::ClassName << " '" << Def.getName() << "'";
184     else
185       OS << "Emitting '" << Def.getName() << "'";
186 
187     if (Pat)
188       OS << " [" << Pat->getKindName() << " '" << Pat->getName() << "']";
189     OS << '\n';
190   }
191 };
192 
193 //===- CombineRuleOperandTypeChecker --------------------------------------===//
194 
195 /// This is a wrapper around OperandTypeChecker specialized for Combiner Rules.
196 /// On top of doing the same things as OperandTypeChecker, this also attempts to
197 /// infer as many types as possible for temporary register defs & immediates in
198 /// apply patterns.
199 ///
200 /// The inference is trivial and leverages the MCOI OperandTypes encoded in
201 /// CodeGenInstructions to infer types across patterns in a CombineRule. It's
202 /// thus very limited and only supports CodeGenInstructions (but that's the main
203 /// use case so it's fine).
204 ///
205 /// We only try to infer untyped operands in apply patterns when they're temp
206 /// reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
207 /// a named operand from a match pattern.
208 class CombineRuleOperandTypeChecker : private OperandTypeChecker {
209 public:
CombineRuleOperandTypeChecker(const Record & RuleDef,const OperandTable & MatchOpTable)210   CombineRuleOperandTypeChecker(const Record &RuleDef,
211                                 const OperandTable &MatchOpTable)
212       : OperandTypeChecker(RuleDef.getLoc()), RuleDef(RuleDef),
213         MatchOpTable(MatchOpTable) {}
214 
215   /// Records and checks a 'match' pattern.
216   bool processMatchPattern(InstructionPattern &P);
217 
218   /// Records and checks an 'apply' pattern.
219   bool processApplyPattern(InstructionPattern &P);
220 
221   /// Propagates types, then perform type inference and do a second round of
222   /// propagation in the apply patterns only if any types were inferred.
223   void propagateAndInferTypes();
224 
225 private:
226   /// TypeEquivalenceClasses are groups of operands of an instruction that share
227   /// a common type.
228   ///
229   /// e.g. [[a, b], [c, d]] means a and b have the same type, and c and
230   /// d have the same type too. b/c and a/d don't have to have the same type,
231   /// though.
232   using TypeEquivalenceClasses = EquivalenceClasses<StringRef>;
233 
234   /// \returns true for `OPERAND_GENERIC_` 0 through 5.
235   /// These are the MCOI types that can be registers. The other MCOI types are
236   /// either immediates, or fancier operands used only post-ISel, so we don't
237   /// care about them for combiners.
canMCOIOperandTypeBeARegister(StringRef MCOIType)238   static bool canMCOIOperandTypeBeARegister(StringRef MCOIType) {
239     // Assume OPERAND_GENERIC_0 through 5 can be registers. The other MCOI
240     // OperandTypes are either never used in gMIR, or not relevant (e.g.
241     // OPERAND_GENERIC_IMM, which is definitely never a register).
242     return MCOIType.drop_back(1).ends_with("OPERAND_GENERIC_");
243   }
244 
245   /// Finds the "MCOI::"" operand types for each operand of \p CGP.
246   ///
247   /// This is a bit trickier than it looks because we need to handle variadic
248   /// in/outs.
249   ///
250   /// e.g. for
251   ///   (G_BUILD_VECTOR $vec, $x, $y) ->
252   ///   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
253   ///    MCOI::OPERAND_GENERIC_1]
254   ///
255   /// For unknown types (which can happen in variadics where varargs types are
256   /// inconsistent), a unique name is given, e.g. "unknown_type_0".
257   static std::vector<std::string>
258   getMCOIOperandTypes(const CodeGenInstructionPattern &CGP);
259 
260   /// Adds the TypeEquivalenceClasses for \p P in \p OutTECs.
261   void getInstEqClasses(const InstructionPattern &P,
262                         TypeEquivalenceClasses &OutTECs) const;
263 
264   /// Calls `getInstEqClasses` on all patterns of the rule to produce the whole
265   /// rule's TypeEquivalenceClasses.
266   TypeEquivalenceClasses getRuleEqClasses() const;
267 
268   /// Tries to infer the type of the \p ImmOpIdx -th operand of \p IP using \p
269   /// TECs.
270   ///
271   /// This is achieved by trying to find a named operand in \p IP that shares
272   /// the same type as \p ImmOpIdx, and using \ref inferNamedOperandType on that
273   /// operand instead.
274   ///
275   /// \returns the inferred type or an empty PatternType if inference didn't
276   /// succeed.
277   PatternType inferImmediateType(const InstructionPattern &IP,
278                                  unsigned ImmOpIdx,
279                                  const TypeEquivalenceClasses &TECs) const;
280 
281   /// Looks inside \p TECs to infer \p OpName's type.
282   ///
283   /// \returns the inferred type or an empty PatternType if inference didn't
284   /// succeed.
285   PatternType inferNamedOperandType(const InstructionPattern &IP,
286                                     StringRef OpName,
287                                     const TypeEquivalenceClasses &TECs,
288                                     bool AllowSelf = false) const;
289 
290   const Record &RuleDef;
291   SmallVector<InstructionPattern *, 8> MatchPats;
292   SmallVector<InstructionPattern *, 8> ApplyPats;
293 
294   const OperandTable &MatchOpTable;
295 };
296 
processMatchPattern(InstructionPattern & P)297 bool CombineRuleOperandTypeChecker::processMatchPattern(InstructionPattern &P) {
298   MatchPats.push_back(&P);
299   return check(P, /*CheckTypeOf*/ [](const auto &) {
300     // GITypeOf in 'match' is currently always rejected by the
301     // CombineRuleBuilder after inference is done.
302     return true;
303   });
304 }
305 
processApplyPattern(InstructionPattern & P)306 bool CombineRuleOperandTypeChecker::processApplyPattern(InstructionPattern &P) {
307   ApplyPats.push_back(&P);
308   return check(P, /*CheckTypeOf*/ [&](const PatternType &Ty) {
309     // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
310     const auto OpName = Ty.getTypeOfOpName();
311     if (MatchOpTable.lookup(OpName).Found)
312       return true;
313 
314     PrintError(RuleDef.getLoc(), "'" + OpName + "' ('" + Ty.str() +
315                                      "') does not refer to a matched operand!");
316     return false;
317   });
318 }
319 
propagateAndInferTypes()320 void CombineRuleOperandTypeChecker::propagateAndInferTypes() {
321   /// First step here is to propagate types using the OperandTypeChecker. That
322   /// way we ensure all uses of a given register have consistent types.
323   propagateTypes();
324 
325   /// Build the TypeEquivalenceClasses for the whole rule.
326   const TypeEquivalenceClasses TECs = getRuleEqClasses();
327 
328   /// Look at the apply patterns and find operands that need to be
329   /// inferred. We then try to find an equivalence class that they're a part of
330   /// and select the best operand to use for the `GITypeOf` type. We prioritize
331   /// defs of matched instructions because those are guaranteed to be registers.
332   bool InferredAny = false;
333   for (auto *Pat : ApplyPats) {
334     for (unsigned K = 0; K < Pat->operands_size(); ++K) {
335       auto &Op = Pat->getOperand(K);
336 
337       // We only want to take a look at untyped defs or immediates.
338       if ((!Op.isDef() && !Op.hasImmValue()) || Op.getType())
339         continue;
340 
341       // Infer defs & named immediates.
342       if (Op.isDef() || Op.isNamedImmediate()) {
343         // Check it's not a redefinition of a matched operand.
344         // In such cases, inference is not necessary because we just copy
345         // operands and don't create temporary registers.
346         if (MatchOpTable.lookup(Op.getOperandName()).Found)
347           continue;
348 
349         // Inference is needed here, so try to do it.
350         if (PatternType Ty =
351                 inferNamedOperandType(*Pat, Op.getOperandName(), TECs)) {
352           if (DebugTypeInfer)
353             errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << '\n';
354           Op.setType(Ty);
355           InferredAny = true;
356         }
357 
358         continue;
359       }
360 
361       // Infer immediates
362       if (Op.hasImmValue()) {
363         if (PatternType Ty = inferImmediateType(*Pat, K, TECs)) {
364           if (DebugTypeInfer)
365             errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << '\n';
366           Op.setType(Ty);
367           InferredAny = true;
368         }
369         continue;
370       }
371     }
372   }
373 
374   // If we've inferred any types, we want to propagate them across the apply
375   // patterns. Type inference only adds GITypeOf types that point to Matched
376   // operands, so we definitely don't want to propagate types into the match
377   // patterns as well, otherwise bad things happen.
378   if (InferredAny) {
379     OperandTypeChecker OTC(RuleDef.getLoc());
380     for (auto *Pat : ApplyPats) {
381       if (!OTC.check(*Pat, [&](const auto &) { return true; }))
382         PrintFatalError(RuleDef.getLoc(),
383                         "OperandTypeChecker unexpectedly failed on '" +
384                             Pat->getName() + "' during Type Inference");
385     }
386     OTC.propagateTypes();
387 
388     if (DebugTypeInfer) {
389       errs() << "Apply patterns for rule " << RuleDef.getName()
390              << " after inference:\n";
391       for (auto *Pat : ApplyPats) {
392         errs() << "  ";
393         Pat->print(errs(), /*PrintName*/ true);
394         errs() << '\n';
395       }
396       errs() << '\n';
397     }
398   }
399 }
400 
inferImmediateType(const InstructionPattern & IP,unsigned ImmOpIdx,const TypeEquivalenceClasses & TECs) const401 PatternType CombineRuleOperandTypeChecker::inferImmediateType(
402     const InstructionPattern &IP, unsigned ImmOpIdx,
403     const TypeEquivalenceClasses &TECs) const {
404   // We can only infer CGPs.
405   const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP);
406   if (!CGP)
407     return {};
408 
409   // For CGPs, we try to infer immediates by trying to infer another named
410   // operand that shares its type.
411   //
412   // e.g.
413   //    Pattern: G_BUILD_VECTOR $x, $y, 0
414   //    MCOIs:   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
415   //              MCOI::OPERAND_GENERIC_1]
416   //    $y has the same type as 0, so we can infer $y and get the type 0 should
417   //    have.
418 
419   // We infer immediates by looking for a named operand that shares the same
420   // MCOI type.
421   const auto MCOITypes = getMCOIOperandTypes(*CGP);
422   StringRef ImmOpTy = MCOITypes[ImmOpIdx];
423 
424   for (const auto &[Idx, Ty] : enumerate(MCOITypes)) {
425     if (Idx != ImmOpIdx && Ty == ImmOpTy) {
426       const auto &Op = IP.getOperand(Idx);
427       if (!Op.isNamedOperand())
428         continue;
429 
430       // Named operand with the same name, try to infer that.
431       if (PatternType InferTy = inferNamedOperandType(IP, Op.getOperandName(),
432                                                       TECs, /*AllowSelf=*/true))
433         return InferTy;
434     }
435   }
436 
437   return {};
438 }
439 
inferNamedOperandType(const InstructionPattern & IP,StringRef OpName,const TypeEquivalenceClasses & TECs,bool AllowSelf) const440 PatternType CombineRuleOperandTypeChecker::inferNamedOperandType(
441     const InstructionPattern &IP, StringRef OpName,
442     const TypeEquivalenceClasses &TECs, bool AllowSelf) const {
443   // This is the simplest possible case, we just need to find a TEC that
444   // contains OpName. Look at all operands in equivalence class and try to
445   // find a suitable one. If `AllowSelf` is true, the operand itself is also
446   // considered suitable.
447 
448   // Check for a def of a matched pattern. This is guaranteed to always
449   // be a register so we can blindly use that.
450   StringRef GoodOpName;
451   for (auto It = TECs.findLeader(OpName); It != TECs.member_end(); ++It) {
452     if (!AllowSelf && *It == OpName)
453       continue;
454 
455     const auto LookupRes = MatchOpTable.lookup(*It);
456     if (LookupRes.Def) // Favor defs
457       return PatternType::getTypeOf(*It);
458 
459     // Otherwise just save this in case we don't find any def.
460     if (GoodOpName.empty() && LookupRes.Found)
461       GoodOpName = *It;
462   }
463 
464   if (!GoodOpName.empty())
465     return PatternType::getTypeOf(GoodOpName);
466 
467   // No good operand found, give up.
468   return {};
469 }
470 
getMCOIOperandTypes(const CodeGenInstructionPattern & CGP)471 std::vector<std::string> CombineRuleOperandTypeChecker::getMCOIOperandTypes(
472     const CodeGenInstructionPattern &CGP) {
473   // FIXME?: Should we cache this? We call it twice when inferring immediates.
474 
475   static unsigned UnknownTypeIdx = 0;
476 
477   std::vector<std::string> OpTypes;
478   auto &CGI = CGP.getInst();
479   Record *VarArgsTy = CGI.TheDef->isSubClassOf("GenericInstruction")
480                           ? CGI.TheDef->getValueAsOptionalDef("variadicOpsType")
481                           : nullptr;
482   std::string VarArgsTyName =
483       VarArgsTy ? ("MCOI::" + VarArgsTy->getValueAsString("OperandType")).str()
484                 : ("unknown_type_" + Twine(UnknownTypeIdx++)).str();
485 
486   // First, handle defs.
487   for (unsigned K = 0; K < CGI.Operands.NumDefs; ++K)
488     OpTypes.push_back(CGI.Operands[K].OperandType);
489 
490   // Then, handle variadic defs if there are any.
491   if (CGP.hasVariadicDefs()) {
492     for (unsigned K = CGI.Operands.NumDefs; K < CGP.getNumInstDefs(); ++K)
493       OpTypes.push_back(VarArgsTyName);
494   }
495 
496   // If we had variadic defs, the op idx in the pattern won't match the op idx
497   // in the CGI anymore.
498   int CGIOpOffset = int(CGI.Operands.NumDefs) - CGP.getNumInstDefs();
499   assert(CGP.hasVariadicDefs() ? (CGIOpOffset <= 0) : (CGIOpOffset == 0));
500 
501   // Handle all remaining use operands, including variadic ones.
502   for (unsigned K = CGP.getNumInstDefs(); K < CGP.getNumInstOperands(); ++K) {
503     unsigned CGIOpIdx = K + CGIOpOffset;
504     if (CGIOpIdx >= CGI.Operands.size()) {
505       assert(CGP.isVariadic());
506       OpTypes.push_back(VarArgsTyName);
507     } else {
508       OpTypes.push_back(CGI.Operands[CGIOpIdx].OperandType);
509     }
510   }
511 
512   assert(OpTypes.size() == CGP.operands_size());
513   return OpTypes;
514 }
515 
getInstEqClasses(const InstructionPattern & P,TypeEquivalenceClasses & OutTECs) const516 void CombineRuleOperandTypeChecker::getInstEqClasses(
517     const InstructionPattern &P, TypeEquivalenceClasses &OutTECs) const {
518   // Determine the TypeEquivalenceClasses by:
519   //    - Getting the MCOI Operand Types.
520   //    - Creating a Map of MCOI Type -> [Operand Indexes]
521   //    - Iterating over the map, filtering types we don't like, and just adding
522   //      the array of Operand Indexes to \p OutTECs.
523 
524   // We can only do this on CodeGenInstructions. Other InstructionPatterns have
525   // no type inference information associated with them.
526   // TODO: Could we add some inference information to builtins at least? e.g.
527   // ReplaceReg should always replace with a reg of the same type, for instance.
528   // Though, those patterns are often used alone so it might not be worth the
529   // trouble to infer their types.
530   auto *CGP = dyn_cast<CodeGenInstructionPattern>(&P);
531   if (!CGP)
532     return;
533 
534   const auto MCOITypes = getMCOIOperandTypes(*CGP);
535   assert(MCOITypes.size() == P.operands_size());
536 
537   DenseMap<StringRef, std::vector<unsigned>> TyToOpIdx;
538   for (const auto &[Idx, Ty] : enumerate(MCOITypes))
539     TyToOpIdx[Ty].push_back(Idx);
540 
541   if (DebugTypeInfer)
542     errs() << "\tGroups for " << P.getName() << ":\t";
543 
544   for (const auto &[Ty, Idxs] : TyToOpIdx) {
545     if (!canMCOIOperandTypeBeARegister(Ty))
546       continue;
547 
548     if (DebugTypeInfer)
549       errs() << '[';
550     StringRef Sep = "";
551 
552     // We only collect named operands.
553     StringRef Leader;
554     for (unsigned Idx : Idxs) {
555       const auto &Op = P.getOperand(Idx);
556       if (!Op.isNamedOperand())
557         continue;
558 
559       const auto OpName = Op.getOperandName();
560       if (DebugTypeInfer) {
561         errs() << Sep << OpName;
562         Sep = ", ";
563       }
564 
565       if (Leader.empty())
566         OutTECs.insert((Leader = OpName));
567       else
568         OutTECs.unionSets(Leader, OpName);
569     }
570 
571     if (DebugTypeInfer)
572       errs() << "] ";
573   }
574 
575   if (DebugTypeInfer)
576     errs() << '\n';
577 }
578 
579 CombineRuleOperandTypeChecker::TypeEquivalenceClasses
getRuleEqClasses() const580 CombineRuleOperandTypeChecker::getRuleEqClasses() const {
581   StringMap<unsigned> OpNameToEqClassIdx;
582   TypeEquivalenceClasses TECs;
583 
584   if (DebugTypeInfer)
585     errs() << "Rule Operand Type Equivalence Classes for " << RuleDef.getName()
586            << ":\n";
587 
588   for (const auto *Pat : MatchPats)
589     getInstEqClasses(*Pat, TECs);
590   for (const auto *Pat : ApplyPats)
591     getInstEqClasses(*Pat, TECs);
592 
593   if (DebugTypeInfer) {
594     errs() << "Final Type Equivalence Classes: ";
595     for (auto ClassIt = TECs.begin(); ClassIt != TECs.end(); ++ClassIt) {
596       // only print non-empty classes.
597       if (auto MembIt = TECs.member_begin(ClassIt);
598           MembIt != TECs.member_end()) {
599         errs() << '[';
600         StringRef Sep = "";
601         for (; MembIt != TECs.member_end(); ++MembIt) {
602           errs() << Sep << *MembIt;
603           Sep = ", ";
604         }
605         errs() << "] ";
606       }
607     }
608     errs() << '\n';
609   }
610 
611   return TECs;
612 }
613 
614 //===- CombineRuleBuilder -------------------------------------------------===//
615 
616 /// Parses combine rule and builds a small intermediate representation to tie
617 /// patterns together and emit RuleMatchers to match them. This may emit more
618 /// than one RuleMatcher, e.g. for `wip_match_opcode`.
619 ///
620 /// Memory management for `Pattern` objects is done through `std::unique_ptr`.
621 /// In most cases, there are two stages to a pattern's lifetime:
622 ///   - Creation in a `parse` function
623 ///     - The unique_ptr is stored in a variable, and may be destroyed if the
624 ///       pattern is found to be semantically invalid.
625 ///   - Ownership transfer into a `PatternMap`
626 ///     - Once a pattern is moved into either the map of Match or Apply
627 ///       patterns, it is known to be valid and it never moves back.
628 class CombineRuleBuilder {
629 public:
630   using PatternMap = MapVector<StringRef, std::unique_ptr<Pattern>>;
631   using PatternAlternatives = DenseMap<const Pattern *, unsigned>;
632 
CombineRuleBuilder(const CodeGenTarget & CGT,SubtargetFeatureInfoMap & SubtargetFeatures,Record & RuleDef,unsigned ID,std::vector<RuleMatcher> & OutRMs)633   CombineRuleBuilder(const CodeGenTarget &CGT,
634                      SubtargetFeatureInfoMap &SubtargetFeatures,
635                      Record &RuleDef, unsigned ID,
636                      std::vector<RuleMatcher> &OutRMs)
637       : CGT(CGT), SubtargetFeatures(SubtargetFeatures), RuleDef(RuleDef),
638         RuleID(ID), OutRMs(OutRMs) {}
639 
640   /// Parses all fields in the RuleDef record.
641   bool parseAll();
642 
643   /// Emits all RuleMatchers into the vector of RuleMatchers passed in the
644   /// constructor.
645   bool emitRuleMatchers();
646 
647   void print(raw_ostream &OS) const;
dump() const648   void dump() const { print(dbgs()); }
649 
650   /// Debug-only verification of invariants.
651 #ifndef NDEBUG
652   void verify() const;
653 #endif
654 
655 private:
getGConstant() const656   const CodeGenInstruction &getGConstant() const {
657     return CGT.getInstruction(RuleDef.getRecords().getDef("G_CONSTANT"));
658   }
659 
PrintError(Twine Msg) const660   void PrintError(Twine Msg) const { ::PrintError(&RuleDef, Msg); }
PrintWarning(Twine Msg) const661   void PrintWarning(Twine Msg) const { ::PrintWarning(RuleDef.getLoc(), Msg); }
PrintNote(Twine Msg) const662   void PrintNote(Twine Msg) const { ::PrintNote(RuleDef.getLoc(), Msg); }
663 
664   void print(raw_ostream &OS, const PatternAlternatives &Alts) const;
665 
666   bool addApplyPattern(std::unique_ptr<Pattern> Pat);
667   bool addMatchPattern(std::unique_ptr<Pattern> Pat);
668 
669   /// Adds the expansions from \see MatchDatas to \p CE.
670   void declareAllMatchDatasExpansions(CodeExpansions &CE) const;
671 
672   /// Adds a matcher \p P to \p IM, expanding its code using \p CE.
673   /// Note that the predicate is added on the last InstructionMatcher.
674   ///
675   /// \p Alts is only used if DebugCXXPreds is enabled.
676   void addCXXPredicate(RuleMatcher &M, const CodeExpansions &CE,
677                        const CXXPattern &P, const PatternAlternatives &Alts);
678 
679   /// Adds an apply \p P to \p IM, expanding its code using \p CE.
680   void addCXXAction(RuleMatcher &M, const CodeExpansions &CE,
681                     const CXXPattern &P);
682 
683   bool hasOnlyCXXApplyPatterns() const;
684   bool hasEraseRoot() const;
685 
686   // Infer machine operand types and check their consistency.
687   bool typecheckPatterns();
688 
689   /// For all PatFragPatterns, add a new entry in PatternAlternatives for each
690   /// PatternList it contains. This is multiplicative, so if we have 2
691   /// PatFrags with 3 alternatives each, we get 2*3 permutations added to
692   /// PermutationsToEmit. The "MaxPermutations" field controls how many
693   /// permutations are allowed before an error is emitted and this function
694   /// returns false. This is a simple safeguard to prevent combination of
695   /// PatFrags from generating enormous amounts of rules.
696   bool buildPermutationsToEmit();
697 
698   /// Checks additional semantics of the Patterns.
699   bool checkSemantics();
700 
701   /// Creates a new RuleMatcher with some boilerplate
702   /// settings/actions/predicates, and and adds it to \p OutRMs.
703   /// \see addFeaturePredicates too.
704   ///
705   /// \param Alts Current set of alternatives, for debug comment.
706   /// \param AdditionalComment Comment string to be added to the
707   ///        `DebugCommentAction`.
708   RuleMatcher &addRuleMatcher(const PatternAlternatives &Alts,
709                               Twine AdditionalComment = "");
710   bool addFeaturePredicates(RuleMatcher &M);
711 
712   bool findRoots();
713   bool buildRuleOperandsTable();
714 
715   bool parseDefs(const DagInit &Def);
716   bool
717   parsePatternList(const DagInit &List,
718                    function_ref<bool(std::unique_ptr<Pattern>)> ParseAction,
719                    StringRef Operator, ArrayRef<SMLoc> DiagLoc,
720                    StringRef AnonPatNamePrefix) const;
721 
722   std::unique_ptr<Pattern> parseInstructionPattern(const Init &Arg,
723                                                    StringRef PatName) const;
724   std::unique_ptr<Pattern> parseWipMatchOpcodeMatcher(const Init &Arg,
725                                                       StringRef PatName) const;
726   bool parseInstructionPatternOperand(InstructionPattern &IP,
727                                       const Init *OpInit,
728                                       const StringInit *OpName) const;
729   bool parseInstructionPatternMIFlags(InstructionPattern &IP,
730                                       const DagInit *Op) const;
731   std::unique_ptr<PatFrag> parsePatFragImpl(const Record *Def) const;
732   bool parsePatFragParamList(
733       ArrayRef<SMLoc> DiagLoc, const DagInit &OpsList,
734       function_ref<bool(StringRef, PatFrag::ParamKind)> ParseAction) const;
735   const PatFrag *parsePatFrag(const Record *Def) const;
736 
737   bool emitMatchPattern(CodeExpansions &CE, const PatternAlternatives &Alts,
738                         const InstructionPattern &IP);
739   bool emitMatchPattern(CodeExpansions &CE, const PatternAlternatives &Alts,
740                         const AnyOpcodePattern &AOP);
741 
742   bool emitPatFragMatchPattern(CodeExpansions &CE,
743                                const PatternAlternatives &Alts, RuleMatcher &RM,
744                                InstructionMatcher *IM,
745                                const PatFragPattern &PFP,
746                                DenseSet<const Pattern *> &SeenPats);
747 
748   bool emitApplyPatterns(CodeExpansions &CE, RuleMatcher &M);
749 
750   // Recursively visits InstructionPatterns from P to build up the
751   // RuleMatcher actions.
752   bool emitInstructionApplyPattern(CodeExpansions &CE, RuleMatcher &M,
753                                    const InstructionPattern &P,
754                                    DenseSet<const Pattern *> &SeenPats,
755                                    StringMap<unsigned> &OperandToTempRegID);
756 
757   bool emitCodeGenInstructionApplyImmOperand(RuleMatcher &M,
758                                              BuildMIAction &DstMI,
759                                              const CodeGenInstructionPattern &P,
760                                              const InstructionOperand &O);
761 
762   bool emitBuiltinApplyPattern(CodeExpansions &CE, RuleMatcher &M,
763                                const BuiltinPattern &P,
764                                StringMap<unsigned> &OperandToTempRegID);
765 
766   // Recursively visits CodeGenInstructionPattern from P to build up the
767   // RuleMatcher/InstructionMatcher. May create new InstructionMatchers as
768   // needed.
769   using OperandMapperFnRef =
770       function_ref<InstructionOperand(const InstructionOperand &)>;
771   using OperandDefLookupFn =
772       function_ref<const InstructionPattern *(StringRef)>;
773   bool emitCodeGenInstructionMatchPattern(
774       CodeExpansions &CE, const PatternAlternatives &Alts, RuleMatcher &M,
775       InstructionMatcher &IM, const CodeGenInstructionPattern &P,
776       DenseSet<const Pattern *> &SeenPats, OperandDefLookupFn LookupOperandDef,
__anoned09a36a0702(const auto &O) 777       OperandMapperFnRef OperandMapper = [](const auto &O) { return O; });
778 
779   const CodeGenTarget &CGT;
780   SubtargetFeatureInfoMap &SubtargetFeatures;
781   Record &RuleDef;
782   const unsigned RuleID;
783   std::vector<RuleMatcher> &OutRMs;
784 
785   // For InstructionMatcher::addOperand
786   unsigned AllocatedTemporariesBaseID = 0;
787 
788   /// The root of the pattern.
789   StringRef RootName;
790 
791   /// These maps have ownership of the actual Pattern objects.
792   /// They both map a Pattern's name to the Pattern instance.
793   PatternMap MatchPats;
794   PatternMap ApplyPats;
795 
796   /// Operand tables to tie match/apply patterns together.
797   OperandTable MatchOpTable;
798   OperandTable ApplyOpTable;
799 
800   /// Set by findRoots.
801   Pattern *MatchRoot = nullptr;
802   SmallDenseSet<InstructionPattern *, 2> ApplyRoots;
803 
804   SmallVector<MatchDataInfo, 2> MatchDatas;
805   SmallVector<PatternAlternatives, 1> PermutationsToEmit;
806 
807   // print()/debug-only members.
808   mutable SmallPtrSet<const PatFrag *, 2> SeenPatFrags;
809 };
810 
parseAll()811 bool CombineRuleBuilder::parseAll() {
812   auto StackTrace = PrettyStackTraceParse(RuleDef);
813 
814   if (!parseDefs(*RuleDef.getValueAsDag("Defs")))
815     return false;
816 
817   if (!parsePatternList(
818           *RuleDef.getValueAsDag("Match"),
819           [this](auto Pat) { return addMatchPattern(std::move(Pat)); }, "match",
820           RuleDef.getLoc(), (RuleDef.getName() + "_match").str()))
821     return false;
822 
823   if (!parsePatternList(
824           *RuleDef.getValueAsDag("Apply"),
825           [this](auto Pat) { return addApplyPattern(std::move(Pat)); }, "apply",
826           RuleDef.getLoc(), (RuleDef.getName() + "_apply").str()))
827     return false;
828 
829   if (!buildRuleOperandsTable() || !typecheckPatterns() || !findRoots() ||
830       !checkSemantics() || !buildPermutationsToEmit())
831     return false;
832   LLVM_DEBUG(verify());
833   return true;
834 }
835 
emitRuleMatchers()836 bool CombineRuleBuilder::emitRuleMatchers() {
837   auto StackTrace = PrettyStackTraceEmit(RuleDef);
838 
839   assert(MatchRoot);
840   CodeExpansions CE;
841   declareAllMatchDatasExpansions(CE);
842 
843   assert(!PermutationsToEmit.empty());
844   for (const auto &Alts : PermutationsToEmit) {
845     switch (MatchRoot->getKind()) {
846     case Pattern::K_AnyOpcode: {
847       if (!emitMatchPattern(CE, Alts, *cast<AnyOpcodePattern>(MatchRoot)))
848         return false;
849       break;
850     }
851     case Pattern::K_PatFrag:
852     case Pattern::K_Builtin:
853     case Pattern::K_CodeGenInstruction:
854       if (!emitMatchPattern(CE, Alts, *cast<InstructionPattern>(MatchRoot)))
855         return false;
856       break;
857     case Pattern::K_CXX:
858       PrintError("C++ code cannot be the root of a rule!");
859       return false;
860     default:
861       llvm_unreachable("unknown pattern kind!");
862     }
863   }
864 
865   return true;
866 }
867 
print(raw_ostream & OS) const868 void CombineRuleBuilder::print(raw_ostream &OS) const {
869   OS << "(CombineRule name:" << RuleDef.getName() << " id:" << RuleID
870      << " root:" << RootName << '\n';
871 
872   if (!MatchDatas.empty()) {
873     OS << "  (MatchDatas\n";
874     for (const auto &MD : MatchDatas) {
875       OS << "    ";
876       MD.print(OS);
877       OS << '\n';
878     }
879     OS << "  )\n";
880   }
881 
882   if (!SeenPatFrags.empty()) {
883     OS << "  (PatFrags\n";
884     for (const auto *PF : SeenPatFrags) {
885       PF->print(OS, /*Indent=*/"    ");
886       OS << '\n';
887     }
888     OS << "  )\n";
889   }
890 
891   const auto DumpPats = [&](StringRef Name, const PatternMap &Pats) {
892     OS << "  (" << Name << " ";
893     if (Pats.empty()) {
894       OS << "<empty>)\n";
895       return;
896     }
897 
898     OS << '\n';
899     for (const auto &[Name, Pat] : Pats) {
900       OS << "    ";
901       if (Pat.get() == MatchRoot)
902         OS << "<match_root>";
903       if (isa<InstructionPattern>(Pat.get()) &&
904           ApplyRoots.contains(cast<InstructionPattern>(Pat.get())))
905         OS << "<apply_root>";
906       OS << Name << ":";
907       Pat->print(OS, /*PrintName=*/false);
908       OS << '\n';
909     }
910     OS << "  )\n";
911   };
912 
913   DumpPats("MatchPats", MatchPats);
914   DumpPats("ApplyPats", ApplyPats);
915 
916   MatchOpTable.print(OS, "MatchPats", /*Indent*/ "  ");
917   ApplyOpTable.print(OS, "ApplyPats", /*Indent*/ "  ");
918 
919   if (PermutationsToEmit.size() > 1) {
920     OS << "  (PermutationsToEmit\n";
921     for (const auto &Perm : PermutationsToEmit) {
922       OS << "    ";
923       print(OS, Perm);
924       OS << ",\n";
925     }
926     OS << "  )\n";
927   }
928 
929   OS << ")\n";
930 }
931 
932 #ifndef NDEBUG
verify() const933 void CombineRuleBuilder::verify() const {
934   const auto VerifyPats = [&](const PatternMap &Pats) {
935     for (const auto &[Name, Pat] : Pats) {
936       if (!Pat)
937         PrintFatalError("null pattern in pattern map!");
938 
939       if (Name != Pat->getName()) {
940         Pat->dump();
941         PrintFatalError("Pattern name mismatch! Map name: " + Name +
942                         ", Pat name: " + Pat->getName());
943       }
944 
945       // Sanity check: the map should point to the same data as the Pattern.
946       // Both strings are allocated in the pool using insertStrRef.
947       if (Name.data() != Pat->getName().data()) {
948         dbgs() << "Map StringRef: '" << Name << "' @ "
949                << (const void *)Name.data() << '\n';
950         dbgs() << "Pat String: '" << Pat->getName() << "' @ "
951                << (const void *)Pat->getName().data() << '\n';
952         PrintFatalError("StringRef stored in the PatternMap is not referencing "
953                         "the same string as its Pattern!");
954       }
955     }
956   };
957 
958   VerifyPats(MatchPats);
959   VerifyPats(ApplyPats);
960 
961   // Check there are no wip_match_opcode patterns in the "apply" patterns.
962   if (any_of(ApplyPats,
963              [&](auto &E) { return isa<AnyOpcodePattern>(E.second.get()); })) {
964     dump();
965     PrintFatalError(
966         "illegal wip_match_opcode pattern in the 'apply' patterns!");
967   }
968 
969   // Check there are no nullptrs in ApplyRoots.
970   if (ApplyRoots.contains(nullptr)) {
971     PrintFatalError(
972         "CombineRuleBuilder's ApplyRoots set contains a null pointer!");
973   }
974 }
975 #endif
976 
print(raw_ostream & OS,const PatternAlternatives & Alts) const977 void CombineRuleBuilder::print(raw_ostream &OS,
978                                const PatternAlternatives &Alts) const {
979   SmallVector<std::string, 1> Strings(
980       map_range(Alts, [](const auto &PatAndPerm) {
981         return PatAndPerm.first->getName().str() + "[" +
982                to_string(PatAndPerm.second) + "]";
983       }));
984   // Sort so output is deterministic for tests. Otherwise it's sorted by pointer
985   // values.
986   sort(Strings);
987   OS << "[" << join(Strings, ", ") << "]";
988 }
989 
addApplyPattern(std::unique_ptr<Pattern> Pat)990 bool CombineRuleBuilder::addApplyPattern(std::unique_ptr<Pattern> Pat) {
991   StringRef Name = Pat->getName();
992   if (ApplyPats.contains(Name)) {
993     PrintError("'" + Name + "' apply pattern defined more than once!");
994     return false;
995   }
996 
997   if (isa<AnyOpcodePattern>(Pat.get())) {
998     PrintError("'" + Name +
999                "': wip_match_opcode is not supported in apply patterns");
1000     return false;
1001   }
1002 
1003   if (isa<PatFragPattern>(Pat.get())) {
1004     PrintError("'" + Name + "': using " + PatFrag::ClassName +
1005                " is not supported in apply patterns");
1006     return false;
1007   }
1008 
1009   if (auto *CXXPat = dyn_cast<CXXPattern>(Pat.get()))
1010     CXXPat->setIsApply();
1011 
1012   ApplyPats[Name] = std::move(Pat);
1013   return true;
1014 }
1015 
addMatchPattern(std::unique_ptr<Pattern> Pat)1016 bool CombineRuleBuilder::addMatchPattern(std::unique_ptr<Pattern> Pat) {
1017   StringRef Name = Pat->getName();
1018   if (MatchPats.contains(Name)) {
1019     PrintError("'" + Name + "' match pattern defined more than once!");
1020     return false;
1021   }
1022 
1023   // For now, none of the builtins can appear in 'match'.
1024   if (const auto *BP = dyn_cast<BuiltinPattern>(Pat.get())) {
1025     PrintError("'" + BP->getInstName() +
1026                "' cannot be used in a 'match' pattern");
1027     return false;
1028   }
1029 
1030   MatchPats[Name] = std::move(Pat);
1031   return true;
1032 }
1033 
declareAllMatchDatasExpansions(CodeExpansions & CE) const1034 void CombineRuleBuilder::declareAllMatchDatasExpansions(
1035     CodeExpansions &CE) const {
1036   for (const auto &MD : MatchDatas)
1037     CE.declare(MD.getPatternSymbol(), MD.getQualifiedVariableName());
1038 }
1039 
addCXXPredicate(RuleMatcher & M,const CodeExpansions & CE,const CXXPattern & P,const PatternAlternatives & Alts)1040 void CombineRuleBuilder::addCXXPredicate(RuleMatcher &M,
1041                                          const CodeExpansions &CE,
1042                                          const CXXPattern &P,
1043                                          const PatternAlternatives &Alts) {
1044   // FIXME: Hack so C++ code is executed last. May not work for more complex
1045   // patterns.
1046   auto &IM = *std::prev(M.insnmatchers().end());
1047   auto Loc = RuleDef.getLoc();
1048   const auto AddComment = [&](raw_ostream &OS) {
1049     OS << "// Pattern Alternatives: ";
1050     print(OS, Alts);
1051     OS << '\n';
1052   };
1053   const auto &ExpandedCode =
1054       DebugCXXPreds ? P.expandCode(CE, Loc, AddComment) : P.expandCode(CE, Loc);
1055   IM->addPredicate<GenericInstructionPredicateMatcher>(
1056       ExpandedCode.getEnumNameWithPrefix(CXXPredPrefix));
1057 }
1058 
addCXXAction(RuleMatcher & M,const CodeExpansions & CE,const CXXPattern & P)1059 void CombineRuleBuilder::addCXXAction(RuleMatcher &M, const CodeExpansions &CE,
1060                                       const CXXPattern &P) {
1061   const auto &ExpandedCode = P.expandCode(CE, RuleDef.getLoc());
1062   M.addAction<CustomCXXAction>(
1063       ExpandedCode.getEnumNameWithPrefix(CXXApplyPrefix));
1064 }
1065 
hasOnlyCXXApplyPatterns() const1066 bool CombineRuleBuilder::hasOnlyCXXApplyPatterns() const {
1067   return all_of(ApplyPats, [&](auto &Entry) {
1068     return isa<CXXPattern>(Entry.second.get());
1069   });
1070 }
1071 
hasEraseRoot() const1072 bool CombineRuleBuilder::hasEraseRoot() const {
1073   return any_of(ApplyPats, [&](auto &Entry) {
1074     if (const auto *BP = dyn_cast<BuiltinPattern>(Entry.second.get()))
1075       return BP->getBuiltinKind() == BI_EraseRoot;
1076     return false;
1077   });
1078 }
1079 
typecheckPatterns()1080 bool CombineRuleBuilder::typecheckPatterns() {
1081   CombineRuleOperandTypeChecker OTC(RuleDef, MatchOpTable);
1082 
1083   for (auto &Pat : values(MatchPats)) {
1084     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
1085       if (!OTC.processMatchPattern(*IP))
1086         return false;
1087     }
1088   }
1089 
1090   for (auto &Pat : values(ApplyPats)) {
1091     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
1092       if (!OTC.processApplyPattern(*IP))
1093         return false;
1094     }
1095   }
1096 
1097   OTC.propagateAndInferTypes();
1098 
1099   // Always check this after in case inference adds some special types to the
1100   // match patterns.
1101   for (auto &Pat : values(MatchPats)) {
1102     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
1103       if (IP->diagnoseAllSpecialTypes(
1104               RuleDef.getLoc(), PatternType::SpecialTyClassName +
1105                                     " is not supported in 'match' patterns")) {
1106         return false;
1107       }
1108     }
1109   }
1110   return true;
1111 }
1112 
buildPermutationsToEmit()1113 bool CombineRuleBuilder::buildPermutationsToEmit() {
1114   PermutationsToEmit.clear();
1115 
1116   // Start with one empty set of alternatives.
1117   PermutationsToEmit.emplace_back();
1118   for (const auto &Pat : values(MatchPats)) {
1119     unsigned NumAlts = 0;
1120     // Note: technically, AnyOpcodePattern also needs permutations, but:
1121     //    - We only allow a single one of them in the root.
1122     //    - They cannot be mixed with any other pattern other than C++ code.
1123     // So we don't really need to take them into account here. We could, but
1124     // that pattern is a hack anyway and the less it's involved, the better.
1125     if (const auto *PFP = dyn_cast<PatFragPattern>(Pat.get()))
1126       NumAlts = PFP->getPatFrag().num_alternatives();
1127     else
1128       continue;
1129 
1130     // For each pattern that needs permutations, multiply the current set of
1131     // alternatives.
1132     auto CurPerms = PermutationsToEmit;
1133     PermutationsToEmit.clear();
1134 
1135     for (const auto &Perm : CurPerms) {
1136       assert(!Perm.count(Pat.get()) && "Pattern already emitted?");
1137       for (unsigned K = 0; K < NumAlts; ++K) {
1138         PatternAlternatives NewPerm = Perm;
1139         NewPerm[Pat.get()] = K;
1140         PermutationsToEmit.emplace_back(std::move(NewPerm));
1141       }
1142     }
1143   }
1144 
1145   if (int64_t MaxPerms = RuleDef.getValueAsInt("MaxPermutations");
1146       MaxPerms > 0) {
1147     if ((int64_t)PermutationsToEmit.size() > MaxPerms) {
1148       PrintError("cannot emit rule '" + RuleDef.getName() + "'; " +
1149                  Twine(PermutationsToEmit.size()) +
1150                  " permutations would be emitted, but the max is " +
1151                  Twine(MaxPerms));
1152       return false;
1153     }
1154   }
1155 
1156   // Ensure we always have a single empty entry, it simplifies the emission
1157   // logic so it doesn't need to handle the case where there are no perms.
1158   if (PermutationsToEmit.empty()) {
1159     PermutationsToEmit.emplace_back();
1160     return true;
1161   }
1162 
1163   return true;
1164 }
1165 
checkSemantics()1166 bool CombineRuleBuilder::checkSemantics() {
1167   assert(MatchRoot && "Cannot call this before findRoots()");
1168 
1169   bool UsesWipMatchOpcode = false;
1170   for (const auto &Match : MatchPats) {
1171     const auto *Pat = Match.second.get();
1172 
1173     if (const auto *CXXPat = dyn_cast<CXXPattern>(Pat)) {
1174       if (!CXXPat->getRawCode().contains("return "))
1175         PrintWarning("'match' C++ code does not seem to return!");
1176       continue;
1177     }
1178 
1179     // MIFlags in match cannot use the following syntax: (MIFlags $mi)
1180     if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(Pat)) {
1181       if (auto *FI = CGP->getMIFlagsInfo()) {
1182         if (!FI->copy_flags().empty()) {
1183           PrintError(
1184               "'match' patterns cannot refer to flags from other instructions");
1185           PrintNote("MIFlags in '" + CGP->getName() +
1186                     "' refer to: " + join(FI->copy_flags(), ", "));
1187           return false;
1188         }
1189       }
1190     }
1191 
1192     const auto *AOP = dyn_cast<AnyOpcodePattern>(Pat);
1193     if (!AOP)
1194       continue;
1195 
1196     if (UsesWipMatchOpcode) {
1197       PrintError("wip_opcode_match can only be present once");
1198       return false;
1199     }
1200 
1201     UsesWipMatchOpcode = true;
1202   }
1203 
1204   for (const auto &Apply : ApplyPats) {
1205     assert(Apply.second.get());
1206     const auto *IP = dyn_cast<InstructionPattern>(Apply.second.get());
1207     if (!IP)
1208       continue;
1209 
1210     if (UsesWipMatchOpcode) {
1211       PrintError("cannot use wip_match_opcode in combination with apply "
1212                  "instruction patterns!");
1213       return false;
1214     }
1215 
1216     // Check that the insts mentioned in copy_flags exist.
1217     if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(IP)) {
1218       if (auto *FI = CGP->getMIFlagsInfo()) {
1219         for (auto InstName : FI->copy_flags()) {
1220           auto It = MatchPats.find(InstName);
1221           if (It == MatchPats.end()) {
1222             PrintError("unknown instruction '$" + InstName +
1223                        "' referenced in MIFlags of '" + CGP->getName() + "'");
1224             return false;
1225           }
1226 
1227           if (!isa<CodeGenInstructionPattern>(It->second.get())) {
1228             PrintError(
1229                 "'$" + InstName +
1230                 "' does not refer to a CodeGenInstruction in MIFlags of '" +
1231                 CGP->getName() + "'");
1232             return false;
1233           }
1234         }
1235       }
1236     }
1237 
1238     const auto *BIP = dyn_cast<BuiltinPattern>(IP);
1239     if (!BIP)
1240       continue;
1241     StringRef Name = BIP->getInstName();
1242 
1243     // (GIEraseInst) has to be the only apply pattern, or it can not be used at
1244     // all. The root cannot have any defs either.
1245     switch (BIP->getBuiltinKind()) {
1246     case BI_EraseRoot: {
1247       if (ApplyPats.size() > 1) {
1248         PrintError(Name + " must be the only 'apply' pattern");
1249         return false;
1250       }
1251 
1252       const auto *IRoot = dyn_cast<CodeGenInstructionPattern>(MatchRoot);
1253       if (!IRoot) {
1254         PrintError(Name +
1255                    " can only be used if the root is a CodeGenInstruction");
1256         return false;
1257       }
1258 
1259       if (IRoot->getNumInstDefs() != 0) {
1260         PrintError(Name + " can only be used if on roots that do "
1261                           "not have any output operand");
1262         PrintNote("'" + IRoot->getInstName() + "' has " +
1263                   Twine(IRoot->getNumInstDefs()) + " output operands");
1264         return false;
1265       }
1266       break;
1267     }
1268     case BI_ReplaceReg: {
1269       // (GIReplaceReg can only be used on the root instruction)
1270       // TODO: When we allow rewriting non-root instructions, also allow this.
1271       StringRef OldRegName = BIP->getOperand(0).getOperandName();
1272       auto *Def = MatchOpTable.getDef(OldRegName);
1273       if (!Def) {
1274         PrintError(Name + " cannot find a matched pattern that defines '" +
1275                    OldRegName + "'");
1276         return false;
1277       }
1278       if (MatchOpTable.getDef(OldRegName) != MatchRoot) {
1279         PrintError(Name + " cannot replace '" + OldRegName +
1280                    "': this builtin can only replace a register defined by the "
1281                    "match root");
1282         return false;
1283       }
1284       break;
1285     }
1286     }
1287   }
1288 
1289   return true;
1290 }
1291 
addRuleMatcher(const PatternAlternatives & Alts,Twine AdditionalComment)1292 RuleMatcher &CombineRuleBuilder::addRuleMatcher(const PatternAlternatives &Alts,
1293                                                 Twine AdditionalComment) {
1294   auto &RM = OutRMs.emplace_back(RuleDef.getLoc());
1295   addFeaturePredicates(RM);
1296   RM.setPermanentGISelFlags(GISF_IgnoreCopies);
1297   RM.addRequiredSimplePredicate(getIsEnabledPredicateEnumName(RuleID));
1298 
1299   std::string Comment;
1300   raw_string_ostream CommentOS(Comment);
1301   CommentOS << "Combiner Rule #" << RuleID << ": " << RuleDef.getName();
1302   if (!Alts.empty()) {
1303     CommentOS << " @ ";
1304     print(CommentOS, Alts);
1305   }
1306   if (!AdditionalComment.isTriviallyEmpty())
1307     CommentOS << "; " << AdditionalComment;
1308   RM.addAction<DebugCommentAction>(Comment);
1309   return RM;
1310 }
1311 
addFeaturePredicates(RuleMatcher & M)1312 bool CombineRuleBuilder::addFeaturePredicates(RuleMatcher &M) {
1313   if (!RuleDef.getValue("Predicates"))
1314     return true;
1315 
1316   ListInit *Preds = RuleDef.getValueAsListInit("Predicates");
1317   for (Init *PI : Preds->getValues()) {
1318     DefInit *Pred = dyn_cast<DefInit>(PI);
1319     if (!Pred)
1320       continue;
1321 
1322     Record *Def = Pred->getDef();
1323     if (!Def->isSubClassOf("Predicate")) {
1324       ::PrintError(Def, "Unknown 'Predicate' Type");
1325       return false;
1326     }
1327 
1328     if (Def->getValueAsString("CondString").empty())
1329       continue;
1330 
1331     if (SubtargetFeatures.count(Def) == 0) {
1332       SubtargetFeatures.emplace(
1333           Def, SubtargetFeatureInfo(Def, SubtargetFeatures.size()));
1334     }
1335 
1336     M.addRequiredFeature(Def);
1337   }
1338 
1339   return true;
1340 }
1341 
findRoots()1342 bool CombineRuleBuilder::findRoots() {
1343   const auto Finish = [&]() {
1344     assert(MatchRoot);
1345 
1346     if (hasOnlyCXXApplyPatterns() || hasEraseRoot())
1347       return true;
1348 
1349     auto *IPRoot = dyn_cast<InstructionPattern>(MatchRoot);
1350     if (!IPRoot)
1351       return true;
1352 
1353     if (IPRoot->getNumInstDefs() == 0) {
1354       // No defs to work with -> find the root using the pattern name.
1355       auto It = ApplyPats.find(RootName);
1356       if (It == ApplyPats.end()) {
1357         PrintError("Cannot find root '" + RootName + "' in apply patterns!");
1358         return false;
1359       }
1360 
1361       auto *ApplyRoot = dyn_cast<InstructionPattern>(It->second.get());
1362       if (!ApplyRoot) {
1363         PrintError("apply pattern root '" + RootName +
1364                    "' must be an instruction pattern");
1365         return false;
1366       }
1367 
1368       ApplyRoots.insert(ApplyRoot);
1369       return true;
1370     }
1371 
1372     // Collect all redefinitions of the MatchRoot's defs and put them in
1373     // ApplyRoots.
1374     const auto DefsNeeded = IPRoot->getApplyDefsNeeded();
1375     for (auto &Op : DefsNeeded) {
1376       assert(Op.isDef() && Op.isNamedOperand());
1377       StringRef Name = Op.getOperandName();
1378 
1379       auto *ApplyRedef = ApplyOpTable.getDef(Name);
1380       if (!ApplyRedef) {
1381         PrintError("'" + Name + "' must be redefined in the 'apply' pattern");
1382         return false;
1383       }
1384 
1385       ApplyRoots.insert((InstructionPattern *)ApplyRedef);
1386     }
1387 
1388     if (auto It = ApplyPats.find(RootName); It != ApplyPats.end()) {
1389       if (find(ApplyRoots, It->second.get()) == ApplyRoots.end()) {
1390         PrintError("apply pattern '" + RootName +
1391                    "' is supposed to be a root but it does not redefine any of "
1392                    "the defs of the match root");
1393         return false;
1394       }
1395     }
1396 
1397     return true;
1398   };
1399 
1400   // Look by pattern name, e.g.
1401   //    (G_FNEG $x, $y):$root
1402   if (auto MatchPatIt = MatchPats.find(RootName);
1403       MatchPatIt != MatchPats.end()) {
1404     MatchRoot = MatchPatIt->second.get();
1405     return Finish();
1406   }
1407 
1408   // Look by def:
1409   //    (G_FNEG $root, $y)
1410   auto LookupRes = MatchOpTable.lookup(RootName);
1411   if (!LookupRes.Found) {
1412     PrintError("Cannot find root '" + RootName + "' in match patterns!");
1413     return false;
1414   }
1415 
1416   MatchRoot = LookupRes.Def;
1417   if (!MatchRoot) {
1418     PrintError("Cannot use live-in operand '" + RootName +
1419                "' as match pattern root!");
1420     return false;
1421   }
1422 
1423   return Finish();
1424 }
1425 
buildRuleOperandsTable()1426 bool CombineRuleBuilder::buildRuleOperandsTable() {
1427   const auto DiagnoseRedefMatch = [&](StringRef OpName) {
1428     PrintError("Operand '" + OpName +
1429                "' is defined multiple times in the 'match' patterns");
1430   };
1431 
1432   const auto DiagnoseRedefApply = [&](StringRef OpName) {
1433     PrintError("Operand '" + OpName +
1434                "' is defined multiple times in the 'apply' patterns");
1435   };
1436 
1437   for (auto &Pat : values(MatchPats)) {
1438     auto *IP = dyn_cast<InstructionPattern>(Pat.get());
1439     if (IP && !MatchOpTable.addPattern(IP, DiagnoseRedefMatch))
1440       return false;
1441   }
1442 
1443   for (auto &Pat : values(ApplyPats)) {
1444     auto *IP = dyn_cast<InstructionPattern>(Pat.get());
1445     if (IP && !ApplyOpTable.addPattern(IP, DiagnoseRedefApply))
1446       return false;
1447   }
1448 
1449   return true;
1450 }
1451 
parseDefs(const DagInit & Def)1452 bool CombineRuleBuilder::parseDefs(const DagInit &Def) {
1453   if (Def.getOperatorAsDef(RuleDef.getLoc())->getName() != "defs") {
1454     PrintError("Expected defs operator");
1455     return false;
1456   }
1457 
1458   SmallVector<StringRef> Roots;
1459   for (unsigned I = 0, E = Def.getNumArgs(); I < E; ++I) {
1460     if (isSpecificDef(*Def.getArg(I), "root")) {
1461       Roots.emplace_back(Def.getArgNameStr(I));
1462       continue;
1463     }
1464 
1465     // Subclasses of GIDefMatchData should declare that this rule needs to pass
1466     // data from the match stage to the apply stage, and ensure that the
1467     // generated matcher has a suitable variable for it to do so.
1468     if (Record *MatchDataRec =
1469             getDefOfSubClass(*Def.getArg(I), "GIDefMatchData")) {
1470       MatchDatas.emplace_back(Def.getArgNameStr(I),
1471                               MatchDataRec->getValueAsString("Type"));
1472       continue;
1473     }
1474 
1475     // Otherwise emit an appropriate error message.
1476     if (getDefOfSubClass(*Def.getArg(I), "GIDefKind"))
1477       PrintError("This GIDefKind not implemented in tablegen");
1478     else if (getDefOfSubClass(*Def.getArg(I), "GIDefKindWithArgs"))
1479       PrintError("This GIDefKindWithArgs not implemented in tablegen");
1480     else
1481       PrintError("Expected a subclass of GIDefKind or a sub-dag whose "
1482                  "operator is of type GIDefKindWithArgs");
1483     return false;
1484   }
1485 
1486   if (Roots.size() != 1) {
1487     PrintError("Combine rules must have exactly one root");
1488     return false;
1489   }
1490 
1491   RootName = Roots.front();
1492 
1493   // Assign variables to all MatchDatas.
1494   AssignMatchDataVariables(MatchDatas);
1495   return true;
1496 }
1497 
parsePatternList(const DagInit & List,function_ref<bool (std::unique_ptr<Pattern>)> ParseAction,StringRef Operator,ArrayRef<SMLoc> DiagLoc,StringRef AnonPatNamePrefix) const1498 bool CombineRuleBuilder::parsePatternList(
1499     const DagInit &List,
1500     function_ref<bool(std::unique_ptr<Pattern>)> ParseAction,
1501     StringRef Operator, ArrayRef<SMLoc> DiagLoc,
1502     StringRef AnonPatNamePrefix) const {
1503   if (List.getOperatorAsDef(RuleDef.getLoc())->getName() != Operator) {
1504     ::PrintError(DiagLoc, "Expected " + Operator + " operator");
1505     return false;
1506   }
1507 
1508   if (List.getNumArgs() == 0) {
1509     ::PrintError(DiagLoc, Operator + " pattern list is empty");
1510     return false;
1511   }
1512 
1513   // The match section consists of a list of matchers and predicates. Parse each
1514   // one and add the equivalent GIMatchDag nodes, predicates, and edges.
1515   for (unsigned I = 0; I < List.getNumArgs(); ++I) {
1516     Init *Arg = List.getArg(I);
1517     std::string Name = List.getArgName(I)
1518                            ? List.getArgName(I)->getValue().str()
1519                            : ("__" + AnonPatNamePrefix + "_" + Twine(I)).str();
1520 
1521     if (auto Pat = parseInstructionPattern(*Arg, Name)) {
1522       if (!ParseAction(std::move(Pat)))
1523         return false;
1524       continue;
1525     }
1526 
1527     if (auto Pat = parseWipMatchOpcodeMatcher(*Arg, Name)) {
1528       if (!ParseAction(std::move(Pat)))
1529         return false;
1530       continue;
1531     }
1532 
1533     // Parse arbitrary C++ code
1534     if (const auto *StringI = dyn_cast<StringInit>(Arg)) {
1535       auto CXXPat = std::make_unique<CXXPattern>(*StringI, insertStrRef(Name));
1536       if (!ParseAction(std::move(CXXPat)))
1537         return false;
1538       continue;
1539     }
1540 
1541     ::PrintError(DiagLoc,
1542                  "Failed to parse pattern: '" + Arg->getAsString() + "'");
1543     return false;
1544   }
1545 
1546   return true;
1547 }
1548 
1549 std::unique_ptr<Pattern>
parseInstructionPattern(const Init & Arg,StringRef Name) const1550 CombineRuleBuilder::parseInstructionPattern(const Init &Arg,
1551                                             StringRef Name) const {
1552   const DagInit *DagPat = dyn_cast<DagInit>(&Arg);
1553   if (!DagPat)
1554     return nullptr;
1555 
1556   std::unique_ptr<InstructionPattern> Pat;
1557   if (const DagInit *IP = getDagWithOperatorOfSubClass(Arg, "Instruction")) {
1558     auto &Instr = CGT.getInstruction(IP->getOperatorAsDef(RuleDef.getLoc()));
1559     Pat =
1560         std::make_unique<CodeGenInstructionPattern>(Instr, insertStrRef(Name));
1561   } else if (const DagInit *PFP =
1562                  getDagWithOperatorOfSubClass(Arg, PatFrag::ClassName)) {
1563     const Record *Def = PFP->getOperatorAsDef(RuleDef.getLoc());
1564     const PatFrag *PF = parsePatFrag(Def);
1565     if (!PF)
1566       return nullptr; // Already diagnosed by parsePatFrag
1567     Pat = std::make_unique<PatFragPattern>(*PF, insertStrRef(Name));
1568   } else if (const DagInit *BP =
1569                  getDagWithOperatorOfSubClass(Arg, BuiltinPattern::ClassName)) {
1570     Pat = std::make_unique<BuiltinPattern>(
1571         *BP->getOperatorAsDef(RuleDef.getLoc()), insertStrRef(Name));
1572   } else {
1573     return nullptr;
1574   }
1575 
1576   for (unsigned K = 0; K < DagPat->getNumArgs(); ++K) {
1577     Init *Arg = DagPat->getArg(K);
1578     if (auto *DagArg = getDagWithSpecificOperator(*Arg, "MIFlags")) {
1579       if (!parseInstructionPatternMIFlags(*Pat, DagArg))
1580         return nullptr;
1581       continue;
1582     }
1583 
1584     if (!parseInstructionPatternOperand(*Pat, Arg, DagPat->getArgName(K)))
1585       return nullptr;
1586   }
1587 
1588   if (!Pat->checkSemantics(RuleDef.getLoc()))
1589     return nullptr;
1590 
1591   return std::move(Pat);
1592 }
1593 
1594 std::unique_ptr<Pattern>
parseWipMatchOpcodeMatcher(const Init & Arg,StringRef Name) const1595 CombineRuleBuilder::parseWipMatchOpcodeMatcher(const Init &Arg,
1596                                                StringRef Name) const {
1597   const DagInit *Matcher = getDagWithSpecificOperator(Arg, "wip_match_opcode");
1598   if (!Matcher)
1599     return nullptr;
1600 
1601   if (Matcher->getNumArgs() == 0) {
1602     PrintError("Empty wip_match_opcode");
1603     return nullptr;
1604   }
1605 
1606   // Each argument is an opcode that can match.
1607   auto Result = std::make_unique<AnyOpcodePattern>(insertStrRef(Name));
1608   for (const auto &Arg : Matcher->getArgs()) {
1609     Record *OpcodeDef = getDefOfSubClass(*Arg, "Instruction");
1610     if (OpcodeDef) {
1611       Result->addOpcode(&CGT.getInstruction(OpcodeDef));
1612       continue;
1613     }
1614 
1615     PrintError("Arguments to wip_match_opcode must be instructions");
1616     return nullptr;
1617   }
1618 
1619   return std::move(Result);
1620 }
1621 
parseInstructionPatternOperand(InstructionPattern & IP,const Init * OpInit,const StringInit * OpName) const1622 bool CombineRuleBuilder::parseInstructionPatternOperand(
1623     InstructionPattern &IP, const Init *OpInit,
1624     const StringInit *OpName) const {
1625   const auto ParseErr = [&]() {
1626     PrintError("cannot parse operand '" + OpInit->getAsUnquotedString() + "' ");
1627     if (OpName)
1628       PrintNote("operand name is '" + OpName->getAsUnquotedString() + "'");
1629     return false;
1630   };
1631 
1632   // untyped immediate, e.g. 0
1633   if (const auto *IntImm = dyn_cast<IntInit>(OpInit)) {
1634     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
1635     IP.addOperand(IntImm->getValue(), insertStrRef(Name), PatternType());
1636     return true;
1637   }
1638 
1639   // typed immediate, e.g. (i32 0)
1640   if (const auto *DagOp = dyn_cast<DagInit>(OpInit)) {
1641     if (DagOp->getNumArgs() != 1)
1642       return ParseErr();
1643 
1644     const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc());
1645     auto ImmTy = PatternType::get(RuleDef.getLoc(), TyDef,
1646                                   "cannot parse immediate '" +
1647                                       DagOp->getAsUnquotedString() + "'");
1648     if (!ImmTy)
1649       return false;
1650 
1651     if (!IP.hasAllDefs()) {
1652       PrintError("out operand of '" + IP.getInstName() +
1653                  "' cannot be an immediate");
1654       return false;
1655     }
1656 
1657     const auto *Val = dyn_cast<IntInit>(DagOp->getArg(0));
1658     if (!Val)
1659       return ParseErr();
1660 
1661     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
1662     IP.addOperand(Val->getValue(), insertStrRef(Name), *ImmTy);
1663     return true;
1664   }
1665 
1666   // Typed operand e.g. $x/$z in (G_FNEG $x, $z)
1667   if (auto *DefI = dyn_cast<DefInit>(OpInit)) {
1668     if (!OpName) {
1669       PrintError("expected an operand name after '" + OpInit->getAsString() +
1670                  "'");
1671       return false;
1672     }
1673     const Record *Def = DefI->getDef();
1674     auto Ty =
1675         PatternType::get(RuleDef.getLoc(), Def, "cannot parse operand type");
1676     if (!Ty)
1677       return false;
1678     IP.addOperand(insertStrRef(OpName->getAsUnquotedString()), *Ty);
1679     return true;
1680   }
1681 
1682   // Untyped operand e.g. $x/$z in (G_FNEG $x, $z)
1683   if (isa<UnsetInit>(OpInit)) {
1684     assert(OpName && "Unset w/ no OpName?");
1685     IP.addOperand(insertStrRef(OpName->getAsUnquotedString()), PatternType());
1686     return true;
1687   }
1688 
1689   return ParseErr();
1690 }
1691 
parseInstructionPatternMIFlags(InstructionPattern & IP,const DagInit * Op) const1692 bool CombineRuleBuilder::parseInstructionPatternMIFlags(
1693     InstructionPattern &IP, const DagInit *Op) const {
1694   auto *CGIP = dyn_cast<CodeGenInstructionPattern>(&IP);
1695   if (!CGIP) {
1696     PrintError("matching/writing MIFlags is only allowed on CodeGenInstruction "
1697                "patterns");
1698     return false;
1699   }
1700 
1701   const auto CheckFlagEnum = [&](const Record *R) {
1702     if (!R->isSubClassOf(MIFlagsEnumClassName)) {
1703       PrintError("'" + R->getName() + "' is not a subclass of '" +
1704                  MIFlagsEnumClassName + "'");
1705       return false;
1706     }
1707 
1708     return true;
1709   };
1710 
1711   if (CGIP->getMIFlagsInfo()) {
1712     PrintError("MIFlags can only be present once on an instruction");
1713     return false;
1714   }
1715 
1716   auto &FI = CGIP->getOrCreateMIFlagsInfo();
1717   for (unsigned K = 0; K < Op->getNumArgs(); ++K) {
1718     const Init *Arg = Op->getArg(K);
1719 
1720     // Match/set a flag: (MIFlags FmNoNans)
1721     if (const auto *Def = dyn_cast<DefInit>(Arg)) {
1722       const Record *R = Def->getDef();
1723       if (!CheckFlagEnum(R))
1724         return false;
1725 
1726       FI.addSetFlag(R);
1727       continue;
1728     }
1729 
1730     // Do not match a flag/unset a flag: (MIFlags (not FmNoNans))
1731     if (const DagInit *NotDag = getDagWithSpecificOperator(*Arg, "not")) {
1732       for (const Init *NotArg : NotDag->getArgs()) {
1733         const DefInit *DefArg = dyn_cast<DefInit>(NotArg);
1734         if (!DefArg) {
1735           PrintError("cannot parse '" + NotArg->getAsUnquotedString() +
1736                      "': expected a '" + MIFlagsEnumClassName + "'");
1737           return false;
1738         }
1739 
1740         const Record *R = DefArg->getDef();
1741         if (!CheckFlagEnum(R))
1742           return false;
1743 
1744         FI.addUnsetFlag(R);
1745         continue;
1746       }
1747 
1748       continue;
1749     }
1750 
1751     // Copy flags from a matched instruction: (MIFlags $mi)
1752     if (isa<UnsetInit>(Arg)) {
1753       FI.addCopyFlag(insertStrRef(Op->getArgName(K)->getAsUnquotedString()));
1754       continue;
1755     }
1756   }
1757 
1758   return true;
1759 }
1760 
1761 std::unique_ptr<PatFrag>
parsePatFragImpl(const Record * Def) const1762 CombineRuleBuilder::parsePatFragImpl(const Record *Def) const {
1763   auto StackTrace = PrettyStackTraceParse(*Def);
1764   if (!Def->isSubClassOf(PatFrag::ClassName))
1765     return nullptr;
1766 
1767   const DagInit *Ins = Def->getValueAsDag("InOperands");
1768   if (Ins->getOperatorAsDef(Def->getLoc())->getName() != "ins") {
1769     ::PrintError(Def, "expected 'ins' operator for " + PatFrag::ClassName +
1770                           " in operands list");
1771     return nullptr;
1772   }
1773 
1774   const DagInit *Outs = Def->getValueAsDag("OutOperands");
1775   if (Outs->getOperatorAsDef(Def->getLoc())->getName() != "outs") {
1776     ::PrintError(Def, "expected 'outs' operator for " + PatFrag::ClassName +
1777                           " out operands list");
1778     return nullptr;
1779   }
1780 
1781   auto Result = std::make_unique<PatFrag>(*Def);
1782   if (!parsePatFragParamList(Def->getLoc(), *Outs,
1783                              [&](StringRef Name, PatFrag::ParamKind Kind) {
1784                                Result->addOutParam(insertStrRef(Name), Kind);
1785                                return true;
1786                              }))
1787     return nullptr;
1788 
1789   if (!parsePatFragParamList(Def->getLoc(), *Ins,
1790                              [&](StringRef Name, PatFrag::ParamKind Kind) {
1791                                Result->addInParam(insertStrRef(Name), Kind);
1792                                return true;
1793                              }))
1794     return nullptr;
1795 
1796   const ListInit *Alts = Def->getValueAsListInit("Alternatives");
1797   unsigned AltIdx = 0;
1798   for (const Init *Alt : *Alts) {
1799     const auto *PatDag = dyn_cast<DagInit>(Alt);
1800     if (!PatDag) {
1801       ::PrintError(Def, "expected dag init for PatFrag pattern alternative");
1802       return nullptr;
1803     }
1804 
1805     PatFrag::Alternative &A = Result->addAlternative();
1806     const auto AddPat = [&](std::unique_ptr<Pattern> Pat) {
1807       A.Pats.push_back(std::move(Pat));
1808       return true;
1809     };
1810 
1811     if (!parsePatternList(
1812             *PatDag, AddPat, "pattern", Def->getLoc(),
1813             /*AnonPatPrefix*/
1814             (Def->getName() + "_alt" + Twine(AltIdx++) + "_pattern").str()))
1815       return nullptr;
1816   }
1817 
1818   if (!Result->buildOperandsTables() || !Result->checkSemantics())
1819     return nullptr;
1820 
1821   return Result;
1822 }
1823 
parsePatFragParamList(ArrayRef<SMLoc> DiagLoc,const DagInit & OpsList,function_ref<bool (StringRef,PatFrag::ParamKind)> ParseAction) const1824 bool CombineRuleBuilder::parsePatFragParamList(
1825     ArrayRef<SMLoc> DiagLoc, const DagInit &OpsList,
1826     function_ref<bool(StringRef, PatFrag::ParamKind)> ParseAction) const {
1827   for (unsigned K = 0; K < OpsList.getNumArgs(); ++K) {
1828     const StringInit *Name = OpsList.getArgName(K);
1829     const Init *Ty = OpsList.getArg(K);
1830 
1831     if (!Name) {
1832       ::PrintError(DiagLoc, "all operands must be named'");
1833       return false;
1834     }
1835     const std::string NameStr = Name->getAsUnquotedString();
1836 
1837     PatFrag::ParamKind OpKind;
1838     if (isSpecificDef(*Ty, "gi_imm"))
1839       OpKind = PatFrag::PK_Imm;
1840     else if (isSpecificDef(*Ty, "root"))
1841       OpKind = PatFrag::PK_Root;
1842     else if (isa<UnsetInit>(Ty) ||
1843              isSpecificDef(*Ty, "gi_mo")) // no type = gi_mo.
1844       OpKind = PatFrag::PK_MachineOperand;
1845     else {
1846       ::PrintError(
1847           DiagLoc,
1848           "'" + NameStr +
1849               "' operand type was expected to be 'root', 'gi_imm' or 'gi_mo'");
1850       return false;
1851     }
1852 
1853     if (!ParseAction(NameStr, OpKind))
1854       return false;
1855   }
1856 
1857   return true;
1858 }
1859 
parsePatFrag(const Record * Def) const1860 const PatFrag *CombineRuleBuilder::parsePatFrag(const Record *Def) const {
1861   // Cache already parsed PatFrags to avoid doing extra work.
1862   static DenseMap<const Record *, std::unique_ptr<PatFrag>> ParsedPatFrags;
1863 
1864   auto It = ParsedPatFrags.find(Def);
1865   if (It != ParsedPatFrags.end()) {
1866     SeenPatFrags.insert(It->second.get());
1867     return It->second.get();
1868   }
1869 
1870   std::unique_ptr<PatFrag> NewPatFrag = parsePatFragImpl(Def);
1871   if (!NewPatFrag) {
1872     ::PrintError(Def, "Could not parse " + PatFrag::ClassName + " '" +
1873                           Def->getName() + "'");
1874     // Put a nullptr in the map so we don't attempt parsing this again.
1875     ParsedPatFrags[Def] = nullptr;
1876     return nullptr;
1877   }
1878 
1879   const auto *Res = NewPatFrag.get();
1880   ParsedPatFrags[Def] = std::move(NewPatFrag);
1881   SeenPatFrags.insert(Res);
1882   return Res;
1883 }
1884 
emitMatchPattern(CodeExpansions & CE,const PatternAlternatives & Alts,const InstructionPattern & IP)1885 bool CombineRuleBuilder::emitMatchPattern(CodeExpansions &CE,
1886                                           const PatternAlternatives &Alts,
1887                                           const InstructionPattern &IP) {
1888   auto StackTrace = PrettyStackTraceEmit(RuleDef, &IP);
1889 
1890   auto &M = addRuleMatcher(Alts);
1891   InstructionMatcher &IM = M.addInstructionMatcher(IP.getName());
1892   declareInstExpansion(CE, IM, IP.getName());
1893 
1894   DenseSet<const Pattern *> SeenPats;
1895 
1896   const auto FindOperandDef = [&](StringRef Op) -> InstructionPattern * {
1897     return MatchOpTable.getDef(Op);
1898   };
1899 
1900   if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP)) {
1901     if (!emitCodeGenInstructionMatchPattern(CE, Alts, M, IM, *CGP, SeenPats,
1902                                             FindOperandDef))
1903       return false;
1904   } else if (const auto *PFP = dyn_cast<PatFragPattern>(&IP)) {
1905     if (!PFP->getPatFrag().canBeMatchRoot()) {
1906       PrintError("cannot use '" + PFP->getInstName() + " as match root");
1907       return false;
1908     }
1909 
1910     if (!emitPatFragMatchPattern(CE, Alts, M, &IM, *PFP, SeenPats))
1911       return false;
1912   } else if (isa<BuiltinPattern>(&IP)) {
1913     llvm_unreachable("No match builtins known!");
1914   } else
1915     llvm_unreachable("Unknown kind of InstructionPattern!");
1916 
1917   // Emit remaining patterns
1918   for (auto &Pat : values(MatchPats)) {
1919     if (SeenPats.contains(Pat.get()))
1920       continue;
1921 
1922     switch (Pat->getKind()) {
1923     case Pattern::K_AnyOpcode:
1924       PrintError("wip_match_opcode can not be used with instruction patterns!");
1925       return false;
1926     case Pattern::K_PatFrag: {
1927       if (!emitPatFragMatchPattern(CE, Alts, M, /*IM*/ nullptr,
1928                                    *cast<PatFragPattern>(Pat.get()), SeenPats))
1929         return false;
1930       continue;
1931     }
1932     case Pattern::K_Builtin:
1933       PrintError("No known match builtins");
1934       return false;
1935     case Pattern::K_CodeGenInstruction:
1936       cast<InstructionPattern>(Pat.get())->reportUnreachable(RuleDef.getLoc());
1937       return false;
1938     case Pattern::K_CXX: {
1939       addCXXPredicate(M, CE, *cast<CXXPattern>(Pat.get()), Alts);
1940       continue;
1941     }
1942     default:
1943       llvm_unreachable("unknown pattern kind!");
1944     }
1945   }
1946 
1947   return emitApplyPatterns(CE, M);
1948 }
1949 
emitMatchPattern(CodeExpansions & CE,const PatternAlternatives & Alts,const AnyOpcodePattern & AOP)1950 bool CombineRuleBuilder::emitMatchPattern(CodeExpansions &CE,
1951                                           const PatternAlternatives &Alts,
1952                                           const AnyOpcodePattern &AOP) {
1953   auto StackTrace = PrettyStackTraceEmit(RuleDef, &AOP);
1954 
1955   for (const CodeGenInstruction *CGI : AOP.insts()) {
1956     auto &M = addRuleMatcher(Alts, "wip_match_opcode '" +
1957                                        CGI->TheDef->getName() + "'");
1958 
1959     InstructionMatcher &IM = M.addInstructionMatcher(AOP.getName());
1960     declareInstExpansion(CE, IM, AOP.getName());
1961     // declareInstExpansion needs to be identical, otherwise we need to create a
1962     // CodeExpansions object here instead.
1963     assert(IM.getInsnVarID() == 0);
1964 
1965     IM.addPredicate<InstructionOpcodeMatcher>(CGI);
1966 
1967     // Emit remaining patterns.
1968     for (auto &Pat : values(MatchPats)) {
1969       if (Pat.get() == &AOP)
1970         continue;
1971 
1972       switch (Pat->getKind()) {
1973       case Pattern::K_AnyOpcode:
1974         PrintError("wip_match_opcode can only be present once!");
1975         return false;
1976       case Pattern::K_PatFrag: {
1977         DenseSet<const Pattern *> SeenPats;
1978         if (!emitPatFragMatchPattern(CE, Alts, M, /*IM*/ nullptr,
1979                                      *cast<PatFragPattern>(Pat.get()),
1980                                      SeenPats))
1981           return false;
1982         continue;
1983       }
1984       case Pattern::K_Builtin:
1985         PrintError("No known match builtins");
1986         return false;
1987       case Pattern::K_CodeGenInstruction:
1988         cast<InstructionPattern>(Pat.get())->reportUnreachable(
1989             RuleDef.getLoc());
1990         return false;
1991       case Pattern::K_CXX: {
1992         addCXXPredicate(M, CE, *cast<CXXPattern>(Pat.get()), Alts);
1993         break;
1994       }
1995       default:
1996         llvm_unreachable("unknown pattern kind!");
1997       }
1998     }
1999 
2000     if (!emitApplyPatterns(CE, M))
2001       return false;
2002   }
2003 
2004   return true;
2005 }
2006 
emitPatFragMatchPattern(CodeExpansions & CE,const PatternAlternatives & Alts,RuleMatcher & RM,InstructionMatcher * IM,const PatFragPattern & PFP,DenseSet<const Pattern * > & SeenPats)2007 bool CombineRuleBuilder::emitPatFragMatchPattern(
2008     CodeExpansions &CE, const PatternAlternatives &Alts, RuleMatcher &RM,
2009     InstructionMatcher *IM, const PatFragPattern &PFP,
2010     DenseSet<const Pattern *> &SeenPats) {
2011   auto StackTrace = PrettyStackTraceEmit(RuleDef, &PFP);
2012 
2013   if (SeenPats.contains(&PFP))
2014     return true;
2015   SeenPats.insert(&PFP);
2016 
2017   const auto &PF = PFP.getPatFrag();
2018 
2019   if (!IM) {
2020     // When we don't have an IM, this means this PatFrag isn't reachable from
2021     // the root. This is only acceptable if it doesn't define anything (e.g. a
2022     // pure C++ PatFrag).
2023     if (PF.num_out_params() != 0) {
2024       PFP.reportUnreachable(RuleDef.getLoc());
2025       return false;
2026     }
2027   } else {
2028     // When an IM is provided, this is reachable from the root, and we're
2029     // expecting to have output operands.
2030     // TODO: If we want to allow for multiple roots we'll need a map of IMs
2031     // then, and emission becomes a bit more complicated.
2032     assert(PF.num_roots() == 1);
2033   }
2034 
2035   CodeExpansions PatFragCEs;
2036   if (!PFP.mapInputCodeExpansions(CE, PatFragCEs, RuleDef.getLoc()))
2037     return false;
2038 
2039   // List of {ParamName, ArgName}.
2040   // When all patterns have been emitted, find expansions in PatFragCEs named
2041   // ArgName and add their expansion to CE using ParamName as the key.
2042   SmallVector<std::pair<std::string, std::string>, 4> CEsToImport;
2043 
2044   // Map parameter names to the actual argument.
2045   const auto OperandMapper =
2046       [&](const InstructionOperand &O) -> InstructionOperand {
2047     if (!O.isNamedOperand())
2048       return O;
2049 
2050     StringRef ParamName = O.getOperandName();
2051 
2052     // Not sure what to do with those tbh. They should probably never be here.
2053     assert(!O.isNamedImmediate() && "TODO: handle named imms");
2054     unsigned PIdx = PF.getParamIdx(ParamName);
2055 
2056     // Map parameters to the argument values.
2057     if (PIdx == (unsigned)-1) {
2058       // This is a temp of the PatFragPattern, prefix the name to avoid
2059       // conflicts.
2060       return O.withNewName(
2061           insertStrRef((PFP.getName() + "." + ParamName).str()));
2062     }
2063 
2064     // The operand will be added to PatFragCEs's code expansions using the
2065     // parameter's name. If it's bound to some operand during emission of the
2066     // patterns, we'll want to add it to CE.
2067     auto ArgOp = PFP.getOperand(PIdx);
2068     if (ArgOp.isNamedOperand())
2069       CEsToImport.emplace_back(ArgOp.getOperandName().str(), ParamName);
2070 
2071     if (ArgOp.getType() && O.getType() && ArgOp.getType() != O.getType()) {
2072       StringRef PFName = PF.getName();
2073       PrintWarning("impossible type constraints: operand " + Twine(PIdx) +
2074                    " of '" + PFP.getName() + "' has type '" +
2075                    ArgOp.getType().str() + "', but '" + PFName +
2076                    "' constrains it to '" + O.getType().str() + "'");
2077       if (ArgOp.isNamedOperand())
2078         PrintNote("operand " + Twine(PIdx) + " of '" + PFP.getName() +
2079                   "' is '" + ArgOp.getOperandName() + "'");
2080       if (O.isNamedOperand())
2081         PrintNote("argument " + Twine(PIdx) + " of '" + PFName + "' is '" +
2082                   ParamName + "'");
2083     }
2084 
2085     return ArgOp;
2086   };
2087 
2088   // PatFragPatterns are only made of InstructionPatterns or CXXPatterns.
2089   // Emit instructions from the root.
2090   const auto &FragAlt = PF.getAlternative(Alts.lookup(&PFP));
2091   const auto &FragAltOT = FragAlt.OpTable;
2092   const auto LookupOperandDef =
2093       [&](StringRef Op) -> const InstructionPattern * {
2094     return FragAltOT.getDef(Op);
2095   };
2096 
2097   DenseSet<const Pattern *> PatFragSeenPats;
2098   for (const auto &[Idx, InOp] : enumerate(PF.out_params())) {
2099     if (InOp.Kind != PatFrag::PK_Root)
2100       continue;
2101 
2102     StringRef ParamName = InOp.Name;
2103     const auto *Def = FragAltOT.getDef(ParamName);
2104     assert(Def && "PatFrag::checkSemantics should have emitted an error if "
2105                   "an out operand isn't defined!");
2106     assert(isa<CodeGenInstructionPattern>(Def) &&
2107            "Nested PatFrags not supported yet");
2108 
2109     if (!emitCodeGenInstructionMatchPattern(
2110             PatFragCEs, Alts, RM, *IM, *cast<CodeGenInstructionPattern>(Def),
2111             PatFragSeenPats, LookupOperandDef, OperandMapper))
2112       return false;
2113   }
2114 
2115   // Emit leftovers.
2116   for (const auto &Pat : FragAlt.Pats) {
2117     if (PatFragSeenPats.contains(Pat.get()))
2118       continue;
2119 
2120     if (const auto *CXXPat = dyn_cast<CXXPattern>(Pat.get())) {
2121       addCXXPredicate(RM, PatFragCEs, *CXXPat, Alts);
2122       continue;
2123     }
2124 
2125     if (const auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
2126       IP->reportUnreachable(PF.getLoc());
2127       return false;
2128     }
2129 
2130     llvm_unreachable("Unexpected pattern kind in PatFrag");
2131   }
2132 
2133   for (const auto &[ParamName, ArgName] : CEsToImport) {
2134     // Note: we're find if ParamName already exists. It just means it's been
2135     // bound before, so we prefer to keep the first binding.
2136     CE.declare(ParamName, PatFragCEs.lookup(ArgName));
2137   }
2138 
2139   return true;
2140 }
2141 
emitApplyPatterns(CodeExpansions & CE,RuleMatcher & M)2142 bool CombineRuleBuilder::emitApplyPatterns(CodeExpansions &CE, RuleMatcher &M) {
2143   if (hasOnlyCXXApplyPatterns()) {
2144     for (auto &Pat : values(ApplyPats))
2145       addCXXAction(M, CE, *cast<CXXPattern>(Pat.get()));
2146     return true;
2147   }
2148 
2149   DenseSet<const Pattern *> SeenPats;
2150   StringMap<unsigned> OperandToTempRegID;
2151 
2152   for (auto *ApplyRoot : ApplyRoots) {
2153     assert(isa<InstructionPattern>(ApplyRoot) &&
2154            "Root can only be a InstructionPattern!");
2155     if (!emitInstructionApplyPattern(CE, M,
2156                                      cast<InstructionPattern>(*ApplyRoot),
2157                                      SeenPats, OperandToTempRegID))
2158       return false;
2159   }
2160 
2161   for (auto &Pat : values(ApplyPats)) {
2162     if (SeenPats.contains(Pat.get()))
2163       continue;
2164 
2165     switch (Pat->getKind()) {
2166     case Pattern::K_AnyOpcode:
2167       llvm_unreachable("Unexpected pattern in apply!");
2168     case Pattern::K_PatFrag:
2169       // TODO: We could support pure C++ PatFrags as a temporary thing.
2170       llvm_unreachable("Unexpected pattern in apply!");
2171     case Pattern::K_Builtin:
2172       if (!emitInstructionApplyPattern(CE, M, cast<BuiltinPattern>(*Pat),
2173                                        SeenPats, OperandToTempRegID))
2174         return false;
2175       break;
2176     case Pattern::K_CodeGenInstruction:
2177       cast<CodeGenInstructionPattern>(*Pat).reportUnreachable(RuleDef.getLoc());
2178       return false;
2179     case Pattern::K_CXX: {
2180       addCXXAction(M, CE, *cast<CXXPattern>(Pat.get()));
2181       continue;
2182     }
2183     default:
2184       llvm_unreachable("unknown pattern kind!");
2185     }
2186   }
2187 
2188   // Erase the root.
2189   unsigned RootInsnID =
2190       M.getInsnVarID(M.getInstructionMatcher(MatchRoot->getName()));
2191   M.addAction<EraseInstAction>(RootInsnID);
2192 
2193   return true;
2194 }
2195 
emitInstructionApplyPattern(CodeExpansions & CE,RuleMatcher & M,const InstructionPattern & P,DenseSet<const Pattern * > & SeenPats,StringMap<unsigned> & OperandToTempRegID)2196 bool CombineRuleBuilder::emitInstructionApplyPattern(
2197     CodeExpansions &CE, RuleMatcher &M, const InstructionPattern &P,
2198     DenseSet<const Pattern *> &SeenPats,
2199     StringMap<unsigned> &OperandToTempRegID) {
2200   auto StackTrace = PrettyStackTraceEmit(RuleDef, &P);
2201 
2202   if (SeenPats.contains(&P))
2203     return true;
2204 
2205   SeenPats.insert(&P);
2206 
2207   // First, render the uses.
2208   for (auto &Op : P.named_operands()) {
2209     if (Op.isDef())
2210       continue;
2211 
2212     StringRef OpName = Op.getOperandName();
2213     if (const auto *DefPat = ApplyOpTable.getDef(OpName)) {
2214       if (!emitInstructionApplyPattern(CE, M, *DefPat, SeenPats,
2215                                        OperandToTempRegID))
2216         return false;
2217     } else {
2218       // If we have no def, check this exists in the MatchRoot.
2219       if (!Op.isNamedImmediate() && !MatchOpTable.lookup(OpName).Found) {
2220         PrintError("invalid output operand '" + OpName +
2221                    "': operand is not a live-in of the match pattern, and it "
2222                    "has no definition");
2223         return false;
2224       }
2225     }
2226   }
2227 
2228   if (const auto *BP = dyn_cast<BuiltinPattern>(&P))
2229     return emitBuiltinApplyPattern(CE, M, *BP, OperandToTempRegID);
2230 
2231   if (isa<PatFragPattern>(&P))
2232     llvm_unreachable("PatFragPatterns is not supported in 'apply'!");
2233 
2234   auto &CGIP = cast<CodeGenInstructionPattern>(P);
2235 
2236   // Now render this inst.
2237   auto &DstMI =
2238       M.addAction<BuildMIAction>(M.allocateOutputInsnID(), &CGIP.getInst());
2239 
2240   for (auto &Op : P.operands()) {
2241     if (Op.isNamedImmediate()) {
2242       PrintError("invalid output operand '" + Op.getOperandName() +
2243                  "': output immediates cannot be named");
2244       PrintNote("while emitting pattern '" + P.getName() + "' (" +
2245                 P.getInstName() + ")");
2246       return false;
2247     }
2248 
2249     if (Op.hasImmValue()) {
2250       if (!emitCodeGenInstructionApplyImmOperand(M, DstMI, CGIP, Op))
2251         return false;
2252       continue;
2253     }
2254 
2255     StringRef OpName = Op.getOperandName();
2256 
2257     // Uses of operand.
2258     if (!Op.isDef()) {
2259       if (auto It = OperandToTempRegID.find(OpName);
2260           It != OperandToTempRegID.end()) {
2261         assert(!MatchOpTable.lookup(OpName).Found &&
2262                "Temp reg is also from match pattern?");
2263         DstMI.addRenderer<TempRegRenderer>(It->second);
2264       } else {
2265         // This should be a match live in or a redef of a matched instr.
2266         // If it's a use of a temporary register, then we messed up somewhere -
2267         // the previous condition should have passed.
2268         assert(MatchOpTable.lookup(OpName).Found &&
2269                !ApplyOpTable.getDef(OpName) && "Temp reg not emitted yet!");
2270         DstMI.addRenderer<CopyRenderer>(OpName);
2271       }
2272       continue;
2273     }
2274 
2275     // Determine what we're dealing with. Are we replace a matched instruction?
2276     // Creating a new one?
2277     auto OpLookupRes = MatchOpTable.lookup(OpName);
2278     if (OpLookupRes.Found) {
2279       if (OpLookupRes.isLiveIn()) {
2280         // live-in of the match pattern.
2281         PrintError("Cannot define live-in operand '" + OpName +
2282                    "' in the 'apply' pattern");
2283         return false;
2284       }
2285       assert(OpLookupRes.Def);
2286 
2287       // TODO: Handle this. We need to mutate the instr, or delete the old
2288       // one.
2289       //       Likewise, we also need to ensure we redef everything, if the
2290       //       instr has more than one def, we need to redef all or nothing.
2291       if (OpLookupRes.Def != MatchRoot) {
2292         PrintError("redefining an instruction other than the root is not "
2293                    "supported (operand '" +
2294                    OpName + "')");
2295         return false;
2296       }
2297       // redef of a match
2298       DstMI.addRenderer<CopyRenderer>(OpName);
2299       continue;
2300     }
2301 
2302     // Define a new register unique to the apply patterns (AKA a "temp"
2303     // register).
2304     unsigned TempRegID;
2305     if (auto It = OperandToTempRegID.find(OpName);
2306         It != OperandToTempRegID.end()) {
2307       TempRegID = It->second;
2308     } else {
2309       // This is a brand new register.
2310       TempRegID = M.allocateTempRegID();
2311       OperandToTempRegID[OpName] = TempRegID;
2312       const auto Ty = Op.getType();
2313       if (!Ty) {
2314         PrintError("def of a new register '" + OpName +
2315                    "' in the apply patterns must have a type");
2316         return false;
2317       }
2318 
2319       declareTempRegExpansion(CE, TempRegID, OpName);
2320       // Always insert the action at the beginning, otherwise we may end up
2321       // using the temp reg before it's available.
2322       M.insertAction<MakeTempRegisterAction>(
2323           M.actions_begin(), getLLTCodeGenOrTempType(Ty, M), TempRegID);
2324     }
2325 
2326     DstMI.addRenderer<TempRegRenderer>(TempRegID, /*IsDef=*/true);
2327   }
2328 
2329   // Render MIFlags
2330   if (const auto *FI = CGIP.getMIFlagsInfo()) {
2331     for (StringRef InstName : FI->copy_flags())
2332       DstMI.addCopiedMIFlags(M.getInstructionMatcher(InstName));
2333     for (StringRef F : FI->set_flags())
2334       DstMI.addSetMIFlags(F);
2335     for (StringRef F : FI->unset_flags())
2336       DstMI.addUnsetMIFlags(F);
2337   }
2338 
2339   // Don't allow mutating opcodes for GISel combiners. We want a more precise
2340   // handling of MIFlags so we require them to be explicitly preserved.
2341   //
2342   // TODO: We don't mutate very often, if at all in combiners, but it'd be nice
2343   // to re-enable this. We'd then need to always clear MIFlags when mutating
2344   // opcodes, and never mutate an inst that we copy flags from.
2345   // DstMI.chooseInsnToMutate(M);
2346   declareInstExpansion(CE, DstMI, P.getName());
2347 
2348   return true;
2349 }
2350 
emitCodeGenInstructionApplyImmOperand(RuleMatcher & M,BuildMIAction & DstMI,const CodeGenInstructionPattern & P,const InstructionOperand & O)2351 bool CombineRuleBuilder::emitCodeGenInstructionApplyImmOperand(
2352     RuleMatcher &M, BuildMIAction &DstMI, const CodeGenInstructionPattern &P,
2353     const InstructionOperand &O) {
2354   // If we have a type, we implicitly emit a G_CONSTANT, except for G_CONSTANT
2355   // itself where we emit a CImm.
2356   //
2357   // No type means we emit a simple imm.
2358   // G_CONSTANT is a special case and needs a CImm though so this is likely a
2359   // mistake.
2360   const bool isGConstant = P.is("G_CONSTANT");
2361   const auto Ty = O.getType();
2362   if (!Ty) {
2363     if (isGConstant) {
2364       PrintError("'G_CONSTANT' immediate must be typed!");
2365       PrintNote("while emitting pattern '" + P.getName() + "' (" +
2366                 P.getInstName() + ")");
2367       return false;
2368     }
2369 
2370     DstMI.addRenderer<ImmRenderer>(O.getImmValue());
2371     return true;
2372   }
2373 
2374   auto ImmTy = getLLTCodeGenOrTempType(Ty, M);
2375 
2376   if (isGConstant) {
2377     DstMI.addRenderer<ImmRenderer>(O.getImmValue(), ImmTy);
2378     return true;
2379   }
2380 
2381   unsigned TempRegID = M.allocateTempRegID();
2382   // Ensure MakeTempReg & the BuildConstantAction occur at the beginning.
2383   auto InsertIt = M.insertAction<MakeTempRegisterAction>(M.actions_begin(),
2384                                                          ImmTy, TempRegID);
2385   M.insertAction<BuildConstantAction>(++InsertIt, TempRegID, O.getImmValue());
2386   DstMI.addRenderer<TempRegRenderer>(TempRegID);
2387   return true;
2388 }
2389 
emitBuiltinApplyPattern(CodeExpansions & CE,RuleMatcher & M,const BuiltinPattern & P,StringMap<unsigned> & OperandToTempRegID)2390 bool CombineRuleBuilder::emitBuiltinApplyPattern(
2391     CodeExpansions &CE, RuleMatcher &M, const BuiltinPattern &P,
2392     StringMap<unsigned> &OperandToTempRegID) {
2393   const auto Error = [&](Twine Reason) {
2394     PrintError("cannot emit '" + P.getInstName() + "' builtin: " + Reason);
2395     return false;
2396   };
2397 
2398   switch (P.getBuiltinKind()) {
2399   case BI_EraseRoot: {
2400     // Root is always inst 0.
2401     M.addAction<EraseInstAction>(/*InsnID*/ 0);
2402     return true;
2403   }
2404   case BI_ReplaceReg: {
2405     StringRef Old = P.getOperand(0).getOperandName();
2406     StringRef New = P.getOperand(1).getOperandName();
2407 
2408     if (!ApplyOpTable.lookup(New).Found && !MatchOpTable.lookup(New).Found)
2409       return Error("unknown operand '" + Old + "'");
2410 
2411     auto &OldOM = M.getOperandMatcher(Old);
2412     if (auto It = OperandToTempRegID.find(New);
2413         It != OperandToTempRegID.end()) {
2414       // Replace with temp reg.
2415       M.addAction<ReplaceRegAction>(OldOM.getInsnVarID(), OldOM.getOpIdx(),
2416                                     It->second);
2417     } else {
2418       // Replace with matched reg.
2419       auto &NewOM = M.getOperandMatcher(New);
2420       M.addAction<ReplaceRegAction>(OldOM.getInsnVarID(), OldOM.getOpIdx(),
2421                                     NewOM.getInsnVarID(), NewOM.getOpIdx());
2422     }
2423     // checkSemantics should have ensured that we can only rewrite the root.
2424     // Ensure we're deleting it.
2425     assert(MatchOpTable.getDef(Old) == MatchRoot);
2426     return true;
2427   }
2428   }
2429 
2430   llvm_unreachable("Unknown BuiltinKind!");
2431 }
2432 
isLiteralImm(const InstructionPattern & P,unsigned OpIdx)2433 bool isLiteralImm(const InstructionPattern &P, unsigned OpIdx) {
2434   if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&P)) {
2435     StringRef InstName = CGP->getInst().TheDef->getName();
2436     return (InstName == "G_CONSTANT" || InstName == "G_FCONSTANT") &&
2437            OpIdx == 1;
2438   }
2439 
2440   llvm_unreachable("TODO");
2441 }
2442 
emitCodeGenInstructionMatchPattern(CodeExpansions & CE,const PatternAlternatives & Alts,RuleMatcher & M,InstructionMatcher & IM,const CodeGenInstructionPattern & P,DenseSet<const Pattern * > & SeenPats,OperandDefLookupFn LookupOperandDef,OperandMapperFnRef OperandMapper)2443 bool CombineRuleBuilder::emitCodeGenInstructionMatchPattern(
2444     CodeExpansions &CE, const PatternAlternatives &Alts, RuleMatcher &M,
2445     InstructionMatcher &IM, const CodeGenInstructionPattern &P,
2446     DenseSet<const Pattern *> &SeenPats, OperandDefLookupFn LookupOperandDef,
2447     OperandMapperFnRef OperandMapper) {
2448   auto StackTrace = PrettyStackTraceEmit(RuleDef, &P);
2449 
2450   if (SeenPats.contains(&P))
2451     return true;
2452 
2453   SeenPats.insert(&P);
2454 
2455   IM.addPredicate<InstructionOpcodeMatcher>(&P.getInst());
2456   declareInstExpansion(CE, IM, P.getName());
2457 
2458   // Check flags if needed.
2459   if (const auto *FI = P.getMIFlagsInfo()) {
2460     assert(FI->copy_flags().empty());
2461 
2462     if (const auto &SetF = FI->set_flags(); !SetF.empty())
2463       IM.addPredicate<MIFlagsInstructionPredicateMatcher>(SetF.getArrayRef());
2464     if (const auto &UnsetF = FI->unset_flags(); !UnsetF.empty())
2465       IM.addPredicate<MIFlagsInstructionPredicateMatcher>(UnsetF.getArrayRef(),
2466                                                           /*CheckNot=*/true);
2467   }
2468 
2469   for (const auto &[Idx, OriginalO] : enumerate(P.operands())) {
2470     // Remap the operand. This is used when emitting InstructionPatterns inside
2471     // PatFrags, so it can remap them to the arguments passed to the pattern.
2472     //
2473     // We use the remapped operand to emit immediates, and for the symbolic
2474     // operand names (in IM.addOperand). CodeExpansions and OperandTable lookups
2475     // still use the original name.
2476     //
2477     // The "def" flag on the remapped operand is always ignored.
2478     auto RemappedO = OperandMapper(OriginalO);
2479     assert(RemappedO.isNamedOperand() == OriginalO.isNamedOperand() &&
2480            "Cannot remap an unnamed operand to a named one!");
2481 
2482     const auto OpName =
2483         RemappedO.isNamedOperand() ? RemappedO.getOperandName().str() : "";
2484     OperandMatcher &OM =
2485         IM.addOperand(Idx, OpName, AllocatedTemporariesBaseID++);
2486     if (!OpName.empty())
2487       declareOperandExpansion(CE, OM, OriginalO.getOperandName());
2488 
2489     // Handle immediates.
2490     if (RemappedO.hasImmValue()) {
2491       if (isLiteralImm(P, Idx))
2492         OM.addPredicate<LiteralIntOperandMatcher>(RemappedO.getImmValue());
2493       else
2494         OM.addPredicate<ConstantIntOperandMatcher>(RemappedO.getImmValue());
2495     }
2496 
2497     // Handle typed operands, but only bother to check if it hasn't been done
2498     // before.
2499     //
2500     // getOperandMatcher will always return the first OM to have been created
2501     // for that Operand. "OM" here is always a new OperandMatcher.
2502     //
2503     // Always emit a check for unnamed operands.
2504     if (OpName.empty() ||
2505         !M.getOperandMatcher(OpName).contains<LLTOperandMatcher>()) {
2506       if (const auto Ty = RemappedO.getType()) {
2507         // TODO: We could support GITypeOf here on the condition that the
2508         // OperandMatcher exists already. Though it's clunky to make this work
2509         // and isn't all that useful so it's just rejected in typecheckPatterns
2510         // at this time.
2511         assert(Ty.isLLT() && "Only LLTs are supported in match patterns!");
2512         OM.addPredicate<LLTOperandMatcher>(getLLTCodeGen(Ty));
2513       }
2514     }
2515 
2516     // Stop here if the operand is a def, or if it had no name.
2517     if (OriginalO.isDef() || !OriginalO.isNamedOperand())
2518       continue;
2519 
2520     const auto *DefPat = LookupOperandDef(OriginalO.getOperandName());
2521     if (!DefPat)
2522       continue;
2523 
2524     if (OriginalO.hasImmValue()) {
2525       assert(!OpName.empty());
2526       // This is a named immediate that also has a def, that's not okay.
2527       // e.g.
2528       //    (G_SEXT $y, (i32 0))
2529       //    (COPY $x, 42:$y)
2530       PrintError("'" + OpName +
2531                  "' is a named immediate, it cannot be defined by another "
2532                  "instruction");
2533       PrintNote("'" + OpName + "' is defined by '" + DefPat->getName() + "'");
2534       return false;
2535     }
2536 
2537     // From here we know that the operand defines an instruction, and we need to
2538     // emit it.
2539     auto InstOpM =
2540         OM.addPredicate<InstructionOperandMatcher>(M, DefPat->getName());
2541     if (!InstOpM) {
2542       // TODO: copy-pasted from GlobalISelEmitter.cpp. Is it still relevant
2543       // here?
2544       PrintError("Nested instruction '" + DefPat->getName() +
2545                  "' cannot be the same as another operand '" +
2546                  OriginalO.getOperandName() + "'");
2547       return false;
2548     }
2549 
2550     auto &IM = (*InstOpM)->getInsnMatcher();
2551     if (const auto *CGIDef = dyn_cast<CodeGenInstructionPattern>(DefPat)) {
2552       if (!emitCodeGenInstructionMatchPattern(CE, Alts, M, IM, *CGIDef,
2553                                               SeenPats, LookupOperandDef,
2554                                               OperandMapper))
2555         return false;
2556       continue;
2557     }
2558 
2559     if (const auto *PFPDef = dyn_cast<PatFragPattern>(DefPat)) {
2560       if (!emitPatFragMatchPattern(CE, Alts, M, &IM, *PFPDef, SeenPats))
2561         return false;
2562       continue;
2563     }
2564 
2565     llvm_unreachable("unknown type of InstructionPattern");
2566   }
2567 
2568   return true;
2569 }
2570 
2571 //===- GICombinerEmitter --------------------------------------------------===//
2572 
2573 /// Main implementation class. This emits the tablegenerated output.
2574 ///
2575 /// It collects rules, uses `CombineRuleBuilder` to parse them and accumulate
2576 /// RuleMatchers, then takes all the necessary state/data from the various
2577 /// static storage pools and wires them together to emit the match table &
2578 /// associated function/data structures.
2579 class GICombinerEmitter final : public GlobalISelMatchTableExecutorEmitter {
2580   RecordKeeper &Records;
2581   StringRef Name;
2582   const CodeGenTarget &Target;
2583   Record *Combiner;
2584   unsigned NextRuleID = 0;
2585 
2586   // List all combine rules (ID, name) imported.
2587   // Note that the combiner rule ID is different from the RuleMatcher ID. The
2588   // latter is internal to the MatchTable, the former is the canonical ID of the
2589   // combine rule used to disable/enable it.
2590   std::vector<std::pair<unsigned, std::string>> AllCombineRules;
2591 
2592   // Keep track of all rules we've seen so far to ensure we don't process
2593   // the same rule twice.
2594   StringSet<> RulesSeen;
2595 
2596   MatchTable buildMatchTable(MutableArrayRef<RuleMatcher> Rules);
2597 
2598   void emitRuleConfigImpl(raw_ostream &OS);
2599 
2600   void emitAdditionalImpl(raw_ostream &OS) override;
2601 
2602   void emitMIPredicateFns(raw_ostream &OS) override;
2603   void emitI64ImmPredicateFns(raw_ostream &OS) override;
2604   void emitAPFloatImmPredicateFns(raw_ostream &OS) override;
2605   void emitAPIntImmPredicateFns(raw_ostream &OS) override;
2606   void emitTestSimplePredicate(raw_ostream &OS) override;
2607   void emitRunCustomAction(raw_ostream &OS) override;
2608 
2609   void emitAdditionalTemporariesDecl(raw_ostream &OS,
2610                                      StringRef Indent) override;
2611 
getTarget() const2612   const CodeGenTarget &getTarget() const override { return Target; }
getClassName() const2613   StringRef getClassName() const override {
2614     return Combiner->getValueAsString("Classname");
2615   }
2616 
getCombineAllMethodName() const2617   StringRef getCombineAllMethodName() const {
2618     return Combiner->getValueAsString("CombineAllMethodName");
2619   }
2620 
getRuleConfigClassName() const2621   std::string getRuleConfigClassName() const {
2622     return getClassName().str() + "RuleConfig";
2623   }
2624 
2625   void gatherRules(std::vector<RuleMatcher> &Rules,
2626                    const std::vector<Record *> &&RulesAndGroups);
2627 
2628 public:
2629   explicit GICombinerEmitter(RecordKeeper &RK, const CodeGenTarget &Target,
2630                              StringRef Name, Record *Combiner);
~GICombinerEmitter()2631   ~GICombinerEmitter() {}
2632 
2633   void run(raw_ostream &OS);
2634 };
2635 
emitRuleConfigImpl(raw_ostream & OS)2636 void GICombinerEmitter::emitRuleConfigImpl(raw_ostream &OS) {
2637   OS << "struct " << getRuleConfigClassName() << " {\n"
2638      << "  SparseBitVector<> DisabledRules;\n\n"
2639      << "  bool isRuleEnabled(unsigned RuleID) const;\n"
2640      << "  bool parseCommandLineOption();\n"
2641      << "  bool setRuleEnabled(StringRef RuleIdentifier);\n"
2642      << "  bool setRuleDisabled(StringRef RuleIdentifier);\n"
2643      << "};\n\n";
2644 
2645   std::vector<std::pair<std::string, std::string>> Cases;
2646   Cases.reserve(AllCombineRules.size());
2647 
2648   for (const auto &[ID, Name] : AllCombineRules)
2649     Cases.emplace_back(Name, "return " + to_string(ID) + ";\n");
2650 
2651   OS << "static std::optional<uint64_t> getRuleIdxForIdentifier(StringRef "
2652         "RuleIdentifier) {\n"
2653      << "  uint64_t I;\n"
2654      << "  // getAtInteger(...) returns false on success\n"
2655      << "  bool Parsed = !RuleIdentifier.getAsInteger(0, I);\n"
2656      << "  if (Parsed)\n"
2657      << "    return I;\n\n"
2658      << "#ifndef NDEBUG\n";
2659   StringMatcher Matcher("RuleIdentifier", Cases, OS);
2660   Matcher.Emit();
2661   OS << "#endif // ifndef NDEBUG\n\n"
2662      << "  return std::nullopt;\n"
2663      << "}\n";
2664 
2665   OS << "static std::optional<std::pair<uint64_t, uint64_t>> "
2666         "getRuleRangeForIdentifier(StringRef RuleIdentifier) {\n"
2667      << "  std::pair<StringRef, StringRef> RangePair = "
2668         "RuleIdentifier.split('-');\n"
2669      << "  if (!RangePair.second.empty()) {\n"
2670      << "    const auto First = "
2671         "getRuleIdxForIdentifier(RangePair.first);\n"
2672      << "    const auto Last = "
2673         "getRuleIdxForIdentifier(RangePair.second);\n"
2674      << "    if (!First || !Last)\n"
2675      << "      return std::nullopt;\n"
2676      << "    if (First >= Last)\n"
2677      << "      report_fatal_error(\"Beginning of range should be before "
2678         "end of range\");\n"
2679      << "    return {{*First, *Last + 1}};\n"
2680      << "  }\n"
2681      << "  if (RangePair.first == \"*\") {\n"
2682      << "    return {{0, " << AllCombineRules.size() << "}};\n"
2683      << "  }\n"
2684      << "  const auto I = getRuleIdxForIdentifier(RangePair.first);\n"
2685      << "  if (!I)\n"
2686      << "    return std::nullopt;\n"
2687      << "  return {{*I, *I + 1}};\n"
2688      << "}\n\n";
2689 
2690   for (bool Enabled : {true, false}) {
2691     OS << "bool " << getRuleConfigClassName() << "::setRule"
2692        << (Enabled ? "Enabled" : "Disabled") << "(StringRef RuleIdentifier) {\n"
2693        << "  auto MaybeRange = getRuleRangeForIdentifier(RuleIdentifier);\n"
2694        << "  if (!MaybeRange)\n"
2695        << "    return false;\n"
2696        << "  for (auto I = MaybeRange->first; I < MaybeRange->second; ++I)\n"
2697        << "    DisabledRules." << (Enabled ? "reset" : "set") << "(I);\n"
2698        << "  return true;\n"
2699        << "}\n\n";
2700   }
2701 
2702   OS << "static std::vector<std::string> " << Name << "Option;\n"
2703      << "static cl::list<std::string> " << Name << "DisableOption(\n"
2704      << "    \"" << Name.lower() << "-disable-rule\",\n"
2705      << "    cl::desc(\"Disable one or more combiner rules temporarily in "
2706      << "the " << Name << " pass\"),\n"
2707      << "    cl::CommaSeparated,\n"
2708      << "    cl::Hidden,\n"
2709      << "    cl::cat(GICombinerOptionCategory),\n"
2710      << "    cl::callback([](const std::string &Str) {\n"
2711      << "      " << Name << "Option.push_back(Str);\n"
2712      << "    }));\n"
2713      << "static cl::list<std::string> " << Name << "OnlyEnableOption(\n"
2714      << "    \"" << Name.lower() << "-only-enable-rule\",\n"
2715      << "    cl::desc(\"Disable all rules in the " << Name
2716      << " pass then re-enable the specified ones\"),\n"
2717      << "    cl::Hidden,\n"
2718      << "    cl::cat(GICombinerOptionCategory),\n"
2719      << "    cl::callback([](const std::string &CommaSeparatedArg) {\n"
2720      << "      StringRef Str = CommaSeparatedArg;\n"
2721      << "      " << Name << "Option.push_back(\"*\");\n"
2722      << "      do {\n"
2723      << "        auto X = Str.split(\",\");\n"
2724      << "        " << Name << "Option.push_back((\"!\" + X.first).str());\n"
2725      << "        Str = X.second;\n"
2726      << "      } while (!Str.empty());\n"
2727      << "    }));\n"
2728      << "\n\n"
2729      << "bool " << getRuleConfigClassName()
2730      << "::isRuleEnabled(unsigned RuleID) const {\n"
2731      << "    return  !DisabledRules.test(RuleID);\n"
2732      << "}\n"
2733      << "bool " << getRuleConfigClassName() << "::parseCommandLineOption() {\n"
2734      << "  for (StringRef Identifier : " << Name << "Option) {\n"
2735      << "    bool Enabled = Identifier.consume_front(\"!\");\n"
2736      << "    if (Enabled && !setRuleEnabled(Identifier))\n"
2737      << "      return false;\n"
2738      << "    if (!Enabled && !setRuleDisabled(Identifier))\n"
2739      << "      return false;\n"
2740      << "  }\n"
2741      << "  return true;\n"
2742      << "}\n\n";
2743 }
2744 
emitAdditionalImpl(raw_ostream & OS)2745 void GICombinerEmitter::emitAdditionalImpl(raw_ostream &OS) {
2746   OS << "bool " << getClassName() << "::" << getCombineAllMethodName()
2747      << "(MachineInstr &I) const {\n"
2748      << "  const TargetSubtargetInfo &ST = MF.getSubtarget();\n"
2749      << "  const PredicateBitset AvailableFeatures = "
2750         "getAvailableFeatures();\n"
2751      << "  B.setInstrAndDebugLoc(I);\n"
2752      << "  State.MIs.clear();\n"
2753      << "  State.MIs.push_back(&I);\n"
2754      << "  " << MatchDataInfo::StructName << " = "
2755      << MatchDataInfo::StructTypeName << "();\n\n"
2756      << "  if (executeMatchTable(*this, State, ExecInfo, B"
2757      << ", getMatchTable(), *ST.getInstrInfo(), MRI, "
2758         "*MRI.getTargetRegisterInfo(), *ST.getRegBankInfo(), AvailableFeatures"
2759      << ", /*CoverageInfo*/ nullptr)) {\n"
2760      << "    return true;\n"
2761      << "  }\n\n"
2762      << "  return false;\n"
2763      << "}\n\n";
2764 }
2765 
emitMIPredicateFns(raw_ostream & OS)2766 void GICombinerEmitter::emitMIPredicateFns(raw_ostream &OS) {
2767   auto MatchCode = CXXPredicateCode::getAllMatchCode();
2768   emitMIPredicateFnsImpl<const CXXPredicateCode *>(
2769       OS, "", ArrayRef<const CXXPredicateCode *>(MatchCode),
2770       [](const CXXPredicateCode *C) -> StringRef { return C->BaseEnumName; },
2771       [](const CXXPredicateCode *C) -> StringRef { return C->Code; });
2772 }
2773 
emitI64ImmPredicateFns(raw_ostream & OS)2774 void GICombinerEmitter::emitI64ImmPredicateFns(raw_ostream &OS) {
2775   // Unused, but still needs to be called.
2776   emitImmPredicateFnsImpl<unsigned>(
2777       OS, "I64", "int64_t", {}, [](unsigned) { return ""; },
2778       [](unsigned) { return ""; });
2779 }
2780 
emitAPFloatImmPredicateFns(raw_ostream & OS)2781 void GICombinerEmitter::emitAPFloatImmPredicateFns(raw_ostream &OS) {
2782   // Unused, but still needs to be called.
2783   emitImmPredicateFnsImpl<unsigned>(
2784       OS, "APFloat", "const APFloat &", {}, [](unsigned) { return ""; },
2785       [](unsigned) { return ""; });
2786 }
2787 
emitAPIntImmPredicateFns(raw_ostream & OS)2788 void GICombinerEmitter::emitAPIntImmPredicateFns(raw_ostream &OS) {
2789   // Unused, but still needs to be called.
2790   emitImmPredicateFnsImpl<unsigned>(
2791       OS, "APInt", "const APInt &", {}, [](unsigned) { return ""; },
2792       [](unsigned) { return ""; });
2793 }
2794 
emitTestSimplePredicate(raw_ostream & OS)2795 void GICombinerEmitter::emitTestSimplePredicate(raw_ostream &OS) {
2796   if (!AllCombineRules.empty()) {
2797     OS << "enum {\n";
2798     std::string EnumeratorSeparator = " = GICXXPred_Invalid + 1,\n";
2799     // To avoid emitting a switch, we expect that all those rules are in order.
2800     // That way we can just get the RuleID from the enum by subtracting
2801     // (GICXXPred_Invalid + 1).
2802     unsigned ExpectedID = 0;
2803     (void)ExpectedID;
2804     for (const auto &ID : keys(AllCombineRules)) {
2805       assert(ExpectedID++ == ID && "combine rules are not ordered!");
2806       OS << "  " << getIsEnabledPredicateEnumName(ID) << EnumeratorSeparator;
2807       EnumeratorSeparator = ",\n";
2808     }
2809     OS << "};\n\n";
2810   }
2811 
2812   OS << "bool " << getClassName()
2813      << "::testSimplePredicate(unsigned Predicate) const {\n"
2814      << "    return RuleConfig.isRuleEnabled(Predicate - "
2815         "GICXXPred_Invalid - "
2816         "1);\n"
2817      << "}\n";
2818 }
2819 
emitRunCustomAction(raw_ostream & OS)2820 void GICombinerEmitter::emitRunCustomAction(raw_ostream &OS) {
2821   const auto ApplyCode = CXXPredicateCode::getAllApplyCode();
2822 
2823   if (!ApplyCode.empty()) {
2824     OS << "enum {\n";
2825     std::string EnumeratorSeparator = " = GICXXCustomAction_Invalid + 1,\n";
2826     for (const auto &Apply : ApplyCode) {
2827       OS << "  " << Apply->getEnumNameWithPrefix(CXXApplyPrefix)
2828          << EnumeratorSeparator;
2829       EnumeratorSeparator = ",\n";
2830     }
2831     OS << "};\n";
2832   }
2833 
2834   OS << "void " << getClassName()
2835      << "::runCustomAction(unsigned ApplyID, const MatcherState &State, "
2836         "NewMIVector &OutMIs) const "
2837         "{\n";
2838   if (!ApplyCode.empty()) {
2839     OS << "  switch(ApplyID) {\n";
2840     for (const auto &Apply : ApplyCode) {
2841       OS << "  case " << Apply->getEnumNameWithPrefix(CXXApplyPrefix) << ":{\n"
2842          << "    " << join(split(Apply->Code, '\n'), "\n    ") << '\n'
2843          << "    return;\n";
2844       OS << "  }\n";
2845     }
2846     OS << "}\n";
2847   }
2848   OS << "  llvm_unreachable(\"Unknown Apply Action\");\n"
2849      << "}\n";
2850 }
2851 
emitAdditionalTemporariesDecl(raw_ostream & OS,StringRef Indent)2852 void GICombinerEmitter::emitAdditionalTemporariesDecl(raw_ostream &OS,
2853                                                       StringRef Indent) {
2854   OS << Indent << "struct " << MatchDataInfo::StructTypeName << " {\n";
2855   for (const auto &[Type, VarNames] : AllMatchDataVars) {
2856     assert(!VarNames.empty() && "Cannot have no vars for this type!");
2857     OS << Indent << "  " << Type << " " << join(VarNames, ", ") << ";\n";
2858   }
2859   OS << Indent << "};\n"
2860      << Indent << "mutable " << MatchDataInfo::StructTypeName << " "
2861      << MatchDataInfo::StructName << ";\n\n";
2862 }
2863 
GICombinerEmitter(RecordKeeper & RK,const CodeGenTarget & Target,StringRef Name,Record * Combiner)2864 GICombinerEmitter::GICombinerEmitter(RecordKeeper &RK,
2865                                      const CodeGenTarget &Target,
2866                                      StringRef Name, Record *Combiner)
2867     : Records(RK), Name(Name), Target(Target), Combiner(Combiner) {}
2868 
2869 MatchTable
buildMatchTable(MutableArrayRef<RuleMatcher> Rules)2870 GICombinerEmitter::buildMatchTable(MutableArrayRef<RuleMatcher> Rules) {
2871   std::vector<Matcher *> InputRules;
2872   for (Matcher &Rule : Rules)
2873     InputRules.push_back(&Rule);
2874 
2875   unsigned CurrentOrdering = 0;
2876   StringMap<unsigned> OpcodeOrder;
2877   for (RuleMatcher &Rule : Rules) {
2878     const StringRef Opcode = Rule.getOpcode();
2879     assert(!Opcode.empty() && "Didn't expect an undefined opcode");
2880     if (OpcodeOrder.count(Opcode) == 0)
2881       OpcodeOrder[Opcode] = CurrentOrdering++;
2882   }
2883 
2884   llvm::stable_sort(InputRules, [&OpcodeOrder](const Matcher *A,
2885                                                const Matcher *B) {
2886     auto *L = static_cast<const RuleMatcher *>(A);
2887     auto *R = static_cast<const RuleMatcher *>(B);
2888     return std::make_tuple(OpcodeOrder[L->getOpcode()], L->getNumOperands()) <
2889            std::make_tuple(OpcodeOrder[R->getOpcode()], R->getNumOperands());
2890   });
2891 
2892   for (Matcher *Rule : InputRules)
2893     Rule->optimize();
2894 
2895   std::vector<std::unique_ptr<Matcher>> MatcherStorage;
2896   std::vector<Matcher *> OptRules =
2897       optimizeRules<GroupMatcher>(InputRules, MatcherStorage);
2898 
2899   for (Matcher *Rule : OptRules)
2900     Rule->optimize();
2901 
2902   OptRules = optimizeRules<SwitchMatcher>(OptRules, MatcherStorage);
2903 
2904   return MatchTable::buildTable(OptRules, /*WithCoverage*/ false,
2905                                 /*IsCombiner*/ true);
2906 }
2907 
2908 /// Recurse into GICombineGroup's and flatten the ruleset into a simple list.
gatherRules(std::vector<RuleMatcher> & ActiveRules,const std::vector<Record * > && RulesAndGroups)2909 void GICombinerEmitter::gatherRules(
2910     std::vector<RuleMatcher> &ActiveRules,
2911     const std::vector<Record *> &&RulesAndGroups) {
2912   for (Record *Rec : RulesAndGroups) {
2913     if (!Rec->isValueUnset("Rules")) {
2914       gatherRules(ActiveRules, Rec->getValueAsListOfDefs("Rules"));
2915       continue;
2916     }
2917 
2918     StringRef RuleName = Rec->getName();
2919     if (!RulesSeen.insert(RuleName).second) {
2920       PrintWarning(Rec->getLoc(),
2921                    "skipping rule '" + Rec->getName() +
2922                        "' because it has already been processed");
2923       continue;
2924     }
2925 
2926     AllCombineRules.emplace_back(NextRuleID, Rec->getName().str());
2927     CombineRuleBuilder CRB(Target, SubtargetFeatures, *Rec, NextRuleID++,
2928                            ActiveRules);
2929 
2930     if (!CRB.parseAll()) {
2931       assert(ErrorsPrinted && "Parsing failed without errors!");
2932       continue;
2933     }
2934 
2935     if (StopAfterParse) {
2936       CRB.print(outs());
2937       continue;
2938     }
2939 
2940     if (!CRB.emitRuleMatchers()) {
2941       assert(ErrorsPrinted && "Emission failed without errors!");
2942       continue;
2943     }
2944   }
2945 }
2946 
run(raw_ostream & OS)2947 void GICombinerEmitter::run(raw_ostream &OS) {
2948   InstructionOpcodeMatcher::initOpcodeValuesMap(Target);
2949   LLTOperandMatcher::initTypeIDValuesMap();
2950 
2951   Records.startTimer("Gather rules");
2952   std::vector<RuleMatcher> Rules;
2953   gatherRules(Rules, Combiner->getValueAsListOfDefs("Rules"));
2954   if (ErrorsPrinted)
2955     PrintFatalError(Combiner->getLoc(), "Failed to parse one or more rules");
2956 
2957   if (StopAfterParse)
2958     return;
2959 
2960   Records.startTimer("Creating Match Table");
2961   unsigned MaxTemporaries = 0;
2962   for (const auto &Rule : Rules)
2963     MaxTemporaries = std::max(MaxTemporaries, Rule.countRendererFns());
2964 
2965   llvm::stable_sort(Rules, [&](const RuleMatcher &A, const RuleMatcher &B) {
2966     if (A.isHigherPriorityThan(B)) {
2967       assert(!B.isHigherPriorityThan(A) && "Cannot be more important "
2968                                            "and less important at "
2969                                            "the same time");
2970       return true;
2971     }
2972     return false;
2973   });
2974 
2975   const MatchTable Table = buildMatchTable(Rules);
2976 
2977   Records.startTimer("Emit combiner");
2978 
2979   emitSourceFileHeader(getClassName().str() + " Combiner Match Table", OS);
2980 
2981   // Unused
2982   std::vector<StringRef> CustomRendererFns;
2983   // Unused
2984   std::vector<Record *> ComplexPredicates;
2985 
2986   SmallVector<LLTCodeGen, 16> TypeObjects;
2987   append_range(TypeObjects, KnownTypes);
2988   llvm::sort(TypeObjects);
2989 
2990   // Hack: Avoid empty declarator.
2991   if (TypeObjects.empty())
2992     TypeObjects.push_back(LLT::scalar(1));
2993 
2994   // GET_GICOMBINER_DEPS, which pulls in extra dependencies.
2995   OS << "#ifdef GET_GICOMBINER_DEPS\n"
2996      << "#include \"llvm/ADT/SparseBitVector.h\"\n"
2997      << "namespace llvm {\n"
2998      << "extern cl::OptionCategory GICombinerOptionCategory;\n"
2999      << "} // end namespace llvm\n"
3000      << "#endif // ifdef GET_GICOMBINER_DEPS\n\n";
3001 
3002   // GET_GICOMBINER_TYPES, which needs to be included before the declaration of
3003   // the class.
3004   OS << "#ifdef GET_GICOMBINER_TYPES\n";
3005   emitRuleConfigImpl(OS);
3006   OS << "#endif // ifdef GET_GICOMBINER_TYPES\n\n";
3007   emitPredicateBitset(OS, "GET_GICOMBINER_TYPES");
3008 
3009   // GET_GICOMBINER_CLASS_MEMBERS, which need to be included inside the class.
3010   emitPredicatesDecl(OS, "GET_GICOMBINER_CLASS_MEMBERS");
3011   emitTemporariesDecl(OS, "GET_GICOMBINER_CLASS_MEMBERS");
3012 
3013   // GET_GICOMBINER_IMPL, which needs to be included outside the class.
3014   emitExecutorImpl(OS, Table, TypeObjects, Rules, ComplexPredicates,
3015                    CustomRendererFns, "GET_GICOMBINER_IMPL");
3016 
3017   // GET_GICOMBINER_CONSTRUCTOR_INITS, which are in the constructor's
3018   // initializer list.
3019   emitPredicatesInit(OS, "GET_GICOMBINER_CONSTRUCTOR_INITS");
3020   emitTemporariesInit(OS, MaxTemporaries, "GET_GICOMBINER_CONSTRUCTOR_INITS");
3021 }
3022 
3023 } // end anonymous namespace
3024 
3025 //===----------------------------------------------------------------------===//
3026 
EmitGICombiner(RecordKeeper & RK,raw_ostream & OS)3027 static void EmitGICombiner(RecordKeeper &RK, raw_ostream &OS) {
3028   EnablePrettyStackTrace();
3029   CodeGenTarget Target(RK);
3030 
3031   if (SelectedCombiners.empty())
3032     PrintFatalError("No combiners selected with -combiners");
3033   for (const auto &Combiner : SelectedCombiners) {
3034     Record *CombinerDef = RK.getDef(Combiner);
3035     if (!CombinerDef)
3036       PrintFatalError("Could not find " + Combiner);
3037     GICombinerEmitter(RK, Target, Combiner, CombinerDef).run(OS);
3038   }
3039 }
3040 
3041 static TableGen::Emitter::Opt X("gen-global-isel-combiner", EmitGICombiner,
3042                                 "Generate GlobalISel Combiner");
3043