1 //=== AArch64PostLegalizerLowering.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization lowering for instructions.
11 ///
12 /// This is used to offload pattern matching from the selector.
13 ///
14 /// For example, this combiner will notice that a G_SHUFFLE_VECTOR is actually
15 /// a G_ZIP, G_UZP, etc.
16 ///
17 /// General optimization combines should be handled by either the
18 /// AArch64PostLegalizerCombiner or the AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21
22 #include "AArch64GlobalISelUtils.h"
23 #include "AArch64Subtarget.h"
24 #include "AArch64TargetMachine.h"
25 #include "GISel/AArch64LegalizerInfo.h"
26 #include "MCTargetDesc/AArch64MCTargetDesc.h"
27 #include "TargetInfo/AArch64TargetInfo.h"
28 #include "Utils/AArch64BaseInfo.h"
29 #include "llvm/CodeGen/GlobalISel/Combiner.h"
30 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
31 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
32 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
33 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
34 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
35 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
36 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
37 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
38 #include "llvm/CodeGen/GlobalISel/Utils.h"
39 #include "llvm/CodeGen/MachineFunctionPass.h"
40 #include "llvm/CodeGen/MachineInstrBuilder.h"
41 #include "llvm/CodeGen/MachineRegisterInfo.h"
42 #include "llvm/CodeGen/TargetOpcodes.h"
43 #include "llvm/CodeGen/TargetPassConfig.h"
44 #include "llvm/IR/InstrTypes.h"
45 #include "llvm/InitializePasses.h"
46 #include "llvm/Support/Debug.h"
47 #include "llvm/Support/ErrorHandling.h"
48 #include <optional>
49
50 #define GET_GICOMBINER_DEPS
51 #include "AArch64GenPostLegalizeGILowering.inc"
52 #undef GET_GICOMBINER_DEPS
53
54 #define DEBUG_TYPE "aarch64-postlegalizer-lowering"
55
56 using namespace llvm;
57 using namespace MIPatternMatch;
58 using namespace AArch64GISelUtils;
59
60 namespace {
61
62 #define GET_GICOMBINER_TYPES
63 #include "AArch64GenPostLegalizeGILowering.inc"
64 #undef GET_GICOMBINER_TYPES
65
66 /// Represents a pseudo instruction which replaces a G_SHUFFLE_VECTOR.
67 ///
68 /// Used for matching target-supported shuffles before codegen.
69 struct ShuffleVectorPseudo {
70 unsigned Opc; ///< Opcode for the instruction. (E.g. G_ZIP1)
71 Register Dst; ///< Destination register.
72 SmallVector<SrcOp, 2> SrcOps; ///< Source registers.
ShuffleVectorPseudo__anonb588a71e0111::ShuffleVectorPseudo73 ShuffleVectorPseudo(unsigned Opc, Register Dst,
74 std::initializer_list<SrcOp> SrcOps)
75 : Opc(Opc), Dst(Dst), SrcOps(SrcOps){};
76 ShuffleVectorPseudo() = default;
77 };
78
79 /// Check if a vector shuffle corresponds to a REV instruction with the
80 /// specified blocksize.
isREVMask(ArrayRef<int> M,unsigned EltSize,unsigned NumElts,unsigned BlockSize)81 bool isREVMask(ArrayRef<int> M, unsigned EltSize, unsigned NumElts,
82 unsigned BlockSize) {
83 assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
84 "Only possible block sizes for REV are: 16, 32, 64");
85 assert(EltSize != 64 && "EltSize cannot be 64 for REV mask.");
86
87 unsigned BlockElts = M[0] + 1;
88
89 // If the first shuffle index is UNDEF, be optimistic.
90 if (M[0] < 0)
91 BlockElts = BlockSize / EltSize;
92
93 if (BlockSize <= EltSize || BlockSize != BlockElts * EltSize)
94 return false;
95
96 for (unsigned i = 0; i < NumElts; ++i) {
97 // Ignore undef indices.
98 if (M[i] < 0)
99 continue;
100 if (static_cast<unsigned>(M[i]) !=
101 (i - i % BlockElts) + (BlockElts - 1 - i % BlockElts))
102 return false;
103 }
104
105 return true;
106 }
107
108 /// Determines if \p M is a shuffle vector mask for a TRN of \p NumElts.
109 /// Whether or not G_TRN1 or G_TRN2 should be used is stored in \p WhichResult.
isTRNMask(ArrayRef<int> M,unsigned NumElts,unsigned & WhichResult)110 bool isTRNMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
111 if (NumElts % 2 != 0)
112 return false;
113 WhichResult = (M[0] == 0 ? 0 : 1);
114 for (unsigned i = 0; i < NumElts; i += 2) {
115 if ((M[i] >= 0 && static_cast<unsigned>(M[i]) != i + WhichResult) ||
116 (M[i + 1] >= 0 &&
117 static_cast<unsigned>(M[i + 1]) != i + NumElts + WhichResult))
118 return false;
119 }
120 return true;
121 }
122
123 /// Check if a G_EXT instruction can handle a shuffle mask \p M when the vector
124 /// sources of the shuffle are different.
getExtMask(ArrayRef<int> M,unsigned NumElts)125 std::optional<std::pair<bool, uint64_t>> getExtMask(ArrayRef<int> M,
126 unsigned NumElts) {
127 // Look for the first non-undef element.
128 auto FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; });
129 if (FirstRealElt == M.end())
130 return std::nullopt;
131
132 // Use APInt to handle overflow when calculating expected element.
133 unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
134 APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1);
135
136 // The following shuffle indices must be the successive elements after the
137 // first real element.
138 if (any_of(
139 make_range(std::next(FirstRealElt), M.end()),
140 [&ExpectedElt](int Elt) { return Elt != ExpectedElt++ && Elt >= 0; }))
141 return std::nullopt;
142
143 // The index of an EXT is the first element if it is not UNDEF.
144 // Watch out for the beginning UNDEFs. The EXT index should be the expected
145 // value of the first element. E.g.
146 // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
147 // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
148 // ExpectedElt is the last mask index plus 1.
149 uint64_t Imm = ExpectedElt.getZExtValue();
150 bool ReverseExt = false;
151
152 // There are two difference cases requiring to reverse input vectors.
153 // For example, for vector <4 x i32> we have the following cases,
154 // Case 1: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, -1, 0>)
155 // Case 2: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, 7, 0>)
156 // For both cases, we finally use mask <5, 6, 7, 0>, which requires
157 // to reverse two input vectors.
158 if (Imm < NumElts)
159 ReverseExt = true;
160 else
161 Imm -= NumElts;
162 return std::make_pair(ReverseExt, Imm);
163 }
164
165 /// Determines if \p M is a shuffle vector mask for a UZP of \p NumElts.
166 /// Whether or not G_UZP1 or G_UZP2 should be used is stored in \p WhichResult.
isUZPMask(ArrayRef<int> M,unsigned NumElts,unsigned & WhichResult)167 bool isUZPMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
168 WhichResult = (M[0] == 0 ? 0 : 1);
169 for (unsigned i = 0; i != NumElts; ++i) {
170 // Skip undef indices.
171 if (M[i] < 0)
172 continue;
173 if (static_cast<unsigned>(M[i]) != 2 * i + WhichResult)
174 return false;
175 }
176 return true;
177 }
178
179 /// \return true if \p M is a zip mask for a shuffle vector of \p NumElts.
180 /// Whether or not G_ZIP1 or G_ZIP2 should be used is stored in \p WhichResult.
isZipMask(ArrayRef<int> M,unsigned NumElts,unsigned & WhichResult)181 bool isZipMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
182 if (NumElts % 2 != 0)
183 return false;
184
185 // 0 means use ZIP1, 1 means use ZIP2.
186 WhichResult = (M[0] == 0 ? 0 : 1);
187 unsigned Idx = WhichResult * NumElts / 2;
188 for (unsigned i = 0; i != NumElts; i += 2) {
189 if ((M[i] >= 0 && static_cast<unsigned>(M[i]) != Idx) ||
190 (M[i + 1] >= 0 && static_cast<unsigned>(M[i + 1]) != Idx + NumElts))
191 return false;
192 Idx += 1;
193 }
194 return true;
195 }
196
197 /// Helper function for matchINS.
198 ///
199 /// \returns a value when \p M is an ins mask for \p NumInputElements.
200 ///
201 /// First element of the returned pair is true when the produced
202 /// G_INSERT_VECTOR_ELT destination should be the LHS of the G_SHUFFLE_VECTOR.
203 ///
204 /// Second element is the destination lane for the G_INSERT_VECTOR_ELT.
isINSMask(ArrayRef<int> M,int NumInputElements)205 std::optional<std::pair<bool, int>> isINSMask(ArrayRef<int> M,
206 int NumInputElements) {
207 if (M.size() != static_cast<size_t>(NumInputElements))
208 return std::nullopt;
209 int NumLHSMatch = 0, NumRHSMatch = 0;
210 int LastLHSMismatch = -1, LastRHSMismatch = -1;
211 for (int Idx = 0; Idx < NumInputElements; ++Idx) {
212 if (M[Idx] == -1) {
213 ++NumLHSMatch;
214 ++NumRHSMatch;
215 continue;
216 }
217 M[Idx] == Idx ? ++NumLHSMatch : LastLHSMismatch = Idx;
218 M[Idx] == Idx + NumInputElements ? ++NumRHSMatch : LastRHSMismatch = Idx;
219 }
220 const int NumNeededToMatch = NumInputElements - 1;
221 if (NumLHSMatch == NumNeededToMatch)
222 return std::make_pair(true, LastLHSMismatch);
223 if (NumRHSMatch == NumNeededToMatch)
224 return std::make_pair(false, LastRHSMismatch);
225 return std::nullopt;
226 }
227
228 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with a
229 /// G_REV instruction. Returns the appropriate G_REV opcode in \p Opc.
matchREV(MachineInstr & MI,MachineRegisterInfo & MRI,ShuffleVectorPseudo & MatchInfo)230 bool matchREV(MachineInstr &MI, MachineRegisterInfo &MRI,
231 ShuffleVectorPseudo &MatchInfo) {
232 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
233 ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
234 Register Dst = MI.getOperand(0).getReg();
235 Register Src = MI.getOperand(1).getReg();
236 LLT Ty = MRI.getType(Dst);
237 unsigned EltSize = Ty.getScalarSizeInBits();
238
239 // Element size for a rev cannot be 64.
240 if (EltSize == 64)
241 return false;
242
243 unsigned NumElts = Ty.getNumElements();
244
245 // Try to produce G_REV64
246 if (isREVMask(ShuffleMask, EltSize, NumElts, 64)) {
247 MatchInfo = ShuffleVectorPseudo(AArch64::G_REV64, Dst, {Src});
248 return true;
249 }
250
251 // TODO: Produce G_REV32 and G_REV16 once we have proper legalization support.
252 // This should be identical to above, but with a constant 32 and constant
253 // 16.
254 return false;
255 }
256
257 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
258 /// a G_TRN1 or G_TRN2 instruction.
matchTRN(MachineInstr & MI,MachineRegisterInfo & MRI,ShuffleVectorPseudo & MatchInfo)259 bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI,
260 ShuffleVectorPseudo &MatchInfo) {
261 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
262 unsigned WhichResult;
263 ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
264 Register Dst = MI.getOperand(0).getReg();
265 unsigned NumElts = MRI.getType(Dst).getNumElements();
266 if (!isTRNMask(ShuffleMask, NumElts, WhichResult))
267 return false;
268 unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2;
269 Register V1 = MI.getOperand(1).getReg();
270 Register V2 = MI.getOperand(2).getReg();
271 MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
272 return true;
273 }
274
275 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
276 /// a G_UZP1 or G_UZP2 instruction.
277 ///
278 /// \param [in] MI - The shuffle vector instruction.
279 /// \param [out] MatchInfo - Either G_UZP1 or G_UZP2 on success.
matchUZP(MachineInstr & MI,MachineRegisterInfo & MRI,ShuffleVectorPseudo & MatchInfo)280 bool matchUZP(MachineInstr &MI, MachineRegisterInfo &MRI,
281 ShuffleVectorPseudo &MatchInfo) {
282 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
283 unsigned WhichResult;
284 ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
285 Register Dst = MI.getOperand(0).getReg();
286 unsigned NumElts = MRI.getType(Dst).getNumElements();
287 if (!isUZPMask(ShuffleMask, NumElts, WhichResult))
288 return false;
289 unsigned Opc = (WhichResult == 0) ? AArch64::G_UZP1 : AArch64::G_UZP2;
290 Register V1 = MI.getOperand(1).getReg();
291 Register V2 = MI.getOperand(2).getReg();
292 MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
293 return true;
294 }
295
matchZip(MachineInstr & MI,MachineRegisterInfo & MRI,ShuffleVectorPseudo & MatchInfo)296 bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI,
297 ShuffleVectorPseudo &MatchInfo) {
298 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
299 unsigned WhichResult;
300 ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
301 Register Dst = MI.getOperand(0).getReg();
302 unsigned NumElts = MRI.getType(Dst).getNumElements();
303 if (!isZipMask(ShuffleMask, NumElts, WhichResult))
304 return false;
305 unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2;
306 Register V1 = MI.getOperand(1).getReg();
307 Register V2 = MI.getOperand(2).getReg();
308 MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
309 return true;
310 }
311
312 /// Helper function for matchDup.
matchDupFromInsertVectorElt(int Lane,MachineInstr & MI,MachineRegisterInfo & MRI,ShuffleVectorPseudo & MatchInfo)313 bool matchDupFromInsertVectorElt(int Lane, MachineInstr &MI,
314 MachineRegisterInfo &MRI,
315 ShuffleVectorPseudo &MatchInfo) {
316 if (Lane != 0)
317 return false;
318
319 // Try to match a vector splat operation into a dup instruction.
320 // We're looking for this pattern:
321 //
322 // %scalar:gpr(s64) = COPY $x0
323 // %undef:fpr(<2 x s64>) = G_IMPLICIT_DEF
324 // %cst0:gpr(s32) = G_CONSTANT i32 0
325 // %zerovec:fpr(<2 x s32>) = G_BUILD_VECTOR %cst0(s32), %cst0(s32)
326 // %ins:fpr(<2 x s64>) = G_INSERT_VECTOR_ELT %undef, %scalar(s64), %cst0(s32)
327 // %splat:fpr(<2 x s64>) = G_SHUFFLE_VECTOR %ins(<2 x s64>), %undef,
328 // %zerovec(<2 x s32>)
329 //
330 // ...into:
331 // %splat = G_DUP %scalar
332
333 // Begin matching the insert.
334 auto *InsMI = getOpcodeDef(TargetOpcode::G_INSERT_VECTOR_ELT,
335 MI.getOperand(1).getReg(), MRI);
336 if (!InsMI)
337 return false;
338 // Match the undef vector operand.
339 if (!getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, InsMI->getOperand(1).getReg(),
340 MRI))
341 return false;
342
343 // Match the index constant 0.
344 if (!mi_match(InsMI->getOperand(3).getReg(), MRI, m_ZeroInt()))
345 return false;
346
347 MatchInfo = ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(0).getReg(),
348 {InsMI->getOperand(2).getReg()});
349 return true;
350 }
351
352 /// Helper function for matchDup.
matchDupFromBuildVector(int Lane,MachineInstr & MI,MachineRegisterInfo & MRI,ShuffleVectorPseudo & MatchInfo)353 bool matchDupFromBuildVector(int Lane, MachineInstr &MI,
354 MachineRegisterInfo &MRI,
355 ShuffleVectorPseudo &MatchInfo) {
356 assert(Lane >= 0 && "Expected positive lane?");
357 // Test if the LHS is a BUILD_VECTOR. If it is, then we can just reference the
358 // lane's definition directly.
359 auto *BuildVecMI = getOpcodeDef(TargetOpcode::G_BUILD_VECTOR,
360 MI.getOperand(1).getReg(), MRI);
361 if (!BuildVecMI)
362 return false;
363 Register Reg = BuildVecMI->getOperand(Lane + 1).getReg();
364 MatchInfo =
365 ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(0).getReg(), {Reg});
366 return true;
367 }
368
matchDup(MachineInstr & MI,MachineRegisterInfo & MRI,ShuffleVectorPseudo & MatchInfo)369 bool matchDup(MachineInstr &MI, MachineRegisterInfo &MRI,
370 ShuffleVectorPseudo &MatchInfo) {
371 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
372 auto MaybeLane = getSplatIndex(MI);
373 if (!MaybeLane)
374 return false;
375 int Lane = *MaybeLane;
376 // If this is undef splat, generate it via "just" vdup, if possible.
377 if (Lane < 0)
378 Lane = 0;
379 if (matchDupFromInsertVectorElt(Lane, MI, MRI, MatchInfo))
380 return true;
381 if (matchDupFromBuildVector(Lane, MI, MRI, MatchInfo))
382 return true;
383 return false;
384 }
385
386 // Check if an EXT instruction can handle the shuffle mask when the vector
387 // sources of the shuffle are the same.
isSingletonExtMask(ArrayRef<int> M,LLT Ty)388 bool isSingletonExtMask(ArrayRef<int> M, LLT Ty) {
389 unsigned NumElts = Ty.getNumElements();
390
391 // Assume that the first shuffle index is not UNDEF. Fail if it is.
392 if (M[0] < 0)
393 return false;
394
395 // If this is a VEXT shuffle, the immediate value is the index of the first
396 // element. The other shuffle indices must be the successive elements after
397 // the first one.
398 unsigned ExpectedElt = M[0];
399 for (unsigned I = 1; I < NumElts; ++I) {
400 // Increment the expected index. If it wraps around, just follow it
401 // back to index zero and keep going.
402 ++ExpectedElt;
403 if (ExpectedElt == NumElts)
404 ExpectedElt = 0;
405
406 if (M[I] < 0)
407 continue; // Ignore UNDEF indices.
408 if (ExpectedElt != static_cast<unsigned>(M[I]))
409 return false;
410 }
411
412 return true;
413 }
414
matchEXT(MachineInstr & MI,MachineRegisterInfo & MRI,ShuffleVectorPseudo & MatchInfo)415 bool matchEXT(MachineInstr &MI, MachineRegisterInfo &MRI,
416 ShuffleVectorPseudo &MatchInfo) {
417 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
418 Register Dst = MI.getOperand(0).getReg();
419 LLT DstTy = MRI.getType(Dst);
420 Register V1 = MI.getOperand(1).getReg();
421 Register V2 = MI.getOperand(2).getReg();
422 auto Mask = MI.getOperand(3).getShuffleMask();
423 uint64_t Imm;
424 auto ExtInfo = getExtMask(Mask, DstTy.getNumElements());
425 uint64_t ExtFactor = MRI.getType(V1).getScalarSizeInBits() / 8;
426
427 if (!ExtInfo) {
428 if (!getOpcodeDef<GImplicitDef>(V2, MRI) ||
429 !isSingletonExtMask(Mask, DstTy))
430 return false;
431
432 Imm = Mask[0] * ExtFactor;
433 MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V1, Imm});
434 return true;
435 }
436 bool ReverseExt;
437 std::tie(ReverseExt, Imm) = *ExtInfo;
438 if (ReverseExt)
439 std::swap(V1, V2);
440 Imm *= ExtFactor;
441 MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V2, Imm});
442 return true;
443 }
444
445 /// Replace a G_SHUFFLE_VECTOR instruction with a pseudo.
446 /// \p Opc is the opcode to use. \p MI is the G_SHUFFLE_VECTOR.
applyShuffleVectorPseudo(MachineInstr & MI,ShuffleVectorPseudo & MatchInfo)447 void applyShuffleVectorPseudo(MachineInstr &MI,
448 ShuffleVectorPseudo &MatchInfo) {
449 MachineIRBuilder MIRBuilder(MI);
450 MIRBuilder.buildInstr(MatchInfo.Opc, {MatchInfo.Dst}, MatchInfo.SrcOps);
451 MI.eraseFromParent();
452 }
453
454 /// Replace a G_SHUFFLE_VECTOR instruction with G_EXT.
455 /// Special-cased because the constant operand must be emitted as a G_CONSTANT
456 /// for the imported tablegen patterns to work.
applyEXT(MachineInstr & MI,ShuffleVectorPseudo & MatchInfo)457 void applyEXT(MachineInstr &MI, ShuffleVectorPseudo &MatchInfo) {
458 MachineIRBuilder MIRBuilder(MI);
459 if (MatchInfo.SrcOps[2].getImm() == 0)
460 MIRBuilder.buildCopy(MatchInfo.Dst, MatchInfo.SrcOps[0]);
461 else {
462 // Tablegen patterns expect an i32 G_CONSTANT as the final op.
463 auto Cst =
464 MIRBuilder.buildConstant(LLT::scalar(32), MatchInfo.SrcOps[2].getImm());
465 MIRBuilder.buildInstr(MatchInfo.Opc, {MatchInfo.Dst},
466 {MatchInfo.SrcOps[0], MatchInfo.SrcOps[1], Cst});
467 }
468 MI.eraseFromParent();
469 }
470
471 /// Match a G_SHUFFLE_VECTOR with a mask which corresponds to a
472 /// G_INSERT_VECTOR_ELT and G_EXTRACT_VECTOR_ELT pair.
473 ///
474 /// e.g.
475 /// %shuf = G_SHUFFLE_VECTOR %left, %right, shufflemask(0, 0)
476 ///
477 /// Can be represented as
478 ///
479 /// %extract = G_EXTRACT_VECTOR_ELT %left, 0
480 /// %ins = G_INSERT_VECTOR_ELT %left, %extract, 1
481 ///
matchINS(MachineInstr & MI,MachineRegisterInfo & MRI,std::tuple<Register,int,Register,int> & MatchInfo)482 bool matchINS(MachineInstr &MI, MachineRegisterInfo &MRI,
483 std::tuple<Register, int, Register, int> &MatchInfo) {
484 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
485 ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
486 Register Dst = MI.getOperand(0).getReg();
487 int NumElts = MRI.getType(Dst).getNumElements();
488 auto DstIsLeftAndDstLane = isINSMask(ShuffleMask, NumElts);
489 if (!DstIsLeftAndDstLane)
490 return false;
491 bool DstIsLeft;
492 int DstLane;
493 std::tie(DstIsLeft, DstLane) = *DstIsLeftAndDstLane;
494 Register Left = MI.getOperand(1).getReg();
495 Register Right = MI.getOperand(2).getReg();
496 Register DstVec = DstIsLeft ? Left : Right;
497 Register SrcVec = Left;
498
499 int SrcLane = ShuffleMask[DstLane];
500 if (SrcLane >= NumElts) {
501 SrcVec = Right;
502 SrcLane -= NumElts;
503 }
504
505 MatchInfo = std::make_tuple(DstVec, DstLane, SrcVec, SrcLane);
506 return true;
507 }
508
applyINS(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & Builder,std::tuple<Register,int,Register,int> & MatchInfo)509 void applyINS(MachineInstr &MI, MachineRegisterInfo &MRI,
510 MachineIRBuilder &Builder,
511 std::tuple<Register, int, Register, int> &MatchInfo) {
512 Builder.setInstrAndDebugLoc(MI);
513 Register Dst = MI.getOperand(0).getReg();
514 auto ScalarTy = MRI.getType(Dst).getElementType();
515 Register DstVec, SrcVec;
516 int DstLane, SrcLane;
517 std::tie(DstVec, DstLane, SrcVec, SrcLane) = MatchInfo;
518 auto SrcCst = Builder.buildConstant(LLT::scalar(64), SrcLane);
519 auto Extract = Builder.buildExtractVectorElement(ScalarTy, SrcVec, SrcCst);
520 auto DstCst = Builder.buildConstant(LLT::scalar(64), DstLane);
521 Builder.buildInsertVectorElement(Dst, DstVec, Extract, DstCst);
522 MI.eraseFromParent();
523 }
524
525 /// isVShiftRImm - Check if this is a valid vector for the immediate
526 /// operand of a vector shift right operation. The value must be in the range:
527 /// 1 <= Value <= ElementBits for a right shift.
isVShiftRImm(Register Reg,MachineRegisterInfo & MRI,LLT Ty,int64_t & Cnt)528 bool isVShiftRImm(Register Reg, MachineRegisterInfo &MRI, LLT Ty,
529 int64_t &Cnt) {
530 assert(Ty.isVector() && "vector shift count is not a vector type");
531 MachineInstr *MI = MRI.getVRegDef(Reg);
532 auto Cst = getAArch64VectorSplatScalar(*MI, MRI);
533 if (!Cst)
534 return false;
535 Cnt = *Cst;
536 int64_t ElementBits = Ty.getScalarSizeInBits();
537 return Cnt >= 1 && Cnt <= ElementBits;
538 }
539
540 /// Match a vector G_ASHR or G_LSHR with a valid immediate shift.
matchVAshrLshrImm(MachineInstr & MI,MachineRegisterInfo & MRI,int64_t & Imm)541 bool matchVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
542 int64_t &Imm) {
543 assert(MI.getOpcode() == TargetOpcode::G_ASHR ||
544 MI.getOpcode() == TargetOpcode::G_LSHR);
545 LLT Ty = MRI.getType(MI.getOperand(1).getReg());
546 if (!Ty.isVector())
547 return false;
548 return isVShiftRImm(MI.getOperand(2).getReg(), MRI, Ty, Imm);
549 }
550
applyVAshrLshrImm(MachineInstr & MI,MachineRegisterInfo & MRI,int64_t & Imm)551 void applyVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
552 int64_t &Imm) {
553 unsigned Opc = MI.getOpcode();
554 assert(Opc == TargetOpcode::G_ASHR || Opc == TargetOpcode::G_LSHR);
555 unsigned NewOpc =
556 Opc == TargetOpcode::G_ASHR ? AArch64::G_VASHR : AArch64::G_VLSHR;
557 MachineIRBuilder MIB(MI);
558 auto ImmDef = MIB.buildConstant(LLT::scalar(32), Imm);
559 MIB.buildInstr(NewOpc, {MI.getOperand(0)}, {MI.getOperand(1), ImmDef});
560 MI.eraseFromParent();
561 }
562
563 /// Determine if it is possible to modify the \p RHS and predicate \p P of a
564 /// G_ICMP instruction such that the right-hand side is an arithmetic immediate.
565 ///
566 /// \returns A pair containing the updated immediate and predicate which may
567 /// be used to optimize the instruction.
568 ///
569 /// \note This assumes that the comparison has been legalized.
570 std::optional<std::pair<uint64_t, CmpInst::Predicate>>
tryAdjustICmpImmAndPred(Register RHS,CmpInst::Predicate P,const MachineRegisterInfo & MRI)571 tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
572 const MachineRegisterInfo &MRI) {
573 const auto &Ty = MRI.getType(RHS);
574 if (Ty.isVector())
575 return std::nullopt;
576 unsigned Size = Ty.getSizeInBits();
577 assert((Size == 32 || Size == 64) && "Expected 32 or 64 bit compare only?");
578
579 // If the RHS is not a constant, or the RHS is already a valid arithmetic
580 // immediate, then there is nothing to change.
581 auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS, MRI);
582 if (!ValAndVReg)
583 return std::nullopt;
584 uint64_t C = ValAndVReg->Value.getZExtValue();
585 if (isLegalArithImmed(C))
586 return std::nullopt;
587
588 // We have a non-arithmetic immediate. Check if adjusting the immediate and
589 // adjusting the predicate will result in a legal arithmetic immediate.
590 switch (P) {
591 default:
592 return std::nullopt;
593 case CmpInst::ICMP_SLT:
594 case CmpInst::ICMP_SGE:
595 // Check for
596 //
597 // x slt c => x sle c - 1
598 // x sge c => x sgt c - 1
599 //
600 // When c is not the smallest possible negative number.
601 if ((Size == 64 && static_cast<int64_t>(C) == INT64_MIN) ||
602 (Size == 32 && static_cast<int32_t>(C) == INT32_MIN))
603 return std::nullopt;
604 P = (P == CmpInst::ICMP_SLT) ? CmpInst::ICMP_SLE : CmpInst::ICMP_SGT;
605 C -= 1;
606 break;
607 case CmpInst::ICMP_ULT:
608 case CmpInst::ICMP_UGE:
609 // Check for
610 //
611 // x ult c => x ule c - 1
612 // x uge c => x ugt c - 1
613 //
614 // When c is not zero.
615 if (C == 0)
616 return std::nullopt;
617 P = (P == CmpInst::ICMP_ULT) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
618 C -= 1;
619 break;
620 case CmpInst::ICMP_SLE:
621 case CmpInst::ICMP_SGT:
622 // Check for
623 //
624 // x sle c => x slt c + 1
625 // x sgt c => s sge c + 1
626 //
627 // When c is not the largest possible signed integer.
628 if ((Size == 32 && static_cast<int32_t>(C) == INT32_MAX) ||
629 (Size == 64 && static_cast<int64_t>(C) == INT64_MAX))
630 return std::nullopt;
631 P = (P == CmpInst::ICMP_SLE) ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGE;
632 C += 1;
633 break;
634 case CmpInst::ICMP_ULE:
635 case CmpInst::ICMP_UGT:
636 // Check for
637 //
638 // x ule c => x ult c + 1
639 // x ugt c => s uge c + 1
640 //
641 // When c is not the largest possible unsigned integer.
642 if ((Size == 32 && static_cast<uint32_t>(C) == UINT32_MAX) ||
643 (Size == 64 && C == UINT64_MAX))
644 return std::nullopt;
645 P = (P == CmpInst::ICMP_ULE) ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
646 C += 1;
647 break;
648 }
649
650 // Check if the new constant is valid, and return the updated constant and
651 // predicate if it is.
652 if (Size == 32)
653 C = static_cast<uint32_t>(C);
654 if (!isLegalArithImmed(C))
655 return std::nullopt;
656 return {{C, P}};
657 }
658
659 /// Determine whether or not it is possible to update the RHS and predicate of
660 /// a G_ICMP instruction such that the RHS will be selected as an arithmetic
661 /// immediate.
662 ///
663 /// \p MI - The G_ICMP instruction
664 /// \p MatchInfo - The new RHS immediate and predicate on success
665 ///
666 /// See tryAdjustICmpImmAndPred for valid transformations.
matchAdjustICmpImmAndPred(MachineInstr & MI,const MachineRegisterInfo & MRI,std::pair<uint64_t,CmpInst::Predicate> & MatchInfo)667 bool matchAdjustICmpImmAndPred(
668 MachineInstr &MI, const MachineRegisterInfo &MRI,
669 std::pair<uint64_t, CmpInst::Predicate> &MatchInfo) {
670 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
671 Register RHS = MI.getOperand(3).getReg();
672 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
673 if (auto MaybeNewImmAndPred = tryAdjustICmpImmAndPred(RHS, Pred, MRI)) {
674 MatchInfo = *MaybeNewImmAndPred;
675 return true;
676 }
677 return false;
678 }
679
applyAdjustICmpImmAndPred(MachineInstr & MI,std::pair<uint64_t,CmpInst::Predicate> & MatchInfo,MachineIRBuilder & MIB,GISelChangeObserver & Observer)680 void applyAdjustICmpImmAndPred(
681 MachineInstr &MI, std::pair<uint64_t, CmpInst::Predicate> &MatchInfo,
682 MachineIRBuilder &MIB, GISelChangeObserver &Observer) {
683 MIB.setInstrAndDebugLoc(MI);
684 MachineOperand &RHS = MI.getOperand(3);
685 MachineRegisterInfo &MRI = *MIB.getMRI();
686 auto Cst = MIB.buildConstant(MRI.cloneVirtualRegister(RHS.getReg()),
687 MatchInfo.first);
688 Observer.changingInstr(MI);
689 RHS.setReg(Cst->getOperand(0).getReg());
690 MI.getOperand(1).setPredicate(MatchInfo.second);
691 Observer.changedInstr(MI);
692 }
693
matchDupLane(MachineInstr & MI,MachineRegisterInfo & MRI,std::pair<unsigned,int> & MatchInfo)694 bool matchDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
695 std::pair<unsigned, int> &MatchInfo) {
696 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
697 Register Src1Reg = MI.getOperand(1).getReg();
698 const LLT SrcTy = MRI.getType(Src1Reg);
699 const LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
700
701 auto LaneIdx = getSplatIndex(MI);
702 if (!LaneIdx)
703 return false;
704
705 // The lane idx should be within the first source vector.
706 if (*LaneIdx >= SrcTy.getNumElements())
707 return false;
708
709 if (DstTy != SrcTy)
710 return false;
711
712 LLT ScalarTy = SrcTy.getElementType();
713 unsigned ScalarSize = ScalarTy.getSizeInBits();
714
715 unsigned Opc = 0;
716 switch (SrcTy.getNumElements()) {
717 case 2:
718 if (ScalarSize == 64)
719 Opc = AArch64::G_DUPLANE64;
720 else if (ScalarSize == 32)
721 Opc = AArch64::G_DUPLANE32;
722 break;
723 case 4:
724 if (ScalarSize == 32)
725 Opc = AArch64::G_DUPLANE32;
726 else if (ScalarSize == 16)
727 Opc = AArch64::G_DUPLANE16;
728 break;
729 case 8:
730 if (ScalarSize == 8)
731 Opc = AArch64::G_DUPLANE8;
732 else if (ScalarSize == 16)
733 Opc = AArch64::G_DUPLANE16;
734 break;
735 case 16:
736 if (ScalarSize == 8)
737 Opc = AArch64::G_DUPLANE8;
738 break;
739 default:
740 break;
741 }
742 if (!Opc)
743 return false;
744
745 MatchInfo.first = Opc;
746 MatchInfo.second = *LaneIdx;
747 return true;
748 }
749
applyDupLane(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::pair<unsigned,int> & MatchInfo)750 void applyDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
751 MachineIRBuilder &B, std::pair<unsigned, int> &MatchInfo) {
752 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
753 Register Src1Reg = MI.getOperand(1).getReg();
754 const LLT SrcTy = MRI.getType(Src1Reg);
755
756 B.setInstrAndDebugLoc(MI);
757 auto Lane = B.buildConstant(LLT::scalar(64), MatchInfo.second);
758
759 Register DupSrc = MI.getOperand(1).getReg();
760 // For types like <2 x s32>, we can use G_DUPLANE32, with a <4 x s32> source.
761 // To do this, we can use a G_CONCAT_VECTORS to do the widening.
762 if (SrcTy.getSizeInBits() == 64) {
763 auto Undef = B.buildUndef(SrcTy);
764 DupSrc = B.buildConcatVectors(SrcTy.multiplyElements(2),
765 {Src1Reg, Undef.getReg(0)})
766 .getReg(0);
767 }
768 B.buildInstr(MatchInfo.first, {MI.getOperand(0).getReg()}, {DupSrc, Lane});
769 MI.eraseFromParent();
770 }
771
matchScalarizeVectorUnmerge(MachineInstr & MI,MachineRegisterInfo & MRI)772 bool matchScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI) {
773 auto &Unmerge = cast<GUnmerge>(MI);
774 Register Src1Reg = Unmerge.getReg(Unmerge.getNumOperands() - 1);
775 const LLT SrcTy = MRI.getType(Src1Reg);
776 return SrcTy.isVector() && !SrcTy.isScalable() &&
777 Unmerge.getNumOperands() == (unsigned)SrcTy.getNumElements() + 1;
778 }
779
applyScalarizeVectorUnmerge(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B)780 void applyScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
781 MachineIRBuilder &B) {
782 auto &Unmerge = cast<GUnmerge>(MI);
783 Register Src1Reg = Unmerge.getReg(Unmerge.getNumOperands() - 1);
784 const LLT SrcTy = MRI.getType(Src1Reg);
785 assert((SrcTy.isVector() && !SrcTy.isScalable()) &&
786 "Expected a fixed length vector");
787
788 for (int I = 0; I < SrcTy.getNumElements(); ++I)
789 B.buildExtractVectorElementConstant(Unmerge.getReg(I), Src1Reg, I);
790 MI.eraseFromParent();
791 }
792
matchBuildVectorToDup(MachineInstr & MI,MachineRegisterInfo & MRI)793 bool matchBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI) {
794 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
795 auto Splat = getAArch64VectorSplat(MI, MRI);
796 if (!Splat)
797 return false;
798 if (Splat->isReg())
799 return true;
800 // Later, during selection, we'll try to match imported patterns using
801 // immAllOnesV and immAllZerosV. These require G_BUILD_VECTOR. Don't lower
802 // G_BUILD_VECTORs which could match those patterns.
803 int64_t Cst = Splat->getCst();
804 return (Cst != 0 && Cst != -1);
805 }
806
applyBuildVectorToDup(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B)807 void applyBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI,
808 MachineIRBuilder &B) {
809 B.setInstrAndDebugLoc(MI);
810 B.buildInstr(AArch64::G_DUP, {MI.getOperand(0).getReg()},
811 {MI.getOperand(1).getReg()});
812 MI.eraseFromParent();
813 }
814
815 /// \returns how many instructions would be saved by folding a G_ICMP's shift
816 /// and/or extension operations.
getCmpOperandFoldingProfit(Register CmpOp,MachineRegisterInfo & MRI)817 unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
818 // No instructions to save if there's more than one use or no uses.
819 if (!MRI.hasOneNonDBGUse(CmpOp))
820 return 0;
821
822 // FIXME: This is duplicated with the selector. (See: selectShiftedRegister)
823 auto IsSupportedExtend = [&](const MachineInstr &MI) {
824 if (MI.getOpcode() == TargetOpcode::G_SEXT_INREG)
825 return true;
826 if (MI.getOpcode() != TargetOpcode::G_AND)
827 return false;
828 auto ValAndVReg =
829 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
830 if (!ValAndVReg)
831 return false;
832 uint64_t Mask = ValAndVReg->Value.getZExtValue();
833 return (Mask == 0xFF || Mask == 0xFFFF || Mask == 0xFFFFFFFF);
834 };
835
836 MachineInstr *Def = getDefIgnoringCopies(CmpOp, MRI);
837 if (IsSupportedExtend(*Def))
838 return 1;
839
840 unsigned Opc = Def->getOpcode();
841 if (Opc != TargetOpcode::G_SHL && Opc != TargetOpcode::G_ASHR &&
842 Opc != TargetOpcode::G_LSHR)
843 return 0;
844
845 auto MaybeShiftAmt =
846 getIConstantVRegValWithLookThrough(Def->getOperand(2).getReg(), MRI);
847 if (!MaybeShiftAmt)
848 return 0;
849 uint64_t ShiftAmt = MaybeShiftAmt->Value.getZExtValue();
850 MachineInstr *ShiftLHS =
851 getDefIgnoringCopies(Def->getOperand(1).getReg(), MRI);
852
853 // Check if we can fold an extend and a shift.
854 // FIXME: This is duplicated with the selector. (See:
855 // selectArithExtendedRegister)
856 if (IsSupportedExtend(*ShiftLHS))
857 return (ShiftAmt <= 4) ? 2 : 1;
858
859 LLT Ty = MRI.getType(Def->getOperand(0).getReg());
860 if (Ty.isVector())
861 return 0;
862 unsigned ShiftSize = Ty.getSizeInBits();
863 if ((ShiftSize == 32 && ShiftAmt <= 31) ||
864 (ShiftSize == 64 && ShiftAmt <= 63))
865 return 1;
866 return 0;
867 }
868
869 /// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
870 /// instruction \p MI.
trySwapICmpOperands(MachineInstr & MI,MachineRegisterInfo & MRI)871 bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
872 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
873 // Swap the operands if it would introduce a profitable folding opportunity.
874 // (e.g. a shift + extend).
875 //
876 // For example:
877 // lsl w13, w11, #1
878 // cmp w13, w12
879 // can be turned into:
880 // cmp w12, w11, lsl #1
881
882 // Don't swap if there's a constant on the RHS, because we know we can fold
883 // that.
884 Register RHS = MI.getOperand(3).getReg();
885 auto RHSCst = getIConstantVRegValWithLookThrough(RHS, MRI);
886 if (RHSCst && isLegalArithImmed(RHSCst->Value.getSExtValue()))
887 return false;
888
889 Register LHS = MI.getOperand(2).getReg();
890 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
891 auto GetRegForProfit = [&](Register Reg) {
892 MachineInstr *Def = getDefIgnoringCopies(Reg, MRI);
893 return isCMN(Def, Pred, MRI) ? Def->getOperand(2).getReg() : Reg;
894 };
895
896 // Don't have a constant on the RHS. If we swap the LHS and RHS of the
897 // compare, would we be able to fold more instructions?
898 Register TheLHS = GetRegForProfit(LHS);
899 Register TheRHS = GetRegForProfit(RHS);
900
901 // If the LHS is more likely to give us a folding opportunity, then swap the
902 // LHS and RHS.
903 return (getCmpOperandFoldingProfit(TheLHS, MRI) >
904 getCmpOperandFoldingProfit(TheRHS, MRI));
905 }
906
applySwapICmpOperands(MachineInstr & MI,GISelChangeObserver & Observer)907 void applySwapICmpOperands(MachineInstr &MI, GISelChangeObserver &Observer) {
908 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
909 Register LHS = MI.getOperand(2).getReg();
910 Register RHS = MI.getOperand(3).getReg();
911 Observer.changedInstr(MI);
912 MI.getOperand(1).setPredicate(CmpInst::getSwappedPredicate(Pred));
913 MI.getOperand(2).setReg(RHS);
914 MI.getOperand(3).setReg(LHS);
915 Observer.changedInstr(MI);
916 }
917
918 /// \returns a function which builds a vector floating point compare instruction
919 /// for a condition code \p CC.
920 /// \param [in] IsZero - True if the comparison is against 0.
921 /// \param [in] NoNans - True if the target has NoNansFPMath.
922 std::function<Register(MachineIRBuilder &)>
getVectorFCMP(AArch64CC::CondCode CC,Register LHS,Register RHS,bool IsZero,bool NoNans,MachineRegisterInfo & MRI)923 getVectorFCMP(AArch64CC::CondCode CC, Register LHS, Register RHS, bool IsZero,
924 bool NoNans, MachineRegisterInfo &MRI) {
925 LLT DstTy = MRI.getType(LHS);
926 assert(DstTy.isVector() && "Expected vector types only?");
927 assert(DstTy == MRI.getType(RHS) && "Src and Dst types must match!");
928 switch (CC) {
929 default:
930 llvm_unreachable("Unexpected condition code!");
931 case AArch64CC::NE:
932 return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
933 auto FCmp = IsZero
934 ? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS})
935 : MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS});
936 return MIB.buildNot(DstTy, FCmp).getReg(0);
937 };
938 case AArch64CC::EQ:
939 return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
940 return IsZero
941 ? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS}).getReg(0)
942 : MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS})
943 .getReg(0);
944 };
945 case AArch64CC::GE:
946 return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
947 return IsZero
948 ? MIB.buildInstr(AArch64::G_FCMGEZ, {DstTy}, {LHS}).getReg(0)
949 : MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {LHS, RHS})
950 .getReg(0);
951 };
952 case AArch64CC::GT:
953 return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
954 return IsZero
955 ? MIB.buildInstr(AArch64::G_FCMGTZ, {DstTy}, {LHS}).getReg(0)
956 : MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {LHS, RHS})
957 .getReg(0);
958 };
959 case AArch64CC::LS:
960 return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
961 return IsZero
962 ? MIB.buildInstr(AArch64::G_FCMLEZ, {DstTy}, {LHS}).getReg(0)
963 : MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {RHS, LHS})
964 .getReg(0);
965 };
966 case AArch64CC::MI:
967 return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
968 return IsZero
969 ? MIB.buildInstr(AArch64::G_FCMLTZ, {DstTy}, {LHS}).getReg(0)
970 : MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {RHS, LHS})
971 .getReg(0);
972 };
973 }
974 }
975
976 /// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
matchLowerVectorFCMP(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & MIB)977 bool matchLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
978 MachineIRBuilder &MIB) {
979 assert(MI.getOpcode() == TargetOpcode::G_FCMP);
980 const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
981
982 Register Dst = MI.getOperand(0).getReg();
983 LLT DstTy = MRI.getType(Dst);
984 if (!DstTy.isVector() || !ST.hasNEON())
985 return false;
986 Register LHS = MI.getOperand(2).getReg();
987 unsigned EltSize = MRI.getType(LHS).getScalarSizeInBits();
988 if (EltSize == 16 && !ST.hasFullFP16())
989 return false;
990 if (EltSize != 16 && EltSize != 32 && EltSize != 64)
991 return false;
992
993 return true;
994 }
995
996 /// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
applyLowerVectorFCMP(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & MIB)997 void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
998 MachineIRBuilder &MIB) {
999 assert(MI.getOpcode() == TargetOpcode::G_FCMP);
1000 const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
1001
1002 const auto &CmpMI = cast<GFCmp>(MI);
1003
1004 Register Dst = CmpMI.getReg(0);
1005 CmpInst::Predicate Pred = CmpMI.getCond();
1006 Register LHS = CmpMI.getLHSReg();
1007 Register RHS = CmpMI.getRHSReg();
1008
1009 LLT DstTy = MRI.getType(Dst);
1010
1011 auto Splat = getAArch64VectorSplat(*MRI.getVRegDef(RHS), MRI);
1012
1013 // Compares against 0 have special target-specific pseudos.
1014 bool IsZero = Splat && Splat->isCst() && Splat->getCst() == 0;
1015
1016 bool Invert = false;
1017 AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
1018 if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) {
1019 // The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
1020 // NaN, so equivalent to a == a and doesn't need the two comparisons an
1021 // "ord" normally would.
1022 RHS = LHS;
1023 IsZero = false;
1024 CC = AArch64CC::EQ;
1025 } else
1026 changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);
1027
1028 // Instead of having an apply function, just build here to simplify things.
1029 MIB.setInstrAndDebugLoc(MI);
1030
1031 const bool NoNans =
1032 ST.getTargetLowering()->getTargetMachine().Options.NoNaNsFPMath;
1033
1034 auto Cmp = getVectorFCMP(CC, LHS, RHS, IsZero, NoNans, MRI);
1035 Register CmpRes;
1036 if (CC2 == AArch64CC::AL)
1037 CmpRes = Cmp(MIB);
1038 else {
1039 auto Cmp2 = getVectorFCMP(CC2, LHS, RHS, IsZero, NoNans, MRI);
1040 auto Cmp2Dst = Cmp2(MIB);
1041 auto Cmp1Dst = Cmp(MIB);
1042 CmpRes = MIB.buildOr(DstTy, Cmp1Dst, Cmp2Dst).getReg(0);
1043 }
1044 if (Invert)
1045 CmpRes = MIB.buildNot(DstTy, CmpRes).getReg(0);
1046 MRI.replaceRegWith(Dst, CmpRes);
1047 MI.eraseFromParent();
1048 }
1049
matchFormTruncstore(MachineInstr & MI,MachineRegisterInfo & MRI,Register & SrcReg)1050 bool matchFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1051 Register &SrcReg) {
1052 assert(MI.getOpcode() == TargetOpcode::G_STORE);
1053 Register DstReg = MI.getOperand(0).getReg();
1054 if (MRI.getType(DstReg).isVector())
1055 return false;
1056 // Match a store of a truncate.
1057 if (!mi_match(DstReg, MRI, m_GTrunc(m_Reg(SrcReg))))
1058 return false;
1059 // Only form truncstores for value types of max 64b.
1060 return MRI.getType(SrcReg).getSizeInBits() <= 64;
1061 }
1062
applyFormTruncstore(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer,Register & SrcReg)1063 void applyFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1064 MachineIRBuilder &B, GISelChangeObserver &Observer,
1065 Register &SrcReg) {
1066 assert(MI.getOpcode() == TargetOpcode::G_STORE);
1067 Observer.changingInstr(MI);
1068 MI.getOperand(0).setReg(SrcReg);
1069 Observer.changedInstr(MI);
1070 }
1071
1072 // Lower vector G_SEXT_INREG back to shifts for selection. We allowed them to
1073 // form in the first place for combine opportunities, so any remaining ones
1074 // at this stage need be lowered back.
matchVectorSextInReg(MachineInstr & MI,MachineRegisterInfo & MRI)1075 bool matchVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI) {
1076 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1077 Register DstReg = MI.getOperand(0).getReg();
1078 LLT DstTy = MRI.getType(DstReg);
1079 return DstTy.isVector();
1080 }
1081
applyVectorSextInReg(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer)1082 void applyVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI,
1083 MachineIRBuilder &B, GISelChangeObserver &Observer) {
1084 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1085 B.setInstrAndDebugLoc(MI);
1086 LegalizerHelper Helper(*MI.getMF(), Observer, B);
1087 Helper.lower(MI, 0, /* Unused hint type */ LLT());
1088 }
1089
1090 /// Combine <N x t>, unused = unmerge(G_EXT <2*N x t> v, undef, N)
1091 /// => unused, <N x t> = unmerge v
matchUnmergeExtToUnmerge(MachineInstr & MI,MachineRegisterInfo & MRI,Register & MatchInfo)1092 bool matchUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
1093 Register &MatchInfo) {
1094 auto &Unmerge = cast<GUnmerge>(MI);
1095 if (Unmerge.getNumDefs() != 2)
1096 return false;
1097 if (!MRI.use_nodbg_empty(Unmerge.getReg(1)))
1098 return false;
1099
1100 LLT DstTy = MRI.getType(Unmerge.getReg(0));
1101 if (!DstTy.isVector())
1102 return false;
1103
1104 MachineInstr *Ext = getOpcodeDef(AArch64::G_EXT, Unmerge.getSourceReg(), MRI);
1105 if (!Ext)
1106 return false;
1107
1108 Register ExtSrc1 = Ext->getOperand(1).getReg();
1109 Register ExtSrc2 = Ext->getOperand(2).getReg();
1110 auto LowestVal =
1111 getIConstantVRegValWithLookThrough(Ext->getOperand(3).getReg(), MRI);
1112 if (!LowestVal || LowestVal->Value.getZExtValue() != DstTy.getSizeInBytes())
1113 return false;
1114
1115 if (!getOpcodeDef<GImplicitDef>(ExtSrc2, MRI))
1116 return false;
1117
1118 MatchInfo = ExtSrc1;
1119 return true;
1120 }
1121
applyUnmergeExtToUnmerge(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer,Register & SrcReg)1122 void applyUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
1123 MachineIRBuilder &B,
1124 GISelChangeObserver &Observer, Register &SrcReg) {
1125 Observer.changingInstr(MI);
1126 // Swap dst registers.
1127 Register Dst1 = MI.getOperand(0).getReg();
1128 MI.getOperand(0).setReg(MI.getOperand(1).getReg());
1129 MI.getOperand(1).setReg(Dst1);
1130 MI.getOperand(2).setReg(SrcReg);
1131 Observer.changedInstr(MI);
1132 }
1133
1134 // Match mul({z/s}ext , {z/s}ext) => {u/s}mull OR
1135 // Match v2s64 mul instructions, which will then be scalarised later on
1136 // Doing these two matches in one function to ensure that the order of matching
1137 // will always be the same.
1138 // Try lowering MUL to MULL before trying to scalarize if needed.
matchExtMulToMULL(MachineInstr & MI,MachineRegisterInfo & MRI)1139 bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
1140 // Get the instructions that defined the source operand
1141 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1142 MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1143 MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1144
1145 if (DstTy.isVector()) {
1146 // If the source operands were EXTENDED before, then {U/S}MULL can be used
1147 unsigned I1Opc = I1->getOpcode();
1148 unsigned I2Opc = I2->getOpcode();
1149 if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1150 (I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1151 (MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1152 MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1153 (MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1154 MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1155 return true;
1156 }
1157 // If result type is v2s64, scalarise the instruction
1158 else if (DstTy == LLT::fixed_vector(2, 64)) {
1159 return true;
1160 }
1161 }
1162 return false;
1163 }
1164
applyExtMulToMULL(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer)1165 void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
1166 MachineIRBuilder &B, GISelChangeObserver &Observer) {
1167 assert(MI.getOpcode() == TargetOpcode::G_MUL &&
1168 "Expected a G_MUL instruction");
1169
1170 // Get the instructions that defined the source operand
1171 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1172 MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1173 MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1174
1175 // If the source operands were EXTENDED before, then {U/S}MULL can be used
1176 unsigned I1Opc = I1->getOpcode();
1177 unsigned I2Opc = I2->getOpcode();
1178 if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1179 (I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1180 (MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1181 MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1182 (MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1183 MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1184
1185 B.setInstrAndDebugLoc(MI);
1186 B.buildInstr(I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UMULL
1187 : AArch64::G_SMULL,
1188 {MI.getOperand(0).getReg()},
1189 {I1->getOperand(1).getReg(), I2->getOperand(1).getReg()});
1190 MI.eraseFromParent();
1191 }
1192 // If result type is v2s64, scalarise the instruction
1193 else if (DstTy == LLT::fixed_vector(2, 64)) {
1194 LegalizerHelper Helper(*MI.getMF(), Observer, B);
1195 B.setInstrAndDebugLoc(MI);
1196 Helper.fewerElementsVector(
1197 MI, 0,
1198 DstTy.changeElementCount(
1199 DstTy.getElementCount().divideCoefficientBy(2)));
1200 }
1201 }
1202
1203 class AArch64PostLegalizerLoweringImpl : public Combiner {
1204 protected:
1205 // TODO: Make CombinerHelper methods const.
1206 mutable CombinerHelper Helper;
1207 const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig;
1208 const AArch64Subtarget &STI;
1209
1210 public:
1211 AArch64PostLegalizerLoweringImpl(
1212 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
1213 GISelCSEInfo *CSEInfo,
1214 const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1215 const AArch64Subtarget &STI);
1216
getName()1217 static const char *getName() { return "AArch6400PreLegalizerCombiner"; }
1218
1219 bool tryCombineAll(MachineInstr &I) const override;
1220
1221 private:
1222 #define GET_GICOMBINER_CLASS_MEMBERS
1223 #include "AArch64GenPostLegalizeGILowering.inc"
1224 #undef GET_GICOMBINER_CLASS_MEMBERS
1225 };
1226
1227 #define GET_GICOMBINER_IMPL
1228 #include "AArch64GenPostLegalizeGILowering.inc"
1229 #undef GET_GICOMBINER_IMPL
1230
AArch64PostLegalizerLoweringImpl(MachineFunction & MF,CombinerInfo & CInfo,const TargetPassConfig * TPC,GISelCSEInfo * CSEInfo,const AArch64PostLegalizerLoweringImplRuleConfig & RuleConfig,const AArch64Subtarget & STI)1231 AArch64PostLegalizerLoweringImpl::AArch64PostLegalizerLoweringImpl(
1232 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
1233 GISelCSEInfo *CSEInfo,
1234 const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1235 const AArch64Subtarget &STI)
1236 : Combiner(MF, CInfo, TPC, /*KB*/ nullptr, CSEInfo),
1237 Helper(Observer, B, /*IsPreLegalize*/ true), RuleConfig(RuleConfig),
1238 STI(STI),
1239 #define GET_GICOMBINER_CONSTRUCTOR_INITS
1240 #include "AArch64GenPostLegalizeGILowering.inc"
1241 #undef GET_GICOMBINER_CONSTRUCTOR_INITS
1242 {
1243 }
1244
1245 class AArch64PostLegalizerLowering : public MachineFunctionPass {
1246 public:
1247 static char ID;
1248
1249 AArch64PostLegalizerLowering();
1250
getPassName() const1251 StringRef getPassName() const override {
1252 return "AArch64PostLegalizerLowering";
1253 }
1254
1255 bool runOnMachineFunction(MachineFunction &MF) override;
1256 void getAnalysisUsage(AnalysisUsage &AU) const override;
1257
1258 private:
1259 AArch64PostLegalizerLoweringImplRuleConfig RuleConfig;
1260 };
1261 } // end anonymous namespace
1262
getAnalysisUsage(AnalysisUsage & AU) const1263 void AArch64PostLegalizerLowering::getAnalysisUsage(AnalysisUsage &AU) const {
1264 AU.addRequired<TargetPassConfig>();
1265 AU.setPreservesCFG();
1266 getSelectionDAGFallbackAnalysisUsage(AU);
1267 MachineFunctionPass::getAnalysisUsage(AU);
1268 }
1269
AArch64PostLegalizerLowering()1270 AArch64PostLegalizerLowering::AArch64PostLegalizerLowering()
1271 : MachineFunctionPass(ID) {
1272 initializeAArch64PostLegalizerLoweringPass(*PassRegistry::getPassRegistry());
1273
1274 if (!RuleConfig.parseCommandLineOption())
1275 report_fatal_error("Invalid rule identifier");
1276 }
1277
runOnMachineFunction(MachineFunction & MF)1278 bool AArch64PostLegalizerLowering::runOnMachineFunction(MachineFunction &MF) {
1279 if (MF.getProperties().hasProperty(
1280 MachineFunctionProperties::Property::FailedISel))
1281 return false;
1282 assert(MF.getProperties().hasProperty(
1283 MachineFunctionProperties::Property::Legalized) &&
1284 "Expected a legalized function?");
1285 auto *TPC = &getAnalysis<TargetPassConfig>();
1286 const Function &F = MF.getFunction();
1287
1288 const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>();
1289 CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
1290 /*LegalizerInfo*/ nullptr, /*OptEnabled=*/true,
1291 F.hasOptSize(), F.hasMinSize());
1292 AArch64PostLegalizerLoweringImpl Impl(MF, CInfo, TPC, /*CSEInfo*/ nullptr,
1293 RuleConfig, ST);
1294 return Impl.combineMachineInstrs();
1295 }
1296
1297 char AArch64PostLegalizerLowering::ID = 0;
1298 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerLowering, DEBUG_TYPE,
1299 "Lower AArch64 MachineInstrs after legalization", false,
1300 false)
1301 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1302 INITIALIZE_PASS_END(AArch64PostLegalizerLowering, DEBUG_TYPE,
1303 "Lower AArch64 MachineInstrs after legalization", false,
1304 false)
1305
1306 namespace llvm {
createAArch64PostLegalizerLowering()1307 FunctionPass *createAArch64PostLegalizerLowering() {
1308 return new AArch64PostLegalizerLowering();
1309 }
1310 } // end namespace llvm
1311