1 //===- VarLenCodeEmitterGen.cpp - CEG for variable-length insts -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The CodeEmitterGen component for variable-length instructions.
10 //
11 // The basic CodeEmitterGen is almost exclusively designed for fixed-
12 // length instructions. A good analogy for its encoding scheme is how printf
13 // works: The (immutable) formatting string represent the fixed values in the
14 // encoded instruction. Placeholders (i.e. %something), on the other hand,
15 // represent encoding for instruction operands.
16 // ```
17 // printf("1101 %src 1001 %dst", <encoded value for operand `src`>,
18 //                               <encoded value for operand `dst`>);
19 // ```
20 // VarLenCodeEmitterGen in this file provides an alternative encoding scheme
21 // that works more like a C++ stream operator:
22 // ```
23 // OS << 0b1101;
24 // if (Cond)
25 //   OS << OperandEncoding0;
26 // OS << 0b1001 << OperandEncoding1;
27 // ```
28 // You are free to concatenate arbitrary types (and sizes) of encoding
29 // fragments on any bit position, bringing more flexibilities on defining
30 // encoding for variable-length instructions.
31 //
32 // In a more specific way, instruction encoding is represented by a DAG type
33 // `Inst` field. Here is an example:
34 // ```
35 // dag Inst = (descend 0b1101, (operand "$src", 4), 0b1001,
36 //                     (operand "$dst", 4));
37 // ```
38 // It represents the following instruction encoding:
39 // ```
40 // MSB                                                     LSB
41 // 1101<encoding for operand src>1001<encoding for operand dst>
42 // ```
43 // For more details about DAG operators in the above snippet, please
44 // refer to \file include/llvm/Target/Target.td.
45 //
46 // VarLenCodeEmitter will convert the above DAG into the same helper function
47 // generated by CodeEmitter, `MCCodeEmitter::getBinaryCodeForInstr` (except
48 // for few details).
49 //
50 //===----------------------------------------------------------------------===//
51 
52 #include "VarLenCodeEmitterGen.h"
53 #include "CodeGenHwModes.h"
54 #include "CodeGenInstruction.h"
55 #include "CodeGenTarget.h"
56 #include "InfoByHwMode.h"
57 #include "llvm/ADT/ArrayRef.h"
58 #include "llvm/ADT/DenseMap.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/TableGen/Error.h"
61 #include "llvm/TableGen/Record.h"
62 
63 using namespace llvm;
64 
65 namespace {
66 
67 class VarLenCodeEmitterGen {
68   RecordKeeper &Records;
69 
70   DenseMap<Record *, VarLenInst> VarLenInsts;
71 
72   // Emit based values (i.e. fixed bits in the encoded instructions)
73   void emitInstructionBaseValues(
74       raw_ostream &OS,
75       ArrayRef<const CodeGenInstruction *> NumberedInstructions,
76       CodeGenTarget &Target, int HwMode = -1);
77 
78   std::string getInstructionCase(Record *R, CodeGenTarget &Target);
79   std::string getInstructionCaseForEncoding(Record *R, Record *EncodingDef,
80                                             CodeGenTarget &Target);
81 
82 public:
83   explicit VarLenCodeEmitterGen(RecordKeeper &R) : Records(R) {}
84 
85   void run(raw_ostream &OS);
86 };
87 } // end anonymous namespace
88 
89 // Get the name of custom encoder or decoder, if there is any.
90 // Returns `{encoder name, decoder name}`.
91 static std::pair<StringRef, StringRef> getCustomCoders(ArrayRef<Init *> Args) {
92   std::pair<StringRef, StringRef> Result;
93   for (const auto *Arg : Args) {
94     const auto *DI = dyn_cast<DagInit>(Arg);
95     if (!DI)
96       continue;
97     const Init *Op = DI->getOperator();
98     if (!isa<DefInit>(Op))
99       continue;
100     // syntax: `(<encoder | decoder> "function name")`
101     StringRef OpName = cast<DefInit>(Op)->getDef()->getName();
102     if (OpName != "encoder" && OpName != "decoder")
103       continue;
104     if (!DI->getNumArgs() || !isa<StringInit>(DI->getArg(0)))
105       PrintFatalError("expected '" + OpName +
106                       "' directive to be followed by a custom function name.");
107     StringRef FuncName = cast<StringInit>(DI->getArg(0))->getValue();
108     if (OpName == "encoder")
109       Result.first = FuncName;
110     else
111       Result.second = FuncName;
112   }
113   return Result;
114 }
115 
116 VarLenInst::VarLenInst(const DagInit *DI, const RecordVal *TheDef)
117     : TheDef(TheDef), NumBits(0U) {
118   buildRec(DI);
119   for (const auto &S : Segments)
120     NumBits += S.BitWidth;
121 }
122 
123 void VarLenInst::buildRec(const DagInit *DI) {
124   assert(TheDef && "The def record is nullptr ?");
125 
126   std::string Op = DI->getOperator()->getAsString();
127 
128   if (Op == "ascend" || Op == "descend") {
129     bool Reverse = Op == "descend";
130     int i = Reverse ? DI->getNumArgs() - 1 : 0;
131     int e = Reverse ? -1 : DI->getNumArgs();
132     int s = Reverse ? -1 : 1;
133     for (; i != e; i += s) {
134       const Init *Arg = DI->getArg(i);
135       if (const auto *BI = dyn_cast<BitsInit>(Arg)) {
136         if (!BI->isComplete())
137           PrintFatalError(TheDef->getLoc(),
138                           "Expecting complete bits init in `" + Op + "`");
139         Segments.push_back({BI->getNumBits(), BI});
140       } else if (const auto *BI = dyn_cast<BitInit>(Arg)) {
141         if (!BI->isConcrete())
142           PrintFatalError(TheDef->getLoc(),
143                           "Expecting concrete bit init in `" + Op + "`");
144         Segments.push_back({1, BI});
145       } else if (const auto *SubDI = dyn_cast<DagInit>(Arg)) {
146         buildRec(SubDI);
147       } else {
148         PrintFatalError(TheDef->getLoc(), "Unrecognized type of argument in `" +
149                                               Op + "`: " + Arg->getAsString());
150       }
151     }
152   } else if (Op == "operand") {
153     // (operand <operand name>, <# of bits>,
154     //          [(encoder <custom encoder>)][, (decoder <custom decoder>)])
155     if (DI->getNumArgs() < 2)
156       PrintFatalError(TheDef->getLoc(),
157                       "Expecting at least 2 arguments for `operand`");
158     HasDynamicSegment = true;
159     const Init *OperandName = DI->getArg(0), *NumBits = DI->getArg(1);
160     if (!isa<StringInit>(OperandName) || !isa<IntInit>(NumBits))
161       PrintFatalError(TheDef->getLoc(), "Invalid argument types for `operand`");
162 
163     auto NumBitsVal = cast<IntInit>(NumBits)->getValue();
164     if (NumBitsVal <= 0)
165       PrintFatalError(TheDef->getLoc(), "Invalid number of bits for `operand`");
166 
167     auto [CustomEncoder, CustomDecoder] =
168         getCustomCoders(DI->getArgs().slice(2));
169     Segments.push_back({static_cast<unsigned>(NumBitsVal), OperandName,
170                         CustomEncoder, CustomDecoder});
171   } else if (Op == "slice") {
172     // (slice <operand name>, <high / low bit>, <low / high bit>,
173     //        [(encoder <custom encoder>)][, (decoder <custom decoder>)])
174     if (DI->getNumArgs() < 3)
175       PrintFatalError(TheDef->getLoc(),
176                       "Expecting at least 3 arguments for `slice`");
177     HasDynamicSegment = true;
178     Init *OperandName = DI->getArg(0), *HiBit = DI->getArg(1),
179          *LoBit = DI->getArg(2);
180     if (!isa<StringInit>(OperandName) || !isa<IntInit>(HiBit) ||
181         !isa<IntInit>(LoBit))
182       PrintFatalError(TheDef->getLoc(), "Invalid argument types for `slice`");
183 
184     auto HiBitVal = cast<IntInit>(HiBit)->getValue(),
185          LoBitVal = cast<IntInit>(LoBit)->getValue();
186     if (HiBitVal < 0 || LoBitVal < 0)
187       PrintFatalError(TheDef->getLoc(), "Invalid bit range for `slice`");
188     bool NeedSwap = false;
189     unsigned NumBits = 0U;
190     if (HiBitVal < LoBitVal) {
191       NeedSwap = true;
192       NumBits = static_cast<unsigned>(LoBitVal - HiBitVal + 1);
193     } else {
194       NumBits = static_cast<unsigned>(HiBitVal - LoBitVal + 1);
195     }
196 
197     auto [CustomEncoder, CustomDecoder] =
198         getCustomCoders(DI->getArgs().slice(3));
199 
200     if (NeedSwap) {
201       // Normalization: Hi bit should always be the second argument.
202       Init *const NewArgs[] = {OperandName, LoBit, HiBit};
203       Segments.push_back({NumBits,
204                           DagInit::get(DI->getOperator(), nullptr, NewArgs, {}),
205                           CustomEncoder, CustomDecoder});
206     } else {
207       Segments.push_back({NumBits, DI, CustomEncoder, CustomDecoder});
208     }
209   }
210 }
211 
212 void VarLenCodeEmitterGen::run(raw_ostream &OS) {
213   CodeGenTarget Target(Records);
214   auto Insts = Records.getAllDerivedDefinitions("Instruction");
215 
216   auto NumberedInstructions = Target.getInstructionsByEnumValue();
217   const CodeGenHwModes &HWM = Target.getHwModes();
218 
219   // The set of HwModes used by instruction encodings.
220   std::set<unsigned> HwModes;
221   for (const CodeGenInstruction *CGI : NumberedInstructions) {
222     Record *R = CGI->TheDef;
223 
224     // Create the corresponding VarLenInst instance.
225     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
226         R->getValueAsBit("isPseudo"))
227       continue;
228 
229     if (const RecordVal *RV = R->getValue("EncodingInfos")) {
230       if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
231         EncodingInfoByHwMode EBM(DI->getDef(), HWM);
232         for (auto &KV : EBM) {
233           HwModes.insert(KV.first);
234           Record *EncodingDef = KV.second;
235           RecordVal *RV = EncodingDef->getValue("Inst");
236           DagInit *DI = cast<DagInit>(RV->getValue());
237           VarLenInsts.insert({EncodingDef, VarLenInst(DI, RV)});
238         }
239         continue;
240       }
241     }
242     RecordVal *RV = R->getValue("Inst");
243     DagInit *DI = cast<DagInit>(RV->getValue());
244     VarLenInsts.insert({R, VarLenInst(DI, RV)});
245   }
246 
247   // Emit function declaration
248   OS << "void " << Target.getName()
249      << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
250      << "    SmallVectorImpl<MCFixup> &Fixups,\n"
251      << "    APInt &Inst,\n"
252      << "    APInt &Scratch,\n"
253      << "    const MCSubtargetInfo &STI) const {\n";
254 
255   // Emit instruction base values
256   if (HwModes.empty()) {
257     emitInstructionBaseValues(OS, NumberedInstructions, Target);
258   } else {
259     for (unsigned HwMode : HwModes)
260       emitInstructionBaseValues(OS, NumberedInstructions, Target, (int)HwMode);
261   }
262 
263   if (!HwModes.empty()) {
264     OS << "  const unsigned **Index;\n";
265     OS << "  const uint64_t *InstBits;\n";
266     OS << "  unsigned HwMode = STI.getHwMode();\n";
267     OS << "  switch (HwMode) {\n";
268     OS << "  default: llvm_unreachable(\"Unknown hardware mode!\"); break;\n";
269     for (unsigned I : HwModes) {
270       OS << "  case " << I << ": InstBits = InstBits_" << HWM.getMode(I).Name
271          << "; Index = Index_" << HWM.getMode(I).Name << "; break;\n";
272     }
273     OS << "  };\n";
274   }
275 
276   // Emit helper function to retrieve base values.
277   OS << "  auto getInstBits = [&](unsigned Opcode) -> APInt {\n"
278      << "    unsigned NumBits = Index[Opcode][0];\n"
279      << "    if (!NumBits)\n"
280      << "      return APInt::getZeroWidth();\n"
281      << "    unsigned Idx = Index[Opcode][1];\n"
282      << "    ArrayRef<uint64_t> Data(&InstBits[Idx], "
283      << "APInt::getNumWords(NumBits));\n"
284      << "    return APInt(NumBits, Data);\n"
285      << "  };\n";
286 
287   // Map to accumulate all the cases.
288   std::map<std::string, std::vector<std::string>> CaseMap;
289 
290   // Construct all cases statement for each opcode
291   for (Record *R : Insts) {
292     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
293         R->getValueAsBit("isPseudo"))
294       continue;
295     std::string InstName =
296         (R->getValueAsString("Namespace") + "::" + R->getName()).str();
297     std::string Case = getInstructionCase(R, Target);
298 
299     CaseMap[Case].push_back(std::move(InstName));
300   }
301 
302   // Emit initial function code
303   OS << "  const unsigned opcode = MI.getOpcode();\n"
304      << "  switch (opcode) {\n";
305 
306   // Emit each case statement
307   for (const auto &C : CaseMap) {
308     const std::string &Case = C.first;
309     const auto &InstList = C.second;
310 
311     ListSeparator LS("\n");
312     for (const auto &InstName : InstList)
313       OS << LS << "    case " << InstName << ":";
314 
315     OS << " {\n";
316     OS << Case;
317     OS << "      break;\n"
318        << "    }\n";
319   }
320   // Default case: unhandled opcode
321   OS << "  default:\n"
322      << "    std::string msg;\n"
323      << "    raw_string_ostream Msg(msg);\n"
324      << "    Msg << \"Not supported instr: \" << MI;\n"
325      << "    report_fatal_error(Msg.str().c_str());\n"
326      << "  }\n";
327   OS << "}\n\n";
328 }
329 
330 static void emitInstBits(raw_ostream &IS, raw_ostream &SS, const APInt &Bits,
331                          unsigned &Index) {
332   if (!Bits.getNumWords()) {
333     IS.indent(4) << "{/*NumBits*/0, /*Index*/0},";
334     return;
335   }
336 
337   IS.indent(4) << "{/*NumBits*/" << Bits.getBitWidth() << ", "
338                << "/*Index*/" << Index << "},";
339 
340   SS.indent(4);
341   for (unsigned I = 0; I < Bits.getNumWords(); ++I, ++Index)
342     SS << "UINT64_C(" << utostr(Bits.getRawData()[I]) << "),";
343 }
344 
345 void VarLenCodeEmitterGen::emitInstructionBaseValues(
346     raw_ostream &OS, ArrayRef<const CodeGenInstruction *> NumberedInstructions,
347     CodeGenTarget &Target, int HwMode) {
348   std::string IndexArray, StorageArray;
349   raw_string_ostream IS(IndexArray), SS(StorageArray);
350 
351   const CodeGenHwModes &HWM = Target.getHwModes();
352   if (HwMode == -1) {
353     IS << "  static const unsigned Index[][2] = {\n";
354     SS << "  static const uint64_t InstBits[] = {\n";
355   } else {
356     StringRef Name = HWM.getMode(HwMode).Name;
357     IS << "  static const unsigned Index_" << Name << "[][2] = {\n";
358     SS << "  static const uint64_t InstBits_" << Name << "[] = {\n";
359   }
360 
361   unsigned NumFixedValueWords = 0U;
362   for (const CodeGenInstruction *CGI : NumberedInstructions) {
363     Record *R = CGI->TheDef;
364 
365     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
366         R->getValueAsBit("isPseudo")) {
367       IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\n";
368       continue;
369     }
370 
371     Record *EncodingDef = R;
372     if (const RecordVal *RV = R->getValue("EncodingInfos")) {
373       if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
374         EncodingInfoByHwMode EBM(DI->getDef(), HWM);
375         if (EBM.hasMode(HwMode))
376           EncodingDef = EBM.get(HwMode);
377       }
378     }
379 
380     auto It = VarLenInsts.find(EncodingDef);
381     if (It == VarLenInsts.end())
382       PrintFatalError(EncodingDef, "VarLenInst not found for this record");
383     const VarLenInst &VLI = It->second;
384 
385     unsigned i = 0U, BitWidth = VLI.size();
386 
387     // Start by filling in fixed values.
388     APInt Value(BitWidth, 0);
389     auto SI = VLI.begin(), SE = VLI.end();
390     // Scan through all the segments that have fixed-bits values.
391     while (i < BitWidth && SI != SE) {
392       unsigned SegmentNumBits = SI->BitWidth;
393       if (const auto *BI = dyn_cast<BitsInit>(SI->Value)) {
394         for (unsigned Idx = 0U; Idx != SegmentNumBits; ++Idx) {
395           auto *B = cast<BitInit>(BI->getBit(Idx));
396           Value.setBitVal(i + Idx, B->getValue());
397         }
398       }
399       if (const auto *BI = dyn_cast<BitInit>(SI->Value))
400         Value.setBitVal(i, BI->getValue());
401 
402       i += SegmentNumBits;
403       ++SI;
404     }
405 
406     emitInstBits(IS, SS, Value, NumFixedValueWords);
407     IS << '\t' << "// " << R->getName() << "\n";
408     if (Value.getNumWords())
409       SS << '\t' << "// " << R->getName() << "\n";
410   }
411   IS.indent(4) << "{/*NumBits*/0, /*Index*/0}\n  };\n";
412   SS.indent(4) << "UINT64_C(0)\n  };\n";
413 
414   OS << IS.str() << SS.str();
415 }
416 
417 std::string VarLenCodeEmitterGen::getInstructionCase(Record *R,
418                                                      CodeGenTarget &Target) {
419   std::string Case;
420   if (const RecordVal *RV = R->getValue("EncodingInfos")) {
421     if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
422       const CodeGenHwModes &HWM = Target.getHwModes();
423       EncodingInfoByHwMode EBM(DI->getDef(), HWM);
424       Case += "      switch (HwMode) {\n";
425       Case += "      default: llvm_unreachable(\"Unhandled HwMode\");\n";
426       for (auto &KV : EBM) {
427         Case += "      case " + itostr(KV.first) + ": {\n";
428         Case += getInstructionCaseForEncoding(R, KV.second, Target);
429         Case += "      break;\n";
430         Case += "      }\n";
431       }
432       Case += "      }\n";
433       return Case;
434     }
435   }
436   return getInstructionCaseForEncoding(R, R, Target);
437 }
438 
439 std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
440     Record *R, Record *EncodingDef, CodeGenTarget &Target) {
441   auto It = VarLenInsts.find(EncodingDef);
442   if (It == VarLenInsts.end())
443     PrintFatalError(EncodingDef, "Parsed encoding record not found");
444   const VarLenInst &VLI = It->second;
445   size_t BitWidth = VLI.size();
446 
447   CodeGenInstruction &CGI = Target.getInstruction(R);
448 
449   std::string Case;
450   raw_string_ostream SS(Case);
451   // Resize the scratch buffer.
452   if (BitWidth && !VLI.isFixedValueOnly())
453     SS.indent(6) << "Scratch = Scratch.zext(" << BitWidth << ");\n";
454   // Populate based value.
455   SS.indent(6) << "Inst = getInstBits(opcode);\n";
456 
457   // Process each segment in VLI.
458   size_t Offset = 0U;
459   for (const auto &ES : VLI) {
460     unsigned NumBits = ES.BitWidth;
461     const Init *Val = ES.Value;
462     // If it's a StringInit or DagInit, it's a reference to an operand
463     // or part of an operand.
464     if (isa<StringInit>(Val) || isa<DagInit>(Val)) {
465       StringRef OperandName;
466       unsigned LoBit = 0U;
467       if (const auto *SV = dyn_cast<StringInit>(Val)) {
468         OperandName = SV->getValue();
469       } else {
470         // Normalized: (slice <operand name>, <high bit>, <low bit>)
471         const auto *DV = cast<DagInit>(Val);
472         OperandName = cast<StringInit>(DV->getArg(0))->getValue();
473         LoBit = static_cast<unsigned>(cast<IntInit>(DV->getArg(2))->getValue());
474       }
475 
476       auto OpIdx = CGI.Operands.ParseOperandName(OperandName);
477       unsigned FlatOpIdx = CGI.Operands.getFlattenedOperandNumber(OpIdx);
478       StringRef CustomEncoder =
479           CGI.Operands[OpIdx.first].EncoderMethodNames[OpIdx.second];
480       if (ES.CustomEncoder.size())
481         CustomEncoder = ES.CustomEncoder;
482 
483       SS.indent(6) << "Scratch.clearAllBits();\n";
484       SS.indent(6) << "// op: " << OperandName.drop_front(1) << "\n";
485       if (CustomEncoder.empty())
486         SS.indent(6) << "getMachineOpValue(MI, MI.getOperand("
487                      << utostr(FlatOpIdx) << ")";
488       else
489         SS.indent(6) << CustomEncoder << "(MI, /*OpIdx=*/" << utostr(FlatOpIdx);
490 
491       SS << ", /*Pos=*/" << utostr(Offset) << ", Scratch, Fixups, STI);\n";
492 
493       SS.indent(6) << "Inst.insertBits("
494                    << "Scratch.extractBits(" << utostr(NumBits) << ", "
495                    << utostr(LoBit) << ")"
496                    << ", " << Offset << ");\n";
497     }
498     Offset += NumBits;
499   }
500 
501   StringRef PostEmitter = R->getValueAsString("PostEncoderMethod");
502   if (!PostEmitter.empty())
503     SS.indent(6) << "Inst = " << PostEmitter << "(MI, Inst, STI);\n";
504 
505   return Case;
506 }
507 
508 namespace llvm {
509 
510 void emitVarLenCodeEmitter(RecordKeeper &R, raw_ostream &OS) {
511   VarLenCodeEmitterGen(R).run(OS);
512 }
513 
514 } // end namespace llvm
515