1 //===- Patterns.cpp --------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Patterns.h"
10 #include "../CodeGenInstruction.h"
11 #include "CXXPredicates.h"
12 #include "CodeExpander.h"
13 #include "CodeExpansions.h"
14 #include "llvm/ADT/StringSet.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include "llvm/TableGen/Error.h"
18 #include "llvm/TableGen/Record.h"
19 
20 namespace llvm {
21 namespace gi {
22 
23 //===- PatternType --------------------------------------------------------===//
24 
get(ArrayRef<SMLoc> DiagLoc,const Record * R,Twine DiagCtx)25 std::optional<PatternType> PatternType::get(ArrayRef<SMLoc> DiagLoc,
26                                             const Record *R, Twine DiagCtx) {
27   assert(R);
28   if (R->isSubClassOf("ValueType")) {
29     PatternType PT(PT_ValueType);
30     PT.Data.Def = R;
31     return PT;
32   }
33 
34   if (R->isSubClassOf(TypeOfClassName)) {
35     auto RawOpName = R->getValueAsString("OpName");
36     if (!RawOpName.starts_with("$")) {
37       PrintError(DiagLoc, DiagCtx + ": invalid operand name format '" +
38                               RawOpName + "' in " + TypeOfClassName +
39                               ": expected '$' followed by an operand name");
40       return std::nullopt;
41     }
42 
43     PatternType PT(PT_TypeOf);
44     PT.Data.Str = RawOpName.drop_front(1);
45     return PT;
46   }
47 
48   PrintError(DiagLoc, DiagCtx + ": unknown type '" + R->getName() + "'");
49   return std::nullopt;
50 }
51 
getTypeOf(StringRef OpName)52 PatternType PatternType::getTypeOf(StringRef OpName) {
53   PatternType PT(PT_TypeOf);
54   PT.Data.Str = OpName;
55   return PT;
56 }
57 
getTypeOfOpName() const58 StringRef PatternType::getTypeOfOpName() const {
59   assert(isTypeOf());
60   return Data.Str;
61 }
62 
getLLTRecord() const63 const Record *PatternType::getLLTRecord() const {
64   assert(isLLT());
65   return Data.Def;
66 }
67 
operator ==(const PatternType & Other) const68 bool PatternType::operator==(const PatternType &Other) const {
69   if (Kind != Other.Kind)
70     return false;
71 
72   switch (Kind) {
73   case PT_None:
74     return true;
75   case PT_ValueType:
76     return Data.Def == Other.Data.Def;
77   case PT_TypeOf:
78     return Data.Str == Other.Data.Str;
79   }
80 
81   llvm_unreachable("Unknown Type Kind");
82 }
83 
str() const84 std::string PatternType::str() const {
85   switch (Kind) {
86   case PT_None:
87     return "";
88   case PT_ValueType:
89     return Data.Def->getName().str();
90   case PT_TypeOf:
91     return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str();
92   }
93 
94   llvm_unreachable("Unknown type!");
95 }
96 
97 //===- Pattern ------------------------------------------------------------===//
98 
dump() const99 void Pattern::dump() const { return print(dbgs()); }
100 
getKindName() const101 const char *Pattern::getKindName() const {
102   switch (Kind) {
103   case K_AnyOpcode:
104     return "AnyOpcodePattern";
105   case K_CXX:
106     return "CXXPattern";
107   case K_CodeGenInstruction:
108     return "CodeGenInstructionPattern";
109   case K_PatFrag:
110     return "PatFragPattern";
111   case K_Builtin:
112     return "BuiltinPattern";
113   }
114 
115   llvm_unreachable("unknown pattern kind!");
116 }
117 
printImpl(raw_ostream & OS,bool PrintName,function_ref<void ()> ContentPrinter) const118 void Pattern::printImpl(raw_ostream &OS, bool PrintName,
119                         function_ref<void()> ContentPrinter) const {
120   OS << "(" << getKindName() << " ";
121   if (PrintName)
122     OS << "name:" << getName() << " ";
123   ContentPrinter();
124   OS << ")";
125 }
126 
127 //===- AnyOpcodePattern ---------------------------------------------------===//
128 
print(raw_ostream & OS,bool PrintName) const129 void AnyOpcodePattern::print(raw_ostream &OS, bool PrintName) const {
130   printImpl(OS, PrintName, [&OS, this]() {
131     OS << "["
132        << join(map_range(Insts,
133                          [](const auto *I) { return I->TheDef->getName(); }),
134                ", ")
135        << "]";
136   });
137 }
138 
139 //===- CXXPattern ---------------------------------------------------------===//
140 
CXXPattern(const StringInit & Code,StringRef Name)141 CXXPattern::CXXPattern(const StringInit &Code, StringRef Name)
142     : CXXPattern(Code.getAsUnquotedString(), Name) {}
143 
144 const CXXPredicateCode &
expandCode(const CodeExpansions & CE,ArrayRef<SMLoc> Locs,function_ref<void (raw_ostream &)> AddComment) const145 CXXPattern::expandCode(const CodeExpansions &CE, ArrayRef<SMLoc> Locs,
146                        function_ref<void(raw_ostream &)> AddComment) const {
147   std::string Result;
148   raw_string_ostream OS(Result);
149 
150   if (AddComment)
151     AddComment(OS);
152 
153   CodeExpander Expander(RawCode, CE, Locs, /*ShowExpansions*/ false);
154   Expander.emit(OS);
155   if (IsApply)
156     return CXXPredicateCode::getApplyCode(std::move(Result));
157   return CXXPredicateCode::getMatchCode(std::move(Result));
158 }
159 
print(raw_ostream & OS,bool PrintName) const160 void CXXPattern::print(raw_ostream &OS, bool PrintName) const {
161   printImpl(OS, PrintName, [&OS, this] {
162     OS << (IsApply ? "apply" : "match") << " code:\"";
163     printEscapedString(getRawCode(), OS);
164     OS << "\"";
165   });
166 }
167 
168 //===- InstructionOperand -------------------------------------------------===//
169 
describe() const170 std::string InstructionOperand::describe() const {
171   if (!hasImmValue())
172     return "MachineOperand $" + getOperandName().str() + "";
173   std::string Str = "imm " + std::to_string(getImmValue());
174   if (isNamedImmediate())
175     Str += ":$" + getOperandName().str() + "";
176   return Str;
177 }
178 
print(raw_ostream & OS) const179 void InstructionOperand::print(raw_ostream &OS) const {
180   if (isDef())
181     OS << "<def>";
182 
183   bool NeedsColon = true;
184   if (Type) {
185     if (hasImmValue())
186       OS << "(" << Type.str() << " " << getImmValue() << ")";
187     else
188       OS << Type.str();
189   } else if (hasImmValue())
190     OS << getImmValue();
191   else
192     NeedsColon = false;
193 
194   if (isNamedOperand())
195     OS << (NeedsColon ? ":" : "") << "$" << getOperandName();
196 }
197 
dump() const198 void InstructionOperand::dump() const { return print(dbgs()); }
199 
200 //===- InstructionPattern -------------------------------------------------===//
201 
diagnoseAllSpecialTypes(ArrayRef<SMLoc> Loc,Twine Msg) const202 bool InstructionPattern::diagnoseAllSpecialTypes(ArrayRef<SMLoc> Loc,
203                                                  Twine Msg) const {
204   bool HasDiag = false;
205   for (const auto &[Idx, Op] : enumerate(operands())) {
206     if (Op.getType().isSpecial()) {
207       PrintError(Loc, Msg);
208       PrintNote(Loc, "operand " + Twine(Idx) + " of '" + getName() +
209                          "' has type '" + Op.getType().str() + "'");
210       HasDiag = true;
211     }
212   }
213   return HasDiag;
214 }
215 
reportUnreachable(ArrayRef<SMLoc> Locs) const216 void InstructionPattern::reportUnreachable(ArrayRef<SMLoc> Locs) const {
217   PrintError(Locs, "pattern '" + getName() + "' ('" + getInstName() +
218                        "') is unreachable from the pattern root!");
219 }
220 
checkSemantics(ArrayRef<SMLoc> Loc)221 bool InstructionPattern::checkSemantics(ArrayRef<SMLoc> Loc) {
222   unsigned NumExpectedOperands = getNumInstOperands();
223 
224   if (isVariadic()) {
225     if (Operands.size() < NumExpectedOperands) {
226       PrintError(Loc, +"'" + getInstName() + "' expected at least " +
227                           Twine(NumExpectedOperands) + " operands, got " +
228                           Twine(Operands.size()));
229       return false;
230     }
231   } else if (NumExpectedOperands != Operands.size()) {
232     PrintError(Loc, +"'" + getInstName() + "' expected " +
233                         Twine(NumExpectedOperands) + " operands, got " +
234                         Twine(Operands.size()));
235     return false;
236   }
237 
238   unsigned OpIdx = 0;
239   unsigned NumDefs = getNumInstDefs();
240   for (auto &Op : Operands)
241     Op.setIsDef(OpIdx++ < NumDefs);
242 
243   return true;
244 }
245 
print(raw_ostream & OS,bool PrintName) const246 void InstructionPattern::print(raw_ostream &OS, bool PrintName) const {
247   printImpl(OS, PrintName, [&OS, this] {
248     OS << getInstName() << " operands:[";
249     StringRef Sep;
250     for (const auto &Op : Operands) {
251       OS << Sep;
252       Op.print(OS);
253       Sep = ", ";
254     }
255     OS << "]";
256 
257     printExtras(OS);
258   });
259 }
260 
261 //===- OperandTable -------------------------------------------------------===//
262 
addPattern(InstructionPattern * P,function_ref<void (StringRef)> DiagnoseRedef)263 bool OperandTable::addPattern(InstructionPattern *P,
264                               function_ref<void(StringRef)> DiagnoseRedef) {
265   for (const auto &Op : P->named_operands()) {
266     StringRef OpName = Op.getOperandName();
267 
268     // We always create an entry in the OperandTable, even for uses.
269     // Uses of operands that don't have a def (= live-ins) will remain with a
270     // nullptr as the Def.
271     //
272     // This allows us tell whether an operand exists in a pattern or not. If
273     // there is no entry for it, it doesn't exist, if there is an entry, it's
274     // used/def'd at least once.
275     auto &Def = Table[OpName];
276 
277     if (!Op.isDef())
278       continue;
279 
280     if (Def) {
281       DiagnoseRedef(OpName);
282       return false;
283     }
284 
285     Def = P;
286   }
287 
288   return true;
289 }
290 
print(raw_ostream & OS,StringRef Name,StringRef Indent) const291 void OperandTable::print(raw_ostream &OS, StringRef Name,
292                          StringRef Indent) const {
293   OS << Indent << "(OperandTable ";
294   if (!Name.empty())
295     OS << Name << " ";
296   if (Table.empty()) {
297     OS << "<empty>)\n";
298     return;
299   }
300 
301   SmallVector<StringRef, 0> Keys(Table.keys());
302   sort(Keys);
303 
304   OS << '\n';
305   for (const auto &Key : Keys) {
306     const auto *Def = Table.at(Key);
307     OS << Indent << "  " << Key << " -> "
308        << (Def ? Def->getName() : "<live-in>") << '\n';
309   }
310   OS << Indent << ")\n";
311 }
312 
dump() const313 void OperandTable::dump() const { print(dbgs()); }
314 
315 //===- MIFlagsInfo --------------------------------------------------------===//
316 
addSetFlag(const Record * R)317 void MIFlagsInfo::addSetFlag(const Record *R) {
318   SetF.insert(R->getValueAsString("EnumName"));
319 }
320 
addUnsetFlag(const Record * R)321 void MIFlagsInfo::addUnsetFlag(const Record *R) {
322   UnsetF.insert(R->getValueAsString("EnumName"));
323 }
324 
addCopyFlag(StringRef InstName)325 void MIFlagsInfo::addCopyFlag(StringRef InstName) { CopyF.insert(InstName); }
326 
327 //===- CodeGenInstructionPattern ------------------------------------------===//
328 
is(StringRef OpcodeName) const329 bool CodeGenInstructionPattern::is(StringRef OpcodeName) const {
330   return I.TheDef->getName() == OpcodeName;
331 }
332 
isVariadic() const333 bool CodeGenInstructionPattern::isVariadic() const {
334   return I.Operands.isVariadic;
335 }
336 
hasVariadicDefs() const337 bool CodeGenInstructionPattern::hasVariadicDefs() const {
338   // Note: we cannot use variadicOpsAreDefs, it's not set for
339   // GenericInstructions.
340   if (!isVariadic())
341     return false;
342 
343   if (I.variadicOpsAreDefs)
344     return true;
345 
346   DagInit *OutOps = I.TheDef->getValueAsDag("OutOperandList");
347   if (OutOps->arg_empty())
348     return false;
349 
350   auto *LastArgTy = dyn_cast<DefInit>(OutOps->getArg(OutOps->arg_size() - 1));
351   return LastArgTy && LastArgTy->getDef()->getName() == "variable_ops";
352 }
353 
getNumInstDefs() const354 unsigned CodeGenInstructionPattern::getNumInstDefs() const {
355   if (!isVariadic() || !hasVariadicDefs())
356     return I.Operands.NumDefs;
357   unsigned NumOuts = I.Operands.size() - I.Operands.NumDefs;
358   assert(Operands.size() > NumOuts);
359   return std::max<unsigned>(I.Operands.NumDefs, Operands.size() - NumOuts);
360 }
361 
getNumInstOperands() const362 unsigned CodeGenInstructionPattern::getNumInstOperands() const {
363   unsigned NumCGIOps = I.Operands.size();
364   return isVariadic() ? std::max<unsigned>(NumCGIOps, Operands.size())
365                       : NumCGIOps;
366 }
367 
getOrCreateMIFlagsInfo()368 MIFlagsInfo &CodeGenInstructionPattern::getOrCreateMIFlagsInfo() {
369   if (!FI)
370     FI = std::make_unique<MIFlagsInfo>();
371   return *FI;
372 }
373 
getInstName() const374 StringRef CodeGenInstructionPattern::getInstName() const {
375   return I.TheDef->getName();
376 }
377 
printExtras(raw_ostream & OS) const378 void CodeGenInstructionPattern::printExtras(raw_ostream &OS) const {
379   if (!FI)
380     return;
381 
382   OS << " (MIFlags";
383   if (!FI->set_flags().empty())
384     OS << " (set " << join(FI->set_flags(), ", ") << ")";
385   if (!FI->unset_flags().empty())
386     OS << " (unset " << join(FI->unset_flags(), ", ") << ")";
387   if (!FI->copy_flags().empty())
388     OS << " (copy " << join(FI->copy_flags(), ", ") << ")";
389   OS << ')';
390 }
391 
392 //===- OperandTypeChecker -------------------------------------------------===//
393 
check(InstructionPattern & P,std::function<bool (const PatternType &)> VerifyTypeOfOperand)394 bool OperandTypeChecker::check(
395     InstructionPattern &P,
396     std::function<bool(const PatternType &)> VerifyTypeOfOperand) {
397   Pats.push_back(&P);
398 
399   for (auto &Op : P.operands()) {
400     const auto Ty = Op.getType();
401     if (!Ty)
402       continue;
403 
404     if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty))
405       return false;
406 
407     if (!Op.isNamedOperand())
408       continue;
409 
410     StringRef OpName = Op.getOperandName();
411     auto &Info = Types[OpName];
412     if (!Info.Type) {
413       Info.Type = Ty;
414       Info.PrintTypeSrcNote = [this, OpName, Ty, &P]() {
415         PrintSeenWithTypeIn(P, OpName, Ty);
416       };
417       continue;
418     }
419 
420     if (Info.Type != Ty) {
421       PrintError(DiagLoc, "conflicting types for operand '" +
422                               Op.getOperandName() + "': '" + Info.Type.str() +
423                               "' vs '" + Ty.str() + "'");
424       PrintSeenWithTypeIn(P, OpName, Ty);
425       Info.PrintTypeSrcNote();
426       return false;
427     }
428   }
429 
430   return true;
431 }
432 
propagateTypes()433 void OperandTypeChecker::propagateTypes() {
434   for (auto *Pat : Pats) {
435     for (auto &Op : Pat->named_operands()) {
436       if (auto &Info = Types[Op.getOperandName()]; Info.Type)
437         Op.setType(Info.Type);
438     }
439   }
440 }
441 
PrintSeenWithTypeIn(InstructionPattern & P,StringRef OpName,PatternType Ty) const442 void OperandTypeChecker::PrintSeenWithTypeIn(InstructionPattern &P,
443                                              StringRef OpName,
444                                              PatternType Ty) const {
445   PrintNote(DiagLoc, "'" + OpName + "' seen with type '" + Ty.str() + "' in '" +
446                          P.getName() + "'");
447 }
448 
getParamKindStr(ParamKind OK)449 StringRef PatFrag::getParamKindStr(ParamKind OK) {
450   switch (OK) {
451   case PK_Root:
452     return "root";
453   case PK_MachineOperand:
454     return "machine_operand";
455   case PK_Imm:
456     return "imm";
457   }
458 
459   llvm_unreachable("Unknown operand kind!");
460 }
461 
462 //===- PatFrag -----------------------------------------------------------===//
463 
PatFrag(const Record & Def)464 PatFrag::PatFrag(const Record &Def) : Def(Def) {
465   assert(Def.isSubClassOf(ClassName));
466 }
467 
getName() const468 StringRef PatFrag::getName() const { return Def.getName(); }
469 
getLoc() const470 ArrayRef<SMLoc> PatFrag::getLoc() const { return Def.getLoc(); }
471 
addInParam(StringRef Name,ParamKind Kind)472 void PatFrag::addInParam(StringRef Name, ParamKind Kind) {
473   Params.emplace_back(Param{Name, Kind});
474 }
475 
in_params() const476 iterator_range<PatFrag::ParamIt> PatFrag::in_params() const {
477   return {Params.begin() + NumOutParams, Params.end()};
478 }
479 
addOutParam(StringRef Name,ParamKind Kind)480 void PatFrag::addOutParam(StringRef Name, ParamKind Kind) {
481   assert(NumOutParams == Params.size() &&
482          "Adding out-param after an in-param!");
483   Params.emplace_back(Param{Name, Kind});
484   ++NumOutParams;
485 }
486 
out_params() const487 iterator_range<PatFrag::ParamIt> PatFrag::out_params() const {
488   return {Params.begin(), Params.begin() + NumOutParams};
489 }
490 
num_roots() const491 unsigned PatFrag::num_roots() const {
492   return count_if(out_params(),
493                   [&](const auto &P) { return P.Kind == PK_Root; });
494 }
495 
getParamIdx(StringRef Name) const496 unsigned PatFrag::getParamIdx(StringRef Name) const {
497   for (const auto &[Idx, Op] : enumerate(Params)) {
498     if (Op.Name == Name)
499       return Idx;
500   }
501 
502   return -1;
503 }
504 
checkSemantics()505 bool PatFrag::checkSemantics() {
506   for (const auto &Alt : Alts) {
507     for (const auto &Pat : Alt.Pats) {
508       switch (Pat->getKind()) {
509       case Pattern::K_AnyOpcode:
510         PrintError("wip_match_opcode cannot be used in " + ClassName);
511         return false;
512       case Pattern::K_Builtin:
513         PrintError("Builtin instructions cannot be used in " + ClassName);
514         return false;
515       case Pattern::K_CXX:
516         continue;
517       case Pattern::K_CodeGenInstruction:
518         if (cast<CodeGenInstructionPattern>(Pat.get())->diagnoseAllSpecialTypes(
519                 Def.getLoc(), PatternType::SpecialTyClassName +
520                                   " is not supported in " + ClassName))
521           return false;
522         continue;
523       case Pattern::K_PatFrag:
524         // TODO: It's just that the emitter doesn't handle it but technically
525         // there is no reason why we can't. We just have to be careful with
526         // operand mappings, it could get complex.
527         PrintError("nested " + ClassName + " are not supported");
528         return false;
529       }
530     }
531   }
532 
533   StringSet<> SeenOps;
534   for (const auto &Op : in_params()) {
535     if (SeenOps.count(Op.Name)) {
536       PrintError("duplicate parameter '" + Op.Name + "'");
537       return false;
538     }
539 
540     // Check this operand is NOT defined in any alternative's patterns.
541     for (const auto &Alt : Alts) {
542       if (Alt.OpTable.lookup(Op.Name).Def) {
543         PrintError("input parameter '" + Op.Name + "' cannot be redefined!");
544         return false;
545       }
546     }
547 
548     if (Op.Kind == PK_Root) {
549       PrintError("input parameterr '" + Op.Name + "' cannot be a root!");
550       return false;
551     }
552 
553     SeenOps.insert(Op.Name);
554   }
555 
556   for (const auto &Op : out_params()) {
557     if (Op.Kind != PK_Root && Op.Kind != PK_MachineOperand) {
558       PrintError("output parameter '" + Op.Name +
559                  "' must be 'root' or 'gi_mo'");
560       return false;
561     }
562 
563     if (SeenOps.count(Op.Name)) {
564       PrintError("duplicate parameter '" + Op.Name + "'");
565       return false;
566     }
567 
568     // Check this operand is defined in all alternative's patterns.
569     for (const auto &Alt : Alts) {
570       const auto *OpDef = Alt.OpTable.getDef(Op.Name);
571       if (!OpDef) {
572         PrintError("output parameter '" + Op.Name +
573                    "' must be defined by all alternative patterns in '" +
574                    Def.getName() + "'");
575         return false;
576       }
577 
578       if (Op.Kind == PK_Root && OpDef->getNumInstDefs() != 1) {
579         // The instruction that defines the root must have a single def.
580         // Otherwise we'd need to support multiple roots and it gets messy.
581         //
582         // e.g. this is not supported:
583         //   (pattern (G_UNMERGE_VALUES $x, $root, $vec))
584         PrintError("all instructions that define root '" + Op.Name + "' in '" +
585                    Def.getName() + "' can only have a single output operand");
586         return false;
587       }
588     }
589 
590     SeenOps.insert(Op.Name);
591   }
592 
593   if (num_out_params() != 0 && num_roots() == 0) {
594     PrintError(ClassName + " must have one root in its 'out' operands");
595     return false;
596   }
597 
598   if (num_roots() > 1) {
599     PrintError(ClassName + " can only have one root");
600     return false;
601   }
602 
603   // TODO: find unused params
604 
605   const auto CheckTypeOf = [&](const PatternType &) -> bool {
606     llvm_unreachable("GITypeOf should have been rejected earlier!");
607   };
608 
609   // Now, typecheck all alternatives.
610   for (auto &Alt : Alts) {
611     OperandTypeChecker OTC(Def.getLoc());
612     for (auto &Pat : Alt.Pats) {
613       if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
614         if (!OTC.check(*IP, CheckTypeOf))
615           return false;
616       }
617     }
618     OTC.propagateTypes();
619   }
620 
621   return true;
622 }
623 
handleUnboundInParam(StringRef ParamName,StringRef ArgName,ArrayRef<SMLoc> DiagLoc) const624 bool PatFrag::handleUnboundInParam(StringRef ParamName, StringRef ArgName,
625                                    ArrayRef<SMLoc> DiagLoc) const {
626   // The parameter must be a live-in of all alternatives for this to work.
627   // Otherwise, we risk having unbound parameters being used (= crashes).
628   //
629   // Examples:
630   //
631   // in (ins $y), (patterns (G_FNEG $dst, $y), "return matchFnegOp(${y})")
632   //    even if $y is unbound, we'll lazily bind it when emitting the G_FNEG.
633   //
634   // in (ins $y), (patterns "return matchFnegOp(${y})")
635   //    if $y is unbound when this fragment is emitted, C++ code expansion will
636   //    fail.
637   for (const auto &Alt : Alts) {
638     auto &OT = Alt.OpTable;
639     if (!OT.lookup(ParamName).Found) {
640       llvm::PrintError(DiagLoc, "operand '" + ArgName + "' (for parameter '" +
641                                     ParamName + "' of '" + getName() +
642                                     "') cannot be unbound");
643       PrintNote(
644           DiagLoc,
645           "one or more alternatives of '" + getName() + "' do not bind '" +
646               ParamName +
647               "' to an instruction operand; either use a bound operand or "
648               "ensure '" +
649               Def.getName() + "' binds '" + ParamName +
650               "' in all alternatives");
651       return false;
652     }
653   }
654 
655   return true;
656 }
657 
buildOperandsTables()658 bool PatFrag::buildOperandsTables() {
659   // enumerate(...) doesn't seem to allow lvalues so we need to count the old
660   // way.
661   unsigned Idx = 0;
662 
663   const auto DiagnoseRedef = [this, &Idx](StringRef OpName) {
664     PrintError("Operand '" + OpName +
665                "' is defined multiple times in patterns of alternative #" +
666                std::to_string(Idx));
667   };
668 
669   for (auto &Alt : Alts) {
670     for (auto &Pat : Alt.Pats) {
671       auto *IP = dyn_cast<InstructionPattern>(Pat.get());
672       if (!IP)
673         continue;
674 
675       if (!Alt.OpTable.addPattern(IP, DiagnoseRedef))
676         return false;
677     }
678 
679     ++Idx;
680   }
681 
682   return true;
683 }
684 
print(raw_ostream & OS,StringRef Indent) const685 void PatFrag::print(raw_ostream &OS, StringRef Indent) const {
686   OS << Indent << "(PatFrag name:" << getName() << '\n';
687   if (!in_params().empty()) {
688     OS << Indent << "  (ins ";
689     printParamsList(OS, in_params());
690     OS << ")\n";
691   }
692 
693   if (!out_params().empty()) {
694     OS << Indent << "  (outs ";
695     printParamsList(OS, out_params());
696     OS << ")\n";
697   }
698 
699   // TODO: Dump OperandTable as well.
700   OS << Indent << "  (alternatives [\n";
701   for (const auto &Alt : Alts) {
702     OS << Indent << "    [\n";
703     for (const auto &Pat : Alt.Pats) {
704       OS << Indent << "      ";
705       Pat->print(OS, /*PrintName=*/true);
706       OS << ",\n";
707     }
708     OS << Indent << "    ],\n";
709   }
710   OS << Indent << "  ])\n";
711 
712   OS << Indent << ')';
713 }
714 
dump() const715 void PatFrag::dump() const { print(dbgs()); }
716 
printParamsList(raw_ostream & OS,iterator_range<ParamIt> Params)717 void PatFrag::printParamsList(raw_ostream &OS, iterator_range<ParamIt> Params) {
718   OS << '['
719      << join(map_range(Params,
720                        [](auto &O) {
721                          return (O.Name + ":" + getParamKindStr(O.Kind)).str();
722                        }),
723              ", ")
724      << ']';
725 }
726 
PrintError(Twine Msg) const727 void PatFrag::PrintError(Twine Msg) const { llvm::PrintError(&Def, Msg); }
728 
getApplyDefsNeeded() const729 ArrayRef<InstructionOperand> PatFragPattern::getApplyDefsNeeded() const {
730   assert(PF.num_roots() == 1);
731   // Only roots need to be redef.
732   for (auto [Idx, Param] : enumerate(PF.out_params())) {
733     if (Param.Kind == PatFrag::PK_Root)
734       return getOperand(Idx);
735   }
736   llvm_unreachable("root not found!");
737 }
738 
739 //===- PatFragPattern -----------------------------------------------------===//
740 
checkSemantics(ArrayRef<SMLoc> DiagLoc)741 bool PatFragPattern::checkSemantics(ArrayRef<SMLoc> DiagLoc) {
742   if (!InstructionPattern::checkSemantics(DiagLoc))
743     return false;
744 
745   for (const auto &[Idx, Op] : enumerate(Operands)) {
746     switch (PF.getParam(Idx).Kind) {
747     case PatFrag::PK_Imm:
748       if (!Op.hasImmValue()) {
749         PrintError(DiagLoc, "expected operand " + std::to_string(Idx) +
750                                 " of '" + getInstName() +
751                                 "' to be an immediate; got " + Op.describe());
752         return false;
753       }
754       if (Op.isNamedImmediate()) {
755         PrintError(DiagLoc, "operand " + std::to_string(Idx) + " of '" +
756                                 getInstName() +
757                                 "' cannot be a named immediate");
758         return false;
759       }
760       break;
761     case PatFrag::PK_Root:
762     case PatFrag::PK_MachineOperand:
763       if (!Op.isNamedOperand() || Op.isNamedImmediate()) {
764         PrintError(DiagLoc, "expected operand " + std::to_string(Idx) +
765                                 " of '" + getInstName() +
766                                 "' to be a MachineOperand; got " +
767                                 Op.describe());
768         return false;
769       }
770       break;
771     }
772   }
773 
774   return true;
775 }
776 
mapInputCodeExpansions(const CodeExpansions & ParentCEs,CodeExpansions & PatFragCEs,ArrayRef<SMLoc> DiagLoc) const777 bool PatFragPattern::mapInputCodeExpansions(const CodeExpansions &ParentCEs,
778                                             CodeExpansions &PatFragCEs,
779                                             ArrayRef<SMLoc> DiagLoc) const {
780   for (const auto &[Idx, Op] : enumerate(operands())) {
781     StringRef ParamName = PF.getParam(Idx).Name;
782 
783     // Operands to a PFP can only be named, or be an immediate, but not a named
784     // immediate.
785     assert(!Op.isNamedImmediate());
786 
787     if (Op.isNamedOperand()) {
788       StringRef ArgName = Op.getOperandName();
789       // Map it only if it's been defined.
790       auto It = ParentCEs.find(ArgName);
791       if (It == ParentCEs.end()) {
792         if (!PF.handleUnboundInParam(ParamName, ArgName, DiagLoc))
793           return false;
794       } else
795         PatFragCEs.declare(ParamName, It->second);
796       continue;
797     }
798 
799     if (Op.hasImmValue()) {
800       PatFragCEs.declare(ParamName, std::to_string(Op.getImmValue()));
801       continue;
802     }
803 
804     llvm_unreachable("Unknown Operand Type!");
805   }
806 
807   return true;
808 }
809 
810 //===- BuiltinPattern -----------------------------------------------------===//
811 
getBuiltinInfo(const Record & Def)812 BuiltinPattern::BuiltinInfo BuiltinPattern::getBuiltinInfo(const Record &Def) {
813   assert(Def.isSubClassOf(ClassName));
814 
815   StringRef Name = Def.getName();
816   for (const auto &KBI : KnownBuiltins) {
817     if (KBI.DefName == Name)
818       return KBI;
819   }
820 
821   PrintFatalError(Def.getLoc(),
822                   "Unimplemented " + ClassName + " def '" + Name + "'");
823 }
824 
checkSemantics(ArrayRef<SMLoc> Loc)825 bool BuiltinPattern::checkSemantics(ArrayRef<SMLoc> Loc) {
826   if (!InstructionPattern::checkSemantics(Loc))
827     return false;
828 
829   // For now all builtins just take names, no immediates.
830   for (const auto &[Idx, Op] : enumerate(operands())) {
831     if (!Op.isNamedOperand() || Op.isNamedImmediate()) {
832       PrintError(Loc, "expected operand " + std::to_string(Idx) + " of '" +
833                           getInstName() + "' to be a name");
834       return false;
835     }
836   }
837 
838   return true;
839 }
840 
841 } // namespace gi
842 } // namespace llvm
843