1 //===- SelectionDAG.cpp - Implement the SelectionDAG data structures ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the SelectionDAG class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/CodeGen/SelectionDAG.h"
14 #include "SDNodeDbgValue.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/APSInt.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/BitVector.h"
20 #include "llvm/ADT/FoldingSet.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/Triple.h"
25 #include "llvm/ADT/Twine.h"
26 #include "llvm/Analysis/AliasAnalysis.h"
27 #include "llvm/Analysis/MemoryLocation.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Analysis/VectorUtils.h"
30 #include "llvm/CodeGen/Analysis.h"
31 #include "llvm/CodeGen/FunctionLoweringInfo.h"
32 #include "llvm/CodeGen/ISDOpcodes.h"
33 #include "llvm/CodeGen/MachineBasicBlock.h"
34 #include "llvm/CodeGen/MachineConstantPool.h"
35 #include "llvm/CodeGen/MachineFrameInfo.h"
36 #include "llvm/CodeGen/MachineFunction.h"
37 #include "llvm/CodeGen/MachineMemOperand.h"
38 #include "llvm/CodeGen/RuntimeLibcalls.h"
39 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
40 #include "llvm/CodeGen/SelectionDAGNodes.h"
41 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
42 #include "llvm/CodeGen/TargetFrameLowering.h"
43 #include "llvm/CodeGen/TargetLowering.h"
44 #include "llvm/CodeGen/TargetRegisterInfo.h"
45 #include "llvm/CodeGen/TargetSubtargetInfo.h"
46 #include "llvm/CodeGen/ValueTypes.h"
47 #include "llvm/IR/Constant.h"
48 #include "llvm/IR/ConstantRange.h"
49 #include "llvm/IR/Constants.h"
50 #include "llvm/IR/DataLayout.h"
51 #include "llvm/IR/DebugInfoMetadata.h"
52 #include "llvm/IR/DebugLoc.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/GlobalValue.h"
56 #include "llvm/IR/Metadata.h"
57 #include "llvm/IR/Type.h"
58 #include "llvm/Support/Casting.h"
59 #include "llvm/Support/CodeGen.h"
60 #include "llvm/Support/Compiler.h"
61 #include "llvm/Support/Debug.h"
62 #include "llvm/Support/ErrorHandling.h"
63 #include "llvm/Support/KnownBits.h"
64 #include "llvm/Support/MachineValueType.h"
65 #include "llvm/Support/MathExtras.h"
66 #include "llvm/Support/Mutex.h"
67 #include "llvm/Support/raw_ostream.h"
68 #include "llvm/Target/TargetMachine.h"
69 #include "llvm/Target/TargetOptions.h"
70 #include "llvm/Transforms/Utils/SizeOpts.h"
71 #include <algorithm>
72 #include <cassert>
73 #include <cstdint>
74 #include <cstdlib>
75 #include <limits>
76 #include <set>
77 #include <string>
78 #include <utility>
79 #include <vector>
80 
81 using namespace llvm;
82 
83 /// makeVTList - Return an instance of the SDVTList struct initialized with the
84 /// specified members.
makeVTList(const EVT * VTs,unsigned NumVTs)85 static SDVTList makeVTList(const EVT *VTs, unsigned NumVTs) {
86   SDVTList Res = {VTs, NumVTs};
87   return Res;
88 }
89 
90 // Default null implementations of the callbacks.
NodeDeleted(SDNode *,SDNode *)91 void SelectionDAG::DAGUpdateListener::NodeDeleted(SDNode*, SDNode*) {}
NodeUpdated(SDNode *)92 void SelectionDAG::DAGUpdateListener::NodeUpdated(SDNode*) {}
NodeInserted(SDNode *)93 void SelectionDAG::DAGUpdateListener::NodeInserted(SDNode *) {}
94 
anchor()95 void SelectionDAG::DAGNodeDeletedListener::anchor() {}
anchor()96 void SelectionDAG::DAGNodeInsertedListener::anchor() {}
97 
98 #define DEBUG_TYPE "selectiondag"
99 
100 static cl::opt<bool> EnableMemCpyDAGOpt("enable-memcpy-dag-opt",
101        cl::Hidden, cl::init(true),
102        cl::desc("Gang up loads and stores generated by inlining of memcpy"));
103 
104 static cl::opt<int> MaxLdStGlue("ldstmemcpy-glue-max",
105        cl::desc("Number limit for gluing ld/st of memcpy."),
106        cl::Hidden, cl::init(0));
107 
NewSDValueDbgMsg(SDValue V,StringRef Msg,SelectionDAG * G)108 static void NewSDValueDbgMsg(SDValue V, StringRef Msg, SelectionDAG *G) {
109   LLVM_DEBUG(dbgs() << Msg; V.getNode()->dump(G););
110 }
111 
112 //===----------------------------------------------------------------------===//
113 //                              ConstantFPSDNode Class
114 //===----------------------------------------------------------------------===//
115 
116 /// isExactlyValue - We don't rely on operator== working on double values, as
117 /// it returns true for things that are clearly not equal, like -0.0 and 0.0.
118 /// As such, this method can be used to do an exact bit-for-bit comparison of
119 /// two floating point values.
isExactlyValue(const APFloat & V) const120 bool ConstantFPSDNode::isExactlyValue(const APFloat& V) const {
121   return getValueAPF().bitwiseIsEqual(V);
122 }
123 
isValueValidForType(EVT VT,const APFloat & Val)124 bool ConstantFPSDNode::isValueValidForType(EVT VT,
125                                            const APFloat& Val) {
126   assert(VT.isFloatingPoint() && "Can only convert between FP types");
127 
128   // convert modifies in place, so make a copy.
129   APFloat Val2 = APFloat(Val);
130   bool losesInfo;
131   (void) Val2.convert(SelectionDAG::EVTToAPFloatSemantics(VT),
132                       APFloat::rmNearestTiesToEven,
133                       &losesInfo);
134   return !losesInfo;
135 }
136 
137 //===----------------------------------------------------------------------===//
138 //                              ISD Namespace
139 //===----------------------------------------------------------------------===//
140 
isConstantSplatVector(const SDNode * N,APInt & SplatVal)141 bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) {
142   if (N->getOpcode() == ISD::SPLAT_VECTOR) {
143     unsigned EltSize =
144         N->getValueType(0).getVectorElementType().getSizeInBits();
145     if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0))) {
146       SplatVal = Op0->getAPIntValue().trunc(EltSize);
147       return true;
148     }
149     if (auto *Op0 = dyn_cast<ConstantFPSDNode>(N->getOperand(0))) {
150       SplatVal = Op0->getValueAPF().bitcastToAPInt().trunc(EltSize);
151       return true;
152     }
153   }
154 
155   auto *BV = dyn_cast<BuildVectorSDNode>(N);
156   if (!BV)
157     return false;
158 
159   APInt SplatUndef;
160   unsigned SplatBitSize;
161   bool HasUndefs;
162   unsigned EltSize = N->getValueType(0).getVectorElementType().getSizeInBits();
163   return BV->isConstantSplat(SplatVal, SplatUndef, SplatBitSize, HasUndefs,
164                              EltSize) &&
165          EltSize == SplatBitSize;
166 }
167 
168 // FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be
169 // specializations of the more general isConstantSplatVector()?
170 
isConstantSplatVectorAllOnes(const SDNode * N,bool BuildVectorOnly)171 bool ISD::isConstantSplatVectorAllOnes(const SDNode *N, bool BuildVectorOnly) {
172   // Look through a bit convert.
173   while (N->getOpcode() == ISD::BITCAST)
174     N = N->getOperand(0).getNode();
175 
176   if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) {
177     APInt SplatVal;
178     return isConstantSplatVector(N, SplatVal) && SplatVal.isAllOnes();
179   }
180 
181   if (N->getOpcode() != ISD::BUILD_VECTOR) return false;
182 
183   unsigned i = 0, e = N->getNumOperands();
184 
185   // Skip over all of the undef values.
186   while (i != e && N->getOperand(i).isUndef())
187     ++i;
188 
189   // Do not accept an all-undef vector.
190   if (i == e) return false;
191 
192   // Do not accept build_vectors that aren't all constants or which have non-~0
193   // elements. We have to be a bit careful here, as the type of the constant
194   // may not be the same as the type of the vector elements due to type
195   // legalization (the elements are promoted to a legal type for the target and
196   // a vector of a type may be legal when the base element type is not).
197   // We only want to check enough bits to cover the vector elements, because
198   // we care if the resultant vector is all ones, not whether the individual
199   // constants are.
200   SDValue NotZero = N->getOperand(i);
201   unsigned EltSize = N->getValueType(0).getScalarSizeInBits();
202   if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(NotZero)) {
203     if (CN->getAPIntValue().countTrailingOnes() < EltSize)
204       return false;
205   } else if (ConstantFPSDNode *CFPN = dyn_cast<ConstantFPSDNode>(NotZero)) {
206     if (CFPN->getValueAPF().bitcastToAPInt().countTrailingOnes() < EltSize)
207       return false;
208   } else
209     return false;
210 
211   // Okay, we have at least one ~0 value, check to see if the rest match or are
212   // undefs. Even with the above element type twiddling, this should be OK, as
213   // the same type legalization should have applied to all the elements.
214   for (++i; i != e; ++i)
215     if (N->getOperand(i) != NotZero && !N->getOperand(i).isUndef())
216       return false;
217   return true;
218 }
219 
isConstantSplatVectorAllZeros(const SDNode * N,bool BuildVectorOnly)220 bool ISD::isConstantSplatVectorAllZeros(const SDNode *N, bool BuildVectorOnly) {
221   // Look through a bit convert.
222   while (N->getOpcode() == ISD::BITCAST)
223     N = N->getOperand(0).getNode();
224 
225   if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) {
226     APInt SplatVal;
227     return isConstantSplatVector(N, SplatVal) && SplatVal.isZero();
228   }
229 
230   if (N->getOpcode() != ISD::BUILD_VECTOR) return false;
231 
232   bool IsAllUndef = true;
233   for (const SDValue &Op : N->op_values()) {
234     if (Op.isUndef())
235       continue;
236     IsAllUndef = false;
237     // Do not accept build_vectors that aren't all constants or which have non-0
238     // elements. We have to be a bit careful here, as the type of the constant
239     // may not be the same as the type of the vector elements due to type
240     // legalization (the elements are promoted to a legal type for the target
241     // and a vector of a type may be legal when the base element type is not).
242     // We only want to check enough bits to cover the vector elements, because
243     // we care if the resultant vector is all zeros, not whether the individual
244     // constants are.
245     unsigned EltSize = N->getValueType(0).getScalarSizeInBits();
246     if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Op)) {
247       if (CN->getAPIntValue().countTrailingZeros() < EltSize)
248         return false;
249     } else if (ConstantFPSDNode *CFPN = dyn_cast<ConstantFPSDNode>(Op)) {
250       if (CFPN->getValueAPF().bitcastToAPInt().countTrailingZeros() < EltSize)
251         return false;
252     } else
253       return false;
254   }
255 
256   // Do not accept an all-undef vector.
257   if (IsAllUndef)
258     return false;
259   return true;
260 }
261 
isBuildVectorAllOnes(const SDNode * N)262 bool ISD::isBuildVectorAllOnes(const SDNode *N) {
263   return isConstantSplatVectorAllOnes(N, /*BuildVectorOnly*/ true);
264 }
265 
isBuildVectorAllZeros(const SDNode * N)266 bool ISD::isBuildVectorAllZeros(const SDNode *N) {
267   return isConstantSplatVectorAllZeros(N, /*BuildVectorOnly*/ true);
268 }
269 
isBuildVectorOfConstantSDNodes(const SDNode * N)270 bool ISD::isBuildVectorOfConstantSDNodes(const SDNode *N) {
271   if (N->getOpcode() != ISD::BUILD_VECTOR)
272     return false;
273 
274   for (const SDValue &Op : N->op_values()) {
275     if (Op.isUndef())
276       continue;
277     if (!isa<ConstantSDNode>(Op))
278       return false;
279   }
280   return true;
281 }
282 
isBuildVectorOfConstantFPSDNodes(const SDNode * N)283 bool ISD::isBuildVectorOfConstantFPSDNodes(const SDNode *N) {
284   if (N->getOpcode() != ISD::BUILD_VECTOR)
285     return false;
286 
287   for (const SDValue &Op : N->op_values()) {
288     if (Op.isUndef())
289       continue;
290     if (!isa<ConstantFPSDNode>(Op))
291       return false;
292   }
293   return true;
294 }
295 
isVectorShrinkable(const SDNode * N,unsigned NewEltSize,bool Signed)296 bool ISD::isVectorShrinkable(const SDNode *N, unsigned NewEltSize,
297                              bool Signed) {
298   assert(N->getValueType(0).isVector() && "Expected a vector!");
299 
300   unsigned EltSize = N->getValueType(0).getScalarSizeInBits();
301   if (EltSize <= NewEltSize)
302     return false;
303 
304   if (N->getOpcode() == ISD::ZERO_EXTEND) {
305     return (N->getOperand(0).getValueType().getScalarSizeInBits() <=
306             NewEltSize) &&
307            !Signed;
308   }
309   if (N->getOpcode() == ISD::SIGN_EXTEND) {
310     return (N->getOperand(0).getValueType().getScalarSizeInBits() <=
311             NewEltSize) &&
312            Signed;
313   }
314   if (N->getOpcode() != ISD::BUILD_VECTOR)
315     return false;
316 
317   for (const SDValue &Op : N->op_values()) {
318     if (Op.isUndef())
319       continue;
320     if (!isa<ConstantSDNode>(Op))
321       return false;
322 
323     APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().trunc(EltSize);
324     if (Signed && C.trunc(NewEltSize).sext(EltSize) != C)
325       return false;
326     if (!Signed && C.trunc(NewEltSize).zext(EltSize) != C)
327       return false;
328   }
329 
330   return true;
331 }
332 
allOperandsUndef(const SDNode * N)333 bool ISD::allOperandsUndef(const SDNode *N) {
334   // Return false if the node has no operands.
335   // This is "logically inconsistent" with the definition of "all" but
336   // is probably the desired behavior.
337   if (N->getNumOperands() == 0)
338     return false;
339   return all_of(N->op_values(), [](SDValue Op) { return Op.isUndef(); });
340 }
341 
isFreezeUndef(const SDNode * N)342 bool ISD::isFreezeUndef(const SDNode *N) {
343   return N->getOpcode() == ISD::FREEZE && N->getOperand(0).isUndef();
344 }
345 
matchUnaryPredicate(SDValue Op,std::function<bool (ConstantSDNode *)> Match,bool AllowUndefs)346 bool ISD::matchUnaryPredicate(SDValue Op,
347                               std::function<bool(ConstantSDNode *)> Match,
348                               bool AllowUndefs) {
349   // FIXME: Add support for scalar UNDEF cases?
350   if (auto *Cst = dyn_cast<ConstantSDNode>(Op))
351     return Match(Cst);
352 
353   // FIXME: Add support for vector UNDEF cases?
354   if (ISD::BUILD_VECTOR != Op.getOpcode() &&
355       ISD::SPLAT_VECTOR != Op.getOpcode())
356     return false;
357 
358   EVT SVT = Op.getValueType().getScalarType();
359   for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
360     if (AllowUndefs && Op.getOperand(i).isUndef()) {
361       if (!Match(nullptr))
362         return false;
363       continue;
364     }
365 
366     auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i));
367     if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst))
368       return false;
369   }
370   return true;
371 }
372 
matchBinaryPredicate(SDValue LHS,SDValue RHS,std::function<bool (ConstantSDNode *,ConstantSDNode *)> Match,bool AllowUndefs,bool AllowTypeMismatch)373 bool ISD::matchBinaryPredicate(
374     SDValue LHS, SDValue RHS,
375     std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match,
376     bool AllowUndefs, bool AllowTypeMismatch) {
377   if (!AllowTypeMismatch && LHS.getValueType() != RHS.getValueType())
378     return false;
379 
380   // TODO: Add support for scalar UNDEF cases?
381   if (auto *LHSCst = dyn_cast<ConstantSDNode>(LHS))
382     if (auto *RHSCst = dyn_cast<ConstantSDNode>(RHS))
383       return Match(LHSCst, RHSCst);
384 
385   // TODO: Add support for vector UNDEF cases?
386   if (LHS.getOpcode() != RHS.getOpcode() ||
387       (LHS.getOpcode() != ISD::BUILD_VECTOR &&
388        LHS.getOpcode() != ISD::SPLAT_VECTOR))
389     return false;
390 
391   EVT SVT = LHS.getValueType().getScalarType();
392   for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
393     SDValue LHSOp = LHS.getOperand(i);
394     SDValue RHSOp = RHS.getOperand(i);
395     bool LHSUndef = AllowUndefs && LHSOp.isUndef();
396     bool RHSUndef = AllowUndefs && RHSOp.isUndef();
397     auto *LHSCst = dyn_cast<ConstantSDNode>(LHSOp);
398     auto *RHSCst = dyn_cast<ConstantSDNode>(RHSOp);
399     if ((!LHSCst && !LHSUndef) || (!RHSCst && !RHSUndef))
400       return false;
401     if (!AllowTypeMismatch && (LHSOp.getValueType() != SVT ||
402                                LHSOp.getValueType() != RHSOp.getValueType()))
403       return false;
404     if (!Match(LHSCst, RHSCst))
405       return false;
406   }
407   return true;
408 }
409 
getVecReduceBaseOpcode(unsigned VecReduceOpcode)410 ISD::NodeType ISD::getVecReduceBaseOpcode(unsigned VecReduceOpcode) {
411   switch (VecReduceOpcode) {
412   default:
413     llvm_unreachable("Expected VECREDUCE opcode");
414   case ISD::VECREDUCE_FADD:
415   case ISD::VECREDUCE_SEQ_FADD:
416   case ISD::VP_REDUCE_FADD:
417   case ISD::VP_REDUCE_SEQ_FADD:
418     return ISD::FADD;
419   case ISD::VECREDUCE_FMUL:
420   case ISD::VECREDUCE_SEQ_FMUL:
421   case ISD::VP_REDUCE_FMUL:
422   case ISD::VP_REDUCE_SEQ_FMUL:
423     return ISD::FMUL;
424   case ISD::VECREDUCE_ADD:
425   case ISD::VP_REDUCE_ADD:
426     return ISD::ADD;
427   case ISD::VECREDUCE_MUL:
428   case ISD::VP_REDUCE_MUL:
429     return ISD::MUL;
430   case ISD::VECREDUCE_AND:
431   case ISD::VP_REDUCE_AND:
432     return ISD::AND;
433   case ISD::VECREDUCE_OR:
434   case ISD::VP_REDUCE_OR:
435     return ISD::OR;
436   case ISD::VECREDUCE_XOR:
437   case ISD::VP_REDUCE_XOR:
438     return ISD::XOR;
439   case ISD::VECREDUCE_SMAX:
440   case ISD::VP_REDUCE_SMAX:
441     return ISD::SMAX;
442   case ISD::VECREDUCE_SMIN:
443   case ISD::VP_REDUCE_SMIN:
444     return ISD::SMIN;
445   case ISD::VECREDUCE_UMAX:
446   case ISD::VP_REDUCE_UMAX:
447     return ISD::UMAX;
448   case ISD::VECREDUCE_UMIN:
449   case ISD::VP_REDUCE_UMIN:
450     return ISD::UMIN;
451   case ISD::VECREDUCE_FMAX:
452   case ISD::VP_REDUCE_FMAX:
453     return ISD::FMAXNUM;
454   case ISD::VECREDUCE_FMIN:
455   case ISD::VP_REDUCE_FMIN:
456     return ISD::FMINNUM;
457   }
458 }
459 
isVPOpcode(unsigned Opcode)460 bool ISD::isVPOpcode(unsigned Opcode) {
461   switch (Opcode) {
462   default:
463     return false;
464 #define BEGIN_REGISTER_VP_SDNODE(VPSD, ...)                                    \
465   case ISD::VPSD:                                                              \
466     return true;
467 #include "llvm/IR/VPIntrinsics.def"
468   }
469 }
470 
isVPBinaryOp(unsigned Opcode)471 bool ISD::isVPBinaryOp(unsigned Opcode) {
472   switch (Opcode) {
473   default:
474     break;
475 #define BEGIN_REGISTER_VP_SDNODE(VPSD, ...) case ISD::VPSD:
476 #define VP_PROPERTY_BINARYOP return true;
477 #define END_REGISTER_VP_SDNODE(VPSD) break;
478 #include "llvm/IR/VPIntrinsics.def"
479   }
480   return false;
481 }
482 
isVPReduction(unsigned Opcode)483 bool ISD::isVPReduction(unsigned Opcode) {
484   switch (Opcode) {
485   default:
486     break;
487 #define BEGIN_REGISTER_VP_SDNODE(VPSD, ...) case ISD::VPSD:
488 #define VP_PROPERTY_REDUCTION(STARTPOS, ...) return true;
489 #define END_REGISTER_VP_SDNODE(VPSD) break;
490 #include "llvm/IR/VPIntrinsics.def"
491   }
492   return false;
493 }
494 
495 /// The operand position of the vector mask.
getVPMaskIdx(unsigned Opcode)496 std::optional<unsigned> ISD::getVPMaskIdx(unsigned Opcode) {
497   switch (Opcode) {
498   default:
499     return std::nullopt;
500 #define BEGIN_REGISTER_VP_SDNODE(VPSD, LEGALPOS, TDNAME, MASKPOS, ...)         \
501   case ISD::VPSD:                                                              \
502     return MASKPOS;
503 #include "llvm/IR/VPIntrinsics.def"
504   }
505 }
506 
507 /// The operand position of the explicit vector length parameter.
getVPExplicitVectorLengthIdx(unsigned Opcode)508 std::optional<unsigned> ISD::getVPExplicitVectorLengthIdx(unsigned Opcode) {
509   switch (Opcode) {
510   default:
511     return std::nullopt;
512 #define BEGIN_REGISTER_VP_SDNODE(VPSD, LEGALPOS, TDNAME, MASKPOS, EVLPOS)      \
513   case ISD::VPSD:                                                              \
514     return EVLPOS;
515 #include "llvm/IR/VPIntrinsics.def"
516   }
517 }
518 
getExtForLoadExtType(bool IsFP,ISD::LoadExtType ExtType)519 ISD::NodeType ISD::getExtForLoadExtType(bool IsFP, ISD::LoadExtType ExtType) {
520   switch (ExtType) {
521   case ISD::EXTLOAD:
522     return IsFP ? ISD::FP_EXTEND : ISD::ANY_EXTEND;
523   case ISD::SEXTLOAD:
524     return ISD::SIGN_EXTEND;
525   case ISD::ZEXTLOAD:
526     return ISD::ZERO_EXTEND;
527   default:
528     break;
529   }
530 
531   llvm_unreachable("Invalid LoadExtType");
532 }
533 
getSetCCSwappedOperands(ISD::CondCode Operation)534 ISD::CondCode ISD::getSetCCSwappedOperands(ISD::CondCode Operation) {
535   // To perform this operation, we just need to swap the L and G bits of the
536   // operation.
537   unsigned OldL = (Operation >> 2) & 1;
538   unsigned OldG = (Operation >> 1) & 1;
539   return ISD::CondCode((Operation & ~6) |  // Keep the N, U, E bits
540                        (OldL << 1) |       // New G bit
541                        (OldG << 2));       // New L bit.
542 }
543 
getSetCCInverseImpl(ISD::CondCode Op,bool isIntegerLike)544 static ISD::CondCode getSetCCInverseImpl(ISD::CondCode Op, bool isIntegerLike) {
545   unsigned Operation = Op;
546   if (isIntegerLike)
547     Operation ^= 7;   // Flip L, G, E bits, but not U.
548   else
549     Operation ^= 15;  // Flip all of the condition bits.
550 
551   if (Operation > ISD::SETTRUE2)
552     Operation &= ~8;  // Don't let N and U bits get set.
553 
554   return ISD::CondCode(Operation);
555 }
556 
getSetCCInverse(ISD::CondCode Op,EVT Type)557 ISD::CondCode ISD::getSetCCInverse(ISD::CondCode Op, EVT Type) {
558   return getSetCCInverseImpl(Op, Type.isInteger());
559 }
560 
getSetCCInverse(ISD::CondCode Op,bool isIntegerLike)561 ISD::CondCode ISD::GlobalISel::getSetCCInverse(ISD::CondCode Op,
562                                                bool isIntegerLike) {
563   return getSetCCInverseImpl(Op, isIntegerLike);
564 }
565 
566 /// For an integer comparison, return 1 if the comparison is a signed operation
567 /// and 2 if the result is an unsigned comparison. Return zero if the operation
568 /// does not depend on the sign of the input (setne and seteq).
isSignedOp(ISD::CondCode Opcode)569 static int isSignedOp(ISD::CondCode Opcode) {
570   switch (Opcode) {
571   default: llvm_unreachable("Illegal integer setcc operation!");
572   case ISD::SETEQ:
573   case ISD::SETNE: return 0;
574   case ISD::SETLT:
575   case ISD::SETLE:
576   case ISD::SETGT:
577   case ISD::SETGE: return 1;
578   case ISD::SETULT:
579   case ISD::SETULE:
580   case ISD::SETUGT:
581   case ISD::SETUGE: return 2;
582   }
583 }
584 
getSetCCOrOperation(ISD::CondCode Op1,ISD::CondCode Op2,EVT Type)585 ISD::CondCode ISD::getSetCCOrOperation(ISD::CondCode Op1, ISD::CondCode Op2,
586                                        EVT Type) {
587   bool IsInteger = Type.isInteger();
588   if (IsInteger && (isSignedOp(Op1) | isSignedOp(Op2)) == 3)
589     // Cannot fold a signed integer setcc with an unsigned integer setcc.
590     return ISD::SETCC_INVALID;
591 
592   unsigned Op = Op1 | Op2;  // Combine all of the condition bits.
593 
594   // If the N and U bits get set, then the resultant comparison DOES suddenly
595   // care about orderedness, and it is true when ordered.
596   if (Op > ISD::SETTRUE2)
597     Op &= ~16;     // Clear the U bit if the N bit is set.
598 
599   // Canonicalize illegal integer setcc's.
600   if (IsInteger && Op == ISD::SETUNE)  // e.g. SETUGT | SETULT
601     Op = ISD::SETNE;
602 
603   return ISD::CondCode(Op);
604 }
605 
getSetCCAndOperation(ISD::CondCode Op1,ISD::CondCode Op2,EVT Type)606 ISD::CondCode ISD::getSetCCAndOperation(ISD::CondCode Op1, ISD::CondCode Op2,
607                                         EVT Type) {
608   bool IsInteger = Type.isInteger();
609   if (IsInteger && (isSignedOp(Op1) | isSignedOp(Op2)) == 3)
610     // Cannot fold a signed setcc with an unsigned setcc.
611     return ISD::SETCC_INVALID;
612 
613   // Combine all of the condition bits.
614   ISD::CondCode Result = ISD::CondCode(Op1 & Op2);
615 
616   // Canonicalize illegal integer setcc's.
617   if (IsInteger) {
618     switch (Result) {
619     default: break;
620     case ISD::SETUO : Result = ISD::SETFALSE; break;  // SETUGT & SETULT
621     case ISD::SETOEQ:                                 // SETEQ  & SETU[LG]E
622     case ISD::SETUEQ: Result = ISD::SETEQ   ; break;  // SETUGE & SETULE
623     case ISD::SETOLT: Result = ISD::SETULT  ; break;  // SETULT & SETNE
624     case ISD::SETOGT: Result = ISD::SETUGT  ; break;  // SETUGT & SETNE
625     }
626   }
627 
628   return Result;
629 }
630 
631 //===----------------------------------------------------------------------===//
632 //                           SDNode Profile Support
633 //===----------------------------------------------------------------------===//
634 
635 /// AddNodeIDOpcode - Add the node opcode to the NodeID data.
AddNodeIDOpcode(FoldingSetNodeID & ID,unsigned OpC)636 static void AddNodeIDOpcode(FoldingSetNodeID &ID, unsigned OpC)  {
637   ID.AddInteger(OpC);
638 }
639 
640 /// AddNodeIDValueTypes - Value type lists are intern'd so we can represent them
641 /// solely with their pointer.
AddNodeIDValueTypes(FoldingSetNodeID & ID,SDVTList VTList)642 static void AddNodeIDValueTypes(FoldingSetNodeID &ID, SDVTList VTList) {
643   ID.AddPointer(VTList.VTs);
644 }
645 
646 /// AddNodeIDOperands - Various routines for adding operands to the NodeID data.
AddNodeIDOperands(FoldingSetNodeID & ID,ArrayRef<SDValue> Ops)647 static void AddNodeIDOperands(FoldingSetNodeID &ID,
648                               ArrayRef<SDValue> Ops) {
649   for (const auto &Op : Ops) {
650     ID.AddPointer(Op.getNode());
651     ID.AddInteger(Op.getResNo());
652   }
653 }
654 
655 /// AddNodeIDOperands - Various routines for adding operands to the NodeID data.
AddNodeIDOperands(FoldingSetNodeID & ID,ArrayRef<SDUse> Ops)656 static void AddNodeIDOperands(FoldingSetNodeID &ID,
657                               ArrayRef<SDUse> Ops) {
658   for (const auto &Op : Ops) {
659     ID.AddPointer(Op.getNode());
660     ID.AddInteger(Op.getResNo());
661   }
662 }
663 
AddNodeIDNode(FoldingSetNodeID & ID,unsigned OpC,SDVTList VTList,ArrayRef<SDValue> OpList)664 static void AddNodeIDNode(FoldingSetNodeID &ID, unsigned OpC,
665                           SDVTList VTList, ArrayRef<SDValue> OpList) {
666   AddNodeIDOpcode(ID, OpC);
667   AddNodeIDValueTypes(ID, VTList);
668   AddNodeIDOperands(ID, OpList);
669 }
670 
671 /// If this is an SDNode with special info, add this info to the NodeID data.
AddNodeIDCustom(FoldingSetNodeID & ID,const SDNode * N)672 static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
673   switch (N->getOpcode()) {
674   case ISD::TargetExternalSymbol:
675   case ISD::ExternalSymbol:
676   case ISD::MCSymbol:
677     llvm_unreachable("Should only be used on nodes with operands");
678   default: break;  // Normal nodes don't need extra info.
679   case ISD::TargetConstant:
680   case ISD::Constant: {
681     const ConstantSDNode *C = cast<ConstantSDNode>(N);
682     ID.AddPointer(C->getConstantIntValue());
683     ID.AddBoolean(C->isOpaque());
684     break;
685   }
686   case ISD::TargetConstantFP:
687   case ISD::ConstantFP:
688     ID.AddPointer(cast<ConstantFPSDNode>(N)->getConstantFPValue());
689     break;
690   case ISD::TargetGlobalAddress:
691   case ISD::GlobalAddress:
692   case ISD::TargetGlobalTLSAddress:
693   case ISD::GlobalTLSAddress: {
694     const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(N);
695     ID.AddPointer(GA->getGlobal());
696     ID.AddInteger(GA->getOffset());
697     ID.AddInteger(GA->getTargetFlags());
698     break;
699   }
700   case ISD::BasicBlock:
701     ID.AddPointer(cast<BasicBlockSDNode>(N)->getBasicBlock());
702     break;
703   case ISD::Register:
704     ID.AddInteger(cast<RegisterSDNode>(N)->getReg());
705     break;
706   case ISD::RegisterMask:
707     ID.AddPointer(cast<RegisterMaskSDNode>(N)->getRegMask());
708     break;
709   case ISD::SRCVALUE:
710     ID.AddPointer(cast<SrcValueSDNode>(N)->getValue());
711     break;
712   case ISD::FrameIndex:
713   case ISD::TargetFrameIndex:
714     ID.AddInteger(cast<FrameIndexSDNode>(N)->getIndex());
715     break;
716   case ISD::LIFETIME_START:
717   case ISD::LIFETIME_END:
718     if (cast<LifetimeSDNode>(N)->hasOffset()) {
719       ID.AddInteger(cast<LifetimeSDNode>(N)->getSize());
720       ID.AddInteger(cast<LifetimeSDNode>(N)->getOffset());
721     }
722     break;
723   case ISD::PSEUDO_PROBE:
724     ID.AddInteger(cast<PseudoProbeSDNode>(N)->getGuid());
725     ID.AddInteger(cast<PseudoProbeSDNode>(N)->getIndex());
726     ID.AddInteger(cast<PseudoProbeSDNode>(N)->getAttributes());
727     break;
728   case ISD::JumpTable:
729   case ISD::TargetJumpTable:
730     ID.AddInteger(cast<JumpTableSDNode>(N)->getIndex());
731     ID.AddInteger(cast<JumpTableSDNode>(N)->getTargetFlags());
732     break;
733   case ISD::ConstantPool:
734   case ISD::TargetConstantPool: {
735     const ConstantPoolSDNode *CP = cast<ConstantPoolSDNode>(N);
736     ID.AddInteger(CP->getAlign().value());
737     ID.AddInteger(CP->getOffset());
738     if (CP->isMachineConstantPoolEntry())
739       CP->getMachineCPVal()->addSelectionDAGCSEId(ID);
740     else
741       ID.AddPointer(CP->getConstVal());
742     ID.AddInteger(CP->getTargetFlags());
743     break;
744   }
745   case ISD::TargetIndex: {
746     const TargetIndexSDNode *TI = cast<TargetIndexSDNode>(N);
747     ID.AddInteger(TI->getIndex());
748     ID.AddInteger(TI->getOffset());
749     ID.AddInteger(TI->getTargetFlags());
750     break;
751   }
752   case ISD::LOAD: {
753     const LoadSDNode *LD = cast<LoadSDNode>(N);
754     ID.AddInteger(LD->getMemoryVT().getRawBits());
755     ID.AddInteger(LD->getRawSubclassData());
756     ID.AddInteger(LD->getPointerInfo().getAddrSpace());
757     ID.AddInteger(LD->getMemOperand()->getFlags());
758     break;
759   }
760   case ISD::STORE: {
761     const StoreSDNode *ST = cast<StoreSDNode>(N);
762     ID.AddInteger(ST->getMemoryVT().getRawBits());
763     ID.AddInteger(ST->getRawSubclassData());
764     ID.AddInteger(ST->getPointerInfo().getAddrSpace());
765     ID.AddInteger(ST->getMemOperand()->getFlags());
766     break;
767   }
768   case ISD::VP_LOAD: {
769     const VPLoadSDNode *ELD = cast<VPLoadSDNode>(N);
770     ID.AddInteger(ELD->getMemoryVT().getRawBits());
771     ID.AddInteger(ELD->getRawSubclassData());
772     ID.AddInteger(ELD->getPointerInfo().getAddrSpace());
773     ID.AddInteger(ELD->getMemOperand()->getFlags());
774     break;
775   }
776   case ISD::VP_STORE: {
777     const VPStoreSDNode *EST = cast<VPStoreSDNode>(N);
778     ID.AddInteger(EST->getMemoryVT().getRawBits());
779     ID.AddInteger(EST->getRawSubclassData());
780     ID.AddInteger(EST->getPointerInfo().getAddrSpace());
781     ID.AddInteger(EST->getMemOperand()->getFlags());
782     break;
783   }
784   case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: {
785     const VPStridedLoadSDNode *SLD = cast<VPStridedLoadSDNode>(N);
786     ID.AddInteger(SLD->getMemoryVT().getRawBits());
787     ID.AddInteger(SLD->getRawSubclassData());
788     ID.AddInteger(SLD->getPointerInfo().getAddrSpace());
789     break;
790   }
791   case ISD::EXPERIMENTAL_VP_STRIDED_STORE: {
792     const VPStridedStoreSDNode *SST = cast<VPStridedStoreSDNode>(N);
793     ID.AddInteger(SST->getMemoryVT().getRawBits());
794     ID.AddInteger(SST->getRawSubclassData());
795     ID.AddInteger(SST->getPointerInfo().getAddrSpace());
796     break;
797   }
798   case ISD::VP_GATHER: {
799     const VPGatherSDNode *EG = cast<VPGatherSDNode>(N);
800     ID.AddInteger(EG->getMemoryVT().getRawBits());
801     ID.AddInteger(EG->getRawSubclassData());
802     ID.AddInteger(EG->getPointerInfo().getAddrSpace());
803     ID.AddInteger(EG->getMemOperand()->getFlags());
804     break;
805   }
806   case ISD::VP_SCATTER: {
807     const VPScatterSDNode *ES = cast<VPScatterSDNode>(N);
808     ID.AddInteger(ES->getMemoryVT().getRawBits());
809     ID.AddInteger(ES->getRawSubclassData());
810     ID.AddInteger(ES->getPointerInfo().getAddrSpace());
811     ID.AddInteger(ES->getMemOperand()->getFlags());
812     break;
813   }
814   case ISD::MLOAD: {
815     const MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
816     ID.AddInteger(MLD->getMemoryVT().getRawBits());
817     ID.AddInteger(MLD->getRawSubclassData());
818     ID.AddInteger(MLD->getPointerInfo().getAddrSpace());
819     ID.AddInteger(MLD->getMemOperand()->getFlags());
820     break;
821   }
822   case ISD::MSTORE: {
823     const MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
824     ID.AddInteger(MST->getMemoryVT().getRawBits());
825     ID.AddInteger(MST->getRawSubclassData());
826     ID.AddInteger(MST->getPointerInfo().getAddrSpace());
827     ID.AddInteger(MST->getMemOperand()->getFlags());
828     break;
829   }
830   case ISD::MGATHER: {
831     const MaskedGatherSDNode *MG = cast<MaskedGatherSDNode>(N);
832     ID.AddInteger(MG->getMemoryVT().getRawBits());
833     ID.AddInteger(MG->getRawSubclassData());
834     ID.AddInteger(MG->getPointerInfo().getAddrSpace());
835     ID.AddInteger(MG->getMemOperand()->getFlags());
836     break;
837   }
838   case ISD::MSCATTER: {
839     const MaskedScatterSDNode *MS = cast<MaskedScatterSDNode>(N);
840     ID.AddInteger(MS->getMemoryVT().getRawBits());
841     ID.AddInteger(MS->getRawSubclassData());
842     ID.AddInteger(MS->getPointerInfo().getAddrSpace());
843     ID.AddInteger(MS->getMemOperand()->getFlags());
844     break;
845   }
846   case ISD::ATOMIC_CMP_SWAP:
847   case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
848   case ISD::ATOMIC_SWAP:
849   case ISD::ATOMIC_LOAD_ADD:
850   case ISD::ATOMIC_LOAD_SUB:
851   case ISD::ATOMIC_LOAD_AND:
852   case ISD::ATOMIC_LOAD_CLR:
853   case ISD::ATOMIC_LOAD_OR:
854   case ISD::ATOMIC_LOAD_XOR:
855   case ISD::ATOMIC_LOAD_NAND:
856   case ISD::ATOMIC_LOAD_MIN:
857   case ISD::ATOMIC_LOAD_MAX:
858   case ISD::ATOMIC_LOAD_UMIN:
859   case ISD::ATOMIC_LOAD_UMAX:
860   case ISD::ATOMIC_LOAD:
861   case ISD::ATOMIC_STORE: {
862     const AtomicSDNode *AT = cast<AtomicSDNode>(N);
863     ID.AddInteger(AT->getMemoryVT().getRawBits());
864     ID.AddInteger(AT->getRawSubclassData());
865     ID.AddInteger(AT->getPointerInfo().getAddrSpace());
866     ID.AddInteger(AT->getMemOperand()->getFlags());
867     break;
868   }
869   case ISD::PREFETCH: {
870     const MemSDNode *PF = cast<MemSDNode>(N);
871     ID.AddInteger(PF->getPointerInfo().getAddrSpace());
872     ID.AddInteger(PF->getMemOperand()->getFlags());
873     break;
874   }
875   case ISD::VECTOR_SHUFFLE: {
876     const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
877     for (unsigned i = 0, e = N->getValueType(0).getVectorNumElements();
878          i != e; ++i)
879       ID.AddInteger(SVN->getMaskElt(i));
880     break;
881   }
882   case ISD::TargetBlockAddress:
883   case ISD::BlockAddress: {
884     const BlockAddressSDNode *BA = cast<BlockAddressSDNode>(N);
885     ID.AddPointer(BA->getBlockAddress());
886     ID.AddInteger(BA->getOffset());
887     ID.AddInteger(BA->getTargetFlags());
888     break;
889   }
890   case ISD::AssertAlign:
891     ID.AddInteger(cast<AssertAlignSDNode>(N)->getAlign().value());
892     break;
893   } // end switch (N->getOpcode())
894 
895   // Target specific memory nodes could also have address spaces and flags
896   // to check.
897   if (N->isTargetMemoryOpcode()) {
898     const MemSDNode *MN = cast<MemSDNode>(N);
899     ID.AddInteger(MN->getPointerInfo().getAddrSpace());
900     ID.AddInteger(MN->getMemOperand()->getFlags());
901   }
902 }
903 
904 /// AddNodeIDNode - Generic routine for adding a nodes info to the NodeID
905 /// data.
AddNodeIDNode(FoldingSetNodeID & ID,const SDNode * N)906 static void AddNodeIDNode(FoldingSetNodeID &ID, const SDNode *N) {
907   AddNodeIDOpcode(ID, N->getOpcode());
908   // Add the return value info.
909   AddNodeIDValueTypes(ID, N->getVTList());
910   // Add the operand info.
911   AddNodeIDOperands(ID, N->ops());
912 
913   // Handle SDNode leafs with special info.
914   AddNodeIDCustom(ID, N);
915 }
916 
917 //===----------------------------------------------------------------------===//
918 //                              SelectionDAG Class
919 //===----------------------------------------------------------------------===//
920 
921 /// doNotCSE - Return true if CSE should not be performed for this node.
doNotCSE(SDNode * N)922 static bool doNotCSE(SDNode *N) {
923   if (N->getValueType(0) == MVT::Glue)
924     return true; // Never CSE anything that produces a flag.
925 
926   switch (N->getOpcode()) {
927   default: break;
928   case ISD::HANDLENODE:
929   case ISD::EH_LABEL:
930     return true;   // Never CSE these nodes.
931   }
932 
933   // Check that remaining values produced are not flags.
934   for (unsigned i = 1, e = N->getNumValues(); i != e; ++i)
935     if (N->getValueType(i) == MVT::Glue)
936       return true; // Never CSE anything that produces a flag.
937 
938   return false;
939 }
940 
941 /// RemoveDeadNodes - This method deletes all unreachable nodes in the
942 /// SelectionDAG.
RemoveDeadNodes()943 void SelectionDAG::RemoveDeadNodes() {
944   // Create a dummy node (which is not added to allnodes), that adds a reference
945   // to the root node, preventing it from being deleted.
946   HandleSDNode Dummy(getRoot());
947 
948   SmallVector<SDNode*, 128> DeadNodes;
949 
950   // Add all obviously-dead nodes to the DeadNodes worklist.
951   for (SDNode &Node : allnodes())
952     if (Node.use_empty())
953       DeadNodes.push_back(&Node);
954 
955   RemoveDeadNodes(DeadNodes);
956 
957   // If the root changed (e.g. it was a dead load, update the root).
958   setRoot(Dummy.getValue());
959 }
960 
961 /// RemoveDeadNodes - This method deletes the unreachable nodes in the
962 /// given list, and any nodes that become unreachable as a result.
RemoveDeadNodes(SmallVectorImpl<SDNode * > & DeadNodes)963 void SelectionDAG::RemoveDeadNodes(SmallVectorImpl<SDNode *> &DeadNodes) {
964 
965   // Process the worklist, deleting the nodes and adding their uses to the
966   // worklist.
967   while (!DeadNodes.empty()) {
968     SDNode *N = DeadNodes.pop_back_val();
969     // Skip to next node if we've already managed to delete the node. This could
970     // happen if replacing a node causes a node previously added to the node to
971     // be deleted.
972     if (N->getOpcode() == ISD::DELETED_NODE)
973       continue;
974 
975     for (DAGUpdateListener *DUL = UpdateListeners; DUL; DUL = DUL->Next)
976       DUL->NodeDeleted(N, nullptr);
977 
978     // Take the node out of the appropriate CSE map.
979     RemoveNodeFromCSEMaps(N);
980 
981     // Next, brutally remove the operand list.  This is safe to do, as there are
982     // no cycles in the graph.
983     for (SDNode::op_iterator I = N->op_begin(), E = N->op_end(); I != E; ) {
984       SDUse &Use = *I++;
985       SDNode *Operand = Use.getNode();
986       Use.set(SDValue());
987 
988       // Now that we removed this operand, see if there are no uses of it left.
989       if (Operand->use_empty())
990         DeadNodes.push_back(Operand);
991     }
992 
993     DeallocateNode(N);
994   }
995 }
996 
RemoveDeadNode(SDNode * N)997 void SelectionDAG::RemoveDeadNode(SDNode *N){
998   SmallVector<SDNode*, 16> DeadNodes(1, N);
999 
1000   // Create a dummy node that adds a reference to the root node, preventing
1001   // it from being deleted.  (This matters if the root is an operand of the
1002   // dead node.)
1003   HandleSDNode Dummy(getRoot());
1004 
1005   RemoveDeadNodes(DeadNodes);
1006 }
1007 
DeleteNode(SDNode * N)1008 void SelectionDAG::DeleteNode(SDNode *N) {
1009   // First take this out of the appropriate CSE map.
1010   RemoveNodeFromCSEMaps(N);
1011 
1012   // Finally, remove uses due to operands of this node, remove from the
1013   // AllNodes list, and delete the node.
1014   DeleteNodeNotInCSEMaps(N);
1015 }
1016 
DeleteNodeNotInCSEMaps(SDNode * N)1017 void SelectionDAG::DeleteNodeNotInCSEMaps(SDNode *N) {
1018   assert(N->getIterator() != AllNodes.begin() &&
1019          "Cannot delete the entry node!");
1020   assert(N->use_empty() && "Cannot delete a node that is not dead!");
1021 
1022   // Drop all of the operands and decrement used node's use counts.
1023   N->DropOperands();
1024 
1025   DeallocateNode(N);
1026 }
1027 
add(SDDbgValue * V,bool isParameter)1028 void SDDbgInfo::add(SDDbgValue *V, bool isParameter) {
1029   assert(!(V->isVariadic() && isParameter));
1030   if (isParameter)
1031     ByvalParmDbgValues.push_back(V);
1032   else
1033     DbgValues.push_back(V);
1034   for (const SDNode *Node : V->getSDNodes())
1035     if (Node)
1036       DbgValMap[Node].push_back(V);
1037 }
1038 
erase(const SDNode * Node)1039 void SDDbgInfo::erase(const SDNode *Node) {
1040   DbgValMapType::iterator I = DbgValMap.find(Node);
1041   if (I == DbgValMap.end())
1042     return;
1043   for (auto &Val: I->second)
1044     Val->setIsInvalidated();
1045   DbgValMap.erase(I);
1046 }
1047 
DeallocateNode(SDNode * N)1048 void SelectionDAG::DeallocateNode(SDNode *N) {
1049   // If we have operands, deallocate them.
1050   removeOperands(N);
1051 
1052   NodeAllocator.Deallocate(AllNodes.remove(N));
1053 
1054   // Set the opcode to DELETED_NODE to help catch bugs when node
1055   // memory is reallocated.
1056   // FIXME: There are places in SDag that have grown a dependency on the opcode
1057   // value in the released node.
1058   __asan_unpoison_memory_region(&N->NodeType, sizeof(N->NodeType));
1059   N->NodeType = ISD::DELETED_NODE;
1060 
1061   // If any of the SDDbgValue nodes refer to this SDNode, invalidate
1062   // them and forget about that node.
1063   DbgInfo->erase(N);
1064 
1065   // Invalidate extra info.
1066   SDEI.erase(N);
1067 }
1068 
1069 #ifndef NDEBUG
1070 /// VerifySDNode - Check the given SDNode.  Aborts if it is invalid.
VerifySDNode(SDNode * N)1071 static void VerifySDNode(SDNode *N) {
1072   switch (N->getOpcode()) {
1073   default:
1074     break;
1075   case ISD::BUILD_PAIR: {
1076     EVT VT = N->getValueType(0);
1077     assert(N->getNumValues() == 1 && "Too many results!");
1078     assert(!VT.isVector() && (VT.isInteger() || VT.isFloatingPoint()) &&
1079            "Wrong return type!");
1080     assert(N->getNumOperands() == 2 && "Wrong number of operands!");
1081     assert(N->getOperand(0).getValueType() == N->getOperand(1).getValueType() &&
1082            "Mismatched operand types!");
1083     assert(N->getOperand(0).getValueType().isInteger() == VT.isInteger() &&
1084            "Wrong operand type!");
1085     assert(VT.getSizeInBits() == 2 * N->getOperand(0).getValueSizeInBits() &&
1086            "Wrong return type size");
1087     break;
1088   }
1089   case ISD::BUILD_VECTOR: {
1090     assert(N->getNumValues() == 1 && "Too many results!");
1091     assert(N->getValueType(0).isVector() && "Wrong return type!");
1092     assert(N->getNumOperands() == N->getValueType(0).getVectorNumElements() &&
1093            "Wrong number of operands!");
1094     EVT EltVT = N->getValueType(0).getVectorElementType();
1095     for (const SDUse &Op : N->ops()) {
1096       assert((Op.getValueType() == EltVT ||
1097               (EltVT.isInteger() && Op.getValueType().isInteger() &&
1098                EltVT.bitsLE(Op.getValueType()))) &&
1099              "Wrong operand type!");
1100       assert(Op.getValueType() == N->getOperand(0).getValueType() &&
1101              "Operands must all have the same type");
1102     }
1103     break;
1104   }
1105   }
1106 }
1107 #endif // NDEBUG
1108 
1109 /// Insert a newly allocated node into the DAG.
1110 ///
1111 /// Handles insertion into the all nodes list and CSE map, as well as
1112 /// verification and other common operations when a new node is allocated.
InsertNode(SDNode * N)1113 void SelectionDAG::InsertNode(SDNode *N) {
1114   AllNodes.push_back(N);
1115 #ifndef NDEBUG
1116   N->PersistentId = NextPersistentId++;
1117   VerifySDNode(N);
1118 #endif
1119   for (DAGUpdateListener *DUL = UpdateListeners; DUL; DUL = DUL->Next)
1120     DUL->NodeInserted(N);
1121 }
1122 
1123 /// RemoveNodeFromCSEMaps - Take the specified node out of the CSE map that
1124 /// correspond to it.  This is useful when we're about to delete or repurpose
1125 /// the node.  We don't want future request for structurally identical nodes
1126 /// to return N anymore.
RemoveNodeFromCSEMaps(SDNode * N)1127 bool SelectionDAG::RemoveNodeFromCSEMaps(SDNode *N) {
1128   bool Erased = false;
1129   switch (N->getOpcode()) {
1130   case ISD::HANDLENODE: return false;  // noop.
1131   case ISD::CONDCODE:
1132     assert(CondCodeNodes[cast<CondCodeSDNode>(N)->get()] &&
1133            "Cond code doesn't exist!");
1134     Erased = CondCodeNodes[cast<CondCodeSDNode>(N)->get()] != nullptr;
1135     CondCodeNodes[cast<CondCodeSDNode>(N)->get()] = nullptr;
1136     break;
1137   case ISD::ExternalSymbol:
1138     Erased = ExternalSymbols.erase(cast<ExternalSymbolSDNode>(N)->getSymbol());
1139     break;
1140   case ISD::TargetExternalSymbol: {
1141     ExternalSymbolSDNode *ESN = cast<ExternalSymbolSDNode>(N);
1142     Erased = TargetExternalSymbols.erase(std::pair<std::string, unsigned>(
1143         ESN->getSymbol(), ESN->getTargetFlags()));
1144     break;
1145   }
1146   case ISD::MCSymbol: {
1147     auto *MCSN = cast<MCSymbolSDNode>(N);
1148     Erased = MCSymbols.erase(MCSN->getMCSymbol());
1149     break;
1150   }
1151   case ISD::VALUETYPE: {
1152     EVT VT = cast<VTSDNode>(N)->getVT();
1153     if (VT.isExtended()) {
1154       Erased = ExtendedValueTypeNodes.erase(VT);
1155     } else {
1156       Erased = ValueTypeNodes[VT.getSimpleVT().SimpleTy] != nullptr;
1157       ValueTypeNodes[VT.getSimpleVT().SimpleTy] = nullptr;
1158     }
1159     break;
1160   }
1161   default:
1162     // Remove it from the CSE Map.
1163     assert(N->getOpcode() != ISD::DELETED_NODE && "DELETED_NODE in CSEMap!");
1164     assert(N->getOpcode() != ISD::EntryToken && "EntryToken in CSEMap!");
1165     Erased = CSEMap.RemoveNode(N);
1166     break;
1167   }
1168 #ifndef NDEBUG
1169   // Verify that the node was actually in one of the CSE maps, unless it has a
1170   // flag result (which cannot be CSE'd) or is one of the special cases that are
1171   // not subject to CSE.
1172   if (!Erased && N->getValueType(N->getNumValues()-1) != MVT::Glue &&
1173       !N->isMachineOpcode() && !doNotCSE(N)) {
1174     N->dump(this);
1175     dbgs() << "\n";
1176     llvm_unreachable("Node is not in map!");
1177   }
1178 #endif
1179   return Erased;
1180 }
1181 
1182 /// AddModifiedNodeToCSEMaps - The specified node has been removed from the CSE
1183 /// maps and modified in place. Add it back to the CSE maps, unless an identical
1184 /// node already exists, in which case transfer all its users to the existing
1185 /// node. This transfer can potentially trigger recursive merging.
1186 void
AddModifiedNodeToCSEMaps(SDNode * N)1187 SelectionDAG::AddModifiedNodeToCSEMaps(SDNode *N) {
1188   // For node types that aren't CSE'd, just act as if no identical node
1189   // already exists.
1190   if (!doNotCSE(N)) {
1191     SDNode *Existing = CSEMap.GetOrInsertNode(N);
1192     if (Existing != N) {
1193       // If there was already an existing matching node, use ReplaceAllUsesWith
1194       // to replace the dead one with the existing one.  This can cause
1195       // recursive merging of other unrelated nodes down the line.
1196       ReplaceAllUsesWith(N, Existing);
1197 
1198       // N is now dead. Inform the listeners and delete it.
1199       for (DAGUpdateListener *DUL = UpdateListeners; DUL; DUL = DUL->Next)
1200         DUL->NodeDeleted(N, Existing);
1201       DeleteNodeNotInCSEMaps(N);
1202       return;
1203     }
1204   }
1205 
1206   // If the node doesn't already exist, we updated it.  Inform listeners.
1207   for (DAGUpdateListener *DUL = UpdateListeners; DUL; DUL = DUL->Next)
1208     DUL->NodeUpdated(N);
1209 }
1210 
1211 /// FindModifiedNodeSlot - Find a slot for the specified node if its operands
1212 /// were replaced with those specified.  If this node is never memoized,
1213 /// return null, otherwise return a pointer to the slot it would take.  If a
1214 /// node already exists with these operands, the slot will be non-null.
FindModifiedNodeSlot(SDNode * N,SDValue Op,void * & InsertPos)1215 SDNode *SelectionDAG::FindModifiedNodeSlot(SDNode *N, SDValue Op,
1216                                            void *&InsertPos) {
1217   if (doNotCSE(N))
1218     return nullptr;
1219 
1220   SDValue Ops[] = { Op };
1221   FoldingSetNodeID ID;
1222   AddNodeIDNode(ID, N->getOpcode(), N->getVTList(), Ops);
1223   AddNodeIDCustom(ID, N);
1224   SDNode *Node = FindNodeOrInsertPos(ID, SDLoc(N), InsertPos);
1225   if (Node)
1226     Node->intersectFlagsWith(N->getFlags());
1227   return Node;
1228 }
1229 
1230 /// FindModifiedNodeSlot - Find a slot for the specified node if its operands
1231 /// were replaced with those specified.  If this node is never memoized,
1232 /// return null, otherwise return a pointer to the slot it would take.  If a
1233 /// node already exists with these operands, the slot will be non-null.
FindModifiedNodeSlot(SDNode * N,SDValue Op1,SDValue Op2,void * & InsertPos)1234 SDNode *SelectionDAG::FindModifiedNodeSlot(SDNode *N,
1235                                            SDValue Op1, SDValue Op2,
1236                                            void *&InsertPos) {
1237   if (doNotCSE(N))
1238     return nullptr;
1239 
1240   SDValue Ops[] = { Op1, Op2 };
1241   FoldingSetNodeID ID;
1242   AddNodeIDNode(ID, N->getOpcode(), N->getVTList(), Ops);
1243   AddNodeIDCustom(ID, N);
1244   SDNode *Node = FindNodeOrInsertPos(ID, SDLoc(N), InsertPos);
1245   if (Node)
1246     Node->intersectFlagsWith(N->getFlags());
1247   return Node;
1248 }
1249 
1250 /// FindModifiedNodeSlot - Find a slot for the specified node if its operands
1251 /// were replaced with those specified.  If this node is never memoized,
1252 /// return null, otherwise return a pointer to the slot it would take.  If a
1253 /// node already exists with these operands, the slot will be non-null.
FindModifiedNodeSlot(SDNode * N,ArrayRef<SDValue> Ops,void * & InsertPos)1254 SDNode *SelectionDAG::FindModifiedNodeSlot(SDNode *N, ArrayRef<SDValue> Ops,
1255                                            void *&InsertPos) {
1256   if (doNotCSE(N))
1257     return nullptr;
1258 
1259   FoldingSetNodeID ID;
1260   AddNodeIDNode(ID, N->getOpcode(), N->getVTList(), Ops);
1261   AddNodeIDCustom(ID, N);
1262   SDNode *Node = FindNodeOrInsertPos(ID, SDLoc(N), InsertPos);
1263   if (Node)
1264     Node->intersectFlagsWith(N->getFlags());
1265   return Node;
1266 }
1267 
getEVTAlign(EVT VT) const1268 Align SelectionDAG::getEVTAlign(EVT VT) const {
1269   Type *Ty = VT == MVT::iPTR ?
1270                    PointerType::get(Type::getInt8Ty(*getContext()), 0) :
1271                    VT.getTypeForEVT(*getContext());
1272 
1273   return getDataLayout().getABITypeAlign(Ty);
1274 }
1275 
1276 // EntryNode could meaningfully have debug info if we can find it...
SelectionDAG(const TargetMachine & tm,CodeGenOpt::Level OL)1277 SelectionDAG::SelectionDAG(const TargetMachine &tm, CodeGenOpt::Level OL)
1278     : TM(tm), OptLevel(OL),
1279       EntryNode(ISD::EntryToken, 0, DebugLoc(), getVTList(MVT::Other, MVT::Glue)),
1280       Root(getEntryNode()) {
1281   InsertNode(&EntryNode);
1282   DbgInfo = new SDDbgInfo();
1283 }
1284 
init(MachineFunction & NewMF,OptimizationRemarkEmitter & NewORE,Pass * PassPtr,const TargetLibraryInfo * LibraryInfo,LegacyDivergenceAnalysis * Divergence,ProfileSummaryInfo * PSIin,BlockFrequencyInfo * BFIin,FunctionVarLocs const * VarLocs)1285 void SelectionDAG::init(MachineFunction &NewMF,
1286                         OptimizationRemarkEmitter &NewORE, Pass *PassPtr,
1287                         const TargetLibraryInfo *LibraryInfo,
1288                         LegacyDivergenceAnalysis *Divergence,
1289                         ProfileSummaryInfo *PSIin, BlockFrequencyInfo *BFIin,
1290                         FunctionVarLocs const *VarLocs) {
1291   MF = &NewMF;
1292   SDAGISelPass = PassPtr;
1293   ORE = &NewORE;
1294   TLI = getSubtarget().getTargetLowering();
1295   TSI = getSubtarget().getSelectionDAGInfo();
1296   LibInfo = LibraryInfo;
1297   Context = &MF->getFunction().getContext();
1298   DA = Divergence;
1299   PSI = PSIin;
1300   BFI = BFIin;
1301   FnVarLocs = VarLocs;
1302 }
1303 
~SelectionDAG()1304 SelectionDAG::~SelectionDAG() {
1305   assert(!UpdateListeners && "Dangling registered DAGUpdateListeners");
1306   allnodes_clear();
1307   OperandRecycler.clear(OperandAllocator);
1308   delete DbgInfo;
1309 }
1310 
shouldOptForSize() const1311 bool SelectionDAG::shouldOptForSize() const {
1312   return MF->getFunction().hasOptSize() ||
1313       llvm::shouldOptimizeForSize(FLI->MBB->getBasicBlock(), PSI, BFI);
1314 }
1315 
allnodes_clear()1316 void SelectionDAG::allnodes_clear() {
1317   assert(&*AllNodes.begin() == &EntryNode);
1318   AllNodes.remove(AllNodes.begin());
1319   while (!AllNodes.empty())
1320     DeallocateNode(&AllNodes.front());
1321 #ifndef NDEBUG
1322   NextPersistentId = 0;
1323 #endif
1324 }
1325 
FindNodeOrInsertPos(const FoldingSetNodeID & ID,void * & InsertPos)1326 SDNode *SelectionDAG::FindNodeOrInsertPos(const FoldingSetNodeID &ID,
1327                                           void *&InsertPos) {
1328   SDNode *N = CSEMap.FindNodeOrInsertPos(ID, InsertPos);
1329   if (N) {
1330     switch (N->getOpcode()) {
1331     default: break;
1332     case ISD::Constant:
1333     case ISD::ConstantFP:
1334       llvm_unreachable("Querying for Constant and ConstantFP nodes requires "
1335                        "debug location.  Use another overload.");
1336     }
1337   }
1338   return N;
1339 }
1340 
FindNodeOrInsertPos(const FoldingSetNodeID & ID,const SDLoc & DL,void * & InsertPos)1341 SDNode *SelectionDAG::FindNodeOrInsertPos(const FoldingSetNodeID &ID,
1342                                           const SDLoc &DL, void *&InsertPos) {
1343   SDNode *N = CSEMap.FindNodeOrInsertPos(ID, InsertPos);
1344   if (N) {
1345     switch (N->getOpcode()) {
1346     case ISD::Constant:
1347     case ISD::ConstantFP:
1348       // Erase debug location from the node if the node is used at several
1349       // different places. Do not propagate one location to all uses as it
1350       // will cause a worse single stepping debugging experience.
1351       if (N->getDebugLoc() != DL.getDebugLoc())
1352         N->setDebugLoc(DebugLoc());
1353       break;
1354     default:
1355       // When the node's point of use is located earlier in the instruction
1356       // sequence than its prior point of use, update its debug info to the
1357       // earlier location.
1358       if (DL.getIROrder() && DL.getIROrder() < N->getIROrder())
1359         N->setDebugLoc(DL.getDebugLoc());
1360       break;
1361     }
1362   }
1363   return N;
1364 }
1365 
clear()1366 void SelectionDAG::clear() {
1367   allnodes_clear();
1368   OperandRecycler.clear(OperandAllocator);
1369   OperandAllocator.Reset();
1370   CSEMap.clear();
1371 
1372   ExtendedValueTypeNodes.clear();
1373   ExternalSymbols.clear();
1374   TargetExternalSymbols.clear();
1375   MCSymbols.clear();
1376   SDEI.clear();
1377   std::fill(CondCodeNodes.begin(), CondCodeNodes.end(),
1378             static_cast<CondCodeSDNode*>(nullptr));
1379   std::fill(ValueTypeNodes.begin(), ValueTypeNodes.end(),
1380             static_cast<SDNode*>(nullptr));
1381 
1382   EntryNode.UseList = nullptr;
1383   InsertNode(&EntryNode);
1384   Root = getEntryNode();
1385   DbgInfo->clear();
1386 }
1387 
getFPExtendOrRound(SDValue Op,const SDLoc & DL,EVT VT)1388 SDValue SelectionDAG::getFPExtendOrRound(SDValue Op, const SDLoc &DL, EVT VT) {
1389   return VT.bitsGT(Op.getValueType())
1390              ? getNode(ISD::FP_EXTEND, DL, VT, Op)
1391              : getNode(ISD::FP_ROUND, DL, VT, Op,
1392                        getIntPtrConstant(0, DL, /*isTarget=*/true));
1393 }
1394 
1395 std::pair<SDValue, SDValue>
getStrictFPExtendOrRound(SDValue Op,SDValue Chain,const SDLoc & DL,EVT VT)1396 SelectionDAG::getStrictFPExtendOrRound(SDValue Op, SDValue Chain,
1397                                        const SDLoc &DL, EVT VT) {
1398   assert(!VT.bitsEq(Op.getValueType()) &&
1399          "Strict no-op FP extend/round not allowed.");
1400   SDValue Res =
1401       VT.bitsGT(Op.getValueType())
1402           ? getNode(ISD::STRICT_FP_EXTEND, DL, {VT, MVT::Other}, {Chain, Op})
1403           : getNode(ISD::STRICT_FP_ROUND, DL, {VT, MVT::Other},
1404                     {Chain, Op, getIntPtrConstant(0, DL)});
1405 
1406   return std::pair<SDValue, SDValue>(Res, SDValue(Res.getNode(), 1));
1407 }
1408 
getAnyExtOrTrunc(SDValue Op,const SDLoc & DL,EVT VT)1409 SDValue SelectionDAG::getAnyExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT) {
1410   return VT.bitsGT(Op.getValueType()) ?
1411     getNode(ISD::ANY_EXTEND, DL, VT, Op) :
1412     getNode(ISD::TRUNCATE, DL, VT, Op);
1413 }
1414 
getSExtOrTrunc(SDValue Op,const SDLoc & DL,EVT VT)1415 SDValue SelectionDAG::getSExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT) {
1416   return VT.bitsGT(Op.getValueType()) ?
1417     getNode(ISD::SIGN_EXTEND, DL, VT, Op) :
1418     getNode(ISD::TRUNCATE, DL, VT, Op);
1419 }
1420 
getZExtOrTrunc(SDValue Op,const SDLoc & DL,EVT VT)1421 SDValue SelectionDAG::getZExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT) {
1422   return VT.bitsGT(Op.getValueType()) ?
1423     getNode(ISD::ZERO_EXTEND, DL, VT, Op) :
1424     getNode(ISD::TRUNCATE, DL, VT, Op);
1425 }
1426 
getBoolExtOrTrunc(SDValue Op,const SDLoc & SL,EVT VT,EVT OpVT)1427 SDValue SelectionDAG::getBoolExtOrTrunc(SDValue Op, const SDLoc &SL, EVT VT,
1428                                         EVT OpVT) {
1429   if (VT.bitsLE(Op.getValueType()))
1430     return getNode(ISD::TRUNCATE, SL, VT, Op);
1431 
1432   TargetLowering::BooleanContent BType = TLI->getBooleanContents(OpVT);
1433   return getNode(TLI->getExtendForContent(BType), SL, VT, Op);
1434 }
1435 
getZeroExtendInReg(SDValue Op,const SDLoc & DL,EVT VT)1436 SDValue SelectionDAG::getZeroExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) {
1437   EVT OpVT = Op.getValueType();
1438   assert(VT.isInteger() && OpVT.isInteger() &&
1439          "Cannot getZeroExtendInReg FP types");
1440   assert(VT.isVector() == OpVT.isVector() &&
1441          "getZeroExtendInReg type should be vector iff the operand "
1442          "type is vector!");
1443   assert((!VT.isVector() ||
1444           VT.getVectorElementCount() == OpVT.getVectorElementCount()) &&
1445          "Vector element counts must match in getZeroExtendInReg");
1446   assert(VT.bitsLE(OpVT) && "Not extending!");
1447   if (OpVT == VT)
1448     return Op;
1449   APInt Imm = APInt::getLowBitsSet(OpVT.getScalarSizeInBits(),
1450                                    VT.getScalarSizeInBits());
1451   return getNode(ISD::AND, DL, OpVT, Op, getConstant(Imm, DL, OpVT));
1452 }
1453 
getPtrExtOrTrunc(SDValue Op,const SDLoc & DL,EVT VT)1454 SDValue SelectionDAG::getPtrExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT) {
1455   // Only unsigned pointer semantics are supported right now. In the future this
1456   // might delegate to TLI to check pointer signedness.
1457   return getZExtOrTrunc(Op, DL, VT);
1458 }
1459 
getPtrExtendInReg(SDValue Op,const SDLoc & DL,EVT VT)1460 SDValue SelectionDAG::getPtrExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) {
1461   // Only unsigned pointer semantics are supported right now. In the future this
1462   // might delegate to TLI to check pointer signedness.
1463   return getZeroExtendInReg(Op, DL, VT);
1464 }
1465 
getNegative(SDValue Val,const SDLoc & DL,EVT VT)1466 SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT) {
1467   return getNode(ISD::SUB, DL, VT, getConstant(0, DL, VT), Val);
1468 }
1469 
1470 /// getNOT - Create a bitwise NOT operation as (XOR Val, -1).
getNOT(const SDLoc & DL,SDValue Val,EVT VT)1471 SDValue SelectionDAG::getNOT(const SDLoc &DL, SDValue Val, EVT VT) {
1472   return getNode(ISD::XOR, DL, VT, Val, getAllOnesConstant(DL, VT));
1473 }
1474 
getLogicalNOT(const SDLoc & DL,SDValue Val,EVT VT)1475 SDValue SelectionDAG::getLogicalNOT(const SDLoc &DL, SDValue Val, EVT VT) {
1476   SDValue TrueValue = getBoolConstant(true, DL, VT, VT);
1477   return getNode(ISD::XOR, DL, VT, Val, TrueValue);
1478 }
1479 
getVPLogicalNOT(const SDLoc & DL,SDValue Val,SDValue Mask,SDValue EVL,EVT VT)1480 SDValue SelectionDAG::getVPLogicalNOT(const SDLoc &DL, SDValue Val,
1481                                       SDValue Mask, SDValue EVL, EVT VT) {
1482   SDValue TrueValue = getBoolConstant(true, DL, VT, VT);
1483   return getNode(ISD::VP_XOR, DL, VT, Val, TrueValue, Mask, EVL);
1484 }
1485 
getVPPtrExtOrTrunc(const SDLoc & DL,EVT VT,SDValue Op,SDValue Mask,SDValue EVL)1486 SDValue SelectionDAG::getVPPtrExtOrTrunc(const SDLoc &DL, EVT VT, SDValue Op,
1487                                          SDValue Mask, SDValue EVL) {
1488   return getVPZExtOrTrunc(DL, VT, Op, Mask, EVL);
1489 }
1490 
getVPZExtOrTrunc(const SDLoc & DL,EVT VT,SDValue Op,SDValue Mask,SDValue EVL)1491 SDValue SelectionDAG::getVPZExtOrTrunc(const SDLoc &DL, EVT VT, SDValue Op,
1492                                        SDValue Mask, SDValue EVL) {
1493   if (VT.bitsGT(Op.getValueType()))
1494     return getNode(ISD::VP_ZERO_EXTEND, DL, VT, Op, Mask, EVL);
1495   if (VT.bitsLT(Op.getValueType()))
1496     return getNode(ISD::VP_TRUNCATE, DL, VT, Op, Mask, EVL);
1497   return Op;
1498 }
1499 
getBoolConstant(bool V,const SDLoc & DL,EVT VT,EVT OpVT)1500 SDValue SelectionDAG::getBoolConstant(bool V, const SDLoc &DL, EVT VT,
1501                                       EVT OpVT) {
1502   if (!V)
1503     return getConstant(0, DL, VT);
1504 
1505   switch (TLI->getBooleanContents(OpVT)) {
1506   case TargetLowering::ZeroOrOneBooleanContent:
1507   case TargetLowering::UndefinedBooleanContent:
1508     return getConstant(1, DL, VT);
1509   case TargetLowering::ZeroOrNegativeOneBooleanContent:
1510     return getAllOnesConstant(DL, VT);
1511   }
1512   llvm_unreachable("Unexpected boolean content enum!");
1513 }
1514 
getConstant(uint64_t Val,const SDLoc & DL,EVT VT,bool isT,bool isO)1515 SDValue SelectionDAG::getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
1516                                   bool isT, bool isO) {
1517   EVT EltVT = VT.getScalarType();
1518   assert((EltVT.getSizeInBits() >= 64 ||
1519           (uint64_t)((int64_t)Val >> EltVT.getSizeInBits()) + 1 < 2) &&
1520          "getConstant with a uint64_t value that doesn't fit in the type!");
1521   return getConstant(APInt(EltVT.getSizeInBits(), Val), DL, VT, isT, isO);
1522 }
1523 
getConstant(const APInt & Val,const SDLoc & DL,EVT VT,bool isT,bool isO)1524 SDValue SelectionDAG::getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
1525                                   bool isT, bool isO) {
1526   return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO);
1527 }
1528 
getConstant(const ConstantInt & Val,const SDLoc & DL,EVT VT,bool isT,bool isO)1529 SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
1530                                   EVT VT, bool isT, bool isO) {
1531   assert(VT.isInteger() && "Cannot create FP integer constant!");
1532 
1533   EVT EltVT = VT.getScalarType();
1534   const ConstantInt *Elt = &Val;
1535 
1536   // In some cases the vector type is legal but the element type is illegal and
1537   // needs to be promoted, for example v8i8 on ARM.  In this case, promote the
1538   // inserted value (the type does not need to match the vector element type).
1539   // Any extra bits introduced will be truncated away.
1540   if (VT.isVector() && TLI->getTypeAction(*getContext(), EltVT) ==
1541                            TargetLowering::TypePromoteInteger) {
1542     EltVT = TLI->getTypeToTransformTo(*getContext(), EltVT);
1543     APInt NewVal = Elt->getValue().zextOrTrunc(EltVT.getSizeInBits());
1544     Elt = ConstantInt::get(*getContext(), NewVal);
1545   }
1546   // In other cases the element type is illegal and needs to be expanded, for
1547   // example v2i64 on MIPS32. In this case, find the nearest legal type, split
1548   // the value into n parts and use a vector type with n-times the elements.
1549   // Then bitcast to the type requested.
1550   // Legalizing constants too early makes the DAGCombiner's job harder so we
1551   // only legalize if the DAG tells us we must produce legal types.
1552   else if (NewNodesMustHaveLegalTypes && VT.isVector() &&
1553            TLI->getTypeAction(*getContext(), EltVT) ==
1554                TargetLowering::TypeExpandInteger) {
1555     const APInt &NewVal = Elt->getValue();
1556     EVT ViaEltVT = TLI->getTypeToTransformTo(*getContext(), EltVT);
1557     unsigned ViaEltSizeInBits = ViaEltVT.getSizeInBits();
1558 
1559     // For scalable vectors, try to use a SPLAT_VECTOR_PARTS node.
1560     if (VT.isScalableVector()) {
1561       assert(EltVT.getSizeInBits() % ViaEltSizeInBits == 0 &&
1562              "Can only handle an even split!");
1563       unsigned Parts = EltVT.getSizeInBits() / ViaEltSizeInBits;
1564 
1565       SmallVector<SDValue, 2> ScalarParts;
1566       for (unsigned i = 0; i != Parts; ++i)
1567         ScalarParts.push_back(getConstant(
1568             NewVal.extractBits(ViaEltSizeInBits, i * ViaEltSizeInBits), DL,
1569             ViaEltVT, isT, isO));
1570 
1571       return getNode(ISD::SPLAT_VECTOR_PARTS, DL, VT, ScalarParts);
1572     }
1573 
1574     unsigned ViaVecNumElts = VT.getSizeInBits() / ViaEltSizeInBits;
1575     EVT ViaVecVT = EVT::getVectorVT(*getContext(), ViaEltVT, ViaVecNumElts);
1576 
1577     // Check the temporary vector is the correct size. If this fails then
1578     // getTypeToTransformTo() probably returned a type whose size (in bits)
1579     // isn't a power-of-2 factor of the requested type size.
1580     assert(ViaVecVT.getSizeInBits() == VT.getSizeInBits());
1581 
1582     SmallVector<SDValue, 2> EltParts;
1583     for (unsigned i = 0; i < ViaVecNumElts / VT.getVectorNumElements(); ++i)
1584       EltParts.push_back(getConstant(
1585           NewVal.extractBits(ViaEltSizeInBits, i * ViaEltSizeInBits), DL,
1586           ViaEltVT, isT, isO));
1587 
1588     // EltParts is currently in little endian order. If we actually want
1589     // big-endian order then reverse it now.
1590     if (getDataLayout().isBigEndian())
1591       std::reverse(EltParts.begin(), EltParts.end());
1592 
1593     // The elements must be reversed when the element order is different
1594     // to the endianness of the elements (because the BITCAST is itself a
1595     // vector shuffle in this situation). However, we do not need any code to
1596     // perform this reversal because getConstant() is producing a vector
1597     // splat.
1598     // This situation occurs in MIPS MSA.
1599 
1600     SmallVector<SDValue, 8> Ops;
1601     for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; ++i)
1602       llvm::append_range(Ops, EltParts);
1603 
1604     SDValue V =
1605         getNode(ISD::BITCAST, DL, VT, getBuildVector(ViaVecVT, DL, Ops));
1606     return V;
1607   }
1608 
1609   assert(Elt->getBitWidth() == EltVT.getSizeInBits() &&
1610          "APInt size does not match type size!");
1611   unsigned Opc = isT ? ISD::TargetConstant : ISD::Constant;
1612   FoldingSetNodeID ID;
1613   AddNodeIDNode(ID, Opc, getVTList(EltVT), std::nullopt);
1614   ID.AddPointer(Elt);
1615   ID.AddBoolean(isO);
1616   void *IP = nullptr;
1617   SDNode *N = nullptr;
1618   if ((N = FindNodeOrInsertPos(ID, DL, IP)))
1619     if (!VT.isVector())
1620       return SDValue(N, 0);
1621 
1622   if (!N) {
1623     N = newSDNode<ConstantSDNode>(isT, isO, Elt, EltVT);
1624     CSEMap.InsertNode(N, IP);
1625     InsertNode(N);
1626     NewSDValueDbgMsg(SDValue(N, 0), "Creating constant: ", this);
1627   }
1628 
1629   SDValue Result(N, 0);
1630   if (VT.isVector())
1631     Result = getSplat(VT, DL, Result);
1632   return Result;
1633 }
1634 
getIntPtrConstant(uint64_t Val,const SDLoc & DL,bool isTarget)1635 SDValue SelectionDAG::getIntPtrConstant(uint64_t Val, const SDLoc &DL,
1636                                         bool isTarget) {
1637   return getConstant(Val, DL, TLI->getPointerTy(getDataLayout()), isTarget);
1638 }
1639 
getShiftAmountConstant(uint64_t Val,EVT VT,const SDLoc & DL,bool LegalTypes)1640 SDValue SelectionDAG::getShiftAmountConstant(uint64_t Val, EVT VT,
1641                                              const SDLoc &DL, bool LegalTypes) {
1642   assert(VT.isInteger() && "Shift amount is not an integer type!");
1643   EVT ShiftVT = TLI->getShiftAmountTy(VT, getDataLayout(), LegalTypes);
1644   return getConstant(Val, DL, ShiftVT);
1645 }
1646 
getVectorIdxConstant(uint64_t Val,const SDLoc & DL,bool isTarget)1647 SDValue SelectionDAG::getVectorIdxConstant(uint64_t Val, const SDLoc &DL,
1648                                            bool isTarget) {
1649   return getConstant(Val, DL, TLI->getVectorIdxTy(getDataLayout()), isTarget);
1650 }
1651 
getConstantFP(const APFloat & V,const SDLoc & DL,EVT VT,bool isTarget)1652 SDValue SelectionDAG::getConstantFP(const APFloat &V, const SDLoc &DL, EVT VT,
1653                                     bool isTarget) {
1654   return getConstantFP(*ConstantFP::get(*getContext(), V), DL, VT, isTarget);
1655 }
1656 
getConstantFP(const ConstantFP & V,const SDLoc & DL,EVT VT,bool isTarget)1657 SDValue SelectionDAG::getConstantFP(const ConstantFP &V, const SDLoc &DL,
1658                                     EVT VT, bool isTarget) {
1659   assert(VT.isFloatingPoint() && "Cannot create integer FP constant!");
1660 
1661   EVT EltVT = VT.getScalarType();
1662 
1663   // Do the map lookup using the actual bit pattern for the floating point
1664   // value, so that we don't have problems with 0.0 comparing equal to -0.0, and
1665   // we don't have issues with SNANs.
1666   unsigned Opc = isTarget ? ISD::TargetConstantFP : ISD::ConstantFP;
1667   FoldingSetNodeID ID;
1668   AddNodeIDNode(ID, Opc, getVTList(EltVT), std::nullopt);
1669   ID.AddPointer(&V);
1670   void *IP = nullptr;
1671   SDNode *N = nullptr;
1672   if ((N = FindNodeOrInsertPos(ID, DL, IP)))
1673     if (!VT.isVector())
1674       return SDValue(N, 0);
1675 
1676   if (!N) {
1677     N = newSDNode<ConstantFPSDNode>(isTarget, &V, EltVT);
1678     CSEMap.InsertNode(N, IP);
1679     InsertNode(N);
1680   }
1681 
1682   SDValue Result(N, 0);
1683   if (VT.isVector())
1684     Result = getSplat(VT, DL, Result);
1685   NewSDValueDbgMsg(Result, "Creating fp constant: ", this);
1686   return Result;
1687 }
1688 
getConstantFP(double Val,const SDLoc & DL,EVT VT,bool isTarget)1689 SDValue SelectionDAG::getConstantFP(double Val, const SDLoc &DL, EVT VT,
1690                                     bool isTarget) {
1691   EVT EltVT = VT.getScalarType();
1692   if (EltVT == MVT::f32)
1693     return getConstantFP(APFloat((float)Val), DL, VT, isTarget);
1694   if (EltVT == MVT::f64)
1695     return getConstantFP(APFloat(Val), DL, VT, isTarget);
1696   if (EltVT == MVT::f80 || EltVT == MVT::f128 || EltVT == MVT::ppcf128 ||
1697       EltVT == MVT::f16 || EltVT == MVT::bf16) {
1698     bool Ignored;
1699     APFloat APF = APFloat(Val);
1700     APF.convert(EVTToAPFloatSemantics(EltVT), APFloat::rmNearestTiesToEven,
1701                 &Ignored);
1702     return getConstantFP(APF, DL, VT, isTarget);
1703   }
1704   llvm_unreachable("Unsupported type in getConstantFP");
1705 }
1706 
getGlobalAddress(const GlobalValue * GV,const SDLoc & DL,EVT VT,int64_t Offset,bool isTargetGA,unsigned TargetFlags)1707 SDValue SelectionDAG::getGlobalAddress(const GlobalValue *GV, const SDLoc &DL,
1708                                        EVT VT, int64_t Offset, bool isTargetGA,
1709                                        unsigned TargetFlags) {
1710   assert((TargetFlags == 0 || isTargetGA) &&
1711          "Cannot set target flags on target-independent globals");
1712 
1713   // Truncate (with sign-extension) the offset value to the pointer size.
1714   unsigned BitWidth = getDataLayout().getPointerTypeSizeInBits(GV->getType());
1715   if (BitWidth < 64)
1716     Offset = SignExtend64(Offset, BitWidth);
1717 
1718   unsigned Opc;
1719   if (GV->isThreadLocal())
1720     Opc = isTargetGA ? ISD::TargetGlobalTLSAddress : ISD::GlobalTLSAddress;
1721   else
1722     Opc = isTargetGA ? ISD::TargetGlobalAddress : ISD::GlobalAddress;
1723 
1724   FoldingSetNodeID ID;
1725   AddNodeIDNode(ID, Opc, getVTList(VT), std::nullopt);
1726   ID.AddPointer(GV);
1727   ID.AddInteger(Offset);
1728   ID.AddInteger(TargetFlags);
1729   void *IP = nullptr;
1730   if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
1731     return SDValue(E, 0);
1732 
1733   auto *N = newSDNode<GlobalAddressSDNode>(
1734       Opc, DL.getIROrder(), DL.getDebugLoc(), GV, VT, Offset, TargetFlags);
1735   CSEMap.InsertNode(N, IP);
1736     InsertNode(N);
1737   return SDValue(N, 0);
1738 }
1739 
getFrameIndex(int FI,EVT VT,bool isTarget)1740 SDValue SelectionDAG::getFrameIndex(int FI, EVT VT, bool isTarget) {
1741   unsigned Opc = isTarget ? ISD::TargetFrameIndex : ISD::FrameIndex;
1742   FoldingSetNodeID ID;
1743   AddNodeIDNode(ID, Opc, getVTList(VT), std::nullopt);
1744   ID.AddInteger(FI);
1745   void *IP = nullptr;
1746   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
1747     return SDValue(E, 0);
1748 
1749   auto *N = newSDNode<FrameIndexSDNode>(FI, VT, isTarget);
1750   CSEMap.InsertNode(N, IP);
1751   InsertNode(N);
1752   return SDValue(N, 0);
1753 }
1754 
getJumpTable(int JTI,EVT VT,bool isTarget,unsigned TargetFlags)1755 SDValue SelectionDAG::getJumpTable(int JTI, EVT VT, bool isTarget,
1756                                    unsigned TargetFlags) {
1757   assert((TargetFlags == 0 || isTarget) &&
1758          "Cannot set target flags on target-independent jump tables");
1759   unsigned Opc = isTarget ? ISD::TargetJumpTable : ISD::JumpTable;
1760   FoldingSetNodeID ID;
1761   AddNodeIDNode(ID, Opc, getVTList(VT), std::nullopt);
1762   ID.AddInteger(JTI);
1763   ID.AddInteger(TargetFlags);
1764   void *IP = nullptr;
1765   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
1766     return SDValue(E, 0);
1767 
1768   auto *N = newSDNode<JumpTableSDNode>(JTI, VT, isTarget, TargetFlags);
1769   CSEMap.InsertNode(N, IP);
1770   InsertNode(N);
1771   return SDValue(N, 0);
1772 }
1773 
getConstantPool(const Constant * C,EVT VT,MaybeAlign Alignment,int Offset,bool isTarget,unsigned TargetFlags)1774 SDValue SelectionDAG::getConstantPool(const Constant *C, EVT VT,
1775                                       MaybeAlign Alignment, int Offset,
1776                                       bool isTarget, unsigned TargetFlags) {
1777   assert((TargetFlags == 0 || isTarget) &&
1778          "Cannot set target flags on target-independent globals");
1779   if (!Alignment)
1780     Alignment = shouldOptForSize()
1781                     ? getDataLayout().getABITypeAlign(C->getType())
1782                     : getDataLayout().getPrefTypeAlign(C->getType());
1783   unsigned Opc = isTarget ? ISD::TargetConstantPool : ISD::ConstantPool;
1784   FoldingSetNodeID ID;
1785   AddNodeIDNode(ID, Opc, getVTList(VT), std::nullopt);
1786   ID.AddInteger(Alignment->value());
1787   ID.AddInteger(Offset);
1788   ID.AddPointer(C);
1789   ID.AddInteger(TargetFlags);
1790   void *IP = nullptr;
1791   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
1792     return SDValue(E, 0);
1793 
1794   auto *N = newSDNode<ConstantPoolSDNode>(isTarget, C, VT, Offset, *Alignment,
1795                                           TargetFlags);
1796   CSEMap.InsertNode(N, IP);
1797   InsertNode(N);
1798   SDValue V = SDValue(N, 0);
1799   NewSDValueDbgMsg(V, "Creating new constant pool: ", this);
1800   return V;
1801 }
1802 
getConstantPool(MachineConstantPoolValue * C,EVT VT,MaybeAlign Alignment,int Offset,bool isTarget,unsigned TargetFlags)1803 SDValue SelectionDAG::getConstantPool(MachineConstantPoolValue *C, EVT VT,
1804                                       MaybeAlign Alignment, int Offset,
1805                                       bool isTarget, unsigned TargetFlags) {
1806   assert((TargetFlags == 0 || isTarget) &&
1807          "Cannot set target flags on target-independent globals");
1808   if (!Alignment)
1809     Alignment = getDataLayout().getPrefTypeAlign(C->getType());
1810   unsigned Opc = isTarget ? ISD::TargetConstantPool : ISD::ConstantPool;
1811   FoldingSetNodeID ID;
1812   AddNodeIDNode(ID, Opc, getVTList(VT), std::nullopt);
1813   ID.AddInteger(Alignment->value());
1814   ID.AddInteger(Offset);
1815   C->addSelectionDAGCSEId(ID);
1816   ID.AddInteger(TargetFlags);
1817   void *IP = nullptr;
1818   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
1819     return SDValue(E, 0);
1820 
1821   auto *N = newSDNode<ConstantPoolSDNode>(isTarget, C, VT, Offset, *Alignment,
1822                                           TargetFlags);
1823   CSEMap.InsertNode(N, IP);
1824   InsertNode(N);
1825   return SDValue(N, 0);
1826 }
1827 
getTargetIndex(int Index,EVT VT,int64_t Offset,unsigned TargetFlags)1828 SDValue SelectionDAG::getTargetIndex(int Index, EVT VT, int64_t Offset,
1829                                      unsigned TargetFlags) {
1830   FoldingSetNodeID ID;
1831   AddNodeIDNode(ID, ISD::TargetIndex, getVTList(VT), std::nullopt);
1832   ID.AddInteger(Index);
1833   ID.AddInteger(Offset);
1834   ID.AddInteger(TargetFlags);
1835   void *IP = nullptr;
1836   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
1837     return SDValue(E, 0);
1838 
1839   auto *N = newSDNode<TargetIndexSDNode>(Index, VT, Offset, TargetFlags);
1840   CSEMap.InsertNode(N, IP);
1841   InsertNode(N);
1842   return SDValue(N, 0);
1843 }
1844 
getBasicBlock(MachineBasicBlock * MBB)1845 SDValue SelectionDAG::getBasicBlock(MachineBasicBlock *MBB) {
1846   FoldingSetNodeID ID;
1847   AddNodeIDNode(ID, ISD::BasicBlock, getVTList(MVT::Other), std::nullopt);
1848   ID.AddPointer(MBB);
1849   void *IP = nullptr;
1850   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
1851     return SDValue(E, 0);
1852 
1853   auto *N = newSDNode<BasicBlockSDNode>(MBB);
1854   CSEMap.InsertNode(N, IP);
1855   InsertNode(N);
1856   return SDValue(N, 0);
1857 }
1858 
getValueType(EVT VT)1859 SDValue SelectionDAG::getValueType(EVT VT) {
1860   if (VT.isSimple() && (unsigned)VT.getSimpleVT().SimpleTy >=
1861       ValueTypeNodes.size())
1862     ValueTypeNodes.resize(VT.getSimpleVT().SimpleTy+1);
1863 
1864   SDNode *&N = VT.isExtended() ?
1865     ExtendedValueTypeNodes[VT] : ValueTypeNodes[VT.getSimpleVT().SimpleTy];
1866 
1867   if (N) return SDValue(N, 0);
1868   N = newSDNode<VTSDNode>(VT);
1869   InsertNode(N);
1870   return SDValue(N, 0);
1871 }
1872 
getExternalSymbol(const char * Sym,EVT VT)1873 SDValue SelectionDAG::getExternalSymbol(const char *Sym, EVT VT) {
1874   SDNode *&N = ExternalSymbols[Sym];
1875   if (N) return SDValue(N, 0);
1876   N = newSDNode<ExternalSymbolSDNode>(false, Sym, 0, VT);
1877   InsertNode(N);
1878   return SDValue(N, 0);
1879 }
1880 
getMCSymbol(MCSymbol * Sym,EVT VT)1881 SDValue SelectionDAG::getMCSymbol(MCSymbol *Sym, EVT VT) {
1882   SDNode *&N = MCSymbols[Sym];
1883   if (N)
1884     return SDValue(N, 0);
1885   N = newSDNode<MCSymbolSDNode>(Sym, VT);
1886   InsertNode(N);
1887   return SDValue(N, 0);
1888 }
1889 
getTargetExternalSymbol(const char * Sym,EVT VT,unsigned TargetFlags)1890 SDValue SelectionDAG::getTargetExternalSymbol(const char *Sym, EVT VT,
1891                                               unsigned TargetFlags) {
1892   SDNode *&N =
1893       TargetExternalSymbols[std::pair<std::string, unsigned>(Sym, TargetFlags)];
1894   if (N) return SDValue(N, 0);
1895   N = newSDNode<ExternalSymbolSDNode>(true, Sym, TargetFlags, VT);
1896   InsertNode(N);
1897   return SDValue(N, 0);
1898 }
1899 
getCondCode(ISD::CondCode Cond)1900 SDValue SelectionDAG::getCondCode(ISD::CondCode Cond) {
1901   if ((unsigned)Cond >= CondCodeNodes.size())
1902     CondCodeNodes.resize(Cond+1);
1903 
1904   if (!CondCodeNodes[Cond]) {
1905     auto *N = newSDNode<CondCodeSDNode>(Cond);
1906     CondCodeNodes[Cond] = N;
1907     InsertNode(N);
1908   }
1909 
1910   return SDValue(CondCodeNodes[Cond], 0);
1911 }
1912 
getStepVector(const SDLoc & DL,EVT ResVT)1913 SDValue SelectionDAG::getStepVector(const SDLoc &DL, EVT ResVT) {
1914   APInt One(ResVT.getScalarSizeInBits(), 1);
1915   return getStepVector(DL, ResVT, One);
1916 }
1917 
getStepVector(const SDLoc & DL,EVT ResVT,APInt StepVal)1918 SDValue SelectionDAG::getStepVector(const SDLoc &DL, EVT ResVT, APInt StepVal) {
1919   assert(ResVT.getScalarSizeInBits() == StepVal.getBitWidth());
1920   if (ResVT.isScalableVector())
1921     return getNode(
1922         ISD::STEP_VECTOR, DL, ResVT,
1923         getTargetConstant(StepVal, DL, ResVT.getVectorElementType()));
1924 
1925   SmallVector<SDValue, 16> OpsStepConstants;
1926   for (uint64_t i = 0; i < ResVT.getVectorNumElements(); i++)
1927     OpsStepConstants.push_back(
1928         getConstant(StepVal * i, DL, ResVT.getVectorElementType()));
1929   return getBuildVector(ResVT, DL, OpsStepConstants);
1930 }
1931 
1932 /// Swaps the values of N1 and N2. Swaps all indices in the shuffle mask M that
1933 /// point at N1 to point at N2 and indices that point at N2 to point at N1.
commuteShuffle(SDValue & N1,SDValue & N2,MutableArrayRef<int> M)1934 static void commuteShuffle(SDValue &N1, SDValue &N2, MutableArrayRef<int> M) {
1935   std::swap(N1, N2);
1936   ShuffleVectorSDNode::commuteMask(M);
1937 }
1938 
getVectorShuffle(EVT VT,const SDLoc & dl,SDValue N1,SDValue N2,ArrayRef<int> Mask)1939 SDValue SelectionDAG::getVectorShuffle(EVT VT, const SDLoc &dl, SDValue N1,
1940                                        SDValue N2, ArrayRef<int> Mask) {
1941   assert(VT.getVectorNumElements() == Mask.size() &&
1942          "Must have the same number of vector elements as mask elements!");
1943   assert(VT == N1.getValueType() && VT == N2.getValueType() &&
1944          "Invalid VECTOR_SHUFFLE");
1945 
1946   // Canonicalize shuffle undef, undef -> undef
1947   if (N1.isUndef() && N2.isUndef())
1948     return getUNDEF(VT);
1949 
1950   // Validate that all indices in Mask are within the range of the elements
1951   // input to the shuffle.
1952   int NElts = Mask.size();
1953   assert(llvm::all_of(Mask,
1954                       [&](int M) { return M < (NElts * 2) && M >= -1; }) &&
1955          "Index out of range");
1956 
1957   // Copy the mask so we can do any needed cleanup.
1958   SmallVector<int, 8> MaskVec(Mask);
1959 
1960   // Canonicalize shuffle v, v -> v, undef
1961   if (N1 == N2) {
1962     N2 = getUNDEF(VT);
1963     for (int i = 0; i != NElts; ++i)
1964       if (MaskVec[i] >= NElts) MaskVec[i] -= NElts;
1965   }
1966 
1967   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
1968   if (N1.isUndef())
1969     commuteShuffle(N1, N2, MaskVec);
1970 
1971   if (TLI->hasVectorBlend()) {
1972     // If shuffling a splat, try to blend the splat instead. We do this here so
1973     // that even when this arises during lowering we don't have to re-handle it.
1974     auto BlendSplat = [&](BuildVectorSDNode *BV, int Offset) {
1975       BitVector UndefElements;
1976       SDValue Splat = BV->getSplatValue(&UndefElements);
1977       if (!Splat)
1978         return;
1979 
1980       for (int i = 0; i < NElts; ++i) {
1981         if (MaskVec[i] < Offset || MaskVec[i] >= (Offset + NElts))
1982           continue;
1983 
1984         // If this input comes from undef, mark it as such.
1985         if (UndefElements[MaskVec[i] - Offset]) {
1986           MaskVec[i] = -1;
1987           continue;
1988         }
1989 
1990         // If we can blend a non-undef lane, use that instead.
1991         if (!UndefElements[i])
1992           MaskVec[i] = i + Offset;
1993       }
1994     };
1995     if (auto *N1BV = dyn_cast<BuildVectorSDNode>(N1))
1996       BlendSplat(N1BV, 0);
1997     if (auto *N2BV = dyn_cast<BuildVectorSDNode>(N2))
1998       BlendSplat(N2BV, NElts);
1999   }
2000 
2001   // Canonicalize all index into lhs, -> shuffle lhs, undef
2002   // Canonicalize all index into rhs, -> shuffle rhs, undef
2003   bool AllLHS = true, AllRHS = true;
2004   bool N2Undef = N2.isUndef();
2005   for (int i = 0; i != NElts; ++i) {
2006     if (MaskVec[i] >= NElts) {
2007       if (N2Undef)
2008         MaskVec[i] = -1;
2009       else
2010         AllLHS = false;
2011     } else if (MaskVec[i] >= 0) {
2012       AllRHS = false;
2013     }
2014   }
2015   if (AllLHS && AllRHS)
2016     return getUNDEF(VT);
2017   if (AllLHS && !N2Undef)
2018     N2 = getUNDEF(VT);
2019   if (AllRHS) {
2020     N1 = getUNDEF(VT);
2021     commuteShuffle(N1, N2, MaskVec);
2022   }
2023   // Reset our undef status after accounting for the mask.
2024   N2Undef = N2.isUndef();
2025   // Re-check whether both sides ended up undef.
2026   if (N1.isUndef() && N2Undef)
2027     return getUNDEF(VT);
2028 
2029   // If Identity shuffle return that node.
2030   bool Identity = true, AllSame = true;
2031   for (int i = 0; i != NElts; ++i) {
2032     if (MaskVec[i] >= 0 && MaskVec[i] != i) Identity = false;
2033     if (MaskVec[i] != MaskVec[0]) AllSame = false;
2034   }
2035   if (Identity && NElts)
2036     return N1;
2037 
2038   // Shuffling a constant splat doesn't change the result.
2039   if (N2Undef) {
2040     SDValue V = N1;
2041 
2042     // Look through any bitcasts. We check that these don't change the number
2043     // (and size) of elements and just changes their types.
2044     while (V.getOpcode() == ISD::BITCAST)
2045       V = V->getOperand(0);
2046 
2047     // A splat should always show up as a build vector node.
2048     if (auto *BV = dyn_cast<BuildVectorSDNode>(V)) {
2049       BitVector UndefElements;
2050       SDValue Splat = BV->getSplatValue(&UndefElements);
2051       // If this is a splat of an undef, shuffling it is also undef.
2052       if (Splat && Splat.isUndef())
2053         return getUNDEF(VT);
2054 
2055       bool SameNumElts =
2056           V.getValueType().getVectorNumElements() == VT.getVectorNumElements();
2057 
2058       // We only have a splat which can skip shuffles if there is a splatted
2059       // value and no undef lanes rearranged by the shuffle.
2060       if (Splat && UndefElements.none()) {
2061         // Splat of <x, x, ..., x>, return <x, x, ..., x>, provided that the
2062         // number of elements match or the value splatted is a zero constant.
2063         if (SameNumElts)
2064           return N1;
2065         if (auto *C = dyn_cast<ConstantSDNode>(Splat))
2066           if (C->isZero())
2067             return N1;
2068       }
2069 
2070       // If the shuffle itself creates a splat, build the vector directly.
2071       if (AllSame && SameNumElts) {
2072         EVT BuildVT = BV->getValueType(0);
2073         const SDValue &Splatted = BV->getOperand(MaskVec[0]);
2074         SDValue NewBV = getSplatBuildVector(BuildVT, dl, Splatted);
2075 
2076         // We may have jumped through bitcasts, so the type of the
2077         // BUILD_VECTOR may not match the type of the shuffle.
2078         if (BuildVT != VT)
2079           NewBV = getNode(ISD::BITCAST, dl, VT, NewBV);
2080         return NewBV;
2081       }
2082     }
2083   }
2084 
2085   FoldingSetNodeID ID;
2086   SDValue Ops[2] = { N1, N2 };
2087   AddNodeIDNode(ID, ISD::VECTOR_SHUFFLE, getVTList(VT), Ops);
2088   for (int i = 0; i != NElts; ++i)
2089     ID.AddInteger(MaskVec[i]);
2090 
2091   void* IP = nullptr;
2092   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP))
2093     return SDValue(E, 0);
2094 
2095   // Allocate the mask array for the node out of the BumpPtrAllocator, since
2096   // SDNode doesn't have access to it.  This memory will be "leaked" when
2097   // the node is deallocated, but recovered when the NodeAllocator is released.
2098   int *MaskAlloc = OperandAllocator.Allocate<int>(NElts);
2099   llvm::copy(MaskVec, MaskAlloc);
2100 
2101   auto *N = newSDNode<ShuffleVectorSDNode>(VT, dl.getIROrder(),
2102                                            dl.getDebugLoc(), MaskAlloc);
2103   createOperands(N, Ops);
2104 
2105   CSEMap.InsertNode(N, IP);
2106   InsertNode(N);
2107   SDValue V = SDValue(N, 0);
2108   NewSDValueDbgMsg(V, "Creating new node: ", this);
2109   return V;
2110 }
2111 
getCommutedVectorShuffle(const ShuffleVectorSDNode & SV)2112 SDValue SelectionDAG::getCommutedVectorShuffle(const ShuffleVectorSDNode &SV) {
2113   EVT VT = SV.getValueType(0);
2114   SmallVector<int, 8> MaskVec(SV.getMask());
2115   ShuffleVectorSDNode::commuteMask(MaskVec);
2116 
2117   SDValue Op0 = SV.getOperand(0);
2118   SDValue Op1 = SV.getOperand(1);
2119   return getVectorShuffle(VT, SDLoc(&SV), Op1, Op0, MaskVec);
2120 }
2121 
getRegister(unsigned RegNo,EVT VT)2122 SDValue SelectionDAG::getRegister(unsigned RegNo, EVT VT) {
2123   FoldingSetNodeID ID;
2124   AddNodeIDNode(ID, ISD::Register, getVTList(VT), std::nullopt);
2125   ID.AddInteger(RegNo);
2126   void *IP = nullptr;
2127   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
2128     return SDValue(E, 0);
2129 
2130   auto *N = newSDNode<RegisterSDNode>(RegNo, VT);
2131   N->SDNodeBits.IsDivergent = TLI->isSDNodeSourceOfDivergence(N, FLI, DA);
2132   CSEMap.InsertNode(N, IP);
2133   InsertNode(N);
2134   return SDValue(N, 0);
2135 }
2136 
getRegisterMask(const uint32_t * RegMask)2137 SDValue SelectionDAG::getRegisterMask(const uint32_t *RegMask) {
2138   FoldingSetNodeID ID;
2139   AddNodeIDNode(ID, ISD::RegisterMask, getVTList(MVT::Untyped), std::nullopt);
2140   ID.AddPointer(RegMask);
2141   void *IP = nullptr;
2142   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
2143     return SDValue(E, 0);
2144 
2145   auto *N = newSDNode<RegisterMaskSDNode>(RegMask);
2146   CSEMap.InsertNode(N, IP);
2147   InsertNode(N);
2148   return SDValue(N, 0);
2149 }
2150 
getEHLabel(const SDLoc & dl,SDValue Root,MCSymbol * Label)2151 SDValue SelectionDAG::getEHLabel(const SDLoc &dl, SDValue Root,
2152                                  MCSymbol *Label) {
2153   return getLabelNode(ISD::EH_LABEL, dl, Root, Label);
2154 }
2155 
getLabelNode(unsigned Opcode,const SDLoc & dl,SDValue Root,MCSymbol * Label)2156 SDValue SelectionDAG::getLabelNode(unsigned Opcode, const SDLoc &dl,
2157                                    SDValue Root, MCSymbol *Label) {
2158   FoldingSetNodeID ID;
2159   SDValue Ops[] = { Root };
2160   AddNodeIDNode(ID, Opcode, getVTList(MVT::Other), Ops);
2161   ID.AddPointer(Label);
2162   void *IP = nullptr;
2163   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
2164     return SDValue(E, 0);
2165 
2166   auto *N =
2167       newSDNode<LabelSDNode>(Opcode, dl.getIROrder(), dl.getDebugLoc(), Label);
2168   createOperands(N, Ops);
2169 
2170   CSEMap.InsertNode(N, IP);
2171   InsertNode(N);
2172   return SDValue(N, 0);
2173 }
2174 
getBlockAddress(const BlockAddress * BA,EVT VT,int64_t Offset,bool isTarget,unsigned TargetFlags)2175 SDValue SelectionDAG::getBlockAddress(const BlockAddress *BA, EVT VT,
2176                                       int64_t Offset, bool isTarget,
2177                                       unsigned TargetFlags) {
2178   unsigned Opc = isTarget ? ISD::TargetBlockAddress : ISD::BlockAddress;
2179 
2180   FoldingSetNodeID ID;
2181   AddNodeIDNode(ID, Opc, getVTList(VT), std::nullopt);
2182   ID.AddPointer(BA);
2183   ID.AddInteger(Offset);
2184   ID.AddInteger(TargetFlags);
2185   void *IP = nullptr;
2186   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
2187     return SDValue(E, 0);
2188 
2189   auto *N = newSDNode<BlockAddressSDNode>(Opc, VT, BA, Offset, TargetFlags);
2190   CSEMap.InsertNode(N, IP);
2191   InsertNode(N);
2192   return SDValue(N, 0);
2193 }
2194 
getSrcValue(const Value * V)2195 SDValue SelectionDAG::getSrcValue(const Value *V) {
2196   FoldingSetNodeID ID;
2197   AddNodeIDNode(ID, ISD::SRCVALUE, getVTList(MVT::Other), std::nullopt);
2198   ID.AddPointer(V);
2199 
2200   void *IP = nullptr;
2201   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
2202     return SDValue(E, 0);
2203 
2204   auto *N = newSDNode<SrcValueSDNode>(V);
2205   CSEMap.InsertNode(N, IP);
2206   InsertNode(N);
2207   return SDValue(N, 0);
2208 }
2209 
getMDNode(const MDNode * MD)2210 SDValue SelectionDAG::getMDNode(const MDNode *MD) {
2211   FoldingSetNodeID ID;
2212   AddNodeIDNode(ID, ISD::MDNODE_SDNODE, getVTList(MVT::Other), std::nullopt);
2213   ID.AddPointer(MD);
2214 
2215   void *IP = nullptr;
2216   if (SDNode *E = FindNodeOrInsertPos(ID, IP))
2217     return SDValue(E, 0);
2218 
2219   auto *N = newSDNode<MDNodeSDNode>(MD);
2220   CSEMap.InsertNode(N, IP);
2221   InsertNode(N);
2222   return SDValue(N, 0);
2223 }
2224 
getBitcast(EVT VT,SDValue V)2225 SDValue SelectionDAG::getBitcast(EVT VT, SDValue V) {
2226   if (VT == V.getValueType())
2227     return V;
2228 
2229   return getNode(ISD::BITCAST, SDLoc(V), VT, V);
2230 }
2231 
getAddrSpaceCast(const SDLoc & dl,EVT VT,SDValue Ptr,unsigned SrcAS,unsigned DestAS)2232 SDValue SelectionDAG::getAddrSpaceCast(const SDLoc &dl, EVT VT, SDValue Ptr,
2233                                        unsigned SrcAS, unsigned DestAS) {
2234   SDValue Ops[] = {Ptr};
2235   FoldingSetNodeID ID;
2236   AddNodeIDNode(ID, ISD::ADDRSPACECAST, getVTList(VT), Ops);
2237   ID.AddInteger(SrcAS);
2238   ID.AddInteger(DestAS);
2239 
2240   void *IP = nullptr;
2241   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP))
2242     return SDValue(E, 0);
2243 
2244   auto *N = newSDNode<AddrSpaceCastSDNode>(dl.getIROrder(), dl.getDebugLoc(),
2245                                            VT, SrcAS, DestAS);
2246   createOperands(N, Ops);
2247 
2248   CSEMap.InsertNode(N, IP);
2249   InsertNode(N);
2250   return SDValue(N, 0);
2251 }
2252 
getFreeze(SDValue V)2253 SDValue SelectionDAG::getFreeze(SDValue V) {
2254   return getNode(ISD::FREEZE, SDLoc(V), V.getValueType(), V);
2255 }
2256 
2257 /// getShiftAmountOperand - Return the specified value casted to
2258 /// the target's desired shift amount type.
getShiftAmountOperand(EVT LHSTy,SDValue Op)2259 SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
2260   EVT OpTy = Op.getValueType();
2261   EVT ShTy = TLI->getShiftAmountTy(LHSTy, getDataLayout());
2262   if (OpTy == ShTy || OpTy.isVector()) return Op;
2263 
2264   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
2265 }
2266 
expandVAArg(SDNode * Node)2267 SDValue SelectionDAG::expandVAArg(SDNode *Node) {
2268   SDLoc dl(Node);
2269   const TargetLowering &TLI = getTargetLoweringInfo();
2270   const Value *V = cast<SrcValueSDNode>(Node->getOperand(2))->getValue();
2271   EVT VT = Node->getValueType(0);
2272   SDValue Tmp1 = Node->getOperand(0);
2273   SDValue Tmp2 = Node->getOperand(1);
2274   const MaybeAlign MA(Node->getConstantOperandVal(3));
2275 
2276   SDValue VAListLoad = getLoad(TLI.getPointerTy(getDataLayout()), dl, Tmp1,
2277                                Tmp2, MachinePointerInfo(V));
2278   SDValue VAList = VAListLoad;
2279 
2280   if (MA && *MA > TLI.getMinStackArgumentAlignment()) {
2281     VAList = getNode(ISD::ADD, dl, VAList.getValueType(), VAList,
2282                      getConstant(MA->value() - 1, dl, VAList.getValueType()));
2283 
2284     VAList =
2285         getNode(ISD::AND, dl, VAList.getValueType(), VAList,
2286                 getConstant(-(int64_t)MA->value(), dl, VAList.getValueType()));
2287   }
2288 
2289   // Increment the pointer, VAList, to the next vaarg
2290   Tmp1 = getNode(ISD::ADD, dl, VAList.getValueType(), VAList,
2291                  getConstant(getDataLayout().getTypeAllocSize(
2292                                                VT.getTypeForEVT(*getContext())),
2293                              dl, VAList.getValueType()));
2294   // Store the incremented VAList to the legalized pointer
2295   Tmp1 =
2296       getStore(VAListLoad.getValue(1), dl, Tmp1, Tmp2, MachinePointerInfo(V));
2297   // Load the actual argument out of the pointer VAList
2298   return getLoad(VT, dl, Tmp1, VAList, MachinePointerInfo());
2299 }
2300 
expandVACopy(SDNode * Node)2301 SDValue SelectionDAG::expandVACopy(SDNode *Node) {
2302   SDLoc dl(Node);
2303   const TargetLowering &TLI = getTargetLoweringInfo();
2304   // This defaults to loading a pointer from the input and storing it to the
2305   // output, returning the chain.
2306   const Value *VD = cast<SrcValueSDNode>(Node->getOperand(3))->getValue();
2307   const Value *VS = cast<SrcValueSDNode>(Node->getOperand(4))->getValue();
2308   SDValue Tmp1 =
2309       getLoad(TLI.getPointerTy(getDataLayout()), dl, Node->getOperand(0),
2310               Node->getOperand(2), MachinePointerInfo(VS));
2311   return getStore(Tmp1.getValue(1), dl, Tmp1, Node->getOperand(1),
2312                   MachinePointerInfo(VD));
2313 }
2314 
getReducedAlign(EVT VT,bool UseABI)2315 Align SelectionDAG::getReducedAlign(EVT VT, bool UseABI) {
2316   const DataLayout &DL = getDataLayout();
2317   Type *Ty = VT.getTypeForEVT(*getContext());
2318   Align RedAlign = UseABI ? DL.getABITypeAlign(Ty) : DL.getPrefTypeAlign(Ty);
2319 
2320   if (TLI->isTypeLegal(VT) || !VT.isVector())
2321     return RedAlign;
2322 
2323   const TargetFrameLowering *TFI = MF->getSubtarget().getFrameLowering();
2324   const Align StackAlign = TFI->getStackAlign();
2325 
2326   // See if we can choose a smaller ABI alignment in cases where it's an
2327   // illegal vector type that will get broken down.
2328   if (RedAlign > StackAlign) {
2329     EVT IntermediateVT;
2330     MVT RegisterVT;
2331     unsigned NumIntermediates;
2332     TLI->getVectorTypeBreakdown(*getContext(), VT, IntermediateVT,
2333                                 NumIntermediates, RegisterVT);
2334     Ty = IntermediateVT.getTypeForEVT(*getContext());
2335     Align RedAlign2 = UseABI ? DL.getABITypeAlign(Ty) : DL.getPrefTypeAlign(Ty);
2336     if (RedAlign2 < RedAlign)
2337       RedAlign = RedAlign2;
2338   }
2339 
2340   return RedAlign;
2341 }
2342 
CreateStackTemporary(TypeSize Bytes,Align Alignment)2343 SDValue SelectionDAG::CreateStackTemporary(TypeSize Bytes, Align Alignment) {
2344   MachineFrameInfo &MFI = MF->getFrameInfo();
2345   const TargetFrameLowering *TFI = MF->getSubtarget().getFrameLowering();
2346   int StackID = 0;
2347   if (Bytes.isScalable())
2348     StackID = TFI->getStackIDForScalableVectors();
2349   // The stack id gives an indication of whether the object is scalable or
2350   // not, so it's safe to pass in the minimum size here.
2351   int FrameIdx = MFI.CreateStackObject(Bytes.getKnownMinValue(), Alignment,
2352                                        false, nullptr, StackID);
2353   return getFrameIndex(FrameIdx, TLI->getFrameIndexTy(getDataLayout()));
2354 }
2355 
CreateStackTemporary(EVT VT,unsigned minAlign)2356 SDValue SelectionDAG::CreateStackTemporary(EVT VT, unsigned minAlign) {
2357   Type *Ty = VT.getTypeForEVT(*getContext());
2358   Align StackAlign =
2359       std::max(getDataLayout().getPrefTypeAlign(Ty), Align(minAlign));
2360   return CreateStackTemporary(VT.getStoreSize(), StackAlign);
2361 }
2362 
CreateStackTemporary(EVT VT1,EVT VT2)2363 SDValue SelectionDAG::CreateStackTemporary(EVT VT1, EVT VT2) {
2364   TypeSize VT1Size = VT1.getStoreSize();
2365   TypeSize VT2Size = VT2.getStoreSize();
2366   assert(VT1Size.isScalable() == VT2Size.isScalable() &&
2367          "Don't know how to choose the maximum size when creating a stack "
2368          "temporary");
2369   TypeSize Bytes = VT1Size.getKnownMinValue() > VT2Size.getKnownMinValue()
2370                        ? VT1Size
2371                        : VT2Size;
2372 
2373   Type *Ty1 = VT1.getTypeForEVT(*getContext());
2374   Type *Ty2 = VT2.getTypeForEVT(*getContext());
2375   const DataLayout &DL = getDataLayout();
2376   Align Align = std::max(DL.getPrefTypeAlign(Ty1), DL.getPrefTypeAlign(Ty2));
2377   return CreateStackTemporary(Bytes, Align);
2378 }
2379 
FoldSetCC(EVT VT,SDValue N1,SDValue N2,ISD::CondCode Cond,const SDLoc & dl)2380 SDValue SelectionDAG::FoldSetCC(EVT VT, SDValue N1, SDValue N2,
2381                                 ISD::CondCode Cond, const SDLoc &dl) {
2382   EVT OpVT = N1.getValueType();
2383 
2384   // These setcc operations always fold.
2385   switch (Cond) {
2386   default: break;
2387   case ISD::SETFALSE:
2388   case ISD::SETFALSE2: return getBoolConstant(false, dl, VT, OpVT);
2389   case ISD::SETTRUE:
2390   case ISD::SETTRUE2: return getBoolConstant(true, dl, VT, OpVT);
2391 
2392   case ISD::SETOEQ:
2393   case ISD::SETOGT:
2394   case ISD::SETOGE:
2395   case ISD::SETOLT:
2396   case ISD::SETOLE:
2397   case ISD::SETONE:
2398   case ISD::SETO:
2399   case ISD::SETUO:
2400   case ISD::SETUEQ:
2401   case ISD::SETUNE:
2402     assert(!OpVT.isInteger() && "Illegal setcc for integer!");
2403     break;
2404   }
2405 
2406   if (OpVT.isInteger()) {
2407     // For EQ and NE, we can always pick a value for the undef to make the
2408     // predicate pass or fail, so we can return undef.
2409     // Matches behavior in llvm::ConstantFoldCompareInstruction.
2410     // icmp eq/ne X, undef -> undef.
2411     if ((N1.isUndef() || N2.isUndef()) &&
2412         (Cond == ISD::SETEQ || Cond == ISD::SETNE))
2413       return getUNDEF(VT);
2414 
2415     // If both operands are undef, we can return undef for int comparison.
2416     // icmp undef, undef -> undef.
2417     if (N1.isUndef() && N2.isUndef())
2418       return getUNDEF(VT);
2419 
2420     // icmp X, X -> true/false
2421     // icmp X, undef -> true/false because undef could be X.
2422     if (N1 == N2)
2423       return getBoolConstant(ISD::isTrueWhenEqual(Cond), dl, VT, OpVT);
2424   }
2425 
2426   if (ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2)) {
2427     const APInt &C2 = N2C->getAPIntValue();
2428     if (ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1)) {
2429       const APInt &C1 = N1C->getAPIntValue();
2430 
2431       return getBoolConstant(ICmpInst::compare(C1, C2, getICmpCondCode(Cond)),
2432                              dl, VT, OpVT);
2433     }
2434   }
2435 
2436   auto *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
2437   auto *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
2438 
2439   if (N1CFP && N2CFP) {
2440     APFloat::cmpResult R = N1CFP->getValueAPF().compare(N2CFP->getValueAPF());
2441     switch (Cond) {
2442     default: break;
2443     case ISD::SETEQ:  if (R==APFloat::cmpUnordered)
2444                         return getUNDEF(VT);
2445                       [[fallthrough]];
2446     case ISD::SETOEQ: return getBoolConstant(R==APFloat::cmpEqual, dl, VT,
2447                                              OpVT);
2448     case ISD::SETNE:  if (R==APFloat::cmpUnordered)
2449                         return getUNDEF(VT);
2450                       [[fallthrough]];
2451     case ISD::SETONE: return getBoolConstant(R==APFloat::cmpGreaterThan ||
2452                                              R==APFloat::cmpLessThan, dl, VT,
2453                                              OpVT);
2454     case ISD::SETLT:  if (R==APFloat::cmpUnordered)
2455                         return getUNDEF(VT);
2456                       [[fallthrough]];
2457     case ISD::SETOLT: return getBoolConstant(R==APFloat::cmpLessThan, dl, VT,
2458                                              OpVT);
2459     case ISD::SETGT:  if (R==APFloat::cmpUnordered)
2460                         return getUNDEF(VT);
2461                       [[fallthrough]];
2462     case ISD::SETOGT: return getBoolConstant(R==APFloat::cmpGreaterThan, dl,
2463                                              VT, OpVT);
2464     case ISD::SETLE:  if (R==APFloat::cmpUnordered)
2465                         return getUNDEF(VT);
2466                       [[fallthrough]];
2467     case ISD::SETOLE: return getBoolConstant(R==APFloat::cmpLessThan ||
2468                                              R==APFloat::cmpEqual, dl, VT,
2469                                              OpVT);
2470     case ISD::SETGE:  if (R==APFloat::cmpUnordered)
2471                         return getUNDEF(VT);
2472                       [[fallthrough]];
2473     case ISD::SETOGE: return getBoolConstant(R==APFloat::cmpGreaterThan ||
2474                                          R==APFloat::cmpEqual, dl, VT, OpVT);
2475     case ISD::SETO:   return getBoolConstant(R!=APFloat::cmpUnordered, dl, VT,
2476                                              OpVT);
2477     case ISD::SETUO:  return getBoolConstant(R==APFloat::cmpUnordered, dl, VT,
2478                                              OpVT);
2479     case ISD::SETUEQ: return getBoolConstant(R==APFloat::cmpUnordered ||
2480                                              R==APFloat::cmpEqual, dl, VT,
2481                                              OpVT);
2482     case ISD::SETUNE: return getBoolConstant(R!=APFloat::cmpEqual, dl, VT,
2483                                              OpVT);
2484     case ISD::SETULT: return getBoolConstant(R==APFloat::cmpUnordered ||
2485                                              R==APFloat::cmpLessThan, dl, VT,
2486                                              OpVT);
2487     case ISD::SETUGT: return getBoolConstant(R==APFloat::cmpGreaterThan ||
2488                                              R==APFloat::cmpUnordered, dl, VT,
2489                                              OpVT);
2490     case ISD::SETULE: return getBoolConstant(R!=APFloat::cmpGreaterThan, dl,
2491                                              VT, OpVT);
2492     case ISD::SETUGE: return getBoolConstant(R!=APFloat::cmpLessThan, dl, VT,
2493                                              OpVT);
2494     }
2495   } else if (N1CFP && OpVT.isSimple() && !N2.isUndef()) {
2496     // Ensure that the constant occurs on the RHS.
2497     ISD::CondCode SwappedCond = ISD::getSetCCSwappedOperands(Cond);
2498     if (!TLI->isCondCodeLegal(SwappedCond, OpVT.getSimpleVT()))
2499       return SDValue();
2500     return getSetCC(dl, VT, N2, N1, SwappedCond);
2501   } else if ((N2CFP && N2CFP->getValueAPF().isNaN()) ||
2502              (OpVT.isFloatingPoint() && (N1.isUndef() || N2.isUndef()))) {
2503     // If an operand is known to be a nan (or undef that could be a nan), we can
2504     // fold it.
2505     // Choosing NaN for the undef will always make unordered comparison succeed
2506     // and ordered comparison fails.
2507     // Matches behavior in llvm::ConstantFoldCompareInstruction.
2508     switch (ISD::getUnorderedFlavor(Cond)) {
2509     default:
2510       llvm_unreachable("Unknown flavor!");
2511     case 0: // Known false.
2512       return getBoolConstant(false, dl, VT, OpVT);
2513     case 1: // Known true.
2514       return getBoolConstant(true, dl, VT, OpVT);
2515     case 2: // Undefined.
2516       return getUNDEF(VT);
2517     }
2518   }
2519 
2520   // Could not fold it.
2521   return SDValue();
2522 }
2523 
2524 /// SignBitIsZero - Return true if the sign bit of Op is known to be zero.  We
2525 /// use this predicate to simplify operations downstream.
SignBitIsZero(SDValue Op,unsigned Depth) const2526 bool SelectionDAG::SignBitIsZero(SDValue Op, unsigned Depth) const {
2527   unsigned BitWidth = Op.getScalarValueSizeInBits();
2528   return MaskedValueIsZero(Op, APInt::getSignMask(BitWidth), Depth);
2529 }
2530 
2531 /// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero.  We use
2532 /// this predicate to simplify operations downstream.  Mask is known to be zero
2533 /// for bits that V cannot have.
MaskedValueIsZero(SDValue V,const APInt & Mask,unsigned Depth) const2534 bool SelectionDAG::MaskedValueIsZero(SDValue V, const APInt &Mask,
2535                                      unsigned Depth) const {
2536   return Mask.isSubsetOf(computeKnownBits(V, Depth).Zero);
2537 }
2538 
2539 /// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero in
2540 /// DemandedElts.  We use this predicate to simplify operations downstream.
2541 /// Mask is known to be zero for bits that V cannot have.
MaskedValueIsZero(SDValue V,const APInt & Mask,const APInt & DemandedElts,unsigned Depth) const2542 bool SelectionDAG::MaskedValueIsZero(SDValue V, const APInt &Mask,
2543                                      const APInt &DemandedElts,
2544                                      unsigned Depth) const {
2545   return Mask.isSubsetOf(computeKnownBits(V, DemandedElts, Depth).Zero);
2546 }
2547 
2548 /// MaskedVectorIsZero - Return true if 'Op' is known to be zero in
2549 /// DemandedElts.  We use this predicate to simplify operations downstream.
MaskedVectorIsZero(SDValue V,const APInt & DemandedElts,unsigned Depth) const2550 bool SelectionDAG::MaskedVectorIsZero(SDValue V, const APInt &DemandedElts,
2551                                       unsigned Depth /* = 0 */) const {
2552   return computeKnownBits(V, DemandedElts, Depth).isZero();
2553 }
2554 
2555 /// MaskedValueIsAllOnes - Return true if '(Op & Mask) == Mask'.
MaskedValueIsAllOnes(SDValue V,const APInt & Mask,unsigned Depth) const2556 bool SelectionDAG::MaskedValueIsAllOnes(SDValue V, const APInt &Mask,
2557                                         unsigned Depth) const {
2558   return Mask.isSubsetOf(computeKnownBits(V, Depth).One);
2559 }
2560 
computeVectorKnownZeroElements(SDValue Op,const APInt & DemandedElts,unsigned Depth) const2561 APInt SelectionDAG::computeVectorKnownZeroElements(SDValue Op,
2562                                                    const APInt &DemandedElts,
2563                                                    unsigned Depth) const {
2564   EVT VT = Op.getValueType();
2565   assert(VT.isVector() && !VT.isScalableVector() && "Only for fixed vectors!");
2566 
2567   unsigned NumElts = VT.getVectorNumElements();
2568   assert(DemandedElts.getBitWidth() == NumElts && "Unexpected demanded mask.");
2569 
2570   APInt KnownZeroElements = APInt::getNullValue(NumElts);
2571   for (unsigned EltIdx = 0; EltIdx != NumElts; ++EltIdx) {
2572     if (!DemandedElts[EltIdx])
2573       continue; // Don't query elements that are not demanded.
2574     APInt Mask = APInt::getOneBitSet(NumElts, EltIdx);
2575     if (MaskedVectorIsZero(Op, Mask, Depth))
2576       KnownZeroElements.setBit(EltIdx);
2577   }
2578   return KnownZeroElements;
2579 }
2580 
2581 /// isSplatValue - Return true if the vector V has the same value
2582 /// across all DemandedElts. For scalable vectors, we don't know the
2583 /// number of lanes at compile time.  Instead, we use a 1 bit APInt
2584 /// to represent a conservative value for all lanes; that is, that
2585 /// one bit value is implicitly splatted across all lanes.
isSplatValue(SDValue V,const APInt & DemandedElts,APInt & UndefElts,unsigned Depth) const2586 bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts,
2587                                 APInt &UndefElts, unsigned Depth) const {
2588   unsigned Opcode = V.getOpcode();
2589   EVT VT = V.getValueType();
2590   assert(VT.isVector() && "Vector type expected");
2591   assert((!VT.isScalableVector() || DemandedElts.getBitWidth() == 1) &&
2592          "scalable demanded bits are ignored");
2593 
2594   if (!DemandedElts)
2595     return false; // No demanded elts, better to assume we don't know anything.
2596 
2597   if (Depth >= MaxRecursionDepth)
2598     return false; // Limit search depth.
2599 
2600   // Deal with some common cases here that work for both fixed and scalable
2601   // vector types.
2602   switch (Opcode) {
2603   case ISD::SPLAT_VECTOR:
2604     UndefElts = V.getOperand(0).isUndef()
2605                     ? APInt::getAllOnes(DemandedElts.getBitWidth())
2606                     : APInt(DemandedElts.getBitWidth(), 0);
2607     return true;
2608   case ISD::ADD:
2609   case ISD::SUB:
2610   case ISD::AND:
2611   case ISD::XOR:
2612   case ISD::OR: {
2613     APInt UndefLHS, UndefRHS;
2614     SDValue LHS = V.getOperand(0);
2615     SDValue RHS = V.getOperand(1);
2616     if (isSplatValue(LHS, DemandedElts, UndefLHS, Depth + 1) &&
2617         isSplatValue(RHS, DemandedElts, UndefRHS, Depth + 1)) {
2618       UndefElts = UndefLHS | UndefRHS;
2619       return true;
2620     }
2621     return false;
2622   }
2623   case ISD::ABS:
2624   case ISD::TRUNCATE:
2625   case ISD::SIGN_EXTEND:
2626   case ISD::ZERO_EXTEND:
2627     return isSplatValue(V.getOperand(0), DemandedElts, UndefElts, Depth + 1);
2628   default:
2629     if (Opcode >= ISD::BUILTIN_OP_END || Opcode == ISD::INTRINSIC_WO_CHAIN ||
2630         Opcode == ISD::INTRINSIC_W_CHAIN || Opcode == ISD::INTRINSIC_VOID)
2631       return TLI->isSplatValueForTargetNode(V, DemandedElts, UndefElts, *this,
2632                                             Depth);
2633     break;
2634 }
2635 
2636   // We don't support other cases than those above for scalable vectors at
2637   // the moment.
2638   if (VT.isScalableVector())
2639     return false;
2640 
2641   unsigned NumElts = VT.getVectorNumElements();
2642   assert(NumElts == DemandedElts.getBitWidth() && "Vector size mismatch");
2643   UndefElts = APInt::getZero(NumElts);
2644 
2645   switch (Opcode) {
2646   case ISD::BUILD_VECTOR: {
2647     SDValue Scl;
2648     for (unsigned i = 0; i != NumElts; ++i) {
2649       SDValue Op = V.getOperand(i);
2650       if (Op.isUndef()) {
2651         UndefElts.setBit(i);
2652         continue;
2653       }
2654       if (!DemandedElts[i])
2655         continue;
2656       if (Scl && Scl != Op)
2657         return false;
2658       Scl = Op;
2659     }
2660     return true;
2661   }
2662   case ISD::VECTOR_SHUFFLE: {
2663     // Check if this is a shuffle node doing a splat or a shuffle of a splat.
2664     APInt DemandedLHS = APInt::getNullValue(NumElts);
2665     APInt DemandedRHS = APInt::getNullValue(NumElts);
2666     ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(V)->getMask();
2667     for (int i = 0; i != (int)NumElts; ++i) {
2668       int M = Mask[i];
2669       if (M < 0) {
2670         UndefElts.setBit(i);
2671         continue;
2672       }
2673       if (!DemandedElts[i])
2674         continue;
2675       if (M < (int)NumElts)
2676         DemandedLHS.setBit(M);
2677       else
2678         DemandedRHS.setBit(M - NumElts);
2679     }
2680 
2681     // If we aren't demanding either op, assume there's no splat.
2682     // If we are demanding both ops, assume there's no splat.
2683     if ((DemandedLHS.isZero() && DemandedRHS.isZero()) ||
2684         (!DemandedLHS.isZero() && !DemandedRHS.isZero()))
2685       return false;
2686 
2687     // See if the demanded elts of the source op is a splat or we only demand
2688     // one element, which should always be a splat.
2689     // TODO: Handle source ops splats with undefs.
2690     auto CheckSplatSrc = [&](SDValue Src, const APInt &SrcElts) {
2691       APInt SrcUndefs;
2692       return (SrcElts.countPopulation() == 1) ||
2693              (isSplatValue(Src, SrcElts, SrcUndefs, Depth + 1) &&
2694               (SrcElts & SrcUndefs).isZero());
2695     };
2696     if (!DemandedLHS.isZero())
2697       return CheckSplatSrc(V.getOperand(0), DemandedLHS);
2698     return CheckSplatSrc(V.getOperand(1), DemandedRHS);
2699   }
2700   case ISD::EXTRACT_SUBVECTOR: {
2701     // Offset the demanded elts by the subvector index.
2702     SDValue Src = V.getOperand(0);
2703     // We don't support scalable vectors at the moment.
2704     if (Src.getValueType().isScalableVector())
2705       return false;
2706     uint64_t Idx = V.getConstantOperandVal(1);
2707     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
2708     APInt UndefSrcElts;
2709     APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
2710     if (isSplatValue(Src, DemandedSrcElts, UndefSrcElts, Depth + 1)) {
2711       UndefElts = UndefSrcElts.extractBits(NumElts, Idx);
2712       return true;
2713     }
2714     break;
2715   }
2716   case ISD::ANY_EXTEND_VECTOR_INREG:
2717   case ISD::SIGN_EXTEND_VECTOR_INREG:
2718   case ISD::ZERO_EXTEND_VECTOR_INREG: {
2719     // Widen the demanded elts by the src element count.
2720     SDValue Src = V.getOperand(0);
2721     // We don't support scalable vectors at the moment.
2722     if (Src.getValueType().isScalableVector())
2723       return false;
2724     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
2725     APInt UndefSrcElts;
2726     APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts);
2727     if (isSplatValue(Src, DemandedSrcElts, UndefSrcElts, Depth + 1)) {
2728       UndefElts = UndefSrcElts.trunc(NumElts);
2729       return true;
2730     }
2731     break;
2732   }
2733   case ISD::BITCAST: {
2734     SDValue Src = V.getOperand(0);
2735     EVT SrcVT = Src.getValueType();
2736     unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
2737     unsigned BitWidth = VT.getScalarSizeInBits();
2738 
2739     // Ignore bitcasts from unsupported types.
2740     // TODO: Add fp support?
2741     if (!SrcVT.isVector() || !SrcVT.isInteger() || !VT.isInteger())
2742       break;
2743 
2744     // Bitcast 'small element' vector to 'large element' vector.
2745     if ((BitWidth % SrcBitWidth) == 0) {
2746       // See if each sub element is a splat.
2747       unsigned Scale = BitWidth / SrcBitWidth;
2748       unsigned NumSrcElts = SrcVT.getVectorNumElements();
2749       APInt ScaledDemandedElts =
2750           APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
2751       for (unsigned I = 0; I != Scale; ++I) {
2752         APInt SubUndefElts;
2753         APInt SubDemandedElt = APInt::getOneBitSet(Scale, I);
2754         APInt SubDemandedElts = APInt::getSplat(NumSrcElts, SubDemandedElt);
2755         SubDemandedElts &= ScaledDemandedElts;
2756         if (!isSplatValue(Src, SubDemandedElts, SubUndefElts, Depth + 1))
2757           return false;
2758         // TODO: Add support for merging sub undef elements.
2759         if (!SubUndefElts.isZero())
2760           return false;
2761       }
2762       return true;
2763     }
2764     break;
2765   }
2766   }
2767 
2768   return false;
2769 }
2770 
2771 /// Helper wrapper to main isSplatValue function.
isSplatValue(SDValue V,bool AllowUndefs) const2772 bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) const {
2773   EVT VT = V.getValueType();
2774   assert(VT.isVector() && "Vector type expected");
2775 
2776   APInt UndefElts;
2777   // Since the number of lanes in a scalable vector is unknown at compile time,
2778   // we track one bit which is implicitly broadcast to all lanes.  This means
2779   // that all lanes in a scalable vector are considered demanded.
2780   APInt DemandedElts
2781     = APInt::getAllOnes(VT.isScalableVector() ? 1 : VT.getVectorNumElements());
2782   return isSplatValue(V, DemandedElts, UndefElts) &&
2783          (AllowUndefs || !UndefElts);
2784 }
2785 
getSplatSourceVector(SDValue V,int & SplatIdx)2786 SDValue SelectionDAG::getSplatSourceVector(SDValue V, int &SplatIdx) {
2787   V = peekThroughExtractSubvectors(V);
2788 
2789   EVT VT = V.getValueType();
2790   unsigned Opcode = V.getOpcode();
2791   switch (Opcode) {
2792   default: {
2793     APInt UndefElts;
2794     // Since the number of lanes in a scalable vector is unknown at compile time,
2795     // we track one bit which is implicitly broadcast to all lanes.  This means
2796     // that all lanes in a scalable vector are considered demanded.
2797     APInt DemandedElts
2798       = APInt::getAllOnes(VT.isScalableVector() ? 1 : VT.getVectorNumElements());
2799 
2800     if (isSplatValue(V, DemandedElts, UndefElts)) {
2801       if (VT.isScalableVector()) {
2802         // DemandedElts and UndefElts are ignored for scalable vectors, since
2803         // the only supported cases are SPLAT_VECTOR nodes.
2804         SplatIdx = 0;
2805       } else {
2806         // Handle case where all demanded elements are UNDEF.
2807         if (DemandedElts.isSubsetOf(UndefElts)) {
2808           SplatIdx = 0;
2809           return getUNDEF(VT);
2810         }
2811         SplatIdx = (UndefElts & DemandedElts).countTrailingOnes();
2812       }
2813       return V;
2814     }
2815     break;
2816   }
2817   case ISD::SPLAT_VECTOR:
2818     SplatIdx = 0;
2819     return V;
2820   case ISD::VECTOR_SHUFFLE: {
2821     assert(!VT.isScalableVector());
2822     // Check if this is a shuffle node doing a splat.
2823     // TODO - remove this and rely purely on SelectionDAG::isSplatValue,
2824     // getTargetVShiftNode currently struggles without the splat source.
2825     auto *SVN = cast<ShuffleVectorSDNode>(V);
2826     if (!SVN->isSplat())
2827       break;
2828     int Idx = SVN->getSplatIndex();
2829     int NumElts = V.getValueType().getVectorNumElements();
2830     SplatIdx = Idx % NumElts;
2831     return V.getOperand(Idx / NumElts);
2832   }
2833   }
2834 
2835   return SDValue();
2836 }
2837 
getSplatValue(SDValue V,bool LegalTypes)2838 SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
2839   int SplatIdx;
2840   if (SDValue SrcVector = getSplatSourceVector(V, SplatIdx)) {
2841     EVT SVT = SrcVector.getValueType().getScalarType();
2842     EVT LegalSVT = SVT;
2843     if (LegalTypes && !TLI->isTypeLegal(SVT)) {
2844       if (!SVT.isInteger())
2845         return SDValue();
2846       LegalSVT = TLI->getTypeToTransformTo(*getContext(), LegalSVT);
2847       if (LegalSVT.bitsLT(SVT))
2848         return SDValue();
2849     }
2850     return getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), LegalSVT, SrcVector,
2851                    getVectorIdxConstant(SplatIdx, SDLoc(V)));
2852   }
2853   return SDValue();
2854 }
2855 
2856 const APInt *
getValidShiftAmountConstant(SDValue V,const APInt & DemandedElts) const2857 SelectionDAG::getValidShiftAmountConstant(SDValue V,
2858                                           const APInt &DemandedElts) const {
2859   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
2860           V.getOpcode() == ISD::SRA) &&
2861          "Unknown shift node");
2862   unsigned BitWidth = V.getScalarValueSizeInBits();
2863   if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1), DemandedElts)) {
2864     // Shifting more than the bitwidth is not valid.
2865     const APInt &ShAmt = SA->getAPIntValue();
2866     if (ShAmt.ult(BitWidth))
2867       return &ShAmt;
2868   }
2869   return nullptr;
2870 }
2871 
getValidMinimumShiftAmountConstant(SDValue V,const APInt & DemandedElts) const2872 const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
2873     SDValue V, const APInt &DemandedElts) const {
2874   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
2875           V.getOpcode() == ISD::SRA) &&
2876          "Unknown shift node");
2877   if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
2878     return ValidAmt;
2879   unsigned BitWidth = V.getScalarValueSizeInBits();
2880   auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
2881   if (!BV)
2882     return nullptr;
2883   const APInt *MinShAmt = nullptr;
2884   for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
2885     if (!DemandedElts[i])
2886       continue;
2887     auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
2888     if (!SA)
2889       return nullptr;
2890     // Shifting more than the bitwidth is not valid.
2891     const APInt &ShAmt = SA->getAPIntValue();
2892     if (ShAmt.uge(BitWidth))
2893       return nullptr;
2894     if (MinShAmt && MinShAmt->ule(ShAmt))
2895       continue;
2896     MinShAmt = &ShAmt;
2897   }
2898   return MinShAmt;
2899 }
2900 
getValidMaximumShiftAmountConstant(SDValue V,const APInt & DemandedElts) const2901 const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
2902     SDValue V, const APInt &DemandedElts) const {
2903   assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
2904           V.getOpcode() == ISD::SRA) &&
2905          "Unknown shift node");
2906   if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
2907     return ValidAmt;
2908   unsigned BitWidth = V.getScalarValueSizeInBits();
2909   auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
2910   if (!BV)
2911     return nullptr;
2912   const APInt *MaxShAmt = nullptr;
2913   for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
2914     if (!DemandedElts[i])
2915       continue;
2916     auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
2917     if (!SA)
2918       return nullptr;
2919     // Shifting more than the bitwidth is not valid.
2920     const APInt &ShAmt = SA->getAPIntValue();
2921     if (ShAmt.uge(BitWidth))
2922       return nullptr;
2923     if (MaxShAmt && MaxShAmt->uge(ShAmt))
2924       continue;
2925     MaxShAmt = &ShAmt;
2926   }
2927   return MaxShAmt;
2928 }
2929 
2930 /// Determine which bits of Op are known to be either zero or one and return
2931 /// them in Known. For vectors, the known bits are those that are shared by
2932 /// every vector element.
computeKnownBits(SDValue Op,unsigned Depth) const2933 KnownBits SelectionDAG::computeKnownBits(SDValue Op, unsigned Depth) const {
2934   EVT VT = Op.getValueType();
2935 
2936   // Since the number of lanes in a scalable vector is unknown at compile time,
2937   // we track one bit which is implicitly broadcast to all lanes.  This means
2938   // that all lanes in a scalable vector are considered demanded.
2939   APInt DemandedElts = VT.isFixedLengthVector()
2940                            ? APInt::getAllOnes(VT.getVectorNumElements())
2941                            : APInt(1, 1);
2942   return computeKnownBits(Op, DemandedElts, Depth);
2943 }
2944 
2945 /// Determine which bits of Op are known to be either zero or one and return
2946 /// them in Known. The DemandedElts argument allows us to only collect the known
2947 /// bits that are shared by the requested vector elements.
computeKnownBits(SDValue Op,const APInt & DemandedElts,unsigned Depth) const2948 KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
2949                                          unsigned Depth) const {
2950   unsigned BitWidth = Op.getScalarValueSizeInBits();
2951 
2952   KnownBits Known(BitWidth);   // Don't know anything.
2953 
2954   if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
2955     // We know all of the bits for a constant!
2956     return KnownBits::makeConstant(C->getAPIntValue());
2957   }
2958   if (auto *C = dyn_cast<ConstantFPSDNode>(Op)) {
2959     // We know all of the bits for a constant fp!
2960     return KnownBits::makeConstant(C->getValueAPF().bitcastToAPInt());
2961   }
2962 
2963   if (Depth >= MaxRecursionDepth)
2964     return Known;  // Limit search depth.
2965 
2966   KnownBits Known2;
2967   unsigned NumElts = DemandedElts.getBitWidth();
2968   assert((!Op.getValueType().isFixedLengthVector() ||
2969           NumElts == Op.getValueType().getVectorNumElements()) &&
2970          "Unexpected vector size");
2971 
2972   if (!DemandedElts)
2973     return Known;  // No demanded elts, better to assume we don't know anything.
2974 
2975   unsigned Opcode = Op.getOpcode();
2976   switch (Opcode) {
2977   case ISD::MERGE_VALUES:
2978     return computeKnownBits(Op.getOperand(Op.getResNo()), DemandedElts,
2979                             Depth + 1);
2980   case ISD::SPLAT_VECTOR: {
2981     SDValue SrcOp = Op.getOperand(0);
2982     assert(SrcOp.getValueSizeInBits() >= BitWidth &&
2983            "Expected SPLAT_VECTOR implicit truncation");
2984     // Implicitly truncate the bits to match the official semantics of
2985     // SPLAT_VECTOR.
2986     Known = computeKnownBits(SrcOp, Depth + 1).trunc(BitWidth);
2987     break;
2988   }
2989   case ISD::BUILD_VECTOR:
2990     assert(!Op.getValueType().isScalableVector());
2991     // Collect the known bits that are shared by every demanded vector element.
2992     Known.Zero.setAllBits(); Known.One.setAllBits();
2993     for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
2994       if (!DemandedElts[i])
2995         continue;
2996 
2997       SDValue SrcOp = Op.getOperand(i);
2998       Known2 = computeKnownBits(SrcOp, Depth + 1);
2999 
3000       // BUILD_VECTOR can implicitly truncate sources, we must handle this.
3001       if (SrcOp.getValueSizeInBits() != BitWidth) {
3002         assert(SrcOp.getValueSizeInBits() > BitWidth &&
3003                "Expected BUILD_VECTOR implicit truncation");
3004         Known2 = Known2.trunc(BitWidth);
3005       }
3006 
3007       // Known bits are the values that are shared by every demanded element.
3008       Known = KnownBits::commonBits(Known, Known2);
3009 
3010       // If we don't know any bits, early out.
3011       if (Known.isUnknown())
3012         break;
3013     }
3014     break;
3015   case ISD::VECTOR_SHUFFLE: {
3016     assert(!Op.getValueType().isScalableVector());
3017     // Collect the known bits that are shared by every vector element referenced
3018     // by the shuffle.
3019     APInt DemandedLHS, DemandedRHS;
3020     const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op);
3021     assert(NumElts == SVN->getMask().size() && "Unexpected vector size");
3022     if (!getShuffleDemandedElts(NumElts, SVN->getMask(), DemandedElts,
3023                                 DemandedLHS, DemandedRHS))
3024       break;
3025 
3026     // Known bits are the values that are shared by every demanded element.
3027     Known.Zero.setAllBits(); Known.One.setAllBits();
3028     if (!!DemandedLHS) {
3029       SDValue LHS = Op.getOperand(0);
3030       Known2 = computeKnownBits(LHS, DemandedLHS, Depth + 1);
3031       Known = KnownBits::commonBits(Known, Known2);
3032     }
3033     // If we don't know any bits, early out.
3034     if (Known.isUnknown())
3035       break;
3036     if (!!DemandedRHS) {
3037       SDValue RHS = Op.getOperand(1);
3038       Known2 = computeKnownBits(RHS, DemandedRHS, Depth + 1);
3039       Known = KnownBits::commonBits(Known, Known2);
3040     }
3041     break;
3042   }
3043   case ISD::CONCAT_VECTORS: {
3044     if (Op.getValueType().isScalableVector())
3045       break;
3046     // Split DemandedElts and test each of the demanded subvectors.
3047     Known.Zero.setAllBits(); Known.One.setAllBits();
3048     EVT SubVectorVT = Op.getOperand(0).getValueType();
3049     unsigned NumSubVectorElts = SubVectorVT.getVectorNumElements();
3050     unsigned NumSubVectors = Op.getNumOperands();
3051     for (unsigned i = 0; i != NumSubVectors; ++i) {
3052       APInt DemandedSub =
3053           DemandedElts.extractBits(NumSubVectorElts, i * NumSubVectorElts);
3054       if (!!DemandedSub) {
3055         SDValue Sub = Op.getOperand(i);
3056         Known2 = computeKnownBits(Sub, DemandedSub, Depth + 1);
3057         Known = KnownBits::commonBits(Known, Known2);
3058       }
3059       // If we don't know any bits, early out.
3060       if (Known.isUnknown())
3061         break;
3062     }
3063     break;
3064   }
3065   case ISD::INSERT_SUBVECTOR: {
3066     if (Op.getValueType().isScalableVector())
3067       break;
3068     // Demand any elements from the subvector and the remainder from the src its
3069     // inserted into.
3070     SDValue Src = Op.getOperand(0);
3071     SDValue Sub = Op.getOperand(1);
3072     uint64_t Idx = Op.getConstantOperandVal(2);
3073     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
3074     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
3075     APInt DemandedSrcElts = DemandedElts;
3076     DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
3077 
3078     Known.One.setAllBits();
3079     Known.Zero.setAllBits();
3080     if (!!DemandedSubElts) {
3081       Known = computeKnownBits(Sub, DemandedSubElts, Depth + 1);
3082       if (Known.isUnknown())
3083         break; // early-out.
3084     }
3085     if (!!DemandedSrcElts) {
3086       Known2 = computeKnownBits(Src, DemandedSrcElts, Depth + 1);
3087       Known = KnownBits::commonBits(Known, Known2);
3088     }
3089     break;
3090   }
3091   case ISD::EXTRACT_SUBVECTOR: {
3092     // Offset the demanded elts by the subvector index.
3093     SDValue Src = Op.getOperand(0);
3094     // Bail until we can represent demanded elements for scalable vectors.
3095     if (Op.getValueType().isScalableVector() || Src.getValueType().isScalableVector())
3096       break;
3097     uint64_t Idx = Op.getConstantOperandVal(1);
3098     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3099     APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
3100     Known = computeKnownBits(Src, DemandedSrcElts, Depth + 1);
3101     break;
3102   }
3103   case ISD::SCALAR_TO_VECTOR: {
3104     if (Op.getValueType().isScalableVector())
3105       break;
3106     // We know about scalar_to_vector as much as we know about it source,
3107     // which becomes the first element of otherwise unknown vector.
3108     if (DemandedElts != 1)
3109       break;
3110 
3111     SDValue N0 = Op.getOperand(0);
3112     Known = computeKnownBits(N0, Depth + 1);
3113     if (N0.getValueSizeInBits() != BitWidth)
3114       Known = Known.trunc(BitWidth);
3115 
3116     break;
3117   }
3118   case ISD::BITCAST: {
3119     if (Op.getValueType().isScalableVector())
3120       break;
3121 
3122     SDValue N0 = Op.getOperand(0);
3123     EVT SubVT = N0.getValueType();
3124     unsigned SubBitWidth = SubVT.getScalarSizeInBits();
3125 
3126     // Ignore bitcasts from unsupported types.
3127     if (!(SubVT.isInteger() || SubVT.isFloatingPoint()))
3128       break;
3129 
3130     // Fast handling of 'identity' bitcasts.
3131     if (BitWidth == SubBitWidth) {
3132       Known = computeKnownBits(N0, DemandedElts, Depth + 1);
3133       break;
3134     }
3135 
3136     bool IsLE = getDataLayout().isLittleEndian();
3137 
3138     // Bitcast 'small element' vector to 'large element' scalar/vector.
3139     if ((BitWidth % SubBitWidth) == 0) {
3140       assert(N0.getValueType().isVector() && "Expected bitcast from vector");
3141 
3142       // Collect known bits for the (larger) output by collecting the known
3143       // bits from each set of sub elements and shift these into place.
3144       // We need to separately call computeKnownBits for each set of
3145       // sub elements as the knownbits for each is likely to be different.
3146       unsigned SubScale = BitWidth / SubBitWidth;
3147       APInt SubDemandedElts(NumElts * SubScale, 0);
3148       for (unsigned i = 0; i != NumElts; ++i)
3149         if (DemandedElts[i])
3150           SubDemandedElts.setBit(i * SubScale);
3151 
3152       for (unsigned i = 0; i != SubScale; ++i) {
3153         Known2 = computeKnownBits(N0, SubDemandedElts.shl(i),
3154                          Depth + 1);
3155         unsigned Shifts = IsLE ? i : SubScale - 1 - i;
3156         Known.insertBits(Known2, SubBitWidth * Shifts);
3157       }
3158     }
3159 
3160     // Bitcast 'large element' scalar/vector to 'small element' vector.
3161     if ((SubBitWidth % BitWidth) == 0) {
3162       assert(Op.getValueType().isVector() && "Expected bitcast to vector");
3163 
3164       // Collect known bits for the (smaller) output by collecting the known
3165       // bits from the overlapping larger input elements and extracting the
3166       // sub sections we actually care about.
3167       unsigned SubScale = SubBitWidth / BitWidth;
3168       APInt SubDemandedElts =
3169           APIntOps::ScaleBitMask(DemandedElts, NumElts / SubScale);
3170       Known2 = computeKnownBits(N0, SubDemandedElts, Depth + 1);
3171 
3172       Known.Zero.setAllBits(); Known.One.setAllBits();
3173       for (unsigned i = 0; i != NumElts; ++i)
3174         if (DemandedElts[i]) {
3175           unsigned Shifts = IsLE ? i : NumElts - 1 - i;
3176           unsigned Offset = (Shifts % SubScale) * BitWidth;
3177           Known = KnownBits::commonBits(Known,
3178                                         Known2.extractBits(BitWidth, Offset));
3179           // If we don't know any bits, early out.
3180           if (Known.isUnknown())
3181             break;
3182         }
3183     }
3184     break;
3185   }
3186   case ISD::AND:
3187     Known = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3188     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3189 
3190     Known &= Known2;
3191     break;
3192   case ISD::OR:
3193     Known = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3194     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3195 
3196     Known |= Known2;
3197     break;
3198   case ISD::XOR:
3199     Known = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3200     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3201 
3202     Known ^= Known2;
3203     break;
3204   case ISD::MUL: {
3205     Known = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3206     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3207     bool SelfMultiply = Op.getOperand(0) == Op.getOperand(1);
3208     // TODO: SelfMultiply can be poison, but not undef.
3209     if (SelfMultiply)
3210       SelfMultiply &= isGuaranteedNotToBeUndefOrPoison(
3211           Op.getOperand(0), DemandedElts, false, Depth + 1);
3212     Known = KnownBits::mul(Known, Known2, SelfMultiply);
3213 
3214     // If the multiplication is known not to overflow, the product of a number
3215     // with itself is non-negative. Only do this if we didn't already computed
3216     // the opposite value for the sign bit.
3217     if (Op->getFlags().hasNoSignedWrap() &&
3218         Op.getOperand(0) == Op.getOperand(1) &&
3219         !Known.isNegative())
3220       Known.makeNonNegative();
3221     break;
3222   }
3223   case ISD::MULHU: {
3224     Known = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3225     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3226     Known = KnownBits::mulhu(Known, Known2);
3227     break;
3228   }
3229   case ISD::MULHS: {
3230     Known = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3231     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3232     Known = KnownBits::mulhs(Known, Known2);
3233     break;
3234   }
3235   case ISD::UMUL_LOHI: {
3236     assert((Op.getResNo() == 0 || Op.getResNo() == 1) && "Unknown result");
3237     Known = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3238     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3239     bool SelfMultiply = Op.getOperand(0) == Op.getOperand(1);
3240     if (Op.getResNo() == 0)
3241       Known = KnownBits::mul(Known, Known2, SelfMultiply);
3242     else
3243       Known = KnownBits::mulhu(Known, Known2);
3244     break;
3245   }
3246   case ISD::SMUL_LOHI: {
3247     assert((Op.getResNo() == 0 || Op.getResNo() == 1) && "Unknown result");
3248     Known = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3249     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3250     bool SelfMultiply = Op.getOperand(0) == Op.getOperand(1);
3251     if (Op.getResNo() == 0)
3252       Known = KnownBits::mul(Known, Known2, SelfMultiply);
3253     else
3254       Known = KnownBits::mulhs(Known, Known2);
3255     break;
3256   }
3257   case ISD::AVGCEILU: {
3258     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3259     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3260     Known = Known.zext(BitWidth + 1);
3261     Known2 = Known2.zext(BitWidth + 1);
3262     KnownBits One = KnownBits::makeConstant(APInt(1, 1));
3263     Known = KnownBits::computeForAddCarry(Known, Known2, One);
3264     Known = Known.extractBits(BitWidth, 1);
3265     break;
3266   }
3267   case ISD::SELECT:
3268   case ISD::VSELECT:
3269     Known = computeKnownBits(Op.getOperand(2), DemandedElts, Depth+1);
3270     // If we don't know any bits, early out.
3271     if (Known.isUnknown())
3272       break;
3273     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth+1);
3274 
3275     // Only known if known in both the LHS and RHS.
3276     Known = KnownBits::commonBits(Known, Known2);
3277     break;
3278   case ISD::SELECT_CC:
3279     Known = computeKnownBits(Op.getOperand(3), DemandedElts, Depth+1);
3280     // If we don't know any bits, early out.
3281     if (Known.isUnknown())
3282       break;
3283     Known2 = computeKnownBits(Op.getOperand(2), DemandedElts, Depth+1);
3284 
3285     // Only known if known in both the LHS and RHS.
3286     Known = KnownBits::commonBits(Known, Known2);
3287     break;
3288   case ISD::SMULO:
3289   case ISD::UMULO:
3290     if (Op.getResNo() != 1)
3291       break;
3292     // The boolean result conforms to getBooleanContents.
3293     // If we know the result of a setcc has the top bits zero, use this info.
3294     // We know that we have an integer-based boolean since these operations
3295     // are only available for integer.
3296     if (TLI->getBooleanContents(Op.getValueType().isVector(), false) ==
3297             TargetLowering::ZeroOrOneBooleanContent &&
3298         BitWidth > 1)
3299       Known.Zero.setBitsFrom(1);
3300     break;
3301   case ISD::SETCC:
3302   case ISD::SETCCCARRY:
3303   case ISD::STRICT_FSETCC:
3304   case ISD::STRICT_FSETCCS: {
3305     unsigned OpNo = Op->isStrictFPOpcode() ? 1 : 0;
3306     // If we know the result of a setcc has the top bits zero, use this info.
3307     if (TLI->getBooleanContents(Op.getOperand(OpNo).getValueType()) ==
3308             TargetLowering::ZeroOrOneBooleanContent &&
3309         BitWidth > 1)
3310       Known.Zero.setBitsFrom(1);
3311     break;
3312   }
3313   case ISD::SHL:
3314     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3315     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3316     Known = KnownBits::shl(Known, Known2);
3317 
3318     // Minimum shift low bits are known zero.
3319     if (const APInt *ShMinAmt =
3320             getValidMinimumShiftAmountConstant(Op, DemandedElts))
3321       Known.Zero.setLowBits(ShMinAmt->getZExtValue());
3322     break;
3323   case ISD::SRL:
3324     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3325     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3326     Known = KnownBits::lshr(Known, Known2);
3327 
3328     // Minimum shift high bits are known zero.
3329     if (const APInt *ShMinAmt =
3330             getValidMinimumShiftAmountConstant(Op, DemandedElts))
3331       Known.Zero.setHighBits(ShMinAmt->getZExtValue());
3332     break;
3333   case ISD::SRA:
3334     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3335     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3336     Known = KnownBits::ashr(Known, Known2);
3337     // TODO: Add minimum shift high known sign bits.
3338     break;
3339   case ISD::FSHL:
3340   case ISD::FSHR:
3341     if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) {
3342       unsigned Amt = C->getAPIntValue().urem(BitWidth);
3343 
3344       // For fshl, 0-shift returns the 1st arg.
3345       // For fshr, 0-shift returns the 2nd arg.
3346       if (Amt == 0) {
3347         Known = computeKnownBits(Op.getOperand(Opcode == ISD::FSHL ? 0 : 1),
3348                                  DemandedElts, Depth + 1);
3349         break;
3350       }
3351 
3352       // fshl: (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
3353       // fshr: (X << (BW - (Z % BW))) | (Y >> (Z % BW))
3354       Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3355       Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3356       if (Opcode == ISD::FSHL) {
3357         Known.One <<= Amt;
3358         Known.Zero <<= Amt;
3359         Known2.One.lshrInPlace(BitWidth - Amt);
3360         Known2.Zero.lshrInPlace(BitWidth - Amt);
3361       } else {
3362         Known.One <<= BitWidth - Amt;
3363         Known.Zero <<= BitWidth - Amt;
3364         Known2.One.lshrInPlace(Amt);
3365         Known2.Zero.lshrInPlace(Amt);
3366       }
3367       Known.One |= Known2.One;
3368       Known.Zero |= Known2.Zero;
3369     }
3370     break;
3371   case ISD::SHL_PARTS:
3372   case ISD::SRA_PARTS:
3373   case ISD::SRL_PARTS: {
3374     assert((Op.getResNo() == 0 || Op.getResNo() == 1) && "Unknown result");
3375 
3376     // Collect lo/hi source values and concatenate.
3377     unsigned LoBits = Op.getOperand(0).getScalarValueSizeInBits();
3378     unsigned HiBits = Op.getOperand(1).getScalarValueSizeInBits();
3379     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3380     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3381     Known = Known2.concat(Known);
3382 
3383     // Collect shift amount.
3384     Known2 = computeKnownBits(Op.getOperand(2), DemandedElts, Depth + 1);
3385 
3386     if (Opcode == ISD::SHL_PARTS)
3387       Known = KnownBits::shl(Known, Known2);
3388     else if (Opcode == ISD::SRA_PARTS)
3389       Known = KnownBits::ashr(Known, Known2);
3390     else // if (Opcode == ISD::SRL_PARTS)
3391       Known = KnownBits::lshr(Known, Known2);
3392 
3393     // TODO: Minimum shift low/high bits are known zero.
3394 
3395     if (Op.getResNo() == 0)
3396       Known = Known.extractBits(LoBits, 0);
3397     else
3398       Known = Known.extractBits(HiBits, LoBits);
3399     break;
3400   }
3401   case ISD::SIGN_EXTEND_INREG: {
3402     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3403     EVT EVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
3404     Known = Known.sextInReg(EVT.getScalarSizeInBits());
3405     break;
3406   }
3407   case ISD::CTTZ:
3408   case ISD::CTTZ_ZERO_UNDEF: {
3409     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3410     // If we have a known 1, its position is our upper bound.
3411     unsigned PossibleTZ = Known2.countMaxTrailingZeros();
3412     unsigned LowBits = llvm::bit_width(PossibleTZ);
3413     Known.Zero.setBitsFrom(LowBits);
3414     break;
3415   }
3416   case ISD::CTLZ:
3417   case ISD::CTLZ_ZERO_UNDEF: {
3418     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3419     // If we have a known 1, its position is our upper bound.
3420     unsigned PossibleLZ = Known2.countMaxLeadingZeros();
3421     unsigned LowBits = llvm::bit_width(PossibleLZ);
3422     Known.Zero.setBitsFrom(LowBits);
3423     break;
3424   }
3425   case ISD::CTPOP: {
3426     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3427     // If we know some of the bits are zero, they can't be one.
3428     unsigned PossibleOnes = Known2.countMaxPopulation();
3429     Known.Zero.setBitsFrom(llvm::bit_width(PossibleOnes));
3430     break;
3431   }
3432   case ISD::PARITY: {
3433     // Parity returns 0 everywhere but the LSB.
3434     Known.Zero.setBitsFrom(1);
3435     break;
3436   }
3437   case ISD::LOAD: {
3438     LoadSDNode *LD = cast<LoadSDNode>(Op);
3439     const Constant *Cst = TLI->getTargetConstantFromLoad(LD);
3440     if (ISD::isNON_EXTLoad(LD) && Cst) {
3441       // Determine any common known bits from the loaded constant pool value.
3442       Type *CstTy = Cst->getType();
3443       if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits() &&
3444           !Op.getValueType().isScalableVector()) {
3445         // If its a vector splat, then we can (quickly) reuse the scalar path.
3446         // NOTE: We assume all elements match and none are UNDEF.
3447         if (CstTy->isVectorTy()) {
3448           if (const Constant *Splat = Cst->getSplatValue()) {
3449             Cst = Splat;
3450             CstTy = Cst->getType();
3451           }
3452         }
3453         // TODO - do we need to handle different bitwidths?
3454         if (CstTy->isVectorTy() && BitWidth == CstTy->getScalarSizeInBits()) {
3455           // Iterate across all vector elements finding common known bits.
3456           Known.One.setAllBits();
3457           Known.Zero.setAllBits();
3458           for (unsigned i = 0; i != NumElts; ++i) {
3459             if (!DemandedElts[i])
3460               continue;
3461             if (Constant *Elt = Cst->getAggregateElement(i)) {
3462               if (auto *CInt = dyn_cast<ConstantInt>(Elt)) {
3463                 const APInt &Value = CInt->getValue();
3464                 Known.One &= Value;
3465                 Known.Zero &= ~Value;
3466                 continue;
3467               }
3468               if (auto *CFP = dyn_cast<ConstantFP>(Elt)) {
3469                 APInt Value = CFP->getValueAPF().bitcastToAPInt();
3470                 Known.One &= Value;
3471                 Known.Zero &= ~Value;
3472                 continue;
3473               }
3474             }
3475             Known.One.clearAllBits();
3476             Known.Zero.clearAllBits();
3477             break;
3478           }
3479         } else if (BitWidth == CstTy->getPrimitiveSizeInBits()) {
3480           if (auto *CInt = dyn_cast<ConstantInt>(Cst)) {
3481             Known = KnownBits::makeConstant(CInt->getValue());
3482           } else if (auto *CFP = dyn_cast<ConstantFP>(Cst)) {
3483             Known =
3484                 KnownBits::makeConstant(CFP->getValueAPF().bitcastToAPInt());
3485           }
3486         }
3487       }
3488     } else if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
3489       // If this is a ZEXTLoad and we are looking at the loaded value.
3490       EVT VT = LD->getMemoryVT();
3491       unsigned MemBits = VT.getScalarSizeInBits();
3492       Known.Zero.setBitsFrom(MemBits);
3493     } else if (const MDNode *Ranges = LD->getRanges()) {
3494       EVT VT = LD->getValueType(0);
3495 
3496       // TODO: Handle for extending loads
3497       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
3498         if (VT.isVector()) {
3499           // Handle truncation to the first demanded element.
3500           // TODO: Figure out which demanded elements are covered
3501           if (DemandedElts != 1 || !getDataLayout().isLittleEndian())
3502             break;
3503 
3504           // Handle the case where a load has a vector type, but scalar memory
3505           // with an attached range.
3506           EVT MemVT = LD->getMemoryVT();
3507           KnownBits KnownFull(MemVT.getSizeInBits());
3508 
3509           computeKnownBitsFromRangeMetadata(*Ranges, KnownFull);
3510           Known = KnownFull.trunc(BitWidth);
3511         } else
3512           computeKnownBitsFromRangeMetadata(*Ranges, Known);
3513       }
3514     }
3515     break;
3516   }
3517   case ISD::ZERO_EXTEND_VECTOR_INREG: {
3518     if (Op.getValueType().isScalableVector())
3519       break;
3520     EVT InVT = Op.getOperand(0).getValueType();
3521     APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
3522     Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
3523     Known = Known.zext(BitWidth);
3524     break;
3525   }
3526   case ISD::ZERO_EXTEND: {
3527     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3528     Known = Known.zext(BitWidth);
3529     break;
3530   }
3531   case ISD::SIGN_EXTEND_VECTOR_INREG: {
3532     if (Op.getValueType().isScalableVector())
3533       break;
3534     EVT InVT = Op.getOperand(0).getValueType();
3535     APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
3536     Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
3537     // If the sign bit is known to be zero or one, then sext will extend
3538     // it to the top bits, else it will just zext.
3539     Known = Known.sext(BitWidth);
3540     break;
3541   }
3542   case ISD::SIGN_EXTEND: {
3543     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3544     // If the sign bit is known to be zero or one, then sext will extend
3545     // it to the top bits, else it will just zext.
3546     Known = Known.sext(BitWidth);
3547     break;
3548   }
3549   case ISD::ANY_EXTEND_VECTOR_INREG: {
3550     if (Op.getValueType().isScalableVector())
3551       break;
3552     EVT InVT = Op.getOperand(0).getValueType();
3553     APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
3554     Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
3555     Known = Known.anyext(BitWidth);
3556     break;
3557   }
3558   case ISD::ANY_EXTEND: {
3559     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3560     Known = Known.anyext(BitWidth);
3561     break;
3562   }
3563   case ISD::TRUNCATE: {
3564     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3565     Known = Known.trunc(BitWidth);
3566     break;
3567   }
3568   case ISD::AssertZext: {
3569     EVT VT = cast<VTSDNode>(Op.getOperand(1))->getVT();
3570     APInt InMask = APInt::getLowBitsSet(BitWidth, VT.getSizeInBits());
3571     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3572     Known.Zero |= (~InMask);
3573     Known.One  &= (~Known.Zero);
3574     break;
3575   }
3576   case ISD::AssertAlign: {
3577     unsigned LogOfAlign = Log2(cast<AssertAlignSDNode>(Op)->getAlign());
3578     assert(LogOfAlign != 0);
3579 
3580     // TODO: Should use maximum with source
3581     // If a node is guaranteed to be aligned, set low zero bits accordingly as
3582     // well as clearing one bits.
3583     Known.Zero.setLowBits(LogOfAlign);
3584     Known.One.clearLowBits(LogOfAlign);
3585     break;
3586   }
3587   case ISD::FGETSIGN:
3588     // All bits are zero except the low bit.
3589     Known.Zero.setBitsFrom(1);
3590     break;
3591   case ISD::USUBO:
3592   case ISD::SSUBO:
3593   case ISD::SUBCARRY:
3594   case ISD::SSUBO_CARRY:
3595     if (Op.getResNo() == 1) {
3596       // If we know the result of a setcc has the top bits zero, use this info.
3597       if (TLI->getBooleanContents(Op.getOperand(0).getValueType()) ==
3598               TargetLowering::ZeroOrOneBooleanContent &&
3599           BitWidth > 1)
3600         Known.Zero.setBitsFrom(1);
3601       break;
3602     }
3603     [[fallthrough]];
3604   case ISD::SUB:
3605   case ISD::SUBC: {
3606     assert(Op.getResNo() == 0 &&
3607            "We only compute knownbits for the difference here.");
3608 
3609     // TODO: Compute influence of the carry operand.
3610     if (Opcode == ISD::SUBCARRY || Opcode == ISD::SSUBO_CARRY)
3611       break;
3612 
3613     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3614     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3615     Known = KnownBits::computeForAddSub(/* Add */ false, /* NSW */ false,
3616                                         Known, Known2);
3617     break;
3618   }
3619   case ISD::UADDO:
3620   case ISD::SADDO:
3621   case ISD::ADDCARRY:
3622   case ISD::SADDO_CARRY:
3623     if (Op.getResNo() == 1) {
3624       // If we know the result of a setcc has the top bits zero, use this info.
3625       if (TLI->getBooleanContents(Op.getOperand(0).getValueType()) ==
3626               TargetLowering::ZeroOrOneBooleanContent &&
3627           BitWidth > 1)
3628         Known.Zero.setBitsFrom(1);
3629       break;
3630     }
3631     [[fallthrough]];
3632   case ISD::ADD:
3633   case ISD::ADDC:
3634   case ISD::ADDE: {
3635     assert(Op.getResNo() == 0 && "We only compute knownbits for the sum here.");
3636 
3637     // With ADDE and ADDCARRY, a carry bit may be added in.
3638     KnownBits Carry(1);
3639     if (Opcode == ISD::ADDE)
3640       // Can't track carry from glue, set carry to unknown.
3641       Carry.resetAll();
3642     else if (Opcode == ISD::ADDCARRY || Opcode == ISD::SADDO_CARRY)
3643       // TODO: Compute known bits for the carry operand. Not sure if it is worth
3644       // the trouble (how often will we find a known carry bit). And I haven't
3645       // tested this very much yet, but something like this might work:
3646       //   Carry = computeKnownBits(Op.getOperand(2), DemandedElts, Depth + 1);
3647       //   Carry = Carry.zextOrTrunc(1, false);
3648       Carry.resetAll();
3649     else
3650       Carry.setAllZero();
3651 
3652     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3653     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3654     Known = KnownBits::computeForAddCarry(Known, Known2, Carry);
3655     break;
3656   }
3657   case ISD::UDIV: {
3658     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3659     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3660     Known = KnownBits::udiv(Known, Known2);
3661     break;
3662   }
3663   case ISD::SREM: {
3664     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3665     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3666     Known = KnownBits::srem(Known, Known2);
3667     break;
3668   }
3669   case ISD::UREM: {
3670     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3671     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3672     Known = KnownBits::urem(Known, Known2);
3673     break;
3674   }
3675   case ISD::EXTRACT_ELEMENT: {
3676     Known = computeKnownBits(Op.getOperand(0), Depth+1);
3677     const unsigned Index = Op.getConstantOperandVal(1);
3678     const unsigned EltBitWidth = Op.getValueSizeInBits();
3679 
3680     // Remove low part of known bits mask
3681     Known.Zero = Known.Zero.getHiBits(Known.getBitWidth() - Index * EltBitWidth);
3682     Known.One = Known.One.getHiBits(Known.getBitWidth() - Index * EltBitWidth);
3683 
3684     // Remove high part of known bit mask
3685     Known = Known.trunc(EltBitWidth);
3686     break;
3687   }
3688   case ISD::EXTRACT_VECTOR_ELT: {
3689     SDValue InVec = Op.getOperand(0);
3690     SDValue EltNo = Op.getOperand(1);
3691     EVT VecVT = InVec.getValueType();
3692     // computeKnownBits not yet implemented for scalable vectors.
3693     if (VecVT.isScalableVector())
3694       break;
3695     const unsigned EltBitWidth = VecVT.getScalarSizeInBits();
3696     const unsigned NumSrcElts = VecVT.getVectorNumElements();
3697 
3698     // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know
3699     // anything about the extended bits.
3700     if (BitWidth > EltBitWidth)
3701       Known = Known.trunc(EltBitWidth);
3702 
3703     // If we know the element index, just demand that vector element, else for
3704     // an unknown element index, ignore DemandedElts and demand them all.
3705     APInt DemandedSrcElts = APInt::getAllOnes(NumSrcElts);
3706     auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo);
3707     if (ConstEltNo && ConstEltNo->getAPIntValue().ult(NumSrcElts))
3708       DemandedSrcElts =
3709           APInt::getOneBitSet(NumSrcElts, ConstEltNo->getZExtValue());
3710 
3711     Known = computeKnownBits(InVec, DemandedSrcElts, Depth + 1);
3712     if (BitWidth > EltBitWidth)
3713       Known = Known.anyext(BitWidth);
3714     break;
3715   }
3716   case ISD::INSERT_VECTOR_ELT: {
3717     if (Op.getValueType().isScalableVector())
3718       break;
3719 
3720     // If we know the element index, split the demand between the
3721     // source vector and the inserted element, otherwise assume we need
3722     // the original demanded vector elements and the value.
3723     SDValue InVec = Op.getOperand(0);
3724     SDValue InVal = Op.getOperand(1);
3725     SDValue EltNo = Op.getOperand(2);
3726     bool DemandedVal = true;
3727     APInt DemandedVecElts = DemandedElts;
3728     auto *CEltNo = dyn_cast<ConstantSDNode>(EltNo);
3729     if (CEltNo && CEltNo->getAPIntValue().ult(NumElts)) {
3730       unsigned EltIdx = CEltNo->getZExtValue();
3731       DemandedVal = !!DemandedElts[EltIdx];
3732       DemandedVecElts.clearBit(EltIdx);
3733     }
3734     Known.One.setAllBits();
3735     Known.Zero.setAllBits();
3736     if (DemandedVal) {
3737       Known2 = computeKnownBits(InVal, Depth + 1);
3738       Known = KnownBits::commonBits(Known, Known2.zextOrTrunc(BitWidth));
3739     }
3740     if (!!DemandedVecElts) {
3741       Known2 = computeKnownBits(InVec, DemandedVecElts, Depth + 1);
3742       Known = KnownBits::commonBits(Known, Known2);
3743     }
3744     break;
3745   }
3746   case ISD::BITREVERSE: {
3747     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3748     Known = Known2.reverseBits();
3749     break;
3750   }
3751   case ISD::BSWAP: {
3752     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3753     Known = Known2.byteSwap();
3754     break;
3755   }
3756   case ISD::ABS: {
3757     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3758     Known = Known2.abs();
3759     break;
3760   }
3761   case ISD::USUBSAT: {
3762     // The result of usubsat will never be larger than the LHS.
3763     Known2 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3764     Known.Zero.setHighBits(Known2.countMinLeadingZeros());
3765     break;
3766   }
3767   case ISD::UMIN: {
3768     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3769     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3770     Known = KnownBits::umin(Known, Known2);
3771     break;
3772   }
3773   case ISD::UMAX: {
3774     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3775     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3776     Known = KnownBits::umax(Known, Known2);
3777     break;
3778   }
3779   case ISD::SMIN:
3780   case ISD::SMAX: {
3781     // If we have a clamp pattern, we know that the number of sign bits will be
3782     // the minimum of the clamp min/max range.
3783     bool IsMax = (Opcode == ISD::SMAX);
3784     ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
3785     if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
3786       if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
3787         CstHigh =
3788             isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
3789     if (CstLow && CstHigh) {
3790       if (!IsMax)
3791         std::swap(CstLow, CstHigh);
3792 
3793       const APInt &ValueLow = CstLow->getAPIntValue();
3794       const APInt &ValueHigh = CstHigh->getAPIntValue();
3795       if (ValueLow.sle(ValueHigh)) {
3796         unsigned LowSignBits = ValueLow.getNumSignBits();
3797         unsigned HighSignBits = ValueHigh.getNumSignBits();
3798         unsigned MinSignBits = std::min(LowSignBits, HighSignBits);
3799         if (ValueLow.isNegative() && ValueHigh.isNegative()) {
3800           Known.One.setHighBits(MinSignBits);
3801           break;
3802         }
3803         if (ValueLow.isNonNegative() && ValueHigh.isNonNegative()) {
3804           Known.Zero.setHighBits(MinSignBits);
3805           break;
3806         }
3807       }
3808     }
3809 
3810     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
3811     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
3812     if (IsMax)
3813       Known = KnownBits::smax(Known, Known2);
3814     else
3815       Known = KnownBits::smin(Known, Known2);
3816 
3817     // For SMAX, if CstLow is non-negative we know the result will be
3818     // non-negative and thus all sign bits are 0.
3819     // TODO: There's an equivalent of this for smin with negative constant for
3820     // known ones.
3821     if (IsMax && CstLow) {
3822       const APInt &ValueLow = CstLow->getAPIntValue();
3823       if (ValueLow.isNonNegative()) {
3824         unsigned SignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
3825         Known.Zero.setHighBits(std::min(SignBits, ValueLow.getNumSignBits()));
3826       }
3827     }
3828 
3829     break;
3830   }
3831   case ISD::FP_TO_UINT_SAT: {
3832     // FP_TO_UINT_SAT produces an unsigned value that fits in the saturating VT.
3833     EVT VT = cast<VTSDNode>(Op.getOperand(1))->getVT();
3834     Known.Zero |= APInt::getBitsSetFrom(BitWidth, VT.getScalarSizeInBits());
3835     break;
3836   }
3837   case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
3838     if (Op.getResNo() == 1) {
3839       // The boolean result conforms to getBooleanContents.
3840       // If we know the result of a setcc has the top bits zero, use this info.
3841       // We know that we have an integer-based boolean since these operations
3842       // are only available for integer.
3843       if (TLI->getBooleanContents(Op.getValueType().isVector(), false) ==
3844               TargetLowering::ZeroOrOneBooleanContent &&
3845           BitWidth > 1)
3846         Known.Zero.setBitsFrom(1);
3847       break;
3848     }
3849     [[fallthrough]];
3850   case ISD::ATOMIC_CMP_SWAP:
3851   case ISD::ATOMIC_SWAP:
3852   case ISD::ATOMIC_LOAD_ADD:
3853   case ISD::ATOMIC_LOAD_SUB:
3854   case ISD::ATOMIC_LOAD_AND:
3855   case ISD::ATOMIC_LOAD_CLR:
3856   case ISD::ATOMIC_LOAD_OR:
3857   case ISD::ATOMIC_LOAD_XOR:
3858   case ISD::ATOMIC_LOAD_NAND:
3859   case ISD::ATOMIC_LOAD_MIN:
3860   case ISD::ATOMIC_LOAD_MAX:
3861   case ISD::ATOMIC_LOAD_UMIN:
3862   case ISD::ATOMIC_LOAD_UMAX:
3863   case ISD::ATOMIC_LOAD: {
3864     unsigned MemBits =
3865         cast<AtomicSDNode>(Op)->getMemoryVT().getScalarSizeInBits();
3866     // If we are looking at the loaded value.
3867     if (Op.getResNo() == 0) {
3868       if (TLI->getExtendForAtomicOps() == ISD::ZERO_EXTEND)
3869         Known.Zero.setBitsFrom(MemBits);
3870     }
3871     break;
3872   }
3873   case ISD::FrameIndex:
3874   case ISD::TargetFrameIndex:
3875     TLI->computeKnownBitsForFrameIndex(cast<FrameIndexSDNode>(Op)->getIndex(),
3876                                        Known, getMachineFunction());
3877     break;
3878 
3879   default:
3880     if (Opcode < ISD::BUILTIN_OP_END)
3881       break;
3882     [[fallthrough]];
3883   case ISD::INTRINSIC_WO_CHAIN:
3884   case ISD::INTRINSIC_W_CHAIN:
3885   case ISD::INTRINSIC_VOID:
3886     // TODO: Probably okay to remove after audit; here to reduce change size
3887     // in initial enablement patch for scalable vectors
3888     if (Op.getValueType().isScalableVector())
3889       break;
3890 
3891     // Allow the target to implement this method for its nodes.
3892     TLI->computeKnownBitsForTargetNode(Op, Known, DemandedElts, *this, Depth);
3893     break;
3894   }
3895 
3896   assert(!Known.hasConflict() && "Bits known to be one AND zero?");
3897   return Known;
3898 }
3899 
computeOverflowKind(SDValue N0,SDValue N1) const3900 SelectionDAG::OverflowKind SelectionDAG::computeOverflowKind(SDValue N0,
3901                                                              SDValue N1) const {
3902   // X + 0 never overflow
3903   if (isNullConstant(N1))
3904     return OFK_Never;
3905 
3906   KnownBits N1Known = computeKnownBits(N1);
3907   if (N1Known.Zero.getBoolValue()) {
3908     KnownBits N0Known = computeKnownBits(N0);
3909 
3910     bool overflow;
3911     (void)N0Known.getMaxValue().uadd_ov(N1Known.getMaxValue(), overflow);
3912     if (!overflow)
3913       return OFK_Never;
3914   }
3915 
3916   // mulhi + 1 never overflow
3917   if (N0.getOpcode() == ISD::UMUL_LOHI && N0.getResNo() == 1 &&
3918       (N1Known.getMaxValue() & 0x01) == N1Known.getMaxValue())
3919     return OFK_Never;
3920 
3921   if (N1.getOpcode() == ISD::UMUL_LOHI && N1.getResNo() == 1) {
3922     KnownBits N0Known = computeKnownBits(N0);
3923 
3924     if ((N0Known.getMaxValue() & 0x01) == N0Known.getMaxValue())
3925       return OFK_Never;
3926   }
3927 
3928   return OFK_Sometime;
3929 }
3930 
isKnownToBeAPowerOfTwo(SDValue Val) const3931 bool SelectionDAG::isKnownToBeAPowerOfTwo(SDValue Val) const {
3932   EVT OpVT = Val.getValueType();
3933   unsigned BitWidth = OpVT.getScalarSizeInBits();
3934 
3935   // Is the constant a known power of 2?
3936   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val))
3937     return Const->getAPIntValue().zextOrTrunc(BitWidth).isPowerOf2();
3938 
3939   // A left-shift of a constant one will have exactly one bit set because
3940   // shifting the bit off the end is undefined.
3941   if (Val.getOpcode() == ISD::SHL) {
3942     auto *C = isConstOrConstSplat(Val.getOperand(0));
3943     if (C && C->getAPIntValue() == 1)
3944       return true;
3945   }
3946 
3947   // Similarly, a logical right-shift of a constant sign-bit will have exactly
3948   // one bit set.
3949   if (Val.getOpcode() == ISD::SRL) {
3950     auto *C = isConstOrConstSplat(Val.getOperand(0));
3951     if (C && C->getAPIntValue().isSignMask())
3952       return true;
3953   }
3954 
3955   // Are all operands of a build vector constant powers of two?
3956   if (Val.getOpcode() == ISD::BUILD_VECTOR)
3957     if (llvm::all_of(Val->ops(), [BitWidth](SDValue E) {
3958           if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(E))
3959             return C->getAPIntValue().zextOrTrunc(BitWidth).isPowerOf2();
3960           return false;
3961         }))
3962       return true;
3963 
3964   // Is the operand of a splat vector a constant power of two?
3965   if (Val.getOpcode() == ISD::SPLAT_VECTOR)
3966     if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val->getOperand(0)))
3967       if (C->getAPIntValue().zextOrTrunc(BitWidth).isPowerOf2())
3968         return true;
3969 
3970   // vscale(power-of-two) is a power-of-two for some targets
3971   if (Val.getOpcode() == ISD::VSCALE &&
3972       getTargetLoweringInfo().isVScaleKnownToBeAPowerOfTwo() &&
3973       isKnownToBeAPowerOfTwo(Val.getOperand(0)))
3974     return true;
3975 
3976   // More could be done here, though the above checks are enough
3977   // to handle some common cases.
3978 
3979   // Fall back to computeKnownBits to catch other known cases.
3980   KnownBits Known = computeKnownBits(Val);
3981   return (Known.countMaxPopulation() == 1) && (Known.countMinPopulation() == 1);
3982 }
3983 
ComputeNumSignBits(SDValue Op,unsigned Depth) const3984 unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, unsigned Depth) const {
3985   EVT VT = Op.getValueType();
3986 
3987   // Since the number of lanes in a scalable vector is unknown at compile time,
3988   // we track one bit which is implicitly broadcast to all lanes.  This means
3989   // that all lanes in a scalable vector are considered demanded.
3990   APInt DemandedElts = VT.isFixedLengthVector()
3991                            ? APInt::getAllOnes(VT.getVectorNumElements())
3992                            : APInt(1, 1);
3993   return ComputeNumSignBits(Op, DemandedElts, Depth);
3994 }
3995 
ComputeNumSignBits(SDValue Op,const APInt & DemandedElts,unsigned Depth) const3996 unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
3997                                           unsigned Depth) const {
3998   EVT VT = Op.getValueType();
3999   assert((VT.isInteger() || VT.isFloatingPoint()) && "Invalid VT!");
4000   unsigned VTBits = VT.getScalarSizeInBits();
4001   unsigned NumElts = DemandedElts.getBitWidth();
4002   unsigned Tmp, Tmp2;
4003   unsigned FirstAnswer = 1;
4004 
4005   if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
4006     const APInt &Val = C->getAPIntValue();
4007     return Val.getNumSignBits();
4008   }
4009 
4010   if (Depth >= MaxRecursionDepth)
4011     return 1;  // Limit search depth.
4012 
4013   if (!DemandedElts)
4014     return 1;  // No demanded elts, better to assume we don't know anything.
4015 
4016   unsigned Opcode = Op.getOpcode();
4017   switch (Opcode) {
4018   default: break;
4019   case ISD::AssertSext:
4020     Tmp = cast<VTSDNode>(Op.getOperand(1))->getVT().getSizeInBits();
4021     return VTBits-Tmp+1;
4022   case ISD::AssertZext:
4023     Tmp = cast<VTSDNode>(Op.getOperand(1))->getVT().getSizeInBits();
4024     return VTBits-Tmp;
4025   case ISD::MERGE_VALUES:
4026     return ComputeNumSignBits(Op.getOperand(Op.getResNo()), DemandedElts,
4027                               Depth + 1);
4028   case ISD::SPLAT_VECTOR: {
4029     // Check if the sign bits of source go down as far as the truncated value.
4030     unsigned NumSrcBits = Op.getOperand(0).getValueSizeInBits();
4031     unsigned NumSrcSignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
4032     if (NumSrcSignBits > (NumSrcBits - VTBits))
4033       return NumSrcSignBits - (NumSrcBits - VTBits);
4034     break;
4035   }
4036   case ISD::BUILD_VECTOR:
4037     assert(!VT.isScalableVector());
4038     Tmp = VTBits;
4039     for (unsigned i = 0, e = Op.getNumOperands(); (i < e) && (Tmp > 1); ++i) {
4040       if (!DemandedElts[i])
4041         continue;
4042 
4043       SDValue SrcOp = Op.getOperand(i);
4044       Tmp2 = ComputeNumSignBits(SrcOp, Depth + 1);
4045 
4046       // BUILD_VECTOR can implicitly truncate sources, we must handle this.
4047       if (SrcOp.getValueSizeInBits() != VTBits) {
4048         assert(SrcOp.getValueSizeInBits() > VTBits &&
4049                "Expected BUILD_VECTOR implicit truncation");
4050         unsigned ExtraBits = SrcOp.getValueSizeInBits() - VTBits;
4051         Tmp2 = (Tmp2 > ExtraBits ? Tmp2 - ExtraBits : 1);
4052       }
4053       Tmp = std::min(Tmp, Tmp2);
4054     }
4055     return Tmp;
4056 
4057   case ISD::VECTOR_SHUFFLE: {
4058     // Collect the minimum number of sign bits that are shared by every vector
4059     // element referenced by the shuffle.
4060     APInt DemandedLHS, DemandedRHS;
4061     const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op);
4062     assert(NumElts == SVN->getMask().size() && "Unexpected vector size");
4063     if (!getShuffleDemandedElts(NumElts, SVN->getMask(), DemandedElts,
4064                                 DemandedLHS, DemandedRHS))
4065       return 1;
4066 
4067     Tmp = std::numeric_limits<unsigned>::max();
4068     if (!!DemandedLHS)
4069       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedLHS, Depth + 1);
4070     if (!!DemandedRHS) {
4071       Tmp2 = ComputeNumSignBits(Op.getOperand(1), DemandedRHS, Depth + 1);
4072       Tmp = std::min(Tmp, Tmp2);
4073     }
4074     // If we don't know anything, early out and try computeKnownBits fall-back.
4075     if (Tmp == 1)
4076       break;
4077     assert(Tmp <= VTBits && "Failed to determine minimum sign bits");
4078     return Tmp;
4079   }
4080 
4081   case ISD::BITCAST: {
4082     if (VT.isScalableVector())
4083       break;
4084     SDValue N0 = Op.getOperand(0);
4085     EVT SrcVT = N0.getValueType();
4086     unsigned SrcBits = SrcVT.getScalarSizeInBits();
4087 
4088     // Ignore bitcasts from unsupported types..
4089     if (!(SrcVT.isInteger() || SrcVT.isFloatingPoint()))
4090       break;
4091 
4092     // Fast handling of 'identity' bitcasts.
4093     if (VTBits == SrcBits)
4094       return ComputeNumSignBits(N0, DemandedElts, Depth + 1);
4095 
4096     bool IsLE = getDataLayout().isLittleEndian();
4097 
4098     // Bitcast 'large element' scalar/vector to 'small element' vector.
4099     if ((SrcBits % VTBits) == 0) {
4100       assert(VT.isVector() && "Expected bitcast to vector");
4101 
4102       unsigned Scale = SrcBits / VTBits;
4103       APInt SrcDemandedElts =
4104           APIntOps::ScaleBitMask(DemandedElts, NumElts / Scale);
4105 
4106       // Fast case - sign splat can be simply split across the small elements.
4107       Tmp = ComputeNumSignBits(N0, SrcDemandedElts, Depth + 1);
4108       if (Tmp == SrcBits)
4109         return VTBits;
4110 
4111       // Slow case - determine how far the sign extends into each sub-element.
4112       Tmp2 = VTBits;
4113       for (unsigned i = 0; i != NumElts; ++i)
4114         if (DemandedElts[i]) {
4115           unsigned SubOffset = i % Scale;
4116           SubOffset = (IsLE ? ((Scale - 1) - SubOffset) : SubOffset);
4117           SubOffset = SubOffset * VTBits;
4118           if (Tmp <= SubOffset)
4119             return 1;
4120           Tmp2 = std::min(Tmp2, Tmp - SubOffset);
4121         }
4122       return Tmp2;
4123     }
4124     break;
4125   }
4126 
4127   case ISD::FP_TO_SINT_SAT:
4128     // FP_TO_SINT_SAT produces a signed value that fits in the saturating VT.
4129     Tmp = cast<VTSDNode>(Op.getOperand(1))->getVT().getScalarSizeInBits();
4130     return VTBits - Tmp + 1;
4131   case ISD::SIGN_EXTEND:
4132     Tmp = VTBits - Op.getOperand(0).getScalarValueSizeInBits();
4133     return ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1) + Tmp;
4134   case ISD::SIGN_EXTEND_INREG:
4135     // Max of the input and what this extends.
4136     Tmp = cast<VTSDNode>(Op.getOperand(1))->getVT().getScalarSizeInBits();
4137     Tmp = VTBits-Tmp+1;
4138     Tmp2 = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
4139     return std::max(Tmp, Tmp2);
4140   case ISD::SIGN_EXTEND_VECTOR_INREG: {
4141     if (VT.isScalableVector())
4142       break;
4143     SDValue Src = Op.getOperand(0);
4144     EVT SrcVT = Src.getValueType();
4145     APInt DemandedSrcElts = DemandedElts.zext(SrcVT.getVectorNumElements());
4146     Tmp = VTBits - SrcVT.getScalarSizeInBits();
4147     return ComputeNumSignBits(Src, DemandedSrcElts, Depth+1) + Tmp;
4148   }
4149   case ISD::SRA:
4150     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4151     // SRA X, C -> adds C sign bits.
4152     if (const APInt *ShAmt =
4153             getValidMinimumShiftAmountConstant(Op, DemandedElts))
4154       Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
4155     return Tmp;
4156   case ISD::SHL:
4157     if (const APInt *ShAmt =
4158             getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
4159       // shl destroys sign bits, ensure it doesn't shift out all sign bits.
4160       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4161       if (ShAmt->ult(Tmp))
4162         return Tmp - ShAmt->getZExtValue();
4163     }
4164     break;
4165   case ISD::AND:
4166   case ISD::OR:
4167   case ISD::XOR:    // NOT is handled here.
4168     // Logical binary ops preserve the number of sign bits at the worst.
4169     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
4170     if (Tmp != 1) {
4171       Tmp2 = ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth+1);
4172       FirstAnswer = std::min(Tmp, Tmp2);
4173       // We computed what we know about the sign bits as our first
4174       // answer. Now proceed to the generic code that uses
4175       // computeKnownBits, and pick whichever answer is better.
4176     }
4177     break;
4178 
4179   case ISD::SELECT:
4180   case ISD::VSELECT:
4181     Tmp = ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth+1);
4182     if (Tmp == 1) return 1;  // Early out.
4183     Tmp2 = ComputeNumSignBits(Op.getOperand(2), DemandedElts, Depth+1);
4184     return std::min(Tmp, Tmp2);
4185   case ISD::SELECT_CC:
4186     Tmp = ComputeNumSignBits(Op.getOperand(2), DemandedElts, Depth+1);
4187     if (Tmp == 1) return 1;  // Early out.
4188     Tmp2 = ComputeNumSignBits(Op.getOperand(3), DemandedElts, Depth+1);
4189     return std::min(Tmp, Tmp2);
4190 
4191   case ISD::SMIN:
4192   case ISD::SMAX: {
4193     // If we have a clamp pattern, we know that the number of sign bits will be
4194     // the minimum of the clamp min/max range.
4195     bool IsMax = (Opcode == ISD::SMAX);
4196     ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
4197     if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
4198       if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
4199         CstHigh =
4200             isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
4201     if (CstLow && CstHigh) {
4202       if (!IsMax)
4203         std::swap(CstLow, CstHigh);
4204       if (CstLow->getAPIntValue().sle(CstHigh->getAPIntValue())) {
4205         Tmp = CstLow->getAPIntValue().getNumSignBits();
4206         Tmp2 = CstHigh->getAPIntValue().getNumSignBits();
4207         return std::min(Tmp, Tmp2);
4208       }
4209     }
4210 
4211     // Fallback - just get the minimum number of sign bits of the operands.
4212     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4213     if (Tmp == 1)
4214       return 1;  // Early out.
4215     Tmp2 = ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth + 1);
4216     return std::min(Tmp, Tmp2);
4217   }
4218   case ISD::UMIN:
4219   case ISD::UMAX:
4220     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4221     if (Tmp == 1)
4222       return 1;  // Early out.
4223     Tmp2 = ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth + 1);
4224     return std::min(Tmp, Tmp2);
4225   case ISD::SADDO:
4226   case ISD::UADDO:
4227   case ISD::SADDO_CARRY:
4228   case ISD::ADDCARRY:
4229   case ISD::SSUBO:
4230   case ISD::USUBO:
4231   case ISD::SSUBO_CARRY:
4232   case ISD::SUBCARRY:
4233   case ISD::SMULO:
4234   case ISD::UMULO:
4235     if (Op.getResNo() != 1)
4236       break;
4237     // The boolean result conforms to getBooleanContents.  Fall through.
4238     // If setcc returns 0/-1, all bits are sign bits.
4239     // We know that we have an integer-based boolean since these operations
4240     // are only available for integer.
4241     if (TLI->getBooleanContents(VT.isVector(), false) ==
4242         TargetLowering::ZeroOrNegativeOneBooleanContent)
4243       return VTBits;
4244     break;
4245   case ISD::SETCC:
4246   case ISD::SETCCCARRY:
4247   case ISD::STRICT_FSETCC:
4248   case ISD::STRICT_FSETCCS: {
4249     unsigned OpNo = Op->isStrictFPOpcode() ? 1 : 0;
4250     // If setcc returns 0/-1, all bits are sign bits.
4251     if (TLI->getBooleanContents(Op.getOperand(OpNo).getValueType()) ==
4252         TargetLowering::ZeroOrNegativeOneBooleanContent)
4253       return VTBits;
4254     break;
4255   }
4256   case ISD::ROTL:
4257   case ISD::ROTR:
4258     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4259 
4260     // If we're rotating an 0/-1 value, then it stays an 0/-1 value.
4261     if (Tmp == VTBits)
4262       return VTBits;
4263 
4264     if (ConstantSDNode *C =
4265             isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
4266       unsigned RotAmt = C->getAPIntValue().urem(VTBits);
4267 
4268       // Handle rotate right by N like a rotate left by 32-N.
4269       if (Opcode == ISD::ROTR)
4270         RotAmt = (VTBits - RotAmt) % VTBits;
4271 
4272       // If we aren't rotating out all of the known-in sign bits, return the
4273       // number that are left.  This handles rotl(sext(x), 1) for example.
4274       if (Tmp > (RotAmt + 1)) return (Tmp - RotAmt);
4275     }
4276     break;
4277   case ISD::ADD:
4278   case ISD::ADDC:
4279     // Add can have at most one carry bit.  Thus we know that the output
4280     // is, at worst, one more bit than the inputs.
4281     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4282     if (Tmp == 1) return 1; // Early out.
4283 
4284     // Special case decrementing a value (ADD X, -1):
4285     if (ConstantSDNode *CRHS =
4286             isConstOrConstSplat(Op.getOperand(1), DemandedElts))
4287       if (CRHS->isAllOnes()) {
4288         KnownBits Known =
4289             computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
4290 
4291         // If the input is known to be 0 or 1, the output is 0/-1, which is all
4292         // sign bits set.
4293         if ((Known.Zero | 1).isAllOnes())
4294           return VTBits;
4295 
4296         // If we are subtracting one from a positive number, there is no carry
4297         // out of the result.
4298         if (Known.isNonNegative())
4299           return Tmp;
4300       }
4301 
4302     Tmp2 = ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth + 1);
4303     if (Tmp2 == 1) return 1; // Early out.
4304     return std::min(Tmp, Tmp2) - 1;
4305   case ISD::SUB:
4306     Tmp2 = ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth + 1);
4307     if (Tmp2 == 1) return 1; // Early out.
4308 
4309     // Handle NEG.
4310     if (ConstantSDNode *CLHS =
4311             isConstOrConstSplat(Op.getOperand(0), DemandedElts))
4312       if (CLHS->isZero()) {
4313         KnownBits Known =
4314             computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
4315         // If the input is known to be 0 or 1, the output is 0/-1, which is all
4316         // sign bits set.
4317         if ((Known.Zero | 1).isAllOnes())
4318           return VTBits;
4319 
4320         // If the input is known to be positive (the sign bit is known clear),
4321         // the output of the NEG has the same number of sign bits as the input.
4322         if (Known.isNonNegative())
4323           return Tmp2;
4324 
4325         // Otherwise, we treat this like a SUB.
4326       }
4327 
4328     // Sub can have at most one carry bit.  Thus we know that the output
4329     // is, at worst, one more bit than the inputs.
4330     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4331     if (Tmp == 1) return 1; // Early out.
4332     return std::min(Tmp, Tmp2) - 1;
4333   case ISD::MUL: {
4334     // The output of the Mul can be at most twice the valid bits in the inputs.
4335     unsigned SignBitsOp0 = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
4336     if (SignBitsOp0 == 1)
4337       break;
4338     unsigned SignBitsOp1 = ComputeNumSignBits(Op.getOperand(1), Depth + 1);
4339     if (SignBitsOp1 == 1)
4340       break;
4341     unsigned OutValidBits =
4342         (VTBits - SignBitsOp0 + 1) + (VTBits - SignBitsOp1 + 1);
4343     return OutValidBits > VTBits ? 1 : VTBits - OutValidBits + 1;
4344   }
4345   case ISD::SREM:
4346     // The sign bit is the LHS's sign bit, except when the result of the
4347     // remainder is zero. The magnitude of the result should be less than or
4348     // equal to the magnitude of the LHS. Therefore, the result should have
4349     // at least as many sign bits as the left hand side.
4350     return ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4351   case ISD::TRUNCATE: {
4352     // Check if the sign bits of source go down as far as the truncated value.
4353     unsigned NumSrcBits = Op.getOperand(0).getScalarValueSizeInBits();
4354     unsigned NumSrcSignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
4355     if (NumSrcSignBits > (NumSrcBits - VTBits))
4356       return NumSrcSignBits - (NumSrcBits - VTBits);
4357     break;
4358   }
4359   case ISD::EXTRACT_ELEMENT: {
4360     if (VT.isScalableVector())
4361       break;
4362     const int KnownSign = ComputeNumSignBits(Op.getOperand(0), Depth+1);
4363     const int BitWidth = Op.getValueSizeInBits();
4364     const int Items = Op.getOperand(0).getValueSizeInBits() / BitWidth;
4365 
4366     // Get reverse index (starting from 1), Op1 value indexes elements from
4367     // little end. Sign starts at big end.
4368     const int rIndex = Items - 1 - Op.getConstantOperandVal(1);
4369 
4370     // If the sign portion ends in our element the subtraction gives correct
4371     // result. Otherwise it gives either negative or > bitwidth result
4372     return std::clamp(KnownSign - rIndex * BitWidth, 0, BitWidth);
4373   }
4374   case ISD::INSERT_VECTOR_ELT: {
4375     if (VT.isScalableVector())
4376       break;
4377     // If we know the element index, split the demand between the
4378     // source vector and the inserted element, otherwise assume we need
4379     // the original demanded vector elements and the value.
4380     SDValue InVec = Op.getOperand(0);
4381     SDValue InVal = Op.getOperand(1);
4382     SDValue EltNo = Op.getOperand(2);
4383     bool DemandedVal = true;
4384     APInt DemandedVecElts = DemandedElts;
4385     auto *CEltNo = dyn_cast<ConstantSDNode>(EltNo);
4386     if (CEltNo && CEltNo->getAPIntValue().ult(NumElts)) {
4387       unsigned EltIdx = CEltNo->getZExtValue();
4388       DemandedVal = !!DemandedElts[EltIdx];
4389       DemandedVecElts.clearBit(EltIdx);
4390     }
4391     Tmp = std::numeric_limits<unsigned>::max();
4392     if (DemandedVal) {
4393       // TODO - handle implicit truncation of inserted elements.
4394       if (InVal.getScalarValueSizeInBits() != VTBits)
4395         break;
4396       Tmp2 = ComputeNumSignBits(InVal, Depth + 1);
4397       Tmp = std::min(Tmp, Tmp2);
4398     }
4399     if (!!DemandedVecElts) {
4400       Tmp2 = ComputeNumSignBits(InVec, DemandedVecElts, Depth + 1);
4401       Tmp = std::min(Tmp, Tmp2);
4402     }
4403     assert(Tmp <= VTBits && "Failed to determine minimum sign bits");
4404     return Tmp;
4405   }
4406   case ISD::EXTRACT_VECTOR_ELT: {
4407     assert(!VT.isScalableVector());
4408     SDValue InVec = Op.getOperand(0);
4409     SDValue EltNo = Op.getOperand(1);
4410     EVT VecVT = InVec.getValueType();
4411     // ComputeNumSignBits not yet implemented for scalable vectors.
4412     if (VecVT.isScalableVector())
4413       break;
4414     const unsigned BitWidth = Op.getValueSizeInBits();
4415     const unsigned EltBitWidth = Op.getOperand(0).getScalarValueSizeInBits();
4416     const unsigned NumSrcElts = VecVT.getVectorNumElements();
4417 
4418     // If BitWidth > EltBitWidth the value is anyext:ed, and we do not know
4419     // anything about sign bits. But if the sizes match we can derive knowledge
4420     // about sign bits from the vector operand.
4421     if (BitWidth != EltBitWidth)
4422       break;
4423 
4424     // If we know the element index, just demand that vector element, else for
4425     // an unknown element index, ignore DemandedElts and demand them all.
4426     APInt DemandedSrcElts = APInt::getAllOnes(NumSrcElts);
4427     auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo);
4428     if (ConstEltNo && ConstEltNo->getAPIntValue().ult(NumSrcElts))
4429       DemandedSrcElts =
4430           APInt::getOneBitSet(NumSrcElts, ConstEltNo->getZExtValue());
4431 
4432     return ComputeNumSignBits(InVec, DemandedSrcElts, Depth + 1);
4433   }
4434   case ISD::EXTRACT_SUBVECTOR: {
4435     // Offset the demanded elts by the subvector index.
4436     SDValue Src = Op.getOperand(0);
4437     // Bail until we can represent demanded elements for scalable vectors.
4438     if (Src.getValueType().isScalableVector())
4439       break;
4440     uint64_t Idx = Op.getConstantOperandVal(1);
4441     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
4442     APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
4443     return ComputeNumSignBits(Src, DemandedSrcElts, Depth + 1);
4444   }
4445   case ISD::CONCAT_VECTORS: {
4446     if (VT.isScalableVector())
4447       break;
4448     // Determine the minimum number of sign bits across all demanded
4449     // elts of the input vectors. Early out if the result is already 1.
4450     Tmp = std::numeric_limits<unsigned>::max();
4451     EVT SubVectorVT = Op.getOperand(0).getValueType();
4452     unsigned NumSubVectorElts = SubVectorVT.getVectorNumElements();
4453     unsigned NumSubVectors = Op.getNumOperands();
4454     for (unsigned i = 0; (i < NumSubVectors) && (Tmp > 1); ++i) {
4455       APInt DemandedSub =
4456           DemandedElts.extractBits(NumSubVectorElts, i * NumSubVectorElts);
4457       if (!DemandedSub)
4458         continue;
4459       Tmp2 = ComputeNumSignBits(Op.getOperand(i), DemandedSub, Depth + 1);
4460       Tmp = std::min(Tmp, Tmp2);
4461     }
4462     assert(Tmp <= VTBits && "Failed to determine minimum sign bits");
4463     return Tmp;
4464   }
4465   case ISD::INSERT_SUBVECTOR: {
4466     if (VT.isScalableVector())
4467       break;
4468     // Demand any elements from the subvector and the remainder from the src its
4469     // inserted into.
4470     SDValue Src = Op.getOperand(0);
4471     SDValue Sub = Op.getOperand(1);
4472     uint64_t Idx = Op.getConstantOperandVal(2);
4473     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
4474     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
4475     APInt DemandedSrcElts = DemandedElts;
4476     DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
4477 
4478     Tmp = std::numeric_limits<unsigned>::max();
4479     if (!!DemandedSubElts) {
4480       Tmp = ComputeNumSignBits(Sub, DemandedSubElts, Depth + 1);
4481       if (Tmp == 1)
4482         return 1; // early-out
4483     }
4484     if (!!DemandedSrcElts) {
4485       Tmp2 = ComputeNumSignBits(Src, DemandedSrcElts, Depth + 1);
4486       Tmp = std::min(Tmp, Tmp2);
4487     }
4488     assert(Tmp <= VTBits && "Failed to determine minimum sign bits");
4489     return Tmp;
4490   }
4491   case ISD::LOAD: {
4492     LoadSDNode *LD = cast<LoadSDNode>(Op);
4493     if (const MDNode *Ranges = LD->getRanges()) {
4494       if (DemandedElts != 1)
4495         break;
4496 
4497       ConstantRange CR = getConstantRangeFromMetadata(*Ranges);
4498       if (VTBits > CR.getBitWidth()) {
4499         switch (LD->getExtensionType()) {
4500         case ISD::SEXTLOAD:
4501           CR = CR.signExtend(VTBits);
4502           break;
4503         case ISD::ZEXTLOAD:
4504           CR = CR.zeroExtend(VTBits);
4505           break;
4506         default:
4507           break;
4508         }
4509       }
4510 
4511       if (VTBits != CR.getBitWidth())
4512         break;
4513       return std::min(CR.getSignedMin().getNumSignBits(),
4514                       CR.getSignedMax().getNumSignBits());
4515     }
4516 
4517     break;
4518   }
4519   case ISD::ATOMIC_CMP_SWAP:
4520   case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
4521   case ISD::ATOMIC_SWAP:
4522   case ISD::ATOMIC_LOAD_ADD:
4523   case ISD::ATOMIC_LOAD_SUB:
4524   case ISD::ATOMIC_LOAD_AND:
4525   case ISD::ATOMIC_LOAD_CLR:
4526   case ISD::ATOMIC_LOAD_OR:
4527   case ISD::ATOMIC_LOAD_XOR:
4528   case ISD::ATOMIC_LOAD_NAND:
4529   case ISD::ATOMIC_LOAD_MIN:
4530   case ISD::ATOMIC_LOAD_MAX:
4531   case ISD::ATOMIC_LOAD_UMIN:
4532   case ISD::ATOMIC_LOAD_UMAX:
4533   case ISD::ATOMIC_LOAD: {
4534     Tmp = cast<AtomicSDNode>(Op)->getMemoryVT().getScalarSizeInBits();
4535     // If we are looking at the loaded value.
4536     if (Op.getResNo() == 0) {
4537       if (Tmp == VTBits)
4538         return 1; // early-out
4539       if (TLI->getExtendForAtomicOps() == ISD::SIGN_EXTEND)
4540         return VTBits - Tmp + 1;
4541       if (TLI->getExtendForAtomicOps() == ISD::ZERO_EXTEND)
4542         return VTBits - Tmp;
4543     }
4544     break;
4545   }
4546   }
4547 
4548   // If we are looking at the loaded value of the SDNode.
4549   if (Op.getResNo() == 0) {
4550     // Handle LOADX separately here. EXTLOAD case will fallthrough.
4551     if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Op)) {
4552       unsigned ExtType = LD->getExtensionType();
4553       switch (ExtType) {
4554       default: break;
4555       case ISD::SEXTLOAD: // e.g. i16->i32 = '17' bits known.
4556         Tmp = LD->getMemoryVT().getScalarSizeInBits();
4557         return VTBits - Tmp + 1;
4558       case ISD::ZEXTLOAD: // e.g. i16->i32 = '16' bits known.
4559         Tmp = LD->getMemoryVT().getScalarSizeInBits();
4560         return VTBits - Tmp;
4561       case ISD::NON_EXTLOAD:
4562         if (const Constant *Cst = TLI->getTargetConstantFromLoad(LD)) {
4563           // We only need to handle vectors - computeKnownBits should handle
4564           // scalar cases.
4565           Type *CstTy = Cst->getType();
4566           if (CstTy->isVectorTy() && !VT.isScalableVector() &&
4567               (NumElts * VTBits) == CstTy->getPrimitiveSizeInBits() &&
4568               VTBits == CstTy->getScalarSizeInBits()) {
4569             Tmp = VTBits;
4570             for (unsigned i = 0; i != NumElts; ++i) {
4571               if (!DemandedElts[i])
4572                 continue;
4573               if (Constant *Elt = Cst->getAggregateElement(i)) {
4574                 if (auto *CInt = dyn_cast<ConstantInt>(Elt)) {
4575                   const APInt &Value = CInt->getValue();
4576                   Tmp = std::min(Tmp, Value.getNumSignBits());
4577                   continue;
4578                 }
4579                 if (auto *CFP = dyn_cast<ConstantFP>(Elt)) {
4580                   APInt Value = CFP->getValueAPF().bitcastToAPInt();
4581                   Tmp = std::min(Tmp, Value.getNumSignBits());
4582                   continue;
4583                 }
4584               }
4585               // Unknown type. Conservatively assume no bits match sign bit.
4586               return 1;
4587             }
4588             return Tmp;
4589           }
4590         }
4591         break;
4592       }
4593     }
4594   }
4595 
4596   // Allow the target to implement this method for its nodes.
4597   if (Opcode >= ISD::BUILTIN_OP_END ||
4598       Opcode == ISD::INTRINSIC_WO_CHAIN ||
4599       Opcode == ISD::INTRINSIC_W_CHAIN ||
4600       Opcode == ISD::INTRINSIC_VOID) {
4601     // TODO: This can probably be removed once target code is audited.  This
4602     // is here purely to reduce patch size and review complexity.
4603     if (!VT.isScalableVector()) {
4604       unsigned NumBits =
4605         TLI->ComputeNumSignBitsForTargetNode(Op, DemandedElts, *this, Depth);
4606       if (NumBits > 1)
4607         FirstAnswer = std::max(FirstAnswer, NumBits);
4608     }
4609   }
4610 
4611   // Finally, if we can prove that the top bits of the result are 0's or 1's,
4612   // use this information.
4613   KnownBits Known = computeKnownBits(Op, DemandedElts, Depth);
4614   return std::max(FirstAnswer, Known.countMinSignBits());
4615 }
4616 
ComputeMaxSignificantBits(SDValue Op,unsigned Depth) const4617 unsigned SelectionDAG::ComputeMaxSignificantBits(SDValue Op,
4618                                                  unsigned Depth) const {
4619   unsigned SignBits = ComputeNumSignBits(Op, Depth);
4620   return Op.getScalarValueSizeInBits() - SignBits + 1;
4621 }
4622 
ComputeMaxSignificantBits(SDValue Op,const APInt & DemandedElts,unsigned Depth) const4623 unsigned SelectionDAG::ComputeMaxSignificantBits(SDValue Op,
4624                                                  const APInt &DemandedElts,
4625                                                  unsigned Depth) const {
4626   unsigned SignBits = ComputeNumSignBits(Op, DemandedElts, Depth);
4627   return Op.getScalarValueSizeInBits() - SignBits + 1;
4628 }
4629 
isGuaranteedNotToBeUndefOrPoison(SDValue Op,bool PoisonOnly,unsigned Depth) const4630 bool SelectionDAG::isGuaranteedNotToBeUndefOrPoison(SDValue Op, bool PoisonOnly,
4631                                                     unsigned Depth) const {
4632   // Early out for FREEZE.
4633   if (Op.getOpcode() == ISD::FREEZE)
4634     return true;
4635 
4636   // TODO: Assume we don't know anything for now.
4637   EVT VT = Op.getValueType();
4638   if (VT.isScalableVector())
4639     return false;
4640 
4641   APInt DemandedElts = VT.isVector()
4642                            ? APInt::getAllOnes(VT.getVectorNumElements())
4643                            : APInt(1, 1);
4644   return isGuaranteedNotToBeUndefOrPoison(Op, DemandedElts, PoisonOnly, Depth);
4645 }
4646 
isGuaranteedNotToBeUndefOrPoison(SDValue Op,const APInt & DemandedElts,bool PoisonOnly,unsigned Depth) const4647 bool SelectionDAG::isGuaranteedNotToBeUndefOrPoison(SDValue Op,
4648                                                     const APInt &DemandedElts,
4649                                                     bool PoisonOnly,
4650                                                     unsigned Depth) const {
4651   unsigned Opcode = Op.getOpcode();
4652 
4653   // Early out for FREEZE.
4654   if (Opcode == ISD::FREEZE)
4655     return true;
4656 
4657   if (Depth >= MaxRecursionDepth)
4658     return false; // Limit search depth.
4659 
4660   if (isIntOrFPConstant(Op))
4661     return true;
4662 
4663   switch (Opcode) {
4664   case ISD::VALUETYPE:
4665   case ISD::FrameIndex:
4666   case ISD::TargetFrameIndex:
4667     return true;
4668 
4669   case ISD::UNDEF:
4670     return PoisonOnly;
4671 
4672   case ISD::BUILD_VECTOR:
4673     // NOTE: BUILD_VECTOR has implicit truncation of wider scalar elements -
4674     // this shouldn't affect the result.
4675     for (unsigned i = 0, e = Op.getNumOperands(); i < e; ++i) {
4676       if (!DemandedElts[i])
4677         continue;
4678       if (!isGuaranteedNotToBeUndefOrPoison(Op.getOperand(i), PoisonOnly,
4679                                             Depth + 1))
4680         return false;
4681     }
4682     return true;
4683 
4684     // TODO: Search for noundef attributes from library functions.
4685 
4686     // TODO: Pointers dereferenced by ISD::LOAD/STORE ops are noundef.
4687 
4688   default:
4689     // Allow the target to implement this method for its nodes.
4690     if (Opcode >= ISD::BUILTIN_OP_END || Opcode == ISD::INTRINSIC_WO_CHAIN ||
4691         Opcode == ISD::INTRINSIC_W_CHAIN || Opcode == ISD::INTRINSIC_VOID)
4692       return TLI->isGuaranteedNotToBeUndefOrPoisonForTargetNode(
4693           Op, DemandedElts, *this, PoisonOnly, Depth);
4694     break;
4695   }
4696 
4697   // If Op can't create undef/poison and none of its operands are undef/poison
4698   // then Op is never undef/poison.
4699   // NOTE: TargetNodes should handle this in themselves in
4700   // isGuaranteedNotToBeUndefOrPoisonForTargetNode.
4701   return !canCreateUndefOrPoison(Op, PoisonOnly, /*ConsiderFlags*/ true,
4702                                  Depth) &&
4703          all_of(Op->ops(), [&](SDValue V) {
4704            return isGuaranteedNotToBeUndefOrPoison(V, PoisonOnly, Depth + 1);
4705          });
4706 }
4707 
canCreateUndefOrPoison(SDValue Op,bool PoisonOnly,bool ConsiderFlags,unsigned Depth) const4708 bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, bool PoisonOnly,
4709                                           bool ConsiderFlags,
4710                                           unsigned Depth) const {
4711   // TODO: Assume we don't know anything for now.
4712   EVT VT = Op.getValueType();
4713   if (VT.isScalableVector())
4714     return true;
4715 
4716   APInt DemandedElts = VT.isVector()
4717                            ? APInt::getAllOnes(VT.getVectorNumElements())
4718                            : APInt(1, 1);
4719   return canCreateUndefOrPoison(Op, DemandedElts, PoisonOnly, ConsiderFlags,
4720                                 Depth);
4721 }
4722 
canCreateUndefOrPoison(SDValue Op,const APInt & DemandedElts,bool PoisonOnly,bool ConsiderFlags,unsigned Depth) const4723 bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
4724                                           bool PoisonOnly, bool ConsiderFlags,
4725                                           unsigned Depth) const {
4726   // TODO: Assume we don't know anything for now.
4727   EVT VT = Op.getValueType();
4728   if (VT.isScalableVector())
4729     return true;
4730 
4731   unsigned Opcode = Op.getOpcode();
4732   switch (Opcode) {
4733   case ISD::AssertSext:
4734   case ISD::AssertZext:
4735   case ISD::FREEZE:
4736   case ISD::INSERT_SUBVECTOR:
4737   case ISD::AND:
4738   case ISD::OR:
4739   case ISD::XOR:
4740   case ISD::ROTL:
4741   case ISD::ROTR:
4742   case ISD::FSHL:
4743   case ISD::FSHR:
4744   case ISD::BSWAP:
4745   case ISD::CTPOP:
4746   case ISD::BITREVERSE:
4747   case ISD::PARITY:
4748   case ISD::SIGN_EXTEND:
4749   case ISD::ZERO_EXTEND:
4750   case ISD::TRUNCATE:
4751   case ISD::SIGN_EXTEND_INREG:
4752   case ISD::SIGN_EXTEND_VECTOR_INREG:
4753   case ISD::ZERO_EXTEND_VECTOR_INREG:
4754   case ISD::BITCAST:
4755   case ISD::BUILD_VECTOR:
4756     return false;
4757 
4758   case ISD::ADD:
4759   case ISD::SUB:
4760   case ISD::MUL:
4761     // Matches hasPoisonGeneratingFlags().
4762     return ConsiderFlags && (Op->getFlags().hasNoSignedWrap() ||
4763                              Op->getFlags().hasNoUnsignedWrap());
4764 
4765   case ISD::SHL:
4766     // If the max shift amount isn't in range, then the shift can create poison.
4767     if (!getValidMaximumShiftAmountConstant(Op, DemandedElts))
4768       return true;
4769 
4770     // Matches hasPoisonGeneratingFlags().
4771     return ConsiderFlags && (Op->getFlags().hasNoSignedWrap() ||
4772                              Op->getFlags().hasNoUnsignedWrap());
4773 
4774   default:
4775     // Allow the target to implement this method for its nodes.
4776     if (Opcode >= ISD::BUILTIN_OP_END || Opcode == ISD::INTRINSIC_WO_CHAIN ||
4777         Opcode == ISD::INTRINSIC_W_CHAIN || Opcode == ISD::INTRINSIC_VOID)
4778       return TLI->canCreateUndefOrPoisonForTargetNode(
4779           Op, DemandedElts, *this, PoisonOnly, ConsiderFlags, Depth);
4780     break;
4781   }
4782 
4783   // Be conservative and return true.
4784   return true;
4785 }
4786 
isBaseWithConstantOffset(SDValue Op) const4787 bool SelectionDAG::isBaseWithConstantOffset(SDValue Op) const {
4788   if ((Op.getOpcode() != ISD::ADD && Op.getOpcode() != ISD::OR) ||
4789       !isa<ConstantSDNode>(Op.getOperand(1)))
4790     return false;
4791 
4792   if (Op.getOpcode() == ISD::OR &&
4793       !MaskedValueIsZero(Op.getOperand(0), Op.getConstantOperandAPInt(1)))
4794     return false;
4795 
4796   return true;
4797 }
4798 
isKnownNeverNaN(SDValue Op,bool SNaN,unsigned Depth) const4799 bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const {
4800   // If we're told that NaNs won't happen, assume they won't.
4801   if (getTarget().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs())
4802     return true;
4803 
4804   if (Depth >= MaxRecursionDepth)
4805     return false; // Limit search depth.
4806 
4807   // If the value is a constant, we can obviously see if it is a NaN or not.
4808   if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Op)) {
4809     return !C->getValueAPF().isNaN() ||
4810            (SNaN && !C->getValueAPF().isSignaling());
4811   }
4812 
4813   unsigned Opcode = Op.getOpcode();
4814   switch (Opcode) {
4815   case ISD::FADD:
4816   case ISD::FSUB:
4817   case ISD::FMUL:
4818   case ISD::FDIV:
4819   case ISD::FREM:
4820   case ISD::FSIN:
4821   case ISD::FCOS:
4822   case ISD::FMA:
4823   case ISD::FMAD: {
4824     if (SNaN)
4825       return true;
4826     // TODO: Need isKnownNeverInfinity
4827     return false;
4828   }
4829   case ISD::FCANONICALIZE:
4830   case ISD::FEXP:
4831   case ISD::FEXP2:
4832   case ISD::FTRUNC:
4833   case ISD::FFLOOR:
4834   case ISD::FCEIL:
4835   case ISD::FROUND:
4836   case ISD::FROUNDEVEN:
4837   case ISD::FRINT:
4838   case ISD::FNEARBYINT: {
4839     if (SNaN)
4840       return true;
4841     return isKnownNeverNaN(Op.getOperand(0), SNaN, Depth + 1);
4842   }
4843   case ISD::FABS:
4844   case ISD::FNEG:
4845   case ISD::FCOPYSIGN: {
4846     return isKnownNeverNaN(Op.getOperand(0), SNaN, Depth + 1);
4847   }
4848   case ISD::SELECT:
4849     return isKnownNeverNaN(Op.getOperand(1), SNaN, Depth + 1) &&
4850            isKnownNeverNaN(Op.getOperand(2), SNaN, Depth + 1);
4851   case ISD::FP_EXTEND:
4852   case ISD::FP_ROUND: {
4853     if (SNaN)
4854       return true;
4855     return isKnownNeverNaN(Op.getOperand(0), SNaN, Depth + 1);
4856   }
4857   case ISD::SINT_TO_FP:
4858   case ISD::UINT_TO_FP:
4859     return true;
4860   case ISD::FSQRT: // Need is known positive
4861   case ISD::FLOG:
4862   case ISD::FLOG2:
4863   case ISD::FLOG10:
4864   case ISD::FPOWI:
4865   case ISD::FPOW: {
4866     if (SNaN)
4867       return true;
4868     // TODO: Refine on operand
4869     return false;
4870   }
4871   case ISD::FMINNUM:
4872   case ISD::FMAXNUM: {
4873     // Only one needs to be known not-nan, since it will be returned if the
4874     // other ends up being one.
4875     return isKnownNeverNaN(Op.getOperand(0), SNaN, Depth + 1) ||
4876            isKnownNeverNaN(Op.getOperand(1), SNaN, Depth + 1);
4877   }
4878   case ISD::FMINNUM_IEEE:
4879   case ISD::FMAXNUM_IEEE: {
4880     if (SNaN)
4881       return true;
4882     // This can return a NaN if either operand is an sNaN, or if both operands
4883     // are NaN.
4884     return (isKnownNeverNaN(Op.getOperand(0), false, Depth + 1) &&
4885             isKnownNeverSNaN(Op.getOperand(1), Depth + 1)) ||
4886            (isKnownNeverNaN(Op.getOperand(1), false, Depth + 1) &&
4887             isKnownNeverSNaN(Op.getOperand(0), Depth + 1));
4888   }
4889   case ISD::FMINIMUM:
4890   case ISD::FMAXIMUM: {
4891     // TODO: Does this quiet or return the origina NaN as-is?
4892     return isKnownNeverNaN(Op.getOperand(0), SNaN, Depth + 1) &&
4893            isKnownNeverNaN(Op.getOperand(1), SNaN, Depth + 1);
4894   }
4895   case ISD::EXTRACT_VECTOR_ELT: {
4896     return isKnownNeverNaN(Op.getOperand(0), SNaN, Depth + 1);
4897   }
4898   case ISD::BUILD_VECTOR: {
4899     for (const SDValue &Opnd : Op->ops())
4900       if (!isKnownNeverNaN(Opnd, SNaN, Depth + 1))
4901         return false;
4902     return true;
4903   }
4904   default:
4905     if (Opcode >= ISD::BUILTIN_OP_END ||
4906         Opcode == ISD::INTRINSIC_WO_CHAIN ||
4907         Opcode == ISD::INTRINSIC_W_CHAIN ||
4908         Opcode == ISD::INTRINSIC_VOID) {
4909       return TLI->isKnownNeverNaNForTargetNode(Op, *this, SNaN, Depth);
4910     }
4911 
4912     return false;
4913   }
4914 }
4915 
isKnownNeverZeroFloat(SDValue Op) const4916 bool SelectionDAG::isKnownNeverZeroFloat(SDValue Op) const {
4917   assert(Op.getValueType().isFloatingPoint() &&
4918          "Floating point type expected");
4919 
4920   // If the value is a constant, we can obviously see if it is a zero or not.
4921   // TODO: Add BuildVector support.
4922   if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Op))
4923     return !C->isZero();
4924   return false;
4925 }
4926 
isKnownNeverZero(SDValue Op) const4927 bool SelectionDAG::isKnownNeverZero(SDValue Op) const {
4928   assert(!Op.getValueType().isFloatingPoint() &&
4929          "Floating point types unsupported - use isKnownNeverZeroFloat");
4930 
4931   // If the value is a constant, we can obviously see if it is a zero or not.
4932   if (ISD::matchUnaryPredicate(Op,
4933                                [](ConstantSDNode *C) { return !C->isZero(); }))
4934     return true;
4935 
4936   // TODO: Recognize more cases here.
4937   switch (Op.getOpcode()) {
4938   default: break;
4939   case ISD::OR:
4940     if (isKnownNeverZero(Op.getOperand(1)) ||
4941         isKnownNeverZero(Op.getOperand(0)))
4942       return true;
4943     break;
4944   }
4945 
4946   return false;
4947 }
4948 
isEqualTo(SDValue A,SDValue B) const4949 bool SelectionDAG::isEqualTo(SDValue A, SDValue B) const {
4950   // Check the obvious case.
4951   if (A == B) return true;
4952 
4953   // For for negative and positive zero.
4954   if (const ConstantFPSDNode *CA = dyn_cast<ConstantFPSDNode>(A))
4955     if (const ConstantFPSDNode *CB = dyn_cast<ConstantFPSDNode>(B))
4956       if (CA->isZero() && CB->isZero()) return true;
4957 
4958   // Otherwise they may not be equal.
4959   return false;
4960 }
4961 
4962 // Only bits set in Mask must be negated, other bits may be arbitrary.
getBitwiseNotOperand(SDValue V,SDValue Mask,bool AllowUndefs)4963 SDValue llvm::getBitwiseNotOperand(SDValue V, SDValue Mask, bool AllowUndefs) {
4964   if (isBitwiseNot(V, AllowUndefs))
4965     return V.getOperand(0);
4966 
4967   // Handle any_extend (not (truncate X)) pattern, where Mask only sets
4968   // bits in the non-extended part.
4969   ConstantSDNode *MaskC = isConstOrConstSplat(Mask);
4970   if (!MaskC || V.getOpcode() != ISD::ANY_EXTEND)
4971     return SDValue();
4972   SDValue ExtArg = V.getOperand(0);
4973   if (ExtArg.getScalarValueSizeInBits() >=
4974           MaskC->getAPIntValue().getActiveBits() &&
4975       isBitwiseNot(ExtArg, AllowUndefs) &&
4976       ExtArg.getOperand(0).getOpcode() == ISD::TRUNCATE &&
4977       ExtArg.getOperand(0).getOperand(0).getValueType() == V.getValueType())
4978     return ExtArg.getOperand(0).getOperand(0);
4979   return SDValue();
4980 }
4981 
haveNoCommonBitsSetCommutative(SDValue A,SDValue B)4982 static bool haveNoCommonBitsSetCommutative(SDValue A, SDValue B) {
4983   // Match masked merge pattern (X & ~M) op (Y & M)
4984   // Including degenerate case (X & ~M) op M
4985   auto MatchNoCommonBitsPattern = [&](SDValue Not, SDValue Mask,
4986                                       SDValue Other) {
4987     if (SDValue NotOperand =
4988             getBitwiseNotOperand(Not, Mask, /* AllowUndefs */ true)) {
4989       if (Other == NotOperand)
4990         return true;
4991       if (Other->getOpcode() == ISD::AND)
4992         return NotOperand == Other->getOperand(0) ||
4993                NotOperand == Other->getOperand(1);
4994     }
4995     return false;
4996   };
4997   if (A->getOpcode() == ISD::AND)
4998     return MatchNoCommonBitsPattern(A->getOperand(0), A->getOperand(1), B) ||
4999            MatchNoCommonBitsPattern(A->getOperand(1), A->getOperand(0), B);
5000   return false;
5001 }
5002 
5003 // FIXME: unify with llvm::haveNoCommonBitsSet.
haveNoCommonBitsSet(SDValue A,SDValue B) const5004 bool SelectionDAG::haveNoCommonBitsSet(SDValue A, SDValue B) const {
5005   assert(A.getValueType() == B.getValueType() &&
5006          "Values must have the same type");
5007   if (haveNoCommonBitsSetCommutative(A, B) ||
5008       haveNoCommonBitsSetCommutative(B, A))
5009     return true;
5010   return KnownBits::haveNoCommonBitsSet(computeKnownBits(A),
5011                                         computeKnownBits(B));
5012 }
5013 
FoldSTEP_VECTOR(const SDLoc & DL,EVT VT,SDValue Step,SelectionDAG & DAG)5014 static SDValue FoldSTEP_VECTOR(const SDLoc &DL, EVT VT, SDValue Step,
5015                                SelectionDAG &DAG) {
5016   if (cast<ConstantSDNode>(Step)->isZero())
5017     return DAG.getConstant(0, DL, VT);
5018 
5019   return SDValue();
5020 }
5021 
FoldBUILD_VECTOR(const SDLoc & DL,EVT VT,ArrayRef<SDValue> Ops,SelectionDAG & DAG)5022 static SDValue FoldBUILD_VECTOR(const SDLoc &DL, EVT VT,
5023                                 ArrayRef<SDValue> Ops,
5024                                 SelectionDAG &DAG) {
5025   int NumOps = Ops.size();
5026   assert(NumOps != 0 && "Can't build an empty vector!");
5027   assert(!VT.isScalableVector() &&
5028          "BUILD_VECTOR cannot be used with scalable types");
5029   assert(VT.getVectorNumElements() == (unsigned)NumOps &&
5030          "Incorrect element count in BUILD_VECTOR!");
5031 
5032   // BUILD_VECTOR of UNDEFs is UNDEF.
5033   if (llvm::all_of(Ops, [](SDValue Op) { return Op.isUndef(); }))
5034     return DAG.getUNDEF(VT);
5035 
5036   // BUILD_VECTOR of seq extract/insert from the same vector + type is Identity.
5037   SDValue IdentitySrc;
5038   bool IsIdentity = true;
5039   for (int i = 0; i != NumOps; ++i) {
5040     if (Ops[i].getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
5041         Ops[i].getOperand(0).getValueType() != VT ||
5042         (IdentitySrc && Ops[i].getOperand(0) != IdentitySrc) ||
5043         !isa<ConstantSDNode>(Ops[i].getOperand(1)) ||
5044         cast<ConstantSDNode>(Ops[i].getOperand(1))->getAPIntValue() != i) {
5045       IsIdentity = false;
5046       break;
5047     }
5048     IdentitySrc = Ops[i].getOperand(0);
5049   }
5050   if (IsIdentity)
5051     return IdentitySrc;
5052 
5053   return SDValue();
5054 }
5055 
5056 /// Try to simplify vector concatenation to an input value, undef, or build
5057 /// vector.
foldCONCAT_VECTORS(const SDLoc & DL,EVT VT,ArrayRef<SDValue> Ops,SelectionDAG & DAG)5058 static SDValue foldCONCAT_VECTORS(const SDLoc &DL, EVT VT,
5059                                   ArrayRef<SDValue> Ops,
5060                                   SelectionDAG &DAG) {
5061   assert(!Ops.empty() && "Can't concatenate an empty list of vectors!");
5062   assert(llvm::all_of(Ops,
5063                       [Ops](SDValue Op) {
5064                         return Ops[0].getValueType() == Op.getValueType();
5065                       }) &&
5066          "Concatenation of vectors with inconsistent value types!");
5067   assert((Ops[0].getValueType().getVectorElementCount() * Ops.size()) ==
5068              VT.getVectorElementCount() &&
5069          "Incorrect element count in vector concatenation!");
5070 
5071   if (Ops.size() == 1)
5072     return Ops[0];
5073 
5074   // Concat of UNDEFs is UNDEF.
5075   if (llvm::all_of(Ops, [](SDValue Op) { return Op.isUndef(); }))
5076     return DAG.getUNDEF(VT);
5077 
5078   // Scan the operands and look for extract operations from a single source
5079   // that correspond to insertion at the same location via this concatenation:
5080   // concat (extract X, 0*subvec_elts), (extract X, 1*subvec_elts), ...
5081   SDValue IdentitySrc;
5082   bool IsIdentity = true;
5083   for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
5084     SDValue Op = Ops[i];
5085     unsigned IdentityIndex = i * Op.getValueType().getVectorMinNumElements();
5086     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
5087         Op.getOperand(0).getValueType() != VT ||
5088         (IdentitySrc && Op.getOperand(0) != IdentitySrc) ||
5089         Op.getConstantOperandVal(1) != IdentityIndex) {
5090       IsIdentity = false;
5091       break;
5092     }
5093     assert((!IdentitySrc || IdentitySrc == Op.getOperand(0)) &&
5094            "Unexpected identity source vector for concat of extracts");
5095     IdentitySrc = Op.getOperand(0);
5096   }
5097   if (IsIdentity) {
5098     assert(IdentitySrc && "Failed to set source vector of extracts");
5099     return IdentitySrc;
5100   }
5101 
5102   // The code below this point is only designed to work for fixed width
5103   // vectors, so we bail out for now.
5104   if (VT.isScalableVector())
5105     return SDValue();
5106 
5107   // A CONCAT_VECTOR with all UNDEF/BUILD_VECTOR operands can be
5108   // simplified to one big BUILD_VECTOR.
5109   // FIXME: Add support for SCALAR_TO_VECTOR as well.
5110   EVT SVT = VT.getScalarType();
5111   SmallVector<SDValue, 16> Elts;
5112   for (SDValue Op : Ops) {
5113     EVT OpVT = Op.getValueType();
5114     if (Op.isUndef())
5115       Elts.append(OpVT.getVectorNumElements(), DAG.getUNDEF(SVT));
5116     else if (Op.getOpcode() == ISD::BUILD_VECTOR)
5117       Elts.append(Op->op_begin(), Op->op_end());
5118     else
5119       return SDValue();
5120   }
5121 
5122   // BUILD_VECTOR requires all inputs to be of the same type, find the
5123   // maximum type and extend them all.
5124   for (SDValue Op : Elts)
5125     SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
5126 
5127   if (SVT.bitsGT(VT.getScalarType())) {
5128     for (SDValue &Op : Elts) {
5129       if (Op.isUndef())
5130         Op = DAG.getUNDEF(SVT);
5131       else
5132         Op = DAG.getTargetLoweringInfo().isZExtFree(Op.getValueType(), SVT)
5133                  ? DAG.getZExtOrTrunc(Op, DL, SVT)
5134                  : DAG.getSExtOrTrunc(Op, DL, SVT);
5135     }
5136   }
5137 
5138   SDValue V = DAG.getBuildVector(VT, DL, Elts);
5139   NewSDValueDbgMsg(V, "New node fold concat vectors: ", &DAG);
5140   return V;
5141 }
5142 
5143 /// Gets or creates the specified node.
getNode(unsigned Opcode,const SDLoc & DL,EVT VT)5144 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
5145   FoldingSetNodeID ID;
5146   AddNodeIDNode(ID, Opcode, getVTList(VT), std::nullopt);
5147   void *IP = nullptr;
5148   if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
5149     return SDValue(E, 0);
5150 
5151   auto *N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(),
5152                               getVTList(VT));
5153   CSEMap.InsertNode(N, IP);
5154 
5155   InsertNode(N);
5156   SDValue V = SDValue(N, 0);
5157   NewSDValueDbgMsg(V, "Creating new node: ", this);
5158   return V;
5159 }
5160 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue Operand)5161 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
5162                               SDValue Operand) {
5163   SDNodeFlags Flags;
5164   if (Inserter)
5165     Flags = Inserter->getFlags();
5166   return getNode(Opcode, DL, VT, Operand, Flags);
5167 }
5168 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue Operand,const SDNodeFlags Flags)5169 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
5170                               SDValue Operand, const SDNodeFlags Flags) {
5171   assert(Operand.getOpcode() != ISD::DELETED_NODE &&
5172          "Operand is DELETED_NODE!");
5173   // Constant fold unary operations with an integer constant operand. Even
5174   // opaque constant will be folded, because the folding of unary operations
5175   // doesn't create new constants with different values. Nevertheless, the
5176   // opaque flag is preserved during folding to prevent future folding with
5177   // other constants.
5178   if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Operand)) {
5179     const APInt &Val = C->getAPIntValue();
5180     switch (Opcode) {
5181     default: break;
5182     case ISD::SIGN_EXTEND:
5183       return getConstant(Val.sextOrTrunc(VT.getSizeInBits()), DL, VT,
5184                          C->isTargetOpcode(), C->isOpaque());
5185     case ISD::TRUNCATE:
5186       if (C->isOpaque())
5187         break;
5188       [[fallthrough]];
5189     case ISD::ZERO_EXTEND:
5190       return getConstant(Val.zextOrTrunc(VT.getSizeInBits()), DL, VT,
5191                          C->isTargetOpcode(), C->isOpaque());
5192     case ISD::ANY_EXTEND:
5193       // Some targets like RISCV prefer to sign extend some types.
5194       if (TLI->isSExtCheaperThanZExt(Operand.getValueType(), VT))
5195         return getConstant(Val.sextOrTrunc(VT.getSizeInBits()), DL, VT,
5196                            C->isTargetOpcode(), C->isOpaque());
5197       return getConstant(Val.zextOrTrunc(VT.getSizeInBits()), DL, VT,
5198                          C->isTargetOpcode(), C->isOpaque());
5199     case ISD::UINT_TO_FP:
5200     case ISD::SINT_TO_FP: {
5201       APFloat apf(EVTToAPFloatSemantics(VT),
5202                   APInt::getZero(VT.getSizeInBits()));
5203       (void)apf.convertFromAPInt(Val,
5204                                  Opcode==ISD::SINT_TO_FP,
5205                                  APFloat::rmNearestTiesToEven);
5206       return getConstantFP(apf, DL, VT);
5207     }
5208     case ISD::BITCAST:
5209       if (VT == MVT::f16 && C->getValueType(0) == MVT::i16)
5210         return getConstantFP(APFloat(APFloat::IEEEhalf(), Val), DL, VT);
5211       if (VT == MVT::f32 && C->getValueType(0) == MVT::i32)
5212         return getConstantFP(APFloat(APFloat::IEEEsingle(), Val), DL, VT);
5213       if (VT == MVT::f64 && C->getValueType(0) == MVT::i64)
5214         return getConstantFP(APFloat(APFloat::IEEEdouble(), Val), DL, VT);
5215       if (VT == MVT::f128 && C->getValueType(0) == MVT::i128)
5216         return getConstantFP(APFloat(APFloat::IEEEquad(), Val), DL, VT);
5217       break;
5218     case ISD::ABS:
5219       return getConstant(Val.abs(), DL, VT, C->isTargetOpcode(),
5220                          C->isOpaque());
5221     case ISD::BITREVERSE:
5222       return getConstant(Val.reverseBits(), DL, VT, C->isTargetOpcode(),
5223                          C->isOpaque());
5224     case ISD::BSWAP:
5225       return getConstant(Val.byteSwap(), DL, VT, C->isTargetOpcode(),
5226                          C->isOpaque());
5227     case ISD::CTPOP:
5228       return getConstant(Val.countPopulation(), DL, VT, C->isTargetOpcode(),
5229                          C->isOpaque());
5230     case ISD::CTLZ:
5231     case ISD::CTLZ_ZERO_UNDEF:
5232       return getConstant(Val.countLeadingZeros(), DL, VT, C->isTargetOpcode(),
5233                          C->isOpaque());
5234     case ISD::CTTZ:
5235     case ISD::CTTZ_ZERO_UNDEF:
5236       return getConstant(Val.countTrailingZeros(), DL, VT, C->isTargetOpcode(),
5237                          C->isOpaque());
5238     case ISD::FP16_TO_FP:
5239     case ISD::BF16_TO_FP: {
5240       bool Ignored;
5241       APFloat FPV(Opcode == ISD::FP16_TO_FP ? APFloat::IEEEhalf()
5242                                             : APFloat::BFloat(),
5243                   (Val.getBitWidth() == 16) ? Val : Val.trunc(16));
5244 
5245       // This can return overflow, underflow, or inexact; we don't care.
5246       // FIXME need to be more flexible about rounding mode.
5247       (void)FPV.convert(EVTToAPFloatSemantics(VT),
5248                         APFloat::rmNearestTiesToEven, &Ignored);
5249       return getConstantFP(FPV, DL, VT);
5250     }
5251     case ISD::STEP_VECTOR: {
5252       if (SDValue V = FoldSTEP_VECTOR(DL, VT, Operand, *this))
5253         return V;
5254       break;
5255     }
5256     }
5257   }
5258 
5259   // Constant fold unary operations with a floating point constant operand.
5260   if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Operand)) {
5261     APFloat V = C->getValueAPF();    // make copy
5262     switch (Opcode) {
5263     case ISD::FNEG:
5264       V.changeSign();
5265       return getConstantFP(V, DL, VT);
5266     case ISD::FABS:
5267       V.clearSign();
5268       return getConstantFP(V, DL, VT);
5269     case ISD::FCEIL: {
5270       APFloat::opStatus fs = V.roundToIntegral(APFloat::rmTowardPositive);
5271       if (fs == APFloat::opOK || fs == APFloat::opInexact)
5272         return getConstantFP(V, DL, VT);
5273       break;
5274     }
5275     case ISD::FTRUNC: {
5276       APFloat::opStatus fs = V.roundToIntegral(APFloat::rmTowardZero);
5277       if (fs == APFloat::opOK || fs == APFloat::opInexact)
5278         return getConstantFP(V, DL, VT);
5279       break;
5280     }
5281     case ISD::FFLOOR: {
5282       APFloat::opStatus fs = V.roundToIntegral(APFloat::rmTowardNegative);
5283       if (fs == APFloat::opOK || fs == APFloat::opInexact)
5284         return getConstantFP(V, DL, VT);
5285       break;
5286     }
5287     case ISD::FP_EXTEND: {
5288       bool ignored;
5289       // This can return overflow, underflow, or inexact; we don't care.
5290       // FIXME need to be more flexible about rounding mode.
5291       (void)V.convert(EVTToAPFloatSemantics(VT),
5292                       APFloat::rmNearestTiesToEven, &ignored);
5293       return getConstantFP(V, DL, VT);
5294     }
5295     case ISD::FP_TO_SINT:
5296     case ISD::FP_TO_UINT: {
5297       bool ignored;
5298       APSInt IntVal(VT.getSizeInBits(), Opcode == ISD::FP_TO_UINT);
5299       // FIXME need to be more flexible about rounding mode.
5300       APFloat::opStatus s =
5301           V.convertToInteger(IntVal, APFloat::rmTowardZero, &ignored);
5302       if (s == APFloat::opInvalidOp) // inexact is OK, in fact usual
5303         break;
5304       return getConstant(IntVal, DL, VT);
5305     }
5306     case ISD::BITCAST:
5307       if (VT == MVT::i16 && C->getValueType(0) == MVT::f16)
5308         return getConstant((uint16_t)V.bitcastToAPInt().getZExtValue(), DL, VT);
5309       if (VT == MVT::i16 && C->getValueType(0) == MVT::bf16)
5310         return getConstant((uint16_t)V.bitcastToAPInt().getZExtValue(), DL, VT);
5311       if (VT == MVT::i32 && C->getValueType(0) == MVT::f32)
5312         return getConstant((uint32_t)V.bitcastToAPInt().getZExtValue(), DL, VT);
5313       if (VT == MVT::i64 && C->getValueType(0) == MVT::f64)
5314         return getConstant(V.bitcastToAPInt().getZExtValue(), DL, VT);
5315       break;
5316     case ISD::FP_TO_FP16:
5317     case ISD::FP_TO_BF16: {
5318       bool Ignored;
5319       // This can return overflow, underflow, or inexact; we don't care.
5320       // FIXME need to be more flexible about rounding mode.
5321       (void)V.convert(Opcode == ISD::FP_TO_FP16 ? APFloat::IEEEhalf()
5322                                                 : APFloat::BFloat(),
5323                       APFloat::rmNearestTiesToEven, &Ignored);
5324       return getConstant(V.bitcastToAPInt().getZExtValue(), DL, VT);
5325     }
5326     }
5327   }
5328 
5329   // Constant fold unary operations with a vector integer or float operand.
5330   switch (Opcode) {
5331   default:
5332     // FIXME: Entirely reasonable to perform folding of other unary
5333     // operations here as the need arises.
5334     break;
5335   case ISD::FNEG:
5336   case ISD::FABS:
5337   case ISD::FCEIL:
5338   case ISD::FTRUNC:
5339   case ISD::FFLOOR:
5340   case ISD::FP_EXTEND:
5341   case ISD::FP_TO_SINT:
5342   case ISD::FP_TO_UINT:
5343   case ISD::TRUNCATE:
5344   case ISD::ANY_EXTEND:
5345   case ISD::ZERO_EXTEND:
5346   case ISD::SIGN_EXTEND:
5347   case ISD::UINT_TO_FP:
5348   case ISD::SINT_TO_FP:
5349   case ISD::ABS:
5350   case ISD::BITREVERSE:
5351   case ISD::BSWAP:
5352   case ISD::CTLZ:
5353   case ISD::CTLZ_ZERO_UNDEF:
5354   case ISD::CTTZ:
5355   case ISD::CTTZ_ZERO_UNDEF:
5356   case ISD::CTPOP: {
5357     SDValue Ops = {Operand};
5358     if (SDValue Fold = FoldConstantArithmetic(Opcode, DL, VT, Ops))
5359       return Fold;
5360   }
5361   }
5362 
5363   unsigned OpOpcode = Operand.getNode()->getOpcode();
5364   switch (Opcode) {
5365   case ISD::STEP_VECTOR:
5366     assert(VT.isScalableVector() &&
5367            "STEP_VECTOR can only be used with scalable types");
5368     assert(OpOpcode == ISD::TargetConstant &&
5369            VT.getVectorElementType() == Operand.getValueType() &&
5370            "Unexpected step operand");
5371     break;
5372   case ISD::FREEZE:
5373     assert(VT == Operand.getValueType() && "Unexpected VT!");
5374     if (isGuaranteedNotToBeUndefOrPoison(Operand, /*PoisonOnly*/ false,
5375                                          /*Depth*/ 1))
5376       return Operand;
5377     break;
5378   case ISD::TokenFactor:
5379   case ISD::MERGE_VALUES:
5380   case ISD::CONCAT_VECTORS:
5381     return Operand;         // Factor, merge or concat of one node?  No need.
5382   case ISD::BUILD_VECTOR: {
5383     // Attempt to simplify BUILD_VECTOR.
5384     SDValue Ops[] = {Operand};
5385     if (SDValue V = FoldBUILD_VECTOR(DL, VT, Ops, *this))
5386       return V;
5387     break;
5388   }
5389   case ISD::FP_ROUND: llvm_unreachable("Invalid method to make FP_ROUND node");
5390   case ISD::FP_EXTEND:
5391     assert(VT.isFloatingPoint() &&
5392            Operand.getValueType().isFloatingPoint() && "Invalid FP cast!");
5393     if (Operand.getValueType() == VT) return Operand;  // noop conversion.
5394     assert((!VT.isVector() ||
5395             VT.getVectorElementCount() ==
5396             Operand.getValueType().getVectorElementCount()) &&
5397            "Vector element count mismatch!");
5398     assert(Operand.getValueType().bitsLT(VT) &&
5399            "Invalid fpext node, dst < src!");
5400     if (Operand.isUndef())
5401       return getUNDEF(VT);
5402     break;
5403   case ISD::FP_TO_SINT:
5404   case ISD::FP_TO_UINT:
5405     if (Operand.isUndef())
5406       return getUNDEF(VT);
5407     break;
5408   case ISD::SINT_TO_FP:
5409   case ISD::UINT_TO_FP:
5410     // [us]itofp(undef) = 0, because the result value is bounded.
5411     if (Operand.isUndef())
5412       return getConstantFP(0.0, DL, VT);
5413     break;
5414   case ISD::SIGN_EXTEND:
5415     assert(VT.isInteger() && Operand.getValueType().isInteger() &&
5416            "Invalid SIGN_EXTEND!");
5417     assert(VT.isVector() == Operand.getValueType().isVector() &&
5418            "SIGN_EXTEND result type type should be vector iff the operand "
5419            "type is vector!");
5420     if (Operand.getValueType() == VT) return Operand;   // noop extension
5421     assert((!VT.isVector() ||
5422             VT.getVectorElementCount() ==
5423                 Operand.getValueType().getVectorElementCount()) &&
5424            "Vector element count mismatch!");
5425     assert(Operand.getValueType().bitsLT(VT) &&
5426            "Invalid sext node, dst < src!");
5427     if (OpOpcode == ISD::SIGN_EXTEND || OpOpcode == ISD::ZERO_EXTEND)
5428       return getNode(OpOpcode, DL, VT, Operand.getOperand(0));
5429     if (OpOpcode == ISD::UNDEF)
5430       // sext(undef) = 0, because the top bits will all be the same.
5431       return getConstant(0, DL, VT);
5432     break;
5433   case ISD::ZERO_EXTEND:
5434     assert(VT.isInteger() && Operand.getValueType().isInteger() &&
5435            "Invalid ZERO_EXTEND!");
5436     assert(VT.isVector() == Operand.getValueType().isVector() &&
5437            "ZERO_EXTEND result type type should be vector iff the operand "
5438            "type is vector!");
5439     if (Operand.getValueType() == VT) return Operand;   // noop extension
5440     assert((!VT.isVector() ||
5441             VT.getVectorElementCount() ==
5442                 Operand.getValueType().getVectorElementCount()) &&
5443            "Vector element count mismatch!");
5444     assert(Operand.getValueType().bitsLT(VT) &&
5445            "Invalid zext node, dst < src!");
5446     if (OpOpcode == ISD::ZERO_EXTEND)   // (zext (zext x)) -> (zext x)
5447       return getNode(ISD::ZERO_EXTEND, DL, VT, Operand.getOperand(0));
5448     if (OpOpcode == ISD::UNDEF)
5449       // zext(undef) = 0, because the top bits will be zero.
5450       return getConstant(0, DL, VT);
5451     break;
5452   case ISD::ANY_EXTEND:
5453     assert(VT.isInteger() && Operand.getValueType().isInteger() &&
5454            "Invalid ANY_EXTEND!");
5455     assert(VT.isVector() == Operand.getValueType().isVector() &&
5456            "ANY_EXTEND result type type should be vector iff the operand "
5457            "type is vector!");
5458     if (Operand.getValueType() == VT) return Operand;   // noop extension
5459     assert((!VT.isVector() ||
5460             VT.getVectorElementCount() ==
5461                 Operand.getValueType().getVectorElementCount()) &&
5462            "Vector element count mismatch!");
5463     assert(Operand.getValueType().bitsLT(VT) &&
5464            "Invalid anyext node, dst < src!");
5465 
5466     if (OpOpcode == ISD::ZERO_EXTEND || OpOpcode == ISD::SIGN_EXTEND ||
5467         OpOpcode == ISD::ANY_EXTEND)
5468       // (ext (zext x)) -> (zext x)  and  (ext (sext x)) -> (sext x)
5469       return getNode(OpOpcode, DL, VT, Operand.getOperand(0));
5470     if (OpOpcode == ISD::UNDEF)
5471       return getUNDEF(VT);
5472 
5473     // (ext (trunc x)) -> x
5474     if (OpOpcode == ISD::TRUNCATE) {
5475       SDValue OpOp = Operand.getOperand(0);
5476       if (OpOp.getValueType() == VT) {
5477         transferDbgValues(Operand, OpOp);
5478         return OpOp;
5479       }
5480     }
5481     break;
5482   case ISD::TRUNCATE:
5483     assert(VT.isInteger() && Operand.getValueType().isInteger() &&
5484            "Invalid TRUNCATE!");
5485     assert(VT.isVector() == Operand.getValueType().isVector() &&
5486            "TRUNCATE result type type should be vector iff the operand "
5487            "type is vector!");
5488     if (Operand.getValueType() == VT) return Operand;   // noop truncate
5489     assert((!VT.isVector() ||
5490             VT.getVectorElementCount() ==
5491                 Operand.getValueType().getVectorElementCount()) &&
5492            "Vector element count mismatch!");
5493     assert(Operand.getValueType().bitsGT(VT) &&
5494            "Invalid truncate node, src < dst!");
5495     if (OpOpcode == ISD::TRUNCATE)
5496       return getNode(ISD::TRUNCATE, DL, VT, Operand.getOperand(0));
5497     if (OpOpcode == ISD::ZERO_EXTEND || OpOpcode == ISD::SIGN_EXTEND ||
5498         OpOpcode == ISD::ANY_EXTEND) {
5499       // If the source is smaller than the dest, we still need an extend.
5500       if (Operand.getOperand(0).getValueType().getScalarType()
5501             .bitsLT(VT.getScalarType()))
5502         return getNode(OpOpcode, DL, VT, Operand.getOperand(0));
5503       if (Operand.getOperand(0).getValueType().bitsGT(VT))
5504         return getNode(ISD::TRUNCATE, DL, VT, Operand.getOperand(0));
5505       return Operand.getOperand(0);
5506     }
5507     if (OpOpcode == ISD::UNDEF)
5508       return getUNDEF(VT);
5509     if (OpOpcode == ISD::VSCALE && !NewNodesMustHaveLegalTypes)
5510       return getVScale(DL, VT, Operand.getConstantOperandAPInt(0));
5511     break;
5512   case ISD::ANY_EXTEND_VECTOR_INREG:
5513   case ISD::ZERO_EXTEND_VECTOR_INREG:
5514   case ISD::SIGN_EXTEND_VECTOR_INREG:
5515     assert(VT.isVector() && "This DAG node is restricted to vector types.");
5516     assert(Operand.getValueType().bitsLE(VT) &&
5517            "The input must be the same size or smaller than the result.");
5518     assert(VT.getVectorMinNumElements() <
5519                Operand.getValueType().getVectorMinNumElements() &&
5520            "The destination vector type must have fewer lanes than the input.");
5521     break;
5522   case ISD::ABS:
5523     assert(VT.isInteger() && VT == Operand.getValueType() &&
5524            "Invalid ABS!");
5525     if (OpOpcode == ISD::UNDEF)
5526       return getConstant(0, DL, VT);
5527     break;
5528   case ISD::BSWAP:
5529     assert(VT.isInteger() && VT == Operand.getValueType() &&
5530            "Invalid BSWAP!");
5531     assert((VT.getScalarSizeInBits() % 16 == 0) &&
5532            "BSWAP types must be a multiple of 16 bits!");
5533     if (OpOpcode == ISD::UNDEF)
5534       return getUNDEF(VT);
5535     // bswap(bswap(X)) -> X.
5536     if (OpOpcode == ISD::BSWAP)
5537       return Operand.getOperand(0);
5538     break;
5539   case ISD::BITREVERSE:
5540     assert(VT.isInteger() && VT == Operand.getValueType() &&
5541            "Invalid BITREVERSE!");
5542     if (OpOpcode == ISD::UNDEF)
5543       return getUNDEF(VT);
5544     break;
5545   case ISD::BITCAST:
5546     assert(VT.getSizeInBits() == Operand.getValueSizeInBits() &&
5547            "Cannot BITCAST between types of different sizes!");
5548     if (VT == Operand.getValueType()) return Operand;  // noop conversion.
5549     if (OpOpcode == ISD::BITCAST)  // bitconv(bitconv(x)) -> bitconv(x)
5550       return getNode(ISD::BITCAST, DL, VT, Operand.getOperand(0));
5551     if (OpOpcode == ISD::UNDEF)
5552       return getUNDEF(VT);
5553     break;
5554   case ISD::SCALAR_TO_VECTOR:
5555     assert(VT.isVector() && !Operand.getValueType().isVector() &&
5556            (VT.getVectorElementType() == Operand.getValueType() ||
5557             (VT.getVectorElementType().isInteger() &&
5558              Operand.getValueType().isInteger() &&
5559              VT.getVectorElementType().bitsLE(Operand.getValueType()))) &&
5560            "Illegal SCALAR_TO_VECTOR node!");
5561     if (OpOpcode == ISD::UNDEF)
5562       return getUNDEF(VT);
5563     // scalar_to_vector(extract_vector_elt V, 0) -> V, top bits are undefined.
5564     if (OpOpcode == ISD::EXTRACT_VECTOR_ELT &&
5565         isa<ConstantSDNode>(Operand.getOperand(1)) &&
5566         Operand.getConstantOperandVal(1) == 0 &&
5567         Operand.getOperand(0).getValueType() == VT)
5568       return Operand.getOperand(0);
5569     break;
5570   case ISD::FNEG:
5571     // Negation of an unknown bag of bits is still completely undefined.
5572     if (OpOpcode == ISD::UNDEF)
5573       return getUNDEF(VT);
5574 
5575     if (OpOpcode == ISD::FNEG)  // --X -> X
5576       return Operand.getOperand(0);
5577     break;
5578   case ISD::FABS:
5579     if (OpOpcode == ISD::FNEG)  // abs(-X) -> abs(X)
5580       return getNode(ISD::FABS, DL, VT, Operand.getOperand(0));
5581     break;
5582   case ISD::VSCALE:
5583     assert(VT == Operand.getValueType() && "Unexpected VT!");
5584     break;
5585   case ISD::CTPOP:
5586     if (Operand.getValueType().getScalarType() == MVT::i1)
5587       return Operand;
5588     break;
5589   case ISD::CTLZ:
5590   case ISD::CTTZ:
5591     if (Operand.getValueType().getScalarType() == MVT::i1)
5592       return getNOT(DL, Operand, Operand.getValueType());
5593     break;
5594   case ISD::VECREDUCE_ADD:
5595     if (Operand.getValueType().getScalarType() == MVT::i1)
5596       return getNode(ISD::VECREDUCE_XOR, DL, VT, Operand);
5597     break;
5598   case ISD::VECREDUCE_SMIN:
5599   case ISD::VECREDUCE_UMAX:
5600     if (Operand.getValueType().getScalarType() == MVT::i1)
5601       return getNode(ISD::VECREDUCE_OR, DL, VT, Operand);
5602     break;
5603   case ISD::VECREDUCE_SMAX:
5604   case ISD::VECREDUCE_UMIN:
5605     if (Operand.getValueType().getScalarType() == MVT::i1)
5606       return getNode(ISD::VECREDUCE_AND, DL, VT, Operand);
5607     break;
5608   }
5609 
5610   SDNode *N;
5611   SDVTList VTs = getVTList(VT);
5612   SDValue Ops[] = {Operand};
5613   if (VT != MVT::Glue) { // Don't CSE flag producing nodes
5614     FoldingSetNodeID ID;
5615     AddNodeIDNode(ID, Opcode, VTs, Ops);
5616     void *IP = nullptr;
5617     if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
5618       E->intersectFlagsWith(Flags);
5619       return SDValue(E, 0);
5620     }
5621 
5622     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
5623     N->setFlags(Flags);
5624     createOperands(N, Ops);
5625     CSEMap.InsertNode(N, IP);
5626   } else {
5627     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
5628     createOperands(N, Ops);
5629   }
5630 
5631   InsertNode(N);
5632   SDValue V = SDValue(N, 0);
5633   NewSDValueDbgMsg(V, "Creating new node: ", this);
5634   return V;
5635 }
5636 
FoldValue(unsigned Opcode,const APInt & C1,const APInt & C2)5637 static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
5638                                       const APInt &C2) {
5639   switch (Opcode) {
5640   case ISD::ADD:  return C1 + C2;
5641   case ISD::SUB:  return C1 - C2;
5642   case ISD::MUL:  return C1 * C2;
5643   case ISD::AND:  return C1 & C2;
5644   case ISD::OR:   return C1 | C2;
5645   case ISD::XOR:  return C1 ^ C2;
5646   case ISD::SHL:  return C1 << C2;
5647   case ISD::SRL:  return C1.lshr(C2);
5648   case ISD::SRA:  return C1.ashr(C2);
5649   case ISD::ROTL: return C1.rotl(C2);
5650   case ISD::ROTR: return C1.rotr(C2);
5651   case ISD::SMIN: return C1.sle(C2) ? C1 : C2;
5652   case ISD::SMAX: return C1.sge(C2) ? C1 : C2;
5653   case ISD::UMIN: return C1.ule(C2) ? C1 : C2;
5654   case ISD::UMAX: return C1.uge(C2) ? C1 : C2;
5655   case ISD::SADDSAT: return C1.sadd_sat(C2);
5656   case ISD::UADDSAT: return C1.uadd_sat(C2);
5657   case ISD::SSUBSAT: return C1.ssub_sat(C2);
5658   case ISD::USUBSAT: return C1.usub_sat(C2);
5659   case ISD::SSHLSAT: return C1.sshl_sat(C2);
5660   case ISD::USHLSAT: return C1.ushl_sat(C2);
5661   case ISD::UDIV:
5662     if (!C2.getBoolValue())
5663       break;
5664     return C1.udiv(C2);
5665   case ISD::UREM:
5666     if (!C2.getBoolValue())
5667       break;
5668     return C1.urem(C2);
5669   case ISD::SDIV:
5670     if (!C2.getBoolValue())
5671       break;
5672     return C1.sdiv(C2);
5673   case ISD::SREM:
5674     if (!C2.getBoolValue())
5675       break;
5676     return C1.srem(C2);
5677   case ISD::MULHS: {
5678     unsigned FullWidth = C1.getBitWidth() * 2;
5679     APInt C1Ext = C1.sext(FullWidth);
5680     APInt C2Ext = C2.sext(FullWidth);
5681     return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
5682   }
5683   case ISD::MULHU: {
5684     unsigned FullWidth = C1.getBitWidth() * 2;
5685     APInt C1Ext = C1.zext(FullWidth);
5686     APInt C2Ext = C2.zext(FullWidth);
5687     return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
5688   }
5689   case ISD::AVGFLOORS: {
5690     unsigned FullWidth = C1.getBitWidth() + 1;
5691     APInt C1Ext = C1.sext(FullWidth);
5692     APInt C2Ext = C2.sext(FullWidth);
5693     return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
5694   }
5695   case ISD::AVGFLOORU: {
5696     unsigned FullWidth = C1.getBitWidth() + 1;
5697     APInt C1Ext = C1.zext(FullWidth);
5698     APInt C2Ext = C2.zext(FullWidth);
5699     return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
5700   }
5701   case ISD::AVGCEILS: {
5702     unsigned FullWidth = C1.getBitWidth() + 1;
5703     APInt C1Ext = C1.sext(FullWidth);
5704     APInt C2Ext = C2.sext(FullWidth);
5705     return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
5706   }
5707   case ISD::AVGCEILU: {
5708     unsigned FullWidth = C1.getBitWidth() + 1;
5709     APInt C1Ext = C1.zext(FullWidth);
5710     APInt C2Ext = C2.zext(FullWidth);
5711     return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
5712   }
5713   }
5714   return std::nullopt;
5715 }
5716 
5717 // Handle constant folding with UNDEF.
5718 // TODO: Handle more cases.
FoldValueWithUndef(unsigned Opcode,const APInt & C1,bool IsUndef1,const APInt & C2,bool IsUndef2)5719 static std::optional<APInt> FoldValueWithUndef(unsigned Opcode, const APInt &C1,
5720                                                bool IsUndef1, const APInt &C2,
5721                                                bool IsUndef2) {
5722   if (!(IsUndef1 || IsUndef2))
5723     return FoldValue(Opcode, C1, C2);
5724 
5725   // Fold and(x, undef) -> 0
5726   // Fold mul(x, undef) -> 0
5727   if (Opcode == ISD::AND || Opcode == ISD::MUL)
5728     return APInt::getZero(C1.getBitWidth());
5729 
5730   return std::nullopt;
5731 }
5732 
FoldSymbolOffset(unsigned Opcode,EVT VT,const GlobalAddressSDNode * GA,const SDNode * N2)5733 SDValue SelectionDAG::FoldSymbolOffset(unsigned Opcode, EVT VT,
5734                                        const GlobalAddressSDNode *GA,
5735                                        const SDNode *N2) {
5736   if (GA->getOpcode() != ISD::GlobalAddress)
5737     return SDValue();
5738   if (!TLI->isOffsetFoldingLegal(GA))
5739     return SDValue();
5740   auto *C2 = dyn_cast<ConstantSDNode>(N2);
5741   if (!C2)
5742     return SDValue();
5743   int64_t Offset = C2->getSExtValue();
5744   switch (Opcode) {
5745   case ISD::ADD: break;
5746   case ISD::SUB: Offset = -uint64_t(Offset); break;
5747   default: return SDValue();
5748   }
5749   return getGlobalAddress(GA->getGlobal(), SDLoc(C2), VT,
5750                           GA->getOffset() + uint64_t(Offset));
5751 }
5752 
isUndef(unsigned Opcode,ArrayRef<SDValue> Ops)5753 bool SelectionDAG::isUndef(unsigned Opcode, ArrayRef<SDValue> Ops) {
5754   switch (Opcode) {
5755   case ISD::SDIV:
5756   case ISD::UDIV:
5757   case ISD::SREM:
5758   case ISD::UREM: {
5759     // If a divisor is zero/undef or any element of a divisor vector is
5760     // zero/undef, the whole op is undef.
5761     assert(Ops.size() == 2 && "Div/rem should have 2 operands");
5762     SDValue Divisor = Ops[1];
5763     if (Divisor.isUndef() || isNullConstant(Divisor))
5764       return true;
5765 
5766     return ISD::isBuildVectorOfConstantSDNodes(Divisor.getNode()) &&
5767            llvm::any_of(Divisor->op_values(),
5768                         [](SDValue V) { return V.isUndef() ||
5769                                         isNullConstant(V); });
5770     // TODO: Handle signed overflow.
5771   }
5772   // TODO: Handle oversized shifts.
5773   default:
5774     return false;
5775   }
5776 }
5777 
FoldConstantArithmetic(unsigned Opcode,const SDLoc & DL,EVT VT,ArrayRef<SDValue> Ops)5778 SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
5779                                              EVT VT, ArrayRef<SDValue> Ops) {
5780   // If the opcode is a target-specific ISD node, there's nothing we can
5781   // do here and the operand rules may not line up with the below, so
5782   // bail early.
5783   // We can't create a scalar CONCAT_VECTORS so skip it. It will break
5784   // for concats involving SPLAT_VECTOR. Concats of BUILD_VECTORS are handled by
5785   // foldCONCAT_VECTORS in getNode before this is called.
5786   if (Opcode >= ISD::BUILTIN_OP_END || Opcode == ISD::CONCAT_VECTORS)
5787     return SDValue();
5788 
5789   unsigned NumOps = Ops.size();
5790   if (NumOps == 0)
5791     return SDValue();
5792 
5793   if (isUndef(Opcode, Ops))
5794     return getUNDEF(VT);
5795 
5796   // Handle binops special cases.
5797   if (NumOps == 2) {
5798     if (SDValue CFP = foldConstantFPMath(Opcode, DL, VT, Ops[0], Ops[1]))
5799       return CFP;
5800 
5801     if (auto *C1 = dyn_cast<ConstantSDNode>(Ops[0])) {
5802       if (auto *C2 = dyn_cast<ConstantSDNode>(Ops[1])) {
5803         if (C1->isOpaque() || C2->isOpaque())
5804           return SDValue();
5805 
5806         std::optional<APInt> FoldAttempt =
5807             FoldValue(Opcode, C1->getAPIntValue(), C2->getAPIntValue());
5808         if (!FoldAttempt)
5809           return SDValue();
5810 
5811         SDValue Folded = getConstant(*FoldAttempt, DL, VT);
5812         assert((!Folded || !VT.isVector()) &&
5813                "Can't fold vectors ops with scalar operands");
5814         return Folded;
5815       }
5816     }
5817 
5818     // fold (add Sym, c) -> Sym+c
5819     if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Ops[0]))
5820       return FoldSymbolOffset(Opcode, VT, GA, Ops[1].getNode());
5821     if (TLI->isCommutativeBinOp(Opcode))
5822       if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Ops[1]))
5823         return FoldSymbolOffset(Opcode, VT, GA, Ops[0].getNode());
5824   }
5825 
5826   // This is for vector folding only from here on.
5827   if (!VT.isVector())
5828     return SDValue();
5829 
5830   ElementCount NumElts = VT.getVectorElementCount();
5831 
5832   // See if we can fold through bitcasted integer ops.
5833   if (NumOps == 2 && VT.isFixedLengthVector() && VT.isInteger() &&
5834       Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
5835       Ops[0].getOpcode() == ISD::BITCAST &&
5836       Ops[1].getOpcode() == ISD::BITCAST) {
5837     SDValue N1 = peekThroughBitcasts(Ops[0]);
5838     SDValue N2 = peekThroughBitcasts(Ops[1]);
5839     auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
5840     auto *BV2 = dyn_cast<BuildVectorSDNode>(N2);
5841     EVT BVVT = N1.getValueType();
5842     if (BV1 && BV2 && BVVT.isInteger() && BVVT == N2.getValueType()) {
5843       bool IsLE = getDataLayout().isLittleEndian();
5844       unsigned EltBits = VT.getScalarSizeInBits();
5845       SmallVector<APInt> RawBits1, RawBits2;
5846       BitVector UndefElts1, UndefElts2;
5847       if (BV1->getConstantRawBits(IsLE, EltBits, RawBits1, UndefElts1) &&
5848           BV2->getConstantRawBits(IsLE, EltBits, RawBits2, UndefElts2)) {
5849         SmallVector<APInt> RawBits;
5850         for (unsigned I = 0, E = NumElts.getFixedValue(); I != E; ++I) {
5851           std::optional<APInt> Fold = FoldValueWithUndef(
5852               Opcode, RawBits1[I], UndefElts1[I], RawBits2[I], UndefElts2[I]);
5853           if (!Fold)
5854             break;
5855           RawBits.push_back(*Fold);
5856         }
5857         if (RawBits.size() == NumElts.getFixedValue()) {
5858           // We have constant folded, but we need to cast this again back to
5859           // the original (possibly legalized) type.
5860           SmallVector<APInt> DstBits;
5861           BitVector DstUndefs;
5862           BuildVectorSDNode::recastRawBits(IsLE, BVVT.getScalarSizeInBits(),
5863                                            DstBits, RawBits, DstUndefs,
5864                                            BitVector(RawBits.size(), false));
5865           EVT BVEltVT = BV1->getOperand(0).getValueType();
5866           unsigned BVEltBits = BVEltVT.getSizeInBits();
5867           SmallVector<SDValue> Ops(DstBits.size(), getUNDEF(BVEltVT));
5868           for (unsigned I = 0, E = DstBits.size(); I != E; ++I) {
5869             if (DstUndefs[I])
5870               continue;
5871             Ops[I] = getConstant(DstBits[I].sext(BVEltBits), DL, BVEltVT);
5872           }
5873           return getBitcast(VT, getBuildVector(BVVT, DL, Ops));
5874         }
5875       }
5876     }
5877   }
5878 
5879   // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
5880   //      (shl step_vector(C0), C1) -> (step_vector(C0 << C1))
5881   if ((Opcode == ISD::MUL || Opcode == ISD::SHL) &&
5882       Ops[0].getOpcode() == ISD::STEP_VECTOR) {
5883     APInt RHSVal;
5884     if (ISD::isConstantSplatVector(Ops[1].getNode(), RHSVal)) {
5885       APInt NewStep = Opcode == ISD::MUL
5886                           ? Ops[0].getConstantOperandAPInt(0) * RHSVal
5887                           : Ops[0].getConstantOperandAPInt(0) << RHSVal;
5888       return getStepVector(DL, VT, NewStep);
5889     }
5890   }
5891 
5892   auto IsScalarOrSameVectorSize = [NumElts](const SDValue &Op) {
5893     return !Op.getValueType().isVector() ||
5894            Op.getValueType().getVectorElementCount() == NumElts;
5895   };
5896 
5897   auto IsBuildVectorSplatVectorOrUndef = [](const SDValue &Op) {
5898     return Op.isUndef() || Op.getOpcode() == ISD::CONDCODE ||
5899            Op.getOpcode() == ISD::BUILD_VECTOR ||
5900            Op.getOpcode() == ISD::SPLAT_VECTOR;
5901   };
5902 
5903   // All operands must be vector types with the same number of elements as
5904   // the result type and must be either UNDEF or a build/splat vector
5905   // or UNDEF scalars.
5906   if (!llvm::all_of(Ops, IsBuildVectorSplatVectorOrUndef) ||
5907       !llvm::all_of(Ops, IsScalarOrSameVectorSize))
5908     return SDValue();
5909 
5910   // If we are comparing vectors, then the result needs to be a i1 boolean that
5911   // is then extended back to the legal result type depending on how booleans
5912   // are represented.
5913   EVT SVT = (Opcode == ISD::SETCC ? MVT::i1 : VT.getScalarType());
5914   ISD::NodeType ExtendCode =
5915       (Opcode == ISD::SETCC && SVT != VT.getScalarType())
5916           ? TargetLowering::getExtendForContent(TLI->getBooleanContents(VT))
5917           : ISD::SIGN_EXTEND;
5918 
5919   // Find legal integer scalar type for constant promotion and
5920   // ensure that its scalar size is at least as large as source.
5921   EVT LegalSVT = VT.getScalarType();
5922   if (NewNodesMustHaveLegalTypes && LegalSVT.isInteger()) {
5923     LegalSVT = TLI->getTypeToTransformTo(*getContext(), LegalSVT);
5924     if (LegalSVT.bitsLT(VT.getScalarType()))
5925       return SDValue();
5926   }
5927 
5928   // For scalable vector types we know we're dealing with SPLAT_VECTORs. We
5929   // only have one operand to check. For fixed-length vector types we may have
5930   // a combination of BUILD_VECTOR and SPLAT_VECTOR.
5931   unsigned NumVectorElts = NumElts.isScalable() ? 1 : NumElts.getFixedValue();
5932 
5933   // Constant fold each scalar lane separately.
5934   SmallVector<SDValue, 4> ScalarResults;
5935   for (unsigned I = 0; I != NumVectorElts; I++) {
5936     SmallVector<SDValue, 4> ScalarOps;
5937     for (SDValue Op : Ops) {
5938       EVT InSVT = Op.getValueType().getScalarType();
5939       if (Op.getOpcode() != ISD::BUILD_VECTOR &&
5940           Op.getOpcode() != ISD::SPLAT_VECTOR) {
5941         if (Op.isUndef())
5942           ScalarOps.push_back(getUNDEF(InSVT));
5943         else
5944           ScalarOps.push_back(Op);
5945         continue;
5946       }
5947 
5948       SDValue ScalarOp =
5949           Op.getOperand(Op.getOpcode() == ISD::SPLAT_VECTOR ? 0 : I);
5950       EVT ScalarVT = ScalarOp.getValueType();
5951 
5952       // Build vector (integer) scalar operands may need implicit
5953       // truncation - do this before constant folding.
5954       if (ScalarVT.isInteger() && ScalarVT.bitsGT(InSVT)) {
5955         // Don't create illegally-typed nodes unless they're constants or undef
5956         // - if we fail to constant fold we can't guarantee the (dead) nodes
5957         // we're creating will be cleaned up before being visited for
5958         // legalization.
5959         if (NewNodesMustHaveLegalTypes && !ScalarOp.isUndef() &&
5960             !isa<ConstantSDNode>(ScalarOp) &&
5961             TLI->getTypeAction(*getContext(), InSVT) !=
5962                 TargetLowering::TypeLegal)
5963           return SDValue();
5964         ScalarOp = getNode(ISD::TRUNCATE, DL, InSVT, ScalarOp);
5965       }
5966 
5967       ScalarOps.push_back(ScalarOp);
5968     }
5969 
5970     // Constant fold the scalar operands.
5971     SDValue ScalarResult = getNode(Opcode, DL, SVT, ScalarOps);
5972 
5973     // Legalize the (integer) scalar constant if necessary.
5974     if (LegalSVT != SVT)
5975       ScalarResult = getNode(ExtendCode, DL, LegalSVT, ScalarResult);
5976 
5977     // Scalar folding only succeeded if the result is a constant or UNDEF.
5978     if (!ScalarResult.isUndef() && ScalarResult.getOpcode() != ISD::Constant &&
5979         ScalarResult.getOpcode() != ISD::ConstantFP)
5980       return SDValue();
5981     ScalarResults.push_back(ScalarResult);
5982   }
5983 
5984   SDValue V = NumElts.isScalable() ? getSplatVector(VT, DL, ScalarResults[0])
5985                                    : getBuildVector(VT, DL, ScalarResults);
5986   NewSDValueDbgMsg(V, "New node fold constant vector: ", this);
5987   return V;
5988 }
5989 
foldConstantFPMath(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2)5990 SDValue SelectionDAG::foldConstantFPMath(unsigned Opcode, const SDLoc &DL,
5991                                          EVT VT, SDValue N1, SDValue N2) {
5992   // TODO: We don't do any constant folding for strict FP opcodes here, but we
5993   //       should. That will require dealing with a potentially non-default
5994   //       rounding mode, checking the "opStatus" return value from the APFloat
5995   //       math calculations, and possibly other variations.
5996   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, /*AllowUndefs*/ false);
5997   ConstantFPSDNode *N2CFP = isConstOrConstSplatFP(N2, /*AllowUndefs*/ false);
5998   if (N1CFP && N2CFP) {
5999     APFloat C1 = N1CFP->getValueAPF(); // make copy
6000     const APFloat &C2 = N2CFP->getValueAPF();
6001     switch (Opcode) {
6002     case ISD::FADD:
6003       C1.add(C2, APFloat::rmNearestTiesToEven);
6004       return getConstantFP(C1, DL, VT);
6005     case ISD::FSUB:
6006       C1.subtract(C2, APFloat::rmNearestTiesToEven);
6007       return getConstantFP(C1, DL, VT);
6008     case ISD::FMUL:
6009       C1.multiply(C2, APFloat::rmNearestTiesToEven);
6010       return getConstantFP(C1, DL, VT);
6011     case ISD::FDIV:
6012       C1.divide(C2, APFloat::rmNearestTiesToEven);
6013       return getConstantFP(C1, DL, VT);
6014     case ISD::FREM:
6015       C1.mod(C2);
6016       return getConstantFP(C1, DL, VT);
6017     case ISD::FCOPYSIGN:
6018       C1.copySign(C2);
6019       return getConstantFP(C1, DL, VT);
6020     case ISD::FMINNUM:
6021       return getConstantFP(minnum(C1, C2), DL, VT);
6022     case ISD::FMAXNUM:
6023       return getConstantFP(maxnum(C1, C2), DL, VT);
6024     case ISD::FMINIMUM:
6025       return getConstantFP(minimum(C1, C2), DL, VT);
6026     case ISD::FMAXIMUM:
6027       return getConstantFP(maximum(C1, C2), DL, VT);
6028     default: break;
6029     }
6030   }
6031   if (N1CFP && Opcode == ISD::FP_ROUND) {
6032     APFloat C1 = N1CFP->getValueAPF();    // make copy
6033     bool Unused;
6034     // This can return overflow, underflow, or inexact; we don't care.
6035     // FIXME need to be more flexible about rounding mode.
6036     (void) C1.convert(EVTToAPFloatSemantics(VT), APFloat::rmNearestTiesToEven,
6037                       &Unused);
6038     return getConstantFP(C1, DL, VT);
6039   }
6040 
6041   switch (Opcode) {
6042   case ISD::FSUB:
6043     // -0.0 - undef --> undef (consistent with "fneg undef")
6044     if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, /*AllowUndefs*/ true))
6045       if (N1C && N1C->getValueAPF().isNegZero() && N2.isUndef())
6046         return getUNDEF(VT);
6047     [[fallthrough]];
6048 
6049   case ISD::FADD:
6050   case ISD::FMUL:
6051   case ISD::FDIV:
6052   case ISD::FREM:
6053     // If both operands are undef, the result is undef. If 1 operand is undef,
6054     // the result is NaN. This should match the behavior of the IR optimizer.
6055     if (N1.isUndef() && N2.isUndef())
6056       return getUNDEF(VT);
6057     if (N1.isUndef() || N2.isUndef())
6058       return getConstantFP(APFloat::getNaN(EVTToAPFloatSemantics(VT)), DL, VT);
6059   }
6060   return SDValue();
6061 }
6062 
getAssertAlign(const SDLoc & DL,SDValue Val,Align A)6063 SDValue SelectionDAG::getAssertAlign(const SDLoc &DL, SDValue Val, Align A) {
6064   assert(Val.getValueType().isInteger() && "Invalid AssertAlign!");
6065 
6066   // There's no need to assert on a byte-aligned pointer. All pointers are at
6067   // least byte aligned.
6068   if (A == Align(1))
6069     return Val;
6070 
6071   FoldingSetNodeID ID;
6072   AddNodeIDNode(ID, ISD::AssertAlign, getVTList(Val.getValueType()), {Val});
6073   ID.AddInteger(A.value());
6074 
6075   void *IP = nullptr;
6076   if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
6077     return SDValue(E, 0);
6078 
6079   auto *N = newSDNode<AssertAlignSDNode>(DL.getIROrder(), DL.getDebugLoc(),
6080                                          Val.getValueType(), A);
6081   createOperands(N, {Val});
6082 
6083   CSEMap.InsertNode(N, IP);
6084   InsertNode(N);
6085 
6086   SDValue V(N, 0);
6087   NewSDValueDbgMsg(V, "Creating new node: ", this);
6088   return V;
6089 }
6090 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2)6091 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
6092                               SDValue N1, SDValue N2) {
6093   SDNodeFlags Flags;
6094   if (Inserter)
6095     Flags = Inserter->getFlags();
6096   return getNode(Opcode, DL, VT, N1, N2, Flags);
6097 }
6098 
canonicalizeCommutativeBinop(unsigned Opcode,SDValue & N1,SDValue & N2) const6099 void SelectionDAG::canonicalizeCommutativeBinop(unsigned Opcode, SDValue &N1,
6100                                                 SDValue &N2) const {
6101   if (!TLI->isCommutativeBinOp(Opcode))
6102     return;
6103 
6104   // Canonicalize:
6105   //   binop(const, nonconst) -> binop(nonconst, const)
6106   SDNode *N1C = isConstantIntBuildVectorOrConstantInt(N1);
6107   SDNode *N2C = isConstantIntBuildVectorOrConstantInt(N2);
6108   SDNode *N1CFP = isConstantFPBuildVectorOrConstantFP(N1);
6109   SDNode *N2CFP = isConstantFPBuildVectorOrConstantFP(N2);
6110   if ((N1C && !N2C) || (N1CFP && !N2CFP))
6111     std::swap(N1, N2);
6112 
6113   // Canonicalize:
6114   //  binop(splat(x), step_vector) -> binop(step_vector, splat(x))
6115   else if (N1.getOpcode() == ISD::SPLAT_VECTOR &&
6116            N2.getOpcode() == ISD::STEP_VECTOR)
6117     std::swap(N1, N2);
6118 }
6119 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2,const SDNodeFlags Flags)6120 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
6121                               SDValue N1, SDValue N2, const SDNodeFlags Flags) {
6122   assert(N1.getOpcode() != ISD::DELETED_NODE &&
6123          N2.getOpcode() != ISD::DELETED_NODE &&
6124          "Operand is DELETED_NODE!");
6125 
6126   canonicalizeCommutativeBinop(Opcode, N1, N2);
6127 
6128   auto *N1C = dyn_cast<ConstantSDNode>(N1);
6129   auto *N2C = dyn_cast<ConstantSDNode>(N2);
6130 
6131   // Don't allow undefs in vector splats - we might be returning N2 when folding
6132   // to zero etc.
6133   ConstantSDNode *N2CV =
6134       isConstOrConstSplat(N2, /*AllowUndefs*/ false, /*AllowTruncation*/ true);
6135 
6136   switch (Opcode) {
6137   default: break;
6138   case ISD::TokenFactor:
6139     assert(VT == MVT::Other && N1.getValueType() == MVT::Other &&
6140            N2.getValueType() == MVT::Other && "Invalid token factor!");
6141     // Fold trivial token factors.
6142     if (N1.getOpcode() == ISD::EntryToken) return N2;
6143     if (N2.getOpcode() == ISD::EntryToken) return N1;
6144     if (N1 == N2) return N1;
6145     break;
6146   case ISD::BUILD_VECTOR: {
6147     // Attempt to simplify BUILD_VECTOR.
6148     SDValue Ops[] = {N1, N2};
6149     if (SDValue V = FoldBUILD_VECTOR(DL, VT, Ops, *this))
6150       return V;
6151     break;
6152   }
6153   case ISD::CONCAT_VECTORS: {
6154     SDValue Ops[] = {N1, N2};
6155     if (SDValue V = foldCONCAT_VECTORS(DL, VT, Ops, *this))
6156       return V;
6157     break;
6158   }
6159   case ISD::AND:
6160     assert(VT.isInteger() && "This operator does not apply to FP types!");
6161     assert(N1.getValueType() == N2.getValueType() &&
6162            N1.getValueType() == VT && "Binary operator types must match!");
6163     // (X & 0) -> 0.  This commonly occurs when legalizing i64 values, so it's
6164     // worth handling here.
6165     if (N2CV && N2CV->isZero())
6166       return N2;
6167     if (N2CV && N2CV->isAllOnes()) // X & -1 -> X
6168       return N1;
6169     break;
6170   case ISD::OR:
6171   case ISD::XOR:
6172   case ISD::ADD:
6173   case ISD::SUB:
6174     assert(VT.isInteger() && "This operator does not apply to FP types!");
6175     assert(N1.getValueType() == N2.getValueType() &&
6176            N1.getValueType() == VT && "Binary operator types must match!");
6177     // (X ^|+- 0) -> X.  This commonly occurs when legalizing i64 values, so
6178     // it's worth handling here.
6179     if (N2CV && N2CV->isZero())
6180       return N1;
6181     if ((Opcode == ISD::ADD || Opcode == ISD::SUB) && VT.isVector() &&
6182         VT.getVectorElementType() == MVT::i1)
6183       return getNode(ISD::XOR, DL, VT, N1, N2);
6184     break;
6185   case ISD::MUL:
6186     assert(VT.isInteger() && "This operator does not apply to FP types!");
6187     assert(N1.getValueType() == N2.getValueType() &&
6188            N1.getValueType() == VT && "Binary operator types must match!");
6189     if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
6190       return getNode(ISD::AND, DL, VT, N1, N2);
6191     if (N2C && (N1.getOpcode() == ISD::VSCALE) && Flags.hasNoSignedWrap()) {
6192       const APInt &MulImm = N1->getConstantOperandAPInt(0);
6193       const APInt &N2CImm = N2C->getAPIntValue();
6194       return getVScale(DL, VT, MulImm * N2CImm);
6195     }
6196     break;
6197   case ISD::UDIV:
6198   case ISD::UREM:
6199   case ISD::MULHU:
6200   case ISD::MULHS:
6201   case ISD::SDIV:
6202   case ISD::SREM:
6203   case ISD::SADDSAT:
6204   case ISD::SSUBSAT:
6205   case ISD::UADDSAT:
6206   case ISD::USUBSAT:
6207     assert(VT.isInteger() && "This operator does not apply to FP types!");
6208     assert(N1.getValueType() == N2.getValueType() &&
6209            N1.getValueType() == VT && "Binary operator types must match!");
6210     if (VT.isVector() && VT.getVectorElementType() == MVT::i1) {
6211       // fold (add_sat x, y) -> (or x, y) for bool types.
6212       if (Opcode == ISD::SADDSAT || Opcode == ISD::UADDSAT)
6213         return getNode(ISD::OR, DL, VT, N1, N2);
6214       // fold (sub_sat x, y) -> (and x, ~y) for bool types.
6215       if (Opcode == ISD::SSUBSAT || Opcode == ISD::USUBSAT)
6216         return getNode(ISD::AND, DL, VT, N1, getNOT(DL, N2, VT));
6217     }
6218     break;
6219   case ISD::ABDS:
6220   case ISD::ABDU:
6221     assert(VT.isInteger() && "This operator does not apply to FP types!");
6222     assert(N1.getValueType() == N2.getValueType() &&
6223            N1.getValueType() == VT && "Binary operator types must match!");
6224     break;
6225   case ISD::SMIN:
6226   case ISD::UMAX:
6227     assert(VT.isInteger() && "This operator does not apply to FP types!");
6228     assert(N1.getValueType() == N2.getValueType() &&
6229            N1.getValueType() == VT && "Binary operator types must match!");
6230     if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
6231       return getNode(ISD::OR, DL, VT, N1, N2);
6232     break;
6233   case ISD::SMAX:
6234   case ISD::UMIN:
6235     assert(VT.isInteger() && "This operator does not apply to FP types!");
6236     assert(N1.getValueType() == N2.getValueType() &&
6237            N1.getValueType() == VT && "Binary operator types must match!");
6238     if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
6239       return getNode(ISD::AND, DL, VT, N1, N2);
6240     break;
6241   case ISD::FADD:
6242   case ISD::FSUB:
6243   case ISD::FMUL:
6244   case ISD::FDIV:
6245   case ISD::FREM:
6246     assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
6247     assert(N1.getValueType() == N2.getValueType() &&
6248            N1.getValueType() == VT && "Binary operator types must match!");
6249     if (SDValue V = simplifyFPBinop(Opcode, N1, N2, Flags))
6250       return V;
6251     break;
6252   case ISD::FCOPYSIGN:   // N1 and result must match.  N1/N2 need not match.
6253     assert(N1.getValueType() == VT &&
6254            N1.getValueType().isFloatingPoint() &&
6255            N2.getValueType().isFloatingPoint() &&
6256            "Invalid FCOPYSIGN!");
6257     break;
6258   case ISD::SHL:
6259     if (N2C && (N1.getOpcode() == ISD::VSCALE) && Flags.hasNoSignedWrap()) {
6260       const APInt &MulImm = N1->getConstantOperandAPInt(0);
6261       const APInt &ShiftImm = N2C->getAPIntValue();
6262       return getVScale(DL, VT, MulImm << ShiftImm);
6263     }
6264     [[fallthrough]];
6265   case ISD::SRA:
6266   case ISD::SRL:
6267     if (SDValue V = simplifyShift(N1, N2))
6268       return V;
6269     [[fallthrough]];
6270   case ISD::ROTL:
6271   case ISD::ROTR:
6272     assert(VT == N1.getValueType() &&
6273            "Shift operators return type must be the same as their first arg");
6274     assert(VT.isInteger() && N2.getValueType().isInteger() &&
6275            "Shifts only work on integers");
6276     assert((!VT.isVector() || VT == N2.getValueType()) &&
6277            "Vector shift amounts must be in the same as their first arg");
6278     // Verify that the shift amount VT is big enough to hold valid shift
6279     // amounts.  This catches things like trying to shift an i1024 value by an
6280     // i8, which is easy to fall into in generic code that uses
6281     // TLI.getShiftAmount().
6282     assert(N2.getValueType().getScalarSizeInBits() >=
6283                Log2_32_Ceil(VT.getScalarSizeInBits()) &&
6284            "Invalid use of small shift amount with oversized value!");
6285 
6286     // Always fold shifts of i1 values so the code generator doesn't need to
6287     // handle them.  Since we know the size of the shift has to be less than the
6288     // size of the value, the shift/rotate count is guaranteed to be zero.
6289     if (VT == MVT::i1)
6290       return N1;
6291     if (N2CV && N2CV->isZero())
6292       return N1;
6293     break;
6294   case ISD::FP_ROUND:
6295     assert(VT.isFloatingPoint() &&
6296            N1.getValueType().isFloatingPoint() &&
6297            VT.bitsLE(N1.getValueType()) &&
6298            N2C && (N2C->getZExtValue() == 0 || N2C->getZExtValue() == 1) &&
6299            "Invalid FP_ROUND!");
6300     if (N1.getValueType() == VT) return N1;  // noop conversion.
6301     break;
6302   case ISD::AssertSext:
6303   case ISD::AssertZext: {
6304     EVT EVT = cast<VTSDNode>(N2)->getVT();
6305     assert(VT == N1.getValueType() && "Not an inreg extend!");
6306     assert(VT.isInteger() && EVT.isInteger() &&
6307            "Cannot *_EXTEND_INREG FP types");
6308     assert(!EVT.isVector() &&
6309            "AssertSExt/AssertZExt type should be the vector element type "
6310            "rather than the vector type!");
6311     assert(EVT.bitsLE(VT.getScalarType()) && "Not extending!");
6312     if (VT.getScalarType() == EVT) return N1; // noop assertion.
6313     break;
6314   }
6315   case ISD::SIGN_EXTEND_INREG: {
6316     EVT EVT = cast<VTSDNode>(N2)->getVT();
6317     assert(VT == N1.getValueType() && "Not an inreg extend!");
6318     assert(VT.isInteger() && EVT.isInteger() &&
6319            "Cannot *_EXTEND_INREG FP types");
6320     assert(EVT.isVector() == VT.isVector() &&
6321            "SIGN_EXTEND_INREG type should be vector iff the operand "
6322            "type is vector!");
6323     assert((!EVT.isVector() ||
6324             EVT.getVectorElementCount() == VT.getVectorElementCount()) &&
6325            "Vector element counts must match in SIGN_EXTEND_INREG");
6326     assert(EVT.bitsLE(VT) && "Not extending!");
6327     if (EVT == VT) return N1;  // Not actually extending
6328 
6329     auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
6330       unsigned FromBits = EVT.getScalarSizeInBits();
6331       Val <<= Val.getBitWidth() - FromBits;
6332       Val.ashrInPlace(Val.getBitWidth() - FromBits);
6333       return getConstant(Val, DL, ConstantVT);
6334     };
6335 
6336     if (N1C) {
6337       const APInt &Val = N1C->getAPIntValue();
6338       return SignExtendInReg(Val, VT);
6339     }
6340 
6341     if (ISD::isBuildVectorOfConstantSDNodes(N1.getNode())) {
6342       SmallVector<SDValue, 8> Ops;
6343       llvm::EVT OpVT = N1.getOperand(0).getValueType();
6344       for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
6345         SDValue Op = N1.getOperand(i);
6346         if (Op.isUndef()) {
6347           Ops.push_back(getUNDEF(OpVT));
6348           continue;
6349         }
6350         ConstantSDNode *C = cast<ConstantSDNode>(Op);
6351         APInt Val = C->getAPIntValue();
6352         Ops.push_back(SignExtendInReg(Val, OpVT));
6353       }
6354       return getBuildVector(VT, DL, Ops);
6355     }
6356     break;
6357   }
6358   case ISD::FP_TO_SINT_SAT:
6359   case ISD::FP_TO_UINT_SAT: {
6360     assert(VT.isInteger() && cast<VTSDNode>(N2)->getVT().isInteger() &&
6361            N1.getValueType().isFloatingPoint() && "Invalid FP_TO_*INT_SAT");
6362     assert(N1.getValueType().isVector() == VT.isVector() &&
6363            "FP_TO_*INT_SAT type should be vector iff the operand type is "
6364            "vector!");
6365     assert((!VT.isVector() || VT.getVectorElementCount() ==
6366                                   N1.getValueType().getVectorElementCount()) &&
6367            "Vector element counts must match in FP_TO_*INT_SAT");
6368     assert(!cast<VTSDNode>(N2)->getVT().isVector() &&
6369            "Type to saturate to must be a scalar.");
6370     assert(cast<VTSDNode>(N2)->getVT().bitsLE(VT.getScalarType()) &&
6371            "Not extending!");
6372     break;
6373   }
6374   case ISD::EXTRACT_VECTOR_ELT:
6375     assert(VT.getSizeInBits() >= N1.getValueType().getScalarSizeInBits() &&
6376            "The result of EXTRACT_VECTOR_ELT must be at least as wide as the \
6377              element type of the vector.");
6378 
6379     // Extract from an undefined value or using an undefined index is undefined.
6380     if (N1.isUndef() || N2.isUndef())
6381       return getUNDEF(VT);
6382 
6383     // EXTRACT_VECTOR_ELT of out-of-bounds element is an UNDEF for fixed length
6384     // vectors. For scalable vectors we will provide appropriate support for
6385     // dealing with arbitrary indices.
6386     if (N2C && N1.getValueType().isFixedLengthVector() &&
6387         N2C->getAPIntValue().uge(N1.getValueType().getVectorNumElements()))
6388       return getUNDEF(VT);
6389 
6390     // EXTRACT_VECTOR_ELT of CONCAT_VECTORS is often formed while lowering is
6391     // expanding copies of large vectors from registers. This only works for
6392     // fixed length vectors, since we need to know the exact number of
6393     // elements.
6394     if (N2C && N1.getOperand(0).getValueType().isFixedLengthVector() &&
6395         N1.getOpcode() == ISD::CONCAT_VECTORS && N1.getNumOperands() > 0) {
6396       unsigned Factor =
6397         N1.getOperand(0).getValueType().getVectorNumElements();
6398       return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
6399                      N1.getOperand(N2C->getZExtValue() / Factor),
6400                      getVectorIdxConstant(N2C->getZExtValue() % Factor, DL));
6401     }
6402 
6403     // EXTRACT_VECTOR_ELT of BUILD_VECTOR or SPLAT_VECTOR is often formed while
6404     // lowering is expanding large vector constants.
6405     if (N2C && (N1.getOpcode() == ISD::BUILD_VECTOR ||
6406                 N1.getOpcode() == ISD::SPLAT_VECTOR)) {
6407       assert((N1.getOpcode() != ISD::BUILD_VECTOR ||
6408               N1.getValueType().isFixedLengthVector()) &&
6409              "BUILD_VECTOR used for scalable vectors");
6410       unsigned Index =
6411           N1.getOpcode() == ISD::BUILD_VECTOR ? N2C->getZExtValue() : 0;
6412       SDValue Elt = N1.getOperand(Index);
6413 
6414       if (VT != Elt.getValueType())
6415         // If the vector element type is not legal, the BUILD_VECTOR operands
6416         // are promoted and implicitly truncated, and the result implicitly
6417         // extended. Make that explicit here.
6418         Elt = getAnyExtOrTrunc(Elt, DL, VT);
6419 
6420       return Elt;
6421     }
6422 
6423     // EXTRACT_VECTOR_ELT of INSERT_VECTOR_ELT is often formed when vector
6424     // operations are lowered to scalars.
6425     if (N1.getOpcode() == ISD::INSERT_VECTOR_ELT) {
6426       // If the indices are the same, return the inserted element else
6427       // if the indices are known different, extract the element from
6428       // the original vector.
6429       SDValue N1Op2 = N1.getOperand(2);
6430       ConstantSDNode *N1Op2C = dyn_cast<ConstantSDNode>(N1Op2);
6431 
6432       if (N1Op2C && N2C) {
6433         if (N1Op2C->getZExtValue() == N2C->getZExtValue()) {
6434           if (VT == N1.getOperand(1).getValueType())
6435             return N1.getOperand(1);
6436           if (VT.isFloatingPoint()) {
6437             assert(VT.getSizeInBits() > N1.getOperand(1).getValueType().getSizeInBits());
6438             return getFPExtendOrRound(N1.getOperand(1), DL, VT);
6439           }
6440           return getSExtOrTrunc(N1.getOperand(1), DL, VT);
6441         }
6442         return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, N1.getOperand(0), N2);
6443       }
6444     }
6445 
6446     // EXTRACT_VECTOR_ELT of v1iX EXTRACT_SUBVECTOR could be formed
6447     // when vector types are scalarized and v1iX is legal.
6448     // vextract (v1iX extract_subvector(vNiX, Idx)) -> vextract(vNiX,Idx).
6449     // Here we are completely ignoring the extract element index (N2),
6450     // which is fine for fixed width vectors, since any index other than 0
6451     // is undefined anyway. However, this cannot be ignored for scalable
6452     // vectors - in theory we could support this, but we don't want to do this
6453     // without a profitability check.
6454     if (N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
6455         N1.getValueType().isFixedLengthVector() &&
6456         N1.getValueType().getVectorNumElements() == 1) {
6457       return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, N1.getOperand(0),
6458                      N1.getOperand(1));
6459     }
6460     break;
6461   case ISD::EXTRACT_ELEMENT:
6462     assert(N2C && (unsigned)N2C->getZExtValue() < 2 && "Bad EXTRACT_ELEMENT!");
6463     assert(!N1.getValueType().isVector() && !VT.isVector() &&
6464            (N1.getValueType().isInteger() == VT.isInteger()) &&
6465            N1.getValueType() != VT &&
6466            "Wrong types for EXTRACT_ELEMENT!");
6467 
6468     // EXTRACT_ELEMENT of BUILD_PAIR is often formed while legalize is expanding
6469     // 64-bit integers into 32-bit parts.  Instead of building the extract of
6470     // the BUILD_PAIR, only to have legalize rip it apart, just do it now.
6471     if (N1.getOpcode() == ISD::BUILD_PAIR)
6472       return N1.getOperand(N2C->getZExtValue());
6473 
6474     // EXTRACT_ELEMENT of a constant int is also very common.
6475     if (N1C) {
6476       unsigned ElementSize = VT.getSizeInBits();
6477       unsigned Shift = ElementSize * N2C->getZExtValue();
6478       const APInt &Val = N1C->getAPIntValue();
6479       return getConstant(Val.extractBits(ElementSize, Shift), DL, VT);
6480     }
6481     break;
6482   case ISD::EXTRACT_SUBVECTOR: {
6483     EVT N1VT = N1.getValueType();
6484     assert(VT.isVector() && N1VT.isVector() &&
6485            "Extract subvector VTs must be vectors!");
6486     assert(VT.getVectorElementType() == N1VT.getVectorElementType() &&
6487            "Extract subvector VTs must have the same element type!");
6488     assert((VT.isFixedLengthVector() || N1VT.isScalableVector()) &&
6489            "Cannot extract a scalable vector from a fixed length vector!");
6490     assert((VT.isScalableVector() != N1VT.isScalableVector() ||
6491             VT.getVectorMinNumElements() <= N1VT.getVectorMinNumElements()) &&
6492            "Extract subvector must be from larger vector to smaller vector!");
6493     assert(N2C && "Extract subvector index must be a constant");
6494     assert((VT.isScalableVector() != N1VT.isScalableVector() ||
6495             (VT.getVectorMinNumElements() + N2C->getZExtValue()) <=
6496                 N1VT.getVectorMinNumElements()) &&
6497            "Extract subvector overflow!");
6498     assert(N2C->getAPIntValue().getBitWidth() ==
6499                TLI->getVectorIdxTy(getDataLayout()).getFixedSizeInBits() &&
6500            "Constant index for EXTRACT_SUBVECTOR has an invalid size");
6501 
6502     // Trivial extraction.
6503     if (VT == N1VT)
6504       return N1;
6505 
6506     // EXTRACT_SUBVECTOR of an UNDEF is an UNDEF.
6507     if (N1.isUndef())
6508       return getUNDEF(VT);
6509 
6510     // EXTRACT_SUBVECTOR of CONCAT_VECTOR can be simplified if the pieces of
6511     // the concat have the same type as the extract.
6512     if (N1.getOpcode() == ISD::CONCAT_VECTORS && N1.getNumOperands() > 0 &&
6513         VT == N1.getOperand(0).getValueType()) {
6514       unsigned Factor = VT.getVectorMinNumElements();
6515       return N1.getOperand(N2C->getZExtValue() / Factor);
6516     }
6517 
6518     // EXTRACT_SUBVECTOR of INSERT_SUBVECTOR is often created
6519     // during shuffle legalization.
6520     if (N1.getOpcode() == ISD::INSERT_SUBVECTOR && N2 == N1.getOperand(2) &&
6521         VT == N1.getOperand(1).getValueType())
6522       return N1.getOperand(1);
6523     break;
6524   }
6525   }
6526 
6527   // Perform trivial constant folding.
6528   if (SDValue SV = FoldConstantArithmetic(Opcode, DL, VT, {N1, N2}))
6529     return SV;
6530 
6531   // Canonicalize an UNDEF to the RHS, even over a constant.
6532   if (N1.isUndef()) {
6533     if (TLI->isCommutativeBinOp(Opcode)) {
6534       std::swap(N1, N2);
6535     } else {
6536       switch (Opcode) {
6537       case ISD::SUB:
6538         return getUNDEF(VT);     // fold op(undef, arg2) -> undef
6539       case ISD::SIGN_EXTEND_INREG:
6540       case ISD::UDIV:
6541       case ISD::SDIV:
6542       case ISD::UREM:
6543       case ISD::SREM:
6544       case ISD::SSUBSAT:
6545       case ISD::USUBSAT:
6546         return getConstant(0, DL, VT);    // fold op(undef, arg2) -> 0
6547       }
6548     }
6549   }
6550 
6551   // Fold a bunch of operators when the RHS is undef.
6552   if (N2.isUndef()) {
6553     switch (Opcode) {
6554     case ISD::XOR:
6555       if (N1.isUndef())
6556         // Handle undef ^ undef -> 0 special case. This is a common
6557         // idiom (misuse).
6558         return getConstant(0, DL, VT);
6559       [[fallthrough]];
6560     case ISD::ADD:
6561     case ISD::SUB:
6562     case ISD::UDIV:
6563     case ISD::SDIV:
6564     case ISD::UREM:
6565     case ISD::SREM:
6566       return getUNDEF(VT);       // fold op(arg1, undef) -> undef
6567     case ISD::MUL:
6568     case ISD::AND:
6569     case ISD::SSUBSAT:
6570     case ISD::USUBSAT:
6571       return getConstant(0, DL, VT);  // fold op(arg1, undef) -> 0
6572     case ISD::OR:
6573     case ISD::SADDSAT:
6574     case ISD::UADDSAT:
6575       return getAllOnesConstant(DL, VT);
6576     }
6577   }
6578 
6579   // Memoize this node if possible.
6580   SDNode *N;
6581   SDVTList VTs = getVTList(VT);
6582   SDValue Ops[] = {N1, N2};
6583   if (VT != MVT::Glue) {
6584     FoldingSetNodeID ID;
6585     AddNodeIDNode(ID, Opcode, VTs, Ops);
6586     void *IP = nullptr;
6587     if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
6588       E->intersectFlagsWith(Flags);
6589       return SDValue(E, 0);
6590     }
6591 
6592     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
6593     N->setFlags(Flags);
6594     createOperands(N, Ops);
6595     CSEMap.InsertNode(N, IP);
6596   } else {
6597     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
6598     createOperands(N, Ops);
6599   }
6600 
6601   InsertNode(N);
6602   SDValue V = SDValue(N, 0);
6603   NewSDValueDbgMsg(V, "Creating new node: ", this);
6604   return V;
6605 }
6606 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2,SDValue N3)6607 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
6608                               SDValue N1, SDValue N2, SDValue N3) {
6609   SDNodeFlags Flags;
6610   if (Inserter)
6611     Flags = Inserter->getFlags();
6612   return getNode(Opcode, DL, VT, N1, N2, N3, Flags);
6613 }
6614 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2,SDValue N3,const SDNodeFlags Flags)6615 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
6616                               SDValue N1, SDValue N2, SDValue N3,
6617                               const SDNodeFlags Flags) {
6618   assert(N1.getOpcode() != ISD::DELETED_NODE &&
6619          N2.getOpcode() != ISD::DELETED_NODE &&
6620          N3.getOpcode() != ISD::DELETED_NODE &&
6621          "Operand is DELETED_NODE!");
6622   // Perform various simplifications.
6623   switch (Opcode) {
6624   case ISD::FMA: {
6625     assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
6626     assert(N1.getValueType() == VT && N2.getValueType() == VT &&
6627            N3.getValueType() == VT && "FMA types must match!");
6628     ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
6629     ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
6630     ConstantFPSDNode *N3CFP = dyn_cast<ConstantFPSDNode>(N3);
6631     if (N1CFP && N2CFP && N3CFP) {
6632       APFloat  V1 = N1CFP->getValueAPF();
6633       const APFloat &V2 = N2CFP->getValueAPF();
6634       const APFloat &V3 = N3CFP->getValueAPF();
6635       V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
6636       return getConstantFP(V1, DL, VT);
6637     }
6638     break;
6639   }
6640   case ISD::BUILD_VECTOR: {
6641     // Attempt to simplify BUILD_VECTOR.
6642     SDValue Ops[] = {N1, N2, N3};
6643     if (SDValue V = FoldBUILD_VECTOR(DL, VT, Ops, *this))
6644       return V;
6645     break;
6646   }
6647   case ISD::CONCAT_VECTORS: {
6648     SDValue Ops[] = {N1, N2, N3};
6649     if (SDValue V = foldCONCAT_VECTORS(DL, VT, Ops, *this))
6650       return V;
6651     break;
6652   }
6653   case ISD::SETCC: {
6654     assert(VT.isInteger() && "SETCC result type must be an integer!");
6655     assert(N1.getValueType() == N2.getValueType() &&
6656            "SETCC operands must have the same type!");
6657     assert(VT.isVector() == N1.getValueType().isVector() &&
6658            "SETCC type should be vector iff the operand type is vector!");
6659     assert((!VT.isVector() || VT.getVectorElementCount() ==
6660                                   N1.getValueType().getVectorElementCount()) &&
6661            "SETCC vector element counts must match!");
6662     // Use FoldSetCC to simplify SETCC's.
6663     if (SDValue V = FoldSetCC(VT, N1, N2, cast<CondCodeSDNode>(N3)->get(), DL))
6664       return V;
6665     // Vector constant folding.
6666     SDValue Ops[] = {N1, N2, N3};
6667     if (SDValue V = FoldConstantArithmetic(Opcode, DL, VT, Ops)) {
6668       NewSDValueDbgMsg(V, "New node vector constant folding: ", this);
6669       return V;
6670     }
6671     break;
6672   }
6673   case ISD::SELECT:
6674   case ISD::VSELECT:
6675     if (SDValue V = simplifySelect(N1, N2, N3))
6676       return V;
6677     break;
6678   case ISD::VECTOR_SHUFFLE:
6679     llvm_unreachable("should use getVectorShuffle constructor!");
6680   case ISD::VECTOR_SPLICE: {
6681     if (cast<ConstantSDNode>(N3)->isNullValue())
6682       return N1;
6683     break;
6684   }
6685   case ISD::INSERT_VECTOR_ELT: {
6686     ConstantSDNode *N3C = dyn_cast<ConstantSDNode>(N3);
6687     // INSERT_VECTOR_ELT into out-of-bounds element is an UNDEF, except
6688     // for scalable vectors where we will generate appropriate code to
6689     // deal with out-of-bounds cases correctly.
6690     if (N3C && N1.getValueType().isFixedLengthVector() &&
6691         N3C->getZExtValue() >= N1.getValueType().getVectorNumElements())
6692       return getUNDEF(VT);
6693 
6694     // Undefined index can be assumed out-of-bounds, so that's UNDEF too.
6695     if (N3.isUndef())
6696       return getUNDEF(VT);
6697 
6698     // If the inserted element is an UNDEF, just use the input vector.
6699     if (N2.isUndef())
6700       return N1;
6701 
6702     break;
6703   }
6704   case ISD::INSERT_SUBVECTOR: {
6705     // Inserting undef into undef is still undef.
6706     if (N1.isUndef() && N2.isUndef())
6707       return getUNDEF(VT);
6708 
6709     EVT N2VT = N2.getValueType();
6710     assert(VT == N1.getValueType() &&
6711            "Dest and insert subvector source types must match!");
6712     assert(VT.isVector() && N2VT.isVector() &&
6713            "Insert subvector VTs must be vectors!");
6714     assert(VT.getVectorElementType() == N2VT.getVectorElementType() &&
6715            "Insert subvector VTs must have the same element type!");
6716     assert((VT.isScalableVector() || N2VT.isFixedLengthVector()) &&
6717            "Cannot insert a scalable vector into a fixed length vector!");
6718     assert((VT.isScalableVector() != N2VT.isScalableVector() ||
6719             VT.getVectorMinNumElements() >= N2VT.getVectorMinNumElements()) &&
6720            "Insert subvector must be from smaller vector to larger vector!");
6721     assert(isa<ConstantSDNode>(N3) &&
6722            "Insert subvector index must be constant");
6723     assert((VT.isScalableVector() != N2VT.isScalableVector() ||
6724             (N2VT.getVectorMinNumElements() +
6725              cast<ConstantSDNode>(N3)->getZExtValue()) <=
6726                 VT.getVectorMinNumElements()) &&
6727            "Insert subvector overflow!");
6728     assert(cast<ConstantSDNode>(N3)->getAPIntValue().getBitWidth() ==
6729                TLI->getVectorIdxTy(getDataLayout()).getFixedSizeInBits() &&
6730            "Constant index for INSERT_SUBVECTOR has an invalid size");
6731 
6732     // Trivial insertion.
6733     if (VT == N2VT)
6734       return N2;
6735 
6736     // If this is an insert of an extracted vector into an undef vector, we
6737     // can just use the input to the extract.
6738     if (N1.isUndef() && N2.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
6739         N2.getOperand(1) == N3 && N2.getOperand(0).getValueType() == VT)
6740       return N2.getOperand(0);
6741     break;
6742   }
6743   case ISD::BITCAST:
6744     // Fold bit_convert nodes from a type to themselves.
6745     if (N1.getValueType() == VT)
6746       return N1;
6747     break;
6748   }
6749 
6750   // Memoize node if it doesn't produce a flag.
6751   SDNode *N;
6752   SDVTList VTs = getVTList(VT);
6753   SDValue Ops[] = {N1, N2, N3};
6754   if (VT != MVT::Glue) {
6755     FoldingSetNodeID ID;
6756     AddNodeIDNode(ID, Opcode, VTs, Ops);
6757     void *IP = nullptr;
6758     if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
6759       E->intersectFlagsWith(Flags);
6760       return SDValue(E, 0);
6761     }
6762 
6763     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
6764     N->setFlags(Flags);
6765     createOperands(N, Ops);
6766     CSEMap.InsertNode(N, IP);
6767   } else {
6768     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
6769     createOperands(N, Ops);
6770   }
6771 
6772   InsertNode(N);
6773   SDValue V = SDValue(N, 0);
6774   NewSDValueDbgMsg(V, "Creating new node: ", this);
6775   return V;
6776 }
6777 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2,SDValue N3,SDValue N4)6778 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
6779                               SDValue N1, SDValue N2, SDValue N3, SDValue N4) {
6780   SDValue Ops[] = { N1, N2, N3, N4 };
6781   return getNode(Opcode, DL, VT, Ops);
6782 }
6783 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N1,SDValue N2,SDValue N3,SDValue N4,SDValue N5)6784 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
6785                               SDValue N1, SDValue N2, SDValue N3, SDValue N4,
6786                               SDValue N5) {
6787   SDValue Ops[] = { N1, N2, N3, N4, N5 };
6788   return getNode(Opcode, DL, VT, Ops);
6789 }
6790 
6791 /// getStackArgumentTokenFactor - Compute a TokenFactor to force all
6792 /// the incoming stack arguments to be loaded from the stack.
getStackArgumentTokenFactor(SDValue Chain)6793 SDValue SelectionDAG::getStackArgumentTokenFactor(SDValue Chain) {
6794   SmallVector<SDValue, 8> ArgChains;
6795 
6796   // Include the original chain at the beginning of the list. When this is
6797   // used by target LowerCall hooks, this helps legalize find the
6798   // CALLSEQ_BEGIN node.
6799   ArgChains.push_back(Chain);
6800 
6801   // Add a chain value for each stack argument.
6802   for (SDNode *U : getEntryNode().getNode()->uses())
6803     if (LoadSDNode *L = dyn_cast<LoadSDNode>(U))
6804       if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr()))
6805         if (FI->getIndex() < 0)
6806           ArgChains.push_back(SDValue(L, 1));
6807 
6808   // Build a tokenfactor for all the chains.
6809   return getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
6810 }
6811 
6812 /// getMemsetValue - Vectorized representation of the memset value
6813 /// operand.
getMemsetValue(SDValue Value,EVT VT,SelectionDAG & DAG,const SDLoc & dl)6814 static SDValue getMemsetValue(SDValue Value, EVT VT, SelectionDAG &DAG,
6815                               const SDLoc &dl) {
6816   assert(!Value.isUndef());
6817 
6818   unsigned NumBits = VT.getScalarSizeInBits();
6819   if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Value)) {
6820     assert(C->getAPIntValue().getBitWidth() == 8);
6821     APInt Val = APInt::getSplat(NumBits, C->getAPIntValue());
6822     if (VT.isInteger()) {
6823       bool IsOpaque = VT.getSizeInBits() > 64 ||
6824           !DAG.getTargetLoweringInfo().isLegalStoreImmediate(C->getSExtValue());
6825       return DAG.getConstant(Val, dl, VT, false, IsOpaque);
6826     }
6827     return DAG.getConstantFP(APFloat(DAG.EVTToAPFloatSemantics(VT), Val), dl,
6828                              VT);
6829   }
6830 
6831   assert(Value.getValueType() == MVT::i8 && "memset with non-byte fill value?");
6832   EVT IntVT = VT.getScalarType();
6833   if (!IntVT.isInteger())
6834     IntVT = EVT::getIntegerVT(*DAG.getContext(), IntVT.getSizeInBits());
6835 
6836   Value = DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, Value);
6837   if (NumBits > 8) {
6838     // Use a multiplication with 0x010101... to extend the input to the
6839     // required length.
6840     APInt Magic = APInt::getSplat(NumBits, APInt(8, 0x01));
6841     Value = DAG.getNode(ISD::MUL, dl, IntVT, Value,
6842                         DAG.getConstant(Magic, dl, IntVT));
6843   }
6844 
6845   if (VT != Value.getValueType() && !VT.isInteger())
6846     Value = DAG.getBitcast(VT.getScalarType(), Value);
6847   if (VT != Value.getValueType())
6848     Value = DAG.getSplatBuildVector(VT, dl, Value);
6849 
6850   return Value;
6851 }
6852 
6853 /// getMemsetStringVal - Similar to getMemsetValue. Except this is only
6854 /// used when a memcpy is turned into a memset when the source is a constant
6855 /// string ptr.
getMemsetStringVal(EVT VT,const SDLoc & dl,SelectionDAG & DAG,const TargetLowering & TLI,const ConstantDataArraySlice & Slice)6856 static SDValue getMemsetStringVal(EVT VT, const SDLoc &dl, SelectionDAG &DAG,
6857                                   const TargetLowering &TLI,
6858                                   const ConstantDataArraySlice &Slice) {
6859   // Handle vector with all elements zero.
6860   if (Slice.Array == nullptr) {
6861     if (VT.isInteger())
6862       return DAG.getConstant(0, dl, VT);
6863     if (VT == MVT::f32 || VT == MVT::f64 || VT == MVT::f128)
6864       return DAG.getConstantFP(0.0, dl, VT);
6865     if (VT.isVector()) {
6866       unsigned NumElts = VT.getVectorNumElements();
6867       MVT EltVT = (VT.getVectorElementType() == MVT::f32) ? MVT::i32 : MVT::i64;
6868       return DAG.getNode(ISD::BITCAST, dl, VT,
6869                          DAG.getConstant(0, dl,
6870                                          EVT::getVectorVT(*DAG.getContext(),
6871                                                           EltVT, NumElts)));
6872     }
6873     llvm_unreachable("Expected type!");
6874   }
6875 
6876   assert(!VT.isVector() && "Can't handle vector type here!");
6877   unsigned NumVTBits = VT.getSizeInBits();
6878   unsigned NumVTBytes = NumVTBits / 8;
6879   unsigned NumBytes = std::min(NumVTBytes, unsigned(Slice.Length));
6880 
6881   APInt Val(NumVTBits, 0);
6882   if (DAG.getDataLayout().isLittleEndian()) {
6883     for (unsigned i = 0; i != NumBytes; ++i)
6884       Val |= (uint64_t)(unsigned char)Slice[i] << i*8;
6885   } else {
6886     for (unsigned i = 0; i != NumBytes; ++i)
6887       Val |= (uint64_t)(unsigned char)Slice[i] << (NumVTBytes-i-1)*8;
6888   }
6889 
6890   // If the "cost" of materializing the integer immediate is less than the cost
6891   // of a load, then it is cost effective to turn the load into the immediate.
6892   Type *Ty = VT.getTypeForEVT(*DAG.getContext());
6893   if (TLI.shouldConvertConstantLoadToIntImm(Val, Ty))
6894     return DAG.getConstant(Val, dl, VT);
6895   return SDValue();
6896 }
6897 
getMemBasePlusOffset(SDValue Base,TypeSize Offset,const SDLoc & DL,const SDNodeFlags Flags)6898 SDValue SelectionDAG::getMemBasePlusOffset(SDValue Base, TypeSize Offset,
6899                                            const SDLoc &DL,
6900                                            const SDNodeFlags Flags) {
6901   EVT VT = Base.getValueType();
6902   SDValue Index;
6903 
6904   if (Offset.isScalable())
6905     Index = getVScale(DL, Base.getValueType(),
6906                       APInt(Base.getValueSizeInBits().getFixedValue(),
6907                             Offset.getKnownMinValue()));
6908   else
6909     Index = getConstant(Offset.getFixedValue(), DL, VT);
6910 
6911   return getMemBasePlusOffset(Base, Index, DL, Flags);
6912 }
6913 
getMemBasePlusOffset(SDValue Ptr,SDValue Offset,const SDLoc & DL,const SDNodeFlags Flags)6914 SDValue SelectionDAG::getMemBasePlusOffset(SDValue Ptr, SDValue Offset,
6915                                            const SDLoc &DL,
6916                                            const SDNodeFlags Flags) {
6917   assert(Offset.getValueType().isInteger());
6918   EVT BasePtrVT = Ptr.getValueType();
6919   return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, Flags);
6920 }
6921 
6922 /// Returns true if memcpy source is constant data.
isMemSrcFromConstant(SDValue Src,ConstantDataArraySlice & Slice)6923 static bool isMemSrcFromConstant(SDValue Src, ConstantDataArraySlice &Slice) {
6924   uint64_t SrcDelta = 0;
6925   GlobalAddressSDNode *G = nullptr;
6926   if (Src.getOpcode() == ISD::GlobalAddress)
6927     G = cast<GlobalAddressSDNode>(Src);
6928   else if (Src.getOpcode() == ISD::ADD &&
6929            Src.getOperand(0).getOpcode() == ISD::GlobalAddress &&
6930            Src.getOperand(1).getOpcode() == ISD::Constant) {
6931     G = cast<GlobalAddressSDNode>(Src.getOperand(0));
6932     SrcDelta = cast<ConstantSDNode>(Src.getOperand(1))->getZExtValue();
6933   }
6934   if (!G)
6935     return false;
6936 
6937   return getConstantDataArrayInfo(G->getGlobal(), Slice, 8,
6938                                   SrcDelta + G->getOffset());
6939 }
6940 
shouldLowerMemFuncForSize(const MachineFunction & MF,SelectionDAG & DAG)6941 static bool shouldLowerMemFuncForSize(const MachineFunction &MF,
6942                                       SelectionDAG &DAG) {
6943   // On Darwin, -Os means optimize for size without hurting performance, so
6944   // only really optimize for size when -Oz (MinSize) is used.
6945   if (MF.getTarget().getTargetTriple().isOSDarwin())
6946     return MF.getFunction().hasMinSize();
6947   return DAG.shouldOptForSize();
6948 }
6949 
chainLoadsAndStoresForMemcpy(SelectionDAG & DAG,const SDLoc & dl,SmallVector<SDValue,32> & OutChains,unsigned From,unsigned To,SmallVector<SDValue,16> & OutLoadChains,SmallVector<SDValue,16> & OutStoreChains)6950 static void chainLoadsAndStoresForMemcpy(SelectionDAG &DAG, const SDLoc &dl,
6951                           SmallVector<SDValue, 32> &OutChains, unsigned From,
6952                           unsigned To, SmallVector<SDValue, 16> &OutLoadChains,
6953                           SmallVector<SDValue, 16> &OutStoreChains) {
6954   assert(OutLoadChains.size() && "Missing loads in memcpy inlining");
6955   assert(OutStoreChains.size() && "Missing stores in memcpy inlining");
6956   SmallVector<SDValue, 16> GluedLoadChains;
6957   for (unsigned i = From; i < To; ++i) {
6958     OutChains.push_back(OutLoadChains[i]);
6959     GluedLoadChains.push_back(OutLoadChains[i]);
6960   }
6961 
6962   // Chain for all loads.
6963   SDValue LoadToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
6964                                   GluedLoadChains);
6965 
6966   for (unsigned i = From; i < To; ++i) {
6967     StoreSDNode *ST = dyn_cast<StoreSDNode>(OutStoreChains[i]);
6968     SDValue NewStore = DAG.getTruncStore(LoadToken, dl, ST->getValue(),
6969                                   ST->getBasePtr(), ST->getMemoryVT(),
6970                                   ST->getMemOperand());
6971     OutChains.push_back(NewStore);
6972   }
6973 }
6974 
getMemcpyLoadsAndStores(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Dst,SDValue Src,uint64_t Size,Align Alignment,bool isVol,bool AlwaysInline,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo,const AAMDNodes & AAInfo,AAResults * AA)6975 static SDValue getMemcpyLoadsAndStores(SelectionDAG &DAG, const SDLoc &dl,
6976                                        SDValue Chain, SDValue Dst, SDValue Src,
6977                                        uint64_t Size, Align Alignment,
6978                                        bool isVol, bool AlwaysInline,
6979                                        MachinePointerInfo DstPtrInfo,
6980                                        MachinePointerInfo SrcPtrInfo,
6981                                        const AAMDNodes &AAInfo, AAResults *AA) {
6982   // Turn a memcpy of undef to nop.
6983   // FIXME: We need to honor volatile even is Src is undef.
6984   if (Src.isUndef())
6985     return Chain;
6986 
6987   // Expand memcpy to a series of load and store ops if the size operand falls
6988   // below a certain threshold.
6989   // TODO: In the AlwaysInline case, if the size is big then generate a loop
6990   // rather than maybe a humongous number of loads and stores.
6991   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6992   const DataLayout &DL = DAG.getDataLayout();
6993   LLVMContext &C = *DAG.getContext();
6994   std::vector<EVT> MemOps;
6995   bool DstAlignCanChange = false;
6996   MachineFunction &MF = DAG.getMachineFunction();
6997   MachineFrameInfo &MFI = MF.getFrameInfo();
6998   bool OptSize = shouldLowerMemFuncForSize(MF, DAG);
6999   FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(Dst);
7000   if (FI && !MFI.isFixedObjectIndex(FI->getIndex()))
7001     DstAlignCanChange = true;
7002   MaybeAlign SrcAlign = DAG.InferPtrAlign(Src);
7003   if (!SrcAlign || Alignment > *SrcAlign)
7004     SrcAlign = Alignment;
7005   assert(SrcAlign && "SrcAlign must be set");
7006   ConstantDataArraySlice Slice;
7007   // If marked as volatile, perform a copy even when marked as constant.
7008   bool CopyFromConstant = !isVol && isMemSrcFromConstant(Src, Slice);
7009   bool isZeroConstant = CopyFromConstant && Slice.Array == nullptr;
7010   unsigned Limit = AlwaysInline ? ~0U : TLI.getMaxStoresPerMemcpy(OptSize);
7011   const MemOp Op = isZeroConstant
7012                        ? MemOp::Set(Size, DstAlignCanChange, Alignment,
7013                                     /*IsZeroMemset*/ true, isVol)
7014                        : MemOp::Copy(Size, DstAlignCanChange, Alignment,
7015                                      *SrcAlign, isVol, CopyFromConstant);
7016   if (!TLI.findOptimalMemOpLowering(
7017           MemOps, Limit, Op, DstPtrInfo.getAddrSpace(),
7018           SrcPtrInfo.getAddrSpace(), MF.getFunction().getAttributes()))
7019     return SDValue();
7020 
7021   if (DstAlignCanChange) {
7022     Type *Ty = MemOps[0].getTypeForEVT(C);
7023     Align NewAlign = DL.getABITypeAlign(Ty);
7024 
7025     // Don't promote to an alignment that would require dynamic stack
7026     // realignment which may conflict with optimizations such as tail call
7027     // optimization.
7028     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
7029     if (!TRI->hasStackRealignment(MF))
7030       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
7031         NewAlign = NewAlign.previous();
7032 
7033     if (NewAlign > Alignment) {
7034       // Give the stack frame object a larger alignment if needed.
7035       if (MFI.getObjectAlign(FI->getIndex()) < NewAlign)
7036         MFI.setObjectAlignment(FI->getIndex(), NewAlign);
7037       Alignment = NewAlign;
7038     }
7039   }
7040 
7041   // Prepare AAInfo for loads/stores after lowering this memcpy.
7042   AAMDNodes NewAAInfo = AAInfo;
7043   NewAAInfo.TBAA = NewAAInfo.TBAAStruct = nullptr;
7044 
7045   const Value *SrcVal = SrcPtrInfo.V.dyn_cast<const Value *>();
7046   bool isConstant =
7047       AA && SrcVal &&
7048       AA->pointsToConstantMemory(MemoryLocation(SrcVal, Size, AAInfo));
7049 
7050   MachineMemOperand::Flags MMOFlags =
7051       isVol ? MachineMemOperand::MOVolatile : MachineMemOperand::MONone;
7052   SmallVector<SDValue, 16> OutLoadChains;
7053   SmallVector<SDValue, 16> OutStoreChains;
7054   SmallVector<SDValue, 32> OutChains;
7055   unsigned NumMemOps = MemOps.size();
7056   uint64_t SrcOff = 0, DstOff = 0;
7057   for (unsigned i = 0; i != NumMemOps; ++i) {
7058     EVT VT = MemOps[i];
7059     unsigned VTSize = VT.getSizeInBits() / 8;
7060     SDValue Value, Store;
7061 
7062     if (VTSize > Size) {
7063       // Issuing an unaligned load / store pair  that overlaps with the previous
7064       // pair. Adjust the offset accordingly.
7065       assert(i == NumMemOps-1 && i != 0);
7066       SrcOff -= VTSize - Size;
7067       DstOff -= VTSize - Size;
7068     }
7069 
7070     if (CopyFromConstant &&
7071         (isZeroConstant || (VT.isInteger() && !VT.isVector()))) {
7072       // It's unlikely a store of a vector immediate can be done in a single
7073       // instruction. It would require a load from a constantpool first.
7074       // We only handle zero vectors here.
7075       // FIXME: Handle other cases where store of vector immediate is done in
7076       // a single instruction.
7077       ConstantDataArraySlice SubSlice;
7078       if (SrcOff < Slice.Length) {
7079         SubSlice = Slice;
7080         SubSlice.move(SrcOff);
7081       } else {
7082         // This is an out-of-bounds access and hence UB. Pretend we read zero.
7083         SubSlice.Array = nullptr;
7084         SubSlice.Offset = 0;
7085         SubSlice.Length = VTSize;
7086       }
7087       Value = getMemsetStringVal(VT, dl, DAG, TLI, SubSlice);
7088       if (Value.getNode()) {
7089         Store = DAG.getStore(
7090             Chain, dl, Value,
7091             DAG.getMemBasePlusOffset(Dst, TypeSize::Fixed(DstOff), dl),
7092             DstPtrInfo.getWithOffset(DstOff), Alignment, MMOFlags, NewAAInfo);
7093         OutChains.push_back(Store);
7094       }
7095     }
7096 
7097     if (!Store.getNode()) {
7098       // The type might not be legal for the target.  This should only happen
7099       // if the type is smaller than a legal type, as on PPC, so the right
7100       // thing to do is generate a LoadExt/StoreTrunc pair.  These simplify
7101       // to Load/Store if NVT==VT.
7102       // FIXME does the case above also need this?
7103       EVT NVT = TLI.getTypeToTransformTo(C, VT);
7104       assert(NVT.bitsGE(VT));
7105 
7106       bool isDereferenceable =
7107         SrcPtrInfo.getWithOffset(SrcOff).isDereferenceable(VTSize, C, DL);
7108       MachineMemOperand::Flags SrcMMOFlags = MMOFlags;
7109       if (isDereferenceable)
7110         SrcMMOFlags |= MachineMemOperand::MODereferenceable;
7111       if (isConstant)
7112         SrcMMOFlags |= MachineMemOperand::MOInvariant;
7113 
7114       Value = DAG.getExtLoad(
7115           ISD::EXTLOAD, dl, NVT, Chain,
7116           DAG.getMemBasePlusOffset(Src, TypeSize::Fixed(SrcOff), dl),
7117           SrcPtrInfo.getWithOffset(SrcOff), VT,
7118           commonAlignment(*SrcAlign, SrcOff), SrcMMOFlags, NewAAInfo);
7119       OutLoadChains.push_back(Value.getValue(1));
7120 
7121       Store = DAG.getTruncStore(
7122           Chain, dl, Value,
7123           DAG.getMemBasePlusOffset(Dst, TypeSize::Fixed(DstOff), dl),
7124           DstPtrInfo.getWithOffset(DstOff), VT, Alignment, MMOFlags, NewAAInfo);
7125       OutStoreChains.push_back(Store);
7126     }
7127     SrcOff += VTSize;
7128     DstOff += VTSize;
7129     Size -= VTSize;
7130   }
7131 
7132   unsigned GluedLdStLimit = MaxLdStGlue == 0 ?
7133                                 TLI.getMaxGluedStoresPerMemcpy() : MaxLdStGlue;
7134   unsigned NumLdStInMemcpy = OutStoreChains.size();
7135 
7136   if (NumLdStInMemcpy) {
7137     // It may be that memcpy might be converted to memset if it's memcpy
7138     // of constants. In such a case, we won't have loads and stores, but
7139     // just stores. In the absence of loads, there is nothing to gang up.
7140     if ((GluedLdStLimit <= 1) || !EnableMemCpyDAGOpt) {
7141       // If target does not care, just leave as it.
7142       for (unsigned i = 0; i < NumLdStInMemcpy; ++i) {
7143         OutChains.push_back(OutLoadChains[i]);
7144         OutChains.push_back(OutStoreChains[i]);
7145       }
7146     } else {
7147       // Ld/St less than/equal limit set by target.
7148       if (NumLdStInMemcpy <= GluedLdStLimit) {
7149           chainLoadsAndStoresForMemcpy(DAG, dl, OutChains, 0,
7150                                         NumLdStInMemcpy, OutLoadChains,
7151                                         OutStoreChains);
7152       } else {
7153         unsigned NumberLdChain =  NumLdStInMemcpy / GluedLdStLimit;
7154         unsigned RemainingLdStInMemcpy = NumLdStInMemcpy % GluedLdStLimit;
7155         unsigned GlueIter = 0;
7156 
7157         for (unsigned cnt = 0; cnt < NumberLdChain; ++cnt) {
7158           unsigned IndexFrom = NumLdStInMemcpy - GlueIter - GluedLdStLimit;
7159           unsigned IndexTo   = NumLdStInMemcpy - GlueIter;
7160 
7161           chainLoadsAndStoresForMemcpy(DAG, dl, OutChains, IndexFrom, IndexTo,
7162                                        OutLoadChains, OutStoreChains);
7163           GlueIter += GluedLdStLimit;
7164         }
7165 
7166         // Residual ld/st.
7167         if (RemainingLdStInMemcpy) {
7168           chainLoadsAndStoresForMemcpy(DAG, dl, OutChains, 0,
7169                                         RemainingLdStInMemcpy, OutLoadChains,
7170                                         OutStoreChains);
7171         }
7172       }
7173     }
7174   }
7175   return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
7176 }
7177 
getMemmoveLoadsAndStores(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Dst,SDValue Src,uint64_t Size,Align Alignment,bool isVol,bool AlwaysInline,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo,const AAMDNodes & AAInfo)7178 static SDValue getMemmoveLoadsAndStores(SelectionDAG &DAG, const SDLoc &dl,
7179                                         SDValue Chain, SDValue Dst, SDValue Src,
7180                                         uint64_t Size, Align Alignment,
7181                                         bool isVol, bool AlwaysInline,
7182                                         MachinePointerInfo DstPtrInfo,
7183                                         MachinePointerInfo SrcPtrInfo,
7184                                         const AAMDNodes &AAInfo) {
7185   // Turn a memmove of undef to nop.
7186   // FIXME: We need to honor volatile even is Src is undef.
7187   if (Src.isUndef())
7188     return Chain;
7189 
7190   // Expand memmove to a series of load and store ops if the size operand falls
7191   // below a certain threshold.
7192   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7193   const DataLayout &DL = DAG.getDataLayout();
7194   LLVMContext &C = *DAG.getContext();
7195   std::vector<EVT> MemOps;
7196   bool DstAlignCanChange = false;
7197   MachineFunction &MF = DAG.getMachineFunction();
7198   MachineFrameInfo &MFI = MF.getFrameInfo();
7199   bool OptSize = shouldLowerMemFuncForSize(MF, DAG);
7200   FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(Dst);
7201   if (FI && !MFI.isFixedObjectIndex(FI->getIndex()))
7202     DstAlignCanChange = true;
7203   MaybeAlign SrcAlign = DAG.InferPtrAlign(Src);
7204   if (!SrcAlign || Alignment > *SrcAlign)
7205     SrcAlign = Alignment;
7206   assert(SrcAlign && "SrcAlign must be set");
7207   unsigned Limit = AlwaysInline ? ~0U : TLI.getMaxStoresPerMemmove(OptSize);
7208   if (!TLI.findOptimalMemOpLowering(
7209           MemOps, Limit,
7210           MemOp::Copy(Size, DstAlignCanChange, Alignment, *SrcAlign,
7211                       /*IsVolatile*/ true),
7212           DstPtrInfo.getAddrSpace(), SrcPtrInfo.getAddrSpace(),
7213           MF.getFunction().getAttributes()))
7214     return SDValue();
7215 
7216   if (DstAlignCanChange) {
7217     Type *Ty = MemOps[0].getTypeForEVT(C);
7218     Align NewAlign = DL.getABITypeAlign(Ty);
7219 
7220     // Don't promote to an alignment that would require dynamic stack
7221     // realignment which may conflict with optimizations such as tail call
7222     // optimization.
7223     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
7224     if (!TRI->hasStackRealignment(MF))
7225       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
7226         NewAlign = NewAlign.previous();
7227 
7228     if (NewAlign > Alignment) {
7229       // Give the stack frame object a larger alignment if needed.
7230       if (MFI.getObjectAlign(FI->getIndex()) < NewAlign)
7231         MFI.setObjectAlignment(FI->getIndex(), NewAlign);
7232       Alignment = NewAlign;
7233     }
7234   }
7235 
7236   // Prepare AAInfo for loads/stores after lowering this memmove.
7237   AAMDNodes NewAAInfo = AAInfo;
7238   NewAAInfo.TBAA = NewAAInfo.TBAAStruct = nullptr;
7239 
7240   MachineMemOperand::Flags MMOFlags =
7241       isVol ? MachineMemOperand::MOVolatile : MachineMemOperand::MONone;
7242   uint64_t SrcOff = 0, DstOff = 0;
7243   SmallVector<SDValue, 8> LoadValues;
7244   SmallVector<SDValue, 8> LoadChains;
7245   SmallVector<SDValue, 8> OutChains;
7246   unsigned NumMemOps = MemOps.size();
7247   for (unsigned i = 0; i < NumMemOps; i++) {
7248     EVT VT = MemOps[i];
7249     unsigned VTSize = VT.getSizeInBits() / 8;
7250     SDValue Value;
7251 
7252     bool isDereferenceable =
7253       SrcPtrInfo.getWithOffset(SrcOff).isDereferenceable(VTSize, C, DL);
7254     MachineMemOperand::Flags SrcMMOFlags = MMOFlags;
7255     if (isDereferenceable)
7256       SrcMMOFlags |= MachineMemOperand::MODereferenceable;
7257 
7258     Value = DAG.getLoad(
7259         VT, dl, Chain,
7260         DAG.getMemBasePlusOffset(Src, TypeSize::Fixed(SrcOff), dl),
7261         SrcPtrInfo.getWithOffset(SrcOff), *SrcAlign, SrcMMOFlags, NewAAInfo);
7262     LoadValues.push_back(Value);
7263     LoadChains.push_back(Value.getValue(1));
7264     SrcOff += VTSize;
7265   }
7266   Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, LoadChains);
7267   OutChains.clear();
7268   for (unsigned i = 0; i < NumMemOps; i++) {
7269     EVT VT = MemOps[i];
7270     unsigned VTSize = VT.getSizeInBits() / 8;
7271     SDValue Store;
7272 
7273     Store = DAG.getStore(
7274         Chain, dl, LoadValues[i],
7275         DAG.getMemBasePlusOffset(Dst, TypeSize::Fixed(DstOff), dl),
7276         DstPtrInfo.getWithOffset(DstOff), Alignment, MMOFlags, NewAAInfo);
7277     OutChains.push_back(Store);
7278     DstOff += VTSize;
7279   }
7280 
7281   return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
7282 }
7283 
7284 /// Lower the call to 'memset' intrinsic function into a series of store
7285 /// operations.
7286 ///
7287 /// \param DAG Selection DAG where lowered code is placed.
7288 /// \param dl Link to corresponding IR location.
7289 /// \param Chain Control flow dependency.
7290 /// \param Dst Pointer to destination memory location.
7291 /// \param Src Value of byte to write into the memory.
7292 /// \param Size Number of bytes to write.
7293 /// \param Alignment Alignment of the destination in bytes.
7294 /// \param isVol True if destination is volatile.
7295 /// \param AlwaysInline Makes sure no function call is generated.
7296 /// \param DstPtrInfo IR information on the memory pointer.
7297 /// \returns New head in the control flow, if lowering was successful, empty
7298 /// SDValue otherwise.
7299 ///
7300 /// The function tries to replace 'llvm.memset' intrinsic with several store
7301 /// operations and value calculation code. This is usually profitable for small
7302 /// memory size or when the semantic requires inlining.
getMemsetStores(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Dst,SDValue Src,uint64_t Size,Align Alignment,bool isVol,bool AlwaysInline,MachinePointerInfo DstPtrInfo,const AAMDNodes & AAInfo)7303 static SDValue getMemsetStores(SelectionDAG &DAG, const SDLoc &dl,
7304                                SDValue Chain, SDValue Dst, SDValue Src,
7305                                uint64_t Size, Align Alignment, bool isVol,
7306                                bool AlwaysInline, MachinePointerInfo DstPtrInfo,
7307                                const AAMDNodes &AAInfo) {
7308   // Turn a memset of undef to nop.
7309   // FIXME: We need to honor volatile even is Src is undef.
7310   if (Src.isUndef())
7311     return Chain;
7312 
7313   // Expand memset to a series of load/store ops if the size operand
7314   // falls below a certain threshold.
7315   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7316   std::vector<EVT> MemOps;
7317   bool DstAlignCanChange = false;
7318   MachineFunction &MF = DAG.getMachineFunction();
7319   MachineFrameInfo &MFI = MF.getFrameInfo();
7320   bool OptSize = shouldLowerMemFuncForSize(MF, DAG);
7321   FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(Dst);
7322   if (FI && !MFI.isFixedObjectIndex(FI->getIndex()))
7323     DstAlignCanChange = true;
7324   bool IsZeroVal =
7325       isa<ConstantSDNode>(Src) && cast<ConstantSDNode>(Src)->isZero();
7326   unsigned Limit = AlwaysInline ? ~0 : TLI.getMaxStoresPerMemset(OptSize);
7327 
7328   if (!TLI.findOptimalMemOpLowering(
7329           MemOps, Limit,
7330           MemOp::Set(Size, DstAlignCanChange, Alignment, IsZeroVal, isVol),
7331           DstPtrInfo.getAddrSpace(), ~0u, MF.getFunction().getAttributes()))
7332     return SDValue();
7333 
7334   if (DstAlignCanChange) {
7335     Type *Ty = MemOps[0].getTypeForEVT(*DAG.getContext());
7336     const DataLayout &DL = DAG.getDataLayout();
7337     Align NewAlign = DL.getABITypeAlign(Ty);
7338 
7339     // Don't promote to an alignment that would require dynamic stack
7340     // realignment which may conflict with optimizations such as tail call
7341     // optimization.
7342     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
7343     if (!TRI->hasStackRealignment(MF))
7344       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
7345         NewAlign = NewAlign.previous();
7346 
7347     if (NewAlign > Alignment) {
7348       // Give the stack frame object a larger alignment if needed.
7349       if (MFI.getObjectAlign(FI->getIndex()) < NewAlign)
7350         MFI.setObjectAlignment(FI->getIndex(), NewAlign);
7351       Alignment = NewAlign;
7352     }
7353   }
7354 
7355   SmallVector<SDValue, 8> OutChains;
7356   uint64_t DstOff = 0;
7357   unsigned NumMemOps = MemOps.size();
7358 
7359   // Find the largest store and generate the bit pattern for it.
7360   EVT LargestVT = MemOps[0];
7361   for (unsigned i = 1; i < NumMemOps; i++)
7362     if (MemOps[i].bitsGT(LargestVT))
7363       LargestVT = MemOps[i];
7364   SDValue MemSetValue = getMemsetValue(Src, LargestVT, DAG, dl);
7365 
7366   // Prepare AAInfo for loads/stores after lowering this memset.
7367   AAMDNodes NewAAInfo = AAInfo;
7368   NewAAInfo.TBAA = NewAAInfo.TBAAStruct = nullptr;
7369 
7370   for (unsigned i = 0; i < NumMemOps; i++) {
7371     EVT VT = MemOps[i];
7372     unsigned VTSize = VT.getSizeInBits() / 8;
7373     if (VTSize > Size) {
7374       // Issuing an unaligned load / store pair  that overlaps with the previous
7375       // pair. Adjust the offset accordingly.
7376       assert(i == NumMemOps-1 && i != 0);
7377       DstOff -= VTSize - Size;
7378     }
7379 
7380     // If this store is smaller than the largest store see whether we can get
7381     // the smaller value for free with a truncate.
7382     SDValue Value = MemSetValue;
7383     if (VT.bitsLT(LargestVT)) {
7384       if (!LargestVT.isVector() && !VT.isVector() &&
7385           TLI.isTruncateFree(LargestVT, VT))
7386         Value = DAG.getNode(ISD::TRUNCATE, dl, VT, MemSetValue);
7387       else
7388         Value = getMemsetValue(Src, VT, DAG, dl);
7389     }
7390     assert(Value.getValueType() == VT && "Value with wrong type.");
7391     SDValue Store = DAG.getStore(
7392         Chain, dl, Value,
7393         DAG.getMemBasePlusOffset(Dst, TypeSize::Fixed(DstOff), dl),
7394         DstPtrInfo.getWithOffset(DstOff), Alignment,
7395         isVol ? MachineMemOperand::MOVolatile : MachineMemOperand::MONone,
7396         NewAAInfo);
7397     OutChains.push_back(Store);
7398     DstOff += VT.getSizeInBits() / 8;
7399     Size -= VTSize;
7400   }
7401 
7402   return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
7403 }
7404 
checkAddrSpaceIsValidForLibcall(const TargetLowering * TLI,unsigned AS)7405 static void checkAddrSpaceIsValidForLibcall(const TargetLowering *TLI,
7406                                             unsigned AS) {
7407   // Lowering memcpy / memset / memmove intrinsics to calls is only valid if all
7408   // pointer operands can be losslessly bitcasted to pointers of address space 0
7409   if (AS != 0 && !TLI->getTargetMachine().isNoopAddrSpaceCast(AS, 0)) {
7410     report_fatal_error("cannot lower memory intrinsic in address space " +
7411                        Twine(AS));
7412   }
7413 }
7414 
getMemcpy(SDValue Chain,const SDLoc & dl,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVol,bool AlwaysInline,bool isTailCall,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo,const AAMDNodes & AAInfo,AAResults * AA)7415 SDValue SelectionDAG::getMemcpy(SDValue Chain, const SDLoc &dl, SDValue Dst,
7416                                 SDValue Src, SDValue Size, Align Alignment,
7417                                 bool isVol, bool AlwaysInline, bool isTailCall,
7418                                 MachinePointerInfo DstPtrInfo,
7419                                 MachinePointerInfo SrcPtrInfo,
7420                                 const AAMDNodes &AAInfo, AAResults *AA) {
7421   // Check to see if we should lower the memcpy to loads and stores first.
7422   // For cases within the target-specified limits, this is the best choice.
7423   ConstantSDNode *ConstantSize = dyn_cast<ConstantSDNode>(Size);
7424   if (ConstantSize) {
7425     // Memcpy with size zero? Just return the original chain.
7426     if (ConstantSize->isZero())
7427       return Chain;
7428 
7429     SDValue Result = getMemcpyLoadsAndStores(
7430         *this, dl, Chain, Dst, Src, ConstantSize->getZExtValue(), Alignment,
7431         isVol, false, DstPtrInfo, SrcPtrInfo, AAInfo, AA);
7432     if (Result.getNode())
7433       return Result;
7434   }
7435 
7436   // Then check to see if we should lower the memcpy with target-specific
7437   // code. If the target chooses to do this, this is the next best.
7438   if (TSI) {
7439     SDValue Result = TSI->EmitTargetCodeForMemcpy(
7440         *this, dl, Chain, Dst, Src, Size, Alignment, isVol, AlwaysInline,
7441         DstPtrInfo, SrcPtrInfo);
7442     if (Result.getNode())
7443       return Result;
7444   }
7445 
7446   // If we really need inline code and the target declined to provide it,
7447   // use a (potentially long) sequence of loads and stores.
7448   if (AlwaysInline) {
7449     assert(ConstantSize && "AlwaysInline requires a constant size!");
7450     return getMemcpyLoadsAndStores(
7451         *this, dl, Chain, Dst, Src, ConstantSize->getZExtValue(), Alignment,
7452         isVol, true, DstPtrInfo, SrcPtrInfo, AAInfo, AA);
7453   }
7454 
7455   checkAddrSpaceIsValidForLibcall(TLI, DstPtrInfo.getAddrSpace());
7456   checkAddrSpaceIsValidForLibcall(TLI, SrcPtrInfo.getAddrSpace());
7457 
7458   // FIXME: If the memcpy is volatile (isVol), lowering it to a plain libc
7459   // memcpy is not guaranteed to be safe. libc memcpys aren't required to
7460   // respect volatile, so they may do things like read or write memory
7461   // beyond the given memory regions. But fixing this isn't easy, and most
7462   // people don't care.
7463 
7464   // Emit a library call.
7465   TargetLowering::ArgListTy Args;
7466   TargetLowering::ArgListEntry Entry;
7467   Entry.Ty = Type::getInt8PtrTy(*getContext());
7468   Entry.Node = Dst; Args.push_back(Entry);
7469   Entry.Node = Src; Args.push_back(Entry);
7470 
7471   Entry.Ty = getDataLayout().getIntPtrType(*getContext());
7472   Entry.Node = Size; Args.push_back(Entry);
7473   // FIXME: pass in SDLoc
7474   TargetLowering::CallLoweringInfo CLI(*this);
7475   CLI.setDebugLoc(dl)
7476       .setChain(Chain)
7477       .setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMCPY),
7478                     Dst.getValueType().getTypeForEVT(*getContext()),
7479                     getExternalSymbol(TLI->getLibcallName(RTLIB::MEMCPY),
7480                                       TLI->getPointerTy(getDataLayout())),
7481                     std::move(Args))
7482       .setDiscardResult()
7483       .setTailCall(isTailCall);
7484 
7485   std::pair<SDValue,SDValue> CallResult = TLI->LowerCallTo(CLI);
7486   return CallResult.second;
7487 }
7488 
getAtomicMemcpy(SDValue Chain,const SDLoc & dl,SDValue Dst,SDValue Src,SDValue Size,Type * SizeTy,unsigned ElemSz,bool isTailCall,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo)7489 SDValue SelectionDAG::getAtomicMemcpy(SDValue Chain, const SDLoc &dl,
7490                                       SDValue Dst, SDValue Src, SDValue Size,
7491                                       Type *SizeTy, unsigned ElemSz,
7492                                       bool isTailCall,
7493                                       MachinePointerInfo DstPtrInfo,
7494                                       MachinePointerInfo SrcPtrInfo) {
7495   // Emit a library call.
7496   TargetLowering::ArgListTy Args;
7497   TargetLowering::ArgListEntry Entry;
7498   Entry.Ty = getDataLayout().getIntPtrType(*getContext());
7499   Entry.Node = Dst;
7500   Args.push_back(Entry);
7501 
7502   Entry.Node = Src;
7503   Args.push_back(Entry);
7504 
7505   Entry.Ty = SizeTy;
7506   Entry.Node = Size;
7507   Args.push_back(Entry);
7508 
7509   RTLIB::Libcall LibraryCall =
7510       RTLIB::getMEMCPY_ELEMENT_UNORDERED_ATOMIC(ElemSz);
7511   if (LibraryCall == RTLIB::UNKNOWN_LIBCALL)
7512     report_fatal_error("Unsupported element size");
7513 
7514   TargetLowering::CallLoweringInfo CLI(*this);
7515   CLI.setDebugLoc(dl)
7516       .setChain(Chain)
7517       .setLibCallee(TLI->getLibcallCallingConv(LibraryCall),
7518                     Type::getVoidTy(*getContext()),
7519                     getExternalSymbol(TLI->getLibcallName(LibraryCall),
7520                                       TLI->getPointerTy(getDataLayout())),
7521                     std::move(Args))
7522       .setDiscardResult()
7523       .setTailCall(isTailCall);
7524 
7525   std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
7526   return CallResult.second;
7527 }
7528 
getMemmove(SDValue Chain,const SDLoc & dl,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVol,bool isTailCall,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo,const AAMDNodes & AAInfo,AAResults * AA)7529 SDValue SelectionDAG::getMemmove(SDValue Chain, const SDLoc &dl, SDValue Dst,
7530                                  SDValue Src, SDValue Size, Align Alignment,
7531                                  bool isVol, bool isTailCall,
7532                                  MachinePointerInfo DstPtrInfo,
7533                                  MachinePointerInfo SrcPtrInfo,
7534                                  const AAMDNodes &AAInfo, AAResults *AA) {
7535   // Check to see if we should lower the memmove to loads and stores first.
7536   // For cases within the target-specified limits, this is the best choice.
7537   ConstantSDNode *ConstantSize = dyn_cast<ConstantSDNode>(Size);
7538   if (ConstantSize) {
7539     // Memmove with size zero? Just return the original chain.
7540     if (ConstantSize->isZero())
7541       return Chain;
7542 
7543     SDValue Result = getMemmoveLoadsAndStores(
7544         *this, dl, Chain, Dst, Src, ConstantSize->getZExtValue(), Alignment,
7545         isVol, false, DstPtrInfo, SrcPtrInfo, AAInfo);
7546     if (Result.getNode())
7547       return Result;
7548   }
7549 
7550   // Then check to see if we should lower the memmove with target-specific
7551   // code. If the target chooses to do this, this is the next best.
7552   if (TSI) {
7553     SDValue Result =
7554         TSI->EmitTargetCodeForMemmove(*this, dl, Chain, Dst, Src, Size,
7555                                       Alignment, isVol, DstPtrInfo, SrcPtrInfo);
7556     if (Result.getNode())
7557       return Result;
7558   }
7559 
7560   checkAddrSpaceIsValidForLibcall(TLI, DstPtrInfo.getAddrSpace());
7561   checkAddrSpaceIsValidForLibcall(TLI, SrcPtrInfo.getAddrSpace());
7562 
7563   // FIXME: If the memmove is volatile, lowering it to plain libc memmove may
7564   // not be safe.  See memcpy above for more details.
7565 
7566   // Emit a library call.
7567   TargetLowering::ArgListTy Args;
7568   TargetLowering::ArgListEntry Entry;
7569   Entry.Ty = Type::getInt8PtrTy(*getContext());
7570   Entry.Node = Dst; Args.push_back(Entry);
7571   Entry.Node = Src; Args.push_back(Entry);
7572 
7573   Entry.Ty = getDataLayout().getIntPtrType(*getContext());
7574   Entry.Node = Size; Args.push_back(Entry);
7575   // FIXME:  pass in SDLoc
7576   TargetLowering::CallLoweringInfo CLI(*this);
7577   CLI.setDebugLoc(dl)
7578       .setChain(Chain)
7579       .setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMMOVE),
7580                     Dst.getValueType().getTypeForEVT(*getContext()),
7581                     getExternalSymbol(TLI->getLibcallName(RTLIB::MEMMOVE),
7582                                       TLI->getPointerTy(getDataLayout())),
7583                     std::move(Args))
7584       .setDiscardResult()
7585       .setTailCall(isTailCall);
7586 
7587   std::pair<SDValue,SDValue> CallResult = TLI->LowerCallTo(CLI);
7588   return CallResult.second;
7589 }
7590 
getAtomicMemmove(SDValue Chain,const SDLoc & dl,SDValue Dst,SDValue Src,SDValue Size,Type * SizeTy,unsigned ElemSz,bool isTailCall,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo)7591 SDValue SelectionDAG::getAtomicMemmove(SDValue Chain, const SDLoc &dl,
7592                                        SDValue Dst, SDValue Src, SDValue Size,
7593                                        Type *SizeTy, unsigned ElemSz,
7594                                        bool isTailCall,
7595                                        MachinePointerInfo DstPtrInfo,
7596                                        MachinePointerInfo SrcPtrInfo) {
7597   // Emit a library call.
7598   TargetLowering::ArgListTy Args;
7599   TargetLowering::ArgListEntry Entry;
7600   Entry.Ty = getDataLayout().getIntPtrType(*getContext());
7601   Entry.Node = Dst;
7602   Args.push_back(Entry);
7603 
7604   Entry.Node = Src;
7605   Args.push_back(Entry);
7606 
7607   Entry.Ty = SizeTy;
7608   Entry.Node = Size;
7609   Args.push_back(Entry);
7610 
7611   RTLIB::Libcall LibraryCall =
7612       RTLIB::getMEMMOVE_ELEMENT_UNORDERED_ATOMIC(ElemSz);
7613   if (LibraryCall == RTLIB::UNKNOWN_LIBCALL)
7614     report_fatal_error("Unsupported element size");
7615 
7616   TargetLowering::CallLoweringInfo CLI(*this);
7617   CLI.setDebugLoc(dl)
7618       .setChain(Chain)
7619       .setLibCallee(TLI->getLibcallCallingConv(LibraryCall),
7620                     Type::getVoidTy(*getContext()),
7621                     getExternalSymbol(TLI->getLibcallName(LibraryCall),
7622                                       TLI->getPointerTy(getDataLayout())),
7623                     std::move(Args))
7624       .setDiscardResult()
7625       .setTailCall(isTailCall);
7626 
7627   std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
7628   return CallResult.second;
7629 }
7630 
getMemset(SDValue Chain,const SDLoc & dl,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVol,bool AlwaysInline,bool isTailCall,MachinePointerInfo DstPtrInfo,const AAMDNodes & AAInfo)7631 SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
7632                                 SDValue Src, SDValue Size, Align Alignment,
7633                                 bool isVol, bool AlwaysInline, bool isTailCall,
7634                                 MachinePointerInfo DstPtrInfo,
7635                                 const AAMDNodes &AAInfo) {
7636   // Check to see if we should lower the memset to stores first.
7637   // For cases within the target-specified limits, this is the best choice.
7638   ConstantSDNode *ConstantSize = dyn_cast<ConstantSDNode>(Size);
7639   if (ConstantSize) {
7640     // Memset with size zero? Just return the original chain.
7641     if (ConstantSize->isZero())
7642       return Chain;
7643 
7644     SDValue Result = getMemsetStores(*this, dl, Chain, Dst, Src,
7645                                      ConstantSize->getZExtValue(), Alignment,
7646                                      isVol, false, DstPtrInfo, AAInfo);
7647 
7648     if (Result.getNode())
7649       return Result;
7650   }
7651 
7652   // Then check to see if we should lower the memset with target-specific
7653   // code. If the target chooses to do this, this is the next best.
7654   if (TSI) {
7655     SDValue Result = TSI->EmitTargetCodeForMemset(
7656         *this, dl, Chain, Dst, Src, Size, Alignment, isVol, AlwaysInline, DstPtrInfo);
7657     if (Result.getNode())
7658       return Result;
7659   }
7660 
7661   // If we really need inline code and the target declined to provide it,
7662   // use a (potentially long) sequence of loads and stores.
7663   if (AlwaysInline) {
7664     assert(ConstantSize && "AlwaysInline requires a constant size!");
7665     SDValue Result = getMemsetStores(*this, dl, Chain, Dst, Src,
7666                                      ConstantSize->getZExtValue(), Alignment,
7667                                      isVol, true, DstPtrInfo, AAInfo);
7668     assert(Result &&
7669            "getMemsetStores must return a valid sequence when AlwaysInline");
7670     return Result;
7671   }
7672 
7673   checkAddrSpaceIsValidForLibcall(TLI, DstPtrInfo.getAddrSpace());
7674 
7675   // Emit a library call.
7676   auto &Ctx = *getContext();
7677   const auto& DL = getDataLayout();
7678 
7679   TargetLowering::CallLoweringInfo CLI(*this);
7680   // FIXME: pass in SDLoc
7681   CLI.setDebugLoc(dl).setChain(Chain);
7682 
7683   ConstantSDNode *ConstantSrc = dyn_cast<ConstantSDNode>(Src);
7684   const bool SrcIsZero = ConstantSrc && ConstantSrc->isZero();
7685   const char *BzeroName = getTargetLoweringInfo().getLibcallName(RTLIB::BZERO);
7686 
7687   // Helper function to create an Entry from Node and Type.
7688   const auto CreateEntry = [](SDValue Node, Type *Ty) {
7689     TargetLowering::ArgListEntry Entry;
7690     Entry.Node = Node;
7691     Entry.Ty = Ty;
7692     return Entry;
7693   };
7694 
7695   // If zeroing out and bzero is present, use it.
7696   if (SrcIsZero && BzeroName) {
7697     TargetLowering::ArgListTy Args;
7698     Args.push_back(CreateEntry(Dst, Type::getInt8PtrTy(Ctx)));
7699     Args.push_back(CreateEntry(Size, DL.getIntPtrType(Ctx)));
7700     CLI.setLibCallee(
7701         TLI->getLibcallCallingConv(RTLIB::BZERO), Type::getVoidTy(Ctx),
7702         getExternalSymbol(BzeroName, TLI->getPointerTy(DL)), std::move(Args));
7703   } else {
7704     TargetLowering::ArgListTy Args;
7705     Args.push_back(CreateEntry(Dst, Type::getInt8PtrTy(Ctx)));
7706     Args.push_back(CreateEntry(Src, Src.getValueType().getTypeForEVT(Ctx)));
7707     Args.push_back(CreateEntry(Size, DL.getIntPtrType(Ctx)));
7708     CLI.setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMSET),
7709                      Dst.getValueType().getTypeForEVT(Ctx),
7710                      getExternalSymbol(TLI->getLibcallName(RTLIB::MEMSET),
7711                                        TLI->getPointerTy(DL)),
7712                      std::move(Args));
7713   }
7714 
7715   CLI.setDiscardResult().setTailCall(isTailCall);
7716 
7717   std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
7718   return CallResult.second;
7719 }
7720 
getAtomicMemset(SDValue Chain,const SDLoc & dl,SDValue Dst,SDValue Value,SDValue Size,Type * SizeTy,unsigned ElemSz,bool isTailCall,MachinePointerInfo DstPtrInfo)7721 SDValue SelectionDAG::getAtomicMemset(SDValue Chain, const SDLoc &dl,
7722                                       SDValue Dst, SDValue Value, SDValue Size,
7723                                       Type *SizeTy, unsigned ElemSz,
7724                                       bool isTailCall,
7725                                       MachinePointerInfo DstPtrInfo) {
7726   // Emit a library call.
7727   TargetLowering::ArgListTy Args;
7728   TargetLowering::ArgListEntry Entry;
7729   Entry.Ty = getDataLayout().getIntPtrType(*getContext());
7730   Entry.Node = Dst;
7731   Args.push_back(Entry);
7732 
7733   Entry.Ty = Type::getInt8Ty(*getContext());
7734   Entry.Node = Value;
7735   Args.push_back(Entry);
7736 
7737   Entry.Ty = SizeTy;
7738   Entry.Node = Size;
7739   Args.push_back(Entry);
7740 
7741   RTLIB::Libcall LibraryCall =
7742       RTLIB::getMEMSET_ELEMENT_UNORDERED_ATOMIC(ElemSz);
7743   if (LibraryCall == RTLIB::UNKNOWN_LIBCALL)
7744     report_fatal_error("Unsupported element size");
7745 
7746   TargetLowering::CallLoweringInfo CLI(*this);
7747   CLI.setDebugLoc(dl)
7748       .setChain(Chain)
7749       .setLibCallee(TLI->getLibcallCallingConv(LibraryCall),
7750                     Type::getVoidTy(*getContext()),
7751                     getExternalSymbol(TLI->getLibcallName(LibraryCall),
7752                                       TLI->getPointerTy(getDataLayout())),
7753                     std::move(Args))
7754       .setDiscardResult()
7755       .setTailCall(isTailCall);
7756 
7757   std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
7758   return CallResult.second;
7759 }
7760 
getAtomic(unsigned Opcode,const SDLoc & dl,EVT MemVT,SDVTList VTList,ArrayRef<SDValue> Ops,MachineMemOperand * MMO)7761 SDValue SelectionDAG::getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT,
7762                                 SDVTList VTList, ArrayRef<SDValue> Ops,
7763                                 MachineMemOperand *MMO) {
7764   FoldingSetNodeID ID;
7765   ID.AddInteger(MemVT.getRawBits());
7766   AddNodeIDNode(ID, Opcode, VTList, Ops);
7767   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
7768   ID.AddInteger(MMO->getFlags());
7769   void* IP = nullptr;
7770   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
7771     cast<AtomicSDNode>(E)->refineAlignment(MMO);
7772     return SDValue(E, 0);
7773   }
7774 
7775   auto *N = newSDNode<AtomicSDNode>(Opcode, dl.getIROrder(), dl.getDebugLoc(),
7776                                     VTList, MemVT, MMO);
7777   createOperands(N, Ops);
7778 
7779   CSEMap.InsertNode(N, IP);
7780   InsertNode(N);
7781   return SDValue(N, 0);
7782 }
7783 
getAtomicCmpSwap(unsigned Opcode,const SDLoc & dl,EVT MemVT,SDVTList VTs,SDValue Chain,SDValue Ptr,SDValue Cmp,SDValue Swp,MachineMemOperand * MMO)7784 SDValue SelectionDAG::getAtomicCmpSwap(unsigned Opcode, const SDLoc &dl,
7785                                        EVT MemVT, SDVTList VTs, SDValue Chain,
7786                                        SDValue Ptr, SDValue Cmp, SDValue Swp,
7787                                        MachineMemOperand *MMO) {
7788   assert(Opcode == ISD::ATOMIC_CMP_SWAP ||
7789          Opcode == ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS);
7790   assert(Cmp.getValueType() == Swp.getValueType() && "Invalid Atomic Op Types");
7791 
7792   SDValue Ops[] = {Chain, Ptr, Cmp, Swp};
7793   return getAtomic(Opcode, dl, MemVT, VTs, Ops, MMO);
7794 }
7795 
getAtomic(unsigned Opcode,const SDLoc & dl,EVT MemVT,SDValue Chain,SDValue Ptr,SDValue Val,MachineMemOperand * MMO)7796 SDValue SelectionDAG::getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT,
7797                                 SDValue Chain, SDValue Ptr, SDValue Val,
7798                                 MachineMemOperand *MMO) {
7799   assert((Opcode == ISD::ATOMIC_LOAD_ADD ||
7800           Opcode == ISD::ATOMIC_LOAD_SUB ||
7801           Opcode == ISD::ATOMIC_LOAD_AND ||
7802           Opcode == ISD::ATOMIC_LOAD_CLR ||
7803           Opcode == ISD::ATOMIC_LOAD_OR ||
7804           Opcode == ISD::ATOMIC_LOAD_XOR ||
7805           Opcode == ISD::ATOMIC_LOAD_NAND ||
7806           Opcode == ISD::ATOMIC_LOAD_MIN ||
7807           Opcode == ISD::ATOMIC_LOAD_MAX ||
7808           Opcode == ISD::ATOMIC_LOAD_UMIN ||
7809           Opcode == ISD::ATOMIC_LOAD_UMAX ||
7810           Opcode == ISD::ATOMIC_LOAD_FADD ||
7811           Opcode == ISD::ATOMIC_LOAD_FSUB ||
7812           Opcode == ISD::ATOMIC_LOAD_FMAX ||
7813           Opcode == ISD::ATOMIC_LOAD_FMIN ||
7814           Opcode == ISD::ATOMIC_LOAD_UINC_WRAP ||
7815           Opcode == ISD::ATOMIC_LOAD_UDEC_WRAP ||
7816           Opcode == ISD::ATOMIC_SWAP ||
7817           Opcode == ISD::ATOMIC_STORE) &&
7818          "Invalid Atomic Op");
7819 
7820   EVT VT = Val.getValueType();
7821 
7822   SDVTList VTs = Opcode == ISD::ATOMIC_STORE ? getVTList(MVT::Other) :
7823                                                getVTList(VT, MVT::Other);
7824   SDValue Ops[] = {Chain, Ptr, Val};
7825   return getAtomic(Opcode, dl, MemVT, VTs, Ops, MMO);
7826 }
7827 
getAtomic(unsigned Opcode,const SDLoc & dl,EVT MemVT,EVT VT,SDValue Chain,SDValue Ptr,MachineMemOperand * MMO)7828 SDValue SelectionDAG::getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT,
7829                                 EVT VT, SDValue Chain, SDValue Ptr,
7830                                 MachineMemOperand *MMO) {
7831   assert(Opcode == ISD::ATOMIC_LOAD && "Invalid Atomic Op");
7832 
7833   SDVTList VTs = getVTList(VT, MVT::Other);
7834   SDValue Ops[] = {Chain, Ptr};
7835   return getAtomic(Opcode, dl, MemVT, VTs, Ops, MMO);
7836 }
7837 
7838 /// getMergeValues - Create a MERGE_VALUES node from the given operands.
getMergeValues(ArrayRef<SDValue> Ops,const SDLoc & dl)7839 SDValue SelectionDAG::getMergeValues(ArrayRef<SDValue> Ops, const SDLoc &dl) {
7840   if (Ops.size() == 1)
7841     return Ops[0];
7842 
7843   SmallVector<EVT, 4> VTs;
7844   VTs.reserve(Ops.size());
7845   for (const SDValue &Op : Ops)
7846     VTs.push_back(Op.getValueType());
7847   return getNode(ISD::MERGE_VALUES, dl, getVTList(VTs), Ops);
7848 }
7849 
getMemIntrinsicNode(unsigned Opcode,const SDLoc & dl,SDVTList VTList,ArrayRef<SDValue> Ops,EVT MemVT,MachinePointerInfo PtrInfo,Align Alignment,MachineMemOperand::Flags Flags,uint64_t Size,const AAMDNodes & AAInfo)7850 SDValue SelectionDAG::getMemIntrinsicNode(
7851     unsigned Opcode, const SDLoc &dl, SDVTList VTList, ArrayRef<SDValue> Ops,
7852     EVT MemVT, MachinePointerInfo PtrInfo, Align Alignment,
7853     MachineMemOperand::Flags Flags, uint64_t Size, const AAMDNodes &AAInfo) {
7854   if (!Size && MemVT.isScalableVector())
7855     Size = MemoryLocation::UnknownSize;
7856   else if (!Size)
7857     Size = MemVT.getStoreSize();
7858 
7859   MachineFunction &MF = getMachineFunction();
7860   MachineMemOperand *MMO =
7861       MF.getMachineMemOperand(PtrInfo, Flags, Size, Alignment, AAInfo);
7862 
7863   return getMemIntrinsicNode(Opcode, dl, VTList, Ops, MemVT, MMO);
7864 }
7865 
getMemIntrinsicNode(unsigned Opcode,const SDLoc & dl,SDVTList VTList,ArrayRef<SDValue> Ops,EVT MemVT,MachineMemOperand * MMO)7866 SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl,
7867                                           SDVTList VTList,
7868                                           ArrayRef<SDValue> Ops, EVT MemVT,
7869                                           MachineMemOperand *MMO) {
7870   assert((Opcode == ISD::INTRINSIC_VOID ||
7871           Opcode == ISD::INTRINSIC_W_CHAIN ||
7872           Opcode == ISD::PREFETCH ||
7873           ((int)Opcode <= std::numeric_limits<int>::max() &&
7874            (int)Opcode >= ISD::FIRST_TARGET_MEMORY_OPCODE)) &&
7875          "Opcode is not a memory-accessing opcode!");
7876 
7877   // Memoize the node unless it returns a flag.
7878   MemIntrinsicSDNode *N;
7879   if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) {
7880     FoldingSetNodeID ID;
7881     AddNodeIDNode(ID, Opcode, VTList, Ops);
7882     ID.AddInteger(getSyntheticNodeSubclassData<MemIntrinsicSDNode>(
7883         Opcode, dl.getIROrder(), VTList, MemVT, MMO));
7884     ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
7885     ID.AddInteger(MMO->getFlags());
7886     void *IP = nullptr;
7887     if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
7888       cast<MemIntrinsicSDNode>(E)->refineAlignment(MMO);
7889       return SDValue(E, 0);
7890     }
7891 
7892     N = newSDNode<MemIntrinsicSDNode>(Opcode, dl.getIROrder(), dl.getDebugLoc(),
7893                                       VTList, MemVT, MMO);
7894     createOperands(N, Ops);
7895 
7896   CSEMap.InsertNode(N, IP);
7897   } else {
7898     N = newSDNode<MemIntrinsicSDNode>(Opcode, dl.getIROrder(), dl.getDebugLoc(),
7899                                       VTList, MemVT, MMO);
7900     createOperands(N, Ops);
7901   }
7902   InsertNode(N);
7903   SDValue V(N, 0);
7904   NewSDValueDbgMsg(V, "Creating new node: ", this);
7905   return V;
7906 }
7907 
getLifetimeNode(bool IsStart,const SDLoc & dl,SDValue Chain,int FrameIndex,int64_t Size,int64_t Offset)7908 SDValue SelectionDAG::getLifetimeNode(bool IsStart, const SDLoc &dl,
7909                                       SDValue Chain, int FrameIndex,
7910                                       int64_t Size, int64_t Offset) {
7911   const unsigned Opcode = IsStart ? ISD::LIFETIME_START : ISD::LIFETIME_END;
7912   const auto VTs = getVTList(MVT::Other);
7913   SDValue Ops[2] = {
7914       Chain,
7915       getFrameIndex(FrameIndex,
7916                     getTargetLoweringInfo().getFrameIndexTy(getDataLayout()),
7917                     true)};
7918 
7919   FoldingSetNodeID ID;
7920   AddNodeIDNode(ID, Opcode, VTs, Ops);
7921   ID.AddInteger(FrameIndex);
7922   ID.AddInteger(Size);
7923   ID.AddInteger(Offset);
7924   void *IP = nullptr;
7925   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP))
7926     return SDValue(E, 0);
7927 
7928   LifetimeSDNode *N = newSDNode<LifetimeSDNode>(
7929       Opcode, dl.getIROrder(), dl.getDebugLoc(), VTs, Size, Offset);
7930   createOperands(N, Ops);
7931   CSEMap.InsertNode(N, IP);
7932   InsertNode(N);
7933   SDValue V(N, 0);
7934   NewSDValueDbgMsg(V, "Creating new node: ", this);
7935   return V;
7936 }
7937 
getPseudoProbeNode(const SDLoc & Dl,SDValue Chain,uint64_t Guid,uint64_t Index,uint32_t Attr)7938 SDValue SelectionDAG::getPseudoProbeNode(const SDLoc &Dl, SDValue Chain,
7939                                          uint64_t Guid, uint64_t Index,
7940                                          uint32_t Attr) {
7941   const unsigned Opcode = ISD::PSEUDO_PROBE;
7942   const auto VTs = getVTList(MVT::Other);
7943   SDValue Ops[] = {Chain};
7944   FoldingSetNodeID ID;
7945   AddNodeIDNode(ID, Opcode, VTs, Ops);
7946   ID.AddInteger(Guid);
7947   ID.AddInteger(Index);
7948   void *IP = nullptr;
7949   if (SDNode *E = FindNodeOrInsertPos(ID, Dl, IP))
7950     return SDValue(E, 0);
7951 
7952   auto *N = newSDNode<PseudoProbeSDNode>(
7953       Opcode, Dl.getIROrder(), Dl.getDebugLoc(), VTs, Guid, Index, Attr);
7954   createOperands(N, Ops);
7955   CSEMap.InsertNode(N, IP);
7956   InsertNode(N);
7957   SDValue V(N, 0);
7958   NewSDValueDbgMsg(V, "Creating new node: ", this);
7959   return V;
7960 }
7961 
7962 /// InferPointerInfo - If the specified ptr/offset is a frame index, infer a
7963 /// MachinePointerInfo record from it.  This is particularly useful because the
7964 /// code generator has many cases where it doesn't bother passing in a
7965 /// MachinePointerInfo to getLoad or getStore when it has "FI+Cst".
InferPointerInfo(const MachinePointerInfo & Info,SelectionDAG & DAG,SDValue Ptr,int64_t Offset=0)7966 static MachinePointerInfo InferPointerInfo(const MachinePointerInfo &Info,
7967                                            SelectionDAG &DAG, SDValue Ptr,
7968                                            int64_t Offset = 0) {
7969   // If this is FI+Offset, we can model it.
7970   if (const FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(Ptr))
7971     return MachinePointerInfo::getFixedStack(DAG.getMachineFunction(),
7972                                              FI->getIndex(), Offset);
7973 
7974   // If this is (FI+Offset1)+Offset2, we can model it.
7975   if (Ptr.getOpcode() != ISD::ADD ||
7976       !isa<ConstantSDNode>(Ptr.getOperand(1)) ||
7977       !isa<FrameIndexSDNode>(Ptr.getOperand(0)))
7978     return Info;
7979 
7980   int FI = cast<FrameIndexSDNode>(Ptr.getOperand(0))->getIndex();
7981   return MachinePointerInfo::getFixedStack(
7982       DAG.getMachineFunction(), FI,
7983       Offset + cast<ConstantSDNode>(Ptr.getOperand(1))->getSExtValue());
7984 }
7985 
7986 /// InferPointerInfo - If the specified ptr/offset is a frame index, infer a
7987 /// MachinePointerInfo record from it.  This is particularly useful because the
7988 /// code generator has many cases where it doesn't bother passing in a
7989 /// MachinePointerInfo to getLoad or getStore when it has "FI+Cst".
InferPointerInfo(const MachinePointerInfo & Info,SelectionDAG & DAG,SDValue Ptr,SDValue OffsetOp)7990 static MachinePointerInfo InferPointerInfo(const MachinePointerInfo &Info,
7991                                            SelectionDAG &DAG, SDValue Ptr,
7992                                            SDValue OffsetOp) {
7993   // If the 'Offset' value isn't a constant, we can't handle this.
7994   if (ConstantSDNode *OffsetNode = dyn_cast<ConstantSDNode>(OffsetOp))
7995     return InferPointerInfo(Info, DAG, Ptr, OffsetNode->getSExtValue());
7996   if (OffsetOp.isUndef())
7997     return InferPointerInfo(Info, DAG, Ptr);
7998   return Info;
7999 }
8000 
getLoad(ISD::MemIndexedMode AM,ISD::LoadExtType ExtType,EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,SDValue Offset,MachinePointerInfo PtrInfo,EVT MemVT,Align Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,const MDNode * Ranges)8001 SDValue SelectionDAG::getLoad(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType,
8002                               EVT VT, const SDLoc &dl, SDValue Chain,
8003                               SDValue Ptr, SDValue Offset,
8004                               MachinePointerInfo PtrInfo, EVT MemVT,
8005                               Align Alignment,
8006                               MachineMemOperand::Flags MMOFlags,
8007                               const AAMDNodes &AAInfo, const MDNode *Ranges) {
8008   assert(Chain.getValueType() == MVT::Other &&
8009         "Invalid chain type");
8010 
8011   MMOFlags |= MachineMemOperand::MOLoad;
8012   assert((MMOFlags & MachineMemOperand::MOStore) == 0);
8013   // If we don't have a PtrInfo, infer the trivial frame index case to simplify
8014   // clients.
8015   if (PtrInfo.V.isNull())
8016     PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr, Offset);
8017 
8018   uint64_t Size = MemoryLocation::getSizeOrUnknown(MemVT.getStoreSize());
8019   MachineFunction &MF = getMachineFunction();
8020   MachineMemOperand *MMO = MF.getMachineMemOperand(PtrInfo, MMOFlags, Size,
8021                                                    Alignment, AAInfo, Ranges);
8022   return getLoad(AM, ExtType, VT, dl, Chain, Ptr, Offset, MemVT, MMO);
8023 }
8024 
getLoad(ISD::MemIndexedMode AM,ISD::LoadExtType ExtType,EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,SDValue Offset,EVT MemVT,MachineMemOperand * MMO)8025 SDValue SelectionDAG::getLoad(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType,
8026                               EVT VT, const SDLoc &dl, SDValue Chain,
8027                               SDValue Ptr, SDValue Offset, EVT MemVT,
8028                               MachineMemOperand *MMO) {
8029   if (VT == MemVT) {
8030     ExtType = ISD::NON_EXTLOAD;
8031   } else if (ExtType == ISD::NON_EXTLOAD) {
8032     assert(VT == MemVT && "Non-extending load from different memory type!");
8033   } else {
8034     // Extending load.
8035     assert(MemVT.getScalarType().bitsLT(VT.getScalarType()) &&
8036            "Should only be an extending load, not truncating!");
8037     assert(VT.isInteger() == MemVT.isInteger() &&
8038            "Cannot convert from FP to Int or Int -> FP!");
8039     assert(VT.isVector() == MemVT.isVector() &&
8040            "Cannot use an ext load to convert to or from a vector!");
8041     assert((!VT.isVector() ||
8042             VT.getVectorElementCount() == MemVT.getVectorElementCount()) &&
8043            "Cannot use an ext load to change the number of vector elements!");
8044   }
8045 
8046   bool Indexed = AM != ISD::UNINDEXED;
8047   assert((Indexed || Offset.isUndef()) && "Unindexed load with an offset!");
8048 
8049   SDVTList VTs = Indexed ?
8050     getVTList(VT, Ptr.getValueType(), MVT::Other) : getVTList(VT, MVT::Other);
8051   SDValue Ops[] = { Chain, Ptr, Offset };
8052   FoldingSetNodeID ID;
8053   AddNodeIDNode(ID, ISD::LOAD, VTs, Ops);
8054   ID.AddInteger(MemVT.getRawBits());
8055   ID.AddInteger(getSyntheticNodeSubclassData<LoadSDNode>(
8056       dl.getIROrder(), VTs, AM, ExtType, MemVT, MMO));
8057   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8058   ID.AddInteger(MMO->getFlags());
8059   void *IP = nullptr;
8060   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8061     cast<LoadSDNode>(E)->refineAlignment(MMO);
8062     return SDValue(E, 0);
8063   }
8064   auto *N = newSDNode<LoadSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, AM,
8065                                   ExtType, MemVT, MMO);
8066   createOperands(N, Ops);
8067 
8068   CSEMap.InsertNode(N, IP);
8069   InsertNode(N);
8070   SDValue V(N, 0);
8071   NewSDValueDbgMsg(V, "Creating new node: ", this);
8072   return V;
8073 }
8074 
getLoad(EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,MachinePointerInfo PtrInfo,MaybeAlign Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,const MDNode * Ranges)8075 SDValue SelectionDAG::getLoad(EVT VT, const SDLoc &dl, SDValue Chain,
8076                               SDValue Ptr, MachinePointerInfo PtrInfo,
8077                               MaybeAlign Alignment,
8078                               MachineMemOperand::Flags MMOFlags,
8079                               const AAMDNodes &AAInfo, const MDNode *Ranges) {
8080   SDValue Undef = getUNDEF(Ptr.getValueType());
8081   return getLoad(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, dl, Chain, Ptr, Undef,
8082                  PtrInfo, VT, Alignment, MMOFlags, AAInfo, Ranges);
8083 }
8084 
getLoad(EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,MachineMemOperand * MMO)8085 SDValue SelectionDAG::getLoad(EVT VT, const SDLoc &dl, SDValue Chain,
8086                               SDValue Ptr, MachineMemOperand *MMO) {
8087   SDValue Undef = getUNDEF(Ptr.getValueType());
8088   return getLoad(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, dl, Chain, Ptr, Undef,
8089                  VT, MMO);
8090 }
8091 
getExtLoad(ISD::LoadExtType ExtType,const SDLoc & dl,EVT VT,SDValue Chain,SDValue Ptr,MachinePointerInfo PtrInfo,EVT MemVT,MaybeAlign Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo)8092 SDValue SelectionDAG::getExtLoad(ISD::LoadExtType ExtType, const SDLoc &dl,
8093                                  EVT VT, SDValue Chain, SDValue Ptr,
8094                                  MachinePointerInfo PtrInfo, EVT MemVT,
8095                                  MaybeAlign Alignment,
8096                                  MachineMemOperand::Flags MMOFlags,
8097                                  const AAMDNodes &AAInfo) {
8098   SDValue Undef = getUNDEF(Ptr.getValueType());
8099   return getLoad(ISD::UNINDEXED, ExtType, VT, dl, Chain, Ptr, Undef, PtrInfo,
8100                  MemVT, Alignment, MMOFlags, AAInfo);
8101 }
8102 
getExtLoad(ISD::LoadExtType ExtType,const SDLoc & dl,EVT VT,SDValue Chain,SDValue Ptr,EVT MemVT,MachineMemOperand * MMO)8103 SDValue SelectionDAG::getExtLoad(ISD::LoadExtType ExtType, const SDLoc &dl,
8104                                  EVT VT, SDValue Chain, SDValue Ptr, EVT MemVT,
8105                                  MachineMemOperand *MMO) {
8106   SDValue Undef = getUNDEF(Ptr.getValueType());
8107   return getLoad(ISD::UNINDEXED, ExtType, VT, dl, Chain, Ptr, Undef,
8108                  MemVT, MMO);
8109 }
8110 
getIndexedLoad(SDValue OrigLoad,const SDLoc & dl,SDValue Base,SDValue Offset,ISD::MemIndexedMode AM)8111 SDValue SelectionDAG::getIndexedLoad(SDValue OrigLoad, const SDLoc &dl,
8112                                      SDValue Base, SDValue Offset,
8113                                      ISD::MemIndexedMode AM) {
8114   LoadSDNode *LD = cast<LoadSDNode>(OrigLoad);
8115   assert(LD->getOffset().isUndef() && "Load is already a indexed load!");
8116   // Don't propagate the invariant or dereferenceable flags.
8117   auto MMOFlags =
8118       LD->getMemOperand()->getFlags() &
8119       ~(MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable);
8120   return getLoad(AM, LD->getExtensionType(), OrigLoad.getValueType(), dl,
8121                  LD->getChain(), Base, Offset, LD->getPointerInfo(),
8122                  LD->getMemoryVT(), LD->getAlign(), MMOFlags, LD->getAAInfo());
8123 }
8124 
getStore(SDValue Chain,const SDLoc & dl,SDValue Val,SDValue Ptr,MachinePointerInfo PtrInfo,Align Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo)8125 SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val,
8126                                SDValue Ptr, MachinePointerInfo PtrInfo,
8127                                Align Alignment,
8128                                MachineMemOperand::Flags MMOFlags,
8129                                const AAMDNodes &AAInfo) {
8130   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8131 
8132   MMOFlags |= MachineMemOperand::MOStore;
8133   assert((MMOFlags & MachineMemOperand::MOLoad) == 0);
8134 
8135   if (PtrInfo.V.isNull())
8136     PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr);
8137 
8138   MachineFunction &MF = getMachineFunction();
8139   uint64_t Size =
8140       MemoryLocation::getSizeOrUnknown(Val.getValueType().getStoreSize());
8141   MachineMemOperand *MMO =
8142       MF.getMachineMemOperand(PtrInfo, MMOFlags, Size, Alignment, AAInfo);
8143   return getStore(Chain, dl, Val, Ptr, MMO);
8144 }
8145 
getStore(SDValue Chain,const SDLoc & dl,SDValue Val,SDValue Ptr,MachineMemOperand * MMO)8146 SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val,
8147                                SDValue Ptr, MachineMemOperand *MMO) {
8148   assert(Chain.getValueType() == MVT::Other &&
8149         "Invalid chain type");
8150   EVT VT = Val.getValueType();
8151   SDVTList VTs = getVTList(MVT::Other);
8152   SDValue Undef = getUNDEF(Ptr.getValueType());
8153   SDValue Ops[] = { Chain, Val, Ptr, Undef };
8154   FoldingSetNodeID ID;
8155   AddNodeIDNode(ID, ISD::STORE, VTs, Ops);
8156   ID.AddInteger(VT.getRawBits());
8157   ID.AddInteger(getSyntheticNodeSubclassData<StoreSDNode>(
8158       dl.getIROrder(), VTs, ISD::UNINDEXED, false, VT, MMO));
8159   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8160   ID.AddInteger(MMO->getFlags());
8161   void *IP = nullptr;
8162   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8163     cast<StoreSDNode>(E)->refineAlignment(MMO);
8164     return SDValue(E, 0);
8165   }
8166   auto *N = newSDNode<StoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
8167                                    ISD::UNINDEXED, false, VT, MMO);
8168   createOperands(N, Ops);
8169 
8170   CSEMap.InsertNode(N, IP);
8171   InsertNode(N);
8172   SDValue V(N, 0);
8173   NewSDValueDbgMsg(V, "Creating new node: ", this);
8174   return V;
8175 }
8176 
getTruncStore(SDValue Chain,const SDLoc & dl,SDValue Val,SDValue Ptr,MachinePointerInfo PtrInfo,EVT SVT,Align Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo)8177 SDValue SelectionDAG::getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val,
8178                                     SDValue Ptr, MachinePointerInfo PtrInfo,
8179                                     EVT SVT, Align Alignment,
8180                                     MachineMemOperand::Flags MMOFlags,
8181                                     const AAMDNodes &AAInfo) {
8182   assert(Chain.getValueType() == MVT::Other &&
8183         "Invalid chain type");
8184 
8185   MMOFlags |= MachineMemOperand::MOStore;
8186   assert((MMOFlags & MachineMemOperand::MOLoad) == 0);
8187 
8188   if (PtrInfo.V.isNull())
8189     PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr);
8190 
8191   MachineFunction &MF = getMachineFunction();
8192   MachineMemOperand *MMO = MF.getMachineMemOperand(
8193       PtrInfo, MMOFlags, MemoryLocation::getSizeOrUnknown(SVT.getStoreSize()),
8194       Alignment, AAInfo);
8195   return getTruncStore(Chain, dl, Val, Ptr, SVT, MMO);
8196 }
8197 
getTruncStore(SDValue Chain,const SDLoc & dl,SDValue Val,SDValue Ptr,EVT SVT,MachineMemOperand * MMO)8198 SDValue SelectionDAG::getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val,
8199                                     SDValue Ptr, EVT SVT,
8200                                     MachineMemOperand *MMO) {
8201   EVT VT = Val.getValueType();
8202 
8203   assert(Chain.getValueType() == MVT::Other &&
8204         "Invalid chain type");
8205   if (VT == SVT)
8206     return getStore(Chain, dl, Val, Ptr, MMO);
8207 
8208   assert(SVT.getScalarType().bitsLT(VT.getScalarType()) &&
8209          "Should only be a truncating store, not extending!");
8210   assert(VT.isInteger() == SVT.isInteger() &&
8211          "Can't do FP-INT conversion!");
8212   assert(VT.isVector() == SVT.isVector() &&
8213          "Cannot use trunc store to convert to or from a vector!");
8214   assert((!VT.isVector() ||
8215           VT.getVectorElementCount() == SVT.getVectorElementCount()) &&
8216          "Cannot use trunc store to change the number of vector elements!");
8217 
8218   SDVTList VTs = getVTList(MVT::Other);
8219   SDValue Undef = getUNDEF(Ptr.getValueType());
8220   SDValue Ops[] = { Chain, Val, Ptr, Undef };
8221   FoldingSetNodeID ID;
8222   AddNodeIDNode(ID, ISD::STORE, VTs, Ops);
8223   ID.AddInteger(SVT.getRawBits());
8224   ID.AddInteger(getSyntheticNodeSubclassData<StoreSDNode>(
8225       dl.getIROrder(), VTs, ISD::UNINDEXED, true, SVT, MMO));
8226   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8227   ID.AddInteger(MMO->getFlags());
8228   void *IP = nullptr;
8229   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8230     cast<StoreSDNode>(E)->refineAlignment(MMO);
8231     return SDValue(E, 0);
8232   }
8233   auto *N = newSDNode<StoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
8234                                    ISD::UNINDEXED, true, SVT, MMO);
8235   createOperands(N, Ops);
8236 
8237   CSEMap.InsertNode(N, IP);
8238   InsertNode(N);
8239   SDValue V(N, 0);
8240   NewSDValueDbgMsg(V, "Creating new node: ", this);
8241   return V;
8242 }
8243 
getIndexedStore(SDValue OrigStore,const SDLoc & dl,SDValue Base,SDValue Offset,ISD::MemIndexedMode AM)8244 SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl,
8245                                       SDValue Base, SDValue Offset,
8246                                       ISD::MemIndexedMode AM) {
8247   StoreSDNode *ST = cast<StoreSDNode>(OrigStore);
8248   assert(ST->getOffset().isUndef() && "Store is already a indexed store!");
8249   SDVTList VTs = getVTList(Base.getValueType(), MVT::Other);
8250   SDValue Ops[] = { ST->getChain(), ST->getValue(), Base, Offset };
8251   FoldingSetNodeID ID;
8252   AddNodeIDNode(ID, ISD::STORE, VTs, Ops);
8253   ID.AddInteger(ST->getMemoryVT().getRawBits());
8254   ID.AddInteger(ST->getRawSubclassData());
8255   ID.AddInteger(ST->getPointerInfo().getAddrSpace());
8256   ID.AddInteger(ST->getMemOperand()->getFlags());
8257   void *IP = nullptr;
8258   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP))
8259     return SDValue(E, 0);
8260 
8261   auto *N = newSDNode<StoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, AM,
8262                                    ST->isTruncatingStore(), ST->getMemoryVT(),
8263                                    ST->getMemOperand());
8264   createOperands(N, Ops);
8265 
8266   CSEMap.InsertNode(N, IP);
8267   InsertNode(N);
8268   SDValue V(N, 0);
8269   NewSDValueDbgMsg(V, "Creating new node: ", this);
8270   return V;
8271 }
8272 
getLoadVP(ISD::MemIndexedMode AM,ISD::LoadExtType ExtType,EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,SDValue Offset,SDValue Mask,SDValue EVL,MachinePointerInfo PtrInfo,EVT MemVT,Align Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,const MDNode * Ranges,bool IsExpanding)8273 SDValue SelectionDAG::getLoadVP(
8274     ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &dl,
8275     SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Mask, SDValue EVL,
8276     MachinePointerInfo PtrInfo, EVT MemVT, Align Alignment,
8277     MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo,
8278     const MDNode *Ranges, bool IsExpanding) {
8279   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8280 
8281   MMOFlags |= MachineMemOperand::MOLoad;
8282   assert((MMOFlags & MachineMemOperand::MOStore) == 0);
8283   // If we don't have a PtrInfo, infer the trivial frame index case to simplify
8284   // clients.
8285   if (PtrInfo.V.isNull())
8286     PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr, Offset);
8287 
8288   uint64_t Size = MemoryLocation::getSizeOrUnknown(MemVT.getStoreSize());
8289   MachineFunction &MF = getMachineFunction();
8290   MachineMemOperand *MMO = MF.getMachineMemOperand(PtrInfo, MMOFlags, Size,
8291                                                    Alignment, AAInfo, Ranges);
8292   return getLoadVP(AM, ExtType, VT, dl, Chain, Ptr, Offset, Mask, EVL, MemVT,
8293                    MMO, IsExpanding);
8294 }
8295 
getLoadVP(ISD::MemIndexedMode AM,ISD::LoadExtType ExtType,EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,SDValue Offset,SDValue Mask,SDValue EVL,EVT MemVT,MachineMemOperand * MMO,bool IsExpanding)8296 SDValue SelectionDAG::getLoadVP(ISD::MemIndexedMode AM,
8297                                 ISD::LoadExtType ExtType, EVT VT,
8298                                 const SDLoc &dl, SDValue Chain, SDValue Ptr,
8299                                 SDValue Offset, SDValue Mask, SDValue EVL,
8300                                 EVT MemVT, MachineMemOperand *MMO,
8301                                 bool IsExpanding) {
8302   bool Indexed = AM != ISD::UNINDEXED;
8303   assert((Indexed || Offset.isUndef()) && "Unindexed load with an offset!");
8304 
8305   SDVTList VTs = Indexed ? getVTList(VT, Ptr.getValueType(), MVT::Other)
8306                          : getVTList(VT, MVT::Other);
8307   SDValue Ops[] = {Chain, Ptr, Offset, Mask, EVL};
8308   FoldingSetNodeID ID;
8309   AddNodeIDNode(ID, ISD::VP_LOAD, VTs, Ops);
8310   ID.AddInteger(VT.getRawBits());
8311   ID.AddInteger(getSyntheticNodeSubclassData<VPLoadSDNode>(
8312       dl.getIROrder(), VTs, AM, ExtType, IsExpanding, MemVT, MMO));
8313   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8314   ID.AddInteger(MMO->getFlags());
8315   void *IP = nullptr;
8316   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8317     cast<VPLoadSDNode>(E)->refineAlignment(MMO);
8318     return SDValue(E, 0);
8319   }
8320   auto *N = newSDNode<VPLoadSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, AM,
8321                                     ExtType, IsExpanding, MemVT, MMO);
8322   createOperands(N, Ops);
8323 
8324   CSEMap.InsertNode(N, IP);
8325   InsertNode(N);
8326   SDValue V(N, 0);
8327   NewSDValueDbgMsg(V, "Creating new node: ", this);
8328   return V;
8329 }
8330 
getLoadVP(EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,SDValue Mask,SDValue EVL,MachinePointerInfo PtrInfo,MaybeAlign Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,const MDNode * Ranges,bool IsExpanding)8331 SDValue SelectionDAG::getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain,
8332                                 SDValue Ptr, SDValue Mask, SDValue EVL,
8333                                 MachinePointerInfo PtrInfo,
8334                                 MaybeAlign Alignment,
8335                                 MachineMemOperand::Flags MMOFlags,
8336                                 const AAMDNodes &AAInfo, const MDNode *Ranges,
8337                                 bool IsExpanding) {
8338   SDValue Undef = getUNDEF(Ptr.getValueType());
8339   return getLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, dl, Chain, Ptr, Undef,
8340                    Mask, EVL, PtrInfo, VT, Alignment, MMOFlags, AAInfo, Ranges,
8341                    IsExpanding);
8342 }
8343 
getLoadVP(EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,SDValue Mask,SDValue EVL,MachineMemOperand * MMO,bool IsExpanding)8344 SDValue SelectionDAG::getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain,
8345                                 SDValue Ptr, SDValue Mask, SDValue EVL,
8346                                 MachineMemOperand *MMO, bool IsExpanding) {
8347   SDValue Undef = getUNDEF(Ptr.getValueType());
8348   return getLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, dl, Chain, Ptr, Undef,
8349                    Mask, EVL, VT, MMO, IsExpanding);
8350 }
8351 
getExtLoadVP(ISD::LoadExtType ExtType,const SDLoc & dl,EVT VT,SDValue Chain,SDValue Ptr,SDValue Mask,SDValue EVL,MachinePointerInfo PtrInfo,EVT MemVT,MaybeAlign Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,bool IsExpanding)8352 SDValue SelectionDAG::getExtLoadVP(ISD::LoadExtType ExtType, const SDLoc &dl,
8353                                    EVT VT, SDValue Chain, SDValue Ptr,
8354                                    SDValue Mask, SDValue EVL,
8355                                    MachinePointerInfo PtrInfo, EVT MemVT,
8356                                    MaybeAlign Alignment,
8357                                    MachineMemOperand::Flags MMOFlags,
8358                                    const AAMDNodes &AAInfo, bool IsExpanding) {
8359   SDValue Undef = getUNDEF(Ptr.getValueType());
8360   return getLoadVP(ISD::UNINDEXED, ExtType, VT, dl, Chain, Ptr, Undef, Mask,
8361                    EVL, PtrInfo, MemVT, Alignment, MMOFlags, AAInfo, nullptr,
8362                    IsExpanding);
8363 }
8364 
getExtLoadVP(ISD::LoadExtType ExtType,const SDLoc & dl,EVT VT,SDValue Chain,SDValue Ptr,SDValue Mask,SDValue EVL,EVT MemVT,MachineMemOperand * MMO,bool IsExpanding)8365 SDValue SelectionDAG::getExtLoadVP(ISD::LoadExtType ExtType, const SDLoc &dl,
8366                                    EVT VT, SDValue Chain, SDValue Ptr,
8367                                    SDValue Mask, SDValue EVL, EVT MemVT,
8368                                    MachineMemOperand *MMO, bool IsExpanding) {
8369   SDValue Undef = getUNDEF(Ptr.getValueType());
8370   return getLoadVP(ISD::UNINDEXED, ExtType, VT, dl, Chain, Ptr, Undef, Mask,
8371                    EVL, MemVT, MMO, IsExpanding);
8372 }
8373 
getIndexedLoadVP(SDValue OrigLoad,const SDLoc & dl,SDValue Base,SDValue Offset,ISD::MemIndexedMode AM)8374 SDValue SelectionDAG::getIndexedLoadVP(SDValue OrigLoad, const SDLoc &dl,
8375                                        SDValue Base, SDValue Offset,
8376                                        ISD::MemIndexedMode AM) {
8377   auto *LD = cast<VPLoadSDNode>(OrigLoad);
8378   assert(LD->getOffset().isUndef() && "Load is already a indexed load!");
8379   // Don't propagate the invariant or dereferenceable flags.
8380   auto MMOFlags =
8381       LD->getMemOperand()->getFlags() &
8382       ~(MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable);
8383   return getLoadVP(AM, LD->getExtensionType(), OrigLoad.getValueType(), dl,
8384                    LD->getChain(), Base, Offset, LD->getMask(),
8385                    LD->getVectorLength(), LD->getPointerInfo(),
8386                    LD->getMemoryVT(), LD->getAlign(), MMOFlags, LD->getAAInfo(),
8387                    nullptr, LD->isExpandingLoad());
8388 }
8389 
getStoreVP(SDValue Chain,const SDLoc & dl,SDValue Val,SDValue Ptr,SDValue Offset,SDValue Mask,SDValue EVL,EVT MemVT,MachineMemOperand * MMO,ISD::MemIndexedMode AM,bool IsTruncating,bool IsCompressing)8390 SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val,
8391                                  SDValue Ptr, SDValue Offset, SDValue Mask,
8392                                  SDValue EVL, EVT MemVT, MachineMemOperand *MMO,
8393                                  ISD::MemIndexedMode AM, bool IsTruncating,
8394                                  bool IsCompressing) {
8395   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8396   bool Indexed = AM != ISD::UNINDEXED;
8397   assert((Indexed || Offset.isUndef()) && "Unindexed vp_store with an offset!");
8398   SDVTList VTs = Indexed ? getVTList(Ptr.getValueType(), MVT::Other)
8399                          : getVTList(MVT::Other);
8400   SDValue Ops[] = {Chain, Val, Ptr, Offset, Mask, EVL};
8401   FoldingSetNodeID ID;
8402   AddNodeIDNode(ID, ISD::VP_STORE, VTs, Ops);
8403   ID.AddInteger(MemVT.getRawBits());
8404   ID.AddInteger(getSyntheticNodeSubclassData<VPStoreSDNode>(
8405       dl.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO));
8406   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8407   ID.AddInteger(MMO->getFlags());
8408   void *IP = nullptr;
8409   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8410     cast<VPStoreSDNode>(E)->refineAlignment(MMO);
8411     return SDValue(E, 0);
8412   }
8413   auto *N = newSDNode<VPStoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, AM,
8414                                      IsTruncating, IsCompressing, MemVT, MMO);
8415   createOperands(N, Ops);
8416 
8417   CSEMap.InsertNode(N, IP);
8418   InsertNode(N);
8419   SDValue V(N, 0);
8420   NewSDValueDbgMsg(V, "Creating new node: ", this);
8421   return V;
8422 }
8423 
getTruncStoreVP(SDValue Chain,const SDLoc & dl,SDValue Val,SDValue Ptr,SDValue Mask,SDValue EVL,MachinePointerInfo PtrInfo,EVT SVT,Align Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,bool IsCompressing)8424 SDValue SelectionDAG::getTruncStoreVP(SDValue Chain, const SDLoc &dl,
8425                                       SDValue Val, SDValue Ptr, SDValue Mask,
8426                                       SDValue EVL, MachinePointerInfo PtrInfo,
8427                                       EVT SVT, Align Alignment,
8428                                       MachineMemOperand::Flags MMOFlags,
8429                                       const AAMDNodes &AAInfo,
8430                                       bool IsCompressing) {
8431   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8432 
8433   MMOFlags |= MachineMemOperand::MOStore;
8434   assert((MMOFlags & MachineMemOperand::MOLoad) == 0);
8435 
8436   if (PtrInfo.V.isNull())
8437     PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr);
8438 
8439   MachineFunction &MF = getMachineFunction();
8440   MachineMemOperand *MMO = MF.getMachineMemOperand(
8441       PtrInfo, MMOFlags, MemoryLocation::getSizeOrUnknown(SVT.getStoreSize()),
8442       Alignment, AAInfo);
8443   return getTruncStoreVP(Chain, dl, Val, Ptr, Mask, EVL, SVT, MMO,
8444                          IsCompressing);
8445 }
8446 
getTruncStoreVP(SDValue Chain,const SDLoc & dl,SDValue Val,SDValue Ptr,SDValue Mask,SDValue EVL,EVT SVT,MachineMemOperand * MMO,bool IsCompressing)8447 SDValue SelectionDAG::getTruncStoreVP(SDValue Chain, const SDLoc &dl,
8448                                       SDValue Val, SDValue Ptr, SDValue Mask,
8449                                       SDValue EVL, EVT SVT,
8450                                       MachineMemOperand *MMO,
8451                                       bool IsCompressing) {
8452   EVT VT = Val.getValueType();
8453 
8454   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8455   if (VT == SVT)
8456     return getStoreVP(Chain, dl, Val, Ptr, getUNDEF(Ptr.getValueType()), Mask,
8457                       EVL, VT, MMO, ISD::UNINDEXED,
8458                       /*IsTruncating*/ false, IsCompressing);
8459 
8460   assert(SVT.getScalarType().bitsLT(VT.getScalarType()) &&
8461          "Should only be a truncating store, not extending!");
8462   assert(VT.isInteger() == SVT.isInteger() && "Can't do FP-INT conversion!");
8463   assert(VT.isVector() == SVT.isVector() &&
8464          "Cannot use trunc store to convert to or from a vector!");
8465   assert((!VT.isVector() ||
8466           VT.getVectorElementCount() == SVT.getVectorElementCount()) &&
8467          "Cannot use trunc store to change the number of vector elements!");
8468 
8469   SDVTList VTs = getVTList(MVT::Other);
8470   SDValue Undef = getUNDEF(Ptr.getValueType());
8471   SDValue Ops[] = {Chain, Val, Ptr, Undef, Mask, EVL};
8472   FoldingSetNodeID ID;
8473   AddNodeIDNode(ID, ISD::VP_STORE, VTs, Ops);
8474   ID.AddInteger(SVT.getRawBits());
8475   ID.AddInteger(getSyntheticNodeSubclassData<VPStoreSDNode>(
8476       dl.getIROrder(), VTs, ISD::UNINDEXED, true, IsCompressing, SVT, MMO));
8477   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8478   ID.AddInteger(MMO->getFlags());
8479   void *IP = nullptr;
8480   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8481     cast<VPStoreSDNode>(E)->refineAlignment(MMO);
8482     return SDValue(E, 0);
8483   }
8484   auto *N =
8485       newSDNode<VPStoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
8486                                ISD::UNINDEXED, true, IsCompressing, SVT, MMO);
8487   createOperands(N, Ops);
8488 
8489   CSEMap.InsertNode(N, IP);
8490   InsertNode(N);
8491   SDValue V(N, 0);
8492   NewSDValueDbgMsg(V, "Creating new node: ", this);
8493   return V;
8494 }
8495 
getIndexedStoreVP(SDValue OrigStore,const SDLoc & dl,SDValue Base,SDValue Offset,ISD::MemIndexedMode AM)8496 SDValue SelectionDAG::getIndexedStoreVP(SDValue OrigStore, const SDLoc &dl,
8497                                         SDValue Base, SDValue Offset,
8498                                         ISD::MemIndexedMode AM) {
8499   auto *ST = cast<VPStoreSDNode>(OrigStore);
8500   assert(ST->getOffset().isUndef() && "Store is already an indexed store!");
8501   SDVTList VTs = getVTList(Base.getValueType(), MVT::Other);
8502   SDValue Ops[] = {ST->getChain(), ST->getValue(), Base,
8503                    Offset,         ST->getMask(),  ST->getVectorLength()};
8504   FoldingSetNodeID ID;
8505   AddNodeIDNode(ID, ISD::VP_STORE, VTs, Ops);
8506   ID.AddInteger(ST->getMemoryVT().getRawBits());
8507   ID.AddInteger(ST->getRawSubclassData());
8508   ID.AddInteger(ST->getPointerInfo().getAddrSpace());
8509   ID.AddInteger(ST->getMemOperand()->getFlags());
8510   void *IP = nullptr;
8511   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP))
8512     return SDValue(E, 0);
8513 
8514   auto *N = newSDNode<VPStoreSDNode>(
8515       dl.getIROrder(), dl.getDebugLoc(), VTs, AM, ST->isTruncatingStore(),
8516       ST->isCompressingStore(), ST->getMemoryVT(), ST->getMemOperand());
8517   createOperands(N, Ops);
8518 
8519   CSEMap.InsertNode(N, IP);
8520   InsertNode(N);
8521   SDValue V(N, 0);
8522   NewSDValueDbgMsg(V, "Creating new node: ", this);
8523   return V;
8524 }
8525 
getStridedLoadVP(ISD::MemIndexedMode AM,ISD::LoadExtType ExtType,EVT VT,const SDLoc & DL,SDValue Chain,SDValue Ptr,SDValue Offset,SDValue Stride,SDValue Mask,SDValue EVL,MachinePointerInfo PtrInfo,EVT MemVT,Align Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,const MDNode * Ranges,bool IsExpanding)8526 SDValue SelectionDAG::getStridedLoadVP(
8527     ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &DL,
8528     SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Stride, SDValue Mask,
8529     SDValue EVL, MachinePointerInfo PtrInfo, EVT MemVT, Align Alignment,
8530     MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo,
8531     const MDNode *Ranges, bool IsExpanding) {
8532   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8533 
8534   MMOFlags |= MachineMemOperand::MOLoad;
8535   assert((MMOFlags & MachineMemOperand::MOStore) == 0);
8536   // If we don't have a PtrInfo, infer the trivial frame index case to simplify
8537   // clients.
8538   if (PtrInfo.V.isNull())
8539     PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr, Offset);
8540 
8541   uint64_t Size = MemoryLocation::UnknownSize;
8542   MachineFunction &MF = getMachineFunction();
8543   MachineMemOperand *MMO = MF.getMachineMemOperand(PtrInfo, MMOFlags, Size,
8544                                                    Alignment, AAInfo, Ranges);
8545   return getStridedLoadVP(AM, ExtType, VT, DL, Chain, Ptr, Offset, Stride, Mask,
8546                           EVL, MemVT, MMO, IsExpanding);
8547 }
8548 
getStridedLoadVP(ISD::MemIndexedMode AM,ISD::LoadExtType ExtType,EVT VT,const SDLoc & DL,SDValue Chain,SDValue Ptr,SDValue Offset,SDValue Stride,SDValue Mask,SDValue EVL,EVT MemVT,MachineMemOperand * MMO,bool IsExpanding)8549 SDValue SelectionDAG::getStridedLoadVP(
8550     ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &DL,
8551     SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Stride, SDValue Mask,
8552     SDValue EVL, EVT MemVT, MachineMemOperand *MMO, bool IsExpanding) {
8553   bool Indexed = AM != ISD::UNINDEXED;
8554   assert((Indexed || Offset.isUndef()) && "Unindexed load with an offset!");
8555 
8556   SDValue Ops[] = {Chain, Ptr, Offset, Stride, Mask, EVL};
8557   SDVTList VTs = Indexed ? getVTList(VT, Ptr.getValueType(), MVT::Other)
8558                          : getVTList(VT, MVT::Other);
8559   FoldingSetNodeID ID;
8560   AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_LOAD, VTs, Ops);
8561   ID.AddInteger(VT.getRawBits());
8562   ID.AddInteger(getSyntheticNodeSubclassData<VPStridedLoadSDNode>(
8563       DL.getIROrder(), VTs, AM, ExtType, IsExpanding, MemVT, MMO));
8564   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8565 
8566   void *IP = nullptr;
8567   if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
8568     cast<VPStridedLoadSDNode>(E)->refineAlignment(MMO);
8569     return SDValue(E, 0);
8570   }
8571 
8572   auto *N =
8573       newSDNode<VPStridedLoadSDNode>(DL.getIROrder(), DL.getDebugLoc(), VTs, AM,
8574                                      ExtType, IsExpanding, MemVT, MMO);
8575   createOperands(N, Ops);
8576   CSEMap.InsertNode(N, IP);
8577   InsertNode(N);
8578   SDValue V(N, 0);
8579   NewSDValueDbgMsg(V, "Creating new node: ", this);
8580   return V;
8581 }
8582 
getStridedLoadVP(EVT VT,const SDLoc & DL,SDValue Chain,SDValue Ptr,SDValue Stride,SDValue Mask,SDValue EVL,MachinePointerInfo PtrInfo,MaybeAlign Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,const MDNode * Ranges,bool IsExpanding)8583 SDValue SelectionDAG::getStridedLoadVP(
8584     EVT VT, const SDLoc &DL, SDValue Chain, SDValue Ptr, SDValue Stride,
8585     SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo, MaybeAlign Alignment,
8586     MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo,
8587     const MDNode *Ranges, bool IsExpanding) {
8588   SDValue Undef = getUNDEF(Ptr.getValueType());
8589   return getStridedLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, DL, Chain, Ptr,
8590                           Undef, Stride, Mask, EVL, PtrInfo, VT, Alignment,
8591                           MMOFlags, AAInfo, Ranges, IsExpanding);
8592 }
8593 
getStridedLoadVP(EVT VT,const SDLoc & DL,SDValue Chain,SDValue Ptr,SDValue Stride,SDValue Mask,SDValue EVL,MachineMemOperand * MMO,bool IsExpanding)8594 SDValue SelectionDAG::getStridedLoadVP(EVT VT, const SDLoc &DL, SDValue Chain,
8595                                        SDValue Ptr, SDValue Stride,
8596                                        SDValue Mask, SDValue EVL,
8597                                        MachineMemOperand *MMO,
8598                                        bool IsExpanding) {
8599   SDValue Undef = getUNDEF(Ptr.getValueType());
8600   return getStridedLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, DL, Chain, Ptr,
8601                           Undef, Stride, Mask, EVL, VT, MMO, IsExpanding);
8602 }
8603 
getExtStridedLoadVP(ISD::LoadExtType ExtType,const SDLoc & DL,EVT VT,SDValue Chain,SDValue Ptr,SDValue Stride,SDValue Mask,SDValue EVL,MachinePointerInfo PtrInfo,EVT MemVT,MaybeAlign Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,bool IsExpanding)8604 SDValue SelectionDAG::getExtStridedLoadVP(
8605     ISD::LoadExtType ExtType, const SDLoc &DL, EVT VT, SDValue Chain,
8606     SDValue Ptr, SDValue Stride, SDValue Mask, SDValue EVL,
8607     MachinePointerInfo PtrInfo, EVT MemVT, MaybeAlign Alignment,
8608     MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo,
8609     bool IsExpanding) {
8610   SDValue Undef = getUNDEF(Ptr.getValueType());
8611   return getStridedLoadVP(ISD::UNINDEXED, ExtType, VT, DL, Chain, Ptr, Undef,
8612                           Stride, Mask, EVL, PtrInfo, MemVT, Alignment,
8613                           MMOFlags, AAInfo, nullptr, IsExpanding);
8614 }
8615 
getExtStridedLoadVP(ISD::LoadExtType ExtType,const SDLoc & DL,EVT VT,SDValue Chain,SDValue Ptr,SDValue Stride,SDValue Mask,SDValue EVL,EVT MemVT,MachineMemOperand * MMO,bool IsExpanding)8616 SDValue SelectionDAG::getExtStridedLoadVP(
8617     ISD::LoadExtType ExtType, const SDLoc &DL, EVT VT, SDValue Chain,
8618     SDValue Ptr, SDValue Stride, SDValue Mask, SDValue EVL, EVT MemVT,
8619     MachineMemOperand *MMO, bool IsExpanding) {
8620   SDValue Undef = getUNDEF(Ptr.getValueType());
8621   return getStridedLoadVP(ISD::UNINDEXED, ExtType, VT, DL, Chain, Ptr, Undef,
8622                           Stride, Mask, EVL, MemVT, MMO, IsExpanding);
8623 }
8624 
getIndexedStridedLoadVP(SDValue OrigLoad,const SDLoc & DL,SDValue Base,SDValue Offset,ISD::MemIndexedMode AM)8625 SDValue SelectionDAG::getIndexedStridedLoadVP(SDValue OrigLoad, const SDLoc &DL,
8626                                               SDValue Base, SDValue Offset,
8627                                               ISD::MemIndexedMode AM) {
8628   auto *SLD = cast<VPStridedLoadSDNode>(OrigLoad);
8629   assert(SLD->getOffset().isUndef() &&
8630          "Strided load is already a indexed load!");
8631   // Don't propagate the invariant or dereferenceable flags.
8632   auto MMOFlags =
8633       SLD->getMemOperand()->getFlags() &
8634       ~(MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable);
8635   return getStridedLoadVP(
8636       AM, SLD->getExtensionType(), OrigLoad.getValueType(), DL, SLD->getChain(),
8637       Base, Offset, SLD->getStride(), SLD->getMask(), SLD->getVectorLength(),
8638       SLD->getPointerInfo(), SLD->getMemoryVT(), SLD->getAlign(), MMOFlags,
8639       SLD->getAAInfo(), nullptr, SLD->isExpandingLoad());
8640 }
8641 
getStridedStoreVP(SDValue Chain,const SDLoc & DL,SDValue Val,SDValue Ptr,SDValue Offset,SDValue Stride,SDValue Mask,SDValue EVL,EVT MemVT,MachineMemOperand * MMO,ISD::MemIndexedMode AM,bool IsTruncating,bool IsCompressing)8642 SDValue SelectionDAG::getStridedStoreVP(SDValue Chain, const SDLoc &DL,
8643                                         SDValue Val, SDValue Ptr,
8644                                         SDValue Offset, SDValue Stride,
8645                                         SDValue Mask, SDValue EVL, EVT MemVT,
8646                                         MachineMemOperand *MMO,
8647                                         ISD::MemIndexedMode AM,
8648                                         bool IsTruncating, bool IsCompressing) {
8649   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8650   bool Indexed = AM != ISD::UNINDEXED;
8651   assert((Indexed || Offset.isUndef()) && "Unindexed vp_store with an offset!");
8652   SDVTList VTs = Indexed ? getVTList(Ptr.getValueType(), MVT::Other)
8653                          : getVTList(MVT::Other);
8654   SDValue Ops[] = {Chain, Val, Ptr, Offset, Stride, Mask, EVL};
8655   FoldingSetNodeID ID;
8656   AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops);
8657   ID.AddInteger(MemVT.getRawBits());
8658   ID.AddInteger(getSyntheticNodeSubclassData<VPStridedStoreSDNode>(
8659       DL.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO));
8660   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8661   void *IP = nullptr;
8662   if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
8663     cast<VPStridedStoreSDNode>(E)->refineAlignment(MMO);
8664     return SDValue(E, 0);
8665   }
8666   auto *N = newSDNode<VPStridedStoreSDNode>(DL.getIROrder(), DL.getDebugLoc(),
8667                                             VTs, AM, IsTruncating,
8668                                             IsCompressing, MemVT, MMO);
8669   createOperands(N, Ops);
8670 
8671   CSEMap.InsertNode(N, IP);
8672   InsertNode(N);
8673   SDValue V(N, 0);
8674   NewSDValueDbgMsg(V, "Creating new node: ", this);
8675   return V;
8676 }
8677 
getTruncStridedStoreVP(SDValue Chain,const SDLoc & DL,SDValue Val,SDValue Ptr,SDValue Stride,SDValue Mask,SDValue EVL,MachinePointerInfo PtrInfo,EVT SVT,Align Alignment,MachineMemOperand::Flags MMOFlags,const AAMDNodes & AAInfo,bool IsCompressing)8678 SDValue SelectionDAG::getTruncStridedStoreVP(
8679     SDValue Chain, const SDLoc &DL, SDValue Val, SDValue Ptr, SDValue Stride,
8680     SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo, EVT SVT,
8681     Align Alignment, MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo,
8682     bool IsCompressing) {
8683   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8684 
8685   MMOFlags |= MachineMemOperand::MOStore;
8686   assert((MMOFlags & MachineMemOperand::MOLoad) == 0);
8687 
8688   if (PtrInfo.V.isNull())
8689     PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr);
8690 
8691   MachineFunction &MF = getMachineFunction();
8692   MachineMemOperand *MMO = MF.getMachineMemOperand(
8693       PtrInfo, MMOFlags, MemoryLocation::UnknownSize, Alignment, AAInfo);
8694   return getTruncStridedStoreVP(Chain, DL, Val, Ptr, Stride, Mask, EVL, SVT,
8695                                 MMO, IsCompressing);
8696 }
8697 
getTruncStridedStoreVP(SDValue Chain,const SDLoc & DL,SDValue Val,SDValue Ptr,SDValue Stride,SDValue Mask,SDValue EVL,EVT SVT,MachineMemOperand * MMO,bool IsCompressing)8698 SDValue SelectionDAG::getTruncStridedStoreVP(SDValue Chain, const SDLoc &DL,
8699                                              SDValue Val, SDValue Ptr,
8700                                              SDValue Stride, SDValue Mask,
8701                                              SDValue EVL, EVT SVT,
8702                                              MachineMemOperand *MMO,
8703                                              bool IsCompressing) {
8704   EVT VT = Val.getValueType();
8705 
8706   assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
8707   if (VT == SVT)
8708     return getStridedStoreVP(Chain, DL, Val, Ptr, getUNDEF(Ptr.getValueType()),
8709                              Stride, Mask, EVL, VT, MMO, ISD::UNINDEXED,
8710                              /*IsTruncating*/ false, IsCompressing);
8711 
8712   assert(SVT.getScalarType().bitsLT(VT.getScalarType()) &&
8713          "Should only be a truncating store, not extending!");
8714   assert(VT.isInteger() == SVT.isInteger() && "Can't do FP-INT conversion!");
8715   assert(VT.isVector() == SVT.isVector() &&
8716          "Cannot use trunc store to convert to or from a vector!");
8717   assert((!VT.isVector() ||
8718           VT.getVectorElementCount() == SVT.getVectorElementCount()) &&
8719          "Cannot use trunc store to change the number of vector elements!");
8720 
8721   SDVTList VTs = getVTList(MVT::Other);
8722   SDValue Undef = getUNDEF(Ptr.getValueType());
8723   SDValue Ops[] = {Chain, Val, Ptr, Undef, Stride, Mask, EVL};
8724   FoldingSetNodeID ID;
8725   AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops);
8726   ID.AddInteger(SVT.getRawBits());
8727   ID.AddInteger(getSyntheticNodeSubclassData<VPStridedStoreSDNode>(
8728       DL.getIROrder(), VTs, ISD::UNINDEXED, true, IsCompressing, SVT, MMO));
8729   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8730   void *IP = nullptr;
8731   if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
8732     cast<VPStridedStoreSDNode>(E)->refineAlignment(MMO);
8733     return SDValue(E, 0);
8734   }
8735   auto *N = newSDNode<VPStridedStoreSDNode>(DL.getIROrder(), DL.getDebugLoc(),
8736                                             VTs, ISD::UNINDEXED, true,
8737                                             IsCompressing, SVT, MMO);
8738   createOperands(N, Ops);
8739 
8740   CSEMap.InsertNode(N, IP);
8741   InsertNode(N);
8742   SDValue V(N, 0);
8743   NewSDValueDbgMsg(V, "Creating new node: ", this);
8744   return V;
8745 }
8746 
getIndexedStridedStoreVP(SDValue OrigStore,const SDLoc & DL,SDValue Base,SDValue Offset,ISD::MemIndexedMode AM)8747 SDValue SelectionDAG::getIndexedStridedStoreVP(SDValue OrigStore,
8748                                                const SDLoc &DL, SDValue Base,
8749                                                SDValue Offset,
8750                                                ISD::MemIndexedMode AM) {
8751   auto *SST = cast<VPStridedStoreSDNode>(OrigStore);
8752   assert(SST->getOffset().isUndef() &&
8753          "Strided store is already an indexed store!");
8754   SDVTList VTs = getVTList(Base.getValueType(), MVT::Other);
8755   SDValue Ops[] = {
8756       SST->getChain(), SST->getValue(),       Base, Offset, SST->getStride(),
8757       SST->getMask(),  SST->getVectorLength()};
8758   FoldingSetNodeID ID;
8759   AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops);
8760   ID.AddInteger(SST->getMemoryVT().getRawBits());
8761   ID.AddInteger(SST->getRawSubclassData());
8762   ID.AddInteger(SST->getPointerInfo().getAddrSpace());
8763   void *IP = nullptr;
8764   if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
8765     return SDValue(E, 0);
8766 
8767   auto *N = newSDNode<VPStridedStoreSDNode>(
8768       DL.getIROrder(), DL.getDebugLoc(), VTs, AM, SST->isTruncatingStore(),
8769       SST->isCompressingStore(), SST->getMemoryVT(), SST->getMemOperand());
8770   createOperands(N, Ops);
8771 
8772   CSEMap.InsertNode(N, IP);
8773   InsertNode(N);
8774   SDValue V(N, 0);
8775   NewSDValueDbgMsg(V, "Creating new node: ", this);
8776   return V;
8777 }
8778 
getGatherVP(SDVTList VTs,EVT VT,const SDLoc & dl,ArrayRef<SDValue> Ops,MachineMemOperand * MMO,ISD::MemIndexType IndexType)8779 SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl,
8780                                   ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
8781                                   ISD::MemIndexType IndexType) {
8782   assert(Ops.size() == 6 && "Incompatible number of operands");
8783 
8784   FoldingSetNodeID ID;
8785   AddNodeIDNode(ID, ISD::VP_GATHER, VTs, Ops);
8786   ID.AddInteger(VT.getRawBits());
8787   ID.AddInteger(getSyntheticNodeSubclassData<VPGatherSDNode>(
8788       dl.getIROrder(), VTs, VT, MMO, IndexType));
8789   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8790   ID.AddInteger(MMO->getFlags());
8791   void *IP = nullptr;
8792   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8793     cast<VPGatherSDNode>(E)->refineAlignment(MMO);
8794     return SDValue(E, 0);
8795   }
8796 
8797   auto *N = newSDNode<VPGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
8798                                       VT, MMO, IndexType);
8799   createOperands(N, Ops);
8800 
8801   assert(N->getMask().getValueType().getVectorElementCount() ==
8802              N->getValueType(0).getVectorElementCount() &&
8803          "Vector width mismatch between mask and data");
8804   assert(N->getIndex().getValueType().getVectorElementCount().isScalable() ==
8805              N->getValueType(0).getVectorElementCount().isScalable() &&
8806          "Scalable flags of index and data do not match");
8807   assert(ElementCount::isKnownGE(
8808              N->getIndex().getValueType().getVectorElementCount(),
8809              N->getValueType(0).getVectorElementCount()) &&
8810          "Vector width mismatch between index and data");
8811   assert(isa<ConstantSDNode>(N->getScale()) &&
8812          cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
8813          "Scale should be a constant power of 2");
8814 
8815   CSEMap.InsertNode(N, IP);
8816   InsertNode(N);
8817   SDValue V(N, 0);
8818   NewSDValueDbgMsg(V, "Creating new node: ", this);
8819   return V;
8820 }
8821 
getScatterVP(SDVTList VTs,EVT VT,const SDLoc & dl,ArrayRef<SDValue> Ops,MachineMemOperand * MMO,ISD::MemIndexType IndexType)8822 SDValue SelectionDAG::getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl,
8823                                    ArrayRef<SDValue> Ops,
8824                                    MachineMemOperand *MMO,
8825                                    ISD::MemIndexType IndexType) {
8826   assert(Ops.size() == 7 && "Incompatible number of operands");
8827 
8828   FoldingSetNodeID ID;
8829   AddNodeIDNode(ID, ISD::VP_SCATTER, VTs, Ops);
8830   ID.AddInteger(VT.getRawBits());
8831   ID.AddInteger(getSyntheticNodeSubclassData<VPScatterSDNode>(
8832       dl.getIROrder(), VTs, VT, MMO, IndexType));
8833   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8834   ID.AddInteger(MMO->getFlags());
8835   void *IP = nullptr;
8836   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8837     cast<VPScatterSDNode>(E)->refineAlignment(MMO);
8838     return SDValue(E, 0);
8839   }
8840   auto *N = newSDNode<VPScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
8841                                        VT, MMO, IndexType);
8842   createOperands(N, Ops);
8843 
8844   assert(N->getMask().getValueType().getVectorElementCount() ==
8845              N->getValue().getValueType().getVectorElementCount() &&
8846          "Vector width mismatch between mask and data");
8847   assert(
8848       N->getIndex().getValueType().getVectorElementCount().isScalable() ==
8849           N->getValue().getValueType().getVectorElementCount().isScalable() &&
8850       "Scalable flags of index and data do not match");
8851   assert(ElementCount::isKnownGE(
8852              N->getIndex().getValueType().getVectorElementCount(),
8853              N->getValue().getValueType().getVectorElementCount()) &&
8854          "Vector width mismatch between index and data");
8855   assert(isa<ConstantSDNode>(N->getScale()) &&
8856          cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
8857          "Scale should be a constant power of 2");
8858 
8859   CSEMap.InsertNode(N, IP);
8860   InsertNode(N);
8861   SDValue V(N, 0);
8862   NewSDValueDbgMsg(V, "Creating new node: ", this);
8863   return V;
8864 }
8865 
getMaskedLoad(EVT VT,const SDLoc & dl,SDValue Chain,SDValue Base,SDValue Offset,SDValue Mask,SDValue PassThru,EVT MemVT,MachineMemOperand * MMO,ISD::MemIndexedMode AM,ISD::LoadExtType ExtTy,bool isExpanding)8866 SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
8867                                     SDValue Base, SDValue Offset, SDValue Mask,
8868                                     SDValue PassThru, EVT MemVT,
8869                                     MachineMemOperand *MMO,
8870                                     ISD::MemIndexedMode AM,
8871                                     ISD::LoadExtType ExtTy, bool isExpanding) {
8872   bool Indexed = AM != ISD::UNINDEXED;
8873   assert((Indexed || Offset.isUndef()) &&
8874          "Unindexed masked load with an offset!");
8875   SDVTList VTs = Indexed ? getVTList(VT, Base.getValueType(), MVT::Other)
8876                          : getVTList(VT, MVT::Other);
8877   SDValue Ops[] = {Chain, Base, Offset, Mask, PassThru};
8878   FoldingSetNodeID ID;
8879   AddNodeIDNode(ID, ISD::MLOAD, VTs, Ops);
8880   ID.AddInteger(MemVT.getRawBits());
8881   ID.AddInteger(getSyntheticNodeSubclassData<MaskedLoadSDNode>(
8882       dl.getIROrder(), VTs, AM, ExtTy, isExpanding, MemVT, MMO));
8883   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8884   ID.AddInteger(MMO->getFlags());
8885   void *IP = nullptr;
8886   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8887     cast<MaskedLoadSDNode>(E)->refineAlignment(MMO);
8888     return SDValue(E, 0);
8889   }
8890   auto *N = newSDNode<MaskedLoadSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
8891                                         AM, ExtTy, isExpanding, MemVT, MMO);
8892   createOperands(N, Ops);
8893 
8894   CSEMap.InsertNode(N, IP);
8895   InsertNode(N);
8896   SDValue V(N, 0);
8897   NewSDValueDbgMsg(V, "Creating new node: ", this);
8898   return V;
8899 }
8900 
getIndexedMaskedLoad(SDValue OrigLoad,const SDLoc & dl,SDValue Base,SDValue Offset,ISD::MemIndexedMode AM)8901 SDValue SelectionDAG::getIndexedMaskedLoad(SDValue OrigLoad, const SDLoc &dl,
8902                                            SDValue Base, SDValue Offset,
8903                                            ISD::MemIndexedMode AM) {
8904   MaskedLoadSDNode *LD = cast<MaskedLoadSDNode>(OrigLoad);
8905   assert(LD->getOffset().isUndef() && "Masked load is already a indexed load!");
8906   return getMaskedLoad(OrigLoad.getValueType(), dl, LD->getChain(), Base,
8907                        Offset, LD->getMask(), LD->getPassThru(),
8908                        LD->getMemoryVT(), LD->getMemOperand(), AM,
8909                        LD->getExtensionType(), LD->isExpandingLoad());
8910 }
8911 
getMaskedStore(SDValue Chain,const SDLoc & dl,SDValue Val,SDValue Base,SDValue Offset,SDValue Mask,EVT MemVT,MachineMemOperand * MMO,ISD::MemIndexedMode AM,bool IsTruncating,bool IsCompressing)8912 SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
8913                                      SDValue Val, SDValue Base, SDValue Offset,
8914                                      SDValue Mask, EVT MemVT,
8915                                      MachineMemOperand *MMO,
8916                                      ISD::MemIndexedMode AM, bool IsTruncating,
8917                                      bool IsCompressing) {
8918   assert(Chain.getValueType() == MVT::Other &&
8919         "Invalid chain type");
8920   bool Indexed = AM != ISD::UNINDEXED;
8921   assert((Indexed || Offset.isUndef()) &&
8922          "Unindexed masked store with an offset!");
8923   SDVTList VTs = Indexed ? getVTList(Base.getValueType(), MVT::Other)
8924                          : getVTList(MVT::Other);
8925   SDValue Ops[] = {Chain, Val, Base, Offset, Mask};
8926   FoldingSetNodeID ID;
8927   AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops);
8928   ID.AddInteger(MemVT.getRawBits());
8929   ID.AddInteger(getSyntheticNodeSubclassData<MaskedStoreSDNode>(
8930       dl.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO));
8931   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8932   ID.AddInteger(MMO->getFlags());
8933   void *IP = nullptr;
8934   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8935     cast<MaskedStoreSDNode>(E)->refineAlignment(MMO);
8936     return SDValue(E, 0);
8937   }
8938   auto *N =
8939       newSDNode<MaskedStoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, AM,
8940                                    IsTruncating, IsCompressing, MemVT, MMO);
8941   createOperands(N, Ops);
8942 
8943   CSEMap.InsertNode(N, IP);
8944   InsertNode(N);
8945   SDValue V(N, 0);
8946   NewSDValueDbgMsg(V, "Creating new node: ", this);
8947   return V;
8948 }
8949 
getIndexedMaskedStore(SDValue OrigStore,const SDLoc & dl,SDValue Base,SDValue Offset,ISD::MemIndexedMode AM)8950 SDValue SelectionDAG::getIndexedMaskedStore(SDValue OrigStore, const SDLoc &dl,
8951                                             SDValue Base, SDValue Offset,
8952                                             ISD::MemIndexedMode AM) {
8953   MaskedStoreSDNode *ST = cast<MaskedStoreSDNode>(OrigStore);
8954   assert(ST->getOffset().isUndef() &&
8955          "Masked store is already a indexed store!");
8956   return getMaskedStore(ST->getChain(), dl, ST->getValue(), Base, Offset,
8957                         ST->getMask(), ST->getMemoryVT(), ST->getMemOperand(),
8958                         AM, ST->isTruncatingStore(), ST->isCompressingStore());
8959 }
8960 
getMaskedGather(SDVTList VTs,EVT MemVT,const SDLoc & dl,ArrayRef<SDValue> Ops,MachineMemOperand * MMO,ISD::MemIndexType IndexType,ISD::LoadExtType ExtTy)8961 SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT MemVT, const SDLoc &dl,
8962                                       ArrayRef<SDValue> Ops,
8963                                       MachineMemOperand *MMO,
8964                                       ISD::MemIndexType IndexType,
8965                                       ISD::LoadExtType ExtTy) {
8966   assert(Ops.size() == 6 && "Incompatible number of operands");
8967 
8968   FoldingSetNodeID ID;
8969   AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
8970   ID.AddInteger(MemVT.getRawBits());
8971   ID.AddInteger(getSyntheticNodeSubclassData<MaskedGatherSDNode>(
8972       dl.getIROrder(), VTs, MemVT, MMO, IndexType, ExtTy));
8973   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
8974   ID.AddInteger(MMO->getFlags());
8975   void *IP = nullptr;
8976   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
8977     cast<MaskedGatherSDNode>(E)->refineAlignment(MMO);
8978     return SDValue(E, 0);
8979   }
8980 
8981   auto *N = newSDNode<MaskedGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(),
8982                                           VTs, MemVT, MMO, IndexType, ExtTy);
8983   createOperands(N, Ops);
8984 
8985   assert(N->getPassThru().getValueType() == N->getValueType(0) &&
8986          "Incompatible type of the PassThru value in MaskedGatherSDNode");
8987   assert(N->getMask().getValueType().getVectorElementCount() ==
8988              N->getValueType(0).getVectorElementCount() &&
8989          "Vector width mismatch between mask and data");
8990   assert(N->getIndex().getValueType().getVectorElementCount().isScalable() ==
8991              N->getValueType(0).getVectorElementCount().isScalable() &&
8992          "Scalable flags of index and data do not match");
8993   assert(ElementCount::isKnownGE(
8994              N->getIndex().getValueType().getVectorElementCount(),
8995              N->getValueType(0).getVectorElementCount()) &&
8996          "Vector width mismatch between index and data");
8997   assert(isa<ConstantSDNode>(N->getScale()) &&
8998          cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
8999          "Scale should be a constant power of 2");
9000 
9001   CSEMap.InsertNode(N, IP);
9002   InsertNode(N);
9003   SDValue V(N, 0);
9004   NewSDValueDbgMsg(V, "Creating new node: ", this);
9005   return V;
9006 }
9007 
getMaskedScatter(SDVTList VTs,EVT MemVT,const SDLoc & dl,ArrayRef<SDValue> Ops,MachineMemOperand * MMO,ISD::MemIndexType IndexType,bool IsTrunc)9008 SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT MemVT, const SDLoc &dl,
9009                                        ArrayRef<SDValue> Ops,
9010                                        MachineMemOperand *MMO,
9011                                        ISD::MemIndexType IndexType,
9012                                        bool IsTrunc) {
9013   assert(Ops.size() == 6 && "Incompatible number of operands");
9014 
9015   FoldingSetNodeID ID;
9016   AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
9017   ID.AddInteger(MemVT.getRawBits());
9018   ID.AddInteger(getSyntheticNodeSubclassData<MaskedScatterSDNode>(
9019       dl.getIROrder(), VTs, MemVT, MMO, IndexType, IsTrunc));
9020   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
9021   ID.AddInteger(MMO->getFlags());
9022   void *IP = nullptr;
9023   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
9024     cast<MaskedScatterSDNode>(E)->refineAlignment(MMO);
9025     return SDValue(E, 0);
9026   }
9027 
9028   auto *N = newSDNode<MaskedScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(),
9029                                            VTs, MemVT, MMO, IndexType, IsTrunc);
9030   createOperands(N, Ops);
9031 
9032   assert(N->getMask().getValueType().getVectorElementCount() ==
9033              N->getValue().getValueType().getVectorElementCount() &&
9034          "Vector width mismatch between mask and data");
9035   assert(
9036       N->getIndex().getValueType().getVectorElementCount().isScalable() ==
9037           N->getValue().getValueType().getVectorElementCount().isScalable() &&
9038       "Scalable flags of index and data do not match");
9039   assert(ElementCount::isKnownGE(
9040              N->getIndex().getValueType().getVectorElementCount(),
9041              N->getValue().getValueType().getVectorElementCount()) &&
9042          "Vector width mismatch between index and data");
9043   assert(isa<ConstantSDNode>(N->getScale()) &&
9044          cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
9045          "Scale should be a constant power of 2");
9046 
9047   CSEMap.InsertNode(N, IP);
9048   InsertNode(N);
9049   SDValue V(N, 0);
9050   NewSDValueDbgMsg(V, "Creating new node: ", this);
9051   return V;
9052 }
9053 
simplifySelect(SDValue Cond,SDValue T,SDValue F)9054 SDValue SelectionDAG::simplifySelect(SDValue Cond, SDValue T, SDValue F) {
9055   // select undef, T, F --> T (if T is a constant), otherwise F
9056   // select, ?, undef, F --> F
9057   // select, ?, T, undef --> T
9058   if (Cond.isUndef())
9059     return isConstantValueOfAnyType(T) ? T : F;
9060   if (T.isUndef())
9061     return F;
9062   if (F.isUndef())
9063     return T;
9064 
9065   // select true, T, F --> T
9066   // select false, T, F --> F
9067   if (auto *CondC = dyn_cast<ConstantSDNode>(Cond))
9068     return CondC->isZero() ? F : T;
9069 
9070   // TODO: This should simplify VSELECT with non-zero constant condition using
9071   // something like this (but check boolean contents to be complete?):
9072   if (ConstantSDNode *CondC = isConstOrConstSplat(Cond, /*AllowUndefs*/ false,
9073                                                   /*AllowTruncation*/ true))
9074     if (CondC->isZero())
9075       return F;
9076 
9077   // select ?, T, T --> T
9078   if (T == F)
9079     return T;
9080 
9081   return SDValue();
9082 }
9083 
simplifyShift(SDValue X,SDValue Y)9084 SDValue SelectionDAG::simplifyShift(SDValue X, SDValue Y) {
9085   // shift undef, Y --> 0 (can always assume that the undef value is 0)
9086   if (X.isUndef())
9087     return getConstant(0, SDLoc(X.getNode()), X.getValueType());
9088   // shift X, undef --> undef (because it may shift by the bitwidth)
9089   if (Y.isUndef())
9090     return getUNDEF(X.getValueType());
9091 
9092   // shift 0, Y --> 0
9093   // shift X, 0 --> X
9094   if (isNullOrNullSplat(X) || isNullOrNullSplat(Y))
9095     return X;
9096 
9097   // shift X, C >= bitwidth(X) --> undef
9098   // All vector elements must be too big (or undef) to avoid partial undefs.
9099   auto isShiftTooBig = [X](ConstantSDNode *Val) {
9100     return !Val || Val->getAPIntValue().uge(X.getScalarValueSizeInBits());
9101   };
9102   if (ISD::matchUnaryPredicate(Y, isShiftTooBig, true))
9103     return getUNDEF(X.getValueType());
9104 
9105   return SDValue();
9106 }
9107 
simplifyFPBinop(unsigned Opcode,SDValue X,SDValue Y,SDNodeFlags Flags)9108 SDValue SelectionDAG::simplifyFPBinop(unsigned Opcode, SDValue X, SDValue Y,
9109                                       SDNodeFlags Flags) {
9110   // If this operation has 'nnan' or 'ninf' and at least 1 disallowed operand
9111   // (an undef operand can be chosen to be Nan/Inf), then the result of this
9112   // operation is poison. That result can be relaxed to undef.
9113   ConstantFPSDNode *XC = isConstOrConstSplatFP(X, /* AllowUndefs */ true);
9114   ConstantFPSDNode *YC = isConstOrConstSplatFP(Y, /* AllowUndefs */ true);
9115   bool HasNan = (XC && XC->getValueAPF().isNaN()) ||
9116                 (YC && YC->getValueAPF().isNaN());
9117   bool HasInf = (XC && XC->getValueAPF().isInfinity()) ||
9118                 (YC && YC->getValueAPF().isInfinity());
9119 
9120   if (Flags.hasNoNaNs() && (HasNan || X.isUndef() || Y.isUndef()))
9121     return getUNDEF(X.getValueType());
9122 
9123   if (Flags.hasNoInfs() && (HasInf || X.isUndef() || Y.isUndef()))
9124     return getUNDEF(X.getValueType());
9125 
9126   if (!YC)
9127     return SDValue();
9128 
9129   // X + -0.0 --> X
9130   if (Opcode == ISD::FADD)
9131     if (YC->getValueAPF().isNegZero())
9132       return X;
9133 
9134   // X - +0.0 --> X
9135   if (Opcode == ISD::FSUB)
9136     if (YC->getValueAPF().isPosZero())
9137       return X;
9138 
9139   // X * 1.0 --> X
9140   // X / 1.0 --> X
9141   if (Opcode == ISD::FMUL || Opcode == ISD::FDIV)
9142     if (YC->getValueAPF().isExactlyValue(1.0))
9143       return X;
9144 
9145   // X * 0.0 --> 0.0
9146   if (Opcode == ISD::FMUL && Flags.hasNoNaNs() && Flags.hasNoSignedZeros())
9147     if (YC->getValueAPF().isZero())
9148       return getConstantFP(0.0, SDLoc(Y), Y.getValueType());
9149 
9150   return SDValue();
9151 }
9152 
getVAArg(EVT VT,const SDLoc & dl,SDValue Chain,SDValue Ptr,SDValue SV,unsigned Align)9153 SDValue SelectionDAG::getVAArg(EVT VT, const SDLoc &dl, SDValue Chain,
9154                                SDValue Ptr, SDValue SV, unsigned Align) {
9155   SDValue Ops[] = { Chain, Ptr, SV, getTargetConstant(Align, dl, MVT::i32) };
9156   return getNode(ISD::VAARG, dl, getVTList(VT, MVT::Other), Ops);
9157 }
9158 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,ArrayRef<SDUse> Ops)9159 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
9160                               ArrayRef<SDUse> Ops) {
9161   switch (Ops.size()) {
9162   case 0: return getNode(Opcode, DL, VT);
9163   case 1: return getNode(Opcode, DL, VT, static_cast<const SDValue>(Ops[0]));
9164   case 2: return getNode(Opcode, DL, VT, Ops[0], Ops[1]);
9165   case 3: return getNode(Opcode, DL, VT, Ops[0], Ops[1], Ops[2]);
9166   default: break;
9167   }
9168 
9169   // Copy from an SDUse array into an SDValue array for use with
9170   // the regular getNode logic.
9171   SmallVector<SDValue, 8> NewOps(Ops.begin(), Ops.end());
9172   return getNode(Opcode, DL, VT, NewOps);
9173 }
9174 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,ArrayRef<SDValue> Ops)9175 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
9176                               ArrayRef<SDValue> Ops) {
9177   SDNodeFlags Flags;
9178   if (Inserter)
9179     Flags = Inserter->getFlags();
9180   return getNode(Opcode, DL, VT, Ops, Flags);
9181 }
9182 
getNode(unsigned Opcode,const SDLoc & DL,EVT VT,ArrayRef<SDValue> Ops,const SDNodeFlags Flags)9183 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
9184                               ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
9185   unsigned NumOps = Ops.size();
9186   switch (NumOps) {
9187   case 0: return getNode(Opcode, DL, VT);
9188   case 1: return getNode(Opcode, DL, VT, Ops[0], Flags);
9189   case 2: return getNode(Opcode, DL, VT, Ops[0], Ops[1], Flags);
9190   case 3: return getNode(Opcode, DL, VT, Ops[0], Ops[1], Ops[2], Flags);
9191   default: break;
9192   }
9193 
9194 #ifndef NDEBUG
9195   for (const auto &Op : Ops)
9196     assert(Op.getOpcode() != ISD::DELETED_NODE &&
9197            "Operand is DELETED_NODE!");
9198 #endif
9199 
9200   switch (Opcode) {
9201   default: break;
9202   case ISD::BUILD_VECTOR:
9203     // Attempt to simplify BUILD_VECTOR.
9204     if (SDValue V = FoldBUILD_VECTOR(DL, VT, Ops, *this))
9205       return V;
9206     break;
9207   case ISD::CONCAT_VECTORS:
9208     if (SDValue V = foldCONCAT_VECTORS(DL, VT, Ops, *this))
9209       return V;
9210     break;
9211   case ISD::SELECT_CC:
9212     assert(NumOps == 5 && "SELECT_CC takes 5 operands!");
9213     assert(Ops[0].getValueType() == Ops[1].getValueType() &&
9214            "LHS and RHS of condition must have same type!");
9215     assert(Ops[2].getValueType() == Ops[3].getValueType() &&
9216            "True and False arms of SelectCC must have same type!");
9217     assert(Ops[2].getValueType() == VT &&
9218            "select_cc node must be of same type as true and false value!");
9219     assert((!Ops[0].getValueType().isVector() ||
9220             Ops[0].getValueType().getVectorElementCount() ==
9221                 VT.getVectorElementCount()) &&
9222            "Expected select_cc with vector result to have the same sized "
9223            "comparison type!");
9224     break;
9225   case ISD::BR_CC:
9226     assert(NumOps == 5 && "BR_CC takes 5 operands!");
9227     assert(Ops[2].getValueType() == Ops[3].getValueType() &&
9228            "LHS/RHS of comparison should match types!");
9229     break;
9230   case ISD::VP_ADD:
9231   case ISD::VP_SUB:
9232     // If it is VP_ADD/VP_SUB mask operation then turn it to VP_XOR
9233     if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
9234       Opcode = ISD::VP_XOR;
9235     break;
9236   case ISD::VP_MUL:
9237     // If it is VP_MUL mask operation then turn it to VP_AND
9238     if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
9239       Opcode = ISD::VP_AND;
9240     break;
9241   case ISD::VP_REDUCE_MUL:
9242     // If it is VP_REDUCE_MUL mask operation then turn it to VP_REDUCE_AND
9243     if (VT == MVT::i1)
9244       Opcode = ISD::VP_REDUCE_AND;
9245     break;
9246   case ISD::VP_REDUCE_ADD:
9247     // If it is VP_REDUCE_ADD mask operation then turn it to VP_REDUCE_XOR
9248     if (VT == MVT::i1)
9249       Opcode = ISD::VP_REDUCE_XOR;
9250     break;
9251   case ISD::VP_REDUCE_SMAX:
9252   case ISD::VP_REDUCE_UMIN:
9253     // If it is VP_REDUCE_SMAX/VP_REDUCE_UMIN mask operation then turn it to
9254     // VP_REDUCE_AND.
9255     if (VT == MVT::i1)
9256       Opcode = ISD::VP_REDUCE_AND;
9257     break;
9258   case ISD::VP_REDUCE_SMIN:
9259   case ISD::VP_REDUCE_UMAX:
9260     // If it is VP_REDUCE_SMIN/VP_REDUCE_UMAX mask operation then turn it to
9261     // VP_REDUCE_OR.
9262     if (VT == MVT::i1)
9263       Opcode = ISD::VP_REDUCE_OR;
9264     break;
9265   }
9266 
9267   // Memoize nodes.
9268   SDNode *N;
9269   SDVTList VTs = getVTList(VT);
9270 
9271   if (VT != MVT::Glue) {
9272     FoldingSetNodeID ID;
9273     AddNodeIDNode(ID, Opcode, VTs, Ops);
9274     void *IP = nullptr;
9275 
9276     if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
9277       return SDValue(E, 0);
9278 
9279     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
9280     createOperands(N, Ops);
9281 
9282     CSEMap.InsertNode(N, IP);
9283   } else {
9284     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
9285     createOperands(N, Ops);
9286   }
9287 
9288   N->setFlags(Flags);
9289   InsertNode(N);
9290   SDValue V(N, 0);
9291   NewSDValueDbgMsg(V, "Creating new node: ", this);
9292   return V;
9293 }
9294 
getNode(unsigned Opcode,const SDLoc & DL,ArrayRef<EVT> ResultTys,ArrayRef<SDValue> Ops)9295 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
9296                               ArrayRef<EVT> ResultTys, ArrayRef<SDValue> Ops) {
9297   return getNode(Opcode, DL, getVTList(ResultTys), Ops);
9298 }
9299 
getNode(unsigned Opcode,const SDLoc & DL,SDVTList VTList,ArrayRef<SDValue> Ops)9300 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
9301                               ArrayRef<SDValue> Ops) {
9302   SDNodeFlags Flags;
9303   if (Inserter)
9304     Flags = Inserter->getFlags();
9305   return getNode(Opcode, DL, VTList, Ops, Flags);
9306 }
9307 
getNode(unsigned Opcode,const SDLoc & DL,SDVTList VTList,ArrayRef<SDValue> Ops,const SDNodeFlags Flags)9308 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
9309                               ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
9310   if (VTList.NumVTs == 1)
9311     return getNode(Opcode, DL, VTList.VTs[0], Ops, Flags);
9312 
9313 #ifndef NDEBUG
9314   for (const auto &Op : Ops)
9315     assert(Op.getOpcode() != ISD::DELETED_NODE &&
9316            "Operand is DELETED_NODE!");
9317 #endif
9318 
9319   switch (Opcode) {
9320   case ISD::SADDO:
9321   case ISD::UADDO:
9322   case ISD::SSUBO:
9323   case ISD::USUBO: {
9324     assert(VTList.NumVTs == 2 && Ops.size() == 2 &&
9325            "Invalid add/sub overflow op!");
9326     assert(VTList.VTs[0].isInteger() && VTList.VTs[1].isInteger() &&
9327            Ops[0].getValueType() == Ops[1].getValueType() &&
9328            Ops[0].getValueType() == VTList.VTs[0] &&
9329            "Binary operator types must match!");
9330     SDValue N1 = Ops[0], N2 = Ops[1];
9331     canonicalizeCommutativeBinop(Opcode, N1, N2);
9332 
9333     // (X +- 0) -> X with zero-overflow.
9334     ConstantSDNode *N2CV = isConstOrConstSplat(N2, /*AllowUndefs*/ false,
9335                                                /*AllowTruncation*/ true);
9336     if (N2CV && N2CV->isZero()) {
9337       SDValue ZeroOverFlow = getConstant(0, DL, VTList.VTs[1]);
9338       return getNode(ISD::MERGE_VALUES, DL, VTList, {N1, ZeroOverFlow}, Flags);
9339     }
9340     break;
9341   }
9342   case ISD::SMUL_LOHI:
9343   case ISD::UMUL_LOHI: {
9344     assert(VTList.NumVTs == 2 && Ops.size() == 2 && "Invalid mul lo/hi op!");
9345     assert(VTList.VTs[0].isInteger() && VTList.VTs[0] == VTList.VTs[1] &&
9346            VTList.VTs[0] == Ops[0].getValueType() &&
9347            VTList.VTs[0] == Ops[1].getValueType() &&
9348            "Binary operator types must match!");
9349     break;
9350   }
9351   case ISD::STRICT_FP_EXTEND:
9352     assert(VTList.NumVTs == 2 && Ops.size() == 2 &&
9353            "Invalid STRICT_FP_EXTEND!");
9354     assert(VTList.VTs[0].isFloatingPoint() &&
9355            Ops[1].getValueType().isFloatingPoint() && "Invalid FP cast!");
9356     assert(VTList.VTs[0].isVector() == Ops[1].getValueType().isVector() &&
9357            "STRICT_FP_EXTEND result type should be vector iff the operand "
9358            "type is vector!");
9359     assert((!VTList.VTs[0].isVector() ||
9360             VTList.VTs[0].getVectorNumElements() ==
9361             Ops[1].getValueType().getVectorNumElements()) &&
9362            "Vector element count mismatch!");
9363     assert(Ops[1].getValueType().bitsLT(VTList.VTs[0]) &&
9364            "Invalid fpext node, dst <= src!");
9365     break;
9366   case ISD::STRICT_FP_ROUND:
9367     assert(VTList.NumVTs == 2 && Ops.size() == 3 && "Invalid STRICT_FP_ROUND!");
9368     assert(VTList.VTs[0].isVector() == Ops[1].getValueType().isVector() &&
9369            "STRICT_FP_ROUND result type should be vector iff the operand "
9370            "type is vector!");
9371     assert((!VTList.VTs[0].isVector() ||
9372             VTList.VTs[0].getVectorNumElements() ==
9373             Ops[1].getValueType().getVectorNumElements()) &&
9374            "Vector element count mismatch!");
9375     assert(VTList.VTs[0].isFloatingPoint() &&
9376            Ops[1].getValueType().isFloatingPoint() &&
9377            VTList.VTs[0].bitsLT(Ops[1].getValueType()) &&
9378            isa<ConstantSDNode>(Ops[2]) &&
9379            (cast<ConstantSDNode>(Ops[2])->getZExtValue() == 0 ||
9380             cast<ConstantSDNode>(Ops[2])->getZExtValue() == 1) &&
9381            "Invalid STRICT_FP_ROUND!");
9382     break;
9383 #if 0
9384   // FIXME: figure out how to safely handle things like
9385   // int foo(int x) { return 1 << (x & 255); }
9386   // int bar() { return foo(256); }
9387   case ISD::SRA_PARTS:
9388   case ISD::SRL_PARTS:
9389   case ISD::SHL_PARTS:
9390     if (N3.getOpcode() == ISD::SIGN_EXTEND_INREG &&
9391         cast<VTSDNode>(N3.getOperand(1))->getVT() != MVT::i1)
9392       return getNode(Opcode, DL, VT, N1, N2, N3.getOperand(0));
9393     else if (N3.getOpcode() == ISD::AND)
9394       if (ConstantSDNode *AndRHS = dyn_cast<ConstantSDNode>(N3.getOperand(1))) {
9395         // If the and is only masking out bits that cannot effect the shift,
9396         // eliminate the and.
9397         unsigned NumBits = VT.getScalarSizeInBits()*2;
9398         if ((AndRHS->getValue() & (NumBits-1)) == NumBits-1)
9399           return getNode(Opcode, DL, VT, N1, N2, N3.getOperand(0));
9400       }
9401     break;
9402 #endif
9403   }
9404 
9405   // Memoize the node unless it returns a flag.
9406   SDNode *N;
9407   if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) {
9408     FoldingSetNodeID ID;
9409     AddNodeIDNode(ID, Opcode, VTList, Ops);
9410     void *IP = nullptr;
9411     if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
9412       return SDValue(E, 0);
9413 
9414     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
9415     createOperands(N, Ops);
9416     CSEMap.InsertNode(N, IP);
9417   } else {
9418     N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
9419     createOperands(N, Ops);
9420   }
9421 
9422   N->setFlags(Flags);
9423   InsertNode(N);
9424   SDValue V(N, 0);
9425   NewSDValueDbgMsg(V, "Creating new node: ", this);
9426   return V;
9427 }
9428 
getNode(unsigned Opcode,const SDLoc & DL,SDVTList VTList)9429 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
9430                               SDVTList VTList) {
9431   return getNode(Opcode, DL, VTList, std::nullopt);
9432 }
9433 
getNode(unsigned Opcode,const SDLoc & DL,SDVTList VTList,SDValue N1)9434 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
9435                               SDValue N1) {
9436   SDValue Ops[] = { N1 };
9437   return getNode(Opcode, DL, VTList, Ops);
9438 }
9439 
getNode(unsigned Opcode,const SDLoc & DL,SDVTList VTList,SDValue N1,SDValue N2)9440 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
9441                               SDValue N1, SDValue N2) {
9442   SDValue Ops[] = { N1, N2 };
9443   return getNode(Opcode, DL, VTList, Ops);
9444 }
9445 
getNode(unsigned Opcode,const SDLoc & DL,SDVTList VTList,SDValue N1,SDValue N2,SDValue N3)9446 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
9447                               SDValue N1, SDValue N2, SDValue N3) {
9448   SDValue Ops[] = { N1, N2, N3 };
9449   return getNode(Opcode, DL, VTList, Ops);
9450 }
9451 
getNode(unsigned Opcode,const SDLoc & DL,SDVTList VTList,SDValue N1,SDValue N2,SDValue N3,SDValue N4)9452 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
9453                               SDValue N1, SDValue N2, SDValue N3, SDValue N4) {
9454   SDValue Ops[] = { N1, N2, N3, N4 };
9455   return getNode(Opcode, DL, VTList, Ops);
9456 }
9457 
getNode(unsigned Opcode,const SDLoc & DL,SDVTList VTList,SDValue N1,SDValue N2,SDValue N3,SDValue N4,SDValue N5)9458 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
9459                               SDValue N1, SDValue N2, SDValue N3, SDValue N4,
9460                               SDValue N5) {
9461   SDValue Ops[] = { N1, N2, N3, N4, N5 };
9462   return getNode(Opcode, DL, VTList, Ops);
9463 }
9464 
getVTList(EVT VT)9465 SDVTList SelectionDAG::getVTList(EVT VT) {
9466   return makeVTList(SDNode::getValueTypeList(VT), 1);
9467 }
9468 
getVTList(EVT VT1,EVT VT2)9469 SDVTList SelectionDAG::getVTList(EVT VT1, EVT VT2) {
9470   FoldingSetNodeID ID;
9471   ID.AddInteger(2U);
9472   ID.AddInteger(VT1.getRawBits());
9473   ID.AddInteger(VT2.getRawBits());
9474 
9475   void *IP = nullptr;
9476   SDVTListNode *Result = VTListMap.FindNodeOrInsertPos(ID, IP);
9477   if (!Result) {
9478     EVT *Array = Allocator.Allocate<EVT>(2);
9479     Array[0] = VT1;
9480     Array[1] = VT2;
9481     Result = new (Allocator) SDVTListNode(ID.Intern(Allocator), Array, 2);
9482     VTListMap.InsertNode(Result, IP);
9483   }
9484   return Result->getSDVTList();
9485 }
9486 
getVTList(EVT VT1,EVT VT2,EVT VT3)9487 SDVTList SelectionDAG::getVTList(EVT VT1, EVT VT2, EVT VT3) {
9488   FoldingSetNodeID ID;
9489   ID.AddInteger(3U);
9490   ID.AddInteger(VT1.getRawBits());
9491   ID.AddInteger(VT2.getRawBits());
9492   ID.AddInteger(VT3.getRawBits());
9493 
9494   void *IP = nullptr;
9495   SDVTListNode *Result = VTListMap.FindNodeOrInsertPos(ID, IP);
9496   if (!Result) {
9497     EVT *Array = Allocator.Allocate<EVT>(3);
9498     Array[0] = VT1;
9499     Array[1] = VT2;
9500     Array[2] = VT3;
9501     Result = new (Allocator) SDVTListNode(ID.Intern(Allocator), Array, 3);
9502     VTListMap.InsertNode(Result, IP);
9503   }
9504   return Result->getSDVTList();
9505 }
9506 
getVTList(EVT VT1,EVT VT2,EVT VT3,EVT VT4)9507 SDVTList SelectionDAG::getVTList(EVT VT1, EVT VT2, EVT VT3, EVT VT4) {
9508   FoldingSetNodeID ID;
9509   ID.AddInteger(4U);
9510   ID.AddInteger(VT1.getRawBits());
9511   ID.AddInteger(VT2.getRawBits());
9512   ID.AddInteger(VT3.getRawBits());
9513   ID.AddInteger(VT4.getRawBits());
9514 
9515   void *IP = nullptr;
9516   SDVTListNode *Result = VTListMap.FindNodeOrInsertPos(ID, IP);
9517   if (!Result) {
9518     EVT *Array = Allocator.Allocate<EVT>(4);
9519     Array[0] = VT1;
9520     Array[1] = VT2;
9521     Array[2] = VT3;
9522     Array[3] = VT4;
9523     Result = new (Allocator) SDVTListNode(ID.Intern(Allocator), Array, 4);
9524     VTListMap.InsertNode(Result, IP);
9525   }
9526   return Result->getSDVTList();
9527 }
9528 
getVTList(ArrayRef<EVT> VTs)9529 SDVTList SelectionDAG::getVTList(ArrayRef<EVT> VTs) {
9530   unsigned NumVTs = VTs.size();
9531   FoldingSetNodeID ID;
9532   ID.AddInteger(NumVTs);
9533   for (unsigned index = 0; index < NumVTs; index++) {
9534     ID.AddInteger(VTs[index].getRawBits());
9535   }
9536 
9537   void *IP = nullptr;
9538   SDVTListNode *Result = VTListMap.FindNodeOrInsertPos(ID, IP);
9539   if (!Result) {
9540     EVT *Array = Allocator.Allocate<EVT>(NumVTs);
9541     llvm::copy(VTs, Array);
9542     Result = new (Allocator) SDVTListNode(ID.Intern(Allocator), Array, NumVTs);
9543     VTListMap.InsertNode(Result, IP);
9544   }
9545   return Result->getSDVTList();
9546 }
9547 
9548 
9549 /// UpdateNodeOperands - *Mutate* the specified node in-place to have the
9550 /// specified operands.  If the resultant node already exists in the DAG,
9551 /// this does not modify the specified node, instead it returns the node that
9552 /// already exists.  If the resultant node does not exist in the DAG, the
9553 /// input node is returned.  As a degenerate case, if you specify the same
9554 /// input operands as the node already has, the input node is returned.
UpdateNodeOperands(SDNode * N,SDValue Op)9555 SDNode *SelectionDAG::UpdateNodeOperands(SDNode *N, SDValue Op) {
9556   assert(N->getNumOperands() == 1 && "Update with wrong number of operands");
9557 
9558   // Check to see if there is no change.
9559   if (Op == N->getOperand(0)) return N;
9560 
9561   // See if the modified node already exists.
9562   void *InsertPos = nullptr;
9563   if (SDNode *Existing = FindModifiedNodeSlot(N, Op, InsertPos))
9564     return Existing;
9565 
9566   // Nope it doesn't.  Remove the node from its current place in the maps.
9567   if (InsertPos)
9568     if (!RemoveNodeFromCSEMaps(N))
9569       InsertPos = nullptr;
9570 
9571   // Now we update the operands.
9572   N->OperandList[0].set(Op);
9573 
9574   updateDivergence(N);
9575   // If this gets put into a CSE map, add it.
9576   if (InsertPos) CSEMap.InsertNode(N, InsertPos);
9577   return N;
9578 }
9579 
UpdateNodeOperands(SDNode * N,SDValue Op1,SDValue Op2)9580 SDNode *SelectionDAG::UpdateNodeOperands(SDNode *N, SDValue Op1, SDValue Op2) {
9581   assert(N->getNumOperands() == 2 && "Update with wrong number of operands");
9582 
9583   // Check to see if there is no change.
9584   if (Op1 == N->getOperand(0) && Op2 == N->getOperand(1))
9585     return N;   // No operands changed, just return the input node.
9586 
9587   // See if the modified node already exists.
9588   void *InsertPos = nullptr;
9589   if (SDNode *Existing = FindModifiedNodeSlot(N, Op1, Op2, InsertPos))
9590     return Existing;
9591 
9592   // Nope it doesn't.  Remove the node from its current place in the maps.
9593   if (InsertPos)
9594     if (!RemoveNodeFromCSEMaps(N))
9595       InsertPos = nullptr;
9596 
9597   // Now we update the operands.
9598   if (N->OperandList[0] != Op1)
9599     N->OperandList[0].set(Op1);
9600   if (N->OperandList[1] != Op2)
9601     N->OperandList[1].set(Op2);
9602 
9603   updateDivergence(N);
9604   // If this gets put into a CSE map, add it.
9605   if (InsertPos) CSEMap.InsertNode(N, InsertPos);
9606   return N;
9607 }
9608 
9609 SDNode *SelectionDAG::
UpdateNodeOperands(SDNode * N,SDValue Op1,SDValue Op2,SDValue Op3)9610 UpdateNodeOperands(SDNode *N, SDValue Op1, SDValue Op2, SDValue Op3) {
9611   SDValue Ops[] = { Op1, Op2, Op3 };
9612   return UpdateNodeOperands(N, Ops);
9613 }
9614 
9615 SDNode *SelectionDAG::
UpdateNodeOperands(SDNode * N,SDValue Op1,SDValue Op2,SDValue Op3,SDValue Op4)9616 UpdateNodeOperands(SDNode *N, SDValue Op1, SDValue Op2,
9617                    SDValue Op3, SDValue Op4) {
9618   SDValue Ops[] = { Op1, Op2, Op3, Op4 };
9619   return UpdateNodeOperands(N, Ops);
9620 }
9621 
9622 SDNode *SelectionDAG::
UpdateNodeOperands(SDNode * N,SDValue Op1,SDValue Op2,SDValue Op3,SDValue Op4,SDValue Op5)9623 UpdateNodeOperands(SDNode *N, SDValue Op1, SDValue Op2,
9624                    SDValue Op3, SDValue Op4, SDValue Op5) {
9625   SDValue Ops[] = { Op1, Op2, Op3, Op4, Op5 };
9626   return UpdateNodeOperands(N, Ops);
9627 }
9628 
9629 SDNode *SelectionDAG::
UpdateNodeOperands(SDNode * N,ArrayRef<SDValue> Ops)9630 UpdateNodeOperands(SDNode *N, ArrayRef<SDValue> Ops) {
9631   unsigned NumOps = Ops.size();
9632   assert(N->getNumOperands() == NumOps &&
9633          "Update with wrong number of operands");
9634 
9635   // If no operands changed just return the input node.
9636   if (std::equal(Ops.begin(), Ops.end(), N->op_begin()))
9637     return N;
9638 
9639   // See if the modified node already exists.
9640   void *InsertPos = nullptr;
9641   if (SDNode *Existing = FindModifiedNodeSlot(N, Ops, InsertPos))
9642     return Existing;
9643 
9644   // Nope it doesn't.  Remove the node from its current place in the maps.
9645   if (InsertPos)
9646     if (!RemoveNodeFromCSEMaps(N))
9647       InsertPos = nullptr;
9648 
9649   // Now we update the operands.
9650   for (unsigned i = 0; i != NumOps; ++i)
9651     if (N->OperandList[i] != Ops[i])
9652       N->OperandList[i].set(Ops[i]);
9653 
9654   updateDivergence(N);
9655   // If this gets put into a CSE map, add it.
9656   if (InsertPos) CSEMap.InsertNode(N, InsertPos);
9657   return N;
9658 }
9659 
9660 /// DropOperands - Release the operands and set this node to have
9661 /// zero operands.
DropOperands()9662 void SDNode::DropOperands() {
9663   // Unlike the code in MorphNodeTo that does this, we don't need to
9664   // watch for dead nodes here.
9665   for (op_iterator I = op_begin(), E = op_end(); I != E; ) {
9666     SDUse &Use = *I++;
9667     Use.set(SDValue());
9668   }
9669 }
9670 
setNodeMemRefs(MachineSDNode * N,ArrayRef<MachineMemOperand * > NewMemRefs)9671 void SelectionDAG::setNodeMemRefs(MachineSDNode *N,
9672                                   ArrayRef<MachineMemOperand *> NewMemRefs) {
9673   if (NewMemRefs.empty()) {
9674     N->clearMemRefs();
9675     return;
9676   }
9677 
9678   // Check if we can avoid allocating by storing a single reference directly.
9679   if (NewMemRefs.size() == 1) {
9680     N->MemRefs = NewMemRefs[0];
9681     N->NumMemRefs = 1;
9682     return;
9683   }
9684 
9685   MachineMemOperand **MemRefsBuffer =
9686       Allocator.template Allocate<MachineMemOperand *>(NewMemRefs.size());
9687   llvm::copy(NewMemRefs, MemRefsBuffer);
9688   N->MemRefs = MemRefsBuffer;
9689   N->NumMemRefs = static_cast<int>(NewMemRefs.size());
9690 }
9691 
9692 /// SelectNodeTo - These are wrappers around MorphNodeTo that accept a
9693 /// machine opcode.
9694 ///
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT)9695 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9696                                    EVT VT) {
9697   SDVTList VTs = getVTList(VT);
9698   return SelectNodeTo(N, MachineOpc, VTs, std::nullopt);
9699 }
9700 
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT,SDValue Op1)9701 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9702                                    EVT VT, SDValue Op1) {
9703   SDVTList VTs = getVTList(VT);
9704   SDValue Ops[] = { Op1 };
9705   return SelectNodeTo(N, MachineOpc, VTs, Ops);
9706 }
9707 
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT,SDValue Op1,SDValue Op2)9708 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9709                                    EVT VT, SDValue Op1,
9710                                    SDValue Op2) {
9711   SDVTList VTs = getVTList(VT);
9712   SDValue Ops[] = { Op1, Op2 };
9713   return SelectNodeTo(N, MachineOpc, VTs, Ops);
9714 }
9715 
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT,SDValue Op1,SDValue Op2,SDValue Op3)9716 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9717                                    EVT VT, SDValue Op1,
9718                                    SDValue Op2, SDValue Op3) {
9719   SDVTList VTs = getVTList(VT);
9720   SDValue Ops[] = { Op1, Op2, Op3 };
9721   return SelectNodeTo(N, MachineOpc, VTs, Ops);
9722 }
9723 
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT,ArrayRef<SDValue> Ops)9724 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9725                                    EVT VT, ArrayRef<SDValue> Ops) {
9726   SDVTList VTs = getVTList(VT);
9727   return SelectNodeTo(N, MachineOpc, VTs, Ops);
9728 }
9729 
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT1,EVT VT2,ArrayRef<SDValue> Ops)9730 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9731                                    EVT VT1, EVT VT2, ArrayRef<SDValue> Ops) {
9732   SDVTList VTs = getVTList(VT1, VT2);
9733   return SelectNodeTo(N, MachineOpc, VTs, Ops);
9734 }
9735 
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT1,EVT VT2)9736 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9737                                    EVT VT1, EVT VT2) {
9738   SDVTList VTs = getVTList(VT1, VT2);
9739   return SelectNodeTo(N, MachineOpc, VTs, std::nullopt);
9740 }
9741 
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT1,EVT VT2,EVT VT3,ArrayRef<SDValue> Ops)9742 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9743                                    EVT VT1, EVT VT2, EVT VT3,
9744                                    ArrayRef<SDValue> Ops) {
9745   SDVTList VTs = getVTList(VT1, VT2, VT3);
9746   return SelectNodeTo(N, MachineOpc, VTs, Ops);
9747 }
9748 
SelectNodeTo(SDNode * N,unsigned MachineOpc,EVT VT1,EVT VT2,SDValue Op1,SDValue Op2)9749 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9750                                    EVT VT1, EVT VT2,
9751                                    SDValue Op1, SDValue Op2) {
9752   SDVTList VTs = getVTList(VT1, VT2);
9753   SDValue Ops[] = { Op1, Op2 };
9754   return SelectNodeTo(N, MachineOpc, VTs, Ops);
9755 }
9756 
SelectNodeTo(SDNode * N,unsigned MachineOpc,SDVTList VTs,ArrayRef<SDValue> Ops)9757 SDNode *SelectionDAG::SelectNodeTo(SDNode *N, unsigned MachineOpc,
9758                                    SDVTList VTs,ArrayRef<SDValue> Ops) {
9759   SDNode *New = MorphNodeTo(N, ~MachineOpc, VTs, Ops);
9760   // Reset the NodeID to -1.
9761   New->setNodeId(-1);
9762   if (New != N) {
9763     ReplaceAllUsesWith(N, New);
9764     RemoveDeadNode(N);
9765   }
9766   return New;
9767 }
9768 
9769 /// UpdateSDLocOnMergeSDNode - If the opt level is -O0 then it throws away
9770 /// the line number information on the merged node since it is not possible to
9771 /// preserve the information that operation is associated with multiple lines.
9772 /// This will make the debugger working better at -O0, were there is a higher
9773 /// probability having other instructions associated with that line.
9774 ///
9775 /// For IROrder, we keep the smaller of the two
UpdateSDLocOnMergeSDNode(SDNode * N,const SDLoc & OLoc)9776 SDNode *SelectionDAG::UpdateSDLocOnMergeSDNode(SDNode *N, const SDLoc &OLoc) {
9777   DebugLoc NLoc = N->getDebugLoc();
9778   if (NLoc && OptLevel == CodeGenOpt::None && OLoc.getDebugLoc() != NLoc) {
9779     N->setDebugLoc(DebugLoc());
9780   }
9781   unsigned Order = std::min(N->getIROrder(), OLoc.getIROrder());
9782   N->setIROrder(Order);
9783   return N;
9784 }
9785 
9786 /// MorphNodeTo - This *mutates* the specified node to have the specified
9787 /// return type, opcode, and operands.
9788 ///
9789 /// Note that MorphNodeTo returns the resultant node.  If there is already a
9790 /// node of the specified opcode and operands, it returns that node instead of
9791 /// the current one.  Note that the SDLoc need not be the same.
9792 ///
9793 /// Using MorphNodeTo is faster than creating a new node and swapping it in
9794 /// with ReplaceAllUsesWith both because it often avoids allocating a new
9795 /// node, and because it doesn't require CSE recalculation for any of
9796 /// the node's users.
9797 ///
9798 /// However, note that MorphNodeTo recursively deletes dead nodes from the DAG.
9799 /// As a consequence it isn't appropriate to use from within the DAG combiner or
9800 /// the legalizer which maintain worklists that would need to be updated when
9801 /// deleting things.
MorphNodeTo(SDNode * N,unsigned Opc,SDVTList VTs,ArrayRef<SDValue> Ops)9802 SDNode *SelectionDAG::MorphNodeTo(SDNode *N, unsigned Opc,
9803                                   SDVTList VTs, ArrayRef<SDValue> Ops) {
9804   // If an identical node already exists, use it.
9805   void *IP = nullptr;
9806   if (VTs.VTs[VTs.NumVTs-1] != MVT::Glue) {
9807     FoldingSetNodeID ID;
9808     AddNodeIDNode(ID, Opc, VTs, Ops);
9809     if (SDNode *ON = FindNodeOrInsertPos(ID, SDLoc(N), IP))
9810       return UpdateSDLocOnMergeSDNode(ON, SDLoc(N));
9811   }
9812 
9813   if (!RemoveNodeFromCSEMaps(N))
9814     IP = nullptr;
9815 
9816   // Start the morphing.
9817   N->NodeType = Opc;
9818   N->ValueList = VTs.VTs;
9819   N->NumValues = VTs.NumVTs;
9820 
9821   // Clear the operands list, updating used nodes to remove this from their
9822   // use list.  Keep track of any operands that become dead as a result.
9823   SmallPtrSet<SDNode*, 16> DeadNodeSet;
9824   for (SDNode::op_iterator I = N->op_begin(), E = N->op_end(); I != E; ) {
9825     SDUse &Use = *I++;
9826     SDNode *Used = Use.getNode();
9827     Use.set(SDValue());
9828     if (Used->use_empty())
9829       DeadNodeSet.insert(Used);
9830   }
9831 
9832   // For MachineNode, initialize the memory references information.
9833   if (MachineSDNode *MN = dyn_cast<MachineSDNode>(N))
9834     MN->clearMemRefs();
9835 
9836   // Swap for an appropriately sized array from the recycler.
9837   removeOperands(N);
9838   createOperands(N, Ops);
9839 
9840   // Delete any nodes that are still dead after adding the uses for the
9841   // new operands.
9842   if (!DeadNodeSet.empty()) {
9843     SmallVector<SDNode *, 16> DeadNodes;
9844     for (SDNode *N : DeadNodeSet)
9845       if (N->use_empty())
9846         DeadNodes.push_back(N);
9847     RemoveDeadNodes(DeadNodes);
9848   }
9849 
9850   if (IP)
9851     CSEMap.InsertNode(N, IP);   // Memoize the new node.
9852   return N;
9853 }
9854 
mutateStrictFPToFP(SDNode * Node)9855 SDNode* SelectionDAG::mutateStrictFPToFP(SDNode *Node) {
9856   unsigned OrigOpc = Node->getOpcode();
9857   unsigned NewOpc;
9858   switch (OrigOpc) {
9859   default:
9860     llvm_unreachable("mutateStrictFPToFP called with unexpected opcode!");
9861 #define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)               \
9862   case ISD::STRICT_##DAGN: NewOpc = ISD::DAGN; break;
9863 #define CMP_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)               \
9864   case ISD::STRICT_##DAGN: NewOpc = ISD::SETCC; break;
9865 #include "llvm/IR/ConstrainedOps.def"
9866   }
9867 
9868   assert(Node->getNumValues() == 2 && "Unexpected number of results!");
9869 
9870   // We're taking this node out of the chain, so we need to re-link things.
9871   SDValue InputChain = Node->getOperand(0);
9872   SDValue OutputChain = SDValue(Node, 1);
9873   ReplaceAllUsesOfValueWith(OutputChain, InputChain);
9874 
9875   SmallVector<SDValue, 3> Ops;
9876   for (unsigned i = 1, e = Node->getNumOperands(); i != e; ++i)
9877     Ops.push_back(Node->getOperand(i));
9878 
9879   SDVTList VTs = getVTList(Node->getValueType(0));
9880   SDNode *Res = MorphNodeTo(Node, NewOpc, VTs, Ops);
9881 
9882   // MorphNodeTo can operate in two ways: if an existing node with the
9883   // specified operands exists, it can just return it.  Otherwise, it
9884   // updates the node in place to have the requested operands.
9885   if (Res == Node) {
9886     // If we updated the node in place, reset the node ID.  To the isel,
9887     // this should be just like a newly allocated machine node.
9888     Res->setNodeId(-1);
9889   } else {
9890     ReplaceAllUsesWith(Node, Res);
9891     RemoveDeadNode(Node);
9892   }
9893 
9894   return Res;
9895 }
9896 
9897 /// getMachineNode - These are used for target selectors to create a new node
9898 /// with specified return type(s), MachineInstr opcode, and operands.
9899 ///
9900 /// Note that getMachineNode returns the resultant node.  If there is already a
9901 /// node of the specified opcode and operands, it returns that node instead of
9902 /// the current one.
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT)9903 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9904                                             EVT VT) {
9905   SDVTList VTs = getVTList(VT);
9906   return getMachineNode(Opcode, dl, VTs, std::nullopt);
9907 }
9908 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT,SDValue Op1)9909 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9910                                             EVT VT, SDValue Op1) {
9911   SDVTList VTs = getVTList(VT);
9912   SDValue Ops[] = { Op1 };
9913   return getMachineNode(Opcode, dl, VTs, Ops);
9914 }
9915 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT,SDValue Op1,SDValue Op2)9916 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9917                                             EVT VT, SDValue Op1, SDValue Op2) {
9918   SDVTList VTs = getVTList(VT);
9919   SDValue Ops[] = { Op1, Op2 };
9920   return getMachineNode(Opcode, dl, VTs, Ops);
9921 }
9922 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT,SDValue Op1,SDValue Op2,SDValue Op3)9923 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9924                                             EVT VT, SDValue Op1, SDValue Op2,
9925                                             SDValue Op3) {
9926   SDVTList VTs = getVTList(VT);
9927   SDValue Ops[] = { Op1, Op2, Op3 };
9928   return getMachineNode(Opcode, dl, VTs, Ops);
9929 }
9930 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT,ArrayRef<SDValue> Ops)9931 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9932                                             EVT VT, ArrayRef<SDValue> Ops) {
9933   SDVTList VTs = getVTList(VT);
9934   return getMachineNode(Opcode, dl, VTs, Ops);
9935 }
9936 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT1,EVT VT2,SDValue Op1,SDValue Op2)9937 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9938                                             EVT VT1, EVT VT2, SDValue Op1,
9939                                             SDValue Op2) {
9940   SDVTList VTs = getVTList(VT1, VT2);
9941   SDValue Ops[] = { Op1, Op2 };
9942   return getMachineNode(Opcode, dl, VTs, Ops);
9943 }
9944 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT1,EVT VT2,SDValue Op1,SDValue Op2,SDValue Op3)9945 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9946                                             EVT VT1, EVT VT2, SDValue Op1,
9947                                             SDValue Op2, SDValue Op3) {
9948   SDVTList VTs = getVTList(VT1, VT2);
9949   SDValue Ops[] = { Op1, Op2, Op3 };
9950   return getMachineNode(Opcode, dl, VTs, Ops);
9951 }
9952 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT1,EVT VT2,ArrayRef<SDValue> Ops)9953 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9954                                             EVT VT1, EVT VT2,
9955                                             ArrayRef<SDValue> Ops) {
9956   SDVTList VTs = getVTList(VT1, VT2);
9957   return getMachineNode(Opcode, dl, VTs, Ops);
9958 }
9959 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT1,EVT VT2,EVT VT3,SDValue Op1,SDValue Op2)9960 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9961                                             EVT VT1, EVT VT2, EVT VT3,
9962                                             SDValue Op1, SDValue Op2) {
9963   SDVTList VTs = getVTList(VT1, VT2, VT3);
9964   SDValue Ops[] = { Op1, Op2 };
9965   return getMachineNode(Opcode, dl, VTs, Ops);
9966 }
9967 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT1,EVT VT2,EVT VT3,SDValue Op1,SDValue Op2,SDValue Op3)9968 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9969                                             EVT VT1, EVT VT2, EVT VT3,
9970                                             SDValue Op1, SDValue Op2,
9971                                             SDValue Op3) {
9972   SDVTList VTs = getVTList(VT1, VT2, VT3);
9973   SDValue Ops[] = { Op1, Op2, Op3 };
9974   return getMachineNode(Opcode, dl, VTs, Ops);
9975 }
9976 
getMachineNode(unsigned Opcode,const SDLoc & dl,EVT VT1,EVT VT2,EVT VT3,ArrayRef<SDValue> Ops)9977 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9978                                             EVT VT1, EVT VT2, EVT VT3,
9979                                             ArrayRef<SDValue> Ops) {
9980   SDVTList VTs = getVTList(VT1, VT2, VT3);
9981   return getMachineNode(Opcode, dl, VTs, Ops);
9982 }
9983 
getMachineNode(unsigned Opcode,const SDLoc & dl,ArrayRef<EVT> ResultTys,ArrayRef<SDValue> Ops)9984 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &dl,
9985                                             ArrayRef<EVT> ResultTys,
9986                                             ArrayRef<SDValue> Ops) {
9987   SDVTList VTs = getVTList(ResultTys);
9988   return getMachineNode(Opcode, dl, VTs, Ops);
9989 }
9990 
getMachineNode(unsigned Opcode,const SDLoc & DL,SDVTList VTs,ArrayRef<SDValue> Ops)9991 MachineSDNode *SelectionDAG::getMachineNode(unsigned Opcode, const SDLoc &DL,
9992                                             SDVTList VTs,
9993                                             ArrayRef<SDValue> Ops) {
9994   bool DoCSE = VTs.VTs[VTs.NumVTs-1] != MVT::Glue;
9995   MachineSDNode *N;
9996   void *IP = nullptr;
9997 
9998   if (DoCSE) {
9999     FoldingSetNodeID ID;
10000     AddNodeIDNode(ID, ~Opcode, VTs, Ops);
10001     IP = nullptr;
10002     if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
10003       return cast<MachineSDNode>(UpdateSDLocOnMergeSDNode(E, DL));
10004     }
10005   }
10006 
10007   // Allocate a new MachineSDNode.
10008   N = newSDNode<MachineSDNode>(~Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
10009   createOperands(N, Ops);
10010 
10011   if (DoCSE)
10012     CSEMap.InsertNode(N, IP);
10013 
10014   InsertNode(N);
10015   NewSDValueDbgMsg(SDValue(N, 0), "Creating new machine node: ", this);
10016   return N;
10017 }
10018 
10019 /// getTargetExtractSubreg - A convenience function for creating
10020 /// TargetOpcode::EXTRACT_SUBREG nodes.
getTargetExtractSubreg(int SRIdx,const SDLoc & DL,EVT VT,SDValue Operand)10021 SDValue SelectionDAG::getTargetExtractSubreg(int SRIdx, const SDLoc &DL, EVT VT,
10022                                              SDValue Operand) {
10023   SDValue SRIdxVal = getTargetConstant(SRIdx, DL, MVT::i32);
10024   SDNode *Subreg = getMachineNode(TargetOpcode::EXTRACT_SUBREG, DL,
10025                                   VT, Operand, SRIdxVal);
10026   return SDValue(Subreg, 0);
10027 }
10028 
10029 /// getTargetInsertSubreg - A convenience function for creating
10030 /// TargetOpcode::INSERT_SUBREG nodes.
getTargetInsertSubreg(int SRIdx,const SDLoc & DL,EVT VT,SDValue Operand,SDValue Subreg)10031 SDValue SelectionDAG::getTargetInsertSubreg(int SRIdx, const SDLoc &DL, EVT VT,
10032                                             SDValue Operand, SDValue Subreg) {
10033   SDValue SRIdxVal = getTargetConstant(SRIdx, DL, MVT::i32);
10034   SDNode *Result = getMachineNode(TargetOpcode::INSERT_SUBREG, DL,
10035                                   VT, Operand, Subreg, SRIdxVal);
10036   return SDValue(Result, 0);
10037 }
10038 
10039 /// getNodeIfExists - Get the specified node if it's already available, or
10040 /// else return NULL.
getNodeIfExists(unsigned Opcode,SDVTList VTList,ArrayRef<SDValue> Ops)10041 SDNode *SelectionDAG::getNodeIfExists(unsigned Opcode, SDVTList VTList,
10042                                       ArrayRef<SDValue> Ops) {
10043   SDNodeFlags Flags;
10044   if (Inserter)
10045     Flags = Inserter->getFlags();
10046   return getNodeIfExists(Opcode, VTList, Ops, Flags);
10047 }
10048 
getNodeIfExists(unsigned Opcode,SDVTList VTList,ArrayRef<SDValue> Ops,const SDNodeFlags Flags)10049 SDNode *SelectionDAG::getNodeIfExists(unsigned Opcode, SDVTList VTList,
10050                                       ArrayRef<SDValue> Ops,
10051                                       const SDNodeFlags Flags) {
10052   if (VTList.VTs[VTList.NumVTs - 1] != MVT::Glue) {
10053     FoldingSetNodeID ID;
10054     AddNodeIDNode(ID, Opcode, VTList, Ops);
10055     void *IP = nullptr;
10056     if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP)) {
10057       E->intersectFlagsWith(Flags);
10058       return E;
10059     }
10060   }
10061   return nullptr;
10062 }
10063 
10064 /// doesNodeExist - Check if a node exists without modifying its flags.
doesNodeExist(unsigned Opcode,SDVTList VTList,ArrayRef<SDValue> Ops)10065 bool SelectionDAG::doesNodeExist(unsigned Opcode, SDVTList VTList,
10066                                  ArrayRef<SDValue> Ops) {
10067   if (VTList.VTs[VTList.NumVTs - 1] != MVT::Glue) {
10068     FoldingSetNodeID ID;
10069     AddNodeIDNode(ID, Opcode, VTList, Ops);
10070     void *IP = nullptr;
10071     if (FindNodeOrInsertPos(ID, SDLoc(), IP))
10072       return true;
10073   }
10074   return false;
10075 }
10076 
10077 /// getDbgValue - Creates a SDDbgValue node.
10078 ///
10079 /// SDNode
getDbgValue(DIVariable * Var,DIExpression * Expr,SDNode * N,unsigned R,bool IsIndirect,const DebugLoc & DL,unsigned O)10080 SDDbgValue *SelectionDAG::getDbgValue(DIVariable *Var, DIExpression *Expr,
10081                                       SDNode *N, unsigned R, bool IsIndirect,
10082                                       const DebugLoc &DL, unsigned O) {
10083   assert(cast<DILocalVariable>(Var)->isValidLocationForIntrinsic(DL) &&
10084          "Expected inlined-at fields to agree");
10085   return new (DbgInfo->getAlloc())
10086       SDDbgValue(DbgInfo->getAlloc(), Var, Expr, SDDbgOperand::fromNode(N, R),
10087                  {}, IsIndirect, DL, O,
10088                  /*IsVariadic=*/false);
10089 }
10090 
10091 /// Constant
getConstantDbgValue(DIVariable * Var,DIExpression * Expr,const Value * C,const DebugLoc & DL,unsigned O)10092 SDDbgValue *SelectionDAG::getConstantDbgValue(DIVariable *Var,
10093                                               DIExpression *Expr,
10094                                               const Value *C,
10095                                               const DebugLoc &DL, unsigned O) {
10096   assert(cast<DILocalVariable>(Var)->isValidLocationForIntrinsic(DL) &&
10097          "Expected inlined-at fields to agree");
10098   return new (DbgInfo->getAlloc())
10099       SDDbgValue(DbgInfo->getAlloc(), Var, Expr, SDDbgOperand::fromConst(C), {},
10100                  /*IsIndirect=*/false, DL, O,
10101                  /*IsVariadic=*/false);
10102 }
10103 
10104 /// FrameIndex
getFrameIndexDbgValue(DIVariable * Var,DIExpression * Expr,unsigned FI,bool IsIndirect,const DebugLoc & DL,unsigned O)10105 SDDbgValue *SelectionDAG::getFrameIndexDbgValue(DIVariable *Var,
10106                                                 DIExpression *Expr, unsigned FI,
10107                                                 bool IsIndirect,
10108                                                 const DebugLoc &DL,
10109                                                 unsigned O) {
10110   assert(cast<DILocalVariable>(Var)->isValidLocationForIntrinsic(DL) &&
10111          "Expected inlined-at fields to agree");
10112   return getFrameIndexDbgValue(Var, Expr, FI, {}, IsIndirect, DL, O);
10113 }
10114 
10115 /// FrameIndex with dependencies
getFrameIndexDbgValue(DIVariable * Var,DIExpression * Expr,unsigned FI,ArrayRef<SDNode * > Dependencies,bool IsIndirect,const DebugLoc & DL,unsigned O)10116 SDDbgValue *SelectionDAG::getFrameIndexDbgValue(DIVariable *Var,
10117                                                 DIExpression *Expr, unsigned FI,
10118                                                 ArrayRef<SDNode *> Dependencies,
10119                                                 bool IsIndirect,
10120                                                 const DebugLoc &DL,
10121                                                 unsigned O) {
10122   assert(cast<DILocalVariable>(Var)->isValidLocationForIntrinsic(DL) &&
10123          "Expected inlined-at fields to agree");
10124   return new (DbgInfo->getAlloc())
10125       SDDbgValue(DbgInfo->getAlloc(), Var, Expr, SDDbgOperand::fromFrameIdx(FI),
10126                  Dependencies, IsIndirect, DL, O,
10127                  /*IsVariadic=*/false);
10128 }
10129 
10130 /// VReg
getVRegDbgValue(DIVariable * Var,DIExpression * Expr,unsigned VReg,bool IsIndirect,const DebugLoc & DL,unsigned O)10131 SDDbgValue *SelectionDAG::getVRegDbgValue(DIVariable *Var, DIExpression *Expr,
10132                                           unsigned VReg, bool IsIndirect,
10133                                           const DebugLoc &DL, unsigned O) {
10134   assert(cast<DILocalVariable>(Var)->isValidLocationForIntrinsic(DL) &&
10135          "Expected inlined-at fields to agree");
10136   return new (DbgInfo->getAlloc())
10137       SDDbgValue(DbgInfo->getAlloc(), Var, Expr, SDDbgOperand::fromVReg(VReg),
10138                  {}, IsIndirect, DL, O,
10139                  /*IsVariadic=*/false);
10140 }
10141 
getDbgValueList(DIVariable * Var,DIExpression * Expr,ArrayRef<SDDbgOperand> Locs,ArrayRef<SDNode * > Dependencies,bool IsIndirect,const DebugLoc & DL,unsigned O,bool IsVariadic)10142 SDDbgValue *SelectionDAG::getDbgValueList(DIVariable *Var, DIExpression *Expr,
10143                                           ArrayRef<SDDbgOperand> Locs,
10144                                           ArrayRef<SDNode *> Dependencies,
10145                                           bool IsIndirect, const DebugLoc &DL,
10146                                           unsigned O, bool IsVariadic) {
10147   assert(cast<DILocalVariable>(Var)->isValidLocationForIntrinsic(DL) &&
10148          "Expected inlined-at fields to agree");
10149   return new (DbgInfo->getAlloc())
10150       SDDbgValue(DbgInfo->getAlloc(), Var, Expr, Locs, Dependencies, IsIndirect,
10151                  DL, O, IsVariadic);
10152 }
10153 
transferDbgValues(SDValue From,SDValue To,unsigned OffsetInBits,unsigned SizeInBits,bool InvalidateDbg)10154 void SelectionDAG::transferDbgValues(SDValue From, SDValue To,
10155                                      unsigned OffsetInBits, unsigned SizeInBits,
10156                                      bool InvalidateDbg) {
10157   SDNode *FromNode = From.getNode();
10158   SDNode *ToNode = To.getNode();
10159   assert(FromNode && ToNode && "Can't modify dbg values");
10160 
10161   // PR35338
10162   // TODO: assert(From != To && "Redundant dbg value transfer");
10163   // TODO: assert(FromNode != ToNode && "Intranode dbg value transfer");
10164   if (From == To || FromNode == ToNode)
10165     return;
10166 
10167   if (!FromNode->getHasDebugValue())
10168     return;
10169 
10170   SDDbgOperand FromLocOp =
10171       SDDbgOperand::fromNode(From.getNode(), From.getResNo());
10172   SDDbgOperand ToLocOp = SDDbgOperand::fromNode(To.getNode(), To.getResNo());
10173 
10174   SmallVector<SDDbgValue *, 2> ClonedDVs;
10175   for (SDDbgValue *Dbg : GetDbgValues(FromNode)) {
10176     if (Dbg->isInvalidated())
10177       continue;
10178 
10179     // TODO: assert(!Dbg->isInvalidated() && "Transfer of invalid dbg value");
10180 
10181     // Create a new location ops vector that is equal to the old vector, but
10182     // with each instance of FromLocOp replaced with ToLocOp.
10183     bool Changed = false;
10184     auto NewLocOps = Dbg->copyLocationOps();
10185     std::replace_if(
10186         NewLocOps.begin(), NewLocOps.end(),
10187         [&Changed, FromLocOp](const SDDbgOperand &Op) {
10188           bool Match = Op == FromLocOp;
10189           Changed |= Match;
10190           return Match;
10191         },
10192         ToLocOp);
10193     // Ignore this SDDbgValue if we didn't find a matching location.
10194     if (!Changed)
10195       continue;
10196 
10197     DIVariable *Var = Dbg->getVariable();
10198     auto *Expr = Dbg->getExpression();
10199     // If a fragment is requested, update the expression.
10200     if (SizeInBits) {
10201       // When splitting a larger (e.g., sign-extended) value whose
10202       // lower bits are described with an SDDbgValue, do not attempt
10203       // to transfer the SDDbgValue to the upper bits.
10204       if (auto FI = Expr->getFragmentInfo())
10205         if (OffsetInBits + SizeInBits > FI->SizeInBits)
10206           continue;
10207       auto Fragment = DIExpression::createFragmentExpression(Expr, OffsetInBits,
10208                                                              SizeInBits);
10209       if (!Fragment)
10210         continue;
10211       Expr = *Fragment;
10212     }
10213 
10214     auto AdditionalDependencies = Dbg->getAdditionalDependencies();
10215     // Clone the SDDbgValue and move it to To.
10216     SDDbgValue *Clone = getDbgValueList(
10217         Var, Expr, NewLocOps, AdditionalDependencies, Dbg->isIndirect(),
10218         Dbg->getDebugLoc(), std::max(ToNode->getIROrder(), Dbg->getOrder()),
10219         Dbg->isVariadic());
10220     ClonedDVs.push_back(Clone);
10221 
10222     if (InvalidateDbg) {
10223       // Invalidate value and indicate the SDDbgValue should not be emitted.
10224       Dbg->setIsInvalidated();
10225       Dbg->setIsEmitted();
10226     }
10227   }
10228 
10229   for (SDDbgValue *Dbg : ClonedDVs) {
10230     assert(is_contained(Dbg->getSDNodes(), ToNode) &&
10231            "Transferred DbgValues should depend on the new SDNode");
10232     AddDbgValue(Dbg, false);
10233   }
10234 }
10235 
salvageDebugInfo(SDNode & N)10236 void SelectionDAG::salvageDebugInfo(SDNode &N) {
10237   if (!N.getHasDebugValue())
10238     return;
10239 
10240   SmallVector<SDDbgValue *, 2> ClonedDVs;
10241   for (auto *DV : GetDbgValues(&N)) {
10242     if (DV->isInvalidated())
10243       continue;
10244     switch (N.getOpcode()) {
10245     default:
10246       break;
10247     case ISD::ADD:
10248       SDValue N0 = N.getOperand(0);
10249       SDValue N1 = N.getOperand(1);
10250       if (!isConstantIntBuildVectorOrConstantInt(N0) &&
10251           isConstantIntBuildVectorOrConstantInt(N1)) {
10252         uint64_t Offset = N.getConstantOperandVal(1);
10253 
10254         // Rewrite an ADD constant node into a DIExpression. Since we are
10255         // performing arithmetic to compute the variable's *value* in the
10256         // DIExpression, we need to mark the expression with a
10257         // DW_OP_stack_value.
10258         auto *DIExpr = DV->getExpression();
10259         auto NewLocOps = DV->copyLocationOps();
10260         bool Changed = false;
10261         for (size_t i = 0; i < NewLocOps.size(); ++i) {
10262           // We're not given a ResNo to compare against because the whole
10263           // node is going away. We know that any ISD::ADD only has one
10264           // result, so we can assume any node match is using the result.
10265           if (NewLocOps[i].getKind() != SDDbgOperand::SDNODE ||
10266               NewLocOps[i].getSDNode() != &N)
10267             continue;
10268           NewLocOps[i] = SDDbgOperand::fromNode(N0.getNode(), N0.getResNo());
10269           SmallVector<uint64_t, 3> ExprOps;
10270           DIExpression::appendOffset(ExprOps, Offset);
10271           DIExpr = DIExpression::appendOpsToArg(DIExpr, ExprOps, i, true);
10272           Changed = true;
10273         }
10274         (void)Changed;
10275         assert(Changed && "Salvage target doesn't use N");
10276 
10277         auto AdditionalDependencies = DV->getAdditionalDependencies();
10278         SDDbgValue *Clone = getDbgValueList(DV->getVariable(), DIExpr,
10279                                             NewLocOps, AdditionalDependencies,
10280                                             DV->isIndirect(), DV->getDebugLoc(),
10281                                             DV->getOrder(), DV->isVariadic());
10282         ClonedDVs.push_back(Clone);
10283         DV->setIsInvalidated();
10284         DV->setIsEmitted();
10285         LLVM_DEBUG(dbgs() << "SALVAGE: Rewriting";
10286                    N0.getNode()->dumprFull(this);
10287                    dbgs() << " into " << *DIExpr << '\n');
10288       }
10289     }
10290   }
10291 
10292   for (SDDbgValue *Dbg : ClonedDVs) {
10293     assert(!Dbg->getSDNodes().empty() &&
10294            "Salvaged DbgValue should depend on a new SDNode");
10295     AddDbgValue(Dbg, false);
10296   }
10297 }
10298 
10299 /// Creates a SDDbgLabel node.
getDbgLabel(DILabel * Label,const DebugLoc & DL,unsigned O)10300 SDDbgLabel *SelectionDAG::getDbgLabel(DILabel *Label,
10301                                       const DebugLoc &DL, unsigned O) {
10302   assert(cast<DILabel>(Label)->isValidLocationForIntrinsic(DL) &&
10303          "Expected inlined-at fields to agree");
10304   return new (DbgInfo->getAlloc()) SDDbgLabel(Label, DL, O);
10305 }
10306 
10307 namespace {
10308 
10309 /// RAUWUpdateListener - Helper for ReplaceAllUsesWith - When the node
10310 /// pointed to by a use iterator is deleted, increment the use iterator
10311 /// so that it doesn't dangle.
10312 ///
10313 class RAUWUpdateListener : public SelectionDAG::DAGUpdateListener {
10314   SDNode::use_iterator &UI;
10315   SDNode::use_iterator &UE;
10316 
NodeDeleted(SDNode * N,SDNode * E)10317   void NodeDeleted(SDNode *N, SDNode *E) override {
10318     // Increment the iterator as needed.
10319     while (UI != UE && N == *UI)
10320       ++UI;
10321   }
10322 
10323 public:
RAUWUpdateListener(SelectionDAG & d,SDNode::use_iterator & ui,SDNode::use_iterator & ue)10324   RAUWUpdateListener(SelectionDAG &d,
10325                      SDNode::use_iterator &ui,
10326                      SDNode::use_iterator &ue)
10327     : SelectionDAG::DAGUpdateListener(d), UI(ui), UE(ue) {}
10328 };
10329 
10330 } // end anonymous namespace
10331 
10332 /// ReplaceAllUsesWith - Modify anything using 'From' to use 'To' instead.
10333 /// This can cause recursive merging of nodes in the DAG.
10334 ///
10335 /// This version assumes From has a single result value.
10336 ///
ReplaceAllUsesWith(SDValue FromN,SDValue To)10337 void SelectionDAG::ReplaceAllUsesWith(SDValue FromN, SDValue To) {
10338   SDNode *From = FromN.getNode();
10339   assert(From->getNumValues() == 1 && FromN.getResNo() == 0 &&
10340          "Cannot replace with this method!");
10341   assert(From != To.getNode() && "Cannot replace uses of with self");
10342 
10343   // Preserve Debug Values
10344   transferDbgValues(FromN, To);
10345   // Preserve extra info.
10346   copyExtraInfo(From, To.getNode());
10347 
10348   // Iterate over all the existing uses of From. New uses will be added
10349   // to the beginning of the use list, which we avoid visiting.
10350   // This specifically avoids visiting uses of From that arise while the
10351   // replacement is happening, because any such uses would be the result
10352   // of CSE: If an existing node looks like From after one of its operands
10353   // is replaced by To, we don't want to replace of all its users with To
10354   // too. See PR3018 for more info.
10355   SDNode::use_iterator UI = From->use_begin(), UE = From->use_end();
10356   RAUWUpdateListener Listener(*this, UI, UE);
10357   while (UI != UE) {
10358     SDNode *User = *UI;
10359 
10360     // This node is about to morph, remove its old self from the CSE maps.
10361     RemoveNodeFromCSEMaps(User);
10362 
10363     // A user can appear in a use list multiple times, and when this
10364     // happens the uses are usually next to each other in the list.
10365     // To help reduce the number of CSE recomputations, process all
10366     // the uses of this user that we can find this way.
10367     do {
10368       SDUse &Use = UI.getUse();
10369       ++UI;
10370       Use.set(To);
10371       if (To->isDivergent() != From->isDivergent())
10372         updateDivergence(User);
10373     } while (UI != UE && *UI == User);
10374     // Now that we have modified User, add it back to the CSE maps.  If it
10375     // already exists there, recursively merge the results together.
10376     AddModifiedNodeToCSEMaps(User);
10377   }
10378 
10379   // If we just RAUW'd the root, take note.
10380   if (FromN == getRoot())
10381     setRoot(To);
10382 }
10383 
10384 /// ReplaceAllUsesWith - Modify anything using 'From' to use 'To' instead.
10385 /// This can cause recursive merging of nodes in the DAG.
10386 ///
10387 /// This version assumes that for each value of From, there is a
10388 /// corresponding value in To in the same position with the same type.
10389 ///
ReplaceAllUsesWith(SDNode * From,SDNode * To)10390 void SelectionDAG::ReplaceAllUsesWith(SDNode *From, SDNode *To) {
10391 #ifndef NDEBUG
10392   for (unsigned i = 0, e = From->getNumValues(); i != e; ++i)
10393     assert((!From->hasAnyUseOfValue(i) ||
10394             From->getValueType(i) == To->getValueType(i)) &&
10395            "Cannot use this version of ReplaceAllUsesWith!");
10396 #endif
10397 
10398   // Handle the trivial case.
10399   if (From == To)
10400     return;
10401 
10402   // Preserve Debug Info. Only do this if there's a use.
10403   for (unsigned i = 0, e = From->getNumValues(); i != e; ++i)
10404     if (From->hasAnyUseOfValue(i)) {
10405       assert((i < To->getNumValues()) && "Invalid To location");
10406       transferDbgValues(SDValue(From, i), SDValue(To, i));
10407     }
10408   // Preserve extra info.
10409   copyExtraInfo(From, To);
10410 
10411   // Iterate over just the existing users of From. See the comments in
10412   // the ReplaceAllUsesWith above.
10413   SDNode::use_iterator UI = From->use_begin(), UE = From->use_end();
10414   RAUWUpdateListener Listener(*this, UI, UE);
10415   while (UI != UE) {
10416     SDNode *User = *UI;
10417 
10418     // This node is about to morph, remove its old self from the CSE maps.
10419     RemoveNodeFromCSEMaps(User);
10420 
10421     // A user can appear in a use list multiple times, and when this
10422     // happens the uses are usually next to each other in the list.
10423     // To help reduce the number of CSE recomputations, process all
10424     // the uses of this user that we can find this way.
10425     do {
10426       SDUse &Use = UI.getUse();
10427       ++UI;
10428       Use.setNode(To);
10429       if (To->isDivergent() != From->isDivergent())
10430         updateDivergence(User);
10431     } while (UI != UE && *UI == User);
10432 
10433     // Now that we have modified User, add it back to the CSE maps.  If it
10434     // already exists there, recursively merge the results together.
10435     AddModifiedNodeToCSEMaps(User);
10436   }
10437 
10438   // If we just RAUW'd the root, take note.
10439   if (From == getRoot().getNode())
10440     setRoot(SDValue(To, getRoot().getResNo()));
10441 }
10442 
10443 /// ReplaceAllUsesWith - Modify anything using 'From' to use 'To' instead.
10444 /// This can cause recursive merging of nodes in the DAG.
10445 ///
10446 /// This version can replace From with any result values.  To must match the
10447 /// number and types of values returned by From.
ReplaceAllUsesWith(SDNode * From,const SDValue * To)10448 void SelectionDAG::ReplaceAllUsesWith(SDNode *From, const SDValue *To) {
10449   if (From->getNumValues() == 1)  // Handle the simple case efficiently.
10450     return ReplaceAllUsesWith(SDValue(From, 0), To[0]);
10451 
10452   for (unsigned i = 0, e = From->getNumValues(); i != e; ++i) {
10453     // Preserve Debug Info.
10454     transferDbgValues(SDValue(From, i), To[i]);
10455     // Preserve extra info.
10456     copyExtraInfo(From, To[i].getNode());
10457   }
10458 
10459   // Iterate over just the existing users of From. See the comments in
10460   // the ReplaceAllUsesWith above.
10461   SDNode::use_iterator UI = From->use_begin(), UE = From->use_end();
10462   RAUWUpdateListener Listener(*this, UI, UE);
10463   while (UI != UE) {
10464     SDNode *User = *UI;
10465 
10466     // This node is about to morph, remove its old self from the CSE maps.
10467     RemoveNodeFromCSEMaps(User);
10468 
10469     // A user can appear in a use list multiple times, and when this happens the
10470     // uses are usually next to each other in the list.  To help reduce the
10471     // number of CSE and divergence recomputations, process all the uses of this
10472     // user that we can find this way.
10473     bool To_IsDivergent = false;
10474     do {
10475       SDUse &Use = UI.getUse();
10476       const SDValue &ToOp = To[Use.getResNo()];
10477       ++UI;
10478       Use.set(ToOp);
10479       To_IsDivergent |= ToOp->isDivergent();
10480     } while (UI != UE && *UI == User);
10481 
10482     if (To_IsDivergent != From->isDivergent())
10483       updateDivergence(User);
10484 
10485     // Now that we have modified User, add it back to the CSE maps.  If it
10486     // already exists there, recursively merge the results together.
10487     AddModifiedNodeToCSEMaps(User);
10488   }
10489 
10490   // If we just RAUW'd the root, take note.
10491   if (From == getRoot().getNode())
10492     setRoot(SDValue(To[getRoot().getResNo()]));
10493 }
10494 
10495 /// ReplaceAllUsesOfValueWith - Replace any uses of From with To, leaving
10496 /// uses of other values produced by From.getNode() alone.  The Deleted
10497 /// vector is handled the same way as for ReplaceAllUsesWith.
ReplaceAllUsesOfValueWith(SDValue From,SDValue To)10498 void SelectionDAG::ReplaceAllUsesOfValueWith(SDValue From, SDValue To){
10499   // Handle the really simple, really trivial case efficiently.
10500   if (From == To) return;
10501 
10502   // Handle the simple, trivial, case efficiently.
10503   if (From.getNode()->getNumValues() == 1) {
10504     ReplaceAllUsesWith(From, To);
10505     return;
10506   }
10507 
10508   // Preserve Debug Info.
10509   transferDbgValues(From, To);
10510   copyExtraInfo(From.getNode(), To.getNode());
10511 
10512   // Iterate over just the existing users of From. See the comments in
10513   // the ReplaceAllUsesWith above.
10514   SDNode::use_iterator UI = From.getNode()->use_begin(),
10515                        UE = From.getNode()->use_end();
10516   RAUWUpdateListener Listener(*this, UI, UE);
10517   while (UI != UE) {
10518     SDNode *User = *UI;
10519     bool UserRemovedFromCSEMaps = false;
10520 
10521     // A user can appear in a use list multiple times, and when this
10522     // happens the uses are usually next to each other in the list.
10523     // To help reduce the number of CSE recomputations, process all
10524     // the uses of this user that we can find this way.
10525     do {
10526       SDUse &Use = UI.getUse();
10527 
10528       // Skip uses of different values from the same node.
10529       if (Use.getResNo() != From.getResNo()) {
10530         ++UI;
10531         continue;
10532       }
10533 
10534       // If this node hasn't been modified yet, it's still in the CSE maps,
10535       // so remove its old self from the CSE maps.
10536       if (!UserRemovedFromCSEMaps) {
10537         RemoveNodeFromCSEMaps(User);
10538         UserRemovedFromCSEMaps = true;
10539       }
10540 
10541       ++UI;
10542       Use.set(To);
10543       if (To->isDivergent() != From->isDivergent())
10544         updateDivergence(User);
10545     } while (UI != UE && *UI == User);
10546     // We are iterating over all uses of the From node, so if a use
10547     // doesn't use the specific value, no changes are made.
10548     if (!UserRemovedFromCSEMaps)
10549       continue;
10550 
10551     // Now that we have modified User, add it back to the CSE maps.  If it
10552     // already exists there, recursively merge the results together.
10553     AddModifiedNodeToCSEMaps(User);
10554   }
10555 
10556   // If we just RAUW'd the root, take note.
10557   if (From == getRoot())
10558     setRoot(To);
10559 }
10560 
10561 namespace {
10562 
10563 /// UseMemo - This class is used by SelectionDAG::ReplaceAllUsesOfValuesWith
10564 /// to record information about a use.
10565 struct UseMemo {
10566   SDNode *User;
10567   unsigned Index;
10568   SDUse *Use;
10569 };
10570 
10571 /// operator< - Sort Memos by User.
operator <(const UseMemo & L,const UseMemo & R)10572 bool operator<(const UseMemo &L, const UseMemo &R) {
10573   return (intptr_t)L.User < (intptr_t)R.User;
10574 }
10575 
10576 /// RAUOVWUpdateListener - Helper for ReplaceAllUsesOfValuesWith - When the node
10577 /// pointed to by a UseMemo is deleted, set the User to nullptr to indicate that
10578 /// the node already has been taken care of recursively.
10579 class RAUOVWUpdateListener : public SelectionDAG::DAGUpdateListener {
10580   SmallVector<UseMemo, 4> &Uses;
10581 
NodeDeleted(SDNode * N,SDNode * E)10582   void NodeDeleted(SDNode *N, SDNode *E) override {
10583     for (UseMemo &Memo : Uses)
10584       if (Memo.User == N)
10585         Memo.User = nullptr;
10586   }
10587 
10588 public:
RAUOVWUpdateListener(SelectionDAG & d,SmallVector<UseMemo,4> & uses)10589   RAUOVWUpdateListener(SelectionDAG &d, SmallVector<UseMemo, 4> &uses)
10590       : SelectionDAG::DAGUpdateListener(d), Uses(uses) {}
10591 };
10592 
10593 } // end anonymous namespace
10594 
calculateDivergence(SDNode * N)10595 bool SelectionDAG::calculateDivergence(SDNode *N) {
10596   if (TLI->isSDNodeAlwaysUniform(N)) {
10597     assert(!TLI->isSDNodeSourceOfDivergence(N, FLI, DA) &&
10598            "Conflicting divergence information!");
10599     return false;
10600   }
10601   if (TLI->isSDNodeSourceOfDivergence(N, FLI, DA))
10602     return true;
10603   for (const auto &Op : N->ops()) {
10604     if (Op.Val.getValueType() != MVT::Other && Op.getNode()->isDivergent())
10605       return true;
10606   }
10607   return false;
10608 }
10609 
updateDivergence(SDNode * N)10610 void SelectionDAG::updateDivergence(SDNode *N) {
10611   SmallVector<SDNode *, 16> Worklist(1, N);
10612   do {
10613     N = Worklist.pop_back_val();
10614     bool IsDivergent = calculateDivergence(N);
10615     if (N->SDNodeBits.IsDivergent != IsDivergent) {
10616       N->SDNodeBits.IsDivergent = IsDivergent;
10617       llvm::append_range(Worklist, N->uses());
10618     }
10619   } while (!Worklist.empty());
10620 }
10621 
CreateTopologicalOrder(std::vector<SDNode * > & Order)10622 void SelectionDAG::CreateTopologicalOrder(std::vector<SDNode *> &Order) {
10623   DenseMap<SDNode *, unsigned> Degree;
10624   Order.reserve(AllNodes.size());
10625   for (auto &N : allnodes()) {
10626     unsigned NOps = N.getNumOperands();
10627     Degree[&N] = NOps;
10628     if (0 == NOps)
10629       Order.push_back(&N);
10630   }
10631   for (size_t I = 0; I != Order.size(); ++I) {
10632     SDNode *N = Order[I];
10633     for (auto *U : N->uses()) {
10634       unsigned &UnsortedOps = Degree[U];
10635       if (0 == --UnsortedOps)
10636         Order.push_back(U);
10637     }
10638   }
10639 }
10640 
10641 #ifndef NDEBUG
VerifyDAGDivergence()10642 void SelectionDAG::VerifyDAGDivergence() {
10643   std::vector<SDNode *> TopoOrder;
10644   CreateTopologicalOrder(TopoOrder);
10645   for (auto *N : TopoOrder) {
10646     assert(calculateDivergence(N) == N->isDivergent() &&
10647            "Divergence bit inconsistency detected");
10648   }
10649 }
10650 #endif
10651 
10652 /// ReplaceAllUsesOfValuesWith - Replace any uses of From with To, leaving
10653 /// uses of other values produced by From.getNode() alone.  The same value
10654 /// may appear in both the From and To list.  The Deleted vector is
10655 /// handled the same way as for ReplaceAllUsesWith.
ReplaceAllUsesOfValuesWith(const SDValue * From,const SDValue * To,unsigned Num)10656 void SelectionDAG::ReplaceAllUsesOfValuesWith(const SDValue *From,
10657                                               const SDValue *To,
10658                                               unsigned Num){
10659   // Handle the simple, trivial case efficiently.
10660   if (Num == 1)
10661     return ReplaceAllUsesOfValueWith(*From, *To);
10662 
10663   transferDbgValues(*From, *To);
10664   copyExtraInfo(From->getNode(), To->getNode());
10665 
10666   // Read up all the uses and make records of them. This helps
10667   // processing new uses that are introduced during the
10668   // replacement process.
10669   SmallVector<UseMemo, 4> Uses;
10670   for (unsigned i = 0; i != Num; ++i) {
10671     unsigned FromResNo = From[i].getResNo();
10672     SDNode *FromNode = From[i].getNode();
10673     for (SDNode::use_iterator UI = FromNode->use_begin(),
10674          E = FromNode->use_end(); UI != E; ++UI) {
10675       SDUse &Use = UI.getUse();
10676       if (Use.getResNo() == FromResNo) {
10677         UseMemo Memo = { *UI, i, &Use };
10678         Uses.push_back(Memo);
10679       }
10680     }
10681   }
10682 
10683   // Sort the uses, so that all the uses from a given User are together.
10684   llvm::sort(Uses);
10685   RAUOVWUpdateListener Listener(*this, Uses);
10686 
10687   for (unsigned UseIndex = 0, UseIndexEnd = Uses.size();
10688        UseIndex != UseIndexEnd; ) {
10689     // We know that this user uses some value of From.  If it is the right
10690     // value, update it.
10691     SDNode *User = Uses[UseIndex].User;
10692     // If the node has been deleted by recursive CSE updates when updating
10693     // another node, then just skip this entry.
10694     if (User == nullptr) {
10695       ++UseIndex;
10696       continue;
10697     }
10698 
10699     // This node is about to morph, remove its old self from the CSE maps.
10700     RemoveNodeFromCSEMaps(User);
10701 
10702     // The Uses array is sorted, so all the uses for a given User
10703     // are next to each other in the list.
10704     // To help reduce the number of CSE recomputations, process all
10705     // the uses of this user that we can find this way.
10706     do {
10707       unsigned i = Uses[UseIndex].Index;
10708       SDUse &Use = *Uses[UseIndex].Use;
10709       ++UseIndex;
10710 
10711       Use.set(To[i]);
10712     } while (UseIndex != UseIndexEnd && Uses[UseIndex].User == User);
10713 
10714     // Now that we have modified User, add it back to the CSE maps.  If it
10715     // already exists there, recursively merge the results together.
10716     AddModifiedNodeToCSEMaps(User);
10717   }
10718 }
10719 
10720 /// AssignTopologicalOrder - Assign a unique node id for each node in the DAG
10721 /// based on their topological order. It returns the maximum id and a vector
10722 /// of the SDNodes* in assigned order by reference.
AssignTopologicalOrder()10723 unsigned SelectionDAG::AssignTopologicalOrder() {
10724   unsigned DAGSize = 0;
10725 
10726   // SortedPos tracks the progress of the algorithm. Nodes before it are
10727   // sorted, nodes after it are unsorted. When the algorithm completes
10728   // it is at the end of the list.
10729   allnodes_iterator SortedPos = allnodes_begin();
10730 
10731   // Visit all the nodes. Move nodes with no operands to the front of
10732   // the list immediately. Annotate nodes that do have operands with their
10733   // operand count. Before we do this, the Node Id fields of the nodes
10734   // may contain arbitrary values. After, the Node Id fields for nodes
10735   // before SortedPos will contain the topological sort index, and the
10736   // Node Id fields for nodes At SortedPos and after will contain the
10737   // count of outstanding operands.
10738   for (SDNode &N : llvm::make_early_inc_range(allnodes())) {
10739     checkForCycles(&N, this);
10740     unsigned Degree = N.getNumOperands();
10741     if (Degree == 0) {
10742       // A node with no uses, add it to the result array immediately.
10743       N.setNodeId(DAGSize++);
10744       allnodes_iterator Q(&N);
10745       if (Q != SortedPos)
10746         SortedPos = AllNodes.insert(SortedPos, AllNodes.remove(Q));
10747       assert(SortedPos != AllNodes.end() && "Overran node list");
10748       ++SortedPos;
10749     } else {
10750       // Temporarily use the Node Id as scratch space for the degree count.
10751       N.setNodeId(Degree);
10752     }
10753   }
10754 
10755   // Visit all the nodes. As we iterate, move nodes into sorted order,
10756   // such that by the time the end is reached all nodes will be sorted.
10757   for (SDNode &Node : allnodes()) {
10758     SDNode *N = &Node;
10759     checkForCycles(N, this);
10760     // N is in sorted position, so all its uses have one less operand
10761     // that needs to be sorted.
10762     for (SDNode *P : N->uses()) {
10763       unsigned Degree = P->getNodeId();
10764       assert(Degree != 0 && "Invalid node degree");
10765       --Degree;
10766       if (Degree == 0) {
10767         // All of P's operands are sorted, so P may sorted now.
10768         P->setNodeId(DAGSize++);
10769         if (P->getIterator() != SortedPos)
10770           SortedPos = AllNodes.insert(SortedPos, AllNodes.remove(P));
10771         assert(SortedPos != AllNodes.end() && "Overran node list");
10772         ++SortedPos;
10773       } else {
10774         // Update P's outstanding operand count.
10775         P->setNodeId(Degree);
10776       }
10777     }
10778     if (Node.getIterator() == SortedPos) {
10779 #ifndef NDEBUG
10780       allnodes_iterator I(N);
10781       SDNode *S = &*++I;
10782       dbgs() << "Overran sorted position:\n";
10783       S->dumprFull(this); dbgs() << "\n";
10784       dbgs() << "Checking if this is due to cycles\n";
10785       checkForCycles(this, true);
10786 #endif
10787       llvm_unreachable(nullptr);
10788     }
10789   }
10790 
10791   assert(SortedPos == AllNodes.end() &&
10792          "Topological sort incomplete!");
10793   assert(AllNodes.front().getOpcode() == ISD::EntryToken &&
10794          "First node in topological sort is not the entry token!");
10795   assert(AllNodes.front().getNodeId() == 0 &&
10796          "First node in topological sort has non-zero id!");
10797   assert(AllNodes.front().getNumOperands() == 0 &&
10798          "First node in topological sort has operands!");
10799   assert(AllNodes.back().getNodeId() == (int)DAGSize-1 &&
10800          "Last node in topologic sort has unexpected id!");
10801   assert(AllNodes.back().use_empty() &&
10802          "Last node in topologic sort has users!");
10803   assert(DAGSize == allnodes_size() && "Node count mismatch!");
10804   return DAGSize;
10805 }
10806 
10807 /// AddDbgValue - Add a dbg_value SDNode. If SD is non-null that means the
10808 /// value is produced by SD.
AddDbgValue(SDDbgValue * DB,bool isParameter)10809 void SelectionDAG::AddDbgValue(SDDbgValue *DB, bool isParameter) {
10810   for (SDNode *SD : DB->getSDNodes()) {
10811     if (!SD)
10812       continue;
10813     assert(DbgInfo->getSDDbgValues(SD).empty() || SD->getHasDebugValue());
10814     SD->setHasDebugValue(true);
10815   }
10816   DbgInfo->add(DB, isParameter);
10817 }
10818 
AddDbgLabel(SDDbgLabel * DB)10819 void SelectionDAG::AddDbgLabel(SDDbgLabel *DB) { DbgInfo->add(DB); }
10820 
makeEquivalentMemoryOrdering(SDValue OldChain,SDValue NewMemOpChain)10821 SDValue SelectionDAG::makeEquivalentMemoryOrdering(SDValue OldChain,
10822                                                    SDValue NewMemOpChain) {
10823   assert(isa<MemSDNode>(NewMemOpChain) && "Expected a memop node");
10824   assert(NewMemOpChain.getValueType() == MVT::Other && "Expected a token VT");
10825   // The new memory operation must have the same position as the old load in
10826   // terms of memory dependency. Create a TokenFactor for the old load and new
10827   // memory operation and update uses of the old load's output chain to use that
10828   // TokenFactor.
10829   if (OldChain == NewMemOpChain || OldChain.use_empty())
10830     return NewMemOpChain;
10831 
10832   SDValue TokenFactor = getNode(ISD::TokenFactor, SDLoc(OldChain), MVT::Other,
10833                                 OldChain, NewMemOpChain);
10834   ReplaceAllUsesOfValueWith(OldChain, TokenFactor);
10835   UpdateNodeOperands(TokenFactor.getNode(), OldChain, NewMemOpChain);
10836   return TokenFactor;
10837 }
10838 
makeEquivalentMemoryOrdering(LoadSDNode * OldLoad,SDValue NewMemOp)10839 SDValue SelectionDAG::makeEquivalentMemoryOrdering(LoadSDNode *OldLoad,
10840                                                    SDValue NewMemOp) {
10841   assert(isa<MemSDNode>(NewMemOp.getNode()) && "Expected a memop node");
10842   SDValue OldChain = SDValue(OldLoad, 1);
10843   SDValue NewMemOpChain = NewMemOp.getValue(1);
10844   return makeEquivalentMemoryOrdering(OldChain, NewMemOpChain);
10845 }
10846 
getSymbolFunctionGlobalAddress(SDValue Op,Function ** OutFunction)10847 SDValue SelectionDAG::getSymbolFunctionGlobalAddress(SDValue Op,
10848                                                      Function **OutFunction) {
10849   assert(isa<ExternalSymbolSDNode>(Op) && "Node should be an ExternalSymbol");
10850 
10851   auto *Symbol = cast<ExternalSymbolSDNode>(Op)->getSymbol();
10852   auto *Module = MF->getFunction().getParent();
10853   auto *Function = Module->getFunction(Symbol);
10854 
10855   if (OutFunction != nullptr)
10856       *OutFunction = Function;
10857 
10858   if (Function != nullptr) {
10859     auto PtrTy = TLI->getPointerTy(getDataLayout(), Function->getAddressSpace());
10860     return getGlobalAddress(Function, SDLoc(Op), PtrTy);
10861   }
10862 
10863   std::string ErrorStr;
10864   raw_string_ostream ErrorFormatter(ErrorStr);
10865   ErrorFormatter << "Undefined external symbol ";
10866   ErrorFormatter << '"' << Symbol << '"';
10867   report_fatal_error(Twine(ErrorFormatter.str()));
10868 }
10869 
10870 //===----------------------------------------------------------------------===//
10871 //                              SDNode Class
10872 //===----------------------------------------------------------------------===//
10873 
isNullConstant(SDValue V)10874 bool llvm::isNullConstant(SDValue V) {
10875   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(V);
10876   return Const != nullptr && Const->isZero();
10877 }
10878 
isNullFPConstant(SDValue V)10879 bool llvm::isNullFPConstant(SDValue V) {
10880   ConstantFPSDNode *Const = dyn_cast<ConstantFPSDNode>(V);
10881   return Const != nullptr && Const->isZero() && !Const->isNegative();
10882 }
10883 
isAllOnesConstant(SDValue V)10884 bool llvm::isAllOnesConstant(SDValue V) {
10885   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(V);
10886   return Const != nullptr && Const->isAllOnes();
10887 }
10888 
isOneConstant(SDValue V)10889 bool llvm::isOneConstant(SDValue V) {
10890   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(V);
10891   return Const != nullptr && Const->isOne();
10892 }
10893 
isMinSignedConstant(SDValue V)10894 bool llvm::isMinSignedConstant(SDValue V) {
10895   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(V);
10896   return Const != nullptr && Const->isMinSignedValue();
10897 }
10898 
isNeutralConstant(unsigned Opcode,SDNodeFlags Flags,SDValue V,unsigned OperandNo)10899 bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V,
10900                              unsigned OperandNo) {
10901   // NOTE: The cases should match with IR's ConstantExpr::getBinOpIdentity().
10902   // TODO: Target-specific opcodes could be added.
10903   if (auto *Const = isConstOrConstSplat(V)) {
10904     switch (Opcode) {
10905     case ISD::ADD:
10906     case ISD::OR:
10907     case ISD::XOR:
10908     case ISD::UMAX:
10909       return Const->isZero();
10910     case ISD::MUL:
10911       return Const->isOne();
10912     case ISD::AND:
10913     case ISD::UMIN:
10914       return Const->isAllOnes();
10915     case ISD::SMAX:
10916       return Const->isMinSignedValue();
10917     case ISD::SMIN:
10918       return Const->isMaxSignedValue();
10919     case ISD::SUB:
10920     case ISD::SHL:
10921     case ISD::SRA:
10922     case ISD::SRL:
10923       return OperandNo == 1 && Const->isZero();
10924     case ISD::UDIV:
10925     case ISD::SDIV:
10926       return OperandNo == 1 && Const->isOne();
10927     }
10928   } else if (auto *ConstFP = isConstOrConstSplatFP(V)) {
10929     switch (Opcode) {
10930     case ISD::FADD:
10931       return ConstFP->isZero() &&
10932              (Flags.hasNoSignedZeros() || ConstFP->isNegative());
10933     case ISD::FSUB:
10934       return OperandNo == 1 && ConstFP->isZero() &&
10935              (Flags.hasNoSignedZeros() || !ConstFP->isNegative());
10936     case ISD::FMUL:
10937       return ConstFP->isExactlyValue(1.0);
10938     case ISD::FDIV:
10939       return OperandNo == 1 && ConstFP->isExactlyValue(1.0);
10940     case ISD::FMINNUM:
10941     case ISD::FMAXNUM: {
10942       // Neutral element for fminnum is NaN, Inf or FLT_MAX, depending on FMF.
10943       EVT VT = V.getValueType();
10944       const fltSemantics &Semantics = SelectionDAG::EVTToAPFloatSemantics(VT);
10945       APFloat NeutralAF = !Flags.hasNoNaNs()
10946                               ? APFloat::getQNaN(Semantics)
10947                               : !Flags.hasNoInfs()
10948                                     ? APFloat::getInf(Semantics)
10949                                     : APFloat::getLargest(Semantics);
10950       if (Opcode == ISD::FMAXNUM)
10951         NeutralAF.changeSign();
10952 
10953       return ConstFP->isExactlyValue(NeutralAF);
10954     }
10955     }
10956   }
10957   return false;
10958 }
10959 
peekThroughBitcasts(SDValue V)10960 SDValue llvm::peekThroughBitcasts(SDValue V) {
10961   while (V.getOpcode() == ISD::BITCAST)
10962     V = V.getOperand(0);
10963   return V;
10964 }
10965 
peekThroughOneUseBitcasts(SDValue V)10966 SDValue llvm::peekThroughOneUseBitcasts(SDValue V) {
10967   while (V.getOpcode() == ISD::BITCAST && V.getOperand(0).hasOneUse())
10968     V = V.getOperand(0);
10969   return V;
10970 }
10971 
peekThroughExtractSubvectors(SDValue V)10972 SDValue llvm::peekThroughExtractSubvectors(SDValue V) {
10973   while (V.getOpcode() == ISD::EXTRACT_SUBVECTOR)
10974     V = V.getOperand(0);
10975   return V;
10976 }
10977 
isBitwiseNot(SDValue V,bool AllowUndefs)10978 bool llvm::isBitwiseNot(SDValue V, bool AllowUndefs) {
10979   if (V.getOpcode() != ISD::XOR)
10980     return false;
10981   V = peekThroughBitcasts(V.getOperand(1));
10982   unsigned NumBits = V.getScalarValueSizeInBits();
10983   ConstantSDNode *C =
10984       isConstOrConstSplat(V, AllowUndefs, /*AllowTruncation*/ true);
10985   return C && (C->getAPIntValue().countTrailingOnes() >= NumBits);
10986 }
10987 
isConstOrConstSplat(SDValue N,bool AllowUndefs,bool AllowTruncation)10988 ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs,
10989                                           bool AllowTruncation) {
10990   EVT VT = N.getValueType();
10991   APInt DemandedElts = VT.isFixedLengthVector()
10992                            ? APInt::getAllOnes(VT.getVectorMinNumElements())
10993                            : APInt(1, 1);
10994   return isConstOrConstSplat(N, DemandedElts, AllowUndefs, AllowTruncation);
10995 }
10996 
isConstOrConstSplat(SDValue N,const APInt & DemandedElts,bool AllowUndefs,bool AllowTruncation)10997 ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
10998                                           bool AllowUndefs,
10999                                           bool AllowTruncation) {
11000   if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
11001     return CN;
11002 
11003   // SplatVectors can truncate their operands. Ignore that case here unless
11004   // AllowTruncation is set.
11005   if (N->getOpcode() == ISD::SPLAT_VECTOR) {
11006     EVT VecEltVT = N->getValueType(0).getVectorElementType();
11007     if (auto *CN = dyn_cast<ConstantSDNode>(N->getOperand(0))) {
11008       EVT CVT = CN->getValueType(0);
11009       assert(CVT.bitsGE(VecEltVT) && "Illegal splat_vector element extension");
11010       if (AllowTruncation || CVT == VecEltVT)
11011         return CN;
11012     }
11013   }
11014 
11015   if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
11016     BitVector UndefElements;
11017     ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements);
11018 
11019     // BuildVectors can truncate their operands. Ignore that case here unless
11020     // AllowTruncation is set.
11021     // TODO: Look into whether we should allow UndefElements in non-DemandedElts
11022     if (CN && (UndefElements.none() || AllowUndefs)) {
11023       EVT CVT = CN->getValueType(0);
11024       EVT NSVT = N.getValueType().getScalarType();
11025       assert(CVT.bitsGE(NSVT) && "Illegal build vector element extension");
11026       if (AllowTruncation || (CVT == NSVT))
11027         return CN;
11028     }
11029   }
11030 
11031   return nullptr;
11032 }
11033 
isConstOrConstSplatFP(SDValue N,bool AllowUndefs)11034 ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) {
11035   EVT VT = N.getValueType();
11036   APInt DemandedElts = VT.isFixedLengthVector()
11037                            ? APInt::getAllOnes(VT.getVectorMinNumElements())
11038                            : APInt(1, 1);
11039   return isConstOrConstSplatFP(N, DemandedElts, AllowUndefs);
11040 }
11041 
isConstOrConstSplatFP(SDValue N,const APInt & DemandedElts,bool AllowUndefs)11042 ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N,
11043                                               const APInt &DemandedElts,
11044                                               bool AllowUndefs) {
11045   if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
11046     return CN;
11047 
11048   if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
11049     BitVector UndefElements;
11050     ConstantFPSDNode *CN =
11051         BV->getConstantFPSplatNode(DemandedElts, &UndefElements);
11052     // TODO: Look into whether we should allow UndefElements in non-DemandedElts
11053     if (CN && (UndefElements.none() || AllowUndefs))
11054       return CN;
11055   }
11056 
11057   if (N.getOpcode() == ISD::SPLAT_VECTOR)
11058     if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N.getOperand(0)))
11059       return CN;
11060 
11061   return nullptr;
11062 }
11063 
isNullOrNullSplat(SDValue N,bool AllowUndefs)11064 bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) {
11065   // TODO: may want to use peekThroughBitcast() here.
11066   ConstantSDNode *C =
11067       isConstOrConstSplat(N, AllowUndefs, /*AllowTruncation=*/true);
11068   return C && C->isZero();
11069 }
11070 
isOneOrOneSplat(SDValue N,bool AllowUndefs)11071 bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) {
11072   ConstantSDNode *C =
11073       isConstOrConstSplat(N, AllowUndefs, /*AllowTruncation*/ true);
11074   return C && C->isOne();
11075 }
11076 
isAllOnesOrAllOnesSplat(SDValue N,bool AllowUndefs)11077 bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
11078   N = peekThroughBitcasts(N);
11079   unsigned BitWidth = N.getScalarValueSizeInBits();
11080   ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
11081   return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
11082 }
11083 
~HandleSDNode()11084 HandleSDNode::~HandleSDNode() {
11085   DropOperands();
11086 }
11087 
GlobalAddressSDNode(unsigned Opc,unsigned Order,const DebugLoc & DL,const GlobalValue * GA,EVT VT,int64_t o,unsigned TF)11088 GlobalAddressSDNode::GlobalAddressSDNode(unsigned Opc, unsigned Order,
11089                                          const DebugLoc &DL,
11090                                          const GlobalValue *GA, EVT VT,
11091                                          int64_t o, unsigned TF)
11092     : SDNode(Opc, Order, DL, getSDVTList(VT)), Offset(o), TargetFlags(TF) {
11093   TheGlobal = GA;
11094 }
11095 
AddrSpaceCastSDNode(unsigned Order,const DebugLoc & dl,EVT VT,unsigned SrcAS,unsigned DestAS)11096 AddrSpaceCastSDNode::AddrSpaceCastSDNode(unsigned Order, const DebugLoc &dl,
11097                                          EVT VT, unsigned SrcAS,
11098                                          unsigned DestAS)
11099     : SDNode(ISD::ADDRSPACECAST, Order, dl, getSDVTList(VT)),
11100       SrcAddrSpace(SrcAS), DestAddrSpace(DestAS) {}
11101 
MemSDNode(unsigned Opc,unsigned Order,const DebugLoc & dl,SDVTList VTs,EVT memvt,MachineMemOperand * mmo)11102 MemSDNode::MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl,
11103                      SDVTList VTs, EVT memvt, MachineMemOperand *mmo)
11104     : SDNode(Opc, Order, dl, VTs), MemoryVT(memvt), MMO(mmo) {
11105   MemSDNodeBits.IsVolatile = MMO->isVolatile();
11106   MemSDNodeBits.IsNonTemporal = MMO->isNonTemporal();
11107   MemSDNodeBits.IsDereferenceable = MMO->isDereferenceable();
11108   MemSDNodeBits.IsInvariant = MMO->isInvariant();
11109 
11110   // We check here that the size of the memory operand fits within the size of
11111   // the MMO. This is because the MMO might indicate only a possible address
11112   // range instead of specifying the affected memory addresses precisely.
11113   // TODO: Make MachineMemOperands aware of scalable vectors.
11114   assert(memvt.getStoreSize().getKnownMinValue() <= MMO->getSize() &&
11115          "Size mismatch!");
11116 }
11117 
11118 /// Profile - Gather unique data for the node.
11119 ///
Profile(FoldingSetNodeID & ID) const11120 void SDNode::Profile(FoldingSetNodeID &ID) const {
11121   AddNodeIDNode(ID, this);
11122 }
11123 
11124 namespace {
11125 
11126   struct EVTArray {
11127     std::vector<EVT> VTs;
11128 
EVTArray__anon569ad0941511::EVTArray11129     EVTArray() {
11130       VTs.reserve(MVT::VALUETYPE_SIZE);
11131       for (unsigned i = 0; i < MVT::VALUETYPE_SIZE; ++i)
11132         VTs.push_back(MVT((MVT::SimpleValueType)i));
11133     }
11134   };
11135 
11136 } // end anonymous namespace
11137 
11138 /// getValueTypeList - Return a pointer to the specified value type.
11139 ///
getValueTypeList(EVT VT)11140 const EVT *SDNode::getValueTypeList(EVT VT) {
11141   static std::set<EVT, EVT::compareRawBits> EVTs;
11142   static EVTArray SimpleVTArray;
11143   static sys::SmartMutex<true> VTMutex;
11144 
11145   if (VT.isExtended()) {
11146     sys::SmartScopedLock<true> Lock(VTMutex);
11147     return &(*EVTs.insert(VT).first);
11148   }
11149   assert(VT.getSimpleVT() < MVT::VALUETYPE_SIZE && "Value type out of range!");
11150   return &SimpleVTArray.VTs[VT.getSimpleVT().SimpleTy];
11151 }
11152 
11153 /// hasNUsesOfValue - Return true if there are exactly NUSES uses of the
11154 /// indicated value.  This method ignores uses of other values defined by this
11155 /// operation.
hasNUsesOfValue(unsigned NUses,unsigned Value) const11156 bool SDNode::hasNUsesOfValue(unsigned NUses, unsigned Value) const {
11157   assert(Value < getNumValues() && "Bad value!");
11158 
11159   // TODO: Only iterate over uses of a given value of the node
11160   for (SDNode::use_iterator UI = use_begin(), E = use_end(); UI != E; ++UI) {
11161     if (UI.getUse().getResNo() == Value) {
11162       if (NUses == 0)
11163         return false;
11164       --NUses;
11165     }
11166   }
11167 
11168   // Found exactly the right number of uses?
11169   return NUses == 0;
11170 }
11171 
11172 /// hasAnyUseOfValue - Return true if there are any use of the indicated
11173 /// value. This method ignores uses of other values defined by this operation.
hasAnyUseOfValue(unsigned Value) const11174 bool SDNode::hasAnyUseOfValue(unsigned Value) const {
11175   assert(Value < getNumValues() && "Bad value!");
11176 
11177   for (SDNode::use_iterator UI = use_begin(), E = use_end(); UI != E; ++UI)
11178     if (UI.getUse().getResNo() == Value)
11179       return true;
11180 
11181   return false;
11182 }
11183 
11184 /// isOnlyUserOf - Return true if this node is the only use of N.
isOnlyUserOf(const SDNode * N) const11185 bool SDNode::isOnlyUserOf(const SDNode *N) const {
11186   bool Seen = false;
11187   for (const SDNode *User : N->uses()) {
11188     if (User == this)
11189       Seen = true;
11190     else
11191       return false;
11192   }
11193 
11194   return Seen;
11195 }
11196 
11197 /// Return true if the only users of N are contained in Nodes.
areOnlyUsersOf(ArrayRef<const SDNode * > Nodes,const SDNode * N)11198 bool SDNode::areOnlyUsersOf(ArrayRef<const SDNode *> Nodes, const SDNode *N) {
11199   bool Seen = false;
11200   for (const SDNode *User : N->uses()) {
11201     if (llvm::is_contained(Nodes, User))
11202       Seen = true;
11203     else
11204       return false;
11205   }
11206 
11207   return Seen;
11208 }
11209 
11210 /// isOperand - Return true if this node is an operand of N.
isOperandOf(const SDNode * N) const11211 bool SDValue::isOperandOf(const SDNode *N) const {
11212   return is_contained(N->op_values(), *this);
11213 }
11214 
isOperandOf(const SDNode * N) const11215 bool SDNode::isOperandOf(const SDNode *N) const {
11216   return any_of(N->op_values(),
11217                 [this](SDValue Op) { return this == Op.getNode(); });
11218 }
11219 
11220 /// reachesChainWithoutSideEffects - Return true if this operand (which must
11221 /// be a chain) reaches the specified operand without crossing any
11222 /// side-effecting instructions on any chain path.  In practice, this looks
11223 /// through token factors and non-volatile loads.  In order to remain efficient,
11224 /// this only looks a couple of nodes in, it does not do an exhaustive search.
11225 ///
11226 /// Note that we only need to examine chains when we're searching for
11227 /// side-effects; SelectionDAG requires that all side-effects are represented
11228 /// by chains, even if another operand would force a specific ordering. This
11229 /// constraint is necessary to allow transformations like splitting loads.
reachesChainWithoutSideEffects(SDValue Dest,unsigned Depth) const11230 bool SDValue::reachesChainWithoutSideEffects(SDValue Dest,
11231                                              unsigned Depth) const {
11232   if (*this == Dest) return true;
11233 
11234   // Don't search too deeply, we just want to be able to see through
11235   // TokenFactor's etc.
11236   if (Depth == 0) return false;
11237 
11238   // If this is a token factor, all inputs to the TF happen in parallel.
11239   if (getOpcode() == ISD::TokenFactor) {
11240     // First, try a shallow search.
11241     if (is_contained((*this)->ops(), Dest)) {
11242       // We found the chain we want as an operand of this TokenFactor.
11243       // Essentially, we reach the chain without side-effects if we could
11244       // serialize the TokenFactor into a simple chain of operations with
11245       // Dest as the last operation. This is automatically true if the
11246       // chain has one use: there are no other ordering constraints.
11247       // If the chain has more than one use, we give up: some other
11248       // use of Dest might force a side-effect between Dest and the current
11249       // node.
11250       if (Dest.hasOneUse())
11251         return true;
11252     }
11253     // Next, try a deep search: check whether every operand of the TokenFactor
11254     // reaches Dest.
11255     return llvm::all_of((*this)->ops(), [=](SDValue Op) {
11256       return Op.reachesChainWithoutSideEffects(Dest, Depth - 1);
11257     });
11258   }
11259 
11260   // Loads don't have side effects, look through them.
11261   if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(*this)) {
11262     if (Ld->isUnordered())
11263       return Ld->getChain().reachesChainWithoutSideEffects(Dest, Depth-1);
11264   }
11265   return false;
11266 }
11267 
hasPredecessor(const SDNode * N) const11268 bool SDNode::hasPredecessor(const SDNode *N) const {
11269   SmallPtrSet<const SDNode *, 32> Visited;
11270   SmallVector<const SDNode *, 16> Worklist;
11271   Worklist.push_back(this);
11272   return hasPredecessorHelper(N, Visited, Worklist);
11273 }
11274 
intersectFlagsWith(const SDNodeFlags Flags)11275 void SDNode::intersectFlagsWith(const SDNodeFlags Flags) {
11276   this->Flags.intersectWith(Flags);
11277 }
11278 
11279 SDValue
matchBinOpReduction(SDNode * Extract,ISD::NodeType & BinOp,ArrayRef<ISD::NodeType> CandidateBinOps,bool AllowPartials)11280 SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
11281                                   ArrayRef<ISD::NodeType> CandidateBinOps,
11282                                   bool AllowPartials) {
11283   // The pattern must end in an extract from index 0.
11284   if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
11285       !isNullConstant(Extract->getOperand(1)))
11286     return SDValue();
11287 
11288   // Match against one of the candidate binary ops.
11289   SDValue Op = Extract->getOperand(0);
11290   if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
11291         return Op.getOpcode() == unsigned(BinOp);
11292       }))
11293     return SDValue();
11294 
11295   // Floating-point reductions may require relaxed constraints on the final step
11296   // of the reduction because they may reorder intermediate operations.
11297   unsigned CandidateBinOp = Op.getOpcode();
11298   if (Op.getValueType().isFloatingPoint()) {
11299     SDNodeFlags Flags = Op->getFlags();
11300     switch (CandidateBinOp) {
11301     case ISD::FADD:
11302       if (!Flags.hasNoSignedZeros() || !Flags.hasAllowReassociation())
11303         return SDValue();
11304       break;
11305     default:
11306       llvm_unreachable("Unhandled FP opcode for binop reduction");
11307     }
11308   }
11309 
11310   // Matching failed - attempt to see if we did enough stages that a partial
11311   // reduction from a subvector is possible.
11312   auto PartialReduction = [&](SDValue Op, unsigned NumSubElts) {
11313     if (!AllowPartials || !Op)
11314       return SDValue();
11315     EVT OpVT = Op.getValueType();
11316     EVT OpSVT = OpVT.getScalarType();
11317     EVT SubVT = EVT::getVectorVT(*getContext(), OpSVT, NumSubElts);
11318     if (!TLI->isExtractSubvectorCheap(SubVT, OpVT, 0))
11319       return SDValue();
11320     BinOp = (ISD::NodeType)CandidateBinOp;
11321     return getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(Op), SubVT, Op,
11322                    getVectorIdxConstant(0, SDLoc(Op)));
11323   };
11324 
11325   // At each stage, we're looking for something that looks like:
11326   // %s = shufflevector <8 x i32> %op, <8 x i32> undef,
11327   //                    <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
11328   //                               i32 undef, i32 undef, i32 undef, i32 undef>
11329   // %a = binop <8 x i32> %op, %s
11330   // Where the mask changes according to the stage. E.g. for a 3-stage pyramid,
11331   // we expect something like:
11332   // <4,5,6,7,u,u,u,u>
11333   // <2,3,u,u,u,u,u,u>
11334   // <1,u,u,u,u,u,u,u>
11335   // While a partial reduction match would be:
11336   // <2,3,u,u,u,u,u,u>
11337   // <1,u,u,u,u,u,u,u>
11338   unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
11339   SDValue PrevOp;
11340   for (unsigned i = 0; i < Stages; ++i) {
11341     unsigned MaskEnd = (1 << i);
11342 
11343     if (Op.getOpcode() != CandidateBinOp)
11344       return PartialReduction(PrevOp, MaskEnd);
11345 
11346     SDValue Op0 = Op.getOperand(0);
11347     SDValue Op1 = Op.getOperand(1);
11348 
11349     ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(Op0);
11350     if (Shuffle) {
11351       Op = Op1;
11352     } else {
11353       Shuffle = dyn_cast<ShuffleVectorSDNode>(Op1);
11354       Op = Op0;
11355     }
11356 
11357     // The first operand of the shuffle should be the same as the other operand
11358     // of the binop.
11359     if (!Shuffle || Shuffle->getOperand(0) != Op)
11360       return PartialReduction(PrevOp, MaskEnd);
11361 
11362     // Verify the shuffle has the expected (at this stage of the pyramid) mask.
11363     for (int Index = 0; Index < (int)MaskEnd; ++Index)
11364       if (Shuffle->getMaskElt(Index) != (int)(MaskEnd + Index))
11365         return PartialReduction(PrevOp, MaskEnd);
11366 
11367     PrevOp = Op;
11368   }
11369 
11370   // Handle subvector reductions, which tend to appear after the shuffle
11371   // reduction stages.
11372   while (Op.getOpcode() == CandidateBinOp) {
11373     unsigned NumElts = Op.getValueType().getVectorNumElements();
11374     SDValue Op0 = Op.getOperand(0);
11375     SDValue Op1 = Op.getOperand(1);
11376     if (Op0.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
11377         Op1.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
11378         Op0.getOperand(0) != Op1.getOperand(0))
11379       break;
11380     SDValue Src = Op0.getOperand(0);
11381     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
11382     if (NumSrcElts != (2 * NumElts))
11383       break;
11384     if (!(Op0.getConstantOperandAPInt(1) == 0 &&
11385           Op1.getConstantOperandAPInt(1) == NumElts) &&
11386         !(Op1.getConstantOperandAPInt(1) == 0 &&
11387           Op0.getConstantOperandAPInt(1) == NumElts))
11388       break;
11389     Op = Src;
11390   }
11391 
11392   BinOp = (ISD::NodeType)CandidateBinOp;
11393   return Op;
11394 }
11395 
UnrollVectorOp(SDNode * N,unsigned ResNE)11396 SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
11397   assert(N->getNumValues() == 1 &&
11398          "Can't unroll a vector with multiple results!");
11399 
11400   EVT VT = N->getValueType(0);
11401   unsigned NE = VT.getVectorNumElements();
11402   EVT EltVT = VT.getVectorElementType();
11403   SDLoc dl(N);
11404 
11405   SmallVector<SDValue, 8> Scalars;
11406   SmallVector<SDValue, 4> Operands(N->getNumOperands());
11407 
11408   // If ResNE is 0, fully unroll the vector op.
11409   if (ResNE == 0)
11410     ResNE = NE;
11411   else if (NE > ResNE)
11412     NE = ResNE;
11413 
11414   unsigned i;
11415   for (i= 0; i != NE; ++i) {
11416     for (unsigned j = 0, e = N->getNumOperands(); j != e; ++j) {
11417       SDValue Operand = N->getOperand(j);
11418       EVT OperandVT = Operand.getValueType();
11419       if (OperandVT.isVector()) {
11420         // A vector operand; extract a single element.
11421         EVT OperandEltVT = OperandVT.getVectorElementType();
11422         Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT,
11423                               Operand, getVectorIdxConstant(i, dl));
11424       } else {
11425         // A scalar operand; just use it as is.
11426         Operands[j] = Operand;
11427       }
11428     }
11429 
11430     switch (N->getOpcode()) {
11431     default: {
11432       Scalars.push_back(getNode(N->getOpcode(), dl, EltVT, Operands,
11433                                 N->getFlags()));
11434       break;
11435     }
11436     case ISD::VSELECT:
11437       Scalars.push_back(getNode(ISD::SELECT, dl, EltVT, Operands));
11438       break;
11439     case ISD::SHL:
11440     case ISD::SRA:
11441     case ISD::SRL:
11442     case ISD::ROTL:
11443     case ISD::ROTR:
11444       Scalars.push_back(getNode(N->getOpcode(), dl, EltVT, Operands[0],
11445                                getShiftAmountOperand(Operands[0].getValueType(),
11446                                                      Operands[1])));
11447       break;
11448     case ISD::SIGN_EXTEND_INREG: {
11449       EVT ExtVT = cast<VTSDNode>(Operands[1])->getVT().getVectorElementType();
11450       Scalars.push_back(getNode(N->getOpcode(), dl, EltVT,
11451                                 Operands[0],
11452                                 getValueType(ExtVT)));
11453     }
11454     }
11455   }
11456 
11457   for (; i < ResNE; ++i)
11458     Scalars.push_back(getUNDEF(EltVT));
11459 
11460   EVT VecVT = EVT::getVectorVT(*getContext(), EltVT, ResNE);
11461   return getBuildVector(VecVT, dl, Scalars);
11462 }
11463 
UnrollVectorOverflowOp(SDNode * N,unsigned ResNE)11464 std::pair<SDValue, SDValue> SelectionDAG::UnrollVectorOverflowOp(
11465     SDNode *N, unsigned ResNE) {
11466   unsigned Opcode = N->getOpcode();
11467   assert((Opcode == ISD::UADDO || Opcode == ISD::SADDO ||
11468           Opcode == ISD::USUBO || Opcode == ISD::SSUBO ||
11469           Opcode == ISD::UMULO || Opcode == ISD::SMULO) &&
11470          "Expected an overflow opcode");
11471 
11472   EVT ResVT = N->getValueType(0);
11473   EVT OvVT = N->getValueType(1);
11474   EVT ResEltVT = ResVT.getVectorElementType();
11475   EVT OvEltVT = OvVT.getVectorElementType();
11476   SDLoc dl(N);
11477 
11478   // If ResNE is 0, fully unroll the vector op.
11479   unsigned NE = ResVT.getVectorNumElements();
11480   if (ResNE == 0)
11481     ResNE = NE;
11482   else if (NE > ResNE)
11483     NE = ResNE;
11484 
11485   SmallVector<SDValue, 8> LHSScalars;
11486   SmallVector<SDValue, 8> RHSScalars;
11487   ExtractVectorElements(N->getOperand(0), LHSScalars, 0, NE);
11488   ExtractVectorElements(N->getOperand(1), RHSScalars, 0, NE);
11489 
11490   EVT SVT = TLI->getSetCCResultType(getDataLayout(), *getContext(), ResEltVT);
11491   SDVTList VTs = getVTList(ResEltVT, SVT);
11492   SmallVector<SDValue, 8> ResScalars;
11493   SmallVector<SDValue, 8> OvScalars;
11494   for (unsigned i = 0; i < NE; ++i) {
11495     SDValue Res = getNode(Opcode, dl, VTs, LHSScalars[i], RHSScalars[i]);
11496     SDValue Ov =
11497         getSelect(dl, OvEltVT, Res.getValue(1),
11498                   getBoolConstant(true, dl, OvEltVT, ResVT),
11499                   getConstant(0, dl, OvEltVT));
11500 
11501     ResScalars.push_back(Res);
11502     OvScalars.push_back(Ov);
11503   }
11504 
11505   ResScalars.append(ResNE - NE, getUNDEF(ResEltVT));
11506   OvScalars.append(ResNE - NE, getUNDEF(OvEltVT));
11507 
11508   EVT NewResVT = EVT::getVectorVT(*getContext(), ResEltVT, ResNE);
11509   EVT NewOvVT = EVT::getVectorVT(*getContext(), OvEltVT, ResNE);
11510   return std::make_pair(getBuildVector(NewResVT, dl, ResScalars),
11511                         getBuildVector(NewOvVT, dl, OvScalars));
11512 }
11513 
areNonVolatileConsecutiveLoads(LoadSDNode * LD,LoadSDNode * Base,unsigned Bytes,int Dist) const11514 bool SelectionDAG::areNonVolatileConsecutiveLoads(LoadSDNode *LD,
11515                                                   LoadSDNode *Base,
11516                                                   unsigned Bytes,
11517                                                   int Dist) const {
11518   if (LD->isVolatile() || Base->isVolatile())
11519     return false;
11520   // TODO: probably too restrictive for atomics, revisit
11521   if (!LD->isSimple())
11522     return false;
11523   if (LD->isIndexed() || Base->isIndexed())
11524     return false;
11525   if (LD->getChain() != Base->getChain())
11526     return false;
11527   EVT VT = LD->getMemoryVT();
11528   if (VT.getSizeInBits() / 8 != Bytes)
11529     return false;
11530 
11531   auto BaseLocDecomp = BaseIndexOffset::match(Base, *this);
11532   auto LocDecomp = BaseIndexOffset::match(LD, *this);
11533 
11534   int64_t Offset = 0;
11535   if (BaseLocDecomp.equalBaseIndex(LocDecomp, *this, Offset))
11536     return (Dist * Bytes == Offset);
11537   return false;
11538 }
11539 
11540 /// InferPtrAlignment - Infer alignment of a load / store address. Return
11541 /// std::nullopt if it cannot be inferred.
InferPtrAlign(SDValue Ptr) const11542 MaybeAlign SelectionDAG::InferPtrAlign(SDValue Ptr) const {
11543   // If this is a GlobalAddress + cst, return the alignment.
11544   const GlobalValue *GV = nullptr;
11545   int64_t GVOffset = 0;
11546   if (TLI->isGAPlusOffset(Ptr.getNode(), GV, GVOffset)) {
11547     unsigned PtrWidth = getDataLayout().getPointerTypeSizeInBits(GV->getType());
11548     KnownBits Known(PtrWidth);
11549     llvm::computeKnownBits(GV, Known, getDataLayout());
11550     unsigned AlignBits = Known.countMinTrailingZeros();
11551     if (AlignBits)
11552       return commonAlignment(Align(1ull << std::min(31U, AlignBits)), GVOffset);
11553   }
11554 
11555   // If this is a direct reference to a stack slot, use information about the
11556   // stack slot's alignment.
11557   int FrameIdx = INT_MIN;
11558   int64_t FrameOffset = 0;
11559   if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(Ptr)) {
11560     FrameIdx = FI->getIndex();
11561   } else if (isBaseWithConstantOffset(Ptr) &&
11562              isa<FrameIndexSDNode>(Ptr.getOperand(0))) {
11563     // Handle FI+Cst
11564     FrameIdx = cast<FrameIndexSDNode>(Ptr.getOperand(0))->getIndex();
11565     FrameOffset = Ptr.getConstantOperandVal(1);
11566   }
11567 
11568   if (FrameIdx != INT_MIN) {
11569     const MachineFrameInfo &MFI = getMachineFunction().getFrameInfo();
11570     return commonAlignment(MFI.getObjectAlign(FrameIdx), FrameOffset);
11571   }
11572 
11573   return std::nullopt;
11574 }
11575 
11576 /// GetSplitDestVTs - Compute the VTs needed for the low/hi parts of a type
11577 /// which is split (or expanded) into two not necessarily identical pieces.
GetSplitDestVTs(const EVT & VT) const11578 std::pair<EVT, EVT> SelectionDAG::GetSplitDestVTs(const EVT &VT) const {
11579   // Currently all types are split in half.
11580   EVT LoVT, HiVT;
11581   if (!VT.isVector())
11582     LoVT = HiVT = TLI->getTypeToTransformTo(*getContext(), VT);
11583   else
11584     LoVT = HiVT = VT.getHalfNumVectorElementsVT(*getContext());
11585 
11586   return std::make_pair(LoVT, HiVT);
11587 }
11588 
11589 /// GetDependentSplitDestVTs - Compute the VTs needed for the low/hi parts of a
11590 /// type, dependent on an enveloping VT that has been split into two identical
11591 /// pieces. Sets the HiIsEmpty flag when hi type has zero storage size.
11592 std::pair<EVT, EVT>
GetDependentSplitDestVTs(const EVT & VT,const EVT & EnvVT,bool * HiIsEmpty) const11593 SelectionDAG::GetDependentSplitDestVTs(const EVT &VT, const EVT &EnvVT,
11594                                        bool *HiIsEmpty) const {
11595   EVT EltTp = VT.getVectorElementType();
11596   // Examples:
11597   //   custom VL=8  with enveloping VL=8/8 yields 8/0 (hi empty)
11598   //   custom VL=9  with enveloping VL=8/8 yields 8/1
11599   //   custom VL=10 with enveloping VL=8/8 yields 8/2
11600   //   etc.
11601   ElementCount VTNumElts = VT.getVectorElementCount();
11602   ElementCount EnvNumElts = EnvVT.getVectorElementCount();
11603   assert(VTNumElts.isScalable() == EnvNumElts.isScalable() &&
11604          "Mixing fixed width and scalable vectors when enveloping a type");
11605   EVT LoVT, HiVT;
11606   if (VTNumElts.getKnownMinValue() > EnvNumElts.getKnownMinValue()) {
11607     LoVT = EVT::getVectorVT(*getContext(), EltTp, EnvNumElts);
11608     HiVT = EVT::getVectorVT(*getContext(), EltTp, VTNumElts - EnvNumElts);
11609     *HiIsEmpty = false;
11610   } else {
11611     // Flag that hi type has zero storage size, but return split envelop type
11612     // (this would be easier if vector types with zero elements were allowed).
11613     LoVT = EVT::getVectorVT(*getContext(), EltTp, VTNumElts);
11614     HiVT = EVT::getVectorVT(*getContext(), EltTp, EnvNumElts);
11615     *HiIsEmpty = true;
11616   }
11617   return std::make_pair(LoVT, HiVT);
11618 }
11619 
11620 /// SplitVector - Split the vector with EXTRACT_SUBVECTOR and return the
11621 /// low/high part.
11622 std::pair<SDValue, SDValue>
SplitVector(const SDValue & N,const SDLoc & DL,const EVT & LoVT,const EVT & HiVT)11623 SelectionDAG::SplitVector(const SDValue &N, const SDLoc &DL, const EVT &LoVT,
11624                           const EVT &HiVT) {
11625   assert(LoVT.isScalableVector() == HiVT.isScalableVector() &&
11626          LoVT.isScalableVector() == N.getValueType().isScalableVector() &&
11627          "Splitting vector with an invalid mixture of fixed and scalable "
11628          "vector types");
11629   assert(LoVT.getVectorMinNumElements() + HiVT.getVectorMinNumElements() <=
11630              N.getValueType().getVectorMinNumElements() &&
11631          "More vector elements requested than available!");
11632   SDValue Lo, Hi;
11633   Lo =
11634       getNode(ISD::EXTRACT_SUBVECTOR, DL, LoVT, N, getVectorIdxConstant(0, DL));
11635   // For scalable vectors it is safe to use LoVT.getVectorMinNumElements()
11636   // (rather than having to use ElementCount), because EXTRACT_SUBVECTOR scales
11637   // IDX with the runtime scaling factor of the result vector type. For
11638   // fixed-width result vectors, that runtime scaling factor is 1.
11639   Hi = getNode(ISD::EXTRACT_SUBVECTOR, DL, HiVT, N,
11640                getVectorIdxConstant(LoVT.getVectorMinNumElements(), DL));
11641   return std::make_pair(Lo, Hi);
11642 }
11643 
SplitEVL(SDValue N,EVT VecVT,const SDLoc & DL)11644 std::pair<SDValue, SDValue> SelectionDAG::SplitEVL(SDValue N, EVT VecVT,
11645                                                    const SDLoc &DL) {
11646   // Split the vector length parameter.
11647   // %evl -> umin(%evl, %halfnumelts) and usubsat(%evl - %halfnumelts).
11648   EVT VT = N.getValueType();
11649   assert(VecVT.getVectorElementCount().isKnownEven() &&
11650          "Expecting the mask to be an evenly-sized vector");
11651   unsigned HalfMinNumElts = VecVT.getVectorMinNumElements() / 2;
11652   SDValue HalfNumElts =
11653       VecVT.isFixedLengthVector()
11654           ? getConstant(HalfMinNumElts, DL, VT)
11655           : getVScale(DL, VT, APInt(VT.getScalarSizeInBits(), HalfMinNumElts));
11656   SDValue Lo = getNode(ISD::UMIN, DL, VT, N, HalfNumElts);
11657   SDValue Hi = getNode(ISD::USUBSAT, DL, VT, N, HalfNumElts);
11658   return std::make_pair(Lo, Hi);
11659 }
11660 
11661 /// Widen the vector up to the next power of two using INSERT_SUBVECTOR.
WidenVector(const SDValue & N,const SDLoc & DL)11662 SDValue SelectionDAG::WidenVector(const SDValue &N, const SDLoc &DL) {
11663   EVT VT = N.getValueType();
11664   EVT WideVT = EVT::getVectorVT(*getContext(), VT.getVectorElementType(),
11665                                 NextPowerOf2(VT.getVectorNumElements()));
11666   return getNode(ISD::INSERT_SUBVECTOR, DL, WideVT, getUNDEF(WideVT), N,
11667                  getVectorIdxConstant(0, DL));
11668 }
11669 
ExtractVectorElements(SDValue Op,SmallVectorImpl<SDValue> & Args,unsigned Start,unsigned Count,EVT EltVT)11670 void SelectionDAG::ExtractVectorElements(SDValue Op,
11671                                          SmallVectorImpl<SDValue> &Args,
11672                                          unsigned Start, unsigned Count,
11673                                          EVT EltVT) {
11674   EVT VT = Op.getValueType();
11675   if (Count == 0)
11676     Count = VT.getVectorNumElements();
11677   if (EltVT == EVT())
11678     EltVT = VT.getVectorElementType();
11679   SDLoc SL(Op);
11680   for (unsigned i = Start, e = Start + Count; i != e; ++i) {
11681     Args.push_back(getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Op,
11682                            getVectorIdxConstant(i, SL)));
11683   }
11684 }
11685 
11686 // getAddressSpace - Return the address space this GlobalAddress belongs to.
getAddressSpace() const11687 unsigned GlobalAddressSDNode::getAddressSpace() const {
11688   return getGlobal()->getType()->getAddressSpace();
11689 }
11690 
getType() const11691 Type *ConstantPoolSDNode::getType() const {
11692   if (isMachineConstantPoolEntry())
11693     return Val.MachineCPVal->getType();
11694   return Val.ConstVal->getType();
11695 }
11696 
isConstantSplat(APInt & SplatValue,APInt & SplatUndef,unsigned & SplatBitSize,bool & HasAnyUndefs,unsigned MinSplatBits,bool IsBigEndian) const11697 bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue, APInt &SplatUndef,
11698                                         unsigned &SplatBitSize,
11699                                         bool &HasAnyUndefs,
11700                                         unsigned MinSplatBits,
11701                                         bool IsBigEndian) const {
11702   EVT VT = getValueType(0);
11703   assert(VT.isVector() && "Expected a vector type");
11704   unsigned VecWidth = VT.getSizeInBits();
11705   if (MinSplatBits > VecWidth)
11706     return false;
11707 
11708   // FIXME: The widths are based on this node's type, but build vectors can
11709   // truncate their operands.
11710   SplatValue = APInt(VecWidth, 0);
11711   SplatUndef = APInt(VecWidth, 0);
11712 
11713   // Get the bits. Bits with undefined values (when the corresponding element
11714   // of the vector is an ISD::UNDEF value) are set in SplatUndef and cleared
11715   // in SplatValue. If any of the values are not constant, give up and return
11716   // false.
11717   unsigned int NumOps = getNumOperands();
11718   assert(NumOps > 0 && "isConstantSplat has 0-size build vector");
11719   unsigned EltWidth = VT.getScalarSizeInBits();
11720 
11721   for (unsigned j = 0; j < NumOps; ++j) {
11722     unsigned i = IsBigEndian ? NumOps - 1 - j : j;
11723     SDValue OpVal = getOperand(i);
11724     unsigned BitPos = j * EltWidth;
11725 
11726     if (OpVal.isUndef())
11727       SplatUndef.setBits(BitPos, BitPos + EltWidth);
11728     else if (auto *CN = dyn_cast<ConstantSDNode>(OpVal))
11729       SplatValue.insertBits(CN->getAPIntValue().zextOrTrunc(EltWidth), BitPos);
11730     else if (auto *CN = dyn_cast<ConstantFPSDNode>(OpVal))
11731       SplatValue.insertBits(CN->getValueAPF().bitcastToAPInt(), BitPos);
11732     else
11733       return false;
11734   }
11735 
11736   // The build_vector is all constants or undefs. Find the smallest element
11737   // size that splats the vector.
11738   HasAnyUndefs = (SplatUndef != 0);
11739 
11740   // FIXME: This does not work for vectors with elements less than 8 bits.
11741   while (VecWidth > 8) {
11742     unsigned HalfSize = VecWidth / 2;
11743     APInt HighValue = SplatValue.extractBits(HalfSize, HalfSize);
11744     APInt LowValue = SplatValue.extractBits(HalfSize, 0);
11745     APInt HighUndef = SplatUndef.extractBits(HalfSize, HalfSize);
11746     APInt LowUndef = SplatUndef.extractBits(HalfSize, 0);
11747 
11748     // If the two halves do not match (ignoring undef bits), stop here.
11749     if ((HighValue & ~LowUndef) != (LowValue & ~HighUndef) ||
11750         MinSplatBits > HalfSize)
11751       break;
11752 
11753     SplatValue = HighValue | LowValue;
11754     SplatUndef = HighUndef & LowUndef;
11755 
11756     VecWidth = HalfSize;
11757   }
11758 
11759   SplatBitSize = VecWidth;
11760   return true;
11761 }
11762 
getSplatValue(const APInt & DemandedElts,BitVector * UndefElements) const11763 SDValue BuildVectorSDNode::getSplatValue(const APInt &DemandedElts,
11764                                          BitVector *UndefElements) const {
11765   unsigned NumOps = getNumOperands();
11766   if (UndefElements) {
11767     UndefElements->clear();
11768     UndefElements->resize(NumOps);
11769   }
11770   assert(NumOps == DemandedElts.getBitWidth() && "Unexpected vector size");
11771   if (!DemandedElts)
11772     return SDValue();
11773   SDValue Splatted;
11774   for (unsigned i = 0; i != NumOps; ++i) {
11775     if (!DemandedElts[i])
11776       continue;
11777     SDValue Op = getOperand(i);
11778     if (Op.isUndef()) {
11779       if (UndefElements)
11780         (*UndefElements)[i] = true;
11781     } else if (!Splatted) {
11782       Splatted = Op;
11783     } else if (Splatted != Op) {
11784       return SDValue();
11785     }
11786   }
11787 
11788   if (!Splatted) {
11789     unsigned FirstDemandedIdx = DemandedElts.countTrailingZeros();
11790     assert(getOperand(FirstDemandedIdx).isUndef() &&
11791            "Can only have a splat without a constant for all undefs.");
11792     return getOperand(FirstDemandedIdx);
11793   }
11794 
11795   return Splatted;
11796 }
11797 
getSplatValue(BitVector * UndefElements) const11798 SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
11799   APInt DemandedElts = APInt::getAllOnes(getNumOperands());
11800   return getSplatValue(DemandedElts, UndefElements);
11801 }
11802 
getRepeatedSequence(const APInt & DemandedElts,SmallVectorImpl<SDValue> & Sequence,BitVector * UndefElements) const11803 bool BuildVectorSDNode::getRepeatedSequence(const APInt &DemandedElts,
11804                                             SmallVectorImpl<SDValue> &Sequence,
11805                                             BitVector *UndefElements) const {
11806   unsigned NumOps = getNumOperands();
11807   Sequence.clear();
11808   if (UndefElements) {
11809     UndefElements->clear();
11810     UndefElements->resize(NumOps);
11811   }
11812   assert(NumOps == DemandedElts.getBitWidth() && "Unexpected vector size");
11813   if (!DemandedElts || NumOps < 2 || !isPowerOf2_32(NumOps))
11814     return false;
11815 
11816   // Set the undefs even if we don't find a sequence (like getSplatValue).
11817   if (UndefElements)
11818     for (unsigned I = 0; I != NumOps; ++I)
11819       if (DemandedElts[I] && getOperand(I).isUndef())
11820         (*UndefElements)[I] = true;
11821 
11822   // Iteratively widen the sequence length looking for repetitions.
11823   for (unsigned SeqLen = 1; SeqLen < NumOps; SeqLen *= 2) {
11824     Sequence.append(SeqLen, SDValue());
11825     for (unsigned I = 0; I != NumOps; ++I) {
11826       if (!DemandedElts[I])
11827         continue;
11828       SDValue &SeqOp = Sequence[I % SeqLen];
11829       SDValue Op = getOperand(I);
11830       if (Op.isUndef()) {
11831         if (!SeqOp)
11832           SeqOp = Op;
11833         continue;
11834       }
11835       if (SeqOp && !SeqOp.isUndef() && SeqOp != Op) {
11836         Sequence.clear();
11837         break;
11838       }
11839       SeqOp = Op;
11840     }
11841     if (!Sequence.empty())
11842       return true;
11843   }
11844 
11845   assert(Sequence.empty() && "Failed to empty non-repeating sequence pattern");
11846   return false;
11847 }
11848 
getRepeatedSequence(SmallVectorImpl<SDValue> & Sequence,BitVector * UndefElements) const11849 bool BuildVectorSDNode::getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
11850                                             BitVector *UndefElements) const {
11851   APInt DemandedElts = APInt::getAllOnes(getNumOperands());
11852   return getRepeatedSequence(DemandedElts, Sequence, UndefElements);
11853 }
11854 
11855 ConstantSDNode *
getConstantSplatNode(const APInt & DemandedElts,BitVector * UndefElements) const11856 BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
11857                                         BitVector *UndefElements) const {
11858   return dyn_cast_or_null<ConstantSDNode>(
11859       getSplatValue(DemandedElts, UndefElements));
11860 }
11861 
11862 ConstantSDNode *
getConstantSplatNode(BitVector * UndefElements) const11863 BuildVectorSDNode::getConstantSplatNode(BitVector *UndefElements) const {
11864   return dyn_cast_or_null<ConstantSDNode>(getSplatValue(UndefElements));
11865 }
11866 
11867 ConstantFPSDNode *
getConstantFPSplatNode(const APInt & DemandedElts,BitVector * UndefElements) const11868 BuildVectorSDNode::getConstantFPSplatNode(const APInt &DemandedElts,
11869                                           BitVector *UndefElements) const {
11870   return dyn_cast_or_null<ConstantFPSDNode>(
11871       getSplatValue(DemandedElts, UndefElements));
11872 }
11873 
11874 ConstantFPSDNode *
getConstantFPSplatNode(BitVector * UndefElements) const11875 BuildVectorSDNode::getConstantFPSplatNode(BitVector *UndefElements) const {
11876   return dyn_cast_or_null<ConstantFPSDNode>(getSplatValue(UndefElements));
11877 }
11878 
11879 int32_t
getConstantFPSplatPow2ToLog2Int(BitVector * UndefElements,uint32_t BitWidth) const11880 BuildVectorSDNode::getConstantFPSplatPow2ToLog2Int(BitVector *UndefElements,
11881                                                    uint32_t BitWidth) const {
11882   if (ConstantFPSDNode *CN =
11883           dyn_cast_or_null<ConstantFPSDNode>(getSplatValue(UndefElements))) {
11884     bool IsExact;
11885     APSInt IntVal(BitWidth);
11886     const APFloat &APF = CN->getValueAPF();
11887     if (APF.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact) !=
11888             APFloat::opOK ||
11889         !IsExact)
11890       return -1;
11891 
11892     return IntVal.exactLogBase2();
11893   }
11894   return -1;
11895 }
11896 
getConstantRawBits(bool IsLittleEndian,unsigned DstEltSizeInBits,SmallVectorImpl<APInt> & RawBitElements,BitVector & UndefElements) const11897 bool BuildVectorSDNode::getConstantRawBits(
11898     bool IsLittleEndian, unsigned DstEltSizeInBits,
11899     SmallVectorImpl<APInt> &RawBitElements, BitVector &UndefElements) const {
11900   // Early-out if this contains anything but Undef/Constant/ConstantFP.
11901   if (!isConstant())
11902     return false;
11903 
11904   unsigned NumSrcOps = getNumOperands();
11905   unsigned SrcEltSizeInBits = getValueType(0).getScalarSizeInBits();
11906   assert(((NumSrcOps * SrcEltSizeInBits) % DstEltSizeInBits) == 0 &&
11907          "Invalid bitcast scale");
11908 
11909   // Extract raw src bits.
11910   SmallVector<APInt> SrcBitElements(NumSrcOps,
11911                                     APInt::getNullValue(SrcEltSizeInBits));
11912   BitVector SrcUndeElements(NumSrcOps, false);
11913 
11914   for (unsigned I = 0; I != NumSrcOps; ++I) {
11915     SDValue Op = getOperand(I);
11916     if (Op.isUndef()) {
11917       SrcUndeElements.set(I);
11918       continue;
11919     }
11920     auto *CInt = dyn_cast<ConstantSDNode>(Op);
11921     auto *CFP = dyn_cast<ConstantFPSDNode>(Op);
11922     assert((CInt || CFP) && "Unknown constant");
11923     SrcBitElements[I] = CInt ? CInt->getAPIntValue().trunc(SrcEltSizeInBits)
11924                              : CFP->getValueAPF().bitcastToAPInt();
11925   }
11926 
11927   // Recast to dst width.
11928   recastRawBits(IsLittleEndian, DstEltSizeInBits, RawBitElements,
11929                 SrcBitElements, UndefElements, SrcUndeElements);
11930   return true;
11931 }
11932 
recastRawBits(bool IsLittleEndian,unsigned DstEltSizeInBits,SmallVectorImpl<APInt> & DstBitElements,ArrayRef<APInt> SrcBitElements,BitVector & DstUndefElements,const BitVector & SrcUndefElements)11933 void BuildVectorSDNode::recastRawBits(bool IsLittleEndian,
11934                                       unsigned DstEltSizeInBits,
11935                                       SmallVectorImpl<APInt> &DstBitElements,
11936                                       ArrayRef<APInt> SrcBitElements,
11937                                       BitVector &DstUndefElements,
11938                                       const BitVector &SrcUndefElements) {
11939   unsigned NumSrcOps = SrcBitElements.size();
11940   unsigned SrcEltSizeInBits = SrcBitElements[0].getBitWidth();
11941   assert(((NumSrcOps * SrcEltSizeInBits) % DstEltSizeInBits) == 0 &&
11942          "Invalid bitcast scale");
11943   assert(NumSrcOps == SrcUndefElements.size() &&
11944          "Vector size mismatch");
11945 
11946   unsigned NumDstOps = (NumSrcOps * SrcEltSizeInBits) / DstEltSizeInBits;
11947   DstUndefElements.clear();
11948   DstUndefElements.resize(NumDstOps, false);
11949   DstBitElements.assign(NumDstOps, APInt::getNullValue(DstEltSizeInBits));
11950 
11951   // Concatenate src elements constant bits together into dst element.
11952   if (SrcEltSizeInBits <= DstEltSizeInBits) {
11953     unsigned Scale = DstEltSizeInBits / SrcEltSizeInBits;
11954     for (unsigned I = 0; I != NumDstOps; ++I) {
11955       DstUndefElements.set(I);
11956       APInt &DstBits = DstBitElements[I];
11957       for (unsigned J = 0; J != Scale; ++J) {
11958         unsigned Idx = (I * Scale) + (IsLittleEndian ? J : (Scale - J - 1));
11959         if (SrcUndefElements[Idx])
11960           continue;
11961         DstUndefElements.reset(I);
11962         const APInt &SrcBits = SrcBitElements[Idx];
11963         assert(SrcBits.getBitWidth() == SrcEltSizeInBits &&
11964                "Illegal constant bitwidths");
11965         DstBits.insertBits(SrcBits, J * SrcEltSizeInBits);
11966       }
11967     }
11968     return;
11969   }
11970 
11971   // Split src element constant bits into dst elements.
11972   unsigned Scale = SrcEltSizeInBits / DstEltSizeInBits;
11973   for (unsigned I = 0; I != NumSrcOps; ++I) {
11974     if (SrcUndefElements[I]) {
11975       DstUndefElements.set(I * Scale, (I + 1) * Scale);
11976       continue;
11977     }
11978     const APInt &SrcBits = SrcBitElements[I];
11979     for (unsigned J = 0; J != Scale; ++J) {
11980       unsigned Idx = (I * Scale) + (IsLittleEndian ? J : (Scale - J - 1));
11981       APInt &DstBits = DstBitElements[Idx];
11982       DstBits = SrcBits.extractBits(DstEltSizeInBits, J * DstEltSizeInBits);
11983     }
11984   }
11985 }
11986 
isConstant() const11987 bool BuildVectorSDNode::isConstant() const {
11988   for (const SDValue &Op : op_values()) {
11989     unsigned Opc = Op.getOpcode();
11990     if (Opc != ISD::UNDEF && Opc != ISD::Constant && Opc != ISD::ConstantFP)
11991       return false;
11992   }
11993   return true;
11994 }
11995 
11996 std::optional<std::pair<APInt, APInt>>
isConstantSequence() const11997 BuildVectorSDNode::isConstantSequence() const {
11998   unsigned NumOps = getNumOperands();
11999   if (NumOps < 2)
12000     return std::nullopt;
12001 
12002   if (!isa<ConstantSDNode>(getOperand(0)) ||
12003       !isa<ConstantSDNode>(getOperand(1)))
12004     return std::nullopt;
12005 
12006   unsigned EltSize = getValueType(0).getScalarSizeInBits();
12007   APInt Start = getConstantOperandAPInt(0).trunc(EltSize);
12008   APInt Stride = getConstantOperandAPInt(1).trunc(EltSize) - Start;
12009 
12010   if (Stride.isZero())
12011     return std::nullopt;
12012 
12013   for (unsigned i = 2; i < NumOps; ++i) {
12014     if (!isa<ConstantSDNode>(getOperand(i)))
12015       return std::nullopt;
12016 
12017     APInt Val = getConstantOperandAPInt(i).trunc(EltSize);
12018     if (Val != (Start + (Stride * i)))
12019       return std::nullopt;
12020   }
12021 
12022   return std::make_pair(Start, Stride);
12023 }
12024 
isSplatMask(const int * Mask,EVT VT)12025 bool ShuffleVectorSDNode::isSplatMask(const int *Mask, EVT VT) {
12026   // Find the first non-undef value in the shuffle mask.
12027   unsigned i, e;
12028   for (i = 0, e = VT.getVectorNumElements(); i != e && Mask[i] < 0; ++i)
12029     /* search */;
12030 
12031   // If all elements are undefined, this shuffle can be considered a splat
12032   // (although it should eventually get simplified away completely).
12033   if (i == e)
12034     return true;
12035 
12036   // Make sure all remaining elements are either undef or the same as the first
12037   // non-undef value.
12038   for (int Idx = Mask[i]; i != e; ++i)
12039     if (Mask[i] >= 0 && Mask[i] != Idx)
12040       return false;
12041   return true;
12042 }
12043 
12044 // Returns the SDNode if it is a constant integer BuildVector
12045 // or constant integer.
isConstantIntBuildVectorOrConstantInt(SDValue N) const12046 SDNode *SelectionDAG::isConstantIntBuildVectorOrConstantInt(SDValue N) const {
12047   if (isa<ConstantSDNode>(N))
12048     return N.getNode();
12049   if (ISD::isBuildVectorOfConstantSDNodes(N.getNode()))
12050     return N.getNode();
12051   // Treat a GlobalAddress supporting constant offset folding as a
12052   // constant integer.
12053   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N))
12054     if (GA->getOpcode() == ISD::GlobalAddress &&
12055         TLI->isOffsetFoldingLegal(GA))
12056       return GA;
12057   if ((N.getOpcode() == ISD::SPLAT_VECTOR) &&
12058       isa<ConstantSDNode>(N.getOperand(0)))
12059     return N.getNode();
12060   return nullptr;
12061 }
12062 
12063 // Returns the SDNode if it is a constant float BuildVector
12064 // or constant float.
isConstantFPBuildVectorOrConstantFP(SDValue N) const12065 SDNode *SelectionDAG::isConstantFPBuildVectorOrConstantFP(SDValue N) const {
12066   if (isa<ConstantFPSDNode>(N))
12067     return N.getNode();
12068 
12069   if (ISD::isBuildVectorOfConstantFPSDNodes(N.getNode()))
12070     return N.getNode();
12071 
12072   if ((N.getOpcode() == ISD::SPLAT_VECTOR) &&
12073       isa<ConstantFPSDNode>(N.getOperand(0)))
12074     return N.getNode();
12075 
12076   return nullptr;
12077 }
12078 
createOperands(SDNode * Node,ArrayRef<SDValue> Vals)12079 void SelectionDAG::createOperands(SDNode *Node, ArrayRef<SDValue> Vals) {
12080   assert(!Node->OperandList && "Node already has operands");
12081   assert(SDNode::getMaxNumOperands() >= Vals.size() &&
12082          "too many operands to fit into SDNode");
12083   SDUse *Ops = OperandRecycler.allocate(
12084       ArrayRecycler<SDUse>::Capacity::get(Vals.size()), OperandAllocator);
12085 
12086   bool IsDivergent = false;
12087   for (unsigned I = 0; I != Vals.size(); ++I) {
12088     Ops[I].setUser(Node);
12089     Ops[I].setInitial(Vals[I]);
12090     if (Ops[I].Val.getValueType() != MVT::Other) // Skip Chain. It does not carry divergence.
12091       IsDivergent |= Ops[I].getNode()->isDivergent();
12092   }
12093   Node->NumOperands = Vals.size();
12094   Node->OperandList = Ops;
12095   if (!TLI->isSDNodeAlwaysUniform(Node)) {
12096     IsDivergent |= TLI->isSDNodeSourceOfDivergence(Node, FLI, DA);
12097     Node->SDNodeBits.IsDivergent = IsDivergent;
12098   }
12099   checkForCycles(Node);
12100 }
12101 
getTokenFactor(const SDLoc & DL,SmallVectorImpl<SDValue> & Vals)12102 SDValue SelectionDAG::getTokenFactor(const SDLoc &DL,
12103                                      SmallVectorImpl<SDValue> &Vals) {
12104   size_t Limit = SDNode::getMaxNumOperands();
12105   while (Vals.size() > Limit) {
12106     unsigned SliceIdx = Vals.size() - Limit;
12107     auto ExtractedTFs = ArrayRef<SDValue>(Vals).slice(SliceIdx, Limit);
12108     SDValue NewTF = getNode(ISD::TokenFactor, DL, MVT::Other, ExtractedTFs);
12109     Vals.erase(Vals.begin() + SliceIdx, Vals.end());
12110     Vals.emplace_back(NewTF);
12111   }
12112   return getNode(ISD::TokenFactor, DL, MVT::Other, Vals);
12113 }
12114 
getNeutralElement(unsigned Opcode,const SDLoc & DL,EVT VT,SDNodeFlags Flags)12115 SDValue SelectionDAG::getNeutralElement(unsigned Opcode, const SDLoc &DL,
12116                                         EVT VT, SDNodeFlags Flags) {
12117   switch (Opcode) {
12118   default:
12119     return SDValue();
12120   case ISD::ADD:
12121   case ISD::OR:
12122   case ISD::XOR:
12123   case ISD::UMAX:
12124     return getConstant(0, DL, VT);
12125   case ISD::MUL:
12126     return getConstant(1, DL, VT);
12127   case ISD::AND:
12128   case ISD::UMIN:
12129     return getAllOnesConstant(DL, VT);
12130   case ISD::SMAX:
12131     return getConstant(APInt::getSignedMinValue(VT.getSizeInBits()), DL, VT);
12132   case ISD::SMIN:
12133     return getConstant(APInt::getSignedMaxValue(VT.getSizeInBits()), DL, VT);
12134   case ISD::FADD:
12135     return getConstantFP(-0.0, DL, VT);
12136   case ISD::FMUL:
12137     return getConstantFP(1.0, DL, VT);
12138   case ISD::FMINNUM:
12139   case ISD::FMAXNUM: {
12140     // Neutral element for fminnum is NaN, Inf or FLT_MAX, depending on FMF.
12141     const fltSemantics &Semantics = EVTToAPFloatSemantics(VT);
12142     APFloat NeutralAF = !Flags.hasNoNaNs() ? APFloat::getQNaN(Semantics) :
12143                         !Flags.hasNoInfs() ? APFloat::getInf(Semantics) :
12144                         APFloat::getLargest(Semantics);
12145     if (Opcode == ISD::FMAXNUM)
12146       NeutralAF.changeSign();
12147 
12148     return getConstantFP(NeutralAF, DL, VT);
12149   }
12150   }
12151 }
12152 
copyExtraInfo(SDNode * From,SDNode * To)12153 void SelectionDAG::copyExtraInfo(SDNode *From, SDNode *To) {
12154   assert(From && To && "Invalid SDNode; empty source SDValue?");
12155   auto I = SDEI.find(From);
12156   if (I == SDEI.end())
12157     return;
12158 
12159   // Use of operator[] on the DenseMap may cause an insertion, which invalidates
12160   // the iterator, hence the need to make a copy to prevent a use-after-free.
12161   NodeExtraInfo Copy = I->second;
12162   SDEI[To] = std::move(Copy);
12163 }
12164 
12165 #ifndef NDEBUG
checkForCyclesHelper(const SDNode * N,SmallPtrSetImpl<const SDNode * > & Visited,SmallPtrSetImpl<const SDNode * > & Checked,const llvm::SelectionDAG * DAG)12166 static void checkForCyclesHelper(const SDNode *N,
12167                                  SmallPtrSetImpl<const SDNode*> &Visited,
12168                                  SmallPtrSetImpl<const SDNode*> &Checked,
12169                                  const llvm::SelectionDAG *DAG) {
12170   // If this node has already been checked, don't check it again.
12171   if (Checked.count(N))
12172     return;
12173 
12174   // If a node has already been visited on this depth-first walk, reject it as
12175   // a cycle.
12176   if (!Visited.insert(N).second) {
12177     errs() << "Detected cycle in SelectionDAG\n";
12178     dbgs() << "Offending node:\n";
12179     N->dumprFull(DAG); dbgs() << "\n";
12180     abort();
12181   }
12182 
12183   for (const SDValue &Op : N->op_values())
12184     checkForCyclesHelper(Op.getNode(), Visited, Checked, DAG);
12185 
12186   Checked.insert(N);
12187   Visited.erase(N);
12188 }
12189 #endif
12190 
checkForCycles(const llvm::SDNode * N,const llvm::SelectionDAG * DAG,bool force)12191 void llvm::checkForCycles(const llvm::SDNode *N,
12192                           const llvm::SelectionDAG *DAG,
12193                           bool force) {
12194 #ifndef NDEBUG
12195   bool check = force;
12196 #ifdef EXPENSIVE_CHECKS
12197   check = true;
12198 #endif  // EXPENSIVE_CHECKS
12199   if (check) {
12200     assert(N && "Checking nonexistent SDNode");
12201     SmallPtrSet<const SDNode*, 32> visited;
12202     SmallPtrSet<const SDNode*, 32> checked;
12203     checkForCyclesHelper(N, visited, checked, DAG);
12204   }
12205 #endif  // !NDEBUG
12206 }
12207 
checkForCycles(const llvm::SelectionDAG * DAG,bool force)12208 void llvm::checkForCycles(const llvm::SelectionDAG *DAG, bool force) {
12209   checkForCycles(DAG->getRoot().getNode(), DAG, force);
12210 }
12211