1 //===-- X86FixupVectorConstants.cpp - optimize constant generation  -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file examines all full size vector constant pool loads and attempts to
10 // replace them with smaller constant pool entries, including:
11 // * Converting AVX512 memory-fold instructions to their broadcast-fold form
12 // * Broadcasting of full width loads.
13 // * TODO: Sign/Zero extension of full width loads.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "X86.h"
18 #include "X86InstrFoldTables.h"
19 #include "X86InstrInfo.h"
20 #include "X86Subtarget.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/CodeGen/MachineConstantPool.h"
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "x86-fixup-vector-constants"
27 
28 STATISTIC(NumInstChanges, "Number of instructions changes");
29 
30 namespace {
31 class X86FixupVectorConstantsPass : public MachineFunctionPass {
32 public:
33   static char ID;
34 
35   X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {}
36 
37   StringRef getPassName() const override {
38     return "X86 Fixup Vector Constants";
39   }
40 
41   bool runOnMachineFunction(MachineFunction &MF) override;
42   bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB,
43                           MachineInstr &MI);
44 
45   // This pass runs after regalloc and doesn't support VReg operands.
46   MachineFunctionProperties getRequiredProperties() const override {
47     return MachineFunctionProperties().set(
48         MachineFunctionProperties::Property::NoVRegs);
49   }
50 
51 private:
52   const X86InstrInfo *TII = nullptr;
53   const X86Subtarget *ST = nullptr;
54   const MCSchedModel *SM = nullptr;
55 };
56 } // end anonymous namespace
57 
58 char X86FixupVectorConstantsPass::ID = 0;
59 
60 INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false)
61 
62 FunctionPass *llvm::createX86FixupVectorConstants() {
63   return new X86FixupVectorConstantsPass();
64 }
65 
66 // Attempt to extract the full width of bits data from the constant.
67 static std::optional<APInt> extractConstantBits(const Constant *C) {
68   unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
69 
70   if (auto *CInt = dyn_cast<ConstantInt>(C))
71     return CInt->getValue();
72 
73   if (auto *CFP = dyn_cast<ConstantFP>(C))
74     return CFP->getValue().bitcastToAPInt();
75 
76   if (auto *CV = dyn_cast<ConstantVector>(C)) {
77     if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
78       if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
79         assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
80         return APInt::getSplat(NumBits, *Bits);
81       }
82     }
83   }
84 
85   if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
86     bool IsInteger = CDS->getElementType()->isIntegerTy();
87     bool IsFloat = CDS->getElementType()->isHalfTy() ||
88                    CDS->getElementType()->isBFloatTy() ||
89                    CDS->getElementType()->isFloatTy() ||
90                    CDS->getElementType()->isDoubleTy();
91     if (IsInteger || IsFloat) {
92       APInt Bits = APInt::getZero(NumBits);
93       unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits();
94       for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
95         if (IsInteger)
96           Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
97         else
98           Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
99                           I * EltBits);
100       }
101       return Bits;
102     }
103   }
104 
105   return std::nullopt;
106 }
107 
108 // Attempt to compute the splat width of bits data by normalizing the splat to
109 // remove undefs.
110 static std::optional<APInt> getSplatableConstant(const Constant *C,
111                                                  unsigned SplatBitWidth) {
112   const Type *Ty = C->getType();
113   assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
114          "Illegal splat width");
115 
116   if (std::optional<APInt> Bits = extractConstantBits(C))
117     if (Bits->isSplat(SplatBitWidth))
118       return Bits->trunc(SplatBitWidth);
119 
120   // Detect general splats with undefs.
121   // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting?
122   if (auto *CV = dyn_cast<ConstantVector>(C)) {
123     unsigned NumOps = CV->getNumOperands();
124     unsigned NumEltsBits = Ty->getScalarSizeInBits();
125     unsigned NumScaleOps = SplatBitWidth / NumEltsBits;
126     if ((SplatBitWidth % NumEltsBits) == 0) {
127       // Collect the elements and ensure that within the repeated splat sequence
128       // they either match or are undef.
129       SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr);
130       for (unsigned Idx = 0; Idx != NumOps; ++Idx) {
131         if (Constant *Elt = CV->getAggregateElement(Idx)) {
132           if (isa<UndefValue>(Elt))
133             continue;
134           unsigned SplatIdx = Idx % NumScaleOps;
135           if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) {
136             Sequence[SplatIdx] = Elt;
137             continue;
138           }
139         }
140         return std::nullopt;
141       }
142       // Extract the constant bits forming the splat and insert into the bits
143       // data, leave undef as zero.
144       APInt SplatBits = APInt::getZero(SplatBitWidth);
145       for (unsigned I = 0; I != NumScaleOps; ++I) {
146         if (!Sequence[I])
147           continue;
148         if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) {
149           SplatBits.insertBits(*Bits, I * Bits->getBitWidth());
150           continue;
151         }
152         return std::nullopt;
153       }
154       return SplatBits;
155     }
156   }
157 
158   return std::nullopt;
159 }
160 
161 // Split raw bits into a constant vector of elements of a specific bit width.
162 // NOTE: We don't always bother converting to scalars if the vector length is 1.
163 static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
164                                  const APInt &Bits, unsigned NumSclBits) {
165   unsigned BitWidth = Bits.getBitWidth();
166 
167   if (NumSclBits == 8) {
168     SmallVector<uint8_t> RawBits;
169     for (unsigned I = 0; I != BitWidth; I += 8)
170       RawBits.push_back(Bits.extractBits(8, I).getZExtValue());
171     return ConstantDataVector::get(Ctx, RawBits);
172   }
173 
174   if (NumSclBits == 16) {
175     SmallVector<uint16_t> RawBits;
176     for (unsigned I = 0; I != BitWidth; I += 16)
177       RawBits.push_back(Bits.extractBits(16, I).getZExtValue());
178     if (SclTy->is16bitFPTy())
179       return ConstantDataVector::getFP(SclTy, RawBits);
180     return ConstantDataVector::get(Ctx, RawBits);
181   }
182 
183   if (NumSclBits == 32) {
184     SmallVector<uint32_t> RawBits;
185     for (unsigned I = 0; I != BitWidth; I += 32)
186       RawBits.push_back(Bits.extractBits(32, I).getZExtValue());
187     if (SclTy->isFloatTy())
188       return ConstantDataVector::getFP(SclTy, RawBits);
189     return ConstantDataVector::get(Ctx, RawBits);
190   }
191 
192   assert(NumSclBits == 64 && "Unhandled vector element width");
193 
194   SmallVector<uint64_t> RawBits;
195   for (unsigned I = 0; I != BitWidth; I += 64)
196     RawBits.push_back(Bits.extractBits(64, I).getZExtValue());
197   if (SclTy->isDoubleTy())
198     return ConstantDataVector::getFP(SclTy, RawBits);
199   return ConstantDataVector::get(Ctx, RawBits);
200 }
201 
202 // Attempt to rebuild a normalized splat vector constant of the requested splat
203 // width, built up of potentially smaller scalar values.
204 static Constant *rebuildSplatableConstant(const Constant *C,
205                                           unsigned SplatBitWidth) {
206   std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
207   if (!Splat)
208     return nullptr;
209 
210   // Determine scalar size to use for the constant splat vector, clamping as we
211   // might have found a splat smaller than the original constant data.
212   const Type *OriginalType = C->getType();
213   Type *SclTy = OriginalType->getScalarType();
214   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
215   NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
216 
217   // Fallback to i64 / double.
218   NumSclBits = (NumSclBits == 8 || NumSclBits == 16 || NumSclBits == 32)
219                    ? NumSclBits
220                    : 64;
221 
222   // Extract per-element bits.
223   return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
224 }
225 
226 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
227                                                      MachineBasicBlock &MBB,
228                                                      MachineInstr &MI) {
229   unsigned Opc = MI.getOpcode();
230   MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
231   bool HasAVX2 = ST->hasAVX2();
232   bool HasDQI = ST->hasDQI();
233   bool HasBWI = ST->hasBWI();
234   bool HasVLX = ST->hasVLX();
235 
236   auto ConvertToBroadcast = [&](unsigned OpBcst256, unsigned OpBcst128,
237                                 unsigned OpBcst64, unsigned OpBcst32,
238                                 unsigned OpBcst16, unsigned OpBcst8,
239                                 unsigned OperandNo) {
240     assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
241            "Unexpected number of operands!");
242 
243     if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
244       // Attempt to detect a suitable splat from increasing splat widths.
245       std::pair<unsigned, unsigned> Broadcasts[] = {
246           {8, OpBcst8},   {16, OpBcst16},   {32, OpBcst32},
247           {64, OpBcst64}, {128, OpBcst128}, {256, OpBcst256},
248       };
249       for (auto [BitWidth, OpBcst] : Broadcasts) {
250         if (OpBcst) {
251           // Construct a suitable splat constant and adjust the MI to
252           // use the new constant pool entry.
253           if (Constant *NewCst = rebuildSplatableConstant(C, BitWidth)) {
254             unsigned NewCPI =
255                 CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8));
256             MI.setDesc(TII->get(OpBcst));
257             MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
258             return true;
259           }
260         }
261       }
262     }
263     return false;
264   };
265 
266   // Attempt to convert full width vector loads into broadcast loads.
267   switch (Opc) {
268   /* FP Loads */
269   case X86::MOVAPDrm:
270   case X86::MOVAPSrm:
271   case X86::MOVUPDrm:
272   case X86::MOVUPSrm:
273     // TODO: SSE3 MOVDDUP Handling
274     return false;
275   case X86::VMOVAPDrm:
276   case X86::VMOVAPSrm:
277   case X86::VMOVUPDrm:
278   case X86::VMOVUPSrm:
279     return ConvertToBroadcast(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
280                               1);
281   case X86::VMOVAPDYrm:
282   case X86::VMOVAPSYrm:
283   case X86::VMOVUPDYrm:
284   case X86::VMOVUPSYrm:
285     return ConvertToBroadcast(0, X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
286                               X86::VBROADCASTSSYrm, 0, 0, 1);
287   case X86::VMOVAPDZ128rm:
288   case X86::VMOVAPSZ128rm:
289   case X86::VMOVUPDZ128rm:
290   case X86::VMOVUPSZ128rm:
291     return ConvertToBroadcast(0, 0, X86::VMOVDDUPZ128rm,
292                               X86::VBROADCASTSSZ128rm, 0, 0, 1);
293   case X86::VMOVAPDZ256rm:
294   case X86::VMOVAPSZ256rm:
295   case X86::VMOVUPDZ256rm:
296   case X86::VMOVUPSZ256rm:
297     return ConvertToBroadcast(0, X86::VBROADCASTF32X4Z256rm,
298                               X86::VBROADCASTSDZ256rm, X86::VBROADCASTSSZ256rm,
299                               0, 0, 1);
300   case X86::VMOVAPDZrm:
301   case X86::VMOVAPSZrm:
302   case X86::VMOVUPDZrm:
303   case X86::VMOVUPSZrm:
304     return ConvertToBroadcast(X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
305                               X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0,
306                               1);
307     /* Integer Loads */
308   case X86::VMOVDQArm:
309   case X86::VMOVDQUrm:
310     return ConvertToBroadcast(
311         0, 0, HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
312         HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
313         HasAVX2 ? X86::VPBROADCASTWrm : 0, HasAVX2 ? X86::VPBROADCASTBrm : 0,
314         1);
315   case X86::VMOVDQAYrm:
316   case X86::VMOVDQUYrm:
317     return ConvertToBroadcast(
318         0, HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
319         HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
320         HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
321         HasAVX2 ? X86::VPBROADCASTWYrm : 0, HasAVX2 ? X86::VPBROADCASTBYrm : 0,
322         1);
323   case X86::VMOVDQA32Z128rm:
324   case X86::VMOVDQA64Z128rm:
325   case X86::VMOVDQU32Z128rm:
326   case X86::VMOVDQU64Z128rm:
327     return ConvertToBroadcast(0, 0, X86::VPBROADCASTQZ128rm,
328                               X86::VPBROADCASTDZ128rm,
329                               HasBWI ? X86::VPBROADCASTWZ128rm : 0,
330                               HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1);
331   case X86::VMOVDQA32Z256rm:
332   case X86::VMOVDQA64Z256rm:
333   case X86::VMOVDQU32Z256rm:
334   case X86::VMOVDQU64Z256rm:
335     return ConvertToBroadcast(0, X86::VBROADCASTI32X4Z256rm,
336                               X86::VPBROADCASTQZ256rm, X86::VPBROADCASTDZ256rm,
337                               HasBWI ? X86::VPBROADCASTWZ256rm : 0,
338                               HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1);
339   case X86::VMOVDQA32Zrm:
340   case X86::VMOVDQA64Zrm:
341   case X86::VMOVDQU32Zrm:
342   case X86::VMOVDQU64Zrm:
343     return ConvertToBroadcast(X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
344                               X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
345                               HasBWI ? X86::VPBROADCASTWZrm : 0,
346                               HasBWI ? X86::VPBROADCASTBZrm : 0, 1);
347   }
348 
349   auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
350     unsigned OpBcst32 = 0, OpBcst64 = 0;
351     unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
352     if (OpSrc32) {
353       if (const X86FoldTableEntry *Mem2Bcst =
354               llvm::lookupBroadcastFoldTable(OpSrc32, 32)) {
355         OpBcst32 = Mem2Bcst->DstOp;
356         OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
357       }
358     }
359     if (OpSrc64) {
360       if (const X86FoldTableEntry *Mem2Bcst =
361               llvm::lookupBroadcastFoldTable(OpSrc64, 64)) {
362         OpBcst64 = Mem2Bcst->DstOp;
363         OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
364       }
365     }
366     assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) &&
367            "OperandNo mismatch");
368 
369     if (OpBcst32 || OpBcst64) {
370       unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
371       return ConvertToBroadcast(0, 0, OpBcst64, OpBcst32, 0, 0, OpNo);
372     }
373     return false;
374   };
375 
376   // Attempt to find a AVX512 mapping from a full width memory-fold instruction
377   // to a broadcast-fold instruction variant.
378   if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX)
379     return ConvertToBroadcastAVX512(Opc, Opc);
380 
381   // Reverse the X86InstrInfo::setExecutionDomainCustom EVEX->VEX logic
382   // conversion to see if we can convert to a broadcasted (integer) logic op.
383   if (HasVLX && !HasDQI) {
384     unsigned OpSrc32 = 0, OpSrc64 = 0;
385     switch (Opc) {
386     case X86::VANDPDrm:
387     case X86::VANDPSrm:
388     case X86::VPANDrm:
389       OpSrc32 = X86 ::VPANDDZ128rm;
390       OpSrc64 = X86 ::VPANDQZ128rm;
391       break;
392     case X86::VANDPDYrm:
393     case X86::VANDPSYrm:
394     case X86::VPANDYrm:
395       OpSrc32 = X86 ::VPANDDZ256rm;
396       OpSrc64 = X86 ::VPANDQZ256rm;
397       break;
398     case X86::VANDNPDrm:
399     case X86::VANDNPSrm:
400     case X86::VPANDNrm:
401       OpSrc32 = X86 ::VPANDNDZ128rm;
402       OpSrc64 = X86 ::VPANDNQZ128rm;
403       break;
404     case X86::VANDNPDYrm:
405     case X86::VANDNPSYrm:
406     case X86::VPANDNYrm:
407       OpSrc32 = X86 ::VPANDNDZ256rm;
408       OpSrc64 = X86 ::VPANDNQZ256rm;
409       break;
410     case X86::VORPDrm:
411     case X86::VORPSrm:
412     case X86::VPORrm:
413       OpSrc32 = X86 ::VPORDZ128rm;
414       OpSrc64 = X86 ::VPORQZ128rm;
415       break;
416     case X86::VORPDYrm:
417     case X86::VORPSYrm:
418     case X86::VPORYrm:
419       OpSrc32 = X86 ::VPORDZ256rm;
420       OpSrc64 = X86 ::VPORQZ256rm;
421       break;
422     case X86::VXORPDrm:
423     case X86::VXORPSrm:
424     case X86::VPXORrm:
425       OpSrc32 = X86 ::VPXORDZ128rm;
426       OpSrc64 = X86 ::VPXORQZ128rm;
427       break;
428     case X86::VXORPDYrm:
429     case X86::VXORPSYrm:
430     case X86::VPXORYrm:
431       OpSrc32 = X86 ::VPXORDZ256rm;
432       OpSrc64 = X86 ::VPXORQZ256rm;
433       break;
434     }
435     if (OpSrc32 || OpSrc64)
436       return ConvertToBroadcastAVX512(OpSrc32, OpSrc64);
437   }
438 
439   return false;
440 }
441 
442 bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) {
443   LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";);
444   bool Changed = false;
445   ST = &MF.getSubtarget<X86Subtarget>();
446   TII = ST->getInstrInfo();
447   SM = &ST->getSchedModel();
448 
449   for (MachineBasicBlock &MBB : MF) {
450     for (MachineInstr &MI : MBB) {
451       if (processInstruction(MF, MBB, MI)) {
452         ++NumInstChanges;
453         Changed = true;
454       }
455     }
456   }
457   LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";);
458   return Changed;
459 }
460