1 //===-- X86FixupVectorConstants.cpp - optimize constant generation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file examines all full size vector constant pool loads and attempts to
10 // replace them with smaller constant pool entries, including:
11 // * Converting AVX512 memory-fold instructions to their broadcast-fold form
12 // * Broadcasting of full width loads.
13 // * TODO: Sign/Zero extension of full width loads.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "X86.h"
18 #include "X86InstrFoldTables.h"
19 #include "X86InstrInfo.h"
20 #include "X86Subtarget.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/CodeGen/MachineConstantPool.h"
23
24 using namespace llvm;
25
26 #define DEBUG_TYPE "x86-fixup-vector-constants"
27
28 STATISTIC(NumInstChanges, "Number of instructions changes");
29
30 namespace {
31 class X86FixupVectorConstantsPass : public MachineFunctionPass {
32 public:
33 static char ID;
34
X86FixupVectorConstantsPass()35 X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {}
36
getPassName() const37 StringRef getPassName() const override {
38 return "X86 Fixup Vector Constants";
39 }
40
41 bool runOnMachineFunction(MachineFunction &MF) override;
42 bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB,
43 MachineInstr &MI);
44
45 // This pass runs after regalloc and doesn't support VReg operands.
getRequiredProperties() const46 MachineFunctionProperties getRequiredProperties() const override {
47 return MachineFunctionProperties().set(
48 MachineFunctionProperties::Property::NoVRegs);
49 }
50
51 private:
52 const X86InstrInfo *TII = nullptr;
53 const X86Subtarget *ST = nullptr;
54 const MCSchedModel *SM = nullptr;
55 };
56 } // end anonymous namespace
57
58 char X86FixupVectorConstantsPass::ID = 0;
59
INITIALIZE_PASS(X86FixupVectorConstantsPass,DEBUG_TYPE,DEBUG_TYPE,false,false)60 INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false)
61
62 FunctionPass *llvm::createX86FixupVectorConstants() {
63 return new X86FixupVectorConstantsPass();
64 }
65
66 // Attempt to extract the full width of bits data from the constant.
extractConstantBits(const Constant * C)67 static std::optional<APInt> extractConstantBits(const Constant *C) {
68 unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
69
70 if (auto *CInt = dyn_cast<ConstantInt>(C))
71 return CInt->getValue();
72
73 if (auto *CFP = dyn_cast<ConstantFP>(C))
74 return CFP->getValue().bitcastToAPInt();
75
76 if (auto *CV = dyn_cast<ConstantVector>(C)) {
77 if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
78 if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
79 assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
80 return APInt::getSplat(NumBits, *Bits);
81 }
82 }
83 }
84
85 if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
86 bool IsInteger = CDS->getElementType()->isIntegerTy();
87 bool IsFloat = CDS->getElementType()->isHalfTy() ||
88 CDS->getElementType()->isBFloatTy() ||
89 CDS->getElementType()->isFloatTy() ||
90 CDS->getElementType()->isDoubleTy();
91 if (IsInteger || IsFloat) {
92 APInt Bits = APInt::getZero(NumBits);
93 unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits();
94 for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
95 if (IsInteger)
96 Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
97 else
98 Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
99 I * EltBits);
100 }
101 return Bits;
102 }
103 }
104
105 return std::nullopt;
106 }
107
108 // Attempt to compute the splat width of bits data by normalizing the splat to
109 // remove undefs.
getSplatableConstant(const Constant * C,unsigned SplatBitWidth)110 static std::optional<APInt> getSplatableConstant(const Constant *C,
111 unsigned SplatBitWidth) {
112 const Type *Ty = C->getType();
113 assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
114 "Illegal splat width");
115
116 if (std::optional<APInt> Bits = extractConstantBits(C))
117 if (Bits->isSplat(SplatBitWidth))
118 return Bits->trunc(SplatBitWidth);
119
120 // Detect general splats with undefs.
121 // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting?
122 if (auto *CV = dyn_cast<ConstantVector>(C)) {
123 unsigned NumOps = CV->getNumOperands();
124 unsigned NumEltsBits = Ty->getScalarSizeInBits();
125 unsigned NumScaleOps = SplatBitWidth / NumEltsBits;
126 if ((SplatBitWidth % NumEltsBits) == 0) {
127 // Collect the elements and ensure that within the repeated splat sequence
128 // they either match or are undef.
129 SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr);
130 for (unsigned Idx = 0; Idx != NumOps; ++Idx) {
131 if (Constant *Elt = CV->getAggregateElement(Idx)) {
132 if (isa<UndefValue>(Elt))
133 continue;
134 unsigned SplatIdx = Idx % NumScaleOps;
135 if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) {
136 Sequence[SplatIdx] = Elt;
137 continue;
138 }
139 }
140 return std::nullopt;
141 }
142 // Extract the constant bits forming the splat and insert into the bits
143 // data, leave undef as zero.
144 APInt SplatBits = APInt::getZero(SplatBitWidth);
145 for (unsigned I = 0; I != NumScaleOps; ++I) {
146 if (!Sequence[I])
147 continue;
148 if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) {
149 SplatBits.insertBits(*Bits, I * Bits->getBitWidth());
150 continue;
151 }
152 return std::nullopt;
153 }
154 return SplatBits;
155 }
156 }
157
158 return std::nullopt;
159 }
160
161 // Split raw bits into a constant vector of elements of a specific bit width.
162 // NOTE: We don't always bother converting to scalars if the vector length is 1.
rebuildConstant(LLVMContext & Ctx,Type * SclTy,const APInt & Bits,unsigned NumSclBits)163 static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
164 const APInt &Bits, unsigned NumSclBits) {
165 unsigned BitWidth = Bits.getBitWidth();
166
167 if (NumSclBits == 8) {
168 SmallVector<uint8_t> RawBits;
169 for (unsigned I = 0; I != BitWidth; I += 8)
170 RawBits.push_back(Bits.extractBits(8, I).getZExtValue());
171 return ConstantDataVector::get(Ctx, RawBits);
172 }
173
174 if (NumSclBits == 16) {
175 SmallVector<uint16_t> RawBits;
176 for (unsigned I = 0; I != BitWidth; I += 16)
177 RawBits.push_back(Bits.extractBits(16, I).getZExtValue());
178 if (SclTy->is16bitFPTy())
179 return ConstantDataVector::getFP(SclTy, RawBits);
180 return ConstantDataVector::get(Ctx, RawBits);
181 }
182
183 if (NumSclBits == 32) {
184 SmallVector<uint32_t> RawBits;
185 for (unsigned I = 0; I != BitWidth; I += 32)
186 RawBits.push_back(Bits.extractBits(32, I).getZExtValue());
187 if (SclTy->isFloatTy())
188 return ConstantDataVector::getFP(SclTy, RawBits);
189 return ConstantDataVector::get(Ctx, RawBits);
190 }
191
192 assert(NumSclBits == 64 && "Unhandled vector element width");
193
194 SmallVector<uint64_t> RawBits;
195 for (unsigned I = 0; I != BitWidth; I += 64)
196 RawBits.push_back(Bits.extractBits(64, I).getZExtValue());
197 if (SclTy->isDoubleTy())
198 return ConstantDataVector::getFP(SclTy, RawBits);
199 return ConstantDataVector::get(Ctx, RawBits);
200 }
201
202 // Attempt to rebuild a normalized splat vector constant of the requested splat
203 // width, built up of potentially smaller scalar values.
rebuildSplatableConstant(const Constant * C,unsigned SplatBitWidth)204 static Constant *rebuildSplatableConstant(const Constant *C,
205 unsigned SplatBitWidth) {
206 std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
207 if (!Splat)
208 return nullptr;
209
210 // Determine scalar size to use for the constant splat vector, clamping as we
211 // might have found a splat smaller than the original constant data.
212 const Type *OriginalType = C->getType();
213 Type *SclTy = OriginalType->getScalarType();
214 unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
215 NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
216
217 // Fallback to i64 / double.
218 NumSclBits = (NumSclBits == 8 || NumSclBits == 16 || NumSclBits == 32)
219 ? NumSclBits
220 : 64;
221
222 // Extract per-element bits.
223 return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
224 }
225
processInstruction(MachineFunction & MF,MachineBasicBlock & MBB,MachineInstr & MI)226 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
227 MachineBasicBlock &MBB,
228 MachineInstr &MI) {
229 unsigned Opc = MI.getOpcode();
230 MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
231 bool HasAVX2 = ST->hasAVX2();
232 bool HasDQI = ST->hasDQI();
233 bool HasBWI = ST->hasBWI();
234 bool HasVLX = ST->hasVLX();
235
236 auto ConvertToBroadcast = [&](unsigned OpBcst256, unsigned OpBcst128,
237 unsigned OpBcst64, unsigned OpBcst32,
238 unsigned OpBcst16, unsigned OpBcst8,
239 unsigned OperandNo) {
240 assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
241 "Unexpected number of operands!");
242
243 if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
244 // Attempt to detect a suitable splat from increasing splat widths.
245 std::pair<unsigned, unsigned> Broadcasts[] = {
246 {8, OpBcst8}, {16, OpBcst16}, {32, OpBcst32},
247 {64, OpBcst64}, {128, OpBcst128}, {256, OpBcst256},
248 };
249 for (auto [BitWidth, OpBcst] : Broadcasts) {
250 if (OpBcst) {
251 // Construct a suitable splat constant and adjust the MI to
252 // use the new constant pool entry.
253 if (Constant *NewCst = rebuildSplatableConstant(C, BitWidth)) {
254 unsigned NewCPI =
255 CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8));
256 MI.setDesc(TII->get(OpBcst));
257 MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
258 return true;
259 }
260 }
261 }
262 }
263 return false;
264 };
265
266 // Attempt to convert full width vector loads into broadcast loads.
267 switch (Opc) {
268 /* FP Loads */
269 case X86::MOVAPDrm:
270 case X86::MOVAPSrm:
271 case X86::MOVUPDrm:
272 case X86::MOVUPSrm:
273 // TODO: SSE3 MOVDDUP Handling
274 return false;
275 case X86::VMOVAPDrm:
276 case X86::VMOVAPSrm:
277 case X86::VMOVUPDrm:
278 case X86::VMOVUPSrm:
279 return ConvertToBroadcast(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
280 1);
281 case X86::VMOVAPDYrm:
282 case X86::VMOVAPSYrm:
283 case X86::VMOVUPDYrm:
284 case X86::VMOVUPSYrm:
285 return ConvertToBroadcast(0, X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
286 X86::VBROADCASTSSYrm, 0, 0, 1);
287 case X86::VMOVAPDZ128rm:
288 case X86::VMOVAPSZ128rm:
289 case X86::VMOVUPDZ128rm:
290 case X86::VMOVUPSZ128rm:
291 return ConvertToBroadcast(0, 0, X86::VMOVDDUPZ128rm,
292 X86::VBROADCASTSSZ128rm, 0, 0, 1);
293 case X86::VMOVAPDZ256rm:
294 case X86::VMOVAPSZ256rm:
295 case X86::VMOVUPDZ256rm:
296 case X86::VMOVUPSZ256rm:
297 return ConvertToBroadcast(0, X86::VBROADCASTF32X4Z256rm,
298 X86::VBROADCASTSDZ256rm, X86::VBROADCASTSSZ256rm,
299 0, 0, 1);
300 case X86::VMOVAPDZrm:
301 case X86::VMOVAPSZrm:
302 case X86::VMOVUPDZrm:
303 case X86::VMOVUPSZrm:
304 return ConvertToBroadcast(X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
305 X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0,
306 1);
307 /* Integer Loads */
308 case X86::VMOVDQArm:
309 case X86::VMOVDQUrm:
310 return ConvertToBroadcast(
311 0, 0, HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
312 HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
313 HasAVX2 ? X86::VPBROADCASTWrm : 0, HasAVX2 ? X86::VPBROADCASTBrm : 0,
314 1);
315 case X86::VMOVDQAYrm:
316 case X86::VMOVDQUYrm:
317 return ConvertToBroadcast(
318 0, HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
319 HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
320 HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
321 HasAVX2 ? X86::VPBROADCASTWYrm : 0, HasAVX2 ? X86::VPBROADCASTBYrm : 0,
322 1);
323 case X86::VMOVDQA32Z128rm:
324 case X86::VMOVDQA64Z128rm:
325 case X86::VMOVDQU32Z128rm:
326 case X86::VMOVDQU64Z128rm:
327 return ConvertToBroadcast(0, 0, X86::VPBROADCASTQZ128rm,
328 X86::VPBROADCASTDZ128rm,
329 HasBWI ? X86::VPBROADCASTWZ128rm : 0,
330 HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1);
331 case X86::VMOVDQA32Z256rm:
332 case X86::VMOVDQA64Z256rm:
333 case X86::VMOVDQU32Z256rm:
334 case X86::VMOVDQU64Z256rm:
335 return ConvertToBroadcast(0, X86::VBROADCASTI32X4Z256rm,
336 X86::VPBROADCASTQZ256rm, X86::VPBROADCASTDZ256rm,
337 HasBWI ? X86::VPBROADCASTWZ256rm : 0,
338 HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1);
339 case X86::VMOVDQA32Zrm:
340 case X86::VMOVDQA64Zrm:
341 case X86::VMOVDQU32Zrm:
342 case X86::VMOVDQU64Zrm:
343 return ConvertToBroadcast(X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
344 X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
345 HasBWI ? X86::VPBROADCASTWZrm : 0,
346 HasBWI ? X86::VPBROADCASTBZrm : 0, 1);
347 }
348
349 auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
350 unsigned OpBcst32 = 0, OpBcst64 = 0;
351 unsigned OpNoBcst32 = 0, OpNoBcst64 = 0;
352 if (OpSrc32) {
353 if (const X86FoldTableEntry *Mem2Bcst =
354 llvm::lookupBroadcastFoldTable(OpSrc32, 32)) {
355 OpBcst32 = Mem2Bcst->DstOp;
356 OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK;
357 }
358 }
359 if (OpSrc64) {
360 if (const X86FoldTableEntry *Mem2Bcst =
361 llvm::lookupBroadcastFoldTable(OpSrc64, 64)) {
362 OpBcst64 = Mem2Bcst->DstOp;
363 OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK;
364 }
365 }
366 assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) &&
367 "OperandNo mismatch");
368
369 if (OpBcst32 || OpBcst64) {
370 unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
371 return ConvertToBroadcast(0, 0, OpBcst64, OpBcst32, 0, 0, OpNo);
372 }
373 return false;
374 };
375
376 // Attempt to find a AVX512 mapping from a full width memory-fold instruction
377 // to a broadcast-fold instruction variant.
378 if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX)
379 return ConvertToBroadcastAVX512(Opc, Opc);
380
381 // Reverse the X86InstrInfo::setExecutionDomainCustom EVEX->VEX logic
382 // conversion to see if we can convert to a broadcasted (integer) logic op.
383 if (HasVLX && !HasDQI) {
384 unsigned OpSrc32 = 0, OpSrc64 = 0;
385 switch (Opc) {
386 case X86::VANDPDrm:
387 case X86::VANDPSrm:
388 case X86::VPANDrm:
389 OpSrc32 = X86 ::VPANDDZ128rm;
390 OpSrc64 = X86 ::VPANDQZ128rm;
391 break;
392 case X86::VANDPDYrm:
393 case X86::VANDPSYrm:
394 case X86::VPANDYrm:
395 OpSrc32 = X86 ::VPANDDZ256rm;
396 OpSrc64 = X86 ::VPANDQZ256rm;
397 break;
398 case X86::VANDNPDrm:
399 case X86::VANDNPSrm:
400 case X86::VPANDNrm:
401 OpSrc32 = X86 ::VPANDNDZ128rm;
402 OpSrc64 = X86 ::VPANDNQZ128rm;
403 break;
404 case X86::VANDNPDYrm:
405 case X86::VANDNPSYrm:
406 case X86::VPANDNYrm:
407 OpSrc32 = X86 ::VPANDNDZ256rm;
408 OpSrc64 = X86 ::VPANDNQZ256rm;
409 break;
410 case X86::VORPDrm:
411 case X86::VORPSrm:
412 case X86::VPORrm:
413 OpSrc32 = X86 ::VPORDZ128rm;
414 OpSrc64 = X86 ::VPORQZ128rm;
415 break;
416 case X86::VORPDYrm:
417 case X86::VORPSYrm:
418 case X86::VPORYrm:
419 OpSrc32 = X86 ::VPORDZ256rm;
420 OpSrc64 = X86 ::VPORQZ256rm;
421 break;
422 case X86::VXORPDrm:
423 case X86::VXORPSrm:
424 case X86::VPXORrm:
425 OpSrc32 = X86 ::VPXORDZ128rm;
426 OpSrc64 = X86 ::VPXORQZ128rm;
427 break;
428 case X86::VXORPDYrm:
429 case X86::VXORPSYrm:
430 case X86::VPXORYrm:
431 OpSrc32 = X86 ::VPXORDZ256rm;
432 OpSrc64 = X86 ::VPXORQZ256rm;
433 break;
434 }
435 if (OpSrc32 || OpSrc64)
436 return ConvertToBroadcastAVX512(OpSrc32, OpSrc64);
437 }
438
439 return false;
440 }
441
runOnMachineFunction(MachineFunction & MF)442 bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) {
443 LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";);
444 bool Changed = false;
445 ST = &MF.getSubtarget<X86Subtarget>();
446 TII = ST->getInstrInfo();
447 SM = &ST->getSchedModel();
448
449 for (MachineBasicBlock &MBB : MF) {
450 for (MachineInstr &MI : MBB) {
451 if (processInstruction(MF, MBB, MI)) {
452 ++NumInstChanges;
453 Changed = true;
454 }
455 }
456 }
457 LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";);
458 return Changed;
459 }
460